mirror of
https://github.com/vale981/ray
synced 2025-04-23 06:25:52 -04:00
[rllib] Reduce concat memory usage, allow object store memory to be specified in init (#1529)
* c * stop agents * comment * Sat Feb 10 02:33:30 PST 2018 * Sat Feb 10 02:33:39 PST 2018 * Update sample_batch.py * Sun Feb 11 14:38:55 PST 2018 * add ppo config warn
This commit is contained in:
parent
b6a06b81ed
commit
7e998db656
8 changed files with 63 additions and 7 deletions
|
@ -114,6 +114,11 @@ class A3CAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir, "checkpoint-{}".format(self.iteration))
|
||||
|
|
|
@ -218,6 +218,11 @@ class DQNAgent(Agent):
|
|||
else:
|
||||
self.local_evaluator.sample(no_replay=True)
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.local_evaluator.sess,
|
||||
|
|
|
@ -300,6 +300,11 @@ class ESAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
w.__ray_terminate__.remote(w._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir, "checkpoint-{}".format(self.iteration))
|
||||
|
|
|
@ -20,7 +20,6 @@ class FullyConnectedNetwork(Model):
|
|||
activation = tf.nn.tanh
|
||||
elif fcnet_activation == "relu":
|
||||
activation = tf.nn.relu
|
||||
print("Constructing fcnet {} {}".format(hiddens, activation))
|
||||
|
||||
with tf.name_scope("fc_net"):
|
||||
i = 1
|
||||
|
|
|
@ -2,11 +2,19 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def arrayify(s):
|
||||
if type(s) in [int, float, str, np.ndarray]:
|
||||
return s
|
||||
elif type(s) is list:
|
||||
# recursive call to convert LazyFrames to arrays
|
||||
return np.array([arrayify(x) for x in s])
|
||||
else:
|
||||
return np.array(s)
|
||||
|
||||
|
||||
class SampleBatch(object):
|
||||
"""Wrapper around a dictionary with string keys and array-like values.
|
||||
|
||||
|
@ -27,7 +35,10 @@ class SampleBatch(object):
|
|||
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
return reduce(lambda a, b: a.concat(b), samples)
|
||||
out = {}
|
||||
for k in samples[0].data.keys():
|
||||
out[k] = np.concatenate([arrayify(s.data[k]) for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
def concat(self, other):
|
||||
"""Returns a new SampleBatch with each data column concatenated.
|
||||
|
|
|
@ -116,6 +116,14 @@ class PPOAgent(Agent):
|
|||
config = self.config
|
||||
model = self.local_evaluator
|
||||
|
||||
if (config["num_workers"] * config["min_steps_per_task"] >
|
||||
config["timesteps_per_batch"]):
|
||||
print(
|
||||
"WARNING: num_workers * min_steps_per_task > "
|
||||
"timesteps_per_batch. This means that the output of some "
|
||||
"tasks will be wasted. Consider decreasing "
|
||||
"min_steps_per_task or increasing timesteps_per_batch.")
|
||||
|
||||
print("===> iteration", self.iteration)
|
||||
|
||||
iter_start = time.time()
|
||||
|
@ -244,6 +252,11 @@ class PPOAgent(Agent):
|
|||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.local_evaluator.sess,
|
||||
|
|
|
@ -4,9 +4,11 @@ from __future__ import print_function
|
|||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.test.mock_evaluator import _MockEvaluator
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.rllib.optimizers import AsyncOptimizer, SampleBatch
|
||||
|
||||
|
||||
class AsyncOptimizerTest(unittest.TestCase):
|
||||
|
@ -25,5 +27,18 @@ class AsyncOptimizerTest(unittest.TestCase):
|
|||
self.assertTrue(all(local.get_weights() == 0))
|
||||
|
||||
|
||||
class SampleBatchTest(unittest.TestCase):
|
||||
def testConcat(self):
|
||||
b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
|
||||
b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
|
||||
b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
|
||||
b12 = b1.concat(b2)
|
||||
self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1])
|
||||
self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4])
|
||||
b = SampleBatch.concat_samples([b1, b2, b3])
|
||||
self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1])
|
||||
self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -1390,7 +1390,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
|||
num_cpus=None, num_gpus=None, resources=None,
|
||||
num_custom_resource=None, num_redis_shards=None,
|
||||
redis_max_clients=None, plasma_directory=None,
|
||||
huge_pages=False, include_webui=True):
|
||||
huge_pages=False, include_webui=True, object_store_memory=None):
|
||||
"""Connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
|
@ -1430,6 +1430,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
|||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which is a Jupyter notebook.
|
||||
object_store_memory: The amount of memory (in bytes) to start the
|
||||
object store with.
|
||||
|
||||
Returns:
|
||||
Address information about the started processes.
|
||||
|
@ -1454,7 +1456,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
|||
redis_max_clients=redis_max_clients,
|
||||
plasma_directory=plasma_directory,
|
||||
huge_pages=huge_pages,
|
||||
include_webui=include_webui)
|
||||
include_webui=include_webui,
|
||||
object_store_memory=object_store_memory)
|
||||
|
||||
|
||||
def cleanup(worker=global_worker):
|
||||
|
|
Loading…
Add table
Reference in a new issue