mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 19:28:27 +08:00
[rllib] Reduce concat memory usage, allow object store memory to be specified in init (#1529)
* c * stop agents * comment * Sat Feb 10 02:33:30 PST 2018 * Sat Feb 10 02:33:39 PST 2018 * Update sample_batch.py * Sun Feb 11 14:38:55 PST 2018 * add ppo config warn
This commit is contained in:
@@ -114,6 +114,11 @@ class A3CAgent(Agent):
|
||||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir, "checkpoint-{}".format(self.iteration))
|
||||
|
||||
@@ -218,6 +218,11 @@ class DQNAgent(Agent):
|
||||
else:
|
||||
self.local_evaluator.sample(no_replay=True)
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.local_evaluator.sess,
|
||||
|
||||
@@ -300,6 +300,11 @@ class ESAgent(Agent):
|
||||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for w in self.workers:
|
||||
w.__ray_terminate__.remote(w._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(
|
||||
checkpoint_dir, "checkpoint-{}".format(self.iteration))
|
||||
|
||||
@@ -20,7 +20,6 @@ class FullyConnectedNetwork(Model):
|
||||
activation = tf.nn.tanh
|
||||
elif fcnet_activation == "relu":
|
||||
activation = tf.nn.relu
|
||||
print("Constructing fcnet {} {}".format(hiddens, activation))
|
||||
|
||||
with tf.name_scope("fc_net"):
|
||||
i = 1
|
||||
|
||||
@@ -2,11 +2,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def arrayify(s):
|
||||
if type(s) in [int, float, str, np.ndarray]:
|
||||
return s
|
||||
elif type(s) is list:
|
||||
# recursive call to convert LazyFrames to arrays
|
||||
return np.array([arrayify(x) for x in s])
|
||||
else:
|
||||
return np.array(s)
|
||||
|
||||
|
||||
class SampleBatch(object):
|
||||
"""Wrapper around a dictionary with string keys and array-like values.
|
||||
|
||||
@@ -27,7 +35,10 @@ class SampleBatch(object):
|
||||
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
return reduce(lambda a, b: a.concat(b), samples)
|
||||
out = {}
|
||||
for k in samples[0].data.keys():
|
||||
out[k] = np.concatenate([arrayify(s.data[k]) for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
def concat(self, other):
|
||||
"""Returns a new SampleBatch with each data column concatenated.
|
||||
|
||||
@@ -116,6 +116,14 @@ class PPOAgent(Agent):
|
||||
config = self.config
|
||||
model = self.local_evaluator
|
||||
|
||||
if (config["num_workers"] * config["min_steps_per_task"] >
|
||||
config["timesteps_per_batch"]):
|
||||
print(
|
||||
"WARNING: num_workers * min_steps_per_task > "
|
||||
"timesteps_per_batch. This means that the output of some "
|
||||
"tasks will be wasted. Consider decreasing "
|
||||
"min_steps_per_task or increasing timesteps_per_batch.")
|
||||
|
||||
print("===> iteration", self.iteration)
|
||||
|
||||
iter_start = time.time()
|
||||
@@ -244,6 +252,11 @@ class PPOAgent(Agent):
|
||||
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote(ev._ray_actor_id.id())
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = self.saver.save(
|
||||
self.local_evaluator.sess,
|
||||
|
||||
@@ -4,9 +4,11 @@ from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.test.mock_evaluator import _MockEvaluator
|
||||
from ray.rllib.optimizers import AsyncOptimizer
|
||||
from ray.rllib.optimizers import AsyncOptimizer, SampleBatch
|
||||
|
||||
|
||||
class AsyncOptimizerTest(unittest.TestCase):
|
||||
@@ -25,5 +27,18 @@ class AsyncOptimizerTest(unittest.TestCase):
|
||||
self.assertTrue(all(local.get_weights() == 0))
|
||||
|
||||
|
||||
class SampleBatchTest(unittest.TestCase):
|
||||
def testConcat(self):
|
||||
b1 = SampleBatch({"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])})
|
||||
b2 = SampleBatch({"a": np.array([1]), "b": np.array([4])})
|
||||
b3 = SampleBatch({"a": np.array([1]), "b": np.array([5])})
|
||||
b12 = b1.concat(b2)
|
||||
self.assertEqual(b12.data["a"].tolist(), [1, 2, 3, 1])
|
||||
self.assertEqual(b12.data["b"].tolist(), [4, 5, 6, 4])
|
||||
b = SampleBatch.concat_samples([b1, b2, b3])
|
||||
self.assertEqual(b.data["a"].tolist(), [1, 2, 3, 1, 1])
|
||||
self.assertEqual(b.data["b"].tolist(), [4, 5, 6, 4, 5])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user