[rllib] Fix issues with PPO model restoration (#1018)

* fix filter

* add test

* lint

* fix

* commit

* Update a3c.py
This commit is contained in:
Eric Liang
2017-09-28 13:12:06 -07:00
committed by Philipp Moritz
parent 427dee511b
commit 19562f6ce5
8 changed files with 137 additions and 26 deletions
+2 -2
View File
@@ -156,5 +156,5 @@ class A3CAgent(Agent):
self.policy.set_weights(self.parameters)
def compute_action(self, observation):
actions = self.policy.compute_actions(observation)[0]
return actions.argmax()
actions = self.policy.compute_actions(observation)
return actions[0]
+33 -1
View File
@@ -12,6 +12,12 @@ class NoFilter(object):
def __call__(self, x, update=True):
return np.asarray(x)
def update(self, other):
pass
def copy(self):
return self
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
@@ -21,6 +27,13 @@ class RunningStat(object):
self._M = np.zeros(shape)
self._S = np.zeros(shape)
def copy(self):
other = RunningStat()
other._n = self._n
other._M = np.copy(self._M)
other._S = np.copy(self._S)
return other
def push(self, x):
x = np.asarray(x)
# Unvectorized update of the running statistics.
@@ -47,6 +60,10 @@ class RunningStat(object):
self._M = M
self._S = S
def __repr__(self):
return '(n={}, mean_mean={}, mean_std={})'.format(
self.n, np.mean(self.mean), np.mean(self.std))
@property
def n(self):
return self._n
@@ -70,12 +87,23 @@ class RunningStat(object):
class MeanStdFilter(object):
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.shape = shape
self.demean = demean
self.destd = destd
self.clip = clip
self.rs = RunningStat(shape)
def update(self, other):
self.rs.update(other.rs)
def copy(self):
other = MeanStdFilter(self.shape)
other.demean = self.demean
other.destd = self.destd
other.clip = self.clip
other.rs = self.rs.copy()
return other
def __call__(self, x, update=True):
x = np.asarray(x)
if update:
@@ -94,6 +122,10 @@ class MeanStdFilter(object):
x = np.clip(x, -self.clip, self.clip)
return x
def __repr__(self):
return 'MeanStdFilter({}, {}, {}, {}, {})'.format(
self.shape, self.demean, self.destd, self.clip, self.rs)
def test_running_stat():
for shp in ((), (3,), (3, 4)):
+3 -2
View File
@@ -116,7 +116,8 @@ class PPOAgent(Agent):
weights = ray.put(model.get_weights())
[a.load_weights.remote(weights) for a in agents]
trajectory, total_reward, traj_len_mean = collect_samples(
agents, config)
agents, config, self.model.observation_filter,
self.model.reward_filter)
print("total reward is ", total_reward)
print("trajectory length mean is ", traj_len_mean)
print("timesteps:", trajectory["dones"].shape[0])
@@ -269,5 +270,5 @@ class PPOAgent(Agent):
for (a, o) in zip(self.agents, extra_data[3])])
def compute_action(self, observation):
observation = self.model.observation_filter(observation)
observation = self.model.observation_filter(observation, update=False)
return self.model.common_policy.compute([observation])[0][0]
+10 -8
View File
@@ -5,12 +5,10 @@ from __future__ import print_function
import numpy as np
import ray
from ray.rllib.ppo.filter import NoFilter
from ray.rllib.ppo.utils import concatenate
def rollouts(policy, env, horizon, observation_filter=NoFilter(),
reward_filter=NoFilter()):
def rollouts(policy, env, horizon, observation_filter, reward_filter):
"""Perform a batch of rollouts of a policy in an environment.
Args:
@@ -98,8 +96,8 @@ def add_advantage_values(trajectory, gamma, lam, reward_filter):
def collect_samples(agents,
config,
observation_filter=NoFilter(),
reward_filter=NoFilter()):
observation_filter,
reward_filter):
num_timesteps_so_far = 0
trajectories = []
total_rewards = []
@@ -109,7 +107,8 @@ def collect_samples(agents,
# tasks here.
agent_dict = {agent.compute_steps.remote(
config["gamma"], config["lambda"],
config["horizon"], config["min_steps_per_task"]):
config["horizon"], config["min_steps_per_task"],
observation_filter, reward_filter):
agent for agent in agents}
while num_timesteps_so_far < config["timesteps_per_batch"]:
# TODO(pcm): Make wait support arbitrary iterators and remove the
@@ -120,12 +119,15 @@ def collect_samples(agents,
# Start task with next trajectory and record it in the dictionary.
agent_dict[agent.compute_steps.remote(
config["gamma"], config["lambda"],
config["horizon"], config["min_steps_per_task"])] = (
config["horizon"], config["min_steps_per_task"],
observation_filter, reward_filter)] = (
agent)
trajectory, rewards, lengths = ray.get(next_trajectory)
trajectory, rewards, lengths, obs_f, rew_f = ray.get(next_trajectory)
total_rewards.extend(rewards)
trajectory_lengths.extend(lengths)
num_timesteps_so_far += len(trajectory["dones"])
trajectories.append(trajectory)
observation_filter.update(obs_f)
reward_filter.update(rew_f)
return (concatenate(trajectories), np.mean(total_rewards),
np.mean(trajectory_lengths))
+17 -2
View File
@@ -210,7 +210,9 @@ class Runner(object):
add_return_values(trajectory, gamma, self.reward_filter)
return trajectory
def compute_steps(self, gamma, lam, horizon, min_steps_per_task=-1):
def compute_steps(
self, gamma, lam, horizon, min_steps_per_task,
observation_filter, reward_filter):
"""Compute multiple rollouts and concatenate the results.
Args:
@@ -219,12 +221,20 @@ class Runner(object):
horizon: Number of steps after which a rollout gets cut
min_steps_per_task: Lower bound on the number of states to be
collected.
observation_filter: Function that is applied to each of the
observations.
reward_filter: Function that is applied to each of the rewards.
Returns:
states: List of states.
total_rewards: Total rewards of the trajectories.
trajectory_lengths: Lengths of the trajectories.
"""
# Update our local filters
self.observation_filter = observation_filter.copy()
self.reward_filter = reward_filter.copy()
num_steps_so_far = 0
trajectories = []
total_rewards = []
@@ -247,7 +257,12 @@ class Runner(object):
trajectories.append(trajectory)
if num_steps_so_far >= min_steps_per_task:
break
return concatenate(trajectories), total_rewards, trajectory_lengths
return (
concatenate(trajectories),
total_rewards,
trajectory_lengths,
self.observation_filter,
self.reward_filter)
RemoteRunner = ray.remote(Runner)
@@ -4,27 +4,37 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import ray
import random
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
def get_mean_action(alg, obs):
out = []
for _ in range(2000):
out.append(float(alg.compute_action(obs)))
return np.mean(out)
ray.init()
for (cls, default_config) in [
(DQNAgent, DQN_CONFIG),
# TODO(ekl) this fails with multiple ES instances in a process
(ESAgent, ES_CONFIG),
(PPOAgent, PG_CONFIG),
# TODO(ekl) this fails with multiple ES instances in a process
# (ESAgent, ES_CONFIG),
(A3CAgent, A3C_CONFIG)]:
config = default_config.copy()
config["num_sgd_iter"] = 5
config["episodes_per_batch"] = 100
config["timesteps_per_batch"] = 1000
alg1 = cls('CartPole-v0', config)
alg2 = cls('CartPole-v0', config)
alg1 = cls("CartPole-v0", config)
alg2 = cls("CartPole-v0", config)
for _ in range(3):
res = alg1.train()
@@ -36,9 +46,7 @@ for (cls, default_config) in [
for _ in range(10):
obs = [
random.random(), random.random(), random.random(), random.random()]
a1 = alg1.compute_action(obs)
a2 = alg2.compute_action(obs)
print("Checking computed actions", obs, a1, a2)
# TODO(ekl) this fails for stochastic policies
assert(a1 == a2)
a1 = get_mean_action(alg1, obs)
a2 = get_mean_action(alg2, obs)
print("Checking computed actions", alg1, obs, a1, a2)
assert(abs(a1-a2) < .05)
+52
View File
@@ -0,0 +1,52 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import gym
import ray
from ray.rllib.ppo import PPOAgent, DEFAULT_CONFIG
config = DEFAULT_CONFIG.copy()
config["num_workers"] = 3
config["num_sgd_iter"] = 6
config["sgd_batchsize"] = 128
config["timesteps_per_batch"] = 4000
ray.init()
# first train one agent
agent = PPOAgent("CartPole-v0", config)
for i in range(10):
result = agent.train()
print(result)
# checkpoint and restore in a copied agent
checkpoint_path = agent.save()
trained_config = config.copy()
test_agent = PPOAgent("CartPole-v0", trained_config)
test_agent.restore(checkpoint_path)
# evaluate on copied agent
results = []
env = gym.make("CartPole-v0")
for _ in range(20):
state = env.reset()
done = False
cumulative_reward = 0
while not done:
action = test_agent.compute_action(state)
state, reward, done, _ = env.step(action)
cumulative_reward += reward
results.append(cumulative_reward)
print("All results", results)
print("Mean result", np.mean(results))
assert(np.mean(results)) > 0.9 * result.episode_reward_mean