/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.impl;

import cern.colt.matrix.tdouble.DoubleMatrix2D;
import com.google.common.base.Preconditions;
import com.google.gson.Gson;
import com.google.gson.TypeAdapter;
import com.google.gson.annotations.JsonAdapter;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.opensha.commons.data.function.EvenlyDiscretizedFunc;
import org.opensha.commons.data.function.HistogramFunction;
import org.opensha.commons.util.ExceptionUtils;
import org.opensha.sha.earthquake.faultSysSolution.FaultSystemRupSet;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.ConstraintWeightingType;
import org.opensha.sha.earthquake.faultSysSolution.inversion.constraints.InversionConstraint;
import org.opensha.sha.earthquake.faultSysSolution.modules.AveSlipModule;
import org.opensha.sha.earthquake.faultSysSolution.modules.ClusterRuptures;
import org.opensha.sha.earthquake.faultSysSolution.modules.SectSlipRates;
import org.opensha.sha.earthquake.faultSysSolution.modules.SlipAlongRuptureModel;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.ClusterRupture;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.Jump;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.plausibility.PlausibilityConfiguration;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.plausibility.impl.prob.Shaw07JumpDistProb;
import org.opensha.sha.earthquake.faultSysSolution.ruptures.strategies.ClusterConnectionStrategy;

public class SlipRateSegmentationConstraint
extends InversionConstraint {
    private SegmentationModel segModel;
    private RateCombiner combiner;
    private boolean netConstraint;
    private boolean includeUnusedInNet;
    private transient FaultSystemRupSet rupSet;
    private transient Map<Jump, List<Integer>> jumpRupturesMap;
    private transient EvenlyDiscretizedFunc scalarBins;

    public SlipRateSegmentationConstraint(FaultSystemRupSet rupSet, SegmentationModel segModel, RateCombiner combiner, double weight, boolean normalized, boolean inequality) {
        this(rupSet, segModel, combiner, weight, normalized, inequality, false, false);
    }

    public SlipRateSegmentationConstraint(FaultSystemRupSet rupSet, SegmentationModel segModel, RateCombiner combiner, double weight, boolean normalized, boolean inequality, boolean netConstraint, boolean includeUnusedInNet) {
        super(SlipRateSegmentationConstraint.getName(normalized, netConstraint), SlipRateSegmentationConstraint.getShortName(normalized, netConstraint), weight, inequality, normalized ? ConstraintWeightingType.NORMALIZED : ConstraintWeightingType.UNNORMALIZED);
        this.segModel = segModel;
        this.combiner = combiner;
        this.netConstraint = netConstraint;
        this.includeUnusedInNet = includeUnusedInNet;
        if (netConstraint) {
            Preconditions.checkState((boolean)normalized, (Object)"Net constraints must be normalized");
            Preconditions.checkState((boolean)(segModel instanceof ScalarSegmentationModel), (Object)"Must be a scalar segmentation model for net constraint");
        }
        this.setRuptureSet(rupSet);
    }

    private static String getShortName(boolean normalized, boolean netConstraint) {
        if (netConstraint) {
            return "NetSlipSeg";
        }
        if (normalized) {
            return "NormSlipSeg";
        }
        return "SlipSeg";
    }

    private static String getName(boolean normalized, boolean netConstraint) {
        if (netConstraint) {
            return "Net (Distance-Binned) Slip Rate Segmentation";
        }
        if (normalized) {
            return "Normalized Slip Rate Segmentation";
        }
        return "Slip Rate Segmentation";
    }

    @Override
    public int getNumRows() {
        this.checkInitJumpRups();
        if (this.netConstraint) {
            return this.scalarBins.size();
        }
        return this.jumpRupturesMap.size();
    }

    @Override
    public long encode(DoubleMatrix2D A, double[] d, int startRow) {
        this.checkInitJumpRups();
        int row = startRow;
        ArrayList<Jump> jumps = new ArrayList<Jump>(this.jumpRupturesMap.keySet());
        Collections.sort(jumps, Jump.id_comparator);
        AveSlipModule aveSlips = this.rupSet.requireModule(AveSlipModule.class);
        SlipAlongRuptureModel slipAlongModel = this.rupSet.requireModule(SlipAlongRuptureModel.class);
        SectSlipRates slipRates = this.rupSet.requireModule(SectSlipRates.class);
        Preconditions.checkState((this.weightingType == ConstraintWeightingType.NORMALIZED || this.weightingType == ConstraintWeightingType.UNNORMALIZED ? 1 : 0) != 0, (Object)"Only normalized and un-normalized weighting types are supported");
        boolean normalized = this.weightingType == ConstraintWeightingType.NORMALIZED;
        HashMap<Jump, Integer> jumpScalarBinIndexes = null;
        double[] scalarBinAMults = null;
        double[] scalarBinTargets = null;
        if (this.netConstraint) {
            Preconditions.checkState((boolean)normalized, (Object)"Net constraint must be normalized");
            Preconditions.checkState((boolean)(this.segModel instanceof ScalarSegmentationModel), (Object)"Must be a scalar segmentation model for net constraint");
            ScalarSegmentationModel scalarSeg = (ScalarSegmentationModel)this.segModel;
            jumpScalarBinIndexes = new HashMap<Jump, Integer>();
            scalarBinAMults = new double[this.scalarBins.size()];
            scalarBinTargets = new double[this.scalarBins.size()];
            for (Jump jump : jumps) {
                int bin;
                double scalar = scalarSeg.getScalar(jump);
                int n = bin = this.scalarBins.getClosestXIndex(scalar);
                scalarBinAMults[n] = scalarBinAMults[n] + 1.0;
                int n2 = bin;
                scalarBinTargets[n2] = scalarBinTargets[n2] + scalarSeg.calcReductionForScalar(scalar);
                jumpScalarBinIndexes.put(jump, bin);
            }
            for (int i = 0; i < scalarBinAMults.length; ++i) {
                scalarBinAMults[i] = 1.0 / scalarBinAMults[i];
                int n = i;
                scalarBinTargets[n] = scalarBinTargets[n] * scalarBinAMults[i];
            }
        }
        long count = 0L;
        for (Jump jump : jumps) {
            int fromID = jump.fromSection.getSectionId();
            int toID = jump.toSection.getSectionId();
            double rate1 = slipRates.getSlipRate(fromID);
            double rate2 = slipRates.getSlipRate(toID);
            double combRate = this.combiner.combine(rate1, rate2);
            Preconditions.checkState((boolean)Double.isFinite(combRate), (String)"Non-finite combined slip-rate: %s", (Object)combRate);
            double segFract = this.segModel.calcReductionBetween(jump);
            Preconditions.checkState((Double.isFinite(segFract) && segFract >= 0.0 && segFract <= 1.0 ? 1 : 0) != 0, (String)"Bad segmentation fraction: %s", (Object)segFract);
            double segRate = combRate * segFract;
            Preconditions.checkState((boolean)Double.isFinite(segRate), (String)"Non-finite segmentation rate: %s", (Object)segRate);
            int bin = -1;
            if (this.netConstraint) {
                bin = (Integer)jumpScalarBinIndexes.get(jump);
                row = startRow + bin;
            }
            for (int rup : this.jumpRupturesMap.get(jump)) {
                double[] slips = slipAlongModel.calcSlipOnSectionsForRup(this.rupSet, aveSlips, rup);
                List<Integer> sects = this.rupSet.getSectionsIndicesForRup(rup);
                double slip1 = Double.NaN;
                double slip2 = Double.NaN;
                for (int i = 0; i < slips.length; ++i) {
                    int sect = sects.get(i);
                    if (sect == fromID) {
                        slip1 = slips[i];
                        continue;
                    }
                    if (sect != toID) continue;
                    slip2 = slips[i];
                }
                double avgSlip = this.combiner.combine(slip1, slip2);
                Preconditions.checkState((boolean)Double.isFinite(avgSlip), (String)"Non-finite average slip across jump: %s (from %s and %s)", (Object)avgSlip, (Object)slip1, (Object)slip2);
                if (this.netConstraint) {
                    double prev = this.getA(A, row, rup);
                    this.setA(A, row, rup, prev + this.weight * scalarBinAMults[bin] * avgSlip / combRate);
                    if (prev != 0.0) continue;
                    ++count;
                    continue;
                }
                if (normalized) {
                    this.setA(A, row, rup, this.weight * avgSlip / combRate);
                } else {
                    this.setA(A, row, rup, this.weight * avgSlip);
                }
                ++count;
            }
            d[row] = this.netConstraint ? this.weight * scalarBinTargets[bin] : (normalized ? this.weight * segFract : this.weight * segRate);
            ++row;
        }
        return count;
    }

    private synchronized void checkInitJumpRups() {
        if (this.jumpRupturesMap == null) {
            System.out.println("Detecting jumps for segmentation constraint");
            this.jumpRupturesMap = new HashMap<Jump, List<Integer>>();
            ClusterRuptures cRups = this.rupSet.requireModule(ClusterRuptures.class);
            int jumpingRups = 0;
            for (int r = 0; r < cRups.size(); ++r) {
                ClusterRupture rup = cRups.get(r);
                boolean hasJumps = false;
                for (Jump jump : rup.getJumpsIterable()) {
                    List<Integer> jumpRups;
                    if (jump.fromSection.getSectionId() > jump.toSection.getSectionId()) {
                        jump = jump.reverse();
                    }
                    if ((jumpRups = this.jumpRupturesMap.get(jump)) == null) {
                        jumpRups = new ArrayList<Integer>();
                        this.jumpRupturesMap.put(jump, jumpRups);
                    }
                    jumpRups.add(r);
                    hasJumps = true;
                }
                if (!hasJumps) continue;
                ++jumpingRups;
            }
            System.out.println("Found " + this.jumpRupturesMap.size() + " unique jumps, involving " + jumpingRups + " ruptures");
            if (this.netConstraint) {
                if (this.includeUnusedInNet) {
                    ClusterConnectionStrategy connStrat = this.rupSet.requireModule(PlausibilityConfiguration.class).getConnectionStrategy();
                    int unusedCount = 0;
                    for (Jump jump : connStrat.getAllPossibleJumps()) {
                        if (jump.fromSection.getSectionId() > jump.toSection.getSectionId()) {
                            jump = jump.reverse();
                        }
                        if (this.jumpRupturesMap.containsKey(jump)) continue;
                        this.jumpRupturesMap.put(jump, new ArrayList());
                        ++unusedCount;
                    }
                    System.out.println("Added " + unusedCount + " additional jumps from connection strategy that are never used");
                }
                ScalarSegmentationModel scalarSeg = (ScalarSegmentationModel)this.segModel;
                double minVal = Double.POSITIVE_INFINITY;
                double maxVal = Double.NEGATIVE_INFINITY;
                for (Jump jump : this.jumpRupturesMap.keySet()) {
                    double val = scalarSeg.getScalar(jump);
                    minVal = Math.min(minVal, val);
                    maxVal = Math.max(maxVal, val);
                }
                this.scalarBins = scalarSeg.getScalarBins(minVal, maxVal);
            }
        }
    }

    @Override
    public void setRuptureSet(FaultSystemRupSet rupSet) {
        if (rupSet != this.rupSet) {
            this.rupSet = rupSet;
            Preconditions.checkState((boolean)rupSet.hasModule(AveSlipModule.class), (Object)"Rupture set does not have average slip data");
            Preconditions.checkState((boolean)rupSet.hasModule(SlipAlongRuptureModel.class), (Object)"Rupture set does not have slip along rupture data");
            Preconditions.checkState((boolean)rupSet.hasModule(SectSlipRates.class), (Object)"Rupture set does not have slip rate data");
            if (!rupSet.hasModule(ClusterRuptures.class)) {
                rupSet.addModule(ClusterRuptures.singleStranged(rupSet));
            }
        }
    }

    @JsonAdapter(value=SegmentationModelAdapter.class)
    public static interface SegmentationModel {
        public double calcReductionBetween(Jump var1);
    }

    public static enum RateCombiner {
        MIN("Min Rate"){

            @Override
            public double combine(double rate1, double rate2) {
                return Math.min(rate1, rate2);
            }
        }
        ,
        MAX("Max Rate"){

            @Override
            public double combine(double rate1, double rate2) {
                return Math.max(rate1, rate2);
            }
        }
        ,
        AVERAGE("Avg Rate"){

            @Override
            public double combine(double rate1, double rate2) {
                return 0.5 * (rate1 + rate2);
            }
        };

        private String label;

        private RateCombiner(String label) {
            this.label = label;
        }

        public String toString() {
            return this.label;
        }

        public abstract double combine(double var1, double var3);
    }

    public static interface ScalarSegmentationModel
    extends SegmentationModel {
        @Override
        default public double calcReductionBetween(Jump jump) {
            return this.calcReductionForScalar(this.getScalar(jump));
        }

        public double calcReductionForScalar(double var1);

        public double getScalar(Jump var1);

        public EvenlyDiscretizedFunc getScalarBins(double var1, double var3);
    }

    public static class SegmentationModelAdapter
    extends TypeAdapter<SegmentationModel> {
        Gson gson = new Gson();

        public void write(JsonWriter out, SegmentationModel value) throws IOException {
            out.beginObject();
            out.name("type").value(value.getClass().getName());
            out.name("data");
            this.gson.toJson((Object)value, value.getClass(), out);
            out.endObject();
        }

        public SegmentationModel read(JsonReader in) throws IOException {
            Class<?> type = null;
            in.beginObject();
            Preconditions.checkState((boolean)in.nextName().equals("type"), (Object)"JSON 'type' object must be first");
            try {
                type = Class.forName(in.nextString());
            }
            catch (ClassCastException | ClassNotFoundException e) {
                throw ExceptionUtils.asRuntimeException(e);
            }
            Preconditions.checkState((boolean)in.nextName().equals("data"), (Object)"JSON 'data' object must be second");
            SegmentationModel model = (SegmentationModel)this.gson.fromJson(in, type);
            in.endObject();
            return model;
        }
    }

    public static class Shaw07JumpDistSegModel
    implements ScalarSegmentationModel {
        private double a;
        private double r0;

        public Shaw07JumpDistSegModel(double a, double r0) {
            this.a = a;
            this.r0 = r0;
        }

        @Override
        public double calcReductionForScalar(double distance) {
            return Shaw07JumpDistProb.calcJumpProbability(distance, this.a, this.r0);
        }

        @Override
        public double getScalar(Jump jump) {
            return jump.distance;
        }

        @Override
        public EvenlyDiscretizedFunc getScalarBins(double minVal, double maxVal) {
            return HistogramFunction.getEncompassingHistogram(0.01, Math.max(maxVal, 1.0), 0.5);
        }
    }
}

