Repository /Rseslib/rseslib-3.0.1.jar:rseslib.processing.classification.parameterised.knn.KnnVis


Back

No file description

Source code

/*
 * $RCSfile: KnnVis.java,v $
 * $Revision: 1.3 $
 * $Date: 2007/08/18 10:44:06 $
 * $Author: wojna $
 * 
 * Copyright (C) 2002 - 2007 Logic Group, Institute of Mathematics, Warsaw University
 * 
 *  This file is part of Rseslib.
 *
 *  Rseslib is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  Rseslib is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


package rseslib.processing.classification.parameterised.knn;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.awt.event.MouseMotionListener;
import java.util.Hashtable;
import java.util.Properties;
import java.util.Random;
import java.util.Set;

import javax.swing.JButton;
import javax.swing.JCheckBox;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JSplitPane;
import javax.swing.JTextField;

import rseslib.processing.classification.VisualClassifier;
import rseslib.structure.attribute.Attribute;
import rseslib.structure.attribute.Header;
import rseslib.structure.attribute.NominalAttribute;
import rseslib.structure.data.DoubleData;
import rseslib.structure.metric.Neighbour;
import rseslib.structure.table.DoubleDataTable;
import rseslib.system.PropertyConfigurationException;
import rseslib.system.progress.Progress;

public class KnnVis extends KnnClassifier implements VisualClassifier
{
    /** Serialization version. */
	private static final long serialVersionUID = 1L;

	private Hashtable<DoubleData, DPoint> placement = new Hashtable<DoubleData, DPoint>();
    private Random rnd = new Random(System.currentTimeMillis());
    private JPanel pnl;
    private Painter painter;
    private double avg;
	private double xmin = Double.NEGATIVE_INFINITY;
	private double ymin = Double.NEGATIVE_INFINITY;
	private double xmax = Double.POSITIVE_INFINITY;
	private double ymax = Double.POSITIVE_INFINITY;
	private int FIND_THRES = 10;
	private int POINT_SIZE = 5;
	private final double START_MULT = 0.02;
	private final double MAX_JUMP = 0.1;
	// private final double DECAY_MULT = 0.99; // dla kwadratow
	private final double DECAY_MULT = 0.995; // dla liniowego
	private final double DECAY_MIN = 0.01;
	private final double EPSILON = 0.00000000000001;
	private final int START_MAX_ROWS = 150;
    double mult = START_MULT;
    private String strLegend = "";
    
    double currDev = 0.00;
    double iter = 0;
    double fProg = 0;  
	private Hashtable<Double, Integer> htCols = new Hashtable<Double, Integer>();
	private int[] startcolors = new int[] { 0, 255, 255*256, 128, 128*256, 128*256+128, 128*256+255, 255*256+128 };
	private Thread calcThread;
	private DoubleDataTable table;
	private JLabel lblInfo;
	private boolean showDetails = true;
	private JButton btnRun;
	
	public KnnVis(Properties prop, DoubleDataTable trainTable, Progress prog) throws PropertyConfigurationException, InterruptedException
	{
		super(prop, trainTable, prog);

		this.table = trainTable;
        int cnt = m_TransformedTrainTable.noOfObjects();
        avg = 0.00;
		int dec = -1;
		

    	for (DoubleData v1 : m_TransformedTrainTable.getDataObjects())
        {
        	if (dec == -1) dec = v1.attributes().decision();
        	if (!htCols.containsKey(v1.get(dec)))
        	{
        		int p = htCols.size();
        		if (p<startcolors.length)
        			htCols.put(v1.get(dec), startcolors[p]);
        		else
        			htCols.put(v1.get(dec), rnd.nextInt());
        	}
        	
        	double partsum = 0.00;
            for (DoubleData v2 : m_TransformedTrainTable.getDataObjects())
            {
            	partsum += m_Metric.dist(v1, v2);
            }
            avg += partsum / cnt;
        }
        avg /= cnt;

        strLegend = "<br><b>Decisions</b>:<br>";
        for (Double key : htCols.keySet())
        {
			String name = NominalAttribute.stringValue(key);
			int color = htCols.get(key);
			int cr = color % 256;
			int cg = (color >> 8) % 256;
			int cb = (color >> 16) % 256; 
			String hexColor = toHex(cr) + toHex(cg) + toHex(cb);
			strLegend += "<font color=#" + hexColor + ">" + name + "</font><br>";
        }
        
        findRandomPlacement(Integer.MAX_VALUE);
	}

	private String toHex(int val)
	{
		return ("" + "0123456789ABCDEF".charAt(val>>4)) + ("" + "0123456789ABCDEF".charAt(val%16));
	}
	
	private void findRandomPlacement(int maxcnt)
	{
		mult = START_MULT;
		iter = 0;
		placement.clear();
        int cnt = m_TransformedTrainTable.noOfObjects();
        for (DoubleData next : m_TransformedTrainTable.getDataObjects())
        {
           if (cnt < maxcnt || rnd.nextInt(cnt) < maxcnt)
           {
        	   DPoint guess = new DPoint(avg);
        	   //obj.add(next);
        	   placement.put(next, guess);
           	}
        }		
	}
	
	private DoubleData findOriginal(DoubleData dat)
	{
		int nr = 0;
		boolean found = false;
        for (DoubleData next : m_TransformedTrainTable.getDataObjects())
        {
        	if (next == dat)
        	{
        		found = true;
        		break;
        	}
        	nr++;
        }
        
        if (found)
        {
            for (DoubleData next : table.getDataObjects())
            {
            	if (nr == 0)
            	{
            		return next;
            	}
            	nr--;
            }
        }
        return dat;
	}
	
	private String formatData(DoubleData dat)
	{
		dat = findOriginal(dat);

		String out = "";
		int cnt = dat.attributes().noOfAttr();
		for (int i=0;i<cnt;i++)
		{
			Attribute attr = dat.attributes().attribute(i);
			String val;
			if (attr.isNominal())
			{
				val = NominalAttribute.stringValue(dat.get(i));
			}
			else
			{
				val = "" + dat.get(i);
				if (val.length() > 5) val = val.substring(0, 5);
			}
			out += attr.name() + ": <i>" + val + "</i><br>";
		}
		return out;
	}

	public void draw(JPanel canvas)
	{
	    if (canvas.equals(pnl)) return;
	      
	    pnl = canvas;
	    painter = new Painter();
	    JScrollPane scroll = new JScrollPane(painter);
	    scroll.setVisible(true);
	    //canvas.add(scroll);
	    
	    JSplitPane jsp = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT);
	    canvas.add(jsp);
	    jsp.setRightComponent(scroll);
	    jsp.setDividerLocation(-1);
	    
	    int max = m_TransformedTrainTable.noOfObjects();
	    if (max > START_MAX_ROWS) max = START_MAX_ROWS;
	    final JTextField jtMax = new JTextField("" + m_TransformedTrainTable.noOfObjects());
	    jtMax.setPreferredSize(new Dimension(100, 24));
	    JButton btnRestart = new JButton("Restart");
	    btnRestart.addActionListener(new ActionListener()
	    {
			public void actionPerformed(ActionEvent e) 
			{
				String txt = jtMax.getText();
				int maxcnt = Integer.MAX_VALUE;
				try
				{
					maxcnt = Integer.parseInt(txt);
				}
				catch (NumberFormatException ignored)
				{
				}
				jtMax.setText("" + maxcnt);
				findRandomPlacement(maxcnt);
				painter.classified = null;
	        	painter.repaint();
	        	startThread();
			}
	    });
	    /*
	    JButton btnStep = new JButton("Step");
	    btnStep.addActionListener(new ActionListener()
	    {

			public void actionPerformed(ActionEvent e) 
			{
				new Thread(new Runnable()
				{
					public void run()
					{
			        	findPlacement();
			        	painter.repaint();						
					}
				}).start();
			}	    	
	    });
	    */
	    JButton btnClassify = new JButton("Classify selected");
	    btnClassify.addActionListener(new ActionListener()
	    {
			public void actionPerformed(ActionEvent e) 
			{
				if (painter.selected != null) drawClassify(pnl, findOriginal(painter.selected));
			}	    	
	    });
	    btnRun = new JButton("Start");
	    btnRun.addActionListener(new ActionListener()
	    {
			public void actionPerformed(ActionEvent e) 
			{
				if (calcThread == null)
				{
					startThread();
				}
				else
				{
					stopThread();
				}
			}	    		    	
	    });
	    
	    final JCheckBox jcbDetails = new JCheckBox("Show object details");
	    jcbDetails.setSelected(showDetails);
	    jcbDetails.addActionListener(new ActionListener()
	    {
			public void actionPerformed(ActionEvent arg0)
			{
				showDetails = jcbDetails.isSelected();
				painter.invalidate();
				painter.repaint();
			}
	    });
	    
	    JPanel pnlBtns = new JPanel();
	    pnlBtns.setLayout(new FlowLayout());
	    pnlBtns.add(new JLabel("Number of points on graph: "));
	    pnlBtns.add(jtMax);
	    pnlBtns.add(btnRestart);
	    // pnlBtns.add(btnStep);
	    pnlBtns.add(btnRun);
	    pnlBtns.add(btnClassify);
	    pnlBtns.add(jcbDetails);
	    canvas.add(pnlBtns, BorderLayout.SOUTH);
	    
	    lblInfo = new JLabel("");
	    JPanel pnlInfo = new JPanel();
	    pnlInfo.setLayout(new BorderLayout());
	    pnlInfo.add(lblInfo, BorderLayout.CENTER);
	    pnlInfo.setPreferredSize(new Dimension(200, 100));
	    pnlInfo.setAlignmentX(0);
	    lblInfo.setAlignmentX(0);
	    //canvas.add(pnlInfo, BorderLayout.EAST);
	    jsp.setLeftComponent(pnlInfo);
		startThread();
	}
	
	private void stopThread()
	{
		fProg = 0;
		btnRun.setText("Start");
		if (calcThread != null) calcThread.stop();
		calcThread = null;
	}
	
	private void startThread()
	{
		btnRun.setText("Stop");
		calcThread = new Thread(new Runnable()
		{
			public void run()
			{
				try
				{
					while (true)
					{
						findPlacement();
						painter.invalidate();
						painter.revalidate();
						painter.repaint();
						try
						{
							Thread.sleep(50);
						}
						catch (InterruptedException ie) {}
					}
				}
				catch (ThreadDeath td)
				{
				}
				catch (Throwable thr)
				{
					thr.printStackTrace();
				}
			}
		});
		calcThread.start();
	}

	public double drawClassify(JPanel canvas, DoubleData obj)
	{
		DoubleData orgobj = obj;
		try
		{
			mult = START_MULT;
			iter = 0;
			obj = m_Transformer.transformToNew(obj);
			Neighbour[] n = m_VicinityProvider.getVicinity(obj, getIntProperty(K_PROPERTY_NAME));
			placement.clear();
			for (int i=0;i<n.length;i++)
			{
        	   DPoint guess = new DPoint(avg);
        	   //obj.add(next);
        	   if (n[i] == null) continue;
        	   placement.put(n[i].neighbour(), guess);				
			}
			draw(canvas);			
	    	placement.put(obj, new DPoint(avg));				
        	int attr = m_TransformedTrainTable.attributes().noOfAttr();
        	double best = Double.MAX_VALUE;
        	DoubleData newobj = obj;
	        for (DoubleData next : placement.keySet())
	        {
	        	double dist = 0;
	        	for (int i=0;i<attr;i++)
	        	{
	        		double del = next.get(i) - obj.get(i);
	        		if (del < 0) del = -del;
	        		dist += del;
	        	}
	        	if (dist < best)
	        	{
	        		best = dist;
	        		newobj = next;
	        	}
	        }
			obj = newobj;
			painter.classified = obj;
			painter.selected = obj;
			painter.repaint();

			if (calcThread == null)
			{
				startThread();
			}
			else
			{
				stopThread();
				startThread();
			}

			return classify(orgobj);
		}
		catch (Exception exc)
		{
			exc.printStackTrace();
		}
		return 0;
	}

	public Header attributes() 
	{
		return null;
	}

    private void findPlacement()
    {
        mult *= DECAY_MULT;
        if (mult < DECAY_MIN) mult = DECAY_MIN;
        fProg = 0;
        int i = 0;
        int s = placement.size();
        int last = 0;
        Hashtable<DoubleData, DPoint> htDelta = new Hashtable<DoubleData, DPoint>();
        long reftime = System.currentTimeMillis();
        Set<DoubleData> set = placement.keySet();        
        DoubleData[] arr = set.toArray(new DoubleData[0]);
        int cnt = arr.length;
        for (int e1 = 0;e1<cnt;e1++)
        {
        	DoubleData elem1 = arr[e1];
        	i++;
        	last++;
        	long currtime = i%10==0 ? System.currentTimeMillis() : reftime;
        	if (currtime - reftime > 200)
        	{
        		reftime = currtime;
        		fProg = (double)i / s;
        		painter.invalidate();
        		painter.repaint();
        		last = 0;
        	}
        	DPoint guess = placement.get(elem1);
        	for (int e2 = 0;e2<cnt;e2++)
        	{
        		DoubleData elem2 = arr[e2];
        		double dist;
    			dist = m_Metric.dist(elem1, elem2);
                if (dist == 0) continue;

                DPoint delta;            	
            	if (htDelta.containsKey(elem2)) 
            		delta = htDelta.get(elem2);
            	else
            		delta = new DPoint(0, 0);
            	
        		DPoint pp = placement.get(elem2);

        		DPoint vec = guess.vect(pp);
        		double veclen = vec.len();

        		if (veclen < EPSILON)
        		{	
                    htDelta.put(elem2, delta);
        			continue;
        		}
        		else
        		{
        			vec.x /= veclen;
        			vec.y /= veclen;
        			//vec.z /= veclen;
        		}
                double lenerr = veclen - dist;
        		double weight = mult * lenerr;
        		//double weight = mult * Math.signum(lenerr) * (lenerr*lenerr);
        		
        		if (Math.abs(weight) < EPSILON)
        		{
                    htDelta.put(elem2, delta);
        			continue;
        		}
        		
                delta.x -= vec.x * weight;
        		delta.y -= vec.y * weight;
                //delta.z -= vec.z * weight;
                
                htDelta.put(elem2, delta);
        	}
        }
        
    	xmax = Double.NEGATIVE_INFINITY;
    	ymax = Double.NEGATIVE_INFINITY;
    	xmin = Double.POSITIVE_INFINITY;
    	ymin = Double.POSITIVE_INFINITY;

    	for (DoubleData elem : placement.keySet())
        {
        	DPoint opos = placement.get(elem);
        	DPoint delta = htDelta.get(elem);
        	if (delta.x > MAX_JUMP * avg) delta.x = MAX_JUMP*avg;
        	if (delta.x < -MAX_JUMP * avg) delta.x = -MAX_JUMP*avg;
        	if (delta.y > MAX_JUMP * avg) delta.y = MAX_JUMP*avg;
        	if (delta.y < -MAX_JUMP * avg) delta.y = -MAX_JUMP*avg;
        	opos.x += delta.x;
        	opos.y += delta.y;
        	placement.put(elem, opos);
        	
        	if (opos.x < xmin) xmin = opos.x;
        	if (opos.x > xmax) xmax = opos.x;
        	if (opos.y < ymin) ymin = opos.y;
        	if (opos.y > ymax) ymax = opos.y;
        }
        
    	iter ++;
        fProg = 0;
		painter.invalidate();
		painter.repaint();
		// System.out.println(System.currentTimeMillis() - starttime);
    }
	
	class Painter extends JPanel implements MouseMotionListener, MouseListener
	{
	    /** Serialization version. */
		private static final long serialVersionUID = 1L;

		private DoubleData hovered;
		private DoubleData selected;
		private DoubleData classified;
		
		public Painter()
		{
			addMouseListener(this);
			addMouseMotionListener(this);
		}
		
		public void paint(Graphics g)
		{
			String info = "<html>";
			//g.setFont(Font.getFont("Monospaced"));
			int w = getWidth()-POINT_SIZE*2;
			int h = getHeight()-POINT_SIZE*2;
			g.clearRect(0, 0, getWidth(), getHeight());
			g.translate(POINT_SIZE, POINT_SIZE);
	        {
				int dec = -1;
		        for (DoubleData elem1 : placement.keySet())
		        {
		        	if (dec == -1) dec = elem1.attributes().decision();
		        	double col = elem1.get(dec);
		        	int val = htCols.get(col);
		        	DPoint guess = placement.get(elem1);
		        	int x = (int)((guess.x - xmin) / (xmax - xmin) * w); 
		        	int y = (int)((guess.y - ymin) / (ymax - ymin) * h);
		        	if (val < 0) val = -val;

		        	if (elem1 == selected)
		        	{
			        	g.setColor(new Color(0, 0, 0));
			        	g.fillOval(x-POINT_SIZE/2, y-POINT_SIZE/2, POINT_SIZE*2, POINT_SIZE*2);		        		
		        	}

	        		g.setColor(new Color(val%256, (val/256)%256, 0));
	        		g.fillOval(x, y, POINT_SIZE, POINT_SIZE);

	        		if (elem1 == classified)
		        	{
	        			g.drawLine(x-POINT_SIZE*2, y+POINT_SIZE/2, x+POINT_SIZE*3, y+POINT_SIZE/2);
	        			g.drawLine(x+POINT_SIZE/2, y-POINT_SIZE*2, x+POINT_SIZE/2, y+POINT_SIZE*3);
		        	}
		        	
		        	if (elem1 == hovered)
		        	{
			        	g.setColor(new Color(0, 0, 0));
			        	g.fillOval(x+1, y+1, POINT_SIZE/2, POINT_SIZE/2);		        		
		        	}
		        }
	        	g.setColor(Color.BLACK);
	        	
	        	info += "<b>Iteration:</b> " + (int)iter + "<br>";
	        	info += "<b>Mult:</b> " + mult + "<br>";
	        	
	        	if (selected != null && showDetails)
	        	{
		        	info += "<b>Selected:</b><br>" + formatData(selected) + "<br>";	        		
	        	}

	        	if (hovered != null && showDetails)
	        	{
	        		info += "<b>Hovered:</b><br>" + formatData(hovered) + "<br>"; 
	        	}

	        	if (selected != null && hovered != null)
	        	{
	        		DPoint p_sel = placement.get(selected);
	        		DPoint p_hov = placement.get(hovered);
	        		double len_met = m_Metric.dist(selected, hovered);
	        		double len_vis = p_sel.dist(p_hov);

	        		int x1 = (int)((p_sel.x - xmin) / (xmax - xmin) * w)+POINT_SIZE/2; 
		        	int y1 = (int)((p_sel.y - ymin) / (ymax - ymin) * h)+POINT_SIZE/2;
		        	int x2 = (int)((p_hov.x - xmin) / (xmax - xmin) * w)+POINT_SIZE/2; 
		        	int y2 = (int)((p_hov.y - ymin) / (ymax - ymin) * h)+POINT_SIZE/2;
		        	
		        	g.drawLine(x1, y1, x2, y2);

		        	info += "<b>Distance:</b><br>";
		        	info += "Metric: <i>" + len_met + "</i><br>";
		        	info += "Visible: <i>" + len_vis + "</i><br>";
	        	}
	        }
	        info += strLegend;
	        lblInfo.setText(info);
	        if (fProg > 0)
	        {
	        	g.setColor(Color.BLACK);
	        	g.fillRect(0, 0, (int)(w*fProg), 2);
	        }
		}

		public void mouseDragged(MouseEvent e) 
		{
		}

		public void mouseMoved(MouseEvent e)
		{
			hovered = findObject(e.getX(), e.getY());
			painter.repaint();			
		}
		
		private DoubleData findObject(int x, int y)
		{
			int w = getWidth();
			int h = getHeight();
			int min = FIND_THRES * FIND_THRES;
			DoubleData ret = null;
	        for (DoubleData elem1 : placement.keySet())
	        {
	        	DPoint guess = placement.get(elem1);
	        	int dx = (int)((guess.x - xmin) / (xmax - xmin) * w); 
	        	int dy = (int)((guess.y - ymin) / (ymax - ymin) * h);
	        	dx -= x-2;
	        	dy -= y-2;
	        	if (dx*dx+dy*dy < min)
	        	{
	        		min = dx*dx+dy*dy;
	        		ret = elem1;
	        	}
	        }
	        return ret;
		}

		public void mouseClicked(MouseEvent e)
		{
			selected = findObject(e.getX(), e.getY());
			painter.repaint();			
		}

		public void mouseEntered(MouseEvent e)
		{
		}

		public void mouseExited(MouseEvent e)
		{
		}

		public void mousePressed(MouseEvent e)
		{
		}

		public void mouseReleased(MouseEvent e)
		{
		}
	}

    class DPoint
    {
    	public double x;
    	public double y;
    	
    	public DPoint(double x, double y)
    	{
    		this.x = x;
    		this.y = y;
    	}
        
        public DPoint(double range)
        {
            x = rnd.nextDouble()*range*2 - range;
            y = rnd.nextDouble()*range*2 - range;
        }
        
        public DPoint vect(DPoint p)
        {
            return new DPoint(p.x - x, p.y - y);        
        }

        public double len()
        {
            return Math.sqrt(x*x + y*y);
        }
        
        public double dist(DPoint p)
        {
        	double dx = p.x - x;
        	double dy = p.y - y;
        	return Math.sqrt(dx*dx + dy*dy);
        }
        
        public int hashCode()
        {
        	return (int)(x*1000000+y*1000);
        }
    }
    
    class HashEntry
    {
    	public DoubleData p1;
    	public DoubleData p2;
    	
    	public HashEntry(DoubleData p1, DoubleData p2)
    	{
    		this.p1 = p1;
    		this.p2 = p2;
    	}
    	
    	public int hashCode()
    	{
    		return p1.hashCode() * 33221 + p2.hashCode()*71;
    	}
    }
}

Copyright © 2008-2011 by TunedIT
Design by luksite