Files
ray/python/ray/rllib/ppo/env.py
T
Richard Liaw cb6dea94bc [rllib] Fix Preprocessor for ATARI (#1066)
* Removing squeeze, fix atari preprocessing

* nit comment

* comments

* jenkins

* Lint
2017-10-03 18:45:02 -07:00

52 lines
1.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gym
import numpy as np
from ray.rllib.models import ModelCatalog
class BatchedEnv(object):
"""This holds multiple gym envs and performs steps on all of them."""
def __init__(self, name, batchsize, options):
self.envs = [gym.make(name) for _ in range(batchsize)]
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
self.batchsize = batchsize
self.preprocessor = ModelCatalog.get_preprocessor(
name, self.envs[0].observation_space.shape, options["model"])
self.extra_frameskip = options.get("extra_frameskip", 1)
assert self.extra_frameskip >= 1
def reset(self):
observations = [
self.preprocessor.transform(env.reset())[None]
for env in self.envs]
self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations)
def step(self, actions, render=False):
observations = []
rewards = []
for i, action in enumerate(actions):
if self.dones[i]:
observations.append(np.zeros(self.shape))
rewards.append(0.0)
continue
reward = 0.0
for j in range(self.extra_frameskip):
observation, r, done, info = self.envs[i].step(action)
reward += r
if done:
break
if render:
self.envs[0].render()
observations.append(self.preprocessor.transform(observation)[None])
rewards.append(reward)
self.dones[i] = done
return (np.vstack(observations), np.array(rewards, dtype="float32"),
np.array(self.dones))