2017-06-29 08:49:56 -07:00
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
2017-08-27 18:56:52 -07:00
|
|
|
import pickle
|
2017-06-29 08:49:56 -07:00
|
|
|
import os
|
2018-08-20 15:28:03 -07:00
|
|
|
import time
|
2017-06-29 08:49:56 -07:00
|
|
|
|
|
|
|
import ray
|
2018-07-01 00:05:08 -07:00
|
|
|
from ray.rllib.agents.agent import Agent, with_common_config
|
2018-06-27 02:30:15 -07:00
|
|
|
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
2018-07-30 13:25:35 -07:00
|
|
|
from ray.rllib.utils import FilterManager, merge_dicts
|
2018-04-16 16:58:15 -07:00
|
|
|
from ray.tune.trial import Resources
|
2017-06-29 08:49:56 -07:00
|
|
|
|
2018-07-01 00:05:08 -07:00
|
|
|
DEFAULT_CONFIG = with_common_config({
|
2017-11-30 00:22:25 -08:00
|
|
|
# Size of rollout batch
|
2018-07-01 00:05:08 -07:00
|
|
|
"sample_batch_size": 10,
|
2017-12-24 12:25:13 -08:00
|
|
|
# Use PyTorch as backend - no LSTM support
|
2017-11-12 00:20:33 -08:00
|
|
|
"use_pytorch": False,
|
2017-12-24 12:25:13 -08:00
|
|
|
# GAE(gamma) parameter
|
|
|
|
"lambda": 1.0,
|
|
|
|
# Max global norm for each gradient calculated by worker
|
|
|
|
"grad_clip": 40.0,
|
|
|
|
# Learning rate
|
|
|
|
"lr": 0.0001,
|
|
|
|
# Value Function Loss coefficient
|
|
|
|
"vf_loss_coeff": 0.5,
|
|
|
|
# Entropy coefficient
|
|
|
|
"entropy_coeff": -0.01,
|
2018-01-05 21:32:41 -08:00
|
|
|
# Whether to place workers on GPUs
|
|
|
|
"use_gpu_for_workers": False,
|
2018-06-09 00:21:35 -07:00
|
|
|
# Whether to emit extra summary stats
|
|
|
|
"summarize": False,
|
2018-08-20 15:28:03 -07:00
|
|
|
# Min time per iteration
|
|
|
|
"min_iter_time_s": 5,
|
2018-08-01 15:11:30 -07:00
|
|
|
# Workers sample async. Note that this increases the effective
|
|
|
|
# sample_batch_size by up to 5x due to async buffering of batches.
|
2018-07-01 00:05:08 -07:00
|
|
|
"sample_async": True,
|
2017-12-28 13:19:04 -08:00
|
|
|
# Model and preprocessor options
|
|
|
|
"model": {
|
2018-06-27 22:51:04 -07:00
|
|
|
# Use LSTM model. Requires TF.
|
2018-06-26 13:17:15 -07:00
|
|
|
"use_lstm": False,
|
2018-06-27 22:51:04 -07:00
|
|
|
# Max seq length for LSTM training.
|
|
|
|
"max_seq_len": 20,
|
2017-12-24 12:25:13 -08:00
|
|
|
# (Image statespace) - Converts image to Channels = 1
|
|
|
|
"grayscale": True,
|
|
|
|
# (Image statespace) - Each pixel
|
|
|
|
"zero_mean": False,
|
|
|
|
# (Image statespace) - Converts image to (dim, dim, C)
|
2018-08-20 15:28:03 -07:00
|
|
|
"dim": 84,
|
2017-12-24 12:25:13 -08:00
|
|
|
# (Image statespace) - Converts image shape to (C, dim, dim)
|
2018-05-30 10:48:11 -07:00
|
|
|
"channel_major": False,
|
2017-12-24 12:25:13 -08:00
|
|
|
},
|
2018-07-01 00:05:08 -07:00
|
|
|
# Configure TF for single-process operation
|
|
|
|
"tf_session_args": {
|
|
|
|
"intra_op_parallelism_threads": 1,
|
|
|
|
"inter_op_parallelism_threads": 1,
|
|
|
|
"gpu_options": {
|
|
|
|
"allow_growth": True,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
})
|
2017-06-29 08:49:56 -07:00
|
|
|
|
|
|
|
|
2017-08-29 16:56:42 -07:00
|
|
|
class A3CAgent(Agent):
|
2018-07-01 00:05:08 -07:00
|
|
|
"""A3C implementations in TensorFlow and PyTorch."""
|
|
|
|
|
2017-10-10 12:49:42 -07:00
|
|
|
_agent_name = "A3C"
|
2017-10-13 16:18:16 -07:00
|
|
|
_default_config = DEFAULT_CONFIG
|
2017-10-10 12:49:42 -07:00
|
|
|
|
2018-04-16 16:58:15 -07:00
|
|
|
@classmethod
|
|
|
|
def default_resource_request(cls, config):
|
2018-07-30 13:25:35 -07:00
|
|
|
cf = merge_dicts(cls._default_config, config)
|
2018-04-16 16:58:15 -07:00
|
|
|
return Resources(
|
2018-05-30 10:48:11 -07:00
|
|
|
cpu=1,
|
|
|
|
gpu=0,
|
2018-04-16 16:58:15 -07:00
|
|
|
extra_cpu=cf["num_workers"],
|
|
|
|
extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)
|
|
|
|
|
2017-10-10 12:49:42 -07:00
|
|
|
def _init(self):
|
2018-06-26 13:17:15 -07:00
|
|
|
if self.config["use_pytorch"]:
|
2018-07-12 19:22:46 +02:00
|
|
|
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
|
2018-07-01 00:05:08 -07:00
|
|
|
A3CTorchPolicyGraph
|
|
|
|
policy_cls = A3CTorchPolicyGraph
|
2018-06-26 13:17:15 -07:00
|
|
|
else:
|
2018-07-12 19:22:46 +02:00
|
|
|
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
2018-07-01 00:05:08 -07:00
|
|
|
policy_cls = A3CPolicyGraph
|
|
|
|
|
|
|
|
self.local_evaluator = self.make_local_evaluator(
|
|
|
|
self.env_creator, policy_cls)
|
|
|
|
self.remote_evaluators = self.make_remote_evaluators(
|
|
|
|
self.env_creator, policy_cls, self.config["num_workers"],
|
|
|
|
{"num_gpus": 1 if self.config["use_gpu_for_workers"] else 0})
|
2018-08-20 15:28:03 -07:00
|
|
|
self.optimizer = self._make_optimizer()
|
|
|
|
|
|
|
|
def _make_optimizer(self):
|
|
|
|
return AsyncGradientsOptimizer(self.local_evaluator,
|
|
|
|
self.remote_evaluators,
|
|
|
|
self.config["optimizer"])
|
2017-07-13 14:53:57 -07:00
|
|
|
|
2017-09-12 14:28:16 -07:00
|
|
|
def _train(self):
|
2018-07-07 13:29:20 -07:00
|
|
|
prev_steps = self.optimizer.num_steps_sampled
|
2018-08-20 15:28:03 -07:00
|
|
|
start = time.time()
|
|
|
|
while time.time() - start < self.config["min_iter_time_s"]:
|
|
|
|
self.optimizer.step()
|
|
|
|
FilterManager.synchronize(self.local_evaluator.filters,
|
|
|
|
self.remote_evaluators)
|
2018-07-07 13:29:20 -07:00
|
|
|
result = self.optimizer.collect_metrics()
|
2018-08-07 12:17:44 -07:00
|
|
|
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
|
|
|
prev_steps)
|
2018-06-27 22:41:34 -07:00
|
|
|
return result
|
2017-08-24 00:09:33 -07:00
|
|
|
|
2018-02-11 19:14:51 -08:00
|
|
|
def _stop(self):
|
|
|
|
# workaround for https://github.com/ray-project/ray/issues/1516
|
|
|
|
for ev in self.remote_evaluators:
|
2018-05-08 19:19:07 -07:00
|
|
|
ev.__ray_terminate__.remote()
|
2018-02-11 19:14:51 -08:00
|
|
|
|
2018-01-29 18:48:45 -08:00
|
|
|
def _save(self, checkpoint_dir):
|
2018-05-30 10:48:11 -07:00
|
|
|
checkpoint_path = os.path.join(checkpoint_dir,
|
|
|
|
"checkpoint-{}".format(self.iteration))
|
2017-12-24 12:25:13 -08:00
|
|
|
agent_state = ray.get(
|
|
|
|
[a.save.remote() for a in self.remote_evaluators])
|
|
|
|
extra_data = {
|
|
|
|
"remote_state": agent_state,
|
2018-05-30 10:48:11 -07:00
|
|
|
"local_state": self.local_evaluator.save()
|
|
|
|
}
|
2017-12-24 12:25:13 -08:00
|
|
|
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
2017-08-27 18:56:52 -07:00
|
|
|
return checkpoint_path
|
|
|
|
|
2017-09-12 14:28:16 -07:00
|
|
|
def _restore(self, checkpoint_path):
|
2017-12-24 12:25:13 -08:00
|
|
|
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
2018-05-30 10:48:11 -07:00
|
|
|
ray.get([
|
|
|
|
a.restore.remote(o)
|
|
|
|
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
|
|
|
|
])
|
2017-12-24 12:25:13 -08:00
|
|
|
self.local_evaluator.restore(extra_data["local_state"])
|