/*
* Debellor
*
* Copyright (C) 2008-2009 by Marcin Wojnarski
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package org.debellor.base.evaluator;
import java.util.ArrayList;
import org.debellor.core.Cell;
import org.debellor.core.Sample;
import org.debellor.core.Sample.SampleType;
import org.debellor.core.cell.BatchOfSamples;
import org.debellor.core.cell.OperatorFromCell;
import org.debellor.core.exception.DebellorException;
import org.debellor.core.exception.data.NoSamplesException;
/**
* <p>Implements evaluation of a cell (decision system) with Train+Test procedure.
* Samples are randomly shuffled before splitting into train and test sets.
* See {@link EvaluatorCell} for more information about usage.
* The evaluated cell must be <i>erasable</i> (it must correctly override {@link Cell#onErase}).
* </p>
*
* <p>Parameters:</p>
* <ul>
* <li><i>trainPercent</i>: percentage of all samples to be used as training set.
* Should be between 1 and 100. Default: 70.
*
* <li><i>repetitions</i>: number of independent repetitions of the train+test procedure.
* Repetitions are independent and data are split randomly in each of them.
* Final results are summed over all repetitions
* (the same Score instance is used for all of them).
* Default: 1.
*
* <li><i>score</i>: name of the Score class to be used to measure quality of the evaluated decision system.
* See docs for {@link EvaluatorCell}.
* </ul>
*
* @author Marcin Wojnarski
*
*/
public class TrainAndTest extends EvaluatorCell {
private static class Knobs {
/** Percentage of all samples to be used as training set.
* Should be between 1 and 100. */
public int trainPercent = 70;
/** Number of times the whole procedure (train+test) will be repeated.
* New random split is generated each time. */
public int repetitions = 1;
}
private Knobs knobs;
private ArrayList<Sample> samples;
private ArrayList<Sample> testSet;
private BatchOfSamples trainSet;
private SampleType type;
protected Stream input;
public TrainAndTest() {}
public TrainAndTest(Cell cell) {
super(cell);
}
@Override
protected void onLearn() throws Exception {
if(learner.state() != Cell.State.EMPTY)
learner.erase();
readKnobs();
readSamples();
initScore();
for(int rep = 0; rep < knobs.repetitions; rep++) {
splitSamples();
trainLearner();
testLearner();
learner.erase();
}
releaseData();
}
private void readKnobs() {
knobs = new Knobs();
try {
knobs.trainPercent = parameters.getAsInt("trainPercent");
}
catch(Exception e) {} // if can't read the parameter, leave default value
try {
knobs.repetitions = parameters.getAsInt("repetitions");
}
catch(Exception e) {} // if can't read the parameter, leave default value
}
private void readSamples() throws DebellorException {
samples = new ArrayList<Sample>();
input = openInputStream();
type = input.sampleType;
Sample s;
while((s = input.next()) != null)
samples.add(s);
input.close();
if(samples.isEmpty()) throw new NoSamplesException();
}
private void splitSamples() {
trainSet = new BatchOfSamples(type);
int trainSize = (int) (samples.size() * (knobs.trainPercent / 100.0));
if(trainSize < 0) trainSize = 1;
if(trainSize > samples.size()) trainSize = samples.size();
int left = samples.size();
for(int i = 0; i < trainSize; i++) {
int pos = random.nextInt(left);
Sample temp = samples.get(pos);
trainSet.add(temp);
int last = left - 1;
if(pos < last) {
samples.set(pos, samples.get(last));
samples.set(last, temp);
}
left--;
}
testSet = new ArrayList<Sample>( samples.subList(0, left) );
}
private void trainLearner() throws DebellorException {
learner.setSource(trainSet);
learner.learn();
}
private void testLearner() throws DebellorException {
OperatorFromCell oper = new OperatorFromCell();
oper.setType(type);
oper.setCell(learner);
for(Sample in : testSet)
score.add(in.decision, oper.applyTo(in.setDecision(null)).decision);
oper.closeCell();
}
protected void releaseData() {
samples = null;
testSet = null;
trainSet = null;
knobs = null;
super.releaseData();
}
}