mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[rllib] Refactor save() / restore() code of agents and avoid O(n_workers) save size (#2982)
This commit is contained in:
@@ -2,11 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
@@ -108,28 +105,3 @@ class A3CAgent(Agent):
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
extra_data = {
|
||||
"remote_state": agent_state,
|
||||
"local_state": self.local_evaluator.save()
|
||||
}
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
ray.get([
|
||||
a.restore.remote(o)
|
||||
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
|
||||
])
|
||||
self.local_evaluator.restore(extra_data["local_state"])
|
||||
|
||||
@@ -4,13 +4,13 @@ from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
@@ -324,95 +324,39 @@ class Agent(Trainable):
|
||||
"""
|
||||
self.local_evaluator.set_weights(weights)
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
if hasattr(self, "remote_evaluators"):
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
class _MockAgent(Agent):
|
||||
"""Mock agent for use in tests"""
|
||||
def __getstate__(self):
|
||||
state = {}
|
||||
if hasattr(self, "local_evaluator"):
|
||||
state["evaluator"] = self.local_evaluator.save()
|
||||
if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"):
|
||||
state["optimizer"] = self.optimizer.save()
|
||||
return state
|
||||
|
||||
_agent_name = "MockAgent"
|
||||
_default_config = {
|
||||
"mock_error": False,
|
||||
"persistent_error": False,
|
||||
"test_variable": 1
|
||||
}
|
||||
|
||||
def _init(self):
|
||||
self.info = None
|
||||
self.restored = False
|
||||
|
||||
def _train(self):
|
||||
if self.config["mock_error"] and self.iteration == 1 \
|
||||
and (self.config["persistent_error"] or not self.restored):
|
||||
raise Exception("mock error")
|
||||
return dict(
|
||||
episode_reward_mean=10,
|
||||
episode_len_mean=10,
|
||||
timesteps_this_iter=10,
|
||||
info={})
|
||||
def __setstate__(self, state):
|
||||
if "evaluator" in state:
|
||||
self.local_evaluator.restore(state["evaluator"])
|
||||
remote_state = ray.put(state["evaluator"])
|
||||
for r in self.remote_evaluators:
|
||||
r.restore.remote(remote_state)
|
||||
if "optimizer" in state:
|
||||
self.optimizer.restore(state["optimizer"])
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.info, f)
|
||||
return path
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
pickle.dump(self.__getstate__(),
|
||||
open(checkpoint_path + ".agent_state", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, 'rb') as f:
|
||||
info = pickle.load(f)
|
||||
self.info = info
|
||||
self.restored = True
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
return info
|
||||
|
||||
def get_info(self):
|
||||
return self.info
|
||||
|
||||
|
||||
class _SigmoidFakeData(_MockAgent):
|
||||
"""Agent that returns sigmoid learning curves.
|
||||
|
||||
This can be helpful for evaluating early stopping algorithms."""
|
||||
|
||||
_agent_name = "SigmoidFakeData"
|
||||
_default_config = {
|
||||
"width": 100,
|
||||
"height": 100,
|
||||
"offset": 0,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1,
|
||||
}
|
||||
|
||||
def _train(self):
|
||||
i = max(0, self.iteration - self.config["offset"])
|
||||
v = np.tanh(float(i) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
return dict(
|
||||
episode_reward_mean=v,
|
||||
episode_len_mean=v,
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={})
|
||||
|
||||
|
||||
class _ParameterTuningAgent(_MockAgent):
|
||||
|
||||
_agent_name = "ParameterTuningAgent"
|
||||
_default_config = {
|
||||
"reward_amt": 10,
|
||||
"dummy_param": 10,
|
||||
"dummy_param2": 15,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1
|
||||
}
|
||||
|
||||
def _train(self):
|
||||
return dict(
|
||||
episode_reward_mean=self.config["reward_amt"] * self.iteration,
|
||||
episode_len_mean=self.config["reward_amt"],
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={})
|
||||
extra_data = pickle.load(open(checkpoint_path + ".agent_state", "rb"))
|
||||
self.__setstate__(extra_data)
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
@@ -458,10 +402,13 @@ def get_agent_class(alg):
|
||||
from ray.tune import script_runner
|
||||
return script_runner.ScriptRunner
|
||||
elif alg == "__fake":
|
||||
from ray.rllib.agents.mock import _MockAgent
|
||||
return _MockAgent
|
||||
elif alg == "__sigmoid_fake_data":
|
||||
from ray.rllib.agents.mock import _SigmoidFakeData
|
||||
return _SigmoidFakeData
|
||||
elif alg == "__parameter_tuning":
|
||||
from ray.rllib.agents.mock import _ParameterTuningAgent
|
||||
return _ParameterTuningAgent
|
||||
else:
|
||||
raise Exception(("Unknown algorithm {}.").format(alg))
|
||||
|
||||
@@ -8,8 +8,6 @@ from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import ray
|
||||
@@ -201,7 +199,6 @@ class ARSAgent(Agent):
|
||||
]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
self.timesteps_so_far = 0
|
||||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@@ -228,7 +225,6 @@ class ARSAgent(Agent):
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
step_tstart = time.time()
|
||||
theta = self.policy.get_weights()
|
||||
assert theta.dtype == np.float32
|
||||
|
||||
@@ -259,7 +255,6 @@ class ARSAgent(Agent):
|
||||
len(all_training_lengths))
|
||||
|
||||
self.episodes_so_far += num_episodes
|
||||
self.timesteps_so_far += num_timesteps
|
||||
|
||||
# Assemble the results.
|
||||
eval_returns = np.array(all_eval_returns)
|
||||
@@ -301,8 +296,6 @@ class ARSAgent(Agent):
|
||||
if len(all_eval_returns) > 0:
|
||||
self.reward_list.append(eval_returns.mean())
|
||||
|
||||
step_tend = time.time()
|
||||
|
||||
tlogger.record_tabular("NoisyEpRewMean", noisy_returns.mean())
|
||||
tlogger.record_tabular("NoisyEpRewStd", noisy_returns.std())
|
||||
tlogger.record_tabular("NoisyEpLenMean", noisy_lengths.mean())
|
||||
@@ -319,9 +312,6 @@ class ARSAgent(Agent):
|
||||
"update_ratio": update_ratio,
|
||||
"episodes_this_iter": noisy_lengths.size,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
"timesteps_so_far": self.timesteps_so_far,
|
||||
"time_elapsed_this_iter": step_tend - step_tstart,
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
result = dict(
|
||||
episode_reward_mean=np.mean(
|
||||
@@ -337,19 +327,15 @@ class ARSAgent(Agent):
|
||||
for w in self.workers:
|
||||
w.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
weights = self.policy.get_weights()
|
||||
objects = [weights, self.episodes_so_far, self.timesteps_so_far]
|
||||
pickle.dump(objects, open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"weights": self.policy.get_weights(),
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
}
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
objects = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.policy.set_weights(objects[0])
|
||||
self.episodes_so_far = objects[1]
|
||||
self.timesteps_so_far = objects[2]
|
||||
def __setstate__(self, state):
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.episodes_so_far = state["episodes_so_far"]
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.policy.compute(observation, update=True)[0]
|
||||
|
||||
@@ -2,11 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib import optimizers
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
@@ -249,30 +246,15 @@ class DQNAgent(Agent):
|
||||
}, **self.optimizer.stats()))
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
def __getstate__(self):
|
||||
state = Agent.__getstate__(self)
|
||||
state.update({
|
||||
"num_target_updates": self.num_target_updates,
|
||||
"last_target_update_ts": self.last_target_update_ts,
|
||||
})
|
||||
return state
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
extra_data = [
|
||||
self.local_evaluator.save(),
|
||||
ray.get([e.save.remote() for e in self.remote_evaluators]),
|
||||
self.optimizer.save(), self.num_target_updates,
|
||||
self.last_target_update_ts
|
||||
]
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.local_evaluator.restore(extra_data[0])
|
||||
ray.get([
|
||||
e.restore.remote(d)
|
||||
for (d, e) in zip(extra_data[1], self.remote_evaluators)
|
||||
])
|
||||
self.optimizer.restore(extra_data[2])
|
||||
self.num_target_updates = extra_data[3]
|
||||
self.last_target_update_ts = extra_data[4]
|
||||
def __setstate__(self, state):
|
||||
Agent.__setstate__(self, state)
|
||||
self.num_target_updates = state["num_target_updates"]
|
||||
self.last_target_update_ts = state["last_target_update_ts"]
|
||||
|
||||
@@ -7,8 +7,6 @@ from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
|
||||
import ray
|
||||
@@ -180,7 +178,6 @@ class ESAgent(Agent):
|
||||
]
|
||||
|
||||
self.episodes_so_far = 0
|
||||
self.timesteps_so_far = 0
|
||||
self.reward_list = []
|
||||
self.tstart = time.time()
|
||||
|
||||
@@ -207,7 +204,6 @@ class ESAgent(Agent):
|
||||
def _train(self):
|
||||
config = self.config
|
||||
|
||||
step_tstart = time.time()
|
||||
theta = self.policy.get_weights()
|
||||
assert theta.dtype == np.float32
|
||||
|
||||
@@ -238,7 +234,6 @@ class ESAgent(Agent):
|
||||
len(all_training_lengths))
|
||||
|
||||
self.episodes_so_far += num_episodes
|
||||
self.timesteps_so_far += num_timesteps
|
||||
|
||||
# Assemble the results.
|
||||
eval_returns = np.array(all_eval_returns)
|
||||
@@ -271,7 +266,6 @@ class ESAgent(Agent):
|
||||
if len(all_eval_returns) > 0:
|
||||
self.reward_list.append(np.mean(eval_returns))
|
||||
|
||||
step_tend = time.time()
|
||||
tlogger.record_tabular("EvalEpRewStd", eval_returns.std())
|
||||
tlogger.record_tabular("EvalEpLenMean", eval_lengths.mean())
|
||||
|
||||
@@ -285,11 +279,6 @@ class ESAgent(Agent):
|
||||
|
||||
tlogger.record_tabular("EpisodesThisIter", noisy_lengths.size)
|
||||
tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far)
|
||||
tlogger.record_tabular("TimestepsThisIter", noisy_lengths.sum())
|
||||
tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
|
||||
|
||||
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
|
||||
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
|
||||
tlogger.dump_tabular()
|
||||
|
||||
info = {
|
||||
@@ -298,10 +287,6 @@ class ESAgent(Agent):
|
||||
"update_ratio": update_ratio,
|
||||
"episodes_this_iter": noisy_lengths.size,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
"timesteps_this_iter": noisy_lengths.sum(),
|
||||
"timesteps_so_far": self.timesteps_so_far,
|
||||
"time_elapsed_this_iter": step_tend - step_tstart,
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
|
||||
reward_mean = np.mean(self.reward_list[-self.report_length:])
|
||||
@@ -318,19 +303,15 @@ class ESAgent(Agent):
|
||||
for w in self.workers:
|
||||
w.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
weights = self.policy.get_weights()
|
||||
objects = [weights, self.episodes_so_far, self.timesteps_so_far]
|
||||
pickle.dump(objects, open(checkpoint_path, "wb"))
|
||||
return checkpoint_path
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"weights": self.policy.get_weights(),
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
}
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
objects = pickle.load(open(checkpoint_path, "rb"))
|
||||
self.policy.set_weights(objects[0])
|
||||
self.episodes_so_far = objects[1]
|
||||
self.timesteps_so_far = objects[2]
|
||||
def __setstate__(self, state):
|
||||
self.policy.set_weights(state["weights"])
|
||||
self.episodes_so_far = state["episodes_so_far"]
|
||||
|
||||
def compute_action(self, observation):
|
||||
return self.policy.compute(observation, update=False)[0]
|
||||
|
||||
@@ -2,11 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pickle
|
||||
import os
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
|
||||
from ray.rllib.agents.agent import Agent, with_common_config
|
||||
@@ -99,28 +96,3 @@ class ImpalaAgent(Agent):
|
||||
result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
|
||||
prev_steps)
|
||||
return result
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
extra_data = {
|
||||
"remote_state": agent_state,
|
||||
"local_state": self.local_evaluator.save()
|
||||
}
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
ray.get([
|
||||
a.restore.remote(o)
|
||||
for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
|
||||
])
|
||||
self.local_evaluator.restore(extra_data["local_state"])
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.agents.agent import Agent
|
||||
|
||||
|
||||
class _MockAgent(Agent):
|
||||
"""Mock agent for use in tests"""
|
||||
|
||||
_agent_name = "MockAgent"
|
||||
_default_config = {
|
||||
"mock_error": False,
|
||||
"persistent_error": False,
|
||||
"test_variable": 1
|
||||
}
|
||||
|
||||
def _init(self):
|
||||
self.info = None
|
||||
self.restored = False
|
||||
|
||||
def _train(self):
|
||||
if self.config["mock_error"] and self.iteration == 1 \
|
||||
and (self.config["persistent_error"] or not self.restored):
|
||||
raise Exception("mock error")
|
||||
return dict(
|
||||
episode_reward_mean=10,
|
||||
episode_len_mean=10,
|
||||
timesteps_this_iter=10,
|
||||
info={})
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.info, f)
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, 'rb') as f:
|
||||
info = pickle.load(f)
|
||||
self.info = info
|
||||
self.restored = True
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
return info
|
||||
|
||||
def get_info(self):
|
||||
return self.info
|
||||
|
||||
|
||||
class _SigmoidFakeData(_MockAgent):
|
||||
"""Agent that returns sigmoid learning curves.
|
||||
|
||||
This can be helpful for evaluating early stopping algorithms."""
|
||||
|
||||
_agent_name = "SigmoidFakeData"
|
||||
_default_config = {
|
||||
"width": 100,
|
||||
"height": 100,
|
||||
"offset": 0,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1,
|
||||
}
|
||||
|
||||
def _train(self):
|
||||
i = max(0, self.iteration - self.config["offset"])
|
||||
v = np.tanh(float(i) / self.config["width"])
|
||||
v *= self.config["height"]
|
||||
return dict(
|
||||
episode_reward_mean=v,
|
||||
episode_len_mean=v,
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={})
|
||||
|
||||
|
||||
class _ParameterTuningAgent(_MockAgent):
|
||||
|
||||
_agent_name = "ParameterTuningAgent"
|
||||
_default_config = {
|
||||
"reward_amt": 10,
|
||||
"dummy_param": 10,
|
||||
"dummy_param2": 15,
|
||||
"iter_time": 10,
|
||||
"iter_timesteps": 1
|
||||
}
|
||||
|
||||
def _train(self):
|
||||
return dict(
|
||||
episode_reward_mean=self.config["reward_amt"] * self.iteration,
|
||||
episode_len_mean=self.config["reward_amt"],
|
||||
timesteps_this_iter=self.config["iter_timesteps"],
|
||||
time_this_iter_s=self.config["iter_time"],
|
||||
info={})
|
||||
@@ -2,10 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents import Agent, with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
|
||||
from ray.rllib.utils import merge_dicts
|
||||
@@ -134,25 +130,3 @@ class PPOAgent(Agent):
|
||||
timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps,
|
||||
info=dict(fetches, **res.get("info", {})))
|
||||
return res
|
||||
|
||||
def _stop(self):
|
||||
# workaround for https://github.com/ray-project/ray/issues/1516
|
||||
for ev in self.remote_evaluators:
|
||||
ev.__ray_terminate__.remote()
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
checkpoint_path = os.path.join(checkpoint_dir,
|
||||
"checkpoint-{}".format(self.iteration))
|
||||
agent_state = ray.get(
|
||||
[a.save.remote() for a in self.remote_evaluators])
|
||||
extra_data = [self.local_evaluator.save(), agent_state]
|
||||
pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
|
||||
self.local_evaluator.restore(extra_data[0])
|
||||
ray.get([
|
||||
a.restore.remote(o)
|
||||
for (a, o) in zip(self.remote_evaluators, extra_data[1])
|
||||
])
|
||||
|
||||
@@ -64,5 +64,5 @@ def summarize_episodes(episodes, new_episodes):
|
||||
episode_reward_min=min_reward,
|
||||
episode_reward_mean=avg_reward,
|
||||
episode_len_mean=avg_length,
|
||||
episodes=len(new_episodes),
|
||||
episodes_this_iter=len(new_episodes),
|
||||
policy_reward_mean=dict(policy_rewards))
|
||||
|
||||
@@ -168,7 +168,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
ev.sample()
|
||||
ray.get(remote_ev.sample.remote())
|
||||
result = collect_metrics(ev, [remote_ev])
|
||||
self.assertEqual(result["episodes"], 20)
|
||||
self.assertEqual(result["episodes_this_iter"], 20)
|
||||
self.assertEqual(result["episode_reward_mean"], 10)
|
||||
|
||||
def testAsync(self):
|
||||
@@ -204,12 +204,12 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes"], 0)
|
||||
self.assertEqual(result["episodes_this_iter"], 0)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes"], 8)
|
||||
self.assertEqual(result["episodes_this_iter"], 8)
|
||||
indices = []
|
||||
for env in ev.async_env.vector_env.envs:
|
||||
self.assertEqual(env.unwrapped.config.worker_index, 0)
|
||||
@@ -235,10 +235,10 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 16)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes"], 0)
|
||||
self.assertEqual(result["episodes_this_iter"], 0)
|
||||
batch = ev.sample()
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes"], 4)
|
||||
self.assertEqual(result["episodes_this_iter"], 4)
|
||||
|
||||
def testVectorEnvSupport(self):
|
||||
ev = PolicyEvaluator(
|
||||
@@ -250,12 +250,12 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes"], 0)
|
||||
self.assertEqual(result["episodes_this_iter"], 0)
|
||||
for _ in range(8):
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 10)
|
||||
result = collect_metrics(ev, [])
|
||||
self.assertEqual(result["episodes"], 8)
|
||||
self.assertEqual(result["episodes_this_iter"], 8)
|
||||
|
||||
def testTruncateEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
|
||||
@@ -16,6 +16,12 @@ NODE_IP = "node_ip"
|
||||
# (Auto-filled) The pid of the training process.
|
||||
PID = "pid"
|
||||
|
||||
# Number of timesteps in this iteration.
|
||||
EPISODES_THIS_ITER = "episodes_this_iter"
|
||||
|
||||
# (Optional/Auto-filled) Accumulated time in seconds for this experiment.
|
||||
EPISODES_TOTAL = "episodes_total"
|
||||
|
||||
# Number of timesteps in this iteration.
|
||||
TIMESTEPS_THIS_ITER = "timesteps_this_iter"
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ import uuid
|
||||
import ray
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S,
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL)
|
||||
TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL,
|
||||
EPISODES_THIS_ITER, EPISODES_TOTAL)
|
||||
from ray.tune.trial import Resources
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -77,6 +78,7 @@ class Trainable(object):
|
||||
self._iteration = 0
|
||||
self._time_total = 0.0
|
||||
self._timesteps_total = None
|
||||
self._episodes_total = None
|
||||
self._time_since_restore = 0.0
|
||||
self._timesteps_since_restore = 0
|
||||
self._iterations_since_restore = 0
|
||||
@@ -162,8 +164,15 @@ class Trainable(object):
|
||||
self._timesteps_total += result[TIMESTEPS_THIS_ITER]
|
||||
self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER]
|
||||
|
||||
# self._timesteps_total should only be tracked if increments provided
|
||||
if result.get(EPISODES_THIS_ITER):
|
||||
if self._episodes_total is None:
|
||||
self._episodes_total = 0
|
||||
self._episodes_total += result[EPISODES_THIS_ITER]
|
||||
|
||||
# self._timesteps_total should not override user-provided total
|
||||
result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total)
|
||||
result.setdefault(EPISODES_TOTAL, self._episodes_total)
|
||||
|
||||
# Provides auto-filled neg_mean_loss for avoiding regressions
|
||||
if result.get("mean_loss"):
|
||||
@@ -205,7 +214,7 @@ class Trainable(object):
|
||||
checkpoint_path = self._save(checkpoint_dir or self.logdir)
|
||||
pickle.dump([
|
||||
self._experiment_id, self._iteration, self._timesteps_total,
|
||||
self._time_total
|
||||
self._time_total, self._episodes_total
|
||||
], open(checkpoint_path + ".tune_metadata", "wb"))
|
||||
return checkpoint_path
|
||||
|
||||
@@ -256,6 +265,7 @@ class Trainable(object):
|
||||
self._iteration = metadata[1]
|
||||
self._timesteps_total = metadata[2]
|
||||
self._time_total = metadata[3]
|
||||
self._episodes_total = metadata[4]
|
||||
self._restored = True
|
||||
|
||||
def restore_from_object(self, obj):
|
||||
|
||||
Reference in New Issue
Block a user