[tune] distributed torch wrapper (#9550)

* changes

* add-working

* checkpoint

* ccleanu

* fix

* ok

* formatting

* ok

* tests

* some-good-stuff

* fix-torch

* ddp-torch

* torch-test

* sessions

* add-small-test

* fix

* remove

* gpu-working

* update-tests

* ok

* try-test

* formgat

* ok

* ok
This commit is contained in:
Richard Liaw
2020-07-26 09:37:22 -07:00
committed by GitHub
parent c6a7b3ac68
commit f3fdb5c5db
13 changed files with 515 additions and 52 deletions
+6 -7
View File
@@ -1,6 +1,5 @@
import logging
import os
import io
import time
import inspect
import shutil
@@ -87,6 +86,7 @@ class StatusReporter:
def make_checkpoint_dir(self, step=None):
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index=step)
logger.debug("Making checkpoint dir at %s", checkpoint_dir)
return checkpoint_dir
def save_checkpoint(self, checkpoint):
@@ -279,6 +279,9 @@ class FunctionRunner(Trainable):
result[SHOULD_CHECKPOINT] = True
return result
def execute(self, fn):
return fn(self)
def create_default_checkpoint_dir(self):
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
self.logdir, index="default")
@@ -306,12 +309,8 @@ class FunctionRunner(Trainable):
def save_to_object(self):
checkpoint_path = self.save()
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
out = io.BytesIO()
if len(data_dict) > 10e6: # getting pretty large
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
out.write(data_dict)
return out.getvalue()
obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
return obj
def load_checkpoint(self, checkpoint):
# This should be removed once Trainables are refactored.