[tune] clean up logs before logging to wandb (#10654)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
Kai Fricke 2020-09-08 22:48:18 +01:00 committed by GitHub
parent dcb9e03fde
commit 69c1a9dd08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,5 +1,5 @@
import os
import pickle
from multiprocessing import Process, Queue
from numbers import Number
@ -19,6 +19,22 @@ WANDB_ENV_VAR = "WANDB_API_KEY"
_WANDB_QUEUE_END = (None, )
def _clean_log(obj):
# Fixes https://github.com/ray-project/ray/issues/10631
if isinstance(obj, dict):
return {k: _clean_log(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_clean_log(v) for v in obj]
# Else
try:
pickle.dumps(obj)
return obj
except Exception:
# give up, similar to _SafeFallBackEncoder
return str(obj)
def wandb_mixin(func):
"""wandb_mixin
@ -304,6 +320,9 @@ class WandbLogger(Logger):
# Grouping
wandb_group = wandb_config.pop("group", self.trial.trainable_name)
# remove unpickleable items!
config = _clean_log(config)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,
@ -324,6 +343,7 @@ class WandbLogger(Logger):
self._wandb.start()
def on_result(self, result):
result = _clean_log(result)
self._queue.put(result)
def close(self):
@ -374,6 +394,9 @@ class WandbTrainableMixin:
default_group = type(self).__name__
wandb_group = wandb_config.pop("group", default_group)
# remove unpickleable items!
_config = _clean_log(_config)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,