mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 06:14:06 +08:00
[rllib] Add magic methods for rollouts (#2024)
This commit is contained in:
@@ -36,8 +36,8 @@ class SampleBatch(object):
|
||||
@staticmethod
|
||||
def concat_samples(samples):
|
||||
out = {}
|
||||
for k in samples[0].data.keys():
|
||||
out[k] = np.concatenate([s.data[k] for s in samples])
|
||||
for k in samples[0].keys():
|
||||
out[k] = np.concatenate([s[k] for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
def concat(self, other):
|
||||
@@ -50,10 +50,10 @@ class SampleBatch(object):
|
||||
{"a": [1, 2, 3, 4, 5]}
|
||||
"""
|
||||
|
||||
assert self.data.keys() == other.data.keys(), "must have same columns"
|
||||
assert self.keys() == other.keys(), "must have same columns"
|
||||
out = {}
|
||||
for k in self.data.keys():
|
||||
out[k] = np.concatenate([self.data[k], other.data[k]])
|
||||
for k in self.keys():
|
||||
out[k] = np.concatenate([self[k], other[k]])
|
||||
return SampleBatch(out)
|
||||
|
||||
def rows(self):
|
||||
@@ -70,7 +70,7 @@ class SampleBatch(object):
|
||||
|
||||
for i in range(self.count):
|
||||
row = {}
|
||||
for k in self.data.keys():
|
||||
for k in self.keys():
|
||||
row[k] = self[k][i]
|
||||
yield row
|
||||
|
||||
@@ -85,19 +85,34 @@ class SampleBatch(object):
|
||||
|
||||
out = []
|
||||
for k in keys:
|
||||
out.append(self.data[k])
|
||||
out.append(self[k])
|
||||
return out
|
||||
|
||||
def shuffle(self):
|
||||
permutation = np.random.permutation(self.count)
|
||||
for key, val in self.data.items():
|
||||
self.data[key] = val[permutation]
|
||||
for key, val in self.items():
|
||||
self[key] = val[permutation]
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
self.data[key] = item
|
||||
|
||||
def __str__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def __repr__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
def __iter__(self):
|
||||
return self.data.__iter__()
|
||||
|
||||
def __contains__(self, x):
|
||||
return x in self.data
|
||||
|
||||
Reference in New Issue
Block a user