mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 15:40:09 +08:00
[rllib] Fix DQN checkpoint/restore and enable test in jenkins (#1063)
* fix dqn restore and add test * Update .gitignore * Update test_checkpoint_restore.py * add checkpoint restore
This commit is contained in:
@@ -224,7 +224,8 @@ class Actor(object):
|
||||
self.episode_rewards,
|
||||
self.episode_lengths,
|
||||
self.saved_mean_reward,
|
||||
self.obs]
|
||||
self.obs,
|
||||
self.replay_buffer]
|
||||
|
||||
def restore(self, data):
|
||||
self.beta_schedule = data[0]
|
||||
@@ -233,6 +234,7 @@ class Actor(object):
|
||||
self.episode_lengths = data[3]
|
||||
self.saved_mean_reward = data[4]
|
||||
self.obs = data[5]
|
||||
self.replay_buffer = data[6]
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -367,7 +369,7 @@ class DQNAgent(Agent):
|
||||
global_step=self.num_iterations)
|
||||
extra_data = [
|
||||
self.actor.save(),
|
||||
self.replay_buffer,
|
||||
ray.get([w.save.remote() for w in self.workers]),
|
||||
self.cur_timestep,
|
||||
self.num_iterations,
|
||||
self.num_target_updates,
|
||||
@@ -376,10 +378,12 @@ class DQNAgent(Agent):
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
self.saver.restore(self.sess, checkpoint_path)
|
||||
self.saver.restore(self.actor.sess, checkpoint_path)
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.actor.restore(extra_data[0])
|
||||
self.replay_buffer = extra_data[1]
|
||||
ray.get([
|
||||
w.restore.remote(d) for (d, w)
|
||||
in zip(extra_data[1], self.workers)])
|
||||
self.cur_timestep = extra_data[2]
|
||||
self.num_iterations = extra_data[3]
|
||||
self.num_target_updates = extra_data[4]
|
||||
|
||||
@@ -11,7 +11,6 @@ import random
|
||||
from ray.rllib.dqn import (DQNAgent, DEFAULT_CONFIG as DQN_CONFIG)
|
||||
from ray.rllib.ppo import (PPOAgent, DEFAULT_CONFIG as PG_CONFIG)
|
||||
from ray.rllib.a3c import (A3CAgent, DEFAULT_CONFIG as A3C_CONFIG)
|
||||
|
||||
# from ray.rllib.es import (ESAgent, DEFAULT_CONFIG as ES_CONFIG)
|
||||
|
||||
|
||||
@@ -26,11 +25,13 @@ ray.init()
|
||||
for (cls, default_config) in [
|
||||
(DQNAgent, DQN_CONFIG),
|
||||
(PPOAgent, PG_CONFIG),
|
||||
# TODO(ekl) this fails with multiple ES instances in a process
|
||||
(A3CAgent, A3C_CONFIG),
|
||||
# https://github.com/ray-project/ray/issues/1062
|
||||
# (ESAgent, ES_CONFIG),
|
||||
(A3CAgent, A3C_CONFIG)]:
|
||||
]:
|
||||
config = default_config.copy()
|
||||
config["num_sgd_iter"] = 5
|
||||
config["use_lstm"] = False # for a3c
|
||||
config["episodes_per_batch"] = 100
|
||||
config["timesteps_per_batch"] = 1000
|
||||
alg1 = cls("CartPole-v0", config)
|
||||
@@ -49,4 +50,4 @@ for (cls, default_config) in [
|
||||
a1 = get_mean_action(alg1, obs)
|
||||
a2 = get_mean_action(alg2, obs)
|
||||
print("Checking computed actions", alg1, obs, a1, a2)
|
||||
assert(abs(a1-a2) < .05)
|
||||
assert abs(a1-a2) < .1, (a1, a2)
|
||||
|
||||
Reference in New Issue
Block a user