/*
* Debellor
*
* Copyright (C) 2008 by Marcin Wojnarski
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package org.debellor.base.evaluator;
import java.util.ArrayList;
import org.debellor.core.Cell;
import org.debellor.core.Parameters;
import org.debellor.core.Sample;
import org.debellor.core.Sample.SampleType;
import org.debellor.core.cell.BatchOfSamples;
import org.debellor.core.cell.OperatorFromCell;
import org.debellor.core.exception.DebellorException;
import org.debellor.core.exception.data.NoSamplesException;
import org.debellor.core.util.Permute;
/**
* <p>Implements evaluation of a cell (decision system) through the procedure of cross-validation (CV).
* See {@link EvaluatorCell} for more information about usage.
* The evaluated cell must be <i>erasable</i> (it must correctly override {@link Cell#onErase}).
* </p>
*
* <p>Parameters:</p>
* <ul>
* <li><i>folds</i>: number of folds of CV. Default: 10.</li>
*
* <li><i>repetitions</i>: number of times the entire procedure (cross-validation) will be repeated.
* Repetitions are independent and data are split randomly in each of them.
* Final results are summed over all repetitions
* (the same Score instance is used for all of them).
* Default: 1.</li>
*
* <li><i>fixTrain</i>: number of first samples of data that will always go to the training set,
* in every fold of CV. Only the remaining samples undergo splitting
* and can be used as a test set. Default: 0.</li>
*
* <li><i>reversed</i>: if <code>true</code>, the smaller part of each data split
* becomes a training set (normally it is the larger part that is used for training).
* Default: <code>false</code>.</li>
*
* <li><i>score</i>: name of the Score class to be used to measure quality of the evaluated decision system.
* See docs for {@link EvaluatorCell}.
* </ul>
*
* @author Marcin Wojnarski
*
*/
public class CrossValidation extends EvaluatorCell {
private static class Knobs {
/** Number of folds of cross-validation. */
public int folds = 10;
/** Number of times the entire procedure (cross-validation) will be repeated. */
public int repetitions = 1;
/** Number of first samples of data that will always go to the training set, in every fold of CV.
* Only the remaining samples undergo splitting and can be used as a test set. */
public int fixTrain = 0;
/** If true, the smaller part of each data split becomes training set */
public boolean reversed = false;
}
private Knobs knobs = new Knobs();
private SampleType sampleType;
private ArrayList<Sample> samples;
protected Stream input;
public CrossValidation() {}
public CrossValidation(Cell cell) {
super(cell);
}
@Override
protected void onLearn() throws Exception {
if(learner.state() != Cell.State.EMPTY)
learner.erase();
readKnobs();
readSamples();
initScore();
int size = samples.size();
for(int rep = 0; rep < knobs.repetitions; rep++) {
int[] permutation = Permute.indices(size - knobs.fixTrain, random);
for(int fold = 0; fold < knobs.folds; fold++) {
ArrayList<Sample> train = new ArrayList<Sample>(size);
ArrayList<Sample> test = new ArrayList<Sample>(size);
splitSamplesInto(train, test, permutation, fold);
trainAndTestLearner(train, test);
learner.erase();
}
}
releaseData();
}
private void readKnobs() {
Parameters par = parameters;
try {
knobs.folds = par.getAsInt("folds");
}
catch(Exception e) {} // if can't read the parameter, leave default value
try {
knobs.repetitions = par.getAsInt("repetitions");
}
catch(Exception e) {} // if can't read the parameter, leave default value
try {
knobs.fixTrain = par.getAsInt("fixTrain");
}
catch(Exception e) {} // if can't read the parameter, leave default value
try {
knobs.reversed = par.getAsBool("reversed");
}
catch(Exception e) {} // if can't read the parameter, leave default value
}
private void readSamples() throws DebellorException {
samples = new ArrayList<Sample>();
input = openInputStream();
sampleType = input.sampleType;
Sample s;
while((s = input.next()) != null)
samples.add(s);
input.close();
if(samples.isEmpty()) throw new NoSamplesException();
}
private void splitSamplesInto(ArrayList<Sample> train, ArrayList<Sample> test, int[] permutation, int fold) {
// firstly, insert samples that always go to the training set
train.addAll(samples.subList(0, knobs.fixTrain));
// then, split remaining samples into train/test
int size = samples.size() - knobs.fixTrain;
int p = size * fold / knobs.folds; // 1st index of test set
int q = size * (fold+1) / knobs.folds; // 1st index after test set
for(int i = 0; i < size; i++) {
Sample s = samples.get(permutation[i] + knobs.fixTrain);
if(((i >= p) && (i < q)) ^ knobs.reversed)
test.add(s);
else
train.add(s);
}
}
private void trainAndTestLearner(ArrayList<Sample> train, ArrayList<Sample> test) throws DebellorException {
BatchOfSamples trainCell = new BatchOfSamples();
trainCell.setType(sampleType);
trainCell.add(train);
learner.setSource(trainCell);
learner.learn();
OperatorFromCell oper = new OperatorFromCell();
oper.setType(sampleType);
oper.setCell(learner);
for(Sample in : test)
score.add(in.decision, oper.applyTo(in.setDecision(null)).decision);
oper.closeCell();
}
protected void releaseData() {
samples = null;
sampleType = null;
knobs = null;
learner = null;
super.releaseData();
}
}