]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - jacinto-ai/caffe-jacinto.git/blob - src/caffe/net.cpp
doc update
[jacinto-ai/caffe-jacinto.git] / src / caffe / net.cpp
1 #include <algorithm>
2 #include <map>
3 #include <set>
4 #include <boost/thread.hpp>
5 #include <hdf5.h>
7 #include "caffe/common.hpp"
8 #include "caffe/layer.hpp"
9 #include "caffe/net.hpp"
10 #include "caffe/parallel.hpp"
11 #include "caffe/util/hdf5.hpp"
12 #include "caffe/util/insert_splits.hpp"
13 #include "caffe/util/math_functions.hpp"
14 #include "caffe/util/signal_handler.h"
15 #include "caffe/util/upgrade_proto.hpp"
16 #include "caffe/filler.hpp"
18 #include "caffe/test/test_caffe_main.hpp"
20 namespace caffe {
22 constexpr int Net::END_OF_ITERATION;
23 constexpr int Net::END_OF_TRAIN;
25 Net::Net(const NetParameter& param,
26     size_t solver_rank,
27     Flag* solver_init_flag,
28     const Net* root_net,
29     bool inner_net,
30     int level,
31     const vector<string>* stages)
32     : root_net_(root_net),
33       solver_(nullptr),
34       solver_rank_(solver_rank),
35       solver_init_flag_(solver_init_flag),
36       inner_net_(inner_net) {
37   Init(param);
38 }
40 Net::Net(const string& param_file,
41     Phase phase,
42     size_t solver_rank,
43     Flag* solver_init_flag,
44     const Net* root_net,
45     bool inner_net,
46     int level,
47     const vector<string>* stages)
48     : root_net_(root_net),
49       solver_(nullptr),
50       solver_rank_(solver_rank),
51       solver_init_flag_(solver_init_flag),
52       inner_net_(inner_net) {
53   NetParameter param;
54   ReadNetParamsFromTextFileOrDie(param_file, &param);
55   // Set phase, stages and level
56   param.mutable_state()->set_phase(phase);
57   if (stages != NULL) {
58     for (int i = 0; i < stages->size(); ++i) {
59       param.mutable_state()->add_stage(stages->at(i));
60     }
61   }
62   param.mutable_state()->set_level(level);
63   Init(param);
64 }
66 Net::~Net() {
67 }
69 void Net::Init(const NetParameter& in_param) {
70   CHECK(inner_net_ || Caffe::root_solver() || root_net_)
71       << "root_net_ needs to be set for all non-root solvers";
72   // Set phase from the state.
73   phase_ = in_param.state().phase();
74   // Filter layers based on their include/exclude rules and
75   // the current NetState.
76   NetParameter filtered_param;
77   FilterNet(in_param, &filtered_param);
78   net_param_ = filtered_param;
79   batch_per_solver_ = caffe::P2PSync::divide_batch_size(&filtered_param);
80   LOG_IF(INFO, Caffe::root_solver())
81       << "Initializing net from parameters: " << std::endl
82       << filtered_param.DebugString();
83   infer_count_ = 0UL;
84   // Create a copy of filtered_param with splits added where necessary.
85   NetParameter param;
86   InsertSplits(filtered_param, &param);
87   // Basically, build all the layers and set up their connections.
88   name_ = param.name();
89   map<string, int> blob_name_to_idx;
90   set<string> available_blobs;
91   gpu_top_memory_data_use_ = gpu_top_memory_diff_use_ = 0UL;
92   gpu_btm_memory_data_use_ = gpu_btm_memory_diff_use_ = 0UL;
93   gpu_shr_memory_data_use_ = gpu_shr_memory_diff_use_ = 0UL;
94   gpu_prm_memory_data_use_ = gpu_prm_memory_diff_use_ = 0UL;
95   gpu_shp_memory_data_use_ = gpu_shp_memory_diff_use_ = 0UL;
96   // For each layer, set up its input and output
97   bottom_vecs_.resize(param.layer_size());
98   top_vecs_.resize(param.layer_size());
99   bottom_id_vecs_.resize(param.layer_size());
100   param_id_vecs_.resize(param.layer_size());
101   top_id_vecs_.resize(param.layer_size());
102   bottom_need_backward_.resize(param.layer_size());
104   // If user skips default math type we use default data type:
105   Type default_fmath, default_bmath;
106   if (in_param.has_default_forward_math()) {
107     default_fmath = in_param.default_forward_math();
108   } else {
109     default_fmath = in_param.default_forward_type();
110     LOG(INFO) << "Using " << Type_Name(default_fmath) << " as default forward math type";
111   }
112   if (in_param.has_default_backward_math()) {
113     default_bmath = in_param.default_backward_math();
114   } else {
115     default_bmath = in_param.default_backward_type();
116     LOG(INFO) << "Using " << Type_Name(default_bmath) << " as default backward math type";
117   }
119   wgrad_sq_.store(0LL);
120   global_grad_scale_coeff_ = 1.F;
121   has_global_grad_scale_param_ = in_param.has_global_grad_scale();
122   global_grad_scale_param_ = in_param.global_grad_scale();
123   global_grad_scale_adaptive_ = in_param.global_grad_scale_adaptive();
125   for (int layer_id = 0; layer_id < param.layer_size(); ++layer_id) {
126     // For non-root solvers, whether this layer is shared from root_net_.
127     bool share_from_root = !inner_net_ && !Caffe::root_solver()
128         && root_net_->layers_[layer_id]->ShareInParallel();
130     const LayerParameter& layer_param = param.layer(layer_id);
131     LayerParameter* mutable_layer_param = param.mutable_layer(layer_id);
133     DLOG_IF(INFO, Caffe::root_solver())
134         << "Setting types for Layer " << layer_param.name();
136     // Inherit phase from net if unset.
137     if (!layer_param.has_phase()) {
138       mutable_layer_param->set_phase(phase_);
139     }
140     const bool is_data_layer = layer_param.has_transform_param();
142     // Data&Math types
143     const bool fm_by_user = layer_param.has_forward_math();
144     if (!fm_by_user) {
145       if (layer_param.has_forward_type()) {
146         mutable_layer_param->set_forward_math(layer_param.forward_type());
147       } else {
148         mutable_layer_param->set_forward_math(default_fmath);
149       }
150     }
151     const bool bm_by_user = layer_param.has_backward_math();
152     if (!bm_by_user) {
153       if (layer_param.has_backward_type()) {
154         mutable_layer_param->set_backward_math(layer_param.backward_type());
155       } else {
156         mutable_layer_param->set_backward_math(default_bmath);
157       }
158     }
160     if (!layer_param.has_forward_type()) {
161       mutable_layer_param->set_forward_type(in_param.default_forward_type());
162     }
163     if (!layer_param.has_backward_type()) {
164       if (is_data_layer) {
165         // In majority of cases we manage to avoid redundant conversion:
166         mutable_layer_param->set_backward_type(FLOAT);
167       } else {
168         mutable_layer_param->set_backward_type(in_param.default_backward_type());
169       }
170     }
172     // Convolution algorithms
173     if (param.has_default_conv_algos_override() && layer_param.has_convolution_param() &&
174         !layer_param.convolution_param().has_conv_algos_override()) {
175       mutable_layer_param->mutable_convolution_param()->
176           set_conv_algos_override(param.default_conv_algos_override());
177     }
179     // cuDNN math
180     if (param.has_default_cudnn_math_override() &&
181         !layer_param.has_cudnn_math_override()) {
182       mutable_layer_param->set_cudnn_math_override(param.default_cudnn_math_override());
183     }
185     // Setup layer.
186     if (layer_param.propagate_down_size() > 0) {
187       CHECK_EQ(layer_param.propagate_down_size(),
188           layer_param.bottom_size())
189           << "propagate_down param must be specified "
190           << "either 0 or bottom_size times ";
191     }
192     if (share_from_root) {
193       LOG(INFO) << "Sharing layer " << layer_param.name() << " from root net";
194       layers_.push_back(root_net_->layers_[layer_id]);
195       layers_[layer_id]->SetShared(true);
196     } else {
197       layers_.push_back(LayerRegistry::CreateLayer(layer_param, solver_rank_));
198     }
199     layer_names_.push_back(layer_param.name());
200     LOG_IF(INFO, Caffe::root_solver())
201         << "Created Layer " << layer_param.name() << " (" << layer_id << ")";
202     bool need_backward = false;
204     // Figure out this layer's input and output
205     for (int bottom_id = 0; bottom_id < layer_param.bottom_size(); ++bottom_id) {
206       const int blob_id = AppendBottom(param, layer_id, bottom_id,
207                                        &available_blobs, &blob_name_to_idx);
208       // If a blob needs backward, this layer should provide it.
209       need_backward |= blob_need_backward_[blob_id];
210     }
211     int num_top = layer_param.top_size();
212     for (int top_id = 0; top_id < num_top; ++top_id) {
213       AppendTop(param, layer_id, top_id, &available_blobs, &blob_name_to_idx);
214       // Collect Input layer tops as Net inputs.
215       if (layer_param.type() == "Input") {
216         const int blob_id = blobs_.size() - 1;
217         net_input_blob_indices_.push_back(blob_id);
218         net_input_blobs_.push_back(blobs_[blob_id].get());
219       }
220     }
221     // If the layer specifies that AutoTopBlobs() -> true and the LayerParameter
222     // specified fewer than the required number (as specified by
223     // ExactNumTopBlobs() or MinTopBlobs()), allocate them here.
224     LayerBase* layer = layers_[layer_id].get();
225     if (layer->AutoTopBlobs()) {
226       const int needed_num_top =
227           std::max(layer->MinTopBlobs(), layer->ExactNumTopBlobs());
228       for (; num_top < needed_num_top; ++num_top) {
229         // Add "anonymous" top blobs -- do not modify available_blobs or
230         // blob_name_to_idx as we don't want these blobs to be usable as input
231         // to other layers.
232         AppendTop(param, layer_id, num_top, NULL, NULL);
233       }
234     }
235     layer->fm_by_user(fm_by_user);
236     layer->bm_by_user(bm_by_user);
238     layers_[layer_id]->set_net_initialized_flag(solver_init_flag_);
240     Flag* layer_inititialized_flag = layers_[layer_id]->layer_inititialized_flag();
241     if (layer_inititialized_flag != nullptr) {
242       layer_inititialized_flags_.push_back(layer_inititialized_flag);
243     }
245     // After this layer is connected, set it up.
246     if (share_from_root) {
247       // Set up size of top blobs using root_net_
248       const vector<Blob*>& base_top = root_net_->top_vecs_[layer_id];
249       const vector<Blob*>& this_top = this->top_vecs_[layer_id];
250       for (int top_id = 0; top_id < base_top.size(); ++top_id) {
251         this_top[top_id]->ReshapeLike(*base_top[top_id]);
252         LOG(INFO) << "Created top blob " << top_id << " (shape: "
253             << this_top[top_id]->shape_string() <<  ") for shared layer "
254             << layer_param.name();
255       }
256     } else {
257       layers_[layer_id]->set_parent_net(this);
258       layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]);
259     }
260     LOG_IF(INFO, Caffe::root_solver())
261         << "Setting up " << layer_names_[layer_id];
262     for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
263       if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) {
264         blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, 0.F);
265       }
266       blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id);
267       LOG_IF(INFO, Caffe::root_solver())
268           << Phase_Name(phase_) << " Top shape for layer " << layer_id << " '"
269           << layer_names_[layer_id] << "' " <<  top_vecs_[layer_id][top_id]->shape_string();
270       if (layer->loss(top_id) != 0.F) {
271         LOG_IF(INFO, Caffe::root_solver())
272           << "    with loss weight " << layer->loss(top_id);
273       }
274       gpu_top_memory_data_use_ += top_vecs_[layer_id][top_id]->gpu_memory_data_use();
275       gpu_top_memory_diff_use_ += top_vecs_[layer_id][top_id]->gpu_memory_diff_use();
276     }
277     const int param_size = layer_param.param_size();
278     const int num_param_blobs = layers_[layer_id]->blobs().size();
279     CHECK_LE(param_size, num_param_blobs)
280         << "Too many params specified for layer " << layer_param.name();
281     ParamSpec default_param_spec;
282     for (int param_id = 0; param_id < num_param_blobs; ++param_id) {
283       const ParamSpec* param_spec = (param_id < param_size) ?
284           &layer_param.param(param_id) : &default_param_spec;
285       const bool param_need_backward = param_spec->lr_mult() != 0;
286       need_backward |= param_need_backward;
287       layers_[layer_id]->set_param_propagate_down(param_id,
288                                                   param_need_backward);
289     }
290     for (int param_id = 0; param_id < num_param_blobs; ++param_id) {
291       AppendParam(param, layer_id, param_id);
292     }
293     // Finally, set the backward flag
294     layer_need_backward_.push_back(need_backward);
295     if (need_backward) {
296       for (int top_id = 0; top_id < top_id_vecs_[layer_id].size(); ++top_id) {
297         blob_need_backward_[top_id_vecs_[layer_id][top_id]] = true;
298       }
299     }
300   }
301   // Go through the net backwards to determine which blobs contribute to the
302   // loss.  We can skip backward computation for blobs that don't contribute
303   // to the loss.
304   // Also checks if all bottom blobs don't need backward computation (possible
305   // because the skip_propagate_down param) and so we can skip backward
306   // computation for the entire layer
307   set<string> blobs_under_loss;
308   set<string> blobs_skip_backp;
309   for (int layer_id = layers_.size() - 1; layer_id >= 0; --layer_id) {
310     bool layer_contributes_loss = false;
311     bool layer_skip_propagate_down = true;
312     for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
313       const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
314       if (layers_[layer_id]->loss(top_id) != 0.F ||
315           (blobs_under_loss.find(blob_name) != blobs_under_loss.end())) {
316         layer_contributes_loss = true;
317       }
318       if (blobs_skip_backp.find(blob_name) == blobs_skip_backp.end()) {
319         layer_skip_propagate_down = false;
320       }
321       if (layer_contributes_loss && !layer_skip_propagate_down)
322         break;
323     }
324     // If this layer can skip backward computation, also all his bottom blobs
325     // don't need backpropagation
326     if (layer_need_backward_[layer_id] && layer_skip_propagate_down) {
327       layer_need_backward_[layer_id] = false;
328       for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
329                ++bottom_id) {
330         bottom_need_backward_[layer_id][bottom_id] = false;
331       }
332     }
333     if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; }
334     if (Caffe::root_solver()) {
335       if (layer_need_backward_[layer_id]) {
336         LOG(INFO) << layer_names_[layer_id] << " needs backward computation.";
337       } else {
338         LOG(INFO) << layer_names_[layer_id]
339             << " does not need backward computation.";
340       }
341     }
342     for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
343          ++bottom_id) {
344       if (layer_contributes_loss) {
345         const string& blob_name =
346             blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
347         blobs_under_loss.insert(blob_name);
348       } else {
349         bottom_need_backward_[layer_id][bottom_id] = false;
350       }
351       if (!bottom_need_backward_[layer_id][bottom_id]) {
352         const string& blob_name =
353                    blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
354         blobs_skip_backp.insert(blob_name);
355       }
356     }
357   }
358   // Handle force_backward if needed.
359   if (param.force_backward()) {
360     for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
361       layer_need_backward_[layer_id] = true;
362       for (int bottom_id = 0;
363            bottom_id < bottom_need_backward_[layer_id].size(); ++bottom_id) {
364         bottom_need_backward_[layer_id][bottom_id] =
365             bottom_need_backward_[layer_id][bottom_id] ||
366             layers_[layer_id]->AllowForceBackward(bottom_id);
367         blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] =
368             blob_need_backward_[bottom_id_vecs_[layer_id][bottom_id]] ||
369             bottom_need_backward_[layer_id][bottom_id];
370       }
371       for (int param_id = 0; param_id < layers_[layer_id]->blobs().size();
372            ++param_id) {
373         layers_[layer_id]->set_param_propagate_down(param_id, true);
374       }
375     }
376   }
377   // In the end, all remaining blobs are considered output blobs.
378   for (set<string>::iterator it = available_blobs.begin();
379       it != available_blobs.end(); ++it) {
380     LOG_IF(INFO, Caffe::root_solver())
381         << "This network produces output " << *it;
382     net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get());
383     net_output_blob_indices_.push_back(blob_name_to_idx[*it]);
384   }
385   for (int blob_id = 0; blob_id < blob_names_.size(); ++blob_id) {
386     blob_names_index_[blob_names_[blob_id]] = blob_id;
387   }
388   for (int layer_id = 0; layer_id < layer_names_.size(); ++layer_id) {
389     layer_names_index_[layer_names_[layer_id]] = layer_id;
390   }
391   ShareWeights();
393   // invert param_layer_indices_ to give map of
394   // (level_id, local param_id) -> global param_id
395   for (int i = 0; i < param_layer_indices_.size(); ++i) {
396     layer_index_params_[param_layer_indices_[i]] = i;
397   }
399   learnable_space_size_[0] = 0UL;
400   learnable_space_size_[1] = 0UL;
401   reduce_buckets_ = (size_t) in_param.reduce_buckets();
402   if (Caffe::device_count() > 0) {
403     LOG_IF(INFO, Caffe::root_solver())
404         << "Top memory (" << Phase_Name(phase_) << ") required for data: "
405         << gpu_top_memory_data_use_ << " diff: " << gpu_top_memory_diff_use_;
406     LOG_IF(INFO, Caffe::root_solver())
407         << "Bottom memory (" << Phase_Name(phase_) << ") required for data: "
408         << gpu_btm_memory_data_use_ << " diff: " << gpu_btm_memory_diff_use_;
409     LOG_IF(INFO, Caffe::root_solver())
410         << "Shared (in-place) memory (" << Phase_Name(phase_) << ") by data: "
411         << gpu_shr_memory_data_use_ << " diff: " << gpu_shr_memory_diff_use_;
412     LOG_IF(INFO, Caffe::root_solver())
413         << "Parameters memory (" << Phase_Name(phase_) << ") required for data: "
414         << gpu_prm_memory_data_use_ << " diff: " << gpu_prm_memory_diff_use_;
415     LOG_IF(INFO, Caffe::root_solver())
416         << "Parameters shared memory (" << Phase_Name(phase_) << ") by data: "
417         << gpu_shp_memory_data_use_ << " diff: " << gpu_shp_memory_diff_use_;
418   }
419   debug_info_ = param.debug_info();
420   trained_layers_shared_ = false;
421   LOG_IF(INFO, Caffe::root_solver()) << "Network initialization done.";
424 void Net::FilterNet(const NetParameter& param, NetParameter* param_filtered) {
425   NetState net_state(param.state());
426   param_filtered->CopyFrom(param);
427   param_filtered->clear_layer();
428   for (int i = 0; i < param.layer_size(); ++i) {
429     const LayerParameter& layer_param = param.layer(i);
430     const string& layer_name = layer_param.name();
431     CHECK(layer_param.include_size() == 0 || layer_param.exclude_size() == 0)
432           << "Specify either include rules or exclude rules; not both.";
433     // If no include rules are specified, the layer is included by default and
434     // only excluded if it meets one of the exclude rules.
435     bool layer_included = (layer_param.include_size() == 0);
436     for (int j = 0; layer_included && j < layer_param.exclude_size(); ++j) {
437       if (StateMeetsRule(net_state, layer_param.exclude(j), layer_name)) {
438         layer_included = false;
439       }
440     }
441     for (int j = 0; !layer_included && j < layer_param.include_size(); ++j) {
442       if (StateMeetsRule(net_state, layer_param.include(j), layer_name)) {
443         layer_included = true;
444       }
445     }
446     if (layer_included) {
447       param_filtered->add_layer()->CopyFrom(layer_param);
448     }
449   }
452 bool Net::StateMeetsRule(const NetState& state,
453     const NetStateRule& rule, const string& layer_name) {
454   // Check whether the rule is broken due to phase.
455   if (rule.has_phase()) {
456       if (rule.phase() != state.phase()) {
457         LOG_IF(INFO, Caffe::root_solver())
458             << "The NetState phase (" << state.phase()
459             << ") differed from the phase (" << rule.phase()
460             << ") specified by a rule in layer " << layer_name;
461         return false;
462       }
463   }
464   // Check whether the rule is broken due to min level.
465   if (rule.has_min_level()) {
466     if (state.level() < rule.min_level()) {
467       LOG_IF(INFO, Caffe::root_solver())
468           << "The NetState level (" << state.level()
469           << ") is above the min_level (" << rule.min_level()
470           << ") specified by a rule in layer " << layer_name;
471       return false;
472     }
473   }
474   // Check whether the rule is broken due to max level.
475   if (rule.has_max_level()) {
476     if (state.level() > rule.max_level()) {
477       LOG_IF(INFO, Caffe::root_solver())
478           << "The NetState level (" << state.level()
479           << ") is above the max_level (" << rule.max_level()
480           << ") specified by a rule in layer " << layer_name;
481       return false;
482     }
483   }
484   // Check whether the rule is broken due to stage. The NetState must
485   // contain ALL of the rule's stages to meet it.
486   for (int i = 0; i < rule.stage_size(); ++i) {
487     // Check that the NetState contains the rule's ith stage.
488     bool has_stage = false;
489     for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
490       if (rule.stage(i) == state.stage(j)) { has_stage = true; }
491     }
492     if (!has_stage) {
493       LOG_IF(INFO, Caffe::root_solver())
494           << "The NetState did not contain stage '" << rule.stage(i)
495           << "' specified by a rule in layer " << layer_name;
496       return false;
497     }
498   }
499   // Check whether the rule is broken due to not_stage. The NetState must
500   // contain NONE of the rule's not_stages to meet it.
501   for (int i = 0; i < rule.not_stage_size(); ++i) {
502     // Check that the NetState contains the rule's ith not_stage.
503     bool has_stage = false;
504     for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
505       if (rule.not_stage(i) == state.stage(j)) { has_stage = true; }
506     }
507     if (has_stage) {
508       LOG_IF(INFO, Caffe::root_solver())
509           << "The NetState contained a not_stage '" << rule.not_stage(i)
510           << "' specified by a rule in layer " << layer_name;
511       return false;
512     }
513   }
514   return true;
517 // Helper for Net::Init: add a new top blob to the net.
518 void Net::AppendTop(const NetParameter& param, const int layer_id, const int top_id,
519     set<string>* available_blobs, map<string, int>* blob_name_to_idx) {
520   const LayerParameter& layer_param = param.layer(layer_id);
521   const string& blob_name = (layer_param.top_size() > top_id) ?
522       layer_param.top(top_id) : "(automatic)";
523   // Check if we are doing in-place computation
524   if (blob_name_to_idx && layer_param.bottom_size() > top_id &&
525       blob_name == layer_param.bottom(top_id)) {
526     // In-place computation
527     LOG_IF(INFO, Caffe::root_solver())
528         << layer_param.name() << " -> " << blob_name << " (in-place)";
529     top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get());
530     top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]);
531     gpu_shr_memory_data_use_ += top_vecs_[layer_id].back()->gpu_memory_data_use();
532     gpu_shr_memory_diff_use_ += top_vecs_[layer_id].back()->gpu_memory_diff_use();
533   } else if (blob_name_to_idx &&
534              blob_name_to_idx->find(blob_name) != blob_name_to_idx->end()) {
535     // If we are not doing in-place computation but have duplicated blobs,
536     // raise an error.
537     LOG(FATAL) << "Top blob '" << blob_name
538                << "' produced by multiple sources.";
539   } else {
540     // Normal output.
541     if (Caffe::root_solver()) {
542       LOG(INFO) << layer_param.name() << " -> " << blob_name;
543     }
545     Type ftype = layer_param.has_forward_type() ? layer_param.forward_type() :
546         param.default_forward_type();
547     Type btype = layer_param.has_backward_type() ? layer_param.backward_type() :
548         param.default_backward_type();
549     shared_ptr<Blob> blob_pointer = Blob::create(ftype, btype);
550     const int blob_id = blobs_.size();
551     blobs_.push_back(blob_pointer);
552     blob_names_.push_back(blob_name);
553     blob_need_backward_.push_back(false);
554     if (blob_name_to_idx) { (*blob_name_to_idx)[blob_name] = blob_id; }
555     top_id_vecs_[layer_id].push_back(blob_id);
556     top_vecs_[layer_id].push_back(blob_pointer.get());
557   }
558   if (available_blobs) { available_blobs->insert(blob_name); }
561 // Helper for Net::Init: add a new bottom blob to the net.
562 int Net::AppendBottom(const NetParameter& param, const int layer_id,
563     const int bottom_id, set<string>* available_blobs,
564     map<string, int>* blob_name_to_idx) {
565   const LayerParameter& layer_param = param.layer(layer_id);
566   const string& blob_name = layer_param.bottom(bottom_id);
567   if (available_blobs->find(blob_name) == available_blobs->end()) {
568     LOG(FATAL) << "Unknown bottom blob '" << blob_name << "' (layer '"
569                << layer_param.name() << "', bottom index " << bottom_id << ")";
570   }
571   const int blob_id = (*blob_name_to_idx)[blob_name];
572   LOG_IF(INFO, Caffe::root_solver())
573       << layer_names_[layer_id] << " <- " << blob_name;
574   bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
575   bottom_id_vecs_[layer_id].push_back(blob_id);
576   available_blobs->erase(blob_name);
577   bool need_backward = blob_need_backward_[blob_id];
578   // Check if the backpropagation on bottom_id should be skipped
579   if (layer_param.propagate_down_size() > 0) {
580     need_backward = layer_param.propagate_down(bottom_id);
581   }
582   bottom_need_backward_[layer_id].push_back(need_backward);
583   gpu_btm_memory_data_use_ += bottom_vecs_[layer_id].back()->gpu_memory_data_use();
584   gpu_btm_memory_diff_use_ += bottom_vecs_[layer_id].back()->gpu_memory_diff_use();
585   return blob_id;
588 void Net::AppendParam(const NetParameter& param, const int layer_id, const int param_id) {
589   const LayerParameter& layer_param = layers_[layer_id]->layer_param();
590   const int param_size = layer_param.param_size();
591   string param_name =
592       (param_size > param_id) ? layer_param.param(param_id).name() : "";
593   if (param_name.size()) {
594     param_display_names_.push_back(param_name);
595   } else {
596     ostringstream param_display_name;
597     param_display_name << param_id;
598     param_display_names_.push_back(param_display_name.str());
599   }
600   const int net_param_id = params_.size();
601   params_.push_back(layers_[layer_id]->blobs()[param_id]);
602   param_id_vecs_[layer_id].push_back(net_param_id);
603   param_layer_indices_.push_back(make_pair(layer_id, param_id));
604   ParamSpec default_param_spec;
605   const ParamSpec* param_spec = (layer_param.param_size() > param_id) ?
606       &layer_param.param(param_id) : &default_param_spec;
607   if (!param_size || !param_name.size() || (param_name.size() &&
608       param_names_index_.find(param_name) == param_names_index_.end())) {
609     // This layer "owns" this parameter blob -- it is either anonymous
610     // (i.e., not given a param_name) or explicitly given a name that we
611     // haven't already seen.
612     param_owners_.push_back(-1);
613     if (param_name.size()) {
614       param_names_index_[param_name] = net_param_id;
615     }
616     const int learnable_param_id = learnable_params_.size();
617     learnable_params_.push_back(params_[net_param_id]);
618     learnable_param_ids_.push_back(learnable_param_id);
619     has_params_lr_.push_back(param_spec->has_lr_mult());
620     has_params_decay_.push_back(param_spec->has_decay_mult());
621     params_lr_.push_back(param_spec->lr_mult());
622     params_weight_decay_.push_back(param_spec->decay_mult());
623   } else {
624     // Named param blob with name we've seen before: share params
625     const int owner_net_param_id = param_names_index_[param_name];
626     param_owners_.push_back(owner_net_param_id);
627     const pair<int, int>& owner_index =
628         param_layer_indices_[owner_net_param_id];
629     const int owner_layer_id = owner_index.first;
630     const int owner_param_id = owner_index.second;
631     LOG_IF(INFO, Caffe::root_solver()) << "Sharing parameters '" << param_name
632         << "' owned by "
633         << "layer '" << layer_names_[owner_layer_id] << "', param "
634         << "index " << owner_param_id;
635     Blob* this_blob = layers_[layer_id]->blobs()[param_id].get();
636     Blob* owner_blob =
637         layers_[owner_layer_id]->blobs()[owner_param_id].get();
638     const int param_size = layer_param.param_size();
639     if (param_size > param_id && (layer_param.param(param_id).share_mode() ==
640                                   ParamSpec_DimCheckMode_PERMISSIVE)) {
641       // Permissive dimension checking -- only check counts are the same.
642       CHECK_EQ(this_blob->count(), owner_blob->count())
643           << "Cannot share param '" << param_name << "' owned by layer '"
644           << layer_names_[owner_layer_id] << "' with layer '"
645           << layer_names_[layer_id] << "'; count mismatch.  Owner layer param "
646           << "shape is " << owner_blob->shape_string() << "; sharing layer "
647           << "shape is " << this_blob->shape_string();
648     } else {
649       // Strict dimension checking -- all dims must be the same.
650       CHECK(this_blob->shape() == owner_blob->shape())
651           << "Cannot share param '" << param_name << "' owned by layer '"
652           << layer_names_[owner_layer_id] << "' with layer '"
653           << layer_names_[layer_id] << "'; shape mismatch.  Owner layer param "
654           << "shape is " << owner_blob->shape_string() << "; sharing layer "
655           << "expects shape " << this_blob->shape_string();
656     }
657     const int learnable_param_id = learnable_param_ids_[owner_net_param_id];
658     learnable_param_ids_.push_back(learnable_param_id);
659     if (param_spec->has_lr_mult()) {
660       if (has_params_lr_[learnable_param_id]) {
661         CHECK_EQ(param_spec->lr_mult(), params_lr_[learnable_param_id])
662             << "Shared param '" << param_name << "' has mismatched lr_mult.";
663       } else {
664         has_params_lr_[learnable_param_id] = true;
665         params_lr_[learnable_param_id] = param_spec->lr_mult();
666       }
667     }
668     if (param_spec->has_decay_mult()) {
669       if (has_params_decay_[learnable_param_id]) {
670         CHECK_EQ(param_spec->decay_mult(),
671                  params_weight_decay_[learnable_param_id])
672             << "Shared param '" << param_name << "' has mismatched decay_mult.";
673       } else {
674         has_params_decay_[learnable_param_id] = true;
675         params_weight_decay_[learnable_param_id] = param_spec->decay_mult();
676       }
677     }
678   }
681 float Net::ForwardFromTo(int start, int end) {
682   CHECK_GE(start, 0);
683   CHECK_LT(end, layers_.size());
685   this->StartQuantization();
687   float loss = 0;
688   for (int i = start; i <= end; ++i) {
689     // LOG(INFO) << " ****** [Forward] (" << i << ") Layer '" << layer_names_[i];
690     // << "' FT " << Type_Name(layers_[i]->forward_type())
691     // << " BT " << Type_Name(layers_[i]->backward_type());
692     float layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
693     loss += layer_loss;
694     if (debug_info_) { ForwardDebugInfo(i); }
695   }
697   this->FinishQuantization();
699   ++infer_count_;
700   return loss;
703 float Net::ForwardFrom(int start) {
704   return ForwardFromTo(start, layers_.size() - 1);
707 float Net::ForwardTo(int end) {
708   return ForwardFromTo(0, end);
711 const vector<Blob*>& Net::Forward(float* loss) {
712   if (loss != NULL) {
713     *loss = ForwardFromTo(0, layers_.size() - 1);
714   } else {
715     ForwardFromTo(0, layers_.size() - 1);
716   }
717   return net_output_blobs_;
720 const vector<Blob*>& Net::Forward(const vector<Blob*>& bottom, float* loss) {
721   LOG_EVERY_N(WARNING, 1000) << "DEPRECATED: Forward(bottom, loss) "
722       << "will be removed in a future version. Use Forward(loss).";
723   // Copy bottom to net bottoms
724   for (int i = 0; i < bottom.size(); ++i) {
725     net_input_blobs_[i]->CopyFrom(*bottom[i]);
726   }
727   return Forward(loss);
730 float Net::ForwardBackward(bool apply_update) {
731   float loss;
732   Forward(&loss);
733   Backward(apply_update);
734   return loss;
737 void Net::BackwardFromTo(int start, int end) {
738   BackwardFromToAu(start, end, true);
741 void Net::BackwardFromToAu(int start, int end, bool apply_update) {
742   CHECK_GE(end, 0);
743   CHECK_LT(start, layers_.size());
744   for (int i = start; i >= end; --i) {
745     if (!layer_need_backward_[i]) {
746       continue;
747     }
749     layers_[i]->Backward(top_vecs_[i], bottom_need_backward_[i], bottom_vecs_[i]);
751     if (debug_info_) {
752       BackwardDebugInfo(i);
753     }
754     if (!apply_update) {
755       continue;
756     }
757     for (int j = 0; j < layers_[i]->blobs().size(); ++j) {
758       if (layers_[i]->skip_apply_update(j)) {
759         continue;
760       }
761       const int param_id = layer_index_params_[make_pair(i, j)];
762       if (param_owners_[param_id] < 0) {
763         const int lparam_id = learnable_param_ids_[param_id];
764         int t = (int)learnable_params_[lparam_id]->diff_type();
765         for (int type_id = 0; type_id < learnable_types().size(); ++type_id) {
766           if (t == learnable_types_[type_id]) {
767             reduction_queue_[type_id].push(lparam_id);
768             break;
769           }
770         }
771       }  // leave it to the owner otherwise
772     }
773   }
774   if (apply_update) {
775     for (int type_id = 0; type_id < learnable_types_.size(); ++type_id) {
776       reduction_queue_[type_id].push(END_OF_ITERATION);
777     }
778   }
781 void Net::Finalize() {
782   for (int type_id = 0; type_id < learnable_types_.size(); ++type_id) {
783     reduction_queue_[type_id].push(END_OF_TRAIN);
784   }
787 size_t Net::received_contiguous_count(int type_id, const std::set<int>& au_ids, int& id_from_ret) {
788   if (learnable_params_.empty() || au_ids.empty() || param_id_vecs_.empty()) {
789     return 0;
790   }
791   size_t cnt_ret = 0UL, cnt = 0UL;
792   const int bottom = *au_ids.begin();
793   const int top = *au_ids.rbegin();
794   id_from_ret = -1;
795   const std::map<size_t, std::set<int>>& ltop = ltop_[type_id];
796   for (auto lit = ltop.rbegin(); lit != ltop.rend(); ++lit) {
797     if (lit->second.empty() || *lit->second.begin() > top) {
798       continue;
799     }
800     bool layer_complete = true;
801     for (auto p = lit->second.begin(); p != lit->second.end(); ++p) {
802       int param_id = *p;
803       if (param_id < bottom || au_ids.find(param_id) == au_ids.end()) {
804         layer_complete = false;
805         break;
806       }
807       cnt += lp_aligned_count(param_id);
808     }
809     if (layer_complete) {
810       id_from_ret = *lit->second.begin();
811       cnt_ret = cnt;
812     } else {
813       break;
814     }
815   }
816   return cnt_ret;
819 void Net::ReduceAndUpdate(int type_id) {
820   DLOG(INFO) << "[" << Caffe::current_device()
821              << "] Entering ReduceAndUpdate thread " << lwp_id()
822              <<  ", type_id " << type_id;
824   size_t bucket_size = 0UL;
825   cublasHandle_t handle = Caffe::cublas_handle(type_id);
826   CHECK_GE(reduce_buckets_, 0);
827   if (Caffe::solver_count() > 1 && reduce_buckets_ > 0) {
828     bucket_size = align_up<6>(learnable_space_size_[type_id] / reduce_buckets_);
829   }
830   std::set<int> au_ids;
832   const bool clip_grads = solver_->param().clip_gradients() >= 0.F;
833   const bool clear_grads = !solver_->param().snapshot_diff() && !clip_grads;
834   const bool use_buckets = reduce_buckets_ > 0;
835   float rate = -1.F;
836   while (!solver_->stop_reducing_requested(type_id)) {
837     const int param_id = reduction_queue_[type_id].pop();
838     SolverAction::Enum request = solver_->GetRequestedAction();
839     if (SolverAction::STOP == request) {
840       solver_->request_early_exit();
841       break;
842     }
843     if (param_id == END_OF_TRAIN) {
844       break;
845     }
846     if (rate < 0.F) {
847       rate = solver_->GetLearningRate();
848     }
849     if (param_id != END_OF_ITERATION) {
850       if (Caffe::solver_count() > 1) {
851         if (!use_buckets && !clip_grads) {
852           Reduce(type_id, param_id);
853           if (solver_->stop_reducing_requested(type_id)) {
854             break;
855           }
856           add_wgrad_sq(solver_->ApplyUpdate(param_id, handle, rate, true, clear_grads));
857           continue;
858         }
859       } else {
860         if (!clip_grads) {
861           this->learnable_params()[param_id]->scale_diff(1.F / global_grad_scale(), handle);
862           add_wgrad_sq(solver_->ApplyUpdate(param_id, handle, rate, true, clear_grads));
863         }
864         continue;
865       }
866     } else if (clip_grads && Caffe::solver_count() == 1) {
867       solver_->ClipGradientsAndNormalize(handle, type_id, au_ids);
868       for (int i : au_ids) {
869         add_wgrad_sq(solver_->ApplyUpdate(i, handle, rate, false, clear_grads));
870       }
871       au_ids.clear();
872     }
874     if (!learnable_params_.empty() && Caffe::solver_count() > 1) {
875       int id_from = -1;
876       // Is bucket big enough? Done with iteration?
877       const size_t received_count = received_contiguous_count(type_id, au_ids, id_from);
878       if (id_from >= 0) {
879         const size_t received_size = received_count * lp_size(id_from);
880         if ((received_size >= bucket_size && !clip_grads) || param_id == END_OF_ITERATION) {
881 //#ifdef DEBUG
882 //          {
883 //            size_t c = 0UL;
884 //            for (int i : au_ids) {
885 //              if (i < id_from) {
886 //                continue;
887 //              }
888 //              c += lp_aligned_count(i);
889 //            }
890 //            CHECK_EQ(c, received_count);
891 //          }
892 //#endif
893           CHECK_EQ((int) learnable_params_[id_from]->diff_type(), learnable_types_[type_id]);
894           ReduceBucket(type_id, received_count, learnable_params_[id_from]->diff_type(),
895               learnable_params_ptrs_[type_id][id_from]);
896           if (solver_->stop_reducing_requested(type_id)) {
897             break;
898           }
900           if (clip_grads) {
901             solver_->ClipGradientsAndNormalize(handle, type_id, au_ids);
902           }
904           for (int i : au_ids) {
905             add_wgrad_sq(solver_->ApplyUpdate(i, handle, rate, !clip_grads, clear_grads));
906           }
907           au_ids.erase(au_ids.find(id_from), au_ids.end());
908         }
909       }
910     }
911     if (param_id == END_OF_ITERATION) {
912       CHECK(au_ids.empty());
913       rate = -1.F;
914       solver_->iteration_complete_signal(type_id);
915     } else {
916       au_ids.emplace(param_id);
917     }
918   }
919   DLOG(INFO) << "[" << Caffe::current_device()
920              << "] Leaving ReduceAndUpdate thread " << lwp_id();
923 void Net::add_wgrad_sq(float wgrad_sq) {
924   if (wgrad_sq > 0.F) {
925     wgrad_sq_.fetch_add(std::llround(wgrad_sq * GRAD_FACTOR));
926   }
929 float Net::wgrad_sq() {
930   return wgrad_sq_.exchange(0LL) / GRAD_FACTOR;
933 void Net::update_grad_scale() {
934   global_grad_scale_coeff_ = 1.F;
935   if (global_grad_scale_enabled()) {
936     if (global_grad_scale_adaptive_) {
937       const float wgsq = wgrad_sq();
938       if (wgsq > 0.F) {
939         global_grad_scale_coeff_ = std::sqrt(wgsq) * global_grad_scale_param_;
940         return;
941       }
942     }
943     global_grad_scale_coeff_ = global_grad_scale_param_;
944   }
947 void Net::Reduce(int type_id, int param_id) {
948   Solver::Callback* cb = solver_->callback();
949   cb->reduce_barrier(type_id);
950   {
951     unique_ptr<unique_lock<shared_mutex>> lock;
952     if (solver_->is_root()) {
953       lock.reset(new unique_lock<shared_mutex>(GPUMemory::read_write_mutex()));
954     }
955     cb->reduce_barrier(type_id);
956     cb->allreduce(type_id, param_id);
957     cb->reduce_barrier(type_id);
958   }
959   this->learnable_params()[param_id]->
960       scale_diff(1.F / (Caffe::solver_count() * global_grad_scale()),
961       Caffe::cublas_handle(type_id));
962   // Also need to barrier to make sure lock isn't undone
963   // until all have completed, but the current nature of
964   // NCCL makes this unnecessary.
965   // solver_->callback()->reduce_barrier();
968 void Net::ReduceBucket(int type_id, size_t count, Type bucket_type, void* bucket) {
969   Solver::Callback* cb = solver_->callback();
970   cb->reduce_barrier(type_id);
971   {
972     unique_ptr<unique_lock<shared_mutex>> lock;
973     if (solver_->is_root()) {
974       lock.reset(new unique_lock<shared_mutex>(GPUMemory::read_write_mutex()));
975     }
976     cb->reduce_barrier(type_id);
977     cb->allreduce_bucket(type_id, count, bucket, bucket_type);
978     cb->reduce_barrier(type_id);
979   }
980   Tensor::gpu_scal(count, bucket_type, bucket, 1.F / (Caffe::solver_count() * global_grad_scale()),
981       Caffe::cublas_handle(type_id));
984 void Net::ForwardDebugInfo(const int layer_id) {
985   LOG_IF(INFO, Caffe::root_solver())
986       << "[Forward] Layer " << layer_names_[layer_id];
987   for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
988     const Blob& blob = *top_vecs_[layer_id][top_id];
989     const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
990     const double data_abs_val_mean = blob.asum_data() / blob.count();
991     LOG_IF(INFO, Caffe::root_solver())
992         << " -> top blob " << blob_name
993         << ", count: " << blob.count()
994         << " data: " << data_abs_val_mean;
995   }
996   for (int param_id = 0; param_id < layers_[layer_id]->blobs().size();
997        ++param_id) {
998     const Blob& blob = *layers_[layer_id]->blobs()[param_id];
999     const int net_param_id = param_id_vecs_[layer_id][param_id];
1000     const string& blob_name = param_display_names_[net_param_id];
1001     const double data_abs_val_mean = blob.asum_data() / blob.count();
1002     LOG_IF(INFO, Caffe::root_solver())
1003         << " -> param blob " << blob_name
1004         << ", count: " << blob.count()
1005         << " data: " << data_abs_val_mean;
1006   }
1009 void Net::BackwardDebugInfo(const int layer_id) {
1010   LOG_IF(INFO, Caffe::root_solver())
1011       << "[Backward] Layer " << layer_names_[layer_id];
1012   const vector<Blob*>& bottom_vec = bottom_vecs_[layer_id];
1013   for (int bottom_id = 0; bottom_id < bottom_vec.size(); ++bottom_id) {
1014     if (!bottom_need_backward_[layer_id][bottom_id]) { continue; }
1015     const Blob& blob = *bottom_vec[bottom_id];
1016     const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
1017     const double diff_abs_val_mean = blob.asum_diff() / blob.count();
1018     LOG_IF(INFO, Caffe::root_solver())
1019         << " -> bottom blob " << blob_name
1020         << ", count: " << blob.count()
1021         << ", diff: " << diff_abs_val_mean;
1022   }
1023   for (int param_id = 0; param_id < layers_[layer_id]->blobs().size();
1024        ++param_id) {
1025     if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; }
1026     const Blob& blob = *layers_[layer_id]->blobs()[param_id];
1027     double diff_abs_val_mean = blob.asum_diff() / blob.count();
1028     LOG_IF(INFO, Caffe::root_solver())
1029         << " -> param blob " << param_id
1030         << ", count: " << blob.count()
1031         << ", diff: " << diff_abs_val_mean;
1032   }
1035 void Net::UpdateDebugInfo(const int param_id) {
1036   const Blob& blob = *params_[param_id];
1037   const int param_owner = param_owners_[param_id];
1038   const string& layer_name = layer_names_[param_layer_indices_[param_id].first];
1039   const string& param_display_name = param_display_names_[param_id];
1040   const double diff_abs_val_mean = blob.asum_diff() / blob.count();
1041   if (param_owner < 0) {
1042     double data_abs_val_mean = blob.asum_data() / blob.count();
1043     LOG_IF(INFO, Caffe::root_solver())
1044         << "    [Update] Layer " << layer_name
1045         << ", param " << param_display_name
1046         << " data: " << data_abs_val_mean
1047         << "; diff: " << diff_abs_val_mean;
1048   } else {
1049     const string& owner_layer_name =
1050         layer_names_[param_layer_indices_[param_owner].first];
1051     LOG_IF(INFO, Caffe::root_solver())
1052         << "    [Update] Layer " << layer_name
1053         << ", param blob " << param_display_name
1054         << " (owned by layer " << owner_layer_name << ", " << "param "
1055         << param_display_names_[param_owners_[param_id]] << ")"
1056         << " diff: " << diff_abs_val_mean;
1057   }
1060 void Net::ShareTrainedLayersWith(const Net* other) {
1061   int num_source_layers = other->layers().size();
1062   for (int i = 0; i < num_source_layers; ++i) {
1063     LayerBase* source_layer = other->layers()[i].get();
1064     const string& source_layer_name = other->layer_names()[i];
1065     int target_layer_id = 0;
1066     while (target_layer_id != layer_names_.size() &&
1067         layer_names_[target_layer_id] != source_layer_name) {
1068       ++target_layer_id;
1069     }
1070     if (target_layer_id == layer_names_.size()) {
1071       LOG(INFO) << "Ignoring source layer " << source_layer_name;
1072       continue;
1073     }
1074     DLOG(INFO) << "Copying source layer " << source_layer_name;
1075     vector<shared_ptr<Blob> >& target_blobs =
1076         layers_[target_layer_id]->blobs();
1077     CHECK_EQ(target_blobs.size(), source_layer->blobs().size())
1078         << "Incompatible number of blobs for layer " << source_layer_name;
1079     for (int j = 0; j < target_blobs.size(); ++j) {
1080       Blob* source_blob = source_layer->blobs()[j].get();
1081       CHECK(target_blobs[j]->shape() == source_blob->shape())
1082           << "Cannot share param " << j << " weights from layer '"
1083           << source_layer_name << "'; shape mismatch.  Source param shape is "
1084           << source_blob->shape_string() << "; target param shape is "
1085           << target_blobs[j]->shape_string();
1086       target_blobs[j]->ShareData(*source_blob);
1087     }
1088   }
1089   trained_layers_shared_ = true;
1092 void Net::BackwardFrom(int start) {
1093   BackwardFromTo(start, 0);
1096 void Net::BackwardTo(int end) {
1097   BackwardFromTo(layers_.size() - 1, end);
1100 void Net::Backward(bool apply_update) {
1101   BackwardFromToAu(layers_.size() - 1, 0, apply_update);
1102   if (debug_info_) {
1103     float asum_data = 0.F, asum_diff = 0.F, sumsq_data = 0.F, sumsq_diff = 0.F;
1104     for (int i = 0; i < learnable_params_.size(); ++i) {
1105       asum_data += learnable_params_[i]->asum_data();
1106       asum_diff += learnable_params_[i]->asum_diff();
1107       sumsq_data += learnable_params_[i]->sumsq_data();
1108       sumsq_diff += learnable_params_[i]->sumsq_diff();
1109     }
1110     const double l2norm_data = std::sqrt(sumsq_data);
1111     const double l2norm_diff = std::sqrt(sumsq_diff);
1112     LOG(ERROR) << "    [Backward] All net params (data, diff): "
1113                << "L1 norm = (" << asum_data << ", " << asum_diff << "); "
1114                << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")";
1115   }
1118 void Net::Reshape() {
1119   for (int i = 0; i < layers_.size(); ++i) {
1120     layers_[i]->Reshape(bottom_vecs_[i], top_vecs_[i]);
1121   }
1124 void Net::CopyTrainedLayersFrom(const NetParameter& param) {
1125   int num_source_layers = param.layer_size();
1126   for (int i = 0; i < num_source_layers; ++i) {
1127     const LayerParameter& source_layer = param.layer(i);
1128     const string& source_layer_name = source_layer.name();
1129     const string& source_layer_type = source_layer.type();
1130     const bool ignore_shape_mismatch = ((solver_==NULL) || solver_->param().ignore_shape_mismatch());
1131     int target_layer_id = 0;
1132     while (target_layer_id != layer_names_.size() &&
1133         layer_names_[target_layer_id] != source_layer_name) {
1134       ++target_layer_id;
1135     }
1136     if (target_layer_id == layer_names_.size()) {
1137       LOG(INFO) << "Ignoring source layer " << source_layer_name;
1138       continue;
1139     }
1140     DLOG(INFO) << "Copying source layer " << source_layer_name;
1141     vector<shared_ptr<Blob> >& target_blobs =
1142         layers_[target_layer_id]->blobs();
1143     if (target_blobs.size() != source_layer.blobs_size()) {
1144       if(source_layer_type == "BatchNorm" && ignore_shape_mismatch) {
1145         LOG(WARNING) << "Incompatible number of blobs for layer " << source_layer_name 
1146             << " target(" << target_blobs.size() << ") vs source(" << source_layer.blobs_size() << ")";    
1147       } else {    
1148         CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
1149             << "Incompatible number of blobs for layer " << source_layer_name 
1150             << " target(" << target_blobs.size() << ") vs source(" << source_layer.blobs_size() << ")";    
1151       }
1152     }
1153     LOG(INFO) << "Copying source layer " << source_layer_name << " Type:"
1154               << source_layer_type << " #blobs=" << source_layer.blobs_size();
1155     int num_blobs_to_copy = std::min<int>(target_blobs.size(), source_layer.blobs_size());              
1156     // check if BN is in legacy DIGITS format?
1157     if (source_layer_type == "BatchNorm") {
1158       for (int j = 0; j < num_blobs_to_copy; ++j) {
1159         const bool kReshape = true;
1160         target_blobs[j]->FromProto(source_layer.blobs(j), kReshape);
1161         DLOG(INFO) << target_blobs[j]->count();
1162       }
1163       if (source_layer.blobs_size() == 5 && target_blobs[4]->count() == 1) {
1164         // old format: 0 - scale , 1 - bias,  2 - mean , 3 - var, 4 - reserved
1165         // new format: 0 - mean  , 1 - var,  2 - reserved , 3- scale, 4 - bias
1166         LOG(INFO) << "BN legacy DIGITS format detected ... ";
1167         std::swap(target_blobs[0], target_blobs[2]);
1168         std::swap(target_blobs[1], target_blobs[3]);
1169         // ==> 0 - mean , 1 -var,  2 - scale , 3 - bias; 4 - reserved
1170         std::swap(target_blobs[2], target_blobs[4]);
1171         std::swap(target_blobs[3], target_blobs[4]);
1172         LOG(INFO) << "BN Transforming to new format completed.";
1173       }
1174       if (source_layer.blobs_size() == 3) {
1175         const float scale_factor = target_blobs[2]->cpu_data<float>()[0] == 0.F ?
1176                                    0.F : 1.F / target_blobs[2]->cpu_data<float>()[0];
1177         caffe_cpu_scale(target_blobs[0]->count(), scale_factor,
1178                         target_blobs[0]->cpu_data<float>(),
1179                         target_blobs[0]->mutable_cpu_data<float>());
1180         caffe_cpu_scale(target_blobs[1]->count(), scale_factor,
1181                         target_blobs[1]->cpu_data<float>(),
1182                         target_blobs[1]->mutable_cpu_data<float>());
1183         target_blobs[2]->mutable_cpu_data<float>()[0] = 1.F;
1184       }
1185       for (int j = 0; j < target_blobs.size(); ++j) {
1186         DLOG(INFO) << target_blobs[j]->count();
1187       }
1188     } else {
1189       for (int j = 0; j < num_blobs_to_copy; ++j) {      
1190         if (!target_blobs[j]->ShapeEquals(source_layer.blobs(j))) {
1191           shared_ptr<Blob> source_blob = Blob::create(target_blobs[j]->data_type(),
1192               target_blobs[j]->diff_type());
1193           const bool kReshape = true;
1194           LOG(WARNING) << "Copying from " << source_layer_name << " to " <<
1195             layers_[target_layer_id]->layer_param().name() <<
1196             " target blob " << j;
1197           source_blob->FromProto(source_layer.blobs(j), kReshape);
1199           //Shape doesn't match. Check if atleast size matches.
1200           if(target_blobs[j]->count() == source_blob->count() && ignore_shape_mismatch) {
1201             LOG(WARNING) << "During copy param " << j << " weights from layer '"
1202                 << source_layer_name << "'; Ignoring shape mismatch and copying forcefully.  Source param shape is "
1203                 << source_blob->shape_string() << "; target param shape is "
1204                 << target_blobs[j]->shape_string() << ". ";
1205                           
1206             const bool kReshape = false;
1207             target_blobs[j]->FromProto(source_layer.blobs(j), kReshape, ignore_shape_mismatch);
1208           } else {
1209             if(ignore_shape_mismatch) {
1210               LOG(WARNING) << "Cannot copy param " << j << " weights from layer '"
1211                 << source_layer_name << "'; shape mismatch.  Source param shape is "
1212                 << source_blob->shape_string() << "; target param shape is "
1213                 << target_blobs[j]->shape_string() << ". "
1214                 << "To learn this layer's parameters from scratch rather than "
1215                 << "copying from a saved net, rename the layer.";
1216               } else {
1217                 LOG(FATAL) << "Cannot copy param " << j << " weights from layer '"
1218                   << source_layer_name << "'; shape mismatch.  Source param shape is "
1219                   << source_blob->shape_string() << "; target param shape is "
1220                   << target_blobs[j]->shape_string() << ". "
1221                   << "To learn this layer's parameters from scratch rather than "
1222                   << "copying from a saved net, rename the layer.";
1223              }
1224           }
1225         } else {
1226           //Go ahead and copy: exactly matching blobs
1227           const bool kReshape = false;
1228           target_blobs[j]->FromProto(source_layer.blobs(j), kReshape);
1229         }
1230       }
1231     }
1232   }
1233   CopyQuantizationRangeInLayers();    
1236 void Net::CopyTrainedLayersFrom(const string trained_filename) {
1237   if (trained_filename.size() >= 3 &&
1238       trained_filename.compare(trained_filename.size() - 3, 3, ".h5") == 0) {
1239     CopyTrainedLayersFromHDF5(trained_filename);
1240   } else {
1241     CopyTrainedLayersFromBinaryProto(trained_filename);
1242   }
1245 void Net::CopyTrainedLayersFromBinaryProto(
1246     const string trained_filename) {
1247   NetParameter param;
1248   ReadNetParamsFromBinaryFileOrDie(trained_filename, &param);
1249   CopyTrainedLayersFrom(param);
1252 void Net::CopyTrainedLayersFromHDF5(const string trained_filename) {
1253   hid_t file_hid = H5Fopen(trained_filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
1254   CHECK_GE(file_hid, 0) << "Couldn't open " << trained_filename;
1255   hid_t data_hid = H5Gopen2(file_hid, "data", H5P_DEFAULT);
1256   CHECK_GE(data_hid, 0) << "Error reading weights from " << trained_filename;
1257   int num_layers = hdf5_get_num_links(data_hid);
1258   for (int i = 0; i < num_layers; ++i) {
1259     string source_layer_name = hdf5_get_name_by_idx(data_hid, i);
1260     if (!layer_names_index_.count(source_layer_name)) {
1261       LOG(INFO) << "Ignoring source layer " << source_layer_name;
1262       continue;
1263     }
1264     int target_layer_id = layer_names_index_[source_layer_name];
1265     DLOG(INFO) << "Copying source layer " << source_layer_name;
1266     vector<shared_ptr<Blob> >& target_blobs =
1267         layers_[target_layer_id]->blobs();
1268     hid_t layer_hid = H5Gopen2(data_hid, source_layer_name.c_str(),
1269         H5P_DEFAULT);
1270     CHECK_GE(layer_hid, 0)
1271         << "Error reading weights from " << trained_filename;
1272     // Check that source layer doesn't have more params than target layer
1273     int num_source_params = hdf5_get_num_links(layer_hid);
1274     CHECK_LE(num_source_params, target_blobs.size())
1275         << "Incompatible number of blobs for layer " << source_layer_name;
1276     for (int j = 0; j < target_blobs.size(); ++j) {
1277       ostringstream oss;
1278       oss << j;
1279       string dataset_name = oss.str();
1280       int target_net_param_id = param_id_vecs_[target_layer_id][j];
1281       if (!H5Lexists(layer_hid, dataset_name.c_str(), H5P_DEFAULT)) {
1282         // Target param doesn't exist in source weights...
1283         if (param_owners_[target_net_param_id] != -1) {
1284           // ...but it's weight-shared in target, so that's fine.
1285           continue;
1286         } else {
1287           LOG(FATAL) << "Incompatible number of blobs for layer "
1288               << source_layer_name;
1289         }
1290       }
1291       hdf5_load_nd_dataset(layer_hid, dataset_name.c_str(), 0, kMaxBlobAxes,
1292           target_blobs[j].get());
1293     }
1294     H5Gclose(layer_hid);
1295   }
1296   H5Gclose(data_hid);
1297   H5Fclose(file_hid);
1300 void Net::ToProto(NetParameter* param, bool write_diff, bool write_data) const {
1301   param->Clear();
1303   // Add bottom and top
1304   if(net_param_.has_quantize()) {
1305       param->set_name(name_ + "___QUANTIZED__SEE_END_OFTHE_FILE");
1306       param->set_quantize(net_param_.quantize());
1307   } else {
1308       param->set_name(name_);
1309   }
1311   if(net_param_.has_net_quantization_param()) {
1312       *param->mutable_net_quantization_param() = net_param_.net_quantization_param();
1313   }
1314   DLOG(INFO) << "Serializing " << layers_.size() << " layers";
1315   for (int i = 0; i < layers_.size(); ++i) {
1316     LayerParameter* layer_param = param->add_layer();
1317     layers_[i]->ToProto(layer_param, write_diff, write_data);
1318   }
1321 template<typename Dtype>
1322 void Net::Convert2FixedPoint_cpu(Dtype* data, const int cnt, const int bit_width, int fl, bool is_unsigned, bool clip) const {
1323   for (int index = 0; index < cnt; ++index) {
1324     data[index] = data[index] * powf(2, fl);
1325     // Saturate data
1326 #if CLIP_QUANT
1327       if(clip) {
1328           int qrange = is_unsigned? bit_width :  (bit_width - 1);
1329           Dtype max_data = +(powf(2, qrange) - 1);
1330           Dtype min_data = is_unsigned? 0 : -(powf(2, qrange));
1331           data[index] = std::max(std::min(data[index], max_data), min_data);
1332       }
1333 #endif
1334     data[index] = round(data[index]);
1335     //data[index] = data[index] * pow(2, -fl);
1336   }
1339 void Net::ToHDF5(const string& filename, bool write_diff) const {
1340   hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
1341       H5P_DEFAULT);
1342   CHECK_GE(file_hid, 0)
1343       << "Couldn't open " << filename << " to save weights.";
1344   hid_t data_hid = H5Gcreate2(file_hid, "data", H5P_DEFAULT, H5P_DEFAULT,
1345       H5P_DEFAULT);
1346   CHECK_GE(data_hid, 0) << "Error saving weights to " << filename << ".";
1347   hid_t diff_hid = -1;
1348   if (write_diff) {
1349     diff_hid = H5Gcreate2(file_hid, "diff", H5P_DEFAULT, H5P_DEFAULT,
1350         H5P_DEFAULT);
1351     CHECK_GE(diff_hid, 0) << "Error saving weights to " << filename << ".";
1352   }
1353   for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
1354     const LayerParameter& layer_param = layers_[layer_id]->layer_param();
1355     string layer_name = layer_param.name();
1356     hid_t layer_data_hid = H5Gcreate2(data_hid, layer_name.c_str(),
1357         H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
1358     CHECK_GE(layer_data_hid, 0)
1359         << "Error saving weights to " << filename << ".";
1360     hid_t layer_diff_hid = -1;
1361     if (write_diff) {
1362       layer_diff_hid = H5Gcreate2(diff_hid, layer_name.c_str(),
1363           H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT);
1364       CHECK_GE(layer_diff_hid, 0)
1365           << "Error saving weights to " << filename << ".";
1366     }
1367     int num_params = layers_[layer_id]->blobs().size();
1368     for (int param_id = 0; param_id < num_params; ++param_id) {
1369       ostringstream dataset_name;
1370       dataset_name << param_id;
1371       const int net_param_id = param_id_vecs_[layer_id][param_id];
1372       if (param_owners_[net_param_id] == -1) {
1373         // Only save params that own themselves
1374         hdf5_save_nd_dataset(layer_data_hid, dataset_name.str(),
1375             *params_[net_param_id]);
1376       }
1377       if (write_diff) {
1378         // Write diffs regardless of weight-sharing
1379         hdf5_save_nd_dataset(layer_diff_hid, dataset_name.str(),
1380             *params_[net_param_id], true);
1381       }
1382     }
1383     H5Gclose(layer_data_hid);
1384     if (write_diff) {
1385       H5Gclose(layer_diff_hid);
1386     }
1387   }
1388   H5Gclose(data_hid);
1389   if (write_diff) {
1390     H5Gclose(diff_hid);
1391   }
1392   H5Fclose(file_hid);
1395 void Net::Update() {
1396   for (int i = 0; i < learnable_params_.size(); ++i) {
1397     learnable_params_[i]->Update();
1398   }
1401 void Net::ClearParamDiffs() {
1402   if (Caffe::mode() == Caffe::GPU) {
1403     caffe_gpu_memset(learnable_space_[0].size(), 0, learnable_space_[0].data());
1404     caffe_gpu_memset(learnable_space_[1].size(), 0, learnable_space_[1].data());
1405   } else {
1406     for (int i = 0; i < learnable_params_.size(); ++i) {
1407       learnable_params_[i]->set_diff(0.F);
1408     }
1409   }
1412 void Net::ShareWeights() {
1413   for (int i = 0; i < params_.size(); ++i) {
1414     if (param_owners_[i] < 0) {
1415       gpu_prm_memory_data_use_ += params_[i]->gpu_memory_data_use();
1416       gpu_prm_memory_diff_use_ += params_[i]->gpu_memory_diff_use();
1417       continue;
1418     }
1419 //    DLOG(INFO) << "param " << i << " has owner " << param_owners_[i];
1420     params_[i]->ShareData(*params_[param_owners_[i]]);
1421     params_[i]->ShareDiff(*params_[param_owners_[i]]);
1422     gpu_shp_memory_data_use_ += params_[i]->gpu_memory_data_use();
1423     gpu_shp_memory_diff_use_ += params_[i]->gpu_memory_diff_use();
1424   }
1427 bool Net::has_blob(const string& blob_name) const {
1428   return blob_names_index_.find(blob_name) != blob_names_index_.end();
1431 const shared_ptr<Blob> Net::blob_by_name(
1432     const string& blob_name) const {
1433   shared_ptr<Blob> blob_ptr;
1434   if (has_blob(blob_name)) {
1435     blob_ptr = blobs_[blob_names_index_.find(blob_name)->second];
1436   } else {
1437     LOG(WARNING) << "Unknown blob name " << blob_name;
1438   }
1439   return blob_ptr;
1442 bool Net::has_layer(const string& layer_name) const {
1443   return layer_names_index_.find(layer_name) != layer_names_index_.end();
1446 const shared_ptr<LayerBase> Net::layer_by_name(
1447     const string& layer_name) const {
1448   shared_ptr<LayerBase> layer_ptr;
1449   if (has_layer(layer_name)) {
1450     layer_ptr = layers_[layer_names_index_.find(layer_name)->second];
1451   } else {
1452     LOG(WARNING) << "Unknown layer name " << layer_name;
1453   }
1454   return layer_ptr;
1457 void Net::set_solver(Solver* s) {
1458   solver_ = s;
1459   for (auto& layer : layers_) {
1460     layer->set_parent_net(this);
1461   }
1464 void Net::InitializeLearnableDiffSpace(int type_id) {
1465   CHECK_GE(type_id, 0);
1466   CHECK_LT(type_id, 2);
1467   const Type t = (Type) learnable_types_[type_id];
1468   if (learnable_params_ptrs_[type_id].size() == learnable_params_.size()) {
1469     LOG(INFO) << print_current_device() << " Already reserved "
1470               << learnable_space_size_[type_id] << " bytes of shared learnable space for type "
1471               << Type_Name(t);
1472     return;
1473   }
1474   learnable_space_size_[type_id] = 0UL;
1475   learnable_params_ptrs_[type_id].resize(learnable_params_.size(), nullptr);
1476   for (int i = 0; i < layers_.size(); ++i) {
1477     for (int j = 0; j < layers_[i]->blobs().size(); ++j) {
1478       if (!layers_[i]->skip_apply_update(j)) {
1479         const int lip = layer_index_params_[make_pair(i, j)];
1480         if (param_owners_[lip] < 0) {
1481           const int param_id = learnable_param_ids_[lip];
1482           if (learnable_params_[param_id]->diff_type() == t) {
1483             learnable_space_size_[type_id] += lp_aligned_count(param_id) * lp_size(param_id);
1484           }
1485         }
1486       }
1487     }
1488   }
1489   // Size have at least one byte, otherwise cudaMalloc fails if net has no
1490   // learnable parameters. Times two.
1491   if (learnable_space_size_[type_id] < 2) {
1492     learnable_space_size_[type_id] = 2;
1493   }
1494   LOG(INFO) << print_current_device() << " Reserving "
1495             << learnable_space_size_[type_id] << " bytes of shared learnable space for type "
1496             << Type_Name(t);
1497   learnable_space_[type_id].reserve(learnable_space_size_[type_id]);
1498   unsigned char* ptr = reinterpret_cast<unsigned char*>(learnable_space_[type_id].data());
1499   caffe_gpu_memset(learnable_space_size_[type_id], 0, ptr);
1500   for (int i = 0; i < layers_.size(); ++i) {
1501     for (int j = 0; j < layers_[i]->blobs().size(); ++j) {
1502       if (!layers_[i]->skip_apply_update(j)) {
1503         const int lip = layer_index_params_[make_pair(i, j)];
1504         if (param_owners_[lip] < 0) {
1505           const int param_id = learnable_param_ids_[lip];
1506           if (learnable_params_[param_id]->diff_type() == t) {
1507             learnable_params_[param_id]->set_gpu_diff(ptr);
1508             learnable_params_ptrs_[type_id][param_id] = static_cast<void*>(ptr);
1509             ptr += lp_aligned_count(param_id) * lp_size(param_id);
1510             learnable_params_mapped_.push_back(learnable_params_[param_id]);
1511             ltop_[type_id][i].insert(param_id);
1512             void *p = learnable_params_[param_id]->
1513                 current_mutable_data_memory(Caffe::mode() == Caffe::GPU);
1514             (void) p;
1515           }
1516         }
1517       } else {
1518         DLOG(INFO) << print_current_device()
1519             << "** Skipping non-learnable blob from " << layers_[i]->name()
1520             << " of type " << layers_[i]->type();
1521       }
1522     }
1523   }
1526 const vector<Type>& Net::learnable_types(bool reset) {
1527   if (reset || learnable_types_.empty()) {
1528     learnable_types_.clear();
1529     int type0 = -1;
1530     int type1 = -1;
1531     for (shared_ptr<Blob> lp : learnable_params_) {
1532       Type t = lp->diff_type();
1533       if (type0 < 0) {
1534         type0 = (int) t;
1535         learnable_types_.push_back(t);
1536       } else if (type1 < 0 && type0 != (int) t) {
1537         type1 = (int) t;
1538         learnable_types_.push_back(t);
1539       }
1540     }
1541     if (learnable_types_.empty() && solver_ != nullptr) {
1542       learnable_types_.push_back(solver_->data_type());
1543     }
1544     CHECK_LE(learnable_types_.size(), 2);
1545   }
1546   return learnable_types_;
1549 template <typename Dtype>
1550 void Net::OptimizeNet() {
1551   auto set_blob_data_at = [&](shared_ptr<Blob>& blob, const int n, const int c, const int h, const int w, const Dtype& value) {
1552     if(blob != NULL && blob->count() > 0) {
1553       Dtype* data = blob->mutable_cpu_data<Dtype>();
1554       int idx = blob->offset(n, c, h, w);
1555       data[idx] = value;
1556     }
1557   };
1558   
1559   auto set_blob_data_at_chan = [&](shared_ptr<Blob>& blob, const int c, const Dtype& value) {
1560     if(blob != NULL && blob->count() > 0) {  
1561       Dtype* data = blob->mutable_cpu_data<Dtype>();  
1562       int idx = blob->shape().size()>1 && blob->shape(0)==1? blob->offset(0,c,0,0): blob->offset(c);
1563       data[idx] = value;
1564     }
1565   };
1566     
1567   enum LayerSequenceType {
1568       LAYER_SEQ_TYPE_OTHER,
1569       LAYER_SEQ_TYPE_CONV_BN,
1570       LAYER_SEQ_TYPE_CONV_BN_SCALE,
1571       LAYER_SEQ_TYPE_BN_CONV,
1572       LAYER_SEQ_TYPE_BN_SCALE_CONV
1573   };
1575   auto layer_sequence_type = [&](int layer_id) {
1576     std::string layer_name = layers_[layer_id]->layer_param().name();
1577     std::string layer_type = layers_[layer_id]->layer_param().type();
1578     std::string layer_type_next = GetTopLayerType(layer_id);
1579     std::string layer_type_next2 = GetTopLayerType2(layer_id);
1580     std::string layer_type_prev = GetBottomLayerType(layer_id);
1581     std::string layer_type_prev2 = GetBottomLayerType2(layer_id);
1583     LayerSequenceType type = LAYER_SEQ_TYPE_OTHER;
1584     if ((layer_id < (layers_.size()-2)) && layer_type == std::string("Convolution") &&
1585         layer_type_next == std::string("BatchNorm") && layer_type_next2 == std::string("Scale")) {
1586         type = LAYER_SEQ_TYPE_CONV_BN_SCALE;
1587     } else if ((layer_id < (layers_.size()-1)) && layer_type == std::string("Convolution") &&
1588         layer_type_next == std::string("BatchNorm")) {
1589         type = LAYER_SEQ_TYPE_CONV_BN;
1590     } else if ((layer_id < (layers_.size()-1)) && layer_type == std::string("BatchNorm") &&
1591         layer_type_next == std::string("Convolution")) {
1592         type = LAYER_SEQ_TYPE_BN_CONV;
1593     } else if ((layer_id < (layers_.size()-2)) && layer_type == std::string("BatchNorm") &&
1594         layer_type_next == std::string("Scale") && layer_type_next2 == std::string("Convolution")) {
1595         if(layer_id>0 && layer_type_prev != std::string("Convolution")) {
1596           type = LAYER_SEQ_TYPE_BN_SCALE_CONV;
1597         }
1598     }
1599     return type;
1600   };
1603   for (int i = 0; i < (layers_.size()-1); i++) {
1604     LayerSequenceType layer_seq_type = layer_sequence_type(i);
1605     if (layer_seq_type == LAYER_SEQ_TYPE_CONV_BN || layer_seq_type == LAYER_SEQ_TYPE_CONV_BN_SCALE) {
1606       LOG(INFO) << "Optimizing layer: " << layers_[i]->type() << " " << layers_[i]->name();
1607       LayerBase& conv_layer = *layers_[i];
1608       //int num_groups = conv_layer.layer_param().convolution_param().group();
1610       // Set bias term if it not there, as it is needed when combining BN
1611       if(conv_layer.blobs().size()==1) {
1612         shared_ptr<Blob> conv_weights = conv_layer.blobs()[0];
1613         int channels = (conv_weights->num_axes() == 1)? conv_weights->count() : conv_weights->shape(0);
1614         int outputs = channels;
1616         bool bias_term = true;
1617         conv_layer.mutable_layer_param().mutable_convolution_param()->set_bias_term(bias_term);
1618         conv_layer.mutable_layer_param().mutable_convolution_param()->mutable_bias_filler()->set_type("constant");
1619         conv_layer.mutable_layer_param().mutable_convolution_param()->mutable_bias_filler()->set_value(0);
1621         //TODO: Revisit if needed
1622         conv_layer.blobs().resize(2);
1623         vector<int> bias_shape(1, outputs);
1624         conv_layer.blobs()[1] = Blob::create<Dtype>(bias_shape);
1625         shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
1626             conv_layer.layer_param().convolution_param().bias_filler()));
1627         bias_filler->Fill(conv_layer.blobs()[1].get());
1628       }
1630       shared_ptr<Blob>& conv_weights = conv_layer.blobs()[0];
1631       int channels = (conv_weights->num_axes() == 1)? conv_weights->count() : conv_weights->shape(0);
1632       //int outputs = channels;
1634       int batch_norm_layer_id = this->GetTopLayerIds(i)[0];
1635       LayerBase& batch_norm_layer = *layers_[batch_norm_layer_id];
1636       bool scale_bias = (batch_norm_layer.blobs().size() == 5);//layers_[i+1]->layer_param().batch_norm_param().scale_bias();
1637       bool has_scale_layer = (layer_seq_type == LAYER_SEQ_TYPE_CONV_BN_SCALE);
1638       int scale_layer_layer_id = this->GetTopLayerIds(batch_norm_layer_id)[0];
1639       shared_ptr<LayerBase> scale_layer = has_scale_layer? layers_[scale_layer_layer_id] : NULL;
1640       bool has_scale_layer_bias = (has_scale_layer && scale_layer->blobs().size()>1);
1642       shared_ptr<Blob>& batch_norm_mean = batch_norm_layer.blobs()[0];
1643       shared_ptr<Blob>& batch_norm_var = batch_norm_layer.blobs()[1];
1644       double eps = batch_norm_layer.layer_param().batch_norm_param().eps();
1646       // Absorb the BatchNorm into convolution
1647       for(int no=0; no<conv_weights->shape(0); no++) {
1648         double var = batch_norm_var->data_at(no) + eps;
1649         double stdev_inv = std::pow(var, double(-0.5));
1650         double scale = scale_bias? batch_norm_layer.blobs()[3]->data_at(no) :
1651                 (has_scale_layer? scale_layer->blobs()[0]->data_at(no) : 1.0);
1652         for(int ni=0; ni<conv_weights->shape(1); ni++) {
1653           for(int w=0; w<conv_weights->shape(2); w++) {
1654             for(int h=0; h<conv_weights->shape(3); h++) {
1655                 double weights = conv_weights->data_at(no,ni,w,h);
1656                 weights = (weights * stdev_inv * scale);
1657                 set_blob_data_at(conv_weights, no, ni, w, h, weights);
1658             }
1659           }
1660         }
1661       }
1663       shared_ptr<Blob>& conv_bias = conv_layer.blobs()[1];
1664       for(int no=0; no<channels; no++) {
1665         double var = batch_norm_var->data_at(no) + eps;
1666         double stdev_inv = std::pow(var, double(-0.5));
1667         double scale = scale_bias? batch_norm_layer.blobs()[3]->data_at(no) :
1668                 (has_scale_layer? scale_layer->blobs()[0]->data_at(no) : 1.0);
1669         double bias = scale_bias? batch_norm_layer.blobs()[4]->data_at(no) :
1670                 (has_scale_layer && has_scale_layer_bias? scale_layer->blobs()[1]->data_at(no) : 0.0);
1671         double mean = batch_norm_mean->data_at(no);
1672         double weights_bias = conv_bias->data_at(no);
1673         weights_bias = ((weights_bias - mean) * stdev_inv * scale + bias);
1674         set_blob_data_at_chan(conv_bias, no, weights_bias);
1675       }
1677       // Set the batch norm (and subsequent scale layer if present) to identity
1678       for(int no=0; no<channels; no++) {
1679         if(scale_bias) {
1680           set_blob_data_at_chan(batch_norm_layer.blobs()[3], no, Dtype(1.0));
1681           set_blob_data_at_chan(batch_norm_layer.blobs()[4], no, Dtype(0.0));
1682         }
1683         if(has_scale_layer) {
1684           set_blob_data_at_chan(scale_layer->blobs()[0], no, Dtype(1.0));
1685           if(has_scale_layer_bias) {
1686             set_blob_data_at_chan(scale_layer->blobs()[1], no, Dtype(0.0));
1687           }
1688         }
1689         set_blob_data_at_chan(batch_norm_mean, no, Dtype(0.0));
1690         //Change var so that after adding eps, it becomes 1.0
1691         set_blob_data_at_chan(batch_norm_var, no, Dtype(1.0 - eps));
1692       }
1694       //transfer back to gpu
1695       for(int blob_id=0; blob_id<conv_layer.blobs().size(); blob_id++) {
1696           conv_layer.blobs()[blob_id]->gpu_data<Dtype>();
1697       }
1698       for(int blob_id=0; blob_id<batch_norm_layer.blobs().size(); blob_id++) {
1699           batch_norm_layer.blobs()[blob_id]->gpu_data<Dtype>();
1700       }
1701       if(has_scale_layer) {
1702           for(int blob_id=0; blob_id<scale_layer->blobs().size(); blob_id++) {
1703               scale_layer->blobs()[blob_id]->gpu_data<Dtype>();
1704           }
1705       }
1706     }
1707   }
1710 #if 0
1711   //Merge a BatchNorm layer that comes before convolution layer
1712   //This is not needed now. BN can be run separately if it cannot be merged
1713   for (int i = 0; i < (layers_.size()-1); i++) {
1714     LayerSequenceType layer_seq_type = layer_sequence_type(i);
1715     if (layer_seq_type == LAYER_SEQ_TYPE_BN_CONV || layer_seq_type == LAYER_SEQ_TYPE_BN_SCALE_CONV) {
1716       LayerBase& batch_norm_layer = *layers_[i];
1717       LayerBase& conv_layer = (layer_seq_type == LAYER_SEQ_TYPE_BN_SCALE_CONV)? *layers_[i+2] : *layers_[i+1];
1718       shared_ptr<Blob>& conv_weights = conv_layer.blobs()[0];
1719       shared_ptr<Blob>& conv_bias = conv_layer.blobs()[1];
1720       //int channels = (conv_weights->num_axes() == 1)? conv_weights->count() : conv_weights->shape(0);
1721       int bn_channels = conv_weights->shape(1);
1722       LOG(INFO) << "Optimizing layer: " << layers_[i]->type() << " " << layers_[i]->name();
1723       //int num_groups = conv_layer.layer_param().convolution_param().group();
1725       // Set bias term if it not there, as it is needed when combining BN
1726       if(conv_layer.blobs().size()==1) {
1727         shared_ptr<Blob> conv_weights = conv_layer.blobs()[0];
1728         int channels = (conv_weights->num_axes() == 1)? conv_weights->count() : conv_weights->shape(0);
1729         int outputs = channels;
1731         bool bias_term = true;
1732         conv_layer.mutable_layer_param().mutable_convolution_param()->set_bias_term(bias_term);
1733         conv_layer.mutable_layer_param().mutable_convolution_param()->mutable_bias_filler()->set_type("constant");
1734         conv_layer.mutable_layer_param().mutable_convolution_param()->mutable_bias_filler()->set_value(0);
1736         //TODO: Revisit if needed
1737         conv_layer.blobs().resize(2);
1738         vector<int> bias_shape(1, outputs);
1739         conv_layer.blobs()[1] = Blob::create<Dtype>(bias_shape);
1740         shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
1741             conv_layer.layer_param().convolution_param().bias_filler()));
1742         bias_filler->Fill(conv_layer.blobs()[1].get());
1743       }
1745       bool scale_bias = layers_[i]->layer_param().batch_norm_param().scale_bias();
1746       bool has_scale_layer = (layer_seq_type == LAYER_SEQ_TYPE_BN_SCALE_CONV);
1747       shared_ptr<LayerBase> scale_layer = has_scale_layer? layers_[i+1] : NULL;
1748       bool has_scale_layer_bias = (has_scale_layer && scale_layer->blobs().size()>1);
1750       shared_ptr<Blob>& batch_norm_mean = batch_norm_layer.blobs()[0];
1751       shared_ptr<Blob>& batch_norm_var = batch_norm_layer.blobs()[1];
1753       Dtype eps = batch_norm_layer.layer_param().batch_norm_param().eps();
1755       // Absorb the BatchNorm into convolution
1756       for(int no=0; no<conv_weights->shape(0); no++) {
1757         Dtype bias_sum1 = 0;
1758         Dtype bias_sum2 = 0;
1759         for(int ni=0; ni<bn_channels; ni++) {
1760           Dtype var = batch_norm_var->data_at(ni) + eps;
1761           Dtype stdev_inv = std::pow(var, Dtype(-0.5));
1762           Dtype scale = scale_bias? batch_norm_layer.blobs()[3]->data_at(ni) : (has_scale_layer? scale_layer->blobs()[0]->data_at(ni) : 1.0);
1763           Dtype bias = scale_bias? batch_norm_layer.blobs()[4]->data_at(ni) : (has_scale_layer && has_scale_layer_bias? scale_layer->blobs()[1]->data_at(ni) : 0.0);
1764           Dtype mean = batch_norm_mean->data_at(ni);
1766           for(int w=0; w<conv_weights->shape(2); w++) {
1767             for(int h=0; h<conv_weights->shape(3); h++) {
1768               bias_sum1 += bias * conv_weights->data_at(no,ni,w,h);
1769               bias_sum2 += mean * stdev_inv * conv_weights->data_at(no,ni,w,h);
1770               set_blob_data_at(conv_weights,no,ni,w,h, conv_weights->data_at(no,ni,w,h) * stdev_inv * scale);
1771             }
1772           }
1773         }
1774         set_blob_data_at_chan(conv_bias,no, conv_bias->data_at(no) + bias_sum1 - bias_sum2);
1775       }
1777       // Set the batch norm to identity
1778       for(int ni=0; ni<bn_channels; ni++) {
1779         if(scale_bias) {      
1780           set_blob_data_at_chan(batch_norm_layer.blobs()[3], ni, Dtype(1.0));
1781           set_blob_data_at_chan(batch_norm_layer.blobs()[4], ni, Dtype(0.0));
1782         } else if(has_scale_layer) {
1783           set_blob_data_at_chan(scale_layer->blobs()[0], ni, Dtype(1.0));
1784           if(has_scale_layer_bias) {
1785             set_blob_data_at_chan(scale_layer->blobs()[1], ni, Dtype(0.0));
1786           }
1787         }
1788         set_blob_data_at_chan(batch_norm_mean, ni, Dtype(0.0));
1789         //Change var so that after adding eps, it becomes 1.0
1790         set_blob_data_at_chan(batch_norm_var, ni, Dtype(1.0 - eps));
1791       }
1793       //transfer back to gpu
1794       for(int blob_id=0; blob_id<conv_layer.blobs().size(); blob_id++) {
1795           conv_layer.blobs()[blob_id]->gpu_data<Dtype>();
1796       }
1797       for(int blob_id=0; blob_id<batch_norm_layer.blobs().size(); blob_id++) {
1798           batch_norm_layer.blobs()[blob_id]->gpu_data<Dtype>();
1799       }
1800       if(has_scale_layer) {
1801           for(int blob_id=0; blob_id<scale_layer->blobs().size(); blob_id++) {
1802               scale_layer->blobs()[blob_id]->gpu_data<Dtype>();
1803           }
1804       }
1806     }
1807   }
1808 #endif
1813 template void Net::OptimizeNet<float>();
1816 void Net::StartQuantization() {
1817   bool quantize = (net_param_.quantize() && net_param_.net_quantization_param().quantization_start() > 0);
1818   if(quantize) {
1819     const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
1820     if(infer_count_ >= net_qparam.quantization_start()) {
1821       if(infer_count_ == net_qparam.quantization_start()) {
1822         LOG(INFO)<< "Enabling quantization flag in quantization_param at infer/iter index: " << infer_count_;
1823         this->EnableQuantizationForSelectedLayers();
1824       }
1825       this->SetQuantizationParams();
1826     }
1827   }
1830 void Net::FinishQuantization() {
1831   bool quantize = (net_param_.quantize() && net_param_.net_quantization_param().quantization_start() > 0);
1832   if(quantize) {
1833     const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
1834     if(net_qparam.update_quantization_param()) {
1835         this->UpdateQuantizationRangeInLayers();
1836     }
1838     if(net_qparam.quantization_start() > 0 && infer_count_ >= net_qparam.quantization_start()) {
1839       string phase = this->phase() == caffe::TRAIN ? "Train" : "Test";
1840       if (net_qparam.display_quantization() > 0 && (infer_count_ % net_qparam.display_quantization() == 0)) {
1841         LOG(INFO)<< "Quantizing the net: " << this->name() + " " + phase;
1842         this->DisplayQuantizationParams();
1843       }
1844     }
1845   }
1848 void Net::ClearQuantizationRangeInLayers() {
1849   max_in_.clear();
1850   max_out_.clear();
1851   max_weights_.clear();
1853   min_in_.clear();
1854   min_out_.clear();
1855   min_weights_.clear();
1858 void Net::CopyQuantizationRangeInLayers() {
1859   max_in_.resize(layers_.size());
1860   max_out_.resize(layers_.size());
1861   max_weights_.resize(layers_.size());
1863   min_in_.resize(layers_.size());
1864   min_out_.resize(layers_.size());
1865   min_weights_.resize(layers_.size());
1867   for (int layer_id = 0; layer_id < layers_.size(); layer_id++) {
1868     min_in_[layer_id].resize(bottom_vecs_[layer_id].size(), 0);
1869     max_in_[layer_id].resize(bottom_vecs_[layer_id].size(), 0);
1870     min_weights_[layer_id].resize(layers_[layer_id]->blobs().size(), 0);
1871     max_weights_[layer_id].resize(layers_[layer_id]->blobs().size(), 0);
1872     min_out_[layer_id].resize(top_vecs_[layer_id].size(), 0);
1873     max_out_[layer_id].resize(top_vecs_[layer_id].size(), 0);
1874   }
1876   for (int layer_id = 0; layer_id < layers_.size(); layer_id++) {
1877     if(!layers_[layer_id]->layer_param().has_quantization_param()) {
1878       continue;
1879     }
1880     QuantizationParameter& source_quantization_param = *layers_[layer_id]->mutable_layer_param().mutable_quantization_param();
1881     for(int blob_id = 0; blob_id<min_in_[layer_id].size(); blob_id++) {
1882       min_in_[layer_id][blob_id] = source_quantization_param.qparam_in(blob_id).min();
1883       max_in_[layer_id][blob_id] = source_quantization_param.qparam_in(blob_id).max();
1884     }
1885     for(int blob_id = 0; blob_id<layers_[layer_id]->blobs().size(); blob_id++) {
1886       min_weights_[layer_id][blob_id] = source_quantization_param.qparam_w(blob_id).min();
1887       max_weights_[layer_id][blob_id] = source_quantization_param.qparam_w(blob_id).max();
1888     }
1889     for(int blob_id = 0; blob_id<min_out_[layer_id].size(); blob_id++) {
1890       min_out_[layer_id][blob_id] = source_quantization_param.qparam_out(blob_id).min();
1891       max_out_[layer_id][blob_id] = source_quantization_param.qparam_out(blob_id).max();
1892     }
1893     source_quantization_param.set_quantized_infer_count(0);
1894   }
1897 void Net::UpdateQuantizationRangeInLayers() {
1898   const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
1900   max_in_.resize(layers_.size());
1901   max_out_.resize(layers_.size());
1902   max_weights_.resize(layers_.size());
1904   min_in_.resize(layers_.size());
1905   min_out_.resize(layers_.size());
1906   min_weights_.resize(layers_.size());
1908   for (int layer_id = 0; layer_id < layers_.size(); layer_id++) {
1909     min_in_[layer_id].resize(bottom_vecs_[layer_id].size(), 0);
1910     max_in_[layer_id].resize(bottom_vecs_[layer_id].size(), 0);
1911     min_weights_[layer_id].resize(layers_[layer_id]->blobs().size(), 0);
1912     max_weights_[layer_id].resize(layers_[layer_id]->blobs().size(), 0);
1913     min_out_[layer_id].resize(top_vecs_[layer_id].size(), 0);
1914     max_out_[layer_id].resize(top_vecs_[layer_id].size(), 0);
1915   }
1917   // Find maximal values.
1918   float range_expansion_factor = net_qparam.range_expansion_factor();
1919   float beta = (infer_count_ == 0? 1.0 : net_qparam.range_update_factor());
1920   float alpha = (1.0 - beta);
1922   for (int layer_id = 0; layer_id < layers_.size(); layer_id++) {
1923     if(bottom_vecs_[layer_id].size()>0) {
1924         for(int blob_id = 0; blob_id<bottom_vecs_[layer_id].size(); blob_id++) {
1925             float min_in = bottom_vecs_[layer_id][blob_id]->min(0, 0);
1926             float max_in = bottom_vecs_[layer_id][blob_id]->max(0, 0);
1927             min_in *= range_expansion_factor;
1928             max_in *= range_expansion_factor;
1929             min_in_[layer_id][blob_id] = min_in_[layer_id][blob_id] * alpha +  min_in * beta;
1930             max_in_[layer_id][blob_id] = max_in_[layer_id][blob_id] * alpha +  max_in * beta;
1931         }
1932     }
1934     if(top_vecs_[layer_id].size()>0) {
1935         for(int blob_id = 0; blob_id<top_vecs_[layer_id].size(); blob_id++) {
1936             float min_out = top_vecs_[layer_id][blob_id]->min(0, 0);
1937             float max_out = top_vecs_[layer_id][blob_id]->max(0, 0);
1938             min_out *= range_expansion_factor;
1939             max_out *= range_expansion_factor;
1940             min_out_[layer_id][blob_id] = min_out_[layer_id][blob_id] * alpha +  min_out * beta;
1941             max_out_[layer_id][blob_id] = max_out_[layer_id][blob_id] * alpha +  max_out * beta;
1942         }
1943     }
1945     //TODO: Set to 1 to consider the weights only, and ignore the bias
1946     int max_params_to_consider = INT_MAX;
1947     int num_params = std::min((int)layers_[layer_id]->blobs().size(), max_params_to_consider);
1948     if(num_params > 0) {
1949         for(int blob_id = 0; blob_id < num_params; blob_id++) {
1950           float min_weights = (float)layers_[layer_id]->blobs()[blob_id]->min(0, 0);
1951           float max_weights = (float)layers_[layer_id]->blobs()[blob_id]->max(0, 0);
1952           //for weights, we can use the actual range - no need for running average.
1953           //min_weights *= expansion_factor;
1954           //max_weights *= expansion_factor;
1955           //min_weights_[layer_id] = min_weights_[layer_id] * alpha + min_weights * beta;
1956           //max_weights_[layer_id] = max_weights_[layer_id] * alpha + max_weights * beta;
1957           min_weights_[layer_id][blob_id] = min_weights;
1958           max_weights_[layer_id][blob_id] = max_weights;
1959         }
1960     }
1961   }
1964 vector<int> Net::GetTopLayerIds(int layer_id, bool only_one) {
1965     vector<int> top_layer_ids;
1966     vector<int> cur_top_ids = this->top_ids(layer_id);
1967     for(int t=0; t<cur_top_ids.size(); t++) {
1968       int tid = cur_top_ids[t];
1969       int found_lid = -1;
1970       for(int search_lid=(layer_id+1); search_lid<layers_.size(); search_lid++) {
1971         vector<int> seach_bids = this->bottom_ids(search_lid);
1972         for(int b=0; b<seach_bids.size(); b++) {
1973             if(tid == seach_bids[b]) {
1974                 found_lid = search_lid;
1975                 break;
1976             }
1977         }
1978         if(found_lid >= 0) {
1979             top_layer_ids.push_back(found_lid);
1980             if(only_one) {
1981                 break;
1982             }
1983         }
1984       }
1985     }
1986     return top_layer_ids;
1988 vector<int> Net::GetTopLayerIds2(int layer_id, bool only_one) {
1989     vector<int> layer_ids = GetTopLayerIds(layer_id);
1990     string layer_types = "";
1991     if(layer_ids.size()==1) {
1992         int top_layer_id = layer_ids[0];
1993         vector<int> layer_ids2 = GetTopLayerIds(top_layer_id);
1994         return layer_ids2;
1995     } else {
1996         return vector<int>();
1997     }
1999 vector<int> Net::GetBottomLayerIds(int layer_id, bool only_one) {
2000     vector<int> bottom_layer_ids;
2001     vector<int> cur_bottom_ids = this->bottom_ids(layer_id);
2002     for(int b=0; b<cur_bottom_ids.size(); b++) {
2003       int bid = cur_bottom_ids[b];
2004       int found_lid = -1;
2005       for(int search_lid=0; search_lid<layer_id; search_lid++) {
2006         vector<int> seach_tids = this->top_ids(search_lid);
2007         for(int t=0; t<seach_tids.size(); t++) {
2008             if(search_lid < layer_id && bid == seach_tids[t]) {
2009                 found_lid = search_lid;
2010             }
2011         }
2012       }
2013       if(found_lid >= 0) {
2014           bottom_layer_ids.push_back(found_lid);
2015           if(only_one) {
2016               break;
2017           }
2018       }
2019     }
2020     return bottom_layer_ids;
2022 string Net::GetTopLayerType(int layer_id, bool only_one) {
2023     vector<int> layer_ids = GetTopLayerIds(layer_id);
2024     string layer_types = "";
2025     for(int i=0; i<layer_ids.size(); i++) {
2026         if(layer_types != "") {
2027             layer_types += ",";
2028         }
2029         layer_types += layers_[layer_ids[i]]->type();
2030         if(only_one) {
2031             break;
2032         }
2033     }
2034     return layer_types;
2037 string Net::GetBottomLayerType(int layer_id, bool only_one) {
2038     vector<int> layer_ids = GetBottomLayerIds(layer_id);
2039     string layer_types = "";
2040     for(int i=0; i<layer_ids.size(); i++) {
2041         if(layer_types != "") {
2042             layer_types += ",";
2043         }
2044         layer_types += layers_[layer_ids[i]]->type();
2045         if(only_one) {
2046             break;
2047         }
2048     }
2049     return layer_types;
2051 string Net::GetTopLayerType2(int layer_id, bool only_one) {
2052     vector<int> layer_ids = GetTopLayerIds(layer_id);
2053     string layer_types = "";
2054     if(layer_ids.size()==1) {
2055         int top_layer_id = layer_ids[0];
2056         vector<int> layer_ids2 = GetTopLayerIds(top_layer_id);
2057         for(int i=0; i<layer_ids2.size(); i++) {
2058             if(layer_types != "") {
2059                 layer_types += ",";
2060             }
2061             layer_types += layers_[layer_ids2[i]]->type();
2062             if(only_one) {
2063                 break;
2064             }
2065         }
2066         return layer_types;
2067     } else {
2068         return "";
2069     }
2071 string Net::GetBottomLayerType2(int layer_id, bool only_one) {
2072     vector<int> layer_ids = GetBottomLayerIds(layer_id);
2073     string layer_types = "";
2074     if(layer_ids.size()==1) {
2075         int bottom_layer_id = layer_ids[0];
2076         vector<int> layer_ids2 = GetBottomLayerIds(bottom_layer_id);
2077         for(int i=0; i<layer_ids2.size(); i++) {
2078             if(layer_types != "") {
2079                 layer_types += ",";
2080             }
2081             layer_types += layers_[layer_ids2[i]]->type();
2082             if(only_one) {
2083                 break;
2084             }
2085         }
2086         return layer_types;
2087     } else {
2088         return "";
2089     }
2091 vector<const QuantizationParameter::QParams*> Net::GetBottomLayerQParams(int layer_id) {
2092     vector<const QuantizationParameter::QParams*> bottom_qparams;
2093     for(int b=0; b<this->bottom_ids(layer_id).size(); b++) {
2094         int search_bid = this->bottom_ids(layer_id)[b];
2095         int found_lid = -1;
2096         //int found_tid = -1;
2097         int found_top = -1;
2098         for (int search_lid = 0; search_lid < layer_id; search_lid++) {
2099             for(int t=0; t<this->top_ids(search_lid).size(); t++) {
2100               int search_tid = this->top_ids(search_lid)[t];
2101               if(search_bid == search_tid) {
2102                   found_lid = search_lid;
2103                   //found_tid = search_tid;
2104                   found_top = t;
2105               }
2106             }
2107         }
2108         if(found_lid >= 0) {
2109             const QuantizationParameter::QParams& qparam_out = layers_[found_lid]->layer_param().quantization_param().qparam_out(found_top);
2110             bottom_qparams.push_back(&qparam_out);
2111         }
2112     }
2113     return bottom_qparams;
2116 void Net::EnableQuantizationForSelectedLayers() {
2117   const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
2118   if(net_qparam.update_quantization_param()) {
2119     for (int layer_id = 0; layer_id < layers_.size(); layer_id++) {
2120       std::string layer_name = layers_[layer_id]->layer_param().name();
2121       std::string layer_type = layers_[layer_id]->layer_param().type();
2123       //LOG(INFO) << "Checking whether quantization is needed for: " << layer_type << " " << layer_name;
2125      //It is assumed that operations across multiple blobs (such as eltwise) is done in high precision.
2126      //So, we don't need to align the quantization ranges of different inputs.
2127      //Infact we don't quantize the input - it is assumed to be already quantized when it reaches the current layer.
2129       //find if this is a merged layer
2130       std::string layer_type_next = GetTopLayerType(layer_id);
2131       std::string layer_type_next2 = GetTopLayerType2(layer_id);
2132       std::string layer_type_prev = GetBottomLayerType(layer_id);
2133       std::string layer_type_prev2 = GetBottomLayerType2(layer_id);
2135       //std::cout << "LinearLayerChain: " << "(" << layer_type_prev2 << ")(" << layer_type_prev << ")("
2136       //        << layer_type << " - " << layer_name << ")" << "(" << layer_type_next << ")("
2137       //        << layer_type_next2 << ")" << std::endl;
2139       bool is_ignored_layer_name = false;
2140       for(int i=0; i<net_qparam.ignored_layer_names_size(); i++) {
2141           if(layer_name == net_qparam.ignored_layer_names(i)) {
2142               is_ignored_layer_name = true;
2143           }
2144       }
2146       bool is_merged_layer = false;
2147       if(layer_type == "Convolution" || layer_type == "InnerProduct" || layer_type == "Deconvolution") {
2148           if(layer_type_next == "BatchNorm" || layer_type_next == "Scale" || layer_type_next == "ReLU") {
2149               is_merged_layer = true;
2150           }
2151       } else if(layer_type == "BatchNorm") {
2152           if(layer_type_next == "Scale" || layer_type_next == "ReLU") {
2153               is_merged_layer = true;
2154           } else if(layer_type_next == "Convolution") {
2155               if(layer_type_prev != "Convolution") {
2156                   is_merged_layer = true;
2157               }
2158           } else if(layer_type_next == "Convolution" || layer_type_next == "Scale") {
2159               if(layer_type_prev != "Convolution") {
2160                   is_merged_layer = true;
2161               }
2162           }
2163       } else if(layer_type == "Scale") {
2164           if(layer_type_next == "ReLU") {
2165               is_merged_layer = true;
2166           } else if(layer_type_next == "Convolution") {
2167               if(layer_type_prev != "Convolution" && layer_type_prev2 != "Convolution") {
2168                   is_merged_layer = true;
2169               }
2170           }
2171       }
2173       //for data layers, only quantize the first output and ignore the others (eg. label)
2174       int max_blobs_to_quantize = INT_MAX;
2175       string layer_type_lower = layer_type;
2176       std::transform(layer_type_lower.begin(), layer_type_lower.end(), layer_type_lower.begin(),
2177               [](unsigned char c) {return std::tolower(c);}
2178       );
2179       if(layer_type_lower.find("Data") != string::npos) {
2180           max_blobs_to_quantize = 1;
2181       }
2182       if(layer_type_lower.find("Input") != string::npos) {
2183           max_blobs_to_quantize = 1;
2184       }
2186       bool is_quantized_layer_type = false;
2187       if(layer_type == "Convolution" || layer_type == "InnerProduct" || layer_type == "Deconvolution" ||
2188               layer_type == "BatchNorm" || layer_type == "Scale" || layer_type == "ReLU" ||
2189               layer_type == "PReLU" || layer_type == "Eltwise" || layer_type == "Concat" ||
2190               layer_type == "Bias" || layer_type == "Pooling") {
2191           is_quantized_layer_type = true;
2192       }
2193       if(layer_type.find("Data") != string::npos) {
2194           is_quantized_layer_type = true;
2195       }
2196       if(layer_type.find("Input") != string::npos) {
2197           is_quantized_layer_type = true;
2198       }
2200       //quantize weights
2201       if(net_qparam.quantize_weights()) {
2202           if(is_quantized_layer_type /*&& (!is_merged_layer)*/ && (!is_ignored_layer_name)) {
2203               if(layer_type == "Convolution" || layer_type == "InnerProduct" || layer_type == "Deconvolution") {
2204                   QuantizationParameter& quantization_param = *layers_[layer_id]->mutable_layer_param().mutable_quantization_param();
2205                   for(int blob_id=0; blob_id<layers_[layer_id]->blobs().size(); blob_id++) {
2206                     if(quantization_param.qparam_w_size() <= blob_id) {
2207                       quantization_param.add_qparam_w();
2208                     }
2209                     quantization_param.mutable_qparam_w(blob_id)->set_quantize(true);
2210                   }
2211               }
2212           }
2213       }
2215       //quantize output activations
2216       if(net_qparam.quantize_activations()) {
2217           if(is_quantized_layer_type && (!is_merged_layer) && (!is_ignored_layer_name)) {
2218               QuantizationParameter& quantization_param = *layers_[layer_id]->mutable_layer_param().mutable_quantization_param();
2219               for(int blob_id=0; blob_id<top_vecs_[layer_id].size(); blob_id++) {
2220                 if(quantization_param.qparam_out_size() <= blob_id) {
2221                   quantization_param.add_qparam_out();
2222                 }
2223                 if(blob_id < max_blobs_to_quantize) {
2224                   quantization_param.mutable_qparam_out(blob_id)->set_quantize(true);
2225                 }
2226               }
2227               LOG(INFO) << "Enabling quantization at output of: " << layer_type << " " << layer_name;
2228           }
2229       }
2230     }
2231   }
2234 void Net::SetQuantizationParams() {
2235   const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
2237   if(net_qparam.update_quantization_param()) {
2238     //insert quantization_param in the layers that do not have it
2239     QuantizationParameter_Rounding rounding_scheme = (this->phase() == caffe::TRAIN ?
2240             QuantizationParameter_Rounding_STOCHASTIC : net_qparam.rounding_scheme());
2242     for (int layer_id = 0; layer_id < layers_.size(); layer_id++) {
2243         QuantizationParameter& quantization_param = *layers_[layer_id]->mutable_layer_param().mutable_quantization_param();
2245         quantization_param.set_precision(net_qparam.precision());
2246         quantization_param.set_rounding_scheme(rounding_scheme);
2247         quantization_param.set_power2_scale_weights(net_qparam.power2_scale_weights());
2248         quantization_param.set_power2_scale_activations(net_qparam.power2_scale_activations());
2249         quantization_param.set_quantized_infer_count(infer_count_ - net_qparam.quantization_start());
2251         // quantize parameters
2252         SetQuantizationParamsLayerWeights(layer_id);
2254         // quantize input activations
2255         SetQuantizationParamsLayerInput(layer_id);
2257         // quantize output activations
2258         SetQuantizationParamsLayerOutput(layer_id);
2259     }
2260   }
2262   //this->DisplayQuantizationParams();
2265 int Net::EstimateAbsBits(float val) {
2266     return (val!=0)? ceil(log2(std::fabs(val))) : 0;
2269 void Net::EstiamteQScaleParams(float min, float max, int bitwidth, bool power2_scale,
2270     bool unsigned_data, bool apply_offset, QuantizationParameter::QParams& qparam_xx) {
2271   float scale_applied = qparam_xx.scale_applied();
2273   qparam_xx.set_bitwidth(bitwidth);
2274   qparam_xx.set_unsigned_data(unsigned_data);
2275   qparam_xx.set_unsigned_quant(unsigned_data || apply_offset);
2276   qparam_xx.set_min(min);
2277   qparam_xx.set_max(max);
2279   //fractbits cannot be computed in cases where non-power of 2 quant is involved. not using it.
2280   qparam_xx.set_fracbits(0);
2282   float max_val_abs = std::max(std::fabs(max), std::fabs(min))*scale_applied;
2283   float max_val_range = std::abs(max - min)*scale_applied;
2285   if(power2_scale) {
2286     int estimated_bits = apply_offset? EstimateAbsBits(max_val_range) :
2287         (unsigned_data? EstimateAbsBits(max_val_abs) : (EstimateAbsBits(max_val_abs)+1));
2288     int shiftbits = estimated_bits - bitwidth;
2289     //account for the scaling due to right shift by shiftbits
2290     float scale_relative = 1.0/((shiftbits>=0)? float(1<<shiftbits) : float(1.0/(1<<std::abs(shiftbits))));
2291     float scale_target = scale_relative*scale_applied;
2293     qparam_xx.set_shiftbits(shiftbits);
2294     qparam_xx.set_scale_target(scale_target);
2295     qparam_xx.set_offset(apply_offset? (0 - min * scale_target) : 0);
2296   } else {
2297     //We can even use (1L<<bitwidth). Since we clip the quantized output - this should not be an issue.
2298     //However found that ((1L<<bitwidth)-1) gave slightly better quality
2299     float max_qrange = ((1L<<bitwidth)); //TODO: change back to: ((1L<<bitwidth)-1);
2300     float max_qrange_half = ((1L<<(bitwidth-1))); //TODO: change back to: ((1L<<(bitwidth-1))-1);
2301     float scale_relative = apply_offset? max_qrange/max_val_range :
2302       (unsigned_data? max_qrange/max_val_abs : max_qrange_half/max_val_abs);
2303     float scale_target = scale_relative*scale_applied;
2305     //shiftbits is not integer in this case - so cannot be set accurately.
2306     //but it may still be needed to simulate rounding in Trim2FixedPoint kernel
2307     int estimated_bits = apply_offset? EstimateAbsBits(max_val_range) :
2308         (unsigned_data? EstimateAbsBits(max_val_abs) : (EstimateAbsBits(max_val_abs)+1));
2309     int shiftbits = estimated_bits - bitwidth;
2311     qparam_xx.set_shiftbits(shiftbits);
2312     qparam_xx.set_offset(apply_offset? (0 - min * scale_target) : 0);
2313     qparam_xx.set_scale_target(scale_target);
2314   }
2316   //the new scale target is on top of the scale applied. so the effective scale will be a product of both,
2317   //after the scaling is applied.
2320 void Net::SetQuantizationParamsLayerInput(const int layer_id) {
2321   //const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
2322   QuantizationParameter& quantization_param = *layers_[layer_id]->mutable_layer_param().mutable_quantization_param();
2324   vector<const QuantizationParameter::QParams*> qparam_bot_vec = this->GetBottomLayerQParams(layer_id);
2325   int num_bottom_vecs = bottom_vecs_[layer_id].size();
2326   CHECK(qparam_bot_vec.size() == num_bottom_vecs) << "Incorrect number of bottom qparams obtained";
2328   //It is assumed that operations across multiple blobs (such as eltwise) is done in high precision.
2329   //So, we don't need to align the quantization ranges of different inputs.
2330   for(int blob_id = 0; blob_id<num_bottom_vecs; blob_id++) {
2331     if(quantization_param.qparam_in_size() <= blob_id) {
2332       quantization_param.add_qparam_in();
2333     }
2335     //Get the scale applied to this blob
2336     QuantizationParameter::QParams& qparam_in = *quantization_param.mutable_qparam_in(blob_id);
2337     CHECK(qparam_bot_vec[blob_id] != NULL) << "Input QParams should not be NULL";
2338     if(qparam_bot_vec[blob_id]->has_scale_applied()) {
2339       qparam_in.set_scale_applied(qparam_bot_vec[blob_id]->scale_applied());
2340     }
2342     const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
2343     float min_layer = min_in_[layer_id][blob_id];
2344     float max_layer = max_in_[layer_id][blob_id];
2345     bool unsigned_data = (min_layer>=0);
2346     EstiamteQScaleParams(min_layer, max_layer, net_qparam.bitwidth_activations(),
2347        net_qparam.power2_scale_activations(), unsigned_data, net_qparam.apply_offset_activations(), qparam_in);
2349     if(qparam_in.quantize()) {
2350         qparam_in.set_scale_applied(qparam_in.scale_target());
2351     }
2352   }
2355 void Net::SetQuantizationParamsLayerOutput(const int layer_id) {
2356   const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
2357   QuantizationParameter& quantization_param = *layers_[layer_id]->mutable_layer_param().mutable_quantization_param();
2358   std::string layer_type = layers_[layer_id]->layer_param().type();
2360   int num_bottom_vecs = bottom_vecs_[layer_id].size();
2361   float scale_applied_in_max = num_bottom_vecs>0? -1.0: 1.0;
2362   //Get the max scale applied in the case of multiple blobs
2363   for(int blob_id = 0; blob_id<num_bottom_vecs; blob_id++) {
2364     QuantizationParameter::QParams& qparam_in = *quantization_param.mutable_qparam_in(blob_id);
2365     scale_applied_in_max = std::max(scale_applied_in_max, qparam_in.scale_applied());
2366   }
2368   //this scale will be applied only in the quantization function,
2369   //but we need to calculate the scaling effect (on the output) due to quantization scale of weights.
2370   float scale_target_w = 1.0;
2371   if(layer_type == "Convolution" || layer_type == "InnerProduct" || layer_type == "Deconvolution") {
2372       scale_target_w = quantization_param.mutable_qparam_w(0)->scale_target();
2373   } else if(layer_type == "BatchNorm") {
2374       scale_target_w = quantization_param.mutable_qparam_w(3)->scale_target();
2375   }
2377   int num_top_vecs = top_vecs_[layer_id].size();
2378   for(int blob_id = 0; blob_id<num_top_vecs; blob_id++) {
2379       if(quantization_param.qparam_out_size() <= blob_id) {
2380         quantization_param.add_qparam_out();
2381       }
2382       float min_layer = min_out_[layer_id][blob_id];
2383       float max_layer = max_out_[layer_id][blob_id];
2384       bool unsigned_data = (min_layer>=0);
2386       //the convolution operations in floating point doesn't change the applied scale,
2387       //but since the weights are quantized, there is a scaling that happens due to it.
2388       //applied for it by multiplying with it.
2389       QuantizationParameter::QParams& qparam_out = *quantization_param.mutable_qparam_out(blob_id);
2390       if(layer_type == "Convolution" || layer_type == "InnerProduct" || layer_type == "Deconvolution" ||
2391               layer_type == "BatchNorm" || layer_type == "Scale") {
2392           qparam_out.set_scale_applied(scale_applied_in_max * scale_target_w);
2393       }
2395       EstiamteQScaleParams(min_layer, max_layer, net_qparam.bitwidth_activations(),
2396           net_qparam.power2_scale_activations(), unsigned_data, net_qparam.apply_offset_activations(), qparam_out);
2398       int num_blobs = layers_[layer_id]->blobs().size();
2399       int fracbits_in = quantization_param.qparam_in_size()>0? quantization_param.qparam_in(0).fracbits() : 0;
2400       int fracbits_out = qparam_out.fracbits();
2401       int fracbits_weights = num_blobs>0? quantization_param.qparam_w(0).fracbits() : 0;
2403       if(layer_type == "Convolution" || layer_type == "InnerProduct" || layer_type == "Deconvolution") {
2404           //special handling for bias.
2405           //scale factor for bias = scale_in * scale_w
2406           if(blob_id > 0) {
2407               QuantizationParameter::QParams& qparam_in = *quantization_param.mutable_qparam_out(0);
2408               QuantizationParameter::QParams& qparam_w = *quantization_param.mutable_qparam_w(0);
2409               float scale_target = qparam_in.scale_target() * qparam_w.scale_target();
2410               qparam_out.set_scale_target(scale_target);
2411           } else {
2412               //avoid left shift at output - will lose accuracy
2413               if((fracbits_in + fracbits_weights) < fracbits_out) {
2414                 fracbits_out = (fracbits_in + fracbits_weights);
2415                 qparam_out.set_fracbits(fracbits_out);
2416               }
2417               //if((fracbits_in + fracbits_weights) < fracbits_out) {
2418               //  LOG(FATAL) << "Qformat error for layer: " << layers_[layer_id]->layer_param().name()
2419               //      << " fracbits_in:" << fracbits_in << " fracbits_weights:" << fracbits_weights
2420               //      << " fracbits_out:" << fracbits_out;
2421               //}
2422           }
2423       }
2425       if(qparam_out.quantize()) {
2426         qparam_out.set_scale_applied(qparam_out.scale_target());
2427         //LOG(INFO) << layer_type << ": " << blob_id << ": " << "scale_applied:" << qparam_out.scale_applied();
2428       }
2429   }
2432 void Net::SetQuantizationParamsLayerWeights(const int layer_id) {
2433   const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
2434   QuantizationParameter& quantization_param = *layers_[layer_id]->mutable_layer_param().mutable_quantization_param();
2435   std::string layer_type = layers_[layer_id]->layer_param().type();
2437   int num_blobs = layers_[layer_id]->blobs().size();
2438   for(int blob_id = 0; blob_id<num_blobs; blob_id++) {
2439     if(quantization_param.qparam_w_size() <= blob_id) {
2440       quantization_param.add_qparam_w();
2441     }
2443     float min_layer = min_weights_[layer_id][blob_id];
2444     float max_layer = max_weights_[layer_id][blob_id];
2445     bool unsigned_data = (min_layer>=0);
2446     int bitwidth = (layer_type == "BatchNorm" || blob_id > 0)? net_qparam.bitwidth_bias() : net_qparam.bitwidth_weights();
2447     QuantizationParameter::QParams& qparam_w = *quantization_param.mutable_qparam_w(blob_id);
2448     EstiamteQScaleParams(min_layer, max_layer, bitwidth,
2449         net_qparam.power2_scale_weights(), unsigned_data, net_qparam.apply_offset_weights(), qparam_w);
2451     //if we set scale_applied for weights, the next time scale_target computation will go wrong.
2452     //since the same values and blob are re-used every frame. Anyway, scale_applied of weights is not used.
2453     //if(qparam_w.quantize()) {
2454     //    qparam_w.set_scale_applied(qparam_w.scale_target());
2455     //}
2456   }
2460 void Net::DisplayQuantizationParams() {
2461   const NetQuantizationParameter& net_qparam = net_param_.net_quantization_param();
2463   for (int i = 0; i < layers_.size(); ++i) {
2464     if (layers_[i]->layer_param().has_quantization_param()) {
2465       // if this is a convolutional layer which should be quantized ...
2466       QuantizationParameter& quantization_param = *layers_[i]->mutable_layer_param().mutable_quantization_param();
2467       int num_blobs = layers_[i]->blobs().size();
2468       if (quantization_param.qparam_w_size()>0 && net_qparam.quantize_weights() && num_blobs>0 &&
2469               quantization_param.qparam_w(0).quantize()) {
2470         LOG(INFO)<<" Q weights:" << i << " Name:" << layers_[i]->layer_param().name() <<
2471         " bitwidth:" << quantization_param.qparam_w(0).bitwidth() <<
2472         " fracbits:" << quantization_param.qparam_w(0).fracbits() <<
2473         " scale:" << quantization_param.qparam_w(0).scale_target() <<
2474         " offset:" << quantization_param.qparam_w(0).offset() <<
2475         " unsigned_data:" << quantization_param.qparam_w(0).unsigned_data() <<
2476         " min:" << quantization_param.qparam_w(0).min() <<
2477         " max:" << quantization_param.qparam_w(0).max() <<
2478         " scale_applied:" << quantization_param.qparam_w(0).scale_applied() <<
2479         " scale_target:" << quantization_param.qparam_w(0).scale_target();
2480       }
2482       if (quantization_param.qparam_w_size()>1 && net_qparam.quantize_weights() && num_blobs>1 &&
2483               quantization_param.qparam_w(1).quantize()) {
2484         LOG(INFO)<<" Q bias:" << i << " Name:" << layers_[i]->layer_param().name() <<
2485         " bitwidth:" << quantization_param.qparam_w(1).bitwidth() <<
2486         " fracbits:" << quantization_param.qparam_w(1).fracbits() <<
2487         " scale:" << quantization_param.qparam_w(1).scale_target() <<
2488         " offset:" << quantization_param.qparam_w(1).offset() <<
2489         " unsigned_data:" << quantization_param.qparam_w(1).unsigned_data() <<
2490         " min:" << quantization_param.qparam_w(1).min() <<
2491         " max:" << quantization_param.qparam_w(1).max() <<
2492         " scale_applied:" << quantization_param.qparam_w(1).scale_applied() <<
2493         " scale_target:" << quantization_param.qparam_w(1).scale_target();
2494       }
2496       if (quantization_param.qparam_in_size()>0 && net_qparam.quantize_activations() &&
2497               quantization_param.qparam_in(0).quantize()) {
2498         int num_bottom_vecs = bottom_vecs_[i].size();
2499         std::stringstream ss;
2500         ss << " Q input :" << i << " Name:" << layers_[i]->layer_param().name();
2501         for(int blob_id=0; blob_id<std::min<int>(num_bottom_vecs, quantization_param.qparam_in_size()); blob_id++) {
2502           ss << " bitwidth:" << quantization_param.qparam_in(blob_id).bitwidth();
2503           ss << " fracbits:" << quantization_param.qparam_in(blob_id).fracbits();
2504           ss << " scale:" << quantization_param.qparam_in(blob_id).scale_target() ;
2505           ss << " offset:" << quantization_param.qparam_in(blob_id).offset() ;
2506           ss << " unsigned_data:" << quantization_param.qparam_in(blob_id).unsigned_data();
2507           ss << " min:" << quantization_param.qparam_in(blob_id).min();
2508           ss << " max:" << quantization_param.qparam_in(blob_id).max();
2509           ss << " scale_applied:" << quantization_param.qparam_in(blob_id).scale_applied();
2510           ss << " scale_target:" << quantization_param.qparam_in(blob_id).scale_target();
2511         }
2512         LOG(INFO) << ss.str();
2513       }
2515       if (quantization_param.qparam_out_size()>0 && net_qparam.quantize_activations() &&
2516               quantization_param.qparam_out(0).quantize()) {
2517         LOG(INFO)<< " Q output:" << i << " Name:" << layers_[i]->layer_param().name() <<
2518         " bitwidth:" << quantization_param.qparam_out(0).bitwidth() <<
2519         " fracbits:" << quantization_param.qparam_out(0).fracbits() <<
2520         " scale:" << quantization_param.qparam_out(0).scale_target() <<
2521         " offset:" << quantization_param.qparam_out(0).offset() <<
2522         " unsigned_data:" << quantization_param.qparam_out(0).unsigned_data() <<
2523         " min:" << quantization_param.qparam_out(0).min() <<
2524         " max:" << quantization_param.qparam_out(0).max() <<
2525         " scale_applied:" << quantization_param.qparam_out(0).scale_applied() <<
2526         " scale_target:" << quantization_param.qparam_out(0).scale_target();
2527       }
2528     }
2529   }
2532 void Net::DisableQuantization() {
2533   for (int i = 0; i < layers_.size(); ++i) {
2534     if (layers_[i]->layer_param().has_quantization_param()) {
2535       QuantizationParameter& quantization_param = *layers_[i]->mutable_layer_param().mutable_quantization_param();
2536       quantization_param.set_precision(QuantizationParameter_Precision_FLOAT);
2537     }
2538   }
2542 //Old, deprecated function.
2543 void Net::FindAndApplyThresholdNet(float threshold_fraction_low, float threshold_fraction_mid, float threshold_fraction_high,
2544     float threshold_value_maxratio, float threshold_value_max, float threshold_step_factor, bool verbose) {
2546   for (int i = 0; i < layers_.size(); i++) {
2547     if (layers_[i]->type() == std::string("Convolution")) {
2548       LayerBase& conv_layer = *layers_[i];
2549       Blob& conv_weights = *conv_layer.blobs()[0];
2550       const ConvolutionParameter& conv_param = layers_[i]->layer_param().convolution_param();
2551       int num_group = layers_[i]->layer_param().convolution_param().group();
2552       //int stride = layers_[i]->layer_param().convolution_param().stride_size()>0? layers_[i]->layer_param().convolution_param().stride(0) : 1;
2554       int no = (conv_weights.num_axes() == 1)? conv_weights.count() : conv_weights.shape(0);
2555       int ni = ((conv_weights.num_axes() == 1)? conv_weights.count() : conv_weights.shape(1))*num_group;
2556       float count = conv_weights.count();
2557       if(verbose) {
2558         LOG(WARNING) << layers_[i]->layer_param().name() << " ni=" << ni << " no=" << no;
2559       }
2561       int kernel_shape_data[2];
2562       if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) {
2563         kernel_shape_data[0] = conv_param.kernel_h();
2564         kernel_shape_data[1] = conv_param.kernel_w();
2565       } else {
2566         const int num_kernel_dims = conv_param.kernel_size_size();
2567         for (int i = 0; i < 2; ++i) {
2568           kernel_shape_data[i] = conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i);
2569         }
2570       }
2572       //need to add it as cfg option :FIX_ME:SN
2573       const bool no_sparsity_for_small_kernel = true;
2574       bool need_sparsity_for_this_layer = true;
2575       if (no_sparsity_for_small_kernel) {
2576         need_sparsity_for_this_layer = (kernel_shape_data[0] > 2) && (kernel_shape_data[1] > 2) && (num_group != no);
2577       }
2579       if((ni>=32 || no >= 32) && num_group<no && need_sparsity_for_this_layer) {
2580         float threshold_fraction_selected = ((ni>=256 && no >= 512)? threshold_fraction_high :
2581             ((ni>=32 && no >= 32)? threshold_fraction_mid: threshold_fraction_low));
2582         float selected_threshold = 0;
2583         float max_abs = std::abs(conv_weights.max(0, 0));
2584         float min_abs = std::abs(conv_weights.min(0, 0));
2585         float max_abs_value = std::max<float>(max_abs, min_abs);
2586         float step_size = max_abs_value * threshold_step_factor;
2587         float max_threshold_value = std::min<float>(threshold_value_max, max_abs_value*threshold_value_maxratio);
2589         float step_sizeX = step_size*100;
2590         float selected_thresholdX = 0;
2591         for(float step=0; step<max_abs_value && step<max_threshold_value; step+=step_sizeX) {
2592           float zcount = conv_weights.count_zero((float)step, 0, 0);
2593           float zratio = zcount / count;
2594           if(zratio <= threshold_fraction_selected) {
2595             selected_thresholdX = step;
2596           } else {
2597             break;
2598           }
2599         }
2601         for(float step=std::max((selected_thresholdX-step_sizeX),0.0f);
2602             step<(selected_thresholdX+step_sizeX) && step<max_abs_value && step<max_threshold_value;
2603             step+=step_size) {
2604           float zcount = conv_weights.count_zero((float)step, 0, 0);
2605           float zratio = zcount / count;
2606           if(zratio <= threshold_fraction_selected) {
2607             selected_threshold = step;
2608           } else {
2609             break;
2610           }
2611         }
2613         conv_weights.zerout(selected_threshold, 0, 0);
2615         if(verbose) {
2616           float zcount = conv_weights.count_zero(0.0, 0, 0);
2617           LOG(WARNING) << layers_[i]->layer_param().name() << " MaxAbsWeight=" << max_abs_value
2618               << " MaxThreshold=" << max_threshold_value << " SelectedThreshold=" << selected_threshold
2619               << " ZeroPercentage=" << (zcount*100/count);
2620         }
2621       }
2622     }
2623   }
2627 void Net::FindAndApplyChannelThresholdNet(float threshold_fraction_low, float threshold_fraction_mid, float threshold_fraction_high,
2628     float threshold_value_maxratio, float threshold_value_max, float threshold_step_factor, bool verbose) {
2630   for (int i = 0; i < layers_.size(); i++) {
2631     if (layers_[i]->type() == std::string("Convolution")) {
2632       LayerBase& conv_layer = *layers_[i];
2633       Blob& conv_weights = *conv_layer.blobs()[0];
2634       const ConvolutionParameter& conv_param = layers_[i]->layer_param().convolution_param();
2635       const string layer_name = layers_[i]->layer_param().name();
2637       int num_group = conv_param.group();
2638       //int stride = conv_param.stride_size()>0? conv_param.stride(0) : 1;
2639       int kernel_shape_data[2];
2640       if (conv_param.has_kernel_h() || conv_param.has_kernel_w()) {
2641         kernel_shape_data[0] = conv_param.kernel_h();
2642         kernel_shape_data[1] = conv_param.kernel_w();
2643       } else {
2644         const int num_kernel_dims = conv_param.kernel_size_size();
2645         for (int i = 0; i < 2; ++i) {
2646           kernel_shape_data[i] = conv_param.kernel_size((num_kernel_dims == 1) ? 0 : i);
2647         }
2648       }
2650       int no = (conv_weights.num_axes() == 1)? conv_weights.count() : conv_weights.shape(0);
2651       int ni = ((conv_weights.num_axes() == 1)? conv_weights.count() : conv_weights.shape(1))*num_group;
2652       float count = conv_weights.count();
2653       if(verbose) {
2654         LOG(WARNING) << layers_[i]->layer_param().name() << " ni=" << ni << " no=" << no;
2655       }
2657       //need to add it as cfg option :FIX_ME:SN
2658       const bool no_sparsity_for_small_kernel = true;
2659       bool need_sparsity_for_this_layer = true;
2660       if (no_sparsity_for_small_kernel) {
2661         need_sparsity_for_this_layer = (kernel_shape_data[0] > 2) && (kernel_shape_data[1] > 2) && (num_group != no);
2662       }
2664       //apply sparsity only to certain layers. exclude layers with small number of input and outputs
2665       //also exclude depth-wise separable layers.
2666       if((ni>=32 || no >= 32)  && (num_group<no) && need_sparsity_for_this_layer) {
2667         float threshold_fraction_selected = ((ni>=256 && no >= 512)? threshold_fraction_high :
2668             ((ni>=32 && no >= 32)? threshold_fraction_mid: threshold_fraction_low));
2670         for(int c=0; c<no; c++) {
2671           int weight_count_channel = ni * kernel_shape_data[0] * kernel_shape_data[1] / num_group;
2672           int start_index = weight_count_channel * c;
2674           float max_abs = std::abs(conv_weights.max(start_index, weight_count_channel));
2675           float min_abs = std::abs(conv_weights.min(start_index, weight_count_channel));
2676           float max_abs_value = std::max<float>(max_abs, min_abs);
2677           float step_size = max_abs_value * threshold_step_factor;
2678           float max_threshold_value = std::min<float>(std::min<float>(threshold_value_max, max_abs_value*threshold_value_maxratio), max_abs_value);
2679           bool verbose_th_val = false;
2680           if(verbose && verbose_th_val) {
2681             if ((max_abs_value*threshold_value_maxratio) > threshold_value_max) {
2682                 LOG(INFO) << "threshold_value_max " << threshold_value_max;
2683                 LOG(INFO) << "threshold_value_maxratio " << threshold_value_maxratio;
2684                 LOG(INFO) << "max_abs_value*threshold_value_maxratio " << (max_abs_value*threshold_value_maxratio);
2685                 LOG(INFO) << "final threshold_value used" << max_threshold_value; 
2686             }
2687           }
2689           float selected_threshold = 0;
2690           float granurality_start = 1000;
2691           for(float granurality = granurality_start, search_iter=0; granurality>=1; granurality=granurality/10, search_iter++) {
2692             float step_sizeX = step_size * granurality;
2693             float range_sizeX = step_sizeX*10*2;
2694             float start_valueX = selected_threshold;
2696             float min_step_val = search_iter>0? std::max((start_valueX-range_sizeX),0.0f) : 0;
2697             float max_step_val = search_iter>0? (start_valueX+range_sizeX) : max_threshold_value;
2698             for(float step= min_step_val; step<max_step_val && step<max_threshold_value; step+=step_sizeX) {
2699               float zcount = conv_weights.count_zero((float)step, start_index, weight_count_channel);
2700               float zratio = zcount / weight_count_channel;
2701               if(zratio <= threshold_fraction_selected) {
2702                 selected_threshold = step;
2703               } else {
2704                 break;
2705               }
2706             }
2707           }
2709           conv_weights.zerout(selected_threshold, start_index, weight_count_channel);
2710           //LOG(INFO) << "Layer:" << layer_name << " channel:" << c << " threshold:"
2711           //   << selected_threshold << " sparsity:"<< conv_weights.count_zero(0.0, start_index, weight_count_channel);
2712         }
2714         if(verbose) {
2715           float zcount = conv_weights.count_zero(0.0, 0, 0);
2716           LOG(WARNING) << layers_[i]->layer_param().name()
2717               //<< " MaxAbsWeight=" << max_abs_value
2718               //<< " MaxThreshold=" << max_threshold_value << " SelectedThreshold=" << selected_threshold
2719               << " ZeroWeightsFraction=" << (zcount/count);
2720         }
2721       }
2722     }
2723   }
2727 /**
2728  * ApplySparseModeConnectivity
2729  * Yet another way to do this is to store the threshold for each layer in FindAndApplyThresholdNet
2730  * And just use it here. But the current implementation of this cuntion is more generic
2731  * since it can be used when thresholding is completely outside.
2732  */
2733 void Net::ApplySparseModeConnectivity() {
2734   for (int i = 0; i < layers_.size(); i++) {
2735     if (layers_[i]->type() == std::string("Convolution")) {
2736       LayerBase& conv_layer = *layers_[i];
2737       Blob& conv_weights = *conv_layer.blobs()[0];
2739       //Use the connectivity information in the blob and zerout values accordingly.
2740       conv_weights.ComputeSparseData();
2742       //This is strictly not necessary
2743       //conv_weights.ComputeSparseDiff();
2744     }
2745   }
2748 void Net::StoreSparseModeConnectivity(SparseMode mode) {
2749   LOG_IF(INFO, Caffe::root_solver()) << "All zero weights of convolution layers are frozen";
2750   if(mode != SPARSE_NONE) {
2751     for(int i=0; i<layers_.size(); i++) {
2752       if(layers_[i]->type() == std::string("Convolution")) {
2753         LayerBase& conv_layer = *layers_[i];
2754         Blob& conv_weights = *conv_layer.blobs()[0];
2756         //Store the non-zero weight information
2757         conv_weights.StoreSparseModeConnectivity(mode);
2758       }
2759     }
2760   }
2763 float Net::DisplaySparsity(bool verbose) {
2764   float total_zero_count = 0, total_count = 0;
2765   {
2766     std::map<std::string, std::pair<int,int> > spasity_map;
2767     int blob_count = this->GetSparsity(spasity_map);
2768     if(verbose) {
2769       LOG(INFO) << "Num Params(" << blob_count << "), " << "Sparsity (zero_weights/count): ";
2770     }
2772     for(std::map<std::string, std::pair<int,int> >::iterator
2773         iter = spasity_map.begin(); iter != spasity_map.end(); iter++) {
2774       std::string param_name = iter->first;
2775       float zero_count = iter->second.first;
2776       float count = iter->second.second;
2777       total_zero_count += zero_count;
2778       total_count += count;
2779       if(verbose) {
2780         LOG(INFO) << param_name << "(" << std::setprecision(3) << (zero_count/count) << ") ";
2781       }
2782     }
2783     if(verbose) {
2784       LOG(INFO) << "Total Sparsity (zero_weights/count) = "
2785           << " (" << total_zero_count << "/" << total_count << ") "
2786           << std::setprecision(3) << (total_zero_count/total_count);
2787     }
2788   }
2790   return (total_zero_count/total_count);
2793 float Net::DisplayConnectivitySparsity(bool verbose) {
2794   float total_zero_count = 0, total_count = 0;
2796   std::map<std::string, std::pair<int,int> > spasity_map;
2797   int blob_count = this->GetConnectivitySparsity(spasity_map);
2798   if(verbose) {
2799     LOG(INFO) << "Num Params(" << blob_count << "), " << "ConnectivitySparsity (zero_weights/count): ";
2800   }
2802   for(std::map<std::string, std::pair<int,int> >::iterator
2803       iter = spasity_map.begin(); iter != spasity_map.end(); iter++) {
2804     std::string param_name = iter->first;
2805     float zero_count = iter->second.first;
2806     float count = iter->second.second;
2807     total_zero_count += zero_count;
2808     total_count += count;
2809     if(verbose) {
2810       LOG(INFO) << param_name << "(" << std::setprecision(3) << (zero_count/count) << ") ";
2811     }
2812   }
2813   if(verbose) {
2814     LOG(INFO) << "Total ConnectivitySparsity (zero_weights/count) = "
2815         << " (" << total_zero_count << "/" << total_count << ") "
2816         << std::setprecision(3) << (total_zero_count/total_count);
2817   }
2819   return (total_zero_count/total_count);
2822 int Net::GetSparsity(std::map<std::string, std::pair<int,int> >& sparsity_map){
2823   int blob_count = 0;
2824   float threshold = 0.0f;
2825   sparsity_map.clear();
2826   int max_params_to_check = 1;
2827   for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
2828       const LayerParameter& layer_param = layers_[layer_id]->layer_param();
2830       bool next_layer_is_softmax = false;
2831       if((layer_id+1) < layers_.size() && (layers_[layer_id+1]->layer_param().type() == "Softmax" ||
2832           layers_[layer_id+1]->layer_param().type() == "SoftmaxWithLoss")) {
2833         next_layer_is_softmax = true;
2834       }
2835       bool next_layer_is_not_softmax = (!next_layer_is_softmax);
2836       bool is_candidate_layer = (layer_param.type() == "Convolution" /*|| layer_param.type() == "InnerProduct"*/);
2838       if(next_layer_is_not_softmax && is_candidate_layer)  {
2839           int num_params_to_check = std::min<int>(max_params_to_check, layers_[layer_id]->blobs().size());
2840           for (int param_id = 0; param_id < num_params_to_check;++param_id) {
2841             const Blob& blob = *layers_[layer_id]->blobs()[param_id];
2842             const int net_param_id = param_id_vecs_[layer_id][param_id];
2843             const string& blob_name = param_display_names_[net_param_id];
2844             //const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
2845             std::pair<int,int> sp_map = std::make_pair(blob.count_zero(threshold, 0, 0), blob.count());
2846             sparsity_map[layer_names_[layer_id] + "_param_" + blob_name] = sp_map;
2847             blob_count++;
2848           }
2849       }
2850   }
2851   return blob_count;
2854 int Net::GetConnectivitySparsity(std::map<std::string, std::pair<int,int> >& sparsity_map){
2855   int blob_count = 0;
2856   float threshold = 0.0f;
2857   sparsity_map.clear();
2858   int max_params_to_check = 1;
2859   for (int layer_id = 0; layer_id < layers_.size(); ++layer_id) {
2860       const LayerParameter& layer_param = layers_[layer_id]->layer_param();
2862       bool next_layer_is_softmax = false;
2863       if((layer_id+1) < layers_.size() && (layers_[layer_id+1]->layer_param().type() == "Softmax" ||
2864           layers_[layer_id+1]->layer_param().type() == "SoftmaxWithLoss")) {
2865         next_layer_is_softmax = true;
2866       }
2867       bool next_layer_is_not_softmax = (!next_layer_is_softmax);
2868       bool is_candidate_layer = (layer_param.type() == "Convolution" /*|| layer_param.type() == "InnerProduct"*/);
2870       if(next_layer_is_not_softmax && is_candidate_layer) {
2871           int num_params_to_check = std::min<int>(max_params_to_check, layers_[layer_id]->blobs().size());
2872           for (int param_id = 0; param_id < num_params_to_check;++param_id) {
2873             const Blob& blob = *layers_[layer_id]->blobs()[param_id];
2874             const int net_param_id = param_id_vecs_[layer_id][param_id];
2875             const string& blob_name = param_display_names_[net_param_id];
2876             //const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
2877             std::pair<int,int> sp_map = std::make_pair(blob.count_zero_connectivity(threshold, 0, 0), blob.count());
2878             sparsity_map[layer_names_[layer_id] + "_param_" + blob_name] = sp_map;
2879             blob_count++;
2880           }
2881       }
2882   }
2883   return blob_count;
2886 template void Net::Convert2FixedPoint_cpu(float* data, const int cnt, const int bw, int fl, bool unsigned_data, bool clip) const;
2888 }  // namespace caffe