package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures;

import java.util.Random;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.policy.StochasticPolicy;
import rlpark.plugin.rltoys.math.ranges.Range;
import rlpark.plugin.rltoys.math.vector.MutableVector;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.utils.Utils;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/functions/policydistributions/structures/BoltzmannDistribution.class */
public class BoltzmannDistribution extends StochasticPolicy implements PolicyDistribution {
    private static final long serialVersionUID = 7036360201611314726L;
    private final MutableVector[] xa;

    @Monitor(level = 4)
    private PVector u;
    private MutableVector xaBar;
    private MutableVector gradBuffer;
    private final StateToStateAction toStateAction;
    private final double[] distribution;

    @Monitor
    private final Range linearRangeOverall;

    @Monitor
    private final Range linearRangeAveraged;
    static final /* synthetic */ boolean $assertionsDisabled;

    public BoltzmannDistribution(Random random, Action[] actionArr, StateToStateAction stateToStateAction) {
        super(random, actionArr);
        this.linearRangeOverall = new Range(1.0d, 1.0d);
        this.linearRangeAveraged = new Range(1.0d, 1.0d);
        if (!$assertionsDisabled && stateToStateAction == null) {
            throw new AssertionError();
        }
        this.toStateAction = stateToStateAction;
        this.distribution = new double[actionArr.length];
        this.xa = new MutableVector[actionArr.length];
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public double pi(Action action) {
        return this.distribution[atoi(action)];
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public void update(RealVector realVector) {
        this.linearRangeAveraged.reset();
        double d = 0.0d;
        clearBuffers(realVector);
        for (int i = 0; i < this.actions.length; i++) {
            this.xa[i].set(this.toStateAction.stateAction(realVector, this.actions[i]));
            double dotProduct = this.u.dotProduct(this.xa[i]);
            this.linearRangeOverall.update(dotProduct);
            this.linearRangeAveraged.update(dotProduct);
            double exp = Math.exp(dotProduct);
            if (!$assertionsDisabled && !Utils.checkValue(exp)) {
                throw new AssertionError();
            }
            this.distribution[i] = exp;
            d += exp;
            this.xaBar.addToSelf(exp, this.xa[i]);
        }
        for (int i2 = 0; i2 < this.distribution.length; i2++) {
            double[] dArr = this.distribution;
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
            if (!$assertionsDisabled && !Utils.checkValue(this.distribution[i2])) {
                throw new AssertionError();
            }
        }
        this.xaBar.mapMultiplyToSelf(1.0d / d);
    }

    private void clearBuffers(RealVector realVector) {
        if (this.xaBar != null) {
            this.xaBar.clear();
            return;
        }
        this.xaBar = this.toStateAction.stateAction(realVector, this.actions[0]).newInstance(this.u.size);
        this.gradBuffer = this.xaBar.newInstance(this.u.size);
        for (int i = 0; i < this.xa.length; i++) {
            this.xa[i] = this.xaBar.newInstance(this.u.size);
        }
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public Action sampleAction() {
        return chooseAction(this.distribution);
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution
    public PVector[] createParameters(int i) {
        this.u = new PVector(this.toStateAction.vectorSize());
        return new PVector[]{this.u};
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution
    public RealVector[] computeGradLog(Action action) {
        this.gradBuffer.clear();
        this.gradBuffer.set(this.xa[atoi(action)]);
        return new RealVector[]{this.gradBuffer.subtractToSelf(this.xaBar)};
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution
    public int nbParameterVectors() {
        return 1;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.StochasticPolicy
    public double[] distribution() {
        return this.distribution;
    }

    public static double probaToLinearValue(int i, double d) {
        double log = Math.log(d * (i - 1)) - Math.log(1.0d - d);
        if ($assertionsDisabled || ((d > 0.5d && log > 0.0d) || (d < 0.5d && log < 0.0d))) {
            return log;
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !BoltzmannDistribution.class.desiredAssertionStatus();
    }
}
