mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00

* bulk rename * deprecation warn * update doc * update fig * line length * rename * make pytest comptaible * fix test * fi sys * rename * wip * fix more * lint * update svg * comments * lint * fix use of batch steps
49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
"""Example of using training on CartPole."""
|
|
|
|
import argparse
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.rllib.contrib.alpha_zero.models.custom_torch_models import DenseModel
|
|
from ray.rllib.contrib.alpha_zero.environments.cartpole import CartPole
|
|
from ray.rllib.models.catalog import ModelCatalog
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--num-workers", default=6, type=int)
|
|
parser.add_argument("--training-iteration", default=10000, type=int)
|
|
parser.add_argument("--ray-num-cpus", default=7, type=int)
|
|
args = parser.parse_args()
|
|
ray.init(num_cpus=args.ray_num_cpus)
|
|
|
|
ModelCatalog.register_custom_model("dense_model", DenseModel)
|
|
|
|
tune.run(
|
|
"contrib/AlphaZero",
|
|
stop={"training_iteration": args.training_iteration},
|
|
max_failures=0,
|
|
config={
|
|
"env": CartPole,
|
|
"num_workers": args.num_workers,
|
|
"rollout_fragment_length": 50,
|
|
"train_batch_size": 500,
|
|
"sgd_minibatch_size": 64,
|
|
"lr": 1e-4,
|
|
"num_sgd_iter": 1,
|
|
"mcts_config": {
|
|
"puct_coefficient": 1.5,
|
|
"num_simulations": 100,
|
|
"temperature": 1.0,
|
|
"dirichlet_epsilon": 0.20,
|
|
"dirichlet_noise": 0.03,
|
|
"argmax_tree_policy": False,
|
|
"add_dirichlet_noise": True,
|
|
},
|
|
"ranked_rewards": {
|
|
"enable": True,
|
|
},
|
|
"model": {
|
|
"custom_model": "dense_model",
|
|
},
|
|
},
|
|
)
|