package rlpark.plugin.rltoys.algorithms.control.gq;

import rlpark.plugin.rltoys.algorithms.LinearLearner;
import rlpark.plugin.rltoys.algorithms.functions.Predictor;
import rlpark.plugin.rltoys.algorithms.traces.ATraces;
import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm;
import rlpark.plugin.rltoys.algorithms.traces.Traces;
import rlpark.plugin.rltoys.math.vector.MutableVector;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/gq/GQ.class */
public class GQ implements Predictor, LinearLearner, EligibilityTraceAlgorithm {
    private static final long serialVersionUID = -4971665888576276439L;

    @Monitor(level = 4)
    public final PVector v;
    protected double alpha_v;
    protected double alpha_w;
    protected double beta_tp1;
    protected double lambda_t;

    @Monitor(level = 4)
    protected final PVector w;
    protected final Traces e;
    protected double delta_t;

    public GQ(double d, double d2, double d3, double d4, int i) {
        this(d, d2, d3, d4, i, new ATraces());
    }

    public GQ(double d, double d2, double d3, double d4, int i, Traces traces) {
        this.alpha_v = d;
        this.alpha_w = d2;
        this.beta_tp1 = d3;
        this.lambda_t = d4;
        this.e = traces.newTraces(i);
        this.v = new PVector(i);
        this.w = new PVector(i);
    }

    protected double initEpisode() {
        this.e.clear();
        return 0.0d;
    }

    public double update(RealVector realVector, double d, double d2, RealVector realVector2, double d3) {
        if (realVector == null) {
            return initEpisode();
        }
        VectorPool pool = VectorPools.pool(realVector);
        this.delta_t = ((d2 + (this.beta_tp1 * d3)) + ((1.0d - this.beta_tp1) * this.v.dotProduct(realVector2))) - this.v.dotProduct(realVector);
        this.e.update((1.0d - this.beta_tp1) * this.lambda_t * d, realVector);
        MutableVector mapMultiplyToSelf = pool.newVector(this.e.vect()).mapMultiplyToSelf(this.delta_t);
        MutableVector newVector = pool.newVector();
        if (!Vectors.isNull(realVector2)) {
            newVector.set(realVector2).mapMultiplyToSelf((1.0d - this.beta_tp1) * (1.0d - this.lambda_t) * this.e.vect().dotProduct(this.w));
        }
        this.v.addToSelf(this.alpha_v, (RealVector) pool.newVector(mapMultiplyToSelf).subtractToSelf(newVector));
        this.w.addToSelf(this.alpha_w, (RealVector) mapMultiplyToSelf.subtractToSelf(pool.newVector(realVector).mapMultiplyToSelf(this.w.dotProduct(realVector))));
        pool.releaseAll();
        return this.delta_t;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.Predictor
    public double predict(RealVector realVector) {
        return this.v.dotProduct(realVector);
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction
    public PVector weights() {
        return this.v;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public void resetWeight(int i) {
        this.v.data[i] = 0.0d;
        this.e.vect().setEntry(i, 0.0d);
        this.w.data[i] = 0.0d;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public double error() {
        return this.delta_t;
    }

    @Override // rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm
    public Traces traces() {
        return this.e;
    }
}
