/*
* $RCSfile: KnnClassifier.java,v $
* $Revision: 1.41 $
* $Date: 2008/01/08 14:33:15 $
* $Author: wojna $
*
* Copyright (C) 2002 - 2007 Logic Group, Institute of Mathematics, Warsaw University
*
* This file is part of Rseslib.
*
* Rseslib is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Rseslib is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package rseslib.processing.classification.parameterised.knn;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Properties;
import rseslib.processing.classification.ClassifierWithDistributedDecision;
import rseslib.processing.classification.parameterised.AbstractParameterisedClassifier;
import rseslib.processing.indexing.metric.TreeIndexer;
import rseslib.processing.metrics.MetricFactory;
import rseslib.processing.searching.metric.IndexingTreeVicinityProvider;
import rseslib.processing.searching.metric.VicinityProvider;
import rseslib.processing.transformation.AttributeTransformer;
import rseslib.processing.transformation.TableTransformer;
import rseslib.structure.attribute.NominalAttribute;
import rseslib.structure.data.DoubleData;
import rseslib.structure.index.metric.IndexingTreeNode;
import rseslib.structure.metric.Metric;
import rseslib.structure.metric.Neighbour;
import rseslib.structure.metric.AbstractWeightedMetric;
import rseslib.structure.table.ArrayListDoubleDataTable;
import rseslib.structure.table.DoubleDataTable;
import rseslib.system.PropertyConfigurationException;
import rseslib.system.progress.EmptyProgress;
import rseslib.system.progress.MultiProgress;
import rseslib.system.progress.Progress;
/**
* K nearest neighbours classifier.
* Consistency checking among nearest neighbours may be used.
* Two types of vote weighting is possible:
* uniform and inversely proportional to distance.
*
* @author Grzegorz Góra, Arkadiusz Wojna, Lukasz Ligowski
*/
public class KnnClassifier extends AbstractParameterisedClassifier implements ClassifierWithDistributedDecision, Serializable
{
/** Attribute weighting methods. */
private enum Voting { Equal, InverseDistance, InverseSquareDistance; }
/** Serialization version. */
private static final long serialVersionUID = 1L;
/** Property name for weighting method. */
public static final String WEIGHTING_METHOD_PROPERTY_NAME = "weightingMethod";
/** Name of property indicating whether the classifier learns the optimal number k. */
public static final String LEARN_OPTIMAL_K_PROPERTY_NAME = "learnOptimalK";
/** Name of property defining the maximal number of k while learning the optimal value. */
public static final String MAXIMAL_K_PROPERTY_NAME = "maxK";
/** Parameter name. */
public static final String K_PROPERTY_NAME = "k";
/** Name of property indicating whether consistency checking is considered. */
public static final String FILTER_NEIGHBOURS_PROPERTY_NAME = "filterNeighboursUsingRules";
/** Name of property indicating whether neighbour voting is weighted with distance. */
public static final String VOTING_PROPERTY_NAME = "voting";
/** Collection of the original training data objects. */
private ArrayList<DoubleData> m_OriginalData;
/** Data transoformer used in the induced metric. */
AttributeTransformer m_Transformer;
/** Transformed training data. */
DoubleDataTable m_TransformedTrainTable;
/** The induced metric. */
Metric m_Metric;
/** Provider of vicinity for test data objects. */
VicinityProvider m_VicinityProvider;
/** Filter for neigbours using cubes on objects and consistency. */
private CubeBasedNeighboursFilter m_NeighboursFilter;
/** Switch to recognize whether searching for optimal k is going on. */
private boolean m_bSelfLearning = false;
/** Maximal value k in parameterised classification. */
private int m_nMaxK;
/** Decision attribute. */
private NominalAttribute m_DecisionAttribute;
/** The default decision defined by the largest support in a training data set. */
private int m_nDefaultDec;
/**
* Constructor that induces a metric
* from a given training set trainTable
* and constructs an indexing tree.
* It transforms data objects inside the constructor.
*
* @param prop Properties of this knn clasifier.
* @param trainTable Table used to build vicinity provider and to learn the optimal value of the classifier parameter.
* @param prog Progress object to report training progress.
* @throws InterruptedException when the user interrupts the execution.
*/
public KnnClassifier(Properties prop, DoubleDataTable trainTable, Progress prog) throws PropertyConfigurationException, InterruptedException
{
super(prop, K_PROPERTY_NAME);
// prepare progress information
int[] progressVolumes = null;
if (getBoolProperty(LEARN_OPTIMAL_K_PROPERTY_NAME))
{
progressVolumes = new int[3];
progressVolumes[0] = 40;
progressVolumes[1] = 10;
progressVolumes[2] = 50;
}
else
{
progressVolumes = new int[2];
progressVolumes[0] = 80;
progressVolumes[1] = 20;
}
prog = new MultiProgress("Learning the k-nn classifier", prog, progressVolumes);
// induce a metric and transform training objects for optimization of distance computations
m_OriginalData = trainTable.getDataObjects();
m_Metric = MetricFactory.getMetric(getProperties(), trainTable);
m_Transformer = m_Metric.transformationOutside();
m_TransformedTrainTable = trainTable;
if (m_Transformer!=null)
m_TransformedTrainTable = TableTransformer.transform(trainTable, m_Transformer);
if (m_Metric instanceof AbstractWeightedMetric)
MetricFactory.adjustWeights(getProperty(WEIGHTING_METHOD_PROPERTY_NAME), (AbstractWeightedMetric)m_Metric, m_TransformedTrainTable, prog);
// index the training objects
IndexingTreeNode indexingTree = new TreeIndexer(null).indexing(m_TransformedTrainTable.getDataObjects(), m_Metric, prog);
m_VicinityProvider = new IndexingTreeVicinityProvider(null, m_Metric, indexingTree);
// store information required in classification
if (m_Metric instanceof AbstractWeightedMetric)
m_NeighboursFilter = new CubeBasedNeighboursFilter((AbstractWeightedMetric)m_Metric);
m_nMaxK = getIntProperty(MAXIMAL_K_PROPERTY_NAME);
m_DecisionAttribute = trainTable.attributes().nominalDecisionAttribute();
m_nDefaultDec = 0;
for (int dec = 1; dec < trainTable.getDecisionDistribution().length; dec++)
if (trainTable.getDecisionDistribution()[dec] > trainTable.getDecisionDistribution()[m_nDefaultDec])
m_nDefaultDec = dec;
if (getBoolProperty(LEARN_OPTIMAL_K_PROPERTY_NAME))
{
m_bSelfLearning = true;
learnOptimalParameterValue(trainTable, prog);
m_bSelfLearning = false;
}
makePropertyModifiable(K_PROPERTY_NAME);
makePropertyModifiable(FILTER_NEIGHBOURS_PROPERTY_NAME);
makePropertyModifiable(VOTING_PROPERTY_NAME);
}
/**
* Constructor that builds an indexing tree.
* It uses the metric given as the parameter.
* It assumes that objects are transformed outside the classifier.
*
* @param prop Properties of this knn clasifier.
* @param metric Metric used in this classifier.
* @param trainTable Table used to build vicinity provider and to learn the optimal value of the classifier parameter.
* @param prog Progress object to report training progress.
* @throws InterruptedException when the user interrupts the execution.
*/
public KnnClassifier(Properties prop, Metric metric, DoubleDataTable trainTable, Progress prog) throws PropertyConfigurationException, InterruptedException
{
super(prop, K_PROPERTY_NAME);
IndexingTreeNode indexingTree = new TreeIndexer(null).indexing(trainTable.getDataObjects(), metric, prog);
m_VicinityProvider = new IndexingTreeVicinityProvider(null, metric, indexingTree);
m_nMaxK = getIntProperty(MAXIMAL_K_PROPERTY_NAME);
if (metric instanceof AbstractWeightedMetric)
m_NeighboursFilter = new CubeBasedNeighboursFilter((AbstractWeightedMetric)metric);
m_DecisionAttribute = trainTable.attributes().nominalDecisionAttribute();
m_nDefaultDec = 0;
for (int dec = 1; dec < trainTable.getDecisionDistribution().length; dec++)
if (trainTable.getDecisionDistribution()[dec] > trainTable.getDecisionDistribution()[m_nDefaultDec]) m_nDefaultDec = dec;
makePropertyModifiable(K_PROPERTY_NAME);
makePropertyModifiable(FILTER_NEIGHBOURS_PROPERTY_NAME);
makePropertyModifiable(VOTING_PROPERTY_NAME);
}
/**
* Constructor.
* It assumes that objects are transformed outside the classifier.
*
* @param prop Map between property names and property values.
* @param decAttr Decision attribute.
* @param vicinProv Provider of vicninities for test data objects.
* @param decDistribution Distribution of decision in a training data set.
*/
public KnnClassifier(Properties prop, NominalAttribute decAttr, VicinityProvider vicinProv, CubeBasedNeighboursFilter neighbourFilter, int[] decDistribution) throws PropertyConfigurationException
{
super(prop, K_PROPERTY_NAME);
m_VicinityProvider = vicinProv;
m_nMaxK = getIntProperty(MAXIMAL_K_PROPERTY_NAME);
m_NeighboursFilter = neighbourFilter;
m_DecisionAttribute = decAttr;
m_nDefaultDec = 0;
for (int dec = 1; dec < decDistribution.length; dec++)
if (decDistribution[dec] > decDistribution[m_nDefaultDec]) m_nDefaultDec = dec;
makePropertyModifiable(K_PROPERTY_NAME);
makePropertyModifiable(FILTER_NEIGHBOURS_PROPERTY_NAME);
makePropertyModifiable(VOTING_PROPERTY_NAME);
}
/**
* Writes this object.
*
* @param out Output for writing.
* @throws IOException if an I/O error has occured.
*/
private void writeObject(ObjectOutputStream out) throws IOException
{
writeAbstractParameterisedClassifier(out);
out.writeObject(m_OriginalData);
out.writeObject(m_Transformer);
out.writeObject(m_Metric);
out.writeInt(m_nMaxK);
out.writeObject(m_DecisionAttribute);
out.writeInt(m_nDefaultDec);
}
/**
* Reads this object.
*
* @param out Output for writing.
* @throws IOException if an I/O error has occured.
*/
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException
{
readAbstractParameterisedClassifier(in);
m_OriginalData = (ArrayList<DoubleData>)in.readObject();
ArrayList<DoubleData> transformedObjects = m_OriginalData;
m_Transformer = (AttributeTransformer)in.readObject();
if (m_Transformer!=null)
{
transformedObjects = new ArrayList<DoubleData>(m_OriginalData.size());
for (DoubleData dObj : m_OriginalData)
transformedObjects.add(m_Transformer.transformToNew(dObj));
}
m_TransformedTrainTable = new ArrayListDoubleDataTable(transformedObjects);
m_Metric = (Metric)in.readObject();
try
{
IndexingTreeNode indexingTree = new TreeIndexer(null).indexing(m_TransformedTrainTable.getDataObjects(), m_Metric, new EmptyProgress());
m_VicinityProvider = new IndexingTreeVicinityProvider(null, m_Metric, indexingTree);
}
catch (InterruptedException e)
{
throw new NotSerializableException(e.getMessage());
}
catch (PropertyConfigurationException e)
{
throw new NotSerializableException(e.getMessage());
}
m_bSelfLearning = false;
m_nMaxK = in.readInt();
if (m_Metric instanceof AbstractWeightedMetric)
m_NeighboursFilter = new CubeBasedNeighboursFilter((AbstractWeightedMetric)m_Metric);
m_DecisionAttribute = (NominalAttribute)in.readObject();
m_nDefaultDec = in.readInt();
}
/**
* Sets the self-learning switch, required to set,
* if k optimization is done outside the classifier.
*
* @param selfLearning The value to be set.
*/
public void setSelfLearning(boolean selfLearning)
{
m_bSelfLearning = selfLearning;
}
/**
* Returns a decision distribution vector
* for a single test object.
* The weight of each decision value is given
* at the position of the vector
* identifed by the local code of this decision value.
*
* @param dObj Test object.
* @return Assigned decision distribution.
*/
public double[] classifyWithDistributedDecision(DoubleData dObj) throws PropertyConfigurationException
{
if (m_Transformer!=null) dObj = m_Transformer.transformToNew(dObj);
Neighbour[] neighbours = m_VicinityProvider.getVicinity(dObj, getIntProperty(K_PROPERTY_NAME));
boolean checkConsistency = getBoolProperty(FILTER_NEIGHBOURS_PROPERTY_NAME);
if (checkConsistency && m_NeighboursFilter!=null)
m_NeighboursFilter.markConsistency(dObj, neighbours);
double[] decDistr = new double[m_DecisionAttribute.noOfValues()];
Voting votingType;
try
{
votingType = Voting.valueOf(getProperty(VOTING_PROPERTY_NAME));
}
catch (IllegalArgumentException e)
{
throw new PropertyConfigurationException("Unknown voting method: "+getProperty(VOTING_PROPERTY_NAME));
}
for (int n = 1; n < neighbours.length; n++)
{
int curDec = m_DecisionAttribute.localValueCode(neighbours[n].neighbour().getDecision());
if (!checkConsistency || neighbours[n].m_bConsistent)
switch (votingType)
{
case Equal:
decDistr[curDec] += 1.0;
break;
case InverseDistance:
decDistr[curDec] += 1.0 / neighbours[n].dist();
break;
case InverseSquareDistance:
decDistr[curDec] += 1.0 / (neighbours[n].dist()*neighbours[n].dist());
break;
}
}
return decDistr;
}
/**
* Assigns a decision to a single test object.
*
* @param dObj Test object.
* @return Assigned decision.
*/
public double classify(DoubleData dObj) throws PropertyConfigurationException
{
double[] decDistr = classifyWithDistributedDecision(dObj);
int bestDec = 0;
for (int dec = 1; dec < decDistr.length; dec++)
if (decDistr[dec] > decDistr[bestDec]) bestDec = dec;
return m_DecisionAttribute.globalValueCode(bestDec);
}
/**
* Classifies a test object on the basis of nearest neighbours.
*
* @param dObj Test object.
* @return Array of assigned decisions, indices correspond to parameter values.
*/
public double[] classifyWithParameter(DoubleData dObj) throws PropertyConfigurationException
{
if (m_Transformer!=null) dObj = m_Transformer.transformToNew(dObj);
Neighbour[] neighbours = null;
if (m_bSelfLearning)
{
Neighbour[] neighboursOneMore = m_VicinityProvider.getVicinity(dObj, m_nMaxK+1);
neighbours = new Neighbour[neighboursOneMore.length-1];
int i = 1;
for (; i < neighbours.length && !dObj.equals(neighboursOneMore[i].neighbour()); i++)
neighbours[i] = neighboursOneMore[i];
for (; i < neighbours.length; i++) neighbours[i] = neighboursOneMore[i+1];
}
else neighbours = m_VicinityProvider.getVicinity(dObj, m_nMaxK);
boolean checkConsistency = getBoolProperty(FILTER_NEIGHBOURS_PROPERTY_NAME);
if (checkConsistency && m_NeighboursFilter!=null)
m_NeighboursFilter.markConsistency(dObj, neighbours);
double[] decisions = new double[m_nMaxK+1];
double[] decDistr = new double[m_DecisionAttribute.noOfValues()];
int bestDec = m_nDefaultDec;
decisions[0] = m_DecisionAttribute.globalValueCode(bestDec);
Voting votingType;
try
{
votingType = Voting.valueOf(getProperty(VOTING_PROPERTY_NAME));
}
catch (IllegalArgumentException e)
{
throw new PropertyConfigurationException("Unknown voting method: "+getProperty(VOTING_PROPERTY_NAME));
}
for (int n = 1; n < decisions.length; n++)
{
if (n < neighbours.length)
{
int curDec = m_DecisionAttribute.localValueCode(neighbours[n].neighbour().getDecision());
if (!checkConsistency || neighbours[n].m_bConsistent)
switch (votingType)
{
case Equal:
decDistr[curDec] += 1.0;
break;
case InverseDistance:
decDistr[curDec] += 1.0 / neighbours[n].dist();
break;
case InverseSquareDistance:
decDistr[curDec] += 1.0 / (neighbours[n].dist()*neighbours[n].dist());
break;
}
if (decDistr[curDec] > decDistr[bestDec]) bestDec = curDec;
}
decisions[n] = m_DecisionAttribute.globalValueCode(bestDec);
}
return decisions;
}
/**
* Calculates statistics.
*/
public void calculateStatistics()
{
try
{
if (getBoolProperty(LEARN_OPTIMAL_K_PROPERTY_NAME))
addToStatistics("Optimal "+K_PROPERTY_NAME, getProperty(K_PROPERTY_NAME));
}
catch (PropertyConfigurationException e)
{
}
//addToStatistics("Average number of distance calculations", Double.toString(m_VicinityProvider.getAverageNoOfDistCalculations()));
//addToStatistics("Std. dev. of the number of distance calculations", Double.toString(m_VicinityProvider.getStdDevNoOfDistCalculations()));
}
/**
* Resets statistics.
*/
public void resetStatistics()
{
}
}