mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
"""Example of using rollout worker classes directly to implement training.
|
|
|
|
Instead of using the built-in Trainer classes provided by RLlib, here we define
|
|
a custom Policy class and manually coordinate distributed sample
|
|
collection and policy optimization.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import gym
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.rllib.policy import Policy
|
|
from ray.rllib.evaluation import RolloutWorker, SampleBatch
|
|
from ray.rllib.evaluation.metrics import collect_metrics
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--gpu", action="store_true")
|
|
parser.add_argument("--num-iters", type=int, default=20)
|
|
parser.add_argument("--num-workers", type=int, default=2)
|
|
|
|
|
|
class CustomPolicy(Policy):
|
|
"""Example of a custom policy written from scratch.
|
|
|
|
You might find it more convenient to extend TF/TorchPolicy instead
|
|
for a real policy.
|
|
"""
|
|
|
|
def __init__(self, observation_space, action_space, config):
|
|
Policy.__init__(self, observation_space, action_space, config)
|
|
# example parameter
|
|
self.w = 1.0
|
|
|
|
def compute_actions(self,
|
|
obs_batch,
|
|
state_batches,
|
|
prev_action_batch=None,
|
|
prev_reward_batch=None,
|
|
info_batch=None,
|
|
episodes=None,
|
|
**kwargs):
|
|
# return random actions
|
|
return [self.action_space.sample() for _ in obs_batch], [], {}
|
|
|
|
def learn_on_batch(self, samples):
|
|
# implement your learning code here
|
|
return {}
|
|
|
|
def update_some_value(self, w):
|
|
# can also call other methods on policies
|
|
self.w = w
|
|
|
|
def get_weights(self):
|
|
return {"w": self.w}
|
|
|
|
def set_weights(self, weights):
|
|
self.w = weights["w"]
|
|
|
|
|
|
def training_workflow(config, reporter):
|
|
# Setup policy and policy evaluation actors
|
|
env = gym.make("CartPole-v0")
|
|
policy = CustomPolicy(env.observation_space, env.action_space, {})
|
|
workers = [
|
|
RolloutWorker.as_remote().remote(lambda c: gym.make("CartPole-v0"),
|
|
CustomPolicy)
|
|
for _ in range(config["num_workers"])
|
|
]
|
|
|
|
for _ in range(config["num_iters"]):
|
|
# Broadcast weights to the policy evaluation workers
|
|
weights = ray.put({"default_policy": policy.get_weights()})
|
|
for w in workers:
|
|
w.set_weights.remote(weights)
|
|
|
|
# Gather a batch of samples
|
|
T1 = SampleBatch.concat_samples(
|
|
ray.get([w.sample.remote() for w in workers]))
|
|
|
|
# Update the remote policy replicas and gather another batch of samples
|
|
new_value = policy.w * 2.0
|
|
for w in workers:
|
|
w.for_policy.remote(lambda p: p.update_some_value(new_value))
|
|
|
|
# Gather another batch of samples
|
|
T2 = SampleBatch.concat_samples(
|
|
ray.get([w.sample.remote() for w in workers]))
|
|
|
|
# Improve the policy using the T1 batch
|
|
policy.learn_on_batch(T1)
|
|
|
|
# Do some arbitrary updates based on the T2 batch
|
|
policy.update_some_value(sum(T2["rewards"]))
|
|
|
|
reporter(**collect_metrics(remote_workers=workers))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
ray.init()
|
|
|
|
tune.run(
|
|
training_workflow,
|
|
resources_per_trial={
|
|
"gpu": 1 if args.gpu else 0,
|
|
"cpu": 1,
|
|
"extra_cpu": args.num_workers,
|
|
},
|
|
config={
|
|
"num_workers": args.num_workers,
|
|
"num_iters": args.num_iters,
|
|
},
|
|
)
|