package edu.udo.cs.ls8.mllib.rapidminer.operators;

import com.rapidminer.example.ExampleSet;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.ValueString;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.ParameterTypeStringCategory;
import edu.udo.cs.ls8.mllib.helpers.IntegerModuloCounter;
import edu.udo.cs.ls8.mllib.rapidminer.entities.ClusterPredictionModel;
import edu.udo.cs.ls8.mllib.rapidminer.helpers.ProportionExtractor;
import edu.udo.cs.ls8.mllib.rapidminer.helpers.ProportionsPerformanceMeasure;
import java.util.List;

/* loaded from: input_file:edu/udo/cs/ls8/mllib/rapidminer/operators/ClusterLabeler.class */
public class ClusterLabeler extends ClusterOperator {
    public static final String PARAMETER_K = "k";
    public static final String PARAMETER_STRATEGY = "strategy";
    public static final String PARAMETER_BFCMP = "brute force comparison";
    public static final String[] strategies = {"BruteForce", "GreedyBySize"};
    private InputPort exampleSetInput;
    private InputPort clusterModelInput;
    private OutputPort exampleSetOutput;
    private OutputPort clusterModelOutput;
    private int numOfGroups;
    private int numOfLabels;
    private int numOfClusters;
    private ExampleSet examples;
    private ClusterPredictionModel clusterModel;
    private ProportionExtractor extractor;
    private ProportionsPerformanceMeasure measure;
    private int max_k;
    private String strategy;
    private boolean bruteforce_cmp;
    private double maxRealAccuracy;
    private double foundAccuracy;
    int[] optRealLabeling;
    private double maxPerformance;
    private double foundPerformance;
    int[] optPerformanceLabeling;
    private double proportionSimilarity;
    private double weightedProportionSimilarity;
    private double labelSizeSimilarity;

    public ClusterLabeler(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.exampleSetInput = getInputPorts().createPort("example set", ExampleSet.class);
        this.clusterModelInput = getInputPorts().createPort("cluster model", ClusterPredictionModel.class);
        this.exampleSetOutput = getOutputPorts().createPort("example set");
        this.clusterModelOutput = getOutputPorts().createPort("prediction model");
        this.numOfGroups = 0;
        this.numOfLabels = 0;
        this.numOfClusters = 0;
        this.examples = null;
        this.clusterModel = null;
        this.extractor = null;
        this.measure = null;
        this.max_k = 0;
        this.strategy = "";
        this.bruteforce_cmp = false;
        this.maxRealAccuracy = Double.NEGATIVE_INFINITY;
        this.foundAccuracy = 0.0d;
        this.optRealLabeling = null;
        this.maxPerformance = Double.NEGATIVE_INFINITY;
        this.foundPerformance = 0.0d;
        this.optPerformanceLabeling = null;
        this.proportionSimilarity = 0.0d;
        this.weightedProportionSimilarity = 0.0d;
        this.labelSizeSimilarity = 0.0d;
        addValue(new ValueString(PARAMETER_STRATEGY, "Strategy") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.1
            public String getStringValue() {
                return ClusterLabeler.this.strategy;
            }
        });
        addValue(new ValueDouble("performance", "Performance") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.2
            public double getDoubleValue() {
                return ClusterLabeler.this.maxPerformance;
            }
        });
        addValue(new ValueDouble("realAccuracy", "Real label accuracy.") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.3
            public double getDoubleValue() {
                return ClusterLabeler.this.maxRealAccuracy;
            }
        });
        addValue(new ValueDouble("foundAccuracy", "Found label accuracy.") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.4
            public double getDoubleValue() {
                return ClusterLabeler.this.foundAccuracy;
            }
        });
        addValue(new ValueDouble("diffAccuracy", "Difference between found and real") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.5
            public double getDoubleValue() {
                return ClusterLabeler.this.foundAccuracy - ClusterLabeler.this.maxRealAccuracy;
            }
        });
        addValue(new ValueDouble("proportionSimilarity", "Proportion similarity") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.6
            public double getDoubleValue() {
                return ClusterLabeler.this.proportionSimilarity;
            }
        });
        addValue(new ValueDouble("weightedProportionSimilarity", "Weighted proportion similarity") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.7
            public double getDoubleValue() {
                return ClusterLabeler.this.weightedProportionSimilarity;
            }
        });
        addValue(new ValueDouble("labelSizeSimilarity", "Label size similarity.") { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterLabeler.8
            public double getDoubleValue() {
                return ClusterLabeler.this.labelSizeSimilarity;
            }
        });
    }

    @Override // edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterOperator
    public void doWork() throws OperatorException {
        super.doWork();
        logMsg("Begin ClusterLabeler");
        this.examples = this.exampleSetInput.getData();
        this.clusterModel = this.clusterModelInput.getData();
        this.max_k = getParameterAsInt("k");
        this.strategy = getParameterAsString(PARAMETER_STRATEGY);
        this.bruteforce_cmp = getParameterAsBoolean(PARAMETER_BFCMP);
        this.numOfClusters = this.clusterModel.getNumberOfClusters();
        this.extractor = new ProportionExtractor(this.examples, this.numOfClusters);
        this.measure = new ProportionsPerformanceMeasure(this.extractor);
        logExampleSetProperties();
        this.numOfGroups = this.extractor.getNumberOfGroups();
        this.numOfLabels = this.extractor.getNumberOfLabels();
        if (this.strategy.equals("BruteForce")) {
            calculateLabelingBruteForce();
        } else if (this.strategy.equals("GreedyBySize")) {
            if (this.bruteforce_cmp) {
                calculateLabelingBruteForce();
            }
            calculateLabelingGreedyBySize();
            if (!this.bruteforce_cmp) {
                this.optRealLabeling = (int[]) this.optPerformanceLabeling.clone();
                this.maxRealAccuracy = 1.0d;
                this.foundPerformance = 1.0d;
            }
        }
        if (this.optPerformanceLabeling == null || this.optRealLabeling == null) {
            logMsg("Opt. labeling is null!");
        } else {
            logMsg("Best labeling by performance: " + getLabelingAsString(this.optPerformanceLabeling) + " (" + this.maxPerformance + ", " + this.foundAccuracy + ")");
            logMsg("   class proportions: " + getDoubleVectorAsString(this.measure.getPredictedClassProportions(this.optPerformanceLabeling)));
            logMsg("   label similarity: " + this.measure.getLabelSizeSimilarity(this.optPerformanceLabeling));
            logMsg("   w. prop. sim.:    " + this.measure.getWeightedProportionSimilarity(this.optPerformanceLabeling));
            logMsg("Best labeling by accuracy:    " + getLabelingAsString(this.optRealLabeling) + "( " + this.foundPerformance + ", " + this.maxRealAccuracy + ")");
            logMsg("   class proportions: " + getDoubleVectorAsString(this.measure.getPredictedClassProportions(this.optRealLabeling)));
            logMsg("   label similarity: " + this.measure.getLabelSizeSimilarity(this.optRealLabeling));
            logMsg("   w. prop. sim.:    " + this.measure.getWeightedProportionSimilarity(this.optRealLabeling));
            logMsg("Real class proportions: " + getDoubleVectorAsString(this.measure.getRealClassProportions()));
            logLabelProportions();
            logPredictedProportions("performance labeling", this.optPerformanceLabeling);
            logPredictedProportions("real labeling", this.optRealLabeling);
            this.proportionSimilarity = this.measure.getProportionSimilarity(this.optPerformanceLabeling);
            this.weightedProportionSimilarity = this.measure.getWeightedProportionSimilarity(this.optPerformanceLabeling);
            this.labelSizeSimilarity = this.measure.getLabelSizeSimilarity(this.optPerformanceLabeling);
            assignClusterLabels(this.optPerformanceLabeling);
        }
        logMsg("End ClusterLabeler");
        this.clusterModelOutput.deliver(this.clusterModel);
        this.exampleSetOutput.deliver(this.examples);
    }

    private void calculateLabelingBruteForce() {
        this.maxRealAccuracy = Double.NEGATIVE_INFINITY;
        this.maxPerformance = Double.NEGATIVE_INFINITY;
        if (this.numOfClusters < this.numOfLabels || this.numOfClusters > this.max_k) {
            logMsg("More labels than clusters!");
            logMsg("--------------------------");
            int[] iArr = new int[this.numOfClusters];
            for (int i = 0; i < this.numOfClusters; i++) {
                iArr[i] = 0;
            }
            this.maxRealAccuracy = this.measure.getRealAccuracy(iArr);
            this.optRealLabeling = iArr;
            this.maxPerformance = this.measure.getPerformance(iArr);
            this.optPerformanceLabeling = iArr;
            return;
        }
        if (this.numOfClusters >= this.numOfLabels) {
            logMsg("Labelings by brute force:");
            logMsg("-------------------------");
            IntegerModuloCounter integerModuloCounter = new IntegerModuloCounter(this.numOfClusters, this.numOfLabels);
            while (!integerModuloCounter.hasOverflow()) {
                if (integerModuloCounter.containsAllMaxValues()) {
                    int[] array = integerModuloCounter.getArray();
                    String str = "Labeling: " + getLabelingAsString(array);
                    double performance = this.measure.getPerformance(array);
                    double realAccuracy = this.measure.getRealAccuracy(array);
                    String str2 = str + " -> " + String.format("%1.5f", Double.valueOf(performance)) + ", " + String.format("%1.5f", Double.valueOf(realAccuracy));
                    if (performance > this.maxPerformance) {
                        this.maxPerformance = performance;
                        this.foundAccuracy = realAccuracy;
                        this.optPerformanceLabeling = (int[]) array.clone();
                        str2 = str2 + " (*)";
                    }
                    if (realAccuracy > this.maxRealAccuracy) {
                        this.maxRealAccuracy = realAccuracy;
                        this.foundPerformance = performance;
                        this.optRealLabeling = (int[]) array.clone();
                        str2 = str2 + " (!)";
                    }
                    logMsg(str2);
                }
                integerModuloCounter.count();
            }
        }
    }

    private void calculateLabelingGreedyBySize() {
        this.maxPerformance = Double.NEGATIVE_INFINITY;
        logMsg("Labelings by greedy (size)");
        logMsg("--------------------------");
        bubbleSort(this.extractor.getClusterSizes());
        int[] iArr = new int[this.numOfClusters];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = 0;
        }
        for (int i2 = 0; i2 < this.numOfClusters; i2++) {
            double d = 0.0d;
            int i3 = 0;
            for (int i4 = 0; i4 < this.numOfLabels; i4++) {
                iArr[i2] = i4;
                double performance = this.measure.getPerformance(iArr);
                if (performance > d) {
                    d = performance;
                    i3 = i4;
                }
            }
            iArr[i2] = i3;
            String str = "Labeling: " + getLabelingAsString(iArr);
            double realAccuracy = this.measure.getRealAccuracy(iArr);
            String str2 = str + " -> " + String.format("%1.5f", Double.valueOf(d)) + ", " + String.format("%1.5f", Double.valueOf(realAccuracy));
            this.maxPerformance = d;
            this.foundAccuracy = realAccuracy;
            logMsg(str2);
        }
        this.optPerformanceLabeling = (int[]) iArr.clone();
    }

    private void assignClusterLabels(int[] iArr) {
        for (int i = 0; i < iArr.length; i++) {
            this.clusterModel.setClusterLabel(i, iArr[i]);
        }
    }

    private String getLabelingAsString(int[] iArr) {
        String str = "";
        for (int i : iArr) {
            str = str + i + " ";
        }
        return str;
    }

    private String getDoubleVectorAsString(double[] dArr) {
        String str = "( ";
        for (int i = 0; i < dArr.length; i++) {
            str = str + String.format("%1.5f", Double.valueOf(dArr[i]));
            if (i < dArr.length - 1) {
                str = str + ", ";
            }
        }
        return str + " )";
    }

    private void logExampleSetProperties() {
        logMsg("Num. of groups: " + this.extractor.getNumberOfGroups());
        logMsg("Num. of labels: " + this.extractor.getNumberOfLabels());
        logMsg("Num. of clusters: " + this.extractor.getNumberOfClusters());
        logMsg("Num. of labeled examples: " + this.extractor.getNumberOfLabeledExamples());
        logMsg("Num. of unlabeled examples: " + this.extractor.getNumberOfUnlabeledExamples());
    }

    private void logLabelProportions() {
        logMsg("Label proportions");
        for (int i = 0; i < this.extractor.getNumberOfGroups(); i++) {
            String format = String.format(" group %2d (%4d):  ", Integer.valueOf(i), Integer.valueOf(this.extractor.getGroupSize(i)));
            int[] labelNumbers = this.extractor.getLabelNumbers(i);
            double[] labelProportions = this.extractor.getLabelProportions(i);
            for (int i2 = 0; i2 < this.extractor.getNumberOfLabels(); i2++) {
                format = format + String.format("%1.3f (%3d)  ", Double.valueOf(labelProportions[i2]), Integer.valueOf(labelNumbers[i2]));
            }
            logMsg(format);
        }
    }

    private void logPredictedProportions(String str, int[] iArr) {
        logMsg("Predicted proportions for " + str);
        for (int i = 0; i < this.extractor.getNumberOfGroups(); i++) {
            String format = String.format(" group %2d (%4d):  ", Integer.valueOf(i), Integer.valueOf(this.extractor.getGroupSize(i)));
            int[] labelNumbers = this.extractor.getLabelNumbers(i);
            double[] predictionProportions = this.extractor.getPredictionProportions(i, iArr);
            for (int i2 = 0; i2 < this.extractor.getNumberOfLabels(); i2++) {
                format = format + String.format("%1.3f (%3d)  ", Double.valueOf(predictionProportions[i2]), Integer.valueOf(labelNumbers[i2]));
            }
            logMsg(format);
        }
    }

    @Override // edu.udo.cs.ls8.mllib.rapidminer.operators.ClusterOperator
    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt("k", "Max. number of clusters", 1, Integer.MAX_VALUE, 2));
        parameterTypes.add(new ParameterTypeStringCategory(PARAMETER_STRATEGY, "Specifies the labeling strategy.", strategies, strategies[0]));
        parameterTypes.add(new ParameterTypeBoolean(PARAMETER_BFCMP, "Brute force for comparison", false));
        return parameterTypes;
    }

    private int[] bubbleSort(int[] iArr) {
        int[] iArr2 = new int[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[i] = i;
        }
        boolean z = true;
        while (z) {
            z = false;
            for (int i2 = 0; i2 < iArr.length - 1; i2++) {
                if (iArr[iArr2[i2]] > iArr[iArr2[i2 + 1]]) {
                    int i3 = iArr2[i2];
                    iArr2[i2] = iArr2[i2 + 1];
                    iArr2[i2 + 1] = i3;
                    z = true;
                }
            }
        }
        return iArr2;
    }
}
