mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
9f0f542660
commit
c6ee3cdff4
3 changed files with 141 additions and 1 deletions
|
@ -1,4 +1,5 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import numpy as np
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
@ -20,6 +21,7 @@ from ray.rllib.evaluation.metrics import collect_metrics
|
||||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||||
|
from ray.rllib.utils.spaces import space_utils
|
||||||
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
from ray.rllib.utils.framework import try_import_tf, TensorStructType
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||||
|
@ -826,6 +828,95 @@ class Trainer(Trainable):
|
||||||
else:
|
else:
|
||||||
return result[0] # backwards compatibility
|
return result[0] # backwards compatibility
|
||||||
|
|
||||||
|
def compute_actions(self,
|
||||||
|
observations,
|
||||||
|
state=None,
|
||||||
|
prev_action=None,
|
||||||
|
prev_reward=None,
|
||||||
|
info=None,
|
||||||
|
policy_id=DEFAULT_POLICY_ID,
|
||||||
|
full_fetch=False,
|
||||||
|
explore=None):
|
||||||
|
"""Computes an action for the specified policy on the local Worker.
|
||||||
|
|
||||||
|
Note that you can also access the policy object through
|
||||||
|
self.get_policy(policy_id) and call compute_actions() on it directly.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
observation (obj): observation from the environment.
|
||||||
|
state (dict): RNN hidden state, if any. If state is not None,
|
||||||
|
then all of compute_single_action(...) is returned
|
||||||
|
(computed action, rnn state(s), logits dictionary).
|
||||||
|
Otherwise compute_single_action(...)[0] is returned
|
||||||
|
(computed action).
|
||||||
|
prev_action (obj): previous action value, if any
|
||||||
|
prev_reward (int): previous reward, if any
|
||||||
|
info (dict): info object, if any
|
||||||
|
policy_id (str): Policy to query (only applies to multi-agent).
|
||||||
|
full_fetch (bool): Whether to return extra action fetch results.
|
||||||
|
This is always set to True if RNN state is specified.
|
||||||
|
explore (bool): Whether to pick an exploitation or exploration
|
||||||
|
action (default: None -> use self.config["explore"]).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
any: The computed action if full_fetch=False, or
|
||||||
|
tuple: The full output of policy.compute_actions() if
|
||||||
|
full_fetch=True or we have an RNN-based Policy.
|
||||||
|
"""
|
||||||
|
# Preprocess obs and states
|
||||||
|
stateDefined = state is not None
|
||||||
|
policy = self.get_policy(policy_id)
|
||||||
|
filtered_obs, filtered_state = [], []
|
||||||
|
for agent_id, ob in observations.items():
|
||||||
|
worker = self.workers.local_worker()
|
||||||
|
preprocessed = worker.preprocessors[policy_id].transform(ob)
|
||||||
|
filtered = worker.filters[policy_id](preprocessed, update=False)
|
||||||
|
filtered_obs.append(filtered)
|
||||||
|
if state is None:
|
||||||
|
continue
|
||||||
|
elif agent_id in state:
|
||||||
|
filtered_state.append(state[agent_id])
|
||||||
|
else:
|
||||||
|
filtered_state.append(policy.get_initial_state())
|
||||||
|
|
||||||
|
# Batch obs and states
|
||||||
|
obs_batch = np.stack(filtered_obs)
|
||||||
|
if state is None:
|
||||||
|
state = []
|
||||||
|
else:
|
||||||
|
state = list(zip(*filtered_state))
|
||||||
|
state = [np.stack(s) for s in state]
|
||||||
|
|
||||||
|
# Figure out the current (sample) time step and pass it into Policy.
|
||||||
|
self.global_vars["timestep"] += 1
|
||||||
|
|
||||||
|
# Batch compute actions
|
||||||
|
actions, states, infos = policy.compute_actions(
|
||||||
|
obs_batch,
|
||||||
|
state,
|
||||||
|
prev_action,
|
||||||
|
prev_reward,
|
||||||
|
info,
|
||||||
|
clip_actions=self.config["clip_actions"],
|
||||||
|
explore=explore,
|
||||||
|
timestep=self.global_vars["timestep"])
|
||||||
|
|
||||||
|
# Unbatch actions for the environment
|
||||||
|
atns, actions = space_utils.unbatch(actions), {}
|
||||||
|
for key, atn in zip(observations, atns):
|
||||||
|
actions[key] = atn
|
||||||
|
|
||||||
|
# Unbatch states into a dict
|
||||||
|
unbatched_states = {}
|
||||||
|
for idx, agent_id in enumerate(observations):
|
||||||
|
unbatched_states[agent_id] = [s[idx] for s in states]
|
||||||
|
|
||||||
|
# Return only actions or full tuple
|
||||||
|
if stateDefined or full_fetch:
|
||||||
|
return actions, unbatched_states, infos
|
||||||
|
else:
|
||||||
|
return actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _name(self) -> str:
|
def _name(self) -> str:
|
||||||
"""Subclasses should override this to declare their name."""
|
"""Subclasses should override this to declare their name."""
|
||||||
|
|
48
rllib/utils/spaces/flexdict.py
Normal file
48
rllib/utils/spaces/flexdict.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import gym
|
||||||
|
|
||||||
|
from ray.rllib.utils.annotations import PublicAPI
|
||||||
|
|
||||||
|
|
||||||
|
@PublicAPI
|
||||||
|
class FlexDict(gym.spaces.Dict):
|
||||||
|
"""Gym Dictionary with arbitrary keys updatable after instantiation
|
||||||
|
|
||||||
|
Example:
|
||||||
|
space = FlexDict({})
|
||||||
|
space['key'] = spaces.Box(4,)
|
||||||
|
See also: documentation for gym.spaces.Dict
|
||||||
|
"""
|
||||||
|
def __init__(self, spaces=None, **spaces_kwargs):
|
||||||
|
err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
|
||||||
|
assert (spaces is None) or (not spaces_kwargs), err
|
||||||
|
|
||||||
|
if spaces is None:
|
||||||
|
spaces = spaces_kwargs
|
||||||
|
|
||||||
|
self.spaces = spaces
|
||||||
|
for space in spaces.values():
|
||||||
|
self.assertSpace(space)
|
||||||
|
|
||||||
|
# None for shape and dtype, since it'll require special handling
|
||||||
|
self.np_random = None
|
||||||
|
self.shape = None
|
||||||
|
self.dtype = None
|
||||||
|
self.seed()
|
||||||
|
|
||||||
|
def assertSpace(self, space):
|
||||||
|
err = "Values of the dict should be instances of gym.Space"
|
||||||
|
assert issubclass(type(space), gym.spaces.Space), err
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
return {k: space.sample() for k, space in self.spaces.items()}
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self.spaces[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, space):
|
||||||
|
self.assertSpace(space)
|
||||||
|
self.spaces[key] = space
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "FlexDict(" + ", ".join(
|
||||||
|
[str(k) + ":" + str(s) for k, s in self.spaces.items()]) + ")"
|
|
@ -21,10 +21,11 @@ def flatten_space(space):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _helper_flatten(space_, l):
|
def _helper_flatten(space_, l):
|
||||||
|
from ray.rllib.utils.spaces.flexdict import FlexDict
|
||||||
if isinstance(space_, Tuple):
|
if isinstance(space_, Tuple):
|
||||||
for s in space_:
|
for s in space_:
|
||||||
_helper_flatten(s, l)
|
_helper_flatten(s, l)
|
||||||
elif isinstance(space_, Dict):
|
elif isinstance(space_, (Dict, FlexDict)):
|
||||||
for k in space_.spaces:
|
for k in space_.spaces:
|
||||||
_helper_flatten(space_[k], l)
|
_helper_flatten(space_[k], l)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Add table
Reference in a new issue