티스토리 뷰
deepnet 패키지를 활용하여 MNIST data를 훈련시키기.
training a deep network to read a number with deepnet package.(MNIST)
모델 :
nn <- dbn.dnn.train(image.train,outnode.train,
hidden=c(500,500,250,125),
output="softmax",
batchsize=100, numepochs=100, learningrate = 0.1)
근거 : 직관!
training 시간 : 7.8시간 (t.end-t.start)
training 결과
- network : https://www.dropbox.com/s/6crgw9d9okjimwp/deepnet01.RData?dl=0
- training set에 대한 error 0.004%,
test set에 대한 error 2% (과도한 적합으로 보임!)
training 진행에 따른 error plot
실행 화면 :
SOURCE :
trainMNIST_deepnet_interactive.R
(train 없이 바로 실행하고자 하면 위의 deepnet01.RData를 다운 받으세요.)
Source 1 : training
library(RCurl)
URL <- "http://www.pjreddie.com/media/files/mnist_train.csv"
x <- getURL(URL)
## Or
## x <- getURL(URL, ssl.verifypeer = FALSE)
train <- read.csv(textConnection(x))
URL <- "http://www.pjreddie.com/media/files/mnist_test.csv"
x <- getURL(URL)
## Or
## x <- getURL(URL, ssl.verifypeer = FALSE)
test <- read.csv(textConnection(x))
train.mat <- data.matrix(train)
test.mat <- data.matrix(test)
y.train <- as.factor(train.mat[,1])
image.train <- train.mat[, -1]
y.test <- as.factor(test.mat[,1])
image.test <- test.mat[, -1]
image.train <- image.train/255
image.test <- image.test/255
outnode.train <- model.matrix( ~ y.train -1)
outnode.test <- model.matrix( ~ y.test -1)
#http://stackoverflow.com/questions/5048638/automatically-expanding-an-r-factor-into-a-collection-of-1-0-indicator-variables
library(deepnet)
rm(test); rm(test.mat); rm(train); rm(train.mat)
t.start <- Sys.time()
nn <- dbn.dnn.train(image.train,outnode.train,
hidden=c(500,500,250,125),
output="softmax",
batchsize=100, numepochs=100, learningrate = 0.1)
#deepnet01.RData
t.end <- Sys.time()
train.pred <- nn.predict(nn, image.train)
train.pred.num <- apply(train.pred, 1, function(x) which(max(x)==x))-1
sum(y.train==train.pred.num)/length(train.pred.num)
test.pred <- nn.predict(nn, image.test)
test.pred.num <- apply(test.pred, 1, function(x) which(max(x)==x))-1
sum(y.test==test.pred.num)/length(test.pred.num)
#save(nn, file="deepnet01.RData")
# test error 2%, train 0.004%
Source 2: interactive testing
# ref : http://www.cs.colostate.edu/~anderson/cs545/assignments/digitsInteractiveStart.R
# ref : http://www.cs.colostate.edu/~anderson/cs545/assignments/solutionsGoodExamples/assignment6Muriel.pdf
load("deepnet01.RData")
library(cairoDevice)
drawMatrix <- function(x) {
image(matrix(x,28,28)[,28:1],col=rev(gray((0:100)/100)),xaxt="n",yaxt="n",xlab="",ylab="",bty="n")
}
drawImage <- function() {
img <- matrix(0,28,28)
if (!is.null(coords)) {
coordsI <- ceiling(coords)
coordsI[coordsI > 28] <- 28
coordsI[coordsI < 1] <- 1
coordsI[,2] <- 29-coordsI[,2]
counts <- table(coordsI[,1],coordsI[,2])
img <- matrix(0,28,28)
#img[as.numeric(rownames(counts)), as.numeric(colnames(counts))] <- 1
img[as.numeric(rownames(counts)), as.numeric(colnames(counts))] <- counts
mx <- max(img)
mn <- min(img)
img <- (img-mn)*5 / (mx-mn)
img[img>=1]=1
}
drawMatrix(img)
img
}
plotEmpty <- function(a=0,b=1) {
plot(c(a,b),c(a,b),type="n",bty="n",xaxt="n",yaxt="n",xlab="",ylab="")
}
#x11(type="Xlib",width=3,height=3)
windows()
Cairo(width=3,height=3) ## for drawing
par(mar=c(0,0,0,0))
coords <- NULL
drawingG <- FALSE
plotEmpty(0,28)
x=c()
getGraphicsEvent("Hold left down to draw. Click right to restart. press <Return/Enter> to read.",
onMouseDown = function(buttons,x,y) {
if (buttons == 0) {
drawingG <<- TRUE
NULL
} else if (buttons == 2) {
coords <<- NULL
plotEmpty(0,28)
NULL
} else if (buttons == 1) {
return(TRUE)
}
},
onMouseUp = function(buttons,x,y) {
drawingG <<- FALSE
# processAndClassify()
NULL
},
onMouseMove = function(buttons,x,y) {
if (drawingG) {
#print(c(x,y))
px <- grconvertX(x, "ndc", "user")
py <- grconvertY(y, "ndc", "user")
coords <<- rbind(coords,c(px,py))
# print(px,py)
points(px,py,pch=19,cex=3)
}
NULL
},
onKeybd = function(key) {
if (key == "Return") {
dev.set(which=dev.next())
x <- t(matrix(drawImage()))*255
x <- matrix(c(x), nrow=1)
#colnames(x) <- colnames(test_h2o)
#save(x, file="x.RData")
#x.h2o <- as.h2o(localH2O, key="x.h2o",x)
x.pred.v <- nn.predict(nn, x)
#colnames(x.pred.v)=0:9
#barplot(x.pred.v)
x.pred <- which(max(x.pred.v)==x.pred.v)-1
print(paste("The Number is : ",x.pred, sep=""))
dev.set(which=dev.prev())
}
NULL
}
)
dev.off(); dev.off()
'차기작 : R을 배우자' 카테고리의 다른 글
Polychoric correlations (2) | 2015.09.22 |
---|---|
rJava loading 문제 (0) | 2015.05.17 |
XOR with package "h2o" (0) | 2014.11.11 |
package deepnet을 활용하여 XOR 학습하기 (0) | 2014.11.08 |
a CRF model for denoising (0) | 2014.10.04 |