package rlpark.plugin.rltoys.problems.noisyinputsum;

import java.util.Random;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.problems.PredictionProblem;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

/* loaded from: input_file:rlpark/plugin/rltoys/problems/noisyinputsum/NoisyInputSum.class */
public class NoisyInputSum implements PredictionProblem {
    private final Random random;
    private int nbSteps;

    @Monitor(level = 4)
    private final PVector weights;

    @Monitor(level = 4)
    private final PVector inputs;

    @Monitor
    private double target;
    private final int nbChangingWeights;
    private int changePeriod;

    public NoisyInputSum(Random random, int i, int i2) {
        this(random, i, i, i2);
    }

    public NoisyInputSum(Random random, int i, int i2, int i3) {
        this.nbSteps = 0;
        this.changePeriod = 20;
        this.random = random;
        this.nbChangingWeights = i;
        this.weights = createWeights(random, i2, i3);
        this.inputs = new PVector(i3);
    }

    private PVector createWeights(Random random, int i, int i2) {
        PVector pVector = new PVector(i2);
        for (int i3 = 0; i3 < pVector.size; i3++) {
            if (i3 < i) {
                pVector.data[i3] = random.nextBoolean() ? 1.0d : -1.0d;
            } else {
                pVector.data[i3] = 0.0d;
            }
        }
        return pVector;
    }

    private void changeWeight() {
        double[] dArr = this.weights.data;
        int nextInt = this.random.nextInt(this.nbChangingWeights);
        dArr[nextInt] = dArr[nextInt] * (-1.0d);
    }

    @Override // rlpark.plugin.rltoys.problems.PredictionProblem
    public boolean update() {
        this.nbSteps++;
        if (this.nbSteps % this.changePeriod == 0) {
            changeWeight();
        }
        for (int i = 0; i < this.inputs.size; i++) {
            this.inputs.data[i] = this.random.nextGaussian();
        }
        this.target = this.weights.dotProduct(this.inputs);
        return true;
    }

    @Override // rlpark.plugin.rltoys.problems.PredictionProblem
    public RealVector input() {
        return this.inputs;
    }

    @Override // rlpark.plugin.rltoys.problems.PredictionProblem
    public double target() {
        return this.target;
    }

    public void setChangePeriod(int i) {
        this.changePeriod = i;
    }

    @Override // rlpark.plugin.rltoys.problems.PredictionProblem
    public int inputDimension() {
        return this.inputs.getDimension();
    }

    public PVector weights() {
        return this.weights;
    }
}
