diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index ddb298af7..afda95062 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -2,11 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import pickle -import os import time -import ray from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.optimizers import AsyncGradientsOptimizer @@ -108,28 +105,3 @@ class A3CAgent(Agent): result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) return result - - def _stop(self): - # workaround for https://github.com/ray-project/ray/issues/1516 - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() - - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - agent_state = ray.get( - [a.save.remote() for a in self.remote_evaluators]) - extra_data = { - "remote_state": agent_state, - "local_state": self.local_evaluator.save() - } - pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) - return checkpoint_path - - def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) - ray.get([ - a.restore.remote(o) - for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) - ]) - self.local_evaluator.restore(extra_data["local_state"]) diff --git a/python/ray/rllib/agents/agent.py b/python/ray/rllib/agents/agent.py index 12975d78f..030ae6424 100644 --- a/python/ray/rllib/agents/agent.py +++ b/python/ray/rllib/agents/agent.py @@ -4,13 +4,13 @@ from __future__ import print_function import copy import json -import numpy as np import os import pickle import tempfile from datetime import datetime - import tensorflow as tf + +import ray from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer from ray.rllib.utils import FilterManager, deep_update, merge_dicts @@ -324,95 +324,39 @@ class Agent(Trainable): """ self.local_evaluator.set_weights(weights) + def _stop(self): + # workaround for https://github.com/ray-project/ray/issues/1516 + if hasattr(self, "remote_evaluators"): + for ev in self.remote_evaluators: + ev.__ray_terminate__.remote() -class _MockAgent(Agent): - """Mock agent for use in tests""" + def __getstate__(self): + state = {} + if hasattr(self, "local_evaluator"): + state["evaluator"] = self.local_evaluator.save() + if hasattr(self, "optimizer") and hasattr(self.optimizer, "save"): + state["optimizer"] = self.optimizer.save() + return state - _agent_name = "MockAgent" - _default_config = { - "mock_error": False, - "persistent_error": False, - "test_variable": 1 - } - - def _init(self): - self.info = None - self.restored = False - - def _train(self): - if self.config["mock_error"] and self.iteration == 1 \ - and (self.config["persistent_error"] or not self.restored): - raise Exception("mock error") - return dict( - episode_reward_mean=10, - episode_len_mean=10, - timesteps_this_iter=10, - info={}) + def __setstate__(self, state): + if "evaluator" in state: + self.local_evaluator.restore(state["evaluator"]) + remote_state = ray.put(state["evaluator"]) + for r in self.remote_evaluators: + r.restore.remote(remote_state) + if "optimizer" in state: + self.optimizer.restore(state["optimizer"]) def _save(self, checkpoint_dir): - path = os.path.join(checkpoint_dir, "mock_agent.pkl") - with open(path, 'wb') as f: - pickle.dump(self.info, f) - return path + checkpoint_path = os.path.join(checkpoint_dir, + "checkpoint-{}".format(self.iteration)) + pickle.dump(self.__getstate__(), + open(checkpoint_path + ".agent_state", "wb")) + return checkpoint_path def _restore(self, checkpoint_path): - with open(checkpoint_path, 'rb') as f: - info = pickle.load(f) - self.info = info - self.restored = True - - def set_info(self, info): - self.info = info - return info - - def get_info(self): - return self.info - - -class _SigmoidFakeData(_MockAgent): - """Agent that returns sigmoid learning curves. - - This can be helpful for evaluating early stopping algorithms.""" - - _agent_name = "SigmoidFakeData" - _default_config = { - "width": 100, - "height": 100, - "offset": 0, - "iter_time": 10, - "iter_timesteps": 1, - } - - def _train(self): - i = max(0, self.iteration - self.config["offset"]) - v = np.tanh(float(i) / self.config["width"]) - v *= self.config["height"] - return dict( - episode_reward_mean=v, - episode_len_mean=v, - timesteps_this_iter=self.config["iter_timesteps"], - time_this_iter_s=self.config["iter_time"], - info={}) - - -class _ParameterTuningAgent(_MockAgent): - - _agent_name = "ParameterTuningAgent" - _default_config = { - "reward_amt": 10, - "dummy_param": 10, - "dummy_param2": 15, - "iter_time": 10, - "iter_timesteps": 1 - } - - def _train(self): - return dict( - episode_reward_mean=self.config["reward_amt"] * self.iteration, - episode_len_mean=self.config["reward_amt"], - timesteps_this_iter=self.config["iter_timesteps"], - time_this_iter_s=self.config["iter_time"], - info={}) + extra_data = pickle.load(open(checkpoint_path + ".agent_state", "rb")) + self.__setstate__(extra_data) def get_agent_class(alg): @@ -458,10 +402,13 @@ def get_agent_class(alg): from ray.tune import script_runner return script_runner.ScriptRunner elif alg == "__fake": + from ray.rllib.agents.mock import _MockAgent return _MockAgent elif alg == "__sigmoid_fake_data": + from ray.rllib.agents.mock import _SigmoidFakeData return _SigmoidFakeData elif alg == "__parameter_tuning": + from ray.rllib.agents.mock import _ParameterTuningAgent return _ParameterTuningAgent else: raise Exception(("Unknown algorithm {}.").format(alg)) diff --git a/python/ray/rllib/agents/ars/ars.py b/python/ray/rllib/agents/ars/ars.py index 69b85eeec..e1a945985 100644 --- a/python/ray/rllib/agents/ars/ars.py +++ b/python/ray/rllib/agents/ars/ars.py @@ -8,8 +8,6 @@ from __future__ import print_function from collections import namedtuple import numpy as np -import os -import pickle import time import ray @@ -201,7 +199,6 @@ class ARSAgent(Agent): ] self.episodes_so_far = 0 - self.timesteps_so_far = 0 self.reward_list = [] self.tstart = time.time() @@ -228,7 +225,6 @@ class ARSAgent(Agent): def _train(self): config = self.config - step_tstart = time.time() theta = self.policy.get_weights() assert theta.dtype == np.float32 @@ -259,7 +255,6 @@ class ARSAgent(Agent): len(all_training_lengths)) self.episodes_so_far += num_episodes - self.timesteps_so_far += num_timesteps # Assemble the results. eval_returns = np.array(all_eval_returns) @@ -301,8 +296,6 @@ class ARSAgent(Agent): if len(all_eval_returns) > 0: self.reward_list.append(eval_returns.mean()) - step_tend = time.time() - tlogger.record_tabular("NoisyEpRewMean", noisy_returns.mean()) tlogger.record_tabular("NoisyEpRewStd", noisy_returns.std()) tlogger.record_tabular("NoisyEpLenMean", noisy_lengths.mean()) @@ -319,9 +312,6 @@ class ARSAgent(Agent): "update_ratio": update_ratio, "episodes_this_iter": noisy_lengths.size, "episodes_so_far": self.episodes_so_far, - "timesteps_so_far": self.timesteps_so_far, - "time_elapsed_this_iter": step_tend - step_tstart, - "time_elapsed": step_tend - self.tstart } result = dict( episode_reward_mean=np.mean( @@ -337,19 +327,15 @@ class ARSAgent(Agent): for w in self.workers: w.__ray_terminate__.remote() - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - weights = self.policy.get_weights() - objects = [weights, self.episodes_so_far, self.timesteps_so_far] - pickle.dump(objects, open(checkpoint_path, "wb")) - return checkpoint_path + def __getstate__(self): + return { + "weights": self.policy.get_weights(), + "episodes_so_far": self.episodes_so_far, + } - def _restore(self, checkpoint_path): - objects = pickle.load(open(checkpoint_path, "rb")) - self.policy.set_weights(objects[0]) - self.episodes_so_far = objects[1] - self.timesteps_so_far = objects[2] + def __setstate__(self, state): + self.policy.set_weights(state["weights"]) + self.episodes_so_far = state["episodes_so_far"] def compute_action(self, observation): return self.policy.compute(observation, update=True)[0] diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 7ee8385c2..c945cdbc9 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -2,11 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import pickle -import os import time -import ray from ray.rllib import optimizers from ray.rllib.agents.agent import Agent, with_common_config from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph @@ -249,30 +246,15 @@ class DQNAgent(Agent): }, **self.optimizer.stats())) return result - def _stop(self): - # workaround for https://github.com/ray-project/ray/issues/1516 - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() + def __getstate__(self): + state = Agent.__getstate__(self) + state.update({ + "num_target_updates": self.num_target_updates, + "last_target_update_ts": self.last_target_update_ts, + }) + return state - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - extra_data = [ - self.local_evaluator.save(), - ray.get([e.save.remote() for e in self.remote_evaluators]), - self.optimizer.save(), self.num_target_updates, - self.last_target_update_ts - ] - pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) - return checkpoint_path - - def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) - self.local_evaluator.restore(extra_data[0]) - ray.get([ - e.restore.remote(d) - for (d, e) in zip(extra_data[1], self.remote_evaluators) - ]) - self.optimizer.restore(extra_data[2]) - self.num_target_updates = extra_data[3] - self.last_target_update_ts = extra_data[4] + def __setstate__(self, state): + Agent.__setstate__(self, state) + self.num_target_updates = state["num_target_updates"] + self.last_target_update_ts = state["last_target_update_ts"] diff --git a/python/ray/rllib/agents/es/es.py b/python/ray/rllib/agents/es/es.py index 452918f58..1ce219b7c 100644 --- a/python/ray/rllib/agents/es/es.py +++ b/python/ray/rllib/agents/es/es.py @@ -7,8 +7,6 @@ from __future__ import print_function from collections import namedtuple import numpy as np -import os -import pickle import time import ray @@ -180,7 +178,6 @@ class ESAgent(Agent): ] self.episodes_so_far = 0 - self.timesteps_so_far = 0 self.reward_list = [] self.tstart = time.time() @@ -207,7 +204,6 @@ class ESAgent(Agent): def _train(self): config = self.config - step_tstart = time.time() theta = self.policy.get_weights() assert theta.dtype == np.float32 @@ -238,7 +234,6 @@ class ESAgent(Agent): len(all_training_lengths)) self.episodes_so_far += num_episodes - self.timesteps_so_far += num_timesteps # Assemble the results. eval_returns = np.array(all_eval_returns) @@ -271,7 +266,6 @@ class ESAgent(Agent): if len(all_eval_returns) > 0: self.reward_list.append(np.mean(eval_returns)) - step_tend = time.time() tlogger.record_tabular("EvalEpRewStd", eval_returns.std()) tlogger.record_tabular("EvalEpLenMean", eval_lengths.mean()) @@ -285,11 +279,6 @@ class ESAgent(Agent): tlogger.record_tabular("EpisodesThisIter", noisy_lengths.size) tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far) - tlogger.record_tabular("TimestepsThisIter", noisy_lengths.sum()) - tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far) - - tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart) - tlogger.record_tabular("TimeElapsed", step_tend - self.tstart) tlogger.dump_tabular() info = { @@ -298,10 +287,6 @@ class ESAgent(Agent): "update_ratio": update_ratio, "episodes_this_iter": noisy_lengths.size, "episodes_so_far": self.episodes_so_far, - "timesteps_this_iter": noisy_lengths.sum(), - "timesteps_so_far": self.timesteps_so_far, - "time_elapsed_this_iter": step_tend - step_tstart, - "time_elapsed": step_tend - self.tstart } reward_mean = np.mean(self.reward_list[-self.report_length:]) @@ -318,19 +303,15 @@ class ESAgent(Agent): for w in self.workers: w.__ray_terminate__.remote() - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - weights = self.policy.get_weights() - objects = [weights, self.episodes_so_far, self.timesteps_so_far] - pickle.dump(objects, open(checkpoint_path, "wb")) - return checkpoint_path + def __getstate__(self): + return { + "weights": self.policy.get_weights(), + "episodes_so_far": self.episodes_so_far, + } - def _restore(self, checkpoint_path): - objects = pickle.load(open(checkpoint_path, "rb")) - self.policy.set_weights(objects[0]) - self.episodes_so_far = objects[1] - self.timesteps_so_far = objects[2] + def __setstate__(self, state): + self.policy.set_weights(state["weights"]) + self.episodes_so_far = state["episodes_so_far"] def compute_action(self, observation): return self.policy.compute(observation, update=False)[0] diff --git a/python/ray/rllib/agents/impala/impala.py b/python/ray/rllib/agents/impala/impala.py index eb16eca60..cfa55bd73 100644 --- a/python/ray/rllib/agents/impala/impala.py +++ b/python/ray/rllib/agents/impala/impala.py @@ -2,11 +2,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import pickle -import os import time -import ray from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph from ray.rllib.agents.agent import Agent, with_common_config @@ -99,28 +96,3 @@ class ImpalaAgent(Agent): result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) return result - - def _stop(self): - # workaround for https://github.com/ray-project/ray/issues/1516 - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() - - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - agent_state = ray.get( - [a.save.remote() for a in self.remote_evaluators]) - extra_data = { - "remote_state": agent_state, - "local_state": self.local_evaluator.save() - } - pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) - return checkpoint_path - - def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) - ray.get([ - a.restore.remote(o) - for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) - ]) - self.local_evaluator.restore(extra_data["local_state"]) diff --git a/python/ray/rllib/agents/mock.py b/python/ray/rllib/agents/mock.py new file mode 100644 index 000000000..526ec146a --- /dev/null +++ b/python/ray/rllib/agents/mock.py @@ -0,0 +1,99 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pickle +import numpy as np + +from ray.rllib.agents.agent import Agent + + +class _MockAgent(Agent): + """Mock agent for use in tests""" + + _agent_name = "MockAgent" + _default_config = { + "mock_error": False, + "persistent_error": False, + "test_variable": 1 + } + + def _init(self): + self.info = None + self.restored = False + + def _train(self): + if self.config["mock_error"] and self.iteration == 1 \ + and (self.config["persistent_error"] or not self.restored): + raise Exception("mock error") + return dict( + episode_reward_mean=10, + episode_len_mean=10, + timesteps_this_iter=10, + info={}) + + def _save(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "mock_agent.pkl") + with open(path, 'wb') as f: + pickle.dump(self.info, f) + return path + + def _restore(self, checkpoint_path): + with open(checkpoint_path, 'rb') as f: + info = pickle.load(f) + self.info = info + self.restored = True + + def set_info(self, info): + self.info = info + return info + + def get_info(self): + return self.info + + +class _SigmoidFakeData(_MockAgent): + """Agent that returns sigmoid learning curves. + + This can be helpful for evaluating early stopping algorithms.""" + + _agent_name = "SigmoidFakeData" + _default_config = { + "width": 100, + "height": 100, + "offset": 0, + "iter_time": 10, + "iter_timesteps": 1, + } + + def _train(self): + i = max(0, self.iteration - self.config["offset"]) + v = np.tanh(float(i) / self.config["width"]) + v *= self.config["height"] + return dict( + episode_reward_mean=v, + episode_len_mean=v, + timesteps_this_iter=self.config["iter_timesteps"], + time_this_iter_s=self.config["iter_time"], + info={}) + + +class _ParameterTuningAgent(_MockAgent): + + _agent_name = "ParameterTuningAgent" + _default_config = { + "reward_amt": 10, + "dummy_param": 10, + "dummy_param2": 15, + "iter_time": 10, + "iter_timesteps": 1 + } + + def _train(self): + return dict( + episode_reward_mean=self.config["reward_amt"] * self.iteration, + episode_len_mean=self.config["reward_amt"], + timesteps_this_iter=self.config["iter_timesteps"], + time_this_iter_s=self.config["iter_time"], + info={}) diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index 41c6db5ba..f452f7893 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -2,10 +2,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import pickle - -import ray from ray.rllib.agents import Agent, with_common_config from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph from ray.rllib.utils import merge_dicts @@ -134,25 +130,3 @@ class PPOAgent(Agent): timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=dict(fetches, **res.get("info", {}))) return res - - def _stop(self): - # workaround for https://github.com/ray-project/ray/issues/1516 - for ev in self.remote_evaluators: - ev.__ray_terminate__.remote() - - def _save(self, checkpoint_dir): - checkpoint_path = os.path.join(checkpoint_dir, - "checkpoint-{}".format(self.iteration)) - agent_state = ray.get( - [a.save.remote() for a in self.remote_evaluators]) - extra_data = [self.local_evaluator.save(), agent_state] - pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) - return checkpoint_path - - def _restore(self, checkpoint_path): - extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) - self.local_evaluator.restore(extra_data[0]) - ray.get([ - a.restore.remote(o) - for (a, o) in zip(self.remote_evaluators, extra_data[1]) - ]) diff --git a/python/ray/rllib/evaluation/metrics.py b/python/ray/rllib/evaluation/metrics.py index b2762f4f9..dc71c4ecd 100644 --- a/python/ray/rllib/evaluation/metrics.py +++ b/python/ray/rllib/evaluation/metrics.py @@ -64,5 +64,5 @@ def summarize_episodes(episodes, new_episodes): episode_reward_min=min_reward, episode_reward_mean=avg_reward, episode_len_mean=avg_length, - episodes=len(new_episodes), + episodes_this_iter=len(new_episodes), policy_reward_mean=dict(policy_rewards)) diff --git a/python/ray/rllib/test/test_policy_evaluator.py b/python/ray/rllib/test/test_policy_evaluator.py index b454c4461..cc189edbf 100644 --- a/python/ray/rllib/test/test_policy_evaluator.py +++ b/python/ray/rllib/test/test_policy_evaluator.py @@ -168,7 +168,7 @@ class TestPolicyEvaluator(unittest.TestCase): ev.sample() ray.get(remote_ev.sample.remote()) result = collect_metrics(ev, [remote_ev]) - self.assertEqual(result["episodes"], 20) + self.assertEqual(result["episodes_this_iter"], 20) self.assertEqual(result["episode_reward_mean"], 10) def testAsync(self): @@ -204,12 +204,12 @@ class TestPolicyEvaluator(unittest.TestCase): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) - self.assertEqual(result["episodes"], 0) + self.assertEqual(result["episodes_this_iter"], 0) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) - self.assertEqual(result["episodes"], 8) + self.assertEqual(result["episodes_this_iter"], 8) indices = [] for env in ev.async_env.vector_env.envs: self.assertEqual(env.unwrapped.config.worker_index, 0) @@ -235,10 +235,10 @@ class TestPolicyEvaluator(unittest.TestCase): batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) - self.assertEqual(result["episodes"], 0) + self.assertEqual(result["episodes_this_iter"], 0) batch = ev.sample() result = collect_metrics(ev, []) - self.assertEqual(result["episodes"], 4) + self.assertEqual(result["episodes_this_iter"], 4) def testVectorEnvSupport(self): ev = PolicyEvaluator( @@ -250,12 +250,12 @@ class TestPolicyEvaluator(unittest.TestCase): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) - self.assertEqual(result["episodes"], 0) + self.assertEqual(result["episodes_this_iter"], 0) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) - self.assertEqual(result["episodes"], 8) + self.assertEqual(result["episodes_this_iter"], 8) def testTruncateEpisodes(self): ev = PolicyEvaluator( diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 41eecfd0f..ec307eaed 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -16,6 +16,12 @@ NODE_IP = "node_ip" # (Auto-filled) The pid of the training process. PID = "pid" +# Number of timesteps in this iteration. +EPISODES_THIS_ITER = "episodes_this_iter" + +# (Optional/Auto-filled) Accumulated time in seconds for this experiment. +EPISODES_TOTAL = "episodes_total" + # Number of timesteps in this iteration. TIMESTEPS_THIS_ITER = "timesteps_this_iter" diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 63ffec603..1e537d26d 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -17,7 +17,8 @@ import uuid import ray from ray.tune.logger import UnifiedLogger from ray.tune.result import (DEFAULT_RESULTS_DIR, TIME_THIS_ITER_S, - TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL) + TIMESTEPS_THIS_ITER, DONE, TIMESTEPS_TOTAL, + EPISODES_THIS_ITER, EPISODES_TOTAL) from ray.tune.trial import Resources logger = logging.getLogger(__name__) @@ -77,6 +78,7 @@ class Trainable(object): self._iteration = 0 self._time_total = 0.0 self._timesteps_total = None + self._episodes_total = None self._time_since_restore = 0.0 self._timesteps_since_restore = 0 self._iterations_since_restore = 0 @@ -162,8 +164,15 @@ class Trainable(object): self._timesteps_total += result[TIMESTEPS_THIS_ITER] self._timesteps_since_restore += result[TIMESTEPS_THIS_ITER] + # self._timesteps_total should only be tracked if increments provided + if result.get(EPISODES_THIS_ITER): + if self._episodes_total is None: + self._episodes_total = 0 + self._episodes_total += result[EPISODES_THIS_ITER] + # self._timesteps_total should not override user-provided total result.setdefault(TIMESTEPS_TOTAL, self._timesteps_total) + result.setdefault(EPISODES_TOTAL, self._episodes_total) # Provides auto-filled neg_mean_loss for avoiding regressions if result.get("mean_loss"): @@ -205,7 +214,7 @@ class Trainable(object): checkpoint_path = self._save(checkpoint_dir or self.logdir) pickle.dump([ self._experiment_id, self._iteration, self._timesteps_total, - self._time_total + self._time_total, self._episodes_total ], open(checkpoint_path + ".tune_metadata", "wb")) return checkpoint_path @@ -256,6 +265,7 @@ class Trainable(object): self._iteration = metadata[1] self._timesteps_total = metadata[2] self._time_total = metadata[3] + self._episodes_total = metadata[4] self._restored = True def restore_from_object(self, obj):