mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[RLlib] Auto-framework, retire use_pytorch
in favor of framework=...
(#8520)
This commit is contained in:
parent
bcdbe2d3d4
commit
2746fc0476
168 changed files with 1447 additions and 1213 deletions
|
@ -229,8 +229,7 @@ matrix:
|
|||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=quick_train rllib/...
|
||||
# Test everything that does not have any of the "main" labels:
|
||||
# "learning_tests|quick_train|examples|tests_dir".
|
||||
#- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests_tf,-learning_tests_torch,-quick_train,-examples,-tests_dir rllib/...
|
||||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=agents_dir_X rllib/...
|
||||
- ./ci/keep_alive bazel test --config=ci --build_tests_only --test_tag_filters=-learning_tests_tf,-learning_tests_torch,-quick_train,-examples,-tests_dir rllib/...
|
||||
|
||||
# RLlib: Everything in rllib/examples/ directory.
|
||||
- os: linux
|
||||
|
|
|
@ -55,7 +55,7 @@ Ape-X variations of DQN and DDPG (`APEX_DQN <https://github.com/ray-project/ray/
|
|||
|
||||
Ape-X architecture
|
||||
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-apex.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pendulum-apex-ddpg.yaml>`__, `MountainCarContinuous-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/mountaincarcontinuous-apex-ddpg.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-apex.yaml>`__.
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/pong-apex.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ddpg/pendulum-apex-ddpg.yaml>`__, `MountainCarContinuous-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ddpg/mountaincarcontinuous-apex-ddpg.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/atari-apex.yaml>`__.
|
||||
|
||||
**Atari results @10M steps**: `more details <https://github.com/ray-project/rl-experiments>`__
|
||||
|
||||
|
@ -103,7 +103,7 @@ In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulli
|
|||
|
||||
IMPALA architecture
|
||||
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-impala-vectorized.yaml>`__, `multi-gpu configuration <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-impala-fast.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-impala.yaml>`__
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/impala/pong-impala.yaml>`__, `vectorized configuration <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/impala/pong-impala-vectorized.yaml>`__, `multi-gpu configuration <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/impala/pong-impala-fast.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/impala/atari-impala.yaml>`__
|
||||
|
||||
**Atari results @10M steps**: `more details <https://github.com/ray-project/rl-experiments>`__
|
||||
|
||||
|
@ -156,7 +156,7 @@ We include an asynchronous variant of Proximal Policy Optimization (PPO) based o
|
|||
|
||||
APPO architecture (same as IMPALA)
|
||||
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-appo.yaml>`__
|
||||
Tuned examples: `PongNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/pong-appo.yaml>`__
|
||||
|
||||
**APPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
|
@ -182,7 +182,7 @@ Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized
|
|||
|
||||
DD-PPO architecture (both sampling and learning are done on worker GPUs)
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/regression_tests/cartpole-ddppo.yaml>`__, `BreakoutNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-ddppo.yaml>`__
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/cartpole-ddppo.yaml>`__, `BreakoutNoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/atari-ddppo.yaml>`__
|
||||
|
||||
**DDPPO-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
|
@ -208,7 +208,7 @@ A2C also supports microbatching (i.e., gradient accumulation), which can be enab
|
|||
|
||||
A2C architecture
|
||||
|
||||
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-a3c.yaml>`__, `PyTorch version <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-a3c-pytorch.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-a2c.yaml>`__
|
||||
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/a3c/pong-a3c.yaml>`__, `PyTorch version <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/a3c/pong-a3c-pytorch.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/a3c/atari-a2c.yaml>`__
|
||||
|
||||
.. tip::
|
||||
Consider using `IMPALA <#importance-weighted-actor-learner-architecture-impala>`__ for faster training with similar timestep efficiency.
|
||||
|
@ -243,7 +243,7 @@ DDPG is implemented similarly to DQN (below). The algorithm can be scaled by inc
|
|||
|
||||
DDPG architecture (same as DQN)
|
||||
|
||||
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pendulum-ddpg.yaml>`__, `MountainCarContinuous-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/mountaincarcontinuous-ddpg.yaml>`__, `HalfCheetah-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/halfcheetah-ddpg.yaml>`__, `TD3 Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pendulum-td3.yaml>`__, `TD3 InvertedPendulum-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/invertedpendulum-td3.yaml>`__, `TD3 Mujoco suite (Ant-v2, HalfCheetah-v2, Hopper-v2, Walker2d-v2) <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/mujoco-td3.yaml>`__.
|
||||
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ddpg/pendulum-ddpg.yaml>`__, `MountainCarContinuous-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ddpg/mountaincarcontinuous-ddpg.yaml>`__, `HalfCheetah-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ddpg/halfcheetah-ddpg.yaml>`__, `TD3 Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pendulum-td3.yaml>`__, `TD3 InvertedPendulum-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/invertedpendulum-td3.yaml>`__, `TD3 Mujoco suite (Ant-v2, HalfCheetah-v2, Hopper-v2, Walker2d-v2) <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/mujoco-td3.yaml>`__.
|
||||
|
||||
**DDPG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
|
@ -264,7 +264,7 @@ DQN can be scaled by increasing the number of workers or using Ape-X. Memory usa
|
|||
|
||||
DQN architecture
|
||||
|
||||
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-dqn.yaml>`__, `Rainbow configuration <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-rainbow.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-dqn.yaml>`__, `with Dueling and Double-Q <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-duel-ddqn.yaml>`__, `with Distributional DQN <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-dist-dqn.yaml>`__.
|
||||
Tuned examples: `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/pong-dqn.yaml>`__, `Rainbow configuration <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/pong-rainbow.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/atari-dqn.yaml>`__, `with Dueling and Double-Q <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/atari-duel-ddqn.yaml>`__, `with Distributional DQN <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/dqn/atari-dist-dqn.yaml>`__.
|
||||
|
||||
.. tip::
|
||||
Consider using `Ape-X <#distributed-prioritized-experience-replay-ape-x>`__ for faster training with similar timestep efficiency.
|
||||
|
@ -308,7 +308,7 @@ Policy Gradients
|
|||
|
||||
Policy gradients architecture (same as A2C)
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/regression_tests/cartpole-pg-tf.yaml>`__
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pg/cartpole-pg.yaml>`__
|
||||
|
||||
**PG-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
|
@ -333,7 +333,7 @@ PPO's clipped objective supports multiple SGD passes over the same batch of expe
|
|||
|
||||
PPO architecture
|
||||
|
||||
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/humanoid-ppo-gae.yaml>`__, `Hopper-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/hopper-ppo.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pendulum-ppo.yaml>`__, `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/pong-ppo.yaml>`__, `Walker2d-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/walker2d-ppo.yaml>`__, `HalfCheetah-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/halfcheetah-ppo.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/atari-ppo.yaml>`__
|
||||
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/humanoid-ppo-gae.yaml>`__, `Hopper-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/hopper-ppo.yaml>`__, `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/pendulum-ppo.yaml>`__, `PongDeterministic-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/pong-ppo.yaml>`__, `Walker2d-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/walker2d-ppo.yaml>`__, `HalfCheetah-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/halfcheetah-ppo.yaml>`__, `{BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ppo/atari-ppo.yaml>`__
|
||||
|
||||
|
||||
**Atari results**: `more details <https://github.com/ray-project/rl-experiments>`__
|
||||
|
@ -381,7 +381,7 @@ Soft Actor Critic (SAC)
|
|||
|
||||
RLlib's soft-actor critic implementation is ported from the `official SAC repo <https://github.com/rail-berkeley/softlearning>`__ to better integrate with RLlib APIs. Note that SAC has two fields to configure for custom models: ``policy_model`` and ``Q_model``.
|
||||
|
||||
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/regression_tests/pendulum-sac.yaml>`__, `HalfCheetah-v3 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/halfcheetah-sac.yaml>`__
|
||||
Tuned examples: `Pendulum-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/sac/pendulum-sac.yaml>`__, `HalfCheetah-v3 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/sac/halfcheetah-sac.yaml>`__
|
||||
|
||||
**MuJoCo results @3M steps:** `more details <https://github.com/ray-project/rl-experiments>`__
|
||||
|
||||
|
@ -409,7 +409,7 @@ Augmented Random Search (ARS)
|
|||
`[paper] <https://arxiv.org/abs/1803.07055>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/ars/ars.py>`__
|
||||
ARS is a random search method for training linear policies for continuous control problems. Code here is adapted from https://github.com/modestyachts/ARS to integrate with RLlib APIs.
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/regression_tests/cartpole-ars.yaml>`__, `Swimmer-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/swimmer-ars.yaml>`__
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ars/cartpole-ars.yaml>`__, `Swimmer-v2 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/ars/swimmer-ars.yaml>`__
|
||||
|
||||
**ARS-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
|
@ -426,7 +426,7 @@ Evolution Strategies
|
|||
`[paper] <https://arxiv.org/abs/1703.03864>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/es/es.py>`__
|
||||
Code here is adapted from https://github.com/openai/evolution-strategies-starter to execute in the distributed setting with Ray.
|
||||
|
||||
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/humanoid-es.yaml>`__
|
||||
Tuned examples: `Humanoid-v1 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/es/humanoid-es.yaml>`__
|
||||
|
||||
**Scalability:**
|
||||
|
||||
|
@ -481,7 +481,7 @@ Advantage Re-Weighted Imitation Learning (MARWIL)
|
|||
|pytorch| |tensorflow|
|
||||
`[paper] <http://papers.nips.cc/paper/7866-exponentially-weighted-imitation-learning-for-batched-historical-data>`__ `[implementation] <https://github.com/ray-project/ray/blob/master/rllib/agents/marwil/marwil.py>`__ MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data. When the ``beta`` hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning. MARWIL requires the `offline datasets API <rllib-offline.html>`__ to be used.
|
||||
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/cartpole-marwil.yaml>`__
|
||||
Tuned examples: `CartPole-v0 <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples/marwil/cartpole-marwil.yaml>`__
|
||||
|
||||
**MARWIL-specific configs** (see also `common configs <rllib-training.html#common-parameters>`__):
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ Feature development, discussion, and upcoming priorities are tracked on the `Git
|
|||
Benchmarks
|
||||
----------
|
||||
|
||||
A number of training run results are available in the `rl-experiments repo <https://github.com/ray-project/rl-experiments>`__, and there is also a list of working hyperparameter configurations in `tuned_examples <https://github.com/ray-project/ray/tree/master/rllib/tuned_examples>`__. Benchmark results are extremely valuable to the community, so if you happen to have results that may be of interest, consider making a pull request to either repo.
|
||||
A number of training run results are available in the `rl-experiments repo <https://github.com/ray-project/rl-experiments>`__, and there is also a list of working hyperparameter configurations in `tuned_examples <https://github.com/ray-project/ray/tree/master/rllib/tuned_examples>`__, sorted by algorithm. Benchmark results are extremely valuable to the community, so if you happen to have results that may be of interest, consider making a pull request to either repo.
|
||||
|
||||
Contributing Algorithms
|
||||
-----------------------
|
||||
|
|
|
@ -9,9 +9,9 @@ Tuned Examples
|
|||
--------------
|
||||
|
||||
- `Tuned examples <https://github.com/ray-project/ray/blob/master/rllib/tuned_examples>`__:
|
||||
Collection of tuned algorithm hyperparameters.
|
||||
- `Atari benchmarks <https://github.com/ray-project/rl-experiments>`__:
|
||||
Collection of reasonably optimized Atari results.
|
||||
Collection of tuned hyperparameters by algorithm.
|
||||
- `MuJoCo and Atari benchmarks <https://github.com/ray-project/rl-experiments>`__:
|
||||
Collection of reasonably optimized Atari and MuJoCo results.
|
||||
|
||||
Blog Posts
|
||||
----------
|
||||
|
|
|
@ -130,7 +130,7 @@ Once implemented, the model can then be registered and used in place of a built-
|
|||
|
||||
ray.init()
|
||||
trainer = a3c.A2CTrainer(env="CartPole-v0", config={
|
||||
"use_pytorch": True,
|
||||
"framework": "torch",
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
# Extra kwargs to be passed to your model's c'tor.
|
||||
|
|
|
@ -38,8 +38,8 @@ Then, you can try out training in the following equivalent ways:
|
|||
from ray import tune
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
tune.run(PPOTrainer, config={"env": "CartPole-v0"}) # "log_level": "INFO" for verbose,
|
||||
# "eager": True for eager execution,
|
||||
# "use_pytorch": True for PyTorch
|
||||
# "framework": "tfe" for tf-eager execution,
|
||||
# "framework": "torch" for PyTorch
|
||||
|
||||
Next, we'll cover three key concepts in RLlib: Policies, Samples, and Trainers.
|
||||
|
||||
|
|
330
rllib/BUILD
330
rllib/BUILD
|
@ -352,14 +352,21 @@ py_test(
|
|||
# Tag: agents_dir
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
# A2CTrainer
|
||||
# A2/3CTrainer
|
||||
py_test(
|
||||
name = "test_a2c",
|
||||
tags = ["agents_dir"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["agents/a3c/tests/test_a2c.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_a3c",
|
||||
tags = ["agents_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/a3c/tests/test_a3c.py"]
|
||||
)
|
||||
|
||||
# APEXTrainer (DQN)
|
||||
py_test(
|
||||
name = "test_apex_dqn",
|
||||
|
@ -379,7 +386,7 @@ py_test(
|
|||
# ARS
|
||||
py_test(
|
||||
name = "test_ars",
|
||||
tags = ["agents_dir_X"],
|
||||
tags = ["agents_dir"],
|
||||
size = "medium",
|
||||
srcs = ["agents/ars/tests/test_ars.py"]
|
||||
)
|
||||
|
@ -519,33 +526,6 @@ py_test(
|
|||
|
||||
# A2C/A3C
|
||||
|
||||
py_test(
|
||||
name = "test_a3c_tf_cartpole_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_a3c_torch_cartpole_v1",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "CartPole-v1",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"sample_async\": false}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_a3c_tf_cartpole_v1_lstm",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
|
@ -554,49 +534,7 @@ py_test(
|
|||
"--env", "CartPole-v1",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_a3c_torch_pendulum_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
size = "small",
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "Pendulum-v0",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"sample_async\": false}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_a3c_tf_pong_deterministic_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "PongDeterministic-v0",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_a3c_torch_pong_deterministic_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "PongDeterministic-v0",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"sample_async\": false}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
@ -606,11 +544,10 @@ py_test(
|
|||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "PongDeterministic-v0",
|
||||
"--env", "PongDeterministic-v4",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"use_pytorch\": true, \"sample_async\": false, \"model\": {\"use_lstm\": false, \"grayscale\": true, \"zero_mean\": false, \"dim\": 84}, \"preprocessor_pref\": \"rllib\"}'",
|
||||
"--config", "'{\"framework\": \"torch\", \"num_workers\": 2, \"sample_async\": false, \"model\": {\"use_lstm\": false, \"grayscale\": true, \"zero_mean\": false, \"dim\": 84}, \"preprocessor_pref\": \"rllib\"}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
@ -623,39 +560,13 @@ py_test(
|
|||
"--env", "Pong-ram-v4",
|
||||
"--run", "A3C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_a2c_tf_pong_deterministic_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "PongDeterministic-v0",
|
||||
"--run", "A2C",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# DDPG/APEX-DDPG/TD3
|
||||
|
||||
py_test(
|
||||
name = "test_ddpg_pendulum_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "Pendulum-v0",
|
||||
"--run", "DDPG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_ddpg_mountaincar_continuous_v0_num_workers_0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
|
@ -664,7 +575,7 @@ py_test(
|
|||
"--env", "MountainCarContinuous-v0",
|
||||
"--run", "DDPG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 0}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 0}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -676,21 +587,7 @@ py_test(
|
|||
"--env", "MountainCarContinuous-v0",
|
||||
"--run", "DDPG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_apex_ddpg_pendulum_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "Pendulum-v0",
|
||||
"--run", "APEX_DDPG",
|
||||
"--ray-num-cpus", "8",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"learning_starts\": 100, \"min_iter_time_s\": 1}'",
|
||||
"--ray-num-cpus", "4"
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -702,23 +599,11 @@ py_test(
|
|||
"--env", "Pendulum-v0",
|
||||
"--run", "APEX_DDPG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"learning_starts\": 100, \"min_iter_time_s\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"learning_starts\": 100, \"min_iter_time_s\": 1, \"batch_mode\": \"complete_episodes\"}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_td3_pendulum_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "Pendulum-v0",
|
||||
"--run", "TD3",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
# DQN/APEX
|
||||
|
||||
py_test(
|
||||
|
@ -729,6 +614,7 @@ py_test(
|
|||
args = [
|
||||
"--env", "FrozenLake-v0",
|
||||
"--run", "DQN",
|
||||
"--config", "'{\"framework\": \"tf\"}'",
|
||||
"--stop", "'{\"training_iteration\": 1}'"
|
||||
]
|
||||
)
|
||||
|
@ -742,7 +628,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"lr\": 1e-3, \"exploration_config\": {\"epsilon_timesteps\": 10000, \"final_epsilon\": 0.02}, \"dueling\": false, \"hiddens\": [], \"model\": {\"fcnet_hiddens\": [64], \"fcnet_activation\": \"relu\"}}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-3, \"exploration_config\": {\"epsilon_timesteps\": 10000, \"final_epsilon\": 0.02}, \"dueling\": false, \"hiddens\": [], \"model\": {\"fcnet_hiddens\": [64], \"fcnet_activation\": \"relu\"}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -754,7 +640,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
@ -770,7 +656,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"input\": \"tests/data/cartpole_small\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole_small\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -782,20 +668,7 @@ py_test(
|
|||
"--env", "PongDeterministic-v4",
|
||||
"--run", "DQN",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"buffer_size\": 10000, \"rollout_fragment_length\": 4, \"learning_starts\": 10000, \"target_network_update_freq\": 1000, \"gamma\": 0.99, \"prioritized_replay\": true}'"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_apex_cartpole_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "APEX",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"timesteps_per_iteration\": 1000, \"num_gpus\": 0, \"min_iter_time_s\": 1}'",
|
||||
"--ray-num-cpus", "4"
|
||||
"--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"buffer_size\": 10000, \"rollout_fragment_length\": 4, \"learning_starts\": 10000, \"target_network_update_freq\": 1000, \"gamma\": 0.99, \"prioritized_replay\": true}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -809,7 +682,7 @@ py_test(
|
|||
"--env", "Pendulum-v0",
|
||||
"--run", "ES",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"stepsize\": 0.01, \"episodes_per_batch\": 20, \"train_batch_size\": 100, \"num_workers\": 2}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"stepsize\": 0.01, \"episodes_per_batch\": 20, \"train_batch_size\": 100, \"num_workers\": 2}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
@ -822,52 +695,13 @@ py_test(
|
|||
"--env", "Pong-v0",
|
||||
"--run", "ES",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"stepsize\": 0.01, \"episodes_per_batch\": 20, \"train_batch_size\": 100, \"num_workers\": 2}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"stepsize\": 0.01, \"episodes_per_batch\": 20, \"train_batch_size\": 100, \"num_workers\": 2}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
||||
# IMPALA
|
||||
|
||||
py_test(
|
||||
name = "test_impala_cartpole_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_impala_cartpole_v0_num_aggregation_workers_2",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_gpus\": 0, \"num_workers\": 2, \"num_aggregation_workers\": 2, \"min_iter_time_s\": 1}'",
|
||||
"--ray-num-cpus", "5",
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_impala_cartpole_v0_lstm",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"model\": {\"use_lstm\": true}}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_impala_buffers_2",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
|
@ -876,7 +710,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_data_loader_buffers\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_data_loader_buffers\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
@ -889,7 +723,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "IMPALA",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_data_loader_buffers\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0, \"model\": {\"use_lstm\": true}}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_iter_time_s\": 1, \"num_data_loader_buffers\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0, \"model\": {\"use_lstm\": true}}'",
|
||||
"--ray-num-cpus", "4",
|
||||
]
|
||||
)
|
||||
|
@ -903,7 +737,7 @@ py_test(
|
|||
"--run", "IMPALA",
|
||||
"--stop", "'{\"timesteps_total\": 40000}'",
|
||||
"--ray-object-store-memory=1000000000",
|
||||
"--config", "'{\"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"rollout_fragment_length\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"rollout_fragment_length\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -921,7 +755,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "MARWIL",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"input\": \"tests/data/cartpole_small\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"shuffle_buffer_size\": 10}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole_small\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"shuffle_buffer_size\": 10}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -937,7 +771,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "MARWIL",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"use_pytorch\": true, \"input\": \"tests/data/cartpole_small\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"shuffle_buffer_size\": 10}'"
|
||||
"--config", "'{\"framework\": \"torch\", \"input\": \"tests/data/cartpole_small\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"shuffle_buffer_size\": 10}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -951,7 +785,7 @@ py_test(
|
|||
"--env", "FrozenLake-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -961,37 +795,10 @@ py_test(
|
|||
size = "small",
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "FrozenLake-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_pg_tf_cartpole_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
size = "small",
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_pg_torch_cartpole_v0",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"rollout_fragment_length\": 500}'"
|
||||
"--config", "'{\"framework\": \"torch\", \"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1003,7 +810,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1, \"model\": {\"use_lstm\": true, \"max_seq_len\": 100}}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1, \"model\": {\"use_lstm\": true, \"max_seq_len\": 100}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1016,7 +823,7 @@ py_test(
|
|||
"--env", "CartPole-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1, \"num_envs_per_worker\": 10}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1, \"num_envs_per_worker\": 10}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1029,7 +836,7 @@ py_test(
|
|||
"--env", "Pong-v0",
|
||||
"--run", "PG",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1044,7 +851,7 @@ py_test(
|
|||
"--env", "FrozenLake-v0",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_sgd_iter\": 10, \"sgd_minibatch_size\": 64, \"train_batch_size\": 1000, \"num_workers\": 1}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"num_sgd_iter\": 10, \"sgd_minibatch_size\": 64, \"train_batch_size\": 1000, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1054,37 +861,10 @@ py_test(
|
|||
size = "small",
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "FrozenLake-v0",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_sgd_iter\": 10, \"sgd_minibatch_size\": 64, \"train_batch_size\": 1000, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_ppo_tf_cartpole_v1",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--env", "CartPole-v1",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"model\": {\"free_log_std\": true}}'"
|
||||
]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_ppo_torch_cartpole_v1",
|
||||
main = "train.py", srcs = ["train.py"],
|
||||
size = "small",
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "CartPole-v1",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"model\": {\"free_log_std\": true}}'"
|
||||
"--config", "'{\"framework\": \"torch\", \"num_sgd_iter\": 10, \"sgd_minibatch_size\": 64, \"train_batch_size\": 1000, \"num_workers\": 1}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1096,7 +876,7 @@ py_test(
|
|||
"--env", "CartPole-v1",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"simple_optimizer\": false, \"num_sgd_iter\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"simple_optimizer\": false, \"num_sgd_iter\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
@ -1109,7 +889,7 @@ py_test(
|
|||
"--env", "CartPole-v1",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"simple_optimizer\": true, \"num_sgd_iter\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"simple_optimizer\": true, \"num_sgd_iter\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
@ -1119,11 +899,10 @@ py_test(
|
|||
# name = "test_ppo_torch_cartpole_v1_lstm_simple_optimizer",
|
||||
# main = "train.py", srcs = ["train.py"],
|
||||
# args = [
|
||||
# "--torch",
|
||||
# "--env", "CartPole-v1",
|
||||
# "--run", "PPO",
|
||||
# "--stop", "'{\"training_iteration\": 1}'",
|
||||
# "--config", "'{\"simple_optimizer\": true, \"num_sgd_iter\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
# "--config", "'{\"framework\": \"torch\", \"simple_optimizer\": true, \"num_sgd_iter\": 2, \"model\": {\"use_lstm\": true}}'",
|
||||
# "--ray-num-cpus", "4"
|
||||
# ]
|
||||
#)
|
||||
|
@ -1136,7 +915,7 @@ py_test(
|
|||
"--env", "CartPole-v1",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"use_gae\": false, \"batch_mode\": \"complete_episodes\"}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"use_gae\": false, \"batch_mode\": \"complete_episodes\"}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1148,7 +927,7 @@ py_test(
|
|||
"--env", "CartPole-v1",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"remote_worker_envs\": true, \"remote_env_batch_wait_ms\": 99999999, \"num_envs_per_worker\": 2, \"num_workers\": 1, \"train_batch_size\": 100, \"sgd_minibatch_size\": 50}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"remote_worker_envs\": true, \"remote_env_batch_wait_ms\": 99999999, \"num_envs_per_worker\": 2, \"num_workers\": 1, \"train_batch_size\": 100, \"sgd_minibatch_size\": 50}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1160,7 +939,7 @@ py_test(
|
|||
"--env", "CartPole-v1",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 2}'",
|
||||
"--config", "'{\"remote_worker_envs\": true, \"num_envs_per_worker\": 2, \"num_workers\": 1, \"train_batch_size\": 100, \"sgd_minibatch_size\": 50}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"remote_worker_envs\": true, \"num_envs_per_worker\": 2, \"num_workers\": 1, \"train_batch_size\": 100, \"sgd_minibatch_size\": 50}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1172,7 +951,7 @@ py_test(
|
|||
"--env", "MontezumaRevenge-v0",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"model\": {\"dim\": 40, \"conv_filters\": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}}'"
|
||||
"--config", "'{\"framework\": \"tf\", \"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"model\": {\"dim\": 40, \"conv_filters\": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1181,11 +960,10 @@ py_test(
|
|||
main = "train.py", srcs = ["train.py"],
|
||||
tags = ["quick_train"],
|
||||
args = [
|
||||
"--torch",
|
||||
"--env", "MontezumaRevenge-v0",
|
||||
"--run", "PPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"model\": {\"dim\": 40, \"conv_filters\": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}}'"
|
||||
"--config", "'{\"framework\": \"torch\", \"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"model\": {\"dim\": 40, \"conv_filters\": [[16, [8, 8], 4], [32, [4, 4], 2], [512, [5, 5], 1]]}}'"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1197,7 +975,7 @@ py_test(
|
|||
"--env", "Pendulum-v0",
|
||||
"--run", "APPO",
|
||||
"--stop", "'{\"training_iteration\": 1}'",
|
||||
"--config", "'{\"num_workers\": 2, \"num_gpus\": 0}'",
|
||||
"--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"num_gpus\": 0}'",
|
||||
"--ray-num-cpus", "4"
|
||||
]
|
||||
)
|
||||
|
@ -1370,6 +1148,13 @@ py_test(
|
|||
srcs = ["tests/test_evaluators.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_export",
|
||||
tags = ["tests_dir", "tests_dir_E"],
|
||||
size = "medium",
|
||||
srcs = ["tests/test_export.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_external_env",
|
||||
tags = ["tests_dir", "tests_dir_E"],
|
||||
|
@ -1459,7 +1244,7 @@ py_test(
|
|||
py_test(
|
||||
name = "tests/test_exec_api",
|
||||
tags = ["tests_dir", "tests_dir_E"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["tests/test_exec_api.py"]
|
||||
)
|
||||
|
||||
|
@ -1486,10 +1271,17 @@ py_test(
|
|||
srcs = ["tests/test_rollout_worker.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_supported_multi_agent",
|
||||
tags = ["tests_dir", "tests_dir_S"],
|
||||
size = "large",
|
||||
srcs = ["tests/test_supported_multi_agent.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_supported_spaces",
|
||||
tags = ["tests_dir", "tests_dir_S"],
|
||||
size = "enormous",
|
||||
size = "large",
|
||||
srcs = ["tests/test_supported_spaces.py"]
|
||||
)
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
return A3CTorchPolicy
|
||||
|
|
|
@ -1,36 +1,53 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.utils.test_utils import check_compute_action
|
||||
import ray.rllib.agents.a3c as a3c
|
||||
from ray.rllib.utils.test_utils import check_compute_action, framework_iterator
|
||||
|
||||
|
||||
class TestA2C(unittest.TestCase):
|
||||
"""Sanity tests for A2C exec impl."""
|
||||
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def test_a2c_compilation(self):
|
||||
"""Test whether an A2CTrainer can be built with both frameworks."""
|
||||
config = a3c.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["num_envs_per_worker"] = 2
|
||||
|
||||
num_iterations = 1
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in framework_iterator(config, ("tf", "torch")):
|
||||
config["sample_async"] = fw == "tf"
|
||||
for env in ["PongDeterministic-v0"]:
|
||||
trainer = a3c.A2CTrainer(config=config, env=env)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
check_compute_action(trainer)
|
||||
|
||||
def test_a2c_exec_impl(ray_start_regular):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"min_iter_time_s": 0,
|
||||
})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
check_compute_action(trainer)
|
||||
config = {"min_iter_time_s": 0}
|
||||
for _ in framework_iterator(config, ("tf", "torch")):
|
||||
trainer = a3c.A2CTrainer(env="CartPole-v0", config=config)
|
||||
assert isinstance(trainer.train(), dict)
|
||||
check_compute_action(trainer)
|
||||
|
||||
def test_a2c_exec_impl_microbatch(ray_start_regular):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"min_iter_time_s": 0,
|
||||
"microbatch_size": 10,
|
||||
})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
check_compute_action(trainer)
|
||||
config = {
|
||||
"min_iter_time_s": 0,
|
||||
"microbatch_size": 10,
|
||||
}
|
||||
for _ in framework_iterator(config, ("tf", "torch")):
|
||||
trainer = a3c.A2CTrainer(env="CartPole-v0", config=config)
|
||||
assert isinstance(trainer.train(), dict)
|
||||
check_compute_action(trainer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
39
rllib/agents/a3c/tests/test_a3c.py
Normal file
39
rllib/agents/a3c/tests/test_a3c.py
Normal file
|
@ -0,0 +1,39 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.a3c as a3c
|
||||
from ray.rllib.utils.test_utils import check_compute_action, framework_iterator
|
||||
|
||||
|
||||
class TestA3C(unittest.TestCase):
|
||||
"""Sanity tests for A2C exec impl."""
|
||||
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def test_a3c_compilation(self):
|
||||
"""Test whether an A3CTrainer can be built with both frameworks."""
|
||||
config = a3c.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 2
|
||||
config["num_envs_per_worker"] = 2
|
||||
|
||||
num_iterations = 1
|
||||
|
||||
# Test against all frameworks.
|
||||
for fw in framework_iterator(config, ("tf", "torch")):
|
||||
config["sample_async"] = fw == "tf"
|
||||
for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
|
||||
trainer = a3c.A3CTrainer(config=config, env=env)
|
||||
for i in range(num_iterations):
|
||||
results = trainer.train()
|
||||
print(results)
|
||||
check_compute_action(trainer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -163,7 +163,7 @@ class Worker:
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ars.ars_torch_policy import ARSTorchPolicy
|
||||
policy_cls = ARSTorchPolicy
|
||||
else:
|
||||
|
@ -207,6 +207,13 @@ class ARSTrainer(Trainer):
|
|||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@override(Trainer)
|
||||
def get_policy(self, policy=DEFAULT_POLICY_ID):
|
||||
if policy != DEFAULT_POLICY_ID:
|
||||
raise ValueError("ARS has no policy '{}'! Use {} "
|
||||
"instead.".format(policy, DEFAULT_POLICY_ID))
|
||||
return self.policy
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
|
|
@ -203,7 +203,7 @@ def validate_config(config):
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ddpg.ddpg_torch_policy import DDPGTorchPolicy
|
||||
return DDPGTorchPolicy
|
||||
else:
|
||||
|
|
|
@ -57,7 +57,8 @@ def build_ddpg_models(policy, observation_space, action_space, config):
|
|||
num_outputs = 256 # arbitrary
|
||||
config["model"]["no_final_linear"] = True
|
||||
else:
|
||||
default_model = TorchNoopModel if config["use_pytorch"] else NoopModel
|
||||
default_model = TorchNoopModel if config["framework"] == "torch" \
|
||||
else NoopModel
|
||||
num_outputs = int(np.product(observation_space.shape))
|
||||
|
||||
policy.model = ModelCatalog.get_model_v2(
|
||||
|
@ -65,9 +66,9 @@ def build_ddpg_models(policy, observation_space, action_space, config):
|
|||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
model_interface=DDPGTorchModel
|
||||
if config["use_pytorch"] else DDPGTFModel,
|
||||
framework=config["framework"],
|
||||
model_interface=(DDPGTorchModel
|
||||
if config["framework"] == "torch" else DDPGTFModel),
|
||||
default_model=default_model,
|
||||
name="ddpg_model",
|
||||
actor_hidden_activation=config["actor_hidden_activation"],
|
||||
|
@ -84,9 +85,9 @@ def build_ddpg_models(policy, observation_space, action_space, config):
|
|||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
model_interface=DDPGTorchModel
|
||||
if config["use_pytorch"] else DDPGTFModel,
|
||||
framework=config["framework"],
|
||||
model_interface=(DDPGTorchModel
|
||||
if config["framework"] == "torch" else DDPGTFModel),
|
||||
default_model=default_model,
|
||||
name="target_ddpg_model",
|
||||
actor_hidden_activation=config["actor_hidden_activation"],
|
||||
|
@ -114,9 +115,9 @@ def get_distribution_inputs_and_class(policy,
|
|||
}, [], None)
|
||||
dist_inputs = model.get_policy_output(model_out)
|
||||
|
||||
return dist_inputs,\
|
||||
TorchDeterministic if policy.config["use_pytorch"] else Deterministic,\
|
||||
[] # []=state out
|
||||
return dist_inputs, (TorchDeterministic
|
||||
if policy.config["framework"] == "torch" else
|
||||
Deterministic), [] # []=state out
|
||||
|
||||
|
||||
def ddpg_actor_critic_loss(policy, model, _, train_batch):
|
||||
|
|
|
@ -17,7 +17,7 @@ class TestApexDDPG(unittest.TestCase):
|
|||
def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self):
|
||||
"""Test whether an APEX-DDPGTrainer can be built on all frameworks."""
|
||||
config = apex_ddpg.APEX_DDPG_DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 3
|
||||
config["num_workers"] = 2
|
||||
config["prioritized_replay"] = True
|
||||
config["timesteps_per_iteration"] = 100
|
||||
config["min_iter_time_s"] = 1
|
||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import re
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.ddpg as ddpg
|
||||
from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \
|
||||
loss_torch
|
||||
|
@ -19,10 +20,18 @@ torch, _ = try_import_torch()
|
|||
|
||||
|
||||
class TestDDPG(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_ddpg_compilation(self):
|
||||
"""Test whether a DDPGTrainer can be built with both frameworks."""
|
||||
config = ddpg.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["num_workers"] = 1
|
||||
config["num_envs_per_worker"] = 2
|
||||
config["learning_starts"] = 0
|
||||
config["exploration_config"]["random_timesteps"] = 100
|
||||
|
@ -367,9 +376,9 @@ class TestDDPG(unittest.TestCase):
|
|||
else:
|
||||
torch_var = policy.model.state_dict()[map_[tf_key]]
|
||||
if tf_var.shape != torch_var.shape:
|
||||
check(tf_var, np.transpose(torch_var), rtol=0.07)
|
||||
check(tf_var, np.transpose(torch_var), atol=0.1)
|
||||
else:
|
||||
check(tf_var, torch_var, rtol=0.07)
|
||||
check(tf_var, torch_var, atol=0.1)
|
||||
|
||||
trainer.stop()
|
||||
|
||||
|
|
|
@ -305,7 +305,7 @@ def calculate_rr_weights(config):
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.dqn.dqn_torch_policy import DQNTorchPolicy
|
||||
return DQNTorchPolicy
|
||||
else:
|
||||
|
@ -313,7 +313,7 @@ def get_policy_class(config):
|
|||
|
||||
|
||||
def get_simple_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.dqn.simple_q_torch_policy import \
|
||||
SimpleQTorchPolicy
|
||||
return SimpleQTorchPolicy
|
||||
|
|
|
@ -79,7 +79,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.dqn.simple_q_torch_policy import \
|
||||
SimpleQTorchPolicy
|
||||
return SimpleQTorchPolicy
|
||||
|
|
|
@ -55,7 +55,7 @@ def build_q_models(policy, obs_space, action_space, config):
|
|||
action_space=action_space,
|
||||
num_outputs=action_space.n,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
framework=config["framework"],
|
||||
name=Q_SCOPE)
|
||||
|
||||
policy.target_q_model = ModelCatalog.get_model_v2(
|
||||
|
@ -63,7 +63,7 @@ def build_q_models(policy, obs_space, action_space, config):
|
|||
action_space=action_space,
|
||||
num_outputs=action_space.n,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
framework=config["framework"],
|
||||
name=Q_TARGET_SCOPE)
|
||||
|
||||
policy.q_func_vars = policy.q_model.variables()
|
||||
|
@ -83,9 +83,9 @@ def get_distribution_inputs_and_class(policy,
|
|||
q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals
|
||||
|
||||
policy.q_values = q_vals
|
||||
return policy.q_values,\
|
||||
TorchCategorical if policy.config["use_pytorch"] else Categorical,\
|
||||
[] # state-outs
|
||||
return policy.q_values, (TorchCategorical
|
||||
if policy.config["framework"] == "torch" else
|
||||
Categorical), [] # state-outs
|
||||
|
||||
|
||||
def build_q_losses(policy, model, dist_class, train_batch):
|
||||
|
|
|
@ -21,9 +21,10 @@ class TestApexDQN(unittest.TestCase):
|
|||
config["timesteps_per_iteration"] = 100
|
||||
config["min_iter_time_s"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
|
||||
trainer.train()
|
||||
trainer.stop()
|
||||
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
||||
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
|
||||
trainer.train()
|
||||
trainer.stop()
|
||||
|
||||
def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
|
||||
"""Test whether an APEX-DQNTrainer can be built on all frameworks."""
|
||||
|
@ -34,7 +35,7 @@ class TestApexDQN(unittest.TestCase):
|
|||
config["min_iter_time_s"] = 1
|
||||
config["optimizer"]["num_replay_buffer_shards"] = 1
|
||||
|
||||
for _ in framework_iterator(config, ("torch", "tf", "eager")):
|
||||
for _ in framework_iterator(config):
|
||||
plain_config = config.copy()
|
||||
trainer = apex.ApexTrainer(config=plain_config, env="CartPole-v0")
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.rllib.agents.dqn as dqn
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check, framework_iterator, \
|
||||
|
@ -10,11 +11,19 @@ tf = try_import_tf()
|
|||
|
||||
|
||||
class TestDQN(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_dqn_compilation(self):
|
||||
"""Test whether a DQNTrainer can be built on all frameworks."""
|
||||
config = dqn.DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
num_iterations = 2
|
||||
config["num_workers"] = 2
|
||||
num_iterations = 1
|
||||
|
||||
for fw in framework_iterator(config):
|
||||
# double-dueling DQN.
|
||||
|
|
|
@ -161,7 +161,7 @@ class Worker:
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.es.es_torch_policy import ESTorchPolicy
|
||||
policy_cls = ESTorchPolicy
|
||||
else:
|
||||
|
@ -203,6 +203,13 @@ class ESTrainer(Trainer):
|
|||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@override(Trainer)
|
||||
def get_policy(self, policy=DEFAULT_POLICY_ID):
|
||||
if policy != DEFAULT_POLICY_ID:
|
||||
raise ValueError("ES has no policy '{}'! Use {} "
|
||||
"instead.".format(policy, DEFAULT_POLICY_ID))
|
||||
return self.policy
|
||||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
|
|
@ -8,16 +8,20 @@ import ray
|
|||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.spaces.space_utils import unbatch
|
||||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
|
||||
unbatch
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
torch, _ = try_import_torch()
|
||||
tree = try_import_tree()
|
||||
|
||||
|
||||
def before_init(policy, observation_space, action_space, config):
|
||||
policy.action_noise_std = config["action_noise_std"]
|
||||
policy.action_space_struct = get_base_struct_from_space(action_space)
|
||||
policy.preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
observation_space)
|
||||
policy.observation_filter = get_filter(config["observation_filter"],
|
||||
|
@ -60,10 +64,18 @@ def before_init(policy, observation_space, action_space, config):
|
|||
SampleBatch.CUR_OBS: observation
|
||||
}, [], None)
|
||||
dist = policy.dist_class(dist_inputs, policy.model)
|
||||
action = dist.sample().detach().numpy()
|
||||
action = dist.sample()
|
||||
|
||||
def _add_noise(single_action, single_action_space):
|
||||
single_action = single_action.detach().numpy()
|
||||
if add_noise and isinstance(single_action_space, gym.spaces.Box):
|
||||
single_action += np.random.randn(*single_action.shape) * \
|
||||
policy.action_noise_std
|
||||
return single_action
|
||||
|
||||
action = tree.map_structure(_add_noise, action,
|
||||
policy.action_space_struct)
|
||||
action = unbatch(action)
|
||||
if add_noise and isinstance(policy.action_space, gym.spaces.Box):
|
||||
action += np.random.randn(*action.shape) * policy.action_noise_std
|
||||
return action
|
||||
|
||||
type(policy).compute_actions = _compute_actions
|
||||
|
@ -87,7 +99,7 @@ def make_model_and_action_dist(policy, observation_space, action_space,
|
|||
dist_type="deterministic",
|
||||
framework="torch")
|
||||
model = ModelCatalog.get_model_v2(
|
||||
observation_space,
|
||||
policy.preprocessor.observation_space,
|
||||
action_space,
|
||||
num_outputs=dist_dim,
|
||||
model_config=config["model"],
|
||||
|
|
|
@ -13,6 +13,7 @@ class TestES(unittest.TestCase):
|
|||
# Keep it simple.
|
||||
config["model"]["fcnet_hiddens"] = [10]
|
||||
config["model"]["fcnet_activation"] = None
|
||||
config["noise_size"] = 2500000
|
||||
|
||||
num_iterations = 2
|
||||
|
||||
|
|
|
@ -147,7 +147,7 @@ def make_learner_thread(local_worker, config):
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
if config["vtrace"]:
|
||||
from ray.rllib.agents.impala.vtrace_torch_policy import \
|
||||
VTraceTorchPolicy
|
||||
|
|
|
@ -28,6 +28,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
print("Env={}".format(env))
|
||||
print("w/ LSTM")
|
||||
# Test w/o LSTM.
|
||||
local_cfg["num_aggregation_workers"] = 0
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
|
@ -37,6 +38,7 @@ class TestIMPALA(unittest.TestCase):
|
|||
# Test w/ LSTM.
|
||||
print("w/o LSTM")
|
||||
local_cfg["model"]["use_lstm"] = True
|
||||
local_cfg["num_aggregation_workers"] = 2
|
||||
trainer = impala.ImpalaTrainer(config=local_cfg, env=env)
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
|
|
|
@ -87,7 +87,7 @@ class LogProbsFromLogitsAndActionsTest(unittest.TestCase):
|
|||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("torch", "tf"), session=True):
|
||||
vtrace = vtrace_tf if fw == "tf" else vtrace_torch
|
||||
vtrace = vtrace_tf if fw != "torch" else vtrace_torch
|
||||
policy_logits = Box(-1.0, 1.0, (seq_len, batch_size, num_actions),
|
||||
np.float32).sample()
|
||||
actions = np.random.randint(
|
||||
|
@ -149,7 +149,7 @@ class VtraceTest(unittest.TestCase):
|
|||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("torch", "tf"), session=True):
|
||||
vtrace = vtrace_tf if fw == "tf" else vtrace_torch
|
||||
vtrace = vtrace_tf if fw != "torch" else vtrace_torch
|
||||
output = vtrace.from_importance_weights(**values)
|
||||
if sess:
|
||||
output = sess.run(output)
|
||||
|
@ -178,7 +178,7 @@ class VtraceTest(unittest.TestCase):
|
|||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("torch", "tf"), session=True):
|
||||
vtrace = vtrace_tf if fw == "tf" else vtrace_torch
|
||||
vtrace = vtrace_tf if fw != "torch" else vtrace_torch
|
||||
|
||||
if fw == "tf":
|
||||
# Intentionally leaving shapes unspecified to test if V-trace
|
||||
|
@ -218,7 +218,7 @@ class VtraceTest(unittest.TestCase):
|
|||
clip_pg_rho_threshold=clip_pg_rho_threshold,
|
||||
**inputs_)
|
||||
|
||||
if fw == "tf":
|
||||
if fw != "torch":
|
||||
target_log_probs = vtrace.log_probs_from_logits_and_actions(
|
||||
inputs_["target_policy_logits"], inputs_["actions"])
|
||||
behaviour_log_probs = vtrace.log_probs_from_logits_and_actions(
|
||||
|
@ -279,7 +279,7 @@ class VtraceTest(unittest.TestCase):
|
|||
def test_higher_rank_inputs_for_importance_weights(self):
|
||||
"""Checks support for additional dimensions in inputs."""
|
||||
for fw in framework_iterator(frameworks=("torch", "tf"), session=True):
|
||||
vtrace = vtrace_tf if fw == "tf" else vtrace_torch
|
||||
vtrace = vtrace_tf if fw != "torch" else vtrace_torch
|
||||
if fw == "tf":
|
||||
inputs_ = {
|
||||
"log_rhos": tf.placeholder(
|
||||
|
@ -307,7 +307,7 @@ class VtraceTest(unittest.TestCase):
|
|||
def test_inconsistent_rank_inputs_for_importance_weights(self):
|
||||
"""Test one of many possible errors in shape of inputs."""
|
||||
for fw in framework_iterator(frameworks=("torch", "tf"), session=True):
|
||||
vtrace = vtrace_tf if fw == "tf" else vtrace_torch
|
||||
vtrace = vtrace_tf if fw != "torch" else vtrace_torch
|
||||
if fw == "tf":
|
||||
inputs_ = {
|
||||
"log_rhos": tf.placeholder(
|
||||
|
|
|
@ -91,20 +91,20 @@ class VTraceLoss:
|
|||
tf.float32))
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
|
||||
# The policy gradients loss
|
||||
# The policy gradients loss.
|
||||
self.pi_loss = -tf.reduce_sum(
|
||||
tf.boolean_mask(actions_logp * self.vtrace_returns.pg_advantages,
|
||||
valid_mask))
|
||||
|
||||
# The baseline loss
|
||||
# The baseline loss.
|
||||
delta = tf.boolean_mask(values - self.vtrace_returns.vs, valid_mask)
|
||||
self.vf_loss = 0.5 * tf.reduce_sum(tf.square(delta))
|
||||
|
||||
# The entropy loss
|
||||
# The entropy loss.
|
||||
self.entropy = tf.reduce_sum(
|
||||
tf.boolean_mask(actions_entropy, valid_mask))
|
||||
|
||||
# The summed weighted loss
|
||||
# The summed weighted loss.
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
|
|
@ -94,19 +94,19 @@ class VTraceLoss:
|
|||
# Move v-trace results back to GPU for actual loss computing.
|
||||
self.value_targets = self.vtrace_returns.vs.to(device)
|
||||
|
||||
# The policy gradients loss
|
||||
# The policy gradients loss.
|
||||
self.pi_loss = -torch.sum(
|
||||
actions_logp * self.vtrace_returns.pg_advantages.to(device) *
|
||||
valid_mask)
|
||||
|
||||
# The baseline loss
|
||||
# The baseline loss.
|
||||
delta = (values - self.value_targets) * valid_mask
|
||||
self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0))
|
||||
|
||||
# The entropy loss
|
||||
# The entropy loss.
|
||||
self.entropy = torch.sum(actions_entropy * valid_mask)
|
||||
|
||||
# The summed weighted loss
|
||||
# The summed weighted loss.
|
||||
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
|
||||
self.entropy * entropy_coeff)
|
||||
|
||||
|
@ -135,10 +135,11 @@ def build_vtrace_loss(policy, model, dist_class, train_batch):
|
|||
rewards = train_batch[SampleBatch.REWARDS]
|
||||
behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
|
||||
behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
|
||||
if isinstance(output_hidden_shape, list):
|
||||
if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
|
||||
unpacked_behaviour_logits = torch.split(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1)
|
||||
behaviour_logits, list(output_hidden_shape), dim=1)
|
||||
unpacked_outputs = torch.split(
|
||||
model_out, list(output_hidden_shape), dim=1)
|
||||
else:
|
||||
unpacked_behaviour_logits = torch.chunk(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
|
@ -162,7 +163,7 @@ def build_vtrace_loss(policy, model, dist_class, train_batch):
|
|||
actions_logp=_make_time_major(
|
||||
action_dist.logp(actions), drop_last=True),
|
||||
actions_entropy=_make_time_major(
|
||||
action_dist.multi_entropy(), drop_last=True),
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=_make_time_major(dones, drop_last=True),
|
||||
behaviour_action_logp=_make_time_major(
|
||||
behaviour_action_logp, drop_last=True),
|
||||
|
|
|
@ -35,15 +35,13 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"learning_starts": 0,
|
||||
# === Parallelism ===
|
||||
"num_workers": 0,
|
||||
# Use PyTorch as framework?
|
||||
"use_pytorch": False
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config.get("use_pytorch") is True:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.marwil.marwil_torch_policy import \
|
||||
MARWILTorchPolicy
|
||||
return MARWILTorchPolicy
|
||||
|
|
|
@ -15,7 +15,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.pg.pg_torch_policy import PGTorchPolicy
|
||||
return PGTorchPolicy
|
||||
else:
|
||||
|
|
|
@ -75,12 +75,11 @@ class TestPG(unittest.TestCase):
|
|||
feed_dict=policy._get_loss_inputs_dict(
|
||||
train_batch, shuffle=False))
|
||||
else:
|
||||
results = (pg.pg_tf_loss
|
||||
if fw == "eager" else pg.pg_torch_loss)(
|
||||
policy,
|
||||
policy.model,
|
||||
dist_class=dist_cls,
|
||||
train_batch=train_batch)
|
||||
results = (pg.pg_tf_loss if fw == "tfe" else pg.pg_torch_loss)(
|
||||
policy,
|
||||
policy.model,
|
||||
dist_class=dist_cls,
|
||||
train_batch=train_batch)
|
||||
|
||||
# Calculate expected results.
|
||||
if fw != "torch":
|
||||
|
|
|
@ -101,7 +101,7 @@ def add_target_callback(config):
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config.get("use_pytorch") is True:
|
||||
if config.get("framework") == "torch":
|
||||
from ray.rllib.agents.ppo.appo_torch_policy import AsyncPPOTorchPolicy
|
||||
return AsyncPPOTorchPolicy
|
||||
else:
|
||||
|
|
|
@ -209,7 +209,7 @@ def build_appo_model(policy, obs_space, action_space, config):
|
|||
logit_dim,
|
||||
config["model"],
|
||||
name=POLICY_SCOPE,
|
||||
framework="torch" if config["use_pytorch"] else "tf")
|
||||
framework="torch" if config["framework"] == "torch" else "tf")
|
||||
policy.model_variables = policy.model.variables()
|
||||
|
||||
policy.target_model = ModelCatalog.get_model_v2(
|
||||
|
@ -218,7 +218,7 @@ def build_appo_model(policy, obs_space, action_space, config):
|
|||
logit_dim,
|
||||
config["model"],
|
||||
name=TARGET_POLICY_SCOPE,
|
||||
framework="torch" if config["use_pytorch"] else "tf")
|
||||
framework="torch" if config["framework"] == "torch" else "tf")
|
||||
policy.target_model_variables = policy.target_model.variables()
|
||||
|
||||
return policy.model
|
||||
|
|
|
@ -241,11 +241,20 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
target_model_out, _ = policy.target_model.from_batch(train_batch)
|
||||
old_policy_behaviour_logits = target_model_out.detach()
|
||||
|
||||
unpacked_behaviour_logits = torch.split(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_old_policy_behaviour_logits = torch.split(
|
||||
old_policy_behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_outputs = torch.split(model_out, output_hidden_shape, dim=1)
|
||||
if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
|
||||
unpacked_behaviour_logits = torch.split(
|
||||
behaviour_logits, list(output_hidden_shape), dim=1)
|
||||
unpacked_old_policy_behaviour_logits = torch.split(
|
||||
old_policy_behaviour_logits, list(output_hidden_shape), dim=1)
|
||||
unpacked_outputs = torch.split(
|
||||
model_out, list(output_hidden_shape), dim=1)
|
||||
else:
|
||||
unpacked_behaviour_logits = torch.chunk(
|
||||
behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_old_policy_behaviour_logits = torch.chunk(
|
||||
old_policy_behaviour_logits, output_hidden_shape, dim=1)
|
||||
unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1)
|
||||
|
||||
old_policy_action_dist = dist_class(old_policy_behaviour_logits, model)
|
||||
prev_action_dist = dist_class(behaviour_logits, policy.model)
|
||||
values = policy.model.value_function()
|
||||
|
@ -269,7 +278,7 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
|
||||
# Prepare KL for Loss
|
||||
mean_kl = _make_time_major(
|
||||
old_policy_action_dist.multi_kl(action_dist), drop_last=True)
|
||||
old_policy_action_dist.kl(action_dist), drop_last=True)
|
||||
|
||||
policy.loss = VTraceSurrogateLoss(
|
||||
actions=_make_time_major(loss_actions, drop_last=True),
|
||||
|
@ -279,10 +288,9 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
action_dist.logp(actions), drop_last=True),
|
||||
old_policy_actions_logp=_make_time_major(
|
||||
old_policy_action_dist.logp(actions), drop_last=True),
|
||||
action_kl=torch.mean(mean_kl, dim=0)
|
||||
if is_multidiscrete else mean_kl,
|
||||
action_kl=mean_kl,
|
||||
actions_entropy=_make_time_major(
|
||||
action_dist.multi_entropy(), drop_last=True),
|
||||
action_dist.entropy(), drop_last=True),
|
||||
dones=_make_time_major(dones, drop_last=True),
|
||||
behaviour_logits=_make_time_major(
|
||||
unpacked_behaviour_logits, drop_last=True),
|
||||
|
@ -308,14 +316,13 @@ def build_appo_surrogate_loss(policy, model, dist_class, train_batch):
|
|||
logger.debug("Using PPO surrogate loss (vtrace=False)")
|
||||
|
||||
# Prepare KL for Loss
|
||||
mean_kl = _make_time_major(prev_action_dist.multi_kl(action_dist))
|
||||
mean_kl = _make_time_major(prev_action_dist.kl(action_dist))
|
||||
|
||||
policy.loss = PPOSurrogateLoss(
|
||||
prev_actions_logp=_make_time_major(prev_action_dist.logp(actions)),
|
||||
actions_logp=_make_time_major(action_dist.logp(actions)),
|
||||
action_kl=torch.mean(mean_kl, dim=0)
|
||||
if is_multidiscrete else mean_kl,
|
||||
actions_entropy=_make_time_major(action_dist.multi_entropy()),
|
||||
action_kl=mean_kl,
|
||||
actions_entropy=_make_time_major(action_dist.entropy()),
|
||||
values=_make_time_major(values),
|
||||
valid_mask=_make_time_major(mask),
|
||||
advantages=_make_time_major(
|
||||
|
|
|
@ -49,7 +49,7 @@ DEFAULT_CONFIG = with_base_config(ppo.DEFAULT_CONFIG, {
|
|||
|
||||
# *** WARNING: configs below are DDPPO overrides over PPO; you
|
||||
# shouldn't need to adjust them. ***
|
||||
"use_pytorch": True, # DDPPO requires PyTorch distributed.
|
||||
"framework": "torch", # DDPPO requires PyTorch distributed.
|
||||
"num_gpus": 0, # Learning is no longer done on the driver process, so
|
||||
# giving GPUs to the driver does not make sense!
|
||||
"num_gpus_per_worker": 1, # Each rollout worker gets a GPU.
|
||||
|
@ -70,7 +70,7 @@ def validate_config(config):
|
|||
raise ValueError(
|
||||
"Set rollout_fragment_length instead of train_batch_size "
|
||||
"for DDPPO.")
|
||||
if not config["use_pytorch"]:
|
||||
if config["framework"] != "torch":
|
||||
raise ValueError(
|
||||
"Distributed data parallel is only supported for PyTorch")
|
||||
if config["num_gpus"]:
|
||||
|
|
|
@ -73,8 +73,6 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# Whether to fake GPUs (using CPUs).
|
||||
# Set this to True for debugging on non-GPU machines (set `num_gpus` > 0).
|
||||
"_fake_gpus": False,
|
||||
# Use PyTorch as framework?
|
||||
"use_pytorch": False,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -138,12 +136,13 @@ def validate_config(config):
|
|||
logger.warning(
|
||||
"Using the simple minibatch optimizer. This will significantly "
|
||||
"reduce performance, consider simple_optimizer=False.")
|
||||
elif config["use_pytorch"] or (tf and tf.executing_eagerly()):
|
||||
config["simple_optimizer"] = True # multi-gpu not supported
|
||||
# Multi-gpu not supported for PyTorch and tf-eager.
|
||||
elif config["framework"] in ["tfe", "torch"]:
|
||||
config["simple_optimizer"] = True
|
||||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.ppo.ppo_torch_policy import PPOTorchPolicy
|
||||
return PPOTorchPolicy
|
||||
else:
|
||||
|
|
|
@ -49,7 +49,7 @@ class TestPPO(unittest.TestCase):
|
|||
def test_ppo_compilation(self):
|
||||
"""Test whether a PPOTrainer can be built with both frameworks."""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
config["num_workers"] = 0 # Run locally.
|
||||
config["num_workers"] = 1
|
||||
num_iterations = 2
|
||||
|
||||
for _ in framework_iterator(config):
|
||||
|
@ -64,6 +64,7 @@ class TestPPO(unittest.TestCase):
|
|||
# Fake GPU setup.
|
||||
config["num_gpus"] = 2
|
||||
config["_fake_gpus"] = True
|
||||
config["framework"] = "tf"
|
||||
# Mimick tuned_example for PPO CartPole.
|
||||
config["num_workers"] = 1
|
||||
config["lr"] = 0.0003
|
||||
|
@ -164,7 +165,7 @@ class TestPPO(unittest.TestCase):
|
|||
init_std = get_value()
|
||||
assert init_std == 0.0, init_std
|
||||
|
||||
if fw == "tf" or fw == "eager":
|
||||
if fw in ["tf", "tfe"]:
|
||||
batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH)
|
||||
else:
|
||||
batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH)
|
||||
|
@ -205,7 +206,7 @@ class TestPPO(unittest.TestCase):
|
|||
# to train_batch dict.
|
||||
# A = [0.99^2 * 0.5 + 0.99 * -1.0 + 1.0, 0.99 * 0.5 - 1.0, 0.5] =
|
||||
# [0.50005, -0.505, 0.5]
|
||||
if fw == "tf" or fw == "eager":
|
||||
if fw == "tf" or fw == "tfe":
|
||||
train_batch = postprocess_ppo_gae_tf(policy, FAKE_BATCH)
|
||||
else:
|
||||
train_batch = postprocess_ppo_gae_torch(policy, FAKE_BATCH)
|
||||
|
@ -216,7 +217,7 @@ class TestPPO(unittest.TestCase):
|
|||
[0.50005, -0.505, 0.5])
|
||||
|
||||
# Calculate actual PPO loss.
|
||||
if fw == "eager":
|
||||
if fw == "tfe":
|
||||
ppo_surrogate_loss_tf(policy, policy.model, Categorical,
|
||||
train_batch)
|
||||
elif fw == "torch":
|
||||
|
|
|
@ -123,7 +123,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
if config.get("use_pytorch") is True:
|
||||
if config["framework"] == "torch":
|
||||
from ray.rllib.agents.sac.sac_torch_policy import SACTorchPolicy
|
||||
return SACTorchPolicy
|
||||
else:
|
||||
|
|
|
@ -63,8 +63,9 @@ def build_sac_model(policy, obs_space, action_space, config):
|
|||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
model_interface=SACTorchModel if config["use_pytorch"] else SACTFModel,
|
||||
framework=config["framework"],
|
||||
model_interface=SACTorchModel
|
||||
if config["framework"] == "torch" else SACTFModel,
|
||||
name="sac_model",
|
||||
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
|
||||
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
|
||||
|
@ -79,8 +80,9 @@ def build_sac_model(policy, obs_space, action_space, config):
|
|||
action_space=action_space,
|
||||
num_outputs=num_outputs,
|
||||
model_config=config["model"],
|
||||
framework="torch" if config["use_pytorch"] else "tf",
|
||||
model_interface=SACTorchModel if config["use_pytorch"] else SACTFModel,
|
||||
framework=config["framework"],
|
||||
model_interface=SACTorchModel
|
||||
if config["framework"] == "torch" else SACTFModel,
|
||||
name="target_sac_model",
|
||||
actor_hidden_activation=config["policy_model"]["fcnet_activation"],
|
||||
actor_hiddens=config["policy_model"]["fcnet_hiddens"],
|
||||
|
|
|
@ -10,22 +10,22 @@ import tempfile
|
|||
import ray
|
||||
from ray.exceptions import RayError
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
from ray.rllib.env.normalize_actions import NormalizeActionWrapper
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts, \
|
||||
try_import_tf
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.rllib.utils.framework import check_framework, try_import_tf
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import ExportFormat
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.rllib.env.normalize_actions import NormalizeActionWrapper
|
||||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
@ -140,18 +140,15 @@ COMMON_CONFIG = {
|
|||
# Use fake (infinite speed) sampler. For testing only.
|
||||
"fake_sampler": False,
|
||||
|
||||
# === Framework Settings ===
|
||||
# Use PyTorch (instead of tf). If using `rllib train`, this can also be
|
||||
# enabled with the `--torch` flag.
|
||||
# NOTE: Some agents may not support `torch` yet and throw an error.
|
||||
"use_pytorch": False,
|
||||
|
||||
# Enable TF eager execution (TF policies only). If using `rllib train`,
|
||||
# this can also be enabled with the `--eager` flag.
|
||||
"eager": False,
|
||||
# === Deep Learning Framework Settings ===
|
||||
# tf: TensorFlow
|
||||
# tfe: TensorFlow eager
|
||||
# torch: PyTorch
|
||||
# auto: "torch" if only PyTorch installed, "tf" otherwise.
|
||||
"framework": "auto",
|
||||
# Enable tracing in eager mode. This greatly improves performance, but
|
||||
# makes it slightly harder to debug since Python code won't be evaluated
|
||||
# after the initial eager pass.
|
||||
# after the initial eager pass. Only possible if framework=tfe.
|
||||
"eager_tracing": False,
|
||||
# Disable eager execution on workers (but allow it on the driver). This
|
||||
# only has an effect if eager is enabled.
|
||||
|
@ -348,6 +345,10 @@ COMMON_CONFIG = {
|
|||
# See rllib/evaluation/observation_function.py for more info.
|
||||
"observation_fn": None,
|
||||
},
|
||||
|
||||
# Deprecated keys:
|
||||
"use_pytorch": DEPRECATED_VALUE, # Replaced by `framework=torch`.
|
||||
"eager": DEPRECATED_VALUE, # Replaced by `framework=tfe`.
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -411,18 +412,11 @@ class Trainer(Trainable):
|
|||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
# User provided config (this is w/o the default Trainer's
|
||||
# `COMMON_CONFIG` (see above)). Will get merged with COMMON_CONFIG
|
||||
# in self._setup().
|
||||
config = config or {}
|
||||
|
||||
if tf and config.get("eager"):
|
||||
if not tf.executing_eagerly():
|
||||
tf.enable_eager_execution()
|
||||
logger.info("Executing eagerly, with eager_tracing={}".format(
|
||||
"True" if config.get("eager_tracing") else "False"))
|
||||
|
||||
if tf and not tf.executing_eagerly() and not config.get("use_pytorch"):
|
||||
logger.info("Tip: set 'eager': true or the --eager flag to enable "
|
||||
"TensorFlow eager execution")
|
||||
|
||||
# Vars to synchronize to workers on each train call
|
||||
self.global_vars = {"timestep": 0}
|
||||
|
||||
|
@ -555,6 +549,32 @@ class Trainer(Trainable):
|
|||
self.config = Trainer.merge_trainer_configs(self._default_config,
|
||||
config)
|
||||
|
||||
# Check and resolve DL framework settings.
|
||||
if "use_pytorch" in self.config and \
|
||||
self.config["use_pytorch"] != DEPRECATED_VALUE:
|
||||
deprecation_warning("use_pytorch", "framework=torch", error=False)
|
||||
if self.config["use_pytorch"]:
|
||||
self.config["framework"] = "torch"
|
||||
self.config.pop("use_pytorch")
|
||||
if "eager" in self.config and self.config["eager"] != DEPRECATED_VALUE:
|
||||
deprecation_warning("eager", "framework=tfe", error=False)
|
||||
if self.config["eager"]:
|
||||
self.config["framework"] = "tfe"
|
||||
self.config.pop("eager")
|
||||
|
||||
# Check all dependencies and resolve "auto" framework.
|
||||
self.config["framework"] = check_framework(self.config["framework"])
|
||||
# Notify about eager/tracing support.
|
||||
if tf and self.config["framework"] == "tfe":
|
||||
if not tf.executing_eagerly():
|
||||
tf.enable_eager_execution()
|
||||
logger.info("Executing eagerly, with eager_tracing={}".format(
|
||||
self.config["eager_tracing"]))
|
||||
if tf and not tf.executing_eagerly() and \
|
||||
self.config["framework"] != "torch":
|
||||
logger.info("Tip: set framework=tfe or the --eager flag to enable "
|
||||
"TensorFlow eager execution")
|
||||
|
||||
if self.config["normalize_actions"]:
|
||||
inner = self.env_creator
|
||||
|
||||
|
|
|
@ -106,7 +106,7 @@ DEFAULT_CONFIG = with_common_config({
|
|||
# === Callbacks ===
|
||||
"callbacks": AlphaZeroDefaultCallbacks,
|
||||
|
||||
"use_pytorch": True,
|
||||
"framework": "torch", # Only PyTorch supported so far.
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
|||
TS_CONFIG = with_common_config({
|
||||
# No remote workers by default.
|
||||
"num_workers": 0,
|
||||
"use_pytorch": True,
|
||||
"framework": "torch", # Only PyTorch supported so far.
|
||||
|
||||
# Do online learning one step at a time.
|
||||
"rollout_fragment_length": 1,
|
||||
|
|
|
@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
|||
UCB_CONFIG = with_common_config({
|
||||
# No remote workers by default.
|
||||
"num_workers": 0,
|
||||
"use_pytorch": True,
|
||||
"framework": "torch", # Only PyTorch supported so far.
|
||||
|
||||
# Do online learning one step at a time.
|
||||
"rollout_fragment_length": 1,
|
||||
|
|
|
@ -18,15 +18,13 @@ from ray.util.debug import log_once
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TS_PATH = "ray.rllib.contrib.bandits.exploration.ThompsonSampling"
|
||||
UCB_PATH = "ray.rllib.contrib.bandits.exploration.UCB"
|
||||
|
||||
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
# No remote workers by default.
|
||||
"num_workers": 0,
|
||||
"use_pytorch": True,
|
||||
"framework": "torch", # Only PyTorch supported so far.
|
||||
|
||||
# Do online learning one step at a time.
|
||||
"rollout_fragment_length": 1,
|
||||
|
|
|
@ -8,7 +8,6 @@ from matplotlib import pyplot as plt
|
|||
import pandas as pd
|
||||
|
||||
from ray import tune
|
||||
from ray.rllib.contrib.bandits.agents import LinUCBTrainer
|
||||
from ray.rllib.contrib.bandits.agents.lin_ucb import UCB_CONFIG
|
||||
from ray.rllib.contrib.bandits.envs import ParametricItemRecoEnv
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ class RandomAgent(Trainer):
|
|||
_name = "RandomAgent"
|
||||
_default_config = with_common_config({
|
||||
"rollouts_per_iteration": 10,
|
||||
"framework": "tf", # not used
|
||||
})
|
||||
|
||||
@override(Trainer)
|
||||
|
|
|
@ -264,7 +264,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
|||
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
|
||||
|
||||
policy_config = policy_config or {}
|
||||
if (tf and policy_config.get("eager")
|
||||
if (tf and policy_config.get("framework") == "tfe"
|
||||
and not policy_config.get("no_eager_on_workers")
|
||||
# This eager check is necessary for certain all-framework tests
|
||||
# that use tf's eager_mode() context generator.
|
||||
|
|
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
|||
"ff_hidden_dim": 32,
|
||||
},
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
|||
"custom_model": "autoregressive_model",
|
||||
"custom_action_dist": "binary_autoreg_dist",
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -33,7 +33,7 @@ if __name__ == "__main__":
|
|||
"custom_model": "bn_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -41,7 +41,7 @@ if __name__ == "__main__":
|
|||
"use_lstm": True,
|
||||
"lstm_use_prev_action_reward": args.use_prev_action_reward,
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
})
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -58,7 +58,7 @@ class CentralizedValueMixin:
|
|||
"""Add method to evaluate the central value function from the model."""
|
||||
|
||||
def __init__(self):
|
||||
if not self.config["use_pytorch"]:
|
||||
if self.config["framework"] != "torch":
|
||||
self.compute_central_vf = make_tf_callable(self.get_session())(
|
||||
self.model.central_value_function)
|
||||
else:
|
||||
|
@ -71,7 +71,7 @@ def centralized_critic_postprocessing(policy,
|
|||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
pytorch = policy.config["use_pytorch"]
|
||||
pytorch = policy.config["framework"] == "torch"
|
||||
if (pytorch and hasattr(policy, "compute_central_vf")) or \
|
||||
(not pytorch and policy.loss_initialized()):
|
||||
assert other_agent_batches is not None
|
||||
|
@ -126,9 +126,9 @@ def loss_with_central_critic(policy, model, dist_class, train_batch):
|
|||
train_batch[SampleBatch.CUR_OBS], train_batch[OPPONENT_OBS],
|
||||
train_batch[OPPONENT_ACTION])
|
||||
|
||||
func = TFLoss if not policy.config["use_pytorch"] else TorchLoss
|
||||
func = TFLoss if not policy.config["framework"] == "torch" else TorchLoss
|
||||
adv = tf.ones_like(train_batch[Postprocessing.ADVANTAGES], dtype=tf.bool) \
|
||||
if not policy.config["use_pytorch"] else \
|
||||
if policy.config["framework"] != "torch" else \
|
||||
torch.ones_like(train_batch[Postprocessing.ADVANTAGES],
|
||||
dtype=torch.bool)
|
||||
|
||||
|
@ -194,7 +194,8 @@ CCPPOTorchPolicy = PPOTorchPolicy.with_updates(
|
|||
|
||||
|
||||
def get_policy_class(config):
|
||||
return CCPPOTorchPolicy if config["use_pytorch"] else CCPPOTFPolicy
|
||||
return CCPPOTorchPolicy if config["framework"] == "torch" \
|
||||
else CCPPOTFPolicy
|
||||
|
||||
|
||||
CCTrainer = PPOTrainer.with_updates(
|
||||
|
@ -214,15 +215,14 @@ if __name__ == "__main__":
|
|||
config = {
|
||||
"env": TwoStepGame,
|
||||
"batch_mode": "complete_episodes",
|
||||
"eager": False,
|
||||
"num_workers": 0,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"pol1": (None, Discrete(6), TwoStepGame.action_space, {
|
||||
"use_pytorch": args.torch
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}),
|
||||
"pol2": (None, Discrete(6), TwoStepGame.action_space, {
|
||||
"use_pytorch": args.torch
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||
|
@ -230,7 +230,7 @@ if __name__ == "__main__":
|
|||
"model": {
|
||||
"custom_model": "cc_model",
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -99,7 +99,7 @@ if __name__ == "__main__":
|
|||
"model": {
|
||||
"custom_model": "cc_model",
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -120,7 +120,7 @@ if __name__ == "__main__":
|
|||
"vf_share_layers": True,
|
||||
"lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs
|
||||
"num_workers": 1, # parallelism
|
||||
"use_pytorch": args.torch
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -162,7 +162,7 @@ if __name__ == "__main__":
|
|||
"corridor_length": 5,
|
||||
},
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -42,7 +42,7 @@ if __name__ == "__main__":
|
|||
"train_batch_size": sample_from(
|
||||
lambda spec: 1000 * max(1, spec.config.num_gpus)),
|
||||
"fake_sampler": True,
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -131,4 +131,5 @@ if __name__ == "__main__":
|
|||
"custom_model": "keras_q_model"
|
||||
if args.run == "DQN" else "keras_model"
|
||||
},
|
||||
"framework": "tf",
|
||||
}))
|
||||
|
|
|
@ -57,7 +57,7 @@ if __name__ == "__main__":
|
|||
"input_files": args.input_files,
|
||||
},
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -76,6 +76,7 @@ if __name__ == "__main__":
|
|||
config={
|
||||
"env": "CartPole-v0",
|
||||
"callbacks": MyCallbacks,
|
||||
"framework": "tf",
|
||||
},
|
||||
return_trials=True)
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ if __name__ == "__main__":
|
|||
"on_train_result": on_train_result,
|
||||
"on_postprocess_traj": on_postprocess_traj,
|
||||
},
|
||||
"framework": "tf",
|
||||
},
|
||||
return_trials=True)
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ if __name__ == "__main__":
|
|||
"custom_model": "rnn",
|
||||
"max_seq_len": 20,
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -51,4 +51,5 @@ if __name__ == "__main__":
|
|||
config={
|
||||
"env": "CartPole-v0",
|
||||
"num_workers": 2,
|
||||
"framework": "tf",
|
||||
})
|
||||
|
|
|
@ -37,5 +37,5 @@ if __name__ == "__main__":
|
|||
config={
|
||||
"env": "CartPole-v0",
|
||||
"num_workers": 2,
|
||||
"use_pytorch": True
|
||||
"framework": "torch",
|
||||
})
|
||||
|
|
|
@ -44,7 +44,7 @@ if __name__ == "__main__":
|
|||
config = {
|
||||
"lr": 0.01,
|
||||
"num_workers": 0,
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
resources = PPOTrainer.default_resource_request(config).to_json()
|
||||
tune.run(my_train_fn, resources_per_trial=resources, config=config)
|
||||
|
|
|
@ -68,6 +68,7 @@ if __name__ == "__main__":
|
|||
"model": {
|
||||
"custom_model": "eager_model"
|
||||
},
|
||||
"framework": "tfe",
|
||||
}
|
||||
stop = {
|
||||
"timesteps_total": args.stop_timesteps,
|
||||
|
|
|
@ -60,7 +60,7 @@ if __name__ == "__main__":
|
|||
config={
|
||||
"env": WindyMazeEnv,
|
||||
"num_workers": 0,
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
},
|
||||
)
|
||||
else:
|
||||
|
@ -93,7 +93,7 @@ if __name__ == "__main__":
|
|||
},
|
||||
"policy_mapping_fn": function(policy_mapping_fn),
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
results = tune.run("PPO", stop=stop, config=config)
|
||||
|
|
|
@ -32,7 +32,7 @@ if __name__ == "__main__":
|
|||
|
||||
# Configure our Trainer.
|
||||
config = {
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
# Extra config passed to the custom model's c'tor as kwargs.
|
||||
|
|
|
@ -81,7 +81,7 @@ if __name__ == "__main__":
|
|||
"policies": policies,
|
||||
"policy_mapping_fn": (lambda agent_id: random.choice(policy_ids)),
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
stop = {
|
||||
"episode_reward_mean": args.stop_reward,
|
||||
|
|
|
@ -55,14 +55,14 @@ if __name__ == "__main__":
|
|||
"multiagent": {
|
||||
"policies": {
|
||||
"pg_policy": (None, obs_space, act_space, {
|
||||
"use_pytorch": args.torch
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}),
|
||||
"random": (RandomPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: ["pg_policy", "random"][agent_id % 2]),
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ if __name__ == "__main__":
|
|||
# disable filters, otherwise we would need to synchronize those
|
||||
# as well to the DQN agent
|
||||
"observation_filter": "NoFilter",
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
})
|
||||
|
||||
dqn_trainer = DQNTrainer(
|
||||
|
@ -82,7 +82,7 @@ if __name__ == "__main__":
|
|||
},
|
||||
"gamma": 0.95,
|
||||
"n_step": 3,
|
||||
"use_pytorch": args.torch or args.mixed_torch_tf,
|
||||
"framework": "torch" if args.torch or args.mixed_torch_tf else "tf"
|
||||
})
|
||||
|
||||
# You should see both the printed X and Y approach 200 as this trains:
|
||||
|
|
|
@ -43,7 +43,7 @@ if __name__ == "__main__":
|
|||
"num_sgd_iter": 4,
|
||||
"num_workers": 0,
|
||||
"vf_loss_coeff": 0.01,
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -61,7 +61,7 @@ if __name__ == "__main__":
|
|||
"custom_model": "pa_model",
|
||||
},
|
||||
"num_workers": 0,
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}, **cfg)
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -35,7 +35,7 @@ def run_same_policy(args, stop):
|
|||
"""Use the same policy for both agents (trivial case)."""
|
||||
config = {
|
||||
"env": RockPaperScissors,
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
results = tune.run("PG", config=config, stop=stop)
|
||||
|
@ -77,12 +77,12 @@ def run_heuristic_vs_learned(args, use_lstm=False, trainer="PG"):
|
|||
"model": {
|
||||
"use_lstm": use_lstm
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}),
|
||||
},
|
||||
"policy_mapping_fn": select_policy,
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
cls = get_agent_class(trainer) if isinstance(trainer, str) else trainer
|
||||
trainer_obj = cls(config=config)
|
||||
|
|
|
@ -21,6 +21,8 @@ CHECKPOINT_FILE = "last_checkpoint_{}.out"
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default="DQN")
|
||||
parser.add_argument(
|
||||
"--framework", type=str, choices=["tf", "torch"], default="tf")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
@ -54,6 +56,7 @@ if __name__ == "__main__":
|
|||
"learning_starts": 100,
|
||||
"timesteps_per_iteration": 200,
|
||||
"log_level": "INFO",
|
||||
"framework": args.framework,
|
||||
}))
|
||||
elif args.run == "PPO":
|
||||
# Example of using PPO (does NOT support off-policy actions).
|
||||
|
@ -63,6 +66,7 @@ if __name__ == "__main__":
|
|||
connector_config, **{
|
||||
"sample_batch_size": 1000,
|
||||
"train_batch_size": 4000,
|
||||
"framework": args.framework,
|
||||
}))
|
||||
else:
|
||||
raise ValueError("--run must be DQN or PPO")
|
||||
|
|
|
@ -139,7 +139,7 @@ if __name__ == "__main__":
|
|||
"policy_mapping_fn": policy_mapping_fn,
|
||||
"policies_to_train": ["dqn_policy", "ppo_policy"],
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
|
|
@ -76,7 +76,7 @@ if __name__ == "__main__":
|
|||
},
|
||||
"policy_mapping_fn": lambda x: "pol1" if x == 0 else "pol2",
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
group = False
|
||||
elif args.run == "QMIX":
|
||||
|
@ -91,11 +91,11 @@ if __name__ == "__main__":
|
|||
"separate_state_space": True,
|
||||
"one_hot_state_encoding": True
|
||||
},
|
||||
"use_pytorch": args.torch,
|
||||
"framework": "torch" if args.torch else "tf",
|
||||
}
|
||||
group = True
|
||||
else:
|
||||
config = {}
|
||||
config = {"framework": "torch" if args.torch else "tf"}
|
||||
group = False
|
||||
|
||||
ray.init(num_cpus=args.num_cpus or None)
|
||||
|
|
|
@ -3,6 +3,7 @@ import timeit
|
|||
import unittest
|
||||
|
||||
from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
from ray.rllib.utils.test_utils import check
|
||||
|
||||
|
||||
class TestSegmentTree(unittest.TestCase):
|
||||
|
@ -114,6 +115,7 @@ class TestSegmentTree(unittest.TestCase):
|
|||
number=10000)
|
||||
"""
|
||||
capacity = 2**20
|
||||
# Expect reductions to be much faster now.
|
||||
new = timeit.timeit(
|
||||
"tree.sum(5, 60000)",
|
||||
setup="from ray.rllib.execution.segment_tree import "
|
||||
|
@ -124,8 +126,24 @@ class TestSegmentTree(unittest.TestCase):
|
|||
setup="from ray.rllib.execution.tests.old_segment_tree import "
|
||||
"OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
|
||||
number=10000)
|
||||
print("Sum performance (time spent) old={} new={}".format(old, new))
|
||||
self.assertGreater(old, new)
|
||||
|
||||
# Expect insertions to be roughly the same.
|
||||
new = timeit.timeit(
|
||||
"tree[50000] = 10; tree[50001] = 11",
|
||||
setup="from ray.rllib.execution.segment_tree import "
|
||||
"SumSegmentTree; tree = SumSegmentTree({})".format(capacity),
|
||||
number=100000)
|
||||
old = timeit.timeit(
|
||||
"tree[50000] = 10; tree[50001] = 11",
|
||||
setup="from ray.rllib.execution.tests.old_segment_tree import "
|
||||
"OldSumSegmentTree; tree = OldSumSegmentTree({})".format(capacity),
|
||||
number=100000)
|
||||
print("Insertion performance (time spent) "
|
||||
"old={} new={}".format(old, new))
|
||||
check(old, new, rtol=0.15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
|
@ -20,10 +20,11 @@ from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
|
|||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
||||
TorchDeterministic, TorchDiagGaussian, \
|
||||
TorchMultiActionDistribution, TorchMultiCategorical
|
||||
from ray.rllib.utils import try_import_tf, try_import_tree
|
||||
from ray.rllib.utils import try_import_tree
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.framework import check_framework, try_import_tf
|
||||
from ray.rllib.utils.spaces.simplex import Simplex
|
||||
from ray.rllib.utils.spaces.space_utils import flatten_space
|
||||
|
||||
|
@ -125,7 +126,7 @@ class ModelCatalog:
|
|||
action_space (Space): Action space of the target gym env.
|
||||
config (Optional[dict]): Optional model config.
|
||||
dist_type (Optional[str]): Identifier of the action distribution.
|
||||
framework (str): One of "tf" or "torch".
|
||||
framework (str): One of "tf", "tfe", or "torch".
|
||||
kwargs (dict): Optional kwargs to pass on to the Distribution's
|
||||
constructor.
|
||||
|
||||
|
@ -133,6 +134,10 @@ class ModelCatalog:
|
|||
dist_class (ActionDistribution): Python class of the distribution.
|
||||
dist_dim (int): The size of the input vector to the distribution.
|
||||
"""
|
||||
|
||||
# Make sure, framework is ok.
|
||||
framework = check_framework(framework)
|
||||
|
||||
dist = None
|
||||
config = config or MODEL_DEFAULTS
|
||||
# Custom distribution given.
|
||||
|
@ -158,13 +163,14 @@ class ModelCatalog:
|
|||
"using a Tuple action space, or the multi-agent API.")
|
||||
# TODO(sven): Check for bounds and return SquashedNormal, etc..
|
||||
if dist_type is None:
|
||||
dist = DiagGaussian if framework == "tf" else TorchDiagGaussian
|
||||
dist = TorchDiagGaussian if framework == "torch" \
|
||||
else DiagGaussian
|
||||
elif dist_type == "deterministic":
|
||||
dist = Deterministic if framework == "tf" else \
|
||||
TorchDeterministic
|
||||
dist = TorchDeterministic if framework == "torch" \
|
||||
else Deterministic
|
||||
# Discrete Space -> Categorical.
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
dist = Categorical if framework == "tf" else TorchCategorical
|
||||
dist = TorchCategorical if framework == "torch" else Categorical
|
||||
# Tuple/Dict Spaces -> MultiAction.
|
||||
elif dist_type in (MultiActionDistribution,
|
||||
TorchMultiActionDistribution) or \
|
||||
|
@ -190,8 +196,8 @@ class ModelCatalog:
|
|||
dist = Dirichlet
|
||||
# MultiDiscrete -> MultiCategorical.
|
||||
elif isinstance(action_space, gym.spaces.MultiDiscrete):
|
||||
dist = MultiCategorical if framework == "tf" else \
|
||||
TorchMultiCategorical
|
||||
dist = TorchMultiCategorical if framework == "torch" else \
|
||||
MultiCategorical
|
||||
return partial(dist, input_lens=action_space.nvec), \
|
||||
int(sum(action_space.nvec))
|
||||
# Unknown type -> Error.
|
||||
|
@ -271,7 +277,7 @@ class ModelCatalog:
|
|||
unflatten the tensor into a ragged tensor.
|
||||
action_space (Space): Action space of the target gym env.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
framework (str): One of "tf" or "torch".
|
||||
framework (str): One of "tf", "tfe", or "torch".
|
||||
name (str): Name (scope) for the model.
|
||||
model_interface (cls): Interface required for the model
|
||||
default_model (cls): Override the default class for the model. This
|
||||
|
@ -282,6 +288,9 @@ class ModelCatalog:
|
|||
model (ModelV2): Model to use for the policy.
|
||||
"""
|
||||
|
||||
# Make sure, framework is ok.
|
||||
framework = check_framework(framework)
|
||||
|
||||
if model_config.get("custom_model"):
|
||||
|
||||
if "custom_options" in model_config and \
|
||||
|
@ -305,7 +314,7 @@ class ModelCatalog:
|
|||
model_cls = ModelCatalog._wrap_if_needed(
|
||||
model_cls, model_interface)
|
||||
|
||||
if framework == "tf":
|
||||
if framework in ["tf", "tfe"]:
|
||||
# Track and warn if vars were created but not registered.
|
||||
created = set()
|
||||
|
||||
|
@ -363,7 +372,7 @@ class ModelCatalog:
|
|||
"used, however you specified a custom model {}".format(
|
||||
model_cls))
|
||||
|
||||
if framework == "tf":
|
||||
if framework in ["tf", "tfe"]:
|
||||
v2_class = None
|
||||
# try to get a default v2 model
|
||||
if not model_config.get("custom_model"):
|
||||
|
@ -511,7 +520,7 @@ class ModelCatalog:
|
|||
options,
|
||||
state_in=None,
|
||||
seq_lens=None):
|
||||
"""Deprecated: use get_model_v2() instead."""
|
||||
"""Deprecated: Use get_model_v2() instead."""
|
||||
|
||||
deprecation_warning("get_model", "get_model_v2", error=False)
|
||||
assert isinstance(input_dict, dict)
|
||||
|
@ -563,7 +572,9 @@ class ModelCatalog:
|
|||
|
||||
@staticmethod
|
||||
def _get_v2_model_class(obs_space, model_config, framework="tf"):
|
||||
model_config = model_config or MODEL_DEFAULTS
|
||||
# Make sure, framework is ok.
|
||||
framework = check_framework(framework)
|
||||
|
||||
if framework == "torch":
|
||||
from ray.rllib.models.torch.fcnet import (FullyConnectedNetwork as
|
||||
FCNet)
|
||||
|
|
|
@ -150,7 +150,7 @@ class TestDistributions(unittest.TestCase):
|
|||
low, high = -2.0, 1.0
|
||||
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=("torch", "tf", "eager"), session=True):
|
||||
frameworks=("torch", "tf", "tfe"), session=True):
|
||||
cls = SquashedGaussian if fw != "torch" else TorchSquashedGaussian
|
||||
|
||||
# Do a stability test using extreme NN outputs to see whether
|
||||
|
@ -296,7 +296,7 @@ class TestDistributions(unittest.TestCase):
|
|||
def test_gumbel_softmax(self):
|
||||
"""Tests the GumbelSoftmax ActionDistribution (tf + eager only)."""
|
||||
for fw, sess in framework_iterator(
|
||||
frameworks=["tf", "eager"], session=True):
|
||||
frameworks=["tf", "tfe"], session=True):
|
||||
batch_size = 1000
|
||||
num_categories = 5
|
||||
input_space = Box(-1.0, 1.0, shape=(batch_size, num_categories))
|
||||
|
|
|
@ -188,7 +188,7 @@ def build_eager_tf_policy(name,
|
|||
much simpler, but has lower performance.
|
||||
|
||||
You shouldn't need to call this directly. Rather, prefer to build a TF
|
||||
graph policy and use set {"eager": true} in the trainer config to have
|
||||
graph policy and use set {"framework": "tfe"} in the trainer config to have
|
||||
it automatically be converted to an eager policy.
|
||||
|
||||
This has the same signature as build_tf_policy()."""
|
||||
|
|
|
@ -38,7 +38,7 @@ def do_test_log_likelihood(run,
|
|||
|
||||
# Test against all frameworks.
|
||||
for fw in framework_iterator(config):
|
||||
if run in [sac.SACTrainer] and fw == "eager":
|
||||
if run in [sac.SACTrainer] and fw == "tfe":
|
||||
continue
|
||||
|
||||
trainer = run(config=config, env=env)
|
||||
|
@ -62,7 +62,7 @@ def do_test_log_likelihood(run,
|
|||
if continuous:
|
||||
for idx in range(num_actions):
|
||||
a = actions[idx]
|
||||
if fw == "tf" or fw == "eager":
|
||||
if fw != "torch":
|
||||
if isinstance(vars, list):
|
||||
expected_mean_logstd = fc(
|
||||
fc(obs_batch, vars[layer_key[1][0]]),
|
||||
|
|
|
@ -67,7 +67,7 @@ if __name__ == "__main__":
|
|||
# Add torch option to exp configs.
|
||||
for exp in experiments.values():
|
||||
if args.torch:
|
||||
exp["config"]["use_pytorch"] = True
|
||||
exp["config"]["framework"] = "torch"
|
||||
|
||||
# Try running each test 3 times and make sure it reaches the given
|
||||
# reward.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
@ -12,7 +13,7 @@ class TestAttentionNetLearning(unittest.TestCase):
|
|||
"env": StatelessCartPole,
|
||||
"gamma": 0.99,
|
||||
"num_envs_per_worker": 20,
|
||||
# "framework": "tf",
|
||||
"framework": "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
|
@ -20,6 +21,14 @@ class TestAttentionNetLearning(unittest.TestCase):
|
|||
"timesteps_total": 5000000,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(num_cpus=5, ignore_reinit_error=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_ppo_attention_net_learning(self):
|
||||
ModelCatalog.register_custom_model("attention_net", GTrXLNet)
|
||||
config = dict(
|
||||
|
|
|
@ -73,6 +73,7 @@ class TestAvailActionsQMix(unittest.TestCase):
|
|||
"env_config": {
|
||||
"avail_action": 3,
|
||||
},
|
||||
"framework": "torch",
|
||||
})
|
||||
for _ in range(5):
|
||||
agent.train() # OK if it doesn't trip the action assertion error
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.tune.trial import ExportFormat
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
def get_mean_action(alg, obs):
|
||||
|
@ -21,7 +18,7 @@ def get_mean_action(alg, obs):
|
|||
CONFIGS = {
|
||||
"A3C": {
|
||||
"explore": False,
|
||||
"num_workers": 1
|
||||
"num_workers": 1,
|
||||
},
|
||||
"APEX_DDPG": {
|
||||
"explore": False,
|
||||
|
@ -37,14 +34,14 @@ CONFIGS = {
|
|||
"num_rollouts": 10,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
"observation_filter": "MeanStdFilter",
|
||||
},
|
||||
"DDPG": {
|
||||
"explore": False,
|
||||
"timesteps_per_iteration": 100
|
||||
"timesteps_per_iteration": 100,
|
||||
},
|
||||
"DQN": {
|
||||
"explore": False
|
||||
"explore": False,
|
||||
},
|
||||
"ES": {
|
||||
"explore": False,
|
||||
|
@ -52,13 +49,13 @@ CONFIGS = {
|
|||
"train_batch_size": 100,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter"
|
||||
"observation_filter": "MeanStdFilter",
|
||||
},
|
||||
"PPO": {
|
||||
"explore": False,
|
||||
"num_sgd_iter": 5,
|
||||
"train_batch_size": 1000,
|
||||
"num_workers": 2
|
||||
"num_workers": 2,
|
||||
},
|
||||
"SAC": {
|
||||
"explore": False,
|
||||
|
@ -66,16 +63,18 @@ CONFIGS = {
|
|||
}
|
||||
|
||||
|
||||
def ckpt_restore_test(use_object_store, alg_name, failures):
|
||||
def ckpt_restore_test(use_object_store, alg_name, failures, framework="tf"):
|
||||
cls = get_agent_class(alg_name)
|
||||
config = CONFIGS[alg_name]
|
||||
config["framework"] = framework
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
alg1 = cls(config=CONFIGS[alg_name], env="Pendulum-v0")
|
||||
alg2 = cls(config=CONFIGS[alg_name], env="Pendulum-v0")
|
||||
env = gym.make("Pendulum-v0")
|
||||
alg1 = cls(config=config, env="Pendulum-v0")
|
||||
alg2 = cls(config=config, env="Pendulum-v0")
|
||||
else:
|
||||
alg1 = cls(config=CONFIGS[alg_name], env="CartPole-v0")
|
||||
alg2 = cls(config=CONFIGS[alg_name], env="CartPole-v0")
|
||||
env = gym.make("CartPole-v0")
|
||||
alg1 = cls(config=config, env="CartPole-v0")
|
||||
alg2 = cls(config=config, env="CartPole-v0")
|
||||
|
||||
policy1 = alg1.get_policy()
|
||||
|
||||
for _ in range(1):
|
||||
res = alg1.train()
|
||||
|
@ -87,17 +86,17 @@ def ckpt_restore_test(use_object_store, alg_name, failures):
|
|||
else:
|
||||
alg2.restore(alg1.save())
|
||||
|
||||
for _ in range(2):
|
||||
for _ in range(1):
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=3),
|
||||
env.observation_space.low,
|
||||
env.observation_space.high)
|
||||
policy1.observation_space.low,
|
||||
policy1.observation_space.high)
|
||||
else:
|
||||
obs = np.clip(
|
||||
np.random.uniform(size=4),
|
||||
env.observation_space.low,
|
||||
env.observation_space.high)
|
||||
policy1.observation_space.low,
|
||||
policy1.observation_space.high)
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
|
@ -105,50 +104,6 @@ def ckpt_restore_test(use_object_store, alg_name, failures):
|
|||
failures.append((alg_name, [a1, a2]))
|
||||
|
||||
|
||||
def export_test(alg_name, failures):
|
||||
def valid_tf_model(model_dir):
|
||||
return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
|
||||
and os.listdir(os.path.join(model_dir, "variables"))
|
||||
|
||||
def valid_tf_checkpoint(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
|
||||
and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
|
||||
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
|
||||
|
||||
cls = get_agent_class(alg_name)
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
algo = cls(config=CONFIGS[alg_name], env="Pendulum-v0")
|
||||
else:
|
||||
algo = cls(config=CONFIGS[alg_name], env="CartPole-v0")
|
||||
|
||||
for _ in range(1):
|
||||
res = algo.train()
|
||||
print("current status: " + str(res))
|
||||
|
||||
export_dir = os.path.join(ray.utils.get_user_temp_dir(),
|
||||
"export_dir_%s" % alg_name)
|
||||
print("Exporting model ", alg_name, export_dir)
|
||||
algo.export_policy_model(export_dir)
|
||||
if not valid_tf_model(export_dir):
|
||||
failures.append(alg_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
print("Exporting checkpoint", alg_name, export_dir)
|
||||
algo.export_policy_checkpoint(export_dir)
|
||||
if not valid_tf_checkpoint(export_dir):
|
||||
failures.append(alg_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
print("Exporting default policy", alg_name, export_dir)
|
||||
algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
|
||||
export_dir)
|
||||
if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
|
||||
or not valid_tf_checkpoint(os.path.join(export_dir,
|
||||
ExportFormat.CHECKPOINT)):
|
||||
failures.append(alg_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
|
||||
class TestCheckpointRestore(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
@ -160,22 +115,20 @@ class TestCheckpointRestore(unittest.TestCase):
|
|||
|
||||
def test_checkpoint_restore(self):
|
||||
failures = []
|
||||
for use_object_store in [False, True]:
|
||||
for name in [
|
||||
"SAC", "ES", "DQN", "DDPG", "PPO", "A3C", "APEX_DDPG",
|
||||
"ARS"
|
||||
]:
|
||||
ckpt_restore_test(use_object_store, name, failures)
|
||||
for fw in framework_iterator(frameworks=("tf", "torch")):
|
||||
for use_object_store in [False, True]:
|
||||
for name in [
|
||||
"A3C", "APEX_DDPG", "ARS", "DDPG", "DQN", "ES", "PPO",
|
||||
"SAC"
|
||||
]:
|
||||
print("Testing algo={} (use_object_store={})".format(
|
||||
name, use_object_store))
|
||||
ckpt_restore_test(
|
||||
use_object_store, name, failures, framework=fw)
|
||||
|
||||
assert not failures, failures
|
||||
print("All checkpoint restore tests passed!")
|
||||
|
||||
failures = []
|
||||
for name in ["SAC", "DQN", "DDPG", "PPO", "A3C"]:
|
||||
export_test(name, failures)
|
||||
assert not failures, failures
|
||||
print("All export tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
|
|
@ -14,7 +14,7 @@ if __name__ == "__main__":
|
|||
# note: no ray.init(), to test it works without Ray
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"use_pytorch": True,
|
||||
"framework": "torch",
|
||||
"num_workers": 0
|
||||
})
|
||||
trainer.train()
|
||||
|
|
|
@ -14,7 +14,7 @@ if __name__ == "__main__":
|
|||
# note: no ray.init(), to test it works without Ray
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"use_pytorch": False,
|
||||
"framework": "tf",
|
||||
"num_workers": 0
|
||||
})
|
||||
trainer.train()
|
||||
|
|
|
@ -6,8 +6,7 @@ from ray.rllib.agents.registry import get_agent_class
|
|||
|
||||
|
||||
def check_support(alg, config, test_trace=True):
|
||||
config["eager"] = True
|
||||
|
||||
config["framework"] = "tfe"
|
||||
# Test both continuous and discrete actions.
|
||||
for cont in [True, False]:
|
||||
if cont and alg in ["DQN", "APEX", "SimpleQ"]:
|
||||
|
@ -24,7 +23,6 @@ def check_support(alg, config, test_trace=True):
|
|||
|
||||
a = get_agent_class(alg)
|
||||
config["log_level"] = "ERROR"
|
||||
|
||||
config["eager_tracing"] = False
|
||||
tune.run(a, config=config, stop={"training_iteration": 1})
|
||||
|
||||
|
@ -35,7 +33,7 @@ def check_support(alg, config, test_trace=True):
|
|||
|
||||
class TestEagerSupport(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=4)
|
||||
ray.init(num_cpus=4, local_mode=True)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
|
|
@ -49,6 +49,7 @@ if __name__ == "__main__":
|
|||
"tmp_file3": tmp3,
|
||||
"tmp_file4": tmp4,
|
||||
},
|
||||
"framework": "tf",
|
||||
},
|
||||
"stop": {
|
||||
"training_iteration": 1
|
||||
|
|
|
@ -5,6 +5,7 @@ import ray
|
|||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.agents.a3c import A3CTrainer
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import _adjust_nstep
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
|
@ -30,34 +31,36 @@ class EvalTest(unittest.TestCase):
|
|||
agent_classes = [A3CTrainer, DQNTrainer]
|
||||
|
||||
for agent_cls in agent_classes:
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
register_env("CartPoleWrapped-v0", env_creator)
|
||||
agent = agent_cls(
|
||||
env="CartPoleWrapped-v0",
|
||||
config={
|
||||
"evaluation_interval": 2,
|
||||
"evaluation_num_episodes": 2,
|
||||
"evaluation_config": {
|
||||
"gamma": 0.98,
|
||||
"env_config": {
|
||||
"fake_arg": True
|
||||
}
|
||||
},
|
||||
})
|
||||
# Given evaluation_interval=2, r0, r2, r4 should not contain
|
||||
# evaluation metrics while r1, r3 should do.
|
||||
r0 = agent.train()
|
||||
r1 = agent.train()
|
||||
r2 = agent.train()
|
||||
r3 = agent.train()
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
ray.init(object_store_memory=1000 * 1024 * 1024)
|
||||
register_env("CartPoleWrapped-v0", env_creator)
|
||||
agent = agent_cls(
|
||||
env="CartPoleWrapped-v0",
|
||||
config={
|
||||
"evaluation_interval": 2,
|
||||
"evaluation_num_episodes": 2,
|
||||
"evaluation_config": {
|
||||
"gamma": 0.98,
|
||||
"env_config": {
|
||||
"fake_arg": True
|
||||
}
|
||||
},
|
||||
"framework": fw,
|
||||
})
|
||||
# Given evaluation_interval=2, r0, r2, r4 should not contain
|
||||
# evaluation metrics while r1, r3 should do.
|
||||
r0 = agent.train()
|
||||
r1 = agent.train()
|
||||
r2 = agent.train()
|
||||
r3 = agent.train()
|
||||
|
||||
self.assertTrue("evaluation" in r1)
|
||||
self.assertTrue("evaluation" in r3)
|
||||
self.assertFalse("evaluation" in r0)
|
||||
self.assertFalse("evaluation" in r2)
|
||||
self.assertTrue("episode_reward_mean" in r1["evaluation"])
|
||||
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
|
||||
ray.shutdown()
|
||||
self.assertTrue("evaluation" in r1)
|
||||
self.assertTrue("evaluation" in r3)
|
||||
self.assertFalse("evaluation" in r0)
|
||||
self.assertFalse("evaluation" in r2)
|
||||
self.assertTrue("episode_reward_mean" in r1["evaluation"])
|
||||
self.assertNotEqual(r1["evaluation"], r3["evaluation"])
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -2,6 +2,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
class TestDistributedExecution(unittest.TestCase):
|
||||
|
@ -9,45 +10,53 @@ class TestDistributedExecution(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
ray.init()
|
||||
ray.init(ignore_reinit_error=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_exec_plan_stats(ray_start_regular):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"min_iter_time_s": 0,
|
||||
})
|
||||
result = trainer.train()
|
||||
assert isinstance(result, dict)
|
||||
assert "info" in result
|
||||
assert "learner" in result["info"]
|
||||
assert "num_steps_sampled" in result["info"]
|
||||
assert "num_steps_trained" in result["info"]
|
||||
assert "timers" in result
|
||||
assert "learn_time_ms" in result["timers"]
|
||||
assert "learn_throughput" in result["timers"]
|
||||
assert "sample_time_ms" in result["timers"]
|
||||
assert "sample_throughput" in result["timers"]
|
||||
assert "update_time_ms" in result["timers"]
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"min_iter_time_s": 0,
|
||||
"framework": fw,
|
||||
})
|
||||
result = trainer.train()
|
||||
assert isinstance(result, dict)
|
||||
assert "info" in result
|
||||
assert "learner" in result["info"]
|
||||
assert "num_steps_sampled" in result["info"]
|
||||
assert "num_steps_trained" in result["info"]
|
||||
assert "timers" in result
|
||||
assert "learn_time_ms" in result["timers"]
|
||||
assert "learn_throughput" in result["timers"]
|
||||
assert "sample_time_ms" in result["timers"]
|
||||
assert "sample_throughput" in result["timers"]
|
||||
assert "update_time_ms" in result["timers"]
|
||||
|
||||
def test_exec_plan_save_restore(ray_start_regular):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0", config={
|
||||
"min_iter_time_s": 0,
|
||||
})
|
||||
res1 = trainer.train()
|
||||
checkpoint = trainer.save()
|
||||
for _ in range(2):
|
||||
res2 = trainer.train()
|
||||
assert res2["timesteps_total"] > res1["timesteps_total"], (res1, res2)
|
||||
trainer.restore(checkpoint)
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
trainer = A2CTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"min_iter_time_s": 0,
|
||||
"framework": fw,
|
||||
})
|
||||
res1 = trainer.train()
|
||||
checkpoint = trainer.save()
|
||||
for _ in range(2):
|
||||
res2 = trainer.train()
|
||||
assert res2["timesteps_total"] > res1["timesteps_total"], \
|
||||
(res1, res2)
|
||||
trainer.restore(checkpoint)
|
||||
|
||||
# Should restore the timesteps counter to the same as res2.
|
||||
res3 = trainer.train()
|
||||
assert res3["timesteps_total"] < res2["timesteps_total"], (res2, res3)
|
||||
# Should restore the timesteps counter to the same as res2.
|
||||
res3 = trainer.train()
|
||||
assert res3["timesteps_total"] < res2["timesteps_total"], \
|
||||
(res2, res3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
132
rllib/tests/test_export.py
Normal file
132
rllib/tests/test_export.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.tune.trial import ExportFormat
|
||||
|
||||
CONFIGS = {
|
||||
"A3C": {
|
||||
"explore": False,
|
||||
"num_workers": 1,
|
||||
"framework": "tf",
|
||||
},
|
||||
"APEX_DDPG": {
|
||||
"explore": False,
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"num_workers": 2,
|
||||
"min_iter_time_s": 1,
|
||||
"optimizer": {
|
||||
"num_replay_buffer_shards": 1,
|
||||
},
|
||||
"framework": "tf",
|
||||
},
|
||||
"ARS": {
|
||||
"explore": False,
|
||||
"num_rollouts": 10,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"framework": "tf",
|
||||
},
|
||||
"DDPG": {
|
||||
"explore": False,
|
||||
"timesteps_per_iteration": 100,
|
||||
"framework": "tf",
|
||||
},
|
||||
"DQN": {
|
||||
"explore": False,
|
||||
"framework": "tf",
|
||||
},
|
||||
"ES": {
|
||||
"explore": False,
|
||||
"episodes_per_batch": 10,
|
||||
"train_batch_size": 100,
|
||||
"num_workers": 2,
|
||||
"noise_size": 2500000,
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"framework": "tf",
|
||||
},
|
||||
"PPO": {
|
||||
"explore": False,
|
||||
"num_sgd_iter": 5,
|
||||
"train_batch_size": 1000,
|
||||
"num_workers": 2,
|
||||
"framework": "tf",
|
||||
},
|
||||
"SAC": {
|
||||
"explore": False,
|
||||
"framework": "tf",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def export_test(alg_name, failures):
|
||||
def valid_tf_model(model_dir):
|
||||
return os.path.exists(os.path.join(model_dir, "saved_model.pb")) \
|
||||
and os.listdir(os.path.join(model_dir, "variables"))
|
||||
|
||||
def valid_tf_checkpoint(checkpoint_dir):
|
||||
return os.path.exists(os.path.join(checkpoint_dir, "model.meta")) \
|
||||
and os.path.exists(os.path.join(checkpoint_dir, "model.index")) \
|
||||
and os.path.exists(os.path.join(checkpoint_dir, "checkpoint"))
|
||||
|
||||
cls = get_agent_class(alg_name)
|
||||
if "DDPG" in alg_name or "SAC" in alg_name:
|
||||
algo = cls(config=CONFIGS[alg_name], env="Pendulum-v0")
|
||||
else:
|
||||
algo = cls(config=CONFIGS[alg_name], env="CartPole-v0")
|
||||
|
||||
for _ in range(1):
|
||||
res = algo.train()
|
||||
print("current status: " + str(res))
|
||||
|
||||
export_dir = os.path.join(ray.utils.get_user_temp_dir(),
|
||||
"export_dir_%s" % alg_name)
|
||||
print("Exporting model ", alg_name, export_dir)
|
||||
algo.export_policy_model(export_dir)
|
||||
if not valid_tf_model(export_dir):
|
||||
failures.append(alg_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
print("Exporting checkpoint", alg_name, export_dir)
|
||||
algo.export_policy_checkpoint(export_dir)
|
||||
if not valid_tf_checkpoint(export_dir):
|
||||
failures.append(alg_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
print("Exporting default policy", alg_name, export_dir)
|
||||
algo.export_model([ExportFormat.CHECKPOINT, ExportFormat.MODEL],
|
||||
export_dir)
|
||||
if not valid_tf_model(os.path.join(export_dir, ExportFormat.MODEL)) \
|
||||
or not valid_tf_checkpoint(os.path.join(export_dir,
|
||||
ExportFormat.CHECKPOINT)):
|
||||
failures.append(alg_name)
|
||||
shutil.rmtree(export_dir)
|
||||
|
||||
|
||||
class TestExport(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(
|
||||
num_cpus=10, object_store_memory=1e9, ignore_reinit_error=True)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_export(self):
|
||||
failures = []
|
||||
for name in ["A3C", "DQN", "DDPG", "PPO", "SAC"]:
|
||||
export_test(name, failures)
|
||||
assert not failures, failures
|
||||
print("All export tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -11,6 +11,7 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
|||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.tests.test_rollout_worker import (BadPolicy, MockPolicy,
|
||||
MockEnv)
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
|
@ -114,10 +115,12 @@ class MultiServing(ExternalEnv):
|
|||
|
||||
|
||||
class TestExternalEnv(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
ray.init()
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(ignore_reinit_error=True)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_external_env_complete_episodes(self):
|
||||
|
@ -165,41 +168,60 @@ class TestExternalEnv(unittest.TestCase):
|
|||
register_env(
|
||||
"test3", lambda _: PartOffPolicyServing(
|
||||
gym.make("CartPole-v0"), off_pol_frac=0.2))
|
||||
dqn = DQNTrainer(
|
||||
env="test3",
|
||||
config={"exploration_config": {
|
||||
config = {
|
||||
"num_workers": 0,
|
||||
"exploration_config": {
|
||||
"epsilon_timesteps": 100
|
||||
}})
|
||||
for i in range(100):
|
||||
result = dqn.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result["episode_reward_mean"], result["timesteps_total"]))
|
||||
if result["episode_reward_mean"] >= 100:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
},
|
||||
}
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
dqn = DQNTrainer(env="test3", config=config)
|
||||
reached = False
|
||||
for i in range(50):
|
||||
result = dqn.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result["episode_reward_mean"],
|
||||
result["timesteps_total"]))
|
||||
if result["episode_reward_mean"] >= 80:
|
||||
reached = True
|
||||
break
|
||||
if not reached:
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def test_train_cartpole(self):
|
||||
register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0")))
|
||||
pg = PGTrainer(env="test", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result["episode_reward_mean"], result["timesteps_total"]))
|
||||
if result["episode_reward_mean"] >= 100:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
config = {"num_workers": 0}
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
pg = PGTrainer(env="test", config=config)
|
||||
reached = False
|
||||
for i in range(80):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result["episode_reward_mean"],
|
||||
result["timesteps_total"]))
|
||||
if result["episode_reward_mean"] >= 80:
|
||||
reached = True
|
||||
break
|
||||
if not reached:
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def test_train_cartpole_multi(self):
|
||||
register_env("test2",
|
||||
lambda _: MultiServing(lambda: gym.make("CartPole-v0")))
|
||||
pg = PGTrainer(env="test2", config={"num_workers": 0})
|
||||
for i in range(100):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result["episode_reward_mean"], result["timesteps_total"]))
|
||||
if result["episode_reward_mean"] >= 100:
|
||||
return
|
||||
raise Exception("failed to improve reward")
|
||||
config = {"num_workers": 0}
|
||||
for _ in framework_iterator(config, frameworks=("tf", "torch")):
|
||||
pg = PGTrainer(env="test2", config=config)
|
||||
reached = False
|
||||
for i in range(80):
|
||||
result = pg.train()
|
||||
print("Iteration {}, reward {}, timesteps {}".format(
|
||||
i, result["episode_reward_mean"],
|
||||
result["timesteps_total"]))
|
||||
if result["episode_reward_mean"] >= 80:
|
||||
reached = True
|
||||
break
|
||||
if not reached:
|
||||
raise Exception("failed to improve reward")
|
||||
|
||||
def test_external_env_horizon_not_supported(self):
|
||||
ev = RolloutWorker(
|
||||
|
|
|
@ -13,10 +13,12 @@ SimpleMultiServing = make_simple_serving(True, ExternalMultiAgentEnv)
|
|||
|
||||
|
||||
class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
ray.init()
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
ray.init(ignore_reinit_error=True)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
ray.shutdown()
|
||||
|
||||
def test_external_multi_agent_env_complete_episodes(self):
|
||||
|
|
|
@ -71,7 +71,10 @@ class MSFTest(unittest.TestCase):
|
|||
|
||||
class FilterManagerTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1, object_store_memory=1000 * 1024 * 1024)
|
||||
ray.init(
|
||||
num_cpus=1,
|
||||
object_store_memory=1000 * 1024 * 1024,
|
||||
ignore_reinit_error=True)
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
|
|
@ -4,6 +4,7 @@ import unittest
|
|||
import ray
|
||||
from ray.rllib import _register_all
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
|
@ -28,7 +29,7 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
|||
def do_test(self, alg, config, fn=None):
|
||||
fn = fn or self._do_test_fault_recover
|
||||
try:
|
||||
ray.init(num_cpus=6)
|
||||
ray.init(num_cpus=6, ignore_reinit_error=True)
|
||||
fn(alg, config)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
|
@ -42,22 +43,24 @@ class IgnoresWorkerFailure(unittest.TestCase):
|
|||
config["num_workers"] = 2
|
||||
config["ignore_worker_failures"] = True
|
||||
config["env_config"] = {"bad_indices": [1]}
|
||||
a = agent_cls(config=config, env="fault_env")
|
||||
result = a.train()
|
||||
self.assertTrue(result["num_healthy_workers"], 1)
|
||||
a.stop()
|
||||
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
||||
a = agent_cls(config=config, env="fault_env")
|
||||
result = a.train()
|
||||
self.assertTrue(result["num_healthy_workers"], 1)
|
||||
a.stop()
|
||||
|
||||
def _do_test_fault_fatal(self, alg, config):
|
||||
register_env("fault_env", lambda c: FaultInjectEnv(c))
|
||||
agent_cls = get_agent_class(alg)
|
||||
|
||||
# Test raises real error when out of workers
|
||||
config["num_workers"] = 2
|
||||
config["ignore_worker_failures"] = True
|
||||
config["env_config"] = {"bad_indices": [1, 2]}
|
||||
a = agent_cls(config=config, env="fault_env")
|
||||
self.assertRaises(Exception, lambda: a.train())
|
||||
a.stop()
|
||||
|
||||
for _ in framework_iterator(config, frameworks=("torch", "tf")):
|
||||
a = agent_cls(config=config, env="fault_env")
|
||||
self.assertRaises(Exception, lambda: a.train())
|
||||
a.stop()
|
||||
|
||||
def test_fatal(self):
|
||||
# test the case where all workers fail
|
||||
|
|
|
@ -17,6 +17,7 @@ from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
|
|||
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
|
||||
from ray.rllib.offline.json_writer import _to_json
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
SAMPLES = SampleBatch({
|
||||
"actions": np.array([1, 2, 3, 4]),
|
||||
|
@ -39,37 +40,44 @@ class AgentIOTest(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def writeOutputs(self, output):
|
||||
def writeOutputs(self, output, fw):
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"output": output,
|
||||
"output": output + (fw if output != "logdir" else ""),
|
||||
"rollout_fragment_length": 250,
|
||||
"framework": fw,
|
||||
})
|
||||
agent.train()
|
||||
return agent
|
||||
|
||||
def testAgentOutputOk(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
reader = JsonReader(self.test_dir + "/*.json")
|
||||
reader.next()
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
self.assertEqual(len(os.listdir(self.test_dir + fw)), 1)
|
||||
reader = JsonReader(self.test_dir + fw + "/*.json")
|
||||
reader.next()
|
||||
|
||||
def testAgentOutputLogdir(self):
|
||||
agent = self.writeOutputs("logdir")
|
||||
self.assertEqual(len(glob.glob(agent.logdir + "/output-*.json")), 1)
|
||||
"""Test special value 'logdir' as Agent's output."""
|
||||
for fw in framework_iterator():
|
||||
agent = self.writeOutputs("logdir", fw)
|
||||
self.assertEqual(
|
||||
len(glob.glob(agent.logdir + "/output-*.json")), 1)
|
||||
|
||||
def testAgentInputDir(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": [],
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir + fw,
|
||||
"input_evaluation": [],
|
||||
"framework": fw,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testSplitByEpisode(self):
|
||||
splits = SAMPLES.split_by_episode()
|
||||
|
@ -79,74 +87,84 @@ class AgentIOTest(unittest.TestCase):
|
|||
self.assertEqual(splits[2].count, 1)
|
||||
|
||||
def testAgentInputPostprocessingEnabled(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
for fw in framework_iterator(frameworks=("tf", "torch")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
|
||||
# Rewrite the files to drop advantages and value_targets for testing
|
||||
for path in glob.glob(self.test_dir + "/*.json"):
|
||||
out = []
|
||||
for line in open(path).readlines():
|
||||
data = json.loads(line)
|
||||
del data["advantages"]
|
||||
del data["value_targets"]
|
||||
out.append(data)
|
||||
with open(path, "w") as f:
|
||||
for data in out:
|
||||
f.write(json.dumps(data))
|
||||
# Rewrite the files to drop advantages and value_targets for
|
||||
# testing
|
||||
for path in glob.glob(self.test_dir + fw + "/*.json"):
|
||||
out = []
|
||||
with open(path) as f:
|
||||
for line in f.readlines():
|
||||
data = json.loads(line)
|
||||
del data["advantages"]
|
||||
del data["value_targets"]
|
||||
out.append(data)
|
||||
with open(path, "w") as f:
|
||||
for data in out:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": [],
|
||||
"postprocess_inputs": True, # adds back 'advantages'
|
||||
})
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir + fw,
|
||||
"input_evaluation": [],
|
||||
"postprocess_inputs": True, # adds back 'advantages'
|
||||
"framework": fw,
|
||||
})
|
||||
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testAgentInputEvalSim(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": ["simulation"],
|
||||
})
|
||||
for _ in range(50):
|
||||
result = agent.train()
|
||||
if not np.isnan(result["episode_reward_mean"]):
|
||||
return # simulation ok
|
||||
time.sleep(0.1)
|
||||
assert False, "did not see any simulation results"
|
||||
for fw in framework_iterator():
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir + fw,
|
||||
"input_evaluation": ["simulation"],
|
||||
"framework": fw,
|
||||
})
|
||||
for _ in range(50):
|
||||
result = agent.train()
|
||||
if not np.isnan(result["episode_reward_mean"]):
|
||||
return # simulation ok
|
||||
time.sleep(0.1)
|
||||
assert False, "did not see any simulation results"
|
||||
|
||||
def testAgentInputList(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + "/*.json"),
|
||||
"input_evaluation": [],
|
||||
"rollout_fragment_length": 99,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
for fw in framework_iterator(frameworks=("torch", "tf")):
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": glob.glob(self.test_dir + fw + "/*.json"),
|
||||
"input_evaluation": [],
|
||||
"rollout_fragment_length": 99,
|
||||
"framework": fw,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testAgentInputDict(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": {
|
||||
self.test_dir: 0.1,
|
||||
"sampler": 0.9,
|
||||
},
|
||||
"train_batch_size": 2000,
|
||||
"input_evaluation": [],
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
|
||||
for fw in framework_iterator():
|
||||
self.writeOutputs(self.test_dir, fw)
|
||||
agent = PGTrainer(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": {
|
||||
self.test_dir + fw: 0.1,
|
||||
"sampler": 0.9,
|
||||
},
|
||||
"train_batch_size": 2000,
|
||||
"input_evaluation": [],
|
||||
"framework": fw,
|
||||
})
|
||||
result = agent.train()
|
||||
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testMultiAgent(self):
|
||||
register_env("multi_agent_cartpole",
|
||||
|
@ -158,48 +176,51 @@ class AgentIOTest(unittest.TestCase):
|
|||
act_space = single_env.action_space
|
||||
return (PGTFPolicy, obs_space, act_space, {})
|
||||
|
||||
pg = PGTrainer(
|
||||
env="multi_agent_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"output": self.test_dir,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
for fw in framework_iterator():
|
||||
pg = PGTrainer(
|
||||
env="multi_agent_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"output": self.test_dir,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
})
|
||||
pg.train()
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
"framework": fw,
|
||||
})
|
||||
pg.train()
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
|
||||
pg.stop()
|
||||
pg = PGTrainer(
|
||||
env="multi_agent_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": ["simulation"],
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
pg.stop()
|
||||
pg = PGTrainer(
|
||||
env="multi_agent_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": ["simulation"],
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
})
|
||||
for _ in range(50):
|
||||
result = pg.train()
|
||||
if not np.isnan(result["episode_reward_mean"]):
|
||||
return # simulation ok
|
||||
time.sleep(0.1)
|
||||
assert False, "did not see any simulation results"
|
||||
"framework": fw,
|
||||
})
|
||||
for _ in range(50):
|
||||
result = pg.train()
|
||||
if not np.isnan(result["episode_reward_mean"]):
|
||||
return # simulation ok
|
||||
time.sleep(0.1)
|
||||
assert False, "did not see any simulation results"
|
||||
|
||||
|
||||
class JsonIOTest(unittest.TestCase):
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import unittest
|
||||
|
||||
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils.test_utils import framework_iterator
|
||||
|
||||
|
||||
class LocalModeTest(unittest.TestCase):
|
||||
|
@ -13,8 +14,9 @@ class LocalModeTest(unittest.TestCase):
|
|||
|
||||
def test_local(self):
|
||||
cf = DEFAULT_CONFIG.copy()
|
||||
agent = PPOTrainer(cf, "CartPole-v0")
|
||||
print(agent.train())
|
||||
for fw in framework_iterator(cf):
|
||||
agent = PPOTrainer(cf, "CartPole-v0")
|
||||
print(agent.train())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue