diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 07b0ecddc..0ab96e8d9 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1,4 +1,5 @@ from datetime import datetime +import numpy as np import copy import logging import math @@ -20,6 +21,7 @@ from ray.rllib.evaluation.metrics import collect_metrics from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.utils import FilterManager, deep_update, merge_dicts +from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.framework import try_import_tf, TensorStructType from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning @@ -826,6 +828,95 @@ class Trainer(Trainable): else: return result[0] # backwards compatibility + def compute_actions(self, + observations, + state=None, + prev_action=None, + prev_reward=None, + info=None, + policy_id=DEFAULT_POLICY_ID, + full_fetch=False, + explore=None): + """Computes an action for the specified policy on the local Worker. + + Note that you can also access the policy object through + self.get_policy(policy_id) and call compute_actions() on it directly. + + Arguments: + observation (obj): observation from the environment. + state (dict): RNN hidden state, if any. If state is not None, + then all of compute_single_action(...) is returned + (computed action, rnn state(s), logits dictionary). + Otherwise compute_single_action(...)[0] is returned + (computed action). + prev_action (obj): previous action value, if any + prev_reward (int): previous reward, if any + info (dict): info object, if any + policy_id (str): Policy to query (only applies to multi-agent). + full_fetch (bool): Whether to return extra action fetch results. + This is always set to True if RNN state is specified. + explore (bool): Whether to pick an exploitation or exploration + action (default: None -> use self.config["explore"]). + + Returns: + any: The computed action if full_fetch=False, or + tuple: The full output of policy.compute_actions() if + full_fetch=True or we have an RNN-based Policy. + """ + # Preprocess obs and states + stateDefined = state is not None + policy = self.get_policy(policy_id) + filtered_obs, filtered_state = [], [] + for agent_id, ob in observations.items(): + worker = self.workers.local_worker() + preprocessed = worker.preprocessors[policy_id].transform(ob) + filtered = worker.filters[policy_id](preprocessed, update=False) + filtered_obs.append(filtered) + if state is None: + continue + elif agent_id in state: + filtered_state.append(state[agent_id]) + else: + filtered_state.append(policy.get_initial_state()) + + # Batch obs and states + obs_batch = np.stack(filtered_obs) + if state is None: + state = [] + else: + state = list(zip(*filtered_state)) + state = [np.stack(s) for s in state] + + # Figure out the current (sample) time step and pass it into Policy. + self.global_vars["timestep"] += 1 + + # Batch compute actions + actions, states, infos = policy.compute_actions( + obs_batch, + state, + prev_action, + prev_reward, + info, + clip_actions=self.config["clip_actions"], + explore=explore, + timestep=self.global_vars["timestep"]) + + # Unbatch actions for the environment + atns, actions = space_utils.unbatch(actions), {} + for key, atn in zip(observations, atns): + actions[key] = atn + + # Unbatch states into a dict + unbatched_states = {} + for idx, agent_id in enumerate(observations): + unbatched_states[agent_id] = [s[idx] for s in states] + + # Return only actions or full tuple + if stateDefined or full_fetch: + return actions, unbatched_states, infos + else: + return actions + @property def _name(self) -> str: """Subclasses should override this to declare their name.""" diff --git a/rllib/utils/spaces/flexdict.py b/rllib/utils/spaces/flexdict.py new file mode 100644 index 000000000..f9bf5fdcd --- /dev/null +++ b/rllib/utils/spaces/flexdict.py @@ -0,0 +1,48 @@ +import gym + +from ray.rllib.utils.annotations import PublicAPI + + +@PublicAPI +class FlexDict(gym.spaces.Dict): + """Gym Dictionary with arbitrary keys updatable after instantiation + + Example: + space = FlexDict({}) + space['key'] = spaces.Box(4,) + See also: documentation for gym.spaces.Dict + """ + def __init__(self, spaces=None, **spaces_kwargs): + err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)" + assert (spaces is None) or (not spaces_kwargs), err + + if spaces is None: + spaces = spaces_kwargs + + self.spaces = spaces + for space in spaces.values(): + self.assertSpace(space) + + # None for shape and dtype, since it'll require special handling + self.np_random = None + self.shape = None + self.dtype = None + self.seed() + + def assertSpace(self, space): + err = "Values of the dict should be instances of gym.Space" + assert issubclass(type(space), gym.spaces.Space), err + + def sample(self): + return {k: space.sample() for k, space in self.spaces.items()} + + def __getitem__(self, key): + return self.spaces[key] + + def __setitem__(self, key, space): + self.assertSpace(space) + self.spaces[key] = space + + def __repr__(self): + return "FlexDict(" + ", ".join( + [str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")" diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index 8111723e3..8cdb7d3b6 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -21,10 +21,11 @@ def flatten_space(space): """ def _helper_flatten(space_, l): + from ray.rllib.utils.spaces.flexdict import FlexDict if isinstance(space_, Tuple): for s in space_: _helper_flatten(s, l) - elif isinstance(space_, Dict): + elif isinstance(space_, (Dict, FlexDict)): for k in space_.spaces: _helper_flatten(space_[k], l) else: