mirror of
https://github.com/wassname/world-models-pytorch.git
synced 2026-06-27 17:18:25 +08:00
18 lines
440 B
Python
18 lines
440 B
Python
from torch.utils.data.dataset import Dataset
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from rollouts import load_rollouts
|
|
|
|
|
|
class VAEDataset(Dataset):
|
|
def __init__(self, rollout_path):
|
|
super(VAEDataset, self).__init__()
|
|
rollouts = load_rollouts(rollout_path)
|
|
self.images = rollouts['observations']
|
|
|
|
def __getitem__(self, index):
|
|
pass
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|