mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Unity3D integration (n Unity3D clients vs learning server). (#8590)
This commit is contained in:
parent
016337d4eb
commit
d8a081a185
31 changed files with 870 additions and 191 deletions
BIN
doc/source/images/rllib-training-inside-a-unity3d-env.png
Normal file
BIN
doc/source/images/rllib-training-inside-a-unity3d-env.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 264 KiB |
|
@ -36,7 +36,7 @@ You can pass either a string name or a Python class to specify an environment. B
|
||||||
while True:
|
while True:
|
||||||
print(trainer.train())
|
print(trainer.train())
|
||||||
|
|
||||||
You can also register a custom env creator function with a string name. This function must take a single ``env_config`` parameter and return an env instance:
|
You can also register a custom env creator function with a string name. This function must take a single ``env_config`` (dict) parameter and return an env instance:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
@ -113,19 +113,20 @@ When using remote envs, you can control the batching level for inference with ``
|
||||||
Multi-Agent and Hierarchical
|
Multi-Agent and Hierarchical
|
||||||
----------------------------
|
----------------------------
|
||||||
|
|
||||||
A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car" and "traffic light" agents in the environment. The model for multi-agent in RLlib as follows: (1) as a user you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure:
|
A multi-agent environment is one which has multiple acting entities per step, e.g., in a traffic simulation, there may be multiple "car"- and "traffic light" agents in the environment. The model for multi-agent in RLlib is as follows: (1) as a user, you define the number of policies available up front, and (2) a function that maps agent ids to policy ids. This is summarized by the below figure:
|
||||||
|
|
||||||
.. image:: multi-agent.svg
|
.. image:: multi-agent.svg
|
||||||
|
|
||||||
The environment itself must subclass the `MultiAgentEnv <https://github.com/ray-project/ray/blob/master/rllib/env/multi_agent_env.py>`__ interface, which can returns observations and rewards from multiple ready agents per step:
|
The environment itself must subclass the `MultiAgentEnv <https://github.com/ray-project/ray/blob/master/rllib/env/multi_agent_env.py>`__ interface, which can return observations and rewards from multiple ready agents per step:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
# Example: using a multi-agent env
|
# Example: using a multi-agent env
|
||||||
> env = MultiAgentTrafficEnv(num_cars=20, num_traffic_lights=5)
|
> env = MultiAgentTrafficEnv(num_cars=20, num_traffic_lights=5)
|
||||||
|
|
||||||
# Observations are a dict mapping agent names to their obs. Not all agents
|
# Observations are a dict mapping agent names to their obs. Only those
|
||||||
# may be present in the dict in each time step.
|
# agents' names that require actions in the next call to `step()` will
|
||||||
|
# be present in the returned observation dict.
|
||||||
> print(env.reset())
|
> print(env.reset())
|
||||||
{
|
{
|
||||||
"car_1": [[...]],
|
"car_1": [[...]],
|
||||||
|
@ -133,14 +134,15 @@ The environment itself must subclass the `MultiAgentEnv <https://github.com/ray-
|
||||||
"traffic_light_1": [[...]],
|
"traffic_light_1": [[...]],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Actions should be provided for each agent that returned an observation.
|
# In the following call to `step`, actions should be provided for each
|
||||||
> new_obs, rewards, dones, infos = env.step(actions={"car_1": ..., "car_2": ...})
|
# agent that returned an observation before:
|
||||||
|
> new_obs, rewards, dones, infos = env.step(actions={"car_1": ..., "car_2": ..., "traffic_light_1": ...})
|
||||||
|
|
||||||
# Similarly, new_obs, rewards, dones, etc. also become dicts
|
# Similarly, new_obs, rewards, dones, etc. also become dicts
|
||||||
> print(rewards)
|
> print(rewards)
|
||||||
{"car_1": 3, "car_2": -1, "traffic_light_1": 0}
|
{"car_1": 3, "car_2": -1, "traffic_light_1": 0}
|
||||||
|
|
||||||
# Individual agents can early exit; env is done when "__all__" = True
|
# Individual agents can early exit; The entire episode is done when "__all__" = True
|
||||||
> print(dones)
|
> print(dones)
|
||||||
{"car_2": True, "__all__": False}
|
{"car_2": True, "__all__": False}
|
||||||
|
|
||||||
|
@ -305,9 +307,14 @@ See this file for a runnable example: `hierarchical_training.py <https://github.
|
||||||
External Agents and Applications
|
External Agents and Applications
|
||||||
--------------------------------
|
--------------------------------
|
||||||
|
|
||||||
In many situations, it does not make sense for an environment to be "stepped" by RLlib. For example, if a policy is to be used in a web serving system, then it is more natural for an agent to query a service that serves policy decisions, and for that service to learn from experience over time. This case also naturally arises with **external simulators** that run independently outside the control of RLlib, but may still want to leverage RLlib for training.
|
In many situations, it does not make sense for an environment to be "stepped" by RLlib. For example, if a policy is to be used in a web serving system, then it is more natural for an agent to query a service that serves policy decisions, and for that service to learn from experience over time. This case also naturally arises with **external simulators** (e.g. Unity3D, other game engines, or the Gazebo robotics simulator) that run independently outside the control of RLlib, but may still want to leverage RLlib for training.
|
||||||
|
|
||||||
RLlib provides the `ExternalEnv <https://github.com/ray-project/ray/blob/master/rllib/env/external_env.py>`__ class for this purpose. Unlike other envs, ExternalEnv has its own thread of control. At any point, agents on that thread can query the current policy for decisions via ``self.get_action()`` and reports rewards via ``self.log_returns()``. This can be done for multiple concurrent episodes as well.
|
.. figure:: images/rllib-training-inside-a-unity3d-env.png
|
||||||
|
:scale: 75 %
|
||||||
|
|
||||||
|
A Unity3D soccer game being learnt by RLlib via the ExternalEnv API.
|
||||||
|
|
||||||
|
RLlib provides the `ExternalEnv <https://github.com/ray-project/ray/blob/master/rllib/env/external_env.py>`__ class for this purpose. Unlike other envs, ExternalEnv has its own thread of control. At any point, agents on that thread can query the current policy for decisions via ``self.get_action()`` and reports rewards, done-dicts, and infos via ``self.log_returns()``. This can be done for multiple concurrent episodes as well.
|
||||||
|
|
||||||
Logging off-policy actions
|
Logging off-policy actions
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -330,8 +337,8 @@ You can configure any Trainer to launch a policy server with the following confi
|
||||||
trainer_config = {
|
trainer_config = {
|
||||||
# An environment class is still required, but it doesn't need to be runnable.
|
# An environment class is still required, but it doesn't need to be runnable.
|
||||||
# You only need to define its action and observation space attributes.
|
# You only need to define its action and observation space attributes.
|
||||||
|
# See examples/serving/unity3d_server.py for an example using a RandomMultiAgentEnv stub.
|
||||||
"env": YOUR_ENV_STUB,
|
"env": YOUR_ENV_STUB,
|
||||||
|
|
||||||
# Use the policy server to generate experiences.
|
# Use the policy server to generate experiences.
|
||||||
"input": (
|
"input": (
|
||||||
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
|
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
|
||||||
|
@ -360,7 +367,13 @@ To understand the difference between standard envs, external envs, and connectin
|
||||||
.. https://docs.google.com/drawings/d/1hJvT9bVGHVrGTbnCZK29BYQIcYNRbZ4Dr6FOPMJDjUs/edit
|
.. https://docs.google.com/drawings/d/1hJvT9bVGHVrGTbnCZK29BYQIcYNRbZ4Dr6FOPMJDjUs/edit
|
||||||
.. image:: rllib-external.svg
|
.. image:: rllib-external.svg
|
||||||
|
|
||||||
Try it yourself by launching a `cartpole_server.py <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_server.py>`__, and connecting to it with any number of clients (`cartpole_client.py <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_client.py>`__):
|
Try it yourself by launching either a
|
||||||
|
`simple CartPole server <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_server.py>`__ (see below), and connecting it to any number of clients
|
||||||
|
(`cartpole_client.py <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/cartpole_client.py>`__) or
|
||||||
|
run a `Unity3D learning sever <https://github.com/ray-project/ray/blob/master/rllib/examples/serving/unity3d_server.py>`__
|
||||||
|
against distributed Unity game engines in the cloud.
|
||||||
|
|
||||||
|
CartPole Example:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
@ -391,9 +404,9 @@ Try it yourself by launching a `cartpole_server.py <https://github.com/ray-proje
|
||||||
Total reward: 200.0
|
Total reward: 200.0
|
||||||
...
|
...
|
||||||
|
|
||||||
For the best performance, when possible we recommend using ``inference_mode="local"`` when possible.
|
For the best performance, we recommend using ``inference_mode="local"`` when possible.
|
||||||
|
|
||||||
Advanced Integrations
|
Advanced Integrations
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
For more complex / high-performance environment integrations, you can instead extend the low-level `BaseEnv <https://github.com/ray-project/ray/blob/master/rllib/env/base_env.py>`__ class. This low-level API models multiple agents executing asynchronously in multiple environments. A call to ``BaseEnv:poll()`` returns observations from ready agents keyed by their environment and agent ids, and actions for those agents are sent back via ``BaseEnv:send_actions()``. BaseEnv is used to implement all the other env types in RLlib, so it offers a superset of their functionality. For example, ``BaseEnv`` is used to implement dynamic batching of observations for inference over `multiple simulator actors <https://github.com/ray-project/ray/blob/master/rllib/env/remote_vector_env.py>`__.
|
For more complex / high-performance environment integrations, you can instead extend the low-level `BaseEnv <https://github.com/ray-project/ray/blob/master/rllib/env/base_env.py>`__ class. This low-level API models multiple agents executing asynchronously in multiple environments. A call to ``BaseEnv:poll()`` returns observations from ready agents keyed by 1) their environment, then 2) agent ids. Actions for those agents are sent back via ``BaseEnv:send_actions()``. BaseEnv is used to implement all the other env types in RLlib, so it offers a superset of their functionality. For example, ``BaseEnv`` is used to implement dynamic batching of observations for inference over `multiple simulator actors <https://github.com/ray-project/ray/blob/master/rllib/env/remote_vector_env.py>`__.
|
||||||
|
|
|
@ -36,12 +36,18 @@ Training Workflows
|
||||||
Custom Envs and Models
|
Custom Envs and Models
|
||||||
----------------------
|
----------------------
|
||||||
|
|
||||||
|
- `Local Unity3D multi-agent environment example <https://github.com/ray-project/ray/tree/master/rllib/examples/unity3d_env_local.py>`__:
|
||||||
|
Example of how to setup an RLlib Trainer against a locally running Unity3D editor instance to
|
||||||
|
learn any Unity3D game (including support for multi-agent).
|
||||||
|
Use this example to try things out and watch the game and the learning progress live in the editor.
|
||||||
|
Providing a compiled game, this example could also run in distributed fashion with `num_workers > 0`.
|
||||||
|
For a more heavy-weight, distributed, cloud-based example, see `Unity3D client/server`_ below.
|
||||||
- `Registering a custom env and model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__:
|
- `Registering a custom env and model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_env.py>`__:
|
||||||
Example of defining and registering a gym env and model for use with RLlib.
|
Example of defining and registering a gym env and model for use with RLlib.
|
||||||
- `Custom Keras model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_model.py>`__:
|
- `Custom Keras model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_model.py>`__:
|
||||||
Example of using a custom Keras model.
|
Example of using a custom Keras model.
|
||||||
- `Custom Keras RNN model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_keras_rnn_model.py>`__:
|
- `Custom Keras RNN model <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_rnn_model.py>`__:
|
||||||
Example of using a custom Keras RNN model.
|
Example of using a custom Keras- or PyTorch RNN model.
|
||||||
- `Registering a custom model with supervised loss <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_loss.py>`__:
|
- `Registering a custom model with supervised loss <https://github.com/ray-project/ray/blob/master/rllib/examples/custom_loss.py>`__:
|
||||||
Example of defining and registering a custom model with a supervised loss.
|
Example of defining and registering a custom model with a supervised loss.
|
||||||
- `Subprocess environment <https://github.com/ray-project/ray/blob/master/rllib/tests/test_env_with_subprocess.py>`__:
|
- `Subprocess environment <https://github.com/ray-project/ray/blob/master/rllib/tests/test_env_with_subprocess.py>`__:
|
||||||
|
@ -55,7 +61,16 @@ Custom Envs and Models
|
||||||
|
|
||||||
Serving and Offline
|
Serving and Offline
|
||||||
-------------------
|
-------------------
|
||||||
- `CartPole server <https://github.com/ray-project/ray/tree/master/rllib/examples/serving>`__:
|
|
||||||
|
.. _Unity3D client/server:
|
||||||
|
|
||||||
|
- `Unity3D client/server <https://github.com/ray-project/ray/tree/master/rllib/examples/serving/unity3d_server.py>`__:
|
||||||
|
Example of how to setup n distributed Unity3D (compiled) games in the cloud that function as data collecting
|
||||||
|
clients against a central RLlib Policy server learning how to play the game.
|
||||||
|
The n distributed clients could themselves be servers for external/human players and allow for control
|
||||||
|
being fully in the hands of the Unity entities instead of RLlib.
|
||||||
|
Note: Uses Unity's MLAgents SDK (>=1.0) and supports all provided MLAgents example games and multi-agent setups.
|
||||||
|
- `CartPole client/server <https://github.com/ray-project/ray/tree/master/rllib/examples/serving/cartpole_server.py>`__:
|
||||||
Example of online serving of predictions for a simple CartPole policy.
|
Example of online serving of predictions for a simple CartPole policy.
|
||||||
- `Saving experiences <https://github.com/ray-project/ray/blob/master/rllib/examples/saving_experiences.py>`__:
|
- `Saving experiences <https://github.com/ray-project/ray/blob/master/rllib/examples/saving_experiences.py>`__:
|
||||||
Example of how to externally generate experience batches in RLlib-compatible format.
|
Example of how to externally generate experience batches in RLlib-compatible format.
|
||||||
|
|
41
rllib/BUILD
41
rllib/BUILD
|
@ -983,6 +983,29 @@ py_test(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# Env tests
|
||||||
|
# rllib/env/
|
||||||
|
#
|
||||||
|
# Tag: env
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
|
sh_test(
|
||||||
|
name = "env/tests/test_local_inference",
|
||||||
|
tags = ["env"],
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["env/tests/test_local_inference.sh"],
|
||||||
|
data = glob(["examples/serving/*.py"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
sh_test(
|
||||||
|
name = "env/tests/test_remote_inference",
|
||||||
|
tags = ["env"],
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["env/tests/test_remote_inference.sh"],
|
||||||
|
data = glob(["examples/serving/*.py"]),
|
||||||
|
)
|
||||||
|
|
||||||
# --------------------------------------------------------------------
|
# --------------------------------------------------------------------
|
||||||
# Models and Distributions
|
# Models and Distributions
|
||||||
# rllib/models/
|
# rllib/models/
|
||||||
|
@ -1692,7 +1715,7 @@ py_test(
|
||||||
name = "examples/multi_agent_cartpole_torch",
|
name = "examples/multi_agent_cartpole_torch",
|
||||||
main = "examples/multi_agent_cartpole.py",
|
main = "examples/multi_agent_cartpole.py",
|
||||||
tags = ["examples", "examples_M"],
|
tags = ["examples", "examples_M"],
|
||||||
size = "small",
|
size = "medium",
|
||||||
srcs = ["examples/multi_agent_cartpole.py"],
|
srcs = ["examples/multi_agent_cartpole.py"],
|
||||||
args = ["--as-test", "--torch", "--stop-reward=70.0", "--num-cpus=4"]
|
args = ["--as-test", "--torch", "--stop-reward=70.0", "--num-cpus=4"]
|
||||||
)
|
)
|
||||||
|
@ -1822,22 +1845,6 @@ py_test(
|
||||||
args = ["--as-test", "--torch"],
|
args = ["--as-test", "--torch"],
|
||||||
)
|
)
|
||||||
|
|
||||||
sh_test(
|
|
||||||
name = "examples/serving/test_local_inference",
|
|
||||||
tags = ["examples", "examples_S", "exclusive"],
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["examples/serving/test_local_inference.sh"],
|
|
||||||
data = glob(["examples/serving/*.py"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
sh_test(
|
|
||||||
name = "examples/serving/test_remote_inference",
|
|
||||||
tags = ["examples", "examples_S", "exclusive"],
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["examples/serving/test_remote_inference.sh"],
|
|
||||||
data = glob(["examples/serving/*.py"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
name = "examples/two_trainer_workflow_tf",
|
name = "examples/two_trainer_workflow_tf",
|
||||||
main = "examples/two_trainer_workflow.py",
|
main = "examples/two_trainer_workflow.py",
|
||||||
|
|
|
@ -26,7 +26,7 @@ class TestDQN(unittest.TestCase):
|
||||||
num_iterations = 1
|
num_iterations = 1
|
||||||
|
|
||||||
for fw in framework_iterator(config):
|
for fw in framework_iterator(config):
|
||||||
# double-dueling DQN.
|
# Double-dueling DQN.
|
||||||
plain_config = config.copy()
|
plain_config = config.copy()
|
||||||
trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0")
|
trainer = dqn.DQNTrainer(config=plain_config, env="CartPole-v0")
|
||||||
for i in range(num_iterations):
|
for i in range(num_iterations):
|
||||||
|
|
|
@ -68,7 +68,8 @@ class OnlineLinearRegression(nn.Module):
|
||||||
return batch_dots.sqrt()
|
return batch_dots.sqrt()
|
||||||
|
|
||||||
def forward(self, x, sample_theta=False):
|
def forward(self, x, sample_theta=False):
|
||||||
""" Predict the scores on input batch using the underlying linear model
|
""" Predict scores on input batch using the underlying linear model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input feature tensor of shape
|
x (torch.Tensor): Input feature tensor of shape
|
||||||
(batch_size, feature_dim)
|
(batch_size, feature_dim)
|
||||||
|
|
29
rllib/env/base_env.py
vendored
29
rllib/env/base_env.py
vendored
|
@ -1,7 +1,7 @@
|
||||||
from ray.rllib.env.external_env import ExternalEnv
|
from ray.rllib.env.external_env import ExternalEnv
|
||||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||||
from ray.rllib.env.vector_env import VectorEnv
|
|
||||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
|
from ray.rllib.env.vector_env import VectorEnv
|
||||||
from ray.rllib.utils.annotations import override, PublicAPI
|
from ray.rllib.utils.annotations import override, PublicAPI
|
||||||
|
|
||||||
ASYNC_RESET_RETURN = "async_reset_return"
|
ASYNC_RESET_RETURN = "async_reset_return"
|
||||||
|
@ -99,16 +99,13 @@ class BaseEnv:
|
||||||
make_env=make_env,
|
make_env=make_env,
|
||||||
existing_envs=[env],
|
existing_envs=[env],
|
||||||
num_envs=num_envs)
|
num_envs=num_envs)
|
||||||
elif isinstance(env, ExternalMultiAgentEnv):
|
|
||||||
if num_envs != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"ExternalMultiAgentEnv does not currently support "
|
|
||||||
"num_envs > 1.")
|
|
||||||
env = _ExternalEnvToBaseEnv(env, multiagent=True)
|
|
||||||
elif isinstance(env, ExternalEnv):
|
elif isinstance(env, ExternalEnv):
|
||||||
if num_envs != 1:
|
if num_envs != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ExternalEnv does not currently support num_envs > 1.")
|
"External(MultiAgent)Env does not currently support "
|
||||||
|
"num_envs > 1. One way of solving this would be to "
|
||||||
|
"treat your Env as a MultiAgentEnv hosting only one "
|
||||||
|
"type of agent but with several copies.")
|
||||||
env = _ExternalEnvToBaseEnv(env)
|
env = _ExternalEnvToBaseEnv(env)
|
||||||
elif isinstance(env, VectorEnv):
|
elif isinstance(env, VectorEnv):
|
||||||
env = _VectorEnvToBaseEnv(env)
|
env = _VectorEnvToBaseEnv(env)
|
||||||
|
@ -166,12 +163,16 @@ class BaseEnv:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def try_reset(self, env_id):
|
def try_reset(self, env_id=None):
|
||||||
"""Attempt to reset the env with the given id.
|
"""Attempt to reset the sub-env with the given id or all sub-envs.
|
||||||
|
|
||||||
If the environment does not support synchronous reset, None can be
|
If the environment does not support synchronous reset, None can be
|
||||||
returned here.
|
returned here.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_id (Optional[int]): The sub-env ID if applicable. If None,
|
||||||
|
reset the entire Env (i.e. all sub-envs).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
obs (dict|None): Resetted observation or None if not supported.
|
obs (dict|None): Resetted observation or None if not supported.
|
||||||
"""
|
"""
|
||||||
|
@ -206,10 +207,10 @@ def _with_dummy_agent_id(env_id_to_values, dummy_id=_DUMMY_AGENT_ID):
|
||||||
class _ExternalEnvToBaseEnv(BaseEnv):
|
class _ExternalEnvToBaseEnv(BaseEnv):
|
||||||
"""Internal adapter of ExternalEnv to BaseEnv."""
|
"""Internal adapter of ExternalEnv to BaseEnv."""
|
||||||
|
|
||||||
def __init__(self, external_env, preprocessor=None, multiagent=False):
|
def __init__(self, external_env, preprocessor=None):
|
||||||
self.external_env = external_env
|
self.external_env = external_env
|
||||||
self.prep = preprocessor
|
self.prep = preprocessor
|
||||||
self.multiagent = multiagent
|
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
|
||||||
self.action_space = external_env.action_space
|
self.action_space = external_env.action_space
|
||||||
if preprocessor:
|
if preprocessor:
|
||||||
self.observation_space = preprocessor.observation_space
|
self.observation_space = preprocessor.observation_space
|
||||||
|
@ -262,8 +263,8 @@ class _ExternalEnvToBaseEnv(BaseEnv):
|
||||||
if "off_policy_action" in data:
|
if "off_policy_action" in data:
|
||||||
off_policy_actions[eid] = data["off_policy_action"]
|
off_policy_actions[eid] = data["off_policy_action"]
|
||||||
if self.multiagent:
|
if self.multiagent:
|
||||||
# ensure a consistent set of keys
|
# Ensure a consistent set of keys
|
||||||
# rely on all_obs having all possible keys for now
|
# rely on all_obs having all possible keys for now.
|
||||||
for eid, eid_dict in all_obs.items():
|
for eid, eid_dict in all_obs.items():
|
||||||
for agent_id in eid_dict.keys():
|
for agent_id in eid_dict.keys():
|
||||||
|
|
||||||
|
|
24
rllib/env/external_env.py
vendored
24
rllib/env/external_env.py
vendored
|
@ -32,16 +32,14 @@ class ExternalEnv(threading.Thread):
|
||||||
>>> register_env("my_env", lambda config: YourExternalEnv(config))
|
>>> register_env("my_env", lambda config: YourExternalEnv(config))
|
||||||
>>> trainer = DQNTrainer(env="my_env")
|
>>> trainer = DQNTrainer(env="my_env")
|
||||||
>>> while True:
|
>>> while True:
|
||||||
print(trainer.train())
|
>>> print(trainer.train())
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def __init__(self, action_space, observation_space, max_concurrent=100):
|
def __init__(self, action_space, observation_space, max_concurrent=100):
|
||||||
"""Initialize an external env.
|
"""Initializes an external env.
|
||||||
|
|
||||||
ExternalEnv subclasses must call this during their __init__.
|
Args:
|
||||||
|
|
||||||
Arguments:
|
|
||||||
action_space (gym.Space): Action space of the env.
|
action_space (gym.Space): Action space of the env.
|
||||||
observation_space (gym.Space): Observation space of the env.
|
observation_space (gym.Space): Observation space of the env.
|
||||||
max_concurrent (int): Max number of active episodes to allow at
|
max_concurrent (int): Max number of active episodes to allow at
|
||||||
|
@ -49,6 +47,7 @@ class ExternalEnv(threading.Thread):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
|
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self.observation_space = observation_space
|
self.observation_space = observation_space
|
||||||
|
@ -78,9 +77,9 @@ class ExternalEnv(threading.Thread):
|
||||||
def start_episode(self, episode_id=None, training_enabled=True):
|
def start_episode(self, episode_id=None, training_enabled=True):
|
||||||
"""Record the start of an episode.
|
"""Record the start of an episode.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode_id (str): Unique string id for the episode or None for
|
episode_id (Optional[str]): Unique string id for the episode or
|
||||||
it to be auto-assigned.
|
None for it to be auto-assigned and returned.
|
||||||
training_enabled (bool): Whether to use experiences for this
|
training_enabled (bool): Whether to use experiences for this
|
||||||
episode to improve the policy.
|
episode to improve the policy.
|
||||||
|
|
||||||
|
@ -108,7 +107,7 @@ class ExternalEnv(threading.Thread):
|
||||||
def get_action(self, episode_id, observation):
|
def get_action(self, episode_id, observation):
|
||||||
"""Record an observation and get the on-policy action.
|
"""Record an observation and get the on-policy action.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode_id (str): Episode id returned from start_episode().
|
episode_id (str): Episode id returned from start_episode().
|
||||||
observation (obj): Current environment observation.
|
observation (obj): Current environment observation.
|
||||||
|
|
||||||
|
@ -123,7 +122,7 @@ class ExternalEnv(threading.Thread):
|
||||||
def log_action(self, episode_id, observation, action):
|
def log_action(self, episode_id, observation, action):
|
||||||
"""Record an observation and (off-policy) action taken.
|
"""Record an observation and (off-policy) action taken.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode_id (str): Episode id returned from start_episode().
|
episode_id (str): Episode id returned from start_episode().
|
||||||
observation (obj): Current environment observation.
|
observation (obj): Current environment observation.
|
||||||
action (obj): Action for the observation.
|
action (obj): Action for the observation.
|
||||||
|
@ -140,7 +139,7 @@ class ExternalEnv(threading.Thread):
|
||||||
episode. Rewards accumulate until the next action. If no reward is
|
episode. Rewards accumulate until the next action. If no reward is
|
||||||
logged before the next action, a reward of 0.0 is assumed.
|
logged before the next action, a reward of 0.0 is assumed.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode_id (str): Episode id returned from start_episode().
|
episode_id (str): Episode id returned from start_episode().
|
||||||
reward (float): Reward from the environment.
|
reward (float): Reward from the environment.
|
||||||
info (dict): Optional info dict.
|
info (dict): Optional info dict.
|
||||||
|
@ -156,7 +155,7 @@ class ExternalEnv(threading.Thread):
|
||||||
def end_episode(self, episode_id, observation):
|
def end_episode(self, episode_id, observation):
|
||||||
"""Record the end of an episode.
|
"""Record the end of an episode.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode_id (str): Episode id returned from start_episode().
|
episode_id (str): Episode id returned from start_episode().
|
||||||
observation (obj): Current environment observation.
|
observation (obj): Current environment observation.
|
||||||
"""
|
"""
|
||||||
|
@ -267,6 +266,7 @@ class _ExternalEnvEpisode:
|
||||||
self.cur_reward = 0.0
|
self.cur_reward = 0.0
|
||||||
if not self.training_enabled:
|
if not self.training_enabled:
|
||||||
item["info"]["training_enabled"] = False
|
item["info"]["training_enabled"] = False
|
||||||
|
|
||||||
with self.results_avail_condition:
|
with self.results_avail_condition:
|
||||||
self.data_queue.put_nowait(item)
|
self.data_queue.put_nowait(item)
|
||||||
self.results_avail_condition.notify()
|
self.results_avail_condition.notify()
|
||||||
|
|
7
rllib/env/external_multi_agent_env.py
vendored
7
rllib/env/external_multi_agent_env.py
vendored
|
@ -14,7 +14,7 @@ class ExternalMultiAgentEnv(ExternalEnv):
|
||||||
|
|
||||||
ExternalMultiAgentEnv subclasses must call this during their __init__.
|
ExternalMultiAgentEnv subclasses must call this during their __init__.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
action_space (gym.Space): Action space of the env.
|
action_space (gym.Space): Action space of the env.
|
||||||
observation_space (gym.Space): Observation space of the env.
|
observation_space (gym.Space): Observation space of the env.
|
||||||
max_concurrent (int): Max number of active episodes to allow at
|
max_concurrent (int): Max number of active episodes to allow at
|
||||||
|
@ -135,10 +135,7 @@ class ExternalMultiAgentEnv(ExternalEnv):
|
||||||
|
|
||||||
if multiagent_done_dict:
|
if multiagent_done_dict:
|
||||||
for agent, done in multiagent_done_dict.items():
|
for agent, done in multiagent_done_dict.items():
|
||||||
if agent in episode.cur_done_dict:
|
episode.cur_done_dict[agent] = done
|
||||||
episode.cur_done_dict[agent] = done
|
|
||||||
else:
|
|
||||||
episode.cur_done_dict[agent] = done
|
|
||||||
|
|
||||||
if info_dict:
|
if info_dict:
|
||||||
episode.cur_info_dict = info_dict or {}
|
episode.cur_info_dict = info_dict or {}
|
||||||
|
|
6
rllib/env/multi_agent_env.py
vendored
6
rllib/env/multi_agent_env.py
vendored
|
@ -21,9 +21,9 @@ class MultiAgentEnv:
|
||||||
"traffic_light_1": [0, 3, 5, 1],
|
"traffic_light_1": [0, 3, 5, 1],
|
||||||
}
|
}
|
||||||
>>> obs, rewards, dones, infos = env.step(
|
>>> obs, rewards, dones, infos = env.step(
|
||||||
action_dict={
|
... action_dict={
|
||||||
"car_0": 1, "car_1": 0, "traffic_light_1": 2,
|
... "car_0": 1, "car_1": 0, "traffic_light_1": 2,
|
||||||
})
|
... })
|
||||||
>>> print(rewards)
|
>>> print(rewards)
|
||||||
{
|
{
|
||||||
"car_0": 3,
|
"car_0": 3,
|
||||||
|
|
61
rllib/env/policy_client.py
vendored
61
rllib/env/policy_client.py
vendored
|
@ -11,6 +11,7 @@ import time
|
||||||
import ray.cloudpickle as pickle
|
import ray.cloudpickle as pickle
|
||||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||||
from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv
|
from ray.rllib.env import ExternalEnv, MultiAgentEnv, ExternalMultiAgentEnv
|
||||||
|
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||||
from ray.rllib.utils.annotations import PublicAPI
|
from ray.rllib.utils.annotations import PublicAPI
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -54,6 +55,7 @@ class PolicyClient:
|
||||||
or None for manual control via client.
|
or None for manual control via client.
|
||||||
"""
|
"""
|
||||||
self.address = address
|
self.address = address
|
||||||
|
self.env = None
|
||||||
if inference_mode == "local":
|
if inference_mode == "local":
|
||||||
self.local = True
|
self.local = True
|
||||||
self._setup_local_rollout_worker(update_interval)
|
self._setup_local_rollout_worker(update_interval)
|
||||||
|
@ -65,11 +67,11 @@ class PolicyClient:
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def start_episode(self, episode_id=None, training_enabled=True):
|
def start_episode(self, episode_id=None, training_enabled=True):
|
||||||
"""Record the start of an episode.
|
"""Record the start of one or more episode(s).
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode_id (str): Unique string id for the episode or None for
|
episode_id (Optional[str]): Unique string id for the episode or
|
||||||
it to be auto-assigned.
|
None for it to be auto-assigned.
|
||||||
training_enabled (bool): Whether to use experiences for this
|
training_enabled (bool): Whether to use experiences for this
|
||||||
episode to improve the policy.
|
episode to improve the policy.
|
||||||
|
|
||||||
|
@ -101,13 +103,20 @@ class PolicyClient:
|
||||||
|
|
||||||
if self.local:
|
if self.local:
|
||||||
self._update_local_policy()
|
self._update_local_policy()
|
||||||
return self.env.get_action(episode_id, observation)
|
if isinstance(episode_id, (list, tuple)):
|
||||||
|
actions = {
|
||||||
return self._send({
|
eid: self.env.get_action(eid, observation[eid])
|
||||||
"command": PolicyClient.GET_ACTION,
|
for eid in episode_id
|
||||||
"observation": observation,
|
}
|
||||||
"episode_id": episode_id,
|
return actions
|
||||||
})["action"]
|
else:
|
||||||
|
return self.env.get_action(episode_id, observation)
|
||||||
|
else:
|
||||||
|
return self._send({
|
||||||
|
"command": PolicyClient.GET_ACTION,
|
||||||
|
"observation": observation,
|
||||||
|
"episode_id": episode_id,
|
||||||
|
})["action"]
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def log_action(self, episode_id, observation, action):
|
def log_action(self, episode_id, observation, action):
|
||||||
|
@ -151,11 +160,11 @@ class PolicyClient:
|
||||||
|
|
||||||
if self.local:
|
if self.local:
|
||||||
self._update_local_policy()
|
self._update_local_policy()
|
||||||
if multiagent_done_dict:
|
if multiagent_done_dict is not None:
|
||||||
|
assert isinstance(reward, dict)
|
||||||
return self.env.log_returns(episode_id, reward, info,
|
return self.env.log_returns(episode_id, reward, info,
|
||||||
multiagent_done_dict)
|
multiagent_done_dict)
|
||||||
else:
|
return self.env.log_returns(episode_id, reward, info)
|
||||||
return self.env.log_returns(episode_id, reward, info)
|
|
||||||
|
|
||||||
self._send({
|
self._send({
|
||||||
"command": PolicyClient.LOG_RETURNS,
|
"command": PolicyClient.LOG_RETURNS,
|
||||||
|
@ -207,7 +216,6 @@ class PolicyClient:
|
||||||
kwargs = self._send({
|
kwargs = self._send({
|
||||||
"command": PolicyClient.GET_WORKER_ARGS,
|
"command": PolicyClient.GET_WORKER_ARGS,
|
||||||
})["worker_args"]
|
})["worker_args"]
|
||||||
|
|
||||||
(self.rollout_worker,
|
(self.rollout_worker,
|
||||||
self.inference_thread) = create_embedded_rollout_worker(
|
self.inference_thread) = create_embedded_rollout_worker(
|
||||||
kwargs, self._send)
|
kwargs, self._send)
|
||||||
|
@ -245,8 +253,14 @@ class _LocalInferenceThread(threading.Thread):
|
||||||
logger.info("Generating new batch of experiences.")
|
logger.info("Generating new batch of experiences.")
|
||||||
samples = self.rollout_worker.sample()
|
samples = self.rollout_worker.sample()
|
||||||
metrics = self.rollout_worker.get_metrics()
|
metrics = self.rollout_worker.get_metrics()
|
||||||
logger.info("Sending batch of {} steps back to server.".format(
|
if isinstance(samples, MultiAgentBatch):
|
||||||
samples.count))
|
logger.info(
|
||||||
|
"Sending batch of {} env steps ({} agent steps) to "
|
||||||
|
"server.".format(samples.count, samples.total()))
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Sending batch of {} steps back to server.".format(
|
||||||
|
samples.count))
|
||||||
self.send_fn({
|
self.send_fn({
|
||||||
"command": PolicyClient.REPORT_SAMPLES,
|
"command": PolicyClient.REPORT_SAMPLES,
|
||||||
"samples": samples,
|
"samples": samples,
|
||||||
|
@ -265,11 +279,11 @@ def auto_wrap_external(real_env_creator):
|
||||||
|
|
||||||
def wrapped_creator(env_config):
|
def wrapped_creator(env_config):
|
||||||
real_env = real_env_creator(env_config)
|
real_env = real_env_creator(env_config)
|
||||||
if not (isinstance(real_env, ExternalEnv)
|
if not isinstance(real_env, (ExternalEnv, ExternalMultiAgentEnv)):
|
||||||
or isinstance(real_env, ExternalMultiAgentEnv)):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"The env you specified is not a type of ExternalEnv. "
|
"The env you specified is not a supported (sub-)type of "
|
||||||
"Attempting to convert it automatically to ExternalEnv.")
|
"ExternalEnv. Attempting to convert it automatically to "
|
||||||
|
"ExternalEnv.")
|
||||||
|
|
||||||
if isinstance(real_env, MultiAgentEnv):
|
if isinstance(real_env, MultiAgentEnv):
|
||||||
external_cls = ExternalMultiAgentEnv
|
external_cls = ExternalMultiAgentEnv
|
||||||
|
@ -278,8 +292,9 @@ def auto_wrap_external(real_env_creator):
|
||||||
|
|
||||||
class ExternalEnvWrapper(external_cls):
|
class ExternalEnvWrapper(external_cls):
|
||||||
def __init__(self, real_env):
|
def __init__(self, real_env):
|
||||||
super().__init__(real_env.action_space,
|
super().__init__(
|
||||||
real_env.observation_space)
|
observation_space=real_env.observation_space,
|
||||||
|
action_space=real_env.action_space)
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# Since we are calling methods on this class in the
|
# Since we are calling methods on this class in the
|
||||||
|
|
|
@ -4,8 +4,8 @@ rm -f last_checkpoint.out
|
||||||
pkill -f cartpole_server.py
|
pkill -f cartpole_server.py
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
if [ -f cartpole_server.py ]; then
|
if [ -f test_local_inference.sh ]; then
|
||||||
basedir="."
|
basedir="../../examples/serving"
|
||||||
else
|
else
|
||||||
basedir="rllib/examples/serving" # In bazel.
|
basedir="rllib/examples/serving" # In bazel.
|
||||||
fi
|
fi
|
||||||
|
@ -19,5 +19,5 @@ while ! curl localhost:9900; do
|
||||||
done
|
done
|
||||||
|
|
||||||
sleep 2
|
sleep 2
|
||||||
python $basedir/cartpole_client.py --stop-at-reward=100 --inference-mode=local
|
python $basedir/cartpole_client.py --stop-reward=150 --inference-mode=local
|
||||||
kill $pid
|
kill $pid
|
|
@ -4,8 +4,8 @@ rm -f last_checkpoint.out
|
||||||
pkill -f cartpole_server.py
|
pkill -f cartpole_server.py
|
||||||
sleep 1
|
sleep 1
|
||||||
|
|
||||||
if [ -f cartpole_server.py ]; then
|
if [ -f test_local_inference.sh ]; then
|
||||||
basedir="."
|
basedir="../../examples/serving"
|
||||||
else
|
else
|
||||||
basedir="rllib/examples/serving" # In bazel.
|
basedir="rllib/examples/serving" # In bazel.
|
||||||
fi
|
fi
|
||||||
|
@ -19,6 +19,6 @@ while ! curl localhost:9900; do
|
||||||
done
|
done
|
||||||
|
|
||||||
sleep 2
|
sleep 2
|
||||||
python $basedir/cartpole_client.py --stop-at-reward=100 --inference-mode=remote
|
python $basedir/cartpole_client.py --stop-reward=150 --inference-mode=remote
|
||||||
kill $pid
|
kill $pid
|
||||||
|
|
232
rllib/env/unity3d_env.py
vendored
Normal file
232
rllib/env/unity3d_env.py
vendored
Normal file
|
@ -0,0 +1,232 @@
|
||||||
|
from gym.spaces import Box, MultiDiscrete, Tuple
|
||||||
|
import logging
|
||||||
|
import mlagents_envs
|
||||||
|
from mlagents_envs.environment import UnityEnvironment
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
|
from ray.rllib.utils.annotations import override
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Unity3DEnv(MultiAgentEnv):
|
||||||
|
"""A MultiAgentEnv representing a single Unity3D game instance.
|
||||||
|
|
||||||
|
For an example on how to use this class inside a Unity game client, which
|
||||||
|
connects to an RLlib Policy server, see:
|
||||||
|
`rllib/examples/serving/unity3d_[client|server].py`
|
||||||
|
|
||||||
|
Supports all Unity3D (MLAgents) examples, multi- or single-agent and
|
||||||
|
gets converted automatically into an ExternalMultiAgentEnv, when used
|
||||||
|
inside an RLlib PolicyClient for cloud/distributed training of Unity games.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
file_name=None,
|
||||||
|
worker_id=0,
|
||||||
|
base_port=5004,
|
||||||
|
seed=0,
|
||||||
|
no_graphics=False,
|
||||||
|
timeout_wait=60,
|
||||||
|
episode_horizon=1000):
|
||||||
|
"""Initializes a Unity3DEnv object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_name (Optional[str]): Name of the Unity game binary.
|
||||||
|
If None, will assume a locally running Unity3D editor
|
||||||
|
to be used, instead.
|
||||||
|
worker_id (int): Number to add to `base_port`. Used when more than
|
||||||
|
one Unity3DEnv (games) are running on the same machine. This
|
||||||
|
will be determined automatically, if possible, so a value
|
||||||
|
of 0 should always suffice here.
|
||||||
|
base_port (int): Port number to connect to Unity environment.
|
||||||
|
`worker_id` increments on top of this.
|
||||||
|
seed (int): A random seed value to use for the Unity3D game.
|
||||||
|
no_graphics (bool): Whether to run the Unity3D simulator in
|
||||||
|
no-graphics mode. Default: False.
|
||||||
|
timeout_wait (int): Time (in seconds) to wait for connection from
|
||||||
|
the Unity3D instance.
|
||||||
|
episode_horizon (int): A hard horizon to abide to. After at most
|
||||||
|
this many steps (per-agent episode `step()` calls), the
|
||||||
|
Unity3D game is reset and will start again (finishing the
|
||||||
|
multi-agent episode that the game represents).
|
||||||
|
Note: The game itself may contain its own episode length
|
||||||
|
limits, which are always obeyed (on top of this value here).
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if file_name is None:
|
||||||
|
print(
|
||||||
|
"No game binary provided, will use a running Unity editor "
|
||||||
|
"instead.\nMake sure you are pressing the Play (|>) button in "
|
||||||
|
"your editor to start.")
|
||||||
|
|
||||||
|
# Try connecting to the Unity3D game instance. If a port
|
||||||
|
while True:
|
||||||
|
self.worker_id = worker_id
|
||||||
|
try:
|
||||||
|
self.unity_env = UnityEnvironment(
|
||||||
|
file_name=file_name,
|
||||||
|
worker_id=worker_id,
|
||||||
|
base_port=base_port,
|
||||||
|
seed=seed,
|
||||||
|
no_graphics=no_graphics,
|
||||||
|
timeout_wait=timeout_wait,
|
||||||
|
)
|
||||||
|
except mlagents_envs.exception.UnityWorkerInUseException as e:
|
||||||
|
worker_id += 1
|
||||||
|
# Hard limit.
|
||||||
|
if worker_id > 100:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Reset entire env every this number of step calls.
|
||||||
|
self.episode_horizon = episode_horizon
|
||||||
|
# Keep track of how many times we have called `step` so far.
|
||||||
|
self.episode_timesteps = 0
|
||||||
|
|
||||||
|
@override(MultiAgentEnv)
|
||||||
|
def step(self, action_dict):
|
||||||
|
"""Performs one multi-agent step through the game.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action_dict (dict): Multi-agent action dict with:
|
||||||
|
keys=agent identifier consisting of
|
||||||
|
[MLagents behavior name, e.g. "Goalie?team=1"] + "_" +
|
||||||
|
[Agent index, a unique MLAgent-assigned index per single
|
||||||
|
agent]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple:
|
||||||
|
obs: Multi-agent observation dict.
|
||||||
|
Only those observations for which to get new actions are
|
||||||
|
returned.
|
||||||
|
rewards: Rewards dict matching `obs`.
|
||||||
|
dones: Done dict with only an __all__ multi-agent entry in it.
|
||||||
|
__all__=True, if episode is done for all agents.
|
||||||
|
infos: An (empty) info dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Set only the required actions (from the DecisionSteps) in Unity3D.
|
||||||
|
all_agents = []
|
||||||
|
for behavior_name in self.unity_env.get_behavior_names():
|
||||||
|
for agent_id in self.unity_env.get_steps(behavior_name)[
|
||||||
|
0].agent_id_to_index.keys():
|
||||||
|
key = behavior_name + "_{}".format(agent_id)
|
||||||
|
all_agents.append(key)
|
||||||
|
self.unity_env.set_action_for_agent(behavior_name, agent_id,
|
||||||
|
action_dict[key])
|
||||||
|
# Do the step.
|
||||||
|
self.unity_env.step()
|
||||||
|
|
||||||
|
obs, rewards, dones, infos = self._get_step_results()
|
||||||
|
|
||||||
|
# Global horizon reached? -> Return __all__ done=True, so user
|
||||||
|
# can reset. Set all agents' individual `done` to True as well.
|
||||||
|
self.episode_timesteps += 1
|
||||||
|
if self.episode_timesteps > self.episode_horizon:
|
||||||
|
return obs, rewards, dict({
|
||||||
|
"__all__": True
|
||||||
|
}, **{agent_id: True
|
||||||
|
for agent_id in all_agents}), infos
|
||||||
|
|
||||||
|
return obs, rewards, dones, infos
|
||||||
|
|
||||||
|
@override(MultiAgentEnv)
|
||||||
|
def reset(self):
|
||||||
|
"""Resets the entire Unity3D scene (a single multi-agent episode)."""
|
||||||
|
self.episode_timesteps = 0
|
||||||
|
self.unity_env.reset()
|
||||||
|
obs, _, _, _ = self._get_step_results()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _get_step_results(self):
|
||||||
|
"""Collects those agents' obs/rewards that have to act in next `step`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple:
|
||||||
|
obs: Multi-agent observation dict.
|
||||||
|
Only those observations for which to get new actions are
|
||||||
|
returned.
|
||||||
|
rewards: Rewards dict matching `obs`.
|
||||||
|
dones: Done dict with only an __all__ multi-agent entry in it.
|
||||||
|
__all__=True, if episode is done for all agents.
|
||||||
|
infos: An (empty) info dict.
|
||||||
|
"""
|
||||||
|
obs = {}
|
||||||
|
rewards = {}
|
||||||
|
infos = {}
|
||||||
|
for behavior_name in self.unity_env.get_behavior_names():
|
||||||
|
decision_steps, terminal_steps = self.unity_env.get_steps(
|
||||||
|
behavior_name)
|
||||||
|
# Important: Only update those sub-envs that are currently
|
||||||
|
# available within _env_state.
|
||||||
|
# Loop through all envs ("agents") and fill in, whatever
|
||||||
|
# information we have.
|
||||||
|
for agent_id, idx in decision_steps.agent_id_to_index.items():
|
||||||
|
key = behavior_name + "_{}".format(agent_id)
|
||||||
|
os = tuple(o[idx] for o in decision_steps.obs)
|
||||||
|
os = os[0] if len(os) == 1 else os
|
||||||
|
obs[key] = os
|
||||||
|
rewards[key] = decision_steps.reward[idx] # rewards vector
|
||||||
|
for agent_id, idx in terminal_steps.agent_id_to_index.items():
|
||||||
|
key = behavior_name + "_{}".format(agent_id)
|
||||||
|
# Only overwrite rewards (last reward in episode), b/c obs
|
||||||
|
# here is the last obs (which doesn't matter anyways).
|
||||||
|
# Unless key does not exist in obs.
|
||||||
|
if key not in obs:
|
||||||
|
os = tuple(o[idx] for o in terminal_steps.obs)
|
||||||
|
obs[key] = os = os[0] if len(os) == 1 else os
|
||||||
|
rewards[key] = terminal_steps.reward[idx] # rewards vector
|
||||||
|
|
||||||
|
# Only use dones if all agents are done, then we should do a reset.
|
||||||
|
return obs, rewards, {"__all__": False}, infos
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_policy_configs_for_game(game_name):
|
||||||
|
|
||||||
|
# The RLlib server must know about the Spaces that the Client will be
|
||||||
|
# using inside Unity3D, up-front.
|
||||||
|
obs_spaces = {
|
||||||
|
# SoccerStrikersVsGoalie.
|
||||||
|
"Striker": Tuple([
|
||||||
|
Box(float("-inf"), float("inf"), (231, )),
|
||||||
|
Box(float("-inf"), float("inf"), (63, )),
|
||||||
|
]),
|
||||||
|
"Goalie": Box(float("-inf"), float("inf"), (738, )),
|
||||||
|
# 3DBall.
|
||||||
|
"Agent": Box(float("-inf"), float("inf"), (8, )),
|
||||||
|
}
|
||||||
|
action_spaces = {
|
||||||
|
# SoccerStrikersVsGoalie.
|
||||||
|
"Striker": MultiDiscrete([3, 3, 3]),
|
||||||
|
"Goalie": MultiDiscrete([3, 3, 3]),
|
||||||
|
# 3DBall.
|
||||||
|
"Agent": Box(float("-inf"), float("inf"), (2, ), dtype=np.float32),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Policies (Unity: "behaviors") and agent-to-policy mapping fns.
|
||||||
|
if game_name == "SoccerStrikersVsGoalie":
|
||||||
|
policies = {
|
||||||
|
"Striker": (None, obs_spaces["Striker"],
|
||||||
|
action_spaces["Striker"], {}),
|
||||||
|
"Goalie": (None, obs_spaces["Goalie"], action_spaces["Goalie"],
|
||||||
|
{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
def policy_mapping_fn(agent_id):
|
||||||
|
return "Striker" if "Striker" in agent_id else "Goalie"
|
||||||
|
|
||||||
|
else: # 3DBall
|
||||||
|
policies = {
|
||||||
|
"Agent": (None, obs_spaces["Agent"], action_spaces["Agent"],
|
||||||
|
{})
|
||||||
|
}
|
||||||
|
|
||||||
|
def policy_mapping_fn(agent_id):
|
||||||
|
return "Agent"
|
||||||
|
|
||||||
|
return policies, policy_mapping_fn
|
104
rllib/env/vector_env.py
vendored
104
rllib/env/vector_env.py
vendored
|
@ -8,31 +8,43 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
class VectorEnv:
|
class VectorEnv:
|
||||||
"""An environment that supports batch evaluation.
|
"""An environment that supports batch evaluation using clones of sub-envs.
|
||||||
|
|
||||||
Subclasses must define the following attributes:
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
action_space (gym.Space): Action space of individual envs.
|
|
||||||
observation_space (gym.Space): Observation space of individual envs.
|
|
||||||
num_envs (int): Number of envs in this vector env.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, observation_space, action_space, num_envs):
|
||||||
|
"""Initializes a VectorEnv object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation_space (Space): The observation Space of a single
|
||||||
|
sub-env.
|
||||||
|
action_space (Space): The action Space of a single sub-env.
|
||||||
|
num_envs (int): The number of clones to make of the given sub-env.
|
||||||
|
"""
|
||||||
|
self.observation_space = observation_space
|
||||||
|
self.action_space = action_space
|
||||||
|
self.num_envs = num_envs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wrap(make_env=None,
|
def wrap(make_env=None,
|
||||||
existing_envs=None,
|
existing_envs=None,
|
||||||
num_envs=1,
|
num_envs=1,
|
||||||
action_space=None,
|
action_space=None,
|
||||||
observation_space=None):
|
observation_space=None,
|
||||||
return _VectorizedGymEnv(make_env, existing_envs or [], num_envs,
|
env_config=None):
|
||||||
action_space, observation_space)
|
return _VectorizedGymEnv(
|
||||||
|
make_env=make_env,
|
||||||
|
existing_envs=existing_envs or [],
|
||||||
|
num_envs=num_envs,
|
||||||
|
observation_space=observation_space,
|
||||||
|
action_space=action_space,
|
||||||
|
env_config=env_config)
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def vector_reset(self):
|
def vector_reset(self):
|
||||||
"""Resets all environments.
|
"""Resets all sub-environments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
obs (list): Vector of observations from each environment.
|
obs (List[any]): List of observations from each environment.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -41,55 +53,73 @@ class VectorEnv:
|
||||||
"""Resets a single environment.
|
"""Resets a single environment.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
obs (obj): Observations from the resetted environment.
|
obs (obj): Observations from the reset sub environment.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def vector_step(self, actions):
|
def vector_step(self, actions):
|
||||||
"""Vectorized step.
|
"""Performs a vectorized step on all sub environments using `actions`.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
actions (list): Actions for each env.
|
actions (List[any]): List of actions (one for each sub-env).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
obs (list): New observations for each env.
|
obs (List[any]): New observations for each sub-env.
|
||||||
rewards (list): Reward values for each env.
|
rewards (List[any]): Reward values for each sub-env.
|
||||||
dones (list): Done values for each env.
|
dones (List[any]): Done values for each sub-env.
|
||||||
infos (list): Info values for each env.
|
infos (List[any]): Info values for each sub-env.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@PublicAPI
|
@PublicAPI
|
||||||
def get_unwrapped(self):
|
def get_unwrapped(self):
|
||||||
"""Returns the underlying env instances."""
|
"""Returns the underlying sub environments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Env]: List of all underlying sub environments.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class _VectorizedGymEnv(VectorEnv):
|
class _VectorizedGymEnv(VectorEnv):
|
||||||
"""Internal wrapper for gym envs to implement VectorEnv.
|
"""Internal wrapper to translate any gym envs into a VectorEnv object.
|
||||||
|
|
||||||
Arguments:
|
|
||||||
make_env (func|None): Factory that produces a new gym env. Must be
|
|
||||||
defined if the number of existing envs is less than num_envs.
|
|
||||||
existing_envs (list): List of existing gym envs.
|
|
||||||
num_envs (int): Desired num gym envs to keep total.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
make_env,
|
make_env=None,
|
||||||
existing_envs,
|
existing_envs=None,
|
||||||
num_envs,
|
num_envs=1,
|
||||||
|
*,
|
||||||
|
observation_space=None,
|
||||||
action_space=None,
|
action_space=None,
|
||||||
observation_space=None):
|
env_config=None):
|
||||||
|
"""Initializes a _VectorizedGymEnv object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
make_env (Optional[callable]): Factory that produces a new gym env
|
||||||
|
taking a single `config` dict arg. Must be defined if the
|
||||||
|
number of `existing_envs` is less than `num_envs`.
|
||||||
|
existing_envs (Optional[List[Env]]): Optional list of already
|
||||||
|
instantiated sub environments.
|
||||||
|
num_envs (int): Total number of sub environments in this VectorEnv.
|
||||||
|
action_space (Optional[Space]): The action space. If None, use
|
||||||
|
existing_envs[0]'s action space.
|
||||||
|
observation_space (Optional[Space]): The observation space.
|
||||||
|
If None, use existing_envs[0]'s action space.
|
||||||
|
env_config (Optional[dict]): Additional sub env config to pass to
|
||||||
|
make_env as first arg.
|
||||||
|
"""
|
||||||
self.make_env = make_env
|
self.make_env = make_env
|
||||||
self.envs = existing_envs
|
self.envs = existing_envs
|
||||||
self.num_envs = num_envs
|
while len(self.envs) < num_envs:
|
||||||
while len(self.envs) < self.num_envs:
|
|
||||||
self.envs.append(self.make_env(len(self.envs)))
|
self.envs.append(self.make_env(len(self.envs)))
|
||||||
self.action_space = action_space or self.envs[0].action_space
|
|
||||||
self.observation_space = observation_space or \
|
super().__init__(
|
||||||
self.envs[0].observation_space
|
observation_space=observation_space
|
||||||
|
or self.envs[0].observation_space,
|
||||||
|
action_space=action_space or self.envs[0].action_space,
|
||||||
|
num_envs=num_envs)
|
||||||
|
|
||||||
@override(VectorEnv)
|
@override(VectorEnv)
|
||||||
def vector_reset(self):
|
def vector_reset(self):
|
||||||
|
|
|
@ -303,11 +303,11 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
self.fake_sampler = fake_sampler
|
self.fake_sampler = fake_sampler
|
||||||
|
|
||||||
self.env = _validate_env(env_creator(env_context))
|
self.env = _validate_env(env_creator(env_context))
|
||||||
if isinstance(self.env, MultiAgentEnv) or \
|
if isinstance(self.env, (BaseEnv, MultiAgentEnv)):
|
||||||
isinstance(self.env, BaseEnv):
|
|
||||||
|
|
||||||
def wrap(env):
|
def wrap(env):
|
||||||
return env # we can't auto-wrap these env types
|
return env # we can't auto-wrap these env types
|
||||||
|
|
||||||
elif is_atari(self.env) and \
|
elif is_atari(self.env) and \
|
||||||
not model_config.get("custom_preprocessor") and \
|
not model_config.get("custom_preprocessor") and \
|
||||||
preprocessor_pref == "deepmind":
|
preprocessor_pref == "deepmind":
|
||||||
|
@ -411,7 +411,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||||
if self.worker_index == 0:
|
if self.worker_index == 0:
|
||||||
logger.info("Built filter map: {}".format(self.filters))
|
logger.info("Built filter map: {}".format(self.filters))
|
||||||
|
|
||||||
# Always use vector env for consistency even if num_envs = 1
|
# Always use vector env for consistency even if num_envs = 1.
|
||||||
self.async_env = BaseEnv.to_base_env(
|
self.async_env = BaseEnv.to_base_env(
|
||||||
self.env,
|
self.env,
|
||||||
make_env=make_env,
|
make_env=make_env,
|
||||||
|
|
|
@ -2,10 +2,10 @@ import collections
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ray.util.debug import log_once
|
|
||||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||||
from ray.rllib.utils.debug import summarize
|
from ray.rllib.utils.debug import summarize
|
||||||
|
from ray.util.debug import log_once
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -124,8 +124,9 @@ class MultiAgentSampleBatchBuilder:
|
||||||
This pushes the postprocessed per-agent batches onto the per-policy
|
This pushes the postprocessed per-agent batches onto the per-policy
|
||||||
builders, clearing per-agent state.
|
builders, clearing per-agent state.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode: current MultiAgentEpisode object or None
|
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
|
||||||
|
object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Materialize the batches so far
|
# Materialize the batches so far
|
||||||
|
@ -198,8 +199,9 @@ class MultiAgentSampleBatchBuilder:
|
||||||
Any unprocessed rows will be first postprocessed with a policy
|
Any unprocessed rows will be first postprocessed with a policy
|
||||||
postprocessor. The internal state of this builder will be reset.
|
postprocessor. The internal state of this builder will be reset.
|
||||||
|
|
||||||
Arguments:
|
Args:
|
||||||
episode: current MultiAgentEpisode object or None
|
episode (Optional[MultiAgentEpisode]): Current MultiAgentEpisode
|
||||||
|
object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.postprocess_batch_so_far(episode)
|
self.postprocess_batch_so_far(episode)
|
||||||
|
|
|
@ -341,7 +341,7 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||||
while True:
|
while True:
|
||||||
perf_stats.iters += 1
|
perf_stats.iters += 1
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
# Get observations from all ready agents
|
# Get observations from all ready agents.
|
||||||
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
unfiltered_obs, rewards, dones, infos, off_policy_actions = \
|
||||||
base_env.poll()
|
base_env.poll()
|
||||||
perf_stats.env_wait_time += time.time() - t0
|
perf_stats.env_wait_time += time.time() - t0
|
||||||
|
@ -351,7 +351,7 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||||
summarize(unfiltered_obs)))
|
summarize(unfiltered_obs)))
|
||||||
logger.info("Info return from env: {}".format(summarize(infos)))
|
logger.info("Info return from env: {}".format(summarize(infos)))
|
||||||
|
|
||||||
# Process observations and prepare for policy evaluation
|
# Process observations and prepare for policy evaluation.
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
active_envs, to_eval, outputs = _process_observations(
|
active_envs, to_eval, outputs = _process_observations(
|
||||||
worker, base_env, policies, batch_builder_pool, active_episodes,
|
worker, base_env, policies, batch_builder_pool, active_episodes,
|
||||||
|
@ -362,13 +362,13 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
yield o
|
yield o
|
||||||
|
|
||||||
# Do batched policy eval
|
# Do batched policy eval (accross vectorized envs).
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
|
eval_results = _do_policy_eval(tf_sess, to_eval, policies,
|
||||||
active_episodes)
|
active_episodes)
|
||||||
perf_stats.inference_time += time.time() - t2
|
perf_stats.inference_time += time.time() - t2
|
||||||
|
|
||||||
# Process results and update episode state
|
# Process results and update episode state.
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
actions_to_send = _process_policy_eval_results(
|
actions_to_send = _process_policy_eval_results(
|
||||||
to_eval, eval_results, active_episodes, active_envs,
|
to_eval, eval_results, active_episodes, active_envs,
|
||||||
|
@ -401,11 +401,11 @@ def _process_observations(
|
||||||
large_batch_threshold = max(1000, rollout_fragment_length * 10) if \
|
large_batch_threshold = max(1000, rollout_fragment_length * 10) if \
|
||||||
rollout_fragment_length != float("inf") else 5000
|
rollout_fragment_length != float("inf") else 5000
|
||||||
|
|
||||||
# For each environment
|
# For each environment.
|
||||||
for env_id, agent_obs in unfiltered_obs.items():
|
for env_id, agent_obs in unfiltered_obs.items():
|
||||||
new_episode = env_id not in active_episodes
|
is_new_episode = env_id not in active_episodes
|
||||||
episode = active_episodes[env_id]
|
episode = active_episodes[env_id]
|
||||||
if not new_episode:
|
if not is_new_episode:
|
||||||
episode.length += 1
|
episode.length += 1
|
||||||
episode.batch_builder.count += 1
|
episode.batch_builder.count += 1
|
||||||
episode._add_agent_rewards(rewards[env_id])
|
episode._add_agent_rewards(rewards[env_id])
|
||||||
|
@ -427,11 +427,11 @@ def _process_observations(
|
||||||
"to terminate (batch_mode=`complete_episodes`). Make sure it "
|
"to terminate (batch_mode=`complete_episodes`). Make sure it "
|
||||||
"does at some point.")
|
"does at some point.")
|
||||||
|
|
||||||
# Check episode termination conditions
|
# Check episode termination conditions.
|
||||||
if dones[env_id]["__all__"] or episode.length >= horizon:
|
if dones[env_id]["__all__"] or episode.length >= horizon:
|
||||||
hit_horizon = (episode.length >= horizon
|
hit_horizon = (episode.length >= horizon
|
||||||
and not dones[env_id]["__all__"])
|
and not dones[env_id]["__all__"])
|
||||||
all_done = True
|
all_agents_done = True
|
||||||
atari_metrics = _fetch_atari_metrics(base_env)
|
atari_metrics = _fetch_atari_metrics(base_env)
|
||||||
if atari_metrics is not None:
|
if atari_metrics is not None:
|
||||||
for m in atari_metrics:
|
for m in atari_metrics:
|
||||||
|
@ -445,7 +445,7 @@ def _process_observations(
|
||||||
episode.hist_data))
|
episode.hist_data))
|
||||||
else:
|
else:
|
||||||
hit_horizon = False
|
hit_horizon = False
|
||||||
all_done = False
|
all_agents_done = False
|
||||||
active_envs.add(env_id)
|
active_envs.add(env_id)
|
||||||
|
|
||||||
# Custom observation function is applied before preprocessing.
|
# Custom observation function is applied before preprocessing.
|
||||||
|
@ -473,7 +473,7 @@ def _process_observations(
|
||||||
if log_once("filtered_obs"):
|
if log_once("filtered_obs"):
|
||||||
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
|
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
|
||||||
|
|
||||||
agent_done = bool(all_done or dones[env_id].get(agent_id))
|
agent_done = bool(all_agents_done or dones[env_id].get(agent_id))
|
||||||
if not agent_done:
|
if not agent_done:
|
||||||
to_eval[policy_id].append(
|
to_eval[policy_id].append(
|
||||||
PolicyEvalData(env_id, agent_id, filtered_obs,
|
PolicyEvalData(env_id, agent_id, filtered_obs,
|
||||||
|
@ -517,15 +517,15 @@ def _process_observations(
|
||||||
if episode.batch_builder.has_pending_agent_data():
|
if episode.batch_builder.has_pending_agent_data():
|
||||||
if dones[env_id]["__all__"] and not no_done_at_end:
|
if dones[env_id]["__all__"] and not no_done_at_end:
|
||||||
episode.batch_builder.check_missing_dones()
|
episode.batch_builder.check_missing_dones()
|
||||||
if (all_done and not pack) or \
|
if (all_agents_done and not pack) or \
|
||||||
episode.batch_builder.count >= rollout_fragment_length:
|
episode.batch_builder.count >= rollout_fragment_length:
|
||||||
outputs.append(episode.batch_builder.build_and_reset(episode))
|
outputs.append(episode.batch_builder.build_and_reset(episode))
|
||||||
elif all_done:
|
elif all_agents_done:
|
||||||
# Make sure postprocessor stays within one episode
|
# Make sure postprocessor stays within one episode
|
||||||
episode.batch_builder.postprocess_batch_so_far(episode)
|
episode.batch_builder.postprocess_batch_so_far(episode)
|
||||||
|
|
||||||
if all_done:
|
if all_agents_done:
|
||||||
# Handle episode termination
|
# Handle episode termination.
|
||||||
batch_builder_pool.append(episode.batch_builder)
|
batch_builder_pool.append(episode.batch_builder)
|
||||||
# Call each policy's Exploration.on_episode_end method.
|
# Call each policy's Exploration.on_episode_end method.
|
||||||
for p in policies.values():
|
for p in policies.values():
|
||||||
|
@ -548,13 +548,13 @@ def _process_observations(
|
||||||
del active_episodes[env_id]
|
del active_episodes[env_id]
|
||||||
resetted_obs = base_env.try_reset(env_id)
|
resetted_obs = base_env.try_reset(env_id)
|
||||||
if resetted_obs is None:
|
if resetted_obs is None:
|
||||||
# Reset not supported, drop this env from the ready list
|
# Reset not supported, drop this env from the ready list.
|
||||||
if horizon != float("inf"):
|
if horizon != float("inf"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Setting episode horizon requires reset() support "
|
"Setting episode horizon requires reset() support "
|
||||||
"from the environment.")
|
"from the environment.")
|
||||||
elif resetted_obs != ASYNC_RESET_RETURN:
|
elif resetted_obs != ASYNC_RESET_RETURN:
|
||||||
# Creates a new episode if this is not async return
|
# Creates a new episode if this is not async return.
|
||||||
# If reset is async, we will get its result in some future poll
|
# If reset is async, we will get its result in some future poll
|
||||||
episode = active_episodes[env_id]
|
episode = active_episodes[env_id]
|
||||||
if observation_fn:
|
if observation_fn:
|
||||||
|
@ -623,7 +623,6 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||||
prev_reward_batch=prev_reward_batch,
|
prev_reward_batch=prev_reward_batch,
|
||||||
timestep=policy.global_timestep)
|
timestep=policy.global_timestep)
|
||||||
else:
|
else:
|
||||||
# TODO(sven): Does this work for LSTM torch?
|
|
||||||
rnn_in_cols = [
|
rnn_in_cols = [
|
||||||
np.stack([row[i] for row in rnn_in])
|
np.stack([row[i] for row in rnn_in])
|
||||||
for i in range(len(rnn_in[0]))
|
for i in range(len(rnn_in[0]))
|
||||||
|
|
|
@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
assert not args.torch, "PyTorch not supported for AttentionNets yet!"
|
assert not args.torch, "PyTorch not supported for AttentionNets yet!"
|
||||||
|
|
||||||
ray.init(num_cpus=args.num_cpus or None, local_mode=True)
|
ray.init(num_cpus=args.num_cpus or None)
|
||||||
|
|
||||||
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
|
registry.register_env("RepeatAfterMeEnv", lambda c: RepeatAfterMeEnv(c))
|
||||||
registry.register_env("RepeatInitialObsEnv",
|
registry.register_env("RepeatInitialObsEnv",
|
||||||
|
|
12
rllib/examples/env/multi_agent.py
vendored
12
rllib/examples/env/multi_agent.py
vendored
|
@ -4,12 +4,16 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||||
from ray.rllib.tests.test_rollout_worker import MockEnv, MockEnv2
|
from ray.rllib.tests.test_rollout_worker import MockEnv, MockEnv2
|
||||||
|
|
||||||
|
|
||||||
def make_multiagent(env_name):
|
def make_multiagent(env_name_or_creator):
|
||||||
class MultiEnv(MultiAgentEnv):
|
class MultiEnv(MultiAgentEnv):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.agents = [
|
num = config.pop("num_agents", 1)
|
||||||
gym.make(env_name) for _ in range(config["num_agents"])
|
if isinstance(env_name_or_creator, str):
|
||||||
]
|
self.agents = [
|
||||||
|
gym.make(env_name_or_creator) for _ in range(num)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.agents = [env_name_or_creator(config) for _ in range(num)]
|
||||||
self.dones = set()
|
self.dones = set()
|
||||||
self.observation_space = self.agents[0].observation_space
|
self.observation_space = self.agents[0].observation_space
|
||||||
self.action_space = self.agents[0].action_space
|
self.action_space = self.agents[0].action_space
|
||||||
|
|
12
rllib/examples/env/random_env.py
vendored
12
rllib/examples/env/random_env.py
vendored
|
@ -1,7 +1,9 @@
|
||||||
import gym
|
import gym
|
||||||
from gym.spaces import Tuple
|
from gym.spaces import Discrete, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from ray.rllib.examples.env.multi_agent import make_multiagent
|
||||||
|
|
||||||
|
|
||||||
class RandomEnv(gym.Env):
|
class RandomEnv(gym.Env):
|
||||||
"""A randomly acting environment.
|
"""A randomly acting environment.
|
||||||
|
@ -14,9 +16,9 @@ class RandomEnv(gym.Env):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
# Action space.
|
# Action space.
|
||||||
self.action_space = config["action_space"]
|
self.action_space = config.get("action_space", Discrete(2))
|
||||||
# Observation space from which to sample.
|
# Observation space from which to sample.
|
||||||
self.observation_space = config["observation_space"]
|
self.observation_space = config.get("observation_space", Discrete(2))
|
||||||
# Reward space from which to sample.
|
# Reward space from which to sample.
|
||||||
self.reward_space = config.get(
|
self.reward_space = config.get(
|
||||||
"reward_space",
|
"reward_space",
|
||||||
|
@ -43,3 +45,7 @@ class RandomEnv(gym.Env):
|
||||||
bool(np.random.choice(
|
bool(np.random.choice(
|
||||||
[True, False], p=[self.p_done, 1.0 - self.p_done]
|
[True, False], p=[self.p_done, 1.0 - self.p_done]
|
||||||
)), {}
|
)), {}
|
||||||
|
|
||||||
|
|
||||||
|
# Multi-agent version of the RandomEnv.
|
||||||
|
RandomMultiAgentEnv = make_multiagent(lambda c: RandomEnv(c))
|
||||||
|
|
|
@ -85,7 +85,12 @@ class SharedWeightsModel2(TFModelV2):
|
||||||
|
|
||||||
TORCH_GLOBAL_SHARED_LAYER = None
|
TORCH_GLOBAL_SHARED_LAYER = None
|
||||||
if torch:
|
if torch:
|
||||||
TORCH_GLOBAL_SHARED_LAYER = SlimFC(32, 32)
|
TORCH_GLOBAL_SHARED_LAYER = SlimFC(
|
||||||
|
64,
|
||||||
|
64,
|
||||||
|
activation_fn=nn.ReLU,
|
||||||
|
initializer=torch.nn.init.xavier_uniform_,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TorchSharedWeightsModel(TorchModelV2, nn.Module):
|
class TorchSharedWeightsModel(TorchModelV2, nn.Module):
|
||||||
|
@ -104,12 +109,22 @@ class TorchSharedWeightsModel(TorchModelV2, nn.Module):
|
||||||
# Non-shared initial layer.
|
# Non-shared initial layer.
|
||||||
self.first_layer = SlimFC(
|
self.first_layer = SlimFC(
|
||||||
int(np.product(observation_space.shape)),
|
int(np.product(observation_space.shape)),
|
||||||
32,
|
64,
|
||||||
activation_fn=nn.ReLU)
|
activation_fn=nn.ReLU,
|
||||||
|
initializer=torch.nn.init.xavier_uniform_)
|
||||||
|
|
||||||
# Non-shared final layer.
|
# Non-shared final layer.
|
||||||
self.last_layer = SlimFC(32, self.num_outputs, activation_fn=nn.ReLU)
|
self.last_layer = SlimFC(
|
||||||
self.vf = SlimFC(32, 1, activation_fn=None)
|
64,
|
||||||
|
self.num_outputs,
|
||||||
|
activation_fn=None,
|
||||||
|
initializer=torch.nn.init.xavier_uniform_)
|
||||||
|
self.vf = SlimFC(
|
||||||
|
64,
|
||||||
|
1,
|
||||||
|
activation_fn=None,
|
||||||
|
initializer=torch.nn.init.xavier_uniform_,
|
||||||
|
)
|
||||||
self._output = None
|
self._output = None
|
||||||
|
|
||||||
@override(ModelV2)
|
@override(ModelV2)
|
||||||
|
|
|
@ -28,7 +28,7 @@ parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--num-agents", type=int, default=4)
|
parser.add_argument("--num-agents", type=int, default=4)
|
||||||
parser.add_argument("--num-policies", type=int, default=2)
|
parser.add_argument("--num-policies", type=int, default=2)
|
||||||
parser.add_argument("--stop-iters", type=int, default=20)
|
parser.add_argument("--stop-iters", type=int, default=200)
|
||||||
parser.add_argument("--stop-reward", type=float, default=150)
|
parser.add_argument("--stop-reward", type=float, default=150)
|
||||||
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||||
parser.add_argument("--simple", action="store_true")
|
parser.add_argument("--simple", action="store_true")
|
||||||
|
@ -74,7 +74,6 @@ if __name__ == "__main__":
|
||||||
"env_config": {
|
"env_config": {
|
||||||
"num_agents": args.num_agents,
|
"num_agents": args.num_agents,
|
||||||
},
|
},
|
||||||
"log_level": "DEBUG",
|
|
||||||
"simple_optimizer": args.simple,
|
"simple_optimizer": args.simple,
|
||||||
"num_sgd_iter": 10,
|
"num_sgd_iter": 10,
|
||||||
"multiagent": {
|
"multiagent": {
|
||||||
|
@ -89,7 +88,7 @@ if __name__ == "__main__":
|
||||||
"training_iteration": args.stop_iters,
|
"training_iteration": args.stop_iters,
|
||||||
}
|
}
|
||||||
|
|
||||||
results = tune.run("PPO", stop=stop, config=config)
|
results = tune.run("PPO", stop=stop, config=config, verbose=1)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
check_learning_achieved(results, args.stop_reward)
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
|
@ -23,7 +23,7 @@ parser.add_argument(
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Whether to take random instead of on-policy actions.")
|
help="Whether to take random instead of on-policy actions.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--stop-at-reward",
|
"--stop-reward",
|
||||||
type=int,
|
type=int,
|
||||||
default=9999,
|
default=9999,
|
||||||
help="Stop once the specified reward is reached.")
|
help="Stop once the specified reward is reached.")
|
||||||
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
||||||
client.log_returns(eid, reward, info=info)
|
client.log_returns(eid, reward, info=info)
|
||||||
if done:
|
if done:
|
||||||
print("Total reward:", rewards)
|
print("Total reward:", rewards)
|
||||||
if rewards >= args.stop_at_reward:
|
if rewards >= args.stop_reward:
|
||||||
print("Target reward achieved, exiting")
|
print("Target reward achieved, exiting")
|
||||||
exit(0)
|
exit(0)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
|
|
|
@ -32,8 +32,7 @@ if __name__ == "__main__":
|
||||||
connector_config = {
|
connector_config = {
|
||||||
# Use the connector server to generate experiences.
|
# Use the connector server to generate experiences.
|
||||||
"input": (
|
"input": (
|
||||||
lambda ioctx: PolicyServerInput( \
|
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
|
||||||
ioctx, SERVER_ADDRESS, SERVER_PORT)
|
|
||||||
),
|
),
|
||||||
# Use a single worker process to run the server.
|
# Use a single worker process to run the server.
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
|
|
120
rllib/examples/serving/unity3d_client.py
Normal file
120
rllib/examples/serving/unity3d_client.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
"""
|
||||||
|
Example of running a Unity3D client instance against an RLlib Policy server.
|
||||||
|
Unity3D clients can be run in distributed fashion on n nodes in the cloud
|
||||||
|
and all connect to the same RLlib server for faster sample collection.
|
||||||
|
For a locally running Unity3D example, see:
|
||||||
|
`examples/unity3d_env_local.py`
|
||||||
|
|
||||||
|
To run this script on possibly different machines
|
||||||
|
against a central Policy server:
|
||||||
|
1) Install Unity3D and `pip install mlagents`.
|
||||||
|
|
||||||
|
2) Compile a Unity3D example game with MLAgents support (e.g. 3DBall or any
|
||||||
|
other one that you created yourself) and place the compiled binary
|
||||||
|
somewhere, where your RLlib client script (see below) can access it.
|
||||||
|
|
||||||
|
2.1) To find Unity3D MLAgent examples, first `pip install mlagents`,
|
||||||
|
then check out the `.../ml-agents/Project/Assets/ML-Agents/Examples/`
|
||||||
|
folder.
|
||||||
|
|
||||||
|
3) Change your RLlib Policy server code so it knows the observation- and
|
||||||
|
action Spaces, the different Policies (called "behaviors" in Unity3D
|
||||||
|
MLAgents), and Agent-to-Policy mappings for your particular game.
|
||||||
|
Alternatively, use one of the two already existing setups (3DBall or
|
||||||
|
SoccerStrikersVsGoalie).
|
||||||
|
|
||||||
|
4) Then run (two separate shells/machines):
|
||||||
|
$ python unity3d_server.py --env 3DBall
|
||||||
|
$ python unity3d_client.py --inference-mode=local --game [path to game binary]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from ray.rllib.env.policy_client import PolicyClient
|
||||||
|
from ray.rllib.env.unity3d_env import Unity3DEnv
|
||||||
|
|
||||||
|
SERVER_ADDRESS = "localhost"
|
||||||
|
SERVER_PORT = 9900
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--game",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The game executable to run as RL env. If not provided, uses local "
|
||||||
|
"Unity3D editor instance.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--horizon",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="The max. number of `step()`s for any episode (per agent) before "
|
||||||
|
"it'll be reset again automatically.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--server",
|
||||||
|
type=str,
|
||||||
|
default=SERVER_ADDRESS + ":" + str(SERVER_PORT),
|
||||||
|
help="The Policy server's address and port to connect to from this client."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-train",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to disable training (on the server side).")
|
||||||
|
parser.add_argument(
|
||||||
|
"--inference-mode",
|
||||||
|
type=str,
|
||||||
|
default="local",
|
||||||
|
choices=["local", "remote"],
|
||||||
|
help="Whether to compute actions `local`ly or `remote`ly. Note that "
|
||||||
|
"`local` is much faster b/c observations/actions do not have to be "
|
||||||
|
"sent via the network.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--update-interval-local-mode",
|
||||||
|
type=float,
|
||||||
|
default=10.0,
|
||||||
|
help="For `inference-mode=local`, every how many seconds do we update "
|
||||||
|
"learnt policy weights from the server?")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop-reward",
|
||||||
|
type=int,
|
||||||
|
default=9999,
|
||||||
|
help="Stop once the specified reward is reached.")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Start the client for sending environment information (e.g. observations,
|
||||||
|
# actions) to a policy server (listening on port 9900).
|
||||||
|
client = PolicyClient(
|
||||||
|
"http://" + args.server,
|
||||||
|
inference_mode=args.inference_mode,
|
||||||
|
update_interval=args.update_interval_local_mode)
|
||||||
|
|
||||||
|
# Start and reset the actual Unity3DEnv (either already running Unity3D
|
||||||
|
# editor or a binary (game) to be started automatically).
|
||||||
|
env = Unity3DEnv(file_name=args.game, episode_horizon=args.horizon)
|
||||||
|
obs = env.reset()
|
||||||
|
eid = client.start_episode(training_enabled=not args.no_train)
|
||||||
|
|
||||||
|
# Keep track of the total reward per episode.
|
||||||
|
total_rewards_this_episode = 0.0
|
||||||
|
|
||||||
|
# Loop infinitely through the env.
|
||||||
|
while True:
|
||||||
|
# Get actions from the Policy server given our current obs.
|
||||||
|
actions = client.get_action(eid, obs)
|
||||||
|
# Apply actions to our env.
|
||||||
|
obs, rewards, dones, infos = env.step(actions)
|
||||||
|
total_rewards_this_episode += sum(rewards.values())
|
||||||
|
# Log rewards and single-agent dones.
|
||||||
|
client.log_returns(eid, rewards, infos, multiagent_done_dict=dones)
|
||||||
|
# Check whether all agents are done and end the episode, if necessary.
|
||||||
|
if dones["__all__"]:
|
||||||
|
print("Episode done: Reward={}".format(total_rewards_this_episode))
|
||||||
|
if total_rewards_this_episode >= args.stop_reward:
|
||||||
|
quit(0)
|
||||||
|
# End the episode and reset Unity Env.
|
||||||
|
total_rewards_this_episode = 0.0
|
||||||
|
client.end_episode(eid, obs)
|
||||||
|
obs = env.reset()
|
||||||
|
# Start a new episode.
|
||||||
|
eid = client.start_episode(training_enabled=not args.no_train)
|
129
rllib/examples/serving/unity3d_server.py
Executable file
129
rllib/examples/serving/unity3d_server.py
Executable file
|
@ -0,0 +1,129 @@
|
||||||
|
"""
|
||||||
|
Example of running a Unity3D (MLAgents) Policy server that can learn
|
||||||
|
Policies via sampling inside many connected Unity game clients (possibly
|
||||||
|
running in the cloud on n nodes).
|
||||||
|
For a locally running Unity3D example, see:
|
||||||
|
`examples/unity3d_env_local.py`
|
||||||
|
|
||||||
|
To run this script against one or more possibly cloud-based clients:
|
||||||
|
1) Install Unity3D and `pip install mlagents`.
|
||||||
|
|
||||||
|
2) Compile a Unity3D example game with MLAgents support (e.g. 3DBall or any
|
||||||
|
other one that you created yourself) and place the compiled binary
|
||||||
|
somewhere, where your RLlib client script (see below) can access it.
|
||||||
|
|
||||||
|
2.1) To find Unity3D MLAgent examples, first `pip install mlagents`,
|
||||||
|
then check out the `.../ml-agents/Project/Assets/ML-Agents/Examples/`
|
||||||
|
folder.
|
||||||
|
|
||||||
|
3) Change this RLlib Policy server code so it knows the observation- and
|
||||||
|
action Spaces, the different Policies (called "behaviors" in Unity3D
|
||||||
|
MLAgents), and Agent-to-Policy mappings for your particular game.
|
||||||
|
Alternatively, use one of the two already existing setups (3DBall or
|
||||||
|
SoccerStrikersVsGoalie).
|
||||||
|
|
||||||
|
4) Then run (two separate shells/machines):
|
||||||
|
$ python unity3d_server.py --env 3DBall
|
||||||
|
$ python unity3d_client.py --inference-mode=local --game [path to game binary]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray.tune import register_env
|
||||||
|
from ray.rllib.agents.ppo import PPOTrainer
|
||||||
|
from ray.rllib.env.policy_server_input import PolicyServerInput
|
||||||
|
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
|
||||||
|
from ray.rllib.examples.env.unity3d_env import Unity3DEnv
|
||||||
|
|
||||||
|
SERVER_ADDRESS = "localhost"
|
||||||
|
SERVER_PORT = 9900
|
||||||
|
CHECKPOINT_FILE = "last_checkpoint_{}.out"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
type=str,
|
||||||
|
default="3DBall",
|
||||||
|
choices=["3DBall", "SoccerStrikersVsGoalie"],
|
||||||
|
help="The name of the Env to run in the Unity3D editor. Either `3DBall` "
|
||||||
|
"or `SoccerStrikersVsGoalie` (feel free to add more to this script!)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=SERVER_PORT,
|
||||||
|
help="The Policy server's port to listen on for ExternalEnv client "
|
||||||
|
"conections.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint-freq",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="The frequency with which to create checkpoint files of the learnt "
|
||||||
|
"Policies.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-restore",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether to load the Policy "
|
||||||
|
"weights from a previous checkpoint")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
ray.init(local_mode=True)
|
||||||
|
|
||||||
|
# Create a fake-env for the server. This env will never be used (neither
|
||||||
|
# for sampling, nor for evaluation) and its obs/action Spaces do not
|
||||||
|
# matter either (multi-agent config below defines Spaces per Policy).
|
||||||
|
register_env("fake_unity", lambda c: RandomMultiAgentEnv(c))
|
||||||
|
|
||||||
|
policies, policy_mapping_fn = \
|
||||||
|
Unity3DEnv.get_policy_configs_for_game(args.env)
|
||||||
|
|
||||||
|
# The entire config will be sent to connecting clients so they can
|
||||||
|
# build their own samplers (and also Policy objects iff
|
||||||
|
# `inference_mode=local` on clients' command line).
|
||||||
|
config = {
|
||||||
|
# Use the connector server to generate experiences.
|
||||||
|
"input": (
|
||||||
|
lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, args.port)),
|
||||||
|
# Use a single worker process (w/ SyncSampler) to run the server.
|
||||||
|
"num_workers": 0,
|
||||||
|
# Disable OPE, since the rollouts are coming from online clients.
|
||||||
|
"input_evaluation": [],
|
||||||
|
|
||||||
|
# Other settings.
|
||||||
|
"sample_batch_size": 64,
|
||||||
|
"train_batch_size": 256,
|
||||||
|
"rollout_fragment_length": 20,
|
||||||
|
# Multi-agent setup for the particular env.
|
||||||
|
"multiagent": {
|
||||||
|
"policies": policies,
|
||||||
|
"policy_mapping_fn": policy_mapping_fn,
|
||||||
|
},
|
||||||
|
"framework": "tf",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create the Trainer used for Policy serving.
|
||||||
|
trainer = PPOTrainer(env="fake_unity", config=config)
|
||||||
|
|
||||||
|
# Attempt to restore from checkpoint if possible.
|
||||||
|
checkpoint_path = CHECKPOINT_FILE.format(args.env)
|
||||||
|
if not args.no_restore and os.path.exists(checkpoint_path):
|
||||||
|
checkpoint_path = open(checkpoint_path).read()
|
||||||
|
print("Restoring from checkpoint path", checkpoint_path)
|
||||||
|
trainer.restore(checkpoint_path)
|
||||||
|
|
||||||
|
# Serving and training loop.
|
||||||
|
count = 0
|
||||||
|
while True:
|
||||||
|
# Calls to train() will block on the configured `input` in the Trainer
|
||||||
|
# config above (PolicyServerInput).
|
||||||
|
print(trainer.train())
|
||||||
|
if count % args.checkpoint_freq == 0:
|
||||||
|
print("Saving learning progress to checkpoint file.")
|
||||||
|
checkpoint = trainer.save()
|
||||||
|
# Write the latest checkpoint location to CHECKPOINT_FILE,
|
||||||
|
# so we can pick up from the latest one after a server re-start.
|
||||||
|
with open(checkpoint_path, "w") as f:
|
||||||
|
f.write(checkpoint)
|
||||||
|
count += 1
|
95
rllib/examples/unity3d_env_local.py
Normal file
95
rllib/examples/unity3d_env_local.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
Example of running an RLlib Trainer against a locally running Unity3D editor
|
||||||
|
instance (available as Unity3DEnv inside RLlib).
|
||||||
|
For a distributed cloud setup example with Unity,
|
||||||
|
see `examples/serving/unity3d_[server|client].py`
|
||||||
|
|
||||||
|
To run this script against a local Unity3D engine:
|
||||||
|
1) Install Unity3D and `pip install mlagents`.
|
||||||
|
|
||||||
|
2) Open the Unity3D Editor and load an example scene from the following
|
||||||
|
ml-agents pip package location:
|
||||||
|
`.../ml-agents/Project/Assets/ML-Agents/Examples/`
|
||||||
|
This script supports the `3DBall` and `SoccerStrikersVsGoalie` examples.
|
||||||
|
Specify the game you chose on your command line via e.g. `--env 3DBall`.
|
||||||
|
Feel free to add more supported examples here.
|
||||||
|
|
||||||
|
3) Then run this script (you will have to press Play in your Unity editor
|
||||||
|
at some point to start the game and the learning process):
|
||||||
|
$ python unity3d_env_local.py --env 3DBall --stop-reward [..] [--torch]?
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray import tune
|
||||||
|
from ray.rllib.env.unity3d_env import Unity3DEnv
|
||||||
|
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--env",
|
||||||
|
type=str,
|
||||||
|
default="3DBall",
|
||||||
|
choices=["3DBall", "SoccerStrikersVsGoalie"],
|
||||||
|
help="The name of the Env to run in the Unity3D editor. Either `3DBall` "
|
||||||
|
"or `SoccerStrikersVsGoalie` (feel free to add more to this script!)")
|
||||||
|
parser.add_argument("--as-test", action="store_true")
|
||||||
|
parser.add_argument("--stop-iters", type=int, default=150)
|
||||||
|
parser.add_argument("--stop-reward", type=float, default=9999.0)
|
||||||
|
parser.add_argument("--stop-timesteps", type=int, default=100000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--horizon",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="The max. number of `step()`s for any episode (per agent) before "
|
||||||
|
"it'll be reset again automatically.")
|
||||||
|
parser.add_argument("--torch", action="store_true")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ray.init(local_mode=True)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tune.register_env(
|
||||||
|
"unity3d",
|
||||||
|
lambda c: Unity3DEnv(episode_horizon=c.get("episode_horizon", 1000)))
|
||||||
|
|
||||||
|
# Get policies (different agent types; "behaviors" in MLAgents) and
|
||||||
|
# the mappings from individual agents to Policies.
|
||||||
|
policies, policy_mapping_fn = \
|
||||||
|
Unity3DEnv.get_policy_configs_for_game(args.env)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"env": "unity3d",
|
||||||
|
"env_config": {
|
||||||
|
"episode_horizon": args.horizon,
|
||||||
|
},
|
||||||
|
# IMPORTANT: Just use one Worker (we only have one Unity running)!
|
||||||
|
"num_workers": 0,
|
||||||
|
# Other settings.
|
||||||
|
"sample_batch_size": 64,
|
||||||
|
"train_batch_size": 256,
|
||||||
|
"rollout_fragment_length": 20,
|
||||||
|
# Multi-agent setup for the particular env.
|
||||||
|
"multiagent": {
|
||||||
|
"policies": policies,
|
||||||
|
"policy_mapping_fn": policy_mapping_fn,
|
||||||
|
},
|
||||||
|
"framework": "tf",
|
||||||
|
}
|
||||||
|
|
||||||
|
stop = {
|
||||||
|
"training_iteration": args.stop_iters,
|
||||||
|
"timesteps_total": args.stop_timesteps,
|
||||||
|
"episode_reward_mean": args.stop_reward,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run the experiment.
|
||||||
|
results = tune.run("PPO", config=config, stop=stop, verbose=1)
|
||||||
|
|
||||||
|
# And check the results.
|
||||||
|
if args.as_test:
|
||||||
|
check_learning_achieved(results, args.stop_reward)
|
||||||
|
|
||||||
|
ray.shutdown()
|
|
@ -232,7 +232,7 @@ class TorchPolicy(Policy):
|
||||||
loss_out = force_list(
|
loss_out = force_list(
|
||||||
self._loss(self, self.model, self.dist_class, train_batch))
|
self._loss(self, self.model, self.dist_class, train_batch))
|
||||||
assert len(loss_out) == len(self._optimizers)
|
assert len(loss_out) == len(self._optimizers)
|
||||||
# assert not any(np.isnan(l.detach().numpy()) for l in loss_out)
|
# assert not any(torch.isnan(l) for l in loss_out)
|
||||||
|
|
||||||
# Loop through all optimizers.
|
# Loop through all optimizers.
|
||||||
grad_info = {"allreduce_latency": 0.0}
|
grad_info = {"allreduce_latency": 0.0}
|
||||||
|
|
|
@ -100,11 +100,11 @@ class MockEnv2(gym.Env):
|
||||||
|
|
||||||
class MockVectorEnv(VectorEnv):
|
class MockVectorEnv(VectorEnv):
|
||||||
def __init__(self, episode_length, num_envs):
|
def __init__(self, episode_length, num_envs):
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
observation_space=gym.spaces.Discrete(1),
|
||||||
|
action_space=gym.spaces.Discrete(2),
|
||||||
|
num_envs=num_envs)
|
||||||
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
|
self.envs = [MockEnv(episode_length) for _ in range(num_envs)]
|
||||||
self.observation_space = gym.spaces.Discrete(1)
|
|
||||||
self.action_space = gym.spaces.Discrete(2)
|
|
||||||
self.num_envs = num_envs
|
|
||||||
|
|
||||||
def vector_reset(self):
|
def vector_reset(self):
|
||||||
return [e.reset() for e in self.envs]
|
return [e.reset() for e in self.envs]
|
||||||
|
|
|
@ -38,8 +38,8 @@ class PolicyClient:
|
||||||
"""Record the start of an episode.
|
"""Record the start of an episode.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
episode_id (str): Unique string id for the episode or None for
|
episode_id (Optional[str]): Unique string id for the episode or
|
||||||
it to be auto-assigned.
|
None for it to be auto-assigned.
|
||||||
training_enabled (bool): Whether to use experiences for this
|
training_enabled (bool): Whether to use experiences for this
|
||||||
episode to improve the policy.
|
episode to improve the policy.
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue