mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 09:12:56 +08:00
[rllib] Fix tune.run(Agent class) (#4630)
* update * Update __init__.py
This commit is contained in:
committed by
Devin Petersohn
parent
56a78baf67
commit
3e234fe937
@@ -0,0 +1,15 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ppo import PPOAgent
|
||||
from ray import tune
|
||||
import ray
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
# Test legacy *Agent classes work (renamed to Trainer)
|
||||
tune.run(
|
||||
PPOAgent,
|
||||
config={"env": "CartPole-v0"},
|
||||
stop={"training_iteration": 2})
|
||||
@@ -10,14 +10,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def renamed_class(cls):
|
||||
"""Helper class for renaming Agent => Trainer with a warning."""
|
||||
|
||||
class DeprecationWrapper(cls):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
old_name = cls.__name__.replace("Trainer", "Agent")
|
||||
new_name = cls.__name__
|
||||
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
|
||||
format(old_name, new_name) +
|
||||
"This will raise an error in the future.")
|
||||
cls.__init__(self, *args, **kwargs)
|
||||
cls.__init__(self, config, env, logger_creator)
|
||||
|
||||
DeprecationWrapper.__name__ = cls.__name__
|
||||
|
||||
return DeprecationWrapper
|
||||
|
||||
|
||||
Reference in New Issue
Block a user