[RLlib] Issue 8412 (Adam vars not stored in ModelV2). (#8480)

This commit is contained in:
Sven Mika 2020-06-05 21:07:02 +02:00 committed by GitHub
parent e62c1d2051
commit 25c0974543
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 203 additions and 77 deletions

View file

@ -1,7 +1,8 @@
from collections import deque, OrderedDict
import numpy as np
from ray.rllib.utils import try_import_tf
from ray.rllib.utils import force_list
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
@ -47,8 +48,7 @@ class TensorFlowVariables:
list.
"""
self.sess = sess
if not isinstance(output, (list, tuple)):
output = [output]
output = force_list(output)
queue = deque(output)
variable_names = []
explored_inputs = set(output)

View file

@ -64,6 +64,12 @@ class ARSTFPolicy:
action += np.random.randn(*action.shape) * self.action_noise_std
return action
def get_state(self):
return {"state": self.get_flat_weights()}
def set_state(self, state):
return self.set_flat_weights(state["state"])
def set_flat_weights(self, x):
self.variables.set_flat(x)

View file

@ -120,6 +120,12 @@ class ESTFPolicy:
self.action_noise_std
return single_action
def get_state(self):
return {"state": self.get_flat_weights()}
def set_state(self, state):
return self.set_flat_weights(state["state"])
def set_flat_weights(self, x):
self.variables.set_flat(x)

View file

@ -332,10 +332,12 @@ def build_eager_tf_policy(name,
"is_training": tf.constant(False),
}
if obs_include_prev_action_reward:
input_dict[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch)
input_dict[SampleBatch.PREV_REWARDS] = \
tf.convert_to_tensor(prev_reward_batch)
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = \
tf.convert_to_tensor(prev_action_batch)
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = \
tf.convert_to_tensor(prev_reward_batch)
# Use Exploration object.
with tf.variable_creator_scope(_disallow_var_creation):
@ -464,6 +466,29 @@ def build_eager_tf_policy(name,
for v, w in zip(variables, weights):
v.assign(w)
@override(Policy)
def get_state(self):
state = {"_state": super().get_state()}
state["_optimizer_variables"] = self._optimizer.variables()
return state
@override(Policy)
def set_state(self, state):
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
if optimizer_vars and self._optimizer.variables():
logger.warning(
"Cannot restore an optimizer's state for tf eager! Keras "
"is not able to save the v1.x optimizers (from "
"tf.compat.v1.train) since they aren't compatible with "
"checkpoints.")
for opt_var, value in zip(self._optimizer.variables(),
optimizer_vars):
opt_var.assign(value)
# Then the Policy's (NN) weights.
super().set_state(state["_state"])
def variables(self):
"""Return the list of all savable variables for this policy."""
return self.model.variables()

View file

@ -139,23 +139,11 @@ class TFPolicy(Policy):
self._action_input = action_input # For logp calculations.
self._dist_inputs = dist_inputs
self.dist_class = dist_class
self._log_likelihood = log_likelihood
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
self._seq_lens = seq_lens
self._max_seq_len = max_seq_len
self._batch_divisibility_req = batch_divisibility_req
self._update_ops = update_ops
self._stats_fetches = {}
self._loss_input_dict = None
self._timestep = timestep if timestep is not None else \
tf.placeholder(tf.int32, (), name="timestep")
if loss is not None:
self._initialize_loss(loss, loss_inputs)
else:
self._loss = None
if len(self._state_inputs) != len(self._state_outputs):
raise ValueError(
"Number of state input and output tensors must match, got: "
@ -169,9 +157,34 @@ class TFPolicy(Policy):
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")
self._batch_divisibility_req = batch_divisibility_req
self._update_ops = update_ops
self._apply_op = None
self._stats_fetches = {}
self._timestep = timestep if timestep is not None else \
tf.placeholder(tf.int32, (), name="timestep")
self._optimizer = None
self._grads_and_vars = None
self._grads = None
# Policy tf-variables (weights), whose values to get/set via
# get_weights/set_weights.
self._variables = None
# Local optimizer's tf-variables (e.g. state vars for Adam).
# Will be stored alongside `self._variables` when checkpointing.
self._optimizer_variables = None
# The loss tf-op.
self._loss = None
# A batch dict passed into loss function as input.
self._loss_input_dict = None
if loss is not None:
self._initialize_loss(loss, loss_inputs)
# The log-likelihood calculator op.
self._log_likelihood = None
if self._dist_inputs is not None and self.dist_class is not None:
self._log_likelihood = log_likelihood
if self._log_likelihood is None and self._dist_inputs is not None and \
self.dist_class is not None:
self._log_likelihood = self.dist_class(
self._dist_inputs, self.model).logp(self._action_input)
@ -250,6 +263,11 @@ class TFPolicy(Policy):
summarize(self._loss_input_dict)))
self._sess.run(tf.global_variables_initializer())
self._optimizer_variables = None
if self._optimizer:
self._optimizer_variables = \
ray.experimental.tf_utils.TensorFlowVariables(
self._optimizer.variables(), self._sess)
@override(Policy)
def compute_actions(self,
@ -355,6 +373,26 @@ class TFPolicy(Policy):
def set_weights(self, weights):
return self._variables.set_weights(weights)
@override(Policy)
def get_state(self):
# For tf Policies, return Policy weights and optimizer var values.
state = super().get_state()
if self._optimizer_variables and \
len(self._optimizer_variables.variables) > 0:
state["_optimizer_variables"] = \
self._sess.run(self._optimizer_variables.variables)
return state
@override(Policy)
def set_state(self, state):
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
if optimizer_vars:
self._optimizer_variables.set_weights(optimizer_vars)
# Then the Policy's (NN) weights.
super().set_state(state)
@override(Policy)
def export_model(self, export_dir):
"""Export tensorflow graph to export_dir for serving."""
@ -441,7 +479,7 @@ class TFPolicy(Policy):
def optimizer(self):
"""TF optimizer to use for policy optimization."""
if hasattr(self, "config"):
return tf.train.AdamOptimizer(self.config["lr"])
return tf.train.AdamOptimizer(learning_rate=self.config["lr"])
else:
return tf.train.AdamOptimizer()
@ -686,7 +724,7 @@ class LearningRateSchedule:
@override(TFPolicy)
def optimizer(self):
return tf.train.AdamOptimizer(self.cur_lr)
return tf.train.AdamOptimizer(learning_rate=self.cur_lr)
@DeveloperAPI

View file

@ -323,6 +323,26 @@ class TorchPolicy(Policy):
weights = convert_to_torch_tensor(weights, device=self.device)
self.model.load_state_dict(weights)
@override(Policy)
def get_state(self):
state = super().get_state()
state["_optimizer_variables"] = []
for i, o in enumerate(self._optimizers):
state["_optimizer_variables"].append(o.state_dict())
return state
@override(Policy)
def set_state(self, state):
state = state.copy() # shallow copy
# Set optimizer vars first.
optimizer_vars = state.pop("_optimizer_variables", None)
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
o.load_state_dict(s)
# Then the Policy's (NN) weights.
super().set_state(state)
@override(Policy)
def is_recurrent(self):
return len(self.model.get_initial_state()) > 0

View file

@ -5,7 +5,7 @@ import unittest
import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.utils.test_utils import framework_iterator
from ray.rllib.utils.test_utils import check, framework_iterator
def get_mean_action(alg, obs):
@ -63,45 +63,67 @@ CONFIGS = {
}
def ckpt_restore_test(use_object_store, alg_name, failures, framework="tf"):
cls = get_agent_class(alg_name)
def ckpt_restore_test(alg_name, tfe=False):
config = CONFIGS[alg_name]
config["framework"] = framework
if "DDPG" in alg_name or "SAC" in alg_name:
alg1 = cls(config=config, env="Pendulum-v0")
alg2 = cls(config=config, env="Pendulum-v0")
else:
alg1 = cls(config=config, env="CartPole-v0")
alg2 = cls(config=config, env="CartPole-v0")
frameworks = (["tfe"] if tfe else []) + ["torch", "tf"]
for fw in framework_iterator(config, frameworks=frameworks):
for use_object_store in [False, True]:
print("use_object_store={}".format(use_object_store))
cls = get_agent_class(alg_name)
if "DDPG" in alg_name or "SAC" in alg_name:
alg1 = cls(config=config, env="Pendulum-v0")
alg2 = cls(config=config, env="Pendulum-v0")
else:
alg1 = cls(config=config, env="CartPole-v0")
alg2 = cls(config=config, env="CartPole-v0")
policy1 = alg1.get_policy()
policy1 = alg1.get_policy()
for _ in range(1):
res = alg1.train()
print("current status: " + str(res))
for _ in range(1):
res = alg1.train()
print("current status: " + str(res))
# Sync the models
if use_object_store:
alg2.restore_from_object(alg1.save_to_object())
else:
alg2.restore(alg1.save())
# Check optimizer state as well.
optim_state = policy1.get_state().get("_optimizer_variables")
for _ in range(1):
if "DDPG" in alg_name or "SAC" in alg_name:
obs = np.clip(
np.random.uniform(size=3),
policy1.observation_space.low,
policy1.observation_space.high)
else:
obs = np.clip(
np.random.uniform(size=4),
policy1.observation_space.low,
policy1.observation_space.high)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
if abs(a1 - a2) > .1:
failures.append((alg_name, [a1, a2]))
# Sync the models
if use_object_store:
alg2.restore_from_object(alg1.save_to_object())
else:
alg2.restore(alg1.save())
# Compare optimizer state with re-loaded one.
if optim_state:
s2 = alg2.get_policy().get_state().get("_optimizer_variables")
# Tf -> Compare states 1:1.
if fw in ["tf", "tfe"]:
check(s2, optim_state)
# For torch, optimizers have state_dicts with keys=params,
# which are different for the two models (ignore these
# different keys, but compare all values nevertheless).
else:
for i, s2_ in enumerate(s2):
check(
list(s2_["state"].values()),
list(optim_state[i]["state"].values()))
for _ in range(1):
if "DDPG" in alg_name or "SAC" in alg_name:
obs = np.clip(
np.random.uniform(size=3),
policy1.observation_space.low,
policy1.observation_space.high)
else:
obs = np.clip(
np.random.uniform(size=4),
policy1.observation_space.low,
policy1.observation_space.high)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
if abs(a1 - a2) > .1:
raise AssertionError("algo={} [a1={} a2={}]".format(
alg_name, a1, a2))
class TestCheckpointRestore(unittest.TestCase):
@ -113,21 +135,29 @@ class TestCheckpointRestore(unittest.TestCase):
def tearDownClass(cls):
ray.shutdown()
def test_checkpoint_restore(self):
failures = []
for fw in framework_iterator(frameworks=("tf", "torch")):
for use_object_store in [False, True]:
for name in [
"A3C", "APEX_DDPG", "ARS", "DDPG", "DQN", "ES", "PPO",
"SAC"
]:
print("Testing algo={} (use_object_store={})".format(
name, use_object_store))
ckpt_restore_test(
use_object_store, name, failures, framework=fw)
def test_a3c_checkpoint_restore(self):
ckpt_restore_test("A3C")
assert not failures, failures
print("All checkpoint restore tests passed!")
def test_apex_ddpg_checkpoint_restore(self):
ckpt_restore_test("APEX_DDPG")
def test_ars_checkpoint_restore(self):
ckpt_restore_test("ARS")
def test_ddpg_checkpoint_restore(self):
ckpt_restore_test("DDPG")
def test_dqn_checkpoint_restore(self):
ckpt_restore_test("DQN")
def test_es_checkpoint_restore(self):
ckpt_restore_test("ES")
def test_ppo_checkpoint_restore(self):
ckpt_restore_test("PPO")
def test_sac_checkpoint_restore(self):
ckpt_restore_test("SAC")
if __name__ == "__main__":

View file

@ -14,9 +14,10 @@ class LocalModeTest(unittest.TestCase):
def test_local(self):
cf = DEFAULT_CONFIG.copy()
for fw in framework_iterator(cf):
for _ in framework_iterator(cf):
agent = PPOTrainer(cf, "CartPole-v0")
print(agent.train())
agent.stop()
if __name__ == "__main__":

View file

@ -58,9 +58,9 @@ def minimize_and_clip(optimizer, clip_val=10):
torch.nn.utils.clip_grad_norm_(p.grad, clip_val)
def sequence_mask(lengths, maxlen, dtype=None):
"""
Exact same behavior as tf.sequence_mask.
def sequence_mask(lengths, maxlen=None, dtype=None):
"""Offers same behavior as tf.sequence_mask for torch.
Thanks to Dimitris Papatheodorou
(https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
39036).