mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[rllib] annotate public vs developer vs private APIs (#3808)
This commit is contained in:
parent
01e18b47f4
commit
04ec47cbd4
45 changed files with 562 additions and 274 deletions
|
@ -6,6 +6,15 @@ Development Install
|
|||
|
||||
You can develop RLlib locally without needing to compile Ray by using the `setup-rllib-dev.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/setup-rllib-dev.py>`__ script. This sets up links between the ``rllib`` dir in your git repo and the one bundled with the ``ray`` package. When using this script, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master <https://github.com/ray-project/ray>`__ and have the latest `wheel <https://ray.readthedocs.io/en/latest/installation.html>`__ installed.)
|
||||
|
||||
API Stability
|
||||
-------------
|
||||
|
||||
Objects and methods annotated with ``@PublicAPI`` or ``@DeveloperAPI`` have the following API compatibility guarantees:
|
||||
|
||||
.. autofunction:: ray.rllib.utils.annotations.PublicAPI
|
||||
|
||||
.. autofunction:: ray.rllib.utils.annotations.DeveloperAPI
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
|
|
|
@ -310,6 +310,6 @@ Note that envs can read from different partitions of the logs based on the ``wor
|
|||
Batch Asynchronous
|
||||
------------------
|
||||
|
||||
The lowest-level "catch-all" environment supported by RLlib is `AsyncVectorEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/async_vector_env.py>`__. AsyncVectorEnv models multiple agents executing asynchronously in multiple environments. A call to ``poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents can be sent back via ``send_actions()``. This interface can be subclassed directly to support batched simulators such as `ELF <https://github.com/facebookresearch/ELF>`__.
|
||||
The lowest-level "catch-all" environment supported by RLlib is `BaseEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/base_env.py>`__. BaseEnv models multiple agents executing asynchronously in multiple environments. A call to ``poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents can be sent back via ``send_actions()``. This interface can be subclassed directly to support batched simulators such as `ELF <https://github.com/facebookresearch/ELF>`__.
|
||||
|
||||
Under the hood, all other envs are converted to AsyncVectorEnv by RLlib so that there is a common internal path for policy evaluation.
|
||||
Under the hood, all other envs are converted to BaseEnv by RLlib so that there is a common internal path for policy evaluation.
|
||||
|
|
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 77 KiB After Width: | Height: | Size: 75 KiB |
|
@ -10,7 +10,7 @@ from ray.tune.registry import register_trainable
|
|||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
|
@ -47,7 +47,7 @@ __all__ = [
|
|||
"TFPolicyGraph",
|
||||
"PolicyEvaluator",
|
||||
"SampleBatch",
|
||||
"AsyncVectorEnv",
|
||||
"BaseEnv",
|
||||
"MultiAgentEnv",
|
||||
"VectorEnv",
|
||||
"ExternalEnv",
|
||||
|
|
|
@ -18,7 +18,7 @@ from ray.rllib.models import MODEL_DEFAULTS
|
|||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
|
@ -182,6 +182,7 @@ COMMON_CONFIG = {
|
|||
# yapf: enable
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def with_common_config(extra_config):
|
||||
"""Returns the given config dict merged with common agent confs."""
|
||||
|
||||
|
@ -196,6 +197,7 @@ def with_base_config(base_config, extra_config):
|
|||
return config
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class Agent(Trainable):
|
||||
"""All RLlib agents extend this base class.
|
||||
|
||||
|
@ -214,6 +216,7 @@ class Agent(Trainable):
|
|||
"custom_resources_per_worker"
|
||||
]
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
|
@ -266,6 +269,7 @@ class Agent(Trainable):
|
|||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
@override(Trainable)
|
||||
@PublicAPI
|
||||
def train(self):
|
||||
"""Overrides super.train to synchronize global vars."""
|
||||
|
||||
|
@ -344,11 +348,13 @@ class Agent(Trainable):
|
|||
extra_data = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self):
|
||||
"""Subclasses should override this for custom initialization."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def compute_action(self,
|
||||
observation,
|
||||
state=None,
|
||||
|
@ -404,6 +410,7 @@ class Agent(Trainable):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
|
@ -413,6 +420,7 @@ class Agent(Trainable):
|
|||
|
||||
return self.local_evaluator.get_policy(policy_id)
|
||||
|
||||
@PublicAPI
|
||||
def get_weights(self, policies=None):
|
||||
"""Return a dictionary of policy ids to weights.
|
||||
|
||||
|
@ -422,6 +430,7 @@ class Agent(Trainable):
|
|||
"""
|
||||
return self.local_evaluator.get_weights(policies)
|
||||
|
||||
@PublicAPI
|
||||
def set_weights(self, weights):
|
||||
"""Set policy weights by policy id.
|
||||
|
||||
|
@ -430,6 +439,7 @@ class Agent(Trainable):
|
|||
"""
|
||||
self.local_evaluator.set_weights(weights)
|
||||
|
||||
@DeveloperAPI
|
||||
def make_local_evaluator(self, env_creator, policy_graph):
|
||||
"""Convenience method to return configured local evaluator."""
|
||||
|
||||
|
@ -444,6 +454,7 @@ class Agent(Trainable):
|
|||
config["local_evaluator_tf_session_args"]
|
||||
}))
|
||||
|
||||
@DeveloperAPI
|
||||
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
||||
"""Convenience method to return a number of remote evaluators."""
|
||||
|
||||
|
@ -459,6 +470,7 @@ class Agent(Trainable):
|
|||
self.config) for i in range(count)
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Export policy model with given policy_id to local directory.
|
||||
|
||||
|
@ -474,6 +486,7 @@ class Agent(Trainable):
|
|||
"""
|
||||
self.local_evaluator.export_policy_model(export_dir, policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
|
|
6
python/ray/rllib/env/__init__.py
vendored
6
python/ray/rllib/env/__init__.py
vendored
|
@ -1,4 +1,4 @@
|
|||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.serving_env import ServingEnv
|
||||
|
@ -6,6 +6,6 @@ from ray.rllib.env.vector_env import VectorEnv
|
|||
from ray.rllib.env.env_context import EnvContext
|
||||
|
||||
__all__ = [
|
||||
"AsyncVectorEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv",
|
||||
"ServingEnv", "EnvContext"
|
||||
"BaseEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv", "ServingEnv",
|
||||
"EnvContext"
|
||||
]
|
||||
|
|
|
@ -5,23 +5,24 @@ from __future__ import print_function
|
|||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
|
||||
class AsyncVectorEnv(object):
|
||||
@PublicAPI
|
||||
class BaseEnv(object):
|
||||
"""The lowest-level env interface used by RLlib for sampling.
|
||||
|
||||
AsyncVectorEnv models multiple agents executing asynchronously in multiple
|
||||
BaseEnv models multiple agents executing asynchronously in multiple
|
||||
environments. A call to poll() returns observations from ready agents
|
||||
keyed by their environment and agent ids, and actions for those agents
|
||||
can be sent back via send_actions().
|
||||
|
||||
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
|
||||
All other env types can be adapted to BaseEnv. RLlib handles these
|
||||
conversions internally in PolicyEvaluator, for example:
|
||||
|
||||
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
|
||||
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
|
||||
rllib.ExternalEnv => rllib.AsyncVectorEnv
|
||||
gym.Env => rllib.VectorEnv => rllib.BaseEnv
|
||||
rllib.MultiAgentEnv => rllib.BaseEnv
|
||||
rllib.ExternalEnv => rllib.BaseEnv
|
||||
|
||||
Attributes:
|
||||
action_space (gym.Space): Action space. This must be defined for
|
||||
|
@ -30,7 +31,7 @@ class AsyncVectorEnv(object):
|
|||
for single-agent envs. Multi-agent envs can set this to None.
|
||||
|
||||
Examples:
|
||||
>>> env = MyAsyncVectorEnv()
|
||||
>>> env = MyBaseEnv()
|
||||
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
|
||||
>>> print(obs)
|
||||
{
|
||||
|
@ -65,26 +66,27 @@ class AsyncVectorEnv(object):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def wrap_async(env, make_env=None, num_envs=1):
|
||||
def to_base_env(env, make_env=None, num_envs=1):
|
||||
"""Wraps any env type as needed to expose the async interface."""
|
||||
if not isinstance(env, AsyncVectorEnv):
|
||||
if not isinstance(env, BaseEnv):
|
||||
if isinstance(env, MultiAgentEnv):
|
||||
env = _MultiAgentEnvToAsync(
|
||||
env = _MultiAgentEnvToBaseEnv(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
elif isinstance(env, ExternalEnv):
|
||||
if num_envs != 1:
|
||||
raise ValueError(
|
||||
"ExternalEnv does not currently support num_envs > 1.")
|
||||
env = _ExternalEnvToAsync(env)
|
||||
env = _ExternalEnvToBaseEnv(env)
|
||||
elif isinstance(env, VectorEnv):
|
||||
env = _VectorEnvToAsync(env)
|
||||
env = _VectorEnvToBaseEnv(env)
|
||||
else:
|
||||
env = VectorEnv.wrap(
|
||||
make_env=make_env, existing_envs=[env], num_envs=num_envs)
|
||||
env = _VectorEnvToAsync(env)
|
||||
assert isinstance(env, AsyncVectorEnv)
|
||||
env = _VectorEnvToBaseEnv(env)
|
||||
assert isinstance(env, BaseEnv)
|
||||
return env
|
||||
|
||||
@PublicAPI
|
||||
def poll(self):
|
||||
"""Returns observations from ready agents.
|
||||
|
||||
|
@ -107,6 +109,7 @@ class AsyncVectorEnv(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def send_actions(self, action_dict):
|
||||
"""Called to send actions back to running agents in this env.
|
||||
|
||||
|
@ -118,6 +121,7 @@ class AsyncVectorEnv(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def try_reset(self, env_id):
|
||||
"""Attempt to reset the env with the given id.
|
||||
|
||||
|
@ -129,6 +133,7 @@ class AsyncVectorEnv(object):
|
|||
"""
|
||||
return None
|
||||
|
||||
@PublicAPI
|
||||
def get_unwrapped(self):
|
||||
"""Return a reference to the underlying gym envs, if any.
|
||||
|
||||
|
@ -146,8 +151,8 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
|
|||
return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
|
||||
|
||||
|
||||
class _ExternalEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of ExternalEnv to AsyncVectorEnv."""
|
||||
class _ExternalEnvToBaseEnv(BaseEnv):
|
||||
"""Internal adapter of ExternalEnv to BaseEnv."""
|
||||
|
||||
def __init__(self, external_env, preprocessor=None):
|
||||
self.external_env = external_env
|
||||
|
@ -159,7 +164,7 @@ class _ExternalEnvToAsync(AsyncVectorEnv):
|
|||
self.observation_space = external_env.observation_space
|
||||
external_env.start()
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def poll(self):
|
||||
with self.external_env._results_avail_condition:
|
||||
results = self._poll()
|
||||
|
@ -174,7 +179,7 @@ class _ExternalEnvToAsync(AsyncVectorEnv):
|
|||
"ExternalEnv was created with max_concurrent={}".format(limit))
|
||||
return results
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def send_actions(self, action_dict):
|
||||
for eid, action in action_dict.items():
|
||||
self.external_env._episodes[eid].action_queue.put(
|
||||
|
@ -204,8 +209,8 @@ class _ExternalEnvToAsync(AsyncVectorEnv):
|
|||
_with_dummy_agent_id(off_policy_actions)
|
||||
|
||||
|
||||
class _VectorEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of VectorEnv to AsyncVectorEnv.
|
||||
class _VectorEnvToBaseEnv(BaseEnv):
|
||||
"""Internal adapter of VectorEnv to BaseEnv.
|
||||
|
||||
We assume the caller will always send the full vector of actions in each
|
||||
call to send_actions(), and that they call reset_at() on all completed
|
||||
|
@ -222,7 +227,7 @@ class _VectorEnvToAsync(AsyncVectorEnv):
|
|||
self.cur_dones = [False for _ in range(self.num_envs)]
|
||||
self.cur_infos = [None for _ in range(self.num_envs)]
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def poll(self):
|
||||
if self.new_obs is None:
|
||||
self.new_obs = self.vector_env.vector_reset()
|
||||
|
@ -239,7 +244,7 @@ class _VectorEnvToAsync(AsyncVectorEnv):
|
|||
_with_dummy_agent_id(dones, "__all__"), \
|
||||
_with_dummy_agent_id(infos), {}
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def send_actions(self, action_dict):
|
||||
action_vector = [None] * self.num_envs
|
||||
for i in range(self.num_envs):
|
||||
|
@ -247,17 +252,17 @@ class _VectorEnvToAsync(AsyncVectorEnv):
|
|||
self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
|
||||
self.vector_env.vector_step(action_vector)
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def try_reset(self, env_id):
|
||||
return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def get_unwrapped(self):
|
||||
return self.vector_env.get_unwrapped()
|
||||
|
||||
|
||||
class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
"""Internal adapter of MultiAgentEnv to AsyncVectorEnv.
|
||||
class _MultiAgentEnvToBaseEnv(BaseEnv):
|
||||
"""Internal adapter of MultiAgentEnv to BaseEnv.
|
||||
|
||||
This also supports vectorization if num_envs > 1.
|
||||
"""
|
||||
|
@ -282,14 +287,14 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
|||
assert isinstance(env, MultiAgentEnv)
|
||||
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def poll(self):
|
||||
obs, rewards, dones, infos = {}, {}, {}, {}
|
||||
for i, env_state in enumerate(self.env_states):
|
||||
obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
|
||||
return obs, rewards, dones, infos, {}
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def send_actions(self, action_dict):
|
||||
for env_id, agent_dict in action_dict.items():
|
||||
if env_id in self.dones:
|
||||
|
@ -311,7 +316,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
|||
self.dones.add(env_id)
|
||||
self.env_states[env_id].observe(obs, rewards, dones, infos)
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def try_reset(self, env_id):
|
||||
obs = self.env_states[env_id].reset()
|
||||
assert isinstance(obs, dict), "Not a multi-agent obs"
|
||||
|
@ -319,7 +324,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
|||
self.dones.remove(env_id)
|
||||
return obs
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
@override(BaseEnv)
|
||||
def get_unwrapped(self):
|
||||
return [state.env for state in self.env_states]
|
||||
|
3
python/ray/rllib/env/env_context.py
vendored
3
python/ray/rllib/env/env_context.py
vendored
|
@ -2,7 +2,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class EnvContext(dict):
|
||||
"""Wraps env configurations to include extra rllib metadata.
|
||||
|
||||
|
|
10
python/ray/rllib/env/external_env.py
vendored
10
python/ray/rllib/env/external_env.py
vendored
|
@ -6,7 +6,10 @@ from six.moves import queue
|
|||
import threading
|
||||
import uuid
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class ExternalEnv(threading.Thread):
|
||||
"""An environment that interfaces with external agents.
|
||||
|
||||
|
@ -36,6 +39,7 @@ class ExternalEnv(threading.Thread):
|
|||
print(agent.train())
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, action_space, observation_space, max_concurrent=100):
|
||||
"""Initialize an external env.
|
||||
|
||||
|
@ -57,6 +61,7 @@ class ExternalEnv(threading.Thread):
|
|||
self._results_avail_condition = threading.Condition()
|
||||
self._max_concurrent_episodes = max_concurrent
|
||||
|
||||
@PublicAPI
|
||||
def run(self):
|
||||
"""Override this to implement the run loop.
|
||||
|
||||
|
@ -73,6 +78,7 @@ class ExternalEnv(threading.Thread):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def start_episode(self, episode_id=None, training_enabled=True):
|
||||
"""Record the start of an episode.
|
||||
|
||||
|
@ -102,6 +108,7 @@ class ExternalEnv(threading.Thread):
|
|||
|
||||
return episode_id
|
||||
|
||||
@PublicAPI
|
||||
def get_action(self, episode_id, observation):
|
||||
"""Record an observation and get the on-policy action.
|
||||
|
||||
|
@ -116,6 +123,7 @@ class ExternalEnv(threading.Thread):
|
|||
episode = self._get(episode_id)
|
||||
return episode.wait_for_action(observation)
|
||||
|
||||
@PublicAPI
|
||||
def log_action(self, episode_id, observation, action):
|
||||
"""Record an observation and (off-policy) action taken.
|
||||
|
||||
|
@ -128,6 +136,7 @@ class ExternalEnv(threading.Thread):
|
|||
episode = self._get(episode_id)
|
||||
episode.log_action(observation, action)
|
||||
|
||||
@PublicAPI
|
||||
def log_returns(self, episode_id, reward, info=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
|
@ -146,6 +155,7 @@ class ExternalEnv(threading.Thread):
|
|||
if info:
|
||||
episode.cur_info = info or {}
|
||||
|
||||
@PublicAPI
|
||||
def end_episode(self, episode_id, observation):
|
||||
"""Record the end of an episode.
|
||||
|
||||
|
|
6
python/ray/rllib/env/multi_agent_env.py
vendored
6
python/ray/rllib/env/multi_agent_env.py
vendored
|
@ -2,7 +2,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class MultiAgentEnv(object):
|
||||
"""An environment that hosts multiple independent agents.
|
||||
|
||||
|
@ -41,6 +44,7 @@ class MultiAgentEnv(object):
|
|||
}
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def reset(self):
|
||||
"""Resets the env and returns observations from ready agents.
|
||||
|
||||
|
@ -49,6 +53,7 @@ class MultiAgentEnv(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def step(self, action_dict):
|
||||
"""Returns observations from ready agents.
|
||||
|
||||
|
@ -68,6 +73,7 @@ class MultiAgentEnv(object):
|
|||
|
||||
# yapf: disable
|
||||
# __grouping_doc_begin__
|
||||
@PublicAPI
|
||||
def with_agent_groups(self, groups, obs_space=None, act_space=None):
|
||||
"""Convenience method for grouping together agents in this env.
|
||||
|
||||
|
|
7
python/ray/rllib/env/vector_env.py
vendored
7
python/ray/rllib/env/vector_env.py
vendored
|
@ -2,9 +2,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class VectorEnv(object):
|
||||
"""An environment that supports batch evaluation.
|
||||
|
||||
|
@ -20,6 +21,7 @@ class VectorEnv(object):
|
|||
def wrap(make_env=None, existing_envs=None, num_envs=1):
|
||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
|
||||
|
||||
@PublicAPI
|
||||
def vector_reset(self):
|
||||
"""Resets all environments.
|
||||
|
||||
|
@ -28,6 +30,7 @@ class VectorEnv(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def reset_at(self, index):
|
||||
"""Resets a single environment.
|
||||
|
||||
|
@ -36,6 +39,7 @@ class VectorEnv(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def vector_step(self, actions):
|
||||
"""Vectorized step.
|
||||
|
||||
|
@ -50,6 +54,7 @@ class VectorEnv(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def get_unwrapped(self):
|
||||
"""Returns the underlying env instances."""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -4,9 +4,9 @@ from ray.rllib.evaluation.interface import EvaluatorInterface
|
|||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import (SampleBatch, MultiAgentBatch,
|
||||
SampleBatchBuilder,
|
||||
MultiAgentSampleBatchBuilder)
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.evaluation.sample_batch_builder import (
|
||||
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
|
||||
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
|
||||
from ray.rllib.evaluation.postprocessing import (compute_advantages,
|
||||
compute_targets)
|
||||
|
|
|
@ -7,16 +7,14 @@ import random
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.env.async_vector_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class MultiAgentEpisode(object):
|
||||
"""Tracks the current state of a (possibly multi-agent) episode.
|
||||
|
||||
The APIs in this class should be considered experimental, but we should
|
||||
avoid changing things for the sake of changing them since users may
|
||||
depend on them for custom metrics or advanced algorithms.
|
||||
|
||||
Attributes:
|
||||
new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
|
||||
add_extra_batch (func): Return a built MultiAgentBatch to the sampler.
|
||||
|
@ -66,6 +64,7 @@ class MultiAgentEpisode(object):
|
|||
self._agent_to_prev_action = {}
|
||||
self._agent_reward_history = defaultdict(list)
|
||||
|
||||
@DeveloperAPI
|
||||
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the policy graph for the specified agent.
|
||||
|
||||
|
@ -77,16 +76,19 @@ class MultiAgentEpisode(object):
|
|||
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
|
||||
return self._agent_to_policy[agent_id]
|
||||
|
||||
@DeveloperAPI
|
||||
def last_observation_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last observation for the specified agent."""
|
||||
|
||||
return self._agent_to_last_obs.get(agent_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def last_info_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last info for the specified agent."""
|
||||
|
||||
return self._agent_to_last_info.get(agent_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def last_action_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last action for the specified agent, or zeros."""
|
||||
|
||||
|
@ -97,6 +99,7 @@ class MultiAgentEpisode(object):
|
|||
flat = _flatten_action(policy.action_space.sample())
|
||||
return np.zeros_like(flat)
|
||||
|
||||
@DeveloperAPI
|
||||
def prev_action_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the previous action for the specified agent."""
|
||||
|
||||
|
@ -106,6 +109,7 @@ class MultiAgentEpisode(object):
|
|||
# We're at t=0, so return all zeros.
|
||||
return np.zeros_like(self.last_action_for(agent_id))
|
||||
|
||||
@DeveloperAPI
|
||||
def prev_reward_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the previous reward for the specified agent."""
|
||||
|
||||
|
@ -116,6 +120,7 @@ class MultiAgentEpisode(object):
|
|||
# We're at t=0, so there is no previous reward, just return zero.
|
||||
return 0.0
|
||||
|
||||
@DeveloperAPI
|
||||
def rnn_state_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last RNN state for the specified agent."""
|
||||
|
||||
|
@ -124,6 +129,7 @@ class MultiAgentEpisode(object):
|
|||
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
|
||||
return self._agent_to_rnn_state[agent_id]
|
||||
|
||||
@DeveloperAPI
|
||||
def last_pi_info_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last info object for the specified agent."""
|
||||
|
||||
|
|
|
@ -4,13 +4,17 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class EvaluatorInterface(object):
|
||||
"""This is the interface between policy optimizers and policy evaluation.
|
||||
|
||||
See also: PolicyEvaluator
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self):
|
||||
"""Returns a batch of experience sampled from this evaluator.
|
||||
|
||||
|
@ -27,6 +31,7 @@ class EvaluatorInterface(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, samples):
|
||||
"""Returns a gradient computed w.r.t the specified samples.
|
||||
|
||||
|
@ -45,6 +50,7 @@ class EvaluatorInterface(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, grads):
|
||||
"""Applies the given gradients to this evaluator's weights.
|
||||
|
||||
|
@ -58,6 +64,7 @@ class EvaluatorInterface(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_weights(self):
|
||||
"""Returns the model weights of this Evaluator.
|
||||
|
||||
|
@ -73,6 +80,7 @@ class EvaluatorInterface(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights):
|
||||
"""Sets the model weights of this Evaluator.
|
||||
|
||||
|
@ -85,6 +93,7 @@ class EvaluatorInterface(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_apply(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
|
@ -100,11 +109,13 @@ class EvaluatorInterface(object):
|
|||
self.apply_gradients(grads)
|
||||
return info
|
||||
|
||||
@DeveloperAPI
|
||||
def get_host(self):
|
||||
"""Returns the hostname of the process running this evaluator."""
|
||||
|
||||
return os.uname()[1]
|
||||
|
||||
@DeveloperAPI
|
||||
def apply(self, func, *args):
|
||||
"""Apply the given function to this evaluator instance."""
|
||||
|
||||
|
|
|
@ -8,10 +8,12 @@ import collections
|
|||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(local_evaluator, remote_evaluators=[],
|
||||
timeout_seconds=180):
|
||||
"""Gathers episode metrics from PolicyEvaluator instances."""
|
||||
|
@ -22,6 +24,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[],
|
|||
return metrics
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_episodes(local_evaluator,
|
||||
remote_evaluators=[],
|
||||
timeout_seconds=180):
|
||||
|
@ -43,6 +46,7 @@ def collect_episodes(local_evaluator,
|
|||
return episodes, num_metric_batches_dropped
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def summarize_episodes(episodes, new_episodes, num_dropped):
|
||||
"""Summarizes a set of episode metrics tuples.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import pickle
|
|||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
@ -22,7 +22,7 @@ from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
|||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import NoPreprocessor
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
@ -30,6 +30,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PolicyEvaluator(EvaluatorInterface):
|
||||
"""Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``.
|
||||
|
||||
|
@ -83,11 +84,13 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
"traffic_light_policy": SampleBatch(...)})
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
@classmethod
|
||||
def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
|
||||
return ray.remote(
|
||||
num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls)
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
|
@ -214,7 +217,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
|
||||
self.env = env_creator(env_context)
|
||||
if isinstance(self.env, MultiAgentEnv) or \
|
||||
isinstance(self.env, AsyncVectorEnv):
|
||||
isinstance(self.env, BaseEnv):
|
||||
|
||||
def wrap(env):
|
||||
return env # we can't auto-wrap these env types
|
||||
|
@ -275,7 +278,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent:
|
||||
if not (isinstance(self.env, MultiAgentEnv)
|
||||
or isinstance(self.env, AsyncVectorEnv)):
|
||||
or isinstance(self.env, BaseEnv)):
|
||||
raise ValueError(
|
||||
"Have multiple policy graphs {}, but the env ".format(
|
||||
self.policy_map) +
|
||||
|
@ -288,7 +291,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
}
|
||||
|
||||
# Always use vector env for consistency even if num_envs = 1
|
||||
self.async_env = AsyncVectorEnv.wrap_async(
|
||||
self.async_env = BaseEnv.to_base_env(
|
||||
self.env, make_env=make_env, num_envs=num_envs)
|
||||
self.num_envs = num_envs
|
||||
|
||||
|
@ -399,6 +402,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
|
||||
return batch
|
||||
|
||||
@DeveloperAPI
|
||||
@ray.method(num_return_vals=2)
|
||||
def sample_with_count(self):
|
||||
"""Same as sample() but returns the count as a separate future."""
|
||||
|
@ -489,6 +493,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
|
||||
return grad_fetch
|
||||
|
||||
@DeveloperAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
|
||||
|
@ -498,16 +503,19 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
|
||||
return self.policy_map.get(policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Apply the given function to the specified policy graph."""
|
||||
|
||||
return func(self.policy_map[policy_id])
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_policy(self, func):
|
||||
"""Apply the given function to each (policy, policy_id) tuple."""
|
||||
|
||||
return [func(policy, pid) for pid, policy in self.policy_map.items()]
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_trainable_policy(self, func):
|
||||
"""Apply the given function to each (policy, policy_id) tuple.
|
||||
|
||||
|
@ -518,6 +526,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
if pid in self.policies_to_train
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
def sync_filters(self, new_filters):
|
||||
"""Changes self's filter to given and rebases any accumulated delta.
|
||||
|
||||
|
@ -528,6 +537,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
for k in self.filters:
|
||||
self.filters[k].sync(new_filters[k])
|
||||
|
||||
@DeveloperAPI
|
||||
def get_filters(self, flush_after=False):
|
||||
"""Returns a snapshot of filters.
|
||||
|
||||
|
@ -544,6 +554,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
f.clear_buffer()
|
||||
return return_filters
|
||||
|
||||
@DeveloperAPI
|
||||
def save(self):
|
||||
filters = self.get_filters(flush_after=True)
|
||||
state = {
|
||||
|
@ -552,18 +563,22 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
}
|
||||
return pickle.dumps({"filters": filters, "state": state})
|
||||
|
||||
@DeveloperAPI
|
||||
def restore(self, objs):
|
||||
objs = pickle.loads(objs)
|
||||
self.sync_filters(objs["filters"])
|
||||
for pid, state in objs["state"].items():
|
||||
self.policy_map[pid].set_state(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def set_global_vars(self, global_vars):
|
||||
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
|
||||
self.policy_map[policy_id].export_model(export_dir)
|
||||
|
||||
@DeveloperAPI
|
||||
def export_policy_checkpoint(self,
|
||||
export_dir,
|
||||
filename_prefix="model",
|
||||
|
|
|
@ -2,7 +2,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PolicyGraph(object):
|
||||
"""An agent policy and loss, i.e., a TFPolicyGraph or other subclass.
|
||||
|
||||
|
@ -21,6 +24,7 @@ class PolicyGraph(object):
|
|||
action_space (gym.Space): Action space of the policy.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
"""Initialize the graph.
|
||||
|
||||
|
@ -37,6 +41,7 @@ class PolicyGraph(object):
|
|||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
|
@ -68,6 +73,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_single_action(self,
|
||||
obs,
|
||||
state,
|
||||
|
@ -116,6 +122,7 @@ class PolicyGraph(object):
|
|||
return action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
@DeveloperAPI
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
|
@ -140,6 +147,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
return sample_batch
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
"""Computes gradients against a batch of experiences.
|
||||
|
||||
|
@ -149,6 +157,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, gradients):
|
||||
"""Applies previously computed gradients.
|
||||
|
||||
|
@ -157,6 +166,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_apply(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
|
@ -173,6 +183,7 @@ class PolicyGraph(object):
|
|||
apply_info = self.apply_gradients(grads)
|
||||
return grad_info, apply_info
|
||||
|
||||
@DeveloperAPI
|
||||
def get_weights(self):
|
||||
"""Returns model weights.
|
||||
|
||||
|
@ -181,6 +192,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights):
|
||||
"""Sets model weights.
|
||||
|
||||
|
@ -189,10 +201,12 @@ class PolicyGraph(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_initial_state(self):
|
||||
"""Returns initial RNN state for the current policy."""
|
||||
return []
|
||||
|
||||
@DeveloperAPI
|
||||
def get_state(self):
|
||||
"""Saves all local state.
|
||||
|
||||
|
@ -201,6 +215,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
return self.get_weights()
|
||||
|
||||
@DeveloperAPI
|
||||
def set_state(self, state):
|
||||
"""Restores all local state.
|
||||
|
||||
|
@ -209,6 +224,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
self.set_weights(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def on_global_var_update(self, global_vars):
|
||||
"""Called on an update to global vars.
|
||||
|
||||
|
@ -217,6 +233,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
pass
|
||||
|
||||
@DeveloperAPI
|
||||
def export_model(self, export_dir):
|
||||
"""Export PolicyGraph to local directory for serving.
|
||||
|
||||
|
@ -225,6 +242,7 @@ class PolicyGraph(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def export_checkpoint(self, export_dir):
|
||||
"""Export PolicyGraph checkpoint to local directory.
|
||||
|
||||
|
|
|
@ -5,12 +5,14 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import scipy.signal
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
def discount(x, gamma):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
|
||||
"""Given a rollout, compute its value targets and the advantage.
|
||||
|
||||
|
@ -54,6 +56,7 @@ def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
|
|||
return SampleBatch(traj)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_targets(rollout, action_space, last_r=0.0, gamma=0.9, lambda_=1.0):
|
||||
"""Given a rollout, compute targets.
|
||||
|
||||
|
|
|
@ -6,166 +6,13 @@ import six
|
|||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
# Defaults policy id for single agent environments
|
||||
DEFAULT_POLICY_ID = "default"
|
||||
|
||||
|
||||
def to_float_array(v):
|
||||
arr = np.array(v)
|
||||
if arr.dtype == np.float64:
|
||||
return arr.astype(np.float32) # save some memory
|
||||
return arr
|
||||
|
||||
|
||||
class SampleBatchBuilder(object):
|
||||
"""Util to build a SampleBatch incrementally.
|
||||
|
||||
For efficiency, SampleBatches hold values in column form (as arrays).
|
||||
However, it is useful to add data one row (dict) at a time.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buffers = collections.defaultdict(list)
|
||||
self.count = 0
|
||||
|
||||
def add_values(self, **values):
|
||||
"""Add the given dictionary (row) of values to this batch."""
|
||||
|
||||
for k, v in values.items():
|
||||
self.buffers[k].append(v)
|
||||
self.count += 1
|
||||
|
||||
def add_batch(self, batch):
|
||||
"""Add the given batch of values to this batch."""
|
||||
|
||||
for k, column in batch.items():
|
||||
self.buffers[k].extend(column)
|
||||
self.count += batch.count
|
||||
|
||||
def build_and_reset(self):
|
||||
"""Returns a sample batch including all previously added values."""
|
||||
|
||||
batch = SampleBatch(
|
||||
{k: to_float_array(v)
|
||||
for k, v in self.buffers.items()})
|
||||
self.buffers.clear()
|
||||
self.count = 0
|
||||
return batch
|
||||
|
||||
|
||||
class MultiAgentSampleBatchBuilder(object):
|
||||
"""Util to build SampleBatches for each policy in a multi-agent env.
|
||||
|
||||
Input data is per-agent, while output data is per-policy. There is an M:N
|
||||
mapping between agents and policies. We retain one local batch builder
|
||||
per agent. When an agent is done, then its local batch is appended into the
|
||||
corresponding policy batch for the agent's policy.
|
||||
"""
|
||||
|
||||
def __init__(self, policy_map, clip_rewards):
|
||||
"""Initialize a MultiAgentSampleBatchBuilder.
|
||||
|
||||
Arguments:
|
||||
policy_map (dict): Maps policy ids to policy graph instances.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.clip_rewards = clip_rewards
|
||||
self.policy_builders = {
|
||||
k: SampleBatchBuilder()
|
||||
for k in policy_map.keys()
|
||||
}
|
||||
self.agent_builders = {}
|
||||
self.agent_to_policy = {}
|
||||
self.count = 0 # increment this manually
|
||||
|
||||
def total(self):
|
||||
"""Returns summed number of steps across all agent buffers."""
|
||||
|
||||
return sum(p.count for p in self.policy_builders.values())
|
||||
|
||||
def has_pending_data(self):
|
||||
"""Returns whether there is pending unprocessed data."""
|
||||
|
||||
return len(self.agent_builders) > 0
|
||||
|
||||
def add_values(self, agent_id, policy_id, **values):
|
||||
"""Add the given dictionary (row) of values to this batch.
|
||||
|
||||
Arguments:
|
||||
agent_id (obj): Unique id for the agent we are adding values for.
|
||||
policy_id (obj): Unique id for policy controlling the agent.
|
||||
values (dict): Row of values to add for this agent.
|
||||
"""
|
||||
|
||||
if agent_id not in self.agent_builders:
|
||||
self.agent_builders[agent_id] = SampleBatchBuilder()
|
||||
self.agent_to_policy[agent_id] = policy_id
|
||||
builder = self.agent_builders[agent_id]
|
||||
builder.add_values(**values)
|
||||
|
||||
def postprocess_batch_so_far(self, episode):
|
||||
"""Apply policy postprocessors to any unprocessed rows.
|
||||
|
||||
This pushes the postprocessed per-agent batches onto the per-policy
|
||||
builders, clearing per-agent state.
|
||||
|
||||
Arguments:
|
||||
episode: current MultiAgentEpisode object or None
|
||||
"""
|
||||
|
||||
# Materialize the batches so far
|
||||
pre_batches = {}
|
||||
for agent_id, builder in self.agent_builders.items():
|
||||
pre_batches[agent_id] = (
|
||||
self.policy_map[self.agent_to_policy[agent_id]],
|
||||
builder.build_and_reset())
|
||||
|
||||
# Apply postprocessor
|
||||
post_batches = {}
|
||||
if self.clip_rewards:
|
||||
for _, (_, pre_batch) in pre_batches.items():
|
||||
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
policy = self.policy_map[self.agent_to_policy[agent_id]]
|
||||
if any(pre_batch["dones"][:-1]) or len(set(
|
||||
pre_batch["eps_id"])) > 1:
|
||||
raise ValueError(
|
||||
"Batches sent to postprocessing must only contain steps "
|
||||
"from a single trajectory.", pre_batch)
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
pre_batch, other_batches, episode)
|
||||
|
||||
# Append into policy batches and reset
|
||||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
||||
def build_and_reset(self, episode):
|
||||
"""Returns the accumulated sample batches for each policy.
|
||||
|
||||
Any unprocessed rows will be first postprocessed with a policy
|
||||
postprocessor. The internal state of this builder will be reset.
|
||||
|
||||
Arguments:
|
||||
episode: current MultiAgentEpisode object or None
|
||||
"""
|
||||
|
||||
self.postprocess_batch_so_far(episode)
|
||||
policy_batches = {}
|
||||
for policy_id, builder in self.policy_builders.items():
|
||||
if builder.count > 0:
|
||||
policy_batches[policy_id] = builder.build_and_reset()
|
||||
old_count = self.count
|
||||
self.count = 0
|
||||
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class MultiAgentBatch(object):
|
||||
"""A batch of experiences from multiple policies in the environment.
|
||||
|
||||
|
@ -177,17 +24,20 @@ class MultiAgentBatch(object):
|
|||
batch contains across all policies in total.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, policy_batches, count):
|
||||
self.policy_batches = policy_batches
|
||||
self.count = count
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def wrap_as_needed(batches, count):
|
||||
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
|
||||
return batches[DEFAULT_POLICY_ID]
|
||||
return MultiAgentBatch(batches, count)
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
policy_batches = collections.defaultdict(list)
|
||||
total_count = 0
|
||||
|
@ -201,11 +51,13 @@ class MultiAgentBatch(object):
|
|||
out[policy_id] = SampleBatch.concat_samples(batches)
|
||||
return MultiAgentBatch(out, total_count)
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
return MultiAgentBatch(
|
||||
{k: v.copy()
|
||||
for (k, v) in self.policy_batches.items()}, self.count)
|
||||
|
||||
@PublicAPI
|
||||
def total(self):
|
||||
ct = 0
|
||||
for batch in self.policy_batches.values():
|
||||
|
@ -221,6 +73,7 @@ class MultiAgentBatch(object):
|
|||
str(self.policy_batches), self.count)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class SampleBatch(object):
|
||||
"""Wrapper around a dictionary with string keys and array-like values.
|
||||
|
||||
|
@ -228,6 +81,7 @@ class SampleBatch(object):
|
|||
samples, each with an "obs" and "reward" attribute.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Constructs a sample batch (same params as dict constructor)."""
|
||||
|
||||
|
@ -243,6 +97,7 @@ class SampleBatch(object):
|
|||
self.count = lengths[0]
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
if isinstance(samples[0], MultiAgentBatch):
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
|
@ -252,6 +107,7 @@ class SampleBatch(object):
|
|||
out[k] = np.concatenate([s[k] for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def concat(self, other):
|
||||
"""Returns a new SampleBatch with each data column concatenated.
|
||||
|
||||
|
@ -268,11 +124,13 @@ class SampleBatch(object):
|
|||
out[k] = np.concatenate([self[k], other[k]])
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
return SampleBatch(
|
||||
{k: np.array(v, copy=True)
|
||||
for (k, v) in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def rows(self):
|
||||
"""Returns an iterator over data rows, i.e. dicts with column values.
|
||||
|
||||
|
@ -291,6 +149,7 @@ class SampleBatch(object):
|
|||
row[k] = self[k][i]
|
||||
yield row
|
||||
|
||||
@PublicAPI
|
||||
def columns(self, keys):
|
||||
"""Returns a list of just the specified columns.
|
||||
|
||||
|
@ -305,6 +164,7 @@ class SampleBatch(object):
|
|||
out.append(self[k])
|
||||
return out
|
||||
|
||||
@PublicAPI
|
||||
def shuffle(self):
|
||||
"""Shuffles the rows of this batch in-place."""
|
||||
|
||||
|
@ -312,6 +172,7 @@ class SampleBatch(object):
|
|||
for key, val in self.items():
|
||||
self[key] = val[permutation]
|
||||
|
||||
@PublicAPI
|
||||
def split_by_episode(self):
|
||||
"""Splits this batch's data by `eps_id`.
|
||||
|
||||
|
@ -335,6 +196,7 @@ class SampleBatch(object):
|
|||
assert sum(s.count for s in slices) == self.count, (slices, self.count)
|
||||
return slices
|
||||
|
||||
@PublicAPI
|
||||
def slice(self, start, end):
|
||||
"""Returns a slice of the row data of this batch.
|
||||
|
||||
|
@ -348,9 +210,19 @@ class SampleBatch(object):
|
|||
|
||||
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
@PublicAPI
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
@PublicAPI
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
@PublicAPI
|
||||
def __setitem__(self, key, item):
|
||||
self.data[key] = item
|
||||
|
||||
|
@ -360,12 +232,6 @@ class SampleBatch(object):
|
|||
def __repr__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def __iter__(self):
|
||||
return self.data.__iter__()
|
||||
|
||||
|
|
173
python/ray/rllib/evaluation/sample_batch_builder.py
Normal file
173
python/ray/rllib/evaluation/sample_batch_builder.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
|
||||
|
||||
def to_float_array(v):
|
||||
arr = np.array(v)
|
||||
if arr.dtype == np.float64:
|
||||
return arr.astype(np.float32) # save some memory
|
||||
return arr
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class SampleBatchBuilder(object):
|
||||
"""Util to build a SampleBatch incrementally.
|
||||
|
||||
For efficiency, SampleBatches hold values in column form (as arrays).
|
||||
However, it is useful to add data one row (dict) at a time.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self):
|
||||
self.buffers = collections.defaultdict(list)
|
||||
self.count = 0
|
||||
|
||||
@PublicAPI
|
||||
def add_values(self, **values):
|
||||
"""Add the given dictionary (row) of values to this batch."""
|
||||
|
||||
for k, v in values.items():
|
||||
self.buffers[k].append(v)
|
||||
self.count += 1
|
||||
|
||||
@PublicAPI
|
||||
def add_batch(self, batch):
|
||||
"""Add the given batch of values to this batch."""
|
||||
|
||||
for k, column in batch.items():
|
||||
self.buffers[k].extend(column)
|
||||
self.count += batch.count
|
||||
|
||||
@PublicAPI
|
||||
def build_and_reset(self):
|
||||
"""Returns a sample batch including all previously added values."""
|
||||
|
||||
batch = SampleBatch(
|
||||
{k: to_float_array(v)
|
||||
for k, v in self.buffers.items()})
|
||||
self.buffers.clear()
|
||||
self.count = 0
|
||||
return batch
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class MultiAgentSampleBatchBuilder(object):
|
||||
"""Util to build SampleBatches for each policy in a multi-agent env.
|
||||
|
||||
Input data is per-agent, while output data is per-policy. There is an M:N
|
||||
mapping between agents and policies. We retain one local batch builder
|
||||
per agent. When an agent is done, then its local batch is appended into the
|
||||
corresponding policy batch for the agent's policy.
|
||||
"""
|
||||
|
||||
def __init__(self, policy_map, clip_rewards):
|
||||
"""Initialize a MultiAgentSampleBatchBuilder.
|
||||
|
||||
Arguments:
|
||||
policy_map (dict): Maps policy ids to policy graph instances.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
"""
|
||||
|
||||
self.policy_map = policy_map
|
||||
self.clip_rewards = clip_rewards
|
||||
self.policy_builders = {
|
||||
k: SampleBatchBuilder()
|
||||
for k in policy_map.keys()
|
||||
}
|
||||
self.agent_builders = {}
|
||||
self.agent_to_policy = {}
|
||||
self.count = 0 # increment this manually
|
||||
|
||||
def total(self):
|
||||
"""Returns summed number of steps across all agent buffers."""
|
||||
|
||||
return sum(p.count for p in self.policy_builders.values())
|
||||
|
||||
def has_pending_data(self):
|
||||
"""Returns whether there is pending unprocessed data."""
|
||||
|
||||
return len(self.agent_builders) > 0
|
||||
|
||||
@DeveloperAPI
|
||||
def add_values(self, agent_id, policy_id, **values):
|
||||
"""Add the given dictionary (row) of values to this batch.
|
||||
|
||||
Arguments:
|
||||
agent_id (obj): Unique id for the agent we are adding values for.
|
||||
policy_id (obj): Unique id for policy controlling the agent.
|
||||
values (dict): Row of values to add for this agent.
|
||||
"""
|
||||
|
||||
if agent_id not in self.agent_builders:
|
||||
self.agent_builders[agent_id] = SampleBatchBuilder()
|
||||
self.agent_to_policy[agent_id] = policy_id
|
||||
builder = self.agent_builders[agent_id]
|
||||
builder.add_values(**values)
|
||||
|
||||
def postprocess_batch_so_far(self, episode):
|
||||
"""Apply policy postprocessors to any unprocessed rows.
|
||||
|
||||
This pushes the postprocessed per-agent batches onto the per-policy
|
||||
builders, clearing per-agent state.
|
||||
|
||||
Arguments:
|
||||
episode: current MultiAgentEpisode object or None
|
||||
"""
|
||||
|
||||
# Materialize the batches so far
|
||||
pre_batches = {}
|
||||
for agent_id, builder in self.agent_builders.items():
|
||||
pre_batches[agent_id] = (
|
||||
self.policy_map[self.agent_to_policy[agent_id]],
|
||||
builder.build_and_reset())
|
||||
|
||||
# Apply postprocessor
|
||||
post_batches = {}
|
||||
if self.clip_rewards:
|
||||
for _, (_, pre_batch) in pre_batches.items():
|
||||
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||
other_batches = pre_batches.copy()
|
||||
del other_batches[agent_id]
|
||||
policy = self.policy_map[self.agent_to_policy[agent_id]]
|
||||
if any(pre_batch["dones"][:-1]) or len(set(
|
||||
pre_batch["eps_id"])) > 1:
|
||||
raise ValueError(
|
||||
"Batches sent to postprocessing must only contain steps "
|
||||
"from a single trajectory.", pre_batch)
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
pre_batch, other_batches, episode)
|
||||
|
||||
# Append into policy batches and reset
|
||||
for agent_id, post_batch in sorted(post_batches.items()):
|
||||
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
|
||||
post_batch)
|
||||
self.agent_builders.clear()
|
||||
self.agent_to_policy.clear()
|
||||
|
||||
@DeveloperAPI
|
||||
def build_and_reset(self, episode):
|
||||
"""Returns the accumulated sample batches for each policy.
|
||||
|
||||
Any unprocessed rows will be first postprocessed with a policy
|
||||
postprocessor. The internal state of this builder will be reset.
|
||||
|
||||
Arguments:
|
||||
episode: current MultiAgentEpisode object or None
|
||||
"""
|
||||
|
||||
self.postprocess_batch_so_far(episode)
|
||||
policy_batches = {}
|
||||
for policy_id, builder in self.policy_builders.items():
|
||||
if builder.count > 0:
|
||||
policy_batches[policy_id] = builder.build_and_reset()
|
||||
old_count = self.count
|
||||
self.count = 0
|
||||
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
|
|
@ -10,9 +10,10 @@ import six.moves.queue as queue
|
|||
import threading
|
||||
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
@ -44,7 +45,7 @@ class SyncSampler(object):
|
|||
pack=False,
|
||||
tf_sess=None,
|
||||
clip_actions=True):
|
||||
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
self.unroll_length = unroll_length
|
||||
self.horizon = horizon
|
||||
self.policies = policies
|
||||
|
@ -53,7 +54,7 @@ class SyncSampler(object):
|
|||
self.obs_filters = obs_filters
|
||||
self.extra_batches = queue.Queue()
|
||||
self.rollout_provider = _env_runner(
|
||||
self.async_vector_env, self.extra_batches.put, self.policies,
|
||||
self.base_env, self.extra_batches.put, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
|
||||
pack, callbacks, tf_sess)
|
||||
|
@ -104,7 +105,7 @@ class AsyncSampler(threading.Thread):
|
|||
for _, f in obs_filters.items():
|
||||
assert getattr(f, "is_concurrent", False), \
|
||||
"Observation Filter must support concurrent updates."
|
||||
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
|
||||
self.base_env = BaseEnv.to_base_env(env)
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.extra_batches = queue.Queue()
|
||||
|
@ -140,7 +141,7 @@ class AsyncSampler(threading.Thread):
|
|||
extra_batches_putter = (
|
||||
lambda x: self.extra_batches.put(x, timeout=600.0))
|
||||
rollout_provider = _env_runner(
|
||||
self.async_vector_env, extra_batches_putter, self.policies,
|
||||
self.base_env, extra_batches_putter, self.policies,
|
||||
self.policy_mapping_fn, self.unroll_length, self.horizon,
|
||||
self.preprocessors, self.obs_filters, self.clip_rewards,
|
||||
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
|
||||
|
@ -182,7 +183,7 @@ class AsyncSampler(threading.Thread):
|
|||
return extra
|
||||
|
||||
|
||||
def _env_runner(async_vector_env,
|
||||
def _env_runner(base_env,
|
||||
extra_batch_callback,
|
||||
policies,
|
||||
policy_mapping_fn,
|
||||
|
@ -198,7 +199,7 @@ def _env_runner(async_vector_env,
|
|||
"""This implements the common experience collection logic.
|
||||
|
||||
Args:
|
||||
async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
|
||||
base_env (BaseEnv): env implementing BaseEnv.
|
||||
extra_batch_callback (fn): function to send extra batch data to.
|
||||
policies (dict): Map of policy ids to PolicyGraph instances.
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
|
@ -226,8 +227,7 @@ def _env_runner(async_vector_env,
|
|||
|
||||
try:
|
||||
if not horizon:
|
||||
horizon = (
|
||||
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
|
||||
horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
|
||||
except Exception:
|
||||
logger.debug("no episode horizon specified, assuming inf")
|
||||
if not horizon:
|
||||
|
@ -248,7 +248,7 @@ def _env_runner(async_vector_env,
|
|||
get_batch_builder, extra_batch_callback)
|
||||
if callbacks.get("on_episode_start"):
|
||||
callbacks["on_episode_start"]({
|
||||
"env": async_vector_env,
|
||||
"env": base_env,
|
||||
"episode": episode
|
||||
})
|
||||
return episode
|
||||
|
@ -258,11 +258,11 @@ def _env_runner(async_vector_env,
|
|||
while True:
|
||||
# Get observations from all ready agents
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
||||
async_vector_env.poll()
|
||||
base_env.poll()
|
||||
|
||||
# Process observations and prepare for policy evaluation
|
||||
active_envs, to_eval, outputs = _process_observations(
|
||||
async_vector_env, policies, batch_builder_pool, active_episodes,
|
||||
base_env, policies, batch_builder_pool, active_episodes,
|
||||
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
|
||||
preprocessors, obs_filters, unroll_length, pack, callbacks)
|
||||
for o in outputs:
|
||||
|
@ -279,10 +279,10 @@ def _env_runner(async_vector_env,
|
|||
|
||||
# Return computed actions to ready envs. We also send to envs that have
|
||||
# taken off-policy actions; those envs are free to ignore the action.
|
||||
async_vector_env.send_actions(actions_to_send)
|
||||
base_env.send_actions(actions_to_send)
|
||||
|
||||
|
||||
def _process_observations(async_vector_env, policies, batch_builder_pool,
|
||||
def _process_observations(base_env, policies, batch_builder_pool,
|
||||
active_episodes, unfiltered_obs, rewards, dones,
|
||||
infos, off_policy_actions, horizon, preprocessors,
|
||||
obs_filters, unroll_length, pack, callbacks):
|
||||
|
@ -325,7 +325,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
|||
# Check episode termination conditions
|
||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||
all_done = True
|
||||
atari_metrics = _fetch_atari_metrics(async_vector_env)
|
||||
atari_metrics = _fetch_atari_metrics(base_env)
|
||||
if atari_metrics is not None:
|
||||
for m in atari_metrics:
|
||||
outputs.append(
|
||||
|
@ -379,10 +379,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
|||
|
||||
# Invoke the step callback after the step is logged to the episode
|
||||
if callbacks.get("on_episode_step"):
|
||||
callbacks["on_episode_step"]({
|
||||
"env": async_vector_env,
|
||||
"episode": episode
|
||||
})
|
||||
callbacks["on_episode_step"]({"env": base_env, "episode": episode})
|
||||
|
||||
# Cut the batch if we're not packing multiple episodes into one,
|
||||
# or if we've exceeded the requested batch size.
|
||||
|
@ -399,11 +396,11 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
|
|||
batch_builder_pool.append(episode.batch_builder)
|
||||
if callbacks.get("on_episode_end"):
|
||||
callbacks["on_episode_end"]({
|
||||
"env": async_vector_env,
|
||||
"env": base_env,
|
||||
"episode": episode
|
||||
})
|
||||
del active_episodes[env_id]
|
||||
resetted_obs = async_vector_env.try_reset(env_id)
|
||||
resetted_obs = base_env.try_reset(env_id)
|
||||
if resetted_obs is None:
|
||||
# Reset not supported, drop this env from the ready list
|
||||
if horizon != float("inf"):
|
||||
|
@ -526,12 +523,12 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
|
|||
return actions_to_send
|
||||
|
||||
|
||||
def _fetch_atari_metrics(async_vector_env):
|
||||
def _fetch_atari_metrics(base_env):
|
||||
"""Atari games have multiple logical episodes, one per life.
|
||||
|
||||
However for metrics reporting we count full episodes all lives included.
|
||||
"""
|
||||
unwrapped = async_vector_env.get_unwrapped()
|
||||
unwrapped = base_env.get_unwrapped()
|
||||
if not unwrapped:
|
||||
return None
|
||||
atari_out = []
|
||||
|
|
|
@ -10,13 +10,14 @@ import numpy as np
|
|||
import ray
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class TFPolicyGraph(PolicyGraph):
|
||||
"""An agent policy and loss implemented in TensorFlow.
|
||||
|
||||
|
@ -41,6 +42,7 @@ class TFPolicyGraph(PolicyGraph):
|
|||
SampleBatch({"action": ..., "advantages": ..., ...})
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
|
@ -208,36 +210,63 @@ class TFPolicyGraph(PolicyGraph):
|
|||
saver = tf.train.Saver()
|
||||
saver.save(self._sess, save_path)
|
||||
|
||||
@DeveloperAPI
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders.
|
||||
|
||||
Optional, only required to work with the multi-GPU optimizer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_feed_dict(self):
|
||||
"""Extra dict to pass to the compute actions session run."""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_fetches(self):
|
||||
"""Extra values to fetch and return from compute_actions()."""
|
||||
return {} # e.g, value function
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_feed_dict(self):
|
||||
"""Extra dict to pass to the compute gradients session run."""
|
||||
return {} # e.g, kl_coeff
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_fetches(self):
|
||||
"""Extra values to fetch and return from compute_gradients()."""
|
||||
return {} # e.g, td error
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_apply_grad_feed_dict(self):
|
||||
"""Extra dict to pass to the apply gradients session run."""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_apply_grad_fetches(self):
|
||||
"""Extra values to fetch and return from apply_gradients()."""
|
||||
return {} # e.g., batch norm updates
|
||||
|
||||
@DeveloperAPI
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
return tf.train.AdamOptimizer()
|
||||
|
||||
@DeveloperAPI
|
||||
def gradients(self, optimizer):
|
||||
"""Override for custom gradient computation."""
|
||||
return optimizer.compute_gradients(self._loss)
|
||||
|
||||
@DeveloperAPI
|
||||
def _get_is_training_placeholder(self):
|
||||
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
||||
|
||||
This can be called safely before __init__ has run.
|
||||
"""
|
||||
if not hasattr(self, "_is_training"):
|
||||
self._is_training = tf.placeholder_with_default(False, ())
|
||||
return self._is_training
|
||||
|
||||
def _extra_input_signature_def(self):
|
||||
"""Extra input signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_feed_dict()
|
||||
|
@ -258,14 +287,6 @@ class TFPolicyGraph(PolicyGraph):
|
|||
for k in fetches.keys()
|
||||
}
|
||||
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
return tf.train.AdamOptimizer()
|
||||
|
||||
def gradients(self, optimizer):
|
||||
"""Override for custom gradient computation."""
|
||||
return optimizer.compute_gradients(self._loss)
|
||||
|
||||
def _build_signature_def(self):
|
||||
"""Build signature def map for tensorflow SavedModelBuilder.
|
||||
"""
|
||||
|
@ -364,15 +385,6 @@ class TFPolicyGraph(PolicyGraph):
|
|||
])
|
||||
return fetches[1], fetches[2]
|
||||
|
||||
def _get_is_training_placeholder(self):
|
||||
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
||||
|
||||
This can be called safely before __init__ has run.
|
||||
"""
|
||||
if not hasattr(self, "_is_training"):
|
||||
self._is_training = tf.placeholder_with_default(False, ())
|
||||
return self._is_training
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
|
@ -414,9 +426,11 @@ class TFPolicyGraph(PolicyGraph):
|
|||
return feed_dict
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class LearningRateSchedule(object):
|
||||
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, lr, lr_schedule):
|
||||
self.cur_lr = tf.get_variable("lr", initializer=lr)
|
||||
if lr_schedule is None:
|
||||
|
|
|
@ -7,7 +7,7 @@ from __future__ import print_function
|
|||
import gym
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatchBuilder
|
||||
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
|
||||
from ray.rllib.offline.json_writer import JsonWriter
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -7,12 +7,13 @@ import distutils.version
|
|||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
distutils.version.LooseVersion("1.5.0"))
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ActionDistribution(object):
|
||||
"""The policy action distribution of an agent.
|
||||
|
||||
|
@ -20,21 +21,26 @@ class ActionDistribution(object):
|
|||
inputs (Tensor): The input vector to compute samples from.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, inputs):
|
||||
self.inputs = inputs
|
||||
|
||||
@DeveloperAPI
|
||||
def logp(self, x):
|
||||
"""The log-likelihood of the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def kl(self, other):
|
||||
"""The KL-divergence between two action distributions."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def entropy(self):
|
||||
"""The entroy of the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self):
|
||||
"""Draw a sample from the action distribution."""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.rllib.models.preprocessors import get_preprocessor
|
|||
from ray.rllib.models.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.visionnet import VisionNetwork
|
||||
from ray.rllib.models.lstm import LSTM
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -69,6 +70,7 @@ MODEL_DEFAULTS = {
|
|||
# yapf: enable
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class ModelCatalog(object):
|
||||
"""Registry of models, preprocessors, and action distributions for envs.
|
||||
|
||||
|
@ -84,6 +86,7 @@ class ModelCatalog(object):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_action_dist(action_space, config, dist_type=None):
|
||||
"""Returns action distribution class and size for the given action space.
|
||||
|
||||
|
@ -134,6 +137,7 @@ class ModelCatalog(object):
|
|||
action_space, dist_type))
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_action_placeholder(action_space):
|
||||
"""Returns an action placeholder that is consistent with the action space
|
||||
|
||||
|
@ -166,6 +170,7 @@ class ModelCatalog(object):
|
|||
" not supported".format(action_space))
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_model(input_dict,
|
||||
obs_space,
|
||||
num_outputs,
|
||||
|
@ -230,6 +235,7 @@ class ModelCatalog(object):
|
|||
options)
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_torch_model(obs_space,
|
||||
num_outputs,
|
||||
options=None,
|
||||
|
@ -276,6 +282,7 @@ class ModelCatalog(object):
|
|||
return PyTorchFCNet(obs_space, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_preprocessor(env, options=None):
|
||||
"""Returns a suitable preprocessor for the given env.
|
||||
|
||||
|
@ -286,6 +293,7 @@ class ModelCatalog(object):
|
|||
options)
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def get_preprocessor_for_space(observation_space, options=None):
|
||||
"""Returns a suitable preprocessor for the given observation space.
|
||||
|
||||
|
@ -317,6 +325,7 @@ class ModelCatalog(object):
|
|||
return prep
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
|
||||
"""Register a custom preprocessor class by name.
|
||||
|
||||
|
@ -331,6 +340,7 @@ class ModelCatalog(object):
|
|||
preprocessor_class)
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def register_custom_model(model_name, model_class):
|
||||
"""Register a custom model class by name.
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import tensorflow.contrib.rnn as rnn
|
|||
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.models.model import Model
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI, PublicAPI
|
||||
|
||||
|
||||
class LSTM(Model):
|
||||
|
@ -91,6 +91,7 @@ class LSTM(Model):
|
|||
return logits, last_layer
|
||||
|
||||
|
||||
@PublicAPI
|
||||
def add_time_dimension(padded_inputs, seq_lens):
|
||||
"""Adds a time dimension to padded inputs.
|
||||
|
||||
|
@ -118,6 +119,7 @@ def add_time_dimension(padded_inputs, seq_lens):
|
|||
return tf.reshape(padded_inputs, new_shape)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def chop_into_sequences(episode_ids,
|
||||
agent_indices,
|
||||
feature_columns,
|
||||
|
|
|
@ -9,8 +9,10 @@ import tensorflow as tf
|
|||
|
||||
from ray.rllib.models.misc import linear, normc_initializer
|
||||
from ray.rllib.models.preprocessors import get_preprocessor
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class Model(object):
|
||||
"""Defines an abstract network model for use with RLlib.
|
||||
|
||||
|
@ -90,6 +92,7 @@ class Model(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def _build_layers_v2(self, input_dict, num_outputs, options):
|
||||
"""Define the layers of a custom model.
|
||||
|
||||
|
@ -122,6 +125,7 @@ class Model(object):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def value_function(self):
|
||||
"""Builds the value function output.
|
||||
|
||||
|
@ -134,6 +138,7 @@ class Model(object):
|
|||
return tf.reshape(
|
||||
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
|
||||
|
||||
@PublicAPI
|
||||
def loss(self):
|
||||
"""Builds any built-in (self-supervised) loss for the model.
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import logging
|
|||
import numpy as np
|
||||
import gym
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
ATARI_OBS_SHAPE = (210, 160, 3)
|
||||
ATARI_RAM_OBS_SHAPE = (128, )
|
||||
|
@ -16,6 +16,7 @@ ATARI_RAM_OBS_SHAPE = (128, )
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class Preprocessor(object):
|
||||
"""Defines an abstract observation preprocessor function.
|
||||
|
||||
|
@ -23,25 +24,30 @@ class Preprocessor(object):
|
|||
shape (obj): Shape of the preprocessed output.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, obs_space, options=None):
|
||||
legacy_patch_shapes(obs_space)
|
||||
self._obs_space = obs_space
|
||||
self._options = options or {}
|
||||
self.shape = self._init_shape(obs_space, options)
|
||||
|
||||
@PublicAPI
|
||||
def _init_shape(self, obs_space, options):
|
||||
"""Returns the shape after preprocessing."""
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def transform(self, observation):
|
||||
"""Returns the preprocessed observation."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@PublicAPI
|
||||
def size(self):
|
||||
return int(np.product(self.shape))
|
||||
|
||||
@property
|
||||
@PublicAPI
|
||||
def observation_space(self):
|
||||
obs_space = gym.spaces.Box(-1.0, 1.0, self.shape, dtype=np.float32)
|
||||
# Stash the unwrapped space so that we can unwrap dict and tuple spaces
|
||||
|
@ -186,6 +192,7 @@ class DictFlatteningPreprocessor(Preprocessor):
|
|||
])
|
||||
|
||||
|
||||
@PublicAPI
|
||||
def get_preprocessor(space):
|
||||
"""Returns an appropriate preprocessor class for the given space."""
|
||||
|
||||
|
|
|
@ -6,8 +6,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
from ray.rllib.models.model import _restore_original_dimensions
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class TorchModel(nn.Module):
|
||||
"""Defines an abstract network model for use with RLlib / PyTorch."""
|
||||
|
||||
|
@ -25,6 +27,7 @@ class TorchModel(nn.Module):
|
|||
self.num_outputs = num_outputs
|
||||
self.options = options
|
||||
|
||||
@PublicAPI
|
||||
def forward(self, input_dict, hidden_state):
|
||||
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
|
||||
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
|
||||
|
@ -33,10 +36,12 @@ class TorchModel(nn.Module):
|
|||
outputs, features, vf, h = self._forward(input_dict, hidden_state)
|
||||
return outputs, features, vf, h
|
||||
|
||||
@PublicAPI
|
||||
def state_init(self):
|
||||
"""Returns a list of initial hidden state tensors, if any."""
|
||||
return []
|
||||
|
||||
@PublicAPI
|
||||
def _forward(self, input_dict, hidden_state):
|
||||
"""Forward pass for the model.
|
||||
|
||||
|
|
|
@ -3,11 +3,14 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class InputReader(object):
|
||||
"""Input object for loading experiences in policy evaluation."""
|
||||
|
||||
@PublicAPI
|
||||
def next(self):
|
||||
"""Return the next batch of experiences read.
|
||||
|
||||
|
|
|
@ -5,8 +5,10 @@ from __future__ import print_function
|
|||
import os
|
||||
|
||||
from ray.rllib.offline.input_reader import SamplerInput
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class IOContext(object):
|
||||
"""Attributes to pass to input / output class constructors.
|
||||
|
||||
|
@ -20,6 +22,7 @@ class IOContext(object):
|
|||
evaluator (PolicyEvaluator): policy evaluator object reference.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self,
|
||||
log_dir=None,
|
||||
config=None,
|
||||
|
@ -30,5 +33,6 @@ class IOContext(object):
|
|||
self.worker_index = worker_index
|
||||
self.evaluator = evaluator
|
||||
|
||||
@PublicAPI
|
||||
def default_sampler_input(self):
|
||||
return SamplerInput(self.evaluator.sampler)
|
||||
|
|
|
@ -19,17 +19,19 @@ from ray.rllib.offline.input_reader import InputReader
|
|||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, SampleBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class JsonReader(InputReader):
|
||||
"""Reader object that loads experiences from JSON file chunks.
|
||||
|
||||
The input files will be read from in an random order."""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, inputs, ioctx=None):
|
||||
"""Initialize a JsonReader.
|
||||
|
||||
|
|
|
@ -18,15 +18,17 @@ except ImportError:
|
|||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.output_writer import OutputWriter
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.compression import pack
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class JsonWriter(OutputWriter):
|
||||
"""Writer object that saves experiences in JSON file chunks."""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self,
|
||||
path,
|
||||
ioctx=None,
|
||||
|
|
|
@ -6,9 +6,10 @@ import numpy as np
|
|||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class MixedInput(InputReader):
|
||||
"""Mixes input from a number of other input sources.
|
||||
|
||||
|
@ -20,6 +21,7 @@ class MixedInput(InputReader):
|
|||
}, ioctx)
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, dist, ioctx):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
|
|
|
@ -3,11 +3,14 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class OutputWriter(object):
|
||||
"""Writer object for saving experiences from policy evaluation."""
|
||||
|
||||
@PublicAPI
|
||||
def write(self, sample_batch):
|
||||
"""Save a batch of experiences.
|
||||
|
||||
|
|
|
@ -5,12 +5,14 @@ from __future__ import print_function
|
|||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PolicyOptimizer(object):
|
||||
"""Policy optimizers encapsulate distributed RL optimization strategies.
|
||||
|
||||
|
@ -36,6 +38,7 @@ class PolicyOptimizer(object):
|
|||
evaluators created by this optimizer.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, local_evaluator, remote_evaluators=None, config=None):
|
||||
"""Create an optimizer instance.
|
||||
|
||||
|
@ -59,11 +62,13 @@ class PolicyOptimizer(object):
|
|||
logger.debug("Created policy optimizer with {}: {}".format(
|
||||
config, self))
|
||||
|
||||
@DeveloperAPI
|
||||
def _init(self):
|
||||
"""Subclasses should prefer overriding this instead of __init__."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def step(self):
|
||||
"""Takes a logical optimization step.
|
||||
|
||||
|
@ -77,6 +82,7 @@ class PolicyOptimizer(object):
|
|||
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def stats(self):
|
||||
"""Returns a dictionary of internal performance statistics."""
|
||||
|
||||
|
@ -85,21 +91,25 @@ class PolicyOptimizer(object):
|
|||
"num_steps_sampled": self.num_steps_sampled,
|
||||
}
|
||||
|
||||
@DeveloperAPI
|
||||
def save(self):
|
||||
"""Returns a serializable object representing the optimizer state."""
|
||||
|
||||
return [self.num_steps_trained, self.num_steps_sampled]
|
||||
|
||||
@DeveloperAPI
|
||||
def restore(self, data):
|
||||
"""Restores optimizer state from the given data object."""
|
||||
|
||||
self.num_steps_trained = data[0]
|
||||
self.num_steps_sampled = data[1]
|
||||
|
||||
@DeveloperAPI
|
||||
def stop(self):
|
||||
"""Release any resources used by this optimizer."""
|
||||
pass
|
||||
|
||||
@DeveloperAPI
|
||||
def collect_metrics(self,
|
||||
timeout_seconds,
|
||||
min_history=100,
|
||||
|
@ -132,6 +142,7 @@ class PolicyOptimizer(object):
|
|||
res.update(info=self.stats())
|
||||
return res
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_evaluator(self, func):
|
||||
"""Apply the given function to each evaluator instance."""
|
||||
|
||||
|
@ -140,6 +151,7 @@ class PolicyOptimizer(object):
|
|||
[ev.apply.remote(func) for ev in self.remote_evaluators])
|
||||
return local_result + remote_results
|
||||
|
||||
@DeveloperAPI
|
||||
def foreach_evaluator_with_index(self, func):
|
||||
"""Apply the given function to each evaluator instance.
|
||||
|
||||
|
|
|
@ -7,11 +7,14 @@ import random
|
|||
import sys
|
||||
|
||||
from ray.rllib.optimizers.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
from ray.rllib.utils.window_stat import WindowStat
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class ReplayBuffer(object):
|
||||
@DeveloperAPI
|
||||
def __init__(self, size):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
|
@ -34,6 +37,7 @@ class ReplayBuffer(object):
|
|||
def __len__(self):
|
||||
return len(self._storage)
|
||||
|
||||
@DeveloperAPI
|
||||
def add(self, obs_t, action, reward, obs_tp1, done, weight):
|
||||
data = (obs_t, action, reward, obs_tp1, done)
|
||||
self._num_added += 1
|
||||
|
@ -64,6 +68,7 @@ class ReplayBuffer(object):
|
|||
return (np.array(obses_t), np.array(actions), np.array(rewards),
|
||||
np.array(obses_tp1), np.array(dones))
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self, batch_size):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
|
@ -93,6 +98,7 @@ class ReplayBuffer(object):
|
|||
self._num_sampled += batch_size
|
||||
return self._encode_sample(idxes)
|
||||
|
||||
@DeveloperAPI
|
||||
def stats(self, debug=False):
|
||||
data = {
|
||||
"added_count": self._num_added,
|
||||
|
@ -105,7 +111,9 @@ class ReplayBuffer(object):
|
|||
return data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
@DeveloperAPI
|
||||
def __init__(self, size, alpha):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
|
@ -135,6 +143,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
self._max_priority = 1.0
|
||||
self._prio_change_stats = WindowStat("reprio", 1000)
|
||||
|
||||
@DeveloperAPI
|
||||
def add(self, obs_t, action, reward, obs_tp1, done, weight):
|
||||
"""See ReplayBuffer.store_effect"""
|
||||
|
||||
|
@ -155,6 +164,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
res.append(idx)
|
||||
return res
|
||||
|
||||
@DeveloperAPI
|
||||
def sample(self, batch_size, beta):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
|
@ -208,6 +218,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
encoded_sample = self._encode_sample(idxes)
|
||||
return tuple(list(encoded_sample) + [weights, idxes])
|
||||
|
||||
@DeveloperAPI
|
||||
def update_priorities(self, idxes, priorities):
|
||||
"""Update priorities of sampled transitions.
|
||||
|
||||
|
@ -234,6 +245,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||
|
||||
self._max_priority = max(self._max_priority, priority)
|
||||
|
||||
@DeveloperAPI
|
||||
def stats(self, debug=False):
|
||||
parent = ReplayBuffer.stats(self, debug)
|
||||
if debug:
|
||||
|
|
|
@ -17,7 +17,7 @@ from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
|
|||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
@ -176,13 +176,13 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
self.assertEqual(done["__all__"], True)
|
||||
|
||||
def testNoResetUntilPoll(self):
|
||||
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 1)
|
||||
env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 1)
|
||||
self.assertFalse(env.get_unwrapped()[0].resetted)
|
||||
env.poll()
|
||||
self.assertTrue(env.get_unwrapped()[0].resetted)
|
||||
|
||||
def testVectorizeBasic(self):
|
||||
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2)
|
||||
env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
|
||||
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
|
||||
|
@ -258,7 +258,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
|||
})
|
||||
|
||||
def testVectorizeRoundRobin(self):
|
||||
env = _MultiAgentEnvToAsync(lambda v: RoundRobinMultiAgent(2), [], 2)
|
||||
env = _MultiAgentEnvToBaseEnv(lambda v: RoundRobinMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
|
||||
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})
|
||||
|
|
|
@ -16,7 +16,7 @@ from ray.rllib.agents.a3c import A2CAgent
|
|||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.async_vector_env import AsyncVectorEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.model import Model
|
||||
|
@ -303,8 +303,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv()))
|
||||
|
||||
def testNestedDictAsync(self):
|
||||
self.doTestNestedDict(
|
||||
lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv()))
|
||||
self.doTestNestedDict(lambda _: BaseEnv.to_base_env(NestedDictEnv()))
|
||||
|
||||
def testNestedTupleGym(self):
|
||||
self.doTestNestedTuple(lambda _: NestedTupleEnv())
|
||||
|
@ -317,8 +316,7 @@ class NestedSpacesTest(unittest.TestCase):
|
|||
self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv()))
|
||||
|
||||
def testNestedTupleAsync(self):
|
||||
self.doTestNestedTuple(
|
||||
lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv()))
|
||||
self.doTestNestedTuple(lambda _: BaseEnv.to_base_env(NestedTupleEnv()))
|
||||
|
||||
def testMultiAgentComplexSpaces(self):
|
||||
ModelCatalog.register_custom_model("dict_spy", DictSpyModel)
|
||||
|
|
|
@ -18,3 +18,36 @@ def override(cls):
|
|||
return method
|
||||
|
||||
return check_override
|
||||
|
||||
|
||||
def PublicAPI(obj):
|
||||
"""Annotation for documenting public APIs.
|
||||
|
||||
Public APIs are classes and methods exposed to end users of RLlib. You
|
||||
can expect these APIs to remain stable across RLlib releases.
|
||||
|
||||
Subclasses that inherit from a ``@PublicAPI`` base class can be
|
||||
assumed part of the RLlib public API as well (e.g., all agent classes
|
||||
are in public API because Agent is ``@PublicAPI``).
|
||||
|
||||
In addition, you can assume all agent configurations are part of their
|
||||
public API as well.
|
||||
"""
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def DeveloperAPI(obj):
|
||||
"""Annotation for documenting developer APIs.
|
||||
|
||||
Developer APIs are classes and methods explicitly exposed to developers
|
||||
for the purposes of building custom algorithms or advanced training
|
||||
strategies on top of RLlib internals. You can generally expect these APIs
|
||||
to be stable sans minor changes (but less stable than public APIs).
|
||||
|
||||
Subclasses that inherit from a ``@DeveloperAPI`` base class can be
|
||||
assumed part of the RLlib developer API as well (e.g., all policy
|
||||
optimizers are developer API because PolicyOptimizer is ``@DeveloperAPI``).
|
||||
"""
|
||||
|
||||
return obj
|
||||
|
|
|
@ -9,6 +9,8 @@ import numpy as np
|
|||
import pyarrow
|
||||
from six import string_types
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
|
@ -21,6 +23,7 @@ except ImportError:
|
|||
LZ4_ENABLED = False
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def pack(data):
|
||||
if LZ4_ENABLED:
|
||||
data = pyarrow.serialize(data).to_buffer().to_pybytes()
|
||||
|
@ -31,12 +34,14 @@ def pack(data):
|
|||
return data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def pack_if_needed(data):
|
||||
if isinstance(data, np.ndarray):
|
||||
data = pack(data)
|
||||
return data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def unpack(data):
|
||||
if LZ4_ENABLED:
|
||||
data = base64.b64decode(data)
|
||||
|
@ -45,6 +50,7 @@ def unpack(data):
|
|||
return data
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def unpack_if_needed(data):
|
||||
if isinstance(data, bytes) or isinstance(data, string_types):
|
||||
data = unpack(data)
|
||||
|
|
|
@ -2,7 +2,10 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class UnsupportedSpaceException(Exception):
|
||||
"""Error for an unsupported action or observation space."""
|
||||
pass
|
||||
|
|
|
@ -3,14 +3,17 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class FilterManager(object):
|
||||
"""Manages filters and coordination across remote evaluators that expose
|
||||
`get_filters` and `sync_filters`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
def synchronize(local_filters, remotes, update_remote=True):
|
||||
"""Aggregates all filters from remote evaluators.
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ from __future__ import print_function
|
|||
import logging
|
||||
import pickle
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
|
@ -16,6 +18,7 @@ except ImportError:
|
|||
" the client side.")
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class PolicyClient(object):
|
||||
"""REST client to interact with a RLlib policy server."""
|
||||
|
||||
|
@ -25,9 +28,11 @@ class PolicyClient(object):
|
|||
LOG_RETURNS = "LOG_RETURNS"
|
||||
END_EPISODE = "END_EPISODE"
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, address):
|
||||
self._address = address
|
||||
|
||||
@PublicAPI
|
||||
def start_episode(self, episode_id=None, training_enabled=True):
|
||||
"""Record the start of an episode.
|
||||
|
||||
|
@ -47,6 +52,7 @@ class PolicyClient(object):
|
|||
"training_enabled": training_enabled,
|
||||
})["episode_id"]
|
||||
|
||||
@PublicAPI
|
||||
def get_action(self, episode_id, observation):
|
||||
"""Record an observation and get the on-policy action.
|
||||
|
||||
|
@ -63,6 +69,7 @@ class PolicyClient(object):
|
|||
"episode_id": episode_id,
|
||||
})["action"]
|
||||
|
||||
@PublicAPI
|
||||
def log_action(self, episode_id, observation, action):
|
||||
"""Record an observation and (off-policy) action taken.
|
||||
|
||||
|
@ -78,6 +85,7 @@ class PolicyClient(object):
|
|||
"episode_id": episode_id,
|
||||
})
|
||||
|
||||
@PublicAPI
|
||||
def log_returns(self, episode_id, reward, info=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
|
@ -96,6 +104,7 @@ class PolicyClient(object):
|
|||
"episode_id": episode_id,
|
||||
})
|
||||
|
||||
@PublicAPI
|
||||
def end_episode(self, episode_id, observation):
|
||||
"""Record the end of an episode.
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ import pickle
|
|||
import sys
|
||||
import traceback
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
|
@ -17,6 +18,7 @@ elif sys.version_info[0] == 3:
|
|||
from socketserver import ThreadingMixIn
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class PolicyServer(ThreadingMixIn, HTTPServer):
|
||||
"""REST server than can be launched from a ExternalEnv.
|
||||
|
||||
|
@ -50,6 +52,7 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
|
|||
>>> client.log_returns(eps_id, reward)
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, external_env, address, port):
|
||||
handler = _make_handler(external_env)
|
||||
HTTPServer.__init__(self, (address, port), handler)
|
||||
|
|
Loading…
Add table
Reference in a new issue