mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Add 2 Transformer learning test cases on StatelessCartPole (PPO and IMPALA). (#8624)
This commit is contained in:
parent
b0bb0584fb
commit
0422e9c5a8
14 changed files with 133 additions and 33 deletions
|
@ -71,7 +71,8 @@ Once implemented, the model can then be registered and used in place of a built-
|
|||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
"custom_options": {}, # extra options to pass to your model
|
||||
# Extra kwargs to be passed to your model's c'tor.
|
||||
"custom_model_config": {},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -132,7 +133,8 @@ Once implemented, the model can then be registered and used in place of a built-
|
|||
"use_pytorch": True,
|
||||
"model": {
|
||||
"custom_model": "my_model",
|
||||
"custom_options": {}, # extra options to pass to your model
|
||||
# Extra kwargs to be passed to your model's c'tor.
|
||||
"custom_model_config": {},
|
||||
},
|
||||
})
|
||||
|
||||
|
@ -165,7 +167,8 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith
|
|||
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
|
||||
"model": {
|
||||
"custom_preprocessor": "my_prep",
|
||||
"custom_options": {}, # extra options to pass to your preprocessor
|
||||
# Extra kwargs to be passed to your model's c'tor.
|
||||
"custom_model_config": {},
|
||||
},
|
||||
})
|
||||
|
||||
|
|
19
rllib/BUILD
19
rllib/BUILD
|
@ -99,7 +99,7 @@ py_test(
|
|||
name = "run_regression_tests_cartpole_appo_torch",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_torch", "learning_tests_cartpole"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = [
|
||||
"tuned_examples/ppo/cartpole-appo.yaml",
|
||||
|
@ -184,7 +184,7 @@ py_test(
|
|||
name = "run_regression_tests_cartpole_es_tf",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_tf", "learning_tests_cartpole"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/es/cartpole-es.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/es"]
|
||||
|
@ -194,7 +194,7 @@ py_test(
|
|||
name = "run_regression_tests_cartpole_es_torch",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_torch", "learning_tests_cartpole"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/es/cartpole-es.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/es", "--torch"]
|
||||
|
@ -205,7 +205,7 @@ py_test(
|
|||
name = "run_regression_tests_cartpole_impala_tf",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_tf", "learning_tests_cartpole"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/impala/cartpole-impala.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/impala"]
|
||||
|
@ -215,7 +215,7 @@ py_test(
|
|||
name = "run_regression_tests_cartpole_impala_torch",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_torch", "learning_tests_cartpole"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/impala/cartpole-impala.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/impala", "--torch"]
|
||||
|
@ -298,7 +298,7 @@ py_test(
|
|||
name = "run_regression_tests_cartpole_sac_torch",
|
||||
main = "tests/run_regression_tests.py",
|
||||
tags = ["learning_tests_torch", "learning_tests_cartpole"],
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["tests/run_regression_tests.py"],
|
||||
data = ["tuned_examples/sac/cartpole-sac.yaml"],
|
||||
args = ["--yaml-dir=tuned_examples/sac", "--torch"]
|
||||
|
@ -1306,6 +1306,13 @@ py_test(
|
|||
# for `tests/test_all_stuff.py`.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
py_test(
|
||||
name = "tests/test_attention_net_learning",
|
||||
tags = ["tests_dir", "tests_dir_A"],
|
||||
size = "large",
|
||||
srcs = ["tests/test_attention_net_learning.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_avail_actions_qmix",
|
||||
tags = ["tests_dir", "tests_dir_A"],
|
||||
|
|
|
@ -73,8 +73,8 @@ class ConvNetModel(ActorCriticModel):
|
|||
ActorCriticModel.__init__(self, obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
in_channels = model_config["custom_options"]["in_channels"]
|
||||
feature_dim = model_config["custom_options"]["feature_dim"]
|
||||
in_channels = model_config["custom_model_config"]["in_channels"]
|
||||
feature_dim = model_config["custom_model_config"]["feature_dim"]
|
||||
|
||||
self.shared_layers = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 32, kernel_size=4, stride=2),
|
||||
|
|
|
@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|||
"model": {
|
||||
"custom_model": GTrXLNet,
|
||||
"max_seq_len": 50,
|
||||
"custom_options": {
|
||||
"custom_model_config": {
|
||||
"num_transformer_units": 1,
|
||||
"attn_dim": 64,
|
||||
"num_heads": 2,
|
||||
|
|
|
@ -53,7 +53,7 @@ if __name__ == "__main__":
|
|||
"num_workers": 0,
|
||||
"model": {
|
||||
"custom_model": "custom_loss",
|
||||
"custom_options": {
|
||||
"custom_model_config": {
|
||||
"input_files": args.input_files,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -36,7 +36,7 @@ if __name__ == "__main__":
|
|||
"model": {
|
||||
"custom_model": "my_model",
|
||||
# Extra config passed to the custom model's c'tor as kwargs.
|
||||
"custom_options": {
|
||||
"custom_model_config": {
|
||||
"cnn_shape": cnn_shape_torch if args.torch else cnn_shape,
|
||||
},
|
||||
"max_seq_len": 20,
|
||||
|
|
|
@ -38,7 +38,8 @@ class CustomLossModel(TFModelV2):
|
|||
@override(ModelV2)
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
# Create a new input reader per worker.
|
||||
reader = JsonReader(self.model_config["custom_options"]["input_files"])
|
||||
reader = JsonReader(
|
||||
self.model_config["custom_model_config"]["input_files"])
|
||||
input_ops = reader.tf_input_ops()
|
||||
|
||||
# Define a secondary loss by building a graph copy with weight sharing.
|
||||
|
@ -80,7 +81,7 @@ class DeprecatedCustomLossModelV1(Model):
|
|||
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
# create a new input reader per worker
|
||||
reader = JsonReader(self.options["custom_options"]["input_files"])
|
||||
reader = JsonReader(self.options["custom_model_config"]["input_files"])
|
||||
input_ops = reader.tf_input_ops()
|
||||
|
||||
# define a secondary loss by building a graph copy with weight sharing
|
||||
|
|
|
@ -23,7 +23,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
|
|||
TorchMultiActionDistribution, TorchMultiCategorical
|
||||
from ray.rllib.utils import try_import_tf, try_import_tree
|
||||
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
|
||||
from ray.rllib.utils.deprecation import deprecation_warning
|
||||
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.space_utils import flatten_space
|
||||
|
||||
|
@ -81,14 +81,17 @@ MODEL_DEFAULTS = {
|
|||
# === Options for custom models ===
|
||||
# Name of a custom model to use
|
||||
"custom_model": None,
|
||||
# Extra options to pass to the custom classes.
|
||||
# These will be available in the Model's
|
||||
"custom_model_config": {},
|
||||
# Name of a custom action distribution to use.
|
||||
"custom_action_dist": None,
|
||||
|
||||
# Extra options to pass to the custom classes
|
||||
"custom_options": {},
|
||||
# Custom preprocessors are deprecated. Please use a wrapper class around
|
||||
# your environment instead to preprocess observations.
|
||||
"custom_preprocessor": None,
|
||||
|
||||
# Deprecated config keys.
|
||||
"custom_options": DEPRECATED_VALUE,
|
||||
}
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
@ -280,6 +283,16 @@ class ModelCatalog:
|
|||
"""
|
||||
|
||||
if model_config.get("custom_model"):
|
||||
|
||||
if "custom_options" in model_config and \
|
||||
model_config["custom_options"] != DEPRECATED_VALUE:
|
||||
deprecation_warning(
|
||||
"model.custom_options",
|
||||
"model.custom_model_config",
|
||||
error=False)
|
||||
model_config["custom_model_config"] = \
|
||||
model_config.pop("custom_options")
|
||||
|
||||
if isinstance(model_config["custom_model"], type):
|
||||
model_cls = model_config["custom_model"]
|
||||
else:
|
||||
|
@ -304,7 +317,7 @@ class ModelCatalog:
|
|||
with tf.variable_creator_scope(track_var_creation):
|
||||
# Try calling with kwargs first (custom ModelV2 should
|
||||
# accept these as kwargs, not get them from
|
||||
# config["custom_options"] anymore)
|
||||
# config["custom_model_config"] anymore).
|
||||
try:
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config,
|
||||
|
@ -315,7 +328,7 @@ class ModelCatalog:
|
|||
logger.warning(
|
||||
"Custom ModelV2 should accept all custom "
|
||||
"options as **kwargs, instead of expecting"
|
||||
" them in config['custom_options']!")
|
||||
" them in config['custom_model_config']!")
|
||||
instance = model_cls(obs_space, action_space,
|
||||
num_outputs, model_config,
|
||||
name)
|
||||
|
|
|
@ -146,7 +146,7 @@ class GTrXLNet(RecurrentNetwork):
|
|||
Examples:
|
||||
>> config["model"]["custom_model"] = GTrXLNet
|
||||
>> config["model"]["max_seq_len"] = 10
|
||||
>> config["model"]["custom_options"] = {
|
||||
>> config["model"]["custom_model_config"] = {
|
||||
>> num_transformer_units=1,
|
||||
>> attn_dim=32,
|
||||
>> num_heads=2,
|
||||
|
|
|
@ -161,7 +161,7 @@ class DynamicTFPolicy(TFPolicy):
|
|||
num_outputs=logit_dim,
|
||||
model_config=self.config["model"],
|
||||
framework="tf",
|
||||
**self.config["model"].get("custom_options", {}))
|
||||
**self.config["model"].get("custom_model_config", {}))
|
||||
|
||||
# Create the Exploration object to use for this Policy.
|
||||
self.exploration = self._create_exploration()
|
||||
|
|
|
@ -117,7 +117,7 @@ def build_torch_policy(name,
|
|||
num_outputs=logit_dim,
|
||||
model_config=self.config["model"],
|
||||
framework="torch",
|
||||
**self.config["model"].get("custom_options", {}))
|
||||
**self.config["model"].get("custom_model_config", {}))
|
||||
|
||||
# Make sure, we passed in a correct Model factory.
|
||||
assert isinstance(self.model, TorchModelV2), \
|
||||
|
|
74
rllib/tests/test_attention_net_learning.py
Normal file
74
rllib/tests/test_attention_net_learning.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import unittest
|
||||
|
||||
from ray import tune
|
||||
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.tf.attention_net import GTrXLNet
|
||||
|
||||
|
||||
class TestAttentionNetLearning(unittest.TestCase):
|
||||
|
||||
config = {
|
||||
"env": StatelessCartPole,
|
||||
"gamma": 0.99,
|
||||
"num_envs_per_worker": 20,
|
||||
# "framework": "tf",
|
||||
}
|
||||
|
||||
stop = {
|
||||
"episode_reward_mean": 180.0,
|
||||
"timesteps_total": 5000000,
|
||||
}
|
||||
|
||||
def test_ppo_attention_net_learning(self):
|
||||
ModelCatalog.register_custom_model("attention_net", GTrXLNet)
|
||||
config = dict(
|
||||
self.config, **{
|
||||
"num_workers": 0,
|
||||
"entropy_coeff": 0.001,
|
||||
"vf_loss_coeff": 1e-5,
|
||||
"num_sgd_iter": 5,
|
||||
"model": {
|
||||
"custom_model": "attention_net",
|
||||
"max_seq_len": 10,
|
||||
"custom_model_config": {
|
||||
"num_transformer_units": 1,
|
||||
"attn_dim": 32,
|
||||
"num_heads": 1,
|
||||
"memory_tau": 5,
|
||||
"head_dim": 32,
|
||||
"ff_hidden_dim": 32,
|
||||
},
|
||||
},
|
||||
})
|
||||
tune.run("PPO", config=config, stop=self.stop, verbose=1)
|
||||
|
||||
def test_impala_attention_net_learning(self):
|
||||
ModelCatalog.register_custom_model("attention_net", GTrXLNet)
|
||||
config = dict(
|
||||
self.config, **{
|
||||
"num_workers": 4,
|
||||
"num_gpus": 0,
|
||||
"entropy_coeff": 0.01,
|
||||
"vf_loss_coeff": 0.001,
|
||||
"lr": 0.0008,
|
||||
"model": {
|
||||
"custom_model": "attention_net",
|
||||
"max_seq_len": 65,
|
||||
"custom_model_config": {
|
||||
"num_transformer_units": 1,
|
||||
"attn_dim": 64,
|
||||
"num_heads": 1,
|
||||
"memory_tau": 10,
|
||||
"head_dim": 32,
|
||||
"ff_hidden_dim": 32,
|
||||
},
|
||||
},
|
||||
})
|
||||
tune.run("IMPALA", config=config, stop=self.stop, verbose=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -4,8 +4,8 @@ import unittest
|
|||
|
||||
import ray
|
||||
from ray.tune import register_env
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.agents.qmix import QMixTrainer
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
|
||||
class AvailActionsTestEnv(MultiAgentEnv):
|
||||
|
|
|
@ -35,19 +35,21 @@ class CustomModel(TFModelV2):
|
|||
class CustomActionDistribution(TFActionDistribution):
|
||||
def __init__(self, inputs, model):
|
||||
# Store our output shape.
|
||||
custom_options = model.model_config["custom_options"]
|
||||
if "output_dim" in custom_options:
|
||||
custom_model_config = model.model_config["custom_model_config"]
|
||||
if "output_dim" in custom_model_config:
|
||||
self.output_shape = tf.concat(
|
||||
[tf.shape(inputs)[:1], custom_options["output_dim"]], axis=0)
|
||||
[tf.shape(inputs)[:1], custom_model_config["output_dim"]],
|
||||
axis=0)
|
||||
else:
|
||||
self.output_shape = tf.shape(inputs)
|
||||
super().__init__(inputs, model)
|
||||
|
||||
@staticmethod
|
||||
def required_model_output_shape(action_space, model_config=None):
|
||||
custom_options = model_config["custom_options"] or {}
|
||||
if custom_options is not None and custom_options.get("output_dim"):
|
||||
return custom_options.get("output_dim")
|
||||
custom_model_config = model_config["custom_model_config"] or {}
|
||||
if custom_model_config is not None and \
|
||||
custom_model_config.get("output_dim"):
|
||||
return custom_model_config.get("output_dim")
|
||||
return action_space.shape
|
||||
|
||||
@override(TFActionDistribution)
|
||||
|
@ -157,7 +159,7 @@ class ModelCatalogTest(unittest.TestCase):
|
|||
dist.entropy()
|
||||
|
||||
# test passing the options to it
|
||||
model_config["custom_options"].update({"output_dim": (3, )})
|
||||
model_config["custom_model_config"].update({"output_dim": (3, )})
|
||||
dist_cls, param_shape = ModelCatalog.get_action_dist(
|
||||
action_space, model_config)
|
||||
self.assertEqual(param_shape, (3, ))
|
||||
|
|
Loading…
Add table
Reference in a new issue