mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
Added option for user to not pass in the session and error messages if so (#192)
* Added option for user to not pass in the session * Small changes.
This commit is contained in:
parent
ab3448a9b4
commit
b9d6135aa1
1 changed files with 14 additions and 5 deletions
|
@ -3,17 +3,20 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
class TensorFlowVariables(object):
|
||||
"""An object used to extract variables from a loss function, and provide
|
||||
methods for getting and setting the weights of said variables.
|
||||
"""An object used to extract variables from a loss function.
|
||||
|
||||
This object also provides methods for getting and setting the weights of the
|
||||
relevant variables.
|
||||
|
||||
Attributes:
|
||||
sess (tf.Session): The tensorflow session used to run assignment.
|
||||
loss: The loss function passed in by the user.
|
||||
variables (List[tf.Variable]): Extracted variables from the loss.
|
||||
assignment_placeholders (List[tf.placeholders]): The nodes that weights get passed to.
|
||||
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
|
||||
passed to.
|
||||
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
|
||||
"""
|
||||
def __init__(self, loss, sess):
|
||||
def __init__(self, loss, sess=None):
|
||||
"""Creates a TensorFlowVariables instance."""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
|
@ -28,10 +31,16 @@ class TensorFlowVariables(object):
|
|||
self.assignment_placeholders[var.op.node_def.name] = tf.placeholder(var.value().dtype, var.get_shape().as_list())
|
||||
self.assignment_nodes.append(var.assign(self.assignment_placeholders[var.op.node_def.name]))
|
||||
|
||||
def set_session(self, sess):
|
||||
"""Modifies the current session used by the class."""
|
||||
self.sess = sess
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns the weights of the variables of the loss function in a list."""
|
||||
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
|
||||
return {v.op.node_def.name: v.eval(session=self.sess) for v in self.variables}
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights."""
|
||||
assert self.sess is not None, "The session is not set. Set the session either by passing it into the TensorFlowVariables constructor or by calling set_session(sess)."
|
||||
self.sess.run(self.assignment_nodes, feed_dict={self.assignment_placeholders[name]: value for (name, value) in new_weights.items()})
|
||||
|
|
Loading…
Add table
Reference in a new issue