데블스캠프2011/둘째날/Machine-Learning/NaiveBayesClassifier/강성현 (rev. 1.3)
설명 ¶
- HashMap을 사용하여 단어와 빈도수를 저장함. 저장할 빈도수가 2개라 int형 2개를 저장하는 Int2 클래스를 만듦.
- 파일입력은 FileData 클래스를 만들어서 사용. java.util.Scanner를 사용하였음.
- train 데이터를 읽어들여서 일단 문자열과 빈도수를 csv 파일로 저장. 이를 Analyze 클래스에서 csv 파일을 읽어들여 test 데이터를 판별.
FileData Class ¶
package org.zeropage.devils.machine;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.Scanner;
public class FileData {
Scanner scan;
public FileData(String filename) throws FileNotFoundException, UnsupportedEncodingException {
scan = new Scanner(new InputStreamReader(new FileInputStream(filename),"UTF-8"));
}
public boolean hasNext() { return scan.hasNext(); }
public String next() { return scan.next(); }
public String nextLine() { return scan.nextLine(); }
}
Train File Analysis ¶
package org.zeropage.devils.machine;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.StringTokenizer;
public class Main {
public static HashMap<String, Int2> words = new HashMap<String, Int2>();
public static void main(String[] args) {
try {
FileData politics = new FileData("train/politics/index.politics.db");
FileData economy = new FileData("train/economy/index.economy.db");
while (politics.hasNext()) {
String article = politics.nextLine();
StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
while (st.hasMoreTokens()) {
String word = st.nextToken();
ArrayList<String> wordsInArticle = new ArrayList<String>();
if (wordsInArticle.contains(word)) continue;
wordsInArticle.add(word);
if (!words.containsKey(word)) {
words.put(word, new Int2());
}
words.get(word).increase1();
}
}
while (economy.hasNext()) {
String article = economy.nextLine();
StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
while (st.hasMoreTokens()) {
String word = st.nextToken();
ArrayList<String> wordsInArticle = new ArrayList<String>();
if (wordsInArticle.contains(word)) continue;
wordsInArticle.add(word);
if (!words.containsKey(word)) {
words.put(word, new Int2());
}
words.get(word).increase2();
}
}
writeCsv("result.csv");
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
}
private static void writeCsv(String filename) throws FileNotFoundException {
PrintWriter pw = new PrintWriter(new FileOutputStream(filename));
String[] wordset = words.keySet().toArray(new String[0]);
Int2[] intset = words.values().toArray(new Int2[0]);
pw.println("word,politics,economy");
for (int i = 0; i < wordset.length; i++) {
pw.println('"' + wordset[i] + '"' + ',' + intset[i].get1() + ',' + intset[i].get2());
}
pw.close();
}
}
class Int2 {
private int _1, _2;
public Int2() {
_1 = 0; _2 = 0;
}
public Int2(int __1, int __2) {
_1 = __1; _2 = __2;
}
public void increase1() { _1++; }
public void increase2() { _2++; }
public int get1() { return _1; }
public int get2() { return _2; }
}
test File Analysis ¶
package org.zeropage.devils.machine;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.StringTokenizer;
import au.com.bytecode.opencsv.CSVReader;
public class Analyze {
public static HashMap<String, Int2> words = new HashMap<String, Int2>();
public static void main(String[] args) {
try {
CSVReader csv = new CSVReader(new InputStreamReader(new FileInputStream("result.csv"), "UTF-8"));
FileData politics = new FileData("test/politics/politics.txt");
FileData economy = new FileData("test/economy/economy.txt");
String[][] data = csv.readAll().toArray(new String[0][0]);
for (int i = 2; i < data.length; i++) {
words.put(data[i][0], new Int2(Integer.parseInt(data[i][1]),Integer.parseInt(data[i][2])));
}
int isP = 0, isE = 0, count = 0;
while (politics.hasNext()) {
String article = politics.nextLine();
StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
double p = 0.0;
while (st.hasMoreTokens()) {
String word = st.nextToken();
ArrayList<String> wordsInArticle = new ArrayList<String>();
if (wordsInArticle.contains(word)) continue;
wordsInArticle.add(word);
if (words.containsKey(word)) {
Int2 f = words.get(word);
p += Math.log((f.get1()+1)/(double)(f.get2()+1));
}
}
if (p > 0) isP++;
count++;
}
while (economy.hasNext()) {
String article = economy.nextLine();
StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
double p = 0.0;
while (st.hasMoreTokens()) {
String word = st.nextToken();
ArrayList<String> wordsInArticle = new ArrayList<String>();
if (wordsInArticle.contains(word)) continue;
wordsInArticle.add(word);
if (words.containsKey(word)) {
Int2 f = words.get(word);
p += Math.log((f.get2()+1)/(double)(f.get1()+1));
}
}
if (p > 0) isE++;
count++;
}
System.out.println((isP + isE)/(double)count);
} catch (UnsupportedEncodingException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}