mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
parent
9f0f542660
commit
c6ee3cdff4
3 changed files with 141 additions and 1 deletions
|
@ -1,4 +1,5 @@
|
|||
from datetime import datetime
|
||||
import numpy as np
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
@ -20,6 +21,7 @@ from ray.rllib.evaluation.metrics import collect_metrics
|
|||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.spaces import space_utils
|
||||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
|
@ -826,6 +828,95 @@ class Trainer(Trainable):
|
|||
else:
|
||||
return result[0] # backwards compatibility
|
||||
|
||||
def compute_actions(self,
|
||||
observations,
|
||||
state=None,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
policy_id=DEFAULT_POLICY_ID,
|
||||
full_fetch=False,
|
||||
explore=None):
|
||||
"""Computes an action for the specified policy on the local Worker.
|
||||
|
||||
Note that you can also access the policy object through
|
||||
self.get_policy(policy_id) and call compute_actions() on it directly.
|
||||
|
||||
Arguments:
|
||||
observation (obj): observation from the environment.
|
||||
state (dict): RNN hidden state, if any. If state is not None,
|
||||
then all of compute_single_action(...) is returned
|
||||
(computed action, rnn state(s), logits dictionary).
|
||||
Otherwise compute_single_action(...)[0] is returned
|
||||
(computed action).
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
policy_id (str): Policy to query (only applies to multi-agent).
|
||||
full_fetch (bool): Whether to return extra action fetch results.
|
||||
This is always set to True if RNN state is specified.
|
||||
explore (bool): Whether to pick an exploitation or exploration
|
||||
action (default: None -> use self.config["explore"]).
|
||||
|
||||
Returns:
|
||||
any: The computed action if full_fetch=False, or
|
||||
tuple: The full output of policy.compute_actions() if
|
||||
full_fetch=True or we have an RNN-based Policy.
|
||||
"""
|
||||
# Preprocess obs and states
|
||||
stateDefined = state is not None
|
||||
policy = self.get_policy(policy_id)
|
||||
filtered_obs, filtered_state = [], []
|
||||
for agent_id, ob in observations.items():
|
||||
worker = self.workers.local_worker()
|
||||
preprocessed = worker.preprocessors[policy_id].transform(ob)
|
||||
filtered = worker.filters[policy_id](preprocessed, update=False)
|
||||
filtered_obs.append(filtered)
|
||||
if state is None:
|
||||
continue
|
||||
elif agent_id in state:
|
||||
filtered_state.append(state[agent_id])
|
||||
else:
|
||||
filtered_state.append(policy.get_initial_state())
|
||||
|
||||
# Batch obs and states
|
||||
obs_batch = np.stack(filtered_obs)
|
||||
if state is None:
|
||||
state = []
|
||||
else:
|
||||
state = list(zip(*filtered_state))
|
||||
state = [np.stack(s) for s in state]
|
||||
|
||||
# Figure out the current (sample) time step and pass it into Policy.
|
||||
self.global_vars["timestep"] += 1
|
||||
|
||||
# Batch compute actions
|
||||
actions, states, infos = policy.compute_actions(
|
||||
obs_batch,
|
||||
state,
|
||||
prev_action,
|
||||
prev_reward,
|
||||
info,
|
||||
clip_actions=self.config["clip_actions"],
|
||||
explore=explore,
|
||||
timestep=self.global_vars["timestep"])
|
||||
|
||||
# Unbatch actions for the environment
|
||||
atns, actions = space_utils.unbatch(actions), {}
|
||||
for key, atn in zip(observations, atns):
|
||||
actions[key] = atn
|
||||
|
||||
# Unbatch states into a dict
|
||||
unbatched_states = {}
|
||||
for idx, agent_id in enumerate(observations):
|
||||
unbatched_states[agent_id] = [s[idx] for s in states]
|
||||
|
||||
# Return only actions or full tuple
|
||||
if stateDefined or full_fetch:
|
||||
return actions, unbatched_states, infos
|
||||
else:
|
||||
return actions
|
||||
|
||||
@property
|
||||
def _name(self) -> str:
|
||||
"""Subclasses should override this to declare their name."""
|
||||
|
|
48
rllib/utils/spaces/flexdict.py
Normal file
48
rllib/utils/spaces/flexdict.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import gym
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class FlexDict(gym.spaces.Dict):
|
||||
"""Gym Dictionary with arbitrary keys updatable after instantiation
|
||||
|
||||
Example:
|
||||
space = FlexDict({})
|
||||
space['key'] = spaces.Box(4,)
|
||||
See also: documentation for gym.spaces.Dict
|
||||
"""
|
||||
def __init__(self, spaces=None, **spaces_kwargs):
|
||||
err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||
assert (spaces is None) or (not spaces_kwargs), err
|
||||
|
||||
if spaces is None:
|
||||
spaces = spaces_kwargs
|
||||
|
||||
self.spaces = spaces
|
||||
for space in spaces.values():
|
||||
self.assertSpace(space)
|
||||
|
||||
# None for shape and dtype, since it'll require special handling
|
||||
self.np_random = None
|
||||
self.shape = None
|
||||
self.dtype = None
|
||||
self.seed()
|
||||
|
||||
def assertSpace(self, space):
|
||||
err = "Values of the dict should be instances of gym.Space"
|
||||
assert issubclass(type(space), gym.spaces.Space), err
|
||||
|
||||
def sample(self):
|
||||
return {k: space.sample() for k, space in self.spaces.items()}
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.spaces[key]
|
||||
|
||||
def __setitem__(self, key, space):
|
||||
self.assertSpace(space)
|
||||
self.spaces[key] = space
|
||||
|
||||
def __repr__(self):
|
||||
return "FlexDict(" + ", ".join(
|
||||
[str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
|
|
@ -21,10 +21,11 @@ def flatten_space(space):
|
|||
"""
|
||||
|
||||
def _helper_flatten(space_, l):
|
||||
from ray.rllib.utils.spaces.flexdict import FlexDict
|
||||
if isinstance(space_, Tuple):
|
||||
for s in space_:
|
||||
_helper_flatten(s, l)
|
||||
elif isinstance(space_, Dict):
|
||||
elif isinstance(space_, (Dict, FlexDict)):
|
||||
for k in space_.spaces:
|
||||
_helper_flatten(space_[k], l)
|
||||
else:
|
||||
|
|
Loading…
Add table
Reference in a new issue