mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib Testing] Add A3C/APPO/BC/DDPPO/MARWIL/CQL/ES/ARS/TD3 to weekly learning tests. (#18381)
This commit is contained in:
parent
64040a90a5
commit
cabaa3b3c6
10 changed files with 333 additions and 58 deletions
|
@ -1,6 +1,6 @@
|
|||
# Deep learning.
|
||||
# --------------
|
||||
tensorflow==2.5.0
|
||||
tensorflow==2.4.3
|
||||
tensorflow-probability==0.12.2
|
||||
torch==1.8.1;sys_platform=="darwin"
|
||||
torchvision==0.9.1;sys_platform=="darwin"
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
base_image: "anyscale/ray-ml:pinned-nightly-py37-gpu"
|
||||
env_vars: {}
|
||||
debian_packages: []
|
||||
debian_packages:
|
||||
- unzip
|
||||
- zip
|
||||
|
||||
python:
|
||||
# These dependencies should be handled by requirements_rllib.txt and
|
||||
|
@ -10,3 +12,6 @@ python:
|
|||
|
||||
post_build_cmds:
|
||||
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
|
||||
# Clone the rl-experiments repo for offline-RL files.
|
||||
- git clone https://github.com/ray-project/rl-experiments.git
|
||||
- cp rl-experiments/halfcheetah-sac/2021-09-06/halfcheetah_expert_sac.zip ~/.
|
||||
|
|
|
@ -19,6 +19,41 @@ a2c-breakoutnoframeskip-v4:
|
|||
[20000000, 0.000000000001],
|
||||
]
|
||||
|
||||
a3c-pongdeterministic-v4:
|
||||
env: PongDeterministic-v4
|
||||
run: A3C
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 18.0
|
||||
timesteps_total: 5000000
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
num_gpus: 0
|
||||
num_workers: 16
|
||||
rollout_fragment_length: 20
|
||||
vf_loss_coeff: 0.5
|
||||
entropy_coeff: 0.01
|
||||
gamma: 0.99
|
||||
grad_clip: 40.0
|
||||
lambda: 1.0
|
||||
lr: 0.0001
|
||||
observation_filter: NoFilter
|
||||
preprocessor_pref: rllib
|
||||
model:
|
||||
use_lstm: true
|
||||
conv_activation: elu
|
||||
dim: 42
|
||||
grayscale: true
|
||||
zero_mean: false
|
||||
# Reduced channel depth and kernel size from default.
|
||||
conv_filters: [
|
||||
[32, [3, 3], 2],
|
||||
[32, [3, 3], 2],
|
||||
[32, [3, 3], 2],
|
||||
[32, [3, 3], 2],
|
||||
]
|
||||
|
||||
apex-breakoutnoframeskip-v4:
|
||||
env: BreakoutNoFrameskip-v4
|
||||
run: APEX
|
||||
|
@ -39,8 +74,8 @@ apex-breakoutnoframeskip-v4:
|
|||
hiddens: [512]
|
||||
buffer_size: 1000000
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
|
@ -52,7 +87,7 @@ apex-breakoutnoframeskip-v4:
|
|||
target_network_update_freq: 50000
|
||||
timesteps_per_iteration: 25000
|
||||
|
||||
appo-pong-no-frameskip-v4:
|
||||
appo-pongnoframeskip-v4:
|
||||
env: PongNoFrameskip-v4
|
||||
run: APPO
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
|
@ -77,7 +112,103 @@ appo-pong-no-frameskip-v4:
|
|||
num_gpus: 1
|
||||
grad_clip: 10
|
||||
model:
|
||||
dim: 42
|
||||
dim: 42
|
||||
|
||||
ars-hopperbulletenv-v0:
|
||||
env: HopperBulletEnv-v0
|
||||
run: ARS
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 100.0
|
||||
timesteps_total: 2000000
|
||||
stop:
|
||||
time_total_s: 2000
|
||||
config:
|
||||
noise_stdev: 0.01
|
||||
num_rollouts: 1
|
||||
rollouts_used: 1
|
||||
num_workers: 1
|
||||
sgd_stepsize: 0.02
|
||||
noise_size: 250000000
|
||||
eval_prob: 0.2
|
||||
offset: 0
|
||||
observation_filter: NoFilter
|
||||
report_length: 3
|
||||
|
||||
# bc-halfcheetahbulletenv-v0:
|
||||
# env: HalfCheetahBulletEnv-v0
|
||||
# run: BC
|
||||
# pass_criteria:
|
||||
# episode_reward_mean: 400.0
|
||||
# timesteps_total: 10000000
|
||||
# stop:
|
||||
# time_total_s: 3600
|
||||
# config:
|
||||
# # Use input produced by expert SAC algo.
|
||||
# input: ["~/halfcheetah_expert_sac.zip"]
|
||||
# actions_in_input_normalized: true
|
||||
|
||||
# num_gpus: 1
|
||||
|
||||
# model:
|
||||
# fcnet_activation: relu
|
||||
# fcnet_hiddens: [256, 256, 256]
|
||||
|
||||
# evaluation_num_workers: 1
|
||||
# evaluation_interval: 3
|
||||
# evaluation_config:
|
||||
# input: sampler
|
||||
|
||||
cql-halfcheetahbulletenv-v0:
|
||||
env: HalfCheetahBulletEnv-v0
|
||||
run: CQL
|
||||
pass_criteria:
|
||||
episode_reward_mean: 400.0
|
||||
timesteps_total: 10000000
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
# Use input produced by expert SAC algo.
|
||||
input: ["~/halfcheetah_expert_sac.zip"]
|
||||
actions_in_input_normalized: true
|
||||
|
||||
soft_horizon: False
|
||||
horizon: 1000
|
||||
Q_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
policy_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
tau: 0.005
|
||||
target_entropy: auto
|
||||
no_done_at_end: false
|
||||
n_step: 3
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: false
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
timesteps_per_iteration: 1000
|
||||
learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
entropy_learning_rate: 0.0001
|
||||
num_workers: 0
|
||||
num_gpus: 1
|
||||
metrics_smoothing_episodes: 5
|
||||
|
||||
# CQL Configs
|
||||
min_q_weight: 5.0
|
||||
bc_iters: 20000
|
||||
temperature: 1.0
|
||||
num_actions: 10
|
||||
lagrangian: False
|
||||
|
||||
# Switch on online evaluation.
|
||||
evaluation_interval: 3
|
||||
evaluation_config:
|
||||
input: sampler
|
||||
|
||||
ddpg-hopperbulletenv-v0:
|
||||
env: HopperBulletEnv-v0
|
||||
|
@ -124,6 +255,45 @@ ddpg-hopperbulletenv-v0:
|
|||
num_gpus_per_worker: 0
|
||||
worker_side_prioritization: false
|
||||
|
||||
# Basically the same as atari-ppo, but adapted for DDPPO. Note that DDPPO
|
||||
# isn't actually any more efficient on Atari, since the network size is
|
||||
# relatively small and the env doesn't require a GPU.
|
||||
ddppo-breakoutnoframeskip-v4:
|
||||
env: BreakoutNoFrameskip-v4
|
||||
run: DDPPO
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 50.0
|
||||
timesteps_total: 10000000
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
# DDPPO only supports PyTorch so far.
|
||||
framework: torch
|
||||
# Worker config: 10 workers, each of which requires a GPU.
|
||||
num_workers: 8
|
||||
# Workers require GPUs, but share 1 GPU amongst 2 workers.
|
||||
num_gpus_per_worker: 0.5
|
||||
# Each worker will sample 100 * 5 envs per worker steps = 500 steps
|
||||
# per optimization round. This is 5000 steps summed across workers.
|
||||
rollout_fragment_length: 100
|
||||
num_envs_per_worker: 5
|
||||
# Each worker will take a minibatch of 50. There are 10 workers total,
|
||||
# so the effective minibatch size will be 500.
|
||||
sgd_minibatch_size: 50
|
||||
num_sgd_iter: 10
|
||||
# Params from standard PPO Atari config:
|
||||
lambda: 0.95
|
||||
kl_coeff: 0.5
|
||||
clip_rewards: true
|
||||
clip_param: 0.1
|
||||
vf_clip_param: 10.0
|
||||
entropy_coeff: 0.01
|
||||
batch_mode: truncate_episodes
|
||||
observation_filter: NoFilter
|
||||
model:
|
||||
vf_share_layers: true
|
||||
|
||||
dqn-breakoutnoframeskip-v4:
|
||||
env: BreakoutNoFrameskip-v4
|
||||
run: DQN
|
||||
|
@ -149,14 +319,26 @@ dqn-breakoutnoframeskip-v4:
|
|||
rollout_fragment_length: 4
|
||||
train_batch_size: 32
|
||||
exploration_config:
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
epsilon_timesteps: 200000
|
||||
final_epsilon: 0.01
|
||||
prioritized_replay_alpha: 0.5
|
||||
final_prioritized_replay_beta: 1.0
|
||||
prioritized_replay_beta_annealing_timesteps: 2000000
|
||||
num_gpus: 0.5
|
||||
timesteps_per_iteration: 10000
|
||||
|
||||
es-humanoid-v2:
|
||||
env: Humanoid-v2
|
||||
run: ES
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 6000.0
|
||||
timesteps_total: 10000000
|
||||
stop:
|
||||
time_total_s: 3600
|
||||
config:
|
||||
num_workers: 50
|
||||
|
||||
impala-breakoutnoframeskip-v4:
|
||||
env: BreakoutNoFrameskip-v4
|
||||
run: IMPALA
|
||||
|
@ -178,6 +360,32 @@ impala-breakoutnoframeskip-v4:
|
|||
]
|
||||
num_gpus: 1
|
||||
|
||||
# marwil-halfcheetahbulletenv-v0:
|
||||
# env: HalfCheetahBulletEnv-v0
|
||||
# run: MARWIL
|
||||
# pass_criteria:
|
||||
# episode_reward_mean: 400.0
|
||||
# timesteps_total: 10000000
|
||||
# stop:
|
||||
# time_total_s: 3600
|
||||
# config:
|
||||
# # Use input produced by expert SAC algo.
|
||||
# input: ["~/halfcheetah_expert_sac.zip"]
|
||||
# actions_in_input_normalized: true
|
||||
# # Switch off input evaluation (data does not contain action probs).
|
||||
# input_evaluation: []
|
||||
|
||||
# num_gpus: 1
|
||||
|
||||
# model:
|
||||
# fcnet_activation: relu
|
||||
# fcnet_hiddens: [256, 256, 256]
|
||||
|
||||
# evaluation_num_workers: 1
|
||||
# evaluation_interval: 1
|
||||
# evaluation_config:
|
||||
# input: sampler
|
||||
|
||||
ppo-breakoutnoframeskip-v4:
|
||||
env: BreakoutNoFrameskip-v4
|
||||
run: PPO
|
||||
|
@ -212,23 +420,23 @@ sac-halfcheetahbulletenv-v0:
|
|||
run: SAC
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 400.0
|
||||
timesteps_total: 80000
|
||||
episode_reward_mean: 600.0
|
||||
timesteps_total: 100000
|
||||
stop:
|
||||
time_total_s: 7200
|
||||
config:
|
||||
horizon: 1000
|
||||
soft_horizon: false
|
||||
Q_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
policy_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256]
|
||||
tau: 0.005
|
||||
target_entropy: auto
|
||||
no_done_at_end: true
|
||||
n_step: 1
|
||||
no_done_at_end: false
|
||||
n_step: 3
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: true
|
||||
train_batch_size: 256
|
||||
|
@ -236,12 +444,24 @@ sac-halfcheetahbulletenv-v0:
|
|||
timesteps_per_iteration: 1000
|
||||
learning_starts: 10000
|
||||
optimization:
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
entropy_learning_rate: 0.0003
|
||||
actor_learning_rate: 0.0003
|
||||
critic_learning_rate: 0.0003
|
||||
entropy_learning_rate: 0.0003
|
||||
num_workers: 0
|
||||
num_gpus: 1
|
||||
clip_actions: false
|
||||
normalize_actions: true
|
||||
evaluation_interval: 1
|
||||
metrics_smoothing_episodes: 5
|
||||
|
||||
td3-halfcheetahbulletenv-v0:
|
||||
env: HalfCheetahBulletEnv-v0
|
||||
run: TD3
|
||||
# Minimum reward and total ts (in given time_total_s) to pass this test.
|
||||
pass_criteria:
|
||||
episode_reward_mean: 400.0
|
||||
timesteps_total: 1000000
|
||||
stop:
|
||||
time_total_s: 7200
|
||||
config:
|
||||
num_gpus: 1
|
||||
learning_starts: 10000
|
||||
exploration_config:
|
||||
random_timesteps: 10000
|
||||
|
|
|
@ -11,6 +11,8 @@
|
|||
smoke_test:
|
||||
run:
|
||||
timeout: 900
|
||||
cluster:
|
||||
compute_template: 4gpus_64cpus.yaml
|
||||
|
||||
# 2-GPU learning tests (CartPole and RepeatAfterMeEnv) for major algos.
|
||||
- name: multi_gpu_learning_tests
|
||||
|
|
|
@ -10,7 +10,7 @@ from ray.rllib.utils.annotations import override
|
|||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.typing import ModelConfigDict, TensorType
|
||||
|
||||
_, nn = try_import_torch()
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
|
||||
class VisionNetwork(TorchModelV2, nn.Module):
|
||||
|
@ -141,10 +141,19 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
|||
else:
|
||||
self.last_layer_is_flattened = True
|
||||
layers.append(nn.Flatten())
|
||||
self.num_outputs = out_channels
|
||||
|
||||
self._convs = nn.Sequential(*layers)
|
||||
|
||||
# If our num_outputs still unknown, we need to do a test pass to
|
||||
# figure out the output dimensions. This could be the case, if we have
|
||||
# the Flatten layer at the end.
|
||||
if self.num_outputs is None:
|
||||
# Create a B=1 dummy sample and push it through out conv-net.
|
||||
dummy_in = torch.from_numpy(self.obs_space.sample()).permute(
|
||||
2, 0, 1).unsqueeze(0).float()
|
||||
dummy_out = self._convs(dummy_in)
|
||||
self.num_outputs = dummy_out.shape[1]
|
||||
|
||||
# Build the value layers
|
||||
self._value_branch_separate = self._value_branch = None
|
||||
if vf_share_layers:
|
||||
|
|
|
@ -8,8 +8,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
|||
"""Returns a framework specific activation function, given a name string.
|
||||
|
||||
Args:
|
||||
name (Optional[str]): One of "relu" (default), "tanh", "swish", or
|
||||
"linear" or None.
|
||||
name (Optional[str]): One of "relu" (default), "tanh", "elu",
|
||||
"swish", or "linear" (same as None).
|
||||
framework (str): One of "jax", "tf|tfe|tf2" or "torch".
|
||||
|
||||
Returns:
|
||||
|
@ -35,6 +35,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
|||
return nn.ReLU
|
||||
elif name == "tanh":
|
||||
return nn.Tanh
|
||||
elif name == "elu":
|
||||
return nn.ELU
|
||||
elif framework == "jax":
|
||||
if name in ["linear", None]:
|
||||
return None
|
||||
|
@ -45,6 +47,8 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
|
|||
return jax.nn.relu
|
||||
elif name == "tanh":
|
||||
return jax.nn.hard_tanh
|
||||
elif name == "elu":
|
||||
return jax.nn.elu
|
||||
else:
|
||||
assert framework in ["tf", "tfe", "tf2"],\
|
||||
"Unsupported framework `{}`!".format(framework)
|
||||
|
|
|
@ -123,6 +123,10 @@ class JsonReader(InputReader):
|
|||
"from URIs like {}".format(path))
|
||||
ctx = smart_open
|
||||
else:
|
||||
# Allow shortcut for home directory ("~/" -> env[HOME]).
|
||||
if path.startswith("~/"):
|
||||
path = os.path.join(os.environ.get("HOME", ""), path[2:])
|
||||
|
||||
# If path doesn't exist, try to interpret is as relative to the
|
||||
# rllib directory (located ../../ from this very module).
|
||||
path_orig = path
|
||||
|
|
|
@ -130,9 +130,11 @@ if __name__ == "__main__":
|
|||
t.stopping_criterion.get("episode_reward_mean"))
|
||||
# Otherwise, expect `episode_reward_mean` to be set.
|
||||
else:
|
||||
min_reward = t.stopping_criterion["episode_reward_mean"]
|
||||
min_reward = t.stopping_criterion.get(
|
||||
"episode_reward_mean")
|
||||
|
||||
if reward_mean >= min_reward:
|
||||
# If min reward not defined, always pass.
|
||||
if min_reward is None or reward_mean >= min_reward:
|
||||
passed = True
|
||||
break
|
||||
|
||||
|
|
|
@ -12,34 +12,33 @@ halfcheetah_cql:
|
|||
#input: d4rl.halfcheetah-medium-v0
|
||||
input: d4rl.halfcheetah-expert-v0
|
||||
#input: d4rl.halfcheetah-medium-replay-v0
|
||||
framework: torch
|
||||
|
||||
# Works for both torch and tf.
|
||||
framework: tf
|
||||
soft_horizon: False
|
||||
horizon: 1000
|
||||
Q_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
policy_model:
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
fcnet_activation: relu
|
||||
fcnet_hiddens: [256, 256, 256]
|
||||
tau: 0.005
|
||||
target_entropy: auto
|
||||
no_done_at_end: false
|
||||
n_step: 1
|
||||
n_step: 3
|
||||
rollout_fragment_length: 1
|
||||
prioritized_replay: false
|
||||
train_batch_size: 256
|
||||
target_network_update_freq: 0
|
||||
timesteps_per_iteration: 1000
|
||||
learning_starts: 10
|
||||
learning_starts: 256
|
||||
optimization:
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
entropy_learning_rate: 0.0001
|
||||
actor_learning_rate: 0.0001
|
||||
critic_learning_rate: 0.0003
|
||||
entropy_learning_rate: 0.0001
|
||||
num_workers: 0
|
||||
num_gpus: 1
|
||||
clip_actions: false
|
||||
normalize_actions: true
|
||||
evaluation_interval: 1
|
||||
metrics_smoothing_episodes: 5
|
||||
# CQL Configs
|
||||
min_q_weight: 5.0
|
||||
|
@ -47,6 +46,7 @@ halfcheetah_cql:
|
|||
temperature: 1.0
|
||||
num_actions: 10
|
||||
lagrangian: False
|
||||
|
||||
evaluation_interval: 3
|
||||
evaluation_config:
|
||||
input: sampler
|
||||
|
||||
|
|
|
@ -406,7 +406,14 @@ def run_learning_tests_from_yaml(
|
|||
|
||||
# Add torch version of all experiments to the list.
|
||||
for k, e in tf_experiments.items():
|
||||
e["config"]["framework"] = "tf"
|
||||
# If framework explicitly given, only test for that framework.
|
||||
# Some algos do not have both versions available.
|
||||
if "framework" in e["config"]:
|
||||
frameworks = [e["config"]["framework"]]
|
||||
else:
|
||||
frameworks = ["tf", "torch"]
|
||||
e["config"]["framework"] = "tf"
|
||||
|
||||
# For smoke-tests, we just run for n min.
|
||||
if smoke_test:
|
||||
# 15min hardcoded for now.
|
||||
|
@ -420,15 +427,27 @@ def run_learning_tests_from_yaml(
|
|||
e["stop"]["episode_reward_mean"] = \
|
||||
e["pass_criteria"]["episode_reward_mean"]
|
||||
|
||||
keys = []
|
||||
# Generate the torch copy of the experiment.
|
||||
e_torch = copy.deepcopy(e)
|
||||
e_torch["config"]["framework"] = "torch"
|
||||
k_tf = re.sub("^(\\w+)-", "\\1-tf-", k)
|
||||
k_torch = re.sub("-tf-", "-torch-", k_tf)
|
||||
experiments[k_tf] = e
|
||||
experiments[k_torch] = e_torch
|
||||
# Generate `checks` dict.
|
||||
for k_ in [k_tf, k_torch]:
|
||||
if len(frameworks) == 2:
|
||||
e_torch = copy.deepcopy(e)
|
||||
e_torch["config"]["framework"] = "torch"
|
||||
keys.append(re.sub("^(\\w+)-", "\\1-tf-", k))
|
||||
keys.append(re.sub("-tf-", "-torch-", keys[0]))
|
||||
experiments[keys[0]] = e
|
||||
experiments[keys[1]] = e_torch
|
||||
# tf-only.
|
||||
elif frameworks[0] == "tf":
|
||||
keys.append(re.sub("^(\\w+)-", "\\1-tf-", k))
|
||||
experiments[keys[0]] = e
|
||||
# torch-only.
|
||||
else:
|
||||
keys.append(re.sub("^(\\w+)-", "\\1-torch-", k))
|
||||
experiments[keys[0]] = e
|
||||
|
||||
# Generate `checks` dict for all experiments (tf and/or torch).
|
||||
for k_ in keys:
|
||||
e = experiments[k_]
|
||||
checks[k_] = {
|
||||
"min_reward": e["pass_criteria"]["episode_reward_mean"],
|
||||
"min_timesteps": e["pass_criteria"]["timesteps_total"],
|
||||
|
@ -436,9 +455,8 @@ def run_learning_tests_from_yaml(
|
|||
"failures": 0,
|
||||
"passed": False,
|
||||
}
|
||||
# These keys would break tune.
|
||||
del e["pass_criteria"]
|
||||
del e_torch["pass_criteria"]
|
||||
# This key would break tune.
|
||||
del e["pass_criteria"]
|
||||
|
||||
# Print out the actual config.
|
||||
print("== Test config ==")
|
||||
|
@ -471,21 +489,32 @@ def run_learning_tests_from_yaml(
|
|||
for t in trials:
|
||||
experiment = re.sub(".+/([^/]+)$", "\\1", t.local_dir)
|
||||
|
||||
# If we have evaluation workers, use their rewards.
|
||||
# This is useful for offline learning tests, where
|
||||
# we evaluate against an actual environment.
|
||||
check_eval = experiments[experiment]["config"].get(
|
||||
"evaluation_interval", None) is not None
|
||||
|
||||
if t.status == "ERROR":
|
||||
checks[experiment]["failures"] += 1
|
||||
else:
|
||||
reward_mean = \
|
||||
t.last_result["evaluation"]["episode_reward_mean"] if \
|
||||
check_eval else t.last_result["episode_reward_mean"]
|
||||
desired_reward = checks[experiment]["min_reward"]
|
||||
desired_timesteps = checks[experiment]["min_timesteps"]
|
||||
|
||||
throughput = t.last_result["timesteps_total"] / \
|
||||
t.last_result["time_total_s"]
|
||||
|
||||
desired_timesteps = checks[experiment]["min_timesteps"]
|
||||
desired_throughput = \
|
||||
desired_timesteps / t.stopping_criterion["time_total_s"]
|
||||
|
||||
if t.last_result["episode_reward_mean"] < desired_reward or \
|
||||
desired_throughput and throughput < desired_throughput:
|
||||
# We failed to reach desired reward or the desired throughput.
|
||||
if reward_mean < desired_reward or \
|
||||
(desired_throughput and
|
||||
throughput < desired_throughput):
|
||||
checks[experiment]["failures"] += 1
|
||||
# We succeeded!
|
||||
else:
|
||||
checks[experiment]["passed"] = True
|
||||
del experiments_to_run[experiment]
|
||||
|
|
Loading…
Add table
Reference in a new issue