package edu.stanford.nlp.parser.nndep;

import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;

/* loaded from: input_file:edu/stanford/nlp/parser/nndep/Classifier.class */
public class Classifier {
    private static Redwood.RedwoodChannels log = Redwood.channels(Classifier.class);
    private final double[][] W1;
    private final double[][] W2;
    private final double[][] E;
    private final double[] b1;
    private double[][] gradSaved;
    private double[][] eg2W1;
    private double[][] eg2W2;
    private double[][] eg2E;
    private double[] eg2b1;
    private double[][] saved;
    private final Map<Integer, Integer> preMap;
    private boolean isTraining;
    private final Dataset dataset;
    private final MulticoreWrapper<Pair<Collection<Example>, FeedforwardParams>, Cost> jobHandler;
    private final Config config;
    private final int numLabels;

    /* loaded from: input_file:edu/stanford/nlp/parser/nndep/Classifier$Cost.class */
    public class Cost {
        private double cost;
        private double percentCorrect;
        private final double[][] gradW1;
        private final double[] gradb1;
        private final double[][] gradW2;
        private final double[][] gradE;

        private Cost(double d, double d2, double[][] dArr, double[] dArr2, double[][] dArr3, double[][] dArr4) {
            this.cost = d;
            this.percentCorrect = d2;
            this.gradW1 = dArr;
            this.gradb1 = dArr2;
            this.gradW2 = dArr3;
            this.gradE = dArr4;
        }

        public void merge(Cost cost) {
            this.cost += cost.getCost();
            this.percentCorrect += cost.getPercentCorrect();
            Classifier.addInPlace(this.gradW1, cost.getGradW1());
            Classifier.addInPlace(this.gradb1, cost.getGradb1());
            Classifier.addInPlace(this.gradW2, cost.getGradW2());
            Classifier.addInPlace(this.gradE, cost.getGradE());
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void backpropSaved(Set<Integer> set) {
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                int intValue2 = ((Integer) Classifier.this.preMap.get(Integer.valueOf(intValue))).intValue();
                Config unused = Classifier.this.config;
                int i = intValue / 48;
                Config unused2 = Classifier.this.config;
                int i2 = (intValue % 48) * Classifier.this.config.embeddingSize;
                for (int i3 = 0; i3 < Classifier.this.config.hiddenSize; i3++) {
                    double d = Classifier.this.gradSaved[intValue2][i3];
                    for (int i4 = 0; i4 < Classifier.this.config.embeddingSize; i4++) {
                        double[] dArr = this.gradW1[i3];
                        int i5 = i2 + i4;
                        dArr[i5] = dArr[i5] + (d * Classifier.this.E[i][i4]);
                        double[] dArr2 = this.gradE[i];
                        int i6 = i4;
                        dArr2[i6] = dArr2[i6] + (d * Classifier.this.W1[i3][i2 + i4]);
                    }
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addL2Regularization(double d) {
            for (int i = 0; i < Classifier.this.W1.length; i++) {
                for (int i2 = 0; i2 < Classifier.this.W1[i].length; i2++) {
                    this.cost += ((d * Classifier.this.W1[i][i2]) * Classifier.this.W1[i][i2]) / 2.0d;
                    double[] dArr = this.gradW1[i];
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + (d * Classifier.this.W1[i][i2]);
                }
            }
            for (int i4 = 0; i4 < Classifier.this.b1.length; i4++) {
                this.cost += ((d * Classifier.this.b1[i4]) * Classifier.this.b1[i4]) / 2.0d;
                double[] dArr2 = this.gradb1;
                int i5 = i4;
                dArr2[i5] = dArr2[i5] + (d * Classifier.this.b1[i4]);
            }
            for (int i6 = 0; i6 < Classifier.this.W2.length; i6++) {
                for (int i7 = 0; i7 < Classifier.this.W2[i6].length; i7++) {
                    this.cost += ((d * Classifier.this.W2[i6][i7]) * Classifier.this.W2[i6][i7]) / 2.0d;
                    double[] dArr3 = this.gradW2[i6];
                    int i8 = i7;
                    dArr3[i8] = dArr3[i8] + (d * Classifier.this.W2[i6][i7]);
                }
            }
            for (int i9 = 0; i9 < Classifier.this.E.length; i9++) {
                for (int i10 = 0; i10 < Classifier.this.E[i9].length; i10++) {
                    this.cost += ((d * Classifier.this.E[i9][i10]) * Classifier.this.E[i9][i10]) / 2.0d;
                    double[] dArr4 = this.gradE[i9];
                    int i11 = i10;
                    dArr4[i11] = dArr4[i11] + (d * Classifier.this.E[i9][i10]);
                }
            }
        }

        public double getCost() {
            return this.cost;
        }

        public double getPercentCorrect() {
            return this.percentCorrect;
        }

        public double[][] getGradW1() {
            return this.gradW1;
        }

        public double[] getGradb1() {
            return this.gradb1;
        }

        public double[][] getGradW2() {
            return this.gradW2;
        }

        public double[][] getGradE() {
            return this.gradE;
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/parser/nndep/Classifier$CostFunction.class */
    private class CostFunction implements ThreadsafeProcessor<Pair<Collection<Example>, FeedforwardParams>, Cost> {
        private double[][] gradW1;
        private double[] gradb1;
        private double[][] gradW2;
        private double[][] gradE;

        private CostFunction() {
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public Cost process(Pair<Collection<Example>, FeedforwardParams> pair) {
            Collection<Example> first = pair.first();
            FeedforwardParams second = pair.second();
            ThreadLocalRandom current = ThreadLocalRandom.current();
            this.gradW1 = new double[Classifier.this.W1.length][Classifier.this.W1[0].length];
            this.gradb1 = new double[Classifier.this.b1.length];
            this.gradW2 = new double[Classifier.this.W2.length][Classifier.this.W2[0].length];
            this.gradE = new double[Classifier.this.E.length][Classifier.this.E[0].length];
            double d = 0.0d;
            double d2 = 0.0d;
            for (Example example : first) {
                List<Integer> feature = example.getFeature();
                List<Integer> label = example.getLabel();
                double[] dArr = new double[Classifier.this.numLabels];
                double[] dArr2 = new double[Classifier.this.config.hiddenSize];
                double[] dArr3 = new double[Classifier.this.config.hiddenSize];
                int[] array = IntStream.range(0, Classifier.this.config.hiddenSize).filter(i -> {
                    return current.nextDouble() > second.getDropOutProb();
                }).toArray();
                int i2 = 0;
                int i3 = 0;
                while (true) {
                    int i4 = i3;
                    Config unused = Classifier.this.config;
                    if (i4 >= 48) {
                        break;
                    }
                    int intValue = feature.get(i3).intValue();
                    Config unused2 = Classifier.this.config;
                    int i5 = (intValue * 48) + i3;
                    if (Classifier.this.preMap.containsKey(Integer.valueOf(i5))) {
                        int intValue2 = ((Integer) Classifier.this.preMap.get(Integer.valueOf(i5))).intValue();
                        for (int i6 : array) {
                            dArr2[i6] = dArr2[i6] + Classifier.this.saved[intValue2][i6];
                        }
                    } else {
                        for (int i7 : array) {
                            for (int i8 = 0; i8 < Classifier.this.config.embeddingSize; i8++) {
                                dArr2[i7] = dArr2[i7] + (Classifier.this.W1[i7][i2 + i8] * Classifier.this.E[intValue][i8]);
                            }
                        }
                    }
                    i2 += Classifier.this.config.embeddingSize;
                    i3++;
                }
                for (int i9 : array) {
                    dArr2[i9] = dArr2[i9] + Classifier.this.b1[i9];
                    dArr3[i9] = Math.pow(dArr2[i9], 3.0d);
                }
                int i10 = -1;
                for (int i11 = 0; i11 < Classifier.this.numLabels; i11++) {
                    if (label.get(i11).intValue() >= 0) {
                        for (int i12 : array) {
                            int i13 = i11;
                            dArr[i13] = dArr[i13] + (Classifier.this.W2[i11][i12] * dArr3[i12]);
                        }
                        if (i10 < 0 || dArr[i11] > dArr[i10]) {
                            i10 = i11;
                        }
                    }
                }
                double d3 = 0.0d;
                double d4 = 0.0d;
                double d5 = dArr[i10];
                for (int i14 = 0; i14 < Classifier.this.numLabels; i14++) {
                    if (label.get(i14).intValue() >= 0) {
                        dArr[i14] = Math.exp(dArr[i14] - d5);
                        if (label.get(i14).intValue() == 1) {
                            d3 += dArr[i14];
                        }
                        d4 += dArr[i14];
                    }
                }
                d += (Math.log(d4) - Math.log(d3)) / second.getBatchSize();
                if (label.get(i10).intValue() == 1) {
                    d2 += 1.0d / second.getBatchSize();
                }
                double[] dArr4 = new double[Classifier.this.config.hiddenSize];
                for (int i15 = 0; i15 < Classifier.this.numLabels; i15++) {
                    if (label.get(i15).intValue() >= 0) {
                        double batchSize = (-(label.get(i15).intValue() - (dArr[i15] / d4))) / second.getBatchSize();
                        for (int i16 : array) {
                            double[] dArr5 = this.gradW2[i15];
                            dArr5[i16] = dArr5[i16] + (batchSize * dArr3[i16]);
                            dArr4[i16] = dArr4[i16] + (batchSize * Classifier.this.W2[i15][i16]);
                        }
                    }
                }
                double[] dArr6 = new double[Classifier.this.config.hiddenSize];
                for (int i17 : array) {
                    dArr6[i17] = dArr4[i17] * 3.0d * dArr2[i17] * dArr2[i17];
                    double[] dArr7 = this.gradb1;
                    dArr7[i17] = dArr7[i17] + dArr6[i17];
                }
                int i18 = 0;
                int i19 = 0;
                while (true) {
                    int i20 = i19;
                    Config unused3 = Classifier.this.config;
                    if (i20 < 48) {
                        int intValue3 = feature.get(i19).intValue();
                        Config unused4 = Classifier.this.config;
                        int i21 = (intValue3 * 48) + i19;
                        if (Classifier.this.preMap.containsKey(Integer.valueOf(i21))) {
                            int intValue4 = ((Integer) Classifier.this.preMap.get(Integer.valueOf(i21))).intValue();
                            for (int i22 : array) {
                                double[] dArr8 = Classifier.this.gradSaved[intValue4];
                                dArr8[i22] = dArr8[i22] + dArr6[i22];
                            }
                        } else {
                            for (int i23 : array) {
                                for (int i24 = 0; i24 < Classifier.this.config.embeddingSize; i24++) {
                                    double[] dArr9 = this.gradW1[i23];
                                    int i25 = i18 + i24;
                                    dArr9[i25] = dArr9[i25] + (dArr6[i23] * Classifier.this.E[intValue3][i24]);
                                    double[] dArr10 = this.gradE[intValue3];
                                    int i26 = i24;
                                    dArr10[i26] = dArr10[i26] + (dArr6[i23] * Classifier.this.W1[i23][i18 + i24]);
                                }
                            }
                        }
                        i18 += Classifier.this.config.embeddingSize;
                        i19++;
                    }
                }
            }
            return new Cost(d, d2, this.gradW1, this.gradb1, this.gradW2, this.gradE);
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public ThreadsafeProcessor<Pair<Collection<Example>, FeedforwardParams>, Cost> newInstance() {
            return new CostFunction();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/parser/nndep/Classifier$FeedforwardParams.class */
    public static class FeedforwardParams {
        private final int batchSize;
        private final double dropOutProb;

        private FeedforwardParams(int i, double d) {
            this.batchSize = i;
            this.dropOutProb = d;
        }

        public int getBatchSize() {
            return this.batchSize;
        }

        public double getDropOutProb() {
            return this.dropOutProb;
        }
    }

    public Classifier(Config config, double[][] dArr, double[][] dArr2, double[] dArr3, double[][] dArr4, List<Integer> list) {
        this(config, null, dArr, dArr2, dArr3, dArr4, list);
    }

    public Classifier(Config config, Dataset dataset, double[][] dArr, double[][] dArr2, double[] dArr3, double[][] dArr4, List<Integer> list) {
        this.config = config;
        this.dataset = dataset;
        this.E = dArr;
        this.W1 = dArr2;
        this.b1 = dArr3;
        this.W2 = dArr4;
        initGradientHistories();
        this.numLabels = dArr4.length;
        this.preMap = new HashMap();
        for (int i = 0; i < list.size() && i < config.numPreComputed; i++) {
            this.preMap.put(list.get(i), Integer.valueOf(i));
        }
        this.isTraining = dataset != null;
        if (this.isTraining) {
            this.jobHandler = new MulticoreWrapper<>(config.trainingThreads, new CostFunction(), false);
        } else {
            this.jobHandler = null;
        }
    }

    private Set<Integer> getToPreCompute(List<Example> list) {
        HashSet hashSet = new HashSet();
        Iterator<Example> it = list.iterator();
        while (it.hasNext()) {
            List<Integer> feature = it.next().getFeature();
            int i = 0;
            while (true) {
                int i2 = i;
                Config config = this.config;
                if (i2 < 48) {
                    int intValue = feature.get(i).intValue();
                    Config config2 = this.config;
                    int i3 = (intValue * 48) + i;
                    if (this.preMap.containsKey(Integer.valueOf(i3))) {
                        hashSet.add(Integer.valueOf(i3));
                    }
                    i++;
                }
            }
        }
        System.err.printf("Percent actually necessary to pre-compute: %f%%%n", Double.valueOf((hashSet.size() / this.config.numPreComputed) * 100.0d));
        return hashSet;
    }

    public Cost computeCostFunction(int i, double d, double d2) {
        validateTraining();
        List<Example> randomSubList = Util.getRandomSubList(this.dataset.examples, i);
        Set<Integer> toPreCompute = getToPreCompute(randomSubList);
        preCompute(toPreCompute);
        FeedforwardParams feedforwardParams = new FeedforwardParams(i, d2);
        this.gradSaved = new double[this.preMap.size()][this.config.hiddenSize];
        Iterator it = CollectionUtils.partitionIntoFolds(randomSubList, this.config.trainingThreads).iterator();
        while (it.hasNext()) {
            this.jobHandler.put(new Pair<>((Collection) it.next(), feedforwardParams));
        }
        this.jobHandler.join(false);
        Cost cost = null;
        while (this.jobHandler.peek()) {
            Cost poll = this.jobHandler.poll();
            if (cost == null) {
                cost = poll;
            } else {
                cost.merge(poll);
            }
        }
        if (cost == null) {
            return null;
        }
        cost.backpropSaved(toPreCompute);
        cost.addL2Regularization(d);
        return cost;
    }

    public void takeAdaGradientStep(Cost cost, double d, double d2) {
        validateTraining();
        double[][] gradW1 = cost.getGradW1();
        double[][] gradW2 = cost.getGradW2();
        double[][] gradE = cost.getGradE();
        double[] gradb1 = cost.getGradb1();
        for (int i = 0; i < this.W1.length; i++) {
            for (int i2 = 0; i2 < this.W1[i].length; i2++) {
                double[] dArr = this.eg2W1[i];
                int i3 = i2;
                dArr[i3] = dArr[i3] + (gradW1[i][i2] * gradW1[i][i2]);
                double[] dArr2 = this.W1[i];
                int i4 = i2;
                dArr2[i4] = dArr2[i4] - ((d * gradW1[i][i2]) / Math.sqrt(this.eg2W1[i][i2] + d2));
            }
        }
        for (int i5 = 0; i5 < this.b1.length; i5++) {
            double[] dArr3 = this.eg2b1;
            int i6 = i5;
            dArr3[i6] = dArr3[i6] + (gradb1[i5] * gradb1[i5]);
            double[] dArr4 = this.b1;
            int i7 = i5;
            dArr4[i7] = dArr4[i7] - ((d * gradb1[i5]) / Math.sqrt(this.eg2b1[i5] + d2));
        }
        for (int i8 = 0; i8 < this.W2.length; i8++) {
            for (int i9 = 0; i9 < this.W2[i8].length; i9++) {
                double[] dArr5 = this.eg2W2[i8];
                int i10 = i9;
                dArr5[i10] = dArr5[i10] + (gradW2[i8][i9] * gradW2[i8][i9]);
                double[] dArr6 = this.W2[i8];
                int i11 = i9;
                dArr6[i11] = dArr6[i11] - ((d * gradW2[i8][i9]) / Math.sqrt(this.eg2W2[i8][i9] + d2));
            }
        }
        if (this.config.doWordEmbeddingGradUpdate) {
            for (int i12 = 0; i12 < this.E.length; i12++) {
                for (int i13 = 0; i13 < this.E[i12].length; i13++) {
                    double[] dArr7 = this.eg2E[i12];
                    int i14 = i13;
                    dArr7[i14] = dArr7[i14] + (gradE[i12][i13] * gradE[i12][i13]);
                    double[] dArr8 = this.E[i12];
                    int i15 = i13;
                    dArr8[i15] = dArr8[i15] - ((d * gradE[i12][i13]) / Math.sqrt(this.eg2E[i12][i13] + d2));
                }
            }
        }
    }

    private void initGradientHistories() {
        this.eg2E = new double[this.E.length][this.E[0].length];
        this.eg2W1 = new double[this.W1.length][this.W1[0].length];
        this.eg2b1 = new double[this.b1.length];
        this.eg2W2 = new double[this.W2.length][this.W2[0].length];
    }

    public void clearGradientHistories() {
        validateTraining();
        initGradientHistories();
    }

    private void validateTraining() {
        if (!this.isTraining) {
            throw new IllegalStateException("Not training, or training was already finalized");
        }
    }

    public void finalizeTraining() {
        validateTraining();
        this.jobHandler.join(true);
        this.isTraining = false;
    }

    public void preCompute() {
        preCompute(this.preMap.keySet());
    }

    public void preCompute(Set<Integer> set) {
        long currentTimeMillis = System.currentTimeMillis();
        this.saved = new double[this.preMap.size()][this.config.hiddenSize];
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            int intValue2 = this.preMap.get(Integer.valueOf(intValue)).intValue();
            Config config = this.config;
            int i = intValue / 48;
            Config config2 = this.config;
            int i2 = intValue % 48;
            for (int i3 = 0; i3 < this.config.hiddenSize; i3++) {
                for (int i4 = 0; i4 < this.config.embeddingSize; i4++) {
                    double[] dArr = this.saved[intValue2];
                    int i5 = i3;
                    dArr[i5] = dArr[i5] + (this.W1[i3][(i2 * this.config.embeddingSize) + i4] * this.E[i][i4]);
                }
            }
        }
        log.info("PreComputed " + set.size() + ", Elapsed Time: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " (s)");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] computeScores(int[] iArr) {
        return computeScores(iArr, this.preMap);
    }

    private double[] computeScores(int[] iArr, Map<Integer, Integer> map) {
        double[] dArr = new double[this.config.hiddenSize];
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            Config config = this.config;
            int i4 = (i3 * 48) + i2;
            if (map.containsKey(Integer.valueOf(i4))) {
                int intValue = map.get(Integer.valueOf(i4)).intValue();
                for (int i5 = 0; i5 < this.config.hiddenSize; i5++) {
                    int i6 = i5;
                    dArr[i6] = dArr[i6] + this.saved[intValue][i5];
                }
            } else {
                for (int i7 = 0; i7 < this.config.hiddenSize; i7++) {
                    for (int i8 = 0; i8 < this.config.embeddingSize; i8++) {
                        int i9 = i7;
                        dArr[i9] = dArr[i9] + (this.W1[i7][i + i8] * this.E[i3][i8]);
                    }
                }
            }
            i += this.config.embeddingSize;
        }
        for (int i10 = 0; i10 < this.config.hiddenSize; i10++) {
            int i11 = i10;
            dArr[i11] = dArr[i11] + this.b1[i10];
            dArr[i10] = dArr[i10] * dArr[i10] * dArr[i10];
        }
        double[] dArr2 = new double[this.numLabels];
        for (int i12 = 0; i12 < this.numLabels; i12++) {
            for (int i13 = 0; i13 < this.config.hiddenSize; i13++) {
                int i14 = i12;
                dArr2[i14] = dArr2[i14] + (this.W2[i12][i13] * dArr[i13]);
            }
        }
        return dArr2;
    }

    public double[][] getW1() {
        return this.W1;
    }

    public double[] getb1() {
        return this.b1;
    }

    public double[][] getW2() {
        return this.W2;
    }

    public double[][] getE() {
        return this.E;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void addInPlace(double[][] dArr, double[][] dArr2) {
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[0].length; i2++) {
                double[] dArr3 = dArr[i];
                int i3 = i2;
                dArr3[i3] = dArr3[i3] + dArr2[i][i2];
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void addInPlace(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + dArr2[i];
        }
    }
}
