package rlpark.plugin.rltoys.algorithms.representations.ltu.networks;

import java.io.Serializable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUArray;
import rlpark.plugin.rltoys.algorithms.representations.ltu.internal.LTUUpdated;
import rlpark.plugin.rltoys.algorithms.representations.ltu.units.LTU;
import rlpark.plugin.rltoys.math.vector.BinaryVector;
import rlpark.plugin.rltoys.utils.Scheduling;

/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/representations/ltu/networks/RandomNetworkScheduler.class */
public class RandomNetworkScheduler implements Serializable {
    private static final long serialVersionUID = -2515509378000478726L;
    private transient ExecutorService executor;
    private transient LTUSumUpdater[] sumUpdaters;
    private transient LTUActivationUpdater[] activationUpdaters;
    private transient Future<?>[] futurs;
    protected final int nbThread;
    BinaryVector obs;
    BinaryVector output;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:rlpark/plugin/rltoys/algorithms/representations/ltu/networks/RandomNetworkScheduler$LTUActivationUpdater.class */
    public class LTUActivationUpdater implements Runnable {
        private final int offset;
        private final LTU[] ltus;

        LTUActivationUpdater(RandomNetwork randomNetwork, int i) {
            this.offset = i;
            this.ltus = randomNetwork.ltus;
        }

        @Override // java.lang.Runnable
        public void run() {
            int i = this.offset;
            while (true) {
                int i2 = i;
                if (i2 >= this.ltus.length) {
                    return;
                }
                LTU ltu = this.ltus[i2];
                if (ltu != null && ltu.updateActivation()) {
                    RandomNetworkScheduler.this.setOutputOn(i2);
                }
                i = i2 + RandomNetworkScheduler.this.nbThread;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:rlpark/plugin/rltoys/algorithms/representations/ltu/networks/RandomNetworkScheduler$LTUSumUpdater.class */
    public class LTUSumUpdater implements Runnable {
        private final int offset;
        private final LTUArray[] connectedLTUs;
        private final LTUUpdated updatedLTUs;
        private final double[] denseInputVector;
        private final boolean[] updated;

        LTUSumUpdater(RandomNetwork randomNetwork, int i) {
            this.offset = i;
            this.connectedLTUs = randomNetwork.connectedLTUs;
            this.updatedLTUs = randomNetwork.updatedLTUs;
            this.updated = this.updatedLTUs.updated;
            this.denseInputVector = randomNetwork.denseInputVector;
        }

        @Override // java.lang.Runnable
        public void run() {
            int i = this.offset;
            int[] activeIndexes = RandomNetworkScheduler.this.obs.getActiveIndexes();
            while (i < activeIndexes.length) {
                updateConnectedLTU(this.connectedLTUs[activeIndexes[i]].array());
                i += RandomNetworkScheduler.this.nbThread;
            }
        }

        private void updateConnectedLTU(LTU[] ltuArr) {
            for (LTU ltu : ltuArr) {
                int index = ltu.index();
                if (!this.updated[index]) {
                    this.updatedLTUs.updateLTUSum(index, ltu, this.denseInputVector);
                }
            }
        }
    }

    public RandomNetworkScheduler() {
        this(Scheduling.getDefaultNbThreads());
    }

    public RandomNetworkScheduler(int i) {
        this.executor = null;
        this.nbThread = i;
    }

    private void initialize(RandomNetwork randomNetwork) {
        this.sumUpdaters = new LTUSumUpdater[this.nbThread];
        this.activationUpdaters = new LTUActivationUpdater[this.nbThread];
        for (int i = 0; i < this.nbThread; i++) {
            this.sumUpdaters[i] = new LTUSumUpdater(randomNetwork, i);
            this.activationUpdaters[i] = new LTUActivationUpdater(randomNetwork, i);
        }
        this.futurs = new Future[this.nbThread];
        this.executor = Scheduling.newFixedThreadPool("randomnetwork", this.nbThread);
    }

    public void update(RandomNetwork randomNetwork, BinaryVector binaryVector) {
        if (this.executor == null) {
            initialize(randomNetwork);
        }
        this.obs = binaryVector;
        this.output = randomNetwork.output;
        for (int i = 0; i < this.sumUpdaters.length; i++) {
            this.futurs[i] = this.executor.submit(this.sumUpdaters[i]);
        }
        waitWorkingThread();
        for (int i2 = 0; i2 < this.sumUpdaters.length; i2++) {
            this.futurs[i2] = this.executor.submit(this.activationUpdaters[i2]);
        }
        waitWorkingThread();
    }

    private void waitWorkingThread() {
        try {
            for (Future<?> future : this.futurs) {
                future.get();
            }
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (ExecutionException e2) {
            throw new RuntimeException(e2.getCause());
        }
    }

    final synchronized void setOutputOn(int i) {
        this.output.setOn(i);
    }

    public void dispose() {
        this.executor.shutdown();
        this.executor = null;
    }
}
