/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.faultSysSolution.hazard;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.opensha.commons.data.xyz.GriddedGeoDataSet;
import org.opensha.commons.geo.GriddedRegion;
import org.opensha.commons.logicTree.LogicTree;
import org.opensha.commons.logicTree.LogicTreeBranch;
import org.opensha.commons.logicTree.LogicTreeLevel;
import org.opensha.commons.logicTree.LogicTreeNode;
import org.opensha.commons.util.MarkdownUtils;
import org.opensha.sha.earthquake.faultSysSolution.hazard.AbstractLTVarianceDecomposition;
import org.opensha.sha.earthquake.faultSysSolution.hazard.LogicTreeHazardCompare;

public class MarginalAveragingLTVarianceDecomposition
extends AbstractLTVarianceDecomposition {
    private GriddedGeoDataSet varAveragingOverSampling;
    private boolean anyNegativeLTCOVs;

    public MarginalAveragingLTVarianceDecomposition(LogicTree<?> tree, List<LogicTreeLevel<?>> uniqueSamplingLevels, ExecutorService exec) {
        super(tree, uniqueSamplingLevels, exec);
    }

    @Override
    public void initForMaps(GriddedGeoDataSet meanMap, GriddedGeoDataSet fullVariance, GriddedGeoDataSet[] allMaps, List<Double> allWeights) {
        super.initForMaps(meanMap, fullVariance, allMaps, allWeights);
        if (this.uniqueSamplingLevels.isEmpty()) {
            this.varAveragingOverSampling = null;
        } else {
            ArrayList<Double> weightsWithoutSampling = new ArrayList<Double>();
            GriddedRegion region = meanMap.getRegion();
            GriddedGeoDataSet sdAveragingOverSampling = new GriddedGeoDataSet(region);
            GriddedGeoDataSet covAveragingOverSampling = new GriddedGeoDataSet(region);
            GriddedGeoDataSet[] mapsAveragingOverSampling = this.doAverageAcrossLevels(-1, null, weightsWithoutSampling, covAveragingOverSampling, sdAveragingOverSampling);
            Preconditions.checkNotNull((Object)mapsAveragingOverSampling);
            LogicTreeHazardCompare.calcSD_COV(mapsAveragingOverSampling, weightsWithoutSampling, meanMap, sdAveragingOverSampling, covAveragingOverSampling, this.exec);
            this.varAveragingOverSampling = new GriddedGeoDataSet(region);
            for (int i = 0; i < region.getNodeCount(); ++i) {
                this.varAveragingOverSampling.set(i, sdAveragingOverSampling.get(i) * sdAveragingOverSampling.get(i));
            }
        }
    }

    @Override
    public AbstractLTVarianceDecomposition.VarianceContributionResult calcMapVarianceContributionForLevel(int levelIndex, LogicTreeLevel<?> level, Map<LogicTreeNode, List<GriddedGeoDataSet>> choiceMaps, Map<LogicTreeNode, List<Double>> choiceMapWeights) {
        GriddedGeoDataSet sdExcluding;
        ArrayList<Double> weightsExcludingLevel = new ArrayList<Double>();
        GriddedRegion region = this.meanMap.getRegion();
        GriddedGeoDataSet covExcluding = new GriddedGeoDataSet(region);
        if (this.doAverageAcrossLevels(levelIndex, level, weightsExcludingLevel, covExcluding, sdExcluding = new GriddedGeoDataSet(region)) == null) {
            return null;
        }
        GriddedGeoDataSet refVariance = this.uniqueSamplingLevels.isEmpty() || this.uniqueSamplingLevels.contains(level) ? this.fullVariance : this.varAveragingOverSampling;
        double sumVarContrib = 0.0;
        double maxVarContrib = 0.0;
        double maxFractVarContrib = 0.0;
        double sumCOVContrib = 0.0;
        double maxCOVContrib = 0.0;
        double maxFractCOVContrib = 0.0;
        int numFinite = 0;
        for (int i = 0; i < sdExcluding.size(); ++i) {
            double sd = sdExcluding.get(i);
            if (!Double.isFinite(sd)) continue;
            double refVar = refVariance.get(i);
            double fullVar = this.fullVariance.get(i);
            double varWithout = sd * sd;
            double varContrib = refVar - varWithout;
            double refCOV = Math.sqrt(refVar) / this.meanMap.get(i);
            double fullCOV = Math.sqrt(fullVar) / this.meanMap.get(i);
            double covWithout = covExcluding.get(i);
            double covContrib = refCOV - covWithout;
            sumVarContrib += varContrib;
            maxVarContrib = Math.max(maxVarContrib, varContrib);
            maxFractVarContrib = Math.max(maxFractVarContrib, varContrib / fullVar);
            sumCOVContrib += covContrib;
            maxCOVContrib = Math.max(maxCOVContrib, covContrib);
            maxFractCOVContrib = Math.max(maxFractCOVContrib, covContrib / fullCOV);
            ++numFinite;
        }
        this.anyNegativeLTCOVs |= sumVarContrib < 0.0 || sumCOVContrib < 0.0;
        if (numFinite > 0) {
            return new AbstractLTVarianceDecomposition.VarianceContributionResult(sumVarContrib / (double)numFinite, maxVarContrib, maxFractVarContrib, sumCOVContrib / (double)numFinite, maxCOVContrib, maxFractCOVContrib);
        }
        return null;
    }

    private GriddedGeoDataSet[] doAverageAcrossLevels(int levelIndex, LogicTreeLevel<?> level, List<Double> weightsExcludingLevel, GriddedGeoDataSet covExcluding, GriddedGeoDataSet sdExcluding) {
        Preconditions.checkState((boolean)weightsExcludingLevel.isEmpty());
        HashMap otherBranchIndexes = new HashMap();
        ArrayList uniqueOtherBranches = new ArrayList();
        if (level == null) {
            System.out.println("Averaging out any sampling levels");
        } else {
            System.out.println("Averaging out level '" + level.getName() + "' by finding all unique branches excluding that level");
        }
        ArrayList<Integer> clearIndexes = new ArrayList<Integer>();
        if (levelIndex >= 0) {
            clearIndexes.add(levelIndex);
        }
        int numLevels = this.tree.getLevels().size();
        for (int l = 0; l < numLevels; ++l) {
            LogicTreeLevel testLevel;
            if (l == levelIndex || !this.uniqueSamplingLevels.contains(testLevel = (LogicTreeLevel)this.tree.getLevels().get(l))) continue;
            System.out.println("\tWill also clear out sampling values from level '" + testLevel.getName() + "'");
            clearIndexes.add(l);
        }
        if (clearIndexes.isEmpty()) {
            return null;
        }
        for (int b = 0; b < this.tree.size(); ++b) {
            LogicTreeBranch branch = this.tree.getBranch(b).copy();
            Iterator iterator = clearIndexes.iterator();
            while (iterator.hasNext()) {
                int l = (Integer)iterator.next();
                branch.clearValue(l);
            }
            ArrayList<Integer> indexes = (ArrayList<Integer>)otherBranchIndexes.get(branch);
            if (indexes == null) {
                indexes = new ArrayList<Integer>();
                otherBranchIndexes.put(branch, indexes);
                uniqueOtherBranches.add(branch);
            }
            indexes.add(b);
        }
        System.out.println("\treduced from " + this.tree.size() + " to " + uniqueOtherBranches.size() + " branches");
        if (this.tree.size() == uniqueOtherBranches.size() || uniqueOtherBranches.size() == 1) {
            return null;
        }
        GriddedGeoDataSet[] mapsExcludingLevel = new GriddedGeoDataSet[uniqueOtherBranches.size()];
        for (int b = 0; b < mapsExcludingLevel.length; ++b) {
            LogicTreeBranch branch = (LogicTreeBranch)uniqueOtherBranches.get(b);
            List indexes = (List)otherBranchIndexes.get(branch);
            double weightSum = 0.0;
            GriddedGeoDataSet map = new GriddedGeoDataSet(this.allMaps[0].getRegion());
            Iterator iterator = indexes.iterator();
            while (iterator.hasNext()) {
                int index = (Integer)iterator.next();
                double subWeight = this.tree.getBranchWeight(index);
                weightSum += subWeight;
                GriddedGeoDataSet subMap = this.allMaps[index];
                for (int i = 0; i < map.size(); ++i) {
                    map.add(i, subMap.get(i) * subWeight);
                }
            }
            map.scale(1.0 / weightSum);
            mapsExcludingLevel[b] = map;
            weightsExcludingLevel.add(weightSum);
        }
        System.out.println("\tCalculating variance excluding");
        LogicTreeHazardCompare.calcSD_COV(mapsExcludingLevel, weightsExcludingLevel, this.meanMap, sdExcluding, covExcluding, this.exec);
        return mapsExcludingLevel;
    }

    @Override
    public String getHeading() {
        return "Logic Tree Variance and COV Contributions";
    }

    @Override
    public List<String> buildLines(List<AbstractLTVarianceDecomposition.VarianceContributionResult> results) {
        ArrayList<String> lines = new ArrayList<String>();
        lines.add("This table summarizes how each logic tree branching level contributes to the overall variance and coefficient of variation (COV) in the model.");
        lines.add("");
        lines.add("For each level, its contribution is estimated by collapsing the logic tree across that level; i.e., by averaging values across all branches that differ only in thier choice at that level.");
        if (this.uniqueSamplingLevels.size() > 1) {
            lines.add("");
            lines.add("This logic tree contains multiple random sampling levels and variance cannot be decomposed between them. They are bundled into a single line of the table.");
        }
        if (this.anyNegativeLTCOVs) {
            lines.add("");
            lines.add("In some cases, the variance (and therefore COV) may slightly increase after removing a level. This is due to the behavior of weighted averaging and is interpreted as statistical noise, indicating that the level does not significantly contribute to model variability.");
        }
        lines.add("");
        lines.add("Both spatially averaged and maximum contributions are reported for each level.");
        lines.add("");
        double fullVarSum = 0.0;
        double fullVarMax = 0.0;
        double fullCOVsum = 0.0;
        double fullCOVmax = 0.0;
        int numValid = 0;
        for (int i = 0; i < this.fullVariance.size(); ++i) {
            double var = this.fullVariance.get(i);
            double mean = this.meanMap.get(i);
            if (!Double.isFinite(var) || !Double.isFinite(mean)) continue;
            double cov = Math.sqrt(var) / mean;
            fullVarSum += var;
            fullVarMax = Math.max(fullVarMax, var);
            fullCOVsum += cov;
            fullCOVmax = Math.max(fullCOVmax, cov);
            ++numValid;
        }
        double fullVar = fullVarSum / (double)numValid;
        double fullCOV = fullCOVsum / (double)numValid;
        MarkdownUtils.TableBuilder table = MarkdownUtils.tableBuilder();
        table.addLine("Branch Level", "Average Variance Contribution", "Average COV Contribution", "Maximum Variance Contribution", "Maximum COV Contribution");
        table.addLine("Full model", "100%", LogicTreeHazardCompare.threeDigits.format(fullCOV), "100%", LogicTreeHazardCompare.threeDigits.format(fullCOVmax));
        int numLevels = this.tree.getLevels().size();
        Preconditions.checkState((results.size() == numLevels ? 1 : 0) != 0);
        int numWithResults = 0;
        for (int l = 0; l < numLevels; ++l) {
            AbstractLTVarianceDecomposition.VarianceContributionResult varResult = results.get(l);
            if (varResult == null) continue;
            ++numWithResults;
            LogicTreeLevel level = (LogicTreeLevel)this.tree.getLevels().get(l);
            Object levelName = level.getName();
            if (this.uniqueSamplingLevels.size() > 1 && this.uniqueSamplingLevels.get(0) == level) {
                levelName = this.uniqueSamplingLevels.size() + " sampling levels: ";
                for (int i = 0; i < this.uniqueSamplingLevels.size(); ++i) {
                    if (i > 0) {
                        levelName = (String)levelName + ", ";
                    }
                    levelName = (String)levelName + ((LogicTreeLevel)this.uniqueSamplingLevels.get(i)).getShortName();
                }
            }
            double fractVar = varResult.meanVarianceContribution / fullVar;
            double fractCOV = varResult.meanCOVContribution / fullCOV;
            table.addLine(new String[]{levelName, LogicTreeHazardCompare.pDF.format(fractVar), LogicTreeHazardCompare.threeDigits.format(varResult.meanCOVContribution) + " (" + LogicTreeHazardCompare.pDF.format(fractCOV) + ")", LogicTreeHazardCompare.pDF.format(varResult.maxFractionalVarianceContribution), LogicTreeHazardCompare.threeDigits.format(varResult.maxCOVContribution) + " (" + LogicTreeHazardCompare.pDF.format(varResult.maxFractionalCOVContribution) + ")"});
        }
        if (numWithResults < 2) {
            return null;
        }
        lines.addAll(table.build());
        lines.add("");
        return lines;
    }
}

