From 7e998db65665f2f4b4d656c64a18870ec3a34eef Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 11 Feb 2018 19:14:51 -0800 Subject: [PATCH] [rllib] Reduce concat memory usage, allow object store memory to be specified in init (#1529) * c * stop agents * comment * Sat Feb 10 02:33:30 PST 2018 * Sat Feb 10 02:33:39 PST 2018 * Update sample_batch.py * Sun Feb 11 14:38:55 PST 2018 * add ppo config warn --- python/ray/rllib/a3c/a3c.py | 5 +++++ python/ray/rllib/dqn/dqn.py | 5 +++++ python/ray/rllib/es/es.py | 5 +++++ python/ray/rllib/models/fcnet.py | 1 - python/ray/rllib/optimizers/sample_batch.py | 17 ++++++++++++++--- python/ray/rllib/ppo/ppo.py | 13 +++++++++++++ python/ray/rllib/test/test_optimizers.py | 17 ++++++++++++++++- python/ray/worker.py | 7 +++++-- 8 files changed, 63 insertions(+), 7 deletions(-) diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index f356c789d..6e349be38 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -114,6 +114,11 @@ class A3CAgent(Agent): return result + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote(ev._ray_actor_id.id()) + def _save(self, checkpoint_dir): checkpoint_path = os.path.join( checkpoint_dir, "checkpoint-{}".format(self.iteration)) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 825cba2a7..a0a5dbc01 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -218,6 +218,11 @@ class DQNAgent(Agent): else: self.local_evaluator.sample(no_replay=True) + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote(ev._ray_actor_id.id()) + def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, diff --git a/python/ray/rllib/es/es.py b/python/ray/rllib/es/es.py index e0e0517c3..003edba26 100644 --- a/python/ray/rllib/es/es.py +++ b/python/ray/rllib/es/es.py @@ -300,6 +300,11 @@ class ESAgent(Agent): return result + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + for w in self.workers: + w.__ray_terminate__.remote(w._ray_actor_id.id()) + def _save(self, checkpoint_dir): checkpoint_path = os.path.join( checkpoint_dir, "checkpoint-{}".format(self.iteration)) diff --git a/python/ray/rllib/models/fcnet.py b/python/ray/rllib/models/fcnet.py index 0bcbd68b0..ab40a6c6b 100644 --- a/python/ray/rllib/models/fcnet.py +++ b/python/ray/rllib/models/fcnet.py @@ -20,7 +20,6 @@ class FullyConnectedNetwork(Model): activation = tf.nn.tanh elif fcnet_activation == "relu": activation = tf.nn.relu - print("Constructing fcnet {} {}".format(hiddens, activation)) with tf.name_scope("fc_net"): i = 1 diff --git a/python/ray/rllib/optimizers/sample_batch.py b/python/ray/rllib/optimizers/sample_batch.py index d93fcdce2..510234171 100644 --- a/python/ray/rllib/optimizers/sample_batch.py +++ b/python/ray/rllib/optimizers/sample_batch.py @@ -2,11 +2,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from functools import reduce - import numpy as np +def arrayify(s): + if type(s) in [int, float, str, np.ndarray]: + return s + elif type(s) is list: + # recursive call to convert LazyFrames to arrays + return np.array([arrayify(x) for x in s]) + else: + return np.array(s) + + class SampleBatch(object): """Wrapper around a dictionary with string keys and array-like values. @@ -27,7 +35,10 @@ class SampleBatch(object): @staticmethod def concat_samples(samples): - return reduce(lambda a, b: a.concat(b), samples) + out = {} + for k in samples[0].data.keys(): + out[k] = np.concatenate([arrayify(s.data[k]) for s in samples]) + return SampleBatch(out) def concat(self, other): """Returns a new SampleBatch with each data column concatenated. diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index ad1773ead..46a43299f 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -116,6 +116,14 @@ class PPOAgent(Agent): config = self.config model = self.local_evaluator + if (config["num_workers"] * config["min_steps_per_task"] > + config["timesteps_per_batch"]): + print( + "WARNING: num_workers * min_steps_per_task > " + "timesteps_per_batch. This means that the output of some " + "tasks will be wasted. Consider decreasing " + "min_steps_per_task or increasing timesteps_per_batch.") + print("===> iteration", self.iteration) iter_start = time.time() @@ -244,6 +252,11 @@ class PPOAgent(Agent): return result + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote(ev._ray_actor_id.id()) + def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, diff --git a/python/ray/rllib/test/test_optimizers.py b/python/ray/rllib/test/test_optimizers.py index 15879ea0d..cfb606101 100644 --- a/python/ray/rllib/test/test_optimizers.py +++ b/python/ray/rllib/test/test_optimizers.py @@ -4,9 +4,11 @@ from __future__ import print_function import unittest +import numpy as np + import ray from ray.rllib.test.mock_evaluator import _MockEvaluator -from ray.rllib.optimizers import AsyncOptimizer +from ray.rllib.optimizers import AsyncOptimizer, SampleBatch class AsyncOptimizerTest(unittest.TestCase): @@ -25,5 +27,18 @@ class AsyncOptimizerTest(unittest.TestCase): self.assertTrue(all(local.get_weights() == 0)) +class SampleBatchTest(unittest.TestCase): + def testConcat(self): + b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}) + b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])}) + b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])}) + b12 = b1.concat(b2) + self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1]) + self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4]) + b = SampleBatch.concat_samples([b1, b2, b3]) + self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1]) + self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5]) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/python/ray/worker.py b/python/ray/worker.py index 19ddad8ca..cd68e3db8 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1390,7 +1390,7 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, num_cpus=None, num_gpus=None, resources=None, num_custom_resource=None, num_redis_shards=None, redis_max_clients=None, plasma_directory=None, - huge_pages=False, include_webui=True): + huge_pages=False, include_webui=True, object_store_memory=None): """Connect to an existing Ray cluster or start one and connect to it. This method handles two cases. Either a Ray cluster already exists and we @@ -1430,6 +1430,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which is a Jupyter notebook. + object_store_memory: The amount of memory (in bytes) to start the + object store with. Returns: Address information about the started processes. @@ -1454,7 +1456,8 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, redis_max_clients=redis_max_clients, plasma_directory=plasma_directory, huge_pages=huge_pages, - include_webui=include_webui) + include_webui=include_webui, + object_store_memory=object_store_memory) def cleanup(worker=global_worker):