Change training tasks to xray for Jenkins tests (#2567)

This commit is contained in:
Yuhong Guo
2018-08-07 04:35:26 +08:00
committed by Robert Nishihara
parent 85b8b2a395
commit 9825da7233
5 changed files with 71 additions and 55 deletions
+8 -2
View File
@@ -153,7 +153,10 @@ class WarpFrame(gym.ObservationWrapper):
self.width = dim # in rllib we use 80
self.height = dim
self.observation_space = spaces.Box(
low=0, high=255, shape=(self.height, self.width, 1))
low=0,
high=255,
shape=(self.height, self.width, 1),
dtype=np.float32)
def observation(self, frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
@@ -170,7 +173,10 @@ class FrameStack(gym.Wrapper):
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
low=0,
high=255,
shape=(shp[0], shp[1], shp[2] * k),
dtype=np.float32)
def reset(self):
ob = self.env.reset()
@@ -10,6 +10,7 @@ To try this out, in two separate shells run:
import os
from gym import spaces
import numpy as np
import ray
from ray.rllib.agents.dqn import DQNAgent
@@ -25,8 +26,9 @@ CHECKPOINT_FILE = "last_checkpoint.out"
class CartpoleServing(ServingEnv):
def __init__(self):
ServingEnv.__init__(self, spaces.Discrete(2),
spaces.Box(low=-10, high=10, shape=(4, )))
ServingEnv.__init__(
self, spaces.Discrete(2),
spaces.Box(low=-10, high=10, shape=(4, ), dtype=np.float32))
def run(self):
print("Starting policy server at {}:{}".format(SERVER_ADDRESS,
+5 -1
View File
@@ -36,7 +36,11 @@ class TaskPool(object):
for worker, obj_id in self.completed():
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id())
ray.worker.global_worker.plasma_client.fetch([plasma_id])
if not ray.global_state.use_raylet:
ray.worker.global_worker.plasma_client.fetch([plasma_id])
else:
(ray.worker.global_worker.local_scheduler_client.
reconstruct_objects([obj_id], True))
self._fetching.append((worker, obj_id))
remaining = []
+5 -1
View File
@@ -28,7 +28,11 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
def __init__(self):
ServingEnv.__init__(
self, spaces.Discrete(2),
spaces.Box(low=-10, high=10, shape=(4,)))
spaces.Box(
low=-10,
high=10,
shape=(4,),
dtype=np.float32))
def run(self):
server = PolicyServer(self, "localhost", 8900)
server.serve_forever()