/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.rupForecastImpl.nshm23.util;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.cli.CommandLine;
import org.opensha.commons.logicTree.LogicTreeBranch;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemRupSet;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemSolution;
import org.opensha.sha.earthquake.faultSysSolution.inversion.ClusterSpecificInversionSolver;
import org.opensha.sha.earthquake.faultSysSolution.inversion.InversionConfiguration;
import org.opensha.sha.earthquake.faultSysSolution.inversion.InversionConfigurationFactory;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.InversionConstraint;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.impl.MFDInversionConstraint;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.impl.SectionTotalRateConstraint;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.impl.SlipRateInversionConstraint;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.impl.SubSectMFDInversionConstraint;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.impl.UncertainDataConstraint;
import org.opensha.sha.earthquake.faultSysSolution.modules.ConnectivityClusters;
import org.opensha.sha.earthquake.faultSysSolution.modules.InitialSolution;
import org.opensha.sha.earthquake.faultSysSolution.modules.InversionMisfitProgress;
import org.opensha.sha.earthquake.faultSysSolution.modules.InversionMisfitStats;
import org.opensha.sha.earthquake.faultSysSolution.modules.InversionMisfits;
import org.opensha.sha.earthquake.faultSysSolution.modules.PaleoseismicConstraintData;
import org.opensha.sha.earthquake.faultSysSolution.modules.RuptureSubSetMappings;
import org.opensha.sha.earthquake.faultSysSolution.modules.SolutionLogicTree;
import org.opensha.sha.earthquake.faultSysSolution.modules.WaterLevelRates;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.plausibility.impl.prob.RuptureProbabilityCalc;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.util.ConnectivityCluster;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.NSHM23_ConstraintBuilder;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.NSHM23_InvConfigFactory;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.logicTree.SectionSupraSeisBValues;
import org.opensha.sha.earthquake.rupForecastImpl.nshm23.util.AnalyticalSingleFaultInversionSolver;

public class ClassicModelInversionSolver
extends ClusterSpecificInversionSolver {
    private RuptureProbabilityCalc.BinaryRuptureProbabilityCalc exclusionModel;
    private HashSet<Integer> paleoParents;
    private int parkfieldID;
    private LogicTreeBranch<?> branch;

    public ClassicModelInversionSolver(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch, RuptureProbabilityCalc.BinaryRuptureProbabilityCalc exclusionModel) {
        this.branch = branch;
        this.exclusionModel = exclusionModel;
        this.parkfieldID = NSHM23_ConstraintBuilder.findParkfieldSection(rupSet);
        this.paleoParents = new HashSet();
        PaleoseismicConstraintData paleoData = rupSet.getModule(PaleoseismicConstraintData.class);
        if (paleoData != null) {
            if (paleoData.hasPaleoRateConstraints()) {
                for (UncertainDataConstraint.SectMappedUncertainDataConstraint sectMappedUncertainDataConstraint : paleoData.getPaleoRateConstraints()) {
                    if (sectMappedUncertainDataConstraint.sectionIndex < 0) continue;
                    this.paleoParents.add(rupSet.getFaultSectionData(sectMappedUncertainDataConstraint.sectionIndex).getParentSectionId());
                }
            }
            if (paleoData.hasPaleoSlipConstraints()) {
                for (UncertainDataConstraint.SectMappedUncertainDataConstraint sectMappedUncertainDataConstraint : paleoData.getPaleoSlipConstraints()) {
                    if (sectMappedUncertainDataConstraint.sectionIndex < 0) continue;
                    this.paleoParents.add(rupSet.getFaultSectionData(sectMappedUncertainDataConstraint.sectionIndex).getParentSectionId());
                }
            }
        }
    }

    private AnalyticalSingleFaultInversionSolver getAnalyiticalSolver(boolean includeExclusionModel) {
        RuptureProbabilityCalc.BinaryRuptureProbabilityCalc exclusionModel;
        RuptureProbabilityCalc.BinaryRuptureProbabilityCalc binaryRuptureProbabilityCalc = exclusionModel = includeExclusionModel ? this.exclusionModel : null;
        if (this.branch.hasValue(SectionSupraSeisBValues.Constant.class)) {
            return new AnalyticalSingleFaultInversionSolver(this.branch.requireValue(SectionSupraSeisBValues.class).getB(), exclusionModel);
        }
        return new AnalyticalSingleFaultInversionSolver(exclusionModel);
    }

    @Override
    protected RuptureProbabilityCalc.BinaryRuptureProbabilityCalc getRuptureExclusionModel(FaultSystemRupSet rupSet, LogicTreeBranch<?> branch) {
        return this.exclusionModel;
    }

    @Override
    protected boolean shouldInvert(ConnectivityCluster cluster) {
        if (cluster.getParentSectIDs().size() > 1) {
            return true;
        }
        if (cluster.getParentSectIDs().contains(this.parkfieldID)) {
            return true;
        }
        for (int paleoParentID : this.paleoParents) {
            if (!cluster.getParentSectIDs().contains(paleoParentID)) continue;
            return true;
        }
        return false;
    }

    private boolean shouldInvert(FaultSystemRupSet rupSet, InversionConfiguration config) {
        if (NSHM23_InvConfigFactory.hasJumps(rupSet)) {
            return true;
        }
        for (InversionConstraint constraint : config.getConstraints()) {
            if (constraint instanceof SlipRateInversionConstraint || constraint instanceof MFDInversionConstraint || constraint instanceof SectionTotalRateConstraint || constraint instanceof SubSectMFDInversionConstraint) continue;
            return true;
        }
        return false;
    }

    @Override
    public FaultSystemSolution run(FaultSystemRupSet rupSet, InversionConfigurationFactory factory, LogicTreeBranch<?> branch, int threads, CommandLine cmd) throws IOException {
        SolutionLogicTree.SolutionProcessor processor;
        FaultSystemSolution inversionSol = super.run(rupSet, factory, branch, threads, cmd);
        if (inversionSol == null) {
            return this.getAnalyiticalSolver(true).run(rupSet, factory, branch, threads, cmd);
        }
        ConnectivityClusters clusters = rupSet.requireModule(ConnectivityClusters.class);
        HashSet<Integer> analyticalSects = new HashSet<Integer>();
        for (ConnectivityCluster cluster : clusters) {
            if (this.shouldInvert(cluster)) continue;
            analyticalSects.addAll(cluster.getSectIDs());
        }
        System.out.println("Calculating analytical solution for " + analyticalSects.size() + "/" + rupSet.getNumSections() + " sections");
        if (analyticalSects.isEmpty()) {
            return inversionSol;
        }
        FaultSystemRupSet analyticalRupSet = rupSet.getForSectionSubSet(analyticalSects, this.exclusionModel);
        InversionConfiguration config = factory.buildInversionConfig(analyticalRupSet, branch, threads);
        AnalyticalSingleFaultInversionSolver analytical = this.getAnalyiticalSolver(false);
        FaultSystemSolution analyticalSol = analytical.run(analyticalRupSet, config);
        double[] analyticalRates = analyticalSol.getRateForAllRups();
        int origNumRups = rupSet.getNumRuptures();
        double[] allInitialRates = inversionSol.hasModule(InitialSolution.class) ? Arrays.copyOf(inversionSol.getModule(InitialSolution.class).get(), origNumRups) : new double[origNumRups];
        double[] allRates = Arrays.copyOf(inversionSol.getRateForAllRups(), origNumRups);
        RuptureSubSetMappings mappings = analyticalRupSet.requireModule(RuptureSubSetMappings.class);
        for (int subsetRupIndex = 0; subsetRupIndex < analyticalRupSet.getNumRuptures(); ++subsetRupIndex) {
            int origRupIndex = mappings.getOrigRupID(subsetRupIndex);
            allRates[origRupIndex] = analyticalRates[subsetRupIndex];
            allInitialRates[origRupIndex] = analyticalRates[subsetRupIndex];
        }
        InversionMisfits inversionMisfits = inversionSol.requireModule(InversionMisfits.class);
        InversionMisfits analyticalMisfits = analyticalSol.requireModule(InversionMisfits.class);
        FaultSystemSolution combinedSol = new FaultSystemSolution(rupSet, allRates);
        if (analyticalSol.hasModule(WaterLevelRates.class) || inversionSol.hasModule(WaterLevelRates.class)) {
            WaterLevelRates inversionWaterLevel = inversionSol.getModule(WaterLevelRates.class);
            WaterLevelRates analyticalWaterLevel = analyticalSol.getModule(WaterLevelRates.class);
            if (analyticalWaterLevel == null) {
                combinedSol.addModule(inversionWaterLevel);
            } else {
                double[] newWaterLevel = inversionWaterLevel == null ? new double[origNumRups] : Arrays.copyOf(inversionWaterLevel.get(), origNumRups);
                for (int subsetRupIndex = 0; subsetRupIndex < analyticalRupSet.getNumRuptures(); ++subsetRupIndex) {
                    int origRupIndex = mappings.getOrigRupID(subsetRupIndex);
                    newWaterLevel[origRupIndex] = analyticalWaterLevel.get(subsetRupIndex);
                }
                combinedSol.addModule(new WaterLevelRates(newWaterLevel));
            }
        }
        combinedSol.addModule(new InitialSolution(allInitialRates));
        InversionMisfits combMisfits = InversionMisfits.appendSeparate(List.of(inversionMisfits, analyticalMisfits));
        combinedSol.addModule(combMisfits);
        combinedSol.addModule(combMisfits.getMisfitStats());
        if (inversionSol.hasModule(ConnectivityClusters.ConnectivityClusterSolutionMisfits.class)) {
            ConnectivityClusters.ConnectivityClusterSolutionMisfits clusterMisfits = inversionSol.requireModule(ConnectivityClusters.ConnectivityClusterSolutionMisfits.class);
            HashMap<ConnectivityCluster, InversionMisfitStats> clusterMisfitsMap = new HashMap<ConnectivityCluster, InversionMisfitStats>();
            InversionMisfitProgress largestProgress = clusterMisfits.getLargestClusterMisfitProgress();
            for (int i = 0; i < clusters.size(); ++i) {
                ConnectivityCluster cluster = clusters.get(i);
                InversionMisfitStats misfits = clusterMisfits.getMisfitStats(i);
                clusterMisfitsMap.put(cluster, misfits);
            }
            combinedSol.addModule(new ConnectivityClusters.ConnectivityClusterSolutionMisfits(combinedSol, clusterMisfitsMap, largestProgress));
        }
        if ((processor = factory.getSolutionLogicTreeProcessor()) != null) {
            processor.processSolution(combinedSol, branch);
        }
        return combinedSol;
    }

    @Override
    public FaultSystemSolution run(FaultSystemRupSet rupSet, InversionConfiguration config, String info) {
        if (!this.shouldInvert(rupSet, config)) {
            return this.getAnalyiticalSolver(true).run(rupSet, config, info);
        }
        return super.run(rupSet, config, info);
    }
}

