mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Cleanups: deep merge configs properly; enforce min iter time on APEX (#2500)
The dict merge prevents crashes when tune is trying to get resource requests for agents and you override a config subkey. The min iter time prevents iterations from getting too small, incurring high overhead. This is easy to run into on Ape-X since throughput can get very high.
This commit is contained in:
parent
62a52ee989
commit
38d00986a5
13 changed files with 83 additions and 64 deletions
|
@ -8,7 +8,7 @@ import os
|
|||
import ray
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.utils import FilterManager, merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
|
@ -71,7 +71,7 @@ class A3CAgent(Agent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=0,
|
||||
|
|
|
@ -10,6 +10,7 @@ import pickle
|
|||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.utils import deep_update
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trainable import Trainable
|
||||
|
@ -35,6 +36,8 @@ COMMON_CONFIG = {
|
|||
"preprocessor_pref": "rllib",
|
||||
# Arguments to pass to the env creator
|
||||
"env_config": {},
|
||||
# Environment name can also be passed via config
|
||||
"env": None,
|
||||
# Arguments to pass to model
|
||||
"model": {
|
||||
"use_lstm": False,
|
||||
|
@ -79,34 +82,6 @@ def with_common_config(extra_config):
|
|||
return config
|
||||
|
||||
|
||||
def _deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
Args:
|
||||
original (dict): Dictionary with default values.
|
||||
new_dict (dict): Dictionary with values to be updated
|
||||
new_keys_allowed (bool): Whether new keys are allowed.
|
||||
whitelist (list): List of keys that correspond to dict values
|
||||
where new subkeys can be introduced. This is only at
|
||||
the top level.
|
||||
"""
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and k != "env":
|
||||
if not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
if type(original.get(k)) is dict:
|
||||
if k in whitelist:
|
||||
_deep_update(original[k], value, True, [])
|
||||
else:
|
||||
_deep_update(original[k], value, new_keys_allowed, [])
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
class Agent(Trainable):
|
||||
"""All RLlib agents extend this base class.
|
||||
|
||||
|
@ -205,9 +180,9 @@ class Agent(Trainable):
|
|||
|
||||
# Merge the supplied config with the class default
|
||||
merged_config = self._default_config.copy()
|
||||
merged_config = _deep_update(merged_config, self.config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
merged_config = deep_update(merged_config, self.config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
self.config = merged_config
|
||||
|
||||
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
|
||||
|
|
|
@ -7,6 +7,7 @@ from ray.rllib.agents.agent import Agent
|
|||
from ray.rllib.agents.bc.bc_evaluator import BCEvaluator, \
|
||||
GPURemoteBCEvaluator, RemoteBCEvaluator
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
|
@ -51,7 +52,7 @@ class BCAgent(Agent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
if cf["use_gpu_for_workers"]:
|
||||
num_gpus_per_worker = 1
|
||||
else:
|
||||
|
|
|
@ -3,8 +3,8 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
from ray.utils import merge_dicts
|
||||
|
||||
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
||||
DDPG_CONFIG,
|
||||
|
@ -28,6 +28,7 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
|
|||
"timesteps_per_iteration": 25000,
|
||||
"per_worker_exploration": True,
|
||||
"worker_side_prioritization": True,
|
||||
"min_iter_time_s": 30,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -44,7 +45,7 @@ class ApexDDPGAgent(DDPGAgent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
|
|
|
@ -106,6 +106,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -3,8 +3,8 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
from ray.utils import merge_dicts
|
||||
|
||||
APEX_DEFAULT_CONFIG = merge_dicts(
|
||||
DQN_CONFIG,
|
||||
|
@ -27,6 +27,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
|
|||
"timesteps_per_iteration": 25000,
|
||||
"per_worker_exploration": True,
|
||||
"worker_side_prioritization": True,
|
||||
"min_iter_time_s": 30,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -43,7 +44,7 @@ class ApexAgent(DQNAgent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
|
|
|
@ -4,12 +4,14 @@ from __future__ import print_function
|
|||
|
||||
import pickle
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib import optimizers
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
|
@ -96,6 +98,8 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"per_worker_exploration": False,
|
||||
# Whether to compute priorities on workers.
|
||||
"worker_side_prioritization": False,
|
||||
# Prevent iterations from going lower than this time span
|
||||
"min_iter_time_s": 1,
|
||||
})
|
||||
|
||||
|
||||
|
@ -108,7 +112,7 @@ class DQNAgent(Agent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
|
@ -174,8 +178,10 @@ class DQNAgent(Agent):
|
|||
def _train(self):
|
||||
start_timestep = self.global_timestep
|
||||
|
||||
start = time.time()
|
||||
while (self.global_timestep - start_timestep <
|
||||
self.config["timesteps_per_iteration"]):
|
||||
self.config["timesteps_per_iteration"]
|
||||
) or time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
self.update_target_if_needed()
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from ray.rllib.agents.es import optimizers
|
|||
from ray.rllib.agents.es import policies
|
||||
from ray.rllib.agents.es import tabular_logger as tlogger
|
||||
from ray.rllib.agents.es import utils
|
||||
from ray.rllib.utils import merge_dicts
|
||||
|
||||
Result = namedtuple("Result", [
|
||||
"noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths",
|
||||
|
@ -26,17 +27,18 @@ Result = namedtuple("Result", [
|
|||
])
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
'l2_coeff': 0.005,
|
||||
'noise_stdev': 0.02,
|
||||
'episodes_per_batch': 1000,
|
||||
'timesteps_per_batch': 10000,
|
||||
'eval_prob': 0.003,
|
||||
'return_proc_mode': "centered_rank",
|
||||
'num_workers': 10,
|
||||
'stepsize': 0.01,
|
||||
'observation_filter': "MeanStdFilter",
|
||||
'noise_size': 250000000,
|
||||
'env_config': {},
|
||||
"l2_coeff": 0.005,
|
||||
"noise_stdev": 0.02,
|
||||
"episodes_per_batch": 1000,
|
||||
"timesteps_per_batch": 10000,
|
||||
"eval_prob": 0.003,
|
||||
"return_proc_mode": "centered_rank",
|
||||
"num_workers": 10,
|
||||
"stepsize": 0.01,
|
||||
"observation_filter": "MeanStdFilter",
|
||||
"noise_size": 250000000,
|
||||
"env": None,
|
||||
"env_config": {},
|
||||
}
|
||||
|
||||
|
||||
|
@ -147,7 +149,7 @@ class ESAgent(Agent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
|
||||
|
||||
def _init(self):
|
||||
|
|
|
@ -5,6 +5,7 @@ from __future__ import print_function
|
|||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.utils import merge_dicts
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
DEFAULT_CONFIG = with_common_config({
|
||||
|
@ -34,7 +35,7 @@ class PGAgent(Agent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
|
||||
|
||||
def _init(self):
|
||||
|
|
|
@ -8,7 +8,7 @@ import pickle
|
|||
import ray
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.utils import FilterManager
|
||||
from ray.rllib.utils import FilterManager, merge_dicts
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
|
@ -66,7 +66,7 @@ class PPOAgent(Agent):
|
|||
|
||||
@classmethod
|
||||
def default_resource_request(cls, config):
|
||||
cf = dict(cls._default_config, **config)
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=cf["num_gpus"],
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import copy
|
||||
|
||||
from ray.rllib.utils.filter_manager import FilterManager
|
||||
from ray.rllib.utils.filter import Filter
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
|
@ -9,3 +11,38 @@ __all__ = [
|
|||
"PolicyClient",
|
||||
"PolicyServer",
|
||||
]
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""Returns a new dict that is d1 and d2 deep merged."""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(original, new_dict, new_keys_allowed, whitelist):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the whitelist, then new subkeys can be introduced.
|
||||
|
||||
Args:
|
||||
original (dict): Dictionary with default values.
|
||||
new_dict (dict): Dictionary with values to be updated
|
||||
new_keys_allowed (bool): Whether new keys are allowed.
|
||||
whitelist (list): List of keys that correspond to dict values
|
||||
where new subkeys can be introduced. This is only at
|
||||
the top level.
|
||||
"""
|
||||
for k, value in new_dict.items():
|
||||
if k not in original:
|
||||
if not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
if type(original.get(k)) is dict:
|
||||
if k in whitelist:
|
||||
deep_update(original[k], value, True, [])
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed, [])
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
|
|
@ -267,13 +267,6 @@ def resources_from_resource_arguments(default_num_cpus, default_num_gpus,
|
|||
return resources
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""Merge two dicts and return a new dict that's their union."""
|
||||
d = d1.copy()
|
||||
d.update(d2)
|
||||
return d
|
||||
|
||||
|
||||
def check_oversized_pickle(pickled, name, obj_type, worker):
|
||||
"""Send a warning message if the pickled object is too large.
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--env CartPole-v0 \
|
||||
--run APEX \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 2, "timesteps_per_iteration": 1000, "gpu": false}'
|
||||
--config '{"num_workers": 2, "timesteps_per_iteration": 1000, "gpu": false, "min_iter_time_s": 1}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
python /ray/python/ray/rllib/train.py \
|
||||
|
@ -197,7 +197,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
|||
--run APEX_DDPG \
|
||||
--ray-num-cpus 8 \
|
||||
--stop '{"training_iteration": 2}' \
|
||||
--config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100}'
|
||||
--config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100, "min_iter_time_s": 1}'
|
||||
|
||||
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
|
||||
sh /ray/test/jenkins_tests/multi_node_tests/test_rllib_eval.sh
|
||||
|
|
Loading…
Add table
Reference in a new issue