/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.random;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.stream.Collectors;
import org.opensha.commons.calc.WeightedSampler;
import org.opensha.commons.logicTree.Affects;
import org.opensha.commons.logicTree.DoesNotAffect;
import org.opensha.commons.logicTree.LogicTreeBranch;
import org.opensha.commons.logicTree.LogicTreeLevel;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemRupSet;
import org.opensha.sha.earthquake.faultSysSolution.modules.ClusterRuptures;
import org.opensha.sha.earthquake.faultSysSolution.modules.RuptureSubSetMappings;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.ClusterRupture;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.FaultSubsectionCluster;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.Jump;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.plausibility.impl.prob.JumpProbabilityCalc;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.plausibility.impl.prob.RuptureProbabilityCalc;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.NSHM23_ConstraintBuilder;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.NSHM23_SegmentationModels;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.SegmentationModelBranchNode;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.SupraSeisBValues;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.random.AbstractSamplingNode;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.random.BranchDependentSampler;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.random.BranchSamplingManager;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.random.RandomBValSampler;
import org.opensha.sha.faultSurface.FaultSection;

public class RandomSegModelSampler
implements BranchDependentSampler<RandomSegModelSampler> {
    private Map<UniqueJump, NSHM23_SegmentationModels> jumpsToModels;
    private boolean allowThroughCreeping;
    private double maxRupLength;
    private transient JumpProbabilityCalc probCalc;
    private transient RuptureProbabilityCalc.BinaryRuptureProbabilityCalc exclusionModel;
    private static CreepingTreatment CREEPING_TREATMENT_DEFAULT = CreepingTreatment.SAMPLE_WITHIN_INCLUSION_CHOICE;

    public RandomSegModelSampler(FaultSystemRupSet rupSet, WeightedSampler<NSHM23_SegmentationModels> enumSampler, CreepingTreatment creepingTreatment, WeightedSampler<NSHM23_SegmentationModels> creepingAllowedSampler, WeightedSampler<NSHM23_SegmentationModels> creepingExcludedSampler) {
        HashSet<UniqueJump> allJumps = new HashSet<UniqueJump>();
        int creepingParentID = -1;
        if (creepingTreatment != CreepingTreatment.FULLY_UNCORRELATED) {
            creepingParentID = NSHM23_ConstraintBuilder.findCreepingSection(rupSet);
            if (creepingTreatment == CreepingTreatment.SAMPLE_WITHIN_INCLUSION_CHOICE) {
                Preconditions.checkNotNull(creepingAllowedSampler);
                Preconditions.checkNotNull(creepingExcludedSampler);
            }
        }
        for (ClusterRupture rup : rupSet.requireModule(ClusterRuptures.class)) {
            for (Jump jump : rup.getJumpsIterable()) {
                allJumps.add(new UniqueJump(jump));
            }
        }
        ArrayList allJumpsList = new ArrayList(allJumps);
        Collections.sort(allJumpsList);
        this.jumpsToModels = new HashMap<UniqueJump, NSHM23_SegmentationModels>(allJumpsList.size());
        NSHM23_SegmentationModels creepingModel = enumSampler.nextItem();
        this.allowThroughCreeping = creepingModel.isIncludeRupturesThroughCreepingSect();
        this.maxRupLength = enumSampler.nextItem().getMaxRuptureLength();
        for (UniqueJump uniqueJump : allJumpsList) {
            if (uniqueJump.parent1 == creepingParentID || uniqueJump.parent2 == creepingParentID) {
                NSHM23_SegmentationModels model;
                switch (creepingTreatment.ordinal()) {
                    case 1: {
                        model = creepingModel;
                        break;
                    }
                    case 0: {
                        model = enumSampler.nextItem();
                        break;
                    }
                    case 2: {
                        if (this.allowThroughCreeping) {
                            model = creepingAllowedSampler.nextItem();
                            break;
                        }
                        model = creepingExcludedSampler.nextItem();
                        break;
                    }
                    default: {
                        throw new IllegalStateException();
                    }
                }
                this.jumpsToModels.put(uniqueJump, model);
                continue;
            }
            this.jumpsToModels.put(uniqueJump, enumSampler.nextItem());
        }
        Preconditions.checkState((this.jumpsToModels.size() == allJumpsList.size() ? 1 : 0) != 0);
    }

    private RandomSegModelSampler(Map<UniqueJump, NSHM23_SegmentationModels> jumpsToModels, boolean allowThroughCreeping, double maxRupLength) {
        this.jumpsToModels = jumpsToModels;
        this.allowThroughCreeping = allowThroughCreeping;
        this.maxRupLength = maxRupLength;
    }

    public synchronized JumpProbabilityCalc getModel(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch) {
        if (this.probCalc == null) {
            this.probCalc = new DeferringRandomJumpProbCalc(rupSet, branch, this.jumpsToModels);
        }
        return this.probCalc;
    }

    public synchronized RuptureProbabilityCalc.BinaryRuptureProbabilityCalc getExclusionModel(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch) {
        if (this.exclusionModel == null) {
            int creepingSectID;
            ArrayList<RuptureProbabilityCalc.BinaryRuptureProbabilityCalc> exclusions = new ArrayList<RuptureProbabilityCalc.BinaryRuptureProbabilityCalc>();
            JumpProbabilityCalc.BinaryJumpProbabilityCalc primaryExclusion = SegmentationModelBranchNode.buildJumpExclusionModel(rupSet, this.getModel(rupSet, branch));
            if (primaryExclusion != null) {
                exclusions.add(primaryExclusion);
            }
            if (!this.allowThroughCreeping && (creepingSectID = NSHM23_ConstraintBuilder.findCreepingSection(rupSet)) >= 0) {
                exclusions.add(new NSHM23_SegmentationModels.ExcludeRupsThroughCreepingSegmentationModel(creepingSectID));
            }
            if (this.maxRupLength > 0.0 && Double.isFinite(this.maxRupLength)) {
                exclusions.add(new NSHM23_SegmentationModels.MaxLengthSegmentationModel(this.maxRupLength));
            }
            if (branch.hasValue(SupraSeisBValues.B_0p0) || branch.hasValue(RandomBValSampler.Node.class)) {
                double[] bValues;
                RuptureProbabilityCalc.BinaryRuptureProbabilityCalc upstreamExclusion = null;
                if (exclusions.size() == 1) {
                    upstreamExclusion = (RuptureProbabilityCalc.BinaryRuptureProbabilityCalc)exclusions.get(0);
                } else if (exclusions.size() > 1) {
                    upstreamExclusion = new RuptureProbabilityCalc.LogicalAnd(exclusions.toArray(new RuptureProbabilityCalc.BinaryRuptureProbabilityCalc[0]));
                }
                if (branch.hasValue(RandomBValSampler.Node.class)) {
                    RandomBValSampler bSampler = (RandomBValSampler)rupSet.requireModule(BranchSamplingManager.class).getSampler(branch.requireValue(RandomBValSampler.Node.class));
                    bValues = bSampler.getBValues();
                } else {
                    bValues = new double[rupSet.getNumSections()];
                }
                exclusions.add(new B0_UnconnectedFullSectionsSegmentationModel(rupSet, bValues, upstreamExclusion));
            }
            if (exclusions.isEmpty()) {
                return null;
            }
            if (exclusions.size() == 1) {
                return (RuptureProbabilityCalc.BinaryRuptureProbabilityCalc)exclusions.get(0);
            }
            System.out.println("Combining " + exclusions.size() + " exclusion models");
            this.exclusionModel = new RuptureProbabilityCalc.LogicalAnd(exclusions.toArray(new RuptureProbabilityCalc.BinaryRuptureProbabilityCalc[0]));
        }
        return this.exclusionModel;
    }

    @Override
    public RandomSegModelSampler getForRuptureSubSet(FaultSystemRupSet rupSubSet, RuptureSubSetMappings mappings) {
        return new RandomSegModelSampler(this.jumpsToModels, this.allowThroughCreeping, this.maxRupLength);
    }

    public static enum CreepingTreatment {
        FULLY_UNCORRELATED,
        FULLY_CORRELATED,
        SAMPLE_WITHIN_INCLUSION_CHOICE;

    }

    private static class UniqueJump
    implements Comparable<UniqueJump> {
        private final int parent1;
        private final String name1;
        private final int parent2;
        private final String name2;
        private final double distance;

        public UniqueJump(Jump jump) {
            if (jump.fromCluster.parentSectionID < jump.toCluster.parentSectionID) {
                this.parent1 = jump.fromCluster.parentSectionID;
                this.parent2 = jump.toCluster.parentSectionID;
                this.name1 = jump.fromSection.getSectionName();
                this.name2 = jump.toSection.getSectionName();
            } else {
                this.parent2 = jump.fromCluster.parentSectionID;
                this.parent1 = jump.toCluster.parentSectionID;
                this.name2 = jump.fromSection.getSectionName();
                this.name1 = jump.toSection.getSectionName();
            }
            Preconditions.checkState((this.parent1 != this.parent2 ? 1 : 0) != 0, (Object)"Can't jump to yourself");
            this.distance = jump.distance;
        }

        public int hashCode() {
            return Objects.hash(this.distance, this.name1, this.name2, this.parent1, this.parent2);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            UniqueJump other = (UniqueJump)obj;
            return Double.doubleToLongBits(this.distance) == Double.doubleToLongBits(other.distance) && Objects.equals(this.name1, other.name1) && Objects.equals(this.name2, other.name2) && this.parent1 == other.parent1 && this.parent2 == other.parent2;
        }

        @Override
        public int compareTo(UniqueJump o) {
            int cmp = Integer.compare(this.parent1, o.parent1);
            if (cmp == 0) {
                cmp = Integer.compare(this.parent2, o.parent2);
            }
            if (cmp == 0) {
                cmp = Double.compare(this.distance, o.distance);
            }
            if (cmp == 0) {
                cmp = this.name1.compareTo(o.name1);
            }
            if (cmp == 0) {
                cmp = this.name2.compareTo(o.name2);
            }
            return cmp;
        }
    }

    private static class DeferringRandomJumpProbCalc
    implements JumpProbabilityCalc {
        private EnumMap<NSHM23_SegmentationModels, JumpProbabilityCalc> modelCalcs;
        private Map<UniqueJump, NSHM23_SegmentationModels> jumpsToModels;

        public DeferringRandomJumpProbCalc(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch, Map<UniqueJump, NSHM23_SegmentationModels> jumpsToModels) {
            this.jumpsToModels = jumpsToModels;
            this.modelCalcs = new EnumMap(NSHM23_SegmentationModels.class);
            for (NSHM23_SegmentationModels model : jumpsToModels.values()) {
                if (this.modelCalcs.containsKey(model)) continue;
                this.modelCalcs.put(model, model.getModel(rupSet, branch));
            }
        }

        @Override
        public boolean isDirectional(boolean splayed) {
            for (JumpProbabilityCalc calc : this.modelCalcs.values()) {
                if (!calc.isDirectional(splayed)) continue;
                return true;
            }
            return false;
        }

        @Override
        public String getName() {
            return "Deferring Randomly Sampled Jump Prob Calc";
        }

        @Override
        public double calcJumpProbability(ClusterRupture fullRupture, Jump jump, boolean verbose) {
            UniqueJump unique = new UniqueJump(jump);
            NSHM23_SegmentationModels model = this.jumpsToModels.get(unique);
            if (verbose) {
                System.out.println("Mapped jump " + String.valueOf(jump) + " to seg model " + String.valueOf(model));
            }
            Preconditions.checkNotNull((Object)model, (String)"Jump not found: %s", (Object)jump);
            JumpProbabilityCalc calc = this.modelCalcs.get(model);
            if (calc == null) {
                Preconditions.checkState((boolean)this.modelCalcs.containsKey(model), (String)"Calc never instantiated for model %s (jump %s)?", (Object)model, (Object)jump);
                return 1.0;
            }
            return calc.calcJumpProbability(fullRupture, jump, verbose);
        }
    }

    public static class B0_UnconnectedFullSectionsSegmentationModel
    implements RuptureProbabilityCalc.BinaryRuptureProbabilityCalc {
        private Map<Integer, List<FaultSection>> parentSectsMap;
        private HashSet<Integer> connectedSects;
        private double[] bValues;

        public B0_UnconnectedFullSectionsSegmentationModel(FaultSystemRupSet rupSet, double[] bValues, RuptureProbabilityCalc.BinaryRuptureProbabilityCalc exclusion) {
            this.bValues = bValues;
            this.parentSectsMap = rupSet.getFaultSectionDataList().stream().collect(Collectors.groupingBy(S -> S.getParentSectionId()));
            this.connectedSects = new HashSet();
            if (exclusion != null) {
                for (ClusterRupture rup : rupSet.requireModule(ClusterRuptures.class)) {
                    if (rup.getTotalNumClusters() == 1 || !exclusion.isRupAllowed(rup, false)) continue;
                    for (FaultSubsectionCluster cluster : rup.getClustersIterable()) {
                        this.connectedSects.add(cluster.parentSectionID);
                    }
                }
            }
        }

        @Override
        public boolean isDirectional(boolean splayed) {
            return false;
        }

        @Override
        public String getName() {
            return "Full Section Segmentation";
        }

        @Override
        public boolean isRupAllowed(ClusterRupture fullRupture, boolean verbose) {
            if (fullRupture.getTotalNumClusters() > 1) {
                return true;
            }
            FaultSubsectionCluster cluster = fullRupture.clusters[0];
            int parentID = cluster.parentSectionID;
            if (this.connectedSects.contains(parentID)) {
                return true;
            }
            for (FaultSection sect : cluster.subSects) {
                if (this.bValues[sect.getSectionId()] == 0.0) continue;
                return true;
            }
            List<FaultSection> fullCluster = this.parentSectsMap.get(parentID);
            return fullCluster.size() == cluster.subSects.size();
        }
    }

    public static class Level
    extends LogicTreeLevel.RandomlySampledLevel<Node> {
        public Level() {
        }

        public Level(int numSamples) {
            this(numSamples, new Random());
        }

        public Level(int numSamples, Random rand) {
            this.buildNodes(rand, numSamples);
        }

        @Override
        public String getShortName() {
            return "SegSamples";
        }

        @Override
        public String getName() {
            return "Segmentation Model Samples";
        }

        @Override
        public Node buildNodeInstance(int index, long seed, double weight) {
            return new Node(index, seed, weight);
        }

        @Override
        public Class<? extends Node> getType() {
            return Node.class;
        }
    }

    @DoesNotAffect.NotAffected(value={@DoesNotAffect(value="fault_sections.geojson"), @DoesNotAffect(value="indices.csv"), @DoesNotAffect(value="properties.csv")})
    @Affects(value="rates.csv")
    public static class Node
    extends AbstractSamplingNode<RandomSegModelSampler>
    implements SegmentationModelBranchNode {
        private Node() {
        }

        public Node(int index, long seed, double weight) {
            super("Segmentation Model Sample " + index, "SegSample" + index, "SegSample" + index, weight, seed);
        }

        @Override
        public RandomSegModelSampler buildSampler(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch, long branchNodeSamplingSeed) {
            System.out.println("Building seg model sampler for " + this.getShortName() + " with seed: " + branchNodeSamplingSeed);
            Random random = new Random(branchNodeSamplingSeed);
            WeightedSampler<NSHM23_SegmentationModels> enumSampler = Node.weightedNodeValueSampler(random, NSHM23_SegmentationModels.class);
            CreepingTreatment creepingTreatment = CREEPING_TREATMENT_DEFAULT;
            WeightedSampler creepingAllowedSampler = null;
            WeightedSampler creepingExcludedSampler = null;
            if (creepingTreatment == CreepingTreatment.SAMPLE_WITHIN_INCLUSION_CHOICE) {
                ArrayList<NSHM23_SegmentationModels> allowedNodes = new ArrayList<NSHM23_SegmentationModels>();
                ArrayList<Double> allowedWeights = new ArrayList<Double>();
                ArrayList<NSHM23_SegmentationModels> excludedNodes = new ArrayList<NSHM23_SegmentationModels>();
                ArrayList<Double> excludedWeights = new ArrayList<Double>();
                for (NSHM23_SegmentationModels node : NSHM23_SegmentationModels.values()) {
                    double weight = node.getNodeWeight(null);
                    if (!(weight > 0.0)) continue;
                    if (node.isIncludeRupturesThroughCreepingSect()) {
                        allowedNodes.add(node);
                        allowedWeights.add(weight);
                        continue;
                    }
                    excludedNodes.add(node);
                    excludedWeights.add(weight);
                }
                creepingAllowedSampler = new WeightedSampler(allowedNodes, allowedWeights, random);
                creepingExcludedSampler = new WeightedSampler(excludedNodes, excludedWeights, random);
            }
            return new RandomSegModelSampler(rupSet, enumSampler, creepingTreatment, creepingAllowedSampler, creepingExcludedSampler);
        }

        @Override
        public JumpProbabilityCalc getModel(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch) {
            BranchSamplingManager manager = rupSet.requireModule(BranchSamplingManager.class);
            RandomSegModelSampler sampler = manager.getSampler(this);
            return sampler.getModel(rupSet, branch);
        }

        @Override
        public RuptureProbabilityCalc.BinaryRuptureProbabilityCalc getExclusionModel(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch) {
            BranchSamplingManager manager = rupSet.requireModule(BranchSamplingManager.class);
            RandomSegModelSampler sampler = manager.getSampler(this);
            return sampler.getExclusionModel(rupSet, branch);
        }
    }
}

