diff --git a/rllib/BUILD b/rllib/BUILD index 894f48720..2f142d5a1 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1406,7 +1406,7 @@ py_test( name = "examples/nested_action_spaces_ppo", main = "examples/nested_action_spaces.py", tags = ["examples", "examples_N"], - size = "small", + size = "medium", srcs = ["examples/nested_action_spaces.py"], args = ["--stop=-500", "--run=PPO"] ) diff --git a/rllib/agents/ars/ars_tf_policy.py b/rllib/agents/ars/ars_tf_policy.py index 3f8292cdf..b3f5dfa1f 100644 --- a/rllib/agents/ars/ars_tf_policy.py +++ b/rllib/agents/ars/ars_tf_policy.py @@ -36,15 +36,18 @@ class ARSTFPolicy: dist_class, dist_dim = ModelCatalog.get_action_dist( self.action_space, config["model"], dist_type="deterministic") - model = ModelCatalog.get_model({ - SampleBatch.CUR_OBS: self.inputs - }, self.observation_space, self.action_space, dist_dim, - config["model"]) - dist = dist_class(model.outputs, model) + self.model = ModelCatalog.get_model_v2( + obs_space=self.preprocessor.observation_space, + action_space=self.action_space, + num_outputs=dist_dim, + model_config=config["model"]) + dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs}) + dist = dist_class(dist_inputs, self.model) + self.sampler = dist.sample() self.variables = ray.experimental.tf_utils.TensorFlowVariables( - model.outputs, self.sess) + dist_inputs, self.sess) self.num_params = sum( np.prod(variable.shape.as_list()) diff --git a/rllib/agents/ddpg/OBSOLETED_ddpg_policy.py b/rllib/agents/ddpg/OBSOLETED_ddpg_policy.py deleted file mode 100644 index 71e0c3ca4..000000000 --- a/rllib/agents/ddpg/OBSOLETED_ddpg_policy.py +++ /dev/null @@ -1,403 +0,0 @@ -from gym.spaces import Box -import logging -import numpy as np - -import ray -import ray.experimental.tf_utils -from ray.rllib.agents.ddpg.ddpg_model import DDPGModel -from ray.rllib.agents.ddpg.noop_model import NoopModel -from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \ - PRIO_WEIGHTS -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models import ModelCatalog -from ray.rllib.models.tf.tf_action_dist import Deterministic -from ray.rllib.utils.annotations import override -from ray.rllib.utils.error import UnsupportedSpaceException -from ray.rllib.policy.tf_policy import TFPolicy -from ray.rllib.policy.tf_policy_template import build_tf_policy -from ray.rllib.utils import try_import_tf -from ray.rllib.utils.tf_ops import huber_loss, minimize_and_clip, \ - make_tf_callable - -tf = try_import_tf() - -logger = logging.getLogger(__name__) - -ACTION_SCOPE = "action" -POLICY_SCOPE = "policy" -POLICY_TARGET_SCOPE = "target_policy" -Q_SCOPE = "critic" -Q_TARGET_SCOPE = "target_critic" -TWIN_Q_SCOPE = "twin_critic" -TWIN_Q_TARGET_SCOPE = "twin_target_critic" - - -def build_ddpg_models(policy, observation_space, action_space, config): - if config["model"]["custom_model"]: - logger.warning( - "Setting use_state_preprocessor=True since a custom model " - "was specified.") - config["use_state_preprocessor"] = True - - if not isinstance(action_space, Box): - raise UnsupportedSpaceException( - "Action space {} is not supported for DDPG.".format(action_space)) - elif len(action_space.shape) > 1: - raise UnsupportedSpaceException( - "Action space has multiple dimensions " - "{}. ".format(action_space.shape) + - "Consider reshaping this into a single dimension, " - "using a Tuple action space, or the multi-agent API.") - - if policy.config["use_state_preprocessor"]: - default_model = None # catalog decides - num_outputs = 256 # arbitrary - config["model"]["no_final_linear"] = True - else: - default_model = NoopModel - num_outputs = int(np.product(observation_space.shape)) - - policy.model = ModelCatalog.get_model_v2( - obs_space=observation_space, - action_space=action_space, - num_outputs=num_outputs, - model_config=config["model"], - framework="tf", - model_interface=DDPGModel, - default_model=default_model, - name="ddpg_model", - actor_hidden_activation=config["actor_hidden_activation"], - actor_hiddens=config["actor_hiddens"], - critic_hidden_activation=config["critic_hidden_activation"], - critic_hiddens=config["critic_hiddens"], - twin_q=config["twin_q"], - add_layer_norm=(policy.config["exploration_config"].get("type") == - "ParameterNoise"), - ) - - policy.target_model = ModelCatalog.get_model_v2( - obs_space=observation_space, - action_space=action_space, - num_outputs=num_outputs, - model_config=config["model"], - framework="tf", - model_interface=DDPGModel, - default_model=default_model, - name="target_ddpg_model", - actor_hidden_activation=config["actor_hidden_activation"], - actor_hiddens=config["actor_hiddens"], - critic_hidden_activation=config["critic_hidden_activation"], - critic_hiddens=config["critic_hiddens"], - twin_q=config["twin_q"], - add_layer_norm=(policy.config["exploration_config"].get("type") == - "ParameterNoise"), - ) - - return policy.model - - -def get_distribution_inputs_and_class(policy, - model, - obs_batch, - *, - explore=True, - **kwargs): - model_out, _ = model({ - "obs": obs_batch, - "is_training": policy._get_is_training_placeholder() - }, [], None) - dist_inputs = model.get_policy_output(model_out) - - return dist_inputs, Deterministic, [] # []=state out - - -def ddpg_actor_critic_loss(policy, model, _, train_batch): - twin_q = policy.config["twin_q"] - gamma = policy.config["gamma"] - n_step = policy.config["n_step"] - use_huber = policy.config["use_huber"] - huber_threshold = policy.config["huber_threshold"] - l2_reg = policy.config["l2_reg"] - - input_dict = { - "obs": train_batch[SampleBatch.CUR_OBS], - "is_training": policy._get_is_training_placeholder(), - } - input_dict_next = { - "obs": train_batch[SampleBatch.NEXT_OBS], - "is_training": policy._get_is_training_placeholder(), - } - - model_out_t, _ = model(input_dict, [], None) - model_out_tp1, _ = model(input_dict_next, [], None) - target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) - - # Policy network evaluation. - with tf.variable_scope(POLICY_SCOPE, reuse=True): - # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - policy_t = model.get_policy_output(model_out_t) - # policy_batchnorm_update_ops = list( - # set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) - - with tf.variable_scope(POLICY_TARGET_SCOPE): - policy_tp1 = \ - policy.target_model.get_policy_output(target_model_out_tp1) - - # Action outputs. - with tf.variable_scope(ACTION_SCOPE, reuse=True): - if policy.config["smooth_target_policy"]: - target_noise_clip = policy.config["target_noise_clip"] - clipped_normal_sample = tf.clip_by_value( - tf.random_normal( - tf.shape(policy_tp1), - stddev=policy.config["target_noise"]), -target_noise_clip, - target_noise_clip) - policy_tp1_smoothed = tf.clip_by_value( - policy_tp1 + clipped_normal_sample, - policy.action_space.low * tf.ones_like(policy_tp1), - policy.action_space.high * tf.ones_like(policy_tp1)) - else: - # No smoothing, just use deterministic actions. - policy_tp1_smoothed = policy_tp1 - - # Q-net(s) evaluation. - # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - with tf.variable_scope(Q_SCOPE): - # Q-values for given actions & observations in given current - q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) - - with tf.variable_scope(Q_SCOPE, reuse=True): - # Q-values for current policy (no noise) in given current state - q_t_det_policy = model.get_q_values(model_out_t, policy_t) - - if twin_q: - with tf.variable_scope(TWIN_Q_SCOPE): - twin_q_t = model.get_twin_q_values( - model_out_t, train_batch[SampleBatch.ACTIONS]) - # q_batchnorm_update_ops = list( - # set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) - - # Target q-net(s) evaluation. - with tf.variable_scope(Q_TARGET_SCOPE): - q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, - policy_tp1_smoothed) - - if twin_q: - with tf.variable_scope(TWIN_Q_TARGET_SCOPE): - twin_q_tp1 = policy.target_model.get_twin_q_values( - target_model_out_tp1, policy_tp1_smoothed) - - q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) - if twin_q: - twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) - q_tp1 = tf.minimum(q_tp1, twin_q_tp1) - - q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) - q_tp1_best_masked = \ - (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \ - q_tp1_best - - # Compute RHS of bellman equation. - q_t_selected_target = tf.stop_gradient(train_batch[SampleBatch.REWARDS] + - gamma**n_step * q_tp1_best_masked) - - # Compute the error (potentially clipped). - if twin_q: - td_error = q_t_selected - q_t_selected_target - twin_td_error = twin_q_t_selected - q_t_selected_target - td_error = td_error + twin_td_error - if use_huber: - errors = huber_loss(td_error, huber_threshold) \ - + huber_loss(twin_td_error, huber_threshold) - else: - errors = 0.5 * tf.square(td_error) + 0.5 * tf.square(twin_td_error) - else: - td_error = q_t_selected - q_t_selected_target - if use_huber: - errors = huber_loss(td_error, huber_threshold) - else: - errors = 0.5 * tf.square(td_error) - - critic_loss = tf.reduce_mean(train_batch[PRIO_WEIGHTS] * errors) - actor_loss = -tf.reduce_mean(q_t_det_policy) - - # Add l2-regularization if required. - if l2_reg is not None: - for var in policy.model.policy_variables(): - if "bias" not in var.name: - actor_loss += (l2_reg * tf.nn.l2_loss(var)) - for var in policy.model.q_variables(): - if "bias" not in var.name: - critic_loss += (l2_reg * tf.nn.l2_loss(var)) - - # Model self-supervised losses. - if policy.config["use_state_preprocessor"]: - # Expand input_dict in case custom_loss' need them. - input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] - input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] - input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] - input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] - actor_loss, critic_loss = model.custom_loss([actor_loss, critic_loss], - input_dict) - - # Store values for stats function. - policy.actor_loss = actor_loss - policy.critic_loss = critic_loss - policy.td_error = td_error - policy.q_t = q_t - - # Return one loss value (even though we treat them separately in our - # 2 optimizers: actor and critic). - return policy.critic_loss + policy.actor_loss - - -def make_ddpg_optimizers(policy, config): - # Create separate optimizers for actor & critic losses. - policy._actor_optimizer = tf.train.AdamOptimizer( - learning_rate=config["actor_lr"]) - policy._critic_optimizer = tf.train.AdamOptimizer( - learning_rate=config["critic_lr"]) - return None - - -def build_apply_op(policy, optimizer, grads_and_vars): - # For policy gradient, update policy net one time v.s. - # update critic net `policy_delay` time(s). - should_apply_actor_opt = tf.equal( - tf.mod(policy.global_step, policy.config["policy_delay"]), 0) - - def make_apply_op(): - return policy._actor_optimizer.apply_gradients( - policy._actor_grads_and_vars) - - actor_op = tf.cond( - should_apply_actor_opt, - true_fn=make_apply_op, - false_fn=lambda: tf.no_op()) - critic_op = policy._critic_optimizer.apply_gradients( - policy._critic_grads_and_vars) - # Increment global step & apply ops. - with tf.control_dependencies([tf.assign_add(policy.global_step, 1)]): - return tf.group(actor_op, critic_op) - - -def gradients_fn(policy, optimizer, loss): - if policy.config["grad_norm_clipping"] is not None: - actor_grads_and_vars = minimize_and_clip( - policy._actor_optimizer, - policy.actor_loss, - var_list=policy.model.policy_variables(), - clip_val=policy.config["grad_norm_clipping"]) - critic_grads_and_vars = minimize_and_clip( - policy._critic_optimizer, - policy.critic_loss, - var_list=policy.model.q_variables(), - clip_val=policy.config["grad_norm_clipping"]) - else: - actor_grads_and_vars = policy._actor_optimizer.compute_gradients( - policy.actor_loss, var_list=policy.model.policy_variables()) - critic_grads_and_vars = policy._critic_optimizer.compute_gradients( - policy.critic_loss, var_list=policy.model.q_variables()) - # Save these for later use in build_apply_op. - policy._actor_grads_and_vars = [(g, v) for (g, v) in actor_grads_and_vars - if g is not None] - policy._critic_grads_and_vars = [(g, v) for (g, v) in critic_grads_and_vars - if g is not None] - grads_and_vars = policy._actor_grads_and_vars + \ - policy._critic_grads_and_vars - return grads_and_vars - - -def build_ddpg_stats(policy, batch): - stats = { - "mean_q": tf.reduce_mean(policy.q_t), - "max_q": tf.reduce_max(policy.q_t), - "min_q": tf.reduce_min(policy.q_t), - } - return stats - - -def before_init_fn(policy, obs_space, action_space, config): - # Create global step for counting the number of update operations. - policy.global_step = tf.train.get_or_create_global_step() - - -class ComputeTDErrorMixin: - def __init__(self, loss_fn): - @make_tf_callable(self.get_session(), dynamic_shape=True) - def compute_td_error(obs_t, act_t, rew_t, obs_tp1, done_mask, - importance_weights): - # Do forward pass on loss to update td errors attribute - # (one TD-error value per item in batch to update PR weights). - loss_fn( - self, self.model, None, { - SampleBatch.CUR_OBS: tf.convert_to_tensor(obs_t), - SampleBatch.ACTIONS: tf.convert_to_tensor(act_t), - SampleBatch.REWARDS: tf.convert_to_tensor(rew_t), - SampleBatch.NEXT_OBS: tf.convert_to_tensor(obs_tp1), - SampleBatch.DONES: tf.convert_to_tensor(done_mask), - PRIO_WEIGHTS: tf.convert_to_tensor(importance_weights), - }) - # `self.td_error` is set in loss_fn. - return self.td_error - - self.compute_td_error = compute_td_error - - -def setup_mid_mixins(policy, obs_space, action_space, config): - ComputeTDErrorMixin.__init__(policy, ddpg_actor_critic_loss) - - -class TargetNetworkMixin: - def __init__(self, config): - @make_tf_callable(self.get_session()) - def update_target_fn(tau): - tau = tf.convert_to_tensor(tau, dtype=tf.float32) - update_target_expr = [] - model_vars = self.model.trainable_variables() - target_model_vars = self.target_model.trainable_variables() - assert len(model_vars) == len(target_model_vars), \ - (model_vars, target_model_vars) - for var, var_target in zip(model_vars, target_model_vars): - update_target_expr.append( - var_target.assign(tau * var + (1.0 - tau) * var_target)) - logger.debug("Update target op {}".format(var_target)) - return tf.group(*update_target_expr) - - # Hard initial update. - self._do_update = update_target_fn - self.update_target(tau=1.0) - - # Support both hard and soft sync. - def update_target(self, tau=None): - self._do_update(np.float32(tau or self.config.get("tau"))) - - @override(TFPolicy) - def variables(self): - return self.model.variables() + self.target_model.variables() - - -def setup_late_mixins(policy, obs_space, action_space, config): - TargetNetworkMixin.__init__(policy, config) - - -DDPGTFPolicy = build_tf_policy( - name="DQNTFPolicy", - get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, - make_model=build_ddpg_models, - action_distribution_fn=get_distribution_inputs_and_class, - loss_fn=ddpg_actor_critic_loss, - stats_fn=build_ddpg_stats, - postprocess_fn=postprocess_nstep_and_prio, - optimizer_fn=make_ddpg_optimizers, - gradients_fn=gradients_fn, - apply_gradients_fn=build_apply_op, - extra_learn_fetches_fn=lambda policy: {"td_error": policy.td_error}, - before_init=before_init_fn, - before_loss_init=setup_mid_mixins, - after_init=setup_late_mixins, - obs_include_prev_action_reward=False, - mixins=[ - TargetNetworkMixin, - ComputeTDErrorMixin, - ]) diff --git a/rllib/agents/ddpg/noop_model.py b/rllib/agents/ddpg/noop_model.py index 8a1cd85b9..6c9145e1b 100644 --- a/rllib/agents/ddpg/noop_model.py +++ b/rllib/agents/ddpg/noop_model.py @@ -1,3 +1,4 @@ +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override @@ -11,7 +12,7 @@ class NoopModel(TFModelV2): This is the model used if use_state_preprocessor=False.""" - @override(TFModelV2) + @override(ModelV2) def forward(self, input_dict, state, seq_lens): return tf.cast(input_dict["obs_flat"], tf.float32), state @@ -21,6 +22,6 @@ class TorchNoopModel(TorchModelV2): This is the model used if use_state_preprocessor=False.""" - @override(TorchModelV2) + @override(ModelV2) def forward(self, input_dict, state, seq_lens): return input_dict["obs_flat"].float(), state diff --git a/rllib/agents/es/es_tf_policy.py b/rllib/agents/es/es_tf_policy.py index 40e33f3c5..82de66e65 100644 --- a/rllib/agents/es/es_tf_policy.py +++ b/rllib/agents/es/es_tf_policy.py @@ -82,14 +82,17 @@ class ESTFPolicy: # Policy network. dist_class, dist_dim = ModelCatalog.get_action_dist( self.action_space, config["model"], dist_type="deterministic") - model = ModelCatalog.get_model({ - SampleBatch.CUR_OBS: self.inputs - }, obs_space, action_space, dist_dim, config["model"]) - dist = dist_class(model.outputs, model) + self.model = ModelCatalog.get_model_v2( + obs_space=self.preprocessor.observation_space, + action_space=action_space, + num_outputs=dist_dim, + model_config=config["model"]) + dist_inputs, _ = self.model({SampleBatch.CUR_OBS: self.inputs}) + dist = dist_class(dist_inputs, self.model) self.sampler = dist.sample() self.variables = ray.experimental.tf_utils.TensorFlowVariables( - model.outputs, self.sess) + dist_inputs, self.sess) self.num_params = sum( np.prod(variable.shape.as_list()) diff --git a/rllib/agents/ppo/ppo_tf_policy.py b/rllib/agents/ppo/ppo_tf_policy.py index 2afaeb951..409bb0291 100644 --- a/rllib/agents/ppo/ppo_tf_policy.py +++ b/rllib/agents/ppo/ppo_tf_policy.py @@ -235,7 +235,7 @@ class ValueNetworkMixin: [prev_action]), SampleBatch.PREV_REWARDS: tf.convert_to_tensor( [prev_reward]), - "is_training": tf.convert_to_tensor(False), + "is_training": tf.convert_to_tensor([False]), }, [tf.convert_to_tensor([s]) for s in state], tf.convert_to_tensor([1])) return self.model.value_function()[0] diff --git a/rllib/agents/qmix/model.py b/rllib/agents/qmix/model.py index 905533302..66ec26975 100644 --- a/rllib/agents/qmix/model.py +++ b/rllib/agents/qmix/model.py @@ -1,3 +1,4 @@ +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override @@ -20,12 +21,12 @@ class RNNModel(TorchModelV2, nn.Module): self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs) - @override(TorchModelV2) + @override(ModelV2) def get_initial_state(self): # make hidden states on same device as model return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)] - @override(TorchModelV2) + @override(ModelV2) def forward(self, input_dict, hidden_state, seq_lens): x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float())) h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim) diff --git a/rllib/agents/qmix/qmix_policy.py b/rllib/agents/qmix/qmix_policy.py index d8c254ef0..355588a2d 100644 --- a/rllib/agents/qmix/qmix_policy.py +++ b/rllib/agents/qmix/qmix_policy.py @@ -10,7 +10,7 @@ from ray.rllib.policy.policy import Policy from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.model import _unpack_obs +from ray.rllib.models.modelv2 import _unpack_obs from ray.rllib.env.constants import GROUP_REWARDS from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.annotations import override diff --git a/rllib/contrib/bandits/models/linear_regression.py b/rllib/contrib/bandits/models/linear_regression.py index 10c0f969f..285486dbb 100644 --- a/rllib/contrib/bandits/models/linear_regression.py +++ b/rllib/contrib/bandits/models/linear_regression.py @@ -1,4 +1,6 @@ import gym + +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils import try_import_torch from ray.rllib.utils.annotations import override @@ -111,7 +113,7 @@ class DiscreteLinearModel(TorchModelV2, nn.Module): self._cur_value = None self._cur_ctx = None - @override(TorchModelV2) + @override(ModelV2) def forward(self, input_dict, state, seq_lens): x = input_dict["obs"] scores = self.predict(x) @@ -137,7 +139,7 @@ class DiscreteLinearModel(TorchModelV2, nn.Module): f"It should be 0 <= arm < {len(self.arms)}" self.arms[arm].partial_fit(x, y) - @override(TorchModelV2) + @override(ModelV2) def value_function(self): assert self._cur_value is not None, "must call forward() first" return self._cur_value @@ -190,7 +192,7 @@ class ParametricLinearModel(TorchModelV2, nn.Module): assert x.size()[ 0] == 1, "Only batch size of 1 is supported for now." - @override(TorchModelV2) + @override(ModelV2) def forward(self, input_dict, state, seq_lens): x = input_dict["obs"]["item"] self._check_inputs(x) @@ -214,7 +216,7 @@ class ParametricLinearModel(TorchModelV2, nn.Module): action_id = arm.item() self.arm.partial_fit(x[:, action_id], y) - @override(TorchModelV2) + @override(ModelV2) def value_function(self): assert self._cur_value is not None, "must call forward() first" return self._cur_value diff --git a/rllib/examples/batch_norm_model.py b/rllib/examples/batch_norm_model.py index dd37a1eec..e58c9ff7a 100644 --- a/rllib/examples/batch_norm_model.py +++ b/rllib/examples/batch_norm_model.py @@ -4,9 +4,12 @@ import argparse import ray from ray import tune -from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override tf = try_import_tf() @@ -15,28 +18,127 @@ parser.add_argument("--num-iters", type=int, default=200) parser.add_argument("--run", type=str, default="PPO") -class BatchNormModel(Model): - def _build_layers_v2(self, input_dict, num_outputs, options): +class BatchNormModel(TFModelV2): + """Example of a TFModelV2 that is built w/o using tf.keras. + + NOTE: This example does not work when using a keras-based TFModelV2 due + to a bug in keras related to missing values for input placeholders, even + though these input values have been provided in a forward pass through the + actual keras Model. + + All Model logic (layers) is defined in the `forward` method (incl. + the batch_normalization layers). Also, all variables are registered + (only once) at the end of `forward`, so an optimizer knows which tensors + to train on. A standard `value_function` override is used. + """ + capture_index = 0 + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super().__init__(obs_space, action_space, num_outputs, model_config, + name) + # Have we registered our vars yet (see `forward`)? + self._registered = False + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): last_layer = input_dict["obs"] hiddens = [256, 256] + with tf.variable_scope("model", reuse=tf.AUTO_REUSE): + for i, size in enumerate(hiddens): + last_layer = tf.layers.dense( + last_layer, + size, + kernel_initializer=normc_initializer(1.0), + activation=tf.nn.tanh, + name="fc{}".format(i)) + # Add a batch norm layer + last_layer = tf.layers.batch_normalization( + last_layer, + training=input_dict["is_training"], + name="bn_{}".format(i)) + + output = tf.layers.dense( + last_layer, + self.num_outputs, + kernel_initializer=normc_initializer(0.01), + activation=None, + name="out") + self._value_out = tf.layers.dense( + last_layer, + 1, + kernel_initializer=normc_initializer(1.0), + activation=None, + name="vf") + if not self._registered: + self.register_variables( + tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES, scope=".+/model/.+")) + self._registered = True + + return output, [] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + +class KerasBatchNormModel(TFModelV2): + """Keras version of above BatchNormModel with exactly the same structure. + + IMORTANT NOTE: This model will not work with PPO due to a bug in keras + that surfaces when having more than one input placeholder (here: `inputs` + and `is_training`) AND using the `make_tf_callable` helper (e.g. used by + PPO), in which auto-placeholders are generated, then passed through the + tf.keras. models.Model. In this last step, the connection between 1) the + provided value in the auto-placeholder and 2) the keras `is_training` + Input is broken and keras complains. + Use the above `BatchNormModel` (a non-keras based TFModelV2), instead. + """ + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super().__init__(obs_space, action_space, num_outputs, model_config, + name) + inputs = tf.keras.layers.Input(shape=obs_space.shape, name="inputs") + is_training = tf.keras.layers.Input( + shape=(), dtype=tf.bool, batch_size=1, name="is_training") + last_layer = inputs + hiddens = [256, 256] for i, size in enumerate(hiddens): label = "fc{}".format(i) - last_layer = tf.layers.dense( - last_layer, - size, + last_layer = tf.keras.layers.Dense( + units=size, kernel_initializer=normc_initializer(1.0), activation=tf.nn.tanh, - name=label) + name=label)(last_layer) # Add a batch norm layer - last_layer = tf.layers.batch_normalization( - last_layer, training=input_dict["is_training"]) - output = tf.layers.dense( - last_layer, - num_outputs, + last_layer = tf.keras.layers.BatchNormalization()( + last_layer, training=is_training[0]) + output = tf.keras.layers.Dense( + units=self.num_outputs, kernel_initializer=normc_initializer(0.01), activation=None, - name="fc_out") - return output, last_layer + name="fc_out")(last_layer) + value_out = tf.keras.layers.Dense( + units=1, + kernel_initializer=normc_initializer(0.01), + activation=None, + name="value_out")(last_layer) + + self.base_model = tf.keras.models.Model( + inputs=[inputs, is_training], outputs=[output, value_out]) + self.register_variables(self.base_model.variables) + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + out, self._value_out = self.base_model( + [input_dict["obs"], input_dict["is_training"]]) + return out, [] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) if __name__ == "__main__": @@ -44,14 +146,21 @@ if __name__ == "__main__": ray.init() ModelCatalog.register_custom_model("bn_model", BatchNormModel) + + config = { + "env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0", + "model": { + "custom_model": "bn_model", + }, + "num_workers": 0, + } + + from ray.rllib.agents.ppo import PPOTrainer + trainer = PPOTrainer(config=config) + trainer.train() + tune.run( args.run, stop={"training_iteration": args.num_iters}, - config={ - "env": "Pendulum-v0" if args.run == "DDPG" else "CartPole-v0", - "model": { - "custom_model": "bn_model", - }, - "num_workers": 0, - }, + config=config, ) diff --git a/rllib/examples/custom_keras_rnn_model.py b/rllib/examples/custom_keras_rnn_model.py index 1226672f3..ec02de923 100644 --- a/rllib/examples/custom_keras_rnn_model.py +++ b/rllib/examples/custom_keras_rnn_model.py @@ -169,7 +169,6 @@ if __name__ == "__main__": "max_seq_len": 20, }, } - tune.run( args.run, config=config, diff --git a/rllib/examples/eager_execution.py b/rllib/examples/eager_execution.py index f125c56c4..772d1e5b0 100644 --- a/rllib/examples/eager_execution.py +++ b/rllib/examples/eager_execution.py @@ -4,11 +4,14 @@ import random import ray from ray import tune from ray.rllib.agents.trainer_template import build_trainer -from ray.rllib.models import Model, ModelCatalog -from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork +from ray.rllib.models import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override tf = try_import_tf() @@ -16,7 +19,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--iters", type=int, default=200) -class EagerModel(Model): +class EagerModel(TFModelV2): """Example of using embedded eager execution in a custom model. This shows how to use tf.py_function() to execute a snippet of TF code @@ -25,15 +28,39 @@ class EagerModel(Model): perform any TF eager operation in tf.py_function(). """ - def _build_layers_v2(self, input_dict, num_outputs, options): - self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space, - self.action_space, num_outputs, - options) - feature_out = tf.py_function(self.forward_eager, - [self.fcnet.last_layer], tf.float32) + def __init__(self, observation_space, action_space, num_outputs, + model_config, name): + super().__init__(observation_space, action_space, num_outputs, + model_config, name) - with tf.control_dependencies([feature_out]): - return tf.identity(self.fcnet.outputs), feature_out + inputs = tf.keras.layers.Input(shape=observation_space.shape) + self.fcnet = FullyConnectedNetwork( + obs_space=self.obs_space, + action_space=self.action_space, + num_outputs=self.num_outputs, + model_config=self.model_config, + name="fc1") + out, value_out = self.fcnet.base_model(inputs) + + def lambda_(x): + eager_out = tf.py_function(self.forward_eager, [x], tf.float32) + with tf.control_dependencies([eager_out]): + eager_out.set_shape(x.shape) + return eager_out + + out = tf.keras.layers.Lambda(lambda_)(out) + self.base_model = tf.keras.models.Model(inputs, [out, value_out]) + self.register_variables(self.base_model.variables) + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + out, self._value_out = self.base_model(input_dict["obs"], state, + seq_lens) + return out, [] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) def forward_eager(self, feature_layer): assert tf.executing_eagerly() @@ -84,13 +111,13 @@ if __name__ == "__main__": ray.init() args = parser.parse_args() ModelCatalog.register_custom_model("eager_model", EagerModel) - tune.run( - MyTrainer, - stop={"training_iteration": args.iters}, - config={ - "env": "CartPole-v0", - "num_workers": 0, - "model": { - "custom_model": "eager_model" - }, - }) + + config = { + "env": "CartPole-v0", + "num_workers": 0, + "model": { + "custom_model": "eager_model" + }, + } + + tune.run(MyTrainer, stop={"training_iteration": args.iters}, config=config) diff --git a/rllib/examples/multi_agent_cartpole.py b/rllib/examples/multi_agent_cartpole.py index 2a6c70c7c..40ade9eef 100644 --- a/rllib/examples/multi_agent_cartpole.py +++ b/rllib/examples/multi_agent_cartpole.py @@ -15,10 +15,13 @@ import random import ray from ray import tune -from ray.rllib.models import Model, ModelCatalog +from ray.rllib.models import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.tests.test_multi_agent_env import MultiCartpole from ray.tune.registry import register_env from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override tf = try_import_tf() @@ -31,8 +34,13 @@ parser.add_argument("--simple", action="store_true") parser.add_argument("--num-cpus", type=int, default=0) -class CustomModel1(Model): - def _build_layers_v2(self, input_dict, num_outputs, options): +class CustomModel1(TFModelV2): + def __init__(self, observation_space, action_space, num_outputs, + model_config, name): + super().__init__(observation_space, action_space, num_outputs, + model_config, name) + + inputs = tf.keras.layers.Input(observation_space.shape) # Example of (optional) weight sharing between two different policies. # Here, we share the variables defined in the 'shared' variable scope # by entering it explicitly with tf.AUTO_REUSE. This creates the @@ -42,29 +50,55 @@ class CustomModel1(Model): tf.VariableScope(tf.AUTO_REUSE, "shared"), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False): - last_layer = tf.layers.dense( - input_dict["obs"], 64, activation=tf.nn.relu, name="fc1") - last_layer = tf.layers.dense( - last_layer, 64, activation=tf.nn.relu, name="fc2") - output = tf.layers.dense( - last_layer, num_outputs, activation=None, name="fc_out") - return output, last_layer + last_layer = tf.keras.layers.Dense( + units=64, activation=tf.nn.relu, name="fc1")(inputs) + output = tf.keras.layers.Dense( + units=num_outputs, activation=None, name="fc_out")(last_layer) + vf = tf.keras.layers.Dense( + units=1, activation=None, name="value_out")(last_layer) + self.base_model = tf.keras.models.Model(inputs, [output, vf]) + self.register_variables(self.base_model.variables) + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + out, self._value_out = self.base_model(input_dict["obs"]) + return out, [] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) -class CustomModel2(Model): - def _build_layers_v2(self, input_dict, num_outputs, options): - # Weights shared with CustomModel1 +class CustomModel2(TFModelV2): + def __init__(self, observation_space, action_space, num_outputs, + model_config, name): + super().__init__(observation_space, action_space, num_outputs, + model_config, name) + + inputs = tf.keras.layers.Input(observation_space.shape) + + # Weights shared with CustomModel1. with tf.variable_scope( tf.VariableScope(tf.AUTO_REUSE, "shared"), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False): - last_layer = tf.layers.dense( - input_dict["obs"], 64, activation=tf.nn.relu, name="fc1") - last_layer = tf.layers.dense( - last_layer, 64, activation=tf.nn.relu, name="fc2") - output = tf.layers.dense( - last_layer, num_outputs, activation=None, name="fc_out") - return output, last_layer + last_layer = tf.keras.layers.Dense( + units=64, activation=tf.nn.relu, name="fc1")(inputs) + output = tf.keras.layers.Dense( + units=num_outputs, activation=None, name="fc_out")(last_layer) + vf = tf.keras.layers.Dense( + units=1, activation=None, name="value_out")(last_layer) + self.base_model = tf.keras.models.Model(inputs, [output, vf]) + self.register_variables(self.base_model.variables) + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + out, self._value_out = self.base_model(input_dict["obs"]) + return out, [] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) if __name__ == "__main__": diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 289542f9c..92f96ac15 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -23,6 +23,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \ TorchMultiActionDistribution, TorchMultiCategorical from ray.rllib.utils import try_import_tf, try_import_tree from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.space_utils import flatten_space @@ -471,6 +472,7 @@ class ModelCatalog: seq_lens=None): """Deprecated: use get_model_v2() instead.""" + deprecation_warning("get_model", "get_model_v2", error=False) assert isinstance(input_dict, dict) options = options or MODEL_DEFAULTS model = ModelCatalog._get_model(input_dict, obs_space, action_space, @@ -496,6 +498,7 @@ class ModelCatalog: @staticmethod def _get_model(input_dict, obs_space, action_space, num_outputs, options, state_in, seq_lens): + deprecation_warning("_get_model", "get_model_v2", error=False) if options.get("custom_model"): model = options["custom_model"] logger.debug("Using custom model {}".format(model)) diff --git a/rllib/models/model.py b/rllib/models/model.py index 52499c73b..ffb56c5eb 100644 --- a/rllib/models/model.py +++ b/rllib/models/model.py @@ -5,7 +5,8 @@ import gym from ray.rllib.models.tf.misc import linear, normc_initializer from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI -from ray.rllib.utils import try_import_tf, try_import_torch +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.framework import try_import_tf, try_import_torch tf = try_import_tf() torch, _ = try_import_torch() @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) class Model: - """This class is deprecated, please use TFModelV2 instead.""" + """This class is deprecated! Use ModelV2 instead.""" def __init__(self, input_dict, @@ -24,6 +25,9 @@ class Model: options, state_in=None, seq_lens=None): + # Soft-deprecate this class. All Models should use the ModelV2 + # API from here on. + deprecation_warning("Model", "ModelV2", error=False) assert isinstance(input_dict, dict), input_dict # Default attribute values for the non-RNN case diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 1f8341423..0a8374498 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -1,6 +1,13 @@ +from collections import OrderedDict +import gym + +from ray.rllib.models.preprocessors import get_preprocessor from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.model import restore_original_dimensions, flatten -from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI +from ray.rllib.utils.framework import try_import_tf, try_import_torch + +tf = try_import_tf() +torch, _ = try_import_torch() @PublicAPI @@ -70,7 +77,7 @@ class ModelV2: Custom models should override this instead of __call__. - Arguments: + Args: input_dict (dict): dictionary of input tensors, including "obs", "obs_flat", "prev_action", "prev_reward", "is_training" state (list): list of state tensors with sizes matching those @@ -80,6 +87,12 @@ class ModelV2: Returns: (outputs, state): The model output tensor of size [BATCH, num_outputs] + + Examples: + >>> def forward(self, input_dict, state, seq_lens): + >>> model_out, self._value_out = self.base_model( + ... input_dict["obs"]) + >>> return model_out, state """ raise NotImplementedError @@ -267,3 +280,98 @@ class NullContextManager: def __exit__(self, *args): pass + + +@DeveloperAPI +def flatten(obs, framework): + """Flatten the given tensor.""" + if framework == "tf": + return tf.layers.flatten(obs) + elif framework == "torch": + assert torch is not None + return torch.flatten(obs, start_dim=1) + else: + raise NotImplementedError("flatten", framework) + + +@DeveloperAPI +def restore_original_dimensions(obs, obs_space, tensorlib=tf): + """Unpacks Dict and Tuple space observations into their original form. + + This is needed since we flatten Dict and Tuple observations in transit. + Before sending them to the model though, we should unflatten them into + Dicts or Tuples of tensors. + + Arguments: + obs: The flattened observation tensor. + obs_space: The flattened obs space. If this has the `original_space` + attribute, we will unflatten the tensor to that shape. + tensorlib: The library used to unflatten (reshape) the array/tensor. + + Returns: + single tensor or dict / tuple of tensors matching the original + observation space. + """ + + if hasattr(obs_space, "original_space"): + if tensorlib == "tf": + tensorlib = tf + elif tensorlib == "torch": + assert torch is not None + tensorlib = torch + return _unpack_obs(obs, obs_space.original_space, tensorlib=tensorlib) + else: + return obs + + +# Cache of preprocessors, for if the user is calling unpack obs often. +_cache = {} + + +def _unpack_obs(obs, space, tensorlib=tf): + """Unpack a flattened Dict or Tuple observation array/tensor. + + Arguments: + obs: The flattened observation tensor + space: The original space prior to flattening + tensorlib: The library used to unflatten (reshape) the array/tensor + """ + + if (isinstance(space, gym.spaces.Dict) + or isinstance(space, gym.spaces.Tuple)): + if id(space) in _cache: + prep = _cache[id(space)] + else: + prep = get_preprocessor(space)(space) + # Make an attempt to cache the result, if enough space left. + if len(_cache) < 999: + _cache[id(space)] = prep + if len(obs.shape) != 2 or obs.shape[1] != prep.shape[0]: + raise ValueError( + "Expected flattened obs shape of [None, {}], got {}".format( + prep.shape[0], obs.shape)) + assert len(prep.preprocessors) == len(space.spaces), \ + (len(prep.preprocessors) == len(space.spaces)) + offset = 0 + if isinstance(space, gym.spaces.Tuple): + u = [] + for p, v in zip(prep.preprocessors, space.spaces): + obs_slice = obs[:, offset:offset + p.size] + offset += p.size + u.append( + _unpack_obs( + tensorlib.reshape(obs_slice, [-1] + list(p.shape)), + v, + tensorlib=tensorlib)) + else: + u = OrderedDict() + for p, (k, v) in zip(prep.preprocessors, space.spaces.items()): + obs_slice = obs[:, offset:offset + p.size] + offset += p.size + u[k] = _unpack_obs( + tensorlib.reshape(obs_slice, [-1] + list(p.shape)), + v, + tensorlib=tensorlib) + return u + else: + return obs diff --git a/rllib/models/tf/__init__.py b/rllib/models/tf/__init__.py index 74b6cfed3..2ca11563f 100644 --- a/rllib/models/tf/__init__.py +++ b/rllib/models/tf/__init__.py @@ -1,14 +1,12 @@ -# TODO(sven): Add once ModelV1 is deprecated and we no longer cause circular -# dependencies b/c of that. -# from ray.rllib.models.tf.tf_modelv2 import TFModelV2 -# from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork -# from ray.rllib.models.tf.recurrent_tf_modelv2 import \ -# RecurrentTFModelV2 -# from ray.rllib.models.tf.visionnet_v2 import VisionNetwork +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork +from ray.rllib.models.tf.recurrent_tf_modelv2 import \ + RecurrentTFModelV2 +from ray.rllib.models.tf.visionnet_v2 import VisionNetwork -# __all__ = [ -# "FullyConnectedNetwork", -# "RecurrentTFModelV2", -# "TFModelV2", -# "VisionNetwork", -# ] +__all__ = [ + "FullyConnectedNetwork", + "RecurrentTFModelV2", + "TFModelV2", + "VisionNetwork", +] diff --git a/rllib/models/tf/fcnet_v1.py b/rllib/models/tf/fcnet_v1.py index c117bcf45..b5a8b075d 100644 --- a/rllib/models/tf/fcnet_v1.py +++ b/rllib/models/tf/fcnet_v1.py @@ -1,6 +1,7 @@ from ray.rllib.models.model import Model from ray.rllib.models.tf.misc import normc_initializer from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import get_activation_fn, try_import_tf tf = try_import_tf() @@ -17,6 +18,12 @@ class FullyConnectedNetwork(Model): Note that dict inputs will be flattened into a vector. To define a model that processes the components separately, use _build_layers_v2(). """ + # Soft deprecate this class. All Models should use the ModelV2 + # API from here on. + deprecation_warning( + "Model->FullyConnectedNetwork", + "ModelV2->FullyConnectedNetwork", + error=False) hiddens = options.get("fcnet_hiddens") activation = get_activation_fn(options.get("fcnet_activation")) diff --git a/rllib/models/tf/fcnet_v2.py b/rllib/models/tf/fcnet_v2.py index 5a1e2dd7b..303016f27 100644 --- a/rllib/models/tf/fcnet_v2.py +++ b/rllib/models/tf/fcnet_v2.py @@ -16,7 +16,7 @@ class FullyConnectedNetwork(TFModelV2): obs_space, action_space, num_outputs, model_config, name) activation = get_activation_fn(model_config.get("fcnet_activation")) - hiddens = model_config.get("fcnet_hiddens") + hiddens = model_config.get("fcnet_hiddens", []) no_final_linear = model_config.get("no_final_linear") vf_share_layers = model_config.get("vf_share_layers") diff --git a/rllib/models/tf/lstm_v1.py b/rllib/models/tf/lstm_v1.py index eea2fb686..972e9aedd 100644 --- a/rllib/models/tf/lstm_v1.py +++ b/rllib/models/tf/lstm_v1.py @@ -4,7 +4,8 @@ from ray.rllib.models.model import Model from ray.rllib.models.tf.misc import linear, normc_initializer from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.utils.annotations import override -from ray.rllib.utils import try_import_tf +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.framework import try_import_tf tf = try_import_tf() @@ -21,6 +22,10 @@ class LSTM(Model): @override(Model) def _build_layers_v2(self, input_dict, num_outputs, options): + # Hard deprecate this class. All Models should use the ModelV2 + # API from here on. + deprecation_warning("Model->LSTM", "RecurrentTFModelV2", error=False) + cell_size = options.get("lstm_cell_size") if options.get("lstm_use_prev_action_reward"): action_dim = int( diff --git a/rllib/models/tf/tf_modelv2.py b/rllib/models/tf/tf_modelv2.py index 71b5b0109..9bd54755c 100644 --- a/rllib/models/tf/tf_modelv2.py +++ b/rllib/models/tf/tf_modelv2.py @@ -52,39 +52,6 @@ class TFModelV2(ModelV2): else: return ModelV2.context(self) - def forward(self, input_dict, state, seq_lens): - """Call the model with the given input tensors and state. - - Any complex observations (dicts, tuples, etc.) will be unpacked by - __call__ before being passed to forward(). To access the flattened - observation tensor, refer to input_dict["obs_flat"]. - - This method can be called any number of times. In eager execution, - each call to forward() will eagerly evaluate the model. In symbolic - execution, each call to forward creates a computation graph that - operates over the variables of this model (i.e., shares weights). - - Custom models should override this instead of __call__. - - Args: - input_dict (dict): dictionary of input tensors, including "obs", - "obs_flat", "prev_action", "prev_reward", "is_training" - state (list): list of state tensors with sizes matching those - returned by get_initial_state + the batch dimension - seq_lens (Tensor): 1d tensor holding input sequence lengths - - Returns: - (outputs, state): The model output tensor of size - [BATCH, num_outputs] - - Examples: - >>> def forward(self, input_dict, state, seq_lens): - >>> model_out, self._value_out = self.base_model( - ... input_dict["obs"]) - >>> return model_out, state - """ - raise NotImplementedError - def update_ops(self): """Return the list of update ops for this model. diff --git a/rllib/models/tf/visionnet_v1.py b/rllib/models/tf/visionnet_v1.py index 73c848183..02d5328ec 100644 --- a/rllib/models/tf/visionnet_v1.py +++ b/rllib/models/tf/visionnet_v1.py @@ -1,6 +1,7 @@ from ray.rllib.models.model import Model from ray.rllib.models.tf.misc import flatten from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import get_activation_fn, try_import_tf tf = try_import_tf() @@ -12,6 +13,10 @@ class VisionNetwork(Model): @override(Model) def _build_layers_v2(self, input_dict, num_outputs, options): + # Hard deprecate this class. All Models should use the ModelV2 + # API from here on. + deprecation_warning( + "Model->VisionNetwork", "ModelV2->VisionNetwork", error=False) inputs = input_dict["obs"] filters = options.get("conv_filters") if not filters: diff --git a/rllib/models/torch/recurrent_torch_model.py b/rllib/models/torch/recurrent_torch_model.py index 5be688259..0d99e0bfb 100644 --- a/rllib/models/torch/recurrent_torch_model.py +++ b/rllib/models/torch/recurrent_torch_model.py @@ -11,7 +11,7 @@ torch, nn = try_import_torch() @DeveloperAPI class RecurrentTorchModel(TorchModelV2, nn.Module): - """Helper class to simplify implementing RNN models with TFModelV2. + """Helper class to simplify implementing RNN models with TorchModelV2. Instead of implementing forward(), you can implement forward_rnn() which takes batches with the time dimension added already. diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index 8f310c190..c232ecd83 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -30,7 +30,7 @@ class TorchModelV2(ModelV2): if not isinstance(self, nn.Module): raise ValueError( "Subclasses of TorchModelV2 must also inherit from " - "nn.Module, e.g., MyModel(TorchModelV2, nn.Module)") + "nn.Module, e.g., MyModel(TorchModel, nn.Module)") ModelV2.__init__( self, @@ -41,39 +41,6 @@ class TorchModelV2(ModelV2): name, framework="torch") - def forward(self, input_dict, state, seq_lens): - """Call the model with the given input tensors and state. - - Any complex observations (dicts, tuples, etc.) will be unpacked by - __call__ before being passed to forward(). To access the flattened - observation tensor, refer to input_dict["obs_flat"]. - - This method can be called any number of times. In eager execution, - each call to forward() will eagerly evaluate the model. In symbolic - execution, each call to forward creates a computation graph that - operates over the variables of this model (i.e., shares weights). - - Custom models should override this instead of __call__. - - Args: - input_dict (dict): dictionary of input tensors, including "obs", - "obs_flat", "prev_action", "prev_reward", "is_training" - state (list): list of state tensors with sizes matching those - returned by get_initial_state + the batch dimension - seq_lens (Tensor): 1d tensor holding input sequence lengths - - Returns: - (outputs, state): The model output tensor of size - [BATCH, num_outputs] - - Examples: - >>> def forward(self, input_dict, state, seq_lens): - >>> features = self._hidden_layers(input_dict["obs"]) - >>> self._value_out = self._value_branch(features) - >>> return self._logits(features), state - """ - raise NotImplementedError - @override(ModelV2) def variables(self, as_dict=False): if as_dict: diff --git a/rllib/tests/test_catalog.py b/rllib/tests/test_catalog.py index 1a2b33533..c7f8ab7d0 100644 --- a/rllib/tests/test_catalog.py +++ b/rllib/tests/test_catalog.py @@ -5,12 +5,12 @@ import unittest import ray from ray.rllib.models import ModelCatalog, MODEL_DEFAULTS, ActionDistribution -from ray.rllib.models.model import Model +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.preprocessors import (NoPreprocessor, OneHotPreprocessor, Preprocessor) -from ray.rllib.models.tf.fcnet_v1 import FullyConnectedNetwork -from ray.rllib.models.tf.visionnet_v1 import VisionNetwork +from ray.rllib.models.tf.fcnet_v2 import FullyConnectedNetwork +from ray.rllib.models.tf.visionnet_v2 import VisionNetwork from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf @@ -27,7 +27,7 @@ class CustomPreprocessor2(Preprocessor): return [1] -class CustomModel(Model): +class CustomModel(TFModelV2): def _build_layers(self, *args): return tf.constant([[0] * 5]), None @@ -101,25 +101,29 @@ class ModelCatalogTest(unittest.TestCase): ray.init(object_store_memory=1000 * 1024 * 1024) with tf.variable_scope("test1"): - p1 = ModelCatalog.get_model({ - "obs": tf.zeros((10, 3), dtype=tf.float32) - }, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5, {}) + p1 = ModelCatalog.get_model_v2( + obs_space=Box(0, 1, shape=(3, ), dtype=np.float32), + action_space=Discrete(5), + num_outputs=5, + model_config={}) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): - p2 = ModelCatalog.get_model({ - "obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32) - }, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), Discrete(5), 5, - {}) + p2 = ModelCatalog.get_model_v2( + obs_space=Box(0, 1, shape=(84, 84, 3), dtype=np.float32), + action_space=Discrete(5), + num_outputs=5, + model_config={}) self.assertEqual(type(p2), VisionNetwork) def test_custom_model(self): ray.init(object_store_memory=1000 * 1024 * 1024) ModelCatalog.register_custom_model("foo", CustomModel) - p1 = ModelCatalog.get_model({ - "obs": tf.constant([1, 2, 3]) - }, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5, - {"custom_model": "foo"}) + p1 = ModelCatalog.get_model_v2( + obs_space=Box(0, 1, shape=(3, ), dtype=np.float32), + action_space=Discrete(5), + num_outputs=5, + model_config={"custom_model": "foo"}) self.assertEqual(str(type(p1)), str(CustomModel)) def test_custom_action_distribution(self): diff --git a/rllib/tests/test_lstm.py b/rllib/tests/test_lstm.py index d3c59e344..bb3246e04 100644 --- a/rllib/tests/test_lstm.py +++ b/rllib/tests/test_lstm.py @@ -5,13 +5,14 @@ import unittest import ray from ray.rllib.agents.ppo import PPOTrainer -from ray.rllib.policy.rnn_sequencing import chop_into_sequences, \ - add_time_dimension +from ray.rllib.policy.rnn_sequencing import chop_into_sequences from ray.rllib.models import ModelCatalog -from ray.rllib.models.tf.misc import linear, normc_initializer -from ray.rllib.models.model import Model +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.misc import normc_initializer +from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2 from ray.tune.registry import register_env from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override tf = try_import_tf() @@ -91,83 +92,109 @@ class TestLSTMUtils(unittest.TestCase): self.assertEqual(seq_lens.tolist(), [1, 2]) -class RNNSpyModel(Model): +class RNNSpyModel(RecurrentTFModelV2): capture_index = 0 + cell_size = 3 - def _build_layers_v2(self, input_dict, num_outputs, options): - # Previously, a new class object was created during - # deserialization and this `capture_index` - # variable would be refreshed between class instantiations. - # This behavior is no longer the case, so we manually refresh - # the variable. - RNNSpyModel.capture_index = 0 + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + super().__init__(obs_space, action_space, num_outputs, model_config, + name) + self.cell_size = RNNSpyModel.cell_size - def spy(sequences, state_in, state_out, seq_lens): - if len(sequences) == 1: + def spy(inputs, seq_lens, h_in, c_in, h_out, c_out): + if len(inputs) == 1: return 0 # don't capture inference inputs # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite ray.experimental.internal_kv._internal_kv_put( "rnn_spy_in_{}".format(RNNSpyModel.capture_index), pickle.dumps({ - "sequences": sequences, - "state_in": state_in, - "state_out": state_out, - "seq_lens": seq_lens + "sequences": inputs, + "seq_lens": seq_lens, + "state_in": [h_in, c_in], + "state_out": [h_out, c_out] }), overwrite=True) RNNSpyModel.capture_index += 1 return 0 - features = input_dict["obs"] - cell_size = 3 - last_layer = add_time_dimension(features, self.seq_lens) + # Create a keras LSTM model. + inputs = tf.keras.layers.Input( + shape=(None, ) + obs_space.shape, name="input") + state_in_h = tf.keras.layers.Input(shape=(self.cell_size, ), name="h") + state_in_c = tf.keras.layers.Input(shape=(self.cell_size, ), name="c") + seq_lens = tf.keras.layers.Input( + shape=(), name="seq_lens", dtype=tf.int32) - # Setup the LSTM cell - lstm = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) - self.state_init = [ - np.zeros(lstm.state_size.c, np.float32), - np.zeros(lstm.state_size.h, np.float32) + lstm_out, state_out_h, state_out_c = tf.keras.layers.LSTM( + self.cell_size, + return_sequences=True, + return_state=True, + name="lstm")( + inputs=inputs, + mask=tf.sequence_mask(seq_lens), + initial_state=[state_in_h, state_in_c]) + self.dense = tf.keras.layers.Dense( + units=self.num_outputs, kernel_initializer=normc_initializer(0.01)) + + def lambda_(inputs): + spy_fn = tf.py_func( + spy, + [ + inputs[0], # observations + inputs[2], # seq_lens + inputs[3], # h_in + inputs[4], # c_in + inputs[5], # h_out + inputs[6], # c_out + ], + tf.int64, + stateful=True) + + # Compute outputs + with tf.control_dependencies([spy_fn]): + return self.dense(inputs[1]) # lstm_out + + logits = tf.keras.layers.Lambda(lambda_)([ + inputs, lstm_out, seq_lens, state_in_h, state_in_c, state_out_h, + state_out_c + ]) + + # Value branch. + value_out = tf.keras.layers.Dense( + units=1, kernel_initializer=normc_initializer(1.0))(lstm_out) + + self.base_model = tf.keras.Model( + [inputs, seq_lens, state_in_h, state_in_c], + [logits, value_out, state_out_h, state_out_c]) + self.base_model.summary() + self.register_variables(self.base_model.variables) + + @override(RecurrentTFModelV2) + def forward_rnn(self, inputs, state, seq_lens): + # Previously, a new class object was created during + # deserialization and this `capture_index` + # variable would be refreshed between class instantiations. + # This behavior is no longer the case, so we manually refresh + # the variable. + RNNSpyModel.capture_index = 0 + model_out, value_out, h, c = self.base_model( + [inputs, seq_lens, state[0], state[1]]) + self._value_out = value_out + return model_out, [h, c] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + @override(ModelV2) + def get_initial_state(self): + return [ + np.zeros(self.cell_size, np.float32), + np.zeros(self.cell_size, np.float32) ] - # Setup LSTM inputs - if self.state_in: - c_in, h_in = self.state_in - else: - c_in = tf.placeholder( - tf.float32, [None, lstm.state_size.c], name="c") - h_in = tf.placeholder( - tf.float32, [None, lstm.state_size.h], name="h") - self.state_in = [c_in, h_in] - - # Setup LSTM outputs - state_in = tf.nn.rnn_cell.LSTMStateTuple(c_in, h_in) - lstm_out, lstm_state = tf.nn.dynamic_rnn( - lstm, - last_layer, - initial_state=state_in, - sequence_length=self.seq_lens, - time_major=False, - dtype=tf.float32) - - self.state_out = list(lstm_state) - spy_fn = tf.py_func( - spy, [ - last_layer, - self.state_in, - self.state_out, - self.seq_lens, - ], - tf.int64, - stateful=True) - - # Compute outputs - with tf.control_dependencies([spy_fn]): - last_layer = tf.reshape(lstm_out, [-1, cell_size]) - logits = linear(last_layer, num_outputs, "action", - normc_initializer(0.01)) - return logits, last_layer - class DebugCounterEnv(gym.Env): def __init__(self): diff --git a/rllib/tests/test_nested_observation_spaces.py b/rllib/tests/test_nested_observation_spaces.py index 2917291ab..8e9bf8ffa 100644 --- a/rllib/tests/test_nested_observation_spaces.py +++ b/rllib/tests/test_nested_observation_spaces.py @@ -12,7 +12,7 @@ from ray.rllib.env import MultiAgentEnv from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.vector_env import VectorEnv from ray.rllib.models import ModelCatalog -from ray.rllib.models.model import Model +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.rollout import rollout @@ -119,13 +119,13 @@ class NestedMultiAgentEnv(MultiAgentEnv): return obs, rew, dones, infos -class InvalidModel(Model): - def _build_layers_v2(self, input_dict, num_outputs, options): +class InvalidModel(TorchModelV2): + def forward(self, input_dict, state, seq_lens): return "not", "valid" -class InvalidModel2(Model): - def _build_layers_v2(self, input_dict, num_outputs, options): +class InvalidModel2(TFModelV2): + def forward(self, input_dict, state, seq_lens): return tf.constant(0), tf.constant(0) @@ -158,10 +158,10 @@ class TorchSpyModel(TorchModelV2, nn.Module): return self.fc.value_function() -class DictSpyModel(Model): +class DictSpyModel(TFModelV2): capture_index = 0 - def _build_layers_v2(self, input_dict, num_outputs, options): + def forward(self, input_dict, state, seq_lens): def spy(pos, front_cam, task): # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite @@ -183,14 +183,14 @@ class DictSpyModel(Model): with tf.control_dependencies([spy_fn]): output = tf.layers.dense(input_dict["obs"]["sensors"]["position"], - num_outputs) - return output, output + self.num_outputs) + return output, [] -class TupleSpyModel(Model): +class TupleSpyModel(TFModelV2): capture_index = 0 - def _build_layers_v2(self, input_dict, num_outputs, options): + def forward(self, input_dict, state, seq_lens): def spy(pos, cam, task): # TF runs this function in an isolated context, so we have to use # redis to communicate back to our suite @@ -211,8 +211,8 @@ class TupleSpyModel(Model): stateful=True) with tf.control_dependencies([spy_fn]): - output = tf.layers.dense(input_dict["obs"][0], num_outputs) - return output, output + output = tf.layers.dense(input_dict["obs"][0], self.num_outputs) + return output, [] class NestedSpacesTest(unittest.TestCase): @@ -226,17 +226,20 @@ class NestedSpacesTest(unittest.TestCase): def test_invalid_model(self): ModelCatalog.register_custom_model("invalid", InvalidModel) - self.assertRaises(ValueError, lambda: PGTrainer( - env="CartPole-v0", config={ - "model": { - "custom_model": "invalid", - }, - })) + self.assertRaisesRegexp( + ValueError, + "Subclasses of TorchModelV2 must also inherit from", + lambda: PGTrainer( + env="CartPole-v0", config={ + "model": { + "custom_model": "invalid", + }, + })) def test_invalid_model2(self): ModelCatalog.register_custom_model("invalid2", InvalidModel2) self.assertRaisesRegexp( - ValueError, "Expected output.*", + ValueError, "Expected output shape of", lambda: PGTrainer( env="CartPole-v0", config={ "model": { diff --git a/rllib/utils/tf_ops.py b/rllib/utils/tf_ops.py index bca05fb0d..ea636f744 100644 --- a/rllib/utils/tf_ops.py +++ b/rllib/utils/tf_ops.py @@ -86,7 +86,8 @@ def make_tf_callable(session_or_none, dynamic_shape=False): name="arg_{}".format(i))) symbolic_out[0] = fn(*placeholders) feed_dict = dict(zip(placeholders, args)) - return session_or_none.run(symbolic_out[0], feed_dict) + ret = session_or_none.run(symbolic_out[0], feed_dict) + return ret return call else: