/*
* Copyright (C) 2009 by Joanna Swietlicka
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package mnist;
import java.io.BufferedInputStream;
import java.io.EOFException;
import java.io.FileNotFoundException;
import java.io.IOException;
import org.debellor.core.Cell;
import org.debellor.core.DataObject;
import org.debellor.core.DataType;
import org.debellor.core.Sample;
import org.debellor.core.Sample.SampleType;
import org.debellor.core.data.DataVector;
import org.debellor.core.data.NumericFeature;
import org.debellor.core.data.SymbolicFeature;
import org.debellor.core.data.DataVector.DataVectorType;
import org.debellor.core.data.NumericFeature.NumericFeatureType;
import org.debellor.core.data.SymbolicFeature.SymbolicFeatureType;
/**
* @author Joanna Swietlicka
*/
public class Reader extends Cell {
public Reader(String iName, String lName) {
super(false);
labelFileName = lName;
imageFileName = iName;
readItems = 0;
}
@Override
protected SampleType onOpen() throws FileNotFoundException, Exception {
imageDataStream = new BufferedInputStream(getClass().getResourceAsStream(imageFileName));
labelDataStream = new BufferedInputStream(getClass().getResourceAsStream(labelFileName));
Integer[] rowNum = new Integer[1], colNum = new Integer[1];
itemsNum = readIHeader(rowNum, colNum);
attrNum = rowNum[0] * colNum[0];
if (!readLHeader().equals(itemsNum)){
throw new Exception("MNIST Reader: Inconsistent file headers.");
}
meta = createSampleType();
return meta;
}
/**
* Reads unsigned bytes
*/
private int[] readBytes(BufferedInputStream ds, int bytesNum) throws EOFException, IOException, Exception{
byte[] bytes = new byte[bytesNum];
int bNum = ds.read(bytes, 0, bytesNum);
if (bNum < 0){
throw new EOFException();
}
if (bNum != bytesNum){
throw new Exception("MNIST Reader: Wrong number of bytes read: "+bNum+", instead of: "+bytesNum);
}
int[] ret = new int[bytesNum];
for (int i = 0; i < bytesNum; i++){
ret[i] = bytes[i];
if (ret[i] < 0) ret[i] += 256;
}
return ret;
}
private Integer readIHeader(Integer[] rowNum, Integer[] colNum) throws IOException, Exception{
Integer ret=0;
rowNum[0]=0;
colNum[0]=0;
try{
int[] b;
if (imageDataStream.skip(4) != 4){//pomijamy pierwsze 4 bajty
throw new Exception("MNIST Reader: Wrong format of the image file.");
}
b = readBytes(imageDataStream, 4);
for (int i = 0; i < 4; i++){
ret += Math.pow(256, 3-i) * b[i];
}
b = readBytes(imageDataStream, 4);
for (int i = 0; i < 4; i++){
rowNum[0] += Math.pow(256, 3-i) * b[i];
}
b = readBytes(imageDataStream, 4);
for (int i = 0; i < 4; i++){
colNum[0] += Math.pow(256, 3-i) * b[i];
}
}catch(EOFException eofe){
throw new Exception("MNIST Reader: Image file too short.");
}
return ret;
}
private Integer readLHeader() throws IOException, Exception{
Integer ret=0;
try{
int[] b;
if (labelDataStream.skip(4) != 4){//pomijamy pierwsze 4 bajty
throw new Exception("MNIST Reader: Wrong format of the label file.");
}
b = readBytes(labelDataStream, 4);
for (int i = 0; i < 4; i++){
ret += Math.pow(256, 3-i) * b[i];
}
}catch(EOFException eofe){
throw new Exception("MNIST Reader: Label file too short.");
}
return ret;
}
private SampleType createSampleType(){
NumericFeatureType[] metaDataArr = new NumericFeatureType[attrNum];
for (int i = 0; i < attrNum; i++){
metaDataArr[i] = new NumericFeatureType();
}
DataVectorType metaData = new DataVectorType((DataType[])metaDataArr);
SymbolicFeatureType metaLabel = new SymbolicFeatureType(10);
SampleType ret = new SampleType(metaData, metaLabel);
return ret;
}
@Override
protected Sample onNext() throws Exception {
DataVector dv = readImageFile();
SymbolicFeature label = readLabelFile();
if (dv == null || label == null){
if (dv == null && label == null){
return null;
}
else{
throw new Exception("MNIST Reader: Different length of input files.");
}
}
readItems++;
if (readItems > itemsNum){
throw new Exception("MNIST Reader: Files too long.");
}
Sample sample = createSample(dv, label);
return sample;
}
private DataVector readImageFile() throws IOException, Exception{
DataObject[] data = new DataObject[attrNum];
try{
int[] bytes = readBytes(imageDataStream, attrNum);
for (int i = 0; i < attrNum; i++){
data[i] = new NumericFeature(bytes[i]);
}
}catch(EOFException eofe){
return null;
}
DataVector ret = new DataVector(data);
return ret;
}
private SymbolicFeature readLabelFile() throws IOException, Exception{
int[] val;
try{
val = readBytes(labelDataStream, 1);
SymbolicFeature ret = new SymbolicFeature(val[0], meta.decision);
return ret;
}catch(EOFException eofe){
return null;
}
}
private Sample createSample(DataVector dv, SymbolicFeature label){
Sample ret = new Sample(dv, label);
return ret;
}
@Override
protected void onClose() throws Exception {
imageDataStream.close();
labelDataStream.close();
imageDataStream = null;
labelDataStream = null;
readItems = 0;
itemsNum = null;
attrNum = null;
meta = null;
}
private String imageFileName, labelFileName;
private BufferedInputStream imageDataStream, labelDataStream;
private Integer itemsNum, attrNum, readItems;
private SampleType meta;
}