Files
pytorch-a2c-ppo-acktr/envs.py
T
2017-10-05 15:57:11 -04:00

37 lines
1.0 KiB
Python
Executable File

import os
import gym
from gym.spaces.box import Box
from baselines import bench
from baselines.common.atari_wrappers import wrap_deepmind
def make_env(env_id, seed, rank, log_dir):
def _thunk():
if env_id.find('Bullet') > -1:
import pybullet_envs
env = pybullet_envs.make(env_id)
else:
env = gym.make(env_id)
env.seed(seed + rank)
env = bench.Monitor(env,
os.path.join(log_dir,
"{}.monitor.json".format(rank)))
# Ugly hack to detect atari.
if env.action_space.__class__.__name__ == 'Discrete':
env = wrap_deepmind(env)
env = WrapPyTorch(env)
return env
return _thunk
class WrapPyTorch(gym.ObservationWrapper):
def __init__(self, env=None):
super(WrapPyTorch, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [1, 84, 84])
def _observation(self, observation):
return observation.transpose(2, 0, 1)