summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: 34bd702)
raw | patch | inline | side by side (parent: 34bd702)
author | Manu Mathew <mathmanu@users.noreply.github.com> | |
Mon, 11 Jun 2018 09:51:03 +0000 (15:21 +0530) | ||
committer | Manu Mathew <mathmanu@users.noreply.github.com> | |
Thu, 5 Jul 2018 08:29:18 +0000 (13:59 +0530) |
src/caffe/net.cpp | patch | blob | history |
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
index a51715e74312d96506a1eb4aa256e6c5de358e52..a4bab3399a026fee56aaa6cd341fd11593a1f4b6 100644 (file)
--- a/src/caffe/net.cpp
+++ b/src/caffe/net.cpp
<< source_layer_type << " #blobs=" << source_layer.blobs_size();
int num_blobs_to_copy = std::min<int>(target_blobs.size(), source_layer.blobs_size());
// check if BN is in legacy DIGITS format?
- if (source_layer_type == "BatchNorm" && source_layer.blobs_size() == 5) {
+ if (source_layer_type == "BatchNorm") {
for (int j = 0; j < num_blobs_to_copy; ++j) {
const bool kReshape = true;
target_blobs[j]->FromProto(source_layer.blobs(j), kReshape);
DLOG(INFO) << target_blobs[j]->count();
}
- if (target_blobs[4]->count() == 1) {
+ if (source_layer.blobs_size() == 5 && target_blobs[4]->count() == 1) {
// old format: 0 - scale , 1 - bias, 2 - mean , 3 - var, 4 - reserved
// new format: 0 - mean , 1 - var, 2 - reserved , 3- scale, 4 - bias
LOG(INFO) << "BN legacy DIGITS format detected ... ";
std::swap(target_blobs[3], target_blobs[4]);
LOG(INFO) << "BN Transforming to new format completed.";
}
+ if (source_layer.blobs_size() == 3) {
+ const float scale_factor = target_blobs[2]->cpu_data<float>()[0] == 0.F ?
+ 0.F : 1.F / target_blobs[2]->cpu_data<float>()[0];
+ caffe_cpu_scale(target_blobs[0]->count(), scale_factor,
+ target_blobs[0]->cpu_data<float>(),
+ target_blobs[0]->mutable_cpu_data<float>());
+ caffe_cpu_scale(target_blobs[1]->count(), scale_factor,
+ target_blobs[1]->cpu_data<float>(),
+ target_blobs[1]->mutable_cpu_data<float>());
+ target_blobs[2]->mutable_cpu_data<float>()[0] = 1.F;
+ }
for (int j = 0; j < target_blobs.size(); ++j) {
DLOG(INFO) << target_blobs[j]->count();
}