Repository /Debellor/debellor-1.0.jar:org.debellor.base.evaluator.CrossValidation


Back

No file description

Source code

/*
 *  Debellor
 *
 *  Copyright (C) 2008 by Marcin Wojnarski
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see http://www.gnu.org/licenses/.
 */
package org.debellor.base.evaluator;

import java.util.ArrayList;

import org.debellor.core.Cell;
import org.debellor.core.Parameters;
import org.debellor.core.Sample;
import org.debellor.core.Sample.SampleType;
import org.debellor.core.cell.BatchOfSamples;
import org.debellor.core.cell.OperatorFromCell;
import org.debellor.core.exception.DebellorException;
import org.debellor.core.exception.data.NoSamplesException;
import org.debellor.core.util.Permute;

/**
 * <p>Implements evaluation of a cell (decision system) through the procedure of cross-validation (CV).
 * See {@link EvaluatorCell} for more information about usage.
 * The evaluated cell must be <i>erasable</i> (it must correctly override {@link Cell#onErase}).
 * </p>
 * 
 * <p>Parameters:</p>
 * <ul>
 * <li><i>folds</i>: number of folds of CV. Default: 10.</li>
 * 
 * <li><i>repetitions</i>: number of times the entire procedure (cross-validation) will be repeated.
 * Repetitions are independent and data are split randomly in each of them.
 * Final results are summed over all repetitions
 * (the same Score instance is used for all of them).
 * Default: 1.</li>
 * 
 * <li><i>fixTrain</i>: number of first samples of data that will always go to the training set,
 * in every fold of CV. Only the remaining samples undergo splitting 
 * and can be used as a test set. Default: 0.</li>
 * 
 * <li><i>reversed</i>: if <code>true</code>, the smaller part of each data split 
 * becomes a training set (normally it is the larger part that is used for training).
 * Default: <code>false</code>.</li>
 * 
 * <li><i>score</i>: name of the Score class to be used to measure quality of the evaluated decision system.
 * See docs for {@link EvaluatorCell}.
 * </ul>
 * 
 * @author Marcin Wojnarski
 *
 */
public class CrossValidation extends EvaluatorCell {

	private static class Knobs {
		/** Number of folds of cross-validation. */
		public int folds = 10;
		
		/** Number of times the entire procedure (cross-validation) will be repeated. */
		public int repetitions = 1;

		/** Number of first samples of data that will always go to the training set, in every fold of CV.
		 * Only the remaining samples undergo splitting and can be used as a test set. */
		public int fixTrain = 0;
		
		/** If true, the smaller part of each data split becomes training set */
		public boolean reversed = false;
	}
	private Knobs knobs = new Knobs();
	
	private SampleType sampleType;
	private ArrayList<Sample> samples;
	
	protected Stream input;

	
	public CrossValidation() {}
	public CrossValidation(Cell cell) {
		super(cell);
	}

	@Override
	protected void onLearn() throws Exception {
		if(learner.state() != Cell.State.EMPTY)
			learner.erase();
		
		readKnobs();
		readSamples();
		initScore();

		int size = samples.size();
		for(int rep = 0; rep < knobs.repetitions; rep++) {
			int[] permutation = Permute.indices(size - knobs.fixTrain, random);
			
			for(int fold = 0; fold < knobs.folds; fold++) {
				ArrayList<Sample> train = new ArrayList<Sample>(size);
				ArrayList<Sample> test = new ArrayList<Sample>(size);

				splitSamplesInto(train, test, permutation, fold);
				trainAndTestLearner(train, test);
				learner.erase();
			}
		}
		
		releaseData();
	}

	private void readKnobs() {
		Parameters par = parameters;
		try { 
			knobs.folds = par.getAsInt("folds");
		}
		catch(Exception e) {}	// if can't read the parameter, leave default value
		try { 
			knobs.repetitions = par.getAsInt("repetitions");
		}
		catch(Exception e) {}	// if can't read the parameter, leave default value
		try { 
			knobs.fixTrain  = par.getAsInt("fixTrain");
		}
		catch(Exception e) {}	// if can't read the parameter, leave default value
		try { 
			knobs.reversed = par.getAsBool("reversed");
		}
		catch(Exception e) {}	// if can't read the parameter, leave default value
	}
	
	private void readSamples() throws DebellorException {
		samples = new ArrayList<Sample>();
		input = openInputStream();
		sampleType = input.sampleType;
		Sample s;
		while((s = input.next()) != null)
			samples.add(s);
		input.close();
		if(samples.isEmpty()) throw new NoSamplesException();
	}

	private void splitSamplesInto(ArrayList<Sample> train, ArrayList<Sample> test, int[] permutation, int fold) {
		// firstly, insert samples that always go to the training set
		train.addAll(samples.subList(0, knobs.fixTrain));
		
		// then, split remaining samples into train/test
		int size = samples.size() - knobs.fixTrain;
		int p = size * fold / knobs.folds;		// 1st index of test set
		int q = size * (fold+1) / knobs.folds;	// 1st index after test set
		for(int i = 0; i < size; i++) {
			Sample s = samples.get(permutation[i] + knobs.fixTrain);
			if(((i >= p) && (i < q)) ^ knobs.reversed)
				test.add(s);
			else
				train.add(s);
		}			
	}

	private void trainAndTestLearner(ArrayList<Sample> train, ArrayList<Sample> test) throws DebellorException {
		BatchOfSamples trainCell = new BatchOfSamples();
		trainCell.setType(sampleType);
		trainCell.add(train);
		
		learner.setSource(trainCell);
		learner.learn();
		
		OperatorFromCell oper = new OperatorFromCell();
		oper.setType(sampleType);
		oper.setCell(learner);
		
		for(Sample in : test)
			score.add(in.decision, oper.applyTo(in.setDecision(null)).decision);
		oper.closeCell();
	}

	protected void releaseData() {
		samples = null;
		sampleType = null;
		knobs = null;
		learner = null;
		super.releaseData();
	}
	
}

Copyright © 2008-2011 by TunedIT
Design by luksite