package rlpark.plugin.rltoys.algorithms.predictions.td;

import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.math.vector.pool.VectorPool;
import rlpark.plugin.rltoys.math.vector.pool.VectorPools;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/predictions/td/HTD.class */
public class HTD implements OnPolicyTD, GVF {
    private static final long serialVersionUID = 8687476023177671278L;
    protected double gamma;
    public double alpha_v;
    public double alpha_w;

    @Monitor(level = 4)
    public PVector v;

    @Monitor(level = 4)
    protected final PVector w;
    public double v_t;
    protected double delta_t;
    private double correction;
    private double ratio;
    private double rho_t;

    public HTD(double d, double d2, double d3, int i) {
        this.alpha_v = d2;
        this.gamma = d;
        this.alpha_w = d3;
        this.v = new PVector(i);
        this.w = new PVector(i);
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD
    public double update(RealVector realVector, RealVector realVector2, double d) {
        return update(1.0d, 1.0d, realVector, realVector2, d);
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD
    public double update(double d, double d2, RealVector realVector, RealVector realVector2, double d3) {
        return update(d, d2, realVector, realVector2, d3, this.gamma, 0.0d);
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.GVF
    public double update(double d, double d2, RealVector realVector, RealVector realVector2, double d3, double d4, double d5) {
        if (realVector == null) {
            return initEpisode();
        }
        VectorPool pool = VectorPools.pool(realVector);
        this.v_t = this.v.dotProduct(realVector);
        this.delta_t = ((d3 + ((1.0d - d4) * d5)) + (d4 * this.v.dotProduct(realVector2))) - this.v_t;
        this.correction = this.w.dotProduct(realVector2);
        this.ratio = (d - d2) / d2;
        this.rho_t = d / d2;
        this.v.addToSelf(this.alpha_v, (RealVector) pool.newVector(realVector).mapMultiplyToSelf(this.rho_t * this.delta_t).addToSelf(pool.newVector(realVector2).mapMultiplyToSelf(d4 * this.ratio * this.correction)));
        this.w.addToSelf(this.alpha_w, (RealVector) pool.newVector(realVector).mapMultiplyToSelf(this.rho_t * (this.delta_t - this.correction)).addToSelf(pool.newVector(realVector2).mapMultiplyToSelf((-d4) * this.correction)));
        pool.releaseAll();
        return this.delta_t;
    }

    protected double initEpisode() {
        this.v_t = 0.0d;
        this.delta_t = 0.0d;
        return this.delta_t;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public void resetWeight(int i) {
        this.v.data[i] = 0.0d;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.Predictor
    public double predict(RealVector realVector) {
        return this.v.dotProduct(realVector);
    }

    public double gamma() {
        return this.gamma;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction
    public PVector weights() {
        return this.v;
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD
    public PVector secondaryWeights() {
        return this.w;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public double error() {
        return this.delta_t;
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD, rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD
    public double prediction() {
        return this.v_t;
    }
}
