package rlpark.plugin.rltoys.algorithms.predictions.supervised;

import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/predictions/supervised/K1.class */
public class K1 implements LearningAlgorithm {
    private static final long serialVersionUID = 2943574757813500087L;
    private final double theta;
    private final PVector weights;
    private final PVector alphas;
    private final PVector betas;
    private final PVector hs;
    private double delta;
    private double prediction;

    public K1(int i, double d) {
        this.theta = d;
        this.weights = new PVector(i);
        this.betas = new PVector(i);
        this.betas.set(Math.log(0.1d));
        this.alphas = new PVector(i);
        this.hs = new PVector(i);
    }

    @Override // rlpark.plugin.rltoys.algorithms.predictions.supervised.LearningAlgorithm
    public double learn(RealVector realVector, double d) {
        PVector pVector = (PVector) realVector;
        this.prediction = predict(pVector);
        this.delta = d - this.prediction;
        double d2 = 0.0d;
        for (int i = 0; i < this.weights.size; i++) {
            double[] dArr = this.betas.data;
            int i2 = i;
            dArr[i2] = dArr[i2] + (this.theta * this.delta * pVector.data[i] * this.hs.data[i]);
            this.alphas.data[i] = Math.exp(this.betas.data[i]);
            d2 += this.alphas.data[i] * pVector.data[i] * pVector.data[i];
        }
        for (int i3 = 0; i3 < this.weights.size; i3++) {
            double d3 = this.alphas.data[i3] / (1.0d + d2);
            double[] dArr2 = this.weights.data;
            int i4 = i3;
            dArr2[i4] = dArr2[i4] + (d3 * this.delta * pVector.data[i3]);
            this.hs.data[i3] = (this.hs.data[i3] + (d3 * this.delta * pVector.data[i3])) * Math.max(0.0d, 1.0d - ((d3 * pVector.data[i3]) * pVector.data[i3]));
        }
        return this.delta;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.Predictor
    public double predict(RealVector realVector) {
        return this.weights.dotProduct(realVector);
    }

    public RealVector alphas() {
        return this.alphas;
    }

    public RealVector h() {
        return this.hs;
    }
}
