package rlpark.plugin.rltoys.algorithms.predictions.td;

import rlpark.plugin.rltoys.algorithms.traces.ATraces;
import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm;
import rlpark.plugin.rltoys.algorithms.traces.Traces;
import rlpark.plugin.rltoys.math.vector.MutableVector;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs;
import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/predictions/td/GTDLambda.class */
public class GTDLambda implements OnPolicyTD, GVF, EligibilityTraceAlgorithm {
    private static final long serialVersionUID = 8687476023177671278L;
    protected double gamma;
    public final double alpha_v;
    public final double alpha_w;
    protected double lambda;
    private double gamma_t;

    @Monitor(level = 4)
    public final PVector v;

    @Monitor(level = 4)
    protected final PVector w;
    private final Traces e;
    protected double v_t;

    @Monitor(wrappers = {Squared.ID, Abs.ID})
    protected double delta_t;
    private double correction;
    private double rho_t;

    public GTDLambda(double d, double d2, double d3, double d4, int i) {
        this(d, d2, d3, d4, i, new ATraces());
    }

    public GTDLambda(double d, double d2, double d3, double d4, int i, Traces traces) {
        this.alpha_v = d3;
        this.gamma = d2;
        this.lambda = d;
        this.alpha_w = d4;
        this.v = new PVector(i);
        this.w = new PVector(i);
        this.e = traces.newTraces(i);
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.GVF
    public double update(double d, double d2, RealVector realVector, RealVector realVector2, double d3, double d4, double d5) {
        if (realVector == null) {
            return initEpisode(d4);
        }
        VectorPool pool = VectorPools.pool(this.e.vect());
        this.v_t = this.v.dotProduct(realVector);
        this.delta_t = ((d3 + ((1.0d - d4) * d5)) + (d4 * this.v.dotProduct(realVector2))) - this.v_t;
        this.e.update(this.gamma_t * this.lambda, realVector);
        this.rho_t = d / d2;
        this.e.vect().mapMultiplyToSelf(this.rho_t);
        MutableVector newVector = pool.newVector();
        if (realVector2 != null) {
            this.correction = this.e.vect().dotProduct(this.w);
            newVector.addToSelf(this.correction * d4 * (1.0d - this.lambda), realVector2);
        }
        MutableVector mapMultiplyToSelf = pool.newVector(this.e.vect()).mapMultiplyToSelf(this.delta_t);
        this.v.addToSelf(this.alpha_v, (RealVector) pool.newVector(mapMultiplyToSelf).subtractToSelf(newVector));
        this.w.addToSelf(this.alpha_w, (RealVector) mapMultiplyToSelf.addToSelf(-this.w.dotProduct(realVector), realVector));
        this.gamma_t = d4;
        pool.releaseAll();
        return this.delta_t;
    }

    protected double initEpisode(double d) {
        this.gamma_t = d;
        this.e.clear();
        this.v_t = 0.0d;
        return 0.0d;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public void resetWeight(int i) {
        this.v.data[i] = 0.0d;
        this.e.vect().setEntry(i, 0.0d);
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD
    public double update(RealVector realVector, RealVector realVector2, double d) {
        return update(1.0d, 1.0d, realVector, realVector2, d, this.gamma, 0.0d);
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD
    public double update(double d, double d2, RealVector realVector, RealVector realVector2, double d3) {
        return update(d, d2, realVector, realVector2, d3, this.gamma, 0.0d);
    }

    public double update(double d, double d2, RealVector realVector, RealVector realVector2, double d3, double d4) {
        return update(d, d2, realVector, realVector2, d3, d4, 0.0d);
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.Predictor
    public double predict(RealVector realVector) {
        return this.v.dotProduct(realVector);
    }

    public double gamma() {
        return this.gamma;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction
    public PVector weights() {
        return this.v;
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD
    public PVector secondaryWeights() {
        return this.w;
    }

    @Override // rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm
    public Traces traces() {
        return this.e;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public double error() {
        return this.delta_t;
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD, rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD
    public double prediction() {
        return this.v_t;
    }

    public double correction() {
        return this.correction;
    }
}
