U E D R , A S I H C RSS

데블스캠프2011/둘째날/Machine-Learning/Naive Bayes Classifier/변형진

package org.zeropage.ml;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;

public class App {
	public static void main(String[] args) throws IOException {
		Tester tester = new Tester(new Trainer[] {
				new Trainer("package/train/economy/index.economy.db").load(),
				new Trainer("package/train/politics/index.politics.db").load() });
		
		BufferedReader[] readers = new BufferedReader[] {
			new BufferedReader(new FileReader("package/test/economy/economy.txt")),
			new BufferedReader(new FileReader("package/test/politics/politics.txt")) };
		
		for (int i = 0; i < readers.length; i++) {
			BufferedReader reader = readers[i];
			int right = 0, wrong = 0;
			try {
				for (String line; (line = reader.readLine()) != null;) {
					if (tester.getWeight(i, line) > 0) {
						right++;
					} else {
						wrong++;
					}
				}
				System.out.println(right + "\t" + wrong + "\t" + ((double)right/(right+wrong)));
			} catch (IOException ex) {
				reader.close();
			}
		}
	}
}

package org.zeropage.ml;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class Trainer {
	private String fileName;
	private int docsCount;
	private int wordsCount;
	private Map<String, Integer> wordCount = new HashMap<String, Integer>();

	public Trainer(String fileName) {
		this.fileName = fileName;
	}
	
	public Trainer load() throws IOException {
		BufferedReader reader = new BufferedReader(new FileReader(fileName));
		try {
			for (String line; (line = reader.readLine()) != null;) {
				for (String word : line.split("\\s+")) {
					Integer count = wordCount.get(word);
					wordCount.put(word, count == null ? 1 : count + 1);
					wordsCount++;
				}
				docsCount++;
			}
		} catch (IOException ex) {
			reader.close();
		}
		return this;
	}
	
	public int getDocsCount() {
		return docsCount;
	}
	
	public int getWordCount(String word) {
		Integer count = wordCount.get(word);
		return count == null ? 1 : count;
	}
	
	public int getWordsCount() {
		return wordsCount;
	}
}

package org.zeropage.ml;

public class Tester {
	private Trainer[] trainers;

	public Tester(Trainer[] trainers) {
		this.trainers = trainers;
	}
	
	public double getWeight(int index, String doc) {
		double value = getLnPsPns(index);
		for (String word : doc.split("\\s+")) {
			value += getLnPwsPwns(index, word);
		}
		return value;
	}
	
	private double getLnPsPns(int index) {
		int sum = 0;
		for (int i = 0; i < trainers.length; i++) {
			if (i != index) {
				sum += trainers[i].getDocsCount();	
			}
		}
		return Math.log((double)trainers[index].getDocsCount()/sum);
	}
	
	private double getLnPwsPwns(int index, String word) {
		int sum = 0, total = 0;
		for (int i = 0; i < trainers.length; i++) {
			if (i != index) {
				sum += trainers[i].getWordCount(word);
				total += trainers[i].getWordsCount();
			}
		}
		return Math.log(((double)trainers[index].getWordCount(word)/trainers[index].getWordsCount()) / ((double)sum/total));
	}
}
Valid XHTML 1.0! Valid CSS! powered by MoniWiki
last modified 2021-02-07 05:29:12
Processing time 0.0058 sec