Files

18 lines
440 B
Python

from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from rollouts import load_rollouts
class VAEDataset(Dataset):
def __init__(self, rollout_path):
super(VAEDataset, self).__init__()
rollouts = load_rollouts(rollout_path)
self.images = rollouts['observations']
def __getitem__(self, index):
pass
def __len__(self):
return len(self.images)