[rllib] Reduce concat memory usage, allow object store memory to be specified in init (#1529)

* c

* stop agents

* comment

* Sat Feb 10 02:33:30 PST 2018

* Sat Feb 10 02:33:39 PST 2018

* Update sample_batch.py

* Sun Feb 11 14:38:55 PST 2018

* add ppo config warn
This commit is contained in:
Eric Liang 2018-02-11 19:14:51 -08:00 committed by GitHub
parent b6a06b81ed
commit 7e998db656
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 63 additions and 7 deletions

View file

@ -114,6 +114,11 @@ class A3CAgent(Agent):
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self.iteration))

View file

@ -218,6 +218,11 @@ class DQNAgent(Agent):
else:
self.local_evaluator.sample(no_replay=True)
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
self.local_evaluator.sess,

View file

@ -300,6 +300,11 @@ class ESAgent(Agent):
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self.workers:
w.__ray_terminate__.remote(w._ray_actor_id.id())
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(
checkpoint_dir, "checkpoint-{}".format(self.iteration))

View file

@ -20,7 +20,6 @@ class FullyConnectedNetwork(Model):
activation = tf.nn.tanh
elif fcnet_activation == "relu":
activation = tf.nn.relu
print("Constructing fcnet {} {}".format(hiddens, activation))
with tf.name_scope("fc_net"):
i = 1

View file

@ -2,11 +2,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import reduce
import numpy as np
def arrayify(s):
if type(s) in [int, float, str, np.ndarray]:
return s
elif type(s) is list:
# recursive call to convert LazyFrames to arrays
return np.array([arrayify(x) for x in s])
else:
return np.array(s)
class SampleBatch(object):
"""Wrapper around a dictionary with string keys and array-like values.
@ -27,7 +35,10 @@ class SampleBatch(object):
@staticmethod
def concat_samples(samples):
return reduce(lambda a, b: a.concat(b), samples)
out = {}
for k in samples[0].data.keys():
out[k] = np.concatenate([arrayify(s.data[k]) for s in samples])
return SampleBatch(out)
def concat(self, other):
"""Returns a new SampleBatch with each data column concatenated.

View file

@ -116,6 +116,14 @@ class PPOAgent(Agent):
config = self.config
model = self.local_evaluator
if (config["num_workers"] * config["min_steps_per_task"] >
config["timesteps_per_batch"]):
print(
"WARNING: num_workers * min_steps_per_task > "
"timesteps_per_batch. This means that the output of some "
"tasks will be wasted. Consider decreasing "
"min_steps_per_task or increasing timesteps_per_batch.")
print("===> iteration", self.iteration)
iter_start = time.time()
@ -244,6 +252,11 @@ class PPOAgent(Agent):
return result
def _stop(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for ev in self.remote_evaluators:
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
def _save(self, checkpoint_dir):
checkpoint_path = self.saver.save(
self.local_evaluator.sess,

View file

@ -4,9 +4,11 @@ from __future__ import print_function
import unittest
import numpy as np
import ray
from ray.rllib.test.mock_evaluator import _MockEvaluator
from ray.rllib.optimizers import AsyncOptimizer
from ray.rllib.optimizers import AsyncOptimizer, SampleBatch
class AsyncOptimizerTest(unittest.TestCase):
@ -25,5 +27,18 @@ class AsyncOptimizerTest(unittest.TestCase):
self.assertTrue(all(local.get_weights() == 0))
class SampleBatchTest(unittest.TestCase):
def testConcat(self):
b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
b12 = b1.concat(b2)
self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1])
self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4])
b = SampleBatch.concat_samples([b1, b2, b3])
self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1])
self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5])
if __name__ == '__main__':
unittest.main(verbosity=2)

View file

@ -1390,7 +1390,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
num_cpus=None, num_gpus=None, resources=None,
num_custom_resource=None, num_redis_shards=None,
redis_max_clients=None, plasma_directory=None,
huge_pages=False, include_webui=True):
huge_pages=False, include_webui=True, object_store_memory=None):
"""Connect to an existing Ray cluster or start one and connect to it.
This method handles two cases. Either a Ray cluster already exists and we
@ -1430,6 +1430,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which is a Jupyter notebook.
object_store_memory: The amount of memory (in bytes) to start the
object store with.
Returns:
Address information about the started processes.
@ -1454,7 +1456,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
redis_max_clients=redis_max_clients,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
include_webui=include_webui)
include_webui=include_webui,
object_store_memory=object_store_memory)
def cleanup(worker=global_worker):