summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: eb058bd)
raw | patch | inline | side by side (parent: eb058bd)
author | Yangqing Jia <jiayq84@gmail.com> | |
Tue, 8 Oct 2013 00:38:03 +0000 (17:38 -0700) | ||
committer | Yangqing Jia <jiayq84@gmail.com> | |
Tue, 8 Oct 2013 00:38:03 +0000 (17:38 -0700) |
diff --git a/src/Makefile b/src/Makefile
index deebe75715d7fbd3ac3ac55a52b1d44fe2773b97..d78e99bc4006a631592a7df089536c06320c9709 100644 (file)
--- a/src/Makefile
+++ b/src/Makefile
# a lowercase prefix (in this case "program") and an uppercased suffix (in this case "NAME"), separated
# by an underscore is used to name attributes for a common element. Think of this like
# using program.NAME, program.C_SRCS, etc. There are no structs in Make, so we use this convention
-# to keep track of attributes that all belong to the same target or program.
+# to keep track of attributes that all belong to the same target or program.
#
PROJECT := caffe
NAME := lib$(PROJECT).so
program: $(OBJS) $(PROGRAM_BINS)
runtest: test
- for testbin in $(TEST_BINS); do $$testbin; done
+ for testbin in $(TEST_BINS); do $$testbin 1; done
$(TEST_BINS): %.testbin : %.o
$(CXX) -pthread $< $(OBJS) $(GTEST_OBJ) -o $@ $(LDFLAGS) $(WARNINGS)
diff --git a/src/caffe/filler.hpp b/src/caffe/filler.hpp
index 99cb5bcc92e9573e78c71130c950a2c4ac83925f..effe62ff2c56155905480cdc0d98fd44510ccee3 100644 (file)
--- a/src/caffe/filler.hpp
+++ b/src/caffe/filler.hpp
//
// It fills the incoming matrix by randomly sampling uniform data from
// [-scale, scale] where scale = sqrt(3 / fan_in) where fan_in is the number
-// of input nodes, and in our case we consider the blob width as the scale.
-// You should make sure the input blob has shape (1, 1, height, width).
+// of input nodes. You should make sure the input blob has shape (num, a, b, c)
+// where a * b * c = fan_in.
template <typename Dtype>
class XavierFiller : public Filler<Dtype> {
public:
: Filler<Dtype>(param) {}
virtual void Fill(Blob<Dtype>* blob) {
CHECK(blob->count());
- CHECK_EQ(blob->num(), 1) << "XavierFiller requires blob.num() = 1.";
- CHECK_EQ(blob->channels(), 1)
- << "XavierFiller requires blob.channels() = 1.";
- int fan_in = blob->width();
+ int fan_in = blob->count() / blob->num();
Dtype scale = sqrt(Dtype(3) / fan_in);
caffe_vRngUniform<Dtype>(blob->count(), blob->mutable_cpu_data(),
-scale, scale);
index 9560e47d36b47eb949e9c329d793b3fd8c3689af..8bf913a669eefdd116d58501fbd9b6c2406a0eb9 100644 (file)
this->blobs_.resize(1);
}
// Intialize the weight
- this->blobs_[0].reset(new Blob<Dtype>(1, 1, NUM_OUTPUT_, K_));
+ this->blobs_[0].reset(
+ new Blob<Dtype>(NUM_OUTPUT_, CHANNELS_ / GROUP_, KSIZE_, KSIZE_));
// fill the weights
shared_ptr<Filler<Dtype> > weight_filler(
GetFiller<Dtype>(this->layer_param_.weight_filler()));
index cb288b33187964f089855a77d00804bc76d00f70..2b2656d2a318729542a0f813128e46c70eb7ed45 100644 (file)
net_->Update();
// Check if we need to do snapshot
- if (param_.snapshot() > 0 && iter_ % param_.snapshot()) {
- // TODO(Yangqing): snapshot
- NOT_IMPLEMENTED;
+ if (param_.snapshot() > 0 && iter_ % param_.snapshot() == 0) {
+ Snapshot(false);
}
if (param_.display()) {
LOG(ERROR) << "Iteration " << iter_ << ", loss = " << loss;
} else {
ss << "_iter_" << iter_;
}
+ string filename = ss.str();
+ LOG(ERROR) << "Snapshotting to " << filename;
ofstream output_file;
- output_file.open(ss.str().c_str());
+ output_file.open(filename.c_str());
CHECK(net_param.SerializeToOstream(&output_file));
output_file.close();
}
index 8a76a508f7a1d07bf1781c2dba1fe5bbb6c56ea7..6bc76aef10a497d74c8a985f9c45e2195796372e 100644 (file)
import numpy as np
def blobproto_to_array(blob):
- arr = np.array(blob.data).reshape(blob.num(), blob.channels(), blobs.height(),
- blobs.width())
+ arr = np.array(blob.data).reshape(blob.num, blob.channels, blob.height,
+ blob.width)
return arr
def array_to_blobproto(arr):
raise ValueError('Incorrect array shape.')
blob = caffe_pb2.BlobProto()
blob.num, blob.channels, blob.height, blob.width = arr.shape;
- blob.data.extend(arr.flat)
+ blob.data.extend(arr.astype(float).flat)
return blob
def array_to_datum(arr):
index c991d8d84cb88fcbf24a6173725fff41d7cea262..e6600f2119de8b790e724575b088f90f0f31cf85 100644 (file)
--- a/src/caffe/syncedmem.cpp
+++ b/src/caffe/syncedmem.cpp
}
}
-
const void* SyncedMemory::cpu_data() {
to_cpu();
return (const void*)cpu_ptr_;
index d6a4ca5fe49f7df2114e960e8b56c20f640b8b88..d908859422f29652a9fc542f8c441428f02bbb91 100644 (file)
//solver_param.set_power(0.75);
solver_param.set_momentum(0.9);
solver_param.set_weight_decay(0.0005);
+ solver_param.set_snapshot(100);
+ solver_param.set_snapshot_prefix("alexnet");
LOG(ERROR) << "Starting Optimization";
SGDSolver<float> solver(solver_param);