mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 18:14:55 +08:00
[rllib] Reduce concat memory usage, allow object store memory to be specified in init (#1529)
* c * stop agents * comment * Sat Feb 10 02:33:30 PST 2018 * Sat Feb 10 02:33:39 PST 2018 * Update sample_batch.py * Sun Feb 11 14:38:55 PST 2018 * add ppo config warn
This commit is contained in:
@@ -2,11 +2,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from functools import reduce
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def arrayify(s):
|
||||
if type(s) in [int, float, str, np.ndarray]:
|
||||
return s
|
||||
elif type(s) is list:
|
||||
# recursive call to convert LazyFrames to arrays
|
||||
return np.array([arrayify(x) for x in s])
|
||||
else:
|
||||
return np.array(s)
|
||||
|
||||
|
||||
class SampleBatch(object):
|
||||
"""Wrapper around a dictionary with string keys and array-like values.
|
||||
|
||||
@@ -27,7 +35,10 @@ class SampleBatch(object):
|
||||
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
return reduce(lambda a, b: a.concat(b), samples)
|
||||
out = {}
|
||||
for k in samples[0].data.keys():
|
||||
out[k] = np.concatenate([arrayify(s.data[k]) for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
def concat(self, other):
|
||||
"""Returns a new SampleBatch with each data column concatenated.
|
||||
|
||||
Reference in New Issue
Block a user