2020-06-06 03:22:19 -07:00
|
|
|
"""Example of using variable-length Repeated / struct observation spaces.
|
|
|
|
|
|
|
|
This example shows:
|
|
|
|
- using a custom environment with Repeated / struct observations
|
|
|
|
- using a custom model to view the batched list observations
|
|
|
|
|
2021-05-18 13:18:12 +02:00
|
|
|
For PyTorch / TF eager mode, use the `--framework=[torch|tf2|tfe]` flag.
|
2020-06-06 03:22:19 -07:00
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
2020-10-02 23:07:44 +02:00
|
|
|
import os
|
2020-06-06 03:22:19 -07:00
|
|
|
|
2020-10-02 23:07:44 +02:00
|
|
|
import ray
|
2020-06-06 03:22:19 -07:00
|
|
|
from ray import tune
|
|
|
|
from ray.rllib.models import ModelCatalog
|
|
|
|
from ray.rllib.examples.env.simple_rpg import SimpleRPG
|
2022-01-29 18:41:57 -08:00
|
|
|
from ray.rllib.examples.models.simple_rpg_model import (
|
|
|
|
CustomTorchRPGModel,
|
|
|
|
CustomTFRPGModel,
|
|
|
|
)
|
2020-06-06 03:22:19 -07:00
|
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
2021-05-18 13:18:12 +02:00
|
|
|
"--framework",
|
|
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
|
|
default="tf2",
|
2022-01-29 18:41:57 -08:00
|
|
|
help="The DL framework specifier.",
|
|
|
|
)
|
2020-06-06 03:22:19 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2020-10-02 23:07:44 +02:00
|
|
|
ray.init()
|
2020-06-06 03:22:19 -07:00
|
|
|
args = parser.parse_args()
|
|
|
|
if args.framework == "torch":
|
|
|
|
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
|
|
|
|
else:
|
|
|
|
ModelCatalog.register_custom_model("my_model", CustomTFRPGModel)
|
2020-10-01 16:57:10 +02:00
|
|
|
|
|
|
|
config = {
|
|
|
|
"framework": args.framework,
|
|
|
|
"env": SimpleRPG,
|
|
|
|
"rollout_fragment_length": 1,
|
|
|
|
"train_batch_size": 2,
|
2020-10-02 23:07:44 +02:00
|
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
2020-10-01 16:57:10 +02:00
|
|
|
"num_workers": 0,
|
|
|
|
"model": {
|
|
|
|
"custom_model": "my_model",
|
2020-06-06 03:22:19 -07:00
|
|
|
},
|
2021-12-13 12:04:23 +01:00
|
|
|
"_disable_preprocessor_api": False,
|
2020-10-01 16:57:10 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
stop = {
|
|
|
|
"timesteps_total": 1,
|
|
|
|
}
|
|
|
|
|
|
|
|
tune.run("PG", config=config, stop=stop, verbose=1)
|