mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 11:10:02 +08:00
[rllib] Fix LSTM regression on truncated sequences and add regression test (#2898)
* fix * add test * yapf * yapf * fix space * Oops that should be lstm: True * Update cartpole_lstm.py
This commit is contained in:
+4
@@ -7,6 +7,10 @@ cv2.ocl.setUseOpenCL(False)
|
||||
|
||||
|
||||
def is_atari(env):
|
||||
if (hasattr(env.observation_space, "shape")
|
||||
and env.observation_space.shape is not None
|
||||
and len(env.observation_space.shape) <= 2):
|
||||
return False
|
||||
return hasattr(env, "unwrapped") and hasattr(env.unwrapped, "ale")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
"""Partially observed variant of the CartPole gym environment.
|
||||
|
||||
https://github.com/openai/gym/blob/master/gym/envs/classic_control/cartpole.py
|
||||
|
||||
We delete the velocity component of the state, so that it can only be solved
|
||||
by a LSTM policy."""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import gym
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
import numpy as np
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--stop", type=int, default=200)
|
||||
|
||||
|
||||
class CartPoleStatelessEnv(gym.Env):
|
||||
metadata = {
|
||||
'render.modes': ['human', 'rgb_array'],
|
||||
'video.frames_per_second': 60
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.gravity = 9.8
|
||||
self.masscart = 1.0
|
||||
self.masspole = 0.1
|
||||
self.total_mass = (self.masspole + self.masscart)
|
||||
self.length = 0.5 # actually half the pole's length
|
||||
self.polemass_length = (self.masspole * self.length)
|
||||
self.force_mag = 10.0
|
||||
self.tau = 0.02 # seconds between state updates
|
||||
|
||||
# Angle at which to fail the episode
|
||||
self.theta_threshold_radians = 12 * 2 * math.pi / 360
|
||||
self.x_threshold = 2.4
|
||||
|
||||
high = np.array([
|
||||
self.x_threshold * 2,
|
||||
self.theta_threshold_radians * 2,
|
||||
])
|
||||
|
||||
self.action_space = spaces.Discrete(2)
|
||||
self.observation_space = spaces.Box(-high, high)
|
||||
|
||||
self.seed()
|
||||
self.viewer = None
|
||||
self.state = None
|
||||
|
||||
self.steps_beyond_done = None
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def step(self, action):
|
||||
assert self.action_space.contains(
|
||||
action), "%r (%s) invalid" % (action, type(action))
|
||||
state = self.state
|
||||
x, x_dot, theta, theta_dot = state
|
||||
force = self.force_mag if action == 1 else -self.force_mag
|
||||
costheta = math.cos(theta)
|
||||
sintheta = math.sin(theta)
|
||||
temp = (force + self.polemass_length * theta_dot * theta_dot * sintheta
|
||||
) / self.total_mass
|
||||
thetaacc = (self.gravity * sintheta - costheta * temp) / (
|
||||
self.length *
|
||||
(4.0 / 3.0 - self.masspole * costheta * costheta / self.total_mass)
|
||||
)
|
||||
xacc = (temp -
|
||||
self.polemass_length * thetaacc * costheta / self.total_mass)
|
||||
x = x + self.tau * x_dot
|
||||
x_dot = x_dot + self.tau * xacc
|
||||
theta = theta + self.tau * theta_dot
|
||||
theta_dot = theta_dot + self.tau * thetaacc
|
||||
self.state = (x, x_dot, theta, theta_dot)
|
||||
done = (x < -self.x_threshold or x > self.x_threshold
|
||||
or theta < -self.theta_threshold_radians
|
||||
or theta > self.theta_threshold_radians)
|
||||
done = bool(done)
|
||||
|
||||
if not done:
|
||||
reward = 1.0
|
||||
elif self.steps_beyond_done is None:
|
||||
# Pole just fell!
|
||||
self.steps_beyond_done = 0
|
||||
reward = 1.0
|
||||
else:
|
||||
self.steps_beyond_done += 1
|
||||
reward = 0.0
|
||||
|
||||
rv = np.r_[self.state[0], self.state[2]]
|
||||
return rv, reward, done, {}
|
||||
|
||||
def reset(self):
|
||||
self.state = self.np_random.uniform(low=-0.05, high=0.05, size=(4, ))
|
||||
self.steps_beyond_done = None
|
||||
|
||||
rv = np.r_[self.state[0], self.state[2]]
|
||||
return rv
|
||||
|
||||
def render(self, mode='human'):
|
||||
screen_width = 600
|
||||
screen_height = 400
|
||||
|
||||
world_width = self.x_threshold * 2
|
||||
scale = screen_width / world_width
|
||||
carty = 100 # TOP OF CART
|
||||
polewidth = 10.0
|
||||
polelen = scale * 1.0
|
||||
cartwidth = 50.0
|
||||
cartheight = 30.0
|
||||
|
||||
if self.viewer is None:
|
||||
from gym.envs.classic_control import rendering
|
||||
self.viewer = rendering.Viewer(screen_width, screen_height)
|
||||
l, r, t, b = (-cartwidth / 2, cartwidth / 2, cartheight / 2,
|
||||
-cartheight / 2)
|
||||
axleoffset = cartheight / 4.0
|
||||
cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
|
||||
self.carttrans = rendering.Transform()
|
||||
cart.add_attr(self.carttrans)
|
||||
self.viewer.add_geom(cart)
|
||||
l, r, t, b = (-polewidth / 2, polewidth / 2,
|
||||
polelen - polewidth / 2, -polewidth / 2)
|
||||
pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
|
||||
pole.set_color(.8, .6, .4)
|
||||
self.poletrans = rendering.Transform(translation=(0, axleoffset))
|
||||
pole.add_attr(self.poletrans)
|
||||
pole.add_attr(self.carttrans)
|
||||
self.viewer.add_geom(pole)
|
||||
self.axle = rendering.make_circle(polewidth / 2)
|
||||
self.axle.add_attr(self.poletrans)
|
||||
self.axle.add_attr(self.carttrans)
|
||||
self.axle.set_color(.5, .5, .8)
|
||||
self.viewer.add_geom(self.axle)
|
||||
self.track = rendering.Line((0, carty), (screen_width, carty))
|
||||
self.track.set_color(0, 0, 0)
|
||||
self.viewer.add_geom(self.track)
|
||||
|
||||
if self.state is None:
|
||||
return None
|
||||
|
||||
x = self.state
|
||||
cartx = x[0] * scale + screen_width / 2.0 # MIDDLE OF CART
|
||||
self.carttrans.set_translation(cartx, carty)
|
||||
self.poletrans.set_rotation(-x[2])
|
||||
|
||||
return self.viewer.render(return_rgb_array=mode == 'rgb_array')
|
||||
|
||||
def close(self):
|
||||
if self.viewer:
|
||||
self.viewer.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import ray
|
||||
from ray import tune
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
tune.register_env("cartpole_stateless", lambda _: CartPoleStatelessEnv())
|
||||
|
||||
ray.init()
|
||||
tune.run_experiments({
|
||||
"test": {
|
||||
"env": "cartpole_stateless",
|
||||
"run": "PG",
|
||||
"stop": {
|
||||
"episode_reward_mean": args.stop
|
||||
},
|
||||
"config": {
|
||||
"model": {
|
||||
"use_lstm": True,
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
@@ -156,9 +156,11 @@ class LSTM(Model):
|
||||
self.state_in = [c_in, h_in]
|
||||
|
||||
# Setup LSTM outputs
|
||||
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
||||
lstm_out, lstm_state = tf.nn.dynamic_rnn(
|
||||
lstm,
|
||||
last_layer,
|
||||
initial_state=state_in,
|
||||
sequence_length=self.seq_lens,
|
||||
time_major=False,
|
||||
dtype=tf.float32)
|
||||
|
||||
@@ -36,7 +36,7 @@ def _parse_results(res_path):
|
||||
for line in f:
|
||||
pass
|
||||
res_dict = _flatten_dict(json.loads(line.strip()))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Importing %s failed...Perhaps empty?" % res_path)
|
||||
return res_dict
|
||||
|
||||
@@ -45,7 +45,7 @@ def _parse_configs(cfg_path):
|
||||
try:
|
||||
with open(cfg_path) as f:
|
||||
cfg_dict = _flatten_dict(json.load(f))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Config parsing failed.")
|
||||
return cfg_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user