mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Fix flakey test_a3c, test_maml, test_apex_dqn. (#19035)
This commit is contained in:
parent
7588bfd315
commit
73f5c4039b
3 changed files with 38 additions and 22 deletions
|
@ -68,6 +68,8 @@ DEFAULT_CONFIG = dqn.DQNTrainer.merge_trainer_configs(
|
|||
# === Hyperparameters from the paper [1] ===
|
||||
# Size of the replay buffer (in sequences, not timesteps).
|
||||
"buffer_size": 100000,
|
||||
# If True prioritized replay buffer will be used.
|
||||
"prioritized_replay": False,
|
||||
# Set automatically: The number of contiguous environment steps to
|
||||
# replay at once. Will be calculated via
|
||||
# model->max_seq_len + burn_in.
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
from typing import Tuple
|
||||
|
||||
from ray.rllib.policy.policy import PolicySpec
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.typing import PartialTrainerConfigDict
|
||||
from ray.rllib.utils.typing import MultiAgentPolicyConfigDict, \
|
||||
PartialTrainerConfigDict
|
||||
|
||||
|
||||
def check_multi_agent(config: PartialTrainerConfigDict):
|
||||
def check_multi_agent(config: PartialTrainerConfigDict) -> \
|
||||
Tuple[MultiAgentPolicyConfigDict, bool]:
|
||||
"""Checks, whether a (partial) config defines a multi-agent setup.
|
||||
|
||||
Args:
|
||||
|
@ -11,18 +15,25 @@ def check_multi_agent(config: PartialTrainerConfigDict):
|
|||
to check for multi-agent.
|
||||
|
||||
Returns:
|
||||
Tuple[MultiAgentPolicyConfigDict, bool]: The resulting (all
|
||||
fixed) multi-agent policy dict and whether we have a
|
||||
multi-agent setup or not.
|
||||
The resulting (all fixed) multi-agent policy dict and whether we
|
||||
have a multi-agent setup or not.
|
||||
"""
|
||||
multiagent_config = config["multiagent"]
|
||||
policies = multiagent_config.get("policies")
|
||||
|
||||
# Nothing specified in config dict -> Assume simple single agent setup
|
||||
# with DEFAULT_POLICY_ID as only policy.
|
||||
if not policies:
|
||||
policies = {DEFAULT_POLICY_ID}
|
||||
# Policies given as set (of PolicyIDs) -> Setup each policy automatically
|
||||
# via empty PolicySpec (will make RLlib infer obs- and action spaces
|
||||
# as well as the Policy's class).
|
||||
if isinstance(policies, set):
|
||||
policies = multiagent_config["policies"] = {
|
||||
pid: PolicySpec()
|
||||
for pid in policies
|
||||
}
|
||||
# Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only
|
||||
# PolicyID found in policies dict.
|
||||
is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies
|
||||
return policies, is_multiagent
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from collections import Counter
|
||||
import copy
|
||||
import gym
|
||||
from gym.spaces import Box
|
||||
import logging
|
||||
import numpy as np
|
||||
import random
|
||||
|
@ -297,7 +297,7 @@ def check_compute_single_action(trainer,
|
|||
call_kwargs["full_fetch"] = full_fetch
|
||||
|
||||
obs = obs_space.sample()
|
||||
if isinstance(obs_space, gym.spaces.Box):
|
||||
if isinstance(obs_space, Box):
|
||||
obs = np.clip(obs, -1.0, 1.0)
|
||||
state_in = None
|
||||
if include_state:
|
||||
|
@ -368,23 +368,24 @@ def check_compute_single_action(trainer,
|
|||
for si, so in zip(state_in, state_out):
|
||||
check(list(si.shape), so.shape)
|
||||
|
||||
# Test whether unsquash/clipping works: Both should push the action
|
||||
# to certainly be within the space's bounds.
|
||||
if not action_space.contains(action):
|
||||
if clip or unsquash or not isinstance(action_space,
|
||||
gym.spaces.Box):
|
||||
# Test whether unsquash/clipping works on the Trainer's
|
||||
# compute_single_action method: Both flags should force the action
|
||||
# to be within the space's bounds.
|
||||
if method_to_test == "single" and what == trainer:
|
||||
if not action_space.contains(action) and \
|
||||
(clip or unsquash or not isinstance(action_space, Box)):
|
||||
raise ValueError(
|
||||
f"Returned action ({action}) of trainer/policy {what} "
|
||||
f"not in Env's action_space {action_space}")
|
||||
# We are operating in normalized space: Expect only smaller action
|
||||
# values.
|
||||
if isinstance(action_space, gym.spaces.Box) and not unsquash and \
|
||||
what.config.get("normalize_actions") and \
|
||||
np.any(np.abs(action) > 10.0):
|
||||
raise ValueError(
|
||||
f"Returned action ({action}) of trainer/policy {what} "
|
||||
"should be in normalized space, but seems too large/small for "
|
||||
"that!")
|
||||
# We are operating in normalized space: Expect only smaller action
|
||||
# values.
|
||||
if isinstance(action_space, Box) and not unsquash and \
|
||||
what.config.get("normalize_actions") and \
|
||||
np.any(np.abs(action) > 3.0):
|
||||
raise ValueError(
|
||||
f"Returned action ({action}) of trainer/policy {what} "
|
||||
"should be in normalized space, but seems too large/small "
|
||||
"for that!")
|
||||
|
||||
# Loop through: Policy vs Trainer; Different API methods to calculate
|
||||
# actions; unsquash option; clip option; full fetch or not.
|
||||
|
@ -501,7 +502,9 @@ def check_train_results(train_results):
|
|||
# Make sure we have a default_policy key if we are not in a
|
||||
# multi-agent setup.
|
||||
if not is_multi_agent:
|
||||
assert DEFAULT_POLICY_ID in learner_info, \
|
||||
# APEX algos sometimes have an empty learner info dict (no metrics
|
||||
# collected yet).
|
||||
assert len(learner_info) == 0 or DEFAULT_POLICY_ID in learner_info, \
|
||||
f"'{DEFAULT_POLICY_ID}' not found in " \
|
||||
f"train_results['infos']['learner'] ({learner_info})!"
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue