Refactor #8792 to integrate latest master (#8956)

This commit is contained in:
Joseph Suarez 2020-06-17 04:55:52 -04:00 committed by GitHub
parent 9f0f542660
commit c6ee3cdff4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 141 additions and 1 deletions

View file

@ -1,4 +1,5 @@
from datetime import datetime
import numpy as np
import copy
import logging
import math
@ -20,6 +21,7 @@ from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.rllib.utils.spaces import space_utils
from ray.rllib.utils.framework import try_import_tf, TensorStructType
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
@ -826,6 +828,95 @@ class Trainer(Trainable):
else:
return result[0] # backwards compatibility
def compute_actions(self,
observations,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id=DEFAULT_POLICY_ID,
full_fetch=False,
explore=None):
"""Computes an action for the specified policy on the local Worker.
Note that you can also access the policy object through
self.get_policy(policy_id) and call compute_actions() on it directly.
Arguments:
observation (obj): observation from the environment.
state (dict): RNN hidden state, if any. If state is not None,
then all of compute_single_action(...) is returned
(computed action, rnn state(s), logits dictionary).
Otherwise compute_single_action(...)[0] is returned
(computed action).
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
policy_id (str): Policy to query (only applies to multi-agent).
full_fetch (bool): Whether to return extra action fetch results.
This is always set to True if RNN state is specified.
explore (bool): Whether to pick an exploitation or exploration
action (default: None -> use self.config["explore"]).
Returns:
any: The computed action if full_fetch=False, or
tuple: The full output of policy.compute_actions() if
full_fetch=True or we have an RNN-based Policy.
"""
# Preprocess obs and states
stateDefined = state is not None
policy = self.get_policy(policy_id)
filtered_obs, filtered_state = [], []
for agent_id, ob in observations.items():
worker = self.workers.local_worker()
preprocessed = worker.preprocessors[policy_id].transform(ob)
filtered = worker.filters[policy_id](preprocessed, update=False)
filtered_obs.append(filtered)
if state is None:
continue
elif agent_id in state:
filtered_state.append(state[agent_id])
else:
filtered_state.append(policy.get_initial_state())
# Batch obs and states
obs_batch = np.stack(filtered_obs)
if state is None:
state = []
else:
state = list(zip(*filtered_state))
state = [np.stack(s) for s in state]
# Figure out the current (sample) time step and pass it into Policy.
self.global_vars["timestep"] += 1
# Batch compute actions
actions, states, infos = policy.compute_actions(
obs_batch,
state,
prev_action,
prev_reward,
info,
clip_actions=self.config["clip_actions"],
explore=explore,
timestep=self.global_vars["timestep"])
# Unbatch actions for the environment
atns, actions = space_utils.unbatch(actions), {}
for key, atn in zip(observations, atns):
actions[key] = atn
# Unbatch states into a dict
unbatched_states = {}
for idx, agent_id in enumerate(observations):
unbatched_states[agent_id] = [s[idx] for s in states]
# Return only actions or full tuple
if stateDefined or full_fetch:
return actions, unbatched_states, infos
else:
return actions
@property
def _name(self) -> str:
"""Subclasses should override this to declare their name."""

View file

@ -0,0 +1,48 @@
import gym
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class FlexDict(gym.spaces.Dict):
"""Gym Dictionary with arbitrary keys updatable after instantiation
Example:
space = FlexDict({})
space['key'] = spaces.Box(4,)
See also: documentation for gym.spaces.Dict
"""
def __init__(self, spaces=None, **spaces_kwargs):
err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
assert (spaces is None) or (not spaces_kwargs), err
if spaces is None:
spaces = spaces_kwargs
self.spaces = spaces
for space in spaces.values():
self.assertSpace(space)
# None for shape and dtype, since it'll require special handling
self.np_random = None
self.shape = None
self.dtype = None
self.seed()
def assertSpace(self, space):
err = "Values of the dict should be instances of gym.Space"
assert issubclass(type(space), gym.spaces.Space), err
def sample(self):
return {k: space.sample() for k, space in self.spaces.items()}
def __getitem__(self, key):
return self.spaces[key]
def __setitem__(self, key, space):
self.assertSpace(space)
self.spaces[key] = space
def __repr__(self):
return "FlexDict(" + ", ".join(
[str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"

View file

@ -21,10 +21,11 @@ def flatten_space(space):
"""
def _helper_flatten(space_, l):
from ray.rllib.utils.spaces.flexdict import FlexDict
if isinstance(space_, Tuple):
for s in space_:
_helper_flatten(s, l)
elif isinstance(space_, Dict):
elif isinstance(space_, (Dict, FlexDict)):
for k in space_.spaces:
_helper_flatten(space_[k], l)
else: