[tune] using None as the parameter default value instead of mutable dict (#1501)

* do not use dict as default parameter

* Update trial.py
This commit is contained in:
the-sea 2018-02-03 13:47:51 +08:00 committed by Eric Liang
parent 369773d3e8
commit a936468f99
3 changed files with 15 additions and 12 deletions

View file

@ -63,7 +63,7 @@ class Agent(Trainable):
_allow_unknown_subkeys = []
def __init__(
self, config={}, env=None, registry=get_registry(),
self, config=None, env=None, registry=get_registry(),
logger_creator=None):
"""Initialize an RLLib agent.
@ -77,6 +77,8 @@ class Agent(Trainable):
object. If unspecified, a default logger is created.
"""
config = config or {}
# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
Trainable.__init__(self, config, registry, logger_creator)

View file

@ -48,7 +48,7 @@ class Trainable(object):
classes and objects by name.
"""
def __init__(self, config={}, registry=None, logger_creator=None):
def __init__(self, config=None, registry=None, logger_creator=None):
"""Initialize an Trainable.
Subclasses should prefer defining ``_setup()`` instead of overriding
@ -68,7 +68,7 @@ class Trainable(object):
self._initialize_ok = False
self._experiment_id = uuid.uuid4().hex
self.config = config
self.config = config or {}
self.registry = registry
if logger_creator:

View file

@ -75,9 +75,9 @@ class Trial(object):
ERROR = "ERROR"
def __init__(
self, trainable_name, config={}, local_dir=DEFAULT_RESULTS_DIR,
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
stopping_criterion={}, checkpoint_freq=0,
stopping_criterion=None, checkpoint_freq=0,
restore_path=None, upload_dir=None):
"""Initialize a new trial.
@ -89,19 +89,20 @@ class Trial(object):
TRAINABLE_CLASS, trainable_name):
raise TuneError("Unknown trainable: " + trainable_name)
for k in stopping_criterion:
if k not in TrainingResult._fields:
raise TuneError(
"Stopping condition key `{}` must be one of {}".format(
k, TrainingResult._fields))
if stopping_criterion:
for k in stopping_criterion:
if k not in TrainingResult._fields:
raise TuneError(
"Stopping condition key `{}` must be one of {}".format(
k, TrainingResult._fields))
# Immutable config
self.trainable_name = trainable_name
self.config = config
self.config = config or {}
self.local_dir = local_dir
self.experiment_tag = experiment_tag
self.resources = resources
self.stopping_criterion = stopping_criterion
self.stopping_criterion = stopping_criterion or {}
self.checkpoint_freq = checkpoint_freq
self.upload_dir = upload_dir