Repository /Rseslib/rseslib-3.0.2.jar:rseslib.processing.classification.ClassifierSet


Back

No file description

Source code

/*
 * $RCSfile: ClassifierSet.java,v $
 * $Revision: 1.8 $
 * $Date: 2007/06/30 17:30:33 $
 * $Author: wojna $
 * 
 * Copyright (C) 2002 - 2007 Logic Group, Institute of Mathematics, Warsaw University
 * 
 *  This file is part of Rseslib.
 *
 *  Rseslib is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  Rseslib is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


package rseslib.processing.classification;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;

import rseslib.structure.attribute.BadHeaderException;
import rseslib.structure.attribute.NominalAttribute;
import rseslib.structure.data.DoubleData;
import rseslib.structure.data.DoubleDataWithDecision;
import rseslib.structure.table.DoubleDataTable;
import rseslib.system.ConfigurationWithStatistics;
import rseslib.system.PropertyConfigurationException;
import rseslib.system.Report;
import rseslib.system.progress.MultiProgress;
import rseslib.system.progress.Progress;

/**
 * This class enables to test a number of classifiers
 * at once. It can be used by different testing methods.
 *
 * @author Arkadiusz Wojna
 *
 */
public class ClassifierSet implements Serializable
{
    /** Serialization version. */
	private static final long serialVersionUID = 1L;

	/** Map between classifier names and their classes. */
	private Map<String,Class> m_ClassifierTypes = new HashMap<String,Class>();
    /** Map between classifier names and their arguments. */
	private Map<String,Properties> m_ClassifierProperties = new HashMap<String,Properties>();
    /** Map between classifier names and classifiers. */
    private Map<String,Classifier> m_Classifiers = new HashMap<String,Classifier>();

    /**
     * Add a classifier to this set of classifiers.
     *
     * @param name 				Name of a classifier to be added.
     * @param classifierType	Class of a classifier to be added.
     * @param prop				Properties of a classifier to be added.
     * 							If null, default properties are loaded.
     */
	public void addClassifier(String name, Class classifierType, Properties prop) throws PropertyConfigurationException
	{
		if (m_ClassifierTypes.containsKey(name) || m_Classifiers.containsKey(name))
			throw new IllegalArgumentException("A classifier named "+name+" already added.");
		m_ClassifierTypes.put(name, classifierType);
		m_ClassifierProperties.put(name, prop);
	}

    /**
     * Writes this object.
     *
     * @param out			Output for writing.
     * @throws IOException	if an I/O error has occured.
     */
    private void writeObject(ObjectOutputStream out) throws IOException
    {
    	out.writeObject(m_ClassifierTypes);
    	out.writeObject(m_ClassifierProperties);
    }

    /**
     * Reads this object.
     *
     * @param out			Output for writing.
     * @throws IOException	if an I/O error has occured.
     */
    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException
    {
    	m_ClassifierTypes = (Map<String,Class>)in.readObject();
    	m_ClassifierProperties = (Map<String,Properties>)in.readObject();
    	m_Classifiers = new HashMap<String,Classifier>();
    }

    /**
     * Add a classifier with default properties to this set of classifiers.
     *
     * @param name 				Name of a classifier to be added.
     * @param classifierType	Class of a classifier to be added.
     * @param prog				Progress object to report training progress.
     */
	public void addClassifier(String name, Class classifierType) throws PropertyConfigurationException
	{
		addClassifier(name, classifierType, null);
	}

    /**
     * Add a classifier to this set of classifiers.
     *
     * @param name Name of a classifier to be added.
     * @param cl   Classifier to be added.
     */
    public void addClassifier(String name, Classifier cl)
    {
		if (m_ClassifierTypes.containsKey(name) || m_Classifiers.containsKey(name))
			throw new IllegalArgumentException("A classifier named "+name+" already added.");
        m_Classifiers.put(name, cl);
    }

    /**
     * Constructs classifiers to be tested.
     *
     * @param trainTable Training data set.
     */
    public void train(DoubleDataTable trainTable, Progress prog) throws InterruptedException
	{
    	if (m_ClassifierTypes.size()==0)
    	{
    		prog.set("Training classifiers", 1);
    		prog.step();
    		return;
    	}
    	int[] progressVolumes = new int[m_ClassifierTypes.size()];
    	progressVolumes[0] = 100/progressVolumes.length;
    	for (int i = 1; i < progressVolumes.length; i++)
    		progressVolumes[i] = 100*(i+1)/progressVolumes.length-progressVolumes[i-1];
    	prog = new MultiProgress("Training classifiers", prog, progressVolumes);
		for (Map.Entry<String,Class> cl : m_ClassifierTypes.entrySet())
		{
			m_Classifiers.remove(cl.getKey());
			try
			{
				Class classifierClass = cl.getValue();
				Properties prop = m_ClassifierProperties.get(cl.getKey());
				Classifier classifier = ClassifierFactory.createClassifier(classifierClass, prop, trainTable, prog);
				m_Classifiers.put(cl.getKey(), classifier);
			}
			catch (InvocationTargetException e)
			{
				if (e.getTargetException() instanceof BadHeaderException)
					Report.displaynl(cl.getKey()+" not trained: "+e.getTargetException().getMessage());
				else Report.exception((Exception)e.getTargetException());
			}
			catch (Exception e)
			{
				Report.exception(e);
			}
		}
	}

    /**
     * Classifies a test data set.
     *
     * @param tstTable  Test data set.
     * @param prog      Progress object for classification process.
     * @return          Map of entries: name of a classifier
     *                  and the object TestResult with a classification result.
     * @throws InterruptedException when the user interrupts the execution.
     */
    public Map<String,TestResult> classify(DoubleDataTable tstTable, Progress prog) throws InterruptedException
    {
        // klasyfikacja tabeli testowej
        if (tstTable.noOfObjects()<=0) throw new RuntimeException("Classification of an empty table");
        NominalAttribute decAttr = tstTable.attributes().nominalDecisionAttribute();
        Map<String,int[][]> mapOfConfusionMatrices = new HashMap<String,int[][]>();
        prog.set("Classifing test table", tstTable.noOfObjects());
        for (DoubleData dObj : tstTable.getDataObjects())
        {
            int objDecLocalCode = decAttr.localValueCode(((DoubleDataWithDecision)dObj).getDecision());
            for (Map.Entry<String,Classifier> cl : m_Classifiers.entrySet())
            {
                int[][] confusionMatrix = (int[][])mapOfConfusionMatrices.get(cl.getKey());
                if (confusionMatrix==null)
                {
                    confusionMatrix = new int[decAttr.noOfValues()][];
                    for (int i = 0; i < confusionMatrix.length; i++)
                        confusionMatrix[i] = new int[decAttr.noOfValues()];
                    mapOfConfusionMatrices.put(cl.getKey(), confusionMatrix);
                }
                try
                {
                    double dec = cl.getValue().classify(dObj);
                    if (!Double.isNaN(dec))
                    	confusionMatrix[objDecLocalCode][decAttr.localValueCode(dec)]++;
                }
                catch (RuntimeException e)
                {
                    Report.exception(e);
                }
                catch (PropertyConfigurationException e)
                {
                    Report.exception(e);
                }
            }
            prog.step();
        }
        // przygotowanie wynikow klasyfikacji
        Map<String,TestResult> resultMap = new HashMap<String,TestResult>();
        for (Map.Entry<String,Classifier> cl : m_Classifiers.entrySet())
        {
            int[][] confusionMatrix = (int[][])mapOfConfusionMatrices.get(cl.getKey());
            cl.getValue().calculateStatistics();
            TestResult results = new TestResult(decAttr, tstTable.getDecisionDistribution(), confusionMatrix, ((ConfigurationWithStatistics)cl.getValue()).getStatistics());
            resultMap.put(cl.getKey(), results);
        }
        return resultMap;
    }
}

Copyright © 2008-2011 by TunedIT
Design by luksite