mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 04:50:40 +08:00
[tune] distributed torch wrapper (#9550)
* changes * add-working * checkpoint * ccleanu * fix * ok * formatting * ok * tests * some-good-stuff * fix-torch * ddp-torch * torch-test * sessions * add-small-test * fix * remove * gpu-working * update-tests * ok * try-test * formgat * ok * ok
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import io
|
||||
import time
|
||||
import inspect
|
||||
import shutil
|
||||
@@ -87,6 +86,7 @@ class StatusReporter:
|
||||
def make_checkpoint_dir(self, step=None):
|
||||
checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index=step)
|
||||
logger.debug("Making checkpoint dir at %s", checkpoint_dir)
|
||||
return checkpoint_dir
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
@@ -279,6 +279,9 @@ class FunctionRunner(Trainable):
|
||||
result[SHOULD_CHECKPOINT] = True
|
||||
return result
|
||||
|
||||
def execute(self, fn):
|
||||
return fn(self)
|
||||
|
||||
def create_default_checkpoint_dir(self):
|
||||
self.default_checkpoint_dir = TrainableUtil.make_checkpoint_dir(
|
||||
self.logdir, index="default")
|
||||
@@ -306,12 +309,8 @@ class FunctionRunner(Trainable):
|
||||
|
||||
def save_to_object(self):
|
||||
checkpoint_path = self.save()
|
||||
data_dict = TrainableUtil.pickle_checkpoint(checkpoint_path)
|
||||
out = io.BytesIO()
|
||||
if len(data_dict) > 10e6: # getting pretty large
|
||||
logger.info("Checkpoint size is {} bytes".format(len(data_dict)))
|
||||
out.write(data_dict)
|
||||
return out.getvalue()
|
||||
obj = TrainableUtil.checkpoint_to_object(checkpoint_path)
|
||||
return obj
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
# This should be removed once Trainables are refactored.
|
||||
|
||||
Reference in New Issue
Block a user