mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 19:32:38 +08:00
[rllib] Fix issues with PPO model restoration (#1018)
* fix filter * add test * lint * fix * commit * Update a3c.py
This commit is contained in:
committed by
Philipp Moritz
parent
427dee511b
commit
19562f6ce5
@@ -156,5 +156,5 @@ class A3CAgent(Agent):
|
||||
self.policy.set_weights(self.parameters)
|
||||
|
||||
def compute_action(self, observation):
|
||||
actions = self.policy.compute_actions(observation)[0]
|
||||
return actions.argmax()
|
||||
actions = self.policy.compute_actions(observation)
|
||||
return actions[0]
|
||||
|
||||
@@ -12,6 +12,12 @@ class NoFilter(object):
|
||||
def __call__(self, x, update=True):
|
||||
return np.asarray(x)
|
||||
|
||||
def update(self, other):
|
||||
pass
|
||||
|
||||
def copy(self):
|
||||
return self
|
||||
|
||||
|
||||
# http://www.johndcook.com/blog/standard_deviation/
|
||||
class RunningStat(object):
|
||||
@@ -21,6 +27,13 @@ class RunningStat(object):
|
||||
self._M = np.zeros(shape)
|
||||
self._S = np.zeros(shape)
|
||||
|
||||
def copy(self):
|
||||
other = RunningStat()
|
||||
other._n = self._n
|
||||
other._M = np.copy(self._M)
|
||||
other._S = np.copy(self._S)
|
||||
return other
|
||||
|
||||
def push(self, x):
|
||||
x = np.asarray(x)
|
||||
# Unvectorized update of the running statistics.
|
||||
@@ -47,6 +60,10 @@ class RunningStat(object):
|
||||
self._M = M
|
||||
self._S = S
|
||||
|
||||
def __repr__(self):
|
||||
return '(n={}, mean_mean={}, mean_std={})'.format(
|
||||
self.n, np.mean(self.mean), np.mean(self.std))
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return self._n
|
||||
@@ -70,12 +87,23 @@ class RunningStat(object):
|
||||
|
||||
class MeanStdFilter(object):
|
||||
def __init__(self, shape, demean=True, destd=True, clip=10.0):
|
||||
self.shape = shape
|
||||
self.demean = demean
|
||||
self.destd = destd
|
||||
self.clip = clip
|
||||
|
||||
self.rs = RunningStat(shape)
|
||||
|
||||
def update(self, other):
|
||||
self.rs.update(other.rs)
|
||||
|
||||
def copy(self):
|
||||
other = MeanStdFilter(self.shape)
|
||||
other.demean = self.demean
|
||||
other.destd = self.destd
|
||||
other.clip = self.clip
|
||||
other.rs = self.rs.copy()
|
||||
return other
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
x = np.asarray(x)
|
||||
if update:
|
||||
@@ -94,6 +122,10 @@ class MeanStdFilter(object):
|
||||
x = np.clip(x, -self.clip, self.clip)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return 'MeanStdFilter({}, {}, {}, {}, {})'.format(
|
||||
self.shape, self.demean, self.destd, self.clip, self.rs)
|
||||
|
||||
|
||||
def test_running_stat():
|
||||
for shp in ((), (3,), (3, 4)):
|
||||
|
||||
@@ -116,7 +116,8 @@ class PPOAgent(Agent):
|
||||
weights = ray.put(model.get_weights())
|
||||
[a.load_weights.remote(weights) for a in agents]
|
||||
trajectory, total_reward, traj_len_mean = collect_samples(
|
||||
agents, config)
|
||||
agents, config, self.model.observation_filter,
|
||||
self.model.reward_filter)
|
||||
print("total reward is ", total_reward)
|
||||
print("trajectory length mean is ", traj_len_mean)
|
||||
print("timesteps:", trajectory["dones"].shape[0])
|
||||
@@ -269,5 +270,5 @@ class PPOAgent(Agent):
|
||||
for (a, o) in zip(self.agents, extra_data[3])])
|
||||
|
||||
def compute_action(self, observation):
|
||||
observation = self.model.observation_filter(observation)
|
||||
observation = self.model.observation_filter(observation, update=False)
|
||||
return self.model.common_policy.compute([observation])[0][0]
|
||||
|
||||
@@ -5,12 +5,10 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
from ray.rllib.ppo.filter import NoFilter
|
||||
from ray.rllib.ppo.utils import concatenate
|
||||
|
||||
|
||||
def rollouts(policy, env, horizon, observation_filter=NoFilter(),
|
||||
reward_filter=NoFilter()):
|
||||
def rollouts(policy, env, horizon, observation_filter, reward_filter):
|
||||
"""Perform a batch of rollouts of a policy in an environment.
|
||||
|
||||
Args:
|
||||
@@ -98,8 +96,8 @@ def add_advantage_values(trajectory, gamma, lam, reward_filter):
|
||||
|
||||
def collect_samples(agents,
|
||||
config,
|
||||
observation_filter=NoFilter(),
|
||||
reward_filter=NoFilter()):
|
||||
observation_filter,
|
||||
reward_filter):
|
||||
num_timesteps_so_far = 0
|
||||
trajectories = []
|
||||
total_rewards = []
|
||||
@@ -109,7 +107,8 @@ def collect_samples(agents,
|
||||
# tasks here.
|
||||
agent_dict = {agent.compute_steps.remote(
|
||||
config["gamma"], config["lambda"],
|
||||
config["horizon"], config["min_steps_per_task"]):
|
||||
config["horizon"], config["min_steps_per_task"],
|
||||
observation_filter, reward_filter):
|
||||
agent for agent in agents}
|
||||
while num_timesteps_so_far < config["timesteps_per_batch"]:
|
||||
# TODO(pcm): Make wait support arbitrary iterators and remove the
|
||||
@@ -120,12 +119,15 @@ def collect_samples(agents,
|
||||
# Start task with next trajectory and record it in the dictionary.
|
||||
agent_dict[agent.compute_steps.remote(
|
||||
config["gamma"], config["lambda"],
|
||||
config["horizon"], config["min_steps_per_task"])] = (
|
||||
config["horizon"], config["min_steps_per_task"],
|
||||
observation_filter, reward_filter)] = (
|
||||
agent)
|
||||
trajectory, rewards, lengths = ray.get(next_trajectory)
|
||||
trajectory, rewards, lengths, obs_f, rew_f = ray.get(next_trajectory)
|
||||
total_rewards.extend(rewards)
|
||||
trajectory_lengths.extend(lengths)
|
||||
num_timesteps_so_far += len(trajectory["dones"])
|
||||
trajectories.append(trajectory)
|
||||
observation_filter.update(obs_f)
|
||||
reward_filter.update(rew_f)
|
||||
return (concatenate(trajectories), np.mean(total_rewards),
|
||||
np.mean(trajectory_lengths))
|
||||
|
||||
@@ -210,7 +210,9 @@ class Runner(object):
|
||||
add_return_values(trajectory, gamma, self.reward_filter)
|
||||
return trajectory
|
||||
|
||||
def compute_steps(self, gamma, lam, horizon, min_steps_per_task=-1):
|
||||
def compute_steps(
|
||||
self, gamma, lam, horizon, min_steps_per_task,
|
||||
observation_filter, reward_filter):
|
||||
"""Compute multiple rollouts and concatenate the results.
|
||||
|
||||
Args:
|
||||
@@ -219,12 +221,20 @@ class Runner(object):
|
||||
horizon: Number of steps after which a rollout gets cut
|
||||
min_steps_per_task: Lower bound on the number of states to be
|
||||
collected.
|
||||
observation_filter: Function that is applied to each of the
|
||||
observations.
|
||||
reward_filter: Function that is applied to each of the rewards.
|
||||
|
||||
Returns:
|
||||
states: List of states.
|
||||
total_rewards: Total rewards of the trajectories.
|
||||
trajectory_lengths: Lengths of the trajectories.
|
||||
"""
|
||||
|
||||
# Update our local filters
|
||||
self.observation_filter = observation_filter.copy()
|
||||
self.reward_filter = reward_filter.copy()
|
||||
|
||||
num_steps_so_far = 0
|
||||
trajectories = []
|
||||
total_rewards = []
|
||||
@@ -247,7 +257,12 @@ class Runner(object):
|
||||
trajectories.append(trajectory)
|
||||
if num_steps_so_far >= min_steps_per_task:
|
||||
break
|
||||
return concatenate(trajectories), total_rewards, trajectory_lengths
|
||||
return (
|
||||
concatenate(trajectories),
|
||||
total_rewards,
|
||||
trajectory_lengths,
|
||||
self.observation_filter,
|
||||
self.reward_filter)
|
||||
|
||||
|
||||
RemoteRunner = ray.remote(Runner)
|
||||
|
||||
@@ -4,27 +4,37 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
import random
|
||||
|
||||
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
|
||||
from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
|
||||
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
|
||||
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
|
||||
|
||||
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
|
||||
|
||||
|
||||
def get_mean_action(alg, obs):
|
||||
out = []
|
||||
for _ in range(2000):
|
||||
out.append(float(alg.compute_action(obs)))
|
||||
return np.mean(out)
|
||||
|
||||
|
||||
ray.init()
|
||||
for (cls, default_config) in [
|
||||
(DQNAgent, DQN_CONFIG),
|
||||
# TODO(ekl) this fails with multiple ES instances in a process
|
||||
(ESAgent, ES_CONFIG),
|
||||
(PPOAgent, PG_CONFIG),
|
||||
# TODO(ekl) this fails with multiple ES instances in a process
|
||||
# (ESAgent, ES_CONFIG),
|
||||
(A3CAgent, A3C_CONFIG)]:
|
||||
config = default_config.copy()
|
||||
config["num_sgd_iter"] = 5
|
||||
config["episodes_per_batch"] = 100
|
||||
config["timesteps_per_batch"] = 1000
|
||||
alg1 = cls('CartPole-v0', config)
|
||||
alg2 = cls('CartPole-v0', config)
|
||||
alg1 = cls("CartPole-v0", config)
|
||||
alg2 = cls("CartPole-v0", config)
|
||||
|
||||
for _ in range(3):
|
||||
res = alg1.train()
|
||||
@@ -36,9 +46,7 @@ for (cls, default_config) in [
|
||||
for _ in range(10):
|
||||
obs = [
|
||||
random.random(), random.random(), random.random(), random.random()]
|
||||
a1 = alg1.compute_action(obs)
|
||||
a2 = alg2.compute_action(obs)
|
||||
print("Checking computed actions", obs, a1, a2)
|
||||
|
||||
# TODO(ekl) this fails for stochastic policies
|
||||
assert(a1 == a2)
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
assert(abs(a1-a2) < .05)
|
||||
|
||||
Executable
+52
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.ppo import PPOAgent, DEFAULT_CONFIG
|
||||
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = 3
|
||||
config["num_sgd_iter"] = 6
|
||||
config["sgd_batchsize"] = 128
|
||||
config["timesteps_per_batch"] = 4000
|
||||
|
||||
ray.init()
|
||||
|
||||
# first train one agent
|
||||
agent = PPOAgent("CartPole-v0", config)
|
||||
|
||||
for i in range(10):
|
||||
result = agent.train()
|
||||
print(result)
|
||||
|
||||
# checkpoint and restore in a copied agent
|
||||
checkpoint_path = agent.save()
|
||||
trained_config = config.copy()
|
||||
test_agent = PPOAgent("CartPole-v0", trained_config)
|
||||
test_agent.restore(checkpoint_path)
|
||||
|
||||
# evaluate on copied agent
|
||||
results = []
|
||||
env = gym.make("CartPole-v0")
|
||||
for _ in range(20):
|
||||
state = env.reset()
|
||||
done = False
|
||||
cumulative_reward = 0
|
||||
|
||||
while not done:
|
||||
action = test_agent.compute_action(state)
|
||||
state, reward, done, _ = env.step(action)
|
||||
cumulative_reward += reward
|
||||
|
||||
results.append(cumulative_reward)
|
||||
|
||||
print("All results", results)
|
||||
print("Mean result", np.mean(results))
|
||||
|
||||
assert(np.mean(results)) > 0.9 * result.episode_reward_mean
|
||||
Reference in New Issue
Block a user