package rlpark.plugin.rltoys.algorithms.discovery.ltu;

import java.util.Comparator;
import java.util.HashSet;
import java.util.Set;
import rlpark.plugin.rltoys.algorithms.LinearLearner;
import rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter;
import rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetwork;
import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.utils.Utils;
import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/discovery/ltu/RecursiveWeightSorter.class */
public class RecursiveWeightSorter extends WeightSorter {
    private static final long serialVersionUID = -654469131883608071L;

    @IgnoreMonitor
    protected final RandomNetwork network;
    private final int nbMaxParents;

    @Monitor(level = 4)
    private final PVector recursiveSum;
    private final double discount;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RecursiveWeightSorter(RandomNetwork randomNetwork, LinearLearner[] linearLearnerArr, int i) {
        super(linearLearnerArr);
        if (!$assertionsDisabled && randomNetwork.inputSize <= randomNetwork.outputSize) {
            throw new AssertionError();
        }
        this.network = randomNetwork;
        this.nbMaxParents = i;
        this.recursiveSum = new PVector(this.sums.size);
        this.discount = Utils.timeStepsToDiscount(i);
    }

    @Override // rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter
    protected Comparator<Integer> createComparator() {
        return new WeightSorter.PVectorBasedComparator(this.recursiveSum) { // from class: rlpark.plugin.rltoys.algorithms.discovery.ltu.RecursiveWeightSorter.1
            private static final long serialVersionUID = 4220495235775683757L;
            private final int maxSort;

            {
                this.maxSort = RecursiveWeightSorter.this.network.outputSize;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter.PVectorBasedComparator, java.util.Comparator
            public int compare(Integer num, Integer num2) {
                if (num.intValue() >= this.maxSort && num2.intValue() < this.maxSort) {
                    return 1;
                }
                if (num2.intValue() < this.maxSort || num.intValue() >= this.maxSort) {
                    return super.compare(num, num2);
                }
                return -1;
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // rlpark.plugin.rltoys.algorithms.discovery.sorting.WeightSorter
    public void updateUnitEvaluation() {
        super.updateUnitEvaluation();
        this.recursiveSum.set(0.0d);
        for (int i = 0; i < this.network.outputSize; i++) {
            this.recursiveSum.data[i] = computeRecursiveWeights(new HashSet(), this.network.ltu(i), 1.0d);
        }
    }

    private double computeRecursiveWeights(Set<LTU> set, LTU ltu, double d) {
        if (!set.add(ltu)) {
            return 0.0d;
        }
        double d2 = this.sums.data[ltu.index()] * d;
        if (set.size() >= this.nbMaxParents) {
            return d2;
        }
        double d3 = d * this.discount;
        for (LTU ltu2 : this.network.parents(ltu.index())) {
            d2 += computeRecursiveWeights(set, ltu2, d3);
        }
        return d2;
    }

    static {
        $assertionsDisabled = !RecursiveWeightSorter.class.desiredAssertionStatus();
    }
}
