U E D R , A S I H C RSS

데블스캠프2011/둘째날/Machine-Learning/SVM/namsangboy

  • SVM Train 파일 생성하기

#-*-coding:utf8-*-
import re, sys, math
classlist = ["economy","politics"]
maketestdir = lambda i : "/home/newmoni/workspace/svm/package/test/"+i+"/"+i+".txt"    
def readtrain():
    path1 = "/home/newmoni/workspace/svm/package/train/economy/index.economy.db"
    path2 = "/home/newmoni/workspace/svm/package/train/politics/index.politics.db"
    makedir = lambda i : "/home/newmoni/workspace/svm/package/train/"+i+"/index."+i+".db"
    classfreqdic = {}
    totalct=0
    wordfreqdic = {}
    for eachclass in classlist:
        doclist = open(makedir(eachclass)).read().split("\n")
        classfreqdic[eachclass]=len(doclist)
        wordfreqdic[eachclass] = {}
        totalct+=len(doclist)
        for line in doclist:
            for word in line.split(" "):
                if not wordfreqdic[eachclass].has_key(word):
                    wordfreqdic[eachclass][word]=0
                wordfreqdic[eachclass][word]+=1
    totalct = float(totalct)
    prob1 = math.log((classfreqdic["economy"]/totalct)/(classfreqdic["politics"]/totalct))
    classprob1 = float(classfreqdic["economy"]/totalct)
    classprob2 = float(classfreqdic["politics"]/totalct)
    return classfreqdic, wordfreqdic, prob1, classprob1, classprob2
def classifydocument(document):
    totalprob = 0
    for word in document.replace("\n"," ").split(" "):
        classfreq1 = wordfreqdic["economy"].get(word,0)+1
        classfreq2 = wordfreqdic["politics"].get(word,0)+1
        totalprob+= math.log((classfreq1/classprob1)/(classfreq2/classprob2))
    return totalprob
if __name__=="__main__":
    classfreqdic, wordfreqdic, prob1, classprob1, classprob2 = readtrain()
    correctct=0
    totalct=0
    for eachclass in classlist:
        doclist = open(maketestdir(eachclass)).read().split("\n")
        for line in doclist:
            totalprob = classifydocument(line)
            print eachclass, totalprob
            if eachclass=="economy":
                if totalprob>0:
                    correctct+=1
            elif eachclass=="politics":
                if totalprob<0:
                    correctct+=1                    
            totalct+=1
    print correctct,totalct, correctct/float(totalct)
  • SVM Test 파일 생성

#-*-coding:utf8-*-
import re, sys, math
from NaiveBayesian import *
if __name__=="__main__":
    classfreqdic, wordfreqdic, prob1, classprob1, classprob2 = readtrain()
    print "read end "
    wordlist = set()
    for eachclass in classlist:
        wordlist.update(wordfreqdic[eachclass].keys())
    print "end"
    wordlist = list(wordlist)
    path1 = "/home/newmoni/workspace/svm/package/train/economy/index.economy.db"
    path2 = "/home/newmoni/workspace/svm/package/train/politics/index.politics.db"
    makedir = lambda i : "/home/newmoni/workspace/svm/package/train/"+i+"/index."+i+".db"
    classfreqdic = {}
    totalct=0
    wordfreqdic = {}
    svmttest = "../data/test2.svm_light"    
    fout = file(svmttest,"a")
    wordindexdic= {}
    for idx,word in enumerate(wordlist):
        wordindexdic[word]=(idx+1)
            
    for idx,eachclass in enumerate(classlist):
        doclist = open(maketestdir(eachclass)).read().split("\n")
        print idx
        classfreqdic[eachclass]=len(doclist)
        wordfreqdic[eachclass] = {}
        totalct+=len(doclist)
        for line in doclist:
            docwordfreq = {}
            for word in line.split(" "):
                if not wordindexdic.has_key(word):
                    continue
                wordidx = wordindexdic[word]
                if not docwordfreq.has_key(wordidx):
                    docwordfreq[wordidx]=0
                docwordfreq[wordidx]+=1
            docwordlist = docwordfreq.keys()
            docwordlist.sort(reverse=False)
            outlist = [str(idx+1)]
            for wordidx in docwordlist:
                outlist.append(str(wordidx)+":"+str(1))
#                outlist.append(str(wordidx)+":"+str(docwordfreq[wordidx]))
            print >>fout, " ".join(outlist)
#            print  " ".join(outlist)
#            sys.exit()
                        
Valid XHTML 1.0! Valid CSS! powered by MoniWiki
last modified 2011-06-28 12:35:25
Processing time 0.0109 sec