package rlpark.plugin.rltoys.algorithms.control.acting;

import rlpark.plugin.rltoys.algorithms.functions.Predictor;
import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.policy.DiscreteActionPolicy;
import rlpark.plugin.rltoys.envio.policy.Policy;
import rlpark.plugin.rltoys.envio.policy.PolicyPrototype;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.utils.Utils;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/acting/Greedy.class */
public class Greedy implements DiscreteActionPolicy, PolicyPrototype {
    private static final long serialVersionUID = 1675962692054005355L;
    protected final StateToStateAction toStateAction;
    protected final Predictor predictor;
    protected final Action[] actions;

    @Monitor
    protected final double[] actionValues;
    protected Action bestAction;

    @Monitor
    private double bestValue;

    public Greedy(Predictor predictor, Action[] actionArr, StateToStateAction stateToStateAction) {
        this.toStateAction = stateToStateAction;
        this.predictor = predictor;
        this.actions = actionArr;
        this.actionValues = new double[actionArr.length];
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public Action sampleAction() {
        return this.bestAction;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public void update(RealVector realVector) {
        updateActionValues(realVector);
        findBestAction();
    }

    private void findBestAction() {
        this.bestValue = this.actionValues[0];
        this.bestAction = this.actions[0];
        for (int i = 1; i < this.actions.length; i++) {
            double d = this.actionValues[i];
            if (d > this.bestValue) {
                this.bestValue = d;
                this.bestAction = this.actions[i];
            }
        }
    }

    private void updateActionValues(RealVector realVector) {
        for (int i = 0; i < this.actions.length; i++) {
            this.actionValues[i] = this.predictor.predict(this.toStateAction.stateAction(realVector, this.actions[i]));
        }
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public double pi(Action action) {
        return action == this.bestAction ? 1.0d : 0.0d;
    }

    public StateToStateAction toStateAction() {
        return this.toStateAction;
    }

    public Action bestAction() {
        return this.bestAction;
    }

    public double bestActionValue() {
        return this.bestValue;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.DiscreteActionPolicy
    public double[] values() {
        return this.actionValues;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.DiscreteActionPolicy
    public Action[] actions() {
        return this.actions;
    }

    public Policy duplicate() {
        return new Greedy(this.predictor, this.actions, (StateToStateAction) Utils.clone(this.toStateAction));
    }
}
