/*
* Debellor
*
* Copyright (C) 2008-2009 by Marcin Wojnarski
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package org.debellor.base;
import java.io.PrintStream;
import java.util.Arrays;
import org.debellor.core.Cell;
import org.debellor.core.Sample;
import org.debellor.core.Sample.SampleType;
import org.debellor.core.data.DataVector;
import org.debellor.core.data.NumericFeature;
import org.debellor.core.data.SymbolicFeature;
import org.debellor.core.data.DataVector.DataVectorType;
import org.debellor.core.data.SymbolicFeature.SymbolicFeatureType;
import org.debellor.core.util.Permute;
/**
* Stream implementation of standard k-means clustering algorithm.
* Can be trained on large data without risk of memory overflow.
*
* @author Marcin Wojnarski
*
*/
public class KMeans extends Cell {
private double[][] centers;
/** temporal variable, to store current sample as double array */
private double[] sample;
private PrintStream msg = System.err;
private int ndim;
private int nclust;
/** threshold for stopping criterion */
private double delta = 0;
private int[] clusterSizes;
private SymbolicFeatureType decisionType;
private Stream input;
@Override
protected void onLearn() throws Exception {
init();
int cycle = 0;
while(true) {
msg.println("K-means, starting cycle: " + ++cycle);
double[][] newCenters = new double[nclust][ndim];
clusterSizes = new int[nclust];
// assign samples to clusters, store accumulated vectors,
// to be used for cluster center calculation
input = openInputStream();
Sample s;
while((s = input.next()) != null) {
sampleToArray(s, sample);
int c = findCluster();
for(int j = 0; j < ndim; j++)
newCenters[c][j] += sample[j];
clusterSizes[c]++;
}
input.close();
msg.print(" Sizes of clusters: ");
// calculate new cluster centers
for(int c = 0; c < nclust; c++) {
msg.print(" " + clusterSizes[c]);
if(clusterSizes[c] == 0) continue;
double[] center = newCenters[c];
int size = clusterSizes[c];
for(int i = 0; i < ndim; i++)
center[i] /= size;
}
msg.println();
boolean finish = areSame(centers, newCenters);
centers = newCenters;
if(finish) break;
}
}
private void init() throws Exception {
msg.println("K-means, 1st initialization pass (count no. of samples)");
int nsamp = 0;
input = openInputStream();
SampleType type = input.sampleType;
checkType(type);
while(input.next() != null)
nsamp++;
input.close();
msg.println(" Number of samples: " + nsamp);
msg.println("K-means, 2nd initialization pass (choose centers)");
nclust = parameters.getAsInt("numClusters");
centers = new double[nclust][ndim];
sample = new double[ndim];
// pick randomly 'nclust' indices of samples and sort
int[] ind = Permute.indices(nsamp, nclust, random);
Arrays.sort(ind);
// initialize centers with the drawn samples
// This step is a bit risky, as we cannot be sure that
// after source reopening the data stream will have the same size.
// In a future release of Debellor there will be a better control of reopening.
input = openInputStream();
Sample s;
for(int c = 0, i = 0; (s = input.next()) != null; i++) {
if(ind[c] == i) {
sampleToArray(s, centers[c++]);
if(c == nclust)
break;
}
}
input.close();
// print samples, for debugging
for(int c = 0; c < nclust; c++)
msg.println("Center #" + c + " = sample #" + ind[c] + ": " +
centers[c][0] + "," + centers[c][1] + ",...");
}
/** Converts sample data to array of doubles.
* Caution: array <code>a</code> will be modified! */
private void sampleToArray(Sample s, double[] a) {
DataVector data = (DataVector) s.data;
for(int i = 0; i < ndim; i++)
a[i] = ((NumericFeature) data.get(i)).value;
}
private void checkType(SampleType type) throws Exception {
DataVectorType dataType = (DataVectorType)type.data;
ndim = dataType.size();
if(ndim <= 0) throw new Exception();
for(int i = 0; i < ndim; i++)
if(dataType.get(i).dataClass != NumericFeature.class)
throw new Exception();
}
private int findCluster() {
int minc = 0;
double mind = distToCenter(0);
for(int c = 1; c < nclust; c++) {
double d = distToCenter(c);
if(d < mind) {
mind = d;
minc = c;
}
}
return minc;
}
private double distToCenter(int c) {
double[] center = centers[c];
double dist = 0;
for(int i = 0; i < ndim; i++) {
double d = sample[i] - center[i];
dist += d * d;
}
return dist;
}
private boolean areSame(double[][] c1, double[][] c2) {
double move = 0;
for(int i = 0; i < nclust; i++)
for(int j = 0; j < ndim; j++)
if(c1[i][j] != c2[i][j]) {
double d = c1[i][j] - c2[i][j];
move += d * d;
}
move = Math.sqrt(move);
msg.println(" centerMove = " + move);
if(move > delta) return false;
return true;
}
@Override
protected SampleType onOpen() throws Exception {
input = openInputStream();
SampleType type = input.sampleType;
DataVectorType dataType = (DataVectorType)type.data;
if(ndim != dataType.size()) throw new Exception();
for(int i = 0; i < ndim; i++)
if(dataType.get(i).dataClass != NumericFeature.class)
throw new Exception();
decisionType = new SymbolicFeatureType(nclust);
return type.setDecision(decisionType);
}
@Override
protected Sample onNext() throws Exception {
Sample s = input.next();
if(s == null) return null;
sampleToArray(s, sample);
return s.setDecision(new SymbolicFeature(decisionType.get(findCluster())));
}
@Override
protected void onClose() throws Exception {
input.close();
decisionType = null;
}
}