package edu.udo.cs.ls8.mllib.rapidminer.operators;

import com.rapidminer.example.Attribute;
import com.rapidminer.example.Example;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.Tools;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.OperatorDescription;
import com.rapidminer.operator.OperatorException;
import com.rapidminer.operator.ports.InputPort;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.metadata.AttributeMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetMetaData;
import com.rapidminer.operator.ports.metadata.ExampleSetPassThroughRule;
import com.rapidminer.operator.ports.metadata.SetRelation;
import com.rapidminer.parameter.ParameterType;
import com.rapidminer.parameter.ParameterTypeDouble;
import com.rapidminer.parameter.ParameterTypeInt;
import com.rapidminer.tools.RandomGenerator;
import edu.udo.cs.ls8.mllib.rapidminer.helpers.OperatorDebugger;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:edu/udo/cs/ls8/mllib/rapidminer/operators/SampleProportionGroups.class */
public class SampleProportionGroups extends Operator {
    public static final String PARAMETER_LABELED = "labeled";
    public static final String PARAMETER_GROUP_SIZE = "group size";
    public static final String PARAMETER_DEVIATION = "deviation";
    public static final String PARAMETER_LOCAL_RANDOM_SEED = "local_random_seed";
    private InputPort exampleSetInput;
    private OutputPort exampleSetOutput;
    private OperatorDebugger debug;
    private double labeled;
    private int groupSize;
    private double deviation;
    private int numOfGroups;
    private int seed;
    private ExampleSet examples;
    private ArrayList<Double> exampleIds;
    private ArrayList<ArrayList<Double>> exampleIdsByLabel;

    public SampleProportionGroups(OperatorDescription operatorDescription) {
        super(operatorDescription);
        this.exampleSetInput = getInputPorts().createPort("example set", ExampleSet.class);
        this.exampleSetOutput = getOutputPorts().createPort("example set");
        this.debug = null;
        this.labeled = 0.0d;
        this.groupSize = 0;
        this.deviation = 0.0d;
        this.numOfGroups = 0;
        this.seed = 0;
        this.examples = null;
        this.exampleIds = null;
        this.exampleIdsByLabel = null;
        this.debug = new OperatorDebugger(this);
        getTransformer().addRule(new ExampleSetPassThroughRule(this.exampleSetInput, this.exampleSetOutput, SetRelation.EQUAL) { // from class: edu.udo.cs.ls8.mllib.rapidminer.operators.SampleProportionGroups.1
            public ExampleSetMetaData modifyExampleSet(ExampleSetMetaData exampleSetMetaData) {
                exampleSetMetaData.addAttribute(new AttributeMetaData("id", 3, "id"));
                exampleSetMetaData.addAttribute(new AttributeMetaData("group", 1, "group"));
                exampleSetMetaData.addAttribute(new AttributeMetaData(SampleProportionGroups.PARAMETER_LABELED, 1, SampleProportionGroups.PARAMETER_LABELED));
                return exampleSetMetaData;
            }
        });
    }

    public void doWork() throws OperatorException {
        super.doWork();
        this.debug.logMsg("Begin SampleProportionGroups");
        this.labeled = getParameterAsDouble(PARAMETER_LABELED);
        this.groupSize = getParameterAsInt(PARAMETER_GROUP_SIZE);
        this.deviation = getParameterAsDouble(PARAMETER_DEVIATION);
        this.seed = getParameterAsInt(PARAMETER_LOCAL_RANDOM_SEED);
        this.examples = this.exampleSetInput.getData();
        this.numOfGroups = (int) (this.examples.size() / this.groupSize);
        if (this.examples.size() % this.groupSize > 0) {
            this.numOfGroups++;
        }
        Tools.onlyNonMissingValues(this.examples, "SampleProportionGroups");
        Tools.checkAndCreateIds(this.examples);
        this.exampleIds = new ArrayList<>();
        Iterator it = this.examples.iterator();
        while (it.hasNext()) {
            this.exampleIds.add(new Double(((Example) it.next()).getId()));
        }
        if (this.debug.print()) {
            int size = this.exampleIds.size();
            for (int i = 0; i < size; i++) {
                double doubleValue = this.exampleIds.get(i).doubleValue();
                int i2 = 0;
                for (int i3 = 0; i3 < size; i3++) {
                    if (this.exampleIds.get(i3).doubleValue() == doubleValue) {
                        i2++;
                    }
                }
                if (i2 != 1) {
                    System.out.println("v = " + doubleValue + ", " + i2);
                }
            }
        }
        Attribute label = this.examples.getAttributes().getLabel();
        int size2 = label.getMapping().size();
        this.exampleIdsByLabel = new ArrayList<>();
        for (int i4 = 0; i4 < size2; i4++) {
            this.exampleIdsByLabel.add(new ArrayList<>());
        }
        for (int i5 = 0; i5 < this.examples.size(); i5++) {
            Example example = this.examples.getExample(i5);
            int value = (int) example.getValue(label);
            ArrayList<Double> arrayList = this.exampleIdsByLabel.get(value);
            arrayList.add(Double.valueOf(example.getId()));
            this.exampleIdsByLabel.set(value, arrayList);
        }
        RandomGenerator randomGenerator = RandomGenerator.getRandomGenerator(this);
        int size3 = ((int) (this.examples.size() * this.labeled)) - (((int) (this.examples.size() * this.labeled)) % size2);
        Attribute attribute = this.examples.getAttributes().get("group");
        if (attribute != null) {
            attribute.getMapping().clear();
            this.examples.getExampleTable().removeAttribute(attribute);
            this.examples.getAttributes().remove(attribute);
        }
        Attribute createAttribute = AttributeFactory.createAttribute("group", 1);
        this.examples.getExampleTable().addAttribute(createAttribute);
        this.examples.getAttributes().addRegular(createAttribute);
        this.examples.getAttributes().setSpecialAttribute(createAttribute, "group");
        for (int i6 = 0; i6 < this.examples.size(); i6++) {
            this.examples.getExample(i6).setValue(createAttribute, -1.0d);
        }
        Attribute attribute2 = this.examples.getAttributes().get(PARAMETER_LABELED);
        if (attribute2 == null) {
            attribute2 = AttributeFactory.createAttribute(PARAMETER_LABELED, 1);
            attribute2.getMapping().mapString("no");
            attribute2.getMapping().mapString("yes");
            this.examples.getExampleTable().addAttribute(attribute2);
            this.examples.getAttributes().addRegular(attribute2);
            this.examples.getAttributes().setSpecialAttribute(attribute2, PARAMETER_LABELED);
        }
        if (this.deviation == 0.0d) {
            for (int i7 = 0; i7 < this.numOfGroups; i7++) {
                createAttribute.getMapping().mapString("group_" + i7);
            }
            int i8 = 0;
            while (this.exampleIds.size() > 0) {
                for (int i9 = 0; i9 < this.numOfGroups && this.exampleIds.size() > 0; i9++) {
                    Example findExampleById = findExampleById(this.exampleIds.remove(randomGenerator.nextInt(this.exampleIds.size())).doubleValue());
                    findExampleById.setValue(createAttribute, "group_" + i9);
                    findExampleById.setValue(attribute2, 0.0d);
                }
                i8++;
            }
            if (this.debug.print()) {
                for (int i10 = -1; i10 < this.numOfGroups; i10++) {
                    int i11 = 0;
                    Iterator it2 = this.examples.iterator();
                    while (it2.hasNext()) {
                        if (((Example) it2.next()).getValue(createAttribute) == i10) {
                            i11++;
                        }
                    }
                    System.out.println("group_" + i10 + ": " + i11);
                }
            }
        } else {
            Random random = new Random();
            int i12 = 0;
            int i13 = 0;
            while (this.exampleIds.size() > 0) {
                createAttribute.getMapping().mapString("group_" + i12);
                int nextGaussian = (int) ((this.deviation * random.nextGaussian()) + this.groupSize);
                if (nextGaussian <= 0) {
                    nextGaussian = 1;
                }
                if (nextGaussian > this.exampleIds.size()) {
                    nextGaussian = this.exampleIds.size();
                }
                int i14 = 0;
                for (int i15 = 0; i15 < nextGaussian && this.exampleIds.size() > 0; i15++) {
                    double doubleValue2 = this.exampleIds.get(randomGenerator.nextInt(this.exampleIds.size())).doubleValue();
                    this.exampleIds.remove(Double.valueOf(doubleValue2));
                    Example findExampleById2 = findExampleById(doubleValue2);
                    findExampleById2.setValue(createAttribute, "group_" + i12);
                    findExampleById2.setValue(attribute2, 0.0d);
                    i14++;
                }
                System.out.println("Sampled " + i14 + " / " + nextGaussian);
                i13 += i14;
                i12++;
            }
            System.out.println("Total: " + i13);
        }
        for (int i16 = 0; i16 < size3 / size2; i16++) {
            for (int i17 = 0; i17 < size2; i17++) {
                ArrayList<Double> arrayList2 = this.exampleIdsByLabel.get(i17);
                double doubleValue3 = arrayList2.get(randomGenerator.nextInt(arrayList2.size())).doubleValue();
                arrayList2.remove(Double.valueOf(doubleValue3));
                findExampleById(doubleValue3).setValue(attribute2, 1.0d);
            }
        }
        this.debug.logMsg("End SampleProportionGroups");
        this.exampleSetOutput.deliver(this.examples);
    }

    public Example findExampleById(double d) {
        for (Example example : this.examples) {
            if (example.getId() == d) {
                return example;
            }
        }
        return null;
    }

    public List<ParameterType> getParameterTypes() {
        List<ParameterType> parameterTypes = super.getParameterTypes();
        parameterTypes.add(OperatorDebugger.DEBUG_PRINT);
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_LABELED, "Total number of labeled examples in percent / 100", 0.0d, 1.0d, 0.0d));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_GROUP_SIZE, "Size of groups to sample", 0, Integer.MAX_VALUE, 10));
        parameterTypes.add(new ParameterTypeDouble(PARAMETER_DEVIATION, "Deviation", 0.0d, Double.MAX_VALUE, 0.0d));
        parameterTypes.add(new ParameterTypeInt(PARAMETER_LOCAL_RANDOM_SEED, "Use the given random seed instead of global random numbers (-1: use global).", -1, Integer.MAX_VALUE, -1));
        return parameterTypes;
    }
}
