ray/rllib/examples/complex_struct_space.py
xwjiang2010 fcf897ee72
[air] update rllib example to use Tuner API. (#26987)
update rllib example to use Tuner API.

Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
2022-07-27 12:12:59 +01:00

58 lines
1.6 KiB
Python

"""Example of using variable-length Repeated / struct observation spaces.
This example shows:
- using a custom environment with Repeated / struct observations
- using a custom model to view the batched list observations
For PyTorch / TF eager mode, use the `--framework=[torch|tf2|tfe]` flag.
"""
import argparse
import os
import ray
from ray import air, tune
from ray.rllib.models import ModelCatalog
from ray.rllib.examples.env.simple_rpg import SimpleRPG
from ray.rllib.examples.models.simple_rpg_model import (
CustomTorchRPGModel,
CustomTFRPGModel,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf2",
help="The DL framework specifier.",
)
if __name__ == "__main__":
ray.init()
args = parser.parse_args()
if args.framework == "torch":
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
else:
ModelCatalog.register_custom_model("my_model", CustomTFRPGModel)
config = {
"framework": args.framework,
"env": SimpleRPG,
"rollout_fragment_length": 1,
"train_batch_size": 2,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0,
"model": {
"custom_model": "my_model",
},
"_disable_preprocessor_api": False,
}
stop = {
"timesteps_total": 1,
}
tuner = tune.Tuner(
"PG", param_space=config, run_config=air.RunConfig(stop=stop, verbose=1)
)