Repository /Rseslib/rseslib-3.0.2.jar:rseslib.processing.classification.neural.NeuronNetworkEngine


Back

No file description

Source code

/*
 * $RCSfile: NeuronNetworkEngine.java,v $
 * $Revision: 1.10 $
 * $Date: 2007/08/18 09:49:56 $
 * $Author: wojna $
 * 
 * Copyright (C) 2002 - 2007 Logic Group, Institute of Mathematics, Warsaw University
 * 
 *  This file is part of Rseslib.
 *
 *  Rseslib is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 3 of the License, or
 *  (at your option) any later version.
 *
 *  Rseslib is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */


package rseslib.processing.classification.neural;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Set;

import rseslib.structure.data.DoubleData;
import rseslib.structure.data.DoubleDataWithDecision;
import rseslib.structure.table.DoubleDataTable;
import rseslib.system.Report;

/**
 * @author Jakub Sakowicz
 *
 * Silnik sieci neuronowej
 *
 */
public class NeuronNetworkEngine {
	
	/**
	 * Konstruktor
	 * @param fullTable    - cala tabelka z danymi
	 * @param trainData    - czesc treningowa
	 * @param validateData - czesc walidacyjna
	 */
	public NeuronNetworkEngine(DoubleDataTable fullTable, ArrayList trainData, ArrayList validateData) {
	
		this.fullTable = fullTable;
		this.trainData = (ArrayList)trainData.clone(); // kopia, gdyz pozniej randomizowana jest kolejnosc
		this.validateData = validateData;
		this.setup();
	}
	
	/**
	 * Konstruktor
	 * @param fullTable    - cala tabelka z danymi
	 * @param trainData    - czesc treningowa
	 * @param validateData - czesc walidacyjna
	 * @param hidden_no     - liczba ukrytych warstw
	 * @param structure     - lista licznosci kolejnych warstw
	 */
	public NeuronNetworkEngine(DoubleDataTable fullTable, ArrayList trainData, ArrayList validateData, int hidden_no, int[] structure) {
	
		this.fullTable = fullTable;
		this.trainData = (ArrayList)trainData.clone(); // kopia, gdyz pozniej randomizowana jest kolejnosc
		this.validateData = validateData;
		this.setup(hidden_no, structure);
		//raz to jestem (przychodzac z learn) juz z tablica uwzgledniajaca input'a i zakonczenie ale poprawnym hidden_no
		//czy jestem tu a innym razem?
	}	
	
	/**
	 * Warstwa wejsiowa (bez perceptronow, tylko tlumaczenie wejscia)
	 */
	public DoubleDataProvider input;
	/**
	 * Ilosci perceptronow w warstwach
	 */
	public int[] noOfPerceptronsInLayer;
	/**
	 * Warstwy sieci
	 */
	public List<Layer> layers;
	
	/**
	 * Tabela z danymi 
	 */	
	private DoubleDataTable fullTable;
	/**
	 * Czesc treningowa
	 */
	private ArrayList trainData;
	/**
	 * Czesc walidacyjna
	 */
	private ArrayList validateData;
	/**
	 * Iterator po danych trenignowych
	 */
	private Iterator dataIterator;
	
	/**
	 * Skutecznosc tego silnika
	 */
	private double myRatio = 0;
	/**
	 * Parametr ALFA tego silnika
	 */
	private double ALFA = Global.INITIAL_ALFA;
	
	/**
	 * Lista dostepnych wartosci parametru decyzyknego
	 */
	public List availableResults;
	/**
	 * Obiekt przetwarzajacy wyjscia ostatniej warstwy na wartosci parametru decyzyjnego
	 */
	private ResultCombiner resultCombiner;
	
	/**
	 * Pojedyncza faza uczenia sieci. Wykonanie Global.ITER_COUNT algorytmu 
	 * back-prop-update.
	 */
	public void learn() {		
		if (myRatio > Global.DEST_TARGET_RATIO) return; // po co wiecej sie uczyc...
		if (myRatio > 50) ALFA *= Global.MULT_ALFA;
		if (myRatio > 99) ALFA *= Global.MULT_ALFA;
		shuffle();
		Global.CURRENT_ALFA = Math.max(ALFA, Global.MIN_ALFA);
		for (int i = 0; i < Global.ITER_COUNT; i++) {
			backPropUpdate();						
		}
		
	}

	/**
	 * Przelicza i zwraca wynik dla wejscia
	 * @param dd - rekord z danymi
	 * @return przewidywana wartosc decyzyjna
	 */
	public double classify(DoubleData dd) {
		count(dd);
		return resultCombiner.getResult();
	}
	
	/**
	 * Zwraca celnosc sieci na danych walidacyjnych
	 * @return liczba zmiennoprzecinkowa z zakresu 0..100
	 */
	public double targetRatio() {
		int hit = 0;
		for (Iterator i = validateData.iterator(); i.hasNext(); ) {
			DoubleDataWithDecision dd = (DoubleDataWithDecision)i.next();
			if (this.classify(dd) == dd.getDecision())
				hit++;
		}
		myRatio = (double)hit/(double)validateData.size() * 100; 
		return myRatio;
	}
	
	/**
	 * Inicjalizacja sieci
	 *
	 */
	private void setup() {			
		shuffle(); 								// pomieszanie danych treningowych
		noOfPerceptronsInLayer = new int[Global.NO_OF_LAYERS];
		createInputProvider(); 					// utworzenie obiektu odpowiadajacego za wejscie
		createResultCombiner();					// utworzenie obiektu odpowiadajacego za wyjscie
		// ilosci perceptronow w postepie geometrycznym
		int av = availableResults.size();
		int in = input.noOfInputs();	
		//System.out.print("AV: " + av + ", in: " + in + " \n");
		double mult = Math.pow((double)av / (double)in, 1/(double)(Global.NO_OF_LAYERS));
		double current = in * mult;
		for (int i=0; i < Global.NO_OF_LAYERS - 1; i++) {
			noOfPerceptronsInLayer[i] = (int)current;
			current *= mult;
			Report.debugnl("Warstwa " + i + " perceptronow: " + noOfPerceptronsInLayer[i]);
		}
		noOfPerceptronsInLayer[Global.NO_OF_LAYERS - 1] = av; 
		createLayers();
	}
	
	/**
	 * Inicjalizacja sieci
	 * @param hidden_no     - liczba ukrytych warstw
	 * @param structure     - lista licznosci kolejnych warstw
	 */
	private void setup(int hidden_no, int[] structure) {			
		shuffle(); 								// pomieszanie danych treningowych
		noOfPerceptronsInLayer = new int[hidden_no + 1];
		createInputProvider(); 					// utworzenie obiektu odpowiadajacego za wejscie
		createResultCombiner();					// utworzenie obiektu odpowiadajacego za wyjscie
		int av = availableResults.size();
        int i;
		for (i=0; i<hidden_no; i++) {
			noOfPerceptronsInLayer[i] = structure[i+1];	
			}	
		noOfPerceptronsInLayer[hidden_no] = av; 
		createLayers(hidden_no);
	}
	
	/**
	 * Miesza dane treningowe
	 */
	private void shuffle() {
		Collections.shuffle(trainData);
		dataIterator = trainData.iterator();
	}
	
	/**
	 * Tworzy obiekt zarzadzajacy wejsciem
	 */
	private void createInputProvider() {
		input = new DoubleDataProvider(fullTable);
	}
	
	/**
	 * Oblicza ilosc mozliwych rezultatow i tworzy obiekt odpowiadajacy za wyjscie.
	 */
	private void createResultCombiner() {
		Set<Double> availableResultsSet = new HashSet<Double>();
		for (DoubleData dd : fullTable.getDataObjects()) {
			availableResultsSet.add(new Double(((DoubleDataWithDecision)dd).getDecision()));
		}
		availableResults = new ArrayList<Double>(availableResultsSet);
		resultCombiner = new ResultCombiner(availableResults);
	}
	
	/**
	 * Tworzy poszczegolne warstwy zgodnie z wartosciami w tablicy
	 * noOfPerceptronsInLayer. Podlacza wejscia i wyjscia
	 */
	private void createLayers() {
		layers = new ArrayList<Layer>();
		IInputProvider currentLayerInput = input;
		for (int i = 0; i < Global.NO_OF_LAYERS; i++ ) {
			Layer layer = new Layer(noOfPerceptronsInLayer[i]);
			layers.add(layer);
			layer.setInput(currentLayerInput);
			currentLayerInput = layer.getOutput();
		}
		resultCombiner.setInput(currentLayerInput);
	}
	
	/**
	 * Tworzy poszczegolne warstwy zgodnie z wartosciami w tablicy
	 * noOfPerceptronsInLayer. Podlacza wejscia i wyjscia
	 */
	private void createLayers(int hidden_no) {
		layers = new ArrayList<Layer>();
		IInputProvider currentLayerInput = input;
		for (int i = 0; i < (hidden_no + 1); i++ ) {
			Layer layer = new Layer(noOfPerceptronsInLayer[i]);
			layers.add(layer);
			layer.setInput(currentLayerInput);
			currentLayerInput = layer.getOutput();
		}
		resultCombiner.setInput(currentLayerInput);
	}	
		
	/**
	 * Dodanie jednego perceptronu do sieci i poprawienie wejsc, wyjsc 
	 * 
	 * @param extended_no numer warstwy ukrytej
	 */
	/*public void changeNetwork(int extended_no) {
		Layer layer;
		int i;
		
		Iterator it = layers.iterator();
		IInputProvider currentLayerInput = input;
		i = 0;
		while(it.hasNext()) {
			layer = (Layer) it.next();
			if (i == extended_no) {
				layer.grow();  //zwieksz licznosc warstwy o jeden perceptron
				noOfPerceptronsInLayer[i]++;
			}
			if ((i-1) == extended_no) {
				//podmiana wejscia wraz z powiekszeniem tablicy wag dla perceptronow
				layer.recombine(currentLayerInput); 
			}
			layer.updateInput(currentLayerInput);
			currentLayerInput = layer.getOutput();
			i++;		
		}
		resultCombiner.setInput(currentLayerInput);	
	}*/
	
	/**
	 * Zwraca nastepny obiekt danych treningowych
	 * @return DoubleDataWithDecision
	 */
	private DoubleDataWithDecision nextData() {
		if (! dataIterator.hasNext())
			dataIterator = trainData.iterator();
		return (DoubleDataWithDecision)dataIterator.next();
	}
	
	/**
	 * Algorytm propagacji wstecznej. Zmodyfikowany o limitowanie ilosci trenowanych
	 * rekordow z danymi do Global.MAX_ITER_SIZE (dla lepszego zarzadzania czasem dzialania).
	 *
	 */
	private void backPropUpdate() {
		for (int i =0; i < Math.min(Global.MAX_ITER_SIZE, trainData.size()); i++) {
			DoubleDataWithDecision dd = this.nextData();
			
			count(dd); // liczymy wynik
			
			resultCombiner.setExpectedResult(dd.getDecision()); // okreslamy pozadany rezultat
			
			// kazemy obliczyc bledy poszczegolnym warstwom
			IDeviationCounter dc = resultCombiner;
			for (ListIterator l = layers.listIterator(layers.size()); l.hasPrevious(); ) {
				Layer layer = (Layer)l.previous();
				dc.countDeviations(layer.perceptrons());
				dc = layer;
			}
	
			// kazemy poprawic wagi
			for (ListIterator l = layers.listIterator(layers.size()); l.hasPrevious(); ) {				
				((Layer)l.previous()).improveWeights();
			}
		}
	}

	
	/**
	 * Przelicza wynik dla wejscia
	 * @param dd - rekord z danymi
	 */
	private void count(DoubleData dd) {
		input.setDoubleData(dd);		
		
		for (ListIterator l = layers.listIterator(); l.hasNext(); ) {
			((Layer)l.next()).count();
		}

		resultCombiner.count();		
	}
	

	/**
	 * Serializuje stan sieci do obiektu
	 * @return Object
	 */
	public Object storeData() {
		Object[] data = new Object[layers.size()];
		for (int i=0; i < layers.size(); i++)
			data[i] = (layers.get(i)).storeData();
		return data;
	}

	/**
	 * Deserializuje stan sieci z obiektu (do ktorego byla wczesniej zserializowana)
	 * @param data - Object
	 */
	public void restoreData(Object data) {
		for (int i=0; i < layers.size(); i++)
			(layers.get(i)).restoreData(((Object[])data)[i]);
	}
	
}

Copyright © 2008-2011 by TunedIT
Design by luksite