ray/rllib/examples/rollout_worker_custom_workflow.py

123 lines
3.8 KiB
Python
Raw Normal View History

"""Example of using rollout worker classes directly to implement training.
Instead of using the built-in Trainer classes provided by RLlib, here we define
a custom Policy class and manually coordinate distributed sample
collection and policy optimization.
"""
import argparse
import gym
import numpy as np
import ray
from ray import tune
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.tune.utils.placement_groups import PlacementGroupFactory
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", action="store_true")
parser.add_argument("--num-iters", type=int, default=20)
parser.add_argument("--num-workers", type=int, default=2)
parser.add_argument("--num-cpus", type=int, default=0)
class CustomPolicy(Policy):
"""Example of a custom policy written from scratch.
You might find it more convenient to extend TF/TorchPolicy instead
for a real policy.
"""
def __init__(self, observation_space, action_space, config):
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124) * Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 23:19:49 +01:00
super().__init__(observation_space, action_space, config)
self.config["framework"] = None
# example parameter
self.w = 1.0
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
# return random actions
return np.array(
[self.action_space.sample() for _ in obs_batch]), [], {}
def learn_on_batch(self, samples):
# implement your learning code here
return {}
def update_some_value(self, w):
# can also call other methods on policies
self.w = w
def get_weights(self):
return {"w": self.w}
def set_weights(self, weights):
self.w = weights["w"]
def training_workflow(config, reporter):
# Setup policy and policy evaluation actors
env = gym.make("CartPole-v0")
policy = CustomPolicy(env.observation_space, env.action_space, {})
workers = [
RolloutWorker.as_remote().remote(
env_creator=lambda c: gym.make("CartPole-v0"), policy=CustomPolicy)
for _ in range(config["num_workers"])
]
for _ in range(config["num_iters"]):
# Broadcast weights to the policy evaluation workers
weights = ray.put({DEFAULT_POLICY_ID: policy.get_weights()})
for w in workers:
w.set_weights.remote(weights)
# Gather a batch of samples
T1 = SampleBatch.concat_samples(
ray.get([w.sample.remote() for w in workers]))
# Update the remote policy replicas and gather another batch of samples
new_value = policy.w * 2.0
for w in workers:
w.for_policy.remote(lambda p: p.update_some_value(new_value))
# Gather another batch of samples
T2 = SampleBatch.concat_samples(
ray.get([w.sample.remote() for w in workers]))
# Improve the policy using the T1 batch
policy.learn_on_batch(T1)
# Do some arbitrary updates based on the T2 batch
policy.update_some_value(sum(T2["rewards"]))
reporter(**collect_metrics(remote_workers=workers))
if __name__ == "__main__":
args = parser.parse_args()
ray.init(num_cpus=args.num_cpus or None)
tune.run(
training_workflow,
resources_per_trial=PlacementGroupFactory(([{
"CPU": 1,
"GPU": 1 if args.gpu else 0
}] + [{
"CPU": 1
}] * args.num_workers)),
config={
"num_workers": args.num_workers,
"num_iters": args.num_iters,
},
verbose=1,
)