[rllib] annotate public vs developer vs private APIs (#3808)

This commit is contained in:
Eric Liang 2019-01-23 21:27:26 -08:00 committed by GitHub
parent 01e18b47f4
commit 04ec47cbd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
45 changed files with 562 additions and 274 deletions

View file

@ -6,6 +6,15 @@ Development Install
You can develop RLlib locally without needing to compile Ray by using the `setup-rllib-dev.py <https://github.com/ray-project/ray/blob/master/python/ray/rllib/setup-rllib-dev.py>`__ script. This sets up links between the ``rllib`` dir in your git repo and the one bundled with the ``ray`` package. When using this script, make sure that your git branch is in sync with the installed Ray binaries (i.e., you are up-to-date on `master <https://github.com/ray-project/ray>`__ and have the latest `wheel <https://ray.readthedocs.io/en/latest/installation.html>`__ installed.)
API Stability
-------------
Objects and methods annotated with ``@PublicAPI`` or ``@DeveloperAPI`` have the following API compatibility guarantees:
.. autofunction:: ray.rllib.utils.annotations.PublicAPI
.. autofunction:: ray.rllib.utils.annotations.DeveloperAPI
Features
--------

View file

@ -310,6 +310,6 @@ Note that envs can read from different partitions of the logs based on the ``wor
Batch Asynchronous
------------------
The lowest-level "catch-all" environment supported by RLlib is `AsyncVectorEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/async_vector_env.py>`__. AsyncVectorEnv models multiple agents executing asynchronously in multiple environments. A call to ``poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents can be sent back via ``send_actions()``. This interface can be subclassed directly to support batched simulators such as `ELF <https://github.com/facebookresearch/ELF>`__.
The lowest-level "catch-all" environment supported by RLlib is `BaseEnv <https://github.com/ray-project/ray/blob/master/python/ray/rllib/env/base_env.py>`__. BaseEnv models multiple agents executing asynchronously in multiple environments. A call to ``poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents can be sent back via ``send_actions()``. This interface can be subclassed directly to support batched simulators such as `ELF <https://github.com/facebookresearch/ELF>`__.
Under the hood, all other envs are converted to AsyncVectorEnv by RLlib so that there is a common internal path for policy evaluation.
Under the hood, all other envs are converted to BaseEnv by RLlib so that there is a common internal path for policy evaluation.

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 77 KiB

After

Width:  |  Height:  |  Size: 75 KiB

View file

@ -10,7 +10,7 @@ from ray.tune.registry import register_trainable
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.external_env import ExternalEnv
@ -47,7 +47,7 @@ __all__ = [
"TFPolicyGraph",
"PolicyEvaluator",
"SampleBatch",
"AsyncVectorEnv",
"BaseEnv",
"MultiAgentEnv",
"VectorEnv",
"ExternalEnv",

View file

@ -18,7 +18,7 @@ from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
from ray.tune.trainable import Trainable
@ -182,6 +182,7 @@ COMMON_CONFIG = {
# yapf: enable
@DeveloperAPI
def with_common_config(extra_config):
"""Returns the given config dict merged with common agent confs."""
@ -196,6 +197,7 @@ def with_base_config(base_config, extra_config):
return config
@PublicAPI
class Agent(Trainable):
"""All RLlib agents extend this base class.
@ -214,6 +216,7 @@ class Agent(Trainable):
"custom_resources_per_worker"
]
@PublicAPI
def __init__(self, config=None, env=None, logger_creator=None):
"""Initialize an RLLib agent.
@ -266,6 +269,7 @@ class Agent(Trainable):
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
@override(Trainable)
@PublicAPI
def train(self):
"""Overrides super.train to synchronize global vars."""
@ -344,11 +348,13 @@ class Agent(Trainable):
extra_data = pickle.load(open(checkpoint_path, "rb"))
self.__setstate__(extra_data)
@DeveloperAPI
def _init(self):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
@PublicAPI
def compute_action(self,
observation,
state=None,
@ -404,6 +410,7 @@ class Agent(Trainable):
raise NotImplementedError
@PublicAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
@ -413,6 +420,7 @@ class Agent(Trainable):
return self.local_evaluator.get_policy(policy_id)
@PublicAPI
def get_weights(self, policies=None):
"""Return a dictionary of policy ids to weights.
@ -422,6 +430,7 @@ class Agent(Trainable):
"""
return self.local_evaluator.get_weights(policies)
@PublicAPI
def set_weights(self, weights):
"""Set policy weights by policy id.
@ -430,6 +439,7 @@ class Agent(Trainable):
"""
self.local_evaluator.set_weights(weights)
@DeveloperAPI
def make_local_evaluator(self, env_creator, policy_graph):
"""Convenience method to return configured local evaluator."""
@ -444,6 +454,7 @@ class Agent(Trainable):
config["local_evaluator_tf_session_args"]
}))
@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
"""Convenience method to return a number of remote evaluators."""
@ -459,6 +470,7 @@ class Agent(Trainable):
self.config) for i in range(count)
]
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
"""Export policy model with given policy_id to local directory.
@ -474,6 +486,7 @@ class Agent(Trainable):
"""
self.local_evaluator.export_policy_model(export_dir, policy_id)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",

View file

@ -1,4 +1,4 @@
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.serving_env import ServingEnv
@ -6,6 +6,6 @@ from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.env_context import EnvContext
__all__ = [
"AsyncVectorEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv",
"ServingEnv", "EnvContext"
"BaseEnv", "MultiAgentEnv", "ExternalEnv", "VectorEnv", "ServingEnv",
"EnvContext"
]

View file

@ -5,23 +5,24 @@ from __future__ import print_function
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI
class AsyncVectorEnv(object):
@PublicAPI
class BaseEnv(object):
"""The lowest-level env interface used by RLlib for sampling.
AsyncVectorEnv models multiple agents executing asynchronously in multiple
BaseEnv models multiple agents executing asynchronously in multiple
environments. A call to poll() returns observations from ready agents
keyed by their environment and agent ids, and actions for those agents
can be sent back via send_actions().
All other env types can be adapted to AsyncVectorEnv. RLlib handles these
All other env types can be adapted to BaseEnv. RLlib handles these
conversions internally in PolicyEvaluator, for example:
gym.Env => rllib.VectorEnv => rllib.AsyncVectorEnv
rllib.MultiAgentEnv => rllib.AsyncVectorEnv
rllib.ExternalEnv => rllib.AsyncVectorEnv
gym.Env => rllib.VectorEnv => rllib.BaseEnv
rllib.MultiAgentEnv => rllib.BaseEnv
rllib.ExternalEnv => rllib.BaseEnv
Attributes:
action_space (gym.Space): Action space. This must be defined for
@ -30,7 +31,7 @@ class AsyncVectorEnv(object):
for single-agent envs. Multi-agent envs can set this to None.
Examples:
>>> env = MyAsyncVectorEnv()
>>> env = MyBaseEnv()
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
>>> print(obs)
{
@ -65,26 +66,27 @@ class AsyncVectorEnv(object):
"""
@staticmethod
def wrap_async(env, make_env=None, num_envs=1):
def to_base_env(env, make_env=None, num_envs=1):
"""Wraps any env type as needed to expose the async interface."""
if not isinstance(env, AsyncVectorEnv):
if not isinstance(env, BaseEnv):
if isinstance(env, MultiAgentEnv):
env = _MultiAgentEnvToAsync(
env = _MultiAgentEnvToBaseEnv(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
elif isinstance(env, ExternalEnv):
if num_envs != 1:
raise ValueError(
"ExternalEnv does not currently support num_envs > 1.")
env = _ExternalEnvToAsync(env)
env = _ExternalEnvToBaseEnv(env)
elif isinstance(env, VectorEnv):
env = _VectorEnvToAsync(env)
env = _VectorEnvToBaseEnv(env)
else:
env = VectorEnv.wrap(
make_env=make_env, existing_envs=[env], num_envs=num_envs)
env = _VectorEnvToAsync(env)
assert isinstance(env, AsyncVectorEnv)
env = _VectorEnvToBaseEnv(env)
assert isinstance(env, BaseEnv)
return env
@PublicAPI
def poll(self):
"""Returns observations from ready agents.
@ -107,6 +109,7 @@ class AsyncVectorEnv(object):
"""
raise NotImplementedError
@PublicAPI
def send_actions(self, action_dict):
"""Called to send actions back to running agents in this env.
@ -118,6 +121,7 @@ class AsyncVectorEnv(object):
"""
raise NotImplementedError
@PublicAPI
def try_reset(self, env_id):
"""Attempt to reset the env with the given id.
@ -129,6 +133,7 @@ class AsyncVectorEnv(object):
"""
return None
@PublicAPI
def get_unwrapped(self):
"""Return a reference to the underlying gym envs, if any.
@ -146,8 +151,8 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
return {k: {dummy_id: v} for (k, v) in env_id_to_values.items()}
class _ExternalEnvToAsync(AsyncVectorEnv):
"""Internal adapter of ExternalEnv to AsyncVectorEnv."""
class _ExternalEnvToBaseEnv(BaseEnv):
"""Internal adapter of ExternalEnv to BaseEnv."""
def __init__(self, external_env, preprocessor=None):
self.external_env = external_env
@ -159,7 +164,7 @@ class _ExternalEnvToAsync(AsyncVectorEnv):
self.observation_space = external_env.observation_space
external_env.start()
@override(AsyncVectorEnv)
@override(BaseEnv)
def poll(self):
with self.external_env._results_avail_condition:
results = self._poll()
@ -174,7 +179,7 @@ class _ExternalEnvToAsync(AsyncVectorEnv):
"ExternalEnv was created with max_concurrent={}".format(limit))
return results
@override(AsyncVectorEnv)
@override(BaseEnv)
def send_actions(self, action_dict):
for eid, action in action_dict.items():
self.external_env._episodes[eid].action_queue.put(
@ -204,8 +209,8 @@ class _ExternalEnvToAsync(AsyncVectorEnv):
_with_dummy_agent_id(off_policy_actions)
class _VectorEnvToAsync(AsyncVectorEnv):
"""Internal adapter of VectorEnv to AsyncVectorEnv.
class _VectorEnvToBaseEnv(BaseEnv):
"""Internal adapter of VectorEnv to BaseEnv.
We assume the caller will always send the full vector of actions in each
call to send_actions(), and that they call reset_at() on all completed
@ -222,7 +227,7 @@ class _VectorEnvToAsync(AsyncVectorEnv):
self.cur_dones = [False for _ in range(self.num_envs)]
self.cur_infos = [None for _ in range(self.num_envs)]
@override(AsyncVectorEnv)
@override(BaseEnv)
def poll(self):
if self.new_obs is None:
self.new_obs = self.vector_env.vector_reset()
@ -239,7 +244,7 @@ class _VectorEnvToAsync(AsyncVectorEnv):
_with_dummy_agent_id(dones, "__all__"), \
_with_dummy_agent_id(infos), {}
@override(AsyncVectorEnv)
@override(BaseEnv)
def send_actions(self, action_dict):
action_vector = [None] * self.num_envs
for i in range(self.num_envs):
@ -247,17 +252,17 @@ class _VectorEnvToAsync(AsyncVectorEnv):
self.new_obs, self.cur_rewards, self.cur_dones, self.cur_infos = \
self.vector_env.vector_step(action_vector)
@override(AsyncVectorEnv)
@override(BaseEnv)
def try_reset(self, env_id):
return {_DUMMY_AGENT_ID: self.vector_env.reset_at(env_id)}
@override(AsyncVectorEnv)
@override(BaseEnv)
def get_unwrapped(self):
return self.vector_env.get_unwrapped()
class _MultiAgentEnvToAsync(AsyncVectorEnv):
"""Internal adapter of MultiAgentEnv to AsyncVectorEnv.
class _MultiAgentEnvToBaseEnv(BaseEnv):
"""Internal adapter of MultiAgentEnv to BaseEnv.
This also supports vectorization if num_envs > 1.
"""
@ -282,14 +287,14 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
assert isinstance(env, MultiAgentEnv)
self.env_states = [_MultiAgentEnvState(env) for env in self.envs]
@override(AsyncVectorEnv)
@override(BaseEnv)
def poll(self):
obs, rewards, dones, infos = {}, {}, {}, {}
for i, env_state in enumerate(self.env_states):
obs[i], rewards[i], dones[i], infos[i] = env_state.poll()
return obs, rewards, dones, infos, {}
@override(AsyncVectorEnv)
@override(BaseEnv)
def send_actions(self, action_dict):
for env_id, agent_dict in action_dict.items():
if env_id in self.dones:
@ -311,7 +316,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)
@override(AsyncVectorEnv)
@override(BaseEnv)
def try_reset(self, env_id):
obs = self.env_states[env_id].reset()
assert isinstance(obs, dict), "Not a multi-agent obs"
@ -319,7 +324,7 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
self.dones.remove(env_id)
return obs
@override(AsyncVectorEnv)
@override(BaseEnv)
def get_unwrapped(self):
return [state.env for state in self.env_states]

View file

@ -2,7 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class EnvContext(dict):
"""Wraps env configurations to include extra rllib metadata.

View file

@ -6,7 +6,10 @@ from six.moves import queue
import threading
import uuid
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class ExternalEnv(threading.Thread):
"""An environment that interfaces with external agents.
@ -36,6 +39,7 @@ class ExternalEnv(threading.Thread):
print(agent.train())
"""
@PublicAPI
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize an external env.
@ -57,6 +61,7 @@ class ExternalEnv(threading.Thread):
self._results_avail_condition = threading.Condition()
self._max_concurrent_episodes = max_concurrent
@PublicAPI
def run(self):
"""Override this to implement the run loop.
@ -73,6 +78,7 @@ class ExternalEnv(threading.Thread):
"""
raise NotImplementedError
@PublicAPI
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
@ -102,6 +108,7 @@ class ExternalEnv(threading.Thread):
return episode_id
@PublicAPI
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
@ -116,6 +123,7 @@ class ExternalEnv(threading.Thread):
episode = self._get(episode_id)
return episode.wait_for_action(observation)
@PublicAPI
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
@ -128,6 +136,7 @@ class ExternalEnv(threading.Thread):
episode = self._get(episode_id)
episode.log_action(observation, action)
@PublicAPI
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
@ -146,6 +155,7 @@ class ExternalEnv(threading.Thread):
if info:
episode.cur_info = info or {}
@PublicAPI
def end_episode(self, episode_id, observation):
"""Record the end of an episode.

View file

@ -2,7 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class MultiAgentEnv(object):
"""An environment that hosts multiple independent agents.
@ -41,6 +44,7 @@ class MultiAgentEnv(object):
}
"""
@PublicAPI
def reset(self):
"""Resets the env and returns observations from ready agents.
@ -49,6 +53,7 @@ class MultiAgentEnv(object):
"""
raise NotImplementedError
@PublicAPI
def step(self, action_dict):
"""Returns observations from ready agents.
@ -68,6 +73,7 @@ class MultiAgentEnv(object):
# yapf: disable
# __grouping_doc_begin__
@PublicAPI
def with_agent_groups(self, groups, obs_space=None, act_space=None):
"""Convenience method for grouping together agents in this env.

View file

@ -2,9 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI
@PublicAPI
class VectorEnv(object):
"""An environment that supports batch evaluation.
@ -20,6 +21,7 @@ class VectorEnv(object):
def wrap(make_env=None, existing_envs=None, num_envs=1):
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs)
@PublicAPI
def vector_reset(self):
"""Resets all environments.
@ -28,6 +30,7 @@ class VectorEnv(object):
"""
raise NotImplementedError
@PublicAPI
def reset_at(self, index):
"""Resets a single environment.
@ -36,6 +39,7 @@ class VectorEnv(object):
"""
raise NotImplementedError
@PublicAPI
def vector_step(self, actions):
"""Vectorized step.
@ -50,6 +54,7 @@ class VectorEnv(object):
"""
raise NotImplementedError
@PublicAPI
def get_unwrapped(self):
"""Returns the underlying env instances."""
raise NotImplementedError

View file

@ -4,9 +4,9 @@ from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.evaluation.sample_batch import (SampleBatch, MultiAgentBatch,
SampleBatchBuilder,
MultiAgentSampleBatchBuilder)
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.evaluation.sample_batch_builder import (
SampleBatchBuilder, MultiAgentSampleBatchBuilder)
from ray.rllib.evaluation.sampler import SyncSampler, AsyncSampler
from ray.rllib.evaluation.postprocessing import (compute_advantages,
compute_targets)

View file

@ -7,16 +7,14 @@ import random
import numpy as np
from ray.rllib.env.async_vector_env import _DUMMY_AGENT_ID
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
class MultiAgentEpisode(object):
"""Tracks the current state of a (possibly multi-agent) episode.
The APIs in this class should be considered experimental, but we should
avoid changing things for the sake of changing them since users may
depend on them for custom metrics or advanced algorithms.
Attributes:
new_batch_builder (func): Create a new MultiAgentSampleBatchBuilder.
add_extra_batch (func): Return a built MultiAgentBatch to the sampler.
@ -66,6 +64,7 @@ class MultiAgentEpisode(object):
self._agent_to_prev_action = {}
self._agent_reward_history = defaultdict(list)
@DeveloperAPI
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the policy graph for the specified agent.
@ -77,16 +76,19 @@ class MultiAgentEpisode(object):
self._agent_to_policy[agent_id] = self._policy_mapping_fn(agent_id)
return self._agent_to_policy[agent_id]
@DeveloperAPI
def last_observation_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last observation for the specified agent."""
return self._agent_to_last_obs.get(agent_id)
@DeveloperAPI
def last_info_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last info for the specified agent."""
return self._agent_to_last_info.get(agent_id)
@DeveloperAPI
def last_action_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last action for the specified agent, or zeros."""
@ -97,6 +99,7 @@ class MultiAgentEpisode(object):
flat = _flatten_action(policy.action_space.sample())
return np.zeros_like(flat)
@DeveloperAPI
def prev_action_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the previous action for the specified agent."""
@ -106,6 +109,7 @@ class MultiAgentEpisode(object):
# We're at t=0, so return all zeros.
return np.zeros_like(self.last_action_for(agent_id))
@DeveloperAPI
def prev_reward_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the previous reward for the specified agent."""
@ -116,6 +120,7 @@ class MultiAgentEpisode(object):
# We're at t=0, so there is no previous reward, just return zero.
return 0.0
@DeveloperAPI
def rnn_state_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last RNN state for the specified agent."""
@ -124,6 +129,7 @@ class MultiAgentEpisode(object):
self._agent_to_rnn_state[agent_id] = policy.get_initial_state()
return self._agent_to_rnn_state[agent_id]
@DeveloperAPI
def last_pi_info_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last info object for the specified agent."""

View file

@ -4,13 +4,17 @@ from __future__ import print_function
import os
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
class EvaluatorInterface(object):
"""This is the interface between policy optimizers and policy evaluation.
See also: PolicyEvaluator
"""
@DeveloperAPI
def sample(self):
"""Returns a batch of experience sampled from this evaluator.
@ -27,6 +31,7 @@ class EvaluatorInterface(object):
raise NotImplementedError
@DeveloperAPI
def compute_gradients(self, samples):
"""Returns a gradient computed w.r.t the specified samples.
@ -45,6 +50,7 @@ class EvaluatorInterface(object):
raise NotImplementedError
@DeveloperAPI
def apply_gradients(self, grads):
"""Applies the given gradients to this evaluator's weights.
@ -58,6 +64,7 @@ class EvaluatorInterface(object):
raise NotImplementedError
@DeveloperAPI
def get_weights(self):
"""Returns the model weights of this Evaluator.
@ -73,6 +80,7 @@ class EvaluatorInterface(object):
raise NotImplementedError
@DeveloperAPI
def set_weights(self, weights):
"""Sets the model weights of this Evaluator.
@ -85,6 +93,7 @@ class EvaluatorInterface(object):
raise NotImplementedError
@DeveloperAPI
def compute_apply(self, samples):
"""Fused compute gradients and apply gradients call.
@ -100,11 +109,13 @@ class EvaluatorInterface(object):
self.apply_gradients(grads)
return info
@DeveloperAPI
def get_host(self):
"""Returns the hostname of the process running this evaluator."""
return os.uname()[1]
@DeveloperAPI
def apply(self, func, *args):
"""Apply the given function to this evaluator instance."""

View file

@ -8,10 +8,12 @@ import collections
import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@DeveloperAPI
def collect_metrics(local_evaluator, remote_evaluators=[],
timeout_seconds=180):
"""Gathers episode metrics from PolicyEvaluator instances."""
@ -22,6 +24,7 @@ def collect_metrics(local_evaluator, remote_evaluators=[],
return metrics
@DeveloperAPI
def collect_episodes(local_evaluator,
remote_evaluators=[],
timeout_seconds=180):
@ -43,6 +46,7 @@ def collect_episodes(local_evaluator,
return episodes, num_metric_batches_dropped
@DeveloperAPI
def summarize_episodes(episodes, new_episodes, num_dropped):
"""Summarizes a set of episode metrics tuples.

View file

@ -8,7 +8,7 @@ import pickle
import tensorflow as tf
import ray
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@ -22,7 +22,7 @@ from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import NoPreprocessor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.tf_run_builder import TFRunBuilder
@ -30,6 +30,7 @@ from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
@DeveloperAPI
class PolicyEvaluator(EvaluatorInterface):
"""Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``.
@ -83,11 +84,13 @@ class PolicyEvaluator(EvaluatorInterface):
"traffic_light_policy": SampleBatch(...)})
"""
@DeveloperAPI
@classmethod
def as_remote(cls, num_cpus=None, num_gpus=None, resources=None):
return ray.remote(
num_cpus=num_cpus, num_gpus=num_gpus, resources=resources)(cls)
@DeveloperAPI
def __init__(self,
env_creator,
policy_graph,
@ -214,7 +217,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.env = env_creator(env_context)
if isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, AsyncVectorEnv):
isinstance(self.env, BaseEnv):
def wrap(env):
return env # we can't auto-wrap these env types
@ -275,7 +278,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
if not (isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, AsyncVectorEnv)):
or isinstance(self.env, BaseEnv)):
raise ValueError(
"Have multiple policy graphs {}, but the env ".format(
self.policy_map) +
@ -288,7 +291,7 @@ class PolicyEvaluator(EvaluatorInterface):
}
# Always use vector env for consistency even if num_envs = 1
self.async_env = AsyncVectorEnv.wrap_async(
self.async_env = BaseEnv.to_base_env(
self.env, make_env=make_env, num_envs=num_envs)
self.num_envs = num_envs
@ -399,6 +402,7 @@ class PolicyEvaluator(EvaluatorInterface):
return batch
@DeveloperAPI
@ray.method(num_return_vals=2)
def sample_with_count(self):
"""Same as sample() but returns the count as a separate future."""
@ -489,6 +493,7 @@ class PolicyEvaluator(EvaluatorInterface):
self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples))
return grad_fetch
@DeveloperAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
@ -498,16 +503,19 @@ class PolicyEvaluator(EvaluatorInterface):
return self.policy_map.get(policy_id)
@DeveloperAPI
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy graph."""
return func(self.policy_map[policy_id])
@DeveloperAPI
def foreach_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple."""
return [func(policy, pid) for pid, policy in self.policy_map.items()]
@DeveloperAPI
def foreach_trainable_policy(self, func):
"""Apply the given function to each (policy, policy_id) tuple.
@ -518,6 +526,7 @@ class PolicyEvaluator(EvaluatorInterface):
if pid in self.policies_to_train
]
@DeveloperAPI
def sync_filters(self, new_filters):
"""Changes self's filter to given and rebases any accumulated delta.
@ -528,6 +537,7 @@ class PolicyEvaluator(EvaluatorInterface):
for k in self.filters:
self.filters[k].sync(new_filters[k])
@DeveloperAPI
def get_filters(self, flush_after=False):
"""Returns a snapshot of filters.
@ -544,6 +554,7 @@ class PolicyEvaluator(EvaluatorInterface):
f.clear_buffer()
return return_filters
@DeveloperAPI
def save(self):
filters = self.get_filters(flush_after=True)
state = {
@ -552,18 +563,22 @@ class PolicyEvaluator(EvaluatorInterface):
}
return pickle.dumps({"filters": filters, "state": state})
@DeveloperAPI
def restore(self, objs):
objs = pickle.loads(objs)
self.sync_filters(objs["filters"])
for pid, state in objs["state"].items():
self.policy_map[pid].set_state(state)
@DeveloperAPI
def set_global_vars(self, global_vars):
self.foreach_policy(lambda p, _: p.on_global_var_update(global_vars))
@DeveloperAPI
def export_policy_model(self, export_dir, policy_id=DEFAULT_POLICY_ID):
self.policy_map[policy_id].export_model(export_dir)
@DeveloperAPI
def export_policy_checkpoint(self,
export_dir,
filename_prefix="model",

View file

@ -2,7 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
class PolicyGraph(object):
"""An agent policy and loss, i.e., a TFPolicyGraph or other subclass.
@ -21,6 +24,7 @@ class PolicyGraph(object):
action_space (gym.Space): Action space of the policy.
"""
@DeveloperAPI
def __init__(self, observation_space, action_space, config):
"""Initialize the graph.
@ -37,6 +41,7 @@ class PolicyGraph(object):
self.observation_space = observation_space
self.action_space = action_space
@DeveloperAPI
def compute_actions(self,
obs_batch,
state_batches,
@ -68,6 +73,7 @@ class PolicyGraph(object):
"""
raise NotImplementedError
@DeveloperAPI
def compute_single_action(self,
obs,
state,
@ -116,6 +122,7 @@ class PolicyGraph(object):
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
@DeveloperAPI
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@ -140,6 +147,7 @@ class PolicyGraph(object):
"""
return sample_batch
@DeveloperAPI
def compute_gradients(self, postprocessed_batch):
"""Computes gradients against a batch of experiences.
@ -149,6 +157,7 @@ class PolicyGraph(object):
"""
raise NotImplementedError
@DeveloperAPI
def apply_gradients(self, gradients):
"""Applies previously computed gradients.
@ -157,6 +166,7 @@ class PolicyGraph(object):
"""
raise NotImplementedError
@DeveloperAPI
def compute_apply(self, samples):
"""Fused compute gradients and apply gradients call.
@ -173,6 +183,7 @@ class PolicyGraph(object):
apply_info = self.apply_gradients(grads)
return grad_info, apply_info
@DeveloperAPI
def get_weights(self):
"""Returns model weights.
@ -181,6 +192,7 @@ class PolicyGraph(object):
"""
raise NotImplementedError
@DeveloperAPI
def set_weights(self, weights):
"""Sets model weights.
@ -189,10 +201,12 @@ class PolicyGraph(object):
"""
raise NotImplementedError
@DeveloperAPI
def get_initial_state(self):
"""Returns initial RNN state for the current policy."""
return []
@DeveloperAPI
def get_state(self):
"""Saves all local state.
@ -201,6 +215,7 @@ class PolicyGraph(object):
"""
return self.get_weights()
@DeveloperAPI
def set_state(self, state):
"""Restores all local state.
@ -209,6 +224,7 @@ class PolicyGraph(object):
"""
self.set_weights(state)
@DeveloperAPI
def on_global_var_update(self, global_vars):
"""Called on an update to global vars.
@ -217,6 +233,7 @@ class PolicyGraph(object):
"""
pass
@DeveloperAPI
def export_model(self, export_dir):
"""Export PolicyGraph to local directory for serving.
@ -225,6 +242,7 @@ class PolicyGraph(object):
"""
raise NotImplementedError
@DeveloperAPI
def export_checkpoint(self, export_dir):
"""Export PolicyGraph checkpoint to local directory.

View file

@ -5,12 +5,14 @@ from __future__ import print_function
import numpy as np
import scipy.signal
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
@DeveloperAPI
def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
"""Given a rollout, compute its value targets and the advantage.
@ -54,6 +56,7 @@ def compute_advantages(rollout, last_r, gamma=0.9, lambda_=1.0, use_gae=True):
return SampleBatch(traj)
@DeveloperAPI
def compute_targets(rollout, action_space, last_r=0.0, gamma=0.9, lambda_=1.0):
"""Given a rollout, compute targets.

View file

@ -6,166 +6,13 @@ import six
import collections
import numpy as np
from ray.rllib.utils.annotations import PublicAPI
# Defaults policy id for single agent environments
DEFAULT_POLICY_ID = "default"
def to_float_array(v):
arr = np.array(v)
if arr.dtype == np.float64:
return arr.astype(np.float32) # save some memory
return arr
class SampleBatchBuilder(object):
"""Util to build a SampleBatch incrementally.
For efficiency, SampleBatches hold values in column form (as arrays).
However, it is useful to add data one row (dict) at a time.
"""
def __init__(self):
self.buffers = collections.defaultdict(list)
self.count = 0
def add_values(self, **values):
"""Add the given dictionary (row) of values to this batch."""
for k, v in values.items():
self.buffers[k].append(v)
self.count += 1
def add_batch(self, batch):
"""Add the given batch of values to this batch."""
for k, column in batch.items():
self.buffers[k].extend(column)
self.count += batch.count
def build_and_reset(self):
"""Returns a sample batch including all previously added values."""
batch = SampleBatch(
{k: to_float_array(v)
for k, v in self.buffers.items()})
self.buffers.clear()
self.count = 0
return batch
class MultiAgentSampleBatchBuilder(object):
"""Util to build SampleBatches for each policy in a multi-agent env.
Input data is per-agent, while output data is per-policy. There is an M:N
mapping between agents and policies. We retain one local batch builder
per agent. When an agent is done, then its local batch is appended into the
corresponding policy batch for the agent's policy.
"""
def __init__(self, policy_map, clip_rewards):
"""Initialize a MultiAgentSampleBatchBuilder.
Arguments:
policy_map (dict): Maps policy ids to policy graph instances.
clip_rewards (bool): Whether to clip rewards before postprocessing.
"""
self.policy_map = policy_map
self.clip_rewards = clip_rewards
self.policy_builders = {
k: SampleBatchBuilder()
for k in policy_map.keys()
}
self.agent_builders = {}
self.agent_to_policy = {}
self.count = 0 # increment this manually
def total(self):
"""Returns summed number of steps across all agent buffers."""
return sum(p.count for p in self.policy_builders.values())
def has_pending_data(self):
"""Returns whether there is pending unprocessed data."""
return len(self.agent_builders) > 0
def add_values(self, agent_id, policy_id, **values):
"""Add the given dictionary (row) of values to this batch.
Arguments:
agent_id (obj): Unique id for the agent we are adding values for.
policy_id (obj): Unique id for policy controlling the agent.
values (dict): Row of values to add for this agent.
"""
if agent_id not in self.agent_builders:
self.agent_builders[agent_id] = SampleBatchBuilder()
self.agent_to_policy[agent_id] = policy_id
builder = self.agent_builders[agent_id]
builder.add_values(**values)
def postprocess_batch_so_far(self, episode):
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Arguments:
episode: current MultiAgentEpisode object or None
"""
# Materialize the batches so far
pre_batches = {}
for agent_id, builder in self.agent_builders.items():
pre_batches[agent_id] = (
self.policy_map[self.agent_to_policy[agent_id]],
builder.build_and_reset())
# Apply postprocessor
post_batches = {}
if self.clip_rewards:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
for agent_id, (_, pre_batch) in pre_batches.items():
other_batches = pre_batches.copy()
del other_batches[agent_id]
policy = self.policy_map[self.agent_to_policy[agent_id]]
if any(pre_batch["dones"][:-1]) or len(set(
pre_batch["eps_id"])) > 1:
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches, episode)
# Append into policy batches and reset
for agent_id, post_batch in sorted(post_batches.items()):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
self.agent_builders.clear()
self.agent_to_policy.clear()
def build_and_reset(self, episode):
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Arguments:
episode: current MultiAgentEpisode object or None
"""
self.postprocess_batch_so_far(episode)
policy_batches = {}
for policy_id, builder in self.policy_builders.items():
if builder.count > 0:
policy_batches[policy_id] = builder.build_and_reset()
old_count = self.count
self.count = 0
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
@PublicAPI
class MultiAgentBatch(object):
"""A batch of experiences from multiple policies in the environment.
@ -177,17 +24,20 @@ class MultiAgentBatch(object):
batch contains across all policies in total.
"""
@PublicAPI
def __init__(self, policy_batches, count):
self.policy_batches = policy_batches
self.count = count
@staticmethod
@PublicAPI
def wrap_as_needed(batches, count):
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
@staticmethod
@PublicAPI
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
total_count = 0
@ -201,11 +51,13 @@ class MultiAgentBatch(object):
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, total_count)
@PublicAPI
def copy(self):
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
@PublicAPI
def total(self):
ct = 0
for batch in self.policy_batches.values():
@ -221,6 +73,7 @@ class MultiAgentBatch(object):
str(self.policy_batches), self.count)
@PublicAPI
class SampleBatch(object):
"""Wrapper around a dictionary with string keys and array-like values.
@ -228,6 +81,7 @@ class SampleBatch(object):
samples, each with an "obs" and "reward" attribute.
"""
@PublicAPI
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor)."""
@ -243,6 +97,7 @@ class SampleBatch(object):
self.count = lengths[0]
@staticmethod
@PublicAPI
def concat_samples(samples):
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
@ -252,6 +107,7 @@ class SampleBatch(object):
out[k] = np.concatenate([s[k] for s in samples])
return SampleBatch(out)
@PublicAPI
def concat(self, other):
"""Returns a new SampleBatch with each data column concatenated.
@ -268,11 +124,13 @@ class SampleBatch(object):
out[k] = np.concatenate([self[k], other[k]])
return SampleBatch(out)
@PublicAPI
def copy(self):
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
@PublicAPI
def rows(self):
"""Returns an iterator over data rows, i.e. dicts with column values.
@ -291,6 +149,7 @@ class SampleBatch(object):
row[k] = self[k][i]
yield row
@PublicAPI
def columns(self, keys):
"""Returns a list of just the specified columns.
@ -305,6 +164,7 @@ class SampleBatch(object):
out.append(self[k])
return out
@PublicAPI
def shuffle(self):
"""Shuffles the rows of this batch in-place."""
@ -312,6 +172,7 @@ class SampleBatch(object):
for key, val in self.items():
self[key] = val[permutation]
@PublicAPI
def split_by_episode(self):
"""Splits this batch's data by `eps_id`.
@ -335,6 +196,7 @@ class SampleBatch(object):
assert sum(s.count for s in slices) == self.count, (slices, self.count)
return slices
@PublicAPI
def slice(self, start, end):
"""Returns a slice of the row data of this batch.
@ -348,9 +210,19 @@ class SampleBatch(object):
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
@PublicAPI
def keys(self):
return self.data.keys()
@PublicAPI
def items(self):
return self.data.items()
@PublicAPI
def __getitem__(self, key):
return self.data[key]
@PublicAPI
def __setitem__(self, key, item):
self.data[key] = item
@ -360,12 +232,6 @@ class SampleBatch(object):
def __repr__(self):
return "SampleBatch({})".format(str(self.data))
def keys(self):
return self.data.keys()
def items(self):
return self.data.items()
def __iter__(self):
return self.data.__iter__()

View file

@ -0,0 +1,173 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
def to_float_array(v):
arr = np.array(v)
if arr.dtype == np.float64:
return arr.astype(np.float32) # save some memory
return arr
@PublicAPI
class SampleBatchBuilder(object):
"""Util to build a SampleBatch incrementally.
For efficiency, SampleBatches hold values in column form (as arrays).
However, it is useful to add data one row (dict) at a time.
"""
@PublicAPI
def __init__(self):
self.buffers = collections.defaultdict(list)
self.count = 0
@PublicAPI
def add_values(self, **values):
"""Add the given dictionary (row) of values to this batch."""
for k, v in values.items():
self.buffers[k].append(v)
self.count += 1
@PublicAPI
def add_batch(self, batch):
"""Add the given batch of values to this batch."""
for k, column in batch.items():
self.buffers[k].extend(column)
self.count += batch.count
@PublicAPI
def build_and_reset(self):
"""Returns a sample batch including all previously added values."""
batch = SampleBatch(
{k: to_float_array(v)
for k, v in self.buffers.items()})
self.buffers.clear()
self.count = 0
return batch
@DeveloperAPI
class MultiAgentSampleBatchBuilder(object):
"""Util to build SampleBatches for each policy in a multi-agent env.
Input data is per-agent, while output data is per-policy. There is an M:N
mapping between agents and policies. We retain one local batch builder
per agent. When an agent is done, then its local batch is appended into the
corresponding policy batch for the agent's policy.
"""
def __init__(self, policy_map, clip_rewards):
"""Initialize a MultiAgentSampleBatchBuilder.
Arguments:
policy_map (dict): Maps policy ids to policy graph instances.
clip_rewards (bool): Whether to clip rewards before postprocessing.
"""
self.policy_map = policy_map
self.clip_rewards = clip_rewards
self.policy_builders = {
k: SampleBatchBuilder()
for k in policy_map.keys()
}
self.agent_builders = {}
self.agent_to_policy = {}
self.count = 0 # increment this manually
def total(self):
"""Returns summed number of steps across all agent buffers."""
return sum(p.count for p in self.policy_builders.values())
def has_pending_data(self):
"""Returns whether there is pending unprocessed data."""
return len(self.agent_builders) > 0
@DeveloperAPI
def add_values(self, agent_id, policy_id, **values):
"""Add the given dictionary (row) of values to this batch.
Arguments:
agent_id (obj): Unique id for the agent we are adding values for.
policy_id (obj): Unique id for policy controlling the agent.
values (dict): Row of values to add for this agent.
"""
if agent_id not in self.agent_builders:
self.agent_builders[agent_id] = SampleBatchBuilder()
self.agent_to_policy[agent_id] = policy_id
builder = self.agent_builders[agent_id]
builder.add_values(**values)
def postprocess_batch_so_far(self, episode):
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Arguments:
episode: current MultiAgentEpisode object or None
"""
# Materialize the batches so far
pre_batches = {}
for agent_id, builder in self.agent_builders.items():
pre_batches[agent_id] = (
self.policy_map[self.agent_to_policy[agent_id]],
builder.build_and_reset())
# Apply postprocessor
post_batches = {}
if self.clip_rewards:
for _, (_, pre_batch) in pre_batches.items():
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
for agent_id, (_, pre_batch) in pre_batches.items():
other_batches = pre_batches.copy()
del other_batches[agent_id]
policy = self.policy_map[self.agent_to_policy[agent_id]]
if any(pre_batch["dones"][:-1]) or len(set(
pre_batch["eps_id"])) > 1:
raise ValueError(
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches, episode)
# Append into policy batches and reset
for agent_id, post_batch in sorted(post_batches.items()):
self.policy_builders[self.agent_to_policy[agent_id]].add_batch(
post_batch)
self.agent_builders.clear()
self.agent_to_policy.clear()
@DeveloperAPI
def build_and_reset(self, episode):
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Arguments:
episode: current MultiAgentEpisode object or None
"""
self.postprocess_batch_so_far(episode)
policy_batches = {}
for policy_id, builder in self.policy_builders.items():
if builder.count > 0:
policy_batches[policy_id] = builder.build_and_reset()
old_count = self.count
self.count = 0
return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)

View file

@ -10,9 +10,10 @@ import six.moves.queue as queue
import threading
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
from ray.rllib.evaluation.sample_batch import MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.utils.tf_run_builder import TFRunBuilder
@ -44,7 +45,7 @@ class SyncSampler(object):
pack=False,
tf_sess=None,
clip_actions=True):
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
self.base_env = BaseEnv.to_base_env(env)
self.unroll_length = unroll_length
self.horizon = horizon
self.policies = policies
@ -53,7 +54,7 @@ class SyncSampler(object):
self.obs_filters = obs_filters
self.extra_batches = queue.Queue()
self.rollout_provider = _env_runner(
self.async_vector_env, self.extra_batches.put, self.policies,
self.base_env, self.extra_batches.put, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, clip_rewards, clip_actions,
pack, callbacks, tf_sess)
@ -104,7 +105,7 @@ class AsyncSampler(threading.Thread):
for _, f in obs_filters.items():
assert getattr(f, "is_concurrent", False), \
"Observation Filter must support concurrent updates."
self.async_vector_env = AsyncVectorEnv.wrap_async(env)
self.base_env = BaseEnv.to_base_env(env)
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.extra_batches = queue.Queue()
@ -140,7 +141,7 @@ class AsyncSampler(threading.Thread):
extra_batches_putter = (
lambda x: self.extra_batches.put(x, timeout=600.0))
rollout_provider = _env_runner(
self.async_vector_env, extra_batches_putter, self.policies,
self.base_env, extra_batches_putter, self.policies,
self.policy_mapping_fn, self.unroll_length, self.horizon,
self.preprocessors, self.obs_filters, self.clip_rewards,
self.clip_actions, self.pack, self.callbacks, self.tf_sess)
@ -182,7 +183,7 @@ class AsyncSampler(threading.Thread):
return extra
def _env_runner(async_vector_env,
def _env_runner(base_env,
extra_batch_callback,
policies,
policy_mapping_fn,
@ -198,7 +199,7 @@ def _env_runner(async_vector_env,
"""This implements the common experience collection logic.
Args:
async_vector_env (AsyncVectorEnv): env implementing AsyncVectorEnv.
base_env (BaseEnv): env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
policies (dict): Map of policy ids to PolicyGraph instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids.
@ -226,8 +227,7 @@ def _env_runner(async_vector_env,
try:
if not horizon:
horizon = (
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
horizon = (base_env.get_unwrapped()[0].spec.max_episode_steps)
except Exception:
logger.debug("no episode horizon specified, assuming inf")
if not horizon:
@ -248,7 +248,7 @@ def _env_runner(async_vector_env,
get_batch_builder, extra_batch_callback)
if callbacks.get("on_episode_start"):
callbacks["on_episode_start"]({
"env": async_vector_env,
"env": base_env,
"episode": episode
})
return episode
@ -258,11 +258,11 @@ def _env_runner(async_vector_env,
while True:
# Get observations from all ready agents
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
async_vector_env.poll()
base_env.poll()
# Process observations and prepare for policy evaluation
active_envs, to_eval, outputs = _process_observations(
async_vector_env, policies, batch_builder_pool, active_episodes,
base_env, policies, batch_builder_pool, active_episodes,
unfiltered_obs, rewards, dones, infos, off_policy_actions, horizon,
preprocessors, obs_filters, unroll_length, pack, callbacks)
for o in outputs:
@ -279,10 +279,10 @@ def _env_runner(async_vector_env,
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
async_vector_env.send_actions(actions_to_send)
base_env.send_actions(actions_to_send)
def _process_observations(async_vector_env, policies, batch_builder_pool,
def _process_observations(base_env, policies, batch_builder_pool,
active_episodes, unfiltered_obs, rewards, dones,
infos, off_policy_actions, horizon, preprocessors,
obs_filters, unroll_length, pack, callbacks):
@ -325,7 +325,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
# Check episode termination conditions
if dones[env_id]["__all__"] or episode.length >= horizon:
all_done = True
atari_metrics = _fetch_atari_metrics(async_vector_env)
atari_metrics = _fetch_atari_metrics(base_env)
if atari_metrics is not None:
for m in atari_metrics:
outputs.append(
@ -379,10 +379,7 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
# Invoke the step callback after the step is logged to the episode
if callbacks.get("on_episode_step"):
callbacks["on_episode_step"]({
"env": async_vector_env,
"episode": episode
})
callbacks["on_episode_step"]({"env": base_env, "episode": episode})
# Cut the batch if we're not packing multiple episodes into one,
# or if we've exceeded the requested batch size.
@ -399,11 +396,11 @@ def _process_observations(async_vector_env, policies, batch_builder_pool,
batch_builder_pool.append(episode.batch_builder)
if callbacks.get("on_episode_end"):
callbacks["on_episode_end"]({
"env": async_vector_env,
"env": base_env,
"episode": episode
})
del active_episodes[env_id]
resetted_obs = async_vector_env.try_reset(env_id)
resetted_obs = base_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
if horizon != float("inf"):
@ -526,12 +523,12 @@ def _process_policy_eval_results(to_eval, eval_results, active_episodes,
return actions_to_send
def _fetch_atari_metrics(async_vector_env):
def _fetch_atari_metrics(base_env):
"""Atari games have multiple logical episodes, one per life.
However for metrics reporting we count full episodes all lives included.
"""
unwrapped = async_vector_env.get_unwrapped()
unwrapped = base_env.get_unwrapped()
if not unwrapped:
return None
atari_out = []

View file

@ -10,13 +10,14 @@ import numpy as np
import ray
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
logger = logging.getLogger(__name__)
@DeveloperAPI
class TFPolicyGraph(PolicyGraph):
"""An agent policy and loss implemented in TensorFlow.
@ -41,6 +42,7 @@ class TFPolicyGraph(PolicyGraph):
SampleBatch({"action": ..., "advantages": ..., ...})
"""
@DeveloperAPI
def __init__(self,
observation_space,
action_space,
@ -208,36 +210,63 @@ class TFPolicyGraph(PolicyGraph):
saver = tf.train.Saver()
saver.save(self._sess, save_path)
@DeveloperAPI
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders.
Optional, only required to work with the multi-GPU optimizer."""
raise NotImplementedError
@DeveloperAPI
def extra_compute_action_feed_dict(self):
"""Extra dict to pass to the compute actions session run."""
return {}
@DeveloperAPI
def extra_compute_action_fetches(self):
"""Extra values to fetch and return from compute_actions()."""
return {} # e.g, value function
@DeveloperAPI
def extra_compute_grad_feed_dict(self):
"""Extra dict to pass to the compute gradients session run."""
return {} # e.g, kl_coeff
@DeveloperAPI
def extra_compute_grad_fetches(self):
"""Extra values to fetch and return from compute_gradients()."""
return {} # e.g, td error
@DeveloperAPI
def extra_apply_grad_feed_dict(self):
"""Extra dict to pass to the apply gradients session run."""
return {}
@DeveloperAPI
def extra_apply_grad_fetches(self):
"""Extra values to fetch and return from apply_gradients()."""
return {} # e.g., batch norm updates
@DeveloperAPI
def optimizer(self):
"""TF optimizer to use for policy optimization."""
return tf.train.AdamOptimizer()
@DeveloperAPI
def gradients(self, optimizer):
"""Override for custom gradient computation."""
return optimizer.compute_gradients(self._loss)
@DeveloperAPI
def _get_is_training_placeholder(self):
"""Get the placeholder for _is_training, i.e., for batch norm layers.
This can be called safely before __init__ has run.
"""
if not hasattr(self, "_is_training"):
self._is_training = tf.placeholder_with_default(False, ())
return self._is_training
def _extra_input_signature_def(self):
"""Extra input signatures to add when exporting tf model.
Inferred from extra_compute_action_feed_dict()
@ -258,14 +287,6 @@ class TFPolicyGraph(PolicyGraph):
for k in fetches.keys()
}
def optimizer(self):
"""TF optimizer to use for policy optimization."""
return tf.train.AdamOptimizer()
def gradients(self, optimizer):
"""Override for custom gradient computation."""
return optimizer.compute_gradients(self._loss)
def _build_signature_def(self):
"""Build signature def map for tensorflow SavedModelBuilder.
"""
@ -364,15 +385,6 @@ class TFPolicyGraph(PolicyGraph):
])
return fetches[1], fetches[2]
def _get_is_training_placeholder(self):
"""Get the placeholder for _is_training, i.e., for batch norm layers.
This can be called safely before __init__ has run.
"""
if not hasattr(self, "_is_training"):
self._is_training = tf.placeholder_with_default(False, ())
return self._is_training
def _get_loss_inputs_dict(self, batch):
feed_dict = {}
if self._batch_divisibility_req > 1:
@ -414,9 +426,11 @@ class TFPolicyGraph(PolicyGraph):
return feed_dict
@DeveloperAPI
class LearningRateSchedule(object):
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self.cur_lr = tf.get_variable("lr", initializer=lr)
if lr_schedule is None:

View file

@ -7,7 +7,7 @@ from __future__ import print_function
import gym
import numpy as np
from ray.rllib.evaluation.sample_batch import SampleBatchBuilder
from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
from ray.rllib.offline.json_writer import JsonWriter
if __name__ == "__main__":

View file

@ -7,12 +7,13 @@ import distutils.version
import tensorflow as tf
import numpy as np
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, DeveloperAPI
use_tf150_api = (distutils.version.LooseVersion(tf.VERSION) >=
distutils.version.LooseVersion("1.5.0"))
@DeveloperAPI
class ActionDistribution(object):
"""The policy action distribution of an agent.
@ -20,21 +21,26 @@ class ActionDistribution(object):
inputs (Tensor): The input vector to compute samples from.
"""
@DeveloperAPI
def __init__(self, inputs):
self.inputs = inputs
@DeveloperAPI
def logp(self, x):
"""The log-likelihood of the action distribution."""
raise NotImplementedError
@DeveloperAPI
def kl(self, other):
"""The KL-divergence between two action distributions."""
raise NotImplementedError
@DeveloperAPI
def entropy(self):
"""The entroy of the action distribution."""
raise NotImplementedError
@DeveloperAPI
def sample(self):
"""Draw a sample from the action distribution."""
raise NotImplementedError

View file

@ -17,6 +17,7 @@ from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.fcnet import FullyConnectedNetwork
from ray.rllib.models.visionnet import VisionNetwork
from ray.rllib.models.lstm import LSTM
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
logger = logging.getLogger(__name__)
@ -69,6 +70,7 @@ MODEL_DEFAULTS = {
# yapf: enable
@PublicAPI
class ModelCatalog(object):
"""Registry of models, preprocessors, and action distributions for envs.
@ -84,6 +86,7 @@ class ModelCatalog(object):
"""
@staticmethod
@DeveloperAPI
def get_action_dist(action_space, config, dist_type=None):
"""Returns action distribution class and size for the given action space.
@ -134,6 +137,7 @@ class ModelCatalog(object):
action_space, dist_type))
@staticmethod
@DeveloperAPI
def get_action_placeholder(action_space):
"""Returns an action placeholder that is consistent with the action space
@ -166,6 +170,7 @@ class ModelCatalog(object):
" not supported".format(action_space))
@staticmethod
@DeveloperAPI
def get_model(input_dict,
obs_space,
num_outputs,
@ -230,6 +235,7 @@ class ModelCatalog(object):
options)
@staticmethod
@DeveloperAPI
def get_torch_model(obs_space,
num_outputs,
options=None,
@ -276,6 +282,7 @@ class ModelCatalog(object):
return PyTorchFCNet(obs_space, num_outputs, options)
@staticmethod
@DeveloperAPI
def get_preprocessor(env, options=None):
"""Returns a suitable preprocessor for the given env.
@ -286,6 +293,7 @@ class ModelCatalog(object):
options)
@staticmethod
@DeveloperAPI
def get_preprocessor_for_space(observation_space, options=None):
"""Returns a suitable preprocessor for the given observation space.
@ -317,6 +325,7 @@ class ModelCatalog(object):
return prep
@staticmethod
@PublicAPI
def register_custom_preprocessor(preprocessor_name, preprocessor_class):
"""Register a custom preprocessor class by name.
@ -331,6 +340,7 @@ class ModelCatalog(object):
preprocessor_class)
@staticmethod
@PublicAPI
def register_custom_model(model_name, model_class):
"""Register a custom model class by name.

View file

@ -23,7 +23,7 @@ import tensorflow.contrib.rnn as rnn
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.model import Model
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, DeveloperAPI, PublicAPI
class LSTM(Model):
@ -91,6 +91,7 @@ class LSTM(Model):
return logits, last_layer
@PublicAPI
def add_time_dimension(padded_inputs, seq_lens):
"""Adds a time dimension to padded inputs.
@ -118,6 +119,7 @@ def add_time_dimension(padded_inputs, seq_lens):
return tf.reshape(padded_inputs, new_shape)
@DeveloperAPI
def chop_into_sequences(episode_ids,
agent_indices,
feature_columns,

View file

@ -9,8 +9,10 @@ import tensorflow as tf
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class Model(object):
"""Defines an abstract network model for use with RLlib.
@ -90,6 +92,7 @@ class Model(object):
"""
raise NotImplementedError
@PublicAPI
def _build_layers_v2(self, input_dict, num_outputs, options):
"""Define the layers of a custom model.
@ -122,6 +125,7 @@ class Model(object):
"""
raise NotImplementedError
@PublicAPI
def value_function(self):
"""Builds the value function output.
@ -134,6 +138,7 @@ class Model(object):
return tf.reshape(
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
@PublicAPI
def loss(self):
"""Builds any built-in (self-supervised) loss for the model.

View file

@ -8,7 +8,7 @@ import logging
import numpy as np
import gym
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI
ATARI_OBS_SHAPE = (210, 160, 3)
ATARI_RAM_OBS_SHAPE = (128, )
@ -16,6 +16,7 @@ ATARI_RAM_OBS_SHAPE = (128, )
logger = logging.getLogger(__name__)
@PublicAPI
class Preprocessor(object):
"""Defines an abstract observation preprocessor function.
@ -23,25 +24,30 @@ class Preprocessor(object):
shape (obj): Shape of the preprocessed output.
"""
@PublicAPI
def __init__(self, obs_space, options=None):
legacy_patch_shapes(obs_space)
self._obs_space = obs_space
self._options = options or {}
self.shape = self._init_shape(obs_space, options)
@PublicAPI
def _init_shape(self, obs_space, options):
"""Returns the shape after preprocessing."""
raise NotImplementedError
@PublicAPI
def transform(self, observation):
"""Returns the preprocessed observation."""
raise NotImplementedError
@property
@PublicAPI
def size(self):
return int(np.product(self.shape))
@property
@PublicAPI
def observation_space(self):
obs_space = gym.spaces.Box(-1.0, 1.0, self.shape, dtype=np.float32)
# Stash the unwrapped space so that we can unwrap dict and tuple spaces
@ -186,6 +192,7 @@ class DictFlatteningPreprocessor(Preprocessor):
])
@PublicAPI
def get_preprocessor(space):
"""Returns an appropriate preprocessor class for the given space."""

View file

@ -6,8 +6,10 @@ import torch
import torch.nn as nn
from ray.rllib.models.model import _restore_original_dimensions
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class TorchModel(nn.Module):
"""Defines an abstract network model for use with RLlib / PyTorch."""
@ -25,6 +27,7 @@ class TorchModel(nn.Module):
self.num_outputs = num_outputs
self.options = options
@PublicAPI
def forward(self, input_dict, hidden_state):
"""Wraps _forward() to unpack flattened Dict and Tuple observations."""
input_dict["obs"] = input_dict["obs"].float() # TODO(ekl): avoid cast
@ -33,10 +36,12 @@ class TorchModel(nn.Module):
outputs, features, vf, h = self._forward(input_dict, hidden_state)
return outputs, features, vf, h
@PublicAPI
def state_init(self):
"""Returns a list of initial hidden state tensors, if any."""
return []
@PublicAPI
def _forward(self, input_dict, hidden_state):
"""Forward pass for the model.

View file

@ -3,11 +3,14 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class InputReader(object):
"""Input object for loading experiences in policy evaluation."""
@PublicAPI
def next(self):
"""Return the next batch of experiences read.

View file

@ -5,8 +5,10 @@ from __future__ import print_function
import os
from ray.rllib.offline.input_reader import SamplerInput
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class IOContext(object):
"""Attributes to pass to input / output class constructors.
@ -20,6 +22,7 @@ class IOContext(object):
evaluator (PolicyEvaluator): policy evaluator object reference.
"""
@PublicAPI
def __init__(self,
log_dir=None,
config=None,
@ -30,5 +33,6 @@ class IOContext(object):
self.worker_index = worker_index
self.evaluator = evaluator
@PublicAPI
def default_sampler_input(self):
return SamplerInput(self.evaluator.sampler)

View file

@ -19,17 +19,19 @@ from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, SampleBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.compression import unpack_if_needed
logger = logging.getLogger(__name__)
@PublicAPI
class JsonReader(InputReader):
"""Reader object that loads experiences from JSON file chunks.
The input files will be read from in an random order."""
@PublicAPI
def __init__(self, inputs, ioctx=None):
"""Initialize a JsonReader.

View file

@ -18,15 +18,17 @@ except ImportError:
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.compression import pack
logger = logging.getLogger(__name__)
@PublicAPI
class JsonWriter(OutputWriter):
"""Writer object that saves experiences in JSON file chunks."""
@PublicAPI
def __init__(self,
path,
ioctx=None,

View file

@ -6,9 +6,10 @@ import numpy as np
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import override, PublicAPI
@PublicAPI
class MixedInput(InputReader):
"""Mixes input from a number of other input sources.
@ -20,6 +21,7 @@ class MixedInput(InputReader):
}, ioctx)
"""
@PublicAPI
def __init__(self, dist, ioctx):
"""Initialize a MixedInput.

View file

@ -3,11 +3,14 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class OutputWriter(object):
"""Writer object for saving experiences from policy evaluation."""
@PublicAPI
def write(self, sample_batch):
"""Save a batch of experiences.

View file

@ -5,12 +5,14 @@ from __future__ import print_function
import logging
import ray
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
logger = logging.getLogger(__name__)
@DeveloperAPI
class PolicyOptimizer(object):
"""Policy optimizers encapsulate distributed RL optimization strategies.
@ -36,6 +38,7 @@ class PolicyOptimizer(object):
evaluators created by this optimizer.
"""
@DeveloperAPI
def __init__(self, local_evaluator, remote_evaluators=None, config=None):
"""Create an optimizer instance.
@ -59,11 +62,13 @@ class PolicyOptimizer(object):
logger.debug("Created policy optimizer with {}: {}".format(
config, self))
@DeveloperAPI
def _init(self):
"""Subclasses should prefer overriding this instead of __init__."""
raise NotImplementedError
@DeveloperAPI
def step(self):
"""Takes a logical optimization step.
@ -77,6 +82,7 @@ class PolicyOptimizer(object):
raise NotImplementedError
@DeveloperAPI
def stats(self):
"""Returns a dictionary of internal performance statistics."""
@ -85,21 +91,25 @@ class PolicyOptimizer(object):
"num_steps_sampled": self.num_steps_sampled,
}
@DeveloperAPI
def save(self):
"""Returns a serializable object representing the optimizer state."""
return [self.num_steps_trained, self.num_steps_sampled]
@DeveloperAPI
def restore(self, data):
"""Restores optimizer state from the given data object."""
self.num_steps_trained = data[0]
self.num_steps_sampled = data[1]
@DeveloperAPI
def stop(self):
"""Release any resources used by this optimizer."""
pass
@DeveloperAPI
def collect_metrics(self,
timeout_seconds,
min_history=100,
@ -132,6 +142,7 @@ class PolicyOptimizer(object):
res.update(info=self.stats())
return res
@DeveloperAPI
def foreach_evaluator(self, func):
"""Apply the given function to each evaluator instance."""
@ -140,6 +151,7 @@ class PolicyOptimizer(object):
[ev.apply.remote(func) for ev in self.remote_evaluators])
return local_result + remote_results
@DeveloperAPI
def foreach_evaluator_with_index(self, func):
"""Apply the given function to each evaluator instance.

View file

@ -7,11 +7,14 @@ import random
import sys
from ray.rllib.optimizers.segment_tree import SumSegmentTree, MinSegmentTree
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.window_stat import WindowStat
@DeveloperAPI
class ReplayBuffer(object):
@DeveloperAPI
def __init__(self, size):
"""Create Prioritized Replay buffer.
@ -34,6 +37,7 @@ class ReplayBuffer(object):
def __len__(self):
return len(self._storage)
@DeveloperAPI
def add(self, obs_t, action, reward, obs_tp1, done, weight):
data = (obs_t, action, reward, obs_tp1, done)
self._num_added += 1
@ -64,6 +68,7 @@ class ReplayBuffer(object):
return (np.array(obses_t), np.array(actions), np.array(rewards),
np.array(obses_tp1), np.array(dones))
@DeveloperAPI
def sample(self, batch_size):
"""Sample a batch of experiences.
@ -93,6 +98,7 @@ class ReplayBuffer(object):
self._num_sampled += batch_size
return self._encode_sample(idxes)
@DeveloperAPI
def stats(self, debug=False):
data = {
"added_count": self._num_added,
@ -105,7 +111,9 @@ class ReplayBuffer(object):
return data
@DeveloperAPI
class PrioritizedReplayBuffer(ReplayBuffer):
@DeveloperAPI
def __init__(self, size, alpha):
"""Create Prioritized Replay buffer.
@ -135,6 +143,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_priority = 1.0
self._prio_change_stats = WindowStat("reprio", 1000)
@DeveloperAPI
def add(self, obs_t, action, reward, obs_tp1, done, weight):
"""See ReplayBuffer.store_effect"""
@ -155,6 +164,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
res.append(idx)
return res
@DeveloperAPI
def sample(self, batch_size, beta):
"""Sample a batch of experiences.
@ -208,6 +218,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
encoded_sample = self._encode_sample(idxes)
return tuple(list(encoded_sample) + [weights, idxes])
@DeveloperAPI
def update_priorities(self, idxes, priorities):
"""Update priorities of sampled transitions.
@ -234,6 +245,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._max_priority = max(self._max_priority, priority)
@DeveloperAPI
def stats(self, debug=False):
parent = ReplayBuffer.stats(self, debug)
if debug:

View file

@ -17,7 +17,7 @@ from ray.rllib.test.test_policy_evaluator import MockEnv, MockEnv2, \
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.env.async_vector_env import _MultiAgentEnvToAsync
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.tune.registry import register_env
@ -176,13 +176,13 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(done["__all__"], True)
def testNoResetUntilPoll(self):
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 1)
env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 1)
self.assertFalse(env.get_unwrapped()[0].resetted)
env.poll()
self.assertTrue(env.get_unwrapped()[0].resetted)
def testVectorizeBasic(self):
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2)
env = _MultiAgentEnvToBaseEnv(lambda v: BasicMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}})
self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}})
@ -258,7 +258,7 @@ class TestMultiAgentEnv(unittest.TestCase):
})
def testVectorizeRoundRobin(self):
env = _MultiAgentEnvToAsync(lambda v: RoundRobinMultiAgent(2), [], 2)
env = _MultiAgentEnvToBaseEnv(lambda v: RoundRobinMultiAgent(2), [], 2)
obs, rew, dones, _, _ = env.poll()
self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}})
self.assertEqual(rew, {0: {0: None}, 1: {0: None}})

View file

@ -16,7 +16,7 @@ from ray.rllib.agents.a3c import A2CAgent
from ray.rllib.agents.pg import PGAgent
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.async_vector_env import AsyncVectorEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.model import Model
@ -303,8 +303,7 @@ class NestedSpacesTest(unittest.TestCase):
self.doTestNestedDict(lambda _: SimpleServing(NestedDictEnv()))
def testNestedDictAsync(self):
self.doTestNestedDict(
lambda _: AsyncVectorEnv.wrap_async(NestedDictEnv()))
self.doTestNestedDict(lambda _: BaseEnv.to_base_env(NestedDictEnv()))
def testNestedTupleGym(self):
self.doTestNestedTuple(lambda _: NestedTupleEnv())
@ -317,8 +316,7 @@ class NestedSpacesTest(unittest.TestCase):
self.doTestNestedTuple(lambda _: SimpleServing(NestedTupleEnv()))
def testNestedTupleAsync(self):
self.doTestNestedTuple(
lambda _: AsyncVectorEnv.wrap_async(NestedTupleEnv()))
self.doTestNestedTuple(lambda _: BaseEnv.to_base_env(NestedTupleEnv()))
def testMultiAgentComplexSpaces(self):
ModelCatalog.register_custom_model("dict_spy", DictSpyModel)

View file

@ -18,3 +18,36 @@ def override(cls):
return method
return check_override
def PublicAPI(obj):
"""Annotation for documenting public APIs.
Public APIs are classes and methods exposed to end users of RLlib. You
can expect these APIs to remain stable across RLlib releases.
Subclasses that inherit from a ``@PublicAPI`` base class can be
assumed part of the RLlib public API as well (e.g., all agent classes
are in public API because Agent is ``@PublicAPI``).
In addition, you can assume all agent configurations are part of their
public API as well.
"""
return obj
def DeveloperAPI(obj):
"""Annotation for documenting developer APIs.
Developer APIs are classes and methods explicitly exposed to developers
for the purposes of building custom algorithms or advanced training
strategies on top of RLlib internals. You can generally expect these APIs
to be stable sans minor changes (but less stable than public APIs).
Subclasses that inherit from a ``@DeveloperAPI`` base class can be
assumed part of the RLlib developer API as well (e.g., all policy
optimizers are developer API because PolicyOptimizer is ``@DeveloperAPI``).
"""
return obj

View file

@ -9,6 +9,8 @@ import numpy as np
import pyarrow
from six import string_types
from ray.rllib.utils.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
try:
@ -21,6 +23,7 @@ except ImportError:
LZ4_ENABLED = False
@DeveloperAPI
def pack(data):
if LZ4_ENABLED:
data = pyarrow.serialize(data).to_buffer().to_pybytes()
@ -31,12 +34,14 @@ def pack(data):
return data
@DeveloperAPI
def pack_if_needed(data):
if isinstance(data, np.ndarray):
data = pack(data)
return data
@DeveloperAPI
def unpack(data):
if LZ4_ENABLED:
data = base64.b64decode(data)
@ -45,6 +50,7 @@ def unpack(data):
return data
@DeveloperAPI
def unpack_if_needed(data):
if isinstance(data, bytes) or isinstance(data, string_types):
data = unpack(data)

View file

@ -2,7 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.utils.annotations import PublicAPI
@PublicAPI
class UnsupportedSpaceException(Exception):
"""Error for an unsupported action or observation space."""
pass

View file

@ -3,14 +3,17 @@ from __future__ import division
from __future__ import print_function
import ray
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
class FilterManager(object):
"""Manages filters and coordination across remote evaluators that expose
`get_filters` and `sync_filters`.
"""
@staticmethod
@DeveloperAPI
def synchronize(local_filters, remotes, update_remote=True):
"""Aggregates all filters from remote evaluators.

View file

@ -5,6 +5,8 @@ from __future__ import print_function
import logging
import pickle
from ray.rllib.utils.annotations import PublicAPI
logger = logging.getLogger(__name__)
try:
@ -16,6 +18,7 @@ except ImportError:
" the client side.")
@PublicAPI
class PolicyClient(object):
"""REST client to interact with a RLlib policy server."""
@ -25,9 +28,11 @@ class PolicyClient(object):
LOG_RETURNS = "LOG_RETURNS"
END_EPISODE = "END_EPISODE"
@PublicAPI
def __init__(self, address):
self._address = address
@PublicAPI
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
@ -47,6 +52,7 @@ class PolicyClient(object):
"training_enabled": training_enabled,
})["episode_id"]
@PublicAPI
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
@ -63,6 +69,7 @@ class PolicyClient(object):
"episode_id": episode_id,
})["action"]
@PublicAPI
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
@ -78,6 +85,7 @@ class PolicyClient(object):
"episode_id": episode_id,
})
@PublicAPI
def log_returns(self, episode_id, reward, info=None):
"""Record returns from the environment.
@ -96,6 +104,7 @@ class PolicyClient(object):
"episode_id": episode_id,
})
@PublicAPI
def end_episode(self, episode_id, observation):
"""Record the end of an episode.

View file

@ -6,6 +6,7 @@ import pickle
import sys
import traceback
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.policy_client import PolicyClient
if sys.version_info[0] == 2:
@ -17,6 +18,7 @@ elif sys.version_info[0] == 3:
from socketserver import ThreadingMixIn
@PublicAPI
class PolicyServer(ThreadingMixIn, HTTPServer):
"""REST server than can be launched from a ExternalEnv.
@ -50,6 +52,7 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
>>> client.log_returns(eps_id, reward)
"""
@PublicAPI
def __init__(self, external_env, address, port):
handler = _make_handler(external_env)
HTTPServer.__init__(self, (address, port), handler)