mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 13:23:53 +08:00
Change training tasks to xray for Jenkins tests (#2567)
This commit is contained in:
committed by
Robert Nishihara
parent
85b8b2a395
commit
9825da7233
+8
-2
@@ -153,7 +153,10 @@ class WarpFrame(gym.ObservationWrapper):
|
||||
self.width = dim # in rllib we use 80
|
||||
self.height = dim
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(self.height, self.width, 1))
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(self.height, self.width, 1),
|
||||
dtype=np.float32)
|
||||
|
||||
def observation(self, frame):
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
@@ -170,7 +173,10 @@ class FrameStack(gym.Wrapper):
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(shp[0], shp[1], shp[2] * k),
|
||||
dtype=np.float32)
|
||||
|
||||
def reset(self):
|
||||
ob = self.env.reset()
|
||||
|
||||
@@ -10,6 +10,7 @@ To try this out, in two separate shells run:
|
||||
|
||||
import os
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNAgent
|
||||
@@ -25,8 +26,9 @@ CHECKPOINT_FILE = "last_checkpoint.out"
|
||||
|
||||
class CartpoleServing(ServingEnv):
|
||||
def __init__(self):
|
||||
ServingEnv.__init__(self, spaces.Discrete(2),
|
||||
spaces.Box(low=-10, high=10, shape=(4, )))
|
||||
ServingEnv.__init__(
|
||||
self, spaces.Discrete(2),
|
||||
spaces.Box(low=-10, high=10, shape=(4, ), dtype=np.float32))
|
||||
|
||||
def run(self):
|
||||
print("Starting policy server at {}:{}".format(SERVER_ADDRESS,
|
||||
|
||||
@@ -36,7 +36,11 @@ class TaskPool(object):
|
||||
|
||||
for worker, obj_id in self.completed():
|
||||
plasma_id = ray.pyarrow.plasma.ObjectID(obj_id.id())
|
||||
ray.worker.global_worker.plasma_client.fetch([plasma_id])
|
||||
if not ray.global_state.use_raylet:
|
||||
ray.worker.global_worker.plasma_client.fetch([plasma_id])
|
||||
else:
|
||||
(ray.worker.global_worker.local_scheduler_client.
|
||||
reconstruct_objects([obj_id], True))
|
||||
self._fetching.append((worker, obj_id))
|
||||
|
||||
remaining = []
|
||||
|
||||
@@ -28,7 +28,11 @@ class PolicyServer(ThreadingMixIn, HTTPServer):
|
||||
def __init__(self):
|
||||
ServingEnv.__init__(
|
||||
self, spaces.Discrete(2),
|
||||
spaces.Box(low=-10, high=10, shape=(4,)))
|
||||
spaces.Box(
|
||||
low=-10,
|
||||
high=10,
|
||||
shape=(4,),
|
||||
dtype=np.float32))
|
||||
def run(self):
|
||||
server = PolicyServer(self, "localhost", 8900)
|
||||
server.serve_forever()
|
||||
|
||||
Reference in New Issue
Block a user