package rlpark.plugin.rltoys.problems.stategraph;

import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math.linear.Array2DRowRealMatrix;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.RealMatrix;
import rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction;
import rlpark.plugin.rltoys.envio.actions.Action;
import rlpark.plugin.rltoys.envio.policy.Policy;
import rlpark.plugin.rltoys.math.vector.RealVector;
import rlpark.plugin.rltoys.math.vector.implementations.PVector;
import rlpark.plugin.rltoys.problems.stategraph.FiniteStateGraph;

/* loaded from: input_file:rlpark/plugin/rltoys/problems/stategraph/FSGAgentState.class */
public class FSGAgentState implements StateToStateAction {
    private static final long serialVersionUID = -6312948577339609928L;
    private final Map<GraphState, Integer> stateIndexes;
    private final FiniteStateGraph graph;
    static final /* synthetic */ boolean $assertionsDisabled;
    public final int size = nbNonAbsorbingState();
    private final PVector featureState = new PVector(this.size);

    public FSGAgentState(FiniteStateGraph finiteStateGraph) {
        this.graph = finiteStateGraph;
        this.stateIndexes = indexStates(finiteStateGraph.states());
    }

    private Map<GraphState, Integer> indexStates(GraphState[] graphStateArr) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        int i = 0;
        for (GraphState graphState : graphStateArr) {
            if (graphState.hasNextState()) {
                linkedHashMap.put(graphState, Integer.valueOf(i));
                i++;
            }
        }
        return linkedHashMap;
    }

    public FiniteStateGraph.StepData step() {
        FiniteStateGraph.StepData step = this.graph.step();
        if (step.s_t != null && step.s_t.hasNextState()) {
            this.featureState.data[this.stateIndexes.get(step.s_t).intValue()] = 0.0d;
        }
        if (step.s_tp1 != null && step.s_tp1.hasNextState()) {
            this.featureState.data[this.stateIndexes.get(step.s_tp1).intValue()] = 1.0d;
        }
        return step;
    }

    public PVector currentFeatureState() {
        return this.graph.currentState() == null ? new PVector(this.size) : this.featureState;
    }

    private RealMatrix createIdentityMatrix(int i) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            array2DRowRealMatrix.setEntry(i2, i2, 1.0d);
        }
        return array2DRowRealMatrix;
    }

    public RealMatrix createPhi() {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(nbStates(), nbNonAbsorbingState());
        for (int i = 0; i < nbStates(); i++) {
            array2DRowRealMatrix.setRow(i, getFeatureVector(states()[i]).data);
        }
        return array2DRowRealMatrix;
    }

    private PVector getFeatureVector(GraphState graphState) {
        PVector pVector = new PVector(nbNonAbsorbingState());
        int i = 0;
        for (int i2 = 0; i2 < nbStates(); i2++) {
            GraphState graphState2 = states()[i2];
            if (graphState2.hasNextState()) {
                if (graphState2 == graphState) {
                    pVector.data[i] = 1.0d;
                }
                i++;
            }
        }
        return pVector;
    }

    public double[] computeSolution(Policy policy, double d, double d2) {
        RealMatrix createPhi = createPhi();
        RealMatrix createTransitionProbablityMatrix = createTransitionProbablityMatrix(policy);
        RealMatrix createStateDistributionMatrix = createStateDistributionMatrix(createStateDistribution(createTransitionProbablityMatrix));
        RealMatrix computePLambda = computePLambda(createTransitionProbablityMatrix, d, d2);
        ArrayRealVector computeAverageReward = computeAverageReward(createTransitionProbablityMatrix);
        RealMatrix computeA = computeA(createPhi, createStateDistributionMatrix, d, computePLambda);
        return new LUDecompositionImpl(computeA).getSolver().getInverse().scalarMultiply(-1.0d).operate(computeB(createPhi, createStateDistributionMatrix, createTransitionProbablityMatrix, computeAverageReward, d, d2)).getData();
    }

    private ArrayRealVector computeB(RealMatrix realMatrix, RealMatrix realMatrix2, RealMatrix realMatrix3, ArrayRealVector arrayRealVector, double d, double d2) {
        return (ArrayRealVector) realMatrix.transpose().operate(realMatrix2.operate(computeIdMinusGammaLambdaP(realMatrix3, d, d2).operate(arrayRealVector)));
    }

    private RealMatrix computeA(RealMatrix realMatrix, RealMatrix realMatrix2, double d, RealMatrix realMatrix3) {
        return realMatrix.transpose().multiply(realMatrix2.multiply(realMatrix3.scalarMultiply(d).subtract(createIdentityMatrix(realMatrix.getRowDimension())).multiply(realMatrix)));
    }

    private ArrayRealVector computeAverageReward(RealMatrix realMatrix) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(realMatrix.getColumnDimension());
        for (int i = 0; i < nbStates(); i++) {
            if (states()[i].hasNextState()) {
                double d = 0.0d;
                for (int i2 = 0; i2 < nbStates(); i2++) {
                    d += realMatrix.getEntry(i, i2) * states()[i2].reward.doubleValue();
                }
                arrayRealVector.setEntry(i, d);
            }
        }
        return arrayRealVector;
    }

    private RealMatrix computePLambda(RealMatrix realMatrix, double d, double d2) {
        return computeIdMinusGammaLambdaP(realMatrix, d, d2).multiply(realMatrix).scalarMultiply(1.0d - d2);
    }

    private RealMatrix computeIdMinusGammaLambdaP(RealMatrix realMatrix, double d, double d2) {
        return new LUDecompositionImpl(createIdentityMatrix(realMatrix.getColumnDimension()).subtract(realMatrix.scalarMultiply(d2 * d))).getSolver().getInverse();
    }

    private RealMatrix createStateDistributionMatrix(ArrayRealVector arrayRealVector) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(nbStates(), nbStates());
        int i = 0;
        for (int i2 = 0; i2 < nbStates(); i2++) {
            if (states()[i2].hasNextState()) {
                array2DRowRealMatrix.setEntry(i2, i2, arrayRealVector.getEntry(i));
                i++;
            }
        }
        return array2DRowRealMatrix;
    }

    private ArrayRealVector createStateDistribution(RealMatrix realMatrix) {
        RealMatrix removeColumnAndRow = removeColumnAndRow(realMatrix.copy(), absorbingStatesSet());
        if (!$assertionsDisabled && removeColumnAndRow.getColumnDimension() != removeColumnAndRow.getRowDimension()) {
            throw new AssertionError();
        }
        RealMatrix multiply = createInitialStateDistribution().multiply(new LUDecompositionImpl(createIdentityMatrix(removeColumnAndRow.getColumnDimension()).subtract(removeColumnAndRow)).getSolver().getInverse());
        double d = 0.0d;
        for (int i = 0; i < multiply.getColumnDimension(); i++) {
            d += multiply.getEntry(0, i);
        }
        return (ArrayRealVector) multiply.scalarMultiply(1.0d / d).getRowVector(0);
    }

    private Set<Integer> absorbingStatesSet() {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i = 0; i < nbStates(); i++) {
            if (!states()[i].hasNextState()) {
                linkedHashSet.add(Integer.valueOf(i));
            }
        }
        return linkedHashSet;
    }

    private int nbNonAbsorbingState() {
        return this.stateIndexes.size();
    }

    private RealMatrix removeColumnAndRow(RealMatrix realMatrix, Set<Integer> set) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(nbNonAbsorbingState(), nbNonAbsorbingState());
        int i = 0;
        for (int i2 = 0; i2 < realMatrix.getRowDimension(); i2++) {
            if (!set.contains(Integer.valueOf(i2))) {
                int i3 = 0;
                for (int i4 = 0; i4 < realMatrix.getColumnDimension(); i4++) {
                    if (!set.contains(Integer.valueOf(i4))) {
                        array2DRowRealMatrix.setEntry(i, i3, realMatrix.getEntry(i2, i4));
                        i3++;
                    }
                }
                i++;
            }
        }
        return array2DRowRealMatrix;
    }

    private RealMatrix createInitialStateDistribution() {
        double[] dArr = new double[nbNonAbsorbingState()];
        int i = 0;
        for (int i2 = 0; i2 < nbStates(); i2++) {
            GraphState graphState = states()[i2];
            if (graphState.hasNextState()) {
                if (graphState != this.graph.initialState()) {
                    dArr[i] = 0.0d;
                } else {
                    dArr[i] = 1.0d;
                }
                i++;
            }
        }
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(1, dArr.length);
        for (int i3 = 0; i3 < dArr.length; i3++) {
            array2DRowRealMatrix.setEntry(0, i3, dArr[i3]);
        }
        return array2DRowRealMatrix;
    }

    private RealMatrix createTransitionProbablityMatrix(Policy policy) {
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(nbStates(), nbStates());
        for (int i = 0; i < nbStates(); i++) {
            GraphState graphState = states()[i];
            policy.update(graphState.v());
            for (Action action : this.graph.actions()) {
                double pi = policy.pi(action);
                GraphState nextState = graphState.nextState(action);
                if (nextState != null) {
                    array2DRowRealMatrix.setEntry(i, this.graph.indexOf(nextState), pi);
                }
            }
        }
        for (Integer num : absorbingStatesSet()) {
            array2DRowRealMatrix.setEntry(num.intValue(), num.intValue(), 1.0d);
        }
        return array2DRowRealMatrix;
    }

    private int nbStates() {
        return this.graph.nbStates();
    }

    private GraphState[] states() {
        return this.graph.states();
    }

    public Map<GraphState, Integer> stateIndexes() {
        return this.stateIndexes;
    }

    public FiniteStateGraph graph() {
        return this.graph;
    }

    public PVector featureState(GraphState graphState) {
        PVector pVector = new PVector(this.size);
        if (graphState != null && graphState.hasNextState()) {
            pVector.data[this.stateIndexes.get(graphState).intValue()] = 1.0d;
        }
        return pVector;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction
    public PVector stateAction(RealVector realVector, Action action) {
        PVector pVector = new PVector(nbNonAbsorbingState() * this.graph.actions().length);
        if (realVector == null) {
            return pVector;
        }
        GraphState state = this.graph.state(realVector);
        for (int i = 0; i < this.graph.actions().length; i++) {
            if (this.graph.actions()[i] == action) {
                pVector.setEntry((i * nbNonAbsorbingState()) + this.stateIndexes.get(state).intValue(), 1.0d);
                return pVector;
            }
        }
        return null;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction
    public int vectorSize() {
        return this.size;
    }

    @Override // rlpark.plugin.rltoys.algorithms.functions.stateactions.StateToStateAction
    public double vectorNorm() {
        return 1.0d;
    }

    static {
        $assertionsDisabled = !FSGAgentState.class.desiredAssertionStatus();
    }
}
