[RLlib] Schedule-classes multi-framework support. (#6926)

This commit is contained in:
Sven Mika
2020-01-28 20:07:55 +01:00
committed by Eric Liang
parent 26d749bc18
commit 4c97348cb6
14 changed files with 371 additions and 10 deletions
+7
View File
@@ -41,3 +41,10 @@ py_test(
# size = "small",
# srcs = ["models/tests/test_distributions.py"]
#)
# Schedules
py_test(
name = "test_schedules",
size = "small",
srcs = ["utils/schedules/tests/test_schedules.py"]
)
+7
View File
@@ -11,6 +11,8 @@ from ray.rllib.utils.numpy import sigmoid, softmax, relu, one_hot, fc, lstm, \
SMALL_NUMBER, LARGE_INTEGER
from ray.rllib.utils.policy_client import PolicyClient
from ray.rllib.utils.policy_server import PolicyServer
from ray.rllib.utils.schedules import LinearSchedule, PiecewiseSchedule, \
PolynomialSchedule, ExponentialSchedule, ConstantSchedule
from ray.rllib.utils.test_utils import check
from ray.tune.utils import merge_dicts, deep_update
@@ -75,12 +77,17 @@ __all__ = [
"try_import_tf",
"try_import_tfp",
"try_import_torch",
"ConstantSchedule",
"DeveloperAPI",
"ExponentialSchedule",
"Filter",
"FilterManager",
"LARGE_INTEGER",
"LinearSchedule",
"PiecewiseSchedule",
"PolicyClient",
"PolicyServer",
"PolynomialSchedule",
"PublicAPI",
"SMALL_NUMBER",
]
+41 -6
View File
@@ -4,8 +4,31 @@ import os
logger = logging.getLogger(__name__)
def try_import_tf():
def check_framework(framework="tf"):
"""
Checks, whether the given framework is "valid", meaning, whether all
necessary dependencies are installed. Errors otherwise.
Args:
framework (str): Once of "tf", "torch", or None.
Returns:
str: The input framework string.
"""
if framework == "tf":
try_import_tf(error=True)
elif framework == "torch":
try_import_torch(error=True)
else:
assert framework is None
return framework
def try_import_tf(error=False):
"""
Args:
error (bool): Whether to raise an error if tf cannot be imported.
Returns:
The tf module (either from tf2.0.compat.v1 OR as tf1.x.
"""
@@ -24,12 +47,17 @@ def try_import_tf():
try:
import tensorflow as tf
return tf
except ImportError:
except ImportError as e:
if error:
raise e
return None
def try_import_tfp():
def try_import_tfp(error=False):
"""
Args:
error (bool): Whether to raise an error if tfp cannot be imported.
Returns:
The tfp module.
"""
@@ -41,12 +69,17 @@ def try_import_tfp():
try:
import tensorflow_probability as tfp
return tfp
except ImportError:
except ImportError as e:
if error:
raise e
return None
def try_import_torch():
def try_import_torch(error=False):
"""
Args:
error (bool): Whether to raise an error if torch cannot be imported.
Returns:
tuple: torch AND torch.nn modules.
"""
@@ -58,5 +91,7 @@ def try_import_torch():
import torch
import torch.nn as nn
return torch, nn
except ImportError:
except ImportError as e:
if error:
raise e
return None, None
+5 -3
View File
@@ -92,7 +92,8 @@ def from_config(cls, config=None, **kwargs):
if type_ is None:
# We have a default constructor that was defined directly by cls
# (not by its children).
if cls is not None and cls.__default_constructor__ is not None and \
if cls is not None and hasattr(cls, "__default_constructor__") and \
cls.__default_constructor__ is not None and \
ctor_args == [] and \
(
not hasattr(cls.__bases__[0], "__default_constructor__")
@@ -199,11 +200,12 @@ def from_file(cls, filename, *args, **kwargs):
def lookup_type(cls, type_):
if cls is not None and isinstance(cls.__type_registry__, dict) and \
if cls is not None and hasattr(cls, "__type_registry__") and \
isinstance(cls.__type_registry__, dict) and \
(
type_ in cls.__type_registry__ or (
isinstance(type_, str) and
re.sub("[\W_]", "", type_.lower())
re.sub("[\\W_]", "", type_.lower())
in cls.__type_registry__
)
):
+11
View File
@@ -0,0 +1,11 @@
from ray.rllib.utils.schedules.schedule import Schedule
from ray.rllib.utils.schedules.constant_schedule import ConstantSchedule
from ray.rllib.utils.schedules.linear_schedule import LinearSchedule
from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule
from ray.rllib.utils.schedules.polynomial_schedule import PolynomialSchedule
from ray.rllib.utils.schedules.exponential_schedule import ExponentialSchedule
__all__ = [
"ConstantSchedule", "ExponentialSchedule", "LinearSchedule", "Schedule",
"PiecewiseSchedule", "PolynomialSchedule"
]
@@ -0,0 +1,18 @@
from ray.rllib.utils.schedules.schedule import Schedule
class ConstantSchedule(Schedule):
"""
A Schedule where the value remains constant over time.
"""
def __init__(self, value, framework=None):
"""
Args:
value (float): The constant value to return, independently of time.
"""
super().__init__(framework=None)
self._v = value
def value(self, t=None):
return self._v
@@ -0,0 +1,37 @@
from ray.rllib.utils.schedules.schedule import Schedule
class ExponentialSchedule(Schedule):
def __init__(self,
schedule_timesteps,
initial_p=1.0,
decay_rate=0.1,
framework=None):
"""
Exponential decay schedule from initial_p to final_p over
schedule_timesteps. After this many time steps always `final_p` is
returned.
Agrs:
schedule_timesteps (int): Number of time steps for which to
linearly anneal initial_p to final_p
initial_p (float): Initial output value.
decay_rate (float): The percentage of the original value after
100% of the time has been reached (see formula above).
>0.0: The smaller the decay-rate, the stronger the decay.
1.0: No decay at all.
framework (Optional[str]): One of "tf", "torch", or None.
"""
super().__init__(framework=framework)
assert schedule_timesteps > 0
self.schedule_timesteps = schedule_timesteps
self.initial_p = initial_p
self.decay_rate = decay_rate
def value(self, t):
"""
Returns the result of:
initial_p * decay_rate ** (`t`/t_max)
"""
return self.initial_p * \
self.decay_rate ** (t / self.schedule_timesteps)
+13
View File
@@ -0,0 +1,13 @@
from ray.rllib.utils.schedules.polynomial_schedule import PolynomialSchedule
class LinearSchedule(PolynomialSchedule):
"""
Linear interpolation between `initial_p` and `final_p`. Simply
uses Polynomial with power=1.0.
final_p + (initial_p - final_p) * (1 - `t`/t_max)
"""
def __init__(self, **kwargs):
super().__init__(power=1.0, **kwargs)
@@ -0,0 +1,54 @@
from ray.rllib.utils.schedules.schedule import Schedule
def _linear_interpolation(l, r, alpha):
return l + alpha * (r - l)
class PiecewiseSchedule(Schedule):
def __init__(self,
endpoints,
interpolation=_linear_interpolation,
outside_value=None,
framework=None):
"""
Args:
endpoints (List[Tuple[int,float]]): A list of tuples
`(t, value)` such that the output
is an interpolation (given by the `interpolation` callable)
between two values.
E.g.
t=400 and endpoints=[(0, 20.0),(500, 30.0)]
output=20.0 + 0.8 * 10.0 = 28.0
NOTE: All the values for time must be sorted in an increasing
order.
interpolation (callable): A function that takes the left-value,
the right-value and an alpha interpolation parameter
(0.0=only left value, 1.0=only right value), which is the
fraction of distance from left endpoint to right endpoint.
outside_value (Optional[float]): If t_pct in call to `value` is
outside of all the intervals in `endpoints` this value is
returned. If None then an AssertionError is raised when outside
value is requested.
"""
# TODO(sven): support tf.
assert framework is None
super().__init__(framework=None)
idxes = [e[0] for e in endpoints]
assert idxes == sorted(idxes)
self.interpolation = interpolation
self.outside_value = outside_value
self.endpoints = endpoints
def value(self, t):
for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
if l_t <= t < r_t:
alpha = float(t - l_t) / (r_t - l_t)
return self.interpolation(l, r, alpha)
# t does not belong to any of the pieces, so doom.
assert self.outside_value is not None
return self.outside_value
@@ -0,0 +1,46 @@
from ray.rllib.utils.schedules.schedule import Schedule
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
class PolynomialSchedule(Schedule):
def __init__(self,
schedule_timesteps,
final_p,
initial_p=1.0,
power=2.0,
framework=None):
"""
Polynomial interpolation between initial_p and final_p over
schedule_timesteps. After this many time steps always `final_p` is
returned.
Agrs:
schedule_timesteps (int): Number of time steps for which to
linearly anneal initial_p to final_p
final_p (float): Final output value.
initial_p (float): Initial output value.
framework (Optional[str]): One of "tf", "torch", or None.
"""
super().__init__(framework=framework)
assert schedule_timesteps > 0
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
self.power = power
def value(self, t):
"""
Returns the result of:
final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power
"""
if self.framework == "tf" and tf.executing_eagerly() is False:
return tf.train.polynomial_decay(
learning_rate=self.initial_p,
global_step=t,
decay_steps=self.schedule_timesteps,
end_learning_rate=self.final_p,
power=self.power)
return self.final_p + (self.initial_p - self.final_p) * (
1.0 - (t / self.schedule_timesteps))**self.power
+46
View File
@@ -0,0 +1,46 @@
from abc import ABCMeta, abstractmethod
from ray.rllib.utils.framework import check_framework
class Schedule(metaclass=ABCMeta):
"""
Schedule classes implement various time-dependent scheduling schemas, such
as:
- Constant behavior.
- Linear decay.
- Piecewise decay.
Useful for backend-agnostic rate/weight changes for learning rates,
exploration epsilons, beta parameters for prioritized replay, loss weights
decay, etc..
Each schedule can be called directly with the `t` (absolute time step)
value and returns the value dependent on the Schedule and the passed time.
"""
def __init__(self, framework=None):
# TODO(sven): replace with .tf_value() / torch_value() methods that
# can be applied late binding, so no need to set framework during
# construction.
self.framework = check_framework(framework)
@abstractmethod
def value(self, t):
"""
Returns the value based on a time value.
Args:
t (int): The time value (e.g. a time step).
NOTE: This could be a tf.Tensor.
Returns:
any: The calculated value depending on the schedule and `t`.
"""
raise NotImplementedError
def __call__(self, t):
"""
Simply calls `self.value(t)`.
"""
return self.value(t)
@@ -0,0 +1,85 @@
import unittest
from ray.rllib.utils.schedules import ConstantSchedule, \
LinearSchedule, ExponentialSchedule, PiecewiseSchedule
from ray.rllib.utils import check, try_import_tf
from ray.rllib.utils.from_config import from_config
tf = try_import_tf()
class TestSchedules(unittest.TestCase):
"""
Tests all time-step/time-percentage dependent Schedule classes.
"""
def test_constant_schedule(self):
value = 2.3
ts = [100, 0, 10, 2, 3, 4, 99, 56, 10000, 23, 234, 56]
for fw in ["tf", "torch", None]:
constant = from_config(ConstantSchedule,
dict(value=value, framework=fw))
for t in ts:
out = constant(t)
check(out, value)
def test_linear_schedule(self):
ts = [0, 50, 10, 100, 90, 2, 1, 99, 23]
for fw in ["tf", "torch", None]:
linear = from_config(
LinearSchedule, {
"schedule_timesteps": 100,
"initial_p": 2.1,
"final_p": 0.6,
"framework": fw
})
if fw == "tf":
tf.enable_eager_execution()
for t in ts:
out = linear(t)
check(out, 2.1 - (t / 100) * (2.1 - 0.6), decimals=4)
def test_polynomial_schedule(self):
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
for fw in ["tf", "torch", None]:
polynomial = from_config(
dict(
type="ray.rllib.utils.schedules.polynomial_schedule."
"PolynomialSchedule",
schedule_timesteps=100,
initial_p=2.0,
final_p=0.5,
power=2.0,
framework=fw))
if fw == "tf":
tf.enable_eager_execution()
for t in ts:
out = polynomial(t)
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
def test_exponential_schedule(self):
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
for fw in ["tf", "torch", None]:
exponential = from_config(
ExponentialSchedule,
dict(
initial_p=2.0,
decay_rate=0.99,
schedule_timesteps=100,
framework=fw))
for t in ts:
out = exponential(t)
check(out, 2.0 * 0.99**(t / 100), decimals=4)
def test_piecewise_schedule(self):
piecewise = from_config(
PiecewiseSchedule,
dict(
endpoints=[(0, 50.0), (25, 100.0), (30, 200.0)],
outside_value=14.5))
ts = [0, 5, 10, 100, 90, 2, 1, 99, 27]
expected = [50.0, 60.0, 70.0, 14.5, 14.5, 54.0, 52.0, 14.5, 140.0]
for t, e in zip(ts, expected):
out = piecewise(t)
check(out, e, decimals=4)
+1 -1
View File
@@ -66,7 +66,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False):
assert bool(x) is bool(y), \
"ERROR: x ({}) is not y ({})!".format(x, y)
# Nones or primitives.
elif x is None or y is None or isinstance(x, (str, int, float)):
elif x is None or y is None or isinstance(x, (str, int)):
if false is True:
assert x != y, "ERROR: x ({}) is the same as y ({})!".format(x, y)
else: