mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 03:08:48 +08:00
[RLlib] Schedule-classes multi-framework support. (#6926)
This commit is contained in:
@@ -41,3 +41,10 @@ py_test(
|
||||
# size = "small",
|
||||
# srcs = ["models/tests/test_distributions.py"]
|
||||
#)
|
||||
|
||||
# Schedules
|
||||
py_test(
|
||||
name = "test_schedules",
|
||||
size = "small",
|
||||
srcs = ["utils/schedules/tests/test_schedules.py"]
|
||||
)
|
||||
|
||||
@@ -11,6 +11,8 @@ from ray.rllib.utils.numpy import sigmoid, softmax, relu, one_hot, fc, lstm, \
|
||||
SMALL_NUMBER, LARGE_INTEGER
|
||||
from ray.rllib.utils.policy_client import PolicyClient
|
||||
from ray.rllib.utils.policy_server import PolicyServer
|
||||
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
|
||||
PolynomialSchedule, ExponentialSchedule, ConstantSchedule
|
||||
from ray.rllib.utils.test_utils import check
|
||||
from ray.tune.utils import merge_dicts, deep_update
|
||||
|
||||
@@ -75,12 +77,17 @@ __all__ = [
|
||||
"try_import_tf",
|
||||
"try_import_tfp",
|
||||
"try_import_torch",
|
||||
"ConstantSchedule",
|
||||
"DeveloperAPI",
|
||||
"ExponentialSchedule",
|
||||
"Filter",
|
||||
"FilterManager",
|
||||
"LARGE_INTEGER",
|
||||
"LinearSchedule",
|
||||
"PiecewiseSchedule",
|
||||
"PolicyClient",
|
||||
"PolicyServer",
|
||||
"PolynomialSchedule",
|
||||
"PublicAPI",
|
||||
"SMALL_NUMBER",
|
||||
]
|
||||
|
||||
@@ -4,8 +4,31 @@ import os
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def try_import_tf():
|
||||
def check_framework(framework="tf"):
|
||||
"""
|
||||
Checks, whether the given framework is "valid", meaning, whether all
|
||||
necessary dependencies are installed. Errors otherwise.
|
||||
|
||||
Args:
|
||||
framework (str): Once of "tf", "torch", or None.
|
||||
|
||||
Returns:
|
||||
str: The input framework string.
|
||||
"""
|
||||
if framework == "tf":
|
||||
try_import_tf(error=True)
|
||||
elif framework == "torch":
|
||||
try_import_torch(error=True)
|
||||
else:
|
||||
assert framework is None
|
||||
return framework
|
||||
|
||||
|
||||
def try_import_tf(error=False):
|
||||
"""
|
||||
Args:
|
||||
error (bool): Whether to raise an error if tf cannot be imported.
|
||||
|
||||
Returns:
|
||||
The tf module (either from tf2.0.compat.v1 OR as tf1.x.
|
||||
"""
|
||||
@@ -24,12 +47,17 @@ def try_import_tf():
|
||||
try:
|
||||
import tensorflow as tf
|
||||
return tf
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
if error:
|
||||
raise e
|
||||
return None
|
||||
|
||||
|
||||
def try_import_tfp():
|
||||
def try_import_tfp(error=False):
|
||||
"""
|
||||
Args:
|
||||
error (bool): Whether to raise an error if tfp cannot be imported.
|
||||
|
||||
Returns:
|
||||
The tfp module.
|
||||
"""
|
||||
@@ -41,12 +69,17 @@ def try_import_tfp():
|
||||
try:
|
||||
import tensorflow_probability as tfp
|
||||
return tfp
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
if error:
|
||||
raise e
|
||||
return None
|
||||
|
||||
|
||||
def try_import_torch():
|
||||
def try_import_torch(error=False):
|
||||
"""
|
||||
Args:
|
||||
error (bool): Whether to raise an error if torch cannot be imported.
|
||||
|
||||
Returns:
|
||||
tuple: torch AND torch.nn modules.
|
||||
"""
|
||||
@@ -58,5 +91,7 @@ def try_import_torch():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
return torch, nn
|
||||
except ImportError:
|
||||
except ImportError as e:
|
||||
if error:
|
||||
raise e
|
||||
return None, None
|
||||
|
||||
@@ -92,7 +92,8 @@ def from_config(cls, config=None, **kwargs):
|
||||
if type_ is None:
|
||||
# We have a default constructor that was defined directly by cls
|
||||
# (not by its children).
|
||||
if cls is not None and cls.__default_constructor__ is not None and \
|
||||
if cls is not None and hasattr(cls, "__default_constructor__") and \
|
||||
cls.__default_constructor__ is not None and \
|
||||
ctor_args == [] and \
|
||||
(
|
||||
not hasattr(cls.__bases__[0], "__default_constructor__")
|
||||
@@ -199,11 +200,12 @@ def from_file(cls, filename, *args, **kwargs):
|
||||
|
||||
|
||||
def lookup_type(cls, type_):
|
||||
if cls is not None and isinstance(cls.__type_registry__, dict) and \
|
||||
if cls is not None and hasattr(cls, "__type_registry__") and \
|
||||
isinstance(cls.__type_registry__, dict) and \
|
||||
(
|
||||
type_ in cls.__type_registry__ or (
|
||||
isinstance(type_, str) and
|
||||
re.sub("[\W_]", "", type_.lower())
|
||||
re.sub("[\\W_]", "", type_.lower())
|
||||
in cls.__type_registry__
|
||||
)
|
||||
):
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from ray.rllib.utils.schedules.schedule import Schedule
|
||||
from ray.rllib.utils.schedules.constant_schedule import ConstantSchedule
|
||||
from ray.rllib.utils.schedules.linear_schedule import LinearSchedule
|
||||
from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule
|
||||
from ray.rllib.utils.schedules.polynomial_schedule import PolynomialSchedule
|
||||
from ray.rllib.utils.schedules.exponential_schedule import ExponentialSchedule
|
||||
|
||||
__all__ = [
|
||||
"ConstantSchedule", "ExponentialSchedule", "LinearSchedule", "Schedule",
|
||||
"PiecewiseSchedule", "PolynomialSchedule"
|
||||
]
|
||||
@@ -0,0 +1,18 @@
|
||||
from ray.rllib.utils.schedules.schedule import Schedule
|
||||
|
||||
|
||||
class ConstantSchedule(Schedule):
|
||||
"""
|
||||
A Schedule where the value remains constant over time.
|
||||
"""
|
||||
|
||||
def __init__(self, value, framework=None):
|
||||
"""
|
||||
Args:
|
||||
value (float): The constant value to return, independently of time.
|
||||
"""
|
||||
super().__init__(framework=None)
|
||||
self._v = value
|
||||
|
||||
def value(self, t=None):
|
||||
return self._v
|
||||
@@ -0,0 +1,37 @@
|
||||
from ray.rllib.utils.schedules.schedule import Schedule
|
||||
|
||||
|
||||
class ExponentialSchedule(Schedule):
|
||||
def __init__(self,
|
||||
schedule_timesteps,
|
||||
initial_p=1.0,
|
||||
decay_rate=0.1,
|
||||
framework=None):
|
||||
"""
|
||||
Exponential decay schedule from initial_p to final_p over
|
||||
schedule_timesteps. After this many time steps always `final_p` is
|
||||
returned.
|
||||
|
||||
Agrs:
|
||||
schedule_timesteps (int): Number of time steps for which to
|
||||
linearly anneal initial_p to final_p
|
||||
initial_p (float): Initial output value.
|
||||
decay_rate (float): The percentage of the original value after
|
||||
100% of the time has been reached (see formula above).
|
||||
>0.0: The smaller the decay-rate, the stronger the decay.
|
||||
1.0: No decay at all.
|
||||
framework (Optional[str]): One of "tf", "torch", or None.
|
||||
"""
|
||||
super().__init__(framework=framework)
|
||||
assert schedule_timesteps > 0
|
||||
self.schedule_timesteps = schedule_timesteps
|
||||
self.initial_p = initial_p
|
||||
self.decay_rate = decay_rate
|
||||
|
||||
def value(self, t):
|
||||
"""
|
||||
Returns the result of:
|
||||
initial_p * decay_rate ** (`t`/t_max)
|
||||
"""
|
||||
return self.initial_p * \
|
||||
self.decay_rate ** (t / self.schedule_timesteps)
|
||||
@@ -0,0 +1,13 @@
|
||||
from ray.rllib.utils.schedules.polynomial_schedule import PolynomialSchedule
|
||||
|
||||
|
||||
class LinearSchedule(PolynomialSchedule):
|
||||
"""
|
||||
Linear interpolation between `initial_p` and `final_p`. Simply
|
||||
uses Polynomial with power=1.0.
|
||||
|
||||
final_p + (initial_p - final_p) * (1 - `t`/t_max)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(power=1.0, **kwargs)
|
||||
@@ -0,0 +1,54 @@
|
||||
from ray.rllib.utils.schedules.schedule import Schedule
|
||||
|
||||
|
||||
def _linear_interpolation(l, r, alpha):
|
||||
return l + alpha * (r - l)
|
||||
|
||||
|
||||
class PiecewiseSchedule(Schedule):
|
||||
def __init__(self,
|
||||
endpoints,
|
||||
interpolation=_linear_interpolation,
|
||||
outside_value=None,
|
||||
framework=None):
|
||||
"""
|
||||
Args:
|
||||
endpoints (List[Tuple[int,float]]): A list of tuples
|
||||
`(t, value)` such that the output
|
||||
is an interpolation (given by the `interpolation` callable)
|
||||
between two values.
|
||||
E.g.
|
||||
t=400 and endpoints=[(0, 20.0),(500, 30.0)]
|
||||
output=20.0 + 0.8 * 10.0 = 28.0
|
||||
NOTE: All the values for time must be sorted in an increasing
|
||||
order.
|
||||
|
||||
interpolation (callable): A function that takes the left-value,
|
||||
the right-value and an alpha interpolation parameter
|
||||
(0.0=only left value, 1.0=only right value), which is the
|
||||
fraction of distance from left endpoint to right endpoint.
|
||||
|
||||
outside_value (Optional[float]): If t_pct in call to `value` is
|
||||
outside of all the intervals in `endpoints` this value is
|
||||
returned. If None then an AssertionError is raised when outside
|
||||
value is requested.
|
||||
"""
|
||||
# TODO(sven): support tf.
|
||||
assert framework is None
|
||||
super().__init__(framework=None)
|
||||
|
||||
idxes = [e[0] for e in endpoints]
|
||||
assert idxes == sorted(idxes)
|
||||
self.interpolation = interpolation
|
||||
self.outside_value = outside_value
|
||||
self.endpoints = endpoints
|
||||
|
||||
def value(self, t):
|
||||
for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
|
||||
if l_t <= t < r_t:
|
||||
alpha = float(t - l_t) / (r_t - l_t)
|
||||
return self.interpolation(l, r, alpha)
|
||||
|
||||
# t does not belong to any of the pieces, so doom.
|
||||
assert self.outside_value is not None
|
||||
return self.outside_value
|
||||
@@ -0,0 +1,46 @@
|
||||
from ray.rllib.utils.schedules.schedule import Schedule
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class PolynomialSchedule(Schedule):
|
||||
def __init__(self,
|
||||
schedule_timesteps,
|
||||
final_p,
|
||||
initial_p=1.0,
|
||||
power=2.0,
|
||||
framework=None):
|
||||
"""
|
||||
Polynomial interpolation between initial_p and final_p over
|
||||
schedule_timesteps. After this many time steps always `final_p` is
|
||||
returned.
|
||||
|
||||
Agrs:
|
||||
schedule_timesteps (int): Number of time steps for which to
|
||||
linearly anneal initial_p to final_p
|
||||
final_p (float): Final output value.
|
||||
initial_p (float): Initial output value.
|
||||
framework (Optional[str]): One of "tf", "torch", or None.
|
||||
"""
|
||||
super().__init__(framework=framework)
|
||||
assert schedule_timesteps > 0
|
||||
self.schedule_timesteps = schedule_timesteps
|
||||
self.final_p = final_p
|
||||
self.initial_p = initial_p
|
||||
self.power = power
|
||||
|
||||
def value(self, t):
|
||||
"""
|
||||
Returns the result of:
|
||||
final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power
|
||||
"""
|
||||
if self.framework == "tf" and tf.executing_eagerly() is False:
|
||||
return tf.train.polynomial_decay(
|
||||
learning_rate=self.initial_p,
|
||||
global_step=t,
|
||||
decay_steps=self.schedule_timesteps,
|
||||
end_learning_rate=self.final_p,
|
||||
power=self.power)
|
||||
return self.final_p + (self.initial_p - self.final_p) * (
|
||||
1.0 - (t / self.schedule_timesteps))**self.power
|
||||
@@ -0,0 +1,46 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
from ray.rllib.utils.framework import check_framework
|
||||
|
||||
|
||||
class Schedule(metaclass=ABCMeta):
|
||||
"""
|
||||
Schedule classes implement various time-dependent scheduling schemas, such
|
||||
as:
|
||||
- Constant behavior.
|
||||
- Linear decay.
|
||||
- Piecewise decay.
|
||||
|
||||
Useful for backend-agnostic rate/weight changes for learning rates,
|
||||
exploration epsilons, beta parameters for prioritized replay, loss weights
|
||||
decay, etc..
|
||||
|
||||
Each schedule can be called directly with the `t` (absolute time step)
|
||||
value and returns the value dependent on the Schedule and the passed time.
|
||||
"""
|
||||
|
||||
def __init__(self, framework=None):
|
||||
# TODO(sven): replace with .tf_value() / torch_value() methods that
|
||||
# can be applied late binding, so no need to set framework during
|
||||
# construction.
|
||||
self.framework = check_framework(framework)
|
||||
|
||||
@abstractmethod
|
||||
def value(self, t):
|
||||
"""
|
||||
Returns the value based on a time value.
|
||||
|
||||
Args:
|
||||
t (int): The time value (e.g. a time step).
|
||||
NOTE: This could be a tf.Tensor.
|
||||
|
||||
Returns:
|
||||
any: The calculated value depending on the schedule and `t`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, t):
|
||||
"""
|
||||
Simply calls `self.value(t)`.
|
||||
"""
|
||||
return self.value(t)
|
||||
@@ -0,0 +1,85 @@
|
||||
import unittest
|
||||
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, \
|
||||
LinearSchedule, ExponentialSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils import check, try_import_tf
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class TestSchedules(unittest.TestCase):
|
||||
"""
|
||||
Tests all time-step/time-percentage dependent Schedule classes.
|
||||
"""
|
||||
|
||||
def test_constant_schedule(self):
|
||||
value = 2.3
|
||||
ts = [100, 0, 10, 2, 3, 4, 99, 56, 10000, 23, 234, 56]
|
||||
|
||||
for fw in ["tf", "torch", None]:
|
||||
constant = from_config(ConstantSchedule,
|
||||
dict(value=value, framework=fw))
|
||||
for t in ts:
|
||||
out = constant(t)
|
||||
check(out, value)
|
||||
|
||||
def test_linear_schedule(self):
|
||||
ts = [0, 50, 10, 100, 90, 2, 1, 99, 23]
|
||||
for fw in ["tf", "torch", None]:
|
||||
linear = from_config(
|
||||
LinearSchedule, {
|
||||
"schedule_timesteps": 100,
|
||||
"initial_p": 2.1,
|
||||
"final_p": 0.6,
|
||||
"framework": fw
|
||||
})
|
||||
if fw == "tf":
|
||||
tf.enable_eager_execution()
|
||||
for t in ts:
|
||||
out = linear(t)
|
||||
check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
|
||||
|
||||
def test_polynomial_schedule(self):
|
||||
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
|
||||
for fw in ["tf", "torch", None]:
|
||||
polynomial = from_config(
|
||||
dict(
|
||||
type="ray.rllib.utils.schedules.polynomial_schedule."
|
||||
"PolynomialSchedule",
|
||||
schedule_timesteps=100,
|
||||
initial_p=2.0,
|
||||
final_p=0.5,
|
||||
power=2.0,
|
||||
framework=fw))
|
||||
if fw == "tf":
|
||||
tf.enable_eager_execution()
|
||||
for t in ts:
|
||||
out = polynomial(t)
|
||||
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
|
||||
|
||||
def test_exponential_schedule(self):
|
||||
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
|
||||
for fw in ["tf", "torch", None]:
|
||||
exponential = from_config(
|
||||
ExponentialSchedule,
|
||||
dict(
|
||||
initial_p=2.0,
|
||||
decay_rate=0.99,
|
||||
schedule_timesteps=100,
|
||||
framework=fw))
|
||||
for t in ts:
|
||||
out = exponential(t)
|
||||
check(out, 2.0 * 0.99**(t / 100), decimals=4)
|
||||
|
||||
def test_piecewise_schedule(self):
|
||||
piecewise = from_config(
|
||||
PiecewiseSchedule,
|
||||
dict(
|
||||
endpoints=[(0, 50.0), (25, 100.0), (30, 200.0)],
|
||||
outside_value=14.5))
|
||||
ts = [0, 5, 10, 100, 90, 2, 1, 99, 27]
|
||||
expected = [50.0, 60.0, 70.0, 14.5, 14.5, 54.0, 52.0, 14.5, 140.0]
|
||||
for t, e in zip(ts, expected):
|
||||
out = piecewise(t)
|
||||
check(out, e, decimals=4)
|
||||
@@ -66,7 +66,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
|
||||
assert bool(x) is bool(y), \
|
||||
"ERROR: x ({}) is not y ({})!".format(x, y)
|
||||
# Nones or primitives.
|
||||
elif x is None or y is None or isinstance(x, (str, int, float)):
|
||||
elif x is None or y is None or isinstance(x, (str, int)):
|
||||
if false is True:
|
||||
assert x != y, "ERROR: x ({}) is the same as y ({})!".format(x, y)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user