Hauptseite   Packages   Klassenhierarchie   ?bersicht   Auflistung der Dateien   Datenstruktur-Elemente  

SON.java

gehe zur Dokumentation dieser Datei
00001 /*
00002  * $Source: /shared/cvsroot/diplom/app/src/java/de/picana/clusterer/SON.java,v $
00003  * $Author: mstolpe $
00004  * $Date: 2003/04/22 09:51:28 $
00005  * $Revision: 1.5 $
00006  * $Release$ 
00007  *
00008  * Created on 5. M?rz 2003, 21:50
00009  *
00010  * Copyright 2002 by Marco Stolpe
00011  */
00012 
00013 package de.picana.clusterer;
00014 
00015 import de.picana.control.*;
00016 import de.picana.logging.*;
00017 import de.picana.math.*;
00018 
00019 import java.io.*;
00020 import java.util.*;
00021 
00022 import weka.core.*;
00023 
00024 
00031 public class SON extends Clusterer {
00032     
00033     protected static final int METHOD_KOHONEN = 0;
00034     protected static final int METHOD_WTAN = 1;
00035     protected static final int NEIGH_GAUSSIAN = 0;
00036     protected static final int NEIGH_WTAN = 1;
00037     protected static final int NEIGH_KOHONEN = 2;
00038     protected static final int TOPO_SQUARE = 0;
00039     protected static final int TOPO_HEXA = 1;
00040     protected static final int ADPT_STEP = 0;
00041     protected static final int ADPT_CONT = 1;
00042     protected static final int ADPT_SNNS = 2;
00043     protected static final int ADPT_PICANA = 3;
00044 
00046     protected int cycles;
00048     protected double radius;
00050     protected String preset;
00051     
00052     protected int method;
00053     protected int neigh;
00054     protected int topo;
00055     protected int adpt;
00056      
00057     protected int width;
00058     protected int height;
00059     
00060     protected double adpt_height;
00061     protected double adpt_radius;
00062     protected double mult_height;
00063     protected double mult_radius;
00064     protected double dist_winner;
00065 
00066     protected int micro_step;
00067     protected int step;
00068     protected int nr_examples;
00069     
00070     protected double b;
00071     protected int dim;
00072     
00073     protected List centroids;
00074     protected List freq_table;
00075     protected Map freq_htable;
00076     
00077     protected long time_total;
00078     
00079     
00080     
00082     public SON() {
00083     }
00084     
00085     
00086     public void init(ParameterSet params, Logger logger) {
00087         
00088         super.init(params, logger);
00089         
00090         cycles = 100;
00091         try {
00092             cycles = Integer.parseInt((String)params.getParameter("cycles"));
00093         } catch (NumberFormatException nfe) {}
00094         
00095         radius = 0.0;
00096         try {
00097             radius = Double.parseDouble((String)params.getParameter("radius"));
00098         } catch (NumberFormatException nfe) {}
00099         
00100         preset = "WTAN";
00101         
00102         if ((String)params.getParameter("preset") != null) {
00103             
00104             preset = ((String)params.getParameter("preset")).toUpperCase();
00105             
00106             if (preset.equals("WTAN")) {
00107                 method = METHOD_WTAN;
00108                 topo = TOPO_SQUARE;
00109                 neigh = NEIGH_WTAN;
00110                 adpt = ADPT_PICANA;
00111             }
00112             if (preset.equals("SOM_SQUARE")) {
00113                 method = METHOD_KOHONEN;
00114                 topo = TOPO_SQUARE;
00115                 neigh = NEIGH_GAUSSIAN;
00116                 adpt = ADPT_STEP;
00117             }
00118             if (preset.equals("SOM_HEXA")) {
00119                 method = METHOD_KOHONEN;
00120                 topo = TOPO_HEXA;
00121                 neigh = NEIGH_GAUSSIAN;
00122                 adpt = ADPT_STEP;
00123             }
00124             if (preset.equals("QSOM_SQUARE")) {
00125                 method = METHOD_KOHONEN;
00126                 topo = TOPO_SQUARE;
00127                 neigh = NEIGH_KOHONEN;
00128                 adpt = ADPT_PICANA;
00129             }
00130             if (preset.equals("QSOM_HEXA")) {
00131                 method = METHOD_KOHONEN;
00132                 topo = TOPO_HEXA;
00133                 neigh = NEIGH_KOHONEN;
00134                 adpt = ADPT_PICANA;
00135             }
00136         }
00137     }
00138     
00139     
00140     public void buildClusterer(Instances set) throws TaskException {
00141         
00142         statwriter.println("algo_clusterer_name:SON");
00143         statwriter.println("algo_clusterer_cycles:" + cycles);
00144         statwriter.println("algo_clusterer_radius:" + radius);
00145         statwriter.println("algo_clusterer_preset:" + preset);
00146         statwriter.println("algo_clusterer_model:" + modelfile);
00147         statwriter.println("algo_clusterer_training_set:" + infile);
00148         
00149         time_total = System.currentTimeMillis();
00150         
00151         // Initialisierung
00152         
00153         dim = set.numAttributes();
00154         
00155         width  = (int)Math.ceil(Math.sqrt((double)num_clusters));
00156         height = num_clusters / width;
00157         if ((num_clusters % width) != 0)
00158             height++;
00159 
00160         micro_step = 1;
00161               
00162         freq_htable = new HashMap();
00163         
00164         int i, j;
00165 
00166         MLVector vec;
00167         Instance inst;
00168         Integer freq;
00169             
00170         // build frequency hash table
00171         
00172         logger.info(LOGSRC, "Build frequency table from " + training_set.numInstances() + " instances ...");
00173         for (i=0; i < training_set.numInstances(); i++) {
00174             vec = new MLVector(training_set.instance(i));
00175             freq = (Integer)freq_htable.get(vec);
00176             
00177             freq_htable.put(vec, (freq==null) ? ONE :
00178                 new Integer(freq.intValue() + 1));
00179         }
00180         logger.info(LOGSRC, "frequency table containing " + freq_htable.size() + " instances built.");
00181             
00182             
00183         // schreibe die Instanzen in neue Liste
00184         // Instanzen in neue Liste
00185             
00186         freq_table = new ArrayList();
00187             
00188         Iterator iter = freq_htable.entrySet().iterator();
00189         while (iter.hasNext()) {
00190             Map.Entry entry = (Map.Entry)iter.next();
00191             vec = (MLVector)entry.getKey();
00192             freq = (Integer)entry.getValue();
00193             freq_table.add(vec);
00194         }
00195         
00196         
00197         // Generiere Startwerte f?r die Vektoren des Netzes durch zuf?llige
00198         // Auswahl von Instanzen aus der Trainingsmenge und zuf?lliger Ver?nderung
00199         // der Instanzen
00200         
00201         centroids = new ArrayList();
00202         
00203         MLVector rand_vec;
00204         int rand_index;
00205         
00206         for (int c=0; c < num_clusters; c++) {
00207             
00208             rand_index = rand.nextInt(freq_table.size());
00209             rand_vec = MLVector.getRandom(
00210                 (MLVector)freq_table.get(rand_index), radius);
00211             
00212             centroids.add(rand_vec);
00213         }
00214         
00215         for (i=0; i < centroids.size(); i++) {
00216             logger.debug(LOGSRC, "weight[" + i + "] = " + (MLVector)centroids.get(i));
00217         }
00218         
00219         // Parameter b
00220         
00221         b = Math.pow (0.005, (double)(1.0 /
00222             ((double)set.numInstances() * (double)cycles)));
00223         
00224         // Wiederhole das Training f?r cycles Lernschritte
00225         
00226         for (int c=0; c < cycles; c++) {
00227         
00228             logger.verbose(LOGSRC, "Cycle: " + (c+1));
00229             
00230             // Trainiere das Netz durch Auswahl einer Instanz aus der
00231             // Indextabelle und aktualisiere die Tabelle
00232         
00233             freq_table = new ArrayList();
00234             
00235             iter = freq_htable.entrySet().iterator();
00236             while (iter.hasNext()) {
00237                 Map.Entry entry = (Map.Entry)iter.next();
00238                 vec = (MLVector)entry.getKey();
00239                 freq = (Integer)entry.getValue();
00240                 vec.freq = freq.intValue();
00241                 freq_table.add(vec);
00242             }
00243             
00244             while (freq_table.size() > 0) {
00245                 
00246                 rand_index = rand.nextInt(freq_table.size());
00247                 vec = (MLVector)freq_table.get(rand_index);
00248                 vec.freq--;
00249                 
00250                 if (vec.freq == 0)
00251                     freq_table.remove(rand_index);
00252                                
00253                 adpt_height = b;
00254                 adpt_radius = b;
00255                 mult_height = 0.9999;
00256                 mult_radius = 0.9999;
00257                 nr_examples = set.numInstances();
00258                 step = c;
00259                 
00260                 if (preset.startsWith("SOM")) {
00261                     adpt_height = 1.0;
00262                     adpt_radius = width;
00263                     nr_examples = cycles;
00264                 }
00265                 
00266                 if ((adpt == ADPT_STEP) || (adpt == ADPT_PICANA))
00267                     update ();
00268                 
00269                 // trainiere das Netz mit diesem Vektor
00270                 train(vec);
00271             }
00272             
00273             //for (i=0; i < centroids.size(); i++) {
00274             //    logger.debug(LOGSRC, "weight[" + i + "] = " + (MLVector)centroids.get(i));
00275             //}
00276         }
00277         
00278         time_total = System.currentTimeMillis() - time_total;
00279         
00280         logger.info(LOGSRC, "Algorithm took " + getTimeString(time_total));
00281         
00282         for (i=0; i < centroids.size(); i++) {
00283             logger.verbose(LOGSRC, "centroid[" + i + "] = " + (MLVector)centroids.get(i));
00284         }
00285         
00286         statwriter.println("time_clusterer_total:" + time_total);
00287     }
00288 
00289     
00290     public void saveModel(String filename) throws TaskException {
00291         
00292         try {
00293             File outfile = new File(filename);
00294             FileOutputStream out = new FileOutputStream(outfile);
00295             PrintWriter pw = new PrintWriter(out);
00296         
00297             for (int att=0; att < training_set.numAttributes(); att++) {
00298                 pw.print(training_set.attribute(att).name());
00299                 if (att != training_set.numAttributes()-1)
00300                     pw.print(",");
00301             }
00302             pw.println();
00303             
00304             for (int i=0; i < centroids.size(); i++) {
00305                 MLVector vec = (MLVector)centroids.get(i);    
00306             
00307                 for (int j=0; j < vec.value.length; j++) {
00308              
00309                     pw.print(vec.value[j]);
00310                     if (j != vec.value.length-1) {
00311                         pw.print(",");    
00312                     }
00313                 }
00314                 pw.println();
00315             }
00316             
00317             pw.close();
00318             
00319         } catch (Exception e) {
00320             throw new TaskException(e.toString());
00321         }
00322     }
00323     
00325     protected double dist_euklidian(int x1, int y1, int x2, int y2) {
00326         double x = x1 - x2;
00327         double y = y1 - y2;
00328         return Math.sqrt ( x*x + y*y );
00329     }
00330     
00332     protected double dist_hexa(int x1, int y1, int x2, int y2) {
00333         double dist;
00334         double val = x1 - x2;
00335         if (((y1 - y2) % 2) != 0)
00336             val = Math.abs(val) - 0.5;
00337          dist = val * val;
00338         val = ((y1-y2)*(y1-y2))*0.866025404;
00339         dist += val;
00340         dist = Math.sqrt(dist);
00341         return dist;
00342     }
00343     
00345     protected double neigh_gaussian() {
00346         
00347         //logger.debug(LOGSRC, "dist_winner = " + dist_winner +
00348         //                   ", mult_height = " + mult_height +
00349         //                   ", mult_radius = " + mult_radius);
00350         
00351         return mult_height * Math.pow (Math.E,
00352             -( (dist_winner / mult_radius) * (dist_winner / mult_radius) ));
00353     }
00354     
00356     protected double neigh_wtan() {
00357         // Achtung! mult_height ist unser Eta
00358 
00359         if (dist_winner >= 1e-10 )
00360         {
00361             //if (dist_winner < Math.sqrt(dim * (15.0/255.0*255.0*3.0)))
00362             //    return mult_height * -0.1;
00363             //else
00364                 //return pow(mult_height,2.0) * (1.0 / (pow(dist_winner, 2.99) + 1.0)); ***
00365                 return Math.pow(mult_height,2.0) * (1.0 / (Math.pow(dist_winner * 441.6729, 1.5) + 1.0));
00366         }
00367         //return pow((0.2 * mult_height),3.0) * 1.0;
00368         //return 0.5 * mult_height; ***
00369         return 0.5 * mult_height;
00370         
00371         // *** Bei 10 und 20 Neuronen, gute Ergebnisse schon bei 1 Zyklus
00372         //     sehr gute Ergebnisse ab 5 Zyklen
00373         //     ab 5 Zyklen insbesondere auch konstant gleichm??ige Ergebnisse
00374         //     bei noch mehr Zyklen kaum noch Ergebnisverbesserung
00375     }
00376     
00377     
00379     protected double neigh_kohonen() {
00380         // return 0.5 * mult_height * pow (2.718281828, -( (dist_winner / mult_radius) *
00381         //                                          (dist_winner / mult_radius) ));
00382 
00383         // Wichtige Variablen:
00384         // micro_step
00385 
00386         // Varianten, die funktioniert!
00387         // return 0.1 * pow(mult_height,2) * (1.0 /(1000 * pow(dist_winner, 3) + 1.0));
00388         // return 0.01 * pow(mult_height,2) * (1.0 /(1000 * pow(dist_winner, 2) + 1.0));
00389         // return 0.01 * pow(mult_height,2) * (1.0 /((100+micro_step) * pow(dist_winner, 5) + 1.0));
00390         // return 0.5 * mult_height * (1.0 /((100+micro_step) * pow(dist_winner, 3) + 1.0));
00391         // return 0.3 * mult_height * (1.0 /(pow(micro_step, 9) * dist_winner + 1.0)); ***
00392 
00393         return 0.3 * mult_height * (1.0 /(Math.pow(micro_step, 9) * dist_winner + 1.0));
00394         
00395         // *** ansatzweise brauchbarer Ansatz bei 20 bis 40 Zyklen
00396     }
00397     
00398     
00400     protected int train(MLVector input) {
00401         int i;
00402         int winner = -1;
00403         double n = 0.0;
00404         MLVector weight;
00405         MLVector sub;
00406         
00407         if (method == METHOD_WTAN)
00408         {
00409             winner = classify (input);
00410 
00411             for (i=0; i < num_clusters; i++)
00412             {
00413                 weight = (MLVector)centroids.get(i);
00414                 
00415                 dist_winner = Distance.euklidian(
00416                     ((MLVector)centroids.get(winner)).value, weight.value);
00417 
00418                 if (neigh == NEIGH_GAUSSIAN)
00419                     n = neigh_gaussian();
00420                 if (neigh == NEIGH_WTAN)
00421                     n = neigh_wtan();
00422                 if (neigh == NEIGH_KOHONEN)
00423                     n = neigh_kohonen();
00424                 
00425                 // Gewicht anpassen gem?? weight = weight + n * (input - weight)
00426                 
00427                 
00428                 sub = MLVector.sub(input, weight);
00429                 sub.mult(n);
00430                 //logger.debug(LOGSRC, "input = " + input + ", weight = " + weight + ", sub*n = " + sub + ", n = " + n);
00431                 weight.add(sub);
00432                 
00433                 //logger.debug(LOGSRC, "newweight = " + weight);
00434             }
00435         }
00436 
00437         if (method == METHOD_KOHONEN)
00438         {
00439             winner = classify (input);
00440 
00441             int x, y, xw, yw;
00442 
00443             xw = winner % width;
00444             yw = winner / width;
00445 
00446             for (i=0; i < num_clusters; i++)
00447             {
00448                 weight = (MLVector)centroids.get(i);
00449                 
00450                 x = i % width;
00451                 y = i / width;
00452 
00453                 if (topo == TOPO_SQUARE)
00454                     dist_winner = dist_euklidian (x, y, xw, yw);
00455 
00456                 if (topo == TOPO_HEXA)
00457                     dist_winner = dist_hexa (x, y, xw, yw);
00458                 
00459                 if (neigh == NEIGH_GAUSSIAN)
00460                     n = neigh_gaussian();
00461                 if (neigh == NEIGH_WTAN)
00462                     n = neigh_wtan();
00463                 if (neigh == NEIGH_KOHONEN)
00464                     n = neigh_kohonen();
00465 
00466                 //if ((micro_step % training_set.numInstances()) == 0) {
00467                 //    logger.debug(LOGSRC, "e^x = " + Math.pow(Math.E, -((dist_winner/mult_radius) * (dist_winner/mult_radius))));
00468                 //    logger.debug(LOGSRC, "n = " + n + ", mh = " + mult_height);
00469                 //}
00470                 
00471                 // Gewicht anpassen gem?? weight = weight + n * (input - weight)
00472                 
00473                 //logger.debug(LOGSRC, "n = " + n);
00474                 sub = MLVector.sub(input, weight);
00475                 
00476                 //logger.debug(LOGSRC, "sub = " + sub);
00477                 
00478                 sub.mult(n);
00479                 
00480                 //logger.debug(LOGSRC, "input = " + input + ", weight = " + weight + ", sub*n = " + sub);
00481                 weight.add(sub);
00482                 
00483                 //logger.debug(LOGSRC, "newweight = " + weight);
00484             }
00485             
00486             //logger.debug(LOGSRC, "");
00487         }
00488         
00489         micro_step++;
00490         update ();
00491 
00492         return winner;
00493     }
00494     
00495     
00497     protected int classify(MLVector input) {
00498     
00499         int winner = -1;
00500         double akt_dist;
00501         double min_dist;
00502         
00503         // Search Winner
00504 
00505         min_dist = Double.MAX_VALUE;
00506     
00507         for (int i=0; i < num_clusters; i++)
00508         {
00509             akt_dist = Distance.euklidian(
00510                 input.value, ((MLVector)centroids.get(i)).value);
00511             
00512             if (akt_dist < min_dist)
00513             {
00514                 min_dist = akt_dist;
00515                 winner = i;
00516             }
00517         }
00518 
00519         return winner;    
00520     }
00521     
00522     
00524     protected void update() {
00525         
00526         switch (adpt)
00527         {
00528             case ADPT_STEP:
00529                 mult_radius = (adpt_radius / (double)nr_examples) *
00530                                (double)(nr_examples - step);
00531                 mult_height = (adpt_height / (double)nr_examples) *
00532                                (double)(nr_examples - step);
00533                 break;
00534 
00535             case ADPT_CONT:
00536                 mult_radius = (adpt_radius / (double)nr_examples) *
00537                                (double)(nr_examples - step);
00538                 mult_height = (adpt_height / (double)nr_examples) *
00539                                (double)(nr_examples - step);
00540                 break;
00541 
00542             case ADPT_SNNS:
00543                 adpt_height *= mult_height;
00544                 adpt_radius *= mult_radius;
00545                 break;
00546 
00547             case ADPT_PICANA:
00548                 mult_radius = Math.pow (adpt_radius, micro_step);
00549                 mult_height = Math.pow (adpt_height, micro_step);
00550                 break;
00551         }
00552     }
00553 }

Erzeugt am Tue Apr 22 11:22:56 2003 f?r Picana von doxygen1.2.18