[rllib] Cleanups: deep merge configs properly; enforce min iter time on APEX (#2500)

The dict merge prevents crashes when tune is trying to get resource requests for agents and you override a config subkey. The min iter time prevents iterations from getting too small, incurring high overhead. This is easy to run into on Ape-X since throughput can get very high.
This commit is contained in:
Eric Liang 2018-07-30 13:25:35 -07:00 committed by GitHub
parent 62a52ee989
commit 38d00986a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 83 additions and 64 deletions

View file

@ -8,7 +8,7 @@ import os
import ray
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.optimizers import AsyncGradientsOptimizer
from ray.rllib.utils import FilterManager
from ray.rllib.utils import FilterManager, merge_dicts
from ray.tune.trial import Resources
DEFAULT_CONFIG = with_common_config({
@ -71,7 +71,7 @@ class A3CAgent(Agent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
return Resources(
cpu=1,
gpu=0,

View file

@ -10,6 +10,7 @@ import pickle
import tensorflow as tf
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.utils import deep_update
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
@ -35,6 +36,8 @@ COMMON_CONFIG = {
"preprocessor_pref": "rllib",
# Arguments to pass to the env creator
"env_config": {},
# Environment name can also be passed via config
"env": None,
# Arguments to pass to model
"model": {
"use_lstm": False,
@ -79,34 +82,6 @@ def with_common_config(extra_config):
return config
def _deep_update(original, new_dict, new_keys_allowed, whitelist):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
whitelist (list): List of keys that correspond to dict values
where new subkeys can be introduced. This is only at
the top level.
"""
for k, value in new_dict.items():
if k not in original and k != "env":
if not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
if type(original.get(k)) is dict:
if k in whitelist:
_deep_update(original[k], value, True, [])
else:
_deep_update(original[k], value, new_keys_allowed, [])
else:
original[k] = value
return original
class Agent(Trainable):
"""All RLlib agents extend this base class.
@ -205,9 +180,9 @@ class Agent(Trainable):
# Merge the supplied config with the class default
merged_config = self._default_config.copy()
merged_config = _deep_update(merged_config, self.config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
merged_config = deep_update(merged_config, self.config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.config = merged_config
# TODO(ekl) setting the graph is unnecessary for PyTorch agents

View file

@ -7,6 +7,7 @@ from ray.rllib.agents.agent import Agent
from ray.rllib.agents.bc.bc_evaluator import BCEvaluator, \
GPURemoteBCEvaluator, RemoteBCEvaluator
from ray.rllib.optimizers import AsyncGradientsOptimizer
from ray.rllib.utils import merge_dicts
from ray.tune.result import TrainingResult
from ray.tune.trial import Resources
@ -51,7 +52,7 @@ class BCAgent(Agent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
if cf["use_gpu_for_workers"]:
num_gpus_per_worker = 1
else:

View file

@ -3,8 +3,8 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.ddpg.ddpg import DDPGAgent, DEFAULT_CONFIG as DDPG_CONFIG
from ray.rllib.utils import merge_dicts
from ray.tune.trial import Resources
from ray.utils import merge_dicts
APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
DDPG_CONFIG,
@ -28,6 +28,7 @@ APEX_DDPG_DEFAULT_CONFIG = merge_dicts(
"timesteps_per_iteration": 25000,
"per_worker_exploration": True,
"worker_side_prioritization": True,
"min_iter_time_s": 30,
},
)
@ -44,7 +45,7 @@ class ApexDDPGAgent(DDPGAgent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
return Resources(
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
gpu=cf["gpu"] and 1 or 0,

View file

@ -106,6 +106,8 @@ DEFAULT_CONFIG = with_common_config({
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
})

View file

@ -3,8 +3,8 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.dqn.dqn import DQNAgent, DEFAULT_CONFIG as DQN_CONFIG
from ray.rllib.utils import merge_dicts
from ray.tune.trial import Resources
from ray.utils import merge_dicts
APEX_DEFAULT_CONFIG = merge_dicts(
DQN_CONFIG,
@ -27,6 +27,7 @@ APEX_DEFAULT_CONFIG = merge_dicts(
"timesteps_per_iteration": 25000,
"per_worker_exploration": True,
"worker_side_prioritization": True,
"min_iter_time_s": 30,
},
)
@ -43,7 +44,7 @@ class ApexAgent(DQNAgent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
return Resources(
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
gpu=cf["gpu"] and 1 or 0,

View file

@ -4,12 +4,14 @@ from __future__ import print_function
import pickle
import os
import time
import ray
from ray.rllib import optimizers
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
from ray.tune.trial import Resources
@ -96,6 +98,8 @@ DEFAULT_CONFIG = with_common_config({
"per_worker_exploration": False,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
})
@ -108,7 +112,7 @@ class DQNAgent(Agent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
return Resources(
cpu=1,
gpu=cf["gpu"] and 1 or 0,
@ -174,8 +178,10 @@ class DQNAgent(Agent):
def _train(self):
start_timestep = self.global_timestep
start = time.time()
while (self.global_timestep - start_timestep <
self.config["timesteps_per_iteration"]):
self.config["timesteps_per_iteration"]
) or time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
self.update_target_if_needed()

View file

@ -19,6 +19,7 @@ from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import policies
from ray.rllib.agents.es import tabular_logger as tlogger
from ray.rllib.agents.es import utils
from ray.rllib.utils import merge_dicts
Result = namedtuple("Result", [
"noise_indices", "noisy_returns", "sign_noisy_returns", "noisy_lengths",
@ -26,17 +27,18 @@ Result = namedtuple("Result", [
])
DEFAULT_CONFIG = {
'l2_coeff': 0.005,
'noise_stdev': 0.02,
'episodes_per_batch': 1000,
'timesteps_per_batch': 10000,
'eval_prob': 0.003,
'return_proc_mode': "centered_rank",
'num_workers': 10,
'stepsize': 0.01,
'observation_filter': "MeanStdFilter",
'noise_size': 250000000,
'env_config': {},
"l2_coeff": 0.005,
"noise_stdev": 0.02,
"episodes_per_batch": 1000,
"timesteps_per_batch": 10000,
"eval_prob": 0.003,
"return_proc_mode": "centered_rank",
"num_workers": 10,
"stepsize": 0.01,
"observation_filter": "MeanStdFilter",
"noise_size": 250000000,
"env": None,
"env_config": {},
}
@ -147,7 +149,7 @@ class ESAgent(Agent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
def _init(self):

View file

@ -5,6 +5,7 @@ from __future__ import print_function
from ray.rllib.agents.agent import Agent, with_common_config
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.utils import merge_dicts
from ray.tune.trial import Resources
DEFAULT_CONFIG = with_common_config({
@ -34,7 +35,7 @@ class PGAgent(Agent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
return Resources(cpu=1, gpu=0, extra_cpu=cf["num_workers"])
def _init(self):

View file

@ -8,7 +8,7 @@ import pickle
import ray
from ray.rllib.agents import Agent, with_common_config
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.utils import FilterManager
from ray.rllib.utils import FilterManager, merge_dicts
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
from ray.tune.trial import Resources
@ -66,7 +66,7 @@ class PPOAgent(Agent):
@classmethod
def default_resource_request(cls, config):
cf = dict(cls._default_config, **config)
cf = merge_dicts(cls._default_config, config)
return Resources(
cpu=1,
gpu=cf["num_gpus"],

View file

@ -1,3 +1,5 @@
import copy
from ray.rllib.utils.filter_manager import FilterManager
from ray.rllib.utils.filter import Filter
from ray.rllib.utils.policy_client import PolicyClient
@ -9,3 +11,38 @@ __all__ = [
"PolicyClient",
"PolicyServer",
]
def merge_dicts(d1, d2):
"""Returns a new dict that is d1 and d2 deep merged."""
merged = copy.deepcopy(d1)
deep_update(merged, d2, True, [])
return merged
def deep_update(original, new_dict, new_keys_allowed, whitelist):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
whitelist (list): List of keys that correspond to dict values
where new subkeys can be introduced. This is only at
the top level.
"""
for k, value in new_dict.items():
if k not in original:
if not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
if type(original.get(k)) is dict:
if k in whitelist:
deep_update(original[k], value, True, [])
else:
deep_update(original[k], value, new_keys_allowed, [])
else:
original[k] = value
return original

View file

@ -267,13 +267,6 @@ def resources_from_resource_arguments(default_num_cpus, default_num_gpus,
return resources
def merge_dicts(d1, d2):
"""Merge two dicts and return a new dict that's their union."""
d = d1.copy()
d.update(d2)
return d
def check_oversized_pickle(pickled, name, obj_type, worker):
"""Send a warning message if the pickled object is too large.

View file

@ -85,7 +85,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
--env CartPole-v0 \
--run APEX \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2, "timesteps_per_iteration": 1000, "gpu": false}'
--config '{"num_workers": 2, "timesteps_per_iteration": 1000, "gpu": false, "min_iter_time_s": 1}'
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
python /ray/python/ray/rllib/train.py \
@ -197,7 +197,7 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
--run APEX_DDPG \
--ray-num-cpus 8 \
--stop '{"training_iteration": 2}' \
--config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100}'
--config '{"num_workers": 2, "optimizer": {"num_replay_buffer_shards": 1}, "learning_starts": 100, "min_iter_time_s": 1}'
docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \
sh /ray/test/jenkins_tests/multi_node_tests/test_rllib_eval.sh