/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.faultSysSolution.util;

import com.google.common.base.Preconditions;
import com.google.common.collect.HashBasedTable;
import com.google.common.primitives.Ints;
import java.awt.geom.Point2D;
import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.zip.ZipFile;
import org.opensha.commons.data.function.ArbitrarilyDiscretizedFunc;
import org.opensha.commons.data.function.DiscretizedFunc;
import org.opensha.commons.data.function.LightFixedXFunc;
import org.opensha.commons.logicTree.BranchWeightProvider;
import org.opensha.commons.logicTree.LogicTree;
import org.opensha.commons.logicTree.LogicTreeBranch;
import org.opensha.commons.util.modules.AverageableModule;
import org.opensha.commons.util.modules.ModuleArchive;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemRupSet;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemSolution;
import org.opensha.sha.earthquake.faultSysSolution.modules.ClusterRuptures;
import org.opensha.sha.earthquake.faultSysSolution.modules.GridSourceList;
import org.opensha.sha.earthquake.faultSysSolution.modules.GridSourceProvider;
import org.opensha.sha.earthquake.faultSysSolution.modules.RupMFDsModule;
import org.opensha.sha.earthquake.faultSysSolution.modules.RupSetTectonicRegimes;
import org.opensha.sha.earthquake.faultSysSolution.modules.SolutionLogicTree;
import org.opensha.sha.earthquake.faultSysSolution.modules.TrueMeanRuptureMappings;
import org.opensha.sha.faultSurface.FaultSection;
import org.opensha.sha.util.TectonicRegionType;

public class TrueMeanSolutionCreator {
    private HashMap<UniqueSection, Integer> uniqueSectsMap;
    private List<UniqueSection> uniqueSectsList;
    private HashMap<UniqueRupture, Integer> uniqueRupsMap;
    private List<UniqueRupture> uniqueRupsList;
    private List<LogicTreeBranch<?>> branches;
    private List<int[]> branchSectMappings;
    private List<int[]> branchRupMappings;
    private List<double[]> branchRupMags;
    private HashMap<Integer, int[]> sectMappingsCache;
    private HashMap<Integer, int[]> rupMappingsCache;
    private HashMap<Integer, double[]> rupMagsCache;
    private LogicTree<?> tree;
    private BranchWeightProvider weightProv;
    private double totWeight = 0.0;
    private boolean doGridProv;
    private AverageableModule.AveragingAccumulator<GridSourceProvider> gridProvAvg;
    private List<TectonicRegionType> rupTRTs = null;
    private boolean allSingleStranded = true;

    public TrueMeanSolutionCreator(LogicTree<?> tree) {
        this.tree = tree;
        this.weightProv = tree.getWeightProvider();
    }

    public TrueMeanSolutionCreator(LogicTree<?> tree, BranchWeightProvider weightProv) {
        this.tree = tree;
        this.weightProv = weightProv;
    }

    public void setDoGridProv(boolean doGridProv) {
        this.doGridProv = doGridProv;
    }

    public synchronized void addSolution(FaultSystemSolution sol, LogicTreeBranch<?> branch) {
        Preconditions.checkState((boolean)this.tree.contains(branch), (Object)"Branch not contained by tree");
        GridSourceProvider gridProv = this.doGridProv ? sol.getGridSourceProvider() : null;
        RupSetTectonicRegimes myRupTRTs = sol.getRupSet().getModule(RupSetTectonicRegimes.class);
        if (this.uniqueSectsMap == null) {
            this.uniqueSectsMap = new HashMap();
            this.uniqueSectsList = new ArrayList<UniqueSection>();
            this.uniqueRupsMap = new HashMap();
            this.uniqueRupsList = new ArrayList<UniqueRupture>();
            this.branches = new ArrayList();
            this.branchSectMappings = new ArrayList<int[]>();
            this.branchRupMappings = new ArrayList<int[]>();
            this.branchRupMags = new ArrayList<double[]>();
            this.sectMappingsCache = new HashMap();
            this.rupMappingsCache = new HashMap();
            this.rupMagsCache = new HashMap();
            if (gridProv != null) {
                this.gridProvAvg = gridProv.averagingAccumulator();
            }
            if (myRupTRTs != null) {
                System.out.println("Will track rupture tectonic region types");
                this.rupTRTs = new ArrayList<TectonicRegionType>();
            }
        } else {
            if (this.gridProvAvg != null && gridProv == null) {
                System.err.println("WARNING: not all solutions contain grid source providers, disabling averaging");
                this.gridProvAvg = null;
            }
            if (this.rupTRTs != null && myRupTRTs == null) {
                System.err.println("WARNING: not all rupture sets contain tectonic region types, no longer tracking");
                this.rupTRTs = null;
            }
        }
        double weight = this.weightProv.getWeight(branch);
        System.out.println("Processing branch: " + String.valueOf(branch) + ", weight=" + weight);
        this.totWeight += weight;
        FaultSystemRupSet rupSet = sol.getRupSet();
        boolean bl = this.allSingleStranded = this.allSingleStranded && rupSet.hasModule(ClusterRuptures.class);
        if (this.allSingleStranded) {
            ClusterRuptures cRups = rupSet.requireModule(ClusterRuptures.class);
            for (int r = 0; this.allSingleStranded && r < cRups.size(); ++r) {
                this.allSingleStranded = cRups.get((int)r).singleStrand;
            }
        }
        int[] sectMappings = new int[rupSet.getNumSections()];
        int numNewSects = 0;
        for (int s = 0; s < sectMappings.length; ++s) {
            int globalID;
            FaultSection sect = rupSet.getFaultSectionData(s);
            UniqueSection unique = new UniqueSection(sect);
            if (this.uniqueSectsMap.containsKey(unique)) {
                globalID = this.uniqueSectsMap.get(unique);
                unique = this.uniqueSectsList.get(globalID);
            } else {
                ++numNewSects;
                globalID = this.uniqueSectsList.size();
                unique.setGlobalID(globalID);
                this.uniqueSectsList.add(unique);
                this.uniqueSectsMap.put(unique, globalID);
            }
            sectMappings[s] = globalID;
            unique.addInstance(sect, weight);
        }
        System.out.println("\t" + numNewSects + "/" + sectMappings.length + " new unique sections");
        int[] rupMappings = new int[rupSet.getNumRuptures()];
        int numNewRups = 0;
        int numNewNonzeroRups = 0;
        int numNewMags = 0;
        RupMFDsModule rupMFDs = sol.getModule(RupMFDsModule.class);
        for (int r = 0; r < rupMappings.length; ++r) {
            DiscretizedFunc rupMFD;
            boolean newMag;
            int globalID;
            List<Integer> origSectIDs = rupSet.getSectionsIndicesForRup(r);
            double rate = sol.getRateForRup(r);
            int[] globalSectIDs = new int[origSectIDs.size()];
            for (int i = 0; i < globalSectIDs.length; ++i) {
                globalSectIDs[i] = sectMappings[origSectIDs.get(i)];
            }
            UniqueRupture unique = new UniqueRupture(globalSectIDs, rupSet.getAveRakeForRup(r), rupSet.getAreaForRup(r), rupSet.getLengthForRup(r));
            if (this.uniqueRupsMap.containsKey(unique)) {
                globalID = this.uniqueRupsMap.get(unique);
                unique = this.uniqueRupsList.get(globalID);
            } else {
                ++numNewRups;
                globalID = this.uniqueRupsMap.size();
                unique.setGlobalID(globalID);
                this.uniqueRupsList.add(unique);
                this.uniqueRupsMap.put(unique, globalID);
                if (this.rupTRTs != null) {
                    this.rupTRTs.add(myRupTRTs.get(r));
                }
            }
            if (rate > 0.0 && unique.rupMFD.size() == 0) {
                ++numNewNonzeroRups;
            }
            rupMappings[r] = globalID;
            double rupSetMag = rupSet.getMagForRup(r);
            unique.addWeightMag(rupSetMag, weight);
            if (!(rate > 0.0) || !(newMag = (rupMFD = rupMFDs == null ? null : rupMFDs.getRuptureMFD(r)) == null ? unique.addForBranch(rupSetMag, rate, weight) : unique.addForBranch(rupMFD, weight))) continue;
            ++numNewMags;
        }
        System.out.println("\t" + numNewRups + "/" + rupMappings.length + " new unique ruptures");
        System.out.println("\t" + numNewNonzeroRups + "/" + rupMappings.length + " new unique ruptures with nonzero rates");
        System.out.println("\t" + numNewMags + "/" + rupMappings.length + " new unique rupture magnitudes");
        if (this.gridProvAvg != null) {
            if (gridProv instanceof GridSourceList) {
                gridProv = GridSourceList.remapAssociations((GridSourceList)gridProv, sectMappings);
            }
            this.gridProvAvg.process(gridProv, weight);
        }
        this.branches.add(branch);
        int sectMappingHash = Arrays.hashCode(sectMappings);
        if (this.sectMappingsCache.containsKey(sectMappingHash) && Arrays.equals(sectMappings, this.sectMappingsCache.get(sectMappingHash))) {
            this.branchSectMappings.add(this.sectMappingsCache.get(sectMappingHash));
        } else {
            this.branchSectMappings.add(sectMappings);
            this.sectMappingsCache.put(sectMappingHash, sectMappings);
        }
        int rupMappingHash = Arrays.hashCode(rupMappings);
        if (this.rupMappingsCache.containsKey(rupMappingHash) && Arrays.equals(rupMappings, this.rupMappingsCache.get(rupMappingHash))) {
            this.branchRupMappings.add(this.rupMappingsCache.get(rupMappingHash));
        } else {
            this.branchRupMappings.add(rupMappings);
            this.rupMappingsCache.put(rupMappingHash, rupMappings);
        }
        double[] rupMags = rupSet.getMagForAllRups();
        int magMappingHash = Arrays.hashCode(rupMags);
        if (this.rupMagsCache.containsKey(magMappingHash) && Arrays.equals(rupMags, this.rupMagsCache.get(magMappingHash))) {
            this.branchRupMags.add(this.rupMagsCache.get(magMappingHash));
        } else {
            this.branchRupMags.add(rupMags);
            this.rupMagsCache.put(magMappingHash, rupMags);
        }
    }

    public synchronized FaultSystemSolution build() {
        System.out.println("Building true mean with " + this.uniqueSectsList.size() + " unique sections and " + this.uniqueRupsList.size() + " unique ruptures, totalWeight=" + this.totWeight);
        HashBasedTable instanceCounts = HashBasedTable.create();
        ArrayList<FaultSection> uniqueSects = new ArrayList<FaultSection>();
        for (int s = 0; s < this.uniqueSectsList.size(); ++s) {
            UniqueSection unique = this.uniqueSectsList.get(s);
            Integer count = (Integer)instanceCounts.get((Object)unique.parentID, (Object)unique.name);
            if (count == null) {
                count = 0;
            }
            instanceCounts.put((Object)unique.parentID, (Object)unique.name, (Object)(count + 1));
            FaultSection sect = unique.buildGlobalInstance(count, this.totWeight);
            uniqueSects.add(sect);
        }
        ArrayList<List<Integer>> globalSectsForRups = new ArrayList<List<Integer>>();
        double[] avgMags = new double[this.uniqueRupsList.size()];
        double[] rakes = new double[avgMags.length];
        double[] areas = new double[avgMags.length];
        double[] lengths = new double[avgMags.length];
        double[] rates = new double[avgMags.length];
        DiscretizedFunc[] rupMFDs = new DiscretizedFunc[avgMags.length];
        for (int r = 0; r < avgMags.length; ++r) {
            UniqueRupture unique = this.uniqueRupsList.get(r);
            avgMags[r] = unique.getMeanMag();
            rakes[r] = unique.rake;
            areas[r] = unique.area;
            lengths[r] = unique.length;
            DiscretizedFunc mfd = unique.getFinalMFD(this.totWeight);
            rates[r] = mfd == null ? 0.0 : mfd.calcSumOfY_Vals();
            rupMFDs[r] = mfd;
            globalSectsForRups.add(Ints.asList((int[])unique.sectIDs));
        }
        FaultSystemRupSet avgRupSet = new FaultSystemRupSet(uniqueSects, globalSectsForRups, avgMags, rakes, areas, lengths);
        if (this.allSingleStranded) {
            avgRupSet.addModule(ClusterRuptures.singleStranged(avgRupSet));
        }
        if (this.rupTRTs != null) {
            avgRupSet.addModule(new RupSetTectonicRegimes(avgRupSet, this.rupTRTs.toArray(new TectonicRegionType[0])));
        }
        FaultSystemSolution avgSol = new FaultSystemSolution(avgRupSet, rates);
        avgSol.addModule(new RupMFDsModule(avgSol, rupMFDs));
        if (this.gridProvAvg != null) {
            avgSol.addModule(this.gridProvAvg.getAverage());
        }
        LogicTree<?> tree = this.tree;
        Preconditions.checkState((this.branches.size() <= tree.size() ? 1 : 0) != 0, (Object)"More branches added than exist in tree, must be duplicates");
        if (this.branches.size() < tree.size()) {
            tree = tree.subset(this.branches);
        }
        avgRupSet.addModule(TrueMeanRuptureMappings.build(tree, this.branchSectMappings, this.branchRupMappings, this.branchRupMags));
        return avgSol;
    }

    public static void main(String[] args) throws IOException {
        File outputFile;
        File gridProvFile;
        File sltFile;
        System.setProperty("java.awt.headless", "true");
        if (args.length == 1 && args[0].equals("--hardcoded")) {
            File dir = new File("/data/kevin/nshm23/batch_inversions/2023_06_23-nshm23_branches-NSHM23_v2-CoulombRupSet-TotNuclRate-NoRed-ThreshAvgIterRelGR");
            sltFile = new File(dir, "results.zip");
            gridProvFile = new File(dir, "results_NSHM23_v2_CoulombRupSet_branch_averaged_gridded.zip");
            outputFile = new File(dir, "true_mean_solution.zip");
        } else if (args.length == 2 || args.length == 3) {
            sltFile = new File(args[0]);
            outputFile = new File(args[1]);
            gridProvFile = args.length == 3 ? new File(args[2]) : null;
        } else {
            throw new IllegalArgumentException("USAGE: <slt-file.zip> <output-file.zip> [<BA sol or grid prov.zip>]");
        }
        Preconditions.checkState((boolean)sltFile.exists(), (String)"Solution logic tree file doesn't exist: %s", (Object)sltFile.getAbsolutePath());
        SolutionLogicTree slt = SolutionLogicTree.load(sltFile);
        LogicTree<?> tree = slt.getLogicTree();
        GridSourceProvider avgGridProv = null;
        if (gridProvFile != null) {
            ZipFile zip = new ZipFile(gridProvFile);
            if (FaultSystemSolution.isSolution(zip)) {
                avgGridProv = FaultSystemSolution.load(zip).requireModule(GridSourceProvider.class);
            } else {
                ModuleArchive avgArchive = new ModuleArchive(zip);
                avgGridProv = avgArchive.requireModule(GridSourceProvider.class);
            }
        }
        TrueMeanSolutionCreator creator = new TrueMeanSolutionCreator(tree);
        creator.setDoGridProv(avgGridProv == null);
        ClusterRuptures cRups = null;
        for (int i = 0; i < tree.size(); ++i) {
            LogicTreeBranch<?> branch = tree.getBranch(i);
            System.out.println("Processing branch " + i + "/" + tree.size() + ": " + String.valueOf(branch));
            FaultSystemSolution sol = slt.forBranch(branch);
            if (cRups != null && cRups.size() != sol.getRupSet().getNumRuptures()) {
                cRups = null;
            }
            if (cRups == null) {
                cRups = sol.getRupSet().getModule(ClusterRuptures.class);
                if (cRups == null) {
                    System.err.println("WARNING: Building ClusterRuptures and assuming single-stranded");
                    cRups = ClusterRuptures.singleStranged(sol.getRupSet());
                    sol.getRupSet().addModule(cRups);
                }
            } else {
                sol.getRupSet().addModule(cRups);
            }
            creator.addSolution(sol, branch);
        }
        FaultSystemSolution avgSol = creator.build();
        if (avgGridProv != null) {
            avgSol.setGridSourceProvider(avgGridProv);
        }
        avgSol.write(outputFile);
    }

    private static class UniqueSection {
        private int parentID;
        private String name;
        private double aveDip;
        private double aveUpperDepth;
        private double aveLowerDepth;
        private Object[] comp;
        private FaultSection sect;
        private int globalID;
        private double weightUsed;
        private double weightedSlip;
        private double weightedSlipStdDev;
        private double weightedCoupling;
        private static DecimalFormat weightDF = new DecimalFormat("0.##%");

        public UniqueSection(FaultSection sect) {
            this.parentID = sect.getParentSectionId();
            this.name = sect.getName();
            this.aveDip = sect.getAveDip();
            this.aveUpperDepth = sect.getReducedAveUpperDepth();
            this.aveLowerDepth = sect.getAveLowerDepth();
            this.comp = new Object[]{this.parentID, this.name, this.aveDip, this.aveUpperDepth, this.aveLowerDepth};
            this.sect = sect;
        }

        public void setGlobalID(int globalID) {
            this.globalID = globalID;
        }

        public void addInstance(FaultSection sect, double weight) {
            this.weightUsed += weight;
            this.weightedSlip += sect.getOrigAveSlipRate() * weight;
            this.weightedSlipStdDev += sect.getOrigSlipRateStdDev() * weight;
            this.weightedCoupling += sect.getCouplingCoeff() * weight;
        }

        public FaultSection buildGlobalInstance(int instanceNum, double totWeight) {
            FaultSection sect = this.sect.clone();
            if (instanceNum > 0 || this.weightUsed != totWeight) {
                sect.setSectionName(sect.getSectionName() + " (Instance " + instanceNum + ", " + weightDF.format(this.weightUsed / totWeight) + " weight)");
            }
            sect.setSectionId(this.globalID);
            sect.setAveSlipRate(this.weightedSlip / totWeight);
            sect.setSlipRateStdDev(this.weightedSlipStdDev / totWeight);
            sect.setCouplingCoeff(this.weightedCoupling / this.weightUsed);
            return sect;
        }

        public int hashCode() {
            return Arrays.hashCode(this.comp);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            return Arrays.equals(this.comp, ((UniqueSection)obj).comp);
        }
    }

    private static class UniqueRupture {
        private int[] sectIDs;
        private double rake;
        private double area;
        private Object[] comp;
        private int hashCode;
        private int globalID;
        private DiscretizedFunc rupMFD;
        private double length;
        private double weightMagSum = 0.0;
        private double weightSum = 0.0;

        public UniqueRupture(int[] sectIDs, double rake, double area, double length) {
            this.sectIDs = sectIDs;
            this.rake = rake;
            this.area = area;
            this.comp = new Object[sectIDs.length + 2];
            int cnt = 0;
            this.comp[cnt++] = rake;
            this.comp[cnt++] = area;
            for (int sectID : sectIDs) {
                this.comp[cnt++] = sectID;
            }
            this.hashCode = Arrays.hashCode(this.comp);
            this.rupMFD = new ArbitrarilyDiscretizedFunc();
            this.length = length;
        }

        public void setGlobalID(int globalID) {
            this.globalID = globalID;
        }

        public void addWeightMag(double mag, double weight) {
            this.weightMagSum += mag * weight;
            this.weightSum += weight;
        }

        public boolean addForBranch(double mag, double rate, double weight) {
            boolean newMag;
            boolean bl = newMag = !this.rupMFD.hasX(mag);
            if (newMag) {
                this.rupMFD.set(mag, rate * weight);
            } else {
                this.rupMFD.set(mag, this.rupMFD.getY(mag) + rate * weight);
            }
            return newMag;
        }

        public boolean addForBranch(DiscretizedFunc branchMFD, double weight) {
            boolean newMag = false;
            for (Point2D pt : branchMFD) {
                if (!this.addForBranch(pt.getX(), pt.getY(), weight)) continue;
                newMag = true;
            }
            return newMag;
        }

        public double getMeanMag() {
            if (this.rupMFD.size() == 0) {
                return this.weightMagSum / this.weightSum;
            }
            double meanMag = 0.0;
            double weightSum = 0.0;
            for (Point2D pt : this.rupMFD) {
                meanMag += pt.getX() * pt.getY();
                weightSum += pt.getY();
            }
            return meanMag / weightSum;
        }

        public DiscretizedFunc getFinalMFD(double totWeight) {
            if (this.rupMFD.size() == 0) {
                return null;
            }
            double[] xVals = new double[this.rupMFD.size()];
            double[] yVals = new double[this.rupMFD.size()];
            for (int i = 0; i < xVals.length; ++i) {
                xVals[i] = this.rupMFD.getX(i);
                yVals[i] = this.rupMFD.getY(i) / totWeight;
            }
            return new LightFixedXFunc(xVals, yVals);
        }

        public int hashCode() {
            return this.hashCode;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            return Arrays.equals(this.comp, ((UniqueRupture)obj).comp);
        }
    }
}

