[RLlib] Unity3D integration (n Unity3D clients vs learning server). (#8590)

This commit is contained in:
Sven Mika 2020-05-30 22:48:34 +02:00 committed by GitHub
parent 016337d4eb
commit d8a081a185
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 870 additions and 191 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

View file

@ -36,7 +36,7 @@ You can pass either a string name or a Python class to specify an environment. B
while True:
print(trainer.train())
You can also register a custom env creator function with a string name. This function must take a single ``env_config`` parameter and return an env instance:
You can also register a custom env creator function with a string name. This function must take a single ``env_config`` (dict) parameter and return an env instance:
.. code-block:: python
@ -113,19 +113,20 @@ When using remote envs, you can control the batching level for inference with ``
Multi-Agent and Hierarchical
----------------------------
A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure:
A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car"- and "traffic light" agents in the environment. The model for multi-agent in RLlib is as follows: (1) as a user, you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure:
.. image:: multi-agent.svg
The environment itself must subclass the `MultiAgentEnv <https://github.com/ray-project/ray/blob/master/rllib/env/multi_agent_env.py>`__ interface, which can returns observations and rewards from multiple ready agents per step:
The environment itself must subclass the `MultiAgentEnv <https://github.com/ray-project/ray/blob/master/rllib/env/multi_agent_env.py>`__ interface, which can return observations and rewards from multiple ready agents per step:
.. code-block:: python
# Example: using a multi-agent env
> env = MultiAgentTrafficEnv(num_cars=20, num_traffic_lights=5)
# Observations are a dict mapping agent names to their obs. Not all agents
# may be present in the dict in each time step.
# Observations are a dict mapping agent names to their obs. Only those
# agents' names that require actions in the next call to `step()` will
# be present in the returned observation dict.
> print(env.reset())
{
"car_1": [[...]],
@ -133,14 +134,15 @@ The environment itself must subclass the `MultiAgentEnv <https://github.com/ray-
"traffic_light_1": [[...]],
}
# Actions should be provided for each agent that returned an observation.
> new_obs, rewards, dones, infos = env.step(actions={"car_1": ..., "car_2": ...})
# In the following call to `step`, actions should be provided for each
# agent that returned an observation before:
> new_obs, rewards, dones, infos = env.step(actions={"car_1": ..., "car_2": ..., "traffic_light_1": ...})
# Similarly, new_obs, rewards, dones, etc. also become dicts
> print(rewards)
{"car_1": 3, "car_2": -1, "traffic_light_1": 0}
# Individual agents can early exit; env is done when "__all__" = True
# Individual agents can early exit; The entire episode is done when "__all__" = True
> print(dones)
{"car_2": True, "__all__": False}
@ -305,9 +307,14 @@ See this file for a runnable example: `hierarchical_training.py <https://github.
External Agents and Applications
--------------------------------
In many situations, it does not make sense for an environment to be "stepped" by RLlib. For example, if a policy is to be used in a web serving system, then it is more natural for an agent to query a service that serves policy decisions, and for that service to learn from experience over time. This case also naturally arises with **external simulators** that run independently outside the control of RLlib, but may still want to leverage RLlib for training.
In many situations, it does not make sense for an environment to be "stepped" by RLlib. For example, if a policy is to be used in a web serving system, then it is more natural for an agent to query a service that serves policy decisions, and for that service to learn from experience over time. This case also naturally arises with **external simulators** (e.g. Unity3D, other game engines, or the Gazebo robotics simulator) that run independently outside the control of RLlib, but may still want to leverage RLlib for training.
RLlib provides the `ExternalEnv <https://github.com/ray-project/ray/blob/master/rllib/env/external_env.py>`__ class for this purpose. Unlike other envs, ExternalEnv has its own thread of control. At any point, agents on that thread can query the current policy for decisions via ``self.get_action()`` and reports rewards via ``self.log_returns()``. This can be done for multiple concurrent episodes as well.
.. figure:: images/rllib-training-inside-a-unity3d-env.png
:scale: 75 %
A Unity3D soccer game being learnt by RLlib via the ExternalEnv API.
RLlib provides the `ExternalEnv <https://github.com/ray-project/ray/blob/master/rllib/env/external_env.py>`__ class for this purpose. Unlike other envs, ExternalEnv has its own thread of control. At any point, agents on that thread can query the current policy for decisions via ``self.get_action()`` and reports rewards, done-dicts, and infos via ``self.log_returns()``. This can be done for multiple concurrent episodes as well.
Logging off-policy actions
~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -330,8 +337,8 @@ You can configure any Trainer to launch a policy server with the following confi
trainer_config = {
# An environment class is still required, but it doesn't need to be runnable.
# You only need to define its action and observation space attributes.
# See examples/serving/unity3d_server.py for an example using a RandomMultiAgentEnv stub.
"env": YOUR_ENV_STUB,
# Use the policy server to generate experiences.
"input": (
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
@ -360,7 +367,13 @@ To understand the difference between standard envs, external envs, and connectin
.. https://docs.google.com/drawings/d/1hJvT9bVGHVrGTbnCZK29BYQIcYNRbZ4Dr6FOPMJDjUs/edit
.. image:: rllib-external.svg
Try it yourself by launching a `cartpole_server.py <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_server.py>`__, and connecting to it with any number of clients (`cartpole_client.py <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_client.py>`__):
Try it yourself by launching either a
`simple CartPole server <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_server.py>`__ (see below), and connecting it to any number of clients
(`cartpole_client.py <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_client.py>`__) or
run a `Unity3D learning sever <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/unity3d_server.py>`__
against distributed Unity game engines in the cloud.
CartPole Example:
.. code-block:: bash
@ -391,9 +404,9 @@ Try it yourself by launching a `cartpole_server.py <https://github.com/ray-proje
Total reward: 200.0
...
For the best performance, when possible we recommend using ``inference_mode="local"`` when possible.
For the best performance, we recommend using ``inference_mode="local"`` when possible.
Advanced Integrations
---------------------
For more complex / high-performance environment integrations, you can instead extend the low-level `BaseEnv <https://github.com/ray-project/ray/blob/master/rllib/env/base_env.py>`__ class. This low-level API models multiple agents executing asynchronously in multiple environments. A call to ``BaseEnv:poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents are sent back via ``BaseEnv:send_actions()``. BaseEnv is used to implement all the other env types in RLlib, so it offers a superset of their functionality. For example, ``BaseEnv`` is used to implement dynamic batching of observations for inference over `multiple simulator actors <https://github.com/ray-project/ray/blob/master/rllib/env/remote_vector_env.py>`__.
For more complex / high-performance environment integrations, you can instead extend the low-level `BaseEnv <https://github.com/ray-project/ray/blob/master/rllib/env/base_env.py>`__ class. This low-level API models multiple agents executing asynchronously in multiple environments. A call to ``BaseEnv:poll()`` returns observations from ready agents keyed by 1) their environment, then 2) agent ids. Actions for those agents are sent back via ``BaseEnv:send_actions()``. BaseEnv is used to implement all the other env types in RLlib, so it offers a superset of their functionality. For example, ``BaseEnv`` is used to implement dynamic batching of observations for inference over `multiple simulator actors <https://github.com/ray-project/ray/blob/master/rllib/env/remote_vector_env.py>`__.

View file

@ -36,12 +36,18 @@ Training Workflows
Custom Envs and Models
----------------------
- `Local Unity3D multi-agent environment example <https://github.com/ray-project/ray/tree/master/rllib/examples/unity3d_env_local.py>`__:
Example of how to setup an RLlib Trainer against a locally running Unity3D editor instance to
learn any Unity3D game (including support for multi-agent).
Use this example to try things out and watch the game and the learning progress live in the editor.
Providing a compiled game, this example could also run in distributed fashion with `num_workers > 0`.
For a more heavy-weight, distributed, cloud-based example, see `Unity3D client/server`_ below.
- `Registering a custom env and model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__:
Example of defining and registering a gym env and model for use with RLlib.
- `Custom Keras model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_model.py>`__:
Example of using a custom Keras model.
- `Custom Keras RNN model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_rnn_model.py>`__:
Example of using a custom Keras RNN model.
- `Custom Keras RNN model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_rnn_model.py>`__:
Example of using a custom Keras- or PyTorch RNN model.
- `Registering a custom model with supervised loss <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_loss.py>`__:
Example of defining and registering a custom model with a supervised loss.
- `Subprocess environment <https://github.com/ray-project/ray/blob/master/rllib/tests/test_env_with_subprocess.py>`__:
@ -55,7 +61,16 @@ Custom Envs and Models
Serving and Offline
-------------------
- `CartPole server <https://github.com/ray-project/ray/tree/master/rllib/examples/serving>`__:
.. _Unity3D client/server:
- `Unity3D client/server <https://github.com/ray-project/ray/tree/master/rllib/examples/serving/unity3d_server.py>`__:
Example of how to setup n distributed Unity3D (compiled) games in the cloud that function as data collecting
clients against a central RLlib Policy server learning how to play the game.
The n distributed clients could themselves be servers for external/human players and allow for control
being fully in the hands of the Unity entities instead of RLlib.
Note: Uses Unity's MLAgents SDK (>=1.0) and supports all provided MLAgents example games and multi-agent setups.
- `CartPole client/server <https://github.com/ray-project/ray/tree/master/rllib/examples/serving/cartpole_server.py>`__:
Example of online serving of predictions for a simple CartPole policy.
- `Saving experiences <https://github.com/ray-project/ray/blob/master/rllib/examples/saving_experiences.py>`__:
Example of how to externally generate experience batches in RLlib-compatible format.

View file

@ -983,6 +983,29 @@ py_test(
]
)
# --------------------------------------------------------------------
# Env tests
# rllib/env/
#
# Tag: env
# --------------------------------------------------------------------
sh_test(
name = "env/tests/test_local_inference",
tags = ["env"],
size = "medium",
srcs = ["env/tests/test_local_inference.sh"],
data = glob(["examples/serving/*.py"]),
)
sh_test(
name = "env/tests/test_remote_inference",
tags = ["env"],
size = "medium",
srcs = ["env/tests/test_remote_inference.sh"],
data = glob(["examples/serving/*.py"]),
)
# --------------------------------------------------------------------
# Models and Distributions
# rllib/models/
@ -1692,7 +1715,7 @@ py_test(
name = "examples/multi_agent_cartpole_torch",
main = "examples/multi_agent_cartpole.py",
tags = ["examples", "examples_M"],
size = "small",
size = "medium",
srcs = ["examples/multi_agent_cartpole.py"],
args = ["--as-test", "--torch", "--stop-reward=70.0", "--num-cpus=4"]
)
@ -1822,22 +1845,6 @@ py_test(
args = ["--as-test", "--torch"],
)
sh_test(
name = "examples/serving/test_local_inference",
tags = ["examples", "examples_S", "exclusive"],
size = "medium",
srcs = ["examples/serving/test_local_inference.sh"],
data = glob(["examples/serving/*.py"]),
)
sh_test(
name = "examples/serving/test_remote_inference",
tags = ["examples", "examples_S", "exclusive"],
size = "medium",
srcs = ["examples/serving/test_remote_inference.sh"],
data = glob(["examples/serving/*.py"]),
)
py_test(
name = "examples/two_trainer_workflow_tf",
main = "examples/two_trainer_workflow.py",

View file

@ -26,7 +26,7 @@ class TestDQN(unittest.TestCase):
num_iterations = 1
for fw in framework_iterator(config):
# double-dueling DQN.
# Double-dueling DQN.
plain_config = config.copy()
trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0")
for i in range(num_iterations):

View file

@ -68,7 +68,8 @@ class OnlineLinearRegression(nn.Module):
return batch_dots.sqrt()
def forward(self, x, sample_theta=False):
""" Predict the scores on input batch using the underlying linear model
""" Predict scores on input batch using the underlying linear model.
Args:
x (torch.Tensor): Input feature tensor of shape
(batch_size, feature_dim)

29
rllib/env/base_env.py vendored
View file

@ -1,7 +1,7 @@
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.utils.annotations import override, PublicAPI
ASYNC_RESET_RETURN = "async_reset_return"
@ -99,16 +99,13 @@ class BaseEnv:
make_env=make_env,
existing_envs=[env],
num_envs=num_envs)
elif isinstance(env, ExternalMultiAgentEnv):
if num_envs != 1:
raise ValueError(
"ExternalMultiAgentEnv does not currently support "
"num_envs > 1.")
env = _ExternalEnvToBaseEnv(env, multiagent=True)
elif isinstance(env, ExternalEnv):
if num_envs != 1:
raise ValueError(
"ExternalEnv does not currently support num_envs > 1.")
"External(MultiAgent)Env does not currently support "
"num_envs > 1. One way of solving this would be to "
"treat your Env as a MultiAgentEnv hosting only one "
"type of agent but with several copies.")
env = _ExternalEnvToBaseEnv(env)
elif isinstance(env, VectorEnv):
env = _VectorEnvToBaseEnv(env)
@ -166,12 +163,16 @@ class BaseEnv:
raise NotImplementedError
@PublicAPI
def try_reset(self, env_id):
"""Attempt to reset the env with the given id.
def try_reset(self, env_id=None):
"""Attempt to reset the sub-env with the given id or all sub-envs.
If the environment does not support synchronous reset, None can be
returned here.
Args:
env_id (Optional[int]): The sub-env ID if applicable. If None,
reset the entire Env (i.e. all sub-envs).
Returns:
obs (dict|None): Resetted observation or None if not supported.
"""
@ -206,10 +207,10 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
class _ExternalEnvToBaseEnv(BaseEnv):
"""Internal adapter of ExternalEnv to BaseEnv."""
def __init__(self, external_env, preprocessor=None, multiagent=False):
def __init__(self, external_env, preprocessor=None):
self.external_env = external_env
self.prep = preprocessor
self.multiagent = multiagent
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
self.action_space = external_env.action_space
if preprocessor:
self.observation_space = preprocessor.observation_space
@ -262,8 +263,8 @@ class _ExternalEnvToBaseEnv(BaseEnv):
if "off_policy_action" in data:
off_policy_actions[eid] = data["off_policy_action"]
if self.multiagent:
# ensure a consistent set of keys
# rely on all_obs having all possible keys for now
# Ensure a consistent set of keys
# rely on all_obs having all possible keys for now.
for eid, eid_dict in all_obs.items():
for agent_id in eid_dict.keys():

View file

@ -32,16 +32,14 @@ class ExternalEnv(threading.Thread):
>>> register_env("my_env", lambda config: YourExternalEnv(config))
>>> trainer = DQNTrainer(env="my_env")
>>> while True:
print(trainer.train())
>>> print(trainer.train())
"""
@PublicAPI
def __init__(self, action_space, observation_space, max_concurrent=100):
"""Initialize an external env.
"""Initializes an external env.
ExternalEnv subclasses must call this during their __init__.
Arguments:
Args:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
max_concurrent (int): Max number of active episodes to allow at
@ -49,6 +47,7 @@ class ExternalEnv(threading.Thread):
"""
threading.Thread.__init__(self)
self.daemon = True
self.action_space = action_space
self.observation_space = observation_space
@ -78,9 +77,9 @@ class ExternalEnv(threading.Thread):
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
Args:
episode_id (Optional[str]): Unique string id for the episode or
None for it to be auto-assigned and returned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
@ -108,7 +107,7 @@ class ExternalEnv(threading.Thread):
def get_action(self, episode_id, observation):
"""Record an observation and get the on-policy action.
Arguments:
Args:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
@ -123,7 +122,7 @@ class ExternalEnv(threading.Thread):
def log_action(self, episode_id, observation, action):
"""Record an observation and (off-policy) action taken.
Arguments:
Args:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
action (obj): Action for the observation.
@ -140,7 +139,7 @@ class ExternalEnv(threading.Thread):
episode. Rewards accumulate until the next action. If no reward is
logged before the next action, a reward of 0.0 is assumed.
Arguments:
Args:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
info (dict): Optional info dict.
@ -156,7 +155,7 @@ class ExternalEnv(threading.Thread):
def end_episode(self, episode_id, observation):
"""Record the end of an episode.
Arguments:
Args:
episode_id (str): Episode id returned from start_episode().
observation (obj): Current environment observation.
"""
@ -267,6 +266,7 @@ class _ExternalEnvEpisode:
self.cur_reward = 0.0
if not self.training_enabled:
item["info"]["training_enabled"] = False
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()

View file

@ -14,7 +14,7 @@ class ExternalMultiAgentEnv(ExternalEnv):
ExternalMultiAgentEnv subclasses must call this during their __init__.
Arguments:
Args:
action_space (gym.Space): Action space of the env.
observation_space (gym.Space): Observation space of the env.
max_concurrent (int): Max number of active episodes to allow at
@ -135,10 +135,7 @@ class ExternalMultiAgentEnv(ExternalEnv):
if multiagent_done_dict:
for agent, done in multiagent_done_dict.items():
if agent in episode.cur_done_dict:
episode.cur_done_dict[agent] = done
else:
episode.cur_done_dict[agent] = done
episode.cur_done_dict[agent] = done
if info_dict:
episode.cur_info_dict = info_dict or {}

View file

@ -21,9 +21,9 @@ class MultiAgentEnv:
"traffic_light_1": [0, 3, 5, 1],
}
>>> obs, rewards, dones, infos = env.step(
action_dict={
"car_0": 1, "car_1": 0, "traffic_light_1": 2,
})
... action_dict={
... "car_0": 1, "car_1": 0, "traffic_light_1": 2,
... })
>>> print(rewards)
{
"car_0": 3,

View file

@ -11,6 +11,7 @@ import time
import ray.cloudpickle as pickle
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI
logger = logging.getLogger(__name__)
@ -54,6 +55,7 @@ class PolicyClient:
or None for manual control via client.
"""
self.address = address
self.env = None
if inference_mode == "local":
self.local = True
self._setup_local_rollout_worker(update_interval)
@ -65,11 +67,11 @@ class PolicyClient:
@PublicAPI
def start_episode(self, episode_id=None, training_enabled=True):
"""Record the start of an episode.
"""Record the start of one or more episode(s).
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
Args:
episode_id (Optional[str]): Unique string id for the episode or
None for it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.
@ -101,13 +103,20 @@ class PolicyClient:
if self.local:
self._update_local_policy()
return self.env.get_action(episode_id, observation)
return self._send({
"command": PolicyClient.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
})["action"]
if isinstance(episode_id, (list, tuple)):
actions = {
eid: self.env.get_action(eid, observation[eid])
for eid in episode_id
}
return actions
else:
return self.env.get_action(episode_id, observation)
else:
return self._send({
"command": PolicyClient.GET_ACTION,
"observation": observation,
"episode_id": episode_id,
})["action"]
@PublicAPI
def log_action(self, episode_id, observation, action):
@ -151,11 +160,11 @@ class PolicyClient:
if self.local:
self._update_local_policy()
if multiagent_done_dict:
if multiagent_done_dict is not None:
assert isinstance(reward, dict)
return self.env.log_returns(episode_id, reward, info,
multiagent_done_dict)
else:
return self.env.log_returns(episode_id, reward, info)
return self.env.log_returns(episode_id, reward, info)
self._send({
"command": PolicyClient.LOG_RETURNS,
@ -207,7 +216,6 @@ class PolicyClient:
kwargs = self._send({
"command": PolicyClient.GET_WORKER_ARGS,
})["worker_args"]
(self.rollout_worker,
self.inference_thread) = create_embedded_rollout_worker(
kwargs, self._send)
@ -245,8 +253,14 @@ class _LocalInferenceThread(threading.Thread):
logger.info("Generating new batch of experiences.")
samples = self.rollout_worker.sample()
metrics = self.rollout_worker.get_metrics()
logger.info("Sending batch of {} steps back to server.".format(
samples.count))
if isinstance(samples, MultiAgentBatch):
logger.info(
"Sending batch of {} env steps ({} agent steps) to "
"server.".format(samples.count, samples.total()))
else:
logger.info(
"Sending batch of {} steps back to server.".format(
samples.count))
self.send_fn({
"command": PolicyClient.REPORT_SAMPLES,
"samples": samples,
@ -265,11 +279,11 @@ def auto_wrap_external(real_env_creator):
def wrapped_creator(env_config):
real_env = real_env_creator(env_config)
if not (isinstance(real_env, ExternalEnv)
or isinstance(real_env, ExternalMultiAgentEnv)):
if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)):
logger.info(
"The env you specified is not a type of ExternalEnv. "
"Attempting to convert it automatically to ExternalEnv.")
"The env you specified is not a supported (sub-)type of "
"ExternalEnv. Attempting to convert it automatically to "
"ExternalEnv.")
if isinstance(real_env, MultiAgentEnv):
external_cls = ExternalMultiAgentEnv
@ -278,8 +292,9 @@ def auto_wrap_external(real_env_creator):
class ExternalEnvWrapper(external_cls):
def __init__(self, real_env):
super().__init__(real_env.action_space,
real_env.observation_space)
super().__init__(
observation_space=real_env.observation_space,
action_space=real_env.action_space)
def run(self):
# Since we are calling methods on this class in the

View file

@ -4,8 +4,8 @@ rm -f last_checkpoint.out
pkill -f cartpole_server.py
sleep 1
if [ -f cartpole_server.py ]; then
basedir="."
if [ -f test_local_inference.sh ]; then
basedir="../../examples/serving"
else
basedir="rllib/examples/serving" # In bazel.
fi
@ -14,10 +14,10 @@ fi
pid=$!
echo "Waiting for server to start"
while ! curl localhost:9900; do
while ! curl localhost:9900; do
sleep 1
done
sleep 2
python $basedir/cartpole_client.py --stop-at-reward=100 --inference-mode=local
python $basedir/cartpole_client.py --stop-reward=150 --inference-mode=local
kill $pid

View file

@ -4,8 +4,8 @@ rm -f last_checkpoint.out
pkill -f cartpole_server.py
sleep 1
if [ -f cartpole_server.py ]; then
basedir="."
if [ -f test_local_inference.sh ]; then
basedir="../../examples/serving"
else
basedir="rllib/examples/serving" # In bazel.
fi
@ -14,11 +14,11 @@ fi
pid=$!
echo "Waiting for server to start"
while ! curl localhost:9900; do
while ! curl localhost:9900; do
sleep 1
done
sleep 2
python $basedir/cartpole_client.py --stop-at-reward=100 --inference-mode=remote
python $basedir/cartpole_client.py --stop-reward=150 --inference-mode=remote
kill $pid

232
rllib/env/unity3d_env.py vendored Normal file
View file

@ -0,0 +1,232 @@
from gym.spaces import Box, MultiDiscrete, Tuple
import logging
import mlagents_envs
from mlagents_envs.environment import UnityEnvironment
import numpy as np
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import override
logger = logging.getLogger(__name__)
class Unity3DEnv(MultiAgentEnv):
"""A MultiAgentEnv representing a single Unity3D game instance.
For an example on how to use this class inside a Unity game client, which
connects to an RLlib Policy server, see:
`rllib/examples/serving/unity3d_[client|server].py`
Supports all Unity3D (MLAgents) examples, multi- or single-agent and
gets converted automatically into an ExternalMultiAgentEnv, when used
inside an RLlib PolicyClient for cloud/distributed training of Unity games.
"""
def __init__(self,
file_name=None,
worker_id=0,
base_port=5004,
seed=0,
no_graphics=False,
timeout_wait=60,
episode_horizon=1000):
"""Initializes a Unity3DEnv object.
Args:
file_name (Optional[str]): Name of the Unity game binary.
If None, will assume a locally running Unity3D editor
to be used, instead.
worker_id (int): Number to add to `base_port`. Used when more than
one Unity3DEnv (games) are running on the same machine. This
will be determined automatically, if possible, so a value
of 0 should always suffice here.
base_port (int): Port number to connect to Unity environment.
`worker_id` increments on top of this.
seed (int): A random seed value to use for the Unity3D game.
no_graphics (bool): Whether to run the Unity3D simulator in
no-graphics mode. Default: False.
timeout_wait (int): Time (in seconds) to wait for connection from
the Unity3D instance.
episode_horizon (int): A hard horizon to abide to. After at most
this many steps (per-agent episode `step()` calls), the
Unity3D game is reset and will start again (finishing the
multi-agent episode that the game represents).
Note: The game itself may contain its own episode length
limits, which are always obeyed (on top of this value here).
"""
super().__init__()
if file_name is None:
print(
"No game binary provided, will use a running Unity editor "
"instead.\nMake sure you are pressing the Play (|>) button in "
"your editor to start.")
# Try connecting to the Unity3D game instance. If a port
while True:
self.worker_id = worker_id
try:
self.unity_env = UnityEnvironment(
file_name=file_name,
worker_id=worker_id,
base_port=base_port,
seed=seed,
no_graphics=no_graphics,
timeout_wait=timeout_wait,
)
except mlagents_envs.exception.UnityWorkerInUseException as e:
worker_id += 1
# Hard limit.
if worker_id > 100:
raise e
else:
break
# Reset entire env every this number of step calls.
self.episode_horizon = episode_horizon
# Keep track of how many times we have called `step` so far.
self.episode_timesteps = 0
@override(MultiAgentEnv)
def step(self, action_dict):
"""Performs one multi-agent step through the game.
Args:
action_dict (dict): Multi-agent action dict with:
keys=agent identifier consisting of
[MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
[Agent index, a unique MLAgent-assigned index per single
agent]
Returns:
tuple:
obs: Multi-agent observation dict.
Only those observations for which to get new actions are
returned.
rewards: Rewards dict matching `obs`.
dones: Done dict with only an __all__ multi-agent entry in it.
__all__=True, if episode is done for all agents.
infos: An (empty) info dict.
"""
# Set only the required actions (from the DecisionSteps) in Unity3D.
all_agents = []
for behavior_name in self.unity_env.get_behavior_names():
for agent_id in self.unity_env.get_steps(behavior_name)[
0].agent_id_to_index.keys():
key = behavior_name + "_{}".format(agent_id)
all_agents.append(key)
self.unity_env.set_action_for_agent(behavior_name, agent_id,
action_dict[key])
# Do the step.
self.unity_env.step()
obs, rewards, dones, infos = self._get_step_results()
# Global horizon reached? -> Return __all__ done=True, so user
# can reset. Set all agents' individual `done` to True as well.
self.episode_timesteps += 1
if self.episode_timesteps > self.episode_horizon:
return obs, rewards, dict({
"__all__": True
}, **{agent_id: True
for agent_id in all_agents}), infos
return obs, rewards, dones, infos
@override(MultiAgentEnv)
def reset(self):
"""Resets the entire Unity3D scene (a single multi-agent episode)."""
self.episode_timesteps = 0
self.unity_env.reset()
obs, _, _, _ = self._get_step_results()
return obs
def _get_step_results(self):
"""Collects those agents' obs/rewards that have to act in next `step`.
Returns:
Tuple:
obs: Multi-agent observation dict.
Only those observations for which to get new actions are
returned.
rewards: Rewards dict matching `obs`.
dones: Done dict with only an __all__ multi-agent entry in it.
__all__=True, if episode is done for all agents.
infos: An (empty) info dict.
"""
obs = {}
rewards = {}
infos = {}
for behavior_name in self.unity_env.get_behavior_names():
decision_steps, terminal_steps = self.unity_env.get_steps(
behavior_name)
# Important: Only update those sub-envs that are currently
# available within _env_state.
# Loop through all envs ("agents") and fill in, whatever
# information we have.
for agent_id, idx in decision_steps.agent_id_to_index.items():
key = behavior_name + "_{}".format(agent_id)
os = tuple(o[idx] for o in decision_steps.obs)
os = os[0] if len(os) == 1 else os
obs[key] = os
rewards[key] = decision_steps.reward[idx] # rewards vector
for agent_id, idx in terminal_steps.agent_id_to_index.items():
key = behavior_name + "_{}".format(agent_id)
# Only overwrite rewards (last reward in episode), b/c obs
# here is the last obs (which doesn't matter anyways).
# Unless key does not exist in obs.
if key not in obs:
os = tuple(o[idx] for o in terminal_steps.obs)
obs[key] = os = os[0] if len(os) == 1 else os
rewards[key] = terminal_steps.reward[idx] # rewards vector
# Only use dones if all agents are done, then we should do a reset.
return obs, rewards, {"__all__": False}, infos
@staticmethod
def get_policy_configs_for_game(game_name):
# The RLlib server must know about the Spaces that the Client will be
# using inside Unity3D, up-front.
obs_spaces = {
# SoccerStrikersVsGoalie.
"Striker": Tuple([
Box(float("-inf"), float("inf"), (231, )),
Box(float("-inf"), float("inf"), (63, )),
]),
"Goalie": Box(float("-inf"), float("inf"), (738, )),
# 3DBall.
"Agent": Box(float("-inf"), float("inf"), (8, )),
}
action_spaces = {
# SoccerStrikersVsGoalie.
"Striker": MultiDiscrete([3, 3, 3]),
"Goalie": MultiDiscrete([3, 3, 3]),
# 3DBall.
"Agent": Box(float("-inf"), float("inf"), (2, ), dtype=np.float32),
}
# Policies (Unity: "behaviors") and agent-to-policy mapping fns.
if game_name == "SoccerStrikersVsGoalie":
policies = {
"Striker": (None, obs_spaces["Striker"],
action_spaces["Striker"], {}),
"Goalie": (None, obs_spaces["Goalie"], action_spaces["Goalie"],
{}),
}
def policy_mapping_fn(agent_id):
return "Striker" if "Striker" in agent_id else "Goalie"
else: # 3DBall
policies = {
"Agent": (None, obs_spaces["Agent"], action_spaces["Agent"],
{})
}
def policy_mapping_fn(agent_id):
return "Agent"
return policies, policy_mapping_fn

View file

@ -8,31 +8,43 @@ logger = logging.getLogger(__name__)
@PublicAPI
class VectorEnv:
"""An environment that supports batch evaluation.
Subclasses must define the following attributes:
Attributes:
action_space (gym.Space): Action space of individual envs.
observation_space (gym.Space): Observation space of individual envs.
num_envs (int): Number of envs in this vector env.
"""An environment that supports batch evaluation using clones of sub-envs.
"""
def __init__(self, observation_space, action_space, num_envs):
"""Initializes a VectorEnv object.
Args:
observation_space (Space): The observation Space of a single
sub-env.
action_space (Space): The action Space of a single sub-env.
num_envs (int): The number of clones to make of the given sub-env.
"""
self.observation_space = observation_space
self.action_space = action_space
self.num_envs = num_envs
@staticmethod
def wrap(make_env=None,
existing_envs=None,
num_envs=1,
action_space=None,
observation_space=None):
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
action_space, observation_space)
observation_space=None,
env_config=None):
return _VectorizedGymEnv(
make_env=make_env,
existing_envs=existing_envs or [],
num_envs=num_envs,
observation_space=observation_space,
action_space=action_space,
env_config=env_config)
@PublicAPI
def vector_reset(self):
"""Resets all environments.
"""Resets all sub-environments.
Returns:
obs (list): Vector of observations from each environment.
obs (List[any]): List of observations from each environment.
"""
raise NotImplementedError
@ -41,55 +53,73 @@ class VectorEnv:
"""Resets a single environment.
Returns:
obs (obj): Observations from the resetted environment.
obs (obj): Observations from the reset sub environment.
"""
raise NotImplementedError
@PublicAPI
def vector_step(self, actions):
"""Vectorized step.
"""Performs a vectorized step on all sub environments using `actions`.
Arguments:
actions (list): Actions for each env.
actions (List[any]): List of actions (one for each sub-env).
Returns:
obs (list): New observations for each env.
rewards (list): Reward values for each env.
dones (list): Done values for each env.
infos (list): Info values for each env.
obs (List[any]): New observations for each sub-env.
rewards (List[any]): Reward values for each sub-env.
dones (List[any]): Done values for each sub-env.
infos (List[any]): Info values for each sub-env.
"""
raise NotImplementedError
@PublicAPI
def get_unwrapped(self):
"""Returns the underlying env instances."""
"""Returns the underlying sub environments.
Returns:
List[Env]: List of all underlying sub environments.
"""
raise NotImplementedError
class _VectorizedGymEnv(VectorEnv):
"""Internal wrapper for gym envs to implement VectorEnv.
Arguments:
make_env (func|None): Factory that produces a new gym env. Must be
defined if the number of existing envs is less than num_envs.
existing_envs (list): List of existing gym envs.
num_envs (int): Desired num gym envs to keep total.
"""Internal wrapper to translate any gym envs into a VectorEnv object.
"""
def __init__(self,
make_env,
existing_envs,
num_envs,
make_env=None,
existing_envs=None,
num_envs=1,
*,
observation_space=None,
action_space=None,
observation_space=None):
env_config=None):
"""Initializes a _VectorizedGymEnv object.
Args:
make_env (Optional[callable]): Factory that produces a new gym env
taking a single `config` dict arg. Must be defined if the
number of `existing_envs` is less than `num_envs`.
existing_envs (Optional[List[Env]]): Optional list of already
instantiated sub environments.
num_envs (int): Total number of sub environments in this VectorEnv.
action_space (Optional[Space]): The action space. If None, use
existing_envs[0]'s action space.
observation_space (Optional[Space]): The observation space.
If None, use existing_envs[0]'s action space.
env_config (Optional[dict]): Additional sub env config to pass to
make_env as first arg.
"""
self.make_env = make_env
self.envs = existing_envs
self.num_envs = num_envs
while len(self.envs) < self.num_envs:
while len(self.envs) < num_envs:
self.envs.append(self.make_env(len(self.envs)))
self.action_space = action_space or self.envs[0].action_space
self.observation_space = observation_space or \
self.envs[0].observation_space
super().__init__(
observation_space=observation_space
or self.envs[0].observation_space,
action_space=action_space or self.envs[0].action_space,
num_envs=num_envs)
@override(VectorEnv)
def vector_reset(self):

View file

@ -303,11 +303,11 @@ class RolloutWorker(ParallelIteratorWorker):
self.fake_sampler = fake_sampler
self.env = _validate_env(env_creator(env_context))
if isinstance(self.env, MultiAgentEnv) or \
isinstance(self.env, BaseEnv):
if isinstance(self.env, (BaseEnv, MultiAgentEnv)):
def wrap(env):
return env # we can't auto-wrap these env types
elif is_atari(self.env) and \
not model_config.get("custom_preprocessor") and \
preprocessor_pref == "deepmind":
@ -411,7 +411,7 @@ class RolloutWorker(ParallelIteratorWorker):
if self.worker_index == 0:
logger.info("Built filter map: {}".format(self.filters))
# Always use vector env for consistency even if num_envs = 1
# Always use vector env for consistency even if num_envs = 1.
self.async_env = BaseEnv.to_base_env(
self.env,
make_env=make_env,

View file

@ -2,10 +2,10 @@ import collections
import logging
import numpy as np
from ray.util.debug import log_once
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.debug import summarize
from ray.util.debug import log_once
logger = logging.getLogger(__name__)
@ -124,8 +124,9 @@ class MultiAgentSampleBatchBuilder:
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Arguments:
episode: current MultiAgentEpisode object or None
Args:
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
object.
"""
# Materialize the batches so far
@ -198,8 +199,9 @@ class MultiAgentSampleBatchBuilder:
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Arguments:
episode: current MultiAgentEpisode object or None
Args:
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
object.
"""
self.postprocess_batch_so_far(episode)

View file

@ -341,7 +341,7 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
while True:
perf_stats.iters += 1
t0 = time.time()
# Get observations from all ready agents
# Get observations from all ready agents.
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
base_env.poll()
perf_stats.env_wait_time += time.time() - t0
@ -351,7 +351,7 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
summarize(unfiltered_obs)))
logger.info("Info return from env: {}".format(summarize(infos)))
# Process observations and prepare for policy evaluation
# Process observations and prepare for policy evaluation.
t1 = time.time()
active_envs, to_eval, outputs = _process_observations(
worker, base_env, policies, batch_builder_pool, active_episodes,
@ -362,13 +362,13 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
for o in outputs:
yield o
# Do batched policy eval
# Do batched policy eval (accross vectorized envs).
t2 = time.time()
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
active_episodes)
perf_stats.inference_time += time.time() - t2
# Process results and update episode state
# Process results and update episode state.
t3 = time.time()
actions_to_send = _process_policy_eval_results(
to_eval, eval_results, active_episodes, active_envs,
@ -401,11 +401,11 @@ def _process_observations(
large_batch_threshold = max(1000, rollout_fragment_length * 10) if \
rollout_fragment_length != float("inf") else 5000
# For each environment
# For each environment.
for env_id, agent_obs in unfiltered_obs.items():
new_episode = env_id not in active_episodes
is_new_episode = env_id not in active_episodes
episode = active_episodes[env_id]
if not new_episode:
if not is_new_episode:
episode.length += 1
episode.batch_builder.count += 1
episode._add_agent_rewards(rewards[env_id])
@ -427,11 +427,11 @@ def _process_observations(
"to terminate (batch_mode=`complete_episodes`). Make sure it "
"does at some point.")
# Check episode termination conditions
# Check episode termination conditions.
if dones[env_id]["__all__"] or episode.length >= horizon:
hit_horizon = (episode.length >= horizon
and not dones[env_id]["__all__"])
all_done = True
all_agents_done = True
atari_metrics = _fetch_atari_metrics(base_env)
if atari_metrics is not None:
for m in atari_metrics:
@ -445,7 +445,7 @@ def _process_observations(
episode.hist_data))
else:
hit_horizon = False
all_done = False
all_agents_done = False
active_envs.add(env_id)
# Custom observation function is applied before preprocessing.
@ -473,7 +473,7 @@ def _process_observations(
if log_once("filtered_obs"):
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
agent_done = bool(all_done or dones[env_id].get(agent_id))
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
if not agent_done:
to_eval[policy_id].append(
PolicyEvalData(env_id, agent_id, filtered_obs,
@ -517,15 +517,15 @@ def _process_observations(
if episode.batch_builder.has_pending_agent_data():
if dones[env_id]["__all__"] and not no_done_at_end:
episode.batch_builder.check_missing_dones()
if (all_done and not pack) or \
if (all_agents_done and not pack) or \
episode.batch_builder.count >= rollout_fragment_length:
outputs.append(episode.batch_builder.build_and_reset(episode))
elif all_done:
elif all_agents_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far(episode)
if all_done:
# Handle episode termination
if all_agents_done:
# Handle episode termination.
batch_builder_pool.append(episode.batch_builder)
# Call each policy's Exploration.on_episode_end method.
for p in policies.values():
@ -548,13 +548,13 @@ def _process_observations(
del active_episodes[env_id]
resetted_obs = base_env.try_reset(env_id)
if resetted_obs is None:
# Reset not supported, drop this env from the ready list
# Reset not supported, drop this env from the ready list.
if horizon != float("inf"):
raise ValueError(
"Setting episode horizon requires reset() support "
"from the environment.")
elif resetted_obs != ASYNC_RESET_RETURN:
# Creates a new episode if this is not async return
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll
episode = active_episodes[env_id]
if observation_fn:
@ -623,7 +623,6 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
prev_reward_batch=prev_reward_batch,
timestep=policy.global_timestep)
else:
# TODO(sven): Does this work for LSTM torch?
rnn_in_cols = [
np.stack([row[i] for row in rnn_in])
for i in range(len(rnn_in[0]))

View file

@ -28,7 +28,7 @@ if __name__ == "__main__":
assert not args.torch, "PyTorch not supported for AttentionNets yet!"
ray.init(num_cpus=args.num_cpus or None, local_mode=True)
ray.init(num_cpus=args.num_cpus or None)
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
registry.register_env("RepeatInitialObsEnv",

View file

@ -4,12 +4,16 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.tests.test_rollout_worker import MockEnv, MockEnv2
def make_multiagent(env_name):
def make_multiagent(env_name_or_creator):
class MultiEnv(MultiAgentEnv):
def __init__(self, config):
self.agents = [
gym.make(env_name) for _ in range(config["num_agents"])
]
num = config.pop("num_agents", 1)
if isinstance(env_name_or_creator, str):
self.agents = [
gym.make(env_name_or_creator) for _ in range(num)
]
else:
self.agents = [env_name_or_creator(config) for _ in range(num)]
self.dones = set()
self.observation_space = self.agents[0].observation_space
self.action_space = self.agents[0].action_space

View file

@ -1,7 +1,9 @@
import gym
from gym.spaces import Tuple
from gym.spaces import Discrete, Tuple
import numpy as np
from ray.rllib.examples.env.multi_agent import make_multiagent
class RandomEnv(gym.Env):
"""A randomly acting environment.
@ -14,9 +16,9 @@ class RandomEnv(gym.Env):
def __init__(self, config):
# Action space.
self.action_space = config["action_space"]
self.action_space = config.get("action_space", Discrete(2))
# Observation space from which to sample.
self.observation_space = config["observation_space"]
self.observation_space = config.get("observation_space", Discrete(2))
# Reward space from which to sample.
self.reward_space = config.get(
"reward_space",
@ -43,3 +45,7 @@ class RandomEnv(gym.Env):
bool(np.random.choice(
[True, False], p=[self.p_done, 1.0 - self.p_done]
)), {}
# Multi-agent version of the RandomEnv.
RandomMultiAgentEnv = make_multiagent(lambda c: RandomEnv(c))

View file

@ -85,7 +85,12 @@ class SharedWeightsModel2(TFModelV2):
TORCH_GLOBAL_SHARED_LAYER = None
if torch:
TORCH_GLOBAL_SHARED_LAYER = SlimFC(32, 32)
TORCH_GLOBAL_SHARED_LAYER = SlimFC(
64,
64,
activation_fn=nn.ReLU,
initializer=torch.nn.init.xavier_uniform_,
)
class TorchSharedWeightsModel(TorchModelV2, nn.Module):
@ -104,12 +109,22 @@ class TorchSharedWeightsModel(TorchModelV2, nn.Module):
# Non-shared initial layer.
self.first_layer = SlimFC(
int(np.product(observation_space.shape)),
32,
activation_fn=nn.ReLU)
64,
activation_fn=nn.ReLU,
initializer=torch.nn.init.xavier_uniform_)
# Non-shared final layer.
self.last_layer = SlimFC(32, self.num_outputs, activation_fn=nn.ReLU)
self.vf = SlimFC(32, 1, activation_fn=None)
self.last_layer = SlimFC(
64,
self.num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_)
self.vf = SlimFC(
64,
1,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self._output = None
@override(ModelV2)

View file

@ -28,7 +28,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--stop-iters", type=int, default=20)
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-reward", type=float, default=150)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--simple", action="store_true")
@ -74,7 +74,6 @@ if __name__ == "__main__":
"env_config": {
"num_agents": args.num_agents,
},
"log_level": "DEBUG",
"simple_optimizer": args.simple,
"num_sgd_iter": 10,
"multiagent": {
@ -89,7 +88,7 @@ if __name__ == "__main__":
"training_iteration": args.stop_iters,
}
results = tune.run("PPO", stop=stop, config=config)
results = tune.run("PPO", stop=stop, config=config, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)

View file

@ -23,7 +23,7 @@ parser.add_argument(
action="store_true",
help="Whether to take random instead of on-policy actions.")
parser.add_argument(
"--stop-at-reward",
"--stop-reward",
type=int,
default=9999,
help="Stop once the specified reward is reached.")
@ -49,7 +49,7 @@ if __name__ == "__main__":
client.log_returns(eid, reward, info=info)
if done:
print("Total reward:", rewards)
if rewards >= args.stop_at_reward:
if rewards >= args.stop_reward:
print("Target reward achieved, exiting")
exit(0)
rewards = 0

View file

@ -32,8 +32,7 @@ if __name__ == "__main__":
connector_config = {
# Use the connector server to generate experiences.
"input": (
lambda ioctx: PolicyServerInput( \
ioctx, SERVER_ADDRESS, SERVER_PORT)
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
),
# Use a single worker process to run the server.
"num_workers": 0,

View file

@ -0,0 +1,120 @@
"""
Example of running a Unity3D client instance against an RLlib Policy server.
Unity3D clients can be run in distributed fashion on n nodes in the cloud
and all connect to the same RLlib server for faster sample collection.
For a locally running Unity3D example, see:
`examples/unity3d_env_local.py`
To run this script on possibly different machines
against a central Policy server:
1) Install Unity3D and `pip install mlagents`.
2) Compile a Unity3D example game with MLAgents support (e.g. 3DBall or any
other one that you created yourself) and place the compiled binary
somewhere, where your RLlib client script (see below) can access it.
2.1) To find Unity3D MLAgent examples, first `pip install mlagents`,
then check out the `.../ml-agents/Project/Assets/ML-Agents/Examples/`
folder.
3) Change your RLlib Policy server code so it knows the observation- and
action Spaces, the different Policies (called "behaviors" in Unity3D
MLAgents), and Agent-to-Policy mappings for your particular game.
Alternatively, use one of the two already existing setups (3DBall or
SoccerStrikersVsGoalie).
4) Then run (two separate shells/machines):
$ python unity3d_server.py --env 3DBall
$ python unity3d_client.py --inference-mode=local --game [path to game binary]
"""
import argparse
from ray.rllib.env.policy_client import PolicyClient
from ray.rllib.env.unity3d_env import Unity3DEnv
SERVER_ADDRESS = "localhost"
SERVER_PORT = 9900
parser = argparse.ArgumentParser()
parser.add_argument(
"--game",
type=str,
default=None,
help="The game executable to run as RL env. If not provided, uses local "
"Unity3D editor instance.")
parser.add_argument(
"--horizon",
type=int,
default=200,
help="The max. number of `step()`s for any episode (per agent) before "
"it'll be reset again automatically.")
parser.add_argument(
"--server",
type=str,
default=SERVER_ADDRESS + ":" + str(SERVER_PORT),
help="The Policy server's address and port to connect to from this client."
)
parser.add_argument(
"--no-train",
action="store_true",
help="Whether to disable training (on the server side).")
parser.add_argument(
"--inference-mode",
type=str,
default="local",
choices=["local", "remote"],
help="Whether to compute actions `local`ly or `remote`ly. Note that "
"`local` is much faster b/c observations/actions do not have to be "
"sent via the network.")
parser.add_argument(
"--update-interval-local-mode",
type=float,
default=10.0,
help="For `inference-mode=local`, every how many seconds do we update "
"learnt policy weights from the server?")
parser.add_argument(
"--stop-reward",
type=int,
default=9999,
help="Stop once the specified reward is reached.")
if __name__ == "__main__":
args = parser.parse_args()
# Start the client for sending environment information (e.g. observations,
# actions) to a policy server (listening on port 9900).
client = PolicyClient(
"http://" + args.server,
inference_mode=args.inference_mode,
update_interval=args.update_interval_local_mode)
# Start and reset the actual Unity3DEnv (either already running Unity3D
# editor or a binary (game) to be started automatically).
env = Unity3DEnv(file_name=args.game, episode_horizon=args.horizon)
obs = env.reset()
eid = client.start_episode(training_enabled=not args.no_train)
# Keep track of the total reward per episode.
total_rewards_this_episode = 0.0
# Loop infinitely through the env.
while True:
# Get actions from the Policy server given our current obs.
actions = client.get_action(eid, obs)
# Apply actions to our env.
obs, rewards, dones, infos = env.step(actions)
total_rewards_this_episode += sum(rewards.values())
# Log rewards and single-agent dones.
client.log_returns(eid, rewards, infos, multiagent_done_dict=dones)
# Check whether all agents are done and end the episode, if necessary.
if dones["__all__"]:
print("Episode done: Reward={}".format(total_rewards_this_episode))
if total_rewards_this_episode >= args.stop_reward:
quit(0)
# End the episode and reset Unity Env.
total_rewards_this_episode = 0.0
client.end_episode(eid, obs)
obs = env.reset()
# Start a new episode.
eid = client.start_episode(training_enabled=not args.no_train)

View file

@ -0,0 +1,129 @@
"""
Example of running a Unity3D (MLAgents) Policy server that can learn
Policies via sampling inside many connected Unity game clients (possibly
running in the cloud on n nodes).
For a locally running Unity3D example, see:
`examples/unity3d_env_local.py`
To run this script against one or more possibly cloud-based clients:
1) Install Unity3D and `pip install mlagents`.
2) Compile a Unity3D example game with MLAgents support (e.g. 3DBall or any
other one that you created yourself) and place the compiled binary
somewhere, where your RLlib client script (see below) can access it.
2.1) To find Unity3D MLAgent examples, first `pip install mlagents`,
then check out the `.../ml-agents/Project/Assets/ML-Agents/Examples/`
folder.
3) Change this RLlib Policy server code so it knows the observation- and
action Spaces, the different Policies (called "behaviors" in Unity3D
MLAgents), and Agent-to-Policy mappings for your particular game.
Alternatively, use one of the two already existing setups (3DBall or
SoccerStrikersVsGoalie).
4) Then run (two separate shells/machines):
$ python unity3d_server.py --env 3DBall
$ python unity3d_client.py --inference-mode=local --game [path to game binary]
"""
import argparse
import os
import ray
from ray.tune import register_env
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.env.policy_server_input import PolicyServerInput
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
from ray.rllib.examples.env.unity3d_env import Unity3DEnv
SERVER_ADDRESS = "localhost"
SERVER_PORT = 9900
CHECKPOINT_FILE = "last_checkpoint_{}.out"
parser = argparse.ArgumentParser()
parser.add_argument(
"--env",
type=str,
default="3DBall",
choices=["3DBall", "SoccerStrikersVsGoalie"],
help="The name of the Env to run in the Unity3D editor. Either `3DBall` "
"or `SoccerStrikersVsGoalie` (feel free to add more to this script!)")
parser.add_argument(
"--port",
type=int,
default=SERVER_PORT,
help="The Policy server's port to listen on for ExternalEnv client "
"conections.")
parser.add_argument(
"--checkpoint-freq",
type=int,
default=10,
help="The frequency with which to create checkpoint files of the learnt "
"Policies.")
parser.add_argument(
"--no-restore",
action="store_true",
help="Whether to load the Policy "
"weights from a previous checkpoint")
if __name__ == "__main__":
args = parser.parse_args()
ray.init(local_mode=True)
# Create a fake-env for the server. This env will never be used (neither
# for sampling, nor for evaluation) and its obs/action Spaces do not
# matter either (multi-agent config below defines Spaces per Policy).
register_env("fake_unity", lambda c: RandomMultiAgentEnv(c))
policies, policy_mapping_fn = \
Unity3DEnv.get_policy_configs_for_game(args.env)
# The entire config will be sent to connecting clients so they can
# build their own samplers (and also Policy objects iff
# `inference_mode=local` on clients' command line).
config = {
# Use the connector server to generate experiences.
"input": (
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)),
# Use a single worker process (w/ SyncSampler) to run the server.
"num_workers": 0,
# Disable OPE, since the rollouts are coming from online clients.
"input_evaluation": [],
# Other settings.
"sample_batch_size": 64,
"train_batch_size": 256,
"rollout_fragment_length": 20,
# Multi-agent setup for the particular env.
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_mapping_fn,
},
"framework": "tf",
}
# Create the Trainer used for Policy serving.
trainer = PPOTrainer(env="fake_unity", config=config)
# Attempt to restore from checkpoint if possible.
checkpoint_path = CHECKPOINT_FILE.format(args.env)
if not args.no_restore and os.path.exists(checkpoint_path):
checkpoint_path = open(checkpoint_path).read()
print("Restoring from checkpoint path", checkpoint_path)
trainer.restore(checkpoint_path)
# Serving and training loop.
count = 0
while True:
# Calls to train() will block on the configured `input` in the Trainer
# config above (PolicyServerInput).
print(trainer.train())
if count % args.checkpoint_freq == 0:
print("Saving learning progress to checkpoint file.")
checkpoint = trainer.save()
# Write the latest checkpoint location to CHECKPOINT_FILE,
# so we can pick up from the latest one after a server re-start.
with open(checkpoint_path, "w") as f:
f.write(checkpoint)
count += 1

View file

@ -0,0 +1,95 @@
"""
Example of running an RLlib Trainer against a locally running Unity3D editor
instance (available as Unity3DEnv inside RLlib).
For a distributed cloud setup example with Unity,
see `examples/serving/unity3d_[server|client].py`
To run this script against a local Unity3D engine:
1) Install Unity3D and `pip install mlagents`.
2) Open the Unity3D Editor and load an example scene from the following
ml-agents pip package location:
`.../ml-agents/Project/Assets/ML-Agents/Examples/`
This script supports the `3DBall` and `SoccerStrikersVsGoalie` examples.
Specify the game you chose on your command line via e.g. `--env 3DBall`.
Feel free to add more supported examples here.
3) Then run this script (you will have to press Play in your Unity editor
at some point to start the game and the learning process):
$ python unity3d_env_local.py --env 3DBall --stop-reward [..] [--torch]?
"""
import argparse
import ray
from ray import tune
from ray.rllib.env.unity3d_env import Unity3DEnv
from ray.rllib.utils.test_utils import check_learning_achieved
parser = argparse.ArgumentParser()
parser.add_argument(
"--env",
type=str,
default="3DBall",
choices=["3DBall", "SoccerStrikersVsGoalie"],
help="The name of the Env to run in the Unity3D editor. Either `3DBall` "
"or `SoccerStrikersVsGoalie` (feel free to add more to this script!)")
parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=150)
parser.add_argument("--stop-reward", type=float, default=9999.0)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument(
"--horizon",
type=int,
default=200,
help="The max. number of `step()`s for any episode (per agent) before "
"it'll be reset again automatically.")
parser.add_argument("--torch", action="store_true")
if __name__ == "__main__":
ray.init(local_mode=True)
args = parser.parse_args()
tune.register_env(
"unity3d",
lambda c: Unity3DEnv(episode_horizon=c.get("episode_horizon", 1000)))
# Get policies (different agent types; "behaviors" in MLAgents) and
# the mappings from individual agents to Policies.
policies, policy_mapping_fn = \
Unity3DEnv.get_policy_configs_for_game(args.env)
config = {
"env": "unity3d",
"env_config": {
"episode_horizon": args.horizon,
},
# IMPORTANT: Just use one Worker (we only have one Unity running)!
"num_workers": 0,
# Other settings.
"sample_batch_size": 64,
"train_batch_size": 256,
"rollout_fragment_length": 20,
# Multi-agent setup for the particular env.
"multiagent": {
"policies": policies,
"policy_mapping_fn": policy_mapping_fn,
},
"framework": "tf",
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
# Run the experiment.
results = tune.run("PPO", config=config, stop=stop, verbose=1)
# And check the results.
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()

View file

@ -232,7 +232,7 @@ class TorchPolicy(Policy):
loss_out = force_list(
self._loss(self, self.model, self.dist_class, train_batch))
assert len(loss_out) == len(self._optimizers)
# assert not any(np.isnan(l.detach().numpy()) for l in loss_out)
# assert not any(torch.isnan(l) for l in loss_out)
# Loop through all optimizers.
grad_info = {"allreduce_latency": 0.0}

View file

@ -100,11 +100,11 @@ class MockEnv2(gym.Env):
class MockVectorEnv(VectorEnv):
def __init__(self, episode_length, num_envs):
super().__init__()
super().__init__(
observation_space=gym.spaces.Discrete(1),
action_space=gym.spaces.Discrete(2),
num_envs=num_envs)
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
self.observation_space = gym.spaces.Discrete(1)
self.action_space = gym.spaces.Discrete(2)
self.num_envs = num_envs
def vector_reset(self):
return [e.reset() for e in self.envs]

View file

@ -38,8 +38,8 @@ class PolicyClient:
"""Record the start of an episode.
Arguments:
episode_id (str): Unique string id for the episode or None for
it to be auto-assigned.
episode_id (Optional[str]): Unique string id for the episode or
None for it to be auto-assigned.
training_enabled (bool): Whether to use experiences for this
episode to improve the policy.