[RLlib] Bring back BC and Marwil learning tests. (#21574)

This commit is contained in:
Jun Gong 2022-01-14 05:35:32 -08:00 committed by GitHub
parent ded4128ebf
commit 7517aefe05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 70 additions and 58 deletions

View file

@ -136,37 +136,35 @@ appo-pongnoframeskip-v4:
# observation_filter: NoFilter
# report_length: 3
# TODO: (sven) Fix all BC-dependent learning tests for cont. actions.
# These seem quite hard to learn from the SAC-recorded HalfCheetahBulletEnv.
# bc-halfcheetahbulletenv-v0:
# env: HalfCheetahBulletEnv-v0
# run: BC
# pass_criteria:
# episode_reward_mean: 400.0
# timesteps_total: 10000000
# stop:
# time_total_s: 3600
# config:
# # Use input produced by expert SAC algo.
# input: ["~/halfcheetah_expert_sac.zip"]
# actions_in_input_normalized: true
bc-halfcheetahbulletenv-v0:
env: HalfCheetahBulletEnv-v0
run: BC
pass_criteria:
evaluation/episode_reward_mean: 400.0
timesteps_total: 10000000
stop:
time_total_s: 3600
config:
# Use input produced by expert SAC algo.
input: ["~/halfcheetah_expert_sac.zip"]
actions_in_input_normalized: true
# num_gpus: 1
num_gpus: 1
# model:
# fcnet_activation: relu
# fcnet_hiddens: [256, 256, 256]
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
# evaluation_num_workers: 1
# evaluation_interval: 3
# evaluation_config:
# input: sampler
evaluation_num_workers: 1
evaluation_interval: 3
evaluation_config:
input: sampler
cql-halfcheetahbulletenv-v0:
env: HalfCheetahBulletEnv-v0
run: CQL
pass_criteria:
episode_reward_mean: 400.0
evaluation/episode_reward_mean: 400.0
timesteps_total: 10000000
stop:
time_total_s: 3600
@ -363,31 +361,31 @@ impala-breakoutnoframeskip-v4:
]
num_gpus: 1
# marwil-halfcheetahbulletenv-v0:
# env: HalfCheetahBulletEnv-v0
# run: MARWIL
# pass_criteria:
# episode_reward_mean: 400.0
# timesteps_total: 10000000
# stop:
# time_total_s: 3600
# config:
# # Use input produced by expert SAC algo.
# input: ["~/halfcheetah_expert_sac.zip"]
# actions_in_input_normalized: true
# # Switch off input evaluation (data does not contain action probs).
# input_evaluation: []
marwil-halfcheetahbulletenv-v0:
env: HalfCheetahBulletEnv-v0
run: MARWIL
pass_criteria:
evaluation/episode_reward_mean: 400.0
timesteps_total: 10000000
stop:
time_total_s: 3600
config:
# Use input produced by expert SAC algo.
input: ["~/halfcheetah_expert_sac.zip"]
actions_in_input_normalized: true
# Switch off input evaluation (data does not contain action probs).
input_evaluation: []
# num_gpus: 1
num_gpus: 1
# model:
# fcnet_activation: relu
# fcnet_hiddens: [256, 256, 256]
model:
fcnet_activation: relu
fcnet_hiddens: [256, 256, 256]
# evaluation_num_workers: 1
# evaluation_interval: 1
# evaluation_config:
# input: sampler
evaluation_num_workers: 1
evaluation_interval: 1
evaluation_config:
input: sampler
ppo-breakoutnoframeskip-v4:
env: BreakoutNoFrameskip-v4

View file

@ -189,7 +189,7 @@ class TestTrainer(unittest.TestCase):
# Eval results are not available at step 0.
# But step 3 should still have it, even though no eval was
# run during that step.
self.assertFalse("evaluation" in r0)
self.assertTrue("evaluation" in r0)
self.assertTrue("evaluation" in r1)
self.assertTrue("evaluation" in r2)
self.assertTrue("evaluation" in r3)

View file

@ -724,10 +724,19 @@ class Trainer(Trainable):
self._episode_history = []
self._episodes_to_be_collected = []
# Evaluation WorkerSet.
# Evaluation WorkerSet and metrics last returned by `self.evaluate()`.
self.evaluation_workers: Optional[WorkerSet] = None
# Metrics most recently returned by `self.evaluate()`.
self.evaluation_metrics = {}
# Initialize common evaluation_metrics to nan, before they become
# available. We want to make sure the metrics are always present
# (although their values may be nan), so that Tune does not complain
# when we use these as stopping criteria.
self.evaluation_metrics = {
"evaluation": {
"episode_reward_max": np.nan,
"episode_reward_min": np.nan,
"episode_reward_mean": np.nan,
}
}
super().__init__(config, logger_creator, remote_checkpoint_dir,
sync_function_tpl)

View file

@ -74,7 +74,7 @@ class AssertEvalCallback(DefaultCallbacks):
# Make sure we always run exactly the given evaluation duration,
# no matter what the other settings are (such as
# `evaluation_num_workers` or `evaluation_parallel_to_training`).
if "evaluation" in result:
if "evaluation" in result and "hist_stats" in result["evaluation"]:
hist_stats = result["evaluation"]["hist_stats"]
# We count in episodes.
if trainer.config["evaluation_duration_unit"] == "episodes":

View file

@ -607,6 +607,13 @@ def run_learning_tests_from_yaml(
start_time = time.monotonic()
def should_check_eval(experiment):
# If we have evaluation workers, use their rewards.
# This is useful for offline learning tests, where
# we evaluate against an actual environment.
return experiment["config"].get("evaluation_interval",
None) is not None
# Loop through all collected files and gather experiments.
# Augment all by `torch` framework.
for yaml_file in yaml_files:
@ -637,11 +644,13 @@ def run_learning_tests_from_yaml(
# create its trainer and run a first iteration.
e["stop"]["time_total_s"] = 0
else:
check_eval = should_check_eval(e)
episode_reward_key = ("episode_reward_mean" if not check_eval
else "evaluation/episode_reward_mean")
# We also stop early, once we reach the desired reward.
min_reward = e.get("pass_criteria",
{}).get("episode_reward_mean")
min_reward = e.get("pass_criteria", {}).get(episode_reward_key)
if min_reward is not None:
e["stop"]["episode_reward_mean"] = min_reward
e["stop"][episode_reward_key] = min_reward
# Generate `checks` dict for all experiments
# (tf, tf2 and/or torch).
@ -723,11 +732,7 @@ def run_learning_tests_from_yaml(
trials_for_experiment.append(t)
print(f" ... Trials: {trials_for_experiment}.")
# If we have evaluation workers, use their rewards.
# This is useful for offline learning tests, where
# we evaluate against an actual environment.
check_eval = experiments[experiment]["config"].get(
"evaluation_interval", None) is not None
check_eval = should_check_eval(experiments[experiment])
# Error: Increase failure count and repeat.
if any(t.status == "ERROR" for t in trials_for_experiment):