index 31c6b1d4a5b20952e16b55c0ffe4177b844cbd76..29d2eb363837f0c5bab4bc80dad159fee58bdb96 100644 (file)
+// Copyright 2013 Yangqing Jia
+
#ifndef CAFFE_VISION_LAYERS_HPP_
#define CAFFE_VISION_LAYERS_HPP_
+#include <leveldb/db.h>
+
+#include <vector>
+
#include "caffe/layer.hpp"
namespace caffe {
class NeuronLayer : public Layer<Dtype> {
public:
explicit NeuronLayer(const LayerParameter& param)
- : Layer<Dtype>(param) {};
+ : Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
};
class ReLULayer : public NeuronLayer<Dtype> {
public:
explicit ReLULayer(const LayerParameter& param)
- : NeuronLayer<Dtype>(param) {};
+ : NeuronLayer<Dtype>(param) {}
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
class DropoutLayer : public NeuronLayer<Dtype> {
public:
explicit DropoutLayer(const LayerParameter& param)
- : NeuronLayer<Dtype>(param) {};
+ : NeuronLayer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
class InnerProductLayer : public Layer<Dtype> {
public:
explicit InnerProductLayer(const LayerParameter& param)
- : Layer<Dtype>(param) {};
+ : Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
class PaddingLayer : public Layer<Dtype> {
public:
explicit PaddingLayer(const LayerParameter& param)
- : Layer<Dtype>(param) {};
+ : Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
class LRNLayer : public Layer<Dtype> {
public:
explicit LRNLayer(const LayerParameter& param)
- : Layer<Dtype>(param) {};
+ : Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
class Im2colLayer : public Layer<Dtype> {
public:
explicit Im2colLayer(const LayerParameter& param)
- : Layer<Dtype>(param) {};
+ : Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
class PoolingLayer : public Layer<Dtype> {
public:
explicit PoolingLayer(const LayerParameter& param)
- : Layer<Dtype>(param) {};
+ : Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
- //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
- // vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
- //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
- // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
int KSIZE_;
int STRIDE_;
int CHANNELS_;
class ConvolutionLayer : public Layer<Dtype> {
public:
explicit ConvolutionLayer(const LayerParameter& param)
- : Layer<Dtype>(param) {};
+ : Layer<Dtype>(param) {}
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
+
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
Blob<Dtype> col_bob_;
- protected:
+
int KSIZE_;
int STRIDE_;
int NUM_;
int N_;
};
+template <typename Dtype>
+class DataLayer : public Layer<Dtype> {
+ public:
+ explicit DataLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+ shared_ptr<leveldb::DB> db_;
+ shared_ptr<leveldb::Iterator> iter_;
+ int datum_size_;
+};
+
+
+template <typename Dtype>
+class SoftmaxLayer : public Layer<Dtype> {
+ public:
+ explicit SoftmaxLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ // virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ // vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+ // sum_multiplier is just used to carry out sum using blas
+ Blob<Dtype> sum_multiplier_;
+ // scale is an intermediate blob to hold temporary results.
+ Blob<Dtype> scale_;
+};
+
+template <typename Dtype>
+class MultinomialLogisticLossLayer : public Layer<Dtype> {
+ public:
+ explicit MultinomialLogisticLossLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ // The loss layer will do nothing during forward - all computation are
+ // carried out in the backward pass.
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+};
+
+template <typename Dtype>
+class EuclideanLossLayer : public Layer<Dtype> {
+ public:
+ explicit EuclideanLossLayer(const LayerParameter& param)
+ : Layer<Dtype>(param), difference_() {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ // The loss layer will do nothing during forward - all computation are
+ // carried out in the backward pass.
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) { return; }
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ // virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ Blob<Dtype> difference_;
+};
+
+
} // namespace caffe
#endif // CAFFE_VISION_LAYERS_HPP_