"""Example of using training on CartPole.""" import argparse import ray from ray import tune from ray.rllib.contrib.alpha_zero.models.custom_torch_models import DenseModel from ray.rllib.contrib.alpha_zero.environments.cartpole import CartPole from ray.rllib.models.catalog import ModelCatalog if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num-workers", default=6, type=int) parser.add_argument("--training-iteration", default=10000, type=int) parser.add_argument("--ray-num-cpus", default=7, type=int) args = parser.parse_args() ray.init(num_cpus=args.ray_num_cpus) ModelCatalog.register_custom_model("dense_model", DenseModel) tune.run( "contrib/AlphaZero", stop={"training_iteration": args.training_iteration}, max_failures=0, config={ "env": CartPole, "num_workers": args.num_workers, "rollout_fragment_length": 50, "train_batch_size": 500, "sgd_minibatch_size": 64, "lr": 1e-4, "num_sgd_iter": 1, "mcts_config": { "puct_coefficient": 1.5, "num_simulations": 100, "temperature": 1.0, "dirichlet_epsilon": 0.20, "dirichlet_noise": 0.03, "argmax_tree_policy": False, "add_dirichlet_noise": True, }, "ranked_rewards": { "enable": True, }, "model": { "custom_model": "dense_model", }, }, )