[RLlib] Add 2 Transformer learning test cases on StatelessCartPole (PPO and IMPALA). (#8624)

This commit is contained in:
Sven Mika 2020-05-27 10:19:47 +02:00 committed by GitHub
parent b0bb0584fb
commit 0422e9c5a8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 133 additions and 33 deletions

View file

@ -71,7 +71,8 @@ Once implemented, the model can then be registered and used in place of a built-
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
"model": {
"custom_model": "my_model",
"custom_options": {}, # extra options to pass to your model
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
})
@ -132,7 +133,8 @@ Once implemented, the model can then be registered and used in place of a built-
"use_pytorch": True,
"model": {
"custom_model": "my_model",
"custom_options": {}, # extra options to pass to your model
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
})
@ -165,7 +167,8 @@ Custom preprocessors should subclass the RLlib `preprocessor class <https://gith
trainer = ppo.PPOTrainer(env="CartPole-v0", config={
"model": {
"custom_preprocessor": "my_prep",
"custom_options": {}, # extra options to pass to your preprocessor
# Extra kwargs to be passed to your model's c'tor.
"custom_model_config": {},
},
})

View file

@ -99,7 +99,7 @@ py_test(
name = "run_regression_tests_cartpole_appo_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "medium",
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = [
"tuned_examples/ppo/cartpole-appo.yaml",
@ -184,7 +184,7 @@ py_test(
name = "run_regression_tests_cartpole_es_tf",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf", "learning_tests_cartpole"],
size = "medium",
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/es/cartpole-es.yaml"],
args = ["--yaml-dir=tuned_examples/es"]
@ -194,7 +194,7 @@ py_test(
name = "run_regression_tests_cartpole_es_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "medium",
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/es/cartpole-es.yaml"],
args = ["--yaml-dir=tuned_examples/es", "--torch"]
@ -205,7 +205,7 @@ py_test(
name = "run_regression_tests_cartpole_impala_tf",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_tf", "learning_tests_cartpole"],
size = "medium",
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/impala/cartpole-impala.yaml"],
args = ["--yaml-dir=tuned_examples/impala"]
@ -215,7 +215,7 @@ py_test(
name = "run_regression_tests_cartpole_impala_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "medium",
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/impala/cartpole-impala.yaml"],
args = ["--yaml-dir=tuned_examples/impala", "--torch"]
@ -298,7 +298,7 @@ py_test(
name = "run_regression_tests_cartpole_sac_torch",
main = "tests/run_regression_tests.py",
tags = ["learning_tests_torch", "learning_tests_cartpole"],
size = "medium",
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/sac/cartpole-sac.yaml"],
args = ["--yaml-dir=tuned_examples/sac", "--torch"]
@ -1306,6 +1306,13 @@ py_test(
# for `tests/test_all_stuff.py`.
# --------------------------------------------------------------------
py_test(
name = "tests/test_attention_net_learning",
tags = ["tests_dir", "tests_dir_A"],
size = "large",
srcs = ["tests/test_attention_net_learning.py"]
)
py_test(
name = "tests/test_avail_actions_qmix",
tags = ["tests_dir", "tests_dir_A"],

View file

@ -73,8 +73,8 @@ class ConvNetModel(ActorCriticModel):
ActorCriticModel.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
in_channels = model_config["custom_options"]["in_channels"]
feature_dim = model_config["custom_options"]["feature_dim"]
in_channels = model_config["custom_model_config"]["in_channels"]
feature_dim = model_config["custom_model_config"]["feature_dim"]
self.shared_layers = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=4, stride=2),

View file

@ -50,7 +50,7 @@ if __name__ == "__main__":
"model": {
"custom_model": GTrXLNet,
"max_seq_len": 50,
"custom_options": {
"custom_model_config": {
"num_transformer_units": 1,
"attn_dim": 64,
"num_heads": 2,

View file

@ -53,7 +53,7 @@ if __name__ == "__main__":
"num_workers": 0,
"model": {
"custom_model": "custom_loss",
"custom_options": {
"custom_model_config": {
"input_files": args.input_files,
},
},

View file

@ -36,7 +36,7 @@ if __name__ == "__main__":
"model": {
"custom_model": "my_model",
# Extra config passed to the custom model's c'tor as kwargs.
"custom_options": {
"custom_model_config": {
"cnn_shape": cnn_shape_torch if args.torch else cnn_shape,
},
"max_seq_len": 20,

View file

@ -38,7 +38,8 @@ class CustomLossModel(TFModelV2):
@override(ModelV2)
def custom_loss(self, policy_loss, loss_inputs):
# Create a new input reader per worker.
reader = JsonReader(self.model_config["custom_options"]["input_files"])
reader = JsonReader(
self.model_config["custom_model_config"]["input_files"])
input_ops = reader.tf_input_ops()
# Define a secondary loss by building a graph copy with weight sharing.
@ -80,7 +81,7 @@ class DeprecatedCustomLossModelV1(Model):
def custom_loss(self, policy_loss, loss_inputs):
# create a new input reader per worker
reader = JsonReader(self.options["custom_options"]["input_files"])
reader = JsonReader(self.options["custom_model_config"]["input_files"])
input_ops = reader.tf_input_ops()
# define a secondary loss by building a graph copy with weight sharing

View file

@ -23,7 +23,7 @@ from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchMultiActionDistribution, TorchMultiCategorical
from ray.rllib.utils import try_import_tf, try_import_tree
from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI
from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.space_utils import flatten_space
@ -81,14 +81,17 @@ MODEL_DEFAULTS = {
# === Options for custom models ===
# Name of a custom model to use
"custom_model": None,
# Extra options to pass to the custom classes.
# These will be available in the Model's
"custom_model_config": {},
# Name of a custom action distribution to use.
"custom_action_dist": None,
# Extra options to pass to the custom classes
"custom_options": {},
# Custom preprocessors are deprecated. Please use a wrapper class around
# your environment instead to preprocess observations.
"custom_preprocessor": None,
# Deprecated config keys.
"custom_options": DEPRECATED_VALUE,
}
# __sphinx_doc_end__
# yapf: enable
@ -280,6 +283,16 @@ class ModelCatalog:
"""
if model_config.get("custom_model"):
if "custom_options" in model_config and \
model_config["custom_options"] != DEPRECATED_VALUE:
deprecation_warning(
"model.custom_options",
"model.custom_model_config",
error=False)
model_config["custom_model_config"] = \
model_config.pop("custom_options")
if isinstance(model_config["custom_model"], type):
model_cls = model_config["custom_model"]
else:
@ -304,7 +317,7 @@ class ModelCatalog:
with tf.variable_creator_scope(track_var_creation):
# Try calling with kwargs first (custom ModelV2 should
# accept these as kwargs, not get them from
# config["custom_options"] anymore)
# config["custom_model_config"] anymore).
try:
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
@ -315,7 +328,7 @@ class ModelCatalog:
logger.warning(
"Custom ModelV2 should accept all custom "
"options as **kwargs, instead of expecting"
" them in config['custom_options']!")
" them in config['custom_model_config']!")
instance = model_cls(obs_space, action_space,
num_outputs, model_config,
name)

View file

@ -146,7 +146,7 @@ class GTrXLNet(RecurrentNetwork):
Examples:
>> config["model"]["custom_model"] = GTrXLNet
>> config["model"]["max_seq_len"] = 10
>> config["model"]["custom_options"] = {
>> config["model"]["custom_model_config"] = {
>> num_transformer_units=1,
>> attn_dim=32,
>> num_heads=2,

View file

@ -161,7 +161,7 @@ class DynamicTFPolicy(TFPolicy):
num_outputs=logit_dim,
model_config=self.config["model"],
framework="tf",
**self.config["model"].get("custom_options", {}))
**self.config["model"].get("custom_model_config", {}))
# Create the Exploration object to use for this Policy.
self.exploration = self._create_exploration()

View file

@ -117,7 +117,7 @@ def build_torch_policy(name,
num_outputs=logit_dim,
model_config=self.config["model"],
framework="torch",
**self.config["model"].get("custom_options", {}))
**self.config["model"].get("custom_model_config", {}))
# Make sure, we passed in a correct Model factory.
assert isinstance(self.model, TorchModelV2), \

View file

@ -0,0 +1,74 @@
import unittest
from ray import tune
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.tf.attention_net import GTrXLNet
class TestAttentionNetLearning(unittest.TestCase):
config = {
"env": StatelessCartPole,
"gamma": 0.99,
"num_envs_per_worker": 20,
# "framework": "tf",
}
stop = {
"episode_reward_mean": 180.0,
"timesteps_total": 5000000,
}
def test_ppo_attention_net_learning(self):
ModelCatalog.register_custom_model("attention_net", GTrXLNet)
config = dict(
self.config, **{
"num_workers": 0,
"entropy_coeff": 0.001,
"vf_loss_coeff": 1e-5,
"num_sgd_iter": 5,
"model": {
"custom_model": "attention_net",
"max_seq_len": 10,
"custom_model_config": {
"num_transformer_units": 1,
"attn_dim": 32,
"num_heads": 1,
"memory_tau": 5,
"head_dim": 32,
"ff_hidden_dim": 32,
},
},
})
tune.run("PPO", config=config, stop=self.stop, verbose=1)
def test_impala_attention_net_learning(self):
ModelCatalog.register_custom_model("attention_net", GTrXLNet)
config = dict(
self.config, **{
"num_workers": 4,
"num_gpus": 0,
"entropy_coeff": 0.01,
"vf_loss_coeff": 0.001,
"lr": 0.0008,
"model": {
"custom_model": "attention_net",
"max_seq_len": 65,
"custom_model_config": {
"num_transformer_units": 1,
"attn_dim": 64,
"num_heads": 1,
"memory_tau": 10,
"head_dim": 32,
"ff_hidden_dim": 32,
},
},
})
tune.run("IMPALA", config=config, stop=self.stop, verbose=1)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

View file

@ -4,8 +4,8 @@ import unittest
import ray
from ray.tune import register_env
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.agents.qmix import QMixTrainer
from ray.rllib.env.multi_agent_env import MultiAgentEnv
class AvailActionsTestEnv(MultiAgentEnv):

View file

@ -35,19 +35,21 @@ class CustomModel(TFModelV2):
class CustomActionDistribution(TFActionDistribution):
def __init__(self, inputs, model):
# Store our output shape.
custom_options = model.model_config["custom_options"]
if "output_dim" in custom_options:
custom_model_config = model.model_config["custom_model_config"]
if "output_dim" in custom_model_config:
self.output_shape = tf.concat(
[tf.shape(inputs)[:1], custom_options["output_dim"]], axis=0)
[tf.shape(inputs)[:1], custom_model_config["output_dim"]],
axis=0)
else:
self.output_shape = tf.shape(inputs)
super().__init__(inputs, model)
@staticmethod
def required_model_output_shape(action_space, model_config=None):
custom_options = model_config["custom_options"] or {}
if custom_options is not None and custom_options.get("output_dim"):
return custom_options.get("output_dim")
custom_model_config = model_config["custom_model_config"] or {}
if custom_model_config is not None and \
custom_model_config.get("output_dim"):
return custom_model_config.get("output_dim")
return action_space.shape
@override(TFActionDistribution)
@ -157,7 +159,7 @@ class ModelCatalogTest(unittest.TestCase):
dist.entropy()
# test passing the options to it
model_config["custom_options"].update({"output_dim": (3, )})
model_config["custom_model_config"].update({"output_dim": (3, )})
dist_cls, param_shape = ModelCatalog.get_action_dist(
action_space, model_config)
self.assertEqual(param_shape, (3, ))