package rlpark.plugin.rltoys.algorithms.predictions.td;

import rlpark.plugin.rltoys.algorithms.traces.ATraces;
import rlpark.plugin.rltoys.algorithms.traces.Traces;
import rlpark.plugin.rltoys.math.vector.DenseVector;
import rlpark.plugin.rltoys.math.vector.MutableVector;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.math.vector.implementations.SVector;
import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
import zephyr.plugin.core.api.internal.monitoring.wrappers.Abs;
import zephyr.plugin.core.api.internal.monitoring.wrappers.Squared;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/predictions/td/TDLambdaAutostep.class */
public class TDLambdaAutostep implements OnPolicyTD {
    private static final long serialVersionUID = 1567652945995637498L;
    protected double mu;
    protected double tau;

    @Monitor(level = 4)
    private final PVector v;
    private double v_t;

    @Monitor(wrappers = {Squared.ID, Abs.ID})
    private double delta_t;
    protected final Traces e;

    @Monitor(level = 4)
    protected final PVector alpha;

    @Monitor(level = 4)
    protected final PVector h;

    @Monitor(level = 4)
    protected final PVector normalizer;
    protected double maxOneM2;
    private final double gamma;
    private final double lambda;
    private double m;
    double tempM;
    private final double lowerNumericalBound;

    public TDLambdaAutostep(double d, double d2, int i) {
        this(d, d2, 0.1d, i, new ATraces());
    }

    public TDLambdaAutostep(double d, double d2, int i, Traces traces) {
        this(d, d2, 0.1d, i, traces);
    }

    public TDLambdaAutostep(double d, double d2, double d3, int i) {
        this(d, d2, d3, i, new ATraces());
    }

    public TDLambdaAutostep(double d, double d2, double d3, int i, Traces traces) {
        this.mu = 0.01d;
        this.tau = 1000.0d;
        this.tempM = 0.0d;
        this.lambda = d;
        this.e = traces.newTraces(i);
        this.gamma = d2;
        this.v = new PVector(i);
        this.alpha = new PVector(i);
        this.alpha.set(d3);
        this.h = new PVector(i);
        this.normalizer = new PVector(i);
        this.normalizer.set(1.0d);
        this.lowerNumericalBound = Math.pow(10.0d, -10.0d) / i;
    }

    public void setMu(double d) {
        this.mu = d;
    }

    public void setTau(double d) {
        this.tau = d;
    }

    protected double initEpisode() {
        this.e.clear();
        return 0.0d;
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD
    public double update(RealVector realVector, RealVector realVector2, double d) {
        if (realVector == null) {
            return initEpisode();
        }
        this.v_t = this.v.dotProduct(realVector);
        this.delta_t = (d + (this.gamma * this.v.dotProduct(realVector2))) - this.v_t;
        this.e.update(this.lambda * this.gamma, realVector);
        PVector pVector = new PVector(realVector.accessData());
        if (this.e.vect() instanceof SVector) {
            updateNormalizationAndStepSizeSparse(this.delta_t, pVector.data);
        } else {
            if (!(this.e.vect() instanceof DenseVector)) {
                throw new RuntimeException("Not implemented");
            }
            updateNormalizationAndStepSizeDense(this.delta_t, pVector.data);
        }
        MutableVector ebeMultiply = this.e.vect().ebeMultiply(this.alpha);
        MutableVector mapMultiply = ebeMultiply.mapMultiply(this.delta_t);
        this.v.addToSelf(mapMultiply);
        this.h.addToSelf(mapMultiply.subtractToSelf(Vectors.absToSelf(ebeMultiply.ebeMultiplyToSelf(pVector)).ebeMultiplyToSelf(this.h)));
        return this.delta_t;
    }

    private void updateNormalizationAndStepSizeDense(double d, double[] dArr) {
        double[] dArr2 = this.normalizer.data;
        double[] dArr3 = this.alpha.data;
        double[] accessData = ((DenseVector) this.e.vect()).accessData();
        for (int i = 0; i < accessData.length; i++) {
            updateStepSizeNormalizers(dArr, dArr2, dArr3, i, accessData[i], d);
        }
        this.m = 0.0d;
        for (int i2 = 0; i2 < accessData.length; i2++) {
            this.m += featureNorm(dArr, dArr3, i2, accessData[i2]);
        }
        this.maxOneM2 = Math.max(1.0d, this.m);
        for (int i3 = 0; i3 < accessData.length; i3++) {
            if (dArr[i3] != 0.0d) {
                int i4 = i3;
                dArr3[i4] = dArr3[i4] / this.maxOneM2;
            }
        }
    }

    private void updateNormalizationAndStepSizeSparse(double d, double[] dArr) {
        double[] dArr2 = this.normalizer.data;
        double[] dArr3 = this.alpha.data;
        SVector sVector = (SVector) this.e.vect();
        for (int i = 0; i < sVector.nonZeroElements(); i++) {
            updateStepSizeNormalizers(dArr, dArr2, dArr3, sVector.activeIndexes[i], sVector.values[i], d);
        }
        this.m = 0.0d;
        for (int i2 = 0; i2 < sVector.nonZeroElements(); i2++) {
            this.m += featureNorm(dArr, dArr3, sVector.activeIndexes[i2], sVector.values[i2]);
        }
        this.maxOneM2 = Math.max(1.0d, this.m);
        for (int i3 : sVector.activeIndexes) {
            if (dArr[i3] != 0.0d) {
                dArr3[i3] = dArr3[i3] / this.maxOneM2;
            }
        }
    }

    private void updateStepSizeNormalizers(double[] dArr, double[] dArr2, double[] dArr3, int i, double d, double d2) {
        double computeAbsDeltaEH = computeAbsDeltaEH(i, d, d2);
        dArr2[i] = Math.max(computeAbsDeltaEH, dArr2[i] + ((featureNorm(dArr, dArr3, i, d) / this.tau) * (computeAbsDeltaEH - dArr2[i])));
        dArr2[i] = Math.max(this.lowerNumericalBound, dArr2[i]);
        dArr3[i] = dArr3[i] * Math.exp((((this.mu * d2) * d) * this.h.data[i]) / dArr2[i]);
        dArr3[i] = Math.max(this.lowerNumericalBound, dArr3[i]);
    }

    private double featureNorm(double[] dArr, double[] dArr2, int i, double d) {
        return dArr2[i] * Math.abs(d * dArr[i]);
    }

    private double computeAbsDeltaEH(int i, double d, double d2) {
        return Math.abs(d * this.h.data[i] * d2);
    }

    public Traces eligibility() {
        return this.e;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.Predictor
    public double predict(RealVector realVector) {
        return this.v.dotProduct(realVector);
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction
    public PVector weights() {
        return this.v;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public void resetWeight(int i) {
        this.v.data[i] = 0.0d;
        this.alpha.data[i] = 0.1d;
        this.h.data[i] = 0.0d;
        this.normalizer.data[i] = 0.0d;
        this.e.vect().setEntry(i, 0.0d);
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public double error() {
        return this.delta_t;
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD, rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD
    public double prediction() {
        return this.v_t;
    }

    public double gamma() {
        return this.gamma;
    }
}
