Repository /Debellor/debellor-1.0.jar:org.debellor.base.evaluator.TrainAndTest


Back

No file description

Source code

/*
 *  Debellor
 *
 *  Copyright (C) 2008-2009 by Marcin Wojnarski
 *
 *  This program is free software: you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation, either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see http://www.gnu.org/licenses/.
 */
package org.debellor.base.evaluator;

import java.util.ArrayList;

import org.debellor.core.Cell;
import org.debellor.core.Sample;
import org.debellor.core.Sample.SampleType;
import org.debellor.core.cell.BatchOfSamples;
import org.debellor.core.cell.OperatorFromCell;
import org.debellor.core.exception.DebellorException;
import org.debellor.core.exception.data.NoSamplesException;

/**
 * <p>Implements evaluation of a cell (decision system) with Train+Test procedure.
 * Samples are randomly shuffled before splitting into train and test sets.
 * See {@link EvaluatorCell} for more information about usage.
 * The evaluated cell must be <i>erasable</i> (it must correctly override {@link Cell#onErase}).
 * </p>
 * 
 * <p>Parameters:</p>
 * <ul>
 * <li><i>trainPercent</i>: percentage of all samples to be used as training set. 
 * Should be between 1 and 100. Default: 70.
 * 
 * <li><i>repetitions</i>: number of independent repetitions of the train+test procedure.
 * Repetitions are independent and data are split randomly in each of them.
 * Final results are summed over all repetitions
 * (the same Score instance is used for all of them).
 * Default: 1.
 * 
 * <li><i>score</i>: name of the Score class to be used to measure quality of the evaluated decision system.
 * See docs for {@link EvaluatorCell}.
 * </ul>
 * 
 * @author Marcin Wojnarski
 *
 */
public class TrainAndTest extends EvaluatorCell {

	private static class Knobs {
		/** Percentage of all samples to be used as training set. 
		 * Should be between 1 and 100. */
		public int trainPercent = 70;
		
		/** Number of times the whole procedure (train+test) will be repeated.
		 * New random split is generated each time. */
		public int repetitions = 1;
	}
	private Knobs knobs;
	
	private ArrayList<Sample> samples;
	private ArrayList<Sample> testSet;
	private BatchOfSamples trainSet;
	private SampleType type;
	
	protected Stream input;
	
	
	public TrainAndTest() {}
	public TrainAndTest(Cell cell) {
		super(cell);
	}

	@Override
	protected void onLearn() throws Exception {
		if(learner.state() != Cell.State.EMPTY)
			learner.erase();
		
		readKnobs();
		readSamples();
		initScore();
		
		for(int rep = 0; rep < knobs.repetitions; rep++) {
			splitSamples();
			trainLearner();
			testLearner();
			learner.erase();
		}
		
		releaseData();
	}

	private void readKnobs() {
		knobs = new Knobs();
		try { 
			knobs.trainPercent = parameters.getAsInt("trainPercent");
		}
		catch(Exception e) {}	// if can't read the parameter, leave default value
		try { 
			knobs.repetitions = parameters.getAsInt("repetitions");
		}
		catch(Exception e) {}	// if can't read the parameter, leave default value
	}
	
	private void readSamples() throws DebellorException {
		samples = new ArrayList<Sample>();
		input = openInputStream();
		type = input.sampleType;
		Sample s;
		while((s = input.next()) != null)
			samples.add(s);
		input.close();
		if(samples.isEmpty()) throw new NoSamplesException();
	}

	private void splitSamples() {
		trainSet = new BatchOfSamples(type);
		
		int trainSize = (int) (samples.size() * (knobs.trainPercent / 100.0));
		if(trainSize < 0) trainSize = 1;
		if(trainSize > samples.size()) trainSize = samples.size();
		
		int left = samples.size();
		for(int i = 0; i < trainSize; i++) {
			int pos = random.nextInt(left);
			Sample temp = samples.get(pos);
			trainSet.add(temp);
			int last = left - 1;
			if(pos < last) {
				samples.set(pos, samples.get(last));
				samples.set(last, temp);
			}
			left--;
		}
		
		testSet = new ArrayList<Sample>( samples.subList(0, left) );
	}

	private void trainLearner() throws DebellorException {
		learner.setSource(trainSet);
		learner.learn();
	}

	private void testLearner() throws DebellorException {
		OperatorFromCell oper = new OperatorFromCell();
		oper.setType(type);
		oper.setCell(learner);
		for(Sample in : testSet)
			score.add(in.decision, oper.applyTo(in.setDecision(null)).decision);
		oper.closeCell();
	}

	protected void releaseData() {
		samples = null;
		testSet = null;
		trainSet = null;
		knobs = null;
		super.releaseData();
	}
	
}

Copyright © 2008-2011 by TunedIT
Design by luksite