/*
 * Decompiled with CFR 0.152.
 */
package opennlp.tools.ml.maxent.quasinewton;

import opennlp.tools.ml.ArrayMath;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.ml.model.Context;

public class QNModel
extends AbstractModel {
    public QNModel(Context[] params, String[] predLabels, String[] outcomeNames) {
        super(params, predLabels, outcomeNames);
        this.modelType = AbstractModel.ModelType.MaxentQn;
    }

    @Override
    public int getNumOutcomes() {
        return this.outcomeNames.length;
    }

    private Context getPredIndex(String predicate) {
        return (Context)this.pmap.get(predicate);
    }

    @Override
    public double[] eval(String[] context) {
        return this.eval(context, new double[this.evalParams.getNumOutcomes()]);
    }

    @Override
    public double[] eval(String[] context, double[] probs) {
        return this.eval(context, null, probs);
    }

    @Override
    public double[] eval(String[] context, float[] values) {
        return this.eval(context, values, new double[this.evalParams.getNumOutcomes()]);
    }

    private double[] eval(String[] context, float[] values, double[] probs) {
        for (int ci = 0; ci < context.length; ++ci) {
            Context pred = this.getPredIndex(context[ci]);
            if (pred == null) continue;
            double predValue = 1.0;
            if (values != null) {
                predValue = values[ci];
            }
            double[] parameters = pred.getParameters();
            int[] outcomes = pred.getOutcomes();
            for (int i = 0; i < outcomes.length; ++i) {
                int oi;
                int n = oi = outcomes[i];
                probs[n] = probs[n] + predValue * parameters[i];
            }
        }
        double logSumExp = ArrayMath.logSumOfExps(probs);
        for (int oi = 0; oi < this.outcomeNames.length; ++oi) {
            probs[oi] = StrictMath.exp(probs[oi] - logSumExp);
        }
        return probs;
    }

    static double[] eval(int[] context, float[] values, double[] probs, int nOutcomes, int nPredLabels, double[] parameters) {
        for (int i = 0; i < context.length; ++i) {
            int predIdx = context[i];
            double predValue = values != null ? (double)values[i] : 1.0;
            for (int oi = 0; oi < nOutcomes; ++oi) {
                int n = oi;
                probs[n] = probs[n] + predValue * parameters[oi * nPredLabels + predIdx];
            }
        }
        double logSumExp = ArrayMath.logSumOfExps(probs);
        for (int oi = 0; oi < nOutcomes; ++oi) {
            probs[oi] = StrictMath.exp(probs[oi] - logSumExp);
        }
        return probs;
    }
}

