Hauptseite   Packages   Klassenhierarchie   ?bersicht   Auflistung der Dateien   Datenstruktur-Elemente  

FastAMLJ.java

gehe zur Dokumentation dieser Datei
00001 /*
00002  * $Source: /shared/cvsroot/diplom/app/src/java/de/picana/clusterer/FastAMLJ.java,v $
00003  * $Author: mstolpe $
00004  * $Date: 2003/04/22 09:51:27 $
00005  * $Revision: 1.11 $
00006  * $Release$
00007  *
00008  * Created on 21. November 2002, 14:05
00009  *
00010  * Copyright 2002 by Marco Stolpe
00011  */
00012 
00013 package de.picana.clusterer;
00014 
00015 import de.picana.control.*;
00016 import de.picana.logging.*;
00017 import de.picana.math.*;
00018 
00019 import java.io.*;
00020 import java.util.*;
00021 
00022 import weka.core.*;
00023 
00024 
00034 public class FastAMLJ extends GenericML {
00035     
00036     private double rho;
00037     private boolean pruning = true;
00038     
00039     private IndexList index_list;
00040     private IndexListIterator[] outer_iter_min;
00041     private IndexListIterator[] outer_iter_max;
00042     private IndexListIterator[] inner_iter_min;
00043     private IndexListIterator[] inner_iter_max;
00044     
00045     private MLVector[] corners;
00046     private MLVector corner;
00047     private double[] mean;
00048     private int numCorners;
00049     private boolean outer_finished;
00050     private boolean inner_finished;
00051     private HashMap max_index;
00052     private double max_dist;
00053     private int sum_pruned = 0;
00054     
00055     private long time_sort;
00056     
00057     
00059     public FastAMLJ() {
00060     }
00061     
00062     public void init(ParameterSet params, Logger logger) {
00063         
00064         super.init(params, logger);
00065         
00066         rho = 0.0;
00067         try {
00068             rho = Double.parseDouble((String)params.getParameter("rho"));
00069         } catch (NumberFormatException nfe) {}
00070         
00071         String p = (String)params.getParameter("pruning");
00072         if (p.equals("true"))
00073             pruning = true;
00074         else
00075             pruning = false;
00076     }
00077     
00078     
00079     protected void buildFirst() {
00080         
00081         statwriter.println("algo_clusterer_name:FastAMLJ");
00082         statwriter.println("algo_clusterer_rho:" + rho);
00083         statwriter.println("algo_clusterer_pruning:" + pruning);
00084         
00085         int i, j;
00086         MLVector vec;
00087         int index;
00088         Integer freq;
00089         
00090         logger.info(LOGSRC, "Build sorted index table ...");
00091         
00092         time_sort = System.currentTimeMillis();
00093         
00094         for (i=0; i < freq_table.size(); i++) {
00095             
00096             vec = (MLVector)freq_table.get(i);
00097             dim = vec.dim;
00098             
00099             for (j=0; j < vec.dim; j++) {
00100                 vec.value[j] *= Math.pow(
00101                 rho * ((double)vec.freq / (double)training_set.numInstances()) + (1-rho), rho);
00102             }
00103             
00104             //System.out.println(vec);
00105         }
00106         
00107         index_list = new IndexList(freq_table);
00108         
00109         time_sort = System.currentTimeMillis() - time_sort;
00110         
00111         logger.info(LOGSRC, "sorted index table built.");
00112         
00113         logger.info(LOGSRC, "Sorting took " + getTimeString(time_sort));
00114         
00115         // get corners
00116         
00117         mean = new double[dim];
00118         
00119         double min;
00120         double max;
00121         
00122         for (i=0; i < dim; i++) {
00123             vec = (MLVector)freq_table.get(index_list.getFirst(i));
00124             min = vec.value[i];
00125             vec = (MLVector)freq_table.get(index_list.getLast(i));
00126             max = vec.value[i];
00127             mean[i] = (max + min) / 2.0;
00128         }
00129         
00130         corner = new MLVector(dim);
00131         
00132         numCorners = (int)Math.pow(2, (double)dim);
00133         corners = new MLVector[numCorners];
00134         int bit;
00135         
00136         logger.debug(LOGSRC, "Find " + numCorners + " corners ...");
00137         
00138         for (i=0; i < numCorners; i++) {
00139             
00140             corners[i] = new MLVector(dim);
00141             
00142             for (j=0; j < dim; j++) {
00143                 
00144                 bit = (i & (1 << j)) >> j;
00145                 
00146                 if (bit == 0) {
00147                     index = index_list.getFirst(j);
00148                     vec = (MLVector)freq_table.get(index);
00149                     corners[i].value[j] = vec.value[j];
00150                 } else {
00151                     index = index_list.getLast(j);
00152                     vec = (MLVector)freq_table.get(index);
00153                     corners[i].value[j] = vec.value[j];
00154                 }
00155             }
00156             
00157             logger.debug(LOGSRC, "corner[" + i + "] = " + corners[i].toString());
00158         }
00159         
00160         logger.debug(LOGSRC, "corners found.");
00161         
00162         // outer loop
00163         
00164         max_dist = 0.0;
00165         max_index = new HashMap();
00166         
00167         outer_iter_min = new IndexListIterator[dim];
00168         outer_iter_max = new IndexListIterator[dim];
00169         inner_iter_min = new IndexListIterator[dim];
00170         inner_iter_max = new IndexListIterator[dim];
00171         
00172         for (i=0; i < dim; i++) {
00173             outer_iter_min[i] = index_list.iterator(i);
00174             outer_iter_max[i] = index_list.iteratorReverse(i);
00175         }
00176         
00177         outer_finished = false;
00178         while (!outer_finished) {
00179             
00180             //System.out.println(sum_pruned + " " + freq_table.size() + " " + i);
00181             
00182             for (i=0; i < dim; i++) {
00183                 if (!(sum_pruned > (freq_table.size()-2))) {
00184                     if (outer_iter_min[i].hasNext()) {
00185                         index = outer_iter_min[i].next();
00186                         find_new_max_dist(index);
00187                         //System.out.println(sum_pruned + " " + freq_table.size() + " " + i);
00188                     }
00189                 }
00190                 if (!(sum_pruned > (freq_table.size()-2))) {
00191                     if (outer_iter_max[i].hasPrevious()) {
00192                         index = outer_iter_max[i].previous();
00193                         find_new_max_dist(index);
00194                         //System.out.println(sum_pruned + " " + freq_table.size() + " " + i);
00195                     }
00196                 }
00197             }
00198             
00199             if (sum_pruned > (freq_table.size()-2))
00200                 outer_finished = true;
00201             else if (outer_iter_min[0].nextIndex() > outer_iter_max[0].previousIndex())
00202                 outer_finished = true;
00203         }
00204         
00205         for (i=0; i < dim; i++) {
00206             index_list.freeIterator(i, outer_iter_min[i]);
00207             index_list.freeIterator(i, outer_iter_max[i]);
00208         }
00209         
00210         
00211         Iterator keys = max_index.keySet().iterator();
00212         while (keys.hasNext()) {
00213             IntegerPair pair = (IntegerPair)keys.next();
00214             MLVector vec_a = (MLVector)freq_table.get(pair.a);
00215             MLVector vec_b = (MLVector)freq_table.get(pair.b);
00216             logger.debug(LOGSRC, "(" + pair.a + ") " + vec_a.toString() + " - " +
00217             "(" + pair.b + ") " + vec_b.toString() + " = " + max_dist);
00218         }
00219         
00220         statwriter.println("stat_Pruned:" + sum_pruned);
00221         
00222         logger.info(LOGSRC, "Pruned " + sum_pruned + " of " + freq_table.size() + " (" +
00223         (freq_table.size()-sum_pruned) + " left) instances in whole.");
00224         
00225         IntegerPair pair = (IntegerPair)getRandomElement(max_index.keySet());
00226         MLVector vec_a = (MLVector)freq_table.get(pair.a);
00227         MLVector vec_b = (MLVector)freq_table.get(pair.b);
00228         centroids.add(vec_a);
00229         logger.info(LOGSRC, "centroid[0] = " + vec_a.toString());
00230         centroids.add(vec_b);
00231         logger.info(LOGSRC, "centroid[1] = " + vec_b.toString());
00232         
00233         statwriter.println("time_clusterer_sort:" + time_sort);
00234     }
00235     
00236     
00237     private void find_new_max_dist(int index1) {
00238         
00239         MLVector vec1 = (MLVector)freq_table.get(index1);
00240         int index2;
00241         
00242         if (vec1.visited == false) {
00243             //System.out.println("Find max_dist for index " + index1);
00244             
00245             int i, j;
00246             
00247             for (i=0; i < dim; i++) {
00248                 inner_iter_min[i] = index_list.iterator(i);
00249                 inner_iter_max[i] = index_list.iteratorReverse(i);
00250             }
00251             
00252             inner_finished = false;
00253             while (!inner_finished) {
00254                 
00255                 boolean prune = false;
00256                 
00257                 for (i=0; i < dim; i++) {
00258                     
00259                     if (inner_iter_min[i].hasNext()) {
00260                         index2 = inner_iter_min[i].next();
00261                         if (find_new_max_dist(index1, index2))
00262                             prune = true;
00263                     }
00264                     if (inner_iter_max[i].hasPrevious()) {
00265                         index2 = inner_iter_max[i].previous();
00266                         if (find_new_max_dist(index1, index2))
00267                             prune = true;
00268                     }
00269                 }
00270                 
00271                 if (prune && pruning)
00272                     prune();
00273                 
00274                 if (sum_pruned > freq_table.size()-2)
00275                     inner_finished = true;
00276                 else if (inner_iter_min[0].nextIndex() > inner_iter_max[0].previousIndex())
00277                     inner_finished = true;
00278                 
00279                 //if (!inner_iter_min[0].hasNext() && !inner_iter_max[0].hasPrevious())
00280                 //    inner_finished = true;
00281             }
00282             
00283             for (i=0; i < dim; i++) {
00284                 index_list.freeIterator(i, inner_iter_min[i]);
00285                 index_list.freeIterator(i, inner_iter_max[i]);
00286             }
00287             
00288             vec1.visited = true;
00289         }
00290     }
00291     
00292     
00293     private boolean find_new_max_dist(int index1, int index2) {
00294         
00295         MLVector vec1 = (MLVector)freq_table.get(index1);
00296         MLVector vec2 = (MLVector)freq_table.get(index2);
00297         
00298         double dist = Distance.euklidian(vec1.value, vec2.value);
00299         
00300         if (dist > max_dist) {
00301             max_index.clear();
00302             max_dist = dist;
00303             
00304             max_index.put(new IntegerPair(index1, index2), new Integer(1));
00305             logger.info(LOGSRC, "Found new maximum distance " + max_dist);
00306             return true;
00307             
00308         } else if (dist == max_dist) {
00309             max_index.put(new IntegerPair(index1, index2), new Integer(1));
00310         }
00311         return false;
00312     }
00313     
00314     
00315     private void prune() {
00316         int pruned = 0;
00317         int i, j;
00318         IndexListIterator iter;
00319         int pindex;
00320         MLVector pvec;
00321         boolean prune;
00322         
00323         for (i=0; i < dim; i++) {
00324             
00325             iter = index_list.iterator(i);
00326             while(iter.hasNext()) {
00327                 pindex = iter.next();
00328                 pvec = (MLVector)freq_table.get(pindex);
00329                 
00330                 if (pvec.pruned)
00331                     prune = true;
00332                 else {
00333                     prune = true;
00334                     /*
00335                     for (j=0; j < numCorners; j++) {
00336                         if (Distance.euklidian(pvec.value, corners[j].value) > max_dist)
00337                             prune = false;
00338                     }
00339                      */
00340                     
00341                     if (Distance.euklidian(pvec.value, calculateCorner(pvec).value)
00342                         > max_dist)     prune = false;
00343                     
00344                     if (prune)
00345                         pruned++;
00346                 }
00347                 
00348                 if (prune) {
00349                     pvec.pruned = true;
00350                     
00351                     iter.remove();
00352                 }
00353             }
00354         }
00355         
00356         logger.info(LOGSRC, "Pruned " + pruned + " instances.");
00357         
00358         sum_pruned += pruned;
00359     }
00360     
00361     
00362     protected void buildRest() {
00363         
00364         int i, j, k;
00365         MLVector vec;
00366         MLVector vec_a;
00367         MLVector vec_b;
00368         List max_index = new ArrayList();
00369         
00370         double max_dist;
00371         double min_dist;
00372         double act_dist;
00373         
00374         for (i=0; i < num_clusters-2; i++) {
00375             
00376             max_dist = 0.0;
00377             max_index.clear();
00378             
00379             for (j=0; j < freq_table.size(); j++) {
00380                 
00381                 vec_a = (MLVector)freq_table.get(j);
00382                 
00383                 min_dist = Double.MAX_VALUE;
00384                 
00385                 for (k=0; k < centroids.size(); k++) {
00386                     
00387                     vec_b = (MLVector)centroids.get(k);
00388                     
00389                     act_dist = Distance.euklidian(vec_a.value, vec_b.value);
00390                     
00391                     if (act_dist < min_dist)
00392                         min_dist = act_dist;
00393                 }
00394                 
00395                 if (min_dist > max_dist) {
00396                     
00397                     max_index.clear();
00398                     max_index.add(new Integer(j));
00399                     max_dist = min_dist;
00400                     
00401                 } else if (min_dist == max_dist) {
00402                     
00403                     max_index.add(new Integer(j));
00404                 }
00405             }
00406             
00407             Integer index = (Integer)getRandomElement(max_index);
00408             vec = (MLVector)freq_table.get(index.intValue());
00409             centroids.add(vec);
00410             logger.info(LOGSRC, "centroid[" + (i+2) + "] = " + vec.toString());
00411         }
00412         
00413         for (i=0; i < centroids.size(); i++) {
00414             
00415             vec = (MLVector)centroids.get(i);
00416             
00417             for (j=0; j < vec.dim; j++) {
00418                 vec.value[j] /= Math.pow(
00419                 rho * ((double)vec.freq / (double)training_set.numInstances()) + (1-rho), rho);
00420             }
00421         }
00422     }
00423     
00424     
00425     protected MLVector calculateCorner(MLVector vec) {
00426         
00427         for (int i=0; i < dim; i++) {
00428             
00429             if (vec.value[i] <= mean[i])
00430                 
00431                 corner.value[i] = ((MLVector)freq_table.get(index_list.getLast(i))).value[i];
00432             else
00433                 corner.value[i] = ((MLVector)freq_table.get(index_list.getFirst(i))).value[i];
00434         }
00435         return corner;
00436     }
00437 }
00438 

Erzeugt am Tue Apr 22 11:22:55 2003 f?r Picana von doxygen1.2.18