/******************************************************************************/ /*! * \file timlUtilReadCIFAR10.c */ /* Copyright (C) 2015 Texas Instruments Incorporated - http://www.ti.com/ * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the * distribution. * * Neither the name of Texas Instruments Incorporated nor the names of * its contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /******************************************************************************* * * INCLUDES * ******************************************************************************/ #include "../api/timl.h" /******************************************************************************* * * DEFINES * ******************************************************************************/ #define TRAIN_FILE_NUM 5 #define SIZE 32 #define CHANNEL 3 #define SAMPLE_PER_FILE 10000 /******************************************************************************/ /*! * \ingroup util * \brief Read CIFA10 database from binary files * \param[out] training Training database * \param[out] testing Testing database * \return Error code */ /******************************************************************************/ int timlUtilReadCIFAR10(const char* path, timlUtilImageSet *training, timlUtilImageSet *testing) { int err; int size; int channel; int num; int fileNum; int i; int j; int k; int read; char str[TIML_UTIL_MAX_STR]; char strNum[TIML_UTIL_MAX_STR]; FILE *fp; uint8_t labelBuffer; uint8_t *dataBuffer; fileNum = TRAIN_FILE_NUM; dataBuffer = NULL; num = SAMPLE_PER_FILE; size = SIZE; channel = CHANNEL; err = 0; training->channel = CHANNEL; training->col = SIZE; training->row = SIZE; training->num = num * fileNum; training->label = NULL; training->data = NULL; testing->channel = CHANNEL; testing->col = SIZE; testing->row = SIZE; testing->num = num; testing->label = NULL; testing->data = NULL; dataBuffer = malloc(sizeof(uint8_t)*size*size*channel); training->label = malloc(sizeof(int)*num*fileNum); training->data = malloc( sizeof(float)*num*fileNum*size*size*channel); training->mean = malloc(sizeof(float)*size*size*3); testing->label = malloc(sizeof(int)*num); testing->data = malloc(sizeof(float)*num*size*size*channel); testing->mean = NULL; for (i = 0; i < size*size*channel; i++) { training->mean[i] = 0.0; } // read all 5 training bin files for (i = 0; i < fileNum; i++) { strcpy(str, path); strcat(str, "/data_batch_"); sprintf(strNum, "%d", i + 1); strcat(str, strNum); strcat(str, ".bin"); fp = fopen(str, "rb"); if (fp == NULL) { return ERROR_UTIL_CIFAR10_TRAINING_READING; } for (j = 0; j < num; j++) { read = fread(&labelBuffer, 1, sizeof(uint8_t), fp); read = fread(dataBuffer, size*size*channel, sizeof(uint8_t), fp); training->label[i*num + j] = labelBuffer; for (k = 0; k < size*size*channel; k++) { training->data[k + i*num*size*size*channel + j*size*size*channel] = dataBuffer[k]; } cblas_saxpy(size*size*channel, 1.0, training->data + i*num*size*size*channel + j*size*size*channel, 1, training->mean, 1); } fclose(fp); } cblas_sscal(size*size*channel, 1.0/(num*fileNum), training->mean, 1); // read testing bin strcpy(str, path); strcat(str, "/test_batch.bin"); fp = fopen(str, "rb"); if (fp == NULL) { return ERROR_UTIL_CIFAR10_TESTING_READING; } for (j = 0; j < num; j++) { read = fread(&labelBuffer, 1, sizeof(uint8_t), fp); read = fread(dataBuffer, size*size*channel, sizeof(uint8_t), fp); testing->label[j] = labelBuffer; for (k = 0; k < size*size*channel; k++) { testing->data[k + j*size*size*channel] = dataBuffer[k]; } } fclose(fp); free(dataBuffer); return err; }