mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Issue 8412 (Adam vars not stored in ModelV2). (#8480)
This commit is contained in:
parent
e62c1d2051
commit
25c0974543
9 changed files with 203 additions and 77 deletions
|
@ -1,7 +1,8 @@
|
|||
from collections import deque, OrderedDict
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.rllib.utils import force_list
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -47,8 +48,7 @@ class TensorFlowVariables:
|
|||
list.
|
||||
"""
|
||||
self.sess = sess
|
||||
if not isinstance(output, (list, tuple)):
|
||||
output = [output]
|
||||
output = force_list(output)
|
||||
queue = deque(output)
|
||||
variable_names = []
|
||||
explored_inputs = set(output)
|
||||
|
|
|
@ -64,6 +64,12 @@ class ARSTFPolicy:
|
|||
action += np.random.randn(*action.shape) * self.action_noise_std
|
||||
return action
|
||||
|
||||
def get_state(self):
|
||||
return {"state": self.get_flat_weights()}
|
||||
|
||||
def set_state(self, state):
|
||||
return self.set_flat_weights(state["state"])
|
||||
|
||||
def set_flat_weights(self, x):
|
||||
self.variables.set_flat(x)
|
||||
|
||||
|
|
|
@ -120,6 +120,12 @@ class ESTFPolicy:
|
|||
self.action_noise_std
|
||||
return single_action
|
||||
|
||||
def get_state(self):
|
||||
return {"state": self.get_flat_weights()}
|
||||
|
||||
def set_state(self, state):
|
||||
return self.set_flat_weights(state["state"])
|
||||
|
||||
def set_flat_weights(self, x):
|
||||
self.variables.set_flat(x)
|
||||
|
||||
|
|
|
@ -332,10 +332,12 @@ def build_eager_tf_policy(name,
|
|||
"is_training": tf.constant(False),
|
||||
}
|
||||
if obs_include_prev_action_reward:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
tf.convert_to_tensor(prev_action_batch)
|
||||
input_dict[SampleBatch.PREV_REWARDS] = \
|
||||
tf.convert_to_tensor(prev_reward_batch)
|
||||
if prev_action_batch is not None:
|
||||
input_dict[SampleBatch.PREV_ACTIONS] = \
|
||||
tf.convert_to_tensor(prev_action_batch)
|
||||
if prev_reward_batch is not None:
|
||||
input_dict[SampleBatch.PREV_REWARDS] = \
|
||||
tf.convert_to_tensor(prev_reward_batch)
|
||||
|
||||
# Use Exploration object.
|
||||
with tf.variable_creator_scope(_disallow_var_creation):
|
||||
|
@ -464,6 +466,29 @@ def build_eager_tf_policy(name,
|
|||
for v, w in zip(variables, weights):
|
||||
v.assign(w)
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
state = {"_state": super().get_state()}
|
||||
state["_optimizer_variables"] = self._optimizer.variables()
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
state = state.copy() # shallow copy
|
||||
# Set optimizer vars first.
|
||||
optimizer_vars = state.pop("_optimizer_variables", None)
|
||||
if optimizer_vars and self._optimizer.variables():
|
||||
logger.warning(
|
||||
"Cannot restore an optimizer's state for tf eager! Keras "
|
||||
"is not able to save the v1.x optimizers (from "
|
||||
"tf.compat.v1.train) since they aren't compatible with "
|
||||
"checkpoints.")
|
||||
for opt_var, value in zip(self._optimizer.variables(),
|
||||
optimizer_vars):
|
||||
opt_var.assign(value)
|
||||
# Then the Policy's (NN) weights.
|
||||
super().set_state(state["_state"])
|
||||
|
||||
def variables(self):
|
||||
"""Return the list of all savable variables for this policy."""
|
||||
return self.model.variables()
|
||||
|
|
|
@ -139,23 +139,11 @@ class TFPolicy(Policy):
|
|||
self._action_input = action_input # For logp calculations.
|
||||
self._dist_inputs = dist_inputs
|
||||
self.dist_class = dist_class
|
||||
self._log_likelihood = log_likelihood
|
||||
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
self._seq_lens = seq_lens
|
||||
self._max_seq_len = max_seq_len
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
self._update_ops = update_ops
|
||||
self._stats_fetches = {}
|
||||
self._loss_input_dict = None
|
||||
self._timestep = timestep if timestep is not None else \
|
||||
tf.placeholder(tf.int32, (), name="timestep")
|
||||
|
||||
if loss is not None:
|
||||
self._initialize_loss(loss, loss_inputs)
|
||||
else:
|
||||
self._loss = None
|
||||
|
||||
if len(self._state_inputs) != len(self._state_outputs):
|
||||
raise ValueError(
|
||||
"Number of state input and output tensors must match, got: "
|
||||
|
@ -169,9 +157,34 @@ class TFPolicy(Policy):
|
|||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
self._update_ops = update_ops
|
||||
self._apply_op = None
|
||||
self._stats_fetches = {}
|
||||
self._timestep = timestep if timestep is not None else \
|
||||
tf.placeholder(tf.int32, (), name="timestep")
|
||||
|
||||
self._optimizer = None
|
||||
self._grads_and_vars = None
|
||||
self._grads = None
|
||||
# Policy tf-variables (weights), whose values to get/set via
|
||||
# get_weights/set_weights.
|
||||
self._variables = None
|
||||
# Local optimizer's tf-variables (e.g. state vars for Adam).
|
||||
# Will be stored alongside `self._variables` when checkpointing.
|
||||
self._optimizer_variables = None
|
||||
|
||||
# The loss tf-op.
|
||||
self._loss = None
|
||||
# A batch dict passed into loss function as input.
|
||||
self._loss_input_dict = None
|
||||
if loss is not None:
|
||||
self._initialize_loss(loss, loss_inputs)
|
||||
|
||||
# The log-likelihood calculator op.
|
||||
self._log_likelihood = None
|
||||
if self._dist_inputs is not None and self.dist_class is not None:
|
||||
self._log_likelihood = log_likelihood
|
||||
if self._log_likelihood is None and self._dist_inputs is not None and \
|
||||
self.dist_class is not None:
|
||||
self._log_likelihood = self.dist_class(
|
||||
self._dist_inputs, self.model).logp(self._action_input)
|
||||
|
||||
|
@ -250,6 +263,11 @@ class TFPolicy(Policy):
|
|||
summarize(self._loss_input_dict)))
|
||||
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
self._optimizer_variables = None
|
||||
if self._optimizer:
|
||||
self._optimizer_variables = \
|
||||
ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self._optimizer.variables(), self._sess)
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
|
@ -355,6 +373,26 @@ class TFPolicy(Policy):
|
|||
def set_weights(self, weights):
|
||||
return self._variables.set_weights(weights)
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
# For tf Policies, return Policy weights and optimizer var values.
|
||||
state = super().get_state()
|
||||
if self._optimizer_variables and \
|
||||
len(self._optimizer_variables.variables) > 0:
|
||||
state["_optimizer_variables"] = \
|
||||
self._sess.run(self._optimizer_variables.variables)
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
state = state.copy() # shallow copy
|
||||
# Set optimizer vars first.
|
||||
optimizer_vars = state.pop("_optimizer_variables", None)
|
||||
if optimizer_vars:
|
||||
self._optimizer_variables.set_weights(optimizer_vars)
|
||||
# Then the Policy's (NN) weights.
|
||||
super().set_state(state)
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
"""Export tensorflow graph to export_dir for serving."""
|
||||
|
@ -441,7 +479,7 @@ class TFPolicy(Policy):
|
|||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
if hasattr(self, "config"):
|
||||
return tf.train.AdamOptimizer(self.config["lr"])
|
||||
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
|
||||
else:
|
||||
return tf.train.AdamOptimizer()
|
||||
|
||||
|
@ -686,7 +724,7 @@ class LearningRateSchedule:
|
|||
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
return tf.train.AdamOptimizer(learning_rate=self.cur_lr)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
|
|
@ -323,6 +323,26 @@ class TorchPolicy(Policy):
|
|||
weights = convert_to_torch_tensor(weights, device=self.device)
|
||||
self.model.load_state_dict(weights)
|
||||
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
state = super().get_state()
|
||||
state["_optimizer_variables"] = []
|
||||
for i, o in enumerate(self._optimizers):
|
||||
state["_optimizer_variables"].append(o.state_dict())
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
state = state.copy() # shallow copy
|
||||
# Set optimizer vars first.
|
||||
optimizer_vars = state.pop("_optimizer_variables", None)
|
||||
if optimizer_vars:
|
||||
assert len(optimizer_vars) == len(self._optimizers)
|
||||
for o, s in zip(self._optimizers, optimizer_vars):
|
||||
o.load_state_dict(s)
|
||||
# Then the Policy's (NN) weights.
|
||||
super().set_state(state)
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
return len(self.model.get_initial_state()) > 0
|
||||
|
|
|
@ -5,7 +5,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator
|
||||
|
||||
|
||||
def get_mean_action(alg, obs):
|
||||
|
@ -63,45 +63,67 @@ CONFIGS = {
|
|||
}
|
||||
|
||||
|
||||
def ckpt_restore_test(use_object_store, alg_name, failures, framework="tf"):
|
||||
cls = get_agent_class(alg_name)
|
||||
def ckpt_restore_test(alg_name, tfe=False):
|
||||
config = CONFIGS[alg_name]
|
||||
config["framework"] = framework
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
alg1 = cls(config=config, env="Pendulum-v0")
|
||||
alg2 = cls(config=config, env="Pendulum-v0")
|
||||
else:
|
||||
alg1 = cls(config=config, env="CartPole-v0")
|
||||
alg2 = cls(config=config, env="CartPole-v0")
|
||||
frameworks = (["tfe"] if tfe else []) + ["torch", "tf"]
|
||||
for fw in framework_iterator(config, frameworks=frameworks):
|
||||
for use_object_store in [False, True]:
|
||||
print("use_object_store={}".format(use_object_store))
|
||||
cls = get_agent_class(alg_name)
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
alg1 = cls(config=config, env="Pendulum-v0")
|
||||
alg2 = cls(config=config, env="Pendulum-v0")
|
||||
else:
|
||||
alg1 = cls(config=config, env="CartPole-v0")
|
||||
alg2 = cls(config=config, env="CartPole-v0")
|
||||
|
||||
policy1 = alg1.get_policy()
|
||||
policy1 = alg1.get_policy()
|
||||
|
||||
for _ in range(1):
|
||||
res = alg1.train()
|
||||
print("current status: " + str(res))
|
||||
for _ in range(1):
|
||||
res = alg1.train()
|
||||
print("current status: " + str(res))
|
||||
|
||||
# Sync the models
|
||||
if use_object_store:
|
||||
alg2.restore_from_object(alg1.save_to_object())
|
||||
else:
|
||||
alg2.restore(alg1.save())
|
||||
# Check optimizer state as well.
|
||||
optim_state = policy1.get_state().get("_optimizer_variables")
|
||||
|
||||
for _ in range(1):
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=3),
|
||||
policy1.observation_space.low,
|
||||
policy1.observation_space.high)
|
||||
else:
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=4),
|
||||
policy1.observation_space.low,
|
||||
policy1.observation_space.high)
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
if abs(a1 - a2) > .1:
|
||||
failures.append((alg_name, [a1, a2]))
|
||||
# Sync the models
|
||||
if use_object_store:
|
||||
alg2.restore_from_object(alg1.save_to_object())
|
||||
else:
|
||||
alg2.restore(alg1.save())
|
||||
|
||||
# Compare optimizer state with re-loaded one.
|
||||
if optim_state:
|
||||
s2 = alg2.get_policy().get_state().get("_optimizer_variables")
|
||||
# Tf -> Compare states 1:1.
|
||||
if fw in ["tf", "tfe"]:
|
||||
check(s2, optim_state)
|
||||
# For torch, optimizers have state_dicts with keys=params,
|
||||
# which are different for the two models (ignore these
|
||||
# different keys, but compare all values nevertheless).
|
||||
else:
|
||||
for i, s2_ in enumerate(s2):
|
||||
check(
|
||||
list(s2_["state"].values()),
|
||||
list(optim_state[i]["state"].values()))
|
||||
|
||||
for _ in range(1):
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=3),
|
||||
policy1.observation_space.low,
|
||||
policy1.observation_space.high)
|
||||
else:
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=4),
|
||||
policy1.observation_space.low,
|
||||
policy1.observation_space.high)
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
if abs(a1 - a2) > .1:
|
||||
raise AssertionError("algo={} [a1={} a2={}]".format(
|
||||
alg_name, a1, a2))
|
||||
|
||||
|
||||
class TestCheckpointRestore(unittest.TestCase):
|
||||
|
@ -113,21 +135,29 @@ class TestCheckpointRestore(unittest.TestCase):
|
|||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_checkpoint_restore(self):
|
||||
failures = []
|
||||
for fw in framework_iterator(frameworks=("tf", "torch")):
|
||||
for use_object_store in [False, True]:
|
||||
for name in [
|
||||
"A3C", "APEX_DDPG", "ARS", "DDPG", "DQN", "ES", "PPO",
|
||||
"SAC"
|
||||
]:
|
||||
print("Testing algo={} (use_object_store={})".format(
|
||||
name, use_object_store))
|
||||
ckpt_restore_test(
|
||||
use_object_store, name, failures, framework=fw)
|
||||
def test_a3c_checkpoint_restore(self):
|
||||
ckpt_restore_test("A3C")
|
||||
|
||||
assert not failures, failures
|
||||
print("All checkpoint restore tests passed!")
|
||||
def test_apex_ddpg_checkpoint_restore(self):
|
||||
ckpt_restore_test("APEX_DDPG")
|
||||
|
||||
def test_ars_checkpoint_restore(self):
|
||||
ckpt_restore_test("ARS")
|
||||
|
||||
def test_ddpg_checkpoint_restore(self):
|
||||
ckpt_restore_test("DDPG")
|
||||
|
||||
def test_dqn_checkpoint_restore(self):
|
||||
ckpt_restore_test("DQN")
|
||||
|
||||
def test_es_checkpoint_restore(self):
|
||||
ckpt_restore_test("ES")
|
||||
|
||||
def test_ppo_checkpoint_restore(self):
|
||||
ckpt_restore_test("PPO")
|
||||
|
||||
def test_sac_checkpoint_restore(self):
|
||||
ckpt_restore_test("SAC")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -14,9 +14,10 @@ class LocalModeTest(unittest.TestCase):
|
|||
|
||||
def test_local(self):
|
||||
cf = DEFAULT_CONFIG.copy()
|
||||
for fw in framework_iterator(cf):
|
||||
for _ in framework_iterator(cf):
|
||||
agent = PPOTrainer(cf, "CartPole-v0")
|
||||
print(agent.train())
|
||||
agent.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -58,9 +58,9 @@ def minimize_and_clip(optimizer, clip_val=10):
|
|||
torch.nn.utils.clip_grad_norm_(p.grad, clip_val)
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen, dtype=None):
|
||||
"""
|
||||
Exact same behavior as tf.sequence_mask.
|
||||
def sequence_mask(lengths, maxlen=None, dtype=None):
|
||||
"""Offers same behavior as tf.sequence_mask for torch.
|
||||
|
||||
Thanks to Dimitris Papatheodorou
|
||||
(https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
|
||||
39036).
|
||||
|
|
Loading…
Add table
Reference in a new issue