ray/rllib/examples/autoregressive_action_dist.py

91 lines
3 KiB
Python

"""Example of specifying an autoregressive action distribution.
In an action space with multiple components (e.g., Tuple(a1, a2)), you might
want a2 to be sampled based on the sampled value of a1, i.e.,
a2_sampled ~ P(a2 | a1_sampled, obs). Normally, a1 and a2 would be sampled
independently.
To do this, you need both a custom model that implements the autoregressive
pattern, and a custom action distribution class that leverages that model.
This examples shows both.
"""
import argparse
import os
import ray
from ray import tune
from ray.rllib.examples.env.correlated_actions_env import CorrelatedActionsEnv
from ray.rllib.examples.models.autoregressive_action_model import \
AutoregressiveActionModel, TorchAutoregressiveActionModel
from ray.rllib.examples.models.autoregressive_action_dist import \
BinaryAutoregressiveDistribution, TorchBinaryAutoregressiveDistribution
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.test_utils import check_learning_achieved
parser = argparse.ArgumentParser()
parser.add_argument(
"--run",
type=str,
default="PPO",
help="The RLlib-registered algorithm to use.")
parser.add_argument(
"--framework",
choices=["tf", "tf2", "tfe", "torch"],
default="tf",
help="The DL framework specifier.")
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.")
parser.add_argument(
"--stop-iters",
type=int,
default=200,
help="Number of iterations to train.")
parser.add_argument(
"--stop-timesteps",
type=int,
default=100000,
help="Number of timesteps to train.")
parser.add_argument(
"--stop-reward",
type=float,
default=200.0,
help="Reward at which we stop training.")
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
ModelCatalog.register_custom_model(
"autoregressive_model", TorchAutoregressiveActionModel
if args.framework == "torch" else AutoregressiveActionModel)
ModelCatalog.register_custom_action_dist(
"binary_autoreg_dist", TorchBinaryAutoregressiveDistribution
if args.framework == "torch" else BinaryAutoregressiveDistribution)
config = {
"env": CorrelatedActionsEnv,
"gamma": 0.5,
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"model": {
"custom_model": "autoregressive_model",
"custom_action_dist": "binary_autoreg_dist",
},
"framework": args.framework,
}
stop = {
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
}
results = tune.run(args.run, stop=stop, config=config, verbose=1)
if args.as_test:
check_learning_achieved(results, args.stop_reward)
ray.shutdown()