mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 22:06:00 +08:00
7c3274e65b
## What do these changes do?
This is a re-implementation of the `FunctionRunner` which enforces some synchronicity between the thread running the training function and the thread running the Trainable which logs results. The main purpose is to make logging consistent across APIs in anticipation of a new function API which will be generator based (through `yield` statements). Without these changes, it will be impossible for the (possibly soon to be) deprecated reporter based API to behave the same as the generator based API.
This new implementation provides additional guarantees to prevent results from being dropped. This makes the logging behavior more intuitive and consistent with how results are handled in custom subclasses of Trainable.
New guarantees for the tune function API:
- Every reported result, i.e., `reporter(**kwargs)` calls, is forwarded to the appropriate loggers instead of being dropped if not enough time has elapsed since the last results.
- The wrapped function only runs if the `FunctionRunner` expects a result, i.e., when `FunctionRunner._train()` has been called. This removes the possibility that a result will be generated by the function but never logged.
- The wrapped function is not called until the first `_train()` call. Currently, the wrapped function is started during the setup phase which could result in dropped results if the trial is cancelled between `_setup()` and the first `_train()` call.
- Exceptions raised by the wrapped function won't be propagated until all results are logged to prevent dropped results.
- The thread running the wrapped function is explicitly stopped when the `FunctionRunner` is stopped with `_stop()`.
- If the wrapped function terminates without reporting `done=True`, a duplicate result with `{"done": True}`, is reported to explicitly terminate the trial, and components will be notified with a duplicate of the last reported result, but this duplicate will not be logged.
## Related issue number
Closes #3956.
#3949
#3834
119 lines
3.8 KiB
Python
119 lines
3.8 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import logging
|
|
from types import FunctionType
|
|
|
|
import ray
|
|
import ray.cloudpickle as pickle
|
|
from ray.experimental.internal_kv import _internal_kv_initialized, \
|
|
_internal_kv_get, _internal_kv_put
|
|
|
|
TRAINABLE_CLASS = "trainable_class"
|
|
ENV_CREATOR = "env_creator"
|
|
RLLIB_MODEL = "rllib_model"
|
|
RLLIB_PREPROCESSOR = "rllib_preprocessor"
|
|
KNOWN_CATEGORIES = [
|
|
TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR
|
|
]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def register_trainable(name, trainable):
|
|
"""Register a trainable function or class.
|
|
|
|
Args:
|
|
name (str): Name to register.
|
|
trainable (obj): Function or tune.Trainable class. Functions must
|
|
take (config, status_reporter) as arguments and will be
|
|
automatically converted into a class during registration.
|
|
"""
|
|
|
|
from ray.tune.trainable import Trainable
|
|
from ray.tune.function_runner import wrap_function
|
|
|
|
if isinstance(trainable, type):
|
|
logger.debug("Detected class for trainable.")
|
|
elif isinstance(trainable, FunctionType):
|
|
logger.debug("Detected function for trainable.")
|
|
trainable = wrap_function(trainable)
|
|
elif callable(trainable):
|
|
logger.warning(
|
|
"Detected unknown callable for trainable. Converting to class.")
|
|
trainable = wrap_function(trainable)
|
|
|
|
if not issubclass(trainable, Trainable):
|
|
raise TypeError("Second argument must be convertable to Trainable",
|
|
trainable)
|
|
_global_registry.register(TRAINABLE_CLASS, name, trainable)
|
|
|
|
|
|
def register_env(name, env_creator):
|
|
"""Register a custom environment for use with RLlib.
|
|
|
|
Args:
|
|
name (str): Name to register.
|
|
env_creator (obj): Function that creates an env.
|
|
"""
|
|
|
|
if not isinstance(env_creator, FunctionType):
|
|
raise TypeError("Second argument must be a function.", env_creator)
|
|
_global_registry.register(ENV_CREATOR, name, env_creator)
|
|
|
|
|
|
def _make_key(category, key):
|
|
"""Generate a binary key for the given category and key.
|
|
|
|
Args:
|
|
category (str): The category of the item
|
|
key (str): The unique identifier for the item
|
|
|
|
Returns:
|
|
The key to use for storing a the value.
|
|
"""
|
|
return (b"TuneRegistry:" + category.encode("ascii") + b"/" +
|
|
key.encode("ascii"))
|
|
|
|
|
|
class _Registry(object):
|
|
def __init__(self):
|
|
self._to_flush = {}
|
|
|
|
def register(self, category, key, value):
|
|
if category not in KNOWN_CATEGORIES:
|
|
from ray.tune import TuneError
|
|
raise TuneError("Unknown category {} not among {}".format(
|
|
category, KNOWN_CATEGORIES))
|
|
self._to_flush[(category, key)] = pickle.dumps(value)
|
|
if _internal_kv_initialized():
|
|
self.flush_values()
|
|
|
|
def contains(self, category, key):
|
|
if _internal_kv_initialized():
|
|
value = _internal_kv_get(_make_key(category, key))
|
|
return value is not None
|
|
else:
|
|
return (category, key) in self._to_flush
|
|
|
|
def get(self, category, key):
|
|
if _internal_kv_initialized():
|
|
value = _internal_kv_get(_make_key(category, key))
|
|
if value is None:
|
|
raise ValueError(
|
|
"Registry value for {}/{} doesn't exist.".format(
|
|
category, key))
|
|
return pickle.loads(value)
|
|
else:
|
|
return pickle.loads(self._to_flush[(category, key)])
|
|
|
|
def flush_values(self):
|
|
for (category, key), value in self._to_flush.items():
|
|
_internal_kv_put(_make_key(category, key), value, overwrite=True)
|
|
self._to_flush.clear()
|
|
|
|
|
|
_global_registry = _Registry()
|
|
ray.worker._post_init_hooks.append(_global_registry.flush_values)
|