package org.zeropage.ml;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
public class App {
public static void main(String[] args) throws IOException {
Tester tester = new Tester(new Trainer[] {
new Trainer("package/train/economy/index.economy.db").load(),
new Trainer("package/train/politics/index.politics.db").load() });
BufferedReader[] readers = new BufferedReader[] {
new BufferedReader(new FileReader("package/test/economy/economy.txt")),
new BufferedReader(new FileReader("package/test/politics/politics.txt")) };
for (int i = 0; i < readers.length; i++) {
BufferedReader reader = readers[i];
int right = 0, wrong = 0;
try {
for (String line; (line = reader.readLine()) != null;) {
if (tester.getWeight(i, line) > 0) {
right++;
} else {
wrong++;
}
}
System.out.println(right + "\t" + wrong + "\t" + ((double)right/(right+wrong)));
} catch (IOException ex) {
reader.close();
}
}
}
}
package org.zeropage.ml;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
public class Trainer {
private String fileName;
private int docsCount;
private int wordsCount;
private Map<String, Integer> wordCount = new HashMap<String, Integer>();
public Trainer(String fileName) {
this.fileName = fileName;
}
public Trainer load() throws IOException {
BufferedReader reader = new BufferedReader(new FileReader(fileName));
try {
for (String line; (line = reader.readLine()) != null;) {
for (String word : line.split("\\s+")) {
Integer count = wordCount.get(word);
wordCount.put(word, count == null ? 1 : count + 1);
wordsCount++;
}
docsCount++;
}
} catch (IOException ex) {
reader.close();
}
return this;
}
public int getDocsCount() {
return docsCount;
}
public int getWordCount(String word) {
Integer count = wordCount.get(word);
return count == null ? 1 : count;
}
public int getWordsCount() {
return wordsCount;
}
}
package org.zeropage.ml;
public class Tester {
private Trainer[] trainers;
public Tester(Trainer[] trainers) {
this.trainers = trainers;
}
public double getWeight(int index, String doc) {
double value = getLnPsPns(index);
for (String word : doc.split("\\s+")) {
value += getLnPwsPwns(index, word);
}
return value;
}
private double getLnPsPns(int index) {
int sum = 0;
for (int i = 0; i < trainers.length; i++) {
if (i != index) {
sum += trainers[i].getDocsCount();
}
}
return Math.log((double)trainers[index].getDocsCount()/sum);
}
private double getLnPwsPwns(int index, String word) {
int sum = 0, total = 0;
for (int i = 0; i < trainers.length; i++) {
if (i != index) {
sum += trainers[i].getWordCount(word);
total += trainers[i].getWordsCount();
}
}
return Math.log(((double)trainers[index].getWordCount(word)/trainers[index].getWordsCount()) / ((double)sum/total));
}
}