diff --git a/rllib/utils/schedules/constant_schedule.py b/rllib/utils/schedules/constant_schedule.py index 34deeb1d9..20cdcda04 100644 --- a/rllib/utils/schedules/constant_schedule.py +++ b/rllib/utils/schedules/constant_schedule.py @@ -1,6 +1,9 @@ from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.schedules.schedule import Schedule +tf1, tf, tfv = try_import_tf() + class ConstantSchedule(Schedule): """ @@ -18,3 +21,7 @@ class ConstantSchedule(Schedule): @override(Schedule) def _value(self, t): return self._v + + @override(Schedule) + def _tf_value_op(self, t): + return tf.constant(self._v) diff --git a/rllib/utils/schedules/exponential_schedule.py b/rllib/utils/schedules/exponential_schedule.py index 0616cd8ef..aaef5101f 100644 --- a/rllib/utils/schedules/exponential_schedule.py +++ b/rllib/utils/schedules/exponential_schedule.py @@ -1,6 +1,9 @@ from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.schedules.schedule import Schedule +torch, _ = try_import_torch() + class ExponentialSchedule(Schedule): def __init__(self, @@ -33,5 +36,7 @@ class ExponentialSchedule(Schedule): def _value(self, t): """Returns the result of: initial_p * decay_rate ** (`t`/t_max) """ + if self.framework == "torch" and torch and isinstance(t, torch.Tensor): + t = t.float() return self.initial_p * \ self.decay_rate ** (t / self.schedule_timesteps) diff --git a/rllib/utils/schedules/polynomial_schedule.py b/rllib/utils/schedules/polynomial_schedule.py index b6402da80..ba54ec542 100644 --- a/rllib/utils/schedules/polynomial_schedule.py +++ b/rllib/utils/schedules/polynomial_schedule.py @@ -1,7 +1,12 @@ +from typing import Union + +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.schedules.schedule import Schedule -from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.types import TensorType tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() class PolynomialSchedule(Schedule): @@ -13,7 +18,7 @@ class PolynomialSchedule(Schedule): power=2.0): """ Polynomial interpolation between initial_p and final_p over - schedule_timesteps. After this many time steps always `final_p` is + schedule_timesteps. After this many time steps, always `final_p` is returned. Agrs: @@ -30,11 +35,19 @@ class PolynomialSchedule(Schedule): self.initial_p = initial_p self.power = power - def _value(self, t): - """ - Returns the result of: + @override(Schedule) + def _value(self, t: Union[int, TensorType]): + """Returns the result of: final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power """ + if self.framework == "torch" and torch and isinstance(t, torch.Tensor): + t = t.float() t = min(t, self.schedule_timesteps) return self.final_p + (self.initial_p - self.final_p) * ( 1.0 - (t / self.schedule_timesteps))**self.power + + @override(Schedule) + def _tf_value_op(self, t: Union[int, TensorType]): + t = tf.math.minimum(t, self.schedule_timesteps) + return self.final_p + (self.initial_p - self.final_p) * ( + 1.0 - (t / self.schedule_timesteps))**self.power diff --git a/rllib/utils/schedules/schedule.py b/rllib/utils/schedules/schedule.py index ca21bd595..ea73fe4b0 100644 --- a/rllib/utils/schedules/schedule.py +++ b/rllib/utils/schedules/schedule.py @@ -35,7 +35,7 @@ class Schedule(metaclass=ABCMeta): Returns: any: The calculated value depending on the schedule and `t`. """ - if self.framework in ["tf", "tfe"]: + if self.framework in ["tf2", "tf", "tfe"]: return self._tf_value_op(t) return self._value(t) @@ -71,4 +71,4 @@ class Schedule(metaclass=ABCMeta): """ # By default (most of the time), tf should work with python code. # Override only if necessary. - return tf.constant(self._value(t)) + return self._value(t) diff --git a/rllib/utils/schedules/tests/test_schedules.py b/rllib/utils/schedules/tests/test_schedules.py index 0fed37092..16b52006b 100644 --- a/rllib/utils/schedules/tests/test_schedules.py +++ b/rllib/utils/schedules/tests/test_schedules.py @@ -2,16 +2,16 @@ import unittest from ray.rllib.utils.schedules import ConstantSchedule, \ LinearSchedule, ExponentialSchedule, PiecewiseSchedule -from ray.rllib.utils import check, framework_iterator, try_import_tf +from ray.rllib.utils import check, framework_iterator, try_import_tf, \ + try_import_torch from ray.rllib.utils.from_config import from_config tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() class TestSchedules(unittest.TestCase): - """ - Tests all time-step dependent Schedule classes. - """ + """Tests all time-step dependent Schedule classes.""" def test_constant_schedule(self): value = 2.3 @@ -19,26 +19,41 @@ class TestSchedules(unittest.TestCase): config = {"value": value} - for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]): - fw_ = fw if fw != "tfe" else "tf" - constant = from_config(ConstantSchedule, config, framework=fw_) + for fw in framework_iterator( + frameworks=["tf2", "tf", "tfe", "torch", None]): + constant = from_config(ConstantSchedule, config, framework=fw) for t in ts: out = constant(t) check(out, value) + ts_as_tensors = self._get_framework_tensors(ts, fw) + for t in ts_as_tensors: + out = constant(t) + assert fw != "tf" or isinstance(out, tf.Tensor) + check(out, value, decimals=4) + def test_linear_schedule(self): ts = [0, 50, 10, 100, 90, 2, 1, 99, 23, 1000] + expected = [2.1 - (min(t, 100) / 100) * (2.1 - 0.6) for t in ts] config = {"schedule_timesteps": 100, "initial_p": 2.1, "final_p": 0.6} - for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]): - fw_ = fw if fw != "tfe" else "tf" - linear = from_config(LinearSchedule, config, framework=fw_) - for t in ts: + for fw in framework_iterator( + frameworks=["tf2", "tf", "tfe", "torch", None]): + linear = from_config(LinearSchedule, config, framework=fw) + for t, e in zip(ts, expected): out = linear(t) - check(out, 2.1 - (min(t, 100) / 100) * (2.1 - 0.6), decimals=4) + check(out, e, decimals=4) + + ts_as_tensors = self._get_framework_tensors(ts, fw) + for t, e in zip(ts_as_tensors, expected): + out = linear(t) + assert fw != "tf" or isinstance(out, tf.Tensor) + check(out, e, decimals=4) def test_polynomial_schedule(self): ts = [0, 5, 10, 100, 90, 2, 1, 99, 23, 1000] + expected = [ + 0.5 + (2.0 - 0.5) * (1.0 - min(t, 100) / 100)**2 for t in ts] config = dict( type="ray.rllib.utils.schedules.polynomial_schedule." "PolynomialSchedule", @@ -47,25 +62,39 @@ class TestSchedules(unittest.TestCase): final_p=0.5, power=2.0) - for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]): - fw_ = fw if fw != "tfe" else "tf" - polynomial = from_config(config, framework=fw_) - for t in ts: + for fw in framework_iterator( + frameworks=["tf2", "tf", "tfe", "torch", None]): + polynomial = from_config(config, framework=fw) + for t, e in zip(ts, expected): out = polynomial(t) - t = min(t, 100) - check(out, 0.5 + (2.0 - 0.5) * (1.0 - t / 100)**2, decimals=4) + check(out, e, decimals=4) + + ts_as_tensors = self._get_framework_tensors(ts, fw) + for t, e in zip(ts_as_tensors, expected): + out = polynomial(t) + assert fw != "tf" or isinstance(out, tf.Tensor) + check(out, e, decimals=4) def test_exponential_schedule(self): + decay_rate = 0.2 ts = [0, 5, 10, 100, 90, 2, 1, 99, 23] - config = dict(initial_p=2.0, decay_rate=0.99, schedule_timesteps=100) + expected = [2.0 * decay_rate**(t / 100) for t in ts] + config = dict( + initial_p=2.0, decay_rate=decay_rate, schedule_timesteps=100) - for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]): - fw_ = fw if fw != "tfe" else "tf" + for fw in framework_iterator( + frameworks=["tf2", "tf", "tfe", "torch", None]): exponential = from_config( - ExponentialSchedule, config, framework=fw_) - for t in ts: + ExponentialSchedule, config, framework=fw) + for t, e in zip(ts, expected): out = exponential(t) - check(out, 2.0 * 0.99**(t / 100), decimals=4) + check(out, e, decimals=4) + + ts_as_tensors = self._get_framework_tensors(ts, fw) + for t, e in zip(ts_as_tensors, expected): + out = exponential(t) + assert fw != "tf" or isinstance(out, tf.Tensor) + check(out, e, decimals=4) def test_piecewise_schedule(self): ts = [0, 5, 10, 100, 90, 2, 1, 99, 27] @@ -74,13 +103,27 @@ class TestSchedules(unittest.TestCase): endpoints=[(0, 50.0), (25, 100.0), (30, 200.0)], outside_value=14.5) - for fw in framework_iterator(frameworks=["tf", "tfe", "torch", None]): - fw_ = fw if fw != "tfe" else "tf" - piecewise = from_config(PiecewiseSchedule, config, framework=fw_) + for fw in framework_iterator( + frameworks=["tf2", "tf", "tfe", "torch", None]): + piecewise = from_config(PiecewiseSchedule, config, framework=fw) for t, e in zip(ts, expected): out = piecewise(t) check(out, e, decimals=4) + ts_as_tensors = self._get_framework_tensors(ts, fw) + for t, e in zip(ts_as_tensors, expected): + out = piecewise(t) + assert fw != "tf" or isinstance(out, tf.Tensor) + check(out, e, decimals=4) + + @staticmethod + def _get_framework_tensors(ts, fw): + if fw == "torch": + ts = [torch.tensor(t, dtype=torch.int32) for t in ts] + elif fw is not None and "tf" in fw: + ts = [tf.constant(t, dtype=tf.int32) for t in ts] + return ts + if __name__ == "__main__": import pytest