mirror of
https://github.com/vale981/ray
synced 2025-03-08 19:41:38 -05:00

* Unifying the code for PGTrainer/Policy wrt tf vs torch. Adding loss function test cases for the PGAgent (confirm equivalence of tf and torch). * Fix LINT line-len errors. * Fix LINT errors. * Fix `tf_pg_policy` imports (formerly: `pg_policy`). * Rename tf_pg_... into pg_tf_... following <alg>_<framework>_... convention, where ...=policy/loss/agent/trainer. Retire `PGAgent` class (use PGTrainer instead). * - Move PG test into agents/pg/tests directory. - All test cases will be located near the classes that are tested and then built into the Bazel/Travis test suite. * Moved post_process_advantages into pg.py (from pg_tf_policy.py), b/c the function is not a tf-specific one. * Fix remaining import errors for agents/pg/... * Fix circular dependency in pg imports. * Add pg tests to Jenkins test suite.
67 lines
2 KiB
Python
67 lines
2 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
|
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
|
|
|
# yapf: disable
|
|
# __sphinx_doc_begin__
|
|
DEFAULT_CONFIG = with_common_config({
|
|
# Size of rollout batch
|
|
"sample_batch_size": 10,
|
|
# Use PyTorch as framework - no LSTM support
|
|
"use_pytorch": False,
|
|
# GAE(gamma) parameter
|
|
"lambda": 1.0,
|
|
# Max global norm for each gradient calculated by worker
|
|
"grad_clip": 40.0,
|
|
# Learning rate
|
|
"lr": 0.0001,
|
|
# Learning rate schedule
|
|
"lr_schedule": None,
|
|
# Value Function Loss coefficient
|
|
"vf_loss_coeff": 0.5,
|
|
# Entropy coefficient
|
|
"entropy_coeff": 0.01,
|
|
# Min time per iteration
|
|
"min_iter_time_s": 5,
|
|
# Workers sample async. Note that this increases the effective
|
|
# sample_batch_size by up to 5x due to async buffering of batches.
|
|
"sample_async": True,
|
|
})
|
|
# __sphinx_doc_end__
|
|
# yapf: enable
|
|
|
|
|
|
def get_policy_class(config):
|
|
if config["use_pytorch"]:
|
|
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
|
A3CTorchPolicy
|
|
return A3CTorchPolicy
|
|
else:
|
|
return A3CTFPolicy
|
|
|
|
|
|
def validate_config(config):
|
|
if config["entropy_coeff"] < 0:
|
|
raise DeprecationWarning("entropy_coeff must be >= 0")
|
|
if config["sample_async"] and config["use_pytorch"]:
|
|
raise ValueError(
|
|
"The sample_async option is not supported with use_pytorch: "
|
|
"Multithreading can be lead to crashes if used with pytorch.")
|
|
|
|
|
|
def make_async_optimizer(workers, config):
|
|
return AsyncGradientsOptimizer(workers, **config["optimizer"])
|
|
|
|
|
|
A3CTrainer = build_trainer(
|
|
name="A3C",
|
|
default_config=DEFAULT_CONFIG,
|
|
default_policy=A3CTFPolicy,
|
|
get_policy_class=get_policy_class,
|
|
validate_config=validate_config,
|
|
make_policy_optimizer=make_async_optimizer)
|