I am trying to retrain the last layer of inception-resnet-v2. Here's what I came up with:
- Get names of variables in the final layer
- Create a
train_op
to minimise only these variables wrt loss - Restore the whole graph except the final layer while initialising only the last layer randomly.
And I implemented that as follows:
with slim.arg_scope(arg_scope): logits = model(images_ph, is_training=True, reuse=None) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels_ph)) accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, 1), labels_ph) train_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'InceptionResnetV2/Logits') optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) train_op = optimizer.minimize(loss, var_list=train_list) # restore all variables whose names doesn't contain 'logits' restore_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='^((?!Logits).)*$') saver = tf.train.Saver(restore_list, write_version=tf.train.SaverDef.V2) with tf.Session() as session: init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer()) session.run(init_op) saver.restore(session, '../models/inception_resnet_v2_2016_08_30.ckpt') # followed by code for running train_op
This doesn't seem to work (training loss, error don't improve much from initial values). Is there a better/elegant way to do this? It would be good learning for me if you can also tell me what's going wrong here.
1 Answers
Answers 1
There are several things:
- how is the learning rate? a too high value can mess with everything (probably not the reason)
- try to use stochastic gradient descent, you should have less problems
is the scope correctly set? if you don't use L2 regularization and batch normalization of the gradients you might fall into a local minimum very soon and the network is unable to learn
from nets import inception_resnet_v2 as net with net.inception_resnet_v2_arg_scope(): logits, end_points = net.inception_resnet_v2(images_ph, num_classes=num_classes, is_training=True)
you should add the regularization variables to the loss (or at least the ones of the last layer):
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) all_losses = [loss] + regularization_losses total_loss = tf.add_n(all_losses, name='total_loss')
training only the full connected layer might not be a good idea, I would train all the network as the features you need for your class aren't necessarily defined in the last layer but few layers before and you need to change them.
double check the train_op runs after the loss:
with ops.name_scope('train_op'): train_op = control_flow_ops.with_dependencies([train_op], total_loss)
0 comments:
Post a Comment