package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures;

import java.util.Random;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.actions.ActionArray;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.Vectors;
import rlpark.plugin.rltoys.utils.Utils;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/functions/policydistributions/structures/NormalDistribution.class */
public class NormalDistribution extends AbstractNormalDistribution {
    private static final long serialVersionUID = -4074721193363280217L;
    protected double sigma2;
    private final double initialMean;
    private final double initialStddev;
    static final /* synthetic */ boolean $assertionsDisabled;

    public NormalDistribution(Random random, double d, double d2) {
        super(random);
        this.initialMean = d;
        this.initialStddev = d2;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution
    public RealVector[] computeGradLog(Action action) {
        updateSteps(ActionArray.toDouble(action));
        this.gradMean.set(this.x).mapMultiplyToSelf(this.meanStep);
        this.gradStddev.set(this.x).mapMultiplyToSelf(this.stddevStep);
        if ($assertionsDisabled || (Vectors.checkValues(this.gradMean) && Vectors.checkValues(this.gradStddev))) {
            return new RealVector[]{this.gradMean, this.gradStddev};
        }
        throw new AssertionError();
    }

    protected void updateSteps(double d) {
        this.meanStep = (d - this.mean) / this.sigma2;
        this.stddevStep = (Utils.square(d - this.mean) / this.sigma2) - 1.0d;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public Action sampleAction() {
        this.a_t = (this.random.nextGaussian() * this.stddev) + this.mean;
        if (Utils.checkValue(this.a_t)) {
            return new ActionArray(this.a_t);
        }
        return null;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.AbstractNormalDistribution
    protected void updateDistribution() {
        this.mean = this.u_mean.dotProduct(this.x) + this.initialMean;
        this.stddev = (Math.exp(this.u_stddev.dotProduct(this.x)) * this.initialStddev) + 1.0E-7d;
        this.sigma2 = Utils.square(this.stddev);
        if ($assertionsDisabled) {
            return;
        }
        if (!Utils.checkValue(this.mean) || !Utils.checkValue(this.sigma2)) {
            throw new AssertionError();
        }
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.AbstractNormalDistribution
    public double pi_s(double d) {
        return Math.exp((-((d - this.mean) * (d - this.mean))) / (2.0d * this.sigma2)) / Math.sqrt(6.283185307179586d * this.sigma2);
    }

    public static JointDistribution newJointDistribution(Random random, int i, double d, double d2) {
        PolicyDistribution[] policyDistributionArr = new PolicyDistribution[i];
        for (int i2 = 0; i2 < policyDistributionArr.length; i2++) {
            policyDistributionArr[i2] = new NormalDistribution(random, d, d2);
        }
        return new JointDistribution(policyDistributionArr);
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf
    public double piMax() {
        return Math.max(pi(new ActionArray(this.mean)), 1.0E-7d);
    }

    static {
        $assertionsDisabled = !NormalDistribution.class.desiredAssertionStatus();
    }
}
