[RLlib Testing] Add A3C/APPO/BC/DDPPO/MARWIL/CQL/ES/ARS/TD3 to weekly learning tests. (#18381)

This commit is contained in:
Sven Mika 2021-09-07 11:48:41 +02:00 committed by GitHub
parent 64040a90a5
commit cabaa3b3c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 333 additions and 58 deletions

View file

@ -1,6 +1,6 @@
# Deep learning.
# --------------
tensorflow==2.5.0
tensorflow==2.4.3
tensorflow-probability==0.12.2
torch==1.8.1;sys_platform=="darwin"
torchvision==0.9.1;sys_platform=="darwin"

View file

@ -1,6 +1,8 @@
base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu"
env_vars: {}
debian_packages: []
debian_packages:
- unzip
- zip
python:
# These dependencies should be handled by requirements_rllib.txt and
@ -10,3 +12,6 @@ python:
post_build_cmds:
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
# Clone the rl-experiments repo for offline-RL files.
- git clone https://github.com/ray-project/rl-experiments.git
- cp rl-experiments/halfcheetah-sac/2021-09-06/halfcheetah_expert_sac.zip ~/.

View file

@ -19,6 +19,41 @@ a2c-breakoutnoframeskip-v4:
[20000000, 0.000000000001],
]
a3c-pongdeterministic-v4:
env: PongDeterministic-v4
run: A3C
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 18.0
timesteps_total: 5000000
stop:
time_total_s: 3600
config:
num_gpus: 0
num_workers: 16
rollout_fragment_length: 20
vf_loss_coeff: 0.5
entropy_coeff: 0.01
gamma: 0.99
grad_clip: 40.0
lambda: 1.0
lr: 0.0001
observation_filter: NoFilter
preprocessor_pref: rllib
model:
use_lstm: true
conv_activation: elu
dim: 42
grayscale: true
zero_mean: false
# Reduced channel depth and kernel size from default.
conv_filters: [
[32, [3, 3], 2],
[32, [3, 3], 2],
[32, [3, 3], 2],
[32, [3, 3], 2],
]
apex-breakoutnoframeskip-v4:
env: BreakoutNoFrameskip-v4
run: APEX
@ -39,8 +74,8 @@ apex-breakoutnoframeskip-v4:
hiddens: [512]
buffer_size: 1000000
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
@ -52,7 +87,7 @@ apex-breakoutnoframeskip-v4:
target_network_update_freq: 50000
timesteps_per_iteration: 25000
appo-pong-no-frameskip-v4:
appo-pongnoframeskip-v4:
env: PongNoFrameskip-v4
run: APPO
# Minimum reward and total ts (in given time_total_s) to pass this test.
@ -77,7 +112,103 @@ appo-pong-no-frameskip-v4:
num_gpus: 1
grad_clip: 10
model:
dim: 42
dim: 42
ars-hopperbulletenv-v0:
env: HopperBulletEnv-v0
run: ARS
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 100.0
timesteps_total: 2000000
stop:
time_total_s: 2000
config:
noise_stdev: 0.01
num_rollouts: 1
rollouts_used: 1
num_workers: 1
sgd_stepsize: 0.02
noise_size: 250000000
eval_prob: 0.2
offset: 0
observation_filter: NoFilter
report_length: 3
# bc-halfcheetahbulletenv-v0:
# env: HalfCheetahBulletEnv-v0
# run: BC
# pass_criteria:
# episode_reward_mean: 400.0
# timesteps_total: 10000000
# stop:
# time_total_s: 3600
# config:
# # Use input produced by expert SAC algo.
# input: ["~/halfcheetah_expert_sac.zip"]
# actions_in_input_normalized: true
# num_gpus: 1
# model:
# fcnet_activation: relu
# fcnet_hiddens: [256, 256, 256]
# evaluation_num_workers: 1
# evaluation_interval: 3
# evaluation_config:
# input: sampler
cql-halfcheetahbulletenv-v0:
env: HalfCheetahBulletEnv-v0
run: CQL
pass_criteria:
episode_reward_mean: 400.0
timesteps_total: 10000000
stop:
time_total_s: 3600
config:
# Use input produced by expert SAC algo.
input: ["~/halfcheetah_expert_sac.zip"]
actions_in_input_normalized: true
soft_horizon: False
horizon: 1000
Q_model:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
policy_model:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
tau: 0.005
target_entropy: auto
no_done_at_end: false
n_step: 3
rollout_fragment_length: 1
prioritized_replay: false
train_batch_size: 256
target_network_update_freq: 0
timesteps_per_iteration: 1000
learning_starts: 256
optimization:
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003
entropy_learning_rate: 0.0001
num_workers: 0
num_gpus: 1
metrics_smoothing_episodes: 5
# CQL Configs
min_q_weight: 5.0
bc_iters: 20000
temperature: 1.0
num_actions: 10
lagrangian: False
# Switch on online evaluation.
evaluation_interval: 3
evaluation_config:
input: sampler
ddpg-hopperbulletenv-v0:
env: HopperBulletEnv-v0
@ -124,6 +255,45 @@ ddpg-hopperbulletenv-v0:
num_gpus_per_worker: 0
worker_side_prioritization: false
# Basically the same as atari-ppo, but adapted for DDPPO. Note that DDPPO
# isn't actually any more efficient on Atari, since the network size is
# relatively small and the env doesn't require a GPU.
ddppo-breakoutnoframeskip-v4:
env: BreakoutNoFrameskip-v4
run: DDPPO
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 50.0
timesteps_total: 10000000
stop:
time_total_s: 3600
config:
# DDPPO only supports PyTorch so far.
framework: torch
# Worker config: 10 workers, each of which requires a GPU.
num_workers: 8
# Workers require GPUs, but share 1 GPU amongst 2 workers.
num_gpus_per_worker: 0.5
# Each worker will sample 100 * 5 envs per worker steps = 500 steps
# per optimization round. This is 5000 steps summed across workers.
rollout_fragment_length: 100
num_envs_per_worker: 5
# Each worker will take a minibatch of 50. There are 10 workers total,
# so the effective minibatch size will be 500.
sgd_minibatch_size: 50
num_sgd_iter: 10
# Params from standard PPO Atari config:
lambda: 0.95
kl_coeff: 0.5
clip_rewards: true
clip_param: 0.1
vf_clip_param: 10.0
entropy_coeff: 0.01
batch_mode: truncate_episodes
observation_filter: NoFilter
model:
vf_share_layers: true
dqn-breakoutnoframeskip-v4:
env: BreakoutNoFrameskip-v4
run: DQN
@ -149,14 +319,26 @@ dqn-breakoutnoframeskip-v4:
rollout_fragment_length: 4
train_batch_size: 32
exploration_config:
epsilon_timesteps: 200000
final_epsilon: 0.01
epsilon_timesteps: 200000
final_epsilon: 0.01
prioritized_replay_alpha: 0.5
final_prioritized_replay_beta: 1.0
prioritized_replay_beta_annealing_timesteps: 2000000
num_gpus: 0.5
timesteps_per_iteration: 10000
es-humanoid-v2:
env: Humanoid-v2
run: ES
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 6000.0
timesteps_total: 10000000
stop:
time_total_s: 3600
config:
num_workers: 50
impala-breakoutnoframeskip-v4:
env: BreakoutNoFrameskip-v4
run: IMPALA
@ -178,6 +360,32 @@ impala-breakoutnoframeskip-v4:
]
num_gpus: 1
# marwil-halfcheetahbulletenv-v0:
# env: HalfCheetahBulletEnv-v0
# run: MARWIL
# pass_criteria:
# episode_reward_mean: 400.0
# timesteps_total: 10000000
# stop:
# time_total_s: 3600
# config:
# # Use input produced by expert SAC algo.
# input: ["~/halfcheetah_expert_sac.zip"]
# actions_in_input_normalized: true
# # Switch off input evaluation (data does not contain action probs).
# input_evaluation: []
# num_gpus: 1
# model:
# fcnet_activation: relu
# fcnet_hiddens: [256, 256, 256]
# evaluation_num_workers: 1
# evaluation_interval: 1
# evaluation_config:
# input: sampler
ppo-breakoutnoframeskip-v4:
env: BreakoutNoFrameskip-v4
run: PPO
@ -212,23 +420,23 @@ sac-halfcheetahbulletenv-v0:
run: SAC
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 400.0
timesteps_total: 80000
episode_reward_mean: 600.0
timesteps_total: 100000
stop:
time_total_s: 7200
config:
horizon: 1000
soft_horizon: false
Q_model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
fcnet_activation: relu
fcnet_hiddens: [256, 256]
policy_model:
fcnet_activation: relu
fcnet_hiddens: [256, 256]
fcnet_activation: relu
fcnet_hiddens: [256, 256]
tau: 0.005
target_entropy: auto
no_done_at_end: true
n_step: 1
no_done_at_end: false
n_step: 3
rollout_fragment_length: 1
prioritized_replay: true
train_batch_size: 256
@ -236,12 +444,24 @@ sac-halfcheetahbulletenv-v0:
timesteps_per_iteration: 1000
learning_starts: 10000
optimization:
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003
entropy_learning_rate: 0.0003
actor_learning_rate: 0.0003
critic_learning_rate: 0.0003
entropy_learning_rate: 0.0003
num_workers: 0
num_gpus: 1
clip_actions: false
normalize_actions: true
evaluation_interval: 1
metrics_smoothing_episodes: 5
td3-halfcheetahbulletenv-v0:
env: HalfCheetahBulletEnv-v0
run: TD3
# Minimum reward and total ts (in given time_total_s) to pass this test.
pass_criteria:
episode_reward_mean: 400.0
timesteps_total: 1000000
stop:
time_total_s: 7200
config:
num_gpus: 1
learning_starts: 10000
exploration_config:
random_timesteps: 10000

View file

@ -11,6 +11,8 @@
smoke_test:
run:
timeout: 900
cluster:
compute_template: 4gpus_64cpus.yaml
# 2-GPU learning tests (CartPole and RepeatAfterMeEnv) for major algos.
- name: multi_gpu_learning_tests

View file

@ -10,7 +10,7 @@ from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
_, nn = try_import_torch()
torch, nn = try_import_torch()
class VisionNetwork(TorchModelV2, nn.Module):
@ -141,10 +141,19 @@ class VisionNetwork(TorchModelV2, nn.Module):
else:
self.last_layer_is_flattened = True
layers.append(nn.Flatten())
self.num_outputs = out_channels
self._convs = nn.Sequential(*layers)
# If our num_outputs still unknown, we need to do a test pass to
# figure out the output dimensions. This could be the case, if we have
# the Flatten layer at the end.
if self.num_outputs is None:
# Create a B=1 dummy sample and push it through out conv-net.
dummy_in = torch.from_numpy(self.obs_space.sample()).permute(
2, 0, 1).unsqueeze(0).float()
dummy_out = self._convs(dummy_in)
self.num_outputs = dummy_out.shape[1]
# Build the value layers
self._value_branch_separate = self._value_branch = None
if vf_share_layers:

View file

@ -8,8 +8,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
"""Returns a framework specific activation function, given a name string.
Args:
name (Optional[str]): One of "relu" (default), "tanh", "swish", or
"linear" or None.
name (Optional[str]): One of "relu" (default), "tanh", "elu",
"swish", or "linear" (same as None).
framework (str): One of "jax", "tf|tfe|tf2" or "torch".
Returns:
@ -35,6 +35,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
return nn.ReLU
elif name == "tanh":
return nn.Tanh
elif name == "elu":
return nn.ELU
elif framework == "jax":
if name in ["linear", None]:
return None
@ -45,6 +47,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
return jax.nn.relu
elif name == "tanh":
return jax.nn.hard_tanh
elif name == "elu":
return jax.nn.elu
else:
assert framework in ["tf", "tfe", "tf2"],\
"Unsupported framework `{}`!".format(framework)

View file

@ -123,6 +123,10 @@ class JsonReader(InputReader):
"from URIs like {}".format(path))
ctx = smart_open
else:
# Allow shortcut for home directory ("~/" -> env[HOME]).
if path.startswith("~/"):
path = os.path.join(os.environ.get("HOME", ""), path[2:])
# If path doesn't exist, try to interpret is as relative to the
# rllib directory (located ../../ from this very module).
path_orig = path

View file

@ -130,9 +130,11 @@ if __name__ == "__main__":
t.stopping_criterion.get("episode_reward_mean"))
# Otherwise, expect `episode_reward_mean` to be set.
else:
min_reward = t.stopping_criterion["episode_reward_mean"]
min_reward = t.stopping_criterion.get(
"episode_reward_mean")
if reward_mean >= min_reward:
# If min reward not defined, always pass.
if min_reward is None or reward_mean >= min_reward:
passed = True
break

View file

@ -12,34 +12,33 @@ halfcheetah_cql:
#input: d4rl.halfcheetah-medium-v0
input: d4rl.halfcheetah-expert-v0
#input: d4rl.halfcheetah-medium-replay-v0
framework: torch
# Works for both torch and tf.
framework: tf
soft_horizon: False
horizon: 1000
Q_model:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
policy_model:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
tau: 0.005
target_entropy: auto
no_done_at_end: false
n_step: 1
n_step: 3
rollout_fragment_length: 1
prioritized_replay: false
train_batch_size: 256
target_network_update_freq: 0
timesteps_per_iteration: 1000
learning_starts: 10
learning_starts: 256
optimization:
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003
entropy_learning_rate: 0.0001
actor_learning_rate: 0.0001
critic_learning_rate: 0.0003
entropy_learning_rate: 0.0001
num_workers: 0
num_gpus: 1
clip_actions: false
normalize_actions: true
evaluation_interval: 1
metrics_smoothing_episodes: 5
# CQL Configs
min_q_weight: 5.0
@ -47,6 +46,7 @@ halfcheetah_cql:
temperature: 1.0
num_actions: 10
lagrangian: False
evaluation_interval: 3
evaluation_config:
input: sampler

View file

@ -406,7 +406,14 @@ def run_learning_tests_from_yaml(
# Add torch version of all experiments to the list.
for k, e in tf_experiments.items():
e["config"]["framework"] = "tf"
# If framework explicitly given, only test for that framework.
# Some algos do not have both versions available.
if "framework" in e["config"]:
frameworks = [e["config"]["framework"]]
else:
frameworks = ["tf", "torch"]
e["config"]["framework"] = "tf"
# For smoke-tests, we just run for n min.
if smoke_test:
# 15min hardcoded for now.
@ -420,15 +427,27 @@ def run_learning_tests_from_yaml(
e["stop"]["episode_reward_mean"] = \
e["pass_criteria"]["episode_reward_mean"]
keys = []
# Generate the torch copy of the experiment.
e_torch = copy.deepcopy(e)
e_torch["config"]["framework"] = "torch"
k_tf = re.sub("^(\\w+)-", "\\1-tf-", k)
k_torch = re.sub("-tf-", "-torch-", k_tf)
experiments[k_tf] = e
experiments[k_torch] = e_torch
# Generate `checks` dict.
for k_ in [k_tf, k_torch]:
if len(frameworks) == 2:
e_torch = copy.deepcopy(e)
e_torch["config"]["framework"] = "torch"
keys.append(re.sub("^(\\w+)-", "\\1-tf-", k))
keys.append(re.sub("-tf-", "-torch-", keys[0]))
experiments[keys[0]] = e
experiments[keys[1]] = e_torch
# tf-only.
elif frameworks[0] == "tf":
keys.append(re.sub("^(\\w+)-", "\\1-tf-", k))
experiments[keys[0]] = e
# torch-only.
else:
keys.append(re.sub("^(\\w+)-", "\\1-torch-", k))
experiments[keys[0]] = e
# Generate `checks` dict for all experiments (tf and/or torch).
for k_ in keys:
e = experiments[k_]
checks[k_] = {
"min_reward": e["pass_criteria"]["episode_reward_mean"],
"min_timesteps": e["pass_criteria"]["timesteps_total"],
@ -436,9 +455,8 @@ def run_learning_tests_from_yaml(
"failures": 0,
"passed": False,
}
# These keys would break tune.
del e["pass_criteria"]
del e_torch["pass_criteria"]
# This key would break tune.
del e["pass_criteria"]
# Print out the actual config.
print("== Test config ==")
@ -471,21 +489,32 @@ def run_learning_tests_from_yaml(
for t in trials:
experiment = re.sub(".+/([^/]+)$", "\\1", t.local_dir)
# If we have evaluation workers, use their rewards.
# This is useful for offline learning tests, where
# we evaluate against an actual environment.
check_eval = experiments[experiment]["config"].get(
"evaluation_interval", None) is not None
if t.status == "ERROR":
checks[experiment]["failures"] += 1
else:
reward_mean = \
t.last_result["evaluation"]["episode_reward_mean"] if \
check_eval else t.last_result["episode_reward_mean"]
desired_reward = checks[experiment]["min_reward"]
desired_timesteps = checks[experiment]["min_timesteps"]
throughput = t.last_result["timesteps_total"] / \
t.last_result["time_total_s"]
desired_timesteps = checks[experiment]["min_timesteps"]
desired_throughput = \
desired_timesteps / t.stopping_criterion["time_total_s"]
if t.last_result["episode_reward_mean"] < desired_reward or \
desired_throughput and throughput < desired_throughput:
# We failed to reach desired reward or the desired throughput.
if reward_mean < desired_reward or \
(desired_throughput and
throughput < desired_throughput):
checks[experiment]["failures"] += 1
# We succeeded!
else:
checks[experiment]["passed"] = True
del experiments_to_run[experiment]