package rlpark.example.demos.learning;

import java.util.Random;
import rlpark.plugin.rltoys.agents.functions.FunctionProjected2D;
import rlpark.plugin.rltoys.agents.functions.ValueFunction2D;
import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentDirect;
import rlpark.plugin.rltoys.agents.offpolicy.OffPolicyAgentEvaluable;
import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.ActorLambdaOffPolicy;
import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.CriticAdapterFA;
import rlpark.plugin.rltoys.algorithms.control.actorcritic.offpolicy.OffPAC;
import rlpark.plugin.rltoys.algorithms.functions.ContinuousFunction;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers.RandomPolicy;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.BoltzmannDistribution;
import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
import rlpark.plugin.rltoys.algorithms.functions.states.Projector;
import rlpark.plugin.rltoys.algorithms.predictions.td.GTDLambda;
import rlpark.plugin.rltoys.algorithms.predictions.td.OffPolicyTD;
import rlpark.plugin.rltoys.algorithms.representations.discretizer.TabularActionDiscretizer;
import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.AbstractPartitionFactory;
import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.BoundedSmallPartitionFactory;
import rlpark.plugin.rltoys.algorithms.representations.tilescoding.StateActionCoders;
import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCoders;
import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersHashing;
import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.Hashing;
import rlpark.plugin.rltoys.algorithms.representations.tilescoding.hashing.MurmurHashing;
import rlpark.plugin.rltoys.algorithms.traces.ATraces;
import rlpark.plugin.rltoys.envio.policy.Policy;
import rlpark.plugin.rltoys.experiments.runners.AbstractRunner;
import rlpark.plugin.rltoys.experiments.runners.Runner;
import rlpark.plugin.rltoys.experiments.scheduling.network.ServerScheduler;
import rlpark.plugin.rltoys.math.ranges.Range;
import rlpark.plugin.rltoys.problems.puddleworld.ConstantFunction;
import rlpark.plugin.rltoys.problems.puddleworld.LocalFeatureSumFunction;
import rlpark.plugin.rltoys.problems.puddleworld.PuddleWorld;
import rlpark.plugin.rltoys.problems.puddleworld.SmoothPuddle;
import rlpark.plugin.rltoys.problems.puddleworld.TargetReachedL1NormTermination;
import zephyr.plugin.core.api.Zephyr;
import zephyr.plugin.core.api.monitoring.abstracts.Monitored;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;
import zephyr.plugin.core.api.signals.Listener;
import zephyr.plugin.core.api.synchronization.Clock;

@Monitor
/* loaded from: input_file:rlpark/example/demos/learning/OffPACPuddleWorld.class */
public class OffPACPuddleWorld implements Runnable {
    private final Runner learningRunner;
    private final Runner evaluationRunner;
    final FunctionProjected2D valueFunction;
    private final Random random = new Random(0);
    private final PuddleWorld behaviourEnvironment = createEnvironment(this.random);
    private final PuddleWorld evaluationEnvironment = createEnvironment(this.random);
    final Clock clock = new Clock("Off-PAC Demo");
    final Clock episodeClock = new Clock("Episodes");

    public OffPACPuddleWorld() {
        OffPolicyAgentEvaluable createOffPACAgent = createOffPACAgent(this.random, this.behaviourEnvironment, new RandomPolicy(this.random, this.behaviourEnvironment.actions()), 0.99d);
        this.learningRunner = new Runner(this.behaviourEnvironment, createOffPACAgent, ServerScheduler.DefaultPort, -1);
        this.evaluationRunner = new Runner(this.evaluationEnvironment, createOffPACAgent.createEvaluatedAgent(), ServerScheduler.DefaultPort, -1);
        CriticAdapterFA criticAdapterFA = (CriticAdapterFA) ((OffPolicyAgentDirect) createOffPACAgent).learner().predictor();
        this.valueFunction = new ValueFunction2D(criticAdapterFA.projector(), this.behaviourEnvironment, criticAdapterFA.predictor());
        connectEpisodesEventsForZephyr();
        Zephyr.advertise(this.clock, this);
    }

    private void connectEpisodesEventsForZephyr() {
        final double[] dArr = new double[2];
        this.evaluationRunner.onEpisodeEnd.connect(new Listener<AbstractRunner.RunnerEvent>() { // from class: rlpark.example.demos.learning.OffPACPuddleWorld.1
            @Override // zephyr.plugin.core.api.signals.Listener
            public void listen(AbstractRunner.RunnerEvent runnerEvent) {
                dArr[0] = runnerEvent.step.time;
                dArr[1] = runnerEvent.episodeReward;
                OffPACPuddleWorld.this.episodeClock.tick();
                System.out.println(String.format("Episodes %d: %d, %f", Integer.valueOf(runnerEvent.nbEpisodeDone), Long.valueOf(runnerEvent.step.time), Double.valueOf(runnerEvent.episodeReward)));
            }
        });
        Zephyr.advertise(this.episodeClock, new Monitored() { // from class: rlpark.example.demos.learning.OffPACPuddleWorld.2
            @Override // zephyr.plugin.core.api.monitoring.abstracts.Monitored
            public double monitoredValue() {
                return dArr[0];
            }
        }, "length");
        Zephyr.advertise(this.episodeClock, new Monitored() { // from class: rlpark.example.demos.learning.OffPACPuddleWorld.3
            @Override // zephyr.plugin.core.api.monitoring.abstracts.Monitored
            public double monitoredValue() {
                return dArr[1];
            }
        }, "reward");
    }

    private static Hashing createHashing(Random random) {
        return new MurmurHashing(random, 1000000);
    }

    private static void setTileCoders(TileCoders tileCoders) {
        tileCoders.addFullTilings(10, 10);
        tileCoders.includeActiveFeature();
    }

    private static AbstractPartitionFactory createPartitionFactory(Random random, Range[] rangeArr) {
        BoundedSmallPartitionFactory boundedSmallPartitionFactory = new BoundedSmallPartitionFactory(rangeArr);
        boundedSmallPartitionFactory.setRandom(random, 0.2d);
        return boundedSmallPartitionFactory;
    }

    public static Projector createProjector(Random random, PuddleWorld puddleWorld) {
        Range[] observationRanges = puddleWorld.getObservationRanges();
        TileCodersHashing tileCodersHashing = new TileCodersHashing(createHashing(random), createPartitionFactory(random, observationRanges), observationRanges.length);
        setTileCoders(tileCodersHashing);
        return tileCodersHashing;
    }

    public static StateToStateAction createToStateAction(Random random, PuddleWorld puddleWorld) {
        Range[] observationRanges = puddleWorld.getObservationRanges();
        StateActionCoders stateActionCoders = new StateActionCoders(new TabularActionDiscretizer(puddleWorld.actions()), createHashing(random), createPartitionFactory(random, observationRanges), observationRanges.length);
        setTileCoders(stateActionCoders.tileCoders());
        return stateActionCoders;
    }

    private OffPolicyTD createCritic(Projector projector, double d) {
        return new CriticAdapterFA(projector, new GTDLambda(0.4d, d, 0.1d / projector.vectorNorm(), 0.0d, projector.vectorSize(), new ATraces()));
    }

    private OffPolicyAgentEvaluable createOffPACAgent(Random random, PuddleWorld puddleWorld, Policy policy, double d) {
        Projector createProjector = createProjector(random, puddleWorld);
        OffPolicyTD createCritic = createCritic(createProjector, d);
        StateToStateAction createToStateAction = createToStateAction(random, puddleWorld);
        return new OffPolicyAgentDirect(policy, new OffPAC(policy, createCritic, new ActorLambdaOffPolicy(0.4d, d, new BoltzmannDistribution(random, puddleWorld.actions(), createToStateAction), 0.001d / createProjector.vectorNorm(), createToStateAction.vectorSize(), new ATraces())));
    }

    private static PuddleWorld createEnvironment(Random random) {
        PuddleWorld puddleWorld = new PuddleWorld(random, 2, new Range(0.0d, 1.0d), new Range(-0.05d, 0.05d), 0.1d);
        int[] iArr = {0, 1};
        puddleWorld.setRewardFunction(new LocalFeatureSumFunction(new double[]{-1.0d, -2.0d, -2.0d, -2.0d}, new ContinuousFunction[]{new ConstantFunction(1.0d), new SmoothPuddle(iArr, new double[]{0.3d, 0.6d}, new double[]{0.1d, 0.03d}), new SmoothPuddle(iArr, new double[]{0.4d, 0.5d}, new double[]{0.03d, 0.1d}), new SmoothPuddle(iArr, new double[]{0.8d, 0.9d}, new double[]{0.03d, 0.1d})}, 0.0d));
        puddleWorld.setTermination(new TargetReachedL1NormTermination(new double[]{1.0d, 1.0d}, 0.1d));
        puddleWorld.setStart(new double[]{0.2d, 0.4d});
        return puddleWorld;
    }

    @Override // java.lang.Runnable
    public void run() {
        while (this.clock.tick()) {
            this.learningRunner.step();
            this.evaluationRunner.step();
        }
    }

    public static void main(String[] strArr) {
        new OffPACPuddleWorld().run();
    }
}
