package rlpark.plugin.rltoys.algorithms.control.acting;

import java.util.Arrays;
import java.util.Random;
import rlpark.plugin.rltoys.algorithms.functions.Predictor;
import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.policy.StochasticPolicy;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.utils.Utils;

/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/acting/SoftMax.class */
public class SoftMax extends StochasticPolicy {
    private static final long serialVersionUID = -2129719316562814077L;
    private final StateToStateAction toStateAction;
    private final double temperature;
    private final Predictor predictor;
    private final double[] distribution;
    static final /* synthetic */ boolean $assertionsDisabled;

    public SoftMax(Random random, Predictor predictor, Action[] actionArr, StateToStateAction stateToStateAction, double d) {
        super(random, actionArr);
        this.toStateAction = stateToStateAction;
        this.temperature = d;
        this.predictor = predictor;
        this.distribution = new double[actionArr.length];
    }

    public SoftMax(Random random, Predictor predictor, Action[] actionArr, StateToStateAction stateToStateAction) {
        this(random, predictor, actionArr, stateToStateAction, 1.0d);
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public Action sampleAction() {
        return chooseAction(this.distribution);
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public void update(RealVector realVector) {
        double d = 0.0d;
        for (int i = 0; i < this.actions.length; i++) {
            double exp = Math.exp(this.predictor.predict(this.toStateAction.stateAction(realVector, this.actions[i])) / this.temperature);
            if (!$assertionsDisabled && !Utils.checkValue(exp)) {
                throw new AssertionError();
            }
            d += exp;
            this.distribution[i] = exp;
        }
        if (d == 0.0d) {
            Arrays.fill(this.distribution, 1.0d);
            d = this.distribution.length;
        }
        for (int i2 = 0; i2 < this.distribution.length; i2++) {
            double[] dArr = this.distribution;
            int i3 = i2;
            dArr[i3] = dArr[i3] / d;
            if (!$assertionsDisabled && !Utils.checkValue(this.distribution[i2])) {
                throw new AssertionError();
            }
        }
        if (!$assertionsDisabled && !checkDistribution(this.distribution)) {
            throw new AssertionError();
        }
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public double pi(Action action) {
        return this.distribution[atoi(action)];
    }

    @Override // rlpark.plugin.rltoys.envio.policy.StochasticPolicy
    public double[] distribution() {
        return this.distribution;
    }

    static {
        $assertionsDisabled = !SoftMax.class.desiredAssertionStatus();
    }
}
