/******************************************************************************/ /*! * \file timlUtilReadMNIST.c */ /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/ * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the * distribution. * * Neither the name of Texas Instruments Incorporated nor the names of * its contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /******************************************************************************* * * INCLUDES * ******************************************************************************/ #include "../api/timl.h" /******************************************************************************* * * DEFINES * ******************************************************************************/ #define DATA_MAGIC_NUM 2051 #define LABEL_MAGIC_NUM 2049 #define CHANNEL 1 /******************************************************************************/ /*! * \ingroup util * \brief Read MNIST database from binary files * \param[in] path Database path * \param[out] training Training database * \param[out] testing Testing database * \return Error code */ /******************************************************************************/ int timlUtilReadMNIST(const char *path, timlUtilImageSet *training, timlUtilImageSet *testing) { int i; FILE *fp; int magicNum; int num; int row; int col; int read; uint32_t buffer; uint8_t *dataBuffer; uint8_t *labelBuffer; char str[TIML_UTIL_MAX_STR]; if (training == NULL) { return ERROR_UTIL_NULL_PTR; } if (testing == NULL) { return ERROR_UTIL_NULL_PTR; } // read training data strcpy(str, path); strcat(str, "/train-images.idx3-ubyte"); fp = fopen(str, "rb"); if (fp == NULL) { return ERROR_UTIL_MNIST_TRAINING_DATA_READING; } // read magic number read = fread(&buffer, 1, sizeof(uint32_t), fp); magicNum = timlUtilReverseEndian32(buffer); if (magicNum != DATA_MAGIC_NUM) { return ERROR_UTIL_MNIST_TRAINING_DATA_READING; } // read num read = fread(&buffer, 1, sizeof(uint32_t), fp); num = timlUtilReverseEndian32(buffer); // read row read = fread(&buffer, 1, sizeof(uint32_t), fp); row = timlUtilReverseEndian32(buffer); // read col read = fread(&buffer, 1, sizeof(uint32_t), fp); col = timlUtilReverseEndian32(buffer); // read the data dataBuffer = malloc(sizeof(uint8_t)*row*col*num); if (dataBuffer == NULL) { return ERROR_UTIL_MNIST_TRAINING_DATA_ALLOCATION; } read = fread(dataBuffer, row*col*num, sizeof(uint8_t), fp); training->data = malloc(sizeof(float)*row*col*num); training->mean = malloc(sizeof(float)*row*col*num); for (i = 0; i < row*col*num; i++) { training->mean[i] = 0.0; } if (training->data == NULL) { free(dataBuffer); return ERROR_UTIL_MNIST_TRAINING_DATA_ALLOCATION; } for (i = 0; i < row*col*num; i++) { training->data[i] = dataBuffer[i]; training->mean[i] += dataBuffer[i]; } training->num = num; training->channel = CHANNEL; training->row = row; training->col = col; free(dataBuffer); fclose(fp); // reading training labels strcpy(str, path); strcat(str, "/train-labels.idx1-ubyte"); fp = fopen(str, "rb"); if (fp == NULL) { return ERROR_UTIL_MNIST_TRAINING_LABEL_READING; } // read magic number read = fread(&buffer, 1, sizeof(uint32_t), fp); magicNum = timlUtilReverseEndian32(buffer); if (magicNum != LABEL_MAGIC_NUM) { return ERROR_UTIL_MNIST_TRAINING_LABEL_READING; } // read num read = fread(&buffer, 1, sizeof(uint32_t), fp); num = timlUtilReverseEndian32(buffer); labelBuffer = malloc(sizeof(uint8_t)*num); if (labelBuffer == NULL) { return ERROR_UTIL_MNIST_TRAINING_LABEL_ALLOCATION; } read = fread(labelBuffer, num, sizeof(uint8_t), fp); training->label = malloc(sizeof(int)*num); if (training->label == NULL) { free(labelBuffer); return ERROR_UTIL_MNIST_TRAINING_LABEL_ALLOCATION; } for (i = 0; i < num; i++) { training->label[i] = labelBuffer[i]; } free(labelBuffer); fclose(fp); // read testing data strcpy(str, path); strcat(str, "/t10k-images.idx3-ubyte"); fp = fopen(str, "rb"); if (fp == NULL) { return ERROR_UTIL_MNIST_TESTING_DATA_READING; } // read magic number read = fread(&buffer, 1, sizeof(uint32_t), fp); magicNum = timlUtilReverseEndian32(buffer); if (magicNum != DATA_MAGIC_NUM) { return ERROR_UTIL_MNIST_TESTING_DATA_READING; } // read num read = fread(&buffer, 1, sizeof(uint32_t), fp); num = timlUtilReverseEndian32(buffer); // read row read = fread(&buffer, 1, sizeof(uint32_t), fp); row = timlUtilReverseEndian32(buffer); // read col read = fread(&buffer, 1, sizeof(uint32_t), fp); col = timlUtilReverseEndian32(buffer); // read the data dataBuffer = malloc(sizeof(uint8_t)*row*col*num); if (dataBuffer == NULL) { return ERROR_UTIL_MNIST_TESTING_DATA_ALLOCATION; } read = fread(dataBuffer, row*col*num, sizeof(uint8_t), fp); testing->data = malloc(sizeof(float)*row*col*num); if (testing->data == NULL) { free(dataBuffer); return ERROR_UTIL_MNIST_TESTING_DATA_ALLOCATION; } for (i = 0; i < row*col*num; i++) { testing->data[i] = dataBuffer[i]; } testing->num = num; testing->channel = CHANNEL; testing->row = row; testing->col = col; free(dataBuffer); fclose(fp); // reading testing labels strcpy(str, path); strcat(str, "/t10k-labels.idx1-ubyte"); fp = fopen(str, "rb"); if (fp == NULL) { return ERROR_UTIL_MNIST_TRAINING_LABEL_READING; } // read magic number read = fread(&buffer, 1, sizeof(uint32_t), fp); magicNum = timlUtilReverseEndian32(buffer); if (magicNum != LABEL_MAGIC_NUM) { return ERROR_UTIL_MNIST_TESTING_LABEL_READING; } // read num read = fread(&buffer, 1, sizeof(uint32_t), fp); num = timlUtilReverseEndian32(buffer); labelBuffer = malloc(sizeof(uint8_t)*num); if (labelBuffer == NULL) { return ERROR_UTIL_MNIST_TESTING_LABEL_ALLOCATION; } read = fread(labelBuffer, num, sizeof(uint8_t), fp); testing->label = malloc(sizeof(int)*num); if (testing->label == NULL) { free(labelBuffer); return ERROR_UTIL_MNIST_TESTING_LABEL_ALLOCATION; } for (i = 0; i < num; i++) { testing->label[i] = labelBuffer[i]; } free(labelBuffer); fclose(fp); return 0; }