00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013 package de.picana.classifier;
00014
00015 import de.picana.clusterer.MLVector;
00016 import de.picana.control.*;
00017 import de.picana.logging.*;
00018 import de.picana.math.*;
00019
00020 import weka.core.*;
00021 import java.io.*;
00022 import java.util.*;
00023 import java.text.*;
00024
00025
00032 public abstract class Classifier extends Task {
00033
00035 protected String infile;
00037 protected String modelfile;
00039 protected String centroidfile;
00041 protected String statfile;
00044 protected String outfile;
00046 protected double param_a;
00047
00049 protected int num_clusters;
00050
00051 private int num_attributes;
00052 private Instances training_set;
00053 private Instances result_set;
00054 private Instances[] clusters;
00055 private double[][] means;
00056
00058 protected DecimalFormat decf;
00059
00060
00062 public Classifier() {
00063 }
00064
00078 public void init(ParameterSet params, Logger logger) {
00079
00080 super.init(params, logger);
00081
00082 decf = (DecimalFormat)DecimalFormat.getInstance(Locale.ENGLISH);
00083 decf.applyPattern("####0.000000");
00084
00085 infile = (String)params.getParameter("in");
00086 modelfile = (String)params.getParameter("model");
00087 outfile = (String)params.getParameter("out");
00088 centroidfile = (String)params.getParameter("centroids");
00089 statfile = (String)params.getParameter("statistics");
00090
00091
00092
00093
00094
00095
00096
00097 }
00098
00100 public void start() throws TaskException {
00101
00102 try {
00103 logger.info(LOGSRC, "Started.");
00104
00105 logger.info(LOGSRC, "Read infile '" + infile + "' ...");
00106 File input = new File(infile);
00107 FileInputStream fis = new FileInputStream(input);
00108 InputStreamReader reader = new InputStreamReader(fis);
00109 training_set = new Instances(reader);
00110 num_attributes = training_set.numAttributes();
00111 logger.info(LOGSRC, "reading '" + infile + "' done.");
00112
00113 logger.info(LOGSRC, "Read modelfile '" + modelfile + "' ...");
00114 loadModel(modelfile);
00115 logger.info(LOGSRC, "reading '" + modelfile + "' done.");
00116
00117 logger.info(LOGSRC, "Classify instances ...");
00118
00119 clusters = new Instances[num_clusters];
00120 means = new double[num_clusters][num_attributes];
00121
00122 for (int i=0; i < training_set.numInstances(); i++) {
00123
00124 Instance ex = training_set.instance(i);
00125 int cl = classify(ex);
00126
00127 if (clusters[cl] == null)
00128 clusters[cl] = new Instances(training_set, 0);
00129
00130 clusters[cl].add(ex);
00131 }
00132
00133 logger.info(LOGSRC, "classification done.");
00134
00135
00136 logger.info(LOGSRC, "Calculate meta data ...");
00137
00138 MLVector vec;
00139 Integer freq;
00140 Instance inst;
00141 int dim;
00142
00143 Map freq_htable = new HashMap();
00144
00145
00146
00147 for (int i=0; i < training_set.numInstances(); i++) {
00148 inst = training_set.instance(i);
00149 dim = inst.numAttributes();
00150
00151
00152 vec = new MLVector(dim);
00153 for (int j=0; j < dim; j++)
00154 vec.value[j] = inst.value(j);
00155
00156 freq = (Integer)freq_htable.get(vec);
00157
00158 freq_htable.put(vec, (freq==null) ? ONE :
00159 new Integer(freq.intValue() + 1));
00160 }
00161
00162 double[] min_dists = new double[num_clusters];
00163 int[] min_cls = new int[num_clusters];
00164 double[] stats = new double[2];
00165
00166 double empvar = Stats.getEmpVar(training_set) /
00167 (training_set.numInstances()-1);
00168 double sst = Stats.getSST(training_set);
00169 double ssb = Stats.getSSB(training_set, clusters);
00170 double ssw = Stats.getSSW(clusters);
00171 double mmd = Distance.getMMD(clusters, min_dists, min_cls);
00172 double gmmd_a1 = Distance.getGMMD(clusters, 0.1, min_dists, min_cls, stats);
00173 double gmmd_a5 = Distance.getGMMD(clusters, 0.5, min_dists, min_cls, stats);
00174 double gmmd_a9 = Distance.getGMMD(clusters, 0.9, min_dists, min_cls, stats);
00175
00176 logger.info(LOGSRC, "Instances: " + training_set.numInstances());
00177 logger.info(LOGSRC, "UniqueInst: " + freq_htable.size());
00178 logger.info(LOGSRC, "Clusters: " + num_clusters);
00179 logger.info(LOGSRC, "EmpVar: " + decf.format(empvar));
00180 logger.info(LOGSRC, "SST: " + decf.format(sst));
00181 logger.info(LOGSRC, "SSB: " + decf.format(ssb));
00182 logger.info(LOGSRC, "SSW: " + decf.format(ssw));
00183 logger.info(LOGSRC, "SSB + SSW: " + decf.format(ssb + ssw));
00184 logger.info(LOGSRC, "MMD: " + decf.format(mmd));
00185
00186
00187 logger.info(LOGSRC, "GMMD_a1: " + decf.format(gmmd_a1));
00188 logger.info(LOGSRC, "GMMD_a5: " + decf.format(gmmd_a5));
00189 logger.info(LOGSRC, "GMMD_a9: " + decf.format(gmmd_a9));
00190 logger.info(LOGSRC, "calculations for meta data done.");
00191
00192 if (statfile != null) {
00193
00194 logger.info(LOGSRC, "Write statistics.");
00195
00196 PrintWriter statw = new PrintWriter(
00197 new FileOutputStream(new File(statfile), true));
00198
00199 statw.println("stat_Instances:" + training_set.numInstances());
00200 statw.println("stat_UniqueInst:" + freq_htable.size());
00201 statw.println("stat_Clusters:" + num_clusters);
00202 statw.println("stat_EMPVAR: " + empvar);
00203 statw.println("stat_SST:" + sst);
00204 statw.println("stat_SSB:" + ssb);
00205 statw.println("stat_SSW:" + ssw);
00206 statw.println("stat_MMD:" + mmd);
00207 statw.println("stat_GMMD_a1:" + gmmd_a1);
00208 statw.println("stat_GMMD_a5:" + gmmd_a5);
00209 statw.println("stat_GMMD_a9:" + gmmd_a9);
00210
00211 statw.close();
00212 }
00213
00214 logger.info(LOGSRC, "Calculate cluster means ...");
00215 for (int cl=0; cl < num_clusters; cl++) {
00216
00217 if (clusters[cl] != null) {
00218
00219 String output = "cluster[" + cl + "] = (";
00220
00221 for (int att=0; att < num_attributes; att++) {
00222
00223 means[cl][att] = clusters[cl].meanOrMode(att);
00224 output += decf.format(means[cl][att]);
00225 if (att != num_attributes-1)
00226 output += ", ";
00227 }
00228
00229 output += ") -> " + clusters[cl].numInstances();
00230
00231 logger.info(LOGSRC, output);
00232 }
00233 }
00234 logger.info(LOGSRC, "calculations for cluster means done.");
00235
00236 if (centroidfile != null) {
00237
00238 logger.info(LOGSRC, "Write centroids.");
00239
00240 PrintWriter centroidsw = new PrintWriter(
00241 new FileOutputStream(new File(centroidfile)));
00242
00243 centroidsw.println("Clusters:" + num_clusters);
00244 for (int att=0; att < num_attributes; att++) {
00245 centroidsw.print(training_set.attribute(att).name());
00246 if (att != num_attributes-1)
00247 centroidsw.print(",");
00248 }
00249 centroidsw.println(":instances");
00250
00251 for (int cl=0; cl < num_clusters; cl++) {
00252
00253 if (clusters[cl] != null) {
00254 for (int att=0; att < num_attributes; att++) {
00255 centroidsw.print(means[cl][att]);
00256 if (att != num_attributes-1)
00257 centroidsw.print(",");
00258 }
00259 centroidsw.println(":" + clusters[cl].numInstances());
00260 } else {
00261 for (int att=0; att < num_attributes; att++) {
00262 centroidsw.print(0.0);
00263 if (att != num_attributes-1)
00264 centroidsw.print(",");
00265 }
00266 centroidsw.println(":0");
00267 }
00268 }
00269
00270 centroidsw.close();
00271 }
00272
00273 logger.info(LOGSRC, "Classify instances and replace them by cluster means ...");
00274
00275 result_set = new Instances(training_set, 0);
00276
00277 for (int i=0; i < training_set.numInstances(); i++) {
00278
00279 Instance ex = training_set.instance(i);
00280 Instance res = new Instance(ex);
00281 int cl = classify(ex);
00282
00283 for (int att=0; att < num_attributes; att++)
00284 res.setValue(att, means[cl][att]);
00285
00286 result_set.add(res);
00287 }
00288
00289 for (int cl=0; cl < num_clusters; cl++) {
00290 if (clusters[cl] != null)
00291 clusters[cl] = null;
00292 }
00293 clusters = null;
00294 training_set = null;
00295
00296 logger.info(LOGSRC, "all instances replaced.");
00297
00298 logger.info(LOGSRC, "Write classified instances to outfile '" + outfile + "'");
00299
00300 File output = new File(outfile);
00301 PrintWriter pw = new PrintWriter(new FileOutputStream(output));
00302
00303 pw.println("% ARFF file");
00304 pw.println("%");
00305 pw.println("@relation " + result_set.relationName());
00306 pw.println();
00307 for (int att=0; att < num_attributes; att++)
00308 pw.println("@attribute " + result_set.attribute(att).name() + " numeric");
00309 pw.println();
00310 pw.println("@data");
00311 pw.println("%");
00312 pw.println("% x instances");
00313 pw.println("%");
00314 pw.flush();
00315
00316 for (int i=0; i < result_set.numInstances(); i++) {
00317
00318 Instance res = result_set.instance(i);
00319 String line = "";
00320 for (int att=0; att < num_attributes; att++)
00321 line += res.value(att) + ",";
00322 line = line.substring(0, line.length()-1);
00323 pw.println(line);
00324 pw.flush();
00325
00326 System.out.flush();
00327 }
00328
00329 pw.close();
00330
00331 logger.info(LOGSRC, "classified instances written to '" + outfile + "'");
00332
00333 logger.info(LOGSRC, "Stopped.");
00334
00335 } catch (Exception e) {
00336 logger.error(LOGSRC, e.toString());
00337 throw new TaskException(e.toString());
00338 }
00339 }
00340
00341 public void stop() {
00342 }
00343
00344 public void pause() {
00345 }
00346
00347 public void resume() {
00348 }
00349
00354 protected abstract void loadModel(String filename) throws TaskException;
00355
00361 protected abstract int classify(Instance i);
00362 }