[RLlib] Trainer._validate_config idempotentcy correction (issue 13427) (#13556)

This commit is contained in:
Raoul Khouri
2021-02-02 07:11:57 -05:00
committed by GitHub
parent 0c93bb77cb
commit 714c367b9d
3 changed files with 38 additions and 1 deletions
+7
View File
@@ -1509,6 +1509,13 @@ py_test(
srcs = ["tests/test_timesteps.py"]
)
py_test(
name = "tests/test_trainer",
tags = ["tests_dir", "tests_dir_T"],
size = "small",
srcs = ["tests/test_trainer.py"]
)
# --------------------------------------------------------------------
# examples/ directory
#
+1 -1
View File
@@ -1094,7 +1094,7 @@ class Trainer(Trainable):
if model_config.get("_time_major"):
raise ValueError("`model._time_major` only supported "
"iff `_use_trajectory_view_api` is True!")
elif traj_view_framestacks != "auto":
elif traj_view_framestacks not in ["auto", 0]:
raise ValueError("`model.num_framestacks` only supported "
"iff `_use_trajectory_view_api` is True!")
model_config["num_framestacks"] = 0
+30
View File
@@ -0,0 +1,30 @@
"""Testing for trainer class"""
import copy
import unittest
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
class TestTrainer(unittest.TestCase):
def test_validate_config_idempotent(self):
"""
Asserts that validate_config run multiple
times on COMMON_CONFIG will be idempotent
"""
# Given
standard_config = copy.deepcopy(COMMON_CONFIG)
standard_config["_use_trajectory_view_api"] = False
# When (we validate config 2 times)
Trainer._validate_config(standard_config)
config_v1 = copy.deepcopy(standard_config)
Trainer._validate_config(standard_config)
config_v2 = copy.deepcopy(standard_config)
# Then
self.assertEqual(config_v1, config_v2)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))