Files
ray/python/ray/rllib/models/preprocessors.py
T
Richard Liaw 16e82b43d1 [rllib] Changes for preprocessors (#1033)
* Changes for preprocessors

* removed comments

* Changes + push for lint

* linted

* adding dependency for travis

* linting won't pass

* reordering

* needed for testing

* added comments

* pip it

* pip dependencies
2017-09-30 13:11:20 -07:00

73 lines
2.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2
import numpy as np
class Preprocessor(object):
"""Defines an abstract observation preprocessor function."""
def __init__(self, options):
self.options = options
self._init()
def _init(self):
pass
def transform_shape(self, obs_shape):
"""Returns the preprocessed observation shape."""
raise NotImplementedError
def transform(self, observation):
"""Returns the preprocessed observation."""
raise NotImplementedError
class AtariPixelPreprocessor(Preprocessor):
def _init(self):
self.grayscale = self.options.get("grayscale", False)
self.zero_mean = self.options.get("zero_mean", True)
self.dim = self.options.get("dim", 80)
def transform_shape(self, obs_shape):
if self.grayscale:
return (self.dim, self.dim, 1)
else:
return (self.dim, self.dim, 3)
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
def transform(self, observation):
"""Downsamples images from (210, 160, 3) by the configured factor."""
scaled = observation[25:-25, :, :]
if self.dim < 80:
scaled = cv2.resize(scaled, (80, 80))
scaled = cv2.resize(scaled, (self.dim, self.dim))
if self.grayscale:
scaled = scaled.mean(2)
scaled = scaled.astype(np.float32)
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
scaled = scaled[None]
if self.zero_mean:
scaled = (scaled - 128) / 128
else:
scaled *= 1.0 / 255.0
return scaled
# TODO(rliaw): Also should include the deepmind preprocessor
class AtariRamPreprocessor(Preprocessor):
def transform_shape(self, obs_shape):
return (128,)
def transform(self, observation):
return (observation - 128) / 128
class NoPreprocessor(Preprocessor):
def transform_shape(self, obs_shape):
return obs_shape
def transform(self, observation):
return observation