index 87f2c2ccf3b85ed642607d24c7e4d3e5af0c8eea..e1be61c35e15780691c4b9ae34115857707fa82d 100644 (file)
repeated float diff = 6 [packed=true];
}
+// The BlobProtoVector is simply a way to pass multiple blobproto instances
+// around.
+message BlobProtoVector {
+ repeated BlobProto blobs = 1;
+}
+
message Datum {
optional int32 channels = 1;
optional int32 height = 2;
enum PoolMethod {
MAX = 0;
AVE = 1;
+ STOCHASTIC = 2;
}
optional PoolMethod pool = 11 [default = MAX]; // The pooling method
optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
// For data layers, specify the data source
optional string source = 16;
- // For data pre-processing, we can do simple scaling and constant subtraction
+ // For data pre-processing, we can do simple scaling and subtracting the
+ // data mean, if provided. Note that the mean subtraction is always carried
+ // out before scaling.
optional float scale = 17 [ default = 1 ];
- optional float subtraction = 18 [ default = 0 ];
+ optional string meanfile = 18;
// For data layers, specify the batch size.
optional uint32 batchsize = 19;
// For data layers, specify if we would like to randomly crop an image.
// The ratio that is multiplied on the global learning rate. If you want to set
// the learning ratio for one blob, you need to set it for all blobs.
repeated float blobs_lr = 51;
+ // The weight decay that is multiplied on the global weight decay.
+ repeated float weight_decay = 52;
+
+ // The rand_skip variable is for the data layer to skip a few data points
+ // to avoid all asynchronous sgd clients to start at the same point. The skip
+ // point would be set as rand_skip * rand(0,1). Note that rand_skip should not
+ // be larger than the number of keys in the leveldb.
+ optional uint32 rand_skip = 53 [ default = 0 ];
}
message LayerConnection {
message NetParameter {
optional string name = 1; // consider giving the network a name
repeated LayerConnection layers = 2; // a bunch of layers.
- repeated string input = 3; // The input to the network
+ // The input blobs to the network.
+ repeated string input = 3;
+ // The dim of the input blobs. For each input blob there should be four
+ // values specifying the num, channels, height and width of the input blob.
+ // Thus, there should be a total of (4 * #input) numbers.
+ repeated int32 input_dim = 4;
+ // Whether the network will force every layer to carry out backward operation.
+ // If set False, then whether to carry out backward is determined
+ // automatically according to the net structure and learning rates.
+ optional bool force_backward = 5 [ default = false ];
}
message SolverParameter {
- optional float base_lr = 1; // The base learning rate
+ optional string train_net = 1; // The proto file for the training net.
+ optional string test_net = 2; // The proto file for the testing net.
+ // The number of iterations for each testing phase.
+ optional int32 test_iter = 3 [ default = 0 ];
+ // The number of iterations between two testing phases.
+ optional int32 test_interval = 4 [ default = 0 ];
+ optional float base_lr = 5; // The base learning rate
// the number of iterations between displaying info. If display = 0, no info
// will be displayed.
- optional int32 display = 2;
- optional int32 max_iter = 3; // the maximum number of iterations
- optional int32 snapshot = 4 [default = 0]; // The snapshot interval
- optional string lr_policy = 5; // The learning rate decay policy.
- optional float min_lr = 6 [default = 0]; // The mininum learning rate
- optional float max_lr = 7 [default = 1e10]; // The maximum learning rate
- optional float gamma = 8; // The parameter to compute the learning rate.
- optional float power = 9; // The parameter to compute the learning rate.
- optional float momentum = 10; // The momentum value.
- optional float weight_decay = 11; // The weight decay.
- optional float stepsize = 12; // the stepsize for learning rate policy "step"
-
- optional string snapshot_prefix = 13; // The prefix for the snapshot.
+ optional int32 display = 6;
+ optional int32 max_iter = 7; // the maximum number of iterations
+ optional string lr_policy = 8; // The learning rate decay policy.
+ optional float gamma = 9; // The parameter to compute the learning rate.
+ optional float power = 10; // The parameter to compute the learning rate.
+ optional float momentum = 11; // The momentum value.
+ optional float weight_decay = 12; // The weight decay.
+ optional int32 stepsize = 13; // the stepsize for learning rate policy "step"
+ optional int32 snapshot = 14 [default = 0]; // The snapshot interval
+ optional string snapshot_prefix = 15; // The prefix for the snapshot.
+ // whether to snapshot diff in the results or not. Snapshotting diff will help
+ // debugging but the final protocol buffer size will be much larger.
+ optional bool snapshot_diff = 16 [ default = false];
+ // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
+ optional int32 solver_mode = 17 [default = 1];
+}
+
+// A message that stores the solver snapshots
+message SolverState {
+ optional int32 iter = 1; // The current iteration
+ optional string learned_net = 2; // The file that stores the learned net.
+ repeated BlobProto history = 3; // The history for sgd solvers
}