From ee4b6e7e3ba9a58e0062db6ed13e96686977c71e Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Fri, 12 Mar 2021 18:27:25 +0100 Subject: [PATCH] [RLlib] Unity3D example broken due to change in ML-Agents API. Attention-net prev-n-a/r. Attention-wrapper works with images. (#14569) --- rllib/agents/trainer.py | 26 ++++--- rllib/env/wrappers/unity3d_env.py | 56 ++++++++++++-- rllib/examples/unity3d_env_local.py | 25 +++++-- rllib/examples/vizdoom_with_attention_net.py | 77 ++++++++++++++++++++ rllib/models/catalog.py | 6 +- rllib/models/modelv2.py | 12 ++- rllib/models/tf/attention_net.py | 70 +++++++++++++++++- rllib/models/tf/visionnet.py | 20 +++-- rllib/models/torch/attention_net.py | 70 +++++++++++++++++- rllib/models/torch/misc.py | 5 +- rllib/models/torch/visionnet.py | 5 +- rllib/models/utils.py | 15 +++- rllib/policy/dynamic_tf_policy.py | 14 +++- 13 files changed, 350 insertions(+), 51 deletions(-) create mode 100644 rllib/examples/vizdoom_with_attention_net.py diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index b8c7b469e..cd5f6ac8b 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -609,19 +609,25 @@ class Trainer(Trainable): elif "." in env: self.env_creator = \ lambda env_context: from_config(env, env_context) - # Try gym/PyBullet. + # Try gym/PyBullet/Vizdoom. else: def _creator(env_context): import gym - # Allow for PyBullet envs to be used as well (via string). - # This allows for doing things like - # `env=CartPoleContinuousBulletEnv-v0`. + # Allow for PyBullet or VizdoomGym envs to be used as well + # (via string). This allows for doing things like + # `env=CartPoleContinuousBulletEnv-v0` or + # `env=VizdoomBasic-v0`. try: import pybullet_envs pybullet_envs.getList() except (ModuleNotFoundError, ImportError): pass + try: + import vizdoomgym + vizdoomgym.__name__ # trick LINTer. + except (ModuleNotFoundError, ImportError): + pass # Try creating a gym env. If this fails we can output a # decent error message. try: @@ -629,12 +635,12 @@ class Trainer(Trainable): except gym.error.Error: raise ValueError( "The env string you provided ({}) is a) not a " - "known gym/PyBullet environment specifier or b) " - "not registered! To register your custom envs, " - "do `from ray import tune; tune.register('[name]'," - " lambda cfg: [return actual " - "env from here using cfg])`. Then you can use " - "[name] as your config['env'].".format(env)) + "known gym/PyBullet/VizdoomEnv environment " + "specifier or b) not registered! To register your " + "custom envs, do `from ray import tune; " + "tune.register('[name]', lambda cfg: [return " + "actual env from here using cfg])`. Then you can " + "use [name] as your config['env'].".format(env)) self.env_creator = _creator else: diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 876c06e96..a82a9955a 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -73,11 +73,13 @@ class Unity3DEnv(MultiAgentEnv): from mlagents_envs.environment import UnityEnvironment # Try connecting to the Unity3D game instance. If a port is blocked + port_ = None while True: # Sleep for random time to allow for concurrent startup of many # environments (num_workers >> 1). Otherwise, would lead to port # conflicts sometimes. - time.sleep(random.randint(1, 10)) + if port_ is not None: + time.sleep(random.randint(1, 10)) port_ = port or (self._BASE_PORT_ENVIRONMENT if file_name else self._BASE_PORT_EDITOR) # cache the worker_id and @@ -101,6 +103,10 @@ class Unity3DEnv(MultiAgentEnv): else: break + # ML-Agents API version. + self.api_version = self.unity_env.API_VERSION.split(".") + self.api_version = [int(s) for s in self.api_version] + # Reset entire env every this number of step calls. self.episode_horizon = episode_horizon # Keep track of how many times we have called `step` so far. @@ -128,16 +134,37 @@ class Unity3DEnv(MultiAgentEnv): it. __all__=True, if episode is done for all agents. - infos: An (empty) info dict. """ + from mlagents_envs.base_env import ActionTuple # Set only the required actions (from the DecisionSteps) in Unity3D. all_agents = [] for behavior_name in self.unity_env.behavior_specs: - for agent_id in self.unity_env.get_steps(behavior_name)[ - 0].agent_id_to_index.keys(): - key = behavior_name + "_{}".format(agent_id) - all_agents.append(key) - self.unity_env.set_action_for_agent(behavior_name, agent_id, - action_dict[key]) + # New ML-Agents API: Set all agents actions at the same time + # via an ActionTuple. Since API v1.4.0. + if self.api_version[0] > 1 or (self.api_version[0] == 1 + and self.api_version[1] >= 4): + actions = [] + for agent_id in self.unity_env.get_steps(behavior_name)[ + 0].agent_id: + key = behavior_name + "_{}".format(agent_id) + all_agents.append(key) + actions.append(action_dict[key]) + if actions: + if actions[0].dtype == np.float32: + action_tuple = ActionTuple( + continuous=np.array(actions)) + else: + action_tuple = ActionTuple(discrete=np.array(actions)) + self.unity_env.set_actions(behavior_name, action_tuple) + # Old behavior: Do not use an ActionTuple and set each agent's + # action individually. + else: + for agent_id in self.unity_env.get_steps(behavior_name)[ + 0].agent_id_to_index.keys(): + key = behavior_name + "_{}".format(agent_id) + all_agents.append(key) + self.unity_env.set_action_for_agent( + behavior_name, agent_id, action_dict[key]) # Do the step. self.unity_env.step() @@ -215,6 +242,8 @@ class Unity3DEnv(MultiAgentEnv): "3DBall": Box(float("-inf"), float("inf"), (8, )), # 3DBallHard. "3DBallHard": Box(float("-inf"), float("inf"), (45, )), + # GridFoodCollector + "GridFoodCollector": Box(float("-inf"), float("inf"), (40, 40, 6)), # Pyramids. "Pyramids": TupleSpace([ Box(float("-inf"), float("inf"), (56, )), @@ -228,6 +257,15 @@ class Unity3DEnv(MultiAgentEnv): Box(float("-inf"), float("inf"), (231, )), Box(float("-inf"), float("inf"), (63, )), ]), + # Sorter. + "Sorter": TupleSpace([ + Box(float("-inf"), float("inf"), ( + 20, + 23, + )), + Box(float("-inf"), float("inf"), (10, )), + Box(float("-inf"), float("inf"), (8, )), + ]), # Tennis. "Tennis": Box(float("-inf"), float("inf"), (27, )), # VisualHallway. @@ -247,11 +285,15 @@ class Unity3DEnv(MultiAgentEnv): # 3DBallHard. "3DBallHard": Box( float("-inf"), float("inf"), (2, ), dtype=np.float32), + # GridFoodCollector. + "GridFoodCollector": MultiDiscrete([3, 3, 3, 2]), # Pyramids. "Pyramids": MultiDiscrete([5]), # SoccerStrikersVsGoalie. "Goalie": MultiDiscrete([3, 3, 3]), "Striker": MultiDiscrete([3, 3, 3]), + # Sorter. + "Sorter": MultiDiscrete([3, 3, 3]), # Tennis. "Tennis": Box(float("-inf"), float("inf"), (3, )), # VisualHallway. diff --git a/rllib/examples/unity3d_env_local.py b/rllib/examples/unity3d_env_local.py index 7dea67ad6..fcd662677 100644 --- a/rllib/examples/unity3d_env_local.py +++ b/rllib/examples/unity3d_env_local.py @@ -34,12 +34,19 @@ parser.add_argument( type=str, default="3DBall", choices=[ - "3DBall", "3DBallHard", "Pyramids", "SoccerStrikersVsGoalie", "Tennis", - "VisualHallway", "Walker" + "3DBall", + "3DBallHard", + "GridFoodCollector", + "Pyramids", + "SoccerStrikersVsGoalie", + "Sorter", + "Tennis", + "VisualHallway", + "Walker", ], help="The name of the Env to run in the Unity3D editor: `3DBall(Hard)?|" - "Pyramids|SoccerStrikersVsGoalie|Tennis|VisualHallway|Walker`" - "(feel free to add more and PR!)") + "Pyramids|GridFoodCollector|SoccerStrikersVsGoalie|Sorter|Tennis|" + "VisualHallway|Walker` (feel free to add more and PR!)") parser.add_argument( "--file-name", type=str, @@ -135,6 +142,13 @@ if __name__ == "__main__": "forward_net_activation": "relu", "inverse_net_activation": "relu", } + elif args.env == "GridFoodCollector": + config["model"] = { + "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], + [256, [10, 10], 1]], + } + elif args.env == "Sorter": + config["model"]["use_attention"] = True stop = { "training_iteration": args.stop_iters, @@ -148,7 +162,8 @@ if __name__ == "__main__": config=config, stop=stop, verbose=1, - checkpoint_freq=10, + checkpoint_freq=5, + checkpoint_at_end=True, restore=args.from_checkpoint) # And check the results. diff --git a/rllib/examples/vizdoom_with_attention_net.py b/rllib/examples/vizdoom_with_attention_net.py new file mode 100644 index 000000000..3bc8ad0ed --- /dev/null +++ b/rllib/examples/vizdoom_with_attention_net.py @@ -0,0 +1,77 @@ +import argparse +import os + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") +parser.add_argument("--num-cpus", type=int, default=0) +parser.add_argument( + "--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf") +parser.add_argument( + "--from-checkpoint", + type=str, + default=None, + help="Full path to a checkpoint file for restoring a previously saved " + "Trainer state.") +parser.add_argument("--num-workers", type=int, default=0) +parser.add_argument( + "--use-n-prev-actions", + type=int, + default=0, + help="How many of the previous actions to use as attention input.") +parser.add_argument( + "--use-n-prev-rewards", + type=int, + default=0, + help="How many of the previous rewards to use as attention input.") +parser.add_argument("--stop-iters", type=int, default=9999) +parser.add_argument("--stop-timesteps", type=int, default=100000000) +parser.add_argument("--stop-reward", type=float, default=1000.0) + +if __name__ == "__main__": + import ray + from ray import tune + + args = parser.parse_args() + + ray.init(num_cpus=args.num_cpus or None) + + config = { + "env": "VizdoomBasic-v0", + # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. + "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), + "model": { + "conv_filters": [], + "use_attention": True, + "attention_num_transformer_units": 1, + "attention_dim": 64, + "attention_num_heads": 2, + "attention_memory_inference": 100, + "attention_memory_training": 50, + "vf_share_layers": True, + "attention_use_n_prev_actions": args.use_n_prev_actions, + "attention_use_n_prev_rewards": args.use_n_prev_rewards, + }, + "framework": args.framework, + # Run with tracing enabled for tfe/tf2. + "eager_tracing": args.framework in ["tfe", "tf2"], + "num_workers": args.num_workers, + "vf_loss_coeff": 0.01, + } + + stop = { + "training_iteration": args.stop_iters, + "timesteps_total": args.stop_timesteps, + "episode_reward_mean": args.stop_reward, + } + + results = tune.run( + args.run, + config=config, + stop=stop, + verbose=2, + checkpoint_freq=5, + checkpoint_at_end=True, + restore=args.from_checkpoint, + ) + print(results) + ray.shutdown() diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 74ddcbeab..4c13994c8 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -115,10 +115,10 @@ MODEL_DEFAULTS: ModelConfigDict = { "attention_position_wise_mlp_dim": 32, # The initial bias values for the 2 GRU gates within a transformer unit. "attention_init_gru_gate_bias": 2.0, - # TODO: Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete). - # "attention_use_n_prev_actions": 0, + # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete). + "attention_use_n_prev_actions": 0, # Whether to feed r_{t-n:t-1} to GTrXL. - # "attention_use_n_prev_rewards": 0, + "attention_use_n_prev_rewards": 0, # == Atari == # Which framestacking size to use for Atari envs. diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index b254e933c..a090fc701 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -414,14 +414,13 @@ def restore_original_dimensions(obs: TensorType, observation space. """ - if tensorlib == "tf": + if tensorlib in ["tf", "tfe", "tf2"]: + assert tf is not None tensorlib = tf elif tensorlib == "torch": assert torch is not None tensorlib = torch original_space = getattr(obs_space, "original_space", obs_space) - if original_space is obs_space: - return obs return _unpack_obs(obs, original_space, tensorlib=tensorlib) @@ -450,7 +449,12 @@ def _unpack_obs(obs: TensorType, space: gym.Space, # Make an attempt to cache the result, if enough space left. if len(_cache) < 999: _cache[id(space)] = prep - if len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]: + # Already unpacked? + if (isinstance(space, gym.spaces.Tuple) and + isinstance(obs, (list, tuple))) or \ + (isinstance(space, gym.spaces.Dict) and isinstance(obs, dict)): + return obs + elif len(obs.shape) < 2 or obs.shape[-1] != prep.shape[0]: raise ValueError( "Expected flattened obs shape of [..., {}], got {}".format( prep.shape[0], obs.shape)) diff --git a/rllib/models/tf/attention_net.py b/rllib/models/tf/attention_net.py index fadd5ed89..9c486fa65 100644 --- a/rllib/models/tf/attention_net.py +++ b/rllib/models/tf/attention_net.py @@ -22,6 +22,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_ops import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType, List tf1, tf, tfv = try_import_tf() @@ -353,6 +354,9 @@ class AttentionWrapper(TFModelV2): super().__init__(obs_space, action_space, None, model_config, name) + self.use_n_prev_actions = model_config["attention_use_n_prev_actions"] + self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"] + if isinstance(action_space, Discrete): self.action_dim = action_space.n elif isinstance(action_space, MultiDiscrete): @@ -362,15 +366,30 @@ class AttentionWrapper(TFModelV2): else: self.action_dim = int(len(action_space)) + # Add prev-action/reward nodes to input to LSTM. + if self.use_n_prev_actions: + self.num_outputs += self.use_n_prev_actions * self.action_dim + if self.use_n_prev_rewards: + self.num_outputs += self.use_n_prev_rewards + cfg = model_config self.attention_dim = cfg["attention_dim"] + if self.num_outputs is not None: + in_space = gym.spaces.Box( + float("-inf"), + float("inf"), + shape=(self.num_outputs, ), + dtype=np.float32) + else: + in_space = obs_space + # Construct GTrXL sub-module w/ num_outputs=None (so it does not # create a logits/value output; we'll do this ourselves in this wrapper # here). self.gtrxl = GTrXLNet( - obs_space, + in_space, action_space, None, model_config, @@ -401,6 +420,20 @@ class AttentionWrapper(TFModelV2): self._value_branch = tf.keras.models.Model([input_], [out]) self.view_requirements = self.gtrxl.view_requirements + self.view_requirements["obs"].space = self.obs_space + + # Add prev-a/r to this model's view, if required. + if self.use_n_prev_actions: + self.view_requirements[SampleBatch.PREV_ACTIONS] = \ + ViewRequirement( + SampleBatch.ACTIONS, + space=self.action_space, + shift="-{}:-1".format(self.use_n_prev_actions)) + if self.use_n_prev_rewards: + self.view_requirements[SampleBatch.PREV_REWARDS] = \ + ViewRequirement( + SampleBatch.REWARDS, + shift="-{}:-1".format(self.use_n_prev_rewards)) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], @@ -410,8 +443,41 @@ class AttentionWrapper(TFModelV2): # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) + # Concat. prev-action/reward if required. + prev_a_r = [] + if self.use_n_prev_actions: + if isinstance(self.action_space, Discrete): + for i in range(self.use_n_prev_actions): + prev_a_r.append( + one_hot(input_dict[SampleBatch.PREV_ACTIONS][:, i], + self.action_space)) + elif isinstance(self.action_space, MultiDiscrete): + for i in range( + self.use_n_prev_actions, + step=self.action_space.shape[0]): + prev_a_r.append( + one_hot( + tf.cast( + input_dict[SampleBatch.PREV_ACTIONS] + [:, i:i + self.action_space.shape[0]], + tf.float32), self.action_space)) + else: + prev_a_r.append( + tf.reshape( + tf.cast(input_dict[SampleBatch.PREV_ACTIONS], + tf.float32), + [-1, self.use_n_prev_actions * self.action_dim])) + if self.use_n_prev_rewards: + prev_a_r.append( + tf.reshape( + tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), + [-1, self.use_n_prev_rewards])) + + if prev_a_r: + wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) + # Then through our GTrXL. - input_dict["obs_flat"] = wrapped_out + input_dict["obs_flat"] = input_dict["obs"] = wrapped_out self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) model_out = self._logits_branch(self._features) diff --git a/rllib/models/tf/visionnet.py b/rllib/models/tf/visionnet.py index 955ac1e52..6f29640f9 100644 --- a/rllib/models/tf/visionnet.py +++ b/rllib/models/tf/visionnet.py @@ -70,7 +70,8 @@ class VisionNetwork(TFModelV2): last_layer = tf.keras.layers.Conv2D( out_size, kernel, - strides=(stride, stride), + strides=stride + if isinstance(stride, (list, tuple)) else (stride, stride), activation=activation, padding="same", data_format="channels_last", @@ -85,7 +86,8 @@ class VisionNetwork(TFModelV2): last_layer = tf.keras.layers.Conv2D( out_size if post_fcnet_hiddens else num_outputs, kernel, - strides=(stride, stride), + strides=stride + if isinstance(stride, (list, tuple)) else (stride, stride), activation=activation, padding="valid", data_format="channels_last", @@ -107,7 +109,8 @@ class VisionNetwork(TFModelV2): last_layer = tf.keras.layers.Conv2D( out_size, kernel, - strides=(stride, stride), + strides=stride + if isinstance(stride, (list, tuple)) else (stride, stride), activation=activation, padding="valid", data_format="channels_last", @@ -169,8 +172,9 @@ class VisionNetwork(TFModelV2): # Build the value layers if vf_share_layers: - last_layer = tf.keras.layers.Lambda( - lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) + if not self.last_layer_is_flattened: + last_layer = tf.keras.layers.Lambda( + lambda x: tf.squeeze(x, axis=[1, 2]))(last_layer) value_out = tf.keras.layers.Dense( 1, name="value_out", @@ -183,7 +187,8 @@ class VisionNetwork(TFModelV2): last_layer = tf.keras.layers.Conv2D( out_size, kernel, - strides=(stride, stride), + strides=stride + if isinstance(stride, (list, tuple)) else (stride, stride), activation=activation, padding="same", data_format="channels_last", @@ -192,7 +197,8 @@ class VisionNetwork(TFModelV2): last_layer = tf.keras.layers.Conv2D( out_size, kernel, - strides=(stride, stride), + strides=stride + if isinstance(stride, (list, tuple)) else (stride, stride), activation=activation, padding="valid", data_format="channels_last", diff --git a/rllib/models/torch/attention_net.py b/rllib/models/torch/attention_net.py index 873c86b33..0a48bd91e 100644 --- a/rllib/models/torch/attention_net.py +++ b/rllib/models/torch/attention_net.py @@ -23,6 +23,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.view_requirement import ViewRequirement from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.torch_ops import one_hot from ray.rllib.utils.typing import ModelConfigDict, TensorType, List torch, nn = try_import_torch() @@ -250,6 +251,9 @@ class AttentionWrapper(TorchModelV2, nn.Module): nn.Module.__init__(self) super().__init__(obs_space, action_space, None, model_config, name) + self.use_n_prev_actions = model_config["attention_use_n_prev_actions"] + self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"] + if isinstance(action_space, Discrete): self.action_dim = action_space.n elif isinstance(action_space, MultiDiscrete): @@ -259,15 +263,30 @@ class AttentionWrapper(TorchModelV2, nn.Module): else: self.action_dim = int(len(action_space)) + # Add prev-action/reward nodes to input to LSTM. + if self.use_n_prev_actions: + self.num_outputs += self.use_n_prev_actions * self.action_dim + if self.use_n_prev_rewards: + self.num_outputs += self.use_n_prev_rewards + cfg = model_config self.attention_dim = cfg["attention_dim"] + if self.num_outputs is not None: + in_space = gym.spaces.Box( + float("-inf"), + float("inf"), + shape=(self.num_outputs, ), + dtype=np.float32) + else: + in_space = obs_space + # Construct GTrXL sub-module w/ num_outputs=None (so it does not # create a logits/value output; we'll do this ourselves in this wrapper # here). self.gtrxl = GTrXLNet( - obs_space, + in_space, action_space, None, model_config, @@ -299,6 +318,20 @@ class AttentionWrapper(TorchModelV2, nn.Module): initializer=torch.nn.init.xavier_uniform_) self.view_requirements = self.gtrxl.view_requirements + self.view_requirements["obs"].space = self.obs_space + + # Add prev-a/r to this model's view, if required. + if self.use_n_prev_actions: + self.view_requirements[SampleBatch.PREV_ACTIONS] = \ + ViewRequirement( + SampleBatch.ACTIONS, + space=self.action_space, + shift="-{}:-1".format(self.use_n_prev_actions)) + if self.use_n_prev_rewards: + self.view_requirements[SampleBatch.PREV_REWARDS] = \ + ViewRequirement( + SampleBatch.REWARDS, + shift="-{}:-1".format(self.use_n_prev_rewards)) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], @@ -308,12 +341,43 @@ class AttentionWrapper(TorchModelV2, nn.Module): # Push obs through "unwrapped" net's `forward()` first. wrapped_out, _ = self._wrapped_forward(input_dict, [], None) + # Concat. prev-action/reward if required. + prev_a_r = [] + if self.use_n_prev_actions: + if isinstance(self.action_space, Discrete): + for i in range(self.use_n_prev_actions): + prev_a_r.append( + one_hot( + input_dict[SampleBatch.PREV_ACTIONS][:, i].float(), + self.action_space)) + elif isinstance(self.action_space, MultiDiscrete): + for i in range( + self.use_n_prev_actions, + step=self.action_space.shape[0]): + prev_a_r.append( + one_hot( + input_dict[SampleBatch.PREV_ACTIONS] + [:, i:i + self.action_space.shape[0]].float(), + self.action_space)) + else: + prev_a_r.append( + torch.reshape( + input_dict[SampleBatch.PREV_ACTIONS].float(), + [-1, self.use_n_prev_actions * self.action_dim])) + if self.use_n_prev_rewards: + prev_a_r.append( + torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), + [-1, self.use_n_prev_rewards])) + + if prev_a_r: + wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) + # Then through our GTrXL. - input_dict["obs_flat"] = wrapped_out + input_dict["obs_flat"] = input_dict["obs"] = wrapped_out self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) model_out = self._logits_branch(self._features) - return model_out, [torch.squeeze(m, 0) for m in memory_outs] + return model_out, memory_outs @override(ModelV2) def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]: diff --git a/rllib/models/torch/misc.py b/rllib/models/torch/misc.py index 9f6d8234e..b314b3e87 100644 --- a/rllib/models/torch/misc.py +++ b/rllib/models/torch/misc.py @@ -39,7 +39,10 @@ def same_padding(in_size: Tuple[int, int], filter_size: Tuple[int, int], filter_height, filter_width = filter_size, filter_size else: filter_height, filter_width = filter_size - stride_height, stride_width = stride_size + if isinstance(stride_size, (int, float)): + stride_height, stride_width = int(stride_size), int(stride_size) + else: + stride_height, stride_width = int(stride_size[0]), int(stride_size[1]) out_height = np.ceil(float(in_height) / float(stride_height)) out_width = np.ceil(float(in_width) / float(stride_width)) diff --git a/rllib/models/torch/visionnet.py b/rllib/models/torch/visionnet.py index 133c851f5..2eec7c382 100644 --- a/rllib/models/torch/visionnet.py +++ b/rllib/models/torch/visionnet.py @@ -60,7 +60,7 @@ class VisionNetwork(TorchModelV2, nn.Module): in_size = [w, h] for out_channels, kernel, stride in filters[:-1]: - padding, out_size = same_padding(in_size, kernel, [stride, stride]) + padding, out_size = same_padding(in_size, kernel, stride) layers.append( SlimConv2d( in_channels, @@ -172,8 +172,7 @@ class VisionNetwork(TorchModelV2, nn.Module): (w, h, in_channels) = obs_space.shape in_size = [w, h] for out_channels, kernel, stride in filters[:-1]: - padding, out_size = same_padding(in_size, kernel, - [stride, stride]) + padding, out_size = same_padding(in_size, kernel, stride) vf_layers.append( SlimConv2d( in_channels, diff --git a/rllib/models/utils.py b/rllib/models/utils.py index f866cc944..b1e84cba0 100644 --- a/rllib/models/utils.py +++ b/rllib/models/utils.py @@ -70,18 +70,29 @@ def get_filter_config(shape): inside a model config dict. """ shape = list(shape) + # VizdoomGym. + filters_240x320x = [ + [16, [12, 16], [7, 9]], + [32, [6, 6], 4], + [256, [9, 9], 1], + ] + # Atari. filters_84x84 = [ [16, [8, 8], 4], [32, [4, 4], 2], [256, [11, 11], 1], ] + # Small (1/2) Atari. filters_42x42 = [ [16, [4, 4], 2], [32, [4, 4], 2], [256, [11, 11], 1], ] - if len(shape) in [2, 3] and (shape[:2] == [84, 84] - or shape[1:] == [84, 84]): + if len(shape) in [2, 3] and (shape[:2] == [240, 320] + or shape[1:] == [240, 320]): + return filters_240x320x + elif len(shape) in [2, 3] and (shape[:2] == [84, 84] + or shape[1:] == [84, 84]): return filters_84x84 elif len(shape) in [2, 3] and (shape[:2] == [42, 42] or shape[1:] == [42, 42]): diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index b2d491d00..6a0d88d30 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -214,15 +214,21 @@ class DynamicTFPolicy(TFPolicy): self.view_requirements, existing_inputs) else: action_ph = ModelCatalog.get_action_placeholder(action_space) - prev_action_ph = ModelCatalog.get_action_placeholder( - action_space, "prev_action") if self.config["_use_trajectory_view_api"]: + prev_action_ph = {} + if SampleBatch.PREV_ACTIONS not in self.view_requirements: + prev_action_ph = { + SampleBatch.PREV_ACTIONS: ModelCatalog. + get_action_placeholder(action_space, "prev_action") + } self._input_dict, self._dummy_batch = \ self._get_input_dict_and_dummy_batch( self.view_requirements, - {SampleBatch.ACTIONS: action_ph, - SampleBatch.PREV_ACTIONS: prev_action_ph}) + dict({SampleBatch.ACTIONS: action_ph}, + **prev_action_ph)) else: + prev_action_ph = ModelCatalog.get_action_placeholder( + action_space, "prev_action") self._input_dict = { SampleBatch.CUR_OBS: tf1.placeholder( tf.float32,