package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy;

import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
import rlpark.plugin.rltoys.algorithms.traces.ATraces;
import rlpark.plugin.rltoys.algorithms.traces.Traces;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.utils.Utils;
import zephyr.plugin.core.api.monitoring.annotations.LabelProvider;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/actorcritic/onpolicy/ActorLambda.class */
public class ActorLambda extends Actor {
    private static final long serialVersionUID = -1601184295976574511L;
    public final Traces[] e_u;
    private final double lambda;
    private final double gamma;

    public ActorLambda(double d, double d2, PolicyDistribution policyDistribution, double d3, int i) {
        this(d, d2, policyDistribution, d3, i, new ATraces());
    }

    public ActorLambda(double d, double d2, PolicyDistribution policyDistribution, double d3, int i, Traces traces) {
        this(d, d2, policyDistribution, Utils.newFilledArray(policyDistribution.nbParameterVectors(), d3), i, traces);
    }

    public ActorLambda(double d, double d2, PolicyDistribution policyDistribution, double[] dArr, int i, Traces traces) {
        super(policyDistribution, dArr, i);
        this.lambda = d;
        this.gamma = d2;
        this.e_u = new Traces[policyDistribution.nbParameterVectors()];
        for (int i2 = 0; i2 < this.e_u.length; i2++) {
            this.e_u[i2] = traces.newTraces(this.u[i2].size);
        }
    }

    @Override // rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy.Actor
    public void update(RealVector realVector, Action action, double d) {
        if (realVector == null) {
            initEpisode();
            return;
        }
        RealVector[] computeGradLog = this.policyDistribution.computeGradLog(action);
        for (int i = 0; i < this.u.length; i++) {
            this.e_u[i].update(this.gamma * this.lambda, computeGradLog[i]);
        }
        updatePolicyParameters(computeGradLog, d);
    }

    protected void updatePolicyParameters(RealVector[] realVectorArr, double d) {
        for (int i = 0; i < this.u.length; i++) {
            this.u[i].addToSelf(this.alpha_u[i] * d, (RealVector) this.e_u[i].vect());
        }
    }

    private void initEpisode() {
        for (Traces traces : this.e_u) {
            traces.clear();
        }
    }

    @LabelProvider(ids = {"e_u"})
    String eligiblityLabelOf(int i) {
        return super.labelOf(i);
    }
}
