ray/rllib/examples/custom_tf_policy.py
gehring 8903bcd0c3 [rllib] Tracing for eager tensorflow policies with tf.function (#5705)
* Added tracing of eager policies with `tf.function`

* lint

* add config option

* add docs

* wip

* tracing now works with a3c

* typo

* none

* file doc

* returns

* syntax error

* syntax error
2019-09-17 01:44:20 -07:00

57 lines
1.6 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import ray
from ray import tune
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.postprocessing import discount
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--iters", type=int, default=200)
def policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(
action_dist.logp(train_batch["actions"]) * train_batch["returns"])
def calculate_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
sample_batch["returns"] = discount(sample_batch["rewards"], 0.99)
return sample_batch
# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTFPolicy = build_tf_policy(
name="MyTFPolicy",
loss_fn=policy_gradient_loss,
postprocess_fn=calculate_advantages,
)
# <class 'ray.rllib.agents.trainer_template.MyCustomTrainer'>
MyTrainer = build_trainer(
name="MyCustomTrainer",
default_policy=MyTFPolicy,
)
if __name__ == "__main__":
ray.init()
args = parser.parse_args()
tune.run(
MyTrainer,
stop={"training_iteration": args.iters},
config={
"env": "CartPole-v0",
"num_workers": 2,
})