티스토리 뷰

deepnet 패키지를 활용하여 MNIST data를 훈련시키기.

training a deep network to read a number with deepnet package.(MNIST)


모델 : 

nn <- dbn.dnn.train(image.train,outnode.train, 

                    hidden=c(500,500,250,125),

                    output="softmax",

                    batchsize=100, numepochs=100, learningrate = 0.1)


근거 : 직관!

training 시간 : 7.8시간 (t.end-t.start)

training 결과

 - network : https://www.dropbox.com/s/6crgw9d9okjimwp/deepnet01.RData?dl=0

 - training set에 대한 error 0.004%, 

    test set에 대한 error 2% (과도한 적합으로 보임!)

training 진행에 따른 error plot

실행 화면 :





SOURCE :

trainMNIST_deepnet.R


trainMNIST_deepnet_interactive.R

(train 없이 바로 실행하고자 하면 위의 deepnet01.RData를 다운 받으세요.)


Source 1 : training

library(RCurl)

URL <- "http://www.pjreddie.com/media/files/mnist_train.csv"

x <- getURL(URL)

## Or 

## x <- getURL(URL, ssl.verifypeer = FALSE)

train <- read.csv(textConnection(x))


URL <- "http://www.pjreddie.com/media/files/mnist_test.csv"

x <- getURL(URL)

## Or 

## x <- getURL(URL, ssl.verifypeer = FALSE)

test <- read.csv(textConnection(x))


train.mat <- data.matrix(train)

test.mat <- data.matrix(test)


y.train <- as.factor(train.mat[,1])

image.train <- train.mat[, -1]

y.test <- as.factor(test.mat[,1])

image.test <- test.mat[, -1]


image.train <- image.train/255

image.test <- image.test/255


outnode.train <- model.matrix( ~ y.train -1)

outnode.test <- model.matrix( ~ y.test -1)

#http://stackoverflow.com/questions/5048638/automatically-expanding-an-r-factor-into-a-collection-of-1-0-indicator-variables


library(deepnet)


rm(test); rm(test.mat); rm(train); rm(train.mat)


t.start <- Sys.time()

nn <- dbn.dnn.train(image.train,outnode.train, 

                    hidden=c(500,500,250,125),

                    output="softmax",

                    batchsize=100, numepochs=100, learningrate = 0.1)

#deepnet01.RData

t.end <- Sys.time()


train.pred <- nn.predict(nn, image.train)

train.pred.num <- apply(train.pred, 1, function(x) which(max(x)==x))-1


sum(y.train==train.pred.num)/length(train.pred.num)


test.pred <- nn.predict(nn, image.test)

test.pred.num <- apply(test.pred, 1, function(x) which(max(x)==x))-1


sum(y.test==test.pred.num)/length(test.pred.num)

#save(nn, file="deepnet01.RData")

#  test error 2%, train 0.004%



Source 2: interactive testing

# ref : http://www.cs.colostate.edu/~anderson/cs545/assignments/digitsInteractiveStart.R

# ref : http://www.cs.colostate.edu/~anderson/cs545/assignments/solutionsGoodExamples/assignment6Muriel.pdf


load("deepnet01.RData")


library(cairoDevice)


drawMatrix <- function(x) {

  image(matrix(x,28,28)[,28:1],col=rev(gray((0:100)/100)),xaxt="n",yaxt="n",xlab="",ylab="",bty="n")

}


drawImage <- function() {

  img <- matrix(0,28,28)

  if (!is.null(coords)) {

    coordsI <- ceiling(coords)

    coordsI[coordsI > 28] <- 28

    coordsI[coordsI < 1] <- 1

    coordsI[,2] <- 29-coordsI[,2]

    counts <- table(coordsI[,1],coordsI[,2])

    img <- matrix(0,28,28)

    #img[as.numeric(rownames(counts)), as.numeric(colnames(counts))] <- 1

    img[as.numeric(rownames(counts)), as.numeric(colnames(counts))] <- counts

    mx <- max(img)

    mn <- min(img)

    img <- (img-mn)*5 / (mx-mn) 

    img[img>=1]=1

  }

  drawMatrix(img)  

  img

}


plotEmpty <- function(a=0,b=1) {

  plot(c(a,b),c(a,b),type="n",bty="n",xaxt="n",yaxt="n",xlab="",ylab="")

}



#x11(type="Xlib",width=3,height=3)

windows()



Cairo(width=3,height=3)  ## for drawing

par(mar=c(0,0,0,0))


coords <- NULL

drawingG <- FALSE

plotEmpty(0,28)


x=c()

getGraphicsEvent("Hold left down to draw. Click right to restart. press <Return/Enter> to read.",

                 onMouseDown = function(buttons,x,y) {

                   if (buttons == 0) {

                     drawingG <<- TRUE

                     NULL

                   } else if (buttons == 2) {

                     coords <<- NULL

                     plotEmpty(0,28)

                     NULL

                   } else if (buttons == 1) {

                     return(TRUE)

                   }

                 },

                 onMouseUp = function(buttons,x,y) {

                   drawingG <<- FALSE

                   # processAndClassify()

                   NULL

                 },

                 onMouseMove = function(buttons,x,y) {

                   if (drawingG) {

                     #print(c(x,y))

                     px <- grconvertX(x, "ndc", "user")

                     py <- grconvertY(y, "ndc", "user")    

                     coords <<- rbind(coords,c(px,py))

                     # print(px,py)

                     points(px,py,pch=19,cex=3)

                   }

                   NULL

                 },

                 onKeybd = function(key) {

                   if (key == "Return") { 

                     dev.set(which=dev.next())

                     x <- t(matrix(drawImage()))*255

                     x <- matrix(c(x), nrow=1)                    

                     #colnames(x) <- colnames(test_h2o)

                     #save(x, file="x.RData")                     

                     #x.h2o <- as.h2o(localH2O, key="x.h2o",x)

                     x.pred.v <- nn.predict(nn, x)    

                     #colnames(x.pred.v)=0:9

                     #barplot(x.pred.v)

                     x.pred <- which(max(x.pred.v)==x.pred.v)-1

                     print(paste("The Number is : ",x.pred, sep=""))

           

                     dev.set(which=dev.prev())                                                            

                   }

                   NULL

                 }

)

dev.off(); dev.off()



'차기작 : R을 배우자' 카테고리의 다른 글

Polychoric correlations  (2) 2015.09.22
rJava loading 문제  (0) 2015.05.17
XOR with package "h2o"  (0) 2014.11.11
package deepnet을 활용하여 XOR 학습하기  (0) 2014.11.08
a CRF model for denoising  (0) 2014.10.04
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/03   »
1 2
3 4 5 6 7 8 9
10 11 12 13 14 15 16
17 18 19 20 21 22 23
24 25 26 27 28 29 30
31
글 보관함