티스토리 뷰

deepnet 패키지를 활용하여 MNIST data를 훈련시키기.

training a deep network to read a number with deepnet package.(MNIST)


모델 : 

nn <- dbn.dnn.train(image.train,outnode.train, 

                    hidden=c(500,500,250,125),

                    output="softmax",

                    batchsize=100, numepochs=100, learningrate = 0.1)


근거 : 직관!

training 시간 : 7.8시간 (t.end-t.start)

training 결과

 - network : https://www.dropbox.com/s/6crgw9d9okjimwp/deepnet01.RData?dl=0

 - training set에 대한 error 0.004%, 

    test set에 대한 error 2% (과도한 적합으로 보임!)

training 진행에 따른 error plot

실행 화면 :





SOURCE :

trainMNIST_deepnet.R


trainMNIST_deepnet_interactive.R

(train 없이 바로 실행하고자 하면 위의 deepnet01.RData를 다운 받으세요.)


Source 1 : training

library(RCurl)

URL <- "http://www.pjreddie.com/media/files/mnist_train.csv"

x <- getURL(URL)

## Or 

## x <- getURL(URL, ssl.verifypeer = FALSE)

train <- read.csv(textConnection(x))


URL <- "http://www.pjreddie.com/media/files/mnist_test.csv"

x <- getURL(URL)

## Or 

## x <- getURL(URL, ssl.verifypeer = FALSE)

test <- read.csv(textConnection(x))


train.mat <- data.matrix(train)

test.mat <- data.matrix(test)


y.train <- as.factor(train.mat[,1])

image.train <- train.mat[, -1]

y.test <- as.factor(test.mat[,1])

image.test <- test.mat[, -1]


image.train <- image.train/255

image.test <- image.test/255


outnode.train <- model.matrix( ~ y.train -1)

outnode.test <- model.matrix( ~ y.test -1)

#http://stackoverflow.com/questions/5048638/automatically-expanding-an-r-factor-into-a-collection-of-1-0-indicator-variables


library(deepnet)


rm(test); rm(test.mat); rm(train); rm(train.mat)


t.start <- Sys.time()

nn <- dbn.dnn.train(image.train,outnode.train, 

                    hidden=c(500,500,250,125),

                    output="softmax",

                    batchsize=100, numepochs=100, learningrate = 0.1)

#deepnet01.RData

t.end <- Sys.time()


train.pred <- nn.predict(nn, image.train)

train.pred.num <- apply(train.pred, 1, function(x) which(max(x)==x))-1


sum(y.train==train.pred.num)/length(train.pred.num)


test.pred <- nn.predict(nn, image.test)

test.pred.num <- apply(test.pred, 1, function(x) which(max(x)==x))-1


sum(y.test==test.pred.num)/length(test.pred.num)

#save(nn, file="deepnet01.RData")

#  test error 2%, train 0.004%



Source 2: interactive testing

# ref : http://www.cs.colostate.edu/~anderson/cs545/assignments/digitsInteractiveStart.R

# ref : http://www.cs.colostate.edu/~anderson/cs545/assignments/solutionsGoodExamples/assignment6Muriel.pdf


load("deepnet01.RData")


library(cairoDevice)


drawMatrix <- function(x) {

  image(matrix(x,28,28)[,28:1],col=rev(gray((0:100)/100)),xaxt="n",yaxt="n",xlab="",ylab="",bty="n")

}


drawImage <- function() {

  img <- matrix(0,28,28)

  if (!is.null(coords)) {

    coordsI <- ceiling(coords)

    coordsI[coordsI > 28] <- 28

    coordsI[coordsI < 1] <- 1

    coordsI[,2] <- 29-coordsI[,2]

    counts <- table(coordsI[,1],coordsI[,2])

    img <- matrix(0,28,28)

    #img[as.numeric(rownames(counts)), as.numeric(colnames(counts))] <- 1

    img[as.numeric(rownames(counts)), as.numeric(colnames(counts))] <- counts

    mx <- max(img)

    mn <- min(img)

    img <- (img-mn)*5 / (mx-mn) 

    img[img>=1]=1

  }

  drawMatrix(img)  

  img

}


plotEmpty <- function(a=0,b=1) {

  plot(c(a,b),c(a,b),type="n",bty="n",xaxt="n",yaxt="n",xlab="",ylab="")

}



#x11(type="Xlib",width=3,height=3)

windows()



Cairo(width=3,height=3)  ## for drawing

par(mar=c(0,0,0,0))


coords <- NULL

drawingG <- FALSE

plotEmpty(0,28)


x=c()

getGraphicsEvent("Hold left down to draw. Click right to restart. press <Return/Enter> to read.",

                 onMouseDown = function(buttons,x,y) {

                   if (buttons == 0) {

                     drawingG <<- TRUE

                     NULL

                   } else if (buttons == 2) {

                     coords <<- NULL

                     plotEmpty(0,28)

                     NULL

                   } else if (buttons == 1) {

                     return(TRUE)

                   }

                 },

                 onMouseUp = function(buttons,x,y) {

                   drawingG <<- FALSE

                   # processAndClassify()

                   NULL

                 },

                 onMouseMove = function(buttons,x,y) {

                   if (drawingG) {

                     #print(c(x,y))

                     px <- grconvertX(x, "ndc", "user")

                     py <- grconvertY(y, "ndc", "user")    

                     coords <<- rbind(coords,c(px,py))

                     # print(px,py)

                     points(px,py,pch=19,cex=3)

                   }

                   NULL

                 },

                 onKeybd = function(key) {

                   if (key == "Return") { 

                     dev.set(which=dev.next())

                     x <- t(matrix(drawImage()))*255

                     x <- matrix(c(x), nrow=1)                    

                     #colnames(x) <- colnames(test_h2o)

                     #save(x, file="x.RData")                     

                     #x.h2o <- as.h2o(localH2O, key="x.h2o",x)

                     x.pred.v <- nn.predict(nn, x)    

                     #colnames(x.pred.v)=0:9

                     #barplot(x.pred.v)

                     x.pred <- which(max(x.pred.v)==x.pred.v)-1

                     print(paste("The Number is : ",x.pred, sep=""))

           

                     dev.set(which=dev.prev())                                                            

                   }

                   NULL

                 }

)

dev.off(); dev.off()



신고

'차기작 : R을 배우자' 카테고리의 다른 글

Polychoric correlations  (1) 2015.09.22
rJava loading 문제  (0) 2015.05.17
training MNIST data with the package "deepnet"  (2) 2014.11.13
XOR with package "h2o"  (0) 2014.11.11
package deepnet을 활용하여 XOR 학습하기  (0) 2014.11.08
a CRF model for denoising  (0) 2014.10.04
댓글
댓글쓰기 폼
공지사항
Total
33,717
Today
4
Yesterday
53
링크
TAG
more
«   2017/05   »
  1 2 3 4 5 6
7 8 9 10 11 12 13
14 15 16 17 18 19 20
21 22 23 24 25 26 27
28 29 30 31      
글 보관함