mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[rllib] Add some debug logs during agent setup (#3247)
This commit is contained in:
parent
cf9e838326
commit
43df405d07
6 changed files with 31 additions and 5 deletions
|
@ -20,6 +20,8 @@ from ray.tune.trainable import Trainable
|
|||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
COMMON_CONFIG = {
|
||||
|
@ -252,6 +254,7 @@ class Agent(Trainable):
|
|||
self.optimizer.local_evaluator.set_global_vars(self.global_vars)
|
||||
for ev in self.optimizer.remote_evaluators:
|
||||
ev.set_global_vars.remote(self.global_vars)
|
||||
logger.debug("updated global vars: {}".format(self.global_vars))
|
||||
|
||||
if (self.config.get("observation_filter", "NoFilter") != "NoFilter"
|
||||
and hasattr(self, "local_evaluator")):
|
||||
|
@ -259,6 +262,8 @@ class Agent(Trainable):
|
|||
self.local_evaluator.filters,
|
||||
self.remote_evaluators,
|
||||
update_remote=self.config["synchronize_filters"])
|
||||
logger.debug("synchronized filters: {}".format(
|
||||
self.local_evaluator.filters))
|
||||
|
||||
return Trainable.train(self)
|
||||
|
||||
|
|
|
@ -24,6 +24,8 @@ from ray.rllib.utils.compression import pack
|
|||
from ray.rllib.utils.filter import get_filter
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PolicyEvaluator(EvaluatorInterface):
|
||||
"""Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``.
|
||||
|
@ -301,6 +303,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
|||
pack=pack_episodes,
|
||||
tf_sess=self.tf_sess)
|
||||
|
||||
logger.debug("Created evaluator with env {} ({}), policies {}".format(
|
||||
self.async_env, self.env, self.policy_map))
|
||||
|
||||
def _build_policy_map(self, policy_dict, policy_config):
|
||||
policy_map = {}
|
||||
for name, (cls, obs_space, act_space,
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
|
@ -11,6 +12,8 @@ from ray.rllib.models.lstm import chop_into_sequences
|
|||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TFPolicyGraph(PolicyGraph):
|
||||
"""An agent policy and loss implemented in TensorFlow.
|
||||
|
@ -116,6 +119,9 @@ class TFPolicyGraph(PolicyGraph):
|
|||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
||||
logger.debug("Created {} with loss inputs: {}".format(
|
||||
self, self._loss_input_dict))
|
||||
|
||||
def build_compute_actions(self,
|
||||
builder,
|
||||
obs_batch,
|
||||
|
|
|
@ -203,6 +203,9 @@ class ModelCatalog(object):
|
|||
model = LSTM(copy, obs_space, num_outputs, options, state_in,
|
||||
seq_lens)
|
||||
|
||||
logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format(
|
||||
model, input_dict, obs_space, state_in, seq_lens, model.outputs,
|
||||
model.state_out))
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
|
@ -282,11 +285,15 @@ class ModelCatalog(object):
|
|||
if options.get("custom_preprocessor"):
|
||||
preprocessor = options["custom_preprocessor"]
|
||||
logger.info("Using custom preprocessor {}".format(preprocessor))
|
||||
return _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
prep = _global_registry.get(RLLIB_PREPROCESSOR, preprocessor)(
|
||||
env.observation_space, options)
|
||||
else:
|
||||
cls = get_preprocessor(env.observation_space)
|
||||
prep = cls(env.observation_space, options)
|
||||
|
||||
preprocessor = get_preprocessor(env.observation_space)
|
||||
return preprocessor(env.observation_space, options)
|
||||
logger.debug("Created preprocessor {}: {} -> {}".format(
|
||||
prep, env.observation_space, prep.shape))
|
||||
return prep
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor_as_wrapper(env, options=None):
|
||||
|
|
|
@ -132,7 +132,7 @@ class TupleFlatteningPreprocessor(Preprocessor):
|
|||
self.preprocessors = []
|
||||
for i in range(len(self._obs_space.spaces)):
|
||||
space = self._obs_space.spaces[i]
|
||||
logger.info("Creating sub-preprocessor for {}".format(space))
|
||||
logger.debug("Creating sub-preprocessor for {}".format(space))
|
||||
preprocessor = get_preprocessor(space)(space, self._options)
|
||||
self.preprocessors.append(preprocessor)
|
||||
size += preprocessor.size
|
||||
|
@ -157,7 +157,7 @@ class DictFlatteningPreprocessor(Preprocessor):
|
|||
size = 0
|
||||
self.preprocessors = []
|
||||
for space in self._obs_space.spaces.values():
|
||||
logger.info("Creating sub-preprocessor for {}".format(space))
|
||||
logger.debug("Creating sub-preprocessor for {}".format(space))
|
||||
preprocessor = get_preprocessor(space)(space, self._options)
|
||||
self.preprocessors.append(preprocessor)
|
||||
size += preprocessor.size
|
||||
|
|
|
@ -57,6 +57,9 @@ class PolicyOptimizer(object):
|
|||
self.num_steps_trained = 0
|
||||
self.num_steps_sampled = 0
|
||||
|
||||
logger.debug("Created policy optimizer with {}: {}".format(
|
||||
config, self))
|
||||
|
||||
def _init(self):
|
||||
"""Subclasses should prefer overriding this instead of __init__."""
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue