[rllib] Remove extra model config kwargs passed incorrectly for Torch models (#10055)

This commit is contained in:
Eric Liang 2020-08-17 11:12:20 -07:00 committed by GitHub
parent bd0b1488ef
commit ca133e2699
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 16 deletions

View file

@ -31,7 +31,7 @@ parser.add_argument(
type=str,
default=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../tests/data/cartpole/small"))
"../tests/data/cartpole/small.json"))
if __name__ == "__main__":
ray.init()

View file

@ -82,8 +82,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
# === Options for custom models ===
# Name of a custom model to use
"custom_model": None,
# Extra options to pass to the custom classes.
# These will be available in the Model's
# Extra options to pass to the custom classes. These will be available to
# the Model's constructor in the model_config field. Also, they will be
# attempted to be passed as **kwargs to ModelV2 models. For an example,
# see rllib/models/[tf|torch]/attention_net.py.
"custom_model_config": {},
# Name of a custom action distribution to use.
"custom_action_dist": None,
@ -302,6 +304,11 @@ class ModelCatalog:
model_config["custom_model_config"] = \
model_config.pop("custom_options")
# Allow model kwargs to be overriden / augmented by
# custom_model_config.
customized_model_kwargs = dict(
model_kwargs, **model_config.get("custom_model_config", {}))
if isinstance(model_config["custom_model"], type):
model_cls = model_config["custom_model"]
else:
@ -329,19 +336,19 @@ class ModelCatalog:
# accept these as kwargs, not get them from
# config["custom_model_config"] anymore).
try:
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name, **model_kwargs)
instance = model_cls(
obs_space, action_space, num_outputs,
model_config, name, **customized_model_kwargs)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name, **model_kwargs)
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_model_config']!")
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name)
# Other error -> re-raise.
else:
raise e
@ -361,9 +368,26 @@ class ModelCatalog:
else:
# PyTorch automatically tracks nn.Modules inside the parent
# nn.Module's constructor.
# TODO(sven): Do this for TF as well.
instance = model_cls(obs_space, action_space, num_outputs,
model_config, name, **model_kwargs)
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_model_config"] anymore).
try:
instance = model_cls(obs_space, action_space,
num_outputs, model_config, name,
**customized_model_kwargs)
except TypeError as e:
# Keyword error: Try old way w/o kwargs.
if "__init__() got an unexpected " in e.args[0]:
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name, **model_kwargs)
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_model_config']!")
# Other error -> re-raise.
else:
raise e
return instance
# TODO(sven): Hard-deprecate Model(V1). This check will be
# superflous then.

View file

@ -200,8 +200,7 @@ class DynamicTFPolicy(TFPolicy):
action_space=action_space,
num_outputs=logit_dim,
model_config=self.config["model"],
framework="tf",
**self.config["model"].get("custom_model_config", {}))
framework="tf")
# Create the Exploration object to use for this Policy.
self.exploration = self._create_exploration()

View file

@ -201,8 +201,7 @@ def build_torch_policy(
action_space=action_space,
num_outputs=logit_dim,
model_config=self.config["model"],
framework="torch",
**self.config["model"].get("custom_model_config", {}))
framework="torch")
# Make sure, we passed in a correct Model factory.
assert isinstance(self.model, TorchModelV2), \