package rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy;

import rlpark.plugin.rltoys.algorithms.control.ControlLearner;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.math.vector.RealVector;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/control/actorcritic/onpolicy/AbstractActorCritic.class */
public abstract class AbstractActorCritic implements ControlLearner {
    private static final long serialVersionUID = -6085810735822394602L;
    public final Actor actor;
    public final OnPolicyTD critic;
    protected double reward = 0.0d;
    protected boolean policyRequireUpdate = true;

    public AbstractActorCritic(OnPolicyTD onPolicyTD, Actor actor) {
        this.critic = onPolicyTD;
        this.actor = actor;
    }

    protected abstract double updateCritic(RealVector realVector, RealVector realVector2, double d);

    protected abstract void updateActor(RealVector realVector, Action action, double d);

    @Override // rlpark.plugin.rltoys.algorithms.control.Control
    public Action proposeAction(RealVector realVector) {
        this.policyRequireUpdate = true;
        policy().update(realVector);
        return policy().sampleAction();
    }

    protected PolicyDistribution policy() {
        return this.actor.policy();
    }

    public Actor actor() {
        return this.actor;
    }

    public OnPolicyTD critic() {
        return this.critic;
    }

    @Override // rlpark.plugin.rltoys.algorithms.control.ControlLearner
    public Action step(RealVector realVector, Action action, RealVector realVector2, double d) {
        this.reward = d;
        double updateCritic = updateCritic(realVector, realVector2, d);
        this.policyRequireUpdate = realVector == null || this.policyRequireUpdate;
        if (this.policyRequireUpdate && realVector != null) {
            policy().update(realVector);
            this.policyRequireUpdate = false;
        }
        updateActor(realVector, action, updateCritic);
        policy().update(realVector2);
        this.policyRequireUpdate = false;
        return policy().sampleAction();
    }
}
