[rllib] Add magic methods for rollouts (#2024)

This commit is contained in:
Alok Singh
2018-05-16 22:59:46 -07:00
committed by Richard Liaw
parent 7549209aea
commit c0e4c9d3d1
4 changed files with 83 additions and 47 deletions
+24 -9
View File
@@ -36,8 +36,8 @@ class SampleBatch(object):
@staticmethod
def concat_samples(samples):
out = {}
for k in samples[0].data.keys():
out[k] = np.concatenate([s.data[k] for s in samples])
for k in samples[0].keys():
out[k] = np.concatenate([s[k] for s in samples])
return SampleBatch(out)
def concat(self, other):
@@ -50,10 +50,10 @@ class SampleBatch(object):
{"a": [1, 2, 3, 4, 5]}
"""
assert self.data.keys() == other.data.keys(), "must have same columns"
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.data.keys():
out[k] = np.concatenate([self.data[k], other.data[k]])
for k in self.keys():
out[k] = np.concatenate([self[k], other[k]])
return SampleBatch(out)
def rows(self):
@@ -70,7 +70,7 @@ class SampleBatch(object):
for i in range(self.count):
row = {}
for k in self.data.keys():
for k in self.keys():
row[k] = self[k][i]
yield row
@@ -85,19 +85,34 @@ class SampleBatch(object):
out = []
for k in keys:
out.append(self.data[k])
out.append(self[k])
return out
def shuffle(self):
permutation = np.random.permutation(self.count)
for key, val in self.data.items():
self.data[key] = val[permutation]
for key, val in self.items():
self[key] = val[permutation]
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, item):
self.data[key] = item
def __str__(self):
return "SampleBatch({})".format(str(self.data))
def __repr__(self):
return "SampleBatch({})".format(str(self.data))
def keys(self):
return self.data.keys()
def items(self):
return self.data.items()
def __iter__(self):
return self.data.__iter__()
def __contains__(self, x):
return x in self.data