0

Solution below

If you are just interested in solving this problem, you can skip to my answer below.

Original question

I'm using tensorflow for reinforcement learning. A swarm of agents uses the model in parallel and one central entity trains it on the collected data.

I had found here: Is it thread-safe when using tf.Session in inference service? that tensorflow sessions are threadsafe. So I simply let the prediction and updating run in parallel.

But now I would like to change the setup. Instead of updating and training on one single model, I now need to keep two models. One is used for prediction and the second one is trained. After some training steps the weights from the second one are copied over to the first. Below is a minimal example in keras. For multiprocessing, it is recommended to finalize the graph, but then I can't copy weights:

# the usual imports
import numpy as np
import tensorflow as tf

from keras.models import *
from keras.layers import *

# set up the first model
i = Input(shape=(10,))
b = Dense(1)(i)
prediction_model = Model(inputs=i, outputs=b)

# set up the second model
i2 = Input(shape=(10,))
b2 = Dense(1)(i2)
training_model = Model(inputs=i2, outputs=b2)

# look at this code, to check if the weights are the same
# here the output is different
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))

# now to use them in multiprocessing, the following is necessary
prediction_model._make_predict_function()
training_model._make_predict_function()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
default_graph = tf.get_default_graph()

# the following line is the critical part
# if this is uncommented, the two options below both fail
# default_graph.finalize()

# option 1, use keras methods to update the weights
prediction_model.set_weights(training_model.get_weights())

# option 2, use tensorflow to update the weights
update_ops = [tf.assign(to_var, from_var) for to_var, from_var in
              zip(prediction_model.trainable_weights, training_model.trainable_weights)]
sess.run(update_ops)

# now the predictions are the same
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))

According to the question above, it is recommended to finalize the graph. If it is not finalized, there can be memory leaks (!?), so that seems like a strong recommendation.

But if I finalize it, I can no longer update the weights. What confuses me about this is: It is possible to train the network, so changing the weights is allowed. Assignment looks to me like the weights are just overwritten, why is this different from applying an optimizer step ?

lhk
  • 27,458
  • 30
  • 122
  • 201

1 Answers1

1

In short, my problem was to assign values to weights of a finalized graph. If this assignment is done after finalization, tensorflow complains that the graph can no longer be changed.

I was confused why this is forbidden. After all, changing the weights by backpropagation is allowed.

But the problem is not related to changing the weights. Keras set_weights() is confusing because it looks as if the weights are simply overwritten (like in backprop). Actually, behind the scenes, assignment operations are added and executed. These new operations represent a change in the graph and that change is forbidden.

So the solution is to set up the assignment operations before finalizing the graph. You have to reorder the code:

# the usual imports
import numpy as np
import tensorflow as tf

from keras.models import *
from keras.layers import *

# set up the first model
i = Input(shape=(10,))
b = Dense(1)(i)
prediction_model = Model(inputs=i, outputs=b)

# set up the second model
i2 = Input(shape=(10,))
b2 = Dense(1)(i2)
training_model = Model(inputs=i2, outputs=b2)

# set up operations to move weights from training to prediction
update_ops = [tf.assign(to_var, from_var) for to_var, from_var in
              zip(prediction_model.trainable_weights, training_model.trainable_weights)]

# now to use them in multiprocessing, the following is necessary
prediction_model._make_predict_function()
training_model._make_predict_function()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
default_graph = tf.get_default_graph()

default_graph.finalize()

# this can be executed now
sess.run(update_ops)

# now the predictions are the same
prediction_model.predict(np.ones((1, 10)))
training_model.predict(np.ones((1, 10)))
lhk
  • 27,458
  • 30
  • 122
  • 201