]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/layers/multinomial_logistic_loss_layer.cu
5ffa4accc5029ce847afb66438a4aa4aae1c6e30
[jacinto-ai/caffe-jacinto.git] / src / caffe / layers / multinomial_logistic_loss_layer.cu
1 // Copyright 2013 Yangqing Jia
3 #include "caffe/layer.hpp"
4 #include "caffe/vision_layers.hpp"
5 #include "caffe/util/math_functions.hpp"
6 #include <algorithm>
7 #include <cmath>
9 using std::max;
11 namespace caffe {
13 template <typename Dtype>
14 void MultinomialLogisticLossLayer<Dtype>::SetUp(
15     const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
16   CHECK_EQ(bottom.size(), 2) << "Loss Layer takes two blobs as input.";
17   CHECK_EQ(top->size(), 0) << "Loss Layer takes no as output.";
18   CHECK_EQ(bottom[0]->num(), bottom[1]->num())
19       << "The data and label should have the same number.";
20   CHECK_EQ(bottom[1]->channels(), 1);
21   CHECK_EQ(bottom[1]->height(), 1);
22   CHECK_EQ(bottom[1]->width(), 1);
23 };
26 template <typename Dtype>
27 Dtype MultinomialLogisticLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
28     const bool propagate_down,
29     vector<Blob<Dtype>*>* bottom) {
30   const Dtype* bottom_data = (*bottom)[0]->cpu_data();
31   const Dtype* bottom_label = (*bottom)[1]->cpu_data();
32   Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
33   int num = (*bottom)[0]->num();
34   int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
35   memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count());
36   Dtype loss = 0;
37   const Dtype kLOG_THRESHOLD = 1e-8;
38   for (int i = 0; i < num; ++i) {
39     int label = static_cast<int>(bottom_label[i]);
40     Dtype prob = max(bottom_data[i * dim + label], kLOG_THRESHOLD);
41     loss -= log(prob);
42     bottom_diff[i * dim + label] = - 1. / prob / num;
43   }
44   return loss / num;
45 }
47 // TODO: implement the GPU version
49 INSTANTIATE_CLASS(MultinomialLogisticLossLayer);
52 }  // namespace caffe