Hauptseite   Packages   Klassenhierarchie   ?bersicht   Auflistung der Dateien   Datenstruktur-Elemente  

FastAML.java

gehe zur Dokumentation dieser Datei
00001 /*
00002  * $Source: /shared/cvsroot/diplom/app/src/java/de/picana/clusterer/FastAML.java,v $
00003  * $Author: mstolpe $
00004  * $Date: 2003/04/22 09:51:27 $
00005  * $Revision: 1.6 $
00006  * $Release$
00007  *
00008  * Created on 21. November 2002, 14:05
00009  *
00010  * Copyright 2002 by Marco Stolpe
00011  */
00012 
00013 package de.picana.clusterer;
00014 
00015 import de.picana.control.*;
00016 import de.picana.logging.*;
00017 import de.picana.math.*;
00018 
00019 import java.io.*;
00020 import java.util.*;
00021 
00022 import weka.core.*;
00023 
00024 
00031 public class FastAML extends GenericML {
00032     
00033     private double rho;
00034     private String strategy;
00035     
00036     private List index_list;
00037     private List set_u;
00038     
00039     private MLVector corner;
00040     private double[] mean;
00041     private double[] max;
00042     private double[] min;
00043     private MLVector vec_a;
00044     private MLVector vec_b;
00045     private double max_dist;
00046     private int sum_pruned = 0;
00047     
00048     private long time_sort;
00049     
00050     //-----------------------------
00051     
00052     
00054     class IndexComparator implements Comparator {
00055      
00056         private int component;
00057         
00058         public IndexComparator(int component) {
00059             this.component = component;
00060         }
00061         
00062         public int compare(Object o1, Object o2) {
00063             MLVector vec1 = (MLVector)o1;
00064             MLVector vec2 = (MLVector)o2;
00065             if (vec1.value[component] > vec2.value[component])
00066                 return 1;
00067             if (vec1.value[component] < vec2.value[component])
00068                 return -1;
00069             return 0; 
00070         }
00071     }
00072     
00074     class MidComparator implements Comparator {
00075         
00076         public MidComparator() {
00077         }
00078         
00079         public int compare(Object o1, Object o2) {
00080             MLVector vec1 = (MLVector)o1;
00081             MLVector vec2 = (MLVector)o2;
00082             if (Distance.euklidian(vec1.value,mean) > Distance.euklidian(vec2.value,mean))
00083                 return -1;
00084             if (Distance.euklidian(vec1.value,mean) < Distance.euklidian(vec2.value,mean))
00085                 return 1;
00086             return 0; 
00087         }
00088     }
00089     
00090 
00091     
00093     public FastAML() {
00094     }
00095     
00096     public void init(ParameterSet params, Logger logger) {
00097         
00098         super.init(params, logger);
00099         
00100         rho = 0.0;
00101         try {
00102             rho = Double.parseDouble((String)params.getParameter("rho"));
00103         } catch (NumberFormatException nfe) {}
00104                 
00105         strategy = (String)params.getParameter("pruning");
00106     }
00107     
00108     
00109     protected void buildFirst() {
00110         
00111         statwriter.println("algo_clusterer_name:FastAML");
00112         statwriter.println("algo_clusterer_rho:" + rho);
00113         statwriter.println("algo_clusterer_pruning:" + strategy);
00114         
00115         int i, j;
00116         MLVector vec;
00117         int index;
00118         Integer freq;
00119         
00120         logger.info(LOGSRC, "Build sorted index table ...");
00121         
00122         time_sort = System.currentTimeMillis();
00123         
00124         set_u = new LinkedList();
00125         min = new double[dim];
00126         for (i=0; i < dim; i++)
00127             min[i] = Double.MAX_VALUE;
00128         max = new double[dim];
00129         for (i=0; i < dim; i++)
00130             max[i] = 0.0;
00131             
00132         for (i=0; i < freq_table.size(); i++) {
00133             
00134             vec = (MLVector)freq_table.get(i);
00135             
00136             for (j=0; j < vec.dim; j++) {
00137                 vec.value[j] *= Math.pow(
00138                     rho * ((double)vec.freq / (double)training_set.numInstances()) + (1-rho), rho);
00139             
00140                 if (vec.value[j] > max[j])
00141                     max[j] = vec.value[j];
00142                 if (vec.value[j] < min[j])
00143                     min[j] = vec.value[j];
00144             }
00145             
00146             set_u.add(vec);
00147         }
00148         
00149         // get mid(HC(U))
00150         
00151         corner = new MLVector(dim);
00152         
00153         mean = new double[dim];
00154         
00155         for (i=0; i < dim; i++) {
00156             mean[i] = (max[i] + min[i]) / 2.0;
00157         }
00158         
00159         // get sorted index table
00160         
00161         if (strategy.equals("rand")) {
00162             index_list = getRandIndexList();
00163         } else if (strategy.equals("lexi")) {
00164             index_list = getLexiIndexList();
00165         } else if (strategy.equals("comp")) {
00166             index_list = getCompIndexList();
00167         } else if (strategy.equals("mid")) {
00168             index_list = getMidIndexList();
00169         }
00170         
00171         time_sort = System.currentTimeMillis() - time_sort;
00172         
00173         logger.info(LOGSRC, "sorted index table built.");
00174         
00175         logger.info(LOGSRC, "Sorting took " + getTimeString(time_sort)); 
00176         
00177         // start algorithm
00178         
00179         max_dist = 0.0;
00180         
00181         for (i=0; i < index_list.size(); i++) {
00182             vec = (MLVector)index_list.get(i);
00183             if (!vec.pruned)
00184                 TryToPrune(vec);
00185         }
00186     
00187         logger.debug(LOGSRC, vec_a.toString() + " - " + vec_b.toString() + " = " + max_dist);
00188         
00189         statwriter.println("stat_Pruned:" + sum_pruned);
00190         
00191         logger.info(LOGSRC, "Pruned " + sum_pruned + " of " + freq_table.size() + " (" +
00192         (freq_table.size()-sum_pruned) + " left) instances in whole.");
00193         
00194        
00195         centroids.add(vec_a);
00196         logger.info(LOGSRC, "centroid[0] = " + vec_a.toString());
00197         centroids.add(vec_b);
00198         logger.info(LOGSRC, "centroid[1] = " + vec_b.toString());
00199         
00200         statwriter.println("time_clusterer_sort:" + time_sort);
00201     }
00202     
00203      
00204     protected void TryToPrune(MLVector p) {
00205         
00206         MLVector q = null;
00207         MLVector temp;
00208         double dist = 0.0;
00209         double max_qdist = 0.0;
00210         boolean prune = false;
00211         int pruned = 0;
00212         
00213         // get point q furthest from p
00214         
00215         Iterator set_u_iter = set_u.iterator();
00216         while (set_u_iter.hasNext()) {
00217             temp = (MLVector)set_u_iter.next();
00218             if (p.equals(temp)) {
00219                 temp.pruned = true;
00220                 set_u_iter.remove();  // remove p in each case
00221             } else {
00222                 dist = Distance.euklidian(p.value,temp.value);
00223                 if (dist > max_qdist) {
00224                     max_qdist = dist;
00225                     q = temp;
00226                     if (max_qdist == max_dist)
00227                         logger.info(LOGSRC, "found new max_qdist = " + max_qdist);
00228                 }
00229             }
00230         }
00231         
00232         if (max_qdist > max_dist) {  // if possible, prune points in set_u
00233             
00234             logger.info(LOGSRC, "found new maximum distance = " + max_qdist);
00235             
00236             max_dist = max_qdist;
00237             vec_a = p;
00238             vec_b = q;
00239         
00240             set_u_iter = set_u.iterator();
00241             while (set_u_iter.hasNext()) {
00242                 temp = (MLVector)set_u_iter.next();
00243 
00244                 if (Distance.euklidian(temp.value, calculateCorner(temp).value) > max_dist) {
00245                     prune = false;
00246                 } else {
00247                     temp.pruned = true;
00248                     set_u_iter.remove();
00249                     pruned++;
00250                 }
00251             }
00252             
00253             logger.info(LOGSRC, "Pruned " + pruned + " instances.");
00254             sum_pruned += pruned;
00255         }
00256     }
00257     
00258     
00259     protected void buildRest() {
00260         
00261         int i, j, k;
00262         MLVector vec = null;
00263         MLVector vec_a;
00264         MLVector vec_b;
00265         
00266         double max_dist;
00267         double min_dist;
00268         double act_dist;
00269         
00270         for (i=0; i < num_clusters-2; i++) {
00271             
00272             max_dist = 0.0;
00273             
00274             for (j=0; j < freq_table.size(); j++) {
00275                 
00276                 vec_a = (MLVector)freq_table.get(j);
00277                 
00278                 min_dist = Double.MAX_VALUE;
00279                 
00280                 for (k=0; k < centroids.size(); k++) {
00281                     
00282                     vec_b = (MLVector)centroids.get(k);
00283                     
00284                     act_dist = Distance.euklidian(vec_a.value, vec_b.value);
00285                     
00286                     if (act_dist < min_dist)
00287                         min_dist = act_dist;
00288                 }
00289                 
00290                 if (min_dist > max_dist) {
00291                     vec = vec_a;
00292                     max_dist = min_dist;
00293                 }
00294             }
00295 
00296             centroids.add(vec);
00297             logger.info(LOGSRC, "centroid[" + (i+2) + "] = " + vec.toString());
00298         }
00299         
00300         for (i=0; i < centroids.size(); i++) {
00301             
00302             vec = (MLVector)centroids.get(i);
00303             
00304             for (j=0; j < vec.dim; j++) {
00305                 vec.value[j] /= Math.pow(
00306                 rho * ((double)vec.freq / (double)training_set.numInstances()) + (1-rho), rho);
00307             }
00308         }
00309     }
00310     
00311     
00312     protected MLVector calculateCorner(MLVector vec) {
00313         
00314         for (int i=0; i < dim; i++) {
00315             
00316             if (vec.value[i] <= mean[i])
00317                 
00318                 corner.value[i] = max[i];
00319             else
00320                 corner.value[i] = min[i];
00321         }
00322         return corner;
00323     }
00324     
00325     
00326     protected List getRandIndexList() {
00327         return freq_table;
00328     }
00329     
00330     protected List getLexiIndexList() {
00331         Collections.sort(freq_table);
00332         return freq_table;
00333     }
00334     
00335     protected List getCompIndexList() {
00336         
00337         MLVector vec = null;
00338         List[] sorted_indexl = new ArrayList[dim];
00339         List sorted = new ArrayList();
00340         int i;
00341         
00342         for (i=0; i < dim; i++) {
00343             sorted_indexl[i] = new ArrayList();
00344             for (int j=0; j < freq_table.size(); j++) {
00345                 vec = (MLVector)freq_table.get(j);
00346                 sorted_indexl[i].add(vec);
00347             }
00348             Collections.sort(sorted_indexl[i], new IndexComparator(i));
00349         }
00350         
00351         i=0;
00352         int min = 0;
00353         int max = freq_table.size()-1;
00354         while (i < (dim*freq_table.size())) {
00355             for (int c=0; c < dim; c++) {
00356                 sorted.add(sorted_indexl[c].get(min));
00357                 i++;
00358             }
00359             min++;
00360             if (i < (dim*freq_table.size())) {
00361                 for (int c=0; c < dim; c++) {
00362                     sorted.add(sorted_indexl[c].get(max));
00363                     i++;
00364                 }
00365                 max--;
00366             }
00367         }
00368         System.out.println(dim*freq_table.size() + " = " + sorted.size());
00369         return sorted;
00370     }
00371     
00372     protected List getMidIndexList() {
00373         Collections.sort(freq_table, new MidComparator());
00374         return freq_table;
00375     }
00376 }
00377 

Erzeugt am Tue Apr 22 11:22:55 2003 f?r Picana von doxygen1.2.18