[RLlib] Fix flakey test_a3c, test_maml, test_apex_dqn. (#19035)

This commit is contained in:
Sven Mika 2021-10-04 13:23:51 +02:00 committed by GitHub
parent 7588bfd315
commit 73f5c4039b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 22 deletions

View file

@ -68,6 +68,8 @@ DEFAULT_CONFIG = dqn.DQNTrainer.merge_trainer_configs(
# === Hyperparameters from the paper [1] ===
# Size of the replay buffer (in sequences, not timesteps).
"buffer_size": 100000,
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
# Set automatically: The number of contiguous environment steps to
# replay at once. Will be calculated via
# model->max_seq_len + burn_in.

View file

@ -1,9 +1,13 @@
from typing import Tuple
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.typing import PartialTrainerConfigDict
from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, \
PartialTrainerConfigDict
def check_multi_agent(config: PartialTrainerConfigDict):
def check_multi_agent(config: PartialTrainerConfigDict) -> \
Tuple[MultiAgentPolicyConfigDict, bool]:
"""Checks, whether a (partial) config defines a multi-agent setup.
Args:
@ -11,18 +15,25 @@ def check_multi_agent(config: PartialTrainerConfigDict):
to check for multi-agent.
Returns:
Tuple[MultiAgentPolicyConfigDict, bool]: The resulting (all
fixed) multi-agent policy dict and whether we have a
multi-agent setup or not.
The resulting (all fixed) multi-agent policy dict and whether we
have a multi-agent setup or not.
"""
multiagent_config = config["multiagent"]
policies = multiagent_config.get("policies")
# Nothing specified in config dict -> Assume simple single agent setup
# with DEFAULT_POLICY_ID as only policy.
if not policies:
policies = {DEFAULT_POLICY_ID}
# Policies given as set (of PolicyIDs) -> Setup each policy automatically
# via empty PolicySpec (will make RLlib infer obs- and action spaces
# as well as the Policy's class).
if isinstance(policies, set):
policies = multiagent_config["policies"] = {
pid: PolicySpec()
for pid in policies
}
# Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only
# PolicyID found in policies dict.
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
return policies, is_multiagent

View file

@ -1,6 +1,6 @@
from collections import Counter
import copy
import gym
from gym.spaces import Box
import logging
import numpy as np
import random
@ -297,7 +297,7 @@ def check_compute_single_action(trainer,
call_kwargs["full_fetch"] = full_fetch
obs = obs_space.sample()
if isinstance(obs_space, gym.spaces.Box):
if isinstance(obs_space, Box):
obs = np.clip(obs, -1.0, 1.0)
state_in = None
if include_state:
@ -368,23 +368,24 @@ def check_compute_single_action(trainer,
for si, so in zip(state_in, state_out):
check(list(si.shape), so.shape)
# Test whether unsquash/clipping works: Both should push the action
# to certainly be within the space's bounds.
if not action_space.contains(action):
if clip or unsquash or not isinstance(action_space,
gym.spaces.Box):
# Test whether unsquash/clipping works on the Trainer's
# compute_single_action method: Both flags should force the action
# to be within the space's bounds.
if method_to_test == "single" and what == trainer:
if not action_space.contains(action) and \
(clip or unsquash or not isinstance(action_space, Box)):
raise ValueError(
f"Returned action ({action}) of trainer/policy {what} "
f"not in Env's action_space {action_space}")
# We are operating in normalized space: Expect only smaller action
# values.
if isinstance(action_space, gym.spaces.Box) and not unsquash and \
what.config.get("normalize_actions") and \
np.any(np.abs(action) > 10.0):
raise ValueError(
f"Returned action ({action}) of trainer/policy {what} "
"should be in normalized space, but seems too large/small for "
"that!")
# We are operating in normalized space: Expect only smaller action
# values.
if isinstance(action_space, Box) and not unsquash and \
what.config.get("normalize_actions") and \
np.any(np.abs(action) > 3.0):
raise ValueError(
f"Returned action ({action}) of trainer/policy {what} "
"should be in normalized space, but seems too large/small "
"for that!")
# Loop through: Policy vs Trainer; Different API methods to calculate
# actions; unsquash option; clip option; full fetch or not.
@ -501,7 +502,9 @@ def check_train_results(train_results):
# Make sure we have a default_policy key if we are not in a
# multi-agent setup.
if not is_multi_agent:
assert DEFAULT_POLICY_ID in learner_info, \
# APEX algos sometimes have an empty learner info dict (no metrics
# collected yet).
assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \
f"'{DEFAULT_POLICY_ID}' not found in " \
f"train_results['infos']['learner'] ({learner_info})!"