/*
 * Decompiled with CFR 0.152.
 */
package org.opensha.commons.calc.cholesky;

import Jama.EigenvalueDecomposition;
import Jama.Matrix;
import com.google.common.base.Stopwatch;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.StringTokenizer;
import java.util.concurrent.TimeUnit;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.opensha.commons.calc.cholesky.CholeskyDecomposition;

public class NearPD {
    private Matrix X;
    private boolean keepDiag;
    private boolean doDykstra;
    private double eigTol;
    private double convTol;
    private int maxit;
    private boolean apache;
    private double conv;
    private double normF;
    private int iter;
    private double[] eigVals;
    private static final DecimalFormat tDF = new DecimalFormat("0.00");
    private static final DecimalFormat pDF = new DecimalFormat("0.000%");

    public NearPD() {
        this.setDefaults();
    }

    public boolean calcNearPD(Matrix x) {
        return this.calcNearPD(x, false);
    }

    public boolean calcNearPD(Matrix x, boolean verbose) {
        int n = x.getRowDimension();
        double[] diagX0 = new double[n];
        double[] d = new double[n];
        Matrix D_S = new Matrix(n, n);
        Matrix R = new Matrix(n, n);
        Matrix Y = new Matrix(n, n);
        Matrix D_plus = new Matrix(n, n);
        Matrix Q = new Matrix(n, n);
        if (this.keepDiag) {
            for (int i = 0; i < n; ++i) {
                diagX0[i] = x.get(i, i);
            }
        }
        Matrix X = x.copy();
        int iter = 0;
        boolean converged = false;
        double conv = Double.POSITIVE_INFINITY;
        Stopwatch totWatch = null;
        if (verbose) {
            System.out.println("Beginning NearPD calculation with up to " + this.maxit + " iterations and convTol=" + (float)this.convTol);
            totWatch = Stopwatch.createStarted();
        }
        boolean apache = this.apache;
        double prevConv = conv;
        double bestConv = conv;
        Matrix bestX = null;
        double[] bestD = null;
        while (iter < this.maxit & !converged) {
            int i;
            EigenvalueDecomposition eig;
            Stopwatch iterWatch = null;
            if (verbose) {
                System.out.println("NearPD iteration " + iter);
                iterWatch = Stopwatch.createStarted();
            }
            Y = X.copy();
            if (this.doDykstra) {
                R = Y.minus(D_S);
            }
            if (apache) {
                Array2DRowRealMatrix mat = this.doDykstra ? new Array2DRowRealMatrix(R.getArray(), false) : new Array2DRowRealMatrix(Y.getArray(), false);
                try {
                    EigenDecomposition eigen = new EigenDecomposition((RealMatrix)mat);
                    d = eigen.getRealEigenvalues();
                    Q = new Matrix(eigen.getV().getData());
                }
                catch (Exception e) {
                    if (verbose) {
                        System.err.println("WARNING: failed via apache, reverting to Jama: " + e.getMessage());
                    }
                    apache = false;
                    eig = this.doDykstra ? R.eig() : Y.eig();
                    d = eig.getRealEigenvalues();
                    Q = eig.getV();
                }
            } else {
                eig = this.doDykstra ? R.eig() : Y.eig();
                d = eig.getRealEigenvalues();
                Q = eig.getV();
            }
            double eigMax = Double.NEGATIVE_INFINITY;
            for (i = 0; i < n; ++i) {
                if (!(d[i] > eigMax)) continue;
                eigMax = d[i];
            }
            for (i = 0; i < n; ++i) {
                double d_plus = Math.max(d[i], this.eigTol * eigMax);
                D_plus.set(i, i, d_plus);
            }
            X = Q.times(D_plus).times(Q.transpose());
            if (this.doDykstra) {
                D_S = X.minus(R);
            }
            if (this.keepDiag) {
                for (i = 0; i < n; ++i) {
                    X.set(i, i, diagX0[i]);
                }
            }
            conv = Y.minus(X).normInf() / Y.normInf();
            ++iter;
            if (conv <= this.convTol) {
                converged = true;
            }
            if (verbose) {
                iterWatch.stop();
                System.out.println("\tTook " + NearPD.elapsed(iterWatch) + "; convergence=" + (float)conv + ", converged=" + converged + ", improvement: " + (float)(prevConv - conv) + " (" + pDF.format((prevConv - conv) / prevConv) + ")" + (conv < bestConv ? " (new best)" : ""));
            }
            if (conv < bestConv) {
                bestX = X.copy();
                bestD = Arrays.copyOf(d, d.length);
                bestConv = conv;
            }
            prevConv = conv;
        }
        if (verbose) {
            totWatch.stop();
            System.out.println("Done NearPD after " + NearPD.elapsed(totWatch) + " (" + NearPD.elapsed(totWatch, iter) + " each), " + iter + " iterations, conv=" + (float)bestConv + ", convTol=" + (float)this.convTol + ", and converged=" + converged);
        }
        this.X = bestX;
        this.conv = bestConv;
        this.normF = x.minus(bestX).normF();
        this.iter = iter;
        this.eigVals = bestD;
        return converged;
    }

    private static String elapsed(Stopwatch watch) {
        return NearPD.elapsed(watch, 1);
    }

    private static String elapsed(Stopwatch watch, int instances) {
        double secs;
        double millis = watch.elapsed(TimeUnit.MILLISECONDS);
        if (instances > 1) {
            millis /= (double)instances;
        }
        if ((secs = millis / 1000.0) < 90.0) {
            return tDF.format(secs) + " s";
        }
        double mins = secs / 60.0;
        return tDF.format(mins) + " m";
    }

    public Matrix getX() {
        return this.X;
    }

    public double getConvergedTolerence() {
        return this.conv;
    }

    public double getFrobNorm() {
        return this.normF;
    }

    public double getIter() {
        return this.iter;
    }

    public double[] getEigVals() {
        return (double[])this.eigVals.clone();
    }

    public Matrix getNearPD() {
        return this.X.copy();
    }

    private void setDefaults() {
        this.keepDiag = false;
        this.doDykstra = true;
        this.eigTol = 1.0E-6;
        this.convTol = 1.0E-7;
        this.maxit = 100;
        this.apache = false;
    }

    public void setKeepDiag(boolean e) {
        this.keepDiag = e;
    }

    public void setDoDykstra(boolean e) {
        this.doDykstra = e;
    }

    public void setConvTol(double val) {
        this.convTol = val;
    }

    public void setEigTol(double val) {
        this.eigTol = val;
    }

    public void setMaxit(int val) {
        this.maxit = val;
    }

    public void setUseApache(boolean apache) {
        this.apache = apache;
    }

    public static void main(String[] args) {
        int n = 10;
        double[][] rho = new double[n][n];
        String rhoString = "1.0000      0.8646      0.5412     -0.0070     -0.4291     -0.2727      0.5448      0.4981     -0.4579     -0.4481 0.8646      1.0000      0.5343     -0.0141     -0.4832     -0.4705      0.4392      0.4716     -0.3771     -0.3660 0.5412      0.5343      1.0000      0.0427     -0.7999     -0.8240      0.2005      0.1926     -0.2985     -0.2684 -0.0070     -0.0141      0.0427      1.0000     -0.4071     -0.8087     -0.1396     -0.1775     -0.0304     -0.0340 -0.4291     -0.4832     -0.7999     -0.4071      1.0000      0.0594     -0.1627     -0.2604      0.2474      0.2541 -0.2727     -0.4705     -0.8240     -0.8087      0.0594      1.0000      0.2672      0.0173      0.3205      0.3749 0.5448      0.4392      0.2005     -0.1396     -0.1627      0.2672      1.0000      0.2581     -0.3109     -0.3051 0.4981      0.4716      0.1926     -0.1775     -0.2604      0.0173      0.2581      1.0000      0.2289      0.2408 -0.4579     -0.3771     -0.2985     -0.0304      0.2474      0.3205     -0.3109      0.2289      1.0000      0.8425 -0.4481     -0.3660     -0.2684     -0.0340      0.2541      0.3749     -0.3051      0.2408      0.8425      1.0000";
        StringTokenizer st = new StringTokenizer(rhoString);
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                rho[i][j] = Double.parseDouble(st.nextToken().trim());
            }
        }
        Matrix rho_matrix = new Matrix(rho);
        CholeskyDecomposition chol = new CholeskyDecomposition(rho_matrix);
        if (!chol.isSPD()) {
            System.out.println("The rho matrix is not SPD");
        }
        EigenvalueDecomposition eig = rho_matrix.eig();
        Object eigValsString = "";
        double[] eigVals = eig.getRealEigenvalues();
        for (int i = 0; i < n; ++i) {
            eigValsString = (String)eigValsString + " " + (double)Math.round(eigVals[i] * 1000.0) / 1000.0;
        }
        System.out.println("The eigVals are: " + (String)eigValsString);
        NearPD nearPd = new NearPD();
        nearPd.setKeepDiag(true);
        nearPd.setUseApache(true);
        nearPd.calcNearPD(rho_matrix);
        Matrix rho_PDmatrix = nearPd.getX();
        String rhoPDString_Matlab = "1.0000    0.3944    0.4262    0.4865    0.5863    0.7108    0.8179    0.8836    0.9155    0.9246    0.8882 0.3944    1.0000    0.9110    0.7905    0.6747    0.5635    0.4589    0.3607    0.2725    0.1959    0.1308 0.4262    0.9110    1.0000    0.8780    0.7590    0.6430    0.5320    0.4280    0.3320    0.2470    0.1729 0.4865    0.7905    0.8780    1.0000    0.8782    0.7588    0.6425    0.5316    0.4277    0.3315    0.2460 0.5863    0.6747    0.7590    0.8782    1.0000    0.8779    0.7587    0.6428    0.5318    0.4277    0.3315 0.7108    0.5635    0.6430    0.7588    0.8779    1.0000    0.8784    0.7593    0.6432    0.5324    0.4289 0.8179    0.4589    0.5320    0.6425    0.7587    0.8784    1.0000    0.8787    0.7595    0.6438    0.5337 0.8836    0.3607    0.4280    0.5316    0.6428    0.7593    0.8787    1.0000    0.8784    0.7597    0.6444 0.9155    0.2725    0.3320    0.4277    0.5318    0.6432    0.7595    0.8784    1.0000    0.8785    0.7600 0.9246    0.1959    0.2470    0.3315    0.4277    0.5324    0.6438    0.7597    0.8785    1.0000    0.8798 0.8882    0.1308    0.1729    0.2460    0.3315    0.4289    0.5337    0.6444    0.7600    0.8798    1.0000";
        System.out.println("Testing NearPD: nearest PD matrix \n Rho_PD = \n");
        rho_PDmatrix.print(12, 8);
        CholeskyDecomposition cholDecompPD = new CholeskyDecomposition(rho_PDmatrix);
        if (!cholDecompPD.isSPD()) {
            throw new RuntimeException("Error: Even after NearPD the matrix is not PD");
        }
    }
}

