Repository /Debellor/debellor-1.0.jar:org.debellor.base.KMeans


Back

No file description

Source code

/*
 *  Debellor
 *
 *  Copyright (C) 2008-2009 by Marcin Wojnarski
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see http://www.gnu.org/licenses/.
 */
package org.debellor.base;

import java.io.PrintStream;
import java.util.Arrays;

import org.debellor.core.Cell;
import org.debellor.core.Sample;
import org.debellor.core.Sample.SampleType;
import org.debellor.core.data.DataVector;
import org.debellor.core.data.NumericFeature;
import org.debellor.core.data.SymbolicFeature;
import org.debellor.core.data.DataVector.DataVectorType;
import org.debellor.core.data.SymbolicFeature.SymbolicFeatureType;
import org.debellor.core.util.Permute;

/**
 * Stream implementation of standard k-means clustering algorithm.
 * Can be trained on large data without risk of memory overflow. 
 * 
 * @author Marcin Wojnarski
 *
 */
public class KMeans extends Cell {

	private double[][] centers;
	
	/** temporal variable, to store current sample as double array */
	private double[] sample;
	
	private PrintStream msg = System.err;

	private int ndim;
	private int nclust;
	
	/** threshold for stopping criterion */
	private double delta = 0;

	private int[] clusterSizes;

	private SymbolicFeatureType decisionType;

	private Stream input;
	
	
	@Override
	protected void onLearn() throws Exception {
		init();
		
		int cycle = 0;
		while(true) {
			msg.println("K-means, starting cycle: " + ++cycle);
			
			double[][] newCenters = new double[nclust][ndim];
			clusterSizes = new int[nclust];
			
			// assign samples to clusters, store accumulated vectors,
			// to be used for cluster center calculation
			input = openInputStream();
			Sample s;
			while((s = input.next()) != null) {
				sampleToArray(s, sample);
				int c = findCluster();
				for(int j = 0; j < ndim; j++)
					newCenters[c][j] += sample[j];
				clusterSizes[c]++;
			}
			input.close();
			
			msg.print(" Sizes of clusters: ");
			
			// calculate new cluster centers			
			for(int c = 0; c < nclust; c++) {
				msg.print(" " + clusterSizes[c]);
				if(clusterSizes[c] == 0) continue;
				double[] center = newCenters[c];
				int size = clusterSizes[c];
				for(int i = 0; i < ndim; i++)
					center[i] /= size;
			}
			msg.println();
			
			boolean finish = areSame(centers, newCenters);
			centers = newCenters;

			if(finish) break;
		}
	}

	private void init() throws Exception {
		msg.println("K-means, 1st initialization pass (count no. of samples)");
		
		int nsamp = 0;
		input = openInputStream();
		SampleType type = input.sampleType;
		checkType(type);
		while(input.next() != null)
			nsamp++;
		input.close();
		
		msg.println(" Number of samples: " + nsamp);
		msg.println("K-means, 2nd initialization pass (choose centers)");

		nclust = parameters.getAsInt("numClusters");
		centers = new double[nclust][ndim];
		sample = new double[ndim];
		
		// pick randomly 'nclust' indices of samples and sort
		int[] ind = Permute.indices(nsamp, nclust, random);
		Arrays.sort(ind);
		
		// initialize centers with the drawn samples
		// This step is a bit risky, as we cannot be sure that
		// after source reopening the data stream will have the same size.
		// In a future release of Debellor there will be a better control of reopening. 
		input = openInputStream();
		Sample s;
		for(int c = 0, i = 0; (s = input.next()) != null; i++) {
			if(ind[c] == i) {
				sampleToArray(s, centers[c++]);
				if(c == nclust)
					break;
			}
		}
		input.close();	
		
		// print samples, for debugging
		for(int c = 0; c < nclust; c++)
			msg.println("Center #" + c + " = sample #" + ind[c] + ":  " + 
					centers[c][0] + "," + centers[c][1] + ",...");
	}

	/** Converts sample data to array of doubles. 
	 * Caution: array <code>a</code> will be modified! */
	private void sampleToArray(Sample s, double[] a) {
		DataVector data = (DataVector) s.data;
		for(int i = 0; i < ndim; i++)
			a[i] = ((NumericFeature) data.get(i)).value;
	}

	private void checkType(SampleType type) throws Exception {
		DataVectorType dataType = (DataVectorType)type.data;
		ndim = dataType.size();
		if(ndim <= 0) throw new Exception();
		for(int i = 0; i < ndim; i++)
			if(dataType.get(i).dataClass != NumericFeature.class)
				throw new Exception();
	}

	private int findCluster() {
		int minc = 0;
		double mind = distToCenter(0);
		for(int c = 1; c < nclust; c++) {
			double d = distToCenter(c);
			if(d < mind) {
				mind = d;
				minc = c;
			}
		}
		return minc;
	}

	private double distToCenter(int c) {
		double[] center = centers[c];
		double dist = 0;
		for(int i = 0; i < ndim; i++) {
			double d = sample[i] - center[i];
			dist += d * d;
		}
		return dist;
	}

	private boolean areSame(double[][] c1, double[][] c2) {
		double move = 0;
		for(int i = 0; i < nclust; i++)
			for(int j = 0; j < ndim; j++)
				if(c1[i][j] != c2[i][j]) {
					double d = c1[i][j] - c2[i][j];
					move += d * d;
				}
		move = Math.sqrt(move);
		msg.println(" centerMove =  " + move);
		if(move > delta) return false;
		return true;
	}

	@Override
	protected SampleType onOpen() throws Exception {
		input = openInputStream();
		SampleType type = input.sampleType;
		DataVectorType dataType = (DataVectorType)type.data;
		if(ndim != dataType.size()) throw new Exception();
		for(int i = 0; i < ndim; i++)
			if(dataType.get(i).dataClass != NumericFeature.class)
				throw new Exception();
		
		decisionType = new SymbolicFeatureType(nclust);
		return type.setDecision(decisionType);
	}

	@Override
	protected Sample onNext() throws Exception {
		Sample s = input.next();
		if(s == null) return null;
		sampleToArray(s, sample);
		return s.setDecision(new SymbolicFeature(decisionType.get(findCluster())));
	}

	@Override
	protected void onClose() throws Exception {
		input.close();
		decisionType = null;
	}

}

Copyright © 2008-2011 by TunedIT
Design by luksite