mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 19:01:10 +08:00
16e82b43d1
* Changes for preprocessors * removed comments * Changes + push for lint * linted * adding dependency for travis * linting won't pass * reordering * needed for testing * added comments * pip it * pip dependencies
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
class Preprocessor(object):
|
|
"""Defines an abstract observation preprocessor function."""
|
|
|
|
def __init__(self, options):
|
|
self.options = options
|
|
self._init()
|
|
|
|
def _init(self):
|
|
pass
|
|
|
|
def transform_shape(self, obs_shape):
|
|
"""Returns the preprocessed observation shape."""
|
|
raise NotImplementedError
|
|
|
|
def transform(self, observation):
|
|
"""Returns the preprocessed observation."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class AtariPixelPreprocessor(Preprocessor):
|
|
def _init(self):
|
|
self.grayscale = self.options.get("grayscale", False)
|
|
self.zero_mean = self.options.get("zero_mean", True)
|
|
self.dim = self.options.get("dim", 80)
|
|
|
|
def transform_shape(self, obs_shape):
|
|
if self.grayscale:
|
|
return (self.dim, self.dim, 1)
|
|
else:
|
|
return (self.dim, self.dim, 3)
|
|
|
|
# TODO(ekl) why does this need to return an extra size-1 dim (the [None])
|
|
def transform(self, observation):
|
|
"""Downsamples images from (210, 160, 3) by the configured factor."""
|
|
scaled = observation[25:-25, :, :]
|
|
if self.dim < 80:
|
|
scaled = cv2.resize(scaled, (80, 80))
|
|
scaled = cv2.resize(scaled, (self.dim, self.dim))
|
|
if self.grayscale:
|
|
scaled = scaled.mean(2)
|
|
scaled = scaled.astype(np.float32)
|
|
scaled = np.reshape(scaled, [self.dim, self.dim, 1])
|
|
scaled = scaled[None]
|
|
if self.zero_mean:
|
|
scaled = (scaled - 128) / 128
|
|
else:
|
|
scaled *= 1.0 / 255.0
|
|
return scaled
|
|
|
|
|
|
# TODO(rliaw): Also should include the deepmind preprocessor
|
|
class AtariRamPreprocessor(Preprocessor):
|
|
def transform_shape(self, obs_shape):
|
|
return (128,)
|
|
|
|
def transform(self, observation):
|
|
return (observation - 128) / 128
|
|
|
|
|
|
class NoPreprocessor(Preprocessor):
|
|
def transform_shape(self, obs_shape):
|
|
return obs_shape
|
|
|
|
def transform(self, observation):
|
|
return observation
|