Files
world-models-pytorch/rollouts.py
T

36 lines
836 B
Python

import gym
import torch
def random_rollouts(env, num_rollouts, render=False):
"""
This function collects random rollouts from a given environment.
"""
obs = env.reset()
num_obs = 0
rollouts = {'observations': [obs], 'actions': ['0']}
while num_obs != num_rollouts:
if render:
env.render()
num_obs += 1
action = env.action_space.sample()
obs, reward, done, _ = env.step(action)
rollouts['observations'].append(obs)
rollouts['actions'].append(action)
if done:
obs = env.reset()
rollouts['observations'].append(obs)
rollouts['actions'].append('0')
return rollouts
def save_rollouts(rollouts):
torch.save(rollouts, 'rollouts.data')
def load_rollouts(fname):
return torch.load(fname)