ray/rllib/contrib/alpha_zero/examples/train_cartpole.py
Eric Liang dd70720578
[rllib] Rename sample_batch_size => rollout_fragment_length (#7503)
* bulk rename

* deprecation warn

* update doc

* update fig

* line length

* rename

* make pytest comptaible

* fix test

* fi sys

* rename

* wip

* fix more

* lint

* update svg

* comments

* lint

* fix use of batch steps
2020-03-14 12:05:04 -07:00

49 lines
1.6 KiB
Python

"""Example of using training on CartPole."""
import argparse
import ray
from ray import tune
from ray.rllib.contrib.alpha_zero.models.custom_torch_models import DenseModel
from ray.rllib.contrib.alpha_zero.environments.cartpole import CartPole
from ray.rllib.models.catalog import ModelCatalog
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num-workers", default=6, type=int)
parser.add_argument("--training-iteration", default=10000, type=int)
parser.add_argument("--ray-num-cpus", default=7, type=int)
args = parser.parse_args()
ray.init(num_cpus=args.ray_num_cpus)
ModelCatalog.register_custom_model("dense_model", DenseModel)
tune.run(
"contrib/AlphaZero",
stop={"training_iteration": args.training_iteration},
max_failures=0,
config={
"env": CartPole,
"num_workers": args.num_workers,
"rollout_fragment_length": 50,
"train_batch_size": 500,
"sgd_minibatch_size": 64,
"lr": 1e-4,
"num_sgd_iter": 1,
"mcts_config": {
"puct_coefficient": 1.5,
"num_simulations": 100,
"temperature": 1.0,
"dirichlet_epsilon": 0.20,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
},
"ranked_rewards": {
"enable": True,
},
"model": {
"custom_model": "dense_model",
},
},
)