00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 package de.picana.clusterer;
00014
00015 import de.picana.control.*;
00016 import de.picana.logging.*;
00017 import de.picana.math.*;
00018
00019 import java.io.*;
00020 import java.util.*;
00021
00022 import weka.core.*;
00023
00024
00031 public class FastAML extends GenericML {
00032
00033 private double rho;
00034 private String strategy;
00035
00036 private List index_list;
00037 private List set_u;
00038
00039 private MLVector corner;
00040 private double[] mean;
00041 private double[] max;
00042 private double[] min;
00043 private MLVector vec_a;
00044 private MLVector vec_b;
00045 private double max_dist;
00046 private int sum_pruned = 0;
00047
00048 private long time_sort;
00049
00050
00051
00052
00054 class IndexComparator implements Comparator {
00055
00056 private int component;
00057
00058 public IndexComparator(int component) {
00059 this.component = component;
00060 }
00061
00062 public int compare(Object o1, Object o2) {
00063 MLVector vec1 = (MLVector)o1;
00064 MLVector vec2 = (MLVector)o2;
00065 if (vec1.value[component] > vec2.value[component])
00066 return 1;
00067 if (vec1.value[component] < vec2.value[component])
00068 return -1;
00069 return 0;
00070 }
00071 }
00072
00074 class MidComparator implements Comparator {
00075
00076 public MidComparator() {
00077 }
00078
00079 public int compare(Object o1, Object o2) {
00080 MLVector vec1 = (MLVector)o1;
00081 MLVector vec2 = (MLVector)o2;
00082 if (Distance.euklidian(vec1.value,mean) > Distance.euklidian(vec2.value,mean))
00083 return -1;
00084 if (Distance.euklidian(vec1.value,mean) < Distance.euklidian(vec2.value,mean))
00085 return 1;
00086 return 0;
00087 }
00088 }
00089
00090
00091
00093 public FastAML() {
00094 }
00095
00096 public void init(ParameterSet params, Logger logger) {
00097
00098 super.init(params, logger);
00099
00100 rho = 0.0;
00101 try {
00102 rho = Double.parseDouble((String)params.getParameter("rho"));
00103 } catch (NumberFormatException nfe) {}
00104
00105 strategy = (String)params.getParameter("pruning");
00106 }
00107
00108
00109 protected void buildFirst() {
00110
00111 statwriter.println("algo_clusterer_name:FastAML");
00112 statwriter.println("algo_clusterer_rho:" + rho);
00113 statwriter.println("algo_clusterer_pruning:" + strategy);
00114
00115 int i, j;
00116 MLVector vec;
00117 int index;
00118 Integer freq;
00119
00120 logger.info(LOGSRC, "Build sorted index table ...");
00121
00122 time_sort = System.currentTimeMillis();
00123
00124 set_u = new LinkedList();
00125 min = new double[dim];
00126 for (i=0; i < dim; i++)
00127 min[i] = Double.MAX_VALUE;
00128 max = new double[dim];
00129 for (i=0; i < dim; i++)
00130 max[i] = 0.0;
00131
00132 for (i=0; i < freq_table.size(); i++) {
00133
00134 vec = (MLVector)freq_table.get(i);
00135
00136 for (j=0; j < vec.dim; j++) {
00137 vec.value[j] *= Math.pow(
00138 rho * ((double)vec.freq / (double)training_set.numInstances()) + (1-rho), rho);
00139
00140 if (vec.value[j] > max[j])
00141 max[j] = vec.value[j];
00142 if (vec.value[j] < min[j])
00143 min[j] = vec.value[j];
00144 }
00145
00146 set_u.add(vec);
00147 }
00148
00149
00150
00151 corner = new MLVector(dim);
00152
00153 mean = new double[dim];
00154
00155 for (i=0; i < dim; i++) {
00156 mean[i] = (max[i] + min[i]) / 2.0;
00157 }
00158
00159
00160
00161 if (strategy.equals("rand")) {
00162 index_list = getRandIndexList();
00163 } else if (strategy.equals("lexi")) {
00164 index_list = getLexiIndexList();
00165 } else if (strategy.equals("comp")) {
00166 index_list = getCompIndexList();
00167 } else if (strategy.equals("mid")) {
00168 index_list = getMidIndexList();
00169 }
00170
00171 time_sort = System.currentTimeMillis() - time_sort;
00172
00173 logger.info(LOGSRC, "sorted index table built.");
00174
00175 logger.info(LOGSRC, "Sorting took " + getTimeString(time_sort));
00176
00177
00178
00179 max_dist = 0.0;
00180
00181 for (i=0; i < index_list.size(); i++) {
00182 vec = (MLVector)index_list.get(i);
00183 if (!vec.pruned)
00184 TryToPrune(vec);
00185 }
00186
00187 logger.debug(LOGSRC, vec_a.toString() + " - " + vec_b.toString() + " = " + max_dist);
00188
00189 statwriter.println("stat_Pruned:" + sum_pruned);
00190
00191 logger.info(LOGSRC, "Pruned " + sum_pruned + " of " + freq_table.size() + " (" +
00192 (freq_table.size()-sum_pruned) + " left) instances in whole.");
00193
00194
00195 centroids.add(vec_a);
00196 logger.info(LOGSRC, "centroid[0] = " + vec_a.toString());
00197 centroids.add(vec_b);
00198 logger.info(LOGSRC, "centroid[1] = " + vec_b.toString());
00199
00200 statwriter.println("time_clusterer_sort:" + time_sort);
00201 }
00202
00203
00204 protected void TryToPrune(MLVector p) {
00205
00206 MLVector q = null;
00207 MLVector temp;
00208 double dist = 0.0;
00209 double max_qdist = 0.0;
00210 boolean prune = false;
00211 int pruned = 0;
00212
00213
00214
00215 Iterator set_u_iter = set_u.iterator();
00216 while (set_u_iter.hasNext()) {
00217 temp = (MLVector)set_u_iter.next();
00218 if (p.equals(temp)) {
00219 temp.pruned = true;
00220 set_u_iter.remove();
00221 } else {
00222 dist = Distance.euklidian(p.value,temp.value);
00223 if (dist > max_qdist) {
00224 max_qdist = dist;
00225 q = temp;
00226 if (max_qdist == max_dist)
00227 logger.info(LOGSRC, "found new max_qdist = " + max_qdist);
00228 }
00229 }
00230 }
00231
00232 if (max_qdist > max_dist) {
00233
00234 logger.info(LOGSRC, "found new maximum distance = " + max_qdist);
00235
00236 max_dist = max_qdist;
00237 vec_a = p;
00238 vec_b = q;
00239
00240 set_u_iter = set_u.iterator();
00241 while (set_u_iter.hasNext()) {
00242 temp = (MLVector)set_u_iter.next();
00243
00244 if (Distance.euklidian(temp.value, calculateCorner(temp).value) > max_dist) {
00245 prune = false;
00246 } else {
00247 temp.pruned = true;
00248 set_u_iter.remove();
00249 pruned++;
00250 }
00251 }
00252
00253 logger.info(LOGSRC, "Pruned " + pruned + " instances.");
00254 sum_pruned += pruned;
00255 }
00256 }
00257
00258
00259 protected void buildRest() {
00260
00261 int i, j, k;
00262 MLVector vec = null;
00263 MLVector vec_a;
00264 MLVector vec_b;
00265
00266 double max_dist;
00267 double min_dist;
00268 double act_dist;
00269
00270 for (i=0; i < num_clusters-2; i++) {
00271
00272 max_dist = 0.0;
00273
00274 for (j=0; j < freq_table.size(); j++) {
00275
00276 vec_a = (MLVector)freq_table.get(j);
00277
00278 min_dist = Double.MAX_VALUE;
00279
00280 for (k=0; k < centroids.size(); k++) {
00281
00282 vec_b = (MLVector)centroids.get(k);
00283
00284 act_dist = Distance.euklidian(vec_a.value, vec_b.value);
00285
00286 if (act_dist < min_dist)
00287 min_dist = act_dist;
00288 }
00289
00290 if (min_dist > max_dist) {
00291 vec = vec_a;
00292 max_dist = min_dist;
00293 }
00294 }
00295
00296 centroids.add(vec);
00297 logger.info(LOGSRC, "centroid[" + (i+2) + "] = " + vec.toString());
00298 }
00299
00300 for (i=0; i < centroids.size(); i++) {
00301
00302 vec = (MLVector)centroids.get(i);
00303
00304 for (j=0; j < vec.dim; j++) {
00305 vec.value[j] /= Math.pow(
00306 rho * ((double)vec.freq / (double)training_set.numInstances()) + (1-rho), rho);
00307 }
00308 }
00309 }
00310
00311
00312 protected MLVector calculateCorner(MLVector vec) {
00313
00314 for (int i=0; i < dim; i++) {
00315
00316 if (vec.value[i] <= mean[i])
00317
00318 corner.value[i] = max[i];
00319 else
00320 corner.value[i] = min[i];
00321 }
00322 return corner;
00323 }
00324
00325
00326 protected List getRandIndexList() {
00327 return freq_table;
00328 }
00329
00330 protected List getLexiIndexList() {
00331 Collections.sort(freq_table);
00332 return freq_table;
00333 }
00334
00335 protected List getCompIndexList() {
00336
00337 MLVector vec = null;
00338 List[] sorted_indexl = new ArrayList[dim];
00339 List sorted = new ArrayList();
00340 int i;
00341
00342 for (i=0; i < dim; i++) {
00343 sorted_indexl[i] = new ArrayList();
00344 for (int j=0; j < freq_table.size(); j++) {
00345 vec = (MLVector)freq_table.get(j);
00346 sorted_indexl[i].add(vec);
00347 }
00348 Collections.sort(sorted_indexl[i], new IndexComparator(i));
00349 }
00350
00351 i=0;
00352 int min = 0;
00353 int max = freq_table.size()-1;
00354 while (i < (dim*freq_table.size())) {
00355 for (int c=0; c < dim; c++) {
00356 sorted.add(sorted_indexl[c].get(min));
00357 i++;
00358 }
00359 min++;
00360 if (i < (dim*freq_table.size())) {
00361 for (int c=0; c < dim; c++) {
00362 sorted.add(sorted_indexl[c].get(max));
00363 i++;
00364 }
00365 max--;
00366 }
00367 }
00368 System.out.println(dim*freq_table.size() + " = " + sorted.size());
00369 return sorted;
00370 }
00371
00372 protected List getMidIndexList() {
00373 Collections.sort(freq_table, new MidComparator());
00374 return freq_table;
00375 }
00376 }
00377