mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Remove extra model config kwargs passed incorrectly for Torch models (#10055)
This commit is contained in:
parent
bd0b1488ef
commit
ca133e2699
4 changed files with 38 additions and 16 deletions
|
@ -31,7 +31,7 @@ parser.add_argument(
|
|||
type=str,
|
||||
default=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../tests/data/cartpole/small"))
|
||||
"../tests/data/cartpole/small.json"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
|
|
@ -82,8 +82,10 @@ MODEL_DEFAULTS: ModelConfigDict = {
|
|||
# === Options for custom models ===
|
||||
# Name of a custom model to use
|
||||
"custom_model": None,
|
||||
# Extra options to pass to the custom classes.
|
||||
# These will be available in the Model's
|
||||
# Extra options to pass to the custom classes. These will be available to
|
||||
# the Model's constructor in the model_config field. Also, they will be
|
||||
# attempted to be passed as **kwargs to ModelV2 models. For an example,
|
||||
# see rllib/models/[tf|torch]/attention_net.py.
|
||||
"custom_model_config": {},
|
||||
# Name of a custom action distribution to use.
|
||||
"custom_action_dist": None,
|
||||
|
@ -302,6 +304,11 @@ class ModelCatalog:
|
|||
model_config["custom_model_config"] = \
|
||||
model_config.pop("custom_options")
|
||||
|
||||
# Allow model kwargs to be overriden / augmented by
|
||||
# custom_model_config.
|
||||
customized_model_kwargs = dict(
|
||||
model_kwargs, **model_config.get("custom_model_config", {}))
|
||||
|
||||
if isinstance(model_config["custom_model"], type):
|
||||
model_cls = model_config["custom_model"]
|
||||
else:
|
||||
|
@ -329,19 +336,19 @@ class ModelCatalog:
|
|||
# accept these as kwargs, not get them from
|
||||
# config["custom_model_config"] anymore).
|
||||
try:
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config,
|
||||
name, **model_kwargs)
|
||||
instance = model_cls(
|
||||
obs_space, action_space, num_outputs,
|
||||
model_config, name, **customized_model_kwargs)
|
||||
except TypeError as e:
|
||||
# Keyword error: Try old way w/o kwargs.
|
||||
if "__init__() got an unexpected " in e.args[0]:
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config,
|
||||
name, **model_kwargs)
|
||||
logger.warning(
|
||||
"Custom ModelV2 should accept all custom "
|
||||
"options as **kwargs, instead of expecting"
|
||||
" them in config['custom_model_config']!")
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config,
|
||||
name)
|
||||
# Other error -> re-raise.
|
||||
else:
|
||||
raise e
|
||||
|
@ -361,9 +368,26 @@ class ModelCatalog:
|
|||
else:
|
||||
# PyTorch automatically tracks nn.Modules inside the parent
|
||||
# nn.Module's constructor.
|
||||
# TODO(sven): Do this for TF as well.
|
||||
instance = model_cls(obs_space, action_space, num_outputs,
|
||||
model_config, name, **model_kwargs)
|
||||
# Try calling with kwargs first (custom ModelV2 should
|
||||
# accept these as kwargs, not get them from
|
||||
# config["custom_model_config"] anymore).
|
||||
try:
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config, name,
|
||||
**customized_model_kwargs)
|
||||
except TypeError as e:
|
||||
# Keyword error: Try old way w/o kwargs.
|
||||
if "__init__() got an unexpected " in e.args[0]:
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config,
|
||||
name, **model_kwargs)
|
||||
logger.warning(
|
||||
"Custom ModelV2 should accept all custom "
|
||||
"options as **kwargs, instead of expecting"
|
||||
" them in config['custom_model_config']!")
|
||||
# Other error -> re-raise.
|
||||
else:
|
||||
raise e
|
||||
return instance
|
||||
# TODO(sven): Hard-deprecate Model(V1). This check will be
|
||||
# superflous then.
|
||||
|
|
|
@ -200,8 +200,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
action_space=action_space,
|
||||
num_outputs=logit_dim,
|
||||
model_config=self.config["model"],
|
||||
framework="tf",
|
||||
**self.config["model"].get("custom_model_config", {}))
|
||||
framework="tf")
|
||||
|
||||
# Create the Exploration object to use for this Policy.
|
||||
self.exploration = self._create_exploration()
|
||||
|
|
|
@ -201,8 +201,7 @@ def build_torch_policy(
|
|||
action_space=action_space,
|
||||
num_outputs=logit_dim,
|
||||
model_config=self.config["model"],
|
||||
framework="torch",
|
||||
**self.config["model"].get("custom_model_config", {}))
|
||||
framework="torch")
|
||||
|
||||
# Make sure, we passed in a correct Model factory.
|
||||
assert isinstance(self.model, TorchModelV2), \
|
||||
|
|
Loading…
Add table
Reference in a new issue