Files
Run-Skeleton-Run/baselines/baselines_common/mpi_saver.py
Kolesnikov Sergey 7401266fe7 pytorch version
2017-11-15 22:18:46 +03:00

36 lines
1.0 KiB
Python

from mpi4py import MPI
import baselines.baselines_common.tf_util as U
import tensorflow as tf
class MpiSaver(object):
def __init__(self, var_list=None, *,
comm=None,
log_prefix="/tmp"):
self.var_list = var_list
self.t = 0
self.saver = tf.train.Saver(
var_list=var_list,
max_to_keep=100,
keep_checkpoint_every_n_hours=0.25,
pad_step_number=True,
save_relative_paths=True)
self.log_prefix = log_prefix
self.comm = MPI.COMM_WORLD if comm is None else comm
def restore(self, restore_from=None):
if restore_from is not None:
self.saver.restore(U.get_session(), restore_from)
self.t += int(restore_from.split("-")[-1])
self.sync()
def sync(self):
if self.comm.Get_rank() == 0: # this is root
self.saver.save(
U.get_session(),
"{}/model.ckpt".format(self.log_prefix),
global_step=self.t)
self.t += 1