StateMachineImpl.java
package io.github.jonloucks.concurrency.impl;
import io.github.jonloucks.concurrency.api.ConcurrencyException;
import io.github.jonloucks.concurrency.api.StateMachine;
import io.github.jonloucks.concurrency.api.Waitable;
import java.time.Duration;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.function.Supplier;
import static io.github.jonloucks.concurrency.impl.Internal.*;
import static io.github.jonloucks.contracts.api.Checks.*;
import static java.util.Optional.ofNullable;
final class StateMachineImpl<S> implements StateMachine<S> {
@Override
public boolean setState(String event, S state) {
if (isTransitionAllowed(event, existsCheck(state))) {
currentState.accept(state);
return true;
}
return false;
}
@Override
public S get() {
return currentState.get();
}
@Override
public Optional<S> getIf(Predicate<S> predicate) {
return currentState.getIf(predicate);
}
@Override
public Optional<S> getWhen(Predicate<S> predicate) {
return currentState.getWhen(predicate);
}
@Override
public Optional<S> getWhen(Predicate<S> predicate, Duration timeout) {
return currentState.getWhen(predicate, timeout);
}
@Override
public <B extends Transition.Builder<B, S, R>, R> R transition(Consumer<Transition.Builder<B, S, R>> builderConsumer) {
final TransitionBuilderImpl<B,S,R> builder = new TransitionBuilderImpl<>();
builderConsumerCheck(builderConsumer).accept(builder);
return transition(builder);
}
@Override
public <R> R transition(Transition<S, R> transition) {
final Transition<S,R> t = transitionCheck(transition);
if (isTransitionAllowed(t.getEvent(), t.getSuccessState())) {
try {
return handleSuccess(t);
} catch (Throwable thrown) {
return handleError(t, thrown);
}
} else {
return handleDenied(t);
}
}
@Override
public boolean hasState(S state) {
return stateToRulesLookup.containsKey(stateCheck(state));
}
@Override
public boolean isTransitionAllowed(String event, S state) {
final String validEvent = Internal.eventCheck(event);
final S toState = stateCheck(state);
final S fromState = getState();
if (hasState(toState) && !fromState.equals(toState)) {
final Set<Rule<S>> rules = stateToRulesLookup.get(fromState);
if (ofNullable(rules).isPresent() && !rules.isEmpty()) {
return rules.stream().allMatch(r -> r.canTransition(validEvent, toState));
}
return true;
}
return false;
}
StateMachineImpl(Config<S> config) {
final Config<S> validConfig = configCheck(config);
final S validInitialState = validConfig.getInitial().orElseThrow(this::getInitialStateNotPresentException);
this.currentState = new WaitableImpl<>(validInitialState);
addStateAndRules(validInitialState, Collections.emptyList() );
validConfig.getStates().forEach(state -> addStateAndRules(state, validConfig.getStateRules(state)));
}
private <R> Transition<S,R> transitionCheck(Transition<S,R> transition) {
final Transition<S,R> validTransition = nullCheck(transition, "Transition must be present.");
existsCheck(validTransition.getSuccessState());
ofNullable(transition.getEvent()).orElseThrow(this::getEventNotPresentException);
return validTransition;
}
private <R> R handleSuccess(Transition<S, R> t) {
final R value = orNull(t.getSuccessValue());
setState(t.getEvent(), t.getSuccessState());
return value;
}
private <R> R handleDenied(Transition<S, R> transition) {
setOptionalState(transition.getFailedState(), transition.getEvent());
if (transition.getFailedValue().isPresent()) {
return transition.getFailedValue().get().get();
}
throw new ConcurrencyException("Illegal state transition from " + getState() +
" to " + transition.getSuccessState() + ".");
}
private <R> R handleError(Transition<S, R> t, Throwable thrown) throws Error, ConcurrencyException, RuntimeException {
setOptionalState(t.getErrorState(), t.getEvent());
if (t.getErrorValue().isPresent()) {
return t.getErrorValue().get().get();
} else {
throwUnchecked(thrown, "State machine error.");
return null;
}
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private static <X> X orNull(Optional<Supplier<X>> optional) {
return optional.map(Supplier::get).orElse(null);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
private <R> void setOptionalState(Optional<S> optional, String event) {
optional.ifPresent(s -> setState(event, s));
}
private S existsCheck(S state) {
final S validState = stateCheck(state);
return illegalCheck(validState, !hasState(validState), "Rule does not exist.");
}
private IllegalArgumentException getInitialStateNotPresentException() {
return new IllegalArgumentException("Initial state must be present.");
}
private IllegalArgumentException getEventNotPresentException() {
return new IllegalArgumentException("Event must be present.");
}
private void addStateAndRules(S state, List<Rule<S>> rules) {
final S validState = stateCheck(state);
final List<Rule<S>> validRules = nullCheck(rules, "Rules must be present.");
final Set<Rule<S>> knownRules = stateToRulesLookup(validState);
validRules.forEach(rule -> knownRules.add(ruleCheck(rule)));
}
private Set<Rule<S>> stateToRulesLookup(S state) {
return stateToRulesLookup.computeIfAbsent(state, k -> new HashSet<>());
}
private final HashMap<S, Set<Rule<S>>> stateToRulesLookup = new HashMap<>();
private final Waitable<S> currentState;
}