package com.rapidminer.operator.dtw.clustering;

import com.rapidminer.operator.IOObjectCollection;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.TimeSeriesOperator;
import com.rapidminer.operator.UserError;
import com.rapidminer.operator.ValueDouble;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.CollectionMetaData;
import com.rapidminer.operator.ports.metadata.MDTransformationRule;
import com.rapidminer.operator.valueseries.Feature;
import com.rapidminer.operator.valueseries.ValueSeries;
import com.rapidminer.operator.valueseries.ValueSeriesMetaData;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeBoolean;
import com.rapidminer.parameter.ParameterTypeCategory;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.parameter.UndefinedParameterError;
import com.rapidminer.parameter.conditions.BooleanParameterCondition;
import com.rapidminer.tools.RandomGenerator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import timeseriesclustering.DTW;
import timeseriesclustering.DegeneratedClusteringException;
import timeseriesclustering.NoConstraint;
import timeseriesclustering.SakoeChibaBand;
import timeseriesclustering.TimeSeriesClustering;
import timeseriesclustering.averaging.AveragingTechnique;
import timeseriesclustering.averaging.EuclidianAveraging;
import timeseriesclustering.averaging.FastShapeBasedAveraging;
import timeseriesclustering.averaging.FixpointAveraging;
import timeseriesclustering.averaging.MedoidAveraging;
import timeseriesclustering.averaging.ProjectionAveraging;
import timeseriesclustering.averaging.ShapeBasedAveraging;

/* loaded from: input_file:com/rapidminer/operator/dtw/clustering/KMeansClustering.class */
public class KMeansClustering extends TimeSeriesOperator {
    private InputPort timeSeriesInput;
    private OutputPort timeSeriesOutput;
    private OutputPort centroidsOutput;
    private TimeSeriesClustering clusterer;

    public KMeansClustering(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.timeSeriesInput = getInputPorts().createPort("series-in", IOObjectCollection.class);
        this.timeSeriesOutput = getOutputPorts().createPort("series-out");
        this.centroidsOutput = getOutputPorts().createPort("centroids-out");
        getTransformer().addRule(new MDTransformationRule() { // from class: com.rapidminer.operator.dtw.clustering.KMeansClustering.1
            public void transformMD() {
                KMeansClustering.this.timeSeriesOutput.deliverMD(new CollectionMetaData(new ValueSeriesMetaData()));
                KMeansClustering.this.centroidsOutput.deliverMD(new CollectionMetaData(new ValueSeriesMetaData()));
            }
        });
        addValue(new ValueDouble("DTW-Count", "The number of the computed DTW distances.") { // from class: com.rapidminer.operator.dtw.clustering.KMeansClustering.2
            public double getDoubleValue() {
                return DTW.count;
            }
        });
        addValue(new ValueDouble("DTW-QueryRatio", "The ratio of the computed DTW distances per Query.") { // from class: com.rapidminer.operator.dtw.clustering.KMeansClustering.3
            public double getDoubleValue() {
                return KMeansClustering.this.clusterer.getDTWRatio();
            }
        });
        addValue(new ValueDouble("costs", "The total costs of the computed clustering.") { // from class: com.rapidminer.operator.dtw.clustering.KMeansClustering.4
            public double getDoubleValue() {
                return KMeansClustering.this.clusterer.getCosts();
            }
        });
    }

    public void doWork() throws UserError {
        IOObjectCollection data = this.timeSeriesInput.getData(IOObjectCollection.class);
        HashMap hashMap = new HashMap();
        this.clusterer = new TimeSeriesClustering();
        this.clusterer.setRandom(RandomGenerator.getRandomGenerator(this));
        for (ValueSeries valueSeries : data.getObjects()) {
            Double[] convert = convert(valueSeries);
            hashMap.put(convert, valueSeries);
            this.clusterer.addTimeSeries(convert);
        }
        try {
            DTW.count = 0.0d;
            int parameterAsInt = getParameterAsInt("k");
            this.clusterer.kMeans(parameterAsInt, getParameterAsBoolean("Sakoe-Chiba-Band") ? new SakoeChibaBand(getParameterAsInt("Sakoe-Chiba-Band width")) : new NoConstraint(), getParameterAsInt("max iterations"), getParameterAsBoolean("kmeans++"), getAveragingTechnique(), null);
            ArrayList<ArrayList<Double[]>> clusters = this.clusterer.getClusters();
            ArrayList<Double[]> centers = this.clusterer.getCenters();
            IOObjectCollection iOObjectCollection = new IOObjectCollection();
            double d = 0.0d;
            Iterator<Double[]> it = centers.iterator();
            while (it.hasNext()) {
                ValueSeries convert2 = convert(it.next());
                convert2.addFeature(new Feature("label", d));
                convert2.addFeature(new Feature("size", this.clusterer.getClusters().get((int) d).size()));
                iOObjectCollection.add(convert2);
                d += 1.0d;
            }
            for (int i = 0; i < parameterAsInt; i++) {
                Iterator<Double[]> it2 = clusters.get(i).iterator();
                while (it2.hasNext()) {
                    Double[] next = it2.next();
                    if (!hashMap.containsKey(next)) {
                        System.err.println("richtig mieser mist ist hier am abgehen");
                        System.out.println(next);
                        System.out.println(next.hashCode());
                        System.out.println(i);
                    }
                    ((ValueSeries) hashMap.get(next)).addFeature(new Feature("cluster", i));
                    if (getParameterAsBoolean("distances")) {
                        for (int i2 = 0; i2 < centers.size(); i2++) {
                            ((ValueSeries) hashMap.get(next)).addFeature(new Feature("d" + i2, DTW.dtw(next, centers.get(i2), Double.POSITIVE_INFINITY, 0.0d, null, null, new NoConstraint()).distance));
                        }
                    }
                }
            }
            this.timeSeriesOutput.deliver(data);
            this.centroidsOutput.deliver(iOObjectCollection);
        } catch (DegeneratedClusteringException e) {
            doWork();
        }
    }

    private AveragingTechnique getAveragingTechnique() throws UndefinedParameterError {
        String parameterAsString = getParameterAsString("averaging technique");
        return parameterAsString.equals("Euclidian") ? new EuclidianAveraging() : parameterAsString.equals("Projection") ? new ProjectionAveraging() : parameterAsString.equals("Fixpoint") ? new FixpointAveraging() : parameterAsString.equals("Shapebased") ? new ShapeBasedAveraging() : parameterAsString.equals("Fast Shapebased") ? new FastShapeBasedAveraging() : parameterAsString.equals("Medoids") ? new MedoidAveraging() : new EuclidianAveraging();
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(new ParameterTypeInt("k", "the number of clusters k-means should compute", 0, Integer.MAX_VALUE, 3));
        parameterTypes.add(new ParameterTypeBoolean("kmeans++", "use k-means++ initialization", true));
        parameterTypes.add(new ParameterTypeBoolean("distances", "Save the distances to the calculated centroids", false));
        parameterTypes.add(new ParameterTypeInt("max iterations", "maximum number of iterations", 0, Integer.MAX_VALUE, 20));
        parameterTypes.add(new ParameterTypeCategory("averaging technique", "the algorithm used to average a set of time series", new String[]{"Euclidian ", "Projection", "Fixpoint", "Shapebased", "Fast Shapebased", "Medoids"}, 1));
        parameterTypes.add(new ParameterTypeBoolean("Sakoe-Chiba-Band", "Use a global constraint?", true));
        ParameterTypeInt parameterTypeInt = new ParameterTypeInt("Sakoe-Chiba-Band width", "the width of the global constraint. 10% of time series length is a commonly used value", 0, Integer.MAX_VALUE, 16);
        parameterTypeInt.registerDependencyCondition(new BooleanParameterCondition(this, "Sakoe-Chiba-Band", true, true));
        parameterTypes.add(parameterTypeInt);
        parameterTypes.addAll(RandomGenerator.getRandomGeneratorParameters(this));
        return parameterTypes;
    }
}
