Files
ray/python/ray/util/sgd/utils.py
T
Richard Liaw d192ef0611 [raysgd] Cleanup User API (#7384)
* Init fp16

* fp16 and schedulers

* scheduler linking and fp16

* to fp16

* loss scaling and documentation

* more documentation

* add tests, refactor config

* moredocs

* more docs

* fix logo, add test mode, add fp16 flag

* fix tests

* fix scheduler

* fix apex

* improve safety

* fix tests

* fix tests

* remove pin memory default

* rm

* fix

* Update doc/examples/doc_code/raysgd_torch_signatures.py

* fix

* migrate changes from other PR

* ok thanks

* pass

* signatures

* lint'

* Update python/ray/experimental/sgd/pytorch/utils.py

* Apply suggestions from code review

Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com>

* should address most comments

* comments

* fix this ci

* first_pass

* add overrides

* override

* fixing up operators

* format

* sgd

* constants

* rm

* revert

* save

* failures

* fixes

* trainer

* run test

* operator

* code

* op

* ok done

* operator

* sgd test fixes

* ok

* trainer

* format

* Apply suggestions from code review

Co-Authored-By: Edward Oakes <ed.nmi.oakes@gmail.com>

* Update doc/source/raysgd/raysgd_pytorch.rst

* docstring

* dcgan

* doc

* commits

* nit

* testing

* revert

* Start renaming pytorch to torch

* Rename PyTorchTrainer to TorchTrainer

* Rename PyTorch runners to Torch runners

* Finish renaming API

* Rename to torch in tests

* Finish renaming docs + tests

* Run format + fix DeprecationWarning

* fix

* move tests up

* benchmarks

* rename

* remove some args

* better metrics output

* fix up the benchmark

* benchmark-yaml

* horovod-benchmark

* benchmarks

* Remove benchmark code for cleanups

* makedatacreator

* relax

* metrics

* autosetsampler

* profile

* movements

* OK

* smoothen

* fix

* nitdocs

* loss

* comments

* fix

* fix

* runner_tests

* codes

* example

* fix_test

* fix

* tests

Co-authored-by: Edward Oakes <ed.nmi.oakes@gmail.com>
Co-authored-by: Maksim Smolin <maximsmol@gmail.com>
2020-03-10 08:41:42 -07:00

228 lines
5.7 KiB
Python

import collections
from contextlib import closing, contextmanager
import logging
import numpy as np
import socket
import time
import ray
from ray.exceptions import RayActorError
logger = logging.getLogger(__name__)
BATCH_COUNT = "batch_count"
NUM_SAMPLES = "num_samples"
BATCH_SIZE = "*batch_size"
class TimerStat:
"""A running stat for conveniently logging the duration of a code block.
Note that this class is *not* thread-safe.
Examples:
Time a call to 'time.sleep'.
>>> import time
>>> sleep_timer = TimerStat()
>>> with sleep_timer:
... time.sleep(1)
>>> round(sleep_timer.mean)
1
"""
def __init__(self, window_size=10):
self._window_size = window_size
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
def __enter__(self):
assert self._start_time is None, "concurrent updates not supported"
self._start_time = time.time()
def __exit__(self, type, value, tb):
assert self._start_time is not None
time_delta = time.time() - self._start_time
self.push(time_delta)
self._start_time = None
def push(self, time_delta):
self._samples.append(time_delta)
if len(self._samples) > self._window_size:
self._samples.pop(0)
self.count += 1
self._total_time += time_delta
def push_units_processed(self, n):
self._units_processed.append(n)
if len(self._units_processed) > self._window_size:
self._units_processed.pop(0)
@property
def mean(self):
return np.mean(self._samples)
@property
def median(self):
return np.median(self._samples)
@property
def sum(self):
return np.sum(self._samples)
@property
def max(self):
return np.max(self._samples)
@property
def first(self):
return self._samples[0] if self._samples else None
@property
def last(self):
return self._samples[-1] if self._samples else None
@property
def size(self):
return len(self._samples)
@property
def mean_units_processed(self):
return float(np.mean(self._units_processed))
@property
def mean_throughput(self):
time_total = sum(self._samples)
if not time_total:
return 0.0
return sum(self._units_processed) / time_total
def reset(self):
self._samples = []
self._units_processed = []
self._start_time = None
self._total_time = 0.0
self.count = 0
@contextmanager
def _nullcontext(enter_result=None):
"""Used for mocking timer context."""
yield enter_result
class TimerCollection:
"""A grouping of Timers."""
def __init__(self):
self._timers = collections.defaultdict(TimerStat)
self._enabled = True
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def reset(self):
for timer in self._timers.values():
timer.reset()
def record(self, key):
if self._enabled:
return self._timers[key]
else:
return _nullcontext()
def stats(self, mean=True, last=False):
aggregates = {}
for k, t in self._timers.items():
if t.count > 0:
if mean:
aggregates["mean_%s_s" % k] = t.mean
if last:
aggregates["last_%s_s" % k] = t.last
return aggregates
def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class AverageMeterCollection:
"""A grouping of AverageMeters."""
def __init__(self):
self._batch_count = 0
self.n = 0
self._meters = collections.defaultdict(AverageMeter)
def update(self, metrics, n=1):
self._batch_count += 1
self.n += n
for metric, value in metrics.items():
self._meters[metric].update(value, n=n)
def summary(self):
"""Returns a dict of average and most recent values for each metric."""
stats = {BATCH_COUNT: self._batch_count, NUM_SAMPLES: self.n}
for metric, meter in self._meters.items():
stats["mean_" + str(metric)] = meter.avg
stats["last_" + str(metric)] = meter.val
return stats
def check_for_failure(remote_values):
"""Checks remote values for any that returned and failed.
Args:
remote_values (list): List of object IDs representing functions
that may fail in the middle of execution. For example, running
a SGD training loop in multiple parallel actor calls.
Returns:
Bool for success in executing given remote tasks.
"""
unfinished = remote_values
try:
while len(unfinished) > 0:
finished, unfinished = ray.wait(unfinished)
finished = ray.get(finished)
return True
except RayActorError as exc:
logger.exception(str(exc))
return False
def override(interface_class):
def overrider(method):
assert (method.__name__ in dir(interface_class))
return method
return overrider