/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.faultSysSolution.modules;

import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.opensha.commons.data.CSVFile;
import org.opensha.commons.data.CSVWriter;
import org.opensha.commons.data.function.ArbDiscrEmpiricalDistFunc;
import org.opensha.commons.logicTree.LogicTree;
import org.opensha.commons.logicTree.LogicTreeBranch;
import org.opensha.commons.util.ExceptionUtils;
import org.opensha.commons.util.io.archive.ArchiveInput;
import org.opensha.commons.util.io.archive.ArchiveOutput;
import org.opensha.commons.util.modules.ArchivableModule;
import org.opensha.commons.util.modules.helpers.CSV_BackedModule;
import org.opensha.commons.util.modules.helpers.FileBackedModule;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemSolution;
import org.opensha.sha.earthquake.faultSysSolution.modules.BranchModuleBuilder;
import org.opensha.sha.earthquake.faultSysSolution.modules.InversionTargetMFDs;
import org.opensha.sha.earthquake.faultSysSolution.modules.SolutionLogicTree;
import org.opensha.sha.earthquake.faultSysSolution.reports.plots.SectBValuePlot;
import org.opensha.sha.faultSurface.FaultSection;
import org.opensha.sha.magdist.IncrementalMagFreqDist;

public class BranchSectBVals
implements ArchivableModule {
    private double[] weights;
    private int[] parentIDs;
    private boolean parentsSorted = false;
    private float[][] sectBVals;
    private float[][] parentBVals;
    private float[][] sectTargetBVals;
    private float[][] parentTargetBVals;
    private static final String SECT_FILE_NAME = "branch_sect_b_vals.csv";
    private static final String PARENT_FILE_NAME = "branch_parent_b_vals.csv";
    private static final String SECT_TARGET_FILE_NAME = "branch_sect_target_b_vals.csv";
    private static final String PARENT_TARGET_FILE_NAME = "branch_parent_target_b_vals.csv";
    private static final DecimalFormat bDF = new DecimalFormat("0.##");

    @Override
    public String getName() {
        return "Branch Section b-values";
    }

    @Override
    public void writeToArchive(ArchiveOutput output, String entryPrefix) throws IOException {
        this.writeBValCSV(FileBackedModule.initOutputStream(output, entryPrefix, SECT_FILE_NAME), this.sectBVals, false);
        output.closeEntry();
        if (this.parentBVals != null) {
            this.writeBValCSV(FileBackedModule.initOutputStream(output, entryPrefix, PARENT_FILE_NAME), this.parentBVals, true);
            output.closeEntry();
        }
        if (this.sectTargetBVals != null) {
            this.writeBValCSV(FileBackedModule.initOutputStream(output, entryPrefix, SECT_TARGET_FILE_NAME), this.sectTargetBVals, false);
            output.closeEntry();
        }
        if (this.parentTargetBVals != null) {
            this.writeBValCSV(FileBackedModule.initOutputStream(output, entryPrefix, PARENT_TARGET_FILE_NAME), this.parentTargetBVals, true);
            output.closeEntry();
        }
    }

    private void writeBValCSV(OutputStream out, float[][] sectBVals, boolean parents) throws IOException {
        CSVWriter csv = new CSVWriter(out, true);
        ArrayList<String> header = new ArrayList<String>(sectBVals.length + 2);
        header.add("Branch Index");
        header.add("Branch Weight");
        if (parents) {
            for (int id : this.parentIDs) {
                header.add("" + id);
            }
        } else {
            for (int s = 0; s < sectBVals[0].length; ++s) {
                header.add("" + s);
            }
        }
        csv.write(header);
        Preconditions.checkState((sectBVals.length == this.weights.length ? 1 : 0) != 0);
        for (int b = 0; b < this.weights.length; ++b) {
            ArrayList<String> line = new ArrayList<String>(header.size());
            line.add("" + b);
            line.add("" + this.weights[b]);
            float[] fArray = sectBVals[b];
            int n = fArray.length;
            for (int i = 0; i < n; ++i) {
                double bVal = fArray[i];
                line.add(bDF.format(bVal));
            }
            csv.write(line);
        }
        csv.flush();
    }

    @Override
    public void initFromArchive(ArchiveInput input, String entryPrefix) throws IOException {
        this.sectBVals = this.loadBValCSV(CSV_BackedModule.loadFromArchive(input, entryPrefix, SECT_FILE_NAME), false);
        if (FileBackedModule.hasEntry(input, entryPrefix, PARENT_FILE_NAME)) {
            this.parentBVals = this.loadBValCSV(CSV_BackedModule.loadFromArchive(input, entryPrefix, PARENT_FILE_NAME), true);
        }
        if (FileBackedModule.hasEntry(input, entryPrefix, SECT_TARGET_FILE_NAME)) {
            this.sectTargetBVals = this.loadBValCSV(CSV_BackedModule.loadFromArchive(input, entryPrefix, SECT_TARGET_FILE_NAME), false);
        }
        if (FileBackedModule.hasEntry(input, entryPrefix, PARENT_TARGET_FILE_NAME)) {
            this.parentTargetBVals = this.loadBValCSV(CSV_BackedModule.loadFromArchive(input, entryPrefix, PARENT_TARGET_FILE_NAME), true);
        }
    }

    private float[][] loadBValCSV(CSVFile<String> csv, boolean parents) {
        int numSects = csv.getLine(0).size() - 2;
        if (parents) {
            int i;
            int[] parentIDs = new int[numSects];
            for (i = 0; i < parentIDs.length; ++i) {
                parentIDs[i] = csv.getInt(0, i + 2);
            }
            if (this.parentIDs == null) {
                this.parentIDs = parentIDs;
                this.parentsSorted = true;
                for (i = 1; this.parentsSorted && i < parentIDs.length; ++i) {
                    this.parentsSorted = parentIDs[i] > parentIDs[i - 1];
                }
            } else {
                Preconditions.checkState((parentIDs.length == this.parentIDs.length ? 1 : 0) != 0);
                for (i = 0; i < parentIDs.length; ++i) {
                    Preconditions.checkState((parentIDs[i] == this.parentIDs[i] ? 1 : 0) != 0);
                }
            }
        }
        float[][] ret = new float[csv.getNumRows() - 1][numSects];
        double[] weights = new double[ret.length];
        for (int b = 0; b < ret.length; ++b) {
            int row = b + 1;
            weights[b] = csv.getDouble(row, 1);
            for (int s = 0; s < numSects; ++s) {
                ret[b][s] = csv.getFloat(row, s + 2);
            }
        }
        if (this.weights == null) {
            this.weights = weights;
        } else {
            Preconditions.checkState((weights.length == this.weights.length ? 1 : 0) != 0);
            for (int i = 0; i < weights.length; ++i) {
                Preconditions.checkState(((float)weights[i] == (float)this.weights[i] ? 1 : 0) != 0);
            }
        }
        return ret;
    }

    public int getNumBranches() {
        return this.weights.length;
    }

    public boolean hasParentBVals() {
        return this.parentBVals != null;
    }

    public boolean hasTargetBVals() {
        return this.sectTargetBVals != null;
    }

    public ArbDiscrEmpiricalDistFunc getSectBValDist(int sectIndex) {
        return this.buildBValDist(this.sectBVals, sectIndex);
    }

    public ArbDiscrEmpiricalDistFunc getParentBValDist(int parentID) {
        return this.buildBValDist(this.parentBVals, this.parentIndex(parentID));
    }

    public ArbDiscrEmpiricalDistFunc getSectTargetBValDist(int sectIndex) {
        return this.buildBValDist(this.sectTargetBVals, sectIndex);
    }

    public ArbDiscrEmpiricalDistFunc getParentTargetBValDist(int parentID) {
        return this.buildBValDist(this.parentTargetBVals, this.parentIndex(parentID));
    }

    private int parentIndex(int parentID) {
        if (this.parentsSorted) {
            int index = Arrays.binarySearch(this.parentIDs, parentID);
            Preconditions.checkState((index >= 0 ? 1 : 0) != 0, (String)"Parent not found: %s", (int)parentID);
            return index;
        }
        for (int i = 0; i < this.parentIDs.length; ++i) {
            if (this.parentIDs[i] != parentID) continue;
            return i;
        }
        throw new IllegalStateException("Parent not found: " + parentID);
    }

    private ArbDiscrEmpiricalDistFunc buildBValDist(float[][] bVals, int sectIndex) {
        ArbDiscrEmpiricalDistFunc dist = new ArbDiscrEmpiricalDistFunc();
        for (int i = 0; i < bVals.length; ++i) {
            dist.set(bVals[i][sectIndex], this.weights[i]);
        }
        return dist;
    }

    public double getSectMeanBVal(int sectIndex) {
        return this.calcMeanBVal(this.sectBVals, sectIndex);
    }

    public double getParentMeanBVal(int parentID) {
        return this.calcMeanBVal(this.parentBVals, this.parentIndex(parentID));
    }

    public double getSectTargetMeanBVal(int sectIndex) {
        return this.calcMeanBVal(this.sectTargetBVals, sectIndex);
    }

    public double getParentTargetMeanBVal(int parentID) {
        return this.calcMeanBVal(this.parentTargetBVals, this.parentIndex(parentID));
    }

    private double calcMeanBVal(float[][] bVals, int sectIndex) {
        double ret = 0.0;
        double sumWeight = 0.0;
        for (int i = 0; i < bVals.length; ++i) {
            ret += (double)bVals[i][sectIndex] * this.weights[i];
            sumWeight += this.weights[i];
        }
        return ret / sumWeight;
    }

    public static void main(String[] args) throws IOException {
        File dir = new File("/home/kevin/OpenSHA/UCERF4/batch_inversions/2022_09_28-nshm23_branches-NSHM23_v2-CoulombRupSet-TotNuclRate-NoRed-ThreshAvgIterRelGR");
        File resultsFile = new File(dir, "results.zip");
        SolutionLogicTree slt = SolutionLogicTree.load(resultsFile);
        final LogicTree<?> tree = slt.getLogicTree();
        File inputBA = new File(dir, "results_NSHM23_v2_CoulombRupSet_branch_averaged.zip");
        File outputFile = new File(dir, "results_NSHM23_v2_CoulombRupSet_branch_averaged_sect_b_vals.zip");
        final Builder builder = new Builder();
        int count = 0;
        CompletableFuture<Void> processingLoadedFuture = null;
        for (final LogicTreeBranch<?> branch : tree) {
            Stopwatch watch = Stopwatch.createStarted();
            System.out.println("Loading solution for branch " + count);
            final FaultSystemSolution sol = slt.forBranch(branch, false);
            if (processingLoadedFuture != null) {
                try {
                    processingLoadedFuture.get();
                }
                catch (InterruptedException | ExecutionException e) {
                    throw ExceptionUtils.asRuntimeException(e);
                }
            }
            processingLoadedFuture = CompletableFuture.runAsync(new Runnable(){

                @Override
                public void run() {
                    builder.process(sol, branch, tree.getBranchWeight(branch));
                }
            });
            watch.stop();
            double secs = (double)watch.elapsed(TimeUnit.MILLISECONDS) / 1000.0;
            System.out.println("DONE branch " + count + " in " + bDF.format(secs) + " s");
            ++count;
        }
        try {
            processingLoadedFuture.get();
        }
        catch (InterruptedException | ExecutionException e) {
            throw ExceptionUtils.asRuntimeException(e);
        }
        BranchSectBVals bVals = builder.build();
        FaultSystemSolution ba = FaultSystemSolution.load(inputBA);
        ba.addModule(bVals);
        ba.write(outputFile);
    }

    public static class Builder
    implements BranchModuleBuilder<FaultSystemSolution, BranchSectBVals> {
        private List<Double> weights;
        private int[] parentIDs;
        private List<float[]> sectBVals;
        private List<float[]> parentBVals;
        private List<float[]> sectTargetBVals;
        private List<float[]> parentTargetBVals;

        @Override
        public synchronized void process(FaultSystemSolution sol, LogicTreeBranch<?> branch, double weight) {
            int numSects = sol.getRupSet().getNumSections();
            if (this.sectBVals == null) {
                InversionTargetMFDs targetMFDs;
                this.sectBVals = new ArrayList<float[]>();
                HashSet<Integer> parentIDs = new HashSet<Integer>();
                for (FaultSection faultSection : sol.getRupSet().getFaultSectionDataList()) {
                    int parentID = faultSection.getParentSectionId();
                    if (parentID < 0) {
                        parentIDs = null;
                        break;
                    }
                    parentIDs.add(parentID);
                }
                if (parentIDs != null) {
                    this.parentBVals = new ArrayList<float[]>();
                    ArrayList sorted = new ArrayList(parentIDs);
                    Collections.sort(sorted);
                    this.parentIDs = Ints.toArray(sorted);
                }
                if ((targetMFDs = sol.getRupSet().getModule(InversionTargetMFDs.class)) != null && targetMFDs.getOnFaultSupraSeisNucleationMFDs() != null) {
                    this.sectTargetBVals = new ArrayList<float[]>();
                    if (parentIDs != null) {
                        this.parentTargetBVals = new ArrayList<float[]>();
                    }
                }
                this.weights = new ArrayList<Double>();
            } else {
                Preconditions.checkState((numSects == this.sectBVals.get(0).length ? 1 : 0) != 0);
            }
            this.weights.add(weight);
            this.sectBVals.add(this.bValArray(SectBValuePlot.estSectBValues(sol)));
            if (this.parentIDs != null) {
                this.parentBVals.add(this.bValArray(SectBValuePlot.estParentSectBValues(sol)));
            }
            if (this.sectTargetBVals != null) {
                InversionTargetMFDs targetMFDs = sol.getRupSet().getModule(InversionTargetMFDs.class);
                if (targetMFDs != null && targetMFDs.getOnFaultSupraSeisNucleationMFDs() != null) {
                    List<? extends IncrementalMagFreqDist> supraTargets = targetMFDs.getOnFaultSupraSeisNucleationMFDs();
                    this.sectTargetBVals.add(this.bValArray(SectBValuePlot.estSectTargetBValues(supraTargets)));
                    if (this.parentIDs != null) {
                        this.parentTargetBVals.add(this.bValArray(SectBValuePlot.estParentSectTargetBValues(sol, supraTargets)));
                    }
                } else {
                    this.sectTargetBVals = null;
                    this.parentTargetBVals = null;
                }
            }
        }

        private float[] bValArray(SectBValuePlot.BValEstimate[] bVals) {
            float[] ret = new float[bVals.length];
            for (int i = 0; i < ret.length; ++i) {
                ret[i] = (float)bVals[i].b;
            }
            return ret;
        }

        private float[] bValArray(Map<Integer, SectBValuePlot.BValEstimate> bVals) {
            float[] ret = new float[this.parentIDs.length];
            for (int i = 0; i < ret.length; ++i) {
                SectBValuePlot.BValEstimate bVal = bVals.get(this.parentIDs[i]);
                Preconditions.checkNotNull((Object)bVal, (String)"no b-value for section %s", (int)this.parentIDs[i]);
                ret[i] = (float)bVal.b;
            }
            return ret;
        }

        @Override
        public BranchSectBVals build() {
            BranchSectBVals ret = new BranchSectBVals();
            int numBranches = this.weights.size();
            Preconditions.checkState((numBranches > 0 ? 1 : 0) != 0);
            ret.weights = Doubles.toArray(this.weights);
            int numSects = this.sectBVals.get(0).length;
            ret.sectBVals = new float[numBranches][];
            if (this.parentBVals != null) {
                ret.parentIDs = this.parentIDs;
                ret.parentsSorted = true;
                ret.parentBVals = new float[numBranches][];
            }
            if (this.sectTargetBVals != null) {
                ret.sectTargetBVals = new float[numBranches][];
            }
            if (this.parentTargetBVals != null) {
                ret.parentTargetBVals = new float[numBranches][];
            }
            for (int b = 0; b < numBranches; ++b) {
                ret.sectBVals[b] = this.sectBVals.get(b);
                Preconditions.checkArgument((ret.sectBVals[b].length == numSects ? 1 : 0) != 0);
                if (this.parentBVals != null) {
                    ret.parentBVals[b] = this.parentBVals.get(b);
                }
                if (this.sectTargetBVals != null) {
                    ret.sectTargetBVals[b] = this.sectTargetBVals.get(b);
                }
                if (this.parentTargetBVals == null) continue;
                ret.parentTargetBVals[b] = this.parentTargetBVals.get(b);
            }
            return ret;
        }
    }
}

