mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 16:46:37 +08:00
[RLlib] Issue #13507: Fix MB-MPO CartPole Env's reward function as well as MB-MPO running into a traj. view API related issue. (#14037)
This commit is contained in:
+6
-6
@@ -542,12 +542,12 @@ py_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# MBMPOTrainer
|
# MBMPOTrainer
|
||||||
#py_test(
|
py_test(
|
||||||
# name = "test_mbmpo",
|
name = "test_mbmpo",
|
||||||
# tags = ["agents_dir"],
|
tags = ["agents_dir"],
|
||||||
# size = "medium",
|
size = "medium",
|
||||||
# srcs = ["agents/mbmpo/tests/test_mbmpo.py"]
|
srcs = ["agents/mbmpo/tests/test_mbmpo.py"]
|
||||||
#)
|
)
|
||||||
|
|
||||||
# PGTrainer
|
# PGTrainer
|
||||||
py_test(
|
py_test(
|
||||||
|
|||||||
@@ -200,6 +200,9 @@ class DynamicsEnsembleCustomModel(TorchModelV2, nn.Module):
|
|||||||
def fit(self):
|
def fit(self):
|
||||||
# Add env samples to Replay Buffer
|
# Add env samples to Replay Buffer
|
||||||
local_worker = get_global_worker()
|
local_worker = get_global_worker()
|
||||||
|
for pid, pol in local_worker.policy_map.items():
|
||||||
|
pol.view_requirements[
|
||||||
|
SampleBatch.NEXT_OBS].used_for_training = True
|
||||||
new_samples = local_worker.sample()
|
new_samples = local_worker.sample()
|
||||||
# Initial Exploration of 8000 timesteps
|
# Initial Exploration of 8000 timesteps
|
||||||
if not self.global_itr:
|
if not self.global_itr:
|
||||||
|
|||||||
Vendored
+41
-41
@@ -1,12 +1,12 @@
|
|||||||
import gym
|
|
||||||
from gym.envs.classic_control import PendulumEnv, CartPoleEnv
|
from gym.envs.classic_control import PendulumEnv, CartPoleEnv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# MuJoCo may not be installed.
|
# MuJoCo may not be installed.
|
||||||
HalfCheetahEnv = HopperEnv = None
|
HalfCheetahEnv = HopperEnv = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from gym.envs.mujoco import HalfCheetahEnv, HopperEnv
|
from gym.envs.mujoco import HalfCheetahEnv, HopperEnv
|
||||||
except (ImportError, gym.error.DependencyNotInstalled):
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -22,11 +22,12 @@ class CartPoleWrapper(CartPoleEnv):
|
|||||||
x = obs_next[:, 0]
|
x = obs_next[:, 0]
|
||||||
theta = obs_next[:, 2]
|
theta = obs_next[:, 2]
|
||||||
|
|
||||||
rew = (x < -self.x_threshold) | (x > self.x_threshold) | (
|
# 1.0 if we are still on, 0.0 if we are terminated due to bounds
|
||||||
theta < -self.theta_threshold_radians) | (
|
# (angular or x-axis) being breached.
|
||||||
theta > self.theta_threshold_radians)
|
rew = 1.0 - ((x < -self.x_threshold) | (x > self.x_threshold) |
|
||||||
|
(theta < -self.theta_threshold_radians) |
|
||||||
|
(theta > self.theta_threshold_radians)).astype(np.float32)
|
||||||
|
|
||||||
rew = rew.astype(float)
|
|
||||||
return rew
|
return rew
|
||||||
|
|
||||||
|
|
||||||
@@ -54,44 +55,43 @@ class PendulumWrapper(PendulumEnv):
|
|||||||
return (((x + np.pi) % (2 * np.pi)) - np.pi)
|
return (((x + np.pi) % (2 * np.pi)) - np.pi)
|
||||||
|
|
||||||
|
|
||||||
if HalfCheetahEnv:
|
class HalfCheetahWrapper(HalfCheetahEnv or object):
|
||||||
|
"""Wrapper for the MuJoCo HalfCheetah-v2 environment.
|
||||||
|
|
||||||
class HalfCheetahWrapper(HalfCheetahEnv):
|
Adds an additional `reward` method for some model-based RL algos (e.g.
|
||||||
"""Wrapper for the MuJoCo HalfCheetah-v2 environment.
|
MB-MPO).
|
||||||
|
"""
|
||||||
|
|
||||||
Adds an additional `reward` method for some model-based RL algos (e.g.
|
def reward(self, obs, action, obs_next):
|
||||||
MB-MPO).
|
if obs.ndim == 2 and action.ndim == 2:
|
||||||
"""
|
assert obs.shape == obs_next.shape
|
||||||
|
forward_vel = obs_next[:, 8]
|
||||||
def reward(self, obs, action, obs_next):
|
ctrl_cost = 0.1 * np.sum(np.square(action), axis=1)
|
||||||
if obs.ndim == 2 and action.ndim == 2:
|
reward = forward_vel - ctrl_cost
|
||||||
assert obs.shape == obs_next.shape
|
|
||||||
forward_vel = obs_next[:, 8]
|
|
||||||
ctrl_cost = 0.1 * np.sum(np.square(action), axis=1)
|
|
||||||
reward = forward_vel - ctrl_cost
|
|
||||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
|
||||||
else:
|
|
||||||
forward_vel = obs_next[8]
|
|
||||||
ctrl_cost = 0.1 * np.square(action).sum()
|
|
||||||
reward = forward_vel - ctrl_cost
|
|
||||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
|
||||||
|
|
||||||
class HopperWrapper(HopperEnv):
|
|
||||||
"""Wrapper for the MuJoCo Hopper-v2 environment.
|
|
||||||
|
|
||||||
Adds an additional `reward` method for some model-based RL algos (e.g.
|
|
||||||
MB-MPO).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def reward(self, obs, action, obs_next):
|
|
||||||
alive_bonus = 1.0
|
|
||||||
assert obs.ndim == 2 and action.ndim == 2
|
|
||||||
assert (obs.shape == obs_next.shape
|
|
||||||
and action.shape[0] == obs.shape[0])
|
|
||||||
vel = obs_next[:, 5]
|
|
||||||
ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1)
|
|
||||||
reward = vel + alive_bonus - ctrl_cost
|
|
||||||
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||||
|
else:
|
||||||
|
forward_vel = obs_next[8]
|
||||||
|
ctrl_cost = 0.1 * np.square(action).sum()
|
||||||
|
reward = forward_vel - ctrl_cost
|
||||||
|
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||||
|
|
||||||
|
|
||||||
|
class HopperWrapper(HopperEnv or object):
|
||||||
|
"""Wrapper for the MuJoCo Hopper-v2 environment.
|
||||||
|
|
||||||
|
Adds an additional `reward` method for some model-based RL algos (e.g.
|
||||||
|
MB-MPO).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reward(self, obs, action, obs_next):
|
||||||
|
alive_bonus = 1.0
|
||||||
|
assert obs.ndim == 2 and action.ndim == 2
|
||||||
|
assert (obs.shape == obs_next.shape
|
||||||
|
and action.shape[0] == obs.shape[0])
|
||||||
|
vel = obs_next[:, 5]
|
||||||
|
ctrl_cost = 1e-3 * np.sum(np.square(action), axis=1)
|
||||||
|
reward = vel + alive_bonus - ctrl_cost
|
||||||
|
return np.minimum(np.maximum(-1000.0, reward), 1000.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -580,10 +580,14 @@ class DynamicTFPolicy(TFPolicy):
|
|||||||
# Add those needed for postprocessing and training.
|
# Add those needed for postprocessing and training.
|
||||||
all_accessed_keys = train_batch.accessed_keys | \
|
all_accessed_keys = train_batch.accessed_keys | \
|
||||||
batch_for_postproc.accessed_keys
|
batch_for_postproc.accessed_keys
|
||||||
# Tag those only needed for post-processing.
|
# Tag those only needed for post-processing (with some exceptions).
|
||||||
for key in batch_for_postproc.accessed_keys:
|
for key in batch_for_postproc.accessed_keys:
|
||||||
if key not in train_batch.accessed_keys and \
|
if key not in train_batch.accessed_keys and \
|
||||||
key not in self.model.view_requirements:
|
key not in self.model.view_requirements and \
|
||||||
|
key not in [
|
||||||
|
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||||
|
SampleBatch.UNROLL_ID, SampleBatch.DONES,
|
||||||
|
SampleBatch.REWARDS, SampleBatch.INFOS]:
|
||||||
if key in self.view_requirements:
|
if key in self.view_requirements:
|
||||||
self.view_requirements[key].used_for_training = False
|
self.view_requirements[key].used_for_training = False
|
||||||
if key in self._loss_input_dict:
|
if key in self._loss_input_dict:
|
||||||
|
|||||||
@@ -668,11 +668,16 @@ class Policy(metaclass=ABCMeta):
|
|||||||
if key not in self.view_requirements:
|
if key not in self.view_requirements:
|
||||||
self.view_requirements[key] = ViewRequirement()
|
self.view_requirements[key] = ViewRequirement()
|
||||||
if self._loss:
|
if self._loss:
|
||||||
# Tag those only needed for post-processing.
|
# Tag those only needed for post-processing (with some
|
||||||
|
# exceptions).
|
||||||
for key in batch_for_postproc.accessed_keys:
|
for key in batch_for_postproc.accessed_keys:
|
||||||
if key not in train_batch.accessed_keys and \
|
if key not in train_batch.accessed_keys and \
|
||||||
key in self.view_requirements and \
|
key in self.view_requirements and \
|
||||||
key not in self.model.view_requirements:
|
key not in self.model.view_requirements and \
|
||||||
|
key not in [
|
||||||
|
SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
|
||||||
|
SampleBatch.UNROLL_ID, SampleBatch.DONES,
|
||||||
|
SampleBatch.REWARDS, SampleBatch.INFOS]:
|
||||||
self.view_requirements[key].used_for_training = False
|
self.view_requirements[key].used_for_training = False
|
||||||
# Remove those not needed at all (leave those that are needed
|
# Remove those not needed at all (leave those that are needed
|
||||||
# by Sampler to properly execute sample collection).
|
# by Sampler to properly execute sample collection).
|
||||||
|
|||||||
Reference in New Issue
Block a user