Change tf_utils.py get_weights to evaluate all tensors at once rather than calling tensor.eval per-tensor. (#8491)

This commit is contained in:
internetcoffeephone 2020-05-19 07:06:03 +02:00 committed by GitHub
parent 6c5ea32857
commit a73c488c74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -161,10 +161,7 @@ class TensorFlowVariables:
Dictionary mapping variable names to their weights.
"""
self._check_sess()
return {
k: v.eval(session=self.sess)
for k, v in self.variables.items()
}
return self.sess.run(self.variables)
def set_weights(self, new_weights):
"""Sets the weights to new_weights.