Hauptseite   Packages   Klassenhierarchie   ?bersicht   Auflistung der Dateien   Datenstruktur-Elemente  

Classifier.java

gehe zur Dokumentation dieser Datei
00001 /*
00002  * $Source: /shared/cvsroot/diplom/app/src/java/de/picana/classifier/Classifier.java,v $
00003  * $Author: mstolpe $
00004  * $Date: 2003/04/22 09:51:27 $
00005  * $Revision: 1.8 $
00006  * $Release$ 
00007  *
00008  * Created on 01. November 2002, 18:00
00009  *
00010  * Copyright 2002 by Marco Stolpe
00011  */
00012 
00013 package de.picana.classifier;
00014 
00015 import de.picana.clusterer.MLVector;
00016 import de.picana.control.*;
00017 import de.picana.logging.*;
00018 import de.picana.math.*;
00019 
00020 import weka.core.*;
00021 import java.io.*;
00022 import java.util.*;
00023 import java.text.*;
00024 
00025 
00032 public abstract class Classifier extends Task {
00033  
00035     protected String infile;
00037     protected String modelfile;
00039     protected String centroidfile;
00041     protected String statfile;
00044     protected String outfile;
00046     protected double param_a;
00047    
00049     protected int num_clusters;
00050     
00051     private int num_attributes;
00052     private Instances training_set;
00053     private Instances result_set;
00054     private Instances[] clusters;
00055     private double[][] means;
00056 
00058     protected DecimalFormat decf;
00059     
00060     
00062     public Classifier() {
00063     }
00064 
00078     public void init(ParameterSet params, Logger logger) {
00079         
00080         super.init(params, logger);
00081         
00082         decf = (DecimalFormat)DecimalFormat.getInstance(Locale.ENGLISH);
00083         decf.applyPattern("####0.000000");
00084         
00085         infile = (String)params.getParameter("in");
00086         modelfile = (String)params.getParameter("model");
00087         outfile = (String)params.getParameter("out");
00088         centroidfile = (String)params.getParameter("centroids");
00089         statfile = (String)params.getParameter("statistics");
00090         
00091         /*
00092         param_a = 0.0;
00093         try {
00094             param_a = Double.parseDouble((String)params.getParameter("a"));
00095         } catch (NumberFormatException nfe) {}
00096          */
00097     }
00098     
00100     public void start() throws TaskException {
00101     
00102         try {
00103             logger.info(LOGSRC, "Started.");
00104             
00105             logger.info(LOGSRC, "Read infile '" + infile + "' ...");
00106             File input = new File(infile);
00107             FileInputStream fis = new FileInputStream(input);
00108             InputStreamReader reader = new InputStreamReader(fis);
00109             training_set = new Instances(reader); 
00110             num_attributes = training_set.numAttributes();
00111             logger.info(LOGSRC, "reading '" + infile + "' done.");
00112             
00113             logger.info(LOGSRC, "Read modelfile '" + modelfile + "' ...");
00114             loadModel(modelfile);
00115             logger.info(LOGSRC, "reading '" + modelfile + "' done.");
00116             
00117             logger.info(LOGSRC, "Classify instances ...");
00118                         
00119             clusters = new Instances[num_clusters];
00120             means = new double[num_clusters][num_attributes];
00121             
00122             for (int i=0; i < training_set.numInstances(); i++) {
00123             
00124                 Instance ex = training_set.instance(i);
00125                 int cl = classify(ex);
00126 
00127                 if (clusters[cl] == null)
00128                     clusters[cl] = new Instances(training_set, 0);
00129                 
00130                 clusters[cl].add(ex);
00131             }
00132             
00133             logger.info(LOGSRC, "classification done.");
00134             
00135             
00136             logger.info(LOGSRC, "Calculate meta data ...");
00137         
00138             MLVector vec;
00139             Integer freq;
00140             Instance inst;
00141             int dim;
00142         
00143             Map freq_htable = new HashMap();
00144         
00145             // build frequency hash table
00146         
00147             for (int i=0; i < training_set.numInstances(); i++) {
00148                 inst = training_set.instance(i);
00149                 dim = inst.numAttributes();
00150             
00151                 // convert instance into vector
00152                 vec = new MLVector(dim);
00153                 for (int j=0; j < dim; j++)
00154                     vec.value[j] = inst.value(j);   
00155             
00156                 freq = (Integer)freq_htable.get(vec);
00157             
00158                 freq_htable.put(vec, (freq==null) ? ONE :
00159                     new Integer(freq.intValue() + 1));
00160             }
00161                         
00162             double[] min_dists = new double[num_clusters];
00163             int[] min_cls = new int[num_clusters];
00164             double[] stats = new double[2];
00165             
00166             double empvar = Stats.getEmpVar(training_set) /
00167                             (training_set.numInstances()-1);
00168             double sst = Stats.getSST(training_set);
00169             double ssb = Stats.getSSB(training_set, clusters);
00170             double ssw = Stats.getSSW(clusters);
00171             double mmd = Distance.getMMD(clusters, min_dists, min_cls);
00172             double gmmd_a1 = Distance.getGMMD(clusters, 0.1, min_dists, min_cls, stats);
00173             double gmmd_a5 = Distance.getGMMD(clusters, 0.5, min_dists, min_cls, stats);
00174             double gmmd_a9 = Distance.getGMMD(clusters, 0.9, min_dists, min_cls, stats);
00175             
00176             logger.info(LOGSRC, "Instances:  " + training_set.numInstances());
00177             logger.info(LOGSRC, "UniqueInst: " + freq_htable.size());
00178             logger.info(LOGSRC, "Clusters:   " + num_clusters);
00179             logger.info(LOGSRC, "EmpVar:     " + decf.format(empvar));
00180             logger.info(LOGSRC, "SST:        " + decf.format(sst));
00181             logger.info(LOGSRC, "SSB:        " + decf.format(ssb));
00182             logger.info(LOGSRC, "SSW:        " + decf.format(ssw));
00183             logger.info(LOGSRC, "SSB + SSW:  " + decf.format(ssb + ssw));
00184             logger.info(LOGSRC, "MMD:        " + decf.format(mmd));
00185             //logger.info(LOGSRC, "GMMD_di:    " + decf.format(stats[0]));
00186             //logger.info(LOGSRC, "GINI:       " + decf.format(stats[1]));
00187             logger.info(LOGSRC, "GMMD_a1:    " + decf.format(gmmd_a1));
00188             logger.info(LOGSRC, "GMMD_a5:    " + decf.format(gmmd_a5));
00189             logger.info(LOGSRC, "GMMD_a9:    " + decf.format(gmmd_a9));
00190             logger.info(LOGSRC, "calculations for meta data done.");
00191                         
00192             if (statfile != null) {
00193                 
00194                 logger.info(LOGSRC, "Write statistics.");
00195                 
00196                 PrintWriter statw = new PrintWriter(
00197                     new FileOutputStream(new File(statfile), true));
00198             
00199                 statw.println("stat_Instances:" + training_set.numInstances());
00200                 statw.println("stat_UniqueInst:" + freq_htable.size());
00201                 statw.println("stat_Clusters:" + num_clusters);
00202                 statw.println("stat_EMPVAR: " + empvar);
00203                 statw.println("stat_SST:" + sst);
00204                 statw.println("stat_SSB:" + ssb);
00205                 statw.println("stat_SSW:" + ssw);
00206                 statw.println("stat_MMD:" + mmd);
00207                 statw.println("stat_GMMD_a1:" + gmmd_a1);
00208                 statw.println("stat_GMMD_a5:" + gmmd_a5);
00209                 statw.println("stat_GMMD_a9:" + gmmd_a9);
00210             
00211                 statw.close();
00212             }
00213             
00214             logger.info(LOGSRC, "Calculate cluster means ...");
00215             for (int cl=0; cl < num_clusters; cl++) {
00216                 
00217                 if (clusters[cl] != null) {
00218                     
00219                     String output = "cluster[" + cl + "] = (";
00220                     
00221                     for (int att=0; att < num_attributes; att++) { 
00222                     
00223                         means[cl][att] = clusters[cl].meanOrMode(att);
00224                         output += decf.format(means[cl][att]);
00225                         if (att != num_attributes-1)
00226                             output += ", ";
00227                     }
00228                     
00229                     output += ") -> " + clusters[cl].numInstances();
00230                 
00231                     logger.info(LOGSRC, output);
00232                 }
00233             }
00234             logger.info(LOGSRC, "calculations for cluster means done.");
00235 
00236             if (centroidfile != null) {
00237                 
00238                 logger.info(LOGSRC, "Write centroids.");    
00239             
00240                 PrintWriter centroidsw = new PrintWriter(
00241                     new FileOutputStream(new File(centroidfile)));
00242                 
00243                 centroidsw.println("Clusters:" + num_clusters);
00244                 for (int att=0; att < num_attributes; att++) {
00245                     centroidsw.print(training_set.attribute(att).name());
00246                     if (att != num_attributes-1)
00247                         centroidsw.print(",");
00248                 }
00249                 centroidsw.println(":instances");
00250                 
00251                 for (int cl=0; cl < num_clusters; cl++) {
00252                     
00253                     if (clusters[cl] != null) {
00254                         for (int att=0; att < num_attributes; att++) {
00255                             centroidsw.print(means[cl][att]);
00256                             if (att != num_attributes-1)
00257                                 centroidsw.print(",");
00258                         }
00259                         centroidsw.println(":" + clusters[cl].numInstances());
00260                     } else {
00261                         for (int att=0; att < num_attributes; att++) {
00262                             centroidsw.print(0.0);
00263                             if (att != num_attributes-1)
00264                                 centroidsw.print(",");
00265                         }
00266                         centroidsw.println(":0");
00267                     }
00268                 }
00269                 
00270                 centroidsw.close();
00271             }
00272             
00273             logger.info(LOGSRC, "Classify instances and replace them by cluster means ...");
00274             
00275             result_set = new Instances(training_set, 0);
00276             
00277             for (int i=0; i < training_set.numInstances(); i++) {
00278             
00279                 Instance ex = training_set.instance(i);
00280                 Instance res = new Instance(ex);
00281                 int cl = classify(ex);
00282 
00283                 for (int att=0; att < num_attributes; att++)                
00284                     res.setValue(att, means[cl][att]);
00285 
00286                 result_set.add(res);
00287             }
00288             
00289             for (int cl=0; cl < num_clusters; cl++) {
00290                 if (clusters[cl] != null)
00291                     clusters[cl] = null;
00292             }
00293             clusters = null;
00294             training_set = null;
00295             
00296             logger.info(LOGSRC, "all instances replaced.");
00297             
00298             logger.info(LOGSRC, "Write classified instances to outfile '" + outfile + "'");
00299             
00300             File output = new File(outfile);
00301             PrintWriter pw = new PrintWriter(new FileOutputStream(output));
00302             
00303             pw.println("% ARFF file");
00304             pw.println("%");
00305             pw.println("@relation " + result_set.relationName());
00306             pw.println();
00307             for (int att=0; att < num_attributes; att++)
00308                 pw.println("@attribute " + result_set.attribute(att).name() + " numeric");
00309             pw.println();
00310             pw.println("@data");
00311             pw.println("%");
00312             pw.println("% x instances");
00313             pw.println("%");
00314             pw.flush();
00315             
00316             for (int i=0; i < result_set.numInstances(); i++) {
00317             
00318                 Instance res = result_set.instance(i);
00319                 String line = "";
00320                 for (int att=0; att < num_attributes; att++)
00321                     line += res.value(att) + ",";
00322                 line = line.substring(0, line.length()-1);                
00323                 pw.println(line);
00324                 pw.flush();
00325                 //System.out.println(i);
00326                 System.out.flush();
00327             }
00328             
00329             pw.close();
00330             
00331             logger.info(LOGSRC, "classified instances written to '" + outfile + "'");
00332             
00333             logger.info(LOGSRC, "Stopped.");
00334         
00335         } catch (Exception e) {
00336             logger.error(LOGSRC, e.toString());
00337             throw new TaskException(e.toString());
00338         }    
00339     }
00340     
00341     public void stop() {
00342     }
00343     
00344     public void pause() {
00345     }
00346     
00347     public void resume() {
00348     }
00349     
00354     protected abstract void loadModel(String filename) throws TaskException;
00355    
00361     protected abstract int classify(Instance i);
00362 }

Erzeugt am Tue Apr 22 11:22:55 2003 f?r Picana von doxygen1.2.18