/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.faultSysSolution.inversion.mpj;

import com.google.common.base.Preconditions;
import edu.usc.kmilner.mpj.taskDispatch.MPJTaskCalculator;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.opensha.commons.logicTree.LogicTree;
import org.opensha.commons.logicTree.LogicTreeBranch;
import org.opensha.commons.logicTree.LogicTreeNode;
import org.opensha.commons.util.ExceptionUtils;
import org.opensha.commons.util.io.archive.ArchiveInput;
import org.opensha.commons.util.modules.OpenSHA_Module;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemRupSet;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemSolution;
import org.opensha.sha.earthquake.faultSysSolution.inversion.InversionConfiguration;
import org.opensha.sha.earthquake.faultSysSolution.inversion.InversionConfigurationFactory;
import org.opensha.sha.earthquake.faultSysSolution.inversion.Inversions;
import org.opensha.sha.earthquake.faultSysSolution.inversion.mpj.AbstractAsyncLogicTreeWriter;
import org.opensha.sha.earthquake.faultSysSolution.modules.SolutionLogicTree;
import org.opensha.sha.earthquake.faultSysSolution.util.AverageSolutionCreator;
import org.opensha.sha.earthquake.faultSysSolution.util.FaultSysTools;

public class MPJ_LogicTreeInversionRunner
extends MPJTaskCalculator {
    private File outputDir;
    private int runsPerBranch = 1;
    private InversionConfigurationFactory factory;
    private LogicTree<LogicTreeNode> tree;
    private CommandLine cmd;
    private int annealingThreads;
    private int runsPerBundle = 1;
    private boolean reprocess = false;
    private boolean reprocessOnly = false;

    public MPJ_LogicTreeInversionRunner(CommandLine cmd) throws IOException {
        super(cmd);
        this.cmd = cmd;
        this.annealingThreads = Integer.parseInt(cmd.getOptionValue("annealing-threads"));
        Preconditions.checkState((this.annealingThreads >= 1 ? 1 : 0) != 0);
        if (cmd.hasOption("runs-per-bundle")) {
            this.runsPerBundle = Integer.parseInt(cmd.getOptionValue("runs-per-bundle"));
        }
        this.shuffle = false;
        this.tree = LogicTree.read(new File(cmd.getOptionValue("logic-tree")));
        if (this.rank == 0) {
            this.debug("Loaded " + this.tree.size() + " tree nodes");
        }
        this.outputDir = new File(cmd.getOptionValue("output-dir"));
        if (this.rank == 0) {
            MPJ_LogicTreeInversionRunner.waitOnDir(this.outputDir, 5, 1000L);
        }
        if (cmd.hasOption("runs-per-branch")) {
            this.runsPerBranch = Integer.parseInt(cmd.getOptionValue("runs-per-branch"));
        }
        try {
            Class<?> factoryClass = Class.forName(cmd.getOptionValue("inversion-factory"));
            this.factory = (InversionConfigurationFactory)factoryClass.getDeclaredConstructor(new Class[0]).newInstance(new Object[0]);
        }
        catch (Exception e) {
            throw ExceptionUtils.asRuntimeException(e);
        }
        if (cmd.hasOption("reprocess-only")) {
            Preconditions.checkState((this.factory.getSolutionLogicTreeProcessor() != null ? 1 : 0) != 0, (Object)"Can't reprocess if we don't have a solution processor");
            this.reprocess = true;
            this.reprocessOnly = true;
        } else {
            this.reprocess = this.factory.getSolutionLogicTreeProcessor() != null && cmd.hasOption("reprocess-existing");
        }
        File cacheDir = FaultSysTools.getCacheDir(cmd);
        if (cacheDir != null) {
            if (this.rank == 0) {
                MPJ_LogicTreeInversionRunner.waitOnDir(cacheDir, 3, 1000L);
            }
            this.factory.setCacheDir(cacheDir);
        }
        this.factory.setAutoCache(this.rank == 0);
        this.debug("Factory type: " + this.factory.getClass().getName());
        if (this.rank == 0) {
            this.postBatchHook = new AsyncLogicTreeWriter(this.factory.getSolutionLogicTreeProcessor());
        }
    }

    private void memoryDebug(String info) {
        info = info == null || ((String)info).isBlank() ? "" : (String)info + "; ";
        this.debug((String)info + MPJ_LogicTreeInversionRunner.memoryString());
    }

    static String memoryString() {
        System.gc();
        Runtime rt = Runtime.getRuntime();
        long totalMB = rt.totalMemory() / 1024L / 1024L;
        long freeMB = rt.freeMemory() / 1024L / 1024L;
        long usedMB = totalMB - freeMB;
        return "mem t/u/f: " + totalMB + "/" + usedMB + "/" + freeMB;
    }

    private static void waitOnDir(File dir, int maxRetries, long sleepMillis) {
        int retry = 0;
        while (!dir.exists() && !dir.mkdir()) {
            try {
                Thread.sleep(sleepMillis);
            }
            catch (InterruptedException e) {
                throw ExceptionUtils.asRuntimeException(e);
            }
            if (retry++ <= maxRetries) continue;
            throw new IllegalStateException("Directory doesn't exist and couldn't be created after " + maxRetries + " retries: " + dir.getAbsolutePath());
        }
    }

    protected int getNumTasks() {
        return this.tree.size() * this.runsPerBranch;
    }

    protected File getSolFile(LogicTreeBranch<?> branch, int run) {
        Object suffix = "";
        if (this.runsPerBranch > 1) {
            suffix = "_run" + run;
        }
        File runDir = branch.getBranchDirectory(this.outputDir, true, (String)suffix);
        return new File(runDir, "solution.zip");
    }

    private int branchForCalcIndex(int index) {
        return index / this.runsPerBranch;
    }

    private int runForCalcIndex(int index) {
        return index % this.runsPerBranch;
    }

    private int indexForBranchRun(int branchIndex, int runIndex) {
        return branchIndex * this.runsPerBranch + runIndex;
    }

    protected void calculateBatch(int[] batch) throws Exception {
        if (this.runsPerBundle > 1) {
            ExecutorService exec = Executors.newFixedThreadPool(this.runsPerBundle);
            ArrayList futures = new ArrayList();
            for (int index : batch) {
                futures.add(exec.submit(new CalcRunnable(index)));
            }
            Object object = futures.iterator();
            while (object.hasNext()) {
                Future future = (Future)object.next();
                try {
                    future.get();
                }
                catch (InterruptedException | ExecutionException e) {
                    exec.shutdown();
                    throw e;
                }
            }
            exec.shutdown();
        } else {
            for (int index : batch) {
                new CalcRunnable(index).run();
            }
        }
        if (this.runsPerBranch > 1) {
            for (int index : batch) {
                int branchIndex = this.branchForCalcIndex(index);
                int run = this.runForCalcIndex(index);
                if (run != this.runsPerBranch - 1) continue;
                boolean allDone = true;
                LogicTreeBranch<LogicTreeNode> branch = this.tree.getBranch(branchIndex);
                ArrayList<File> solFiles = new ArrayList<File>();
                for (int oRun = 0; oRun < this.runsPerBranch; ++oRun) {
                    File solFile = this.getSolFile(branch, oRun);
                    if (!solFile.exists()) {
                        allDone = false;
                        break;
                    }
                    solFiles.add(solFile);
                }
                if (!allDone) continue;
                Preconditions.checkState((solFiles.size() == this.runsPerBranch ? 1 : 0) != 0);
                this.debug("Branch index " + branchIndex + " is all done, doing a compute node average for " + String.valueOf(branch));
                File runDir = branch.getBranchDirectory(this.outputDir, true);
                File outputFile = new File(runDir, "average_solution.zip");
                AverageSolutionCreator.average(outputFile, solFiles);
            }
        }
    }

    protected void doFinalAssembly() throws Exception {
        if (this.rank == 0) {
            this.memoryDebug("waiting for any post batch hook operations to finish");
            ((AsyncLogicTreeWriter)this.postBatchHook).shutdown();
            this.memoryDebug("post batch hook done");
        }
    }

    public static Options createOptions() {
        Options ops = MPJTaskCalculator.createOptions();
        ops.addRequiredOption("lt", "logic-tree", true, "Path to logic tree JSON file");
        ops.addRequiredOption("od", "output-dir", true, "Path to output directory");
        ops.addOption(FaultSysTools.cacheDirOption());
        ops.addRequiredOption("at", "annealing-threads", true, "Number of annealing threads per inversion");
        ops.addOption("rpb", "runs-per-branch", true, "Runs per branch (default is 1)");
        ops.addOption("rpb", "runs-per-bundle", true, "Simultaneous runs to executure (default is 1)");
        ops.addRequiredOption("ifc", "inversion-factory", true, "Inversion configuration factory classname");
        ops.addOption("rpe", "reprocess-existing", false, "Flag to enable re-processing of already completed solutionswith the factory's SolutionProcessor before branch averaging");
        ops.addOption(null, "reprocess-only", false, "Flag to only re-process already completed solutions with the factory's SolutionProcessor before branch averaging, ensuring that all inversions are already completed");
        for (Option op : InversionConfiguration.createSAOptions().getOptions()) {
            ops.addOption(op);
        }
        return ops;
    }

    public static void main(String[] args) {
        System.setProperty("java.awt.headless", "true");
        try {
            args = MPJTaskCalculator.initMPJ((String[])args);
            Options options = MPJ_LogicTreeInversionRunner.createOptions();
            CommandLine cmd = MPJ_LogicTreeInversionRunner.parse((Options)options, (String[])args, MPJ_LogicTreeInversionRunner.class);
            MPJ_LogicTreeInversionRunner driver = new MPJ_LogicTreeInversionRunner(cmd);
            driver.run();
            MPJ_LogicTreeInversionRunner.finalizeMPJ();
            System.exit(0);
        }
        catch (Throwable t) {
            MPJ_LogicTreeInversionRunner.abortAndExit((Throwable)t);
        }
    }

    private class AsyncLogicTreeWriter
    extends AbstractAsyncLogicTreeWriter {
        private boolean[] dones;

        public AsyncLogicTreeWriter(SolutionLogicTree.SolutionProcessor processor) {
            super(MPJ_LogicTreeInversionRunner.this.outputDir, processor, MPJ_LogicTreeInversionRunner.this.tree);
            this.dones = new boolean[this.getNumTasks()];
        }

        @Override
        public int getNumTasks() {
            return MPJ_LogicTreeInversionRunner.this.getNumTasks();
        }

        @Override
        public void debug(String message) {
            MPJ_LogicTreeInversionRunner.this.debug(message);
        }

        @Override
        public LogicTreeBranch<?> getBranch(int calcIndex) {
            return MPJ_LogicTreeInversionRunner.this.tree.getBranch(MPJ_LogicTreeInversionRunner.this.branchForCalcIndex(calcIndex));
        }

        @Override
        public FaultSystemSolution getSolution(LogicTreeBranch<?> branch, int calcIndex) throws IOException {
            FaultSystemSolution sol;
            this.dones[calcIndex] = true;
            int branchIndex = MPJ_LogicTreeInversionRunner.this.branchForCalcIndex(calcIndex);
            ArrayList<File> solFiles = new ArrayList<File>();
            for (int run = 0; run < MPJ_LogicTreeInversionRunner.this.runsPerBranch; ++run) {
                int doneIndex = MPJ_LogicTreeInversionRunner.this.indexForBranchRun(branchIndex, run);
                if (!this.dones[doneIndex]) {
                    this.debug("AsyncLogicTree: not ready, waiting on run " + run + " for branch " + branchIndex + " (origIndex=" + calcIndex + ", checkIndex=" + doneIndex + "): " + String.valueOf(branch));
                    return null;
                }
                solFiles.add(MPJ_LogicTreeInversionRunner.this.getSolFile(branch, run));
            }
            if (MPJ_LogicTreeInversionRunner.this.runsPerBranch > 1) {
                String dirName = branch.buildFileName();
                File avgDir = new File(MPJ_LogicTreeInversionRunner.this.outputDir, dirName);
                Preconditions.checkState((avgDir.exists() || avgDir.mkdir() ? 1 : 0) != 0);
                File avgFile = new File(avgDir, "average_solution.zip");
                if (avgFile.exists()) {
                    this.debug("AsyncLogicTree: loading external average from " + avgFile.getAbsolutePath());
                    sol = FaultSystemSolution.load(new ArchiveInput.ApacheZipFileInput(avgFile));
                } else {
                    this.debug("AsyncLogicTree: building average for " + String.valueOf(branch));
                    FaultSystemSolution[] inputs = new FaultSystemSolution[solFiles.size()];
                    for (int i = 0; i < inputs.length; ++i) {
                        inputs[i] = FaultSystemSolution.load((File)solFiles.get(i));
                    }
                    sol = AverageSolutionCreator.buildAverage(inputs);
                    sol.write(avgFile);
                }
            } else {
                sol = FaultSystemSolution.load(new ArchiveInput.ApacheZipFileInput((File)solFiles.get(0)));
            }
            return sol;
        }

        @Override
        public void abortAndExit(Throwable t, int exitCode) {
            MPJ_LogicTreeInversionRunner.abortAndExit((Throwable)t, (int)exitCode);
        }
    }

    private class CalcRunnable
    implements Runnable {
        private int index;

        public CalcRunnable(int index) {
            this.index = index;
        }

        @Override
        public void run() {
            FaultSystemSolution sol;
            int branchIndex = MPJ_LogicTreeInversionRunner.this.branchForCalcIndex(this.index);
            int run = MPJ_LogicTreeInversionRunner.this.runForCalcIndex(this.index);
            LogicTreeBranch<LogicTreeNode> branch = MPJ_LogicTreeInversionRunner.this.tree.getBranch(branchIndex);
            MPJ_LogicTreeInversionRunner.this.debug("index " + this.index + " is branch " + branchIndex + " run " + run + ": " + String.valueOf(branch));
            File solFile = MPJ_LogicTreeInversionRunner.this.getSolFile(branch, run);
            boolean exists = solFile.exists();
            Preconditions.checkState((!MPJ_LogicTreeInversionRunner.this.reprocessOnly || exists ? 1 : 0) != 0, (String)"--reprocess-only was supplied but no solution exists for breanch %s: %s", branch, (Object)solFile.getAbsolutePath());
            if (exists) {
                MPJ_LogicTreeInversionRunner.this.debug(solFile.getAbsolutePath() + " exists, testing loading...");
                try {
                    FaultSystemSolution sol2 = FaultSystemSolution.load(solFile);
                    MPJ_LogicTreeInversionRunner.this.debug("skipping " + this.index + " (already done)");
                    if (MPJ_LogicTreeInversionRunner.this.reprocess) {
                        FaultSystemRupSet origRupSet = sol2.getRupSet();
                        SolutionLogicTree.SolutionProcessor processor = MPJ_LogicTreeInversionRunner.this.factory.getSolutionLogicTreeProcessor();
                        FaultSystemRupSet rpRupSet = MPJ_LogicTreeInversionRunner.this.factory.updateRuptureSetForBranch(origRupSet, branch);
                        MPJ_LogicTreeInversionRunner.this.factory.buildInversionConfig(rpRupSet, branch, 8);
                        if (rpRupSet != origRupSet) {
                            for (OpenSHA_Module module : origRupSet.getModules()) {
                                if (rpRupSet.hasModuleSuperclass(module.getClass())) continue;
                                rpRupSet.addModule(module);
                            }
                            sol2 = sol2.copy(rpRupSet.getArchive());
                        }
                        sol2 = processor.processSolution(sol2, branch);
                        sol2.write(solFile);
                    }
                    return;
                }
                catch (Exception e) {
                    if (MPJ_LogicTreeInversionRunner.this.reprocessOnly) {
                        MPJ_LogicTreeInversionRunner.this.debug("Failed to reprocess " + this.index + ", and --reprocess-only is enabled: " + e.getMessage());
                        MPJTaskCalculator.abortAndExit((Throwable)e);
                    }
                    MPJ_LogicTreeInversionRunner.this.debug("Failed to load, re-inverting: " + e.getMessage());
                }
            }
            MPJ_LogicTreeInversionRunner.this.memoryDebug("Beginning config for " + this.index);
            try {
                sol = Inversions.run(MPJ_LogicTreeInversionRunner.this.factory, branch, MPJ_LogicTreeInversionRunner.this.annealingThreads, MPJ_LogicTreeInversionRunner.this.cmd);
                sol.write(solFile);
            }
            catch (IOException e) {
                throw ExceptionUtils.asRuntimeException(e);
            }
            sol = null;
            MPJ_LogicTreeInversionRunner.this.memoryDebug("DONE " + this.index);
        }
    }
}

