Repository /Rseslib/rseslib-3.0.2.jar:rseslib.processing.classification.tree.c45.DecisionTreeNode


Back

No file description

Source code

/*
 * $RCSfile: DecisionTreeNode.java,v $
 * $Revision: 1.19 $
 * $Date: 2007/12/31 12:53:57 $
 * $Author: wojna $
 * 
 * Copyright (C) 2002 - 2007 Logic Group, Institute of Mathematics, Warsaw University
 * 
 *  This file is part of Rseslib.
 *
 *  Rseslib is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  Rseslib is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


package rseslib.processing.classification.tree.c45;

import java.io.*;
import java.util.ArrayList;
import java.util.Collection;

import rseslib.structure.attribute.*;
import rseslib.structure.data.DoubleData;
import rseslib.structure.data.DoubleDataWithDecision;
import rseslib.structure.function.intval.Discrimination;
import rseslib.structure.vector.Vector;
import rseslib.system.Report;
import rseslib.processing.classification.tree.c45.DiscriminationProvider;

/**
 * The node in the decision tree.
 *
 * @author      Arkadiusz Wojna
 */
public class DecisionTreeNode implements Serializable
{
    /** Serialization version. */
	private static final long serialVersionUID = 1L;
	
	/** Information about the decision. */
    NominalAttribute m_DecisionAttribute = null;
    /** Statistics of the decision distribution for this node. */
    Vector m_DecisionVector = null;
    /** Label of the decision for this node. */
    double m_Decision;
	/** Childer (sub-nodes) of this node. If node is a leaf than null. */
	DecisionTreeNode m_Children[] = null;
    /** Determines to which sub-node the object should be sent. */
    Discrimination m_BranchSelector = null;

    public DecisionTreeNode(Collection<DoubleData> dataObjects, Header hdr, DiscriminationProvider discrProv, double defaultDec)
    {
        m_DecisionAttribute = hdr.nominalDecisionAttribute();
        m_DecisionVector = new Vector(m_DecisionAttribute.noOfValues());
        for (DoubleData dObj : dataObjects)
            m_DecisionVector.increment(m_DecisionAttribute.localValueCode(((DoubleDataWithDecision)dObj).getDecision()));
        int best = 0;
        int different = 0;
        for (int dec = 0; dec < m_DecisionVector.dimension(); dec++)
        {
            if (m_DecisionVector.get(dec)>m_DecisionVector.get(best)) best = dec;
            if (m_DecisionVector.get(dec)>0) different++;
        }
        if (different==0) m_Decision = defaultDec;
        else m_Decision = m_DecisionAttribute.globalValueCode(best);
        if (different > 1) m_BranchSelector = discrProv.getDiscrimination(dataObjects, hdr);
        if (m_BranchSelector!=null)
        {
            ArrayList<DoubleData>[] childrenSets = new ArrayList[m_BranchSelector.noOfValues()];
            for (int child = 0; child < childrenSets.length; child++)
                childrenSets[child] = new ArrayList<DoubleData>();
            for (DoubleData dObj : dataObjects)
            {
                int branch = m_BranchSelector.intValue(dObj);
                if (branch>=0) childrenSets[branch].add(dObj);
            }
            m_Children = new DecisionTreeNode[childrenSets.length];
            for (int child = 0; child < m_Children.length; child++)
                m_Children[child] = new DecisionTreeNode(childrenSets[child], hdr, discrProv, m_Decision);
        }
    }

    /**
     * Writes this object.
     *
     * @param out			Output for writing.
     * @throws IOException	if an I/O error has occured.
     */
    private void writeObject(ObjectOutputStream out) throws IOException
    {
    	out.writeObject(m_BranchSelector);
    	out.writeObject(m_DecisionAttribute);
    	out.writeObject(m_DecisionVector);
    	out.writeObject(NominalAttribute.stringValue(m_Decision));
    	out.writeObject(m_Children);
    }

    /**
     * Reads this object.
     *
     * @param out			Output for writing.
     * @throws IOException	if an I/O error has occured.
     */
    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException
    {
    	m_BranchSelector = (Discrimination)in.readObject();
    	m_DecisionAttribute = (NominalAttribute)in.readObject();
    	m_DecisionVector = (Vector)in.readObject();
    	m_Decision = m_DecisionAttribute.globalValueCode((String)in.readObject());
    	m_Children = (DecisionTreeNode[])in.readObject();
    }

    /**
     * Prunes this subtree and returns the number
     * of correctly classified objects from a given set.
     *
     * @param dataObject	Objects used to validate and prune this subtree.
     * @return    			Number of correctly classified objects from dataObjects. 
     */
    public int prune(Collection<DoubleData> dataObjects)
    {
    	int correct = 0;
    	if (isLeaf())
    	{
    		for (DoubleData dObj : dataObjects)
    			if (((DoubleDataWithDecision)dObj).getDecision()==m_Decision)
    				correct++;
    	}
    	else
    	{
	        ArrayList<DoubleData> childrenSets[] = new ArrayList[m_Children.length];
	        for (int child = 0; child < childrenSets.length; child++)
	        	childrenSets[child] = new ArrayList<DoubleData>();
	        for (DoubleData dObj : dataObjects)
	        {
	            int branch = m_BranchSelector.intValue(dObj);
	            if (branch>=0) childrenSets[branch].add(dObj);
	        }
	    	int leafsCorrect = 0;	
	        for (int child = 0; child < m_Children.length; child++)
	        	leafsCorrect += m_Children[child].prune(childrenSets[child]);
	        for (DecisionTreeNode child : m_Children)
	            if (!child.isLeaf()) return leafsCorrect;
    		for (DoubleData dObj : dataObjects)
    			if (((DoubleDataWithDecision)dObj).getDecision()==m_Decision)
    				correct++;
    		if (correct >= leafsCorrect) cutTree();
    		else correct = leafsCorrect;
    	}
    	return correct;
    }

    /**
     * Changes the decision in this node.
     * 
     * @param newDec New decision to be set.
     */
    public void changeDecision(double newDec)
    {
    	m_Decision = newDec;
    }
    
    /**
     * Returns classification for object <code>obj</code>.
     * Classification is represented by decision distribution.
     * @param obj object for which classification is done.
     * @return classification for object <code>obj</code>.
     */
    public double classify(DoubleData obj)
    {
        // If the node is a leaf than its decision distribution is returned
        if (isLeaf()) return m_Decision;
        else
        {
            int val = m_BranchSelector.intValue(obj);
            /* If the test determines one node
             * then its decision distribution is returned */
            if (val>=0) return m_Children[val].classify(obj);
            /* If the test can not determine one node
             * than sum of all decision distributions is returned */
            else
            {
                Vector decVec = new Vector(m_Children[0].classifyWithDecDistribution(obj));
                for (int i=1; i<m_Children.length; i++)
                decVec.add(m_Children[i].classifyWithDecDistribution(obj));
                int bestDec = 0;
                for (int dec = 1; dec < decVec.dimension(); dec++)
                    if (decVec.get(dec) > decVec.get(bestDec)) bestDec = dec;
                return m_DecisionAttribute.globalValueCode(bestDec);
            }
        }
    }

    /**
     * Returns classification for object <code>obj</code>.
     * Classification is represented by decision distribution.
     * @param obj object for which classification is done.
     * @return classification for object <code>obj</code>.
     */
    public Vector classifyWithDecDistribution(DoubleData obj)
    {
        // If the node is a leaf than its decision distribution is returned
        if (isLeaf()) return m_DecisionVector;
        else
        {
            int val = m_BranchSelector.intValue(obj);
            /* If the test determines one node
             * then its decision distribution is returned */
            if (val>=0)
            {
            	Vector res = m_Children[val].classifyWithDecDistribution(obj);
            	for (int i=0; i<res.dimension(); i++)
            		if (res.get(i)!=0) return res;
            	return m_DecisionVector;
            }
            /* If the test can not determine one node
             * than sum of all decision distributions is returned */
            else
            {
                Vector decVec = new Vector(m_Children[0].classifyWithDecDistribution(obj));
                for (int i=1; i<m_Children.length; i++)
                decVec.add(m_Children[i].classifyWithDecDistribution(obj));
                return decVec;
            }
        }
    }

    /**
     * Returns true if this node has no childern.
     * @return true if this node has no childern.
     */
    public boolean isLeaf()
    {
        return (m_Children==null) || (m_Children.length==0);
    }

    /**
     * Returns depth of the tree beginning from this node.
     * @return depth of the tree beginning from this node.
     */
    public int getDepth()
    {
        if (isLeaf()) return 0;
        else
        {
            int max=0;
            for (int i=0;i<m_Children.length;i++)
            {
                int depth=m_Children[i].getDepth();
                if (depth>max) max=depth;
            }
            return 1+max;
        }
    }

    /**
     * Returns number of leaves in subtree starting from this node.
     * @return number of leaves in subtree starting from this node.
     */
    public int getLeaves()
    {
        if (isLeaf()) return 1;
        else
        {
            int sum=0;
            for (int i=0;i<m_Children.length;i++)
            sum+=m_Children[i].getLeaves();
            return sum;
        }
    }

    /**
     * Returns the index of the branch matching the object.
     *
     * @param obj Object to be tested.
     * @return    The index of the branch matching the object.
     */
    public int branch(DoubleData obj)
    {
    	return m_BranchSelector.intValue(obj);
    }

    /**
     * Returns the number of branches.
     *
     * @return    The number of branches.
     */
    public int noOfBranches()
    {
    	return m_Children.length;
    }

    /**
     * Returns the subnode from a given branch.
     *
     * @param branchIndex The index of a branch.
     * @return    		  The subnode from the branch.
     */
    public DecisionTreeNode subnode(int branchIndex)
    {
      return m_Children[branchIndex];
    }

    /**
     * Cuts the subtree at this node and converts it to a leaf node.
     */
    public void cutTree()
    {
    	m_Children = null;
    	m_BranchSelector = null;
    }

    /**
     * Returns the decision distribution in this node.
     *
     * @return The decision distribution in this node.
     */
    public Vector getDecisionVector()
    {
        return m_DecisionVector;
    }

    /**
     * Returns the decision for this node.
     *
     * @return The decision for this node.
     */
    public String getDecisionLabel()
    {
        return NominalAttribute.stringValue(m_Decision);
    }

    /**
     * Outputs a tree at a given level.
     *
     * @param level The level at which the tree is to be printed.
     * @return      The description of a tree at a given level.
     */
    public String toString(int level)
    {
        StringBuffer text = new StringBuffer();
        if (isLeaf())
        {
            int bestDec = 0;
            for (int dec = 1; dec < m_DecisionVector.dimension(); dec++)
                if (m_DecisionVector.get(dec) > m_DecisionVector.get(bestDec))
                    bestDec = dec;
            text.append(": "+NominalAttribute.stringValue(m_DecisionAttribute.globalValueCode(bestDec)));
        }
        else
        {
            for (int child = 0; child < m_Children.length; child++)
            {
                text.append(Report.lineSeparator);
                for (int i = 0; i < level; i++) text.append("|  ");
                text.append(m_BranchSelector.toString(child));
                text.append(m_Children[child].toString(level + 1));
            }
      }
      return text.toString();
    }
}

Copyright © 2008-2011 by TunedIT
Design by luksite