mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 23:46:50 +08:00
Remove from X import Y convention in RLlib ES. (#1774)
This commit is contained in:
committed by
Philipp Moritz
parent
13b3df9321
commit
10dabce4d7
@@ -220,7 +220,7 @@ class _ParameterTuningAgent(_MockAgent):
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of an known agent given its name."""
|
||||
"""Returns the class of a known agent given its name."""
|
||||
|
||||
if alg == "PPO":
|
||||
from ray.rllib import ppo
|
||||
|
||||
@@ -12,14 +12,12 @@ import pickle
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.rllib.agent import Agent
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib import agent
|
||||
|
||||
from ray.rllib.es import optimizers
|
||||
from ray.rllib.es import policies
|
||||
from ray.rllib.es import tabular_logger as tlogger
|
||||
from ray.rllib.es import utils
|
||||
from ray.tune.result import TrainingResult
|
||||
|
||||
|
||||
Result = namedtuple("Result", [
|
||||
@@ -72,7 +70,9 @@ class Worker(object):
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = env_creator(config["env_config"])
|
||||
self.preprocessor = ModelCatalog.get_preprocessor(registry, self.env)
|
||||
from ray.rllib import models
|
||||
self.preprocessor = models.ModelCatalog.get_preprocessor(
|
||||
registry, self.env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.GenericPolicy(
|
||||
@@ -133,7 +133,7 @@ class Worker(object):
|
||||
eval_lengths=eval_lengths)
|
||||
|
||||
|
||||
class ESAgent(Agent):
|
||||
class ESAgent(agent.Agent):
|
||||
_agent_name = "ES"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_allow_unknown_subkeys = ["env_config"]
|
||||
@@ -144,7 +144,9 @@ class ESAgent(Agent):
|
||||
}
|
||||
|
||||
env = self.env_creator(self.config["env_config"])
|
||||
preprocessor = ModelCatalog.get_preprocessor(self.registry, env)
|
||||
from ray.rllib import models
|
||||
preprocessor = models.ModelCatalog.get_preprocessor(
|
||||
self.registry, env)
|
||||
|
||||
self.sess = utils.make_session(single_threaded=False)
|
||||
self.policy = policies.GenericPolicy(
|
||||
@@ -292,7 +294,7 @@ class ESAgent(Agent):
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
|
||||
result = TrainingResult(
|
||||
result = ray.tune.result.TrainingResult(
|
||||
episode_reward_mean=eval_returns.mean(),
|
||||
episode_len_mean=eval_lengths.mean(),
|
||||
timesteps_this_iter=noisy_lengths.sum(),
|
||||
|
||||
Reference in New Issue
Block a user