package es.ehu.si.ixa.pipe.pos.train;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import opennlp.tools.postag.POSEvaluator;
import opennlp.tools.postag.POSModel;
import opennlp.tools.postag.POSSample;
import opennlp.tools.postag.POSTaggerEvaluationMonitor;
import opennlp.tools.postag.POSTaggerFactory;
import opennlp.tools.postag.POSTaggerME;
import opennlp.tools.postag.WordTagSampleStream;
import opennlp.tools.util.ObjectStream;
import opennlp.tools.util.TrainingParameters;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:es/ehu/si/ixa/pipe/pos/train/AbstractMorphoTaggerTrainer.class */
public abstract class AbstractMorphoTaggerTrainer implements MorphoTaggerTrainer {
    protected String lang;
    protected String trainData;
    protected String testData;
    protected ObjectStream<POSSample> trainSamples;
    protected ObjectStream<POSSample> testSamples;
    protected int beamSize;
    protected POSTaggerFactory posTaggerFactory;

    public AbstractMorphoTaggerTrainer(String str, String str2, String str3, int i) throws IOException {
        this.lang = str;
        this.trainData = str2;
        this.testData = str3;
        this.trainSamples = new WordTagSampleStream(InputOutputUtils.readInputData(str2));
        this.testSamples = new WordTagSampleStream(InputOutputUtils.readInputData(str3));
        this.beamSize = i;
    }

    @Override // es.ehu.si.ixa.pipe.pos.train.MorphoTaggerTrainer
    public POSModel train(TrainingParameters trainingParameters) {
        if (this.posTaggerFactory == null) {
            throw new IllegalStateException("Classes derived from AbstractMorphoTrainer must create a POSTaggerFactory features!");
        }
        POSModel pOSModel = null;
        POSEvaluator pOSEvaluator = null;
        try {
            pOSModel = POSTaggerME.train(this.lang, this.trainSamples, trainingParameters, this.posTaggerFactory);
            pOSEvaluator = evaluate(pOSModel, this.testSamples);
        } catch (IOException e) {
            System.err.println("IO error while loading traing and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        System.out.println("Final result: " + pOSEvaluator.getWordAccuracy());
        return pOSModel;
    }

    @Override // es.ehu.si.ixa.pipe.pos.train.MorphoTaggerTrainer
    public POSModel trainCrossEval(String str, String str2, TrainingParameters trainingParameters, String[] strArr) {
        List<Integer> list = null;
        try {
            list = crossEval(str, str2, trainingParameters, strArr);
        } catch (IOException e) {
            System.err.println("IO error while loading training and test sets!");
            e.printStackTrace();
            System.exit(1);
        }
        TrainingParameters trainingParameters2 = new TrainingParameters();
        trainingParameters2.put("Algorithm", trainingParameters.algorithm());
        trainingParameters2.put("Iterations", Integer.toString(list.get(0).intValue()));
        trainingParameters2.put("Cutoff", Integer.toString(list.get(1).intValue()));
        return train(trainingParameters2);
    }

    private List<Integer> crossEval(String str, String str2, TrainingParameters trainingParameters, String[] strArr) throws IOException {
        System.out.println("Cross Evaluation:");
        ArrayList arrayList = new ArrayList();
        new ArrayList();
        POSTaggerFactory pOSTaggerFactory = new POSTaggerFactory();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        ArrayList arrayList2 = new ArrayList(Collections.nCopies(Integer.valueOf(trainingParameters.getSettings().get("Cutoff")).intValue(), 0));
        ArrayList arrayList3 = new ArrayList(Collections.nCopies(Integer.valueOf(trainingParameters.getSettings().get("Iterations")).intValue(), 0));
        for (int i = 0; i < arrayList2.size() + 1; i++) {
            int intValue = Integer.valueOf(strArr[0]).intValue();
            int intValue2 = Integer.valueOf(strArr[1]).intValue();
            int i2 = intValue;
            int i3 = 10;
            while (true) {
                int i4 = i2 + i3;
                if (i4 < arrayList3.size() + 10) {
                    ObjectStream<String> readInputData = InputOutputUtils.readInputData(str);
                    ObjectStream<String> readInputData2 = InputOutputUtils.readInputData(str2);
                    WordTagSampleStream wordTagSampleStream = new WordTagSampleStream(readInputData);
                    WordTagSampleStream wordTagSampleStream2 = new WordTagSampleStream(readInputData2);
                    trainingParameters.put("Iterations", Integer.toString(i4));
                    trainingParameters.put("Cutoff", Integer.toString(i));
                    System.out.println("Trying with " + i4 + " iterations...");
                    POSEvaluator evaluate = evaluate(POSTaggerME.train(this.lang, wordTagSampleStream, trainingParameters, pOSTaggerFactory), wordTagSampleStream2);
                    double wordAccuracy = evaluate.getWordAccuracy();
                    StringBuilder sb = new StringBuilder();
                    sb.append("Iterations: ").append(i4).append(" cutoff: ").append(i).append(" ").append("Accuracy: ").append(wordAccuracy).append(IOUtils.LINE_SEPARATOR_UNIX);
                    FileUtils.write(new File("pos-results.txt"), (CharSequence) sb.toString(), true);
                    ArrayList arrayList4 = new ArrayList();
                    arrayList4.add(Integer.valueOf(i4));
                    arrayList4.add(Integer.valueOf(i));
                    linkedHashMap.put(arrayList4, Double.valueOf(wordAccuracy));
                    System.out.println();
                    System.out.println("Iterations: " + i4 + " cutoff: " + i);
                    System.out.println(evaluate.getWordAccuracy());
                    i2 = i4;
                    i3 = intValue2;
                }
            }
        }
        System.out.println();
        InputOutputUtils.printIterationResults(linkedHashMap);
        InputOutputUtils.getBestIterations(linkedHashMap, arrayList);
        List<Integer> list = (List) arrayList.get(0);
        System.out.println("Final Params " + list.get(0) + " " + list.get(1));
        return list;
    }

    @Override // es.ehu.si.ixa.pipe.pos.train.MorphoTaggerTrainer
    public POSEvaluator evaluate(POSModel pOSModel, ObjectStream<POSSample> objectStream) {
        POSEvaluator pOSEvaluator = new POSEvaluator(new POSTaggerME(pOSModel, this.beamSize, 0), new POSTaggerEvaluationMonitor[0]);
        try {
            pOSEvaluator.evaluate(objectStream);
        } catch (IOException e) {
            System.err.println("IO error while loading test set for evaluation!");
            e.printStackTrace();
            System.exit(1);
        }
        return pOSEvaluator;
    }
}
