package rlpark.plugin.rltoys.algorithms.functions.policydistributions.structures;

import java.util.ArrayList;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution;
import rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.actions.ActionArray;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import zephyr.plugin.core.api.monitoring.annotations.IgnoreMonitor;
import zephyr.plugin.core.api.monitoring.annotations.Monitor;

@Monitor
/* loaded from: input_file:rlpark/plugin/rltoys/algorithms/functions/policydistributions/structures/JointDistribution.class */
public class JointDistribution implements PolicyParameterized, BoundedPdf {
    private static final long serialVersionUID = -7545331400083047916L;
    protected final PolicyDistribution[] distributions;

    @IgnoreMonitor
    private int[] weightsToAction;

    public JointDistribution(PolicyDistribution[] policyDistributionArr) {
        this.distributions = policyDistributionArr;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public double pi(Action action) {
        double d = 1.0d;
        for (int i = 0; i < this.distributions.length; i++) {
            d *= this.distributions[i].pi(ActionArray.getDim(action, i));
        }
        return d;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public ActionArray sampleAction() {
        ArrayList<ActionArray> arrayList = new ArrayList();
        int i = 0;
        for (PolicyDistribution policyDistribution : this.distributions) {
            ActionArray actionArray = (ActionArray) policyDistribution.sampleAction();
            i += actionArray.actions.length;
            arrayList.add(actionArray);
        }
        double[] dArr = new double[i];
        int i2 = 0;
        for (ActionArray actionArray2 : arrayList) {
            System.arraycopy(actionArray2.actions, 0, dArr, i2, actionArray2.actions.length);
            i2 += actionArray2.actions.length;
        }
        return new ActionArray(dArr);
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution
    public PVector[] createParameters(int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < this.distributions.length; i2++) {
            for (PVector pVector : this.distributions[i2].createParameters(i)) {
                arrayList.add(pVector);
                arrayList2.add(Integer.valueOf(i2));
            }
        }
        PVector[] pVectorArr = new PVector[arrayList.size()];
        arrayList.toArray(pVectorArr);
        this.weightsToAction = new int[arrayList.size()];
        for (int i3 = 0; i3 < this.weightsToAction.length; i3++) {
            this.weightsToAction[i3] = ((Integer) arrayList2.get(i3)).intValue();
        }
        return pVectorArr;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution
    public RealVector[] computeGradLog(Action action) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.distributions.length; i++) {
            for (RealVector realVector : this.distributions[i].computeGradLog(ActionArray.getDim(action, i))) {
                arrayList.add(realVector);
            }
        }
        RealVector[] realVectorArr = new RealVector[arrayList.size()];
        arrayList.toArray(realVectorArr);
        return realVectorArr;
    }

    public int weightsIndexToActionIndex(int i) {
        return this.weightsToAction[i];
    }

    public PolicyDistribution policy(int i) {
        return this.distributions[i];
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyDistribution
    public int nbParameterVectors() {
        int i = 0;
        for (PolicyDistribution policyDistribution : this.distributions) {
            i += policyDistribution.nbParameterVectors();
        }
        return i;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.BoundedPdf
    public double piMax() {
        double d = 1.0d;
        for (PolicyDistribution policyDistribution : this.distributions) {
            d *= ((BoundedPdf) policyDistribution).piMax();
        }
        return d;
    }

    @Override // rlpark.plugin.rltoys.envio.policy.Policy
    public void update(RealVector realVector) {
        for (PolicyDistribution policyDistribution : this.distributions) {
            policyDistribution.update(realVector);
        }
    }

    public PolicyDistribution[] policies() {
        return this.distributions;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized
    public void setParameters(PVector... pVectorArr) {
        int i = 0;
        for (PolicyDistribution policyDistribution : this.distributions) {
            PVector[] pVectorArr2 = new PVector[policyDistribution.nbParameterVectors()];
            System.arraycopy(pVectorArr, i, pVectorArr2, 0, pVectorArr2.length);
            ((PolicyParameterized) policyDistribution).setParameters(pVectorArr2);
            i += pVectorArr2.length;
        }
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.policydistributions.PolicyParameterized
    public PVector[] parameters() {
        PVector[] pVectorArr = new PVector[nbParameterVectors()];
        int i = 0;
        for (PolicyDistribution policyDistribution : this.distributions) {
            System.arraycopy(((PolicyParameterized) policyDistribution).parameters(), 0, pVectorArr, i, policyDistribution.nbParameterVectors());
            i += policyDistribution.nbParameterVectors();
        }
        return pVectorArr;
    }
}
