package rlpark.plugin.rltoys.algorithms.control.sarsa;

import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.policy.Policies;
import rlpark.plugin.rltoys.envio.policy.Policy;
import rlpark.plugin.rltoys.math.vector.MutableVector;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
import rlpark.plugin.rltoys.math.vector.pool.VectorPools;

/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/sarsa/ExpectedSarsaControl.class */
public class ExpectedSarsaControl extends SarsaControl {
    private static final long serialVersionUID = 738626133717186128L;
    private final Action[] actions;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ExpectedSarsaControl(Action[] actionArr, Policy policy, StateToStateAction stateToStateAction, Sarsa sarsa) {
        super(policy, stateToStateAction, sarsa);
        this.actions = actionArr;
    }

    @Override // rlpark.plugin.rltoys.algorithms.control.sarsa.SarsaControl, rlpark.plugin.rltoys.algorithms.control.ControlLearner
    public Action step(RealVector realVector, Action action, RealVector realVector2, double d) {
        this.acting.update(realVector2);
        Action sampleAction = this.acting.sampleAction();
        RealVector realVector3 = null;
        VectorPool pool = VectorPools.pool(realVector2, this.sarsa.q.size);
        MutableVector newVector = pool.newVector();
        if (realVector2 != null) {
            for (Action action2 : this.actions) {
                double pi = this.acting.pi(action2);
                if (pi != 0.0d) {
                    RealVector stateAction = this.toStateAction.stateAction(realVector2, action2);
                    if (action2 == sampleAction) {
                        realVector3 = stateAction.copy();
                    }
                    newVector.addToSelf(pi, stateAction);
                } else if (!$assertionsDisabled && action2 == sampleAction) {
                    throw new AssertionError();
                }
            }
        }
        this.sarsa.update(realVector != null ? this.xa_t : null, realVector3, d);
        this.xa_t = realVector3;
        pool.releaseAll();
        return sampleAction;
    }

    @Override // rlpark.plugin.rltoys.algorithms.control.sarsa.SarsaControl
    public Policy acting() {
        return this.acting;
    }

    @Override // rlpark.plugin.rltoys.algorithms.control.sarsa.SarsaControl, rlpark.plugin.rltoys.algorithms.control.Control
    public Action proposeAction(RealVector realVector) {
        return Policies.decide(this.acting, realVector);
    }

    static {
        $assertionsDisabled = !ExpectedSarsaControl.class.desiredAssertionStatus();
    }
}
