package rlpark.plugin.rltoys.horde.demons;

import java.io.Serializable;
import rlpark.plugin.rltoys.algorithms.predictions.td.OnPolicyTD;
import rlpark.plugin.rltoys.algorithms.predictions.td.TD;
import rlpark.plugin.rltoys.algorithms.predictions.td.TDErrorMonitor;
import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambdaAutostep;
import rlpark.plugin.rltoys.horde.functions.RewardFunction;
import rlpark.plugin.rltoys.utils.NotImplemented;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

/* loaded from: input_file:rlpark/plugin/rltoys/horde/demons/PredictionDemonVerifier.class */
public class PredictionDemonVerifier implements Serializable {
    private static final long serialVersionUID = 6127406364376542150L;
    private final PredictionDemon predictionDemon;
    private final RewardFunction rewardFunction;

    @Monitor
    private final TDErrorMonitor errorMonitor;

    public PredictionDemonVerifier(PredictionDemon predictionDemon) {
        this(extractGamma(predictionDemon.predicter()), predictionDemon);
    }

    public PredictionDemonVerifier(double d, PredictionDemon predictionDemon) {
        this(d, predictionDemon, 0.01d);
    }

    public PredictionDemonVerifier(double d, PredictionDemon predictionDemon, double d2) {
        this.predictionDemon = predictionDemon;
        this.rewardFunction = predictionDemon.rewardFunction();
        this.errorMonitor = new TDErrorMonitor(d, d2);
    }

    public static double extractGamma(OnPolicyTD onPolicyTD) {
        if (onPolicyTD instanceof TD) {
            return ((TD) onPolicyTD).gamma();
        }
        if (onPolicyTD instanceof TDLambdaAutostep) {
            return ((TDLambdaAutostep) onPolicyTD).gamma();
        }
        throw new NotImplemented();
    }

    public TDErrorMonitor errorMonitor() {
        return this.errorMonitor;
    }

    public double update(boolean z) {
        return this.errorMonitor.update(this.predictionDemon.prediction(), this.rewardFunction.reward(), z);
    }
}
