"""Example of using rollout worker classes directly to implement training. Instead of using the built-in Trainer classes provided by RLlib, here we define a custom Policy class and manually coordinate distributed sample collection and policy optimization. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import gym import ray from ray import tune from ray.rllib.policy import Policy from ray.rllib.evaluation import RolloutWorker, SampleBatch from ray.rllib.evaluation.metrics import collect_metrics parser = argparse.ArgumentParser() parser.add_argument("--gpu", action="store_true") parser.add_argument("--num-iters", type=int, default=20) parser.add_argument("--num-workers", type=int, default=2) class CustomPolicy(Policy): """Example of a custom policy written from scratch. You might find it more convenient to extend TF/TorchPolicy instead for a real policy. """ def __init__(self, observation_space, action_space, config): Policy.__init__(self, observation_space, action_space, config) # example parameter self.w = 1.0 def compute_actions(self, obs_batch, state_batches, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, **kwargs): # return random actions return [self.action_space.sample() for _ in obs_batch], [], {} def learn_on_batch(self, samples): # implement your learning code here return {} def update_some_value(self, w): # can also call other methods on policies self.w = w def get_weights(self): return {"w": self.w} def set_weights(self, weights): self.w = weights["w"] def training_workflow(config, reporter): # Setup policy and policy evaluation actors env = gym.make("CartPole-v0") policy = CustomPolicy(env.observation_space, env.action_space, {}) workers = [ RolloutWorker.as_remote().remote(lambda c: gym.make("CartPole-v0"), CustomPolicy) for _ in range(config["num_workers"]) ] for _ in range(config["num_iters"]): # Broadcast weights to the policy evaluation workers weights = ray.put({"default_policy": policy.get_weights()}) for w in workers: w.set_weights.remote(weights) # Gather a batch of samples T1 = SampleBatch.concat_samples( ray.get([w.sample.remote() for w in workers])) # Update the remote policy replicas and gather another batch of samples new_value = policy.w * 2.0 for w in workers: w.for_policy.remote(lambda p: p.update_some_value(new_value)) # Gather another batch of samples T2 = SampleBatch.concat_samples( ray.get([w.sample.remote() for w in workers])) # Improve the policy using the T1 batch policy.learn_on_batch(T1) # Do some arbitrary updates based on the T2 batch policy.update_some_value(sum(T2["rewards"])) reporter(**collect_metrics(remote_workers=workers)) if __name__ == "__main__": args = parser.parse_args() ray.init() tune.run( training_workflow, resources_per_trial={ "gpu": 1 if args.gpu else 0, "cpu": 1, "extra_cpu": args.num_workers, }, config={ "num_workers": args.num_workers, "num_iters": args.num_iters, }, )