U E D R , A S I H C RSS

데블스캠프2011/둘째날/Machine-Learning/Naive Bayes Classifier/강성현

Difference between r1.2 and the current

@@ -28,12 +28,8 @@
public String next() { return scan.next(); }

public String nextLine() { return scan.nextLine(); }

}
 
 
}}}

=== Train File Analysis ===


설명

  • HashMap을 사용하여 단어와 빈도수를 저장함. 저장할 빈도수가 2개라 int형 2개를 저장하는 Int2 클래스를 만듦.
  • 파일입력은 FileData 클래스를 만들어서 사용. java.util.Scanner를 사용하였음.
  • train 데이터를 읽어들여서 일단 문자열과 빈도수를 csv 파일로 저장. 이를 Analyze 클래스에서 csv 파일을 읽어들여 test 데이터를 판별.

Source Code


FileData Class

package org.zeropage.devils.machine;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.Scanner;

public class FileData {
	
	Scanner scan;
	
	public FileData(String filename) throws FileNotFoundException, UnsupportedEncodingException {
		scan = new Scanner(new InputStreamReader(new FileInputStream(filename),"UTF-8"));
	}
	
	public boolean hasNext() { return scan.hasNext(); }
	
	public String next() { return scan.next(); }
	
	public String nextLine() { return scan.nextLine(); }

}

Train File Analysis

package org.zeropage.devils.machine;

import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.StringTokenizer;

public class Main {
	
	public static HashMap<String, Int2> words = new HashMap<String, Int2>();
	
	public static void main(String[] args) {
		
		try {
			FileData politics = new FileData("train/politics/index.politics.db");
			FileData economy = new FileData("train/economy/index.economy.db");
			
			while (politics.hasNext()) {
				String article = politics.nextLine();
				StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
				
				while (st.hasMoreTokens()) {
					
					String word = st.nextToken();
					ArrayList<String> wordsInArticle = new ArrayList<String>();
					if (wordsInArticle.contains(word)) continue;
					wordsInArticle.add(word);
				
					if (!words.containsKey(word)) {
						words.put(word, new Int2());	
					}
					words.get(word).increase1();
				}
			}
			
			while (economy.hasNext()) {
				String article = economy.nextLine();
				StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
				
				while (st.hasMoreTokens()) {
					
					String word = st.nextToken();
					ArrayList<String> wordsInArticle = new ArrayList<String>();
					if (wordsInArticle.contains(word)) continue;
					wordsInArticle.add(word);
				
					if (!words.containsKey(word)) {
						words.put(word, new Int2());	
					}
					words.get(word).increase2();
				}
			}
			
			writeCsv("result.csv");
			
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (UnsupportedEncodingException e) {
			e.printStackTrace();
		}
	}
	
	private static void writeCsv(String filename) throws FileNotFoundException {
		PrintWriter pw = new PrintWriter(new FileOutputStream(filename));
		
		String[] wordset = words.keySet().toArray(new String[0]);
		Int2[] intset = words.values().toArray(new Int2[0]);
		
		pw.println("word,politics,economy");
		
		for (int i = 0; i < wordset.length; i++) {
			pw.println('"' + wordset[i] + '"' + ',' + intset[i].get1() + ',' + intset[i].get2());
		}
		pw.close();
	}

}

class Int2 {
	private int _1, _2;
	
	public Int2() {
		_1 = 0; _2 = 0;
	}
	
	public Int2(int __1, int __2) {
		_1 = __1; _2 = __2;
	}
	
	public void increase1() { _1++; }
	public void increase2() { _2++; }
	
	public int get1() { return _1; }
	public int get2() { return _2; }
}

test File Analysis

package org.zeropage.devils.machine;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.StringTokenizer;

import au.com.bytecode.opencsv.CSVReader;

public class Analyze {
	
	public static HashMap<String, Int2> words = new HashMap<String, Int2>();
	
	public static void main(String[] args) {
		
		try {
			CSVReader csv = new CSVReader(new InputStreamReader(new FileInputStream("result.csv"), "UTF-8"));
			FileData politics = new FileData("test/politics/politics.txt");
			FileData economy = new FileData("test/economy/economy.txt");
			
			String[][] data = csv.readAll().toArray(new String[0][0]);
			
			for (int i = 2; i < data.length; i++) {
				words.put(data[i][0], new Int2(Integer.parseInt(data[i][1]),Integer.parseInt(data[i][2])));
			}
			
			int isP = 0, isE = 0, count = 0;
			
			while (politics.hasNext()) {
				String article = politics.nextLine();
				StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
				double p = 0.0;
				
				while (st.hasMoreTokens()) {
					
					String word = st.nextToken();
					ArrayList<String> wordsInArticle = new ArrayList<String>();
					if (wordsInArticle.contains(word)) continue;
					wordsInArticle.add(word);
					
					if (words.containsKey(word)) {
						Int2 f = words.get(word);
						p += Math.log((f.get1()+1)/(double)(f.get2()+1));
					}
				}
				if (p > 0) isP++;
				
				count++;
				
			}
			
			while (economy.hasNext()) {
				String article = economy.nextLine();
				StringTokenizer st = new StringTokenizer(article, " \t\n\r\f\'\"");
				double p = 0.0;
				
				while (st.hasMoreTokens()) {
					
					String word = st.nextToken();
					ArrayList<String> wordsInArticle = new ArrayList<String>();
					if (wordsInArticle.contains(word)) continue;
					wordsInArticle.add(word);
					if (words.containsKey(word)) {
						Int2 f = words.get(word);
						p += Math.log((f.get2()+1)/(double)(f.get1()+1));
					}
				}
				if (p > 0) isE++;
				count++;
			}
			
			System.out.println((isP + isE)/(double)count);
			
		} catch (UnsupportedEncodingException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
}

Valid XHTML 1.0! Valid CSS! powered by MoniWiki
last modified 2021-02-07 05:29:13
Processing time 0.0286 sec