mirror of
https://github.com/vale981/ray
synced 2025-03-04 17:41:43 -05:00

update rllib example to use Tuner API. Signed-off-by: xwjiang2010 <xwjiang2010@gmail.com>
58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
"""Example of using variable-length Repeated / struct observation spaces.
|
|
|
|
This example shows:
|
|
- using a custom environment with Repeated / struct observations
|
|
- using a custom model to view the batched list observations
|
|
|
|
For PyTorch / TF eager mode, use the `--framework=[torch|tf2|tfe]` flag.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
|
|
import ray
|
|
from ray import air, tune
|
|
from ray.rllib.models import ModelCatalog
|
|
from ray.rllib.examples.env.simple_rpg import SimpleRPG
|
|
from ray.rllib.examples.models.simple_rpg_model import (
|
|
CustomTorchRPGModel,
|
|
CustomTFRPGModel,
|
|
)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--framework",
|
|
choices=["tf", "tf2", "tfe", "torch"],
|
|
default="tf2",
|
|
help="The DL framework specifier.",
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
ray.init()
|
|
args = parser.parse_args()
|
|
if args.framework == "torch":
|
|
ModelCatalog.register_custom_model("my_model", CustomTorchRPGModel)
|
|
else:
|
|
ModelCatalog.register_custom_model("my_model", CustomTFRPGModel)
|
|
|
|
config = {
|
|
"framework": args.framework,
|
|
"env": SimpleRPG,
|
|
"rollout_fragment_length": 1,
|
|
"train_batch_size": 2,
|
|
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
|
|
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
|
|
"num_workers": 0,
|
|
"model": {
|
|
"custom_model": "my_model",
|
|
},
|
|
"_disable_preprocessor_api": False,
|
|
}
|
|
|
|
stop = {
|
|
"timesteps_total": 1,
|
|
}
|
|
|
|
tuner = tune.Tuner(
|
|
"PG", param_space=config, run_config=air.RunConfig(stop=stop, verbose=1)
|
|
)
|