package rlpark.plugin.rltoys.algorithms.control.sarsa;

import rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction;
import rlpark.plugin.rltoys.algorithms.functions.Predictor;
import rlpark.plugin.rltoys.algorithms.traces.ATraces;
import rlpark.plugin.rltoys.algorithms.traces.Traces;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/sarsa/Sarsa.class */
public class Sarsa implements Predictor, ParameterizedFunction {
    private static final long serialVersionUID = 9030254074554565900L;

    @Monitor(level = 4)
    protected final Traces e;

    @Monitor(level = 4)
    protected final PVector q;
    protected final double lambda;
    protected final double gamma;
    protected double alpha;
    protected double delta;
    protected double v_t;
    protected double v_tp1;
    static final /* synthetic */ boolean $assertionsDisabled;

    public Sarsa(double d, double d2, double d3, int i) {
        this(d, d2, d3, i, new ATraces());
    }

    public Sarsa(double d, double d2, double d3, int i, Traces traces) {
        this(d, d2, d3, new PVector(i), traces);
    }

    public Sarsa(double d, double d2, double d3, PVector pVector, Traces traces) {
        this.alpha = d;
        this.gamma = d2;
        this.lambda = d3;
        this.q = pVector;
        this.e = traces.newTraces(pVector.getDimension());
    }

    public double update(RealVector realVector, RealVector realVector2, double d) {
        if (realVector == null) {
            return initEpisode();
        }
        this.v_tp1 = realVector2 != null ? this.q.dotProduct(realVector2) : 0.0d;
        this.v_t = this.q.dotProduct(realVector);
        this.delta = (d + (this.gamma * this.v_tp1)) - this.v_t;
        this.e.update(this.gamma * this.lambda, realVector);
        this.q.addToSelf(this.alpha * this.delta, (RealVector) this.e.vect());
        return this.delta;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double initEpisode() {
        this.e.clear();
        return 0.0d;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.Predictor
    public double predict(RealVector realVector) {
        if ($assertionsDisabled || this.q.getDimension() == realVector.getDimension()) {
            return this.q.dotProduct(realVector);
        }
        throw new AssertionError();
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction
    public PVector weights() {
        return this.q;
    }

    static {
        $assertionsDisabled = !Sarsa.class.desiredAssertionStatus();
    }
}
