[RLlib] Add tensor-based tests for Schedules and fix some bugs related to using Schedules with tensor time input. (#9782)

This commit is contained in:
Sven Mika
2020-07-30 12:49:32 +02:00
committed by GitHub
parent 372114b4ed
commit f6bd12eb18
5 changed files with 102 additions and 34 deletions
@@ -1,6 +1,9 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.schedules.schedule import Schedule
tf1, tf, tfv = try_import_tf()
class ConstantSchedule(Schedule):
"""
@@ -18,3 +21,7 @@ class ConstantSchedule(Schedule):
@override(Schedule)
def _value(self, t):
return self._v
@override(Schedule)
def _tf_value_op(self, t):
return tf.constant(self._v)
@@ -1,6 +1,9 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.schedules.schedule import Schedule
torch, _ = try_import_torch()
class ExponentialSchedule(Schedule):
def __init__(self,
@@ -33,5 +36,7 @@ class ExponentialSchedule(Schedule):
def _value(self, t):
"""Returns the result of: initial_p * decay_rate ** (`t`/t_max)
"""
if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
t = t.float()
return self.initial_p * \
self.decay_rate ** (t / self.schedule_timesteps)
+18 -5
View File
@@ -1,7 +1,12 @@
from typing import Union
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.schedules.schedule import Schedule
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.types import TensorType
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class PolynomialSchedule(Schedule):
@@ -13,7 +18,7 @@ class PolynomialSchedule(Schedule):
power=2.0):
"""
Polynomial interpolation between initial_p and final_p over
schedule_timesteps. After this many time steps always `final_p` is
schedule_timesteps. After this many time steps, always `final_p` is
returned.
Agrs:
@@ -30,11 +35,19 @@ class PolynomialSchedule(Schedule):
self.initial_p = initial_p
self.power = power
def _value(self, t):
"""
Returns the result of:
@override(Schedule)
def _value(self, t: Union[int, TensorType]):
"""Returns the result of:
final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power
"""
if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
t = t.float()
t = min(t, self.schedule_timesteps)
return self.final_p + (self.initial_p - self.final_p) * (
1.0 - (t / self.schedule_timesteps))**self.power
@override(Schedule)
def _tf_value_op(self, t: Union[int, TensorType]):
t = tf.math.minimum(t, self.schedule_timesteps)
return self.final_p + (self.initial_p - self.final_p) * (
1.0 - (t / self.schedule_timesteps))**self.power
+2 -2
View File
@@ -35,7 +35,7 @@ class Schedule(metaclass=ABCMeta):
Returns:
any: The calculated value depending on the schedule and `t`.
"""
if self.framework in ["tf", "tfe"]:
if self.framework in ["tf2", "tf", "tfe"]:
return self._tf_value_op(t)
return self._value(t)
@@ -71,4 +71,4 @@ class Schedule(metaclass=ABCMeta):
"""
# By default (most of the time), tf should work with python code.
# Override only if necessary.
return tf.constant(self._value(t))
return self._value(t)
+70 -27
View File
@@ -2,16 +2,16 @@ import unittest
from ray.rllib.utils.schedules import ConstantSchedule, \
LinearSchedule, ExponentialSchedule, PiecewiseSchedule
from ray.rllib.utils import check, framework_iterator, try_import_tf
from ray.rllib.utils import check, framework_iterator, try_import_tf, \
try_import_torch
from ray.rllib.utils.from_config import from_config
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
class TestSchedules(unittest.TestCase):
"""
Tests all time-step dependent Schedule classes.
"""
"""Tests all time-step dependent Schedule classes."""
def test_constant_schedule(self):
value = 2.3
@@ -19,26 +19,41 @@ class TestSchedules(unittest.TestCase):
config = {"value": value}
for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]):
fw_ = fw if fw != "tfe" else "tf"
constant = from_config(ConstantSchedule, config, framework=fw_)
for fw in framework_iterator(
frameworks=["tf2", "tf", "tfe", "torch", None]):
constant = from_config(ConstantSchedule, config, framework=fw)
for t in ts:
out = constant(t)
check(out, value)
ts_as_tensors = self._get_framework_tensors(ts, fw)
for t in ts_as_tensors:
out = constant(t)
assert fw != "tf" or isinstance(out, tf.Tensor)
check(out, value, decimals=4)
def test_linear_schedule(self):
ts = [0, 50, 10, 100, 90, 2, 1, 99, 23, 1000]
expected = [2.1 - (min(t, 100) / 100) * (2.1 - 0.6) for t in ts]
config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6}
for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]):
fw_ = fw if fw != "tfe" else "tf"
linear = from_config(LinearSchedule, config, framework=fw_)
for t in ts:
for fw in framework_iterator(
frameworks=["tf2", "tf", "tfe", "torch", None]):
linear = from_config(LinearSchedule, config, framework=fw)
for t, e in zip(ts, expected):
out = linear(t)
check(out, 2.1 - (min(t, 100) / 100) * (2.1 - 0.6), decimals=4)
check(out, e, decimals=4)
ts_as_tensors = self._get_framework_tensors(ts, fw)
for t, e in zip(ts_as_tensors, expected):
out = linear(t)
assert fw != "tf" or isinstance(out, tf.Tensor)
check(out, e, decimals=4)
def test_polynomial_schedule(self):
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23, 1000]
expected = [
0.5 + (2.0 - 0.5) * (1.0 - min(t, 100) / 100)**2 for t in ts]
config = dict(
type="ray.rllib.utils.schedules.polynomial_schedule."
"PolynomialSchedule",
@@ -47,25 +62,39 @@ class TestSchedules(unittest.TestCase):
final_p=0.5,
power=2.0)
for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]):
fw_ = fw if fw != "tfe" else "tf"
polynomial = from_config(config, framework=fw_)
for t in ts:
for fw in framework_iterator(
frameworks=["tf2", "tf", "tfe", "torch", None]):
polynomial = from_config(config, framework=fw)
for t, e in zip(ts, expected):
out = polynomial(t)
t = min(t, 100)
check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4)
check(out, e, decimals=4)
ts_as_tensors = self._get_framework_tensors(ts, fw)
for t, e in zip(ts_as_tensors, expected):
out = polynomial(t)
assert fw != "tf" or isinstance(out, tf.Tensor)
check(out, e, decimals=4)
def test_exponential_schedule(self):
decay_rate = 0.2
ts = [0, 5, 10, 100, 90, 2, 1, 99, 23]
config = dict(initial_p=2.0, decay_rate=0.99, schedule_timesteps=100)
expected = [2.0 * decay_rate**(t / 100) for t in ts]
config = dict(
initial_p=2.0, decay_rate=decay_rate, schedule_timesteps=100)
for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]):
fw_ = fw if fw != "tfe" else "tf"
for fw in framework_iterator(
frameworks=["tf2", "tf", "tfe", "torch", None]):
exponential = from_config(
ExponentialSchedule, config, framework=fw_)
for t in ts:
ExponentialSchedule, config, framework=fw)
for t, e in zip(ts, expected):
out = exponential(t)
check(out, 2.0 * 0.99**(t / 100), decimals=4)
check(out, e, decimals=4)
ts_as_tensors = self._get_framework_tensors(ts, fw)
for t, e in zip(ts_as_tensors, expected):
out = exponential(t)
assert fw != "tf" or isinstance(out, tf.Tensor)
check(out, e, decimals=4)
def test_piecewise_schedule(self):
ts = [0, 5, 10, 100, 90, 2, 1, 99, 27]
@@ -74,13 +103,27 @@ class TestSchedules(unittest.TestCase):
endpoints=[(0, 50.0), (25, 100.0), (30, 200.0)],
outside_value=14.5)
for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]):
fw_ = fw if fw != "tfe" else "tf"
piecewise = from_config(PiecewiseSchedule, config, framework=fw_)
for fw in framework_iterator(
frameworks=["tf2", "tf", "tfe", "torch", None]):
piecewise = from_config(PiecewiseSchedule, config, framework=fw)
for t, e in zip(ts, expected):
out = piecewise(t)
check(out, e, decimals=4)
ts_as_tensors = self._get_framework_tensors(ts, fw)
for t, e in zip(ts_as_tensors, expected):
out = piecewise(t)
assert fw != "tf" or isinstance(out, tf.Tensor)
check(out, e, decimals=4)
@staticmethod
def _get_framework_tensors(ts, fw):
if fw == "torch":
ts = [torch.tensor(t, dtype=torch.int32) for t in ts]
elif fw is not None and "tf" in fw:
ts = [tf.constant(t, dtype=tf.int32) for t in ts]
return ts
if __name__ == "__main__":
import pytest