From 714c367b9d4b7d220ef6f427882f8cdec0d3348b Mon Sep 17 00:00:00 2001 From: Raoul Khouri <69156393+raoul-khour-ts@users.noreply.github.com> Date: Tue, 2 Feb 2021 07:11:57 -0500 Subject: [PATCH] [RLlib] Trainer._validate_config idempotentcy correction (issue 13427) (#13556) --- rllib/BUILD | 7 +++++++ rllib/agents/trainer.py | 2 +- rllib/tests/test_trainer.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 rllib/tests/test_trainer.py diff --git a/rllib/BUILD b/rllib/BUILD index dd1d4c163..9658983ab 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1509,6 +1509,13 @@ py_test( srcs = ["tests/test_timesteps.py"] ) +py_test( + name = "tests/test_trainer", + tags = ["tests_dir", "tests_dir_T"], + size = "small", + srcs = ["tests/test_trainer.py"] +) + # -------------------------------------------------------------------- # examples/ directory # diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 47e637f6d..65e315a1d 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -1094,7 +1094,7 @@ class Trainer(Trainable): if model_config.get("_time_major"): raise ValueError("`model._time_major` only supported " "iff `_use_trajectory_view_api` is True!") - elif traj_view_framestacks != "auto": + elif traj_view_framestacks not in ["auto", 0]: raise ValueError("`model.num_framestacks` only supported " "iff `_use_trajectory_view_api` is True!") model_config["num_framestacks"] = 0 diff --git a/rllib/tests/test_trainer.py b/rllib/tests/test_trainer.py new file mode 100644 index 000000000..7555c27c5 --- /dev/null +++ b/rllib/tests/test_trainer.py @@ -0,0 +1,30 @@ +"""Testing for trainer class""" +import copy +import unittest +from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG + + +class TestTrainer(unittest.TestCase): + def test_validate_config_idempotent(self): + """ + Asserts that validate_config run multiple + times on COMMON_CONFIG will be idempotent + """ + # Given + standard_config = copy.deepcopy(COMMON_CONFIG) + standard_config["_use_trajectory_view_api"] = False + + # When (we validate config 2 times) + Trainer._validate_config(standard_config) + config_v1 = copy.deepcopy(standard_config) + Trainer._validate_config(standard_config) + config_v2 = copy.deepcopy(standard_config) + + # Then + self.assertEqual(config_v1, config_v2) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))