ray/examples/a3c/driver.py
Richard Liaw b463d9e5c7 Initial A3C Example - PongDeterministic-v3 (#331)
* Initializing A3C code

* Modifications for Ray usage

* cleanup

* removing universe dependency

* fixes (not yet working

* hack

* documentation

* Cleanup

* Preliminary Portion

Make sure to change when merging

* RL part

* Cleaning up Driver and Worker code

* Updating driver code

* instructions...

* fixed

* Minor changes.

* Fixing cmake issues

* ray instruction

* updating port to new universe

* Fix for env.configure

* redundant commands

* Revert scipy.misc -> cv2 and raise exception for wrong gym version.
2017-03-11 00:57:53 -08:00

81 lines
2.8 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ray
import numpy as np
from runner import RunnerThread, process_rollout
from LSTM import LSTMPolicy
import tensorflow as tf
import six.moves.queue as queue
import gym
import sys
from datetime import datetime, timedelta
from misc import timestamp, Profiler
from envs import create_env
@ray.actor
class Runner(object):
"""Actor object to start running simulation on workers.
Gradient computation is also executed from this object."""
def __init__(self, env_name, actor_id, logdir="tmp/", start=True):
env = create_env(env_name, None, None)
self.id = actor_id
num_actions = env.action_space.n
self.policy = LSTMPolicy(env.observation_space.shape, num_actions, actor_id)
self.runner = RunnerThread(env, self.policy, 20)
self.env = env
self.logdir = logdir
if start:
self.start()
def pull_batch_from_queue(self):
""" self explanatory: take a rollout from the queue of the thread runner. """
rollout = self.runner.queue.get(timeout=600.0)
while not rollout.terminal:
try:
rollout.extend(self.runner.queue.get_nowait())
except queue.Empty:
break
return rollout
def start(self):
summary_writer = tf.summary.FileWriter(self.logdir + "test_1")
self.summary_writer = summary_writer
self.runner.start_runner(self.policy.sess, summary_writer)
def compute_gradient(self, params):
self.policy.set_weights(params)
rollout = self.pull_batch_from_queue()
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
gradient = self.policy.get_gradients(batch)
info = {"id": self.id,
"size": len(batch.a)}
return gradient, info
def train(num_workers, env_name="PongDeterministic-v3"):
env = create_env(env_name, None, None)
policy = LSTMPolicy(env.observation_space.shape, env.action_space.n, 0)
agents = [Runner(env_name, i) for i in range(num_workers)]
parameters = policy.get_weights()
gradient_list = [agent.compute_gradient(parameters) for agent in agents]
steps = 0
obs = 0
while True:
done_id, gradient_list = ray.wait(gradient_list)
gradient, info = ray.get(done_id)[0]
policy.model_update(gradient)
parameters = policy.get_weights()
steps += 1
obs += info["size"]
gradient_list.extend([agents[info["id"]].compute_gradient(parameters)])
return policy
if __name__ == '__main__':
if gym.__version__[:3] == '0.8':
raise Exception("This example currently does not work with gym==0.8.0. "
"Please downgrade to gym==0.7.4.");
NW = int(sys.argv[1])
ray.init(num_workers=NW, num_cpus=NW)
train(NW)