package org.zeropage.ml; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; public class App { public static void main(String[] args) throws IOException { Tester tester = new Tester(new Trainer[] { new Trainer("package/train/economy/index.economy.db").load(), new Trainer("package/train/politics/index.politics.db").load() }); BufferedReader[] readers = new BufferedReader[] { new BufferedReader(new FileReader("package/test/economy/economy.txt")), new BufferedReader(new FileReader("package/test/politics/politics.txt")) }; for (int i = 0; i < readers.length; i++) { BufferedReader reader = readers[i]; int right = 0, wrong = 0; try { for (String line; (line = reader.readLine()) != null;) { if (tester.getWeight(i, line) > 0) { right++; } else { wrong++; } } System.out.println(right + "\t" + wrong + "\t" + ((double)right/(right+wrong))); } catch (IOException ex) { reader.close(); } } } }
package org.zeropage.ml; import java.io.BufferedReader; import java.io.FileReader; import java.io.IOException; import java.util.HashMap; import java.util.Map; public class Trainer { private String fileName; private int docsCount; private int wordsCount; private Map<String, Integer> wordCount = new HashMap<String, Integer>(); public Trainer(String fileName) { this.fileName = fileName; } public Trainer load() throws IOException { BufferedReader reader = new BufferedReader(new FileReader(fileName)); try { for (String line; (line = reader.readLine()) != null;) { for (String word : line.split("\\s+")) { Integer count = wordCount.get(word); wordCount.put(word, count == null ? 1 : count + 1); wordsCount++; } docsCount++; } } catch (IOException ex) { reader.close(); } return this; } public int getDocsCount() { return docsCount; } public int getWordCount(String word) { Integer count = wordCount.get(word); return count == null ? 1 : count; } public int getWordsCount() { return wordsCount; } }
package org.zeropage.ml; public class Tester { private Trainer[] trainers; public Tester(Trainer[] trainers) { this.trainers = trainers; } public double getWeight(int index, String doc) { double value = getLnPsPns(index); for (String word : doc.split("\\s+")) { value += getLnPwsPwns(index, word); } return value; } private double getLnPsPns(int index) { int sum = 0; for (int i = 0; i < trainers.length; i++) { if (i != index) { sum += trainers[i].getDocsCount(); } } return Math.log((double)trainers[index].getDocsCount()/sum); } private double getLnPwsPwns(int index, String word) { int sum = 0, total = 0; for (int i = 0; i < trainers.length; i++) { if (i != index) { sum += trainers[i].getWordCount(word); total += trainers[i].getWordsCount(); } } return Math.log(((double)trainers[index].getWordCount(word)/trainers[index].getWordsCount()) / ((double)sum/total)); } }