"""Example of a custom gym environment and model. Run this for a demo. This example shows: - using a custom environment - using a custom model - using Tune for grid search You can visualize experiment results in ~/ray_results using TensorBoard. """ import argparse import gym from gym.spaces import Discrete, Box import numpy as np import os import ray from ray import tune from ray.tune import grid_search from ray.rllib.models import ModelCatalog from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.test_utils import check_learning_achieved tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() parser = argparse.ArgumentParser() parser.add_argument("--run", type=str, default="PPO") parser.add_argument("--torch", action="store_true") parser.add_argument("--as-test", action="store_true") parser.add_argument("--stop-iters", type=int, default=50) parser.add_argument("--stop-timesteps", type=int, default=100000) parser.add_argument("--stop-reward", type=float, default=0.1) class SimpleCorridor(gym.Env): """Example of a custom env in which you have to walk down a corridor. You can configure the length of the corridor via the env config.""" def __init__(self, config): self.end_pos = config["corridor_length"] self.cur_pos = 0 self.action_space = Discrete(2) self.observation_space = Box( 0.0, self.end_pos, shape=(1, ), dtype=np.float32) def reset(self): self.cur_pos = 0 return [self.cur_pos] def step(self, action): assert action in [0, 1], action if action == 0 and self.cur_pos > 0: self.cur_pos -= 1 elif action == 1: self.cur_pos += 1 done = self.cur_pos >= self.end_pos return [self.cur_pos], 1.0 if done else -0.1, done, {} class CustomModel(TFModelV2): """Example of a keras custom model that just delegates to an fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): super(CustomModel, self).__init__(obs_space, action_space, num_outputs, model_config, name) self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): return self.model.forward(input_dict, state, seq_lens) def value_function(self): return self.model.value_function() class TorchCustomModel(TorchModelV2, nn.Module): """Example of a PyTorch custom model that just delegates to a fc-net.""" def __init__(self, obs_space, action_space, num_outputs, model_config, name): TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) nn.Module.__init__(self) self.torch_sub_model = TorchFC(obs_space, action_space, num_outputs, model_config, name) def forward(self, input_dict, state, seq_lens): input_dict["obs"] = input_dict["obs"].float() fc_out, _ = self.torch_sub_model(input_dict, state, seq_lens) return fc_out, [] def value_function(self): return torch.reshape(self.torch_sub_model.value_function(), [-1]) if __name__ == "__main__": args = parser.parse_args() ray.init() # Can also register the env creator function explicitly with: # register_env("corridor", lambda config: SimpleCorridor(config)) ModelCatalog.register_custom_model( "my_model", TorchCustomModel if args.torch else CustomModel) config = { "env": SimpleCorridor, # or "corridor" if registered above "env_config": { "corridor_length": 5, }, # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "model": { "custom_model": "my_model", }, "vf_share_layers": True, "lr": grid_search([1e-2, 1e-4, 1e-6]), # try different lrs "num_workers": 1, # parallelism "framework": "torch" if args.torch else "tf", } stop = { "training_iteration": args.stop_iters, "timesteps_total": args.stop_timesteps, "episode_reward_mean": args.stop_reward, } results = tune.run(args.run, config=config, stop=stop) if args.as_test: check_learning_achieved(results, args.stop_reward) ray.shutdown()