00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 package de.picana.clusterer;
00014
00015 import de.picana.control.*;
00016 import de.picana.logging.*;
00017 import de.picana.math.*;
00018
00019 import java.io.*;
00020 import java.util.*;
00021
00022 import weka.core.*;
00023
00024
00031 public abstract class GenericML extends Clusterer {
00032
00033 protected List freq_table;
00034 protected List centroids;
00035 protected Map freq_htable;
00036
00037 protected int dim;
00038
00039 protected long time_freq;
00040 protected long time_step1;
00041 protected long time_step2;
00042 protected long time_total;
00043
00044
00046 public GenericML() {
00047 }
00048
00049
00050 public void buildClusterer(Instances set) throws TaskException {
00051
00052 statwriter.println("algo_clusterer_model:" + modelfile);
00053 statwriter.println("algo_clusterer_training_set:" + infile);
00054
00055 int i, j;
00056 MLVector vec;
00057
00058 centroids = new ArrayList();
00059
00060 Integer freq;
00061 Instance inst;
00062
00063 time_total = System.currentTimeMillis();
00064
00065 time_freq = System.currentTimeMillis();
00066
00067 freq_htable = new HashMap();
00068
00069
00070
00071 logger.info(LOGSRC, "Build frequency table from " + set.numInstances() + " instances ...");
00072 for (i=0; i < set.numInstances(); i++) {
00073 vec = new MLVector(set.instance(i));
00074 dim = vec.dim;
00075 freq = (Integer)freq_htable.get(vec);
00076
00077 freq_htable.put(vec, (freq==null) ? ONE :
00078 new Integer(freq.intValue() + 1));
00079 }
00080 logger.info(LOGSRC, "frequency table containing " + freq_htable.size() + " instances built.");
00081
00082
00083
00084 freq_table = new ArrayList();
00085 Iterator entries = freq_htable.entrySet().iterator();
00086 Map.Entry entry;
00087 while (entries.hasNext()) {
00088 entry = (Map.Entry)entries.next();
00089 vec = (MLVector)entry.getKey();
00090 freq = (Integer)entry.getValue();
00091 vec.freq = freq.intValue();
00092 freq_table.add(vec);
00093 }
00094 freq_htable = null;
00095
00096 time_freq = System.currentTimeMillis() - time_freq;
00097
00098 logger.info(LOGSRC, "Frequency table built took " + getTimeString(time_freq));
00099
00100 statwriter.println("stat_Instances:" + set.numInstances());
00101 statwriter.println("stat_UniqueInst:" + freq_table.size());
00102
00103 logger.info(LOGSRC, "Find farthest pair (= first two centroids) ...");
00104 time_step1 = System.currentTimeMillis();
00105 buildFirst();
00106 time_step1 = System.currentTimeMillis() - time_step1;
00107 logger.info(LOGSRC, "found farthest pair.");
00108
00109 logger.info(LOGSRC, "Step 1 took " + getTimeString(time_step1));
00110
00111 logger.info(LOGSRC, "Find the last " + (num_clusters-2) + " centroids ...");
00112 time_step2 = System.currentTimeMillis();
00113 buildRest();
00114 time_step2 = System.currentTimeMillis() - time_step2;
00115 logger.info(LOGSRC, "found the last " + (num_clusters-2) + " centroids.");
00116
00117 logger.info(LOGSRC, "Step 2 took " + getTimeString(time_step2));
00118
00119 time_total = System.currentTimeMillis() - time_total;
00120
00121 logger.info(LOGSRC, "Algorithm took " + getTimeString(time_total));
00122
00123 statwriter.println("time_clusterer_freq:" + time_freq);
00124 statwriter.println("time_clusterer_step1:" + time_step1);
00125 statwriter.println("time_clusterer_step2:" + time_step2);
00126 statwriter.println("time_clusterer_total:" + time_total);
00127 }
00128
00129
00130 protected abstract void buildFirst();
00131
00132 protected abstract void buildRest();
00133
00134
00135 protected Object getRandomElement(Collection c) {
00136 int index = rand.nextInt(c.size());
00137 int i=0;
00138 Object obj = null;
00139 Iterator elems = c.iterator();
00140 while (elems.hasNext()) {
00141 if (i==index)
00142 obj = elems.next();
00143 else
00144 elems.next();
00145 i++;
00146 }
00147 return obj;
00148 }
00149
00150
00151 public void saveModel(String filename) throws TaskException {
00152
00153 try {
00154 File outfile = new File(filename);
00155 FileOutputStream out = new FileOutputStream(outfile);
00156 PrintWriter pw = new PrintWriter(out);
00157
00158 for (int att=0; att < training_set.numAttributes(); att++) {
00159 pw.print(training_set.attribute(att).name());
00160 if (att != training_set.numAttributes()-1)
00161 pw.print(",");
00162 }
00163 pw.println();
00164
00165 for (int i=0; i < centroids.size(); i++) {
00166 MLVector vec = (MLVector)centroids.get(i);
00167
00168 for (int j=0; j < vec.value.length; j++) {
00169
00170 pw.print(vec.value[j]);
00171 if (j != vec.value.length-1) {
00172 pw.print(",");
00173 }
00174 }
00175 pw.println();
00176 }
00177
00178 pw.close();
00179
00180 } catch (Exception e) {
00181 throw new TaskException(e.toString());
00182 }
00183 }
00184 }