mirror of
https://github.com/vale981/ray
synced 2025-03-09 12:56:46 -04:00
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
![]() |
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import ray
|
||
|
import numpy as np
|
||
|
from runner import RunnerThread, process_rollout
|
||
|
from LSTM import LSTMPolicy
|
||
|
import tensorflow as tf
|
||
|
import six.moves.queue as queue
|
||
|
import gym
|
||
|
import sys
|
||
|
from datetime import datetime, timedelta
|
||
|
from misc import timestamp, Profiler
|
||
|
from envs import create_env
|
||
|
|
||
|
@ray.actor
|
||
|
class Runner(object):
|
||
|
"""Actor object to start running simulation on workers.
|
||
|
Gradient computation is also executed from this object."""
|
||
|
def __init__(self, env_name, actor_id, logdir="tmp/", start=True):
|
||
|
env = create_env(env_name, None, None)
|
||
|
self.id = actor_id
|
||
|
num_actions = env.action_space.n
|
||
|
self.policy = LSTMPolicy(env.observation_space.shape, num_actions, actor_id)
|
||
|
self.runner = RunnerThread(env, self.policy, 20)
|
||
|
self.env = env
|
||
|
self.logdir = logdir
|
||
|
if start:
|
||
|
self.start()
|
||
|
|
||
|
def pull_batch_from_queue(self):
|
||
|
""" self explanatory: take a rollout from the queue of the thread runner. """
|
||
|
rollout = self.runner.queue.get(timeout=600.0)
|
||
|
while not rollout.terminal:
|
||
|
try:
|
||
|
rollout.extend(self.runner.queue.get_nowait())
|
||
|
except queue.Empty:
|
||
|
break
|
||
|
return rollout
|
||
|
|
||
|
def start(self):
|
||
|
summary_writer = tf.summary.FileWriter(self.logdir + "test_1")
|
||
|
self.summary_writer = summary_writer
|
||
|
self.runner.start_runner(self.policy.sess, summary_writer)
|
||
|
|
||
|
def compute_gradient(self, params):
|
||
|
self.policy.set_weights(params)
|
||
|
rollout = self.pull_batch_from_queue()
|
||
|
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
|
||
|
gradient = self.policy.get_gradients(batch)
|
||
|
info = {"id": self.id,
|
||
|
"size": len(batch.a)}
|
||
|
return gradient, info
|
||
|
|
||
|
|
||
|
def train(num_workers, env_name="PongDeterministic-v3"):
|
||
|
env = create_env(env_name, None, None)
|
||
|
policy = LSTMPolicy(env.observation_space.shape, env.action_space.n, 0)
|
||
|
agents = [Runner(env_name, i) for i in range(num_workers)]
|
||
|
parameters = policy.get_weights()
|
||
|
gradient_list = [agent.compute_gradient(parameters) for agent in agents]
|
||
|
steps = 0
|
||
|
obs = 0
|
||
|
while True:
|
||
|
done_id, gradient_list = ray.wait(gradient_list)
|
||
|
gradient, info = ray.get(done_id)[0]
|
||
|
policy.model_update(gradient)
|
||
|
parameters = policy.get_weights()
|
||
|
steps += 1
|
||
|
obs += info["size"]
|
||
|
gradient_list.extend([agents[info["id"]].compute_gradient(parameters)])
|
||
|
return policy
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
if gym.__version__[:3] == '0.8':
|
||
|
raise Exception("This example currently does not work with gym==0.8.0. "
|
||
|
"Please downgrade to gym==0.7.4.");
|
||
|
NW = int(sys.argv[1])
|
||
|
ray.init(num_workers=NW, num_cpus=NW)
|
||
|
train(NW)
|