package rlpark.plugin.rltoys.algorithms.control.qlearning;

import rlpark.plugin.rltoys.algorithms.LinearLearner;
import rlpark.plugin.rltoys.algorithms.control.acting.Greedy;
import rlpark.plugin.rltoys.algorithms.functions.Predictor;
import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
import rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm;
import rlpark.plugin.rltoys.algorithms.traces.Traces;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.policy.Policy;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/qlearning/QLearning.class */
public class QLearning implements Predictor, LinearLearner, EligibilityTraceAlgorithm {
    private static final long serialVersionUID = -404558746167490755L;

    @Monitor(level = 4)
    protected final PVector theta;
    private final Traces e;
    private final double lambda;
    private final double gamma;
    private final double alpha;
    private final StateToStateAction toStateAction;
    private double delta;
    private final Greedy greedy;

    public QLearning(Action[] actionArr, double d, double d2, double d3, StateToStateAction stateToStateAction, Traces traces) {
        this.alpha = d;
        this.gamma = d2;
        this.lambda = d3;
        this.toStateAction = stateToStateAction;
        this.greedy = new Greedy(this, actionArr, stateToStateAction);
        this.theta = new PVector(stateToStateAction.vectorSize());
        this.e = traces.newTraces(stateToStateAction.vectorSize());
    }

    public double update(RealVector realVector, Action action, RealVector realVector2, Action action2, double d) {
        if (realVector == null) {
            return initEpisode();
        }
        this.greedy.update(realVector);
        Action bestAction = this.greedy.bestAction();
        this.greedy.update(realVector2);
        RealVector stateAction = this.toStateAction.stateAction(realVector, action);
        this.delta = (d + (this.gamma * this.greedy.bestActionValue())) - this.theta.dotProduct(stateAction);
        if (action == bestAction) {
            this.e.update(this.gamma * this.lambda, stateAction);
        } else {
            this.e.clear();
            this.e.update(0.0d, stateAction);
        }
        this.theta.addToSelf(this.alpha * this.delta, (RealVector) this.e.vect());
        return this.delta;
    }

    private double initEpisode() {
        if (this.e != null) {
            this.e.clear();
        }
        this.delta = 0.0d;
        return this.delta;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.Predictor
    public double predict(RealVector realVector) {
        return this.theta.dotProduct(realVector);
    }

    public PVector theta() {
        return this.theta;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public void resetWeight(int i) {
        this.theta.setEntry(i, 0.0d);
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.ParameterizedFunction
    public PVector weights() {
        return this.theta;
    }

    @Override // rlpark.plugin.rltoys.algorithms.LinearLearner
    public double error() {
        return this.delta;
    }

    public Policy greedy() {
        return this.greedy;
    }

    @Override // rlpark.plugin.rltoys.algorithms.traces.EligibilityTraceAlgorithm
    public Traces traces() {
        return this.e;
    }
}
