mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[tune] using None as the parameter default value instead of mutable dict (#1501)
* do not use dict as default parameter * Update trial.py
This commit is contained in:
parent
369773d3e8
commit
a936468f99
3 changed files with 15 additions and 12 deletions
|
@ -63,7 +63,7 @@ class Agent(Trainable):
|
|||
_allow_unknown_subkeys = []
|
||||
|
||||
def __init__(
|
||||
self, config={}, env=None, registry=get_registry(),
|
||||
self, config=None, env=None, registry=get_registry(),
|
||||
logger_creator=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
|
@ -77,6 +77,8 @@ class Agent(Trainable):
|
|||
object. If unspecified, a default logger is created.
|
||||
"""
|
||||
|
||||
config = config or {}
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = env or config.get("env")
|
||||
Trainable.__init__(self, config, registry, logger_creator)
|
||||
|
|
|
@ -48,7 +48,7 @@ class Trainable(object):
|
|||
classes and objects by name.
|
||||
"""
|
||||
|
||||
def __init__(self, config={}, registry=None, logger_creator=None):
|
||||
def __init__(self, config=None, registry=None, logger_creator=None):
|
||||
"""Initialize an Trainable.
|
||||
|
||||
Subclasses should prefer defining ``_setup()`` instead of overriding
|
||||
|
@ -68,7 +68,7 @@ class Trainable(object):
|
|||
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
self.config = config
|
||||
self.config = config or {}
|
||||
self.registry = registry
|
||||
|
||||
if logger_creator:
|
||||
|
|
|
@ -75,9 +75,9 @@ class Trial(object):
|
|||
ERROR = "ERROR"
|
||||
|
||||
def __init__(
|
||||
self, trainable_name, config={}, local_dir=DEFAULT_RESULTS_DIR,
|
||||
self, trainable_name, config=None, local_dir=DEFAULT_RESULTS_DIR,
|
||||
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion={}, checkpoint_freq=0,
|
||||
stopping_criterion=None, checkpoint_freq=0,
|
||||
restore_path=None, upload_dir=None):
|
||||
"""Initialize a new trial.
|
||||
|
||||
|
@ -89,19 +89,20 @@ class Trial(object):
|
|||
TRAINABLE_CLASS, trainable_name):
|
||||
raise TuneError("Unknown trainable: " + trainable_name)
|
||||
|
||||
for k in stopping_criterion:
|
||||
if k not in TrainingResult._fields:
|
||||
raise TuneError(
|
||||
"Stopping condition key `{}` must be one of {}".format(
|
||||
k, TrainingResult._fields))
|
||||
if stopping_criterion:
|
||||
for k in stopping_criterion:
|
||||
if k not in TrainingResult._fields:
|
||||
raise TuneError(
|
||||
"Stopping condition key `{}` must be one of {}".format(
|
||||
k, TrainingResult._fields))
|
||||
|
||||
# Immutable config
|
||||
self.trainable_name = trainable_name
|
||||
self.config = config
|
||||
self.config = config or {}
|
||||
self.local_dir = local_dir
|
||||
self.experiment_tag = experiment_tag
|
||||
self.resources = resources
|
||||
self.stopping_criterion = stopping_criterion
|
||||
self.stopping_criterion = stopping_criterion or {}
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.upload_dir = upload_dir
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue