/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.faultSysSolution.util;

import com.google.common.base.Preconditions;
import com.google.common.math.DoubleMath;
import java.io.File;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.opensha.commons.data.Named;
import org.opensha.commons.data.function.ArbitrarilyDiscretizedFunc;
import org.opensha.commons.data.function.DiscretizedFunc;
import org.opensha.commons.logicTree.BranchWeightProvider;
import org.opensha.commons.logicTree.LogicTree;
import org.opensha.commons.logicTree.LogicTreeBranch;
import org.opensha.commons.logicTree.LogicTreeLevel;
import org.opensha.commons.logicTree.LogicTreeNode;
import org.opensha.commons.util.DataUtils;
import org.opensha.commons.util.ExceptionUtils;
import org.opensha.commons.util.FaultUtils;
import org.opensha.commons.util.modules.AverageableModule;
import org.opensha.commons.util.modules.ModuleContainer;
import org.opensha.commons.util.modules.OpenSHA_Module;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemRupSet;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemSolution;
import org.opensha.sha.earthquake.faultSysSolution.modules.AveSlipModule;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchAverageableModule;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchAveragingOrder;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchModuleBuilder;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchParentSectParticMFDs;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchRegionalMFDs;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchSectBVals;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchSectNuclMFDs;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchSectParticMFDs;
import org.opensha.sha.earthquake.faultSysSolution.modules.InfoModule;
import org.opensha.sha.earthquake.faultSysSolution.modules.LogicTreeRateStatistics;
import org.opensha.sha.earthquake.faultSysSolution.modules.ModSectMinMags;
import org.opensha.sha.earthquake.faultSysSolution.modules.RupMFDsModule;
import org.opensha.sha.earthquake.faultSysSolution.modules.SlipAlongRuptureModel;
import org.opensha.sha.earthquake.faultSysSolution.modules.SolutionLogicTree;
import org.opensha.sha.earthquake.faultSysSolution.modules.SolutionSlipRates;
import org.opensha.sha.earthquake.faultSysSolution.util.FaultSectionBranchAverager;
import org.opensha.sha.earthquake.faultSysSolution.util.FaultSysTools;
import org.opensha.sha.earthquake.faultSysSolution.util.SlipAlongRuptureModelBranchNode;
import org.opensha.sha.faultSurface.FaultSection;
import scratch.UCERF3.enumTreeBranches.DeformationModels;
import scratch.UCERF3.enumTreeBranches.ScalingRelationships;
import scratch.UCERF3.enumTreeBranches.SlipAlongRuptureModels;

public class BranchAverageSolutionCreator {
    private double totWeight = 0.0;
    private double[] avgRates = null;
    private double[] avgMags = null;
    private double[] avgAreas = null;
    private double[] avgLengths = null;
    private List<FaultUtils.AngleAverager> avgRakes = null;
    private List<List<Integer>> sectIndices = null;
    private List<DiscretizedFunc> rupMFDs = null;
    private FaultSystemRupSet refRupSet = null;
    private FaultSectionBranchAverager sectAverager = null;
    private LogicTreeBranch<LogicTreeNode> combBranch = null;
    private List<Double> weights = new ArrayList<Double>();
    private Map<LogicTreeNode, Integer> nodeCounts = new HashMap<LogicTreeNode, Integer>();
    private Map<LogicTreeNode, Double> nodeWeights = new HashMap<LogicTreeNode, Double>();
    private boolean skipRupturesBelowSectMin = true;
    private boolean accumulatingSlipRates = true;
    private List<TypedAccumulator<?>> rupSetAvgAccumulators;
    private List<TypedAccumulator<?>> solAvgAccumulators;
    private List<Class<? extends OpenSHA_Module>> skipModules = new ArrayList<Class<? extends OpenSHA_Module>>();
    private BranchWeightProvider weightProv;
    private LogicTreeRateStatistics.Builder rateStatsBuilder;
    private List<BranchModuleBuilder<FaultSystemSolution, ?>> solBranchModuleBuilders;
    private ExecutorService exec;

    public BranchAverageSolutionCreator(BranchWeightProvider weightProv) {
        this.weightProv = weightProv;
    }

    public void setSkipRupturesBelowSectMin(boolean skipRupturesBelowSectMin) {
        this.skipRupturesBelowSectMin = skipRupturesBelowSectMin;
    }

    public void skipModule(Class<? extends BranchAverageableModule<?>> clazz) {
        this.skipModules.add(clazz);
        if (SolutionSlipRates.class.isAssignableFrom(clazz)) {
            this.accumulatingSlipRates = false;
        }
    }

    public synchronized void addSolution(FaultSystemSolution sol, LogicTreeBranch<?> branch) {
        double weight = this.weightProv.getWeight(branch);
        Preconditions.checkState((weight > 0.0 ? 1 : 0) != 0, (String)"Can't average in branch with weight=%s: %s", (Object)weight, branch);
        this.weights.add(weight);
        this.totWeight += weight;
        FaultSystemRupSet rupSet = sol.getRupSet();
        ModSectMinMags modMinMags = rupSet.getModule(ModSectMinMags.class);
        if (this.accumulatingSlipRates && !sol.hasModule(SolutionSlipRates.class)) {
            if (rupSet.hasModule(AveSlipModule.class) && rupSet.hasModule(SlipAlongRuptureModel.class)) {
                sol.addModule(SolutionSlipRates.calc(sol, rupSet.requireModule(AveSlipModule.class), rupSet.requireModule(SlipAlongRuptureModel.class)));
            } else {
                this.accumulatingSlipRates = false;
            }
        }
        if (this.avgRates == null) {
            int r;
            this.avgRates = new double[rupSet.getNumRuptures()];
            this.avgMags = new double[this.avgRates.length];
            this.avgAreas = new double[this.avgRates.length];
            this.avgLengths = new double[this.avgRates.length];
            this.avgRakes = new ArrayList<FaultUtils.AngleAverager>();
            for (r = 0; r < rupSet.getNumRuptures(); ++r) {
                this.avgRakes.add(new FaultUtils.AngleAverager());
            }
            this.refRupSet = rupSet;
            this.sectAverager = new FaultSectionBranchAverager(rupSet.getFaultSectionDataList());
            this.rupSetAvgAccumulators = this.initAccumulators(rupSet);
            this.solAvgAccumulators = this.initAccumulators(sol);
            this.combBranch = branch.copy();
            this.sectIndices = rupSet.getSectionIndicesForAllRups();
            this.rupMFDs = new ArrayList<DiscretizedFunc>();
            for (r = 0; r < this.avgRates.length; ++r) {
                this.rupMFDs.add(new ArbitrarilyDiscretizedFunc());
            }
            this.rateStatsBuilder = new LogicTreeRateStatistics.Builder();
            this.solBranchModuleBuilders = new ArrayList();
            this.solBranchModuleBuilders.add(new BranchAveragingOrder.Builder());
            this.solBranchModuleBuilders.add(new BranchRegionalMFDs.Builder());
            this.solBranchModuleBuilders.add(new BranchSectNuclMFDs.Builder());
            this.solBranchModuleBuilders.add(new BranchSectParticMFDs.Builder());
            this.solBranchModuleBuilders.add(new BranchParentSectParticMFDs.Builder());
            this.solBranchModuleBuilders.add(new BranchSectBVals.Builder());
        } else {
            Preconditions.checkState((boolean)this.refRupSet.isEquivalentTo(rupSet), (Object)"Rupture sets are not equivalent");
        }
        this.rateStatsBuilder.process(branch, sol.getRateForAllRups());
        ArrayList futures = new ArrayList();
        if (this.exec == null) {
            this.exec = Executors.newFixedThreadPool(Integer.min(8, FaultSysTools.defaultNumThreads()));
        }
        try {
            futures.addAll(this.processBuilders(this.solBranchModuleBuilders, sol, branch, weight));
            futures.addAll(this.processAccumulators(this.rupSetAvgAccumulators, rupSet, branch, weight));
            futures.addAll(this.processAccumulators(this.solAvgAccumulators, sol, branch, weight));
            for (int i = 0; i < this.combBranch.size(); ++i) {
                LogicTreeNode logicTreeNode = this.combBranch.getValue(i);
                Object branchVal = branch.getValue(i);
                if (logicTreeNode != null && !logicTreeNode.equals(branchVal)) {
                    this.combBranch.clearValue(i);
                }
                int prevCount = this.nodeCounts.containsKey(branchVal) ? this.nodeCounts.get(branchVal) : 0;
                this.nodeCounts.put((LogicTreeNode)branchVal, prevCount + 1);
                double prevWeight = this.nodeWeights.containsKey(branchVal) ? this.nodeWeights.get(branchVal) : 0.0;
                this.nodeWeights.put((LogicTreeNode)branchVal, prevWeight + weight);
            }
            for (int r = 0; r < this.avgRates.length; ++r) {
                this.avgRakes.get(r).add(rupSet.getAveRakeForRup(r), weight);
                double d = sol.getRateForRup(r);
                Preconditions.checkState((d >= 0.0 ? 1 : 0) != 0, (String)"bad rate: %s", (Object)d);
                if (d == 0.0) continue;
                double mag = rupSet.getMagForRup(r);
                if (this.skipRupturesBelowSectMin && modMinMags != null && modMinMags.isRupBelowSectMinMag(r)) continue;
                int n = r;
                this.avgRates[n] = this.avgRates[n] + d * weight;
                DiscretizedFunc rupMFD = this.rupMFDs.get(r);
                double y = d * weight;
                if (rupMFD.hasX(mag)) {
                    y += rupMFD.getY(mag);
                }
                rupMFD.set(mag, y);
            }
            BranchAverageSolutionCreator.addWeighted(this.avgMags, rupSet.getMagForAllRups(), weight);
            BranchAverageSolutionCreator.addWeighted(this.avgAreas, rupSet.getAreaForAllRups(), weight);
            BranchAverageSolutionCreator.addWeighted(this.avgLengths, rupSet.getLengthForAllRups(), weight);
            this.sectAverager.addWeighted(rupSet.getFaultSectionDataList(), weight);
            for (Future future : futures) {
                future.get();
            }
        }
        catch (InterruptedException | RuntimeException | ExecutionException e) {
            if (this.exec != null) {
                this.exec.shutdown();
                this.exec = null;
            }
            throw ExceptionUtils.asRuntimeException(e);
        }
    }

    private <T extends BranchAverageableModule<T>> TypedAccumulator<T> getTypedAccumulator(BranchAverageableModule<T> module) {
        AverageableModule.AveragingAccumulator accumulator = module.averagingAccumulator();
        if (accumulator == null) {
            return null;
        }
        return new TypedAccumulator(accumulator);
    }

    private List<TypedAccumulator<?>> initAccumulators(ModuleContainer<OpenSHA_Module> container) {
        ArrayList accumulators = new ArrayList();
        for (OpenSHA_Module module : container.getModulesAssignableTo(BranchAverageableModule.class, true, this.skipModules)) {
            if (this.skipRupturesBelowSectMin && module instanceof ModSectMinMags) {
                System.out.println("Won't accumulate branch-averaged ModSectMinMags, we're skipping ruptures below those magnitudes in the rate calculations");
                continue;
            }
            Preconditions.checkState((boolean)(module instanceof BranchAverageableModule));
            System.out.println("Building branch-averaging accumulator for: " + module.getName());
            TypedAccumulator accumulator = this.getTypedAccumulator((BranchAverageableModule)module);
            if (accumulator == null) {
                System.err.println("WARNING: accumulator is null for module " + module.getName() + ", skipping averaging");
                continue;
            }
            accumulators.add(accumulator);
        }
        return accumulators;
    }

    private List<Future<?>> processAccumulators(List<TypedAccumulator<?>> accumulators, ModuleContainer<OpenSHA_Module> container, LogicTreeBranch<?> branch, double weight) {
        ArrayList runs = new ArrayList(accumulators.size());
        for (TypedAccumulator<?> accumulator : new ArrayList(accumulators)) {
            Runnable run = this.initAccumulatorRunnable(accumulators, accumulator, container, branch, weight);
            if (run == null) continue;
            runs.add((AccumulateRunnable<?>)run);
        }
        if (runs.isEmpty()) {
            return List.of();
        }
        ArrayList futures = new ArrayList(runs.size());
        for (Runnable run : runs) {
            futures.add(this.exec.submit(run));
        }
        return futures;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private <E extends BranchAverageableModule<E>> AccumulateRunnable<E> initAccumulatorRunnable(List<TypedAccumulator<?>> accumulators, TypedAccumulator<E> accumulator, ModuleContainer<OpenSHA_Module> container, LogicTreeBranch<?> branch, double weight) {
        BranchAverageableModule module = (BranchAverageableModule)container.getModule(accumulator.accumulator.getType());
        if (module == null) {
            List<TypedAccumulator<?>> list = accumulators;
            synchronized (list) {
                this.stopTrackingAccumulator(accumulators, accumulator, branch, "Module not loaded (null)");
            }
            return null;
        }
        return new AccumulateRunnable(this, accumulators, accumulator, module, branch, weight);
    }

    private void stopTrackingAccumulator(List<TypedAccumulator<?>> accumulators, TypedAccumulator<?> accumulator, LogicTreeBranch<?> branch, String error) {
        System.err.println("Error processing accumulator, will no longer average " + accumulator.accumulator.getType().getName() + ".\n\tFailed on branch: " + String.valueOf(branch) + "\n\tError message: " + error);
        System.err.flush();
        accumulators.remove(accumulator);
        if (this.accumulatingSlipRates && SolutionSlipRates.class.isAssignableFrom(accumulator.accumulator.getType())) {
            this.accumulatingSlipRates = false;
        }
    }

    private static void buildAverageModules(List<TypedAccumulator<?>> accumulators, ModuleContainer<OpenSHA_Module> container) {
        for (TypedAccumulator<?> accumulator : accumulators) {
            try {
                container.addModule((OpenSHA_Module)accumulator.accumulator.getAverage());
            }
            catch (Exception e) {
                e.printStackTrace();
                System.err.println("Error building average module of type " + accumulator.accumulator.getType().getName());
                System.err.flush();
            }
        }
    }

    private <E extends ModuleContainer<OpenSHA_Module>> List<Future<?>> processBuilders(List<BranchModuleBuilder<E, ?>> builders, E source, LogicTreeBranch<?> branch, double weight) {
        ArrayList<BuilderRunnable> runs = new ArrayList<BuilderRunnable>(builders.size());
        for (BranchModuleBuilder<E, ?> builder : builders) {
            runs.add(new BuilderRunnable(this, builders, builder, source, branch, weight));
        }
        ArrayList futures = new ArrayList(runs.size());
        for (Runnable runnable : runs) {
            futures.add(this.exec.submit(runnable));
        }
        return futures;
    }

    private static <E extends ModuleContainer<OpenSHA_Module>> void buildBranchModules(List<BranchModuleBuilder<E, ?>> builders, E container) {
        for (BranchModuleBuilder<E, ?> builder : builders) {
            try {
                ((ModuleContainer)container).addModule(builder.build());
            }
            catch (Exception e) {
                e.printStackTrace();
                System.err.println("Error building branch module of type " + builder.getClass().getName());
                System.err.flush();
            }
        }
    }

    public synchronized FaultSystemSolution build() {
        Preconditions.checkState((!this.weights.isEmpty() ? 1 : 0) != 0, (Object)"No solutions added!");
        Preconditions.checkState((this.totWeight > 0.0 ? 1 : 0) != 0, (String)"Total weight is not positive: %s", (Object)this.totWeight);
        if (this.exec != null) {
            this.exec.shutdown();
            this.exec = null;
        }
        System.out.println("Common branches: " + String.valueOf(this.combBranch));
        System.out.println("Normalizing by total weight: " + this.totWeight);
        double[] rakes = new double[this.avgRates.length];
        for (int r = 0; r < this.avgRates.length; ++r) {
            int n = r;
            this.avgRates[n] = this.avgRates[n] / this.totWeight;
            int n2 = r;
            this.avgMags[n2] = this.avgMags[n2] / this.totWeight;
            int n3 = r;
            this.avgAreas[n3] = this.avgAreas[n3] / this.totWeight;
            int n4 = r;
            this.avgLengths[n4] = this.avgLengths[n4] / this.totWeight;
            if (this.avgRates[r] == 0.0) {
                int size = this.rupMFDs.get(r).size();
                Preconditions.checkState((size == 0 ? 1 : 0) != 0, (String)"rate=0 but mfd has %s values?", (int)size);
                this.rupMFDs.set(r, null);
            } else {
                DiscretizedFunc rupMFD = this.rupMFDs.get(r);
                rupMFD.scale(1.0 / this.totWeight);
                double calcRate = rupMFD.calcSumOfY_Vals();
                Preconditions.checkState((DoubleMath.fuzzyEquals((double)calcRate, (double)this.avgRates[r], (double)1.0E-12) || DataUtils.getPercentDiff(calcRate, this.avgRates[r]) < 1.0E-5 ? 1 : 0) != 0, (String)"Rupture MFD rate=%s, avgRate=%s", (Object)calcRate, (Object)this.avgRates[r]);
            }
            rakes[r] = FaultUtils.getInRakeRange(this.avgRakes.get(r).getAverage());
        }
        List<FaultSection> subSects = this.sectAverager.buildAverageSects();
        FaultSystemRupSet avgRupSet = FaultSystemRupSet.builder(subSects, this.sectIndices).rupRakes(rakes).rupAreas(this.avgAreas).rupLengths(this.avgLengths).rupMags(this.avgMags).build();
        BranchAverageSolutionCreator.buildAverageModules(this.rupSetAvgAccumulators, avgRupSet);
        int numNonNull = 0;
        boolean haveSlipAlong = false;
        for (int i = 0; i < this.combBranch.size(); ++i) {
            LogicTreeNode value = this.combBranch.getValue(i);
            if (value == null) continue;
            ++numNonNull;
            if (value instanceof SlipAlongRuptureModel) {
                avgRupSet.addModule((SlipAlongRuptureModel)((Object)value));
                haveSlipAlong = true;
                continue;
            }
            if (!(value instanceof SlipAlongRuptureModelBranchNode)) continue;
            avgRupSet.addModule(((SlipAlongRuptureModelBranchNode)value).getModel());
            haveSlipAlong = true;
        }
        if (!haveSlipAlong && this.hasAllEqually(this.nodeCounts, SlipAlongRuptureModels.UNIFORM, SlipAlongRuptureModels.TAPERED)) {
            this.combBranch.setValue(SlipAlongRuptureModels.MEAN_UCERF3);
            avgRupSet.addModule(SlipAlongRuptureModels.MEAN_UCERF3.getModel());
            ++numNonNull;
        }
        if (this.combBranch.getValue(ScalingRelationships.class) == null && this.hasAllEqually(this.nodeCounts, ScalingRelationships.ELLB_SQRT_LENGTH, ScalingRelationships.ELLSWORTH_B, ScalingRelationships.HANKS_BAKUN_08, ScalingRelationships.SHAW_2009_MOD, ScalingRelationships.SHAW_CONST_STRESS_DROP)) {
            this.combBranch.setValue(ScalingRelationships.MEAN_UCERF3);
            if (!avgRupSet.hasModule(AveSlipModule.class)) {
                avgRupSet.addModule(AveSlipModule.forModel(avgRupSet, ScalingRelationships.MEAN_UCERF3));
            }
            ++numNonNull;
        }
        if (this.combBranch.getValue(DeformationModels.class) == null && this.hasAllEqually(this.nodeCounts, DeformationModels.GEOLOGIC, DeformationModels.ABM, DeformationModels.NEOKINEMA, DeformationModels.ZENGBB)) {
            this.combBranch.setValue(DeformationModels.MEAN_UCERF3);
            ++numNonNull;
        }
        if (numNonNull > 0) {
            avgRupSet.addModule(this.combBranch);
            System.out.println("Combined logic tree branch: " + String.valueOf(this.combBranch));
        }
        FaultSystemSolution sol = new FaultSystemSolution(avgRupSet, this.avgRates);
        BranchAverageSolutionCreator.buildAverageModules(this.solAvgAccumulators, sol);
        sol.addModule(this.combBranch);
        sol.addModule(new RupMFDsModule(sol, this.rupMFDs.toArray(new DiscretizedFunc[0])));
        sol.addModule(this.rateStatsBuilder.build());
        BranchAverageSolutionCreator.buildBranchModules(this.solBranchModuleBuilders, sol);
        DecimalFormat weightDF = new DecimalFormat("0.###%");
        String info = "Branch Averaged Fault System Solution across " + this.weights.size() + " branches.\n\nThe utilized branches at each level are (counts and total relative weights in parenthesis):\n\n";
        for (int i = 0; i < this.combBranch.size(); ++i) {
            LogicTreeLevel<LogicTreeNode> level = this.combBranch.getLevel(i);
            info = info + level.getName() + ":\n";
            int numIncluded = 0;
            int numSkipped = 0;
            int lastSkippedCount = -1;
            int totalSkippedCount = 0;
            Named lastSkipped = null;
            for (LogicTreeNode choice : level.getNodes()) {
                Integer count = this.nodeCounts.get(choice);
                if (count == null) continue;
                if (numIncluded < 15) {
                    double weight = this.nodeWeights.get(choice);
                    info = info + "\t" + choice.getName() + " (" + count + "; " + weightDF.format(weight / this.totWeight) + ")\n";
                    ++numIncluded;
                    continue;
                }
                lastSkipped = choice;
                lastSkippedCount = count;
                totalSkippedCount += count.intValue();
                ++numSkipped;
            }
            if (lastSkipped == null) continue;
            if (numSkipped > 1) {
                info = info + "\t...(skipping " + (numSkipped - 1) + " branches used " + (totalSkippedCount - lastSkippedCount) + " times)...\n";
            }
            double weight = this.nodeWeights.get(lastSkipped);
            info = info + "\t" + lastSkipped.getName() + " (" + lastSkippedCount + "; " + weightDF.format(weight / this.totWeight) + ")\n";
        }
        sol.addModule(new InfoModule(info));
        this.avgRates = null;
        this.weights = new ArrayList<Double>();
        this.totWeight = 0.0;
        return sol;
    }

    private boolean hasAllEqually(Map<LogicTreeNode, Integer> nodeCounts, LogicTreeNode ... nodes) {
        Integer commonCount = null;
        for (LogicTreeNode node : nodes) {
            Integer count = nodeCounts.get(node);
            if (count == null) {
                return false;
            }
            if (commonCount == null) {
                commonCount = count;
                continue;
            }
            if (commonCount.intValue() == count.intValue()) continue;
            return false;
        }
        return true;
    }

    private static void addWeighted(double[] running, double[] vals, double weight) {
        Preconditions.checkState((running.length == vals.length ? 1 : 0) != 0);
        for (int i = 0; i < running.length; ++i) {
            int n = i;
            running[n] = running[n] + vals[i] * weight;
        }
    }

    private static Options createOptions() {
        Options ops = new Options();
        ops.addRequiredOption("if", "input-file", true, "Input solution logic tree zip file.");
        ops.addRequiredOption("of", "output-file", true, "Output branch averaged solution file.");
        ops.addOption("rt", "restrict-tree", true, "Restrict the logic tree to the given value. Specify values by either their short name or file prefix. If such a name is ambiguous (applies to multiple branch levels), excplicitly set the level as <level-short-name>=<value>. Repeat this argument to specify multiple values (within a single level and/or across multiple levels).");
        ops.addOption("rw", "reweight", false, "Flag to use current branch weights rather than those when the simulation was originally run");
        return ops;
    }

    public static void main(String[] args) throws IOException {
        CommandLine cmd = FaultSysTools.parseOptions(BranchAverageSolutionCreator.createOptions(), args, BranchAverageSolutionCreator.class);
        File inputFile = new File(cmd.getOptionValue("input-file"));
        File outputFile = new File(cmd.getOptionValue("output-file"));
        SolutionLogicTree slt = SolutionLogicTree.load(inputFile);
        LogicTree<?> tree = slt.getLogicTree();
        BranchWeightProvider weightProv = cmd.hasOption("reweight") ? new BranchWeightProvider.CurrentWeights() : tree.getWeightProvider();
        BranchAverageSolutionCreator ba = new BranchAverageSolutionCreator(weightProv);
        ArrayList restrictTos = null;
        String[] restrictNames = cmd.getOptionValues("restrict-tree");
        if (restrictNames != null && restrictNames.length > 0) {
            restrictTos = new ArrayList();
            for (int i = 0; i < tree.getLevels().size(); ++i) {
                restrictTos.add(new ArrayList());
            }
            String[] i = restrictNames;
            int n = i.length;
            for (int j = 0; j < n; ++j) {
                String op;
                String valName = op = i[j];
                String levelName = null;
                if (op.contains("=")) {
                    levelName = op.substring(0, op.indexOf(61)).trim();
                    valName = op.substring(op.indexOf(61) + 1).trim();
                    System.out.println("Looking for logic tree level with name '" + levelName + "', value of '" + valName + "'");
                } else {
                    System.out.println("Looking for logic tree value of '" + valName + "'");
                }
                LogicTreeNode match = null;
                for (int i2 = 0; i2 < tree.getLevels().size(); ++i2) {
                    LogicTreeLevel level = (LogicTreeLevel)tree.getLevels().get(i2);
                    if (levelName != null && !level.getShortName().equals(levelName)) continue;
                    for (LogicTreeNode value : level.getNodes()) {
                        if (!value.getShortName().equals(valName) && !value.getFilePrefix().equals(valName)) continue;
                        if (match == null) {
                            System.out.println("Found matching node: " + String.valueOf(value));
                            match = value;
                            ((List)restrictTos.get(i2)).add(value);
                            continue;
                        }
                        throw new IllegalStateException("Logic tree value '" + valName + "' is ambiguous (matches multiple values across multiple levels). Specify the appropriate level as --restrict-tree <level-short-name>=<value>.");
                    }
                }
                Preconditions.checkNotNull(match, (String)"Didn't find matching logic tree value='%s' (level='%s')", (Object)valName, (Object)levelName);
            }
        }
        for (LogicTreeBranch<?> branch : slt.getLogicTree()) {
            if (restrictTos != null) {
                Preconditions.checkState((branch.size() == tree.getLevels().size() ? 1 : 0) != 0);
                boolean hasAll = true;
                for (int i = 0; i < branch.size(); ++i) {
                    List restrictTo = (List)restrictTos.get(i);
                    if (restrictTo.isEmpty()) continue;
                    boolean match = false;
                    for (LogicTreeNode restrict : restrictTo) {
                        if (!branch.hasValue(restrict)) continue;
                        match = true;
                    }
                    if (match) continue;
                    hasAll = false;
                    break;
                }
                if (!hasAll) {
                    System.out.println("Skipping branch: " + String.valueOf(branch));
                    continue;
                }
            }
            FaultSystemSolution sol = slt.forBranch(branch);
            ba.addSolution(sol, branch);
        }
        System.out.println("Building final branch-averaged solution.");
        FaultSystemSolution baSol = ba.build();
        baSol.write(outputFile);
    }

    class TypedAccumulator<T extends BranchAverageableModule<T>> {
        private final AverageableModule.AveragingAccumulator<T> accumulator;

        public TypedAccumulator(AverageableModule.AveragingAccumulator<T> accumulator) {
            this.accumulator = accumulator;
        }

        public AverageableModule.AveragingAccumulator<T> getAccumulator() {
            return this.accumulator;
        }
    }

    private class AccumulateRunnable<E extends BranchAverageableModule<E>>
    implements Runnable {
        private List<TypedAccumulator<?>> accumulators;
        private TypedAccumulator<E> accumulator;
        private E module;
        private LogicTreeBranch<?> branch;
        private double weight;
        final /* synthetic */ BranchAverageSolutionCreator this$0;

        /*
         * WARNING - Possible parameter corruption
         * WARNING - void declaration
         */
        public AccumulateRunnable(List<TypedAccumulator<?>> accumulator, TypedAccumulator<E> module, E branch, LogicTreeBranch<?> weight, double d2) {
            void accumulators;
            this.this$0 = (BranchAverageSolutionCreator)d;
            this.accumulators = accumulators;
            this.accumulator = accumulator;
            this.module = module;
            this.branch = branch;
            this.weight = (double)weight;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            try {
                this.accumulator.accumulator.process(this.module, this.weight);
            }
            catch (Exception e) {
                List<TypedAccumulator<?>> list = this.accumulators;
                synchronized (list) {
                    this.this$0.stopTrackingAccumulator(this.accumulators, this.accumulator, this.branch, e.getMessage());
                }
            }
        }
    }

    private class BuilderRunnable<E extends ModuleContainer<OpenSHA_Module>>
    implements Runnable {
        private List<BranchModuleBuilder<E, ?>> builders;
        private BranchModuleBuilder<E, ?> builder;
        private E source;
        private LogicTreeBranch<?> branch;
        private double weight;
        final /* synthetic */ BranchAverageSolutionCreator this$0;

        /*
         * WARNING - Possible parameter corruption
         * WARNING - void declaration
         */
        public BuilderRunnable(List<BranchModuleBuilder<E, ?>> builder, BranchModuleBuilder<E, ?> source, E branch, LogicTreeBranch<?> weight, double d2) {
            void builders;
            this.this$0 = (BranchAverageSolutionCreator)d;
            this.builders = builders;
            this.builder = builder;
            this.source = source;
            this.branch = branch;
            this.weight = (double)weight;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            try {
                this.builder.process(this.source, this.branch, this.weight);
            }
            catch (Exception e) {
                List<BranchModuleBuilder<E, ?>> list = this.builders;
                synchronized (list) {
                    System.err.println("Error processing branch module builder, will no longer average " + this.builder.getClass().getName() + "\n\tError message: " + e.getMessage());
                    System.err.flush();
                    this.builders.remove(this.builder);
                }
            }
        }
    }
}

