package rlpark.plugin.rltoys.problems.puddleworld;

import java.util.Arrays;
import java.util.Random;
import rlpark.plugin.rltoys.algorithms.functions.ContinuousFunction;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.actions.ActionArray;
import rlpark.plugin.rltoys.envio.observations.Legend;
import rlpark.plugin.rltoys.envio.rl.TRStep;
import rlpark.plugin.rltoys.math.ranges.Range;
import rlpark.plugin.rltoys.problems.ProblemBounded;
import rlpark.plugin.rltoys.problems.ProblemContinuousAction;
import rlpark.plugin.rltoys.problems.ProblemDiscreteAction;
import rlpark.plugin.rltoys.utils.Utils;
import zephyr.plugin.core.api.monitoring.abstracts.DataMonitor;
import zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer;
import zephyr.plugin.core.api.monitoring.abstracts.Monitored;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

/* loaded from: input_file:rlpark/plugin/rltoys/problems/puddleworld/PuddleWorld.class */
public class PuddleWorld implements ProblemBounded, ProblemDiscreteAction, ProblemContinuousAction, MonitorContainer {
    private final int nbDimensions;
    private final Random random;
    private final Range observationRange;
    private final Range actionRange;
    private final double absoluteNoise;

    @Monitor
    private final double[] lastActions;
    protected TRStep step = null;
    private double[] start = null;

    @Monitor
    private ContinuousFunction rewardFunction = null;
    private TerminationFunction terminationFunction = null;
    private final Legend legend = createLegend();
    private final Action[] actions = createActions();

    public PuddleWorld(Random random, int i, Range range, Range range2, double d) {
        this.random = random;
        this.observationRange = range;
        this.actionRange = range2;
        this.nbDimensions = i;
        this.absoluteNoise = (range2.length() / 2.0d) * d;
        this.lastActions = new double[i];
    }

    private Action[] createActions() {
        Action[] actionArr = new Action[(2 * this.nbDimensions) + 1];
        for (int i = 0; i < actionArr.length - 1; i++) {
            int i2 = i / 2;
            int i3 = i % 2;
            double[] newFilledArray = Utils.newFilledArray(this.nbDimensions, 0.0d);
            if (i3 == 0) {
                newFilledArray[i2] = -1.0d;
            } else {
                newFilledArray[i2] = 1.0d;
            }
            actionArr[i] = new ActionArray(newFilledArray);
        }
        actionArr[actionArr.length - 1] = new ActionArray(Utils.newFilledArray(this.nbDimensions, 0.0d));
        return actionArr;
    }

    public void setStart(double[] dArr) {
        this.start = dArr;
    }

    public void setRewardFunction(ContinuousFunction continuousFunction) {
        this.rewardFunction = continuousFunction;
    }

    public void setTermination(TerminationFunction terminationFunction) {
        this.terminationFunction = terminationFunction;
    }

    private Legend createLegend() {
        String[] strArr = new String[this.nbDimensions];
        for (int i = 0; i < this.nbDimensions; i++) {
            strArr[i] = "x" + i;
        }
        return new Legend(strArr);
    }

    @Override // rlpark.plugin.rltoys.problems.RLProblem
    public TRStep initialize() {
        double[] dArr = this.start;
        if (dArr == null) {
            dArr = new double[this.nbDimensions];
            do {
                for (int i = 0; i < dArr.length; i++) {
                    dArr[i] = this.observationRange.choose(this.random);
                }
            } while (isTerminated(dArr));
        }
        this.step = new TRStep(dArr, reward(dArr));
        return this.step;
    }

    @Override // rlpark.plugin.rltoys.problems.RLProblem
    public TRStep step(Action action) {
        if (isTerminated(this.step.o_tp1)) {
            this.step = this.step.createEndingStep();
            return this.step;
        }
        double[] computeEnvironmentAction = computeEnvironmentAction(action);
        double[] dArr = new double[this.nbDimensions];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = this.observationRange.bound(this.step.o_tp1[i] + computeEnvironmentAction[i]);
        }
        this.step = new TRStep(this.step, action, dArr, reward(dArr));
        return this.step;
    }

    private double reward(double[] dArr) {
        if (this.rewardFunction == null) {
            return 0.0d;
        }
        return this.rewardFunction.value(dArr);
    }

    private boolean isTerminated(double[] dArr) {
        if (this.terminationFunction == null) {
            return false;
        }
        return this.terminationFunction.isTerminated(dArr);
    }

    private double[] computeEnvironmentAction(Action action) {
        double[] dArr = ((ActionArray) action).actions;
        System.arraycopy(dArr, 0, this.lastActions, 0, this.nbDimensions);
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr2.length; i++) {
            dArr2[i] = this.actionRange.bound(dArr[i]) + ((this.random.nextDouble() * this.absoluteNoise) - (this.absoluteNoise / 2.0d));
        }
        return dArr2;
    }

    @Override // rlpark.plugin.rltoys.problems.RLProblem
    public Legend legend() {
        return this.legend;
    }

    @Override // rlpark.plugin.rltoys.problems.ProblemContinuousAction
    public Range[] actionRanges() {
        Range[] rangeArr = new Range[this.nbDimensions];
        Arrays.fill(rangeArr, this.actionRange);
        return rangeArr;
    }

    @Override // rlpark.plugin.rltoys.problems.ProblemBounded
    public Range[] getObservationRanges() {
        Range[] rangeArr = new Range[this.nbDimensions];
        Arrays.fill(rangeArr, this.observationRange);
        return rangeArr;
    }

    @Override // zephyr.plugin.core.api.monitoring.abstracts.MonitorContainer
    public void addToMonitor(DataMonitor dataMonitor) {
        dataMonitor.add("Reward", new Monitored() { // from class: rlpark.plugin.rltoys.problems.puddleworld.PuddleWorld.1
            @Override // zephyr.plugin.core.api.monitoring.abstracts.Monitored
            public double monitoredValue() {
                if (PuddleWorld.this.step != null) {
                    return PuddleWorld.this.step.r_tp1;
                }
                return 0.0d;
            }
        });
        for (int i = 0; i < this.legend.nbLabels(); i++) {
            final int i2 = i;
            dataMonitor.add(this.legend.label(i), new Monitored() { // from class: rlpark.plugin.rltoys.problems.puddleworld.PuddleWorld.2
                @Override // zephyr.plugin.core.api.monitoring.abstracts.Monitored
                public double monitoredValue() {
                    if (PuddleWorld.this.step == null || PuddleWorld.this.step.o_tp1 == null) {
                        return 0.0d;
                    }
                    return PuddleWorld.this.step.o_tp1[i2];
                }
            });
        }
    }

    public int nbDimensions() {
        return this.nbDimensions;
    }

    public ContinuousFunction rewardFunction() {
        return this.rewardFunction;
    }

    public double[] start() {
        return this.start;
    }

    @Override // rlpark.plugin.rltoys.problems.ProblemDiscreteAction
    public Action[] actions() {
        return this.actions;
    }

    @Override // rlpark.plugin.rltoys.problems.RLProblem
    public TRStep lastStep() {
        return this.step;
    }

    @Override // rlpark.plugin.rltoys.problems.RLProblem
    public TRStep forceEndEpisode() {
        this.step = this.step.createEndingStep();
        return this.step;
    }
}
