/*
* Copyright (C) 2009 by TunedIT
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package rsctc2010;
import java.util.HashMap;
import java.util.Map;
import org.debellor.base.evaluator.score.Score;
import org.debellor.core.DataObject;
/**
* Calculates balanced accuracy of a decision system (typically a classifier),
* i.e. the fraction of samples for which predicted decision
* was exactly the same as the target one,
* calculated separately for each possible target decision
* and averaged over target decisions, without accounting for different frequencies of decisions
* (so rarely occuring decisions have the same impact on overall result
* as the most frequent one).
*
* @author Marcin Wojnarski
*
*/
public class BalancedAccuracy extends Score {
private static class Counter {
/** Number of correct predictions for a given target decision */
public int good = 0;
/** Number of incorrect predictions for a given target decision */
public int bad = 0;
}
private Map<DataObject, Counter> stats;
public BalancedAccuracy() {
reset();
}
@Override
public void reset() {
stats = new HashMap<DataObject, Counter>();
}
@Override
public void add(DataObject target, DataObject prediction) {
if(target == null) return;
Counter counter = stats.get(target);
if(counter == null) {
counter = new Counter();
stats.put(target, counter);
}
if(target.equals(prediction))
counter.good++;
else
counter.bad++;
}
@Override
public String report() {
return "Balanced accuracy: " + result() * 100 + "%";
}
@Override
public double result() {
if(stats.size() == 0) return 1.0;
double sum = 0.0;
for(Counter c : stats.values())
sum += c.good / ((double) (c.good + c.bad));
return sum / stats.size();
}
}