[ti-machine-learning/ti-machine-learning.git] / src / app / cnn / class / mnist / appCNNClassMNISTTraining.c
diff --git a/src/app/cnn/class/mnist/appCNNClassMNISTTraining.c b/src/app/cnn/class/mnist/appCNNClassMNISTTraining.c
index 940a6a8057d513c9646012e06fd55f07146a653c..9f2709ae733c240655b7348f10f984fba2163572 100644 (file)
#define BATCH_SIZE 100\r
#define IMAGE_CHANNEL 1\r
#define LEARN_RATE 0.1\r
+#define EPOCH 10\r
\r
/*******************************************************************************\r
*\r
\r
int main()\r
{\r
-\r
return appCNNClassMNISTTraining();\r
}\r
\r
\r
int appCNNClassMNISTTraining()\r
{\r
- int i;\r
+ int i, j;\r
int dim;\r
long mem;\r
struct timespec startTime;\r
\r
// setup CNN\r
printf("1. Build up the CNN\n");\r
- timlConvNeuralNetwork *cnn = timlCNNCreateConvNeuralNetwork(timlCNNTrainingParamsDefault(), 0);\r
+ timlConvNeuralNetwork *cnn = timlCNNCreateConvNeuralNetwork(timlCNNTrainingParamsDefault());\r
cnn->params.learningRate = LEARN_RATE;\r
+ cnn->params.maxBatchSize = BATCH_SIZE;\r
cnn->params.batchSize = BATCH_SIZE;\r
+ cnn->params.batchUpdate = BATCH_SIZE;\r
inputParams = timlCNNInputParamsDefault();\r
inputParams.scale = 1.0/256.0;\r
timlCNNAddInputLayer(cnn, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, inputParams); // input layer\r
timlCNNAddLinearLayer(cnn, 500, timlCNNLinearParamsDefault()); // linear layer\r
timlCNNAddNonlinearLayer(cnn, Util_Relu); // relu layer\r
timlCNNAddLinearLayer(cnn, 10, timlCNNLinearParamsDefault()); // linear layer\r
- timlCNNAddNonlinearLayer(cnn, Util_Softmax); // softmax layer\r
+ timlCNNAddSoftmaxCostLayer(cnn); // softmax cost layer\r
timlCNNInitialize(cnn);\r
timlCNNReset(cnn);\r
- mem = timlCNNMemory(cnn);\r
timlCNNPrint(cnn);\r
- printf("CNN memory allocation = %.10f MB.\n", (float) mem/1024.0/1024.0);\r
- printf("CNN parameter # = %ld.\n", timlCNNGetParamsNum(cnn));\r
\r
// read MNIST database\r
printf("2. Read the MNIST database\n");\r
// training\r
printf("3. Start training\n");\r
clock_gettime(CLOCK_REALTIME, &startTime);\r
- for (i = 0; i < batchNum; i++) {\r
- timlCNNSupervisedTrainingWithLabelBatchMode(cnn, training.data + i*batchSize*dim, training.label + i*batchSize, dim, batchSize);\r
+ for (j =0; j < EPOCH; j++) {\r
+ for (i = 0; i < batchNum; i++) {\r
+ timlCNNSupervisedTrainingWithLabel(cnn, training.data + i*batchSize*dim, IMAGE_ROW, IMAGE_COL, IMAGE_CHANNEL, training.label + i*batchSize, 1, 1, batchSize);\r
+ }\r
}\r
clock_gettime(CLOCK_REALTIME, &endTime);\r
trainingTime = timlUtilDiffTime(startTime, endTime);\r