package rlpark.example.demos.learning;

import java.util.Random;
import rlpark.plugin.rltoys.agents.functions.FunctionProjected2D;
import rlpark.plugin.rltoys.agents.functions.ValueFunction2D;
import rlpark.plugin.rltoys.agents.rl.LearnerAgentFA;
import rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy.Actor;
import rlpark.plugin.rltoys.algorithms.control.actorcritic.onpolicy.AverageRewardActorCritic;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.helpers.ScaledPolicyDistribution;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures.NormalDistributionScaled;
import rlpark.plugin.rltoys.algorithms.predictions.td.TDLambda;
import rlpark.plugin.rltoys.algorithms.representations.discretizer.partitions.AbstractPartitionFactory;
import rlpark.plugin.rltoys.algorithms.representations.tilescoding.TileCodersNoHashing;
import rlpark.plugin.rltoys.experiments.runners.AbstractRunner;
import rlpark.plugin.rltoys.experiments.runners.Runner;
import rlpark.plugin.rltoys.math.ranges.Range;
import rlpark.plugin.rltoys.problems.pendulum.SwingPendulum;
import zephyr.plugin.core.api.Zephyr;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;
import zephyr.plugin.core.api.signals.Listener;
import zephyr.plugin.core.api.synchronization.Clock;

@Monitor
/* loaded from: input_file:rlpark/example/demos/learning/ActorCriticPendulum.class */
public class ActorCriticPendulum implements Runnable {
    final FunctionProjected2D valueFunction;
    double reward;
    private final SwingPendulum problem;
    private final Clock clock = new Clock("ActorCriticPendulum");
    private final LearnerAgentFA agent;
    private final Runner runner;

    public ActorCriticPendulum() {
        Random random = new Random(0L);
        this.problem = new SwingPendulum(null, false);
        TileCodersNoHashing tileCodersNoHashing = new TileCodersNoHashing(this.problem.getObservationRanges());
        ((AbstractPartitionFactory) tileCodersNoHashing.discretizerFactory()).setRandom(random, 0.2d);
        tileCodersNoHashing.addFullTilings(10, 10);
        double vectorNorm = tileCodersNoHashing.vectorNorm();
        int vectorSize = tileCodersNoHashing.vectorSize();
        TDLambda tDLambda = new TDLambda(0.5d, 1.0d, 0.1d / vectorNorm, vectorSize);
        this.agent = new LearnerAgentFA(new AverageRewardActorCritic(1.0E-4d, tDLambda, new Actor(new ScaledPolicyDistribution(new NormalDistributionScaled(random, 0.0d, 1.0d), new Range(-2.0d, 2.0d), this.problem.actionRanges()[0]), 0.001d / vectorNorm, vectorSize)), tileCodersNoHashing);
        this.valueFunction = new ValueFunction2D(tileCodersNoHashing, this.problem, tDLambda);
        this.runner = new Runner(this.problem, this.agent, 1000, -1);
        this.runner.onEpisodeEnd.connect(new Listener<AbstractRunner.RunnerEvent>() { // from class: rlpark.example.demos.learning.ActorCriticPendulum.1
            @Override // zephyr.plugin.core.api.signals.Listener
            public void listen(AbstractRunner.RunnerEvent runnerEvent) {
                System.out.println(String.format("Episode %d: %f", Integer.valueOf(runnerEvent.nbEpisodeDone), Double.valueOf(runnerEvent.episodeReward)));
            }
        });
        Zephyr.advertise(this.clock, this);
    }

    @Override // java.lang.Runnable
    public void run() {
        while (this.clock.tick()) {
            this.runner.step();
        }
    }

    public static void main(String[] strArr) {
        new ActorCriticPendulum().run();
    }
}
