package rlpark.plugin.rltoys.experiments.testing.predictions;

import org.junit.Assert;
import org.junit.Test;
import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD;
import rlpark.plugin.rltoys.experiments.testing.predictions.FiniteStateGraphOnPolicy;
import rlpark.plugin.rltoys.experiments.testing.predictions.RandomWalkOffPolicy;
import rlpark.plugin.rltoys.experiments.testing.results.TestingResult;

/* loaded from: input_file:rlpark/plugin/rltoys/experiments/testing/predictions/OffPolicyTests.class */
public abstract class OffPolicyTests extends OnPolicyTests {
    private static final double Gamma = 0.9d;

    @Test
    public void testOffPolicy() {
        for (RandomWalkOffPolicy.OffPolicyTDFactory offPolicyTDFactory : offPolicyTDFactory()) {
            testOffPolicy(0.0d, 0.2d, 0.5d, offPolicyTDFactory);
            testOffPolicy(0.0d, 0.5d, 0.2d, offPolicyTDFactory);
        }
    }

    @Test
    public void testOffPolicyWithLambda() {
        for (RandomWalkOffPolicy.OffPolicyTDFactory offPolicyTDFactory : offPolicyTDFactory()) {
            for (double d : lambdaValues()) {
                testOffPolicy(d, 0.2d, 0.5d, offPolicyTDFactory);
                testOffPolicy(d, 0.5d, 0.2d, offPolicyTDFactory);
            }
        }
    }

    protected RandomWalkOffPolicy.OffPolicyTDFactory[] offPolicyTDFactory() {
        FiniteStateGraphOnPolicy.OnPolicyTDFactory[] onPolicyFactories = onPolicyFactories();
        RandomWalkOffPolicy.OffPolicyTDFactory[] offPolicyTDFactoryArr = new RandomWalkOffPolicy.OffPolicyTDFactory[onPolicyFactories.length];
        for (int i = 0; i < offPolicyTDFactoryArr.length; i++) {
            final FiniteStateGraphOnPolicy.OnPolicyTDFactory onPolicyTDFactory = onPolicyFactories[i];
            offPolicyTDFactoryArr[i] = new RandomWalkOffPolicy.OffPolicyTDFactory() { // from class: rlpark.plugin.rltoys.experiments.testing.predictions.OffPolicyTests.1
                @Override // rlpark.plugin.rltoys.experiments.testing.predictions.RandomWalkOffPolicy.OffPolicyTDFactory
                public OffPolicyTD newTD(double d, double d2, double d3, int i2) {
                    return (OffPolicyTD) onPolicyTDFactory.create(d, d2, d3, i2);
                }
            };
        }
        return offPolicyTDFactoryArr;
    }

    private void testOffPolicy(double d, double d2, double d3, RandomWalkOffPolicy.OffPolicyTDFactory offPolicyTDFactory) {
        TestingResult<OffPolicyTD> testOffPolicyGTD = RandomWalkOffPolicy.testOffPolicyGTD(nbEpisodeMax(), precision(), d, 0.9d, d2, d3, offPolicyTDFactory);
        Assert.assertTrue(testOffPolicyGTD.message, testOffPolicyGTD.passed);
    }
}
