/*
* $RCSfile: KnnVis.java,v $
* $Revision: 1.3 $
* $Date: 2007/08/18 10:44:06 $
* $Author: wojna $
*
* Copyright (C) 2002 - 2007 Logic Group, Institute of Mathematics, Warsaw University
*
* This file is part of Rseslib.
*
* Rseslib is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*
* Rseslib is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package rseslib.processing.classification.parameterised.knn;
import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.awt.event.MouseMotionListener;
import java.util.Hashtable;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import javax.swing.JButton;
import javax.swing.JCheckBox;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import javax.swing.JSplitPane;
import javax.swing.JTextField;
import rseslib.processing.classification.VisualClassifier;
import rseslib.structure.attribute.Attribute;
import rseslib.structure.attribute.Header;
import rseslib.structure.attribute.NominalAttribute;
import rseslib.structure.data.DoubleData;
import rseslib.structure.metric.Neighbour;
import rseslib.structure.table.DoubleDataTable;
import rseslib.system.PropertyConfigurationException;
import rseslib.system.progress.Progress;
public class KnnVis extends KnnClassifier implements VisualClassifier
{
/** Serialization version. */
private static final long serialVersionUID = 1L;
private Hashtable<DoubleData, DPoint> placement = new Hashtable<DoubleData, DPoint>();
private Random rnd = new Random(System.currentTimeMillis());
private JPanel pnl;
private Painter painter;
private double avg;
private double xmin = Double.NEGATIVE_INFINITY;
private double ymin = Double.NEGATIVE_INFINITY;
private double xmax = Double.POSITIVE_INFINITY;
private double ymax = Double.POSITIVE_INFINITY;
private int FIND_THRES = 10;
private int POINT_SIZE = 5;
private final double START_MULT = 0.02;
private final double MAX_JUMP = 0.1;
// private final double DECAY_MULT = 0.99; // dla kwadratow
private final double DECAY_MULT = 0.995; // dla liniowego
private final double DECAY_MIN = 0.01;
private final double EPSILON = 0.00000000000001;
private final int START_MAX_ROWS = 150;
double mult = START_MULT;
private String strLegend = "";
double currDev = 0.00;
double iter = 0;
double fProg = 0;
private Hashtable<Double, Integer> htCols = new Hashtable<Double, Integer>();
private int[] startcolors = new int[] { 0, 255, 255*256, 128, 128*256, 128*256+128, 128*256+255, 255*256+128 };
private Thread calcThread;
private DoubleDataTable table;
private JLabel lblInfo;
private boolean showDetails = true;
private JButton btnRun;
public KnnVis(Properties prop, DoubleDataTable trainTable, Progress prog) throws PropertyConfigurationException, InterruptedException
{
super(prop, trainTable, prog);
this.table = trainTable;
int cnt = m_TransformedTrainTable.noOfObjects();
avg = 0.00;
int dec = -1;
for (DoubleData v1 : m_TransformedTrainTable.getDataObjects())
{
if (dec == -1) dec = v1.attributes().decision();
if (!htCols.containsKey(v1.get(dec)))
{
int p = htCols.size();
if (p<startcolors.length)
htCols.put(v1.get(dec), startcolors[p]);
else
htCols.put(v1.get(dec), rnd.nextInt());
}
double partsum = 0.00;
for (DoubleData v2 : m_TransformedTrainTable.getDataObjects())
{
partsum += m_Metric.dist(v1, v2);
}
avg += partsum / cnt;
}
avg /= cnt;
strLegend = "<br><b>Decisions</b>:<br>";
for (Double key : htCols.keySet())
{
String name = NominalAttribute.stringValue(key);
int color = htCols.get(key);
int cr = color % 256;
int cg = (color >> 8) % 256;
int cb = (color >> 16) % 256;
String hexColor = toHex(cr) + toHex(cg) + toHex(cb);
strLegend += "<font color=#" + hexColor + ">" + name + "</font><br>";
}
findRandomPlacement(Integer.MAX_VALUE);
}
private String toHex(int val)
{
return ("" + "0123456789ABCDEF".charAt(val>>4)) + ("" + "0123456789ABCDEF".charAt(val%16));
}
private void findRandomPlacement(int maxcnt)
{
mult = START_MULT;
iter = 0;
placement.clear();
int cnt = m_TransformedTrainTable.noOfObjects();
for (DoubleData next : m_TransformedTrainTable.getDataObjects())
{
if (cnt < maxcnt || rnd.nextInt(cnt) < maxcnt)
{
DPoint guess = new DPoint(avg);
//obj.add(next);
placement.put(next, guess);
}
}
}
private DoubleData findOriginal(DoubleData dat)
{
int nr = 0;
boolean found = false;
for (DoubleData next : m_TransformedTrainTable.getDataObjects())
{
if (next == dat)
{
found = true;
break;
}
nr++;
}
if (found)
{
for (DoubleData next : table.getDataObjects())
{
if (nr == 0)
{
return next;
}
nr--;
}
}
return dat;
}
private String formatData(DoubleData dat)
{
dat = findOriginal(dat);
String out = "";
int cnt = dat.attributes().noOfAttr();
for (int i=0;i<cnt;i++)
{
Attribute attr = dat.attributes().attribute(i);
String val;
if (attr.isNominal())
{
val = NominalAttribute.stringValue(dat.get(i));
}
else
{
val = "" + dat.get(i);
if (val.length() > 5) val = val.substring(0, 5);
}
out += attr.name() + ": <i>" + val + "</i><br>";
}
return out;
}
public void draw(JPanel canvas)
{
if (canvas.equals(pnl)) return;
pnl = canvas;
painter = new Painter();
JScrollPane scroll = new JScrollPane(painter);
scroll.setVisible(true);
//canvas.add(scroll);
JSplitPane jsp = new JSplitPane(JSplitPane.HORIZONTAL_SPLIT);
canvas.add(jsp);
jsp.setRightComponent(scroll);
jsp.setDividerLocation(-1);
int max = m_TransformedTrainTable.noOfObjects();
if (max > START_MAX_ROWS) max = START_MAX_ROWS;
final JTextField jtMax = new JTextField("" + m_TransformedTrainTable.noOfObjects());
jtMax.setPreferredSize(new Dimension(100, 24));
JButton btnRestart = new JButton("Restart");
btnRestart.addActionListener(new ActionListener()
{
public void actionPerformed(ActionEvent e)
{
String txt = jtMax.getText();
int maxcnt = Integer.MAX_VALUE;
try
{
maxcnt = Integer.parseInt(txt);
}
catch (NumberFormatException ignored)
{
}
jtMax.setText("" + maxcnt);
findRandomPlacement(maxcnt);
painter.classified = null;
painter.repaint();
startThread();
}
});
/*
JButton btnStep = new JButton("Step");
btnStep.addActionListener(new ActionListener()
{
public void actionPerformed(ActionEvent e)
{
new Thread(new Runnable()
{
public void run()
{
findPlacement();
painter.repaint();
}
}).start();
}
});
*/
JButton btnClassify = new JButton("Classify selected");
btnClassify.addActionListener(new ActionListener()
{
public void actionPerformed(ActionEvent e)
{
if (painter.selected != null) drawClassify(pnl, findOriginal(painter.selected));
}
});
btnRun = new JButton("Start");
btnRun.addActionListener(new ActionListener()
{
public void actionPerformed(ActionEvent e)
{
if (calcThread == null)
{
startThread();
}
else
{
stopThread();
}
}
});
final JCheckBox jcbDetails = new JCheckBox("Show object details");
jcbDetails.setSelected(showDetails);
jcbDetails.addActionListener(new ActionListener()
{
public void actionPerformed(ActionEvent arg0)
{
showDetails = jcbDetails.isSelected();
painter.invalidate();
painter.repaint();
}
});
JPanel pnlBtns = new JPanel();
pnlBtns.setLayout(new FlowLayout());
pnlBtns.add(new JLabel("Number of points on graph: "));
pnlBtns.add(jtMax);
pnlBtns.add(btnRestart);
// pnlBtns.add(btnStep);
pnlBtns.add(btnRun);
pnlBtns.add(btnClassify);
pnlBtns.add(jcbDetails);
canvas.add(pnlBtns, BorderLayout.SOUTH);
lblInfo = new JLabel("");
JPanel pnlInfo = new JPanel();
pnlInfo.setLayout(new BorderLayout());
pnlInfo.add(lblInfo, BorderLayout.CENTER);
pnlInfo.setPreferredSize(new Dimension(200, 100));
pnlInfo.setAlignmentX(0);
lblInfo.setAlignmentX(0);
//canvas.add(pnlInfo, BorderLayout.EAST);
jsp.setLeftComponent(pnlInfo);
startThread();
}
private void stopThread()
{
fProg = 0;
btnRun.setText("Start");
if (calcThread != null) calcThread.stop();
calcThread = null;
}
private void startThread()
{
btnRun.setText("Stop");
calcThread = new Thread(new Runnable()
{
public void run()
{
try
{
while (true)
{
findPlacement();
painter.invalidate();
painter.revalidate();
painter.repaint();
try
{
Thread.sleep(50);
}
catch (InterruptedException ie) {}
}
}
catch (ThreadDeath td)
{
}
catch (Throwable thr)
{
thr.printStackTrace();
}
}
});
calcThread.start();
}
public double drawClassify(JPanel canvas, DoubleData obj)
{
DoubleData orgobj = obj;
try
{
mult = START_MULT;
iter = 0;
obj = m_Transformer.transformToNew(obj);
Neighbour[] n = m_VicinityProvider.getVicinity(obj, getIntProperty(K_PROPERTY_NAME));
placement.clear();
for (int i=0;i<n.length;i++)
{
DPoint guess = new DPoint(avg);
//obj.add(next);
if (n[i] == null) continue;
placement.put(n[i].neighbour(), guess);
}
draw(canvas);
placement.put(obj, new DPoint(avg));
int attr = m_TransformedTrainTable.attributes().noOfAttr();
double best = Double.MAX_VALUE;
DoubleData newobj = obj;
for (DoubleData next : placement.keySet())
{
double dist = 0;
for (int i=0;i<attr;i++)
{
double del = next.get(i) - obj.get(i);
if (del < 0) del = -del;
dist += del;
}
if (dist < best)
{
best = dist;
newobj = next;
}
}
obj = newobj;
painter.classified = obj;
painter.selected = obj;
painter.repaint();
if (calcThread == null)
{
startThread();
}
else
{
stopThread();
startThread();
}
return classify(orgobj);
}
catch (Exception exc)
{
exc.printStackTrace();
}
return 0;
}
public Header attributes()
{
return null;
}
private void findPlacement()
{
mult *= DECAY_MULT;
if (mult < DECAY_MIN) mult = DECAY_MIN;
fProg = 0;
int i = 0;
int s = placement.size();
int last = 0;
Hashtable<DoubleData, DPoint> htDelta = new Hashtable<DoubleData, DPoint>();
long reftime = System.currentTimeMillis();
Set<DoubleData> set = placement.keySet();
DoubleData[] arr = set.toArray(new DoubleData[0]);
int cnt = arr.length;
for (int e1 = 0;e1<cnt;e1++)
{
DoubleData elem1 = arr[e1];
i++;
last++;
long currtime = i%10==0 ? System.currentTimeMillis() : reftime;
if (currtime - reftime > 200)
{
reftime = currtime;
fProg = (double)i / s;
painter.invalidate();
painter.repaint();
last = 0;
}
DPoint guess = placement.get(elem1);
for (int e2 = 0;e2<cnt;e2++)
{
DoubleData elem2 = arr[e2];
double dist;
dist = m_Metric.dist(elem1, elem2);
if (dist == 0) continue;
DPoint delta;
if (htDelta.containsKey(elem2))
delta = htDelta.get(elem2);
else
delta = new DPoint(0, 0);
DPoint pp = placement.get(elem2);
DPoint vec = guess.vect(pp);
double veclen = vec.len();
if (veclen < EPSILON)
{
htDelta.put(elem2, delta);
continue;
}
else
{
vec.x /= veclen;
vec.y /= veclen;
//vec.z /= veclen;
}
double lenerr = veclen - dist;
double weight = mult * lenerr;
//double weight = mult * Math.signum(lenerr) * (lenerr*lenerr);
if (Math.abs(weight) < EPSILON)
{
htDelta.put(elem2, delta);
continue;
}
delta.x -= vec.x * weight;
delta.y -= vec.y * weight;
//delta.z -= vec.z * weight;
htDelta.put(elem2, delta);
}
}
xmax = Double.NEGATIVE_INFINITY;
ymax = Double.NEGATIVE_INFINITY;
xmin = Double.POSITIVE_INFINITY;
ymin = Double.POSITIVE_INFINITY;
for (DoubleData elem : placement.keySet())
{
DPoint opos = placement.get(elem);
DPoint delta = htDelta.get(elem);
if (delta.x > MAX_JUMP * avg) delta.x = MAX_JUMP*avg;
if (delta.x < -MAX_JUMP * avg) delta.x = -MAX_JUMP*avg;
if (delta.y > MAX_JUMP * avg) delta.y = MAX_JUMP*avg;
if (delta.y < -MAX_JUMP * avg) delta.y = -MAX_JUMP*avg;
opos.x += delta.x;
opos.y += delta.y;
placement.put(elem, opos);
if (opos.x < xmin) xmin = opos.x;
if (opos.x > xmax) xmax = opos.x;
if (opos.y < ymin) ymin = opos.y;
if (opos.y > ymax) ymax = opos.y;
}
iter ++;
fProg = 0;
painter.invalidate();
painter.repaint();
// System.out.println(System.currentTimeMillis() - starttime);
}
class Painter extends JPanel implements MouseMotionListener, MouseListener
{
/** Serialization version. */
private static final long serialVersionUID = 1L;
private DoubleData hovered;
private DoubleData selected;
private DoubleData classified;
public Painter()
{
addMouseListener(this);
addMouseMotionListener(this);
}
public void paint(Graphics g)
{
String info = "<html>";
//g.setFont(Font.getFont("Monospaced"));
int w = getWidth()-POINT_SIZE*2;
int h = getHeight()-POINT_SIZE*2;
g.clearRect(0, 0, getWidth(), getHeight());
g.translate(POINT_SIZE, POINT_SIZE);
{
int dec = -1;
for (DoubleData elem1 : placement.keySet())
{
if (dec == -1) dec = elem1.attributes().decision();
double col = elem1.get(dec);
int val = htCols.get(col);
DPoint guess = placement.get(elem1);
int x = (int)((guess.x - xmin) / (xmax - xmin) * w);
int y = (int)((guess.y - ymin) / (ymax - ymin) * h);
if (val < 0) val = -val;
if (elem1 == selected)
{
g.setColor(new Color(0, 0, 0));
g.fillOval(x-POINT_SIZE/2, y-POINT_SIZE/2, POINT_SIZE*2, POINT_SIZE*2);
}
g.setColor(new Color(val%256, (val/256)%256, 0));
g.fillOval(x, y, POINT_SIZE, POINT_SIZE);
if (elem1 == classified)
{
g.drawLine(x-POINT_SIZE*2, y+POINT_SIZE/2, x+POINT_SIZE*3, y+POINT_SIZE/2);
g.drawLine(x+POINT_SIZE/2, y-POINT_SIZE*2, x+POINT_SIZE/2, y+POINT_SIZE*3);
}
if (elem1 == hovered)
{
g.setColor(new Color(0, 0, 0));
g.fillOval(x+1, y+1, POINT_SIZE/2, POINT_SIZE/2);
}
}
g.setColor(Color.BLACK);
info += "<b>Iteration:</b> " + (int)iter + "<br>";
info += "<b>Mult:</b> " + mult + "<br>";
if (selected != null && showDetails)
{
info += "<b>Selected:</b><br>" + formatData(selected) + "<br>";
}
if (hovered != null && showDetails)
{
info += "<b>Hovered:</b><br>" + formatData(hovered) + "<br>";
}
if (selected != null && hovered != null)
{
DPoint p_sel = placement.get(selected);
DPoint p_hov = placement.get(hovered);
double len_met = m_Metric.dist(selected, hovered);
double len_vis = p_sel.dist(p_hov);
int x1 = (int)((p_sel.x - xmin) / (xmax - xmin) * w)+POINT_SIZE/2;
int y1 = (int)((p_sel.y - ymin) / (ymax - ymin) * h)+POINT_SIZE/2;
int x2 = (int)((p_hov.x - xmin) / (xmax - xmin) * w)+POINT_SIZE/2;
int y2 = (int)((p_hov.y - ymin) / (ymax - ymin) * h)+POINT_SIZE/2;
g.drawLine(x1, y1, x2, y2);
info += "<b>Distance:</b><br>";
info += "Metric: <i>" + len_met + "</i><br>";
info += "Visible: <i>" + len_vis + "</i><br>";
}
}
info += strLegend;
lblInfo.setText(info);
if (fProg > 0)
{
g.setColor(Color.BLACK);
g.fillRect(0, 0, (int)(w*fProg), 2);
}
}
public void mouseDragged(MouseEvent e)
{
}
public void mouseMoved(MouseEvent e)
{
hovered = findObject(e.getX(), e.getY());
painter.repaint();
}
private DoubleData findObject(int x, int y)
{
int w = getWidth();
int h = getHeight();
int min = FIND_THRES * FIND_THRES;
DoubleData ret = null;
for (DoubleData elem1 : placement.keySet())
{
DPoint guess = placement.get(elem1);
int dx = (int)((guess.x - xmin) / (xmax - xmin) * w);
int dy = (int)((guess.y - ymin) / (ymax - ymin) * h);
dx -= x-2;
dy -= y-2;
if (dx*dx+dy*dy < min)
{
min = dx*dx+dy*dy;
ret = elem1;
}
}
return ret;
}
public void mouseClicked(MouseEvent e)
{
selected = findObject(e.getX(), e.getY());
painter.repaint();
}
public void mouseEntered(MouseEvent e)
{
}
public void mouseExited(MouseEvent e)
{
}
public void mousePressed(MouseEvent e)
{
}
public void mouseReleased(MouseEvent e)
{
}
}
class DPoint
{
public double x;
public double y;
public DPoint(double x, double y)
{
this.x = x;
this.y = y;
}
public DPoint(double range)
{
x = rnd.nextDouble()*range*2 - range;
y = rnd.nextDouble()*range*2 - range;
}
public DPoint vect(DPoint p)
{
return new DPoint(p.x - x, p.y - y);
}
public double len()
{
return Math.sqrt(x*x + y*y);
}
public double dist(DPoint p)
{
double dx = p.x - x;
double dy = p.y - y;
return Math.sqrt(dx*dx + dy*dy);
}
public int hashCode()
{
return (int)(x*1000000+y*1000);
}
}
class HashEntry
{
public DoubleData p1;
public DoubleData p2;
public HashEntry(DoubleData p1, DoubleData p2)
{
this.p1 = p1;
this.p2 = p2;
}
public int hashCode()
{
return p1.hashCode() * 33221 + p2.hashCode()*71;
}
}
}