package rlpark.plugin.rltoys.algorithms.representations.ltu.networks;

import java.util.LinkedHashSet;
import java.util.Random;
import java.util.Set;
import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU;
import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTUAdaptiveDensity;
import rlpark.plugin.rltoys.math.vector.BinaryVector;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/representations/ltu/networks/AutoRegulatedNetwork.class */
public class AutoRegulatedNetwork extends RandomNetwork {
    private static final long serialVersionUID = 1847556584654367004L;
    public final int minUnitActive;
    public final int maxUnitActive;
    public final double minDensity;
    public final double maxDensity;
    private final Random random;
    private int missingUnit;
    private int overUnit;
    static final /* synthetic */ boolean $assertionsDisabled;

    public AutoRegulatedNetwork(Random random, int i, int i2, double d, double d2) {
        super(i, i2);
        this.random = random;
        this.minDensity = d;
        this.maxDensity = d2;
        this.minUnitActive = (int) (d * i2);
        this.maxUnitActive = (int) (d2 * i2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // rlpark.plugin.rltoys.algorithms.representations.ltu.networks.RandomNetwork
    public void postProjection(BinaryVector binaryVector) {
        this.missingUnit = 0;
        this.overUnit = 0;
        int nonZeroElements = this.output.nonZeroElements();
        if (nonZeroElements > this.maxUnitActive) {
            decreaseDensity(binaryVector);
        }
        if (nonZeroElements < this.minUnitActive) {
            increaseDensity(binaryVector);
        }
        super.postProjection(binaryVector);
    }

    private void increaseDensity(BinaryVector binaryVector) {
        this.missingUnit = this.minUnitActive - this.output.nonZeroElements();
        Set<LTUAdaptiveDensity> buildCouldHaveAgreeUnits = buildCouldHaveAgreeUnits(binaryVector);
        if (!$assertionsDisabled && this.missingUnit <= 0) {
            throw new AssertionError();
        }
        double min = Math.min(1.0d, this.missingUnit / buildCouldHaveAgreeUnits.size());
        for (LTUAdaptiveDensity lTUAdaptiveDensity : buildCouldHaveAgreeUnits) {
            if (this.random.nextFloat() <= min) {
                lTUAdaptiveDensity.increaseDensity(this.random, this.denseInputVector);
            }
        }
    }

    private Set<LTUAdaptiveDensity> buildCouldHaveAgreeUnits(BinaryVector binaryVector) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i : binaryVector.getActiveIndexes()) {
            for (LTU ltu : parents(i)) {
                if (ltu != null && !ltu.isActive() && (ltu instanceof LTUAdaptiveDensity)) {
                    linkedHashSet.add((LTUAdaptiveDensity) ltu);
                }
            }
        }
        return linkedHashSet;
    }

    private void decreaseDensity(BinaryVector binaryVector) {
        LTU ltu;
        this.overUnit = this.output.nonZeroElements() - this.maxUnitActive;
        if (!$assertionsDisabled && this.overUnit <= 0) {
            throw new AssertionError();
        }
        double nonZeroElements = this.overUnit / this.output.nonZeroElements();
        for (int i : this.output.getActiveIndexes()) {
            if (this.random.nextFloat() <= nonZeroElements && (ltu = this.ltus[i]) != null && (ltu instanceof LTUAdaptiveDensity)) {
                ((LTUAdaptiveDensity) ltu).decreaseDensity(this.random, this.denseInputVector);
            }
        }
    }

    static {
        $assertionsDisabled = !AutoRegulatedNetwork.class.desiredAssertionStatus();
    }
}
