mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:16:19 +08:00
[RLlib] Trainer._validate_config idempotentcy correction (issue 13427) (#13556)
This commit is contained in:
@@ -1509,6 +1509,13 @@ py_test(
|
||||
srcs = ["tests/test_timesteps.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tests/test_trainer",
|
||||
tags = ["tests_dir", "tests_dir_T"],
|
||||
size = "small",
|
||||
srcs = ["tests/test_trainer.py"]
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# examples/ directory
|
||||
#
|
||||
|
||||
@@ -1094,7 +1094,7 @@ class Trainer(Trainable):
|
||||
if model_config.get("_time_major"):
|
||||
raise ValueError("`model._time_major` only supported "
|
||||
"iff `_use_trajectory_view_api` is True!")
|
||||
elif traj_view_framestacks != "auto":
|
||||
elif traj_view_framestacks not in ["auto", 0]:
|
||||
raise ValueError("`model.num_framestacks` only supported "
|
||||
"iff `_use_trajectory_view_api` is True!")
|
||||
model_config["num_framestacks"] = 0
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Testing for trainer class"""
|
||||
import copy
|
||||
import unittest
|
||||
from ray.rllib.agents.trainer import Trainer, COMMON_CONFIG
|
||||
|
||||
|
||||
class TestTrainer(unittest.TestCase):
|
||||
def test_validate_config_idempotent(self):
|
||||
"""
|
||||
Asserts that validate_config run multiple
|
||||
times on COMMON_CONFIG will be idempotent
|
||||
"""
|
||||
# Given
|
||||
standard_config = copy.deepcopy(COMMON_CONFIG)
|
||||
standard_config["_use_trajectory_view_api"] = False
|
||||
|
||||
# When (we validate config 2 times)
|
||||
Trainer._validate_config(standard_config)
|
||||
config_v1 = copy.deepcopy(standard_config)
|
||||
Trainer._validate_config(standard_config)
|
||||
config_v2 = copy.deepcopy(standard_config)
|
||||
|
||||
# Then
|
||||
self.assertEqual(config_v1, config_v2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user