Remove from X import Y convention in RLlib ES. (#1774)

This commit is contained in:
Robert Nishihara
2018-03-23 05:54:31 -07:00
committed by Philipp Moritz
parent 13b3df9321
commit 10dabce4d7
2 changed files with 10 additions and 8 deletions
+1 -1
View File
@@ -220,7 +220,7 @@ class _ParameterTuningAgent(_MockAgent):
def get_agent_class(alg):
"""Returns the class of an known agent given its name."""
"""Returns the class of a known agent given its name."""
if alg == "PPO":
from ray.rllib import ppo
+9 -7
View File
@@ -12,14 +12,12 @@ import pickle
import time
import ray
from ray.rllib.agent import Agent
from ray.rllib.models import ModelCatalog
from ray.rllib import agent
from ray.rllib.es import optimizers
from ray.rllib.es import policies
from ray.rllib.es import tabular_logger as tlogger
from ray.rllib.es import utils
from ray.tune.result import TrainingResult
Result = namedtuple("Result", [
@@ -72,7 +70,9 @@ class Worker(object):
self.noise = SharedNoiseTable(noise)
self.env = env_creator(config["env_config"])
self.preprocessor = ModelCatalog.get_preprocessor(registry, self.env)
from ray.rllib import models
self.preprocessor = models.ModelCatalog.get_preprocessor(
registry, self.env)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.GenericPolicy(
@@ -133,7 +133,7 @@ class Worker(object):
eval_lengths=eval_lengths)
class ESAgent(Agent):
class ESAgent(agent.Agent):
_agent_name = "ES"
_default_config = DEFAULT_CONFIG
_allow_unknown_subkeys = ["env_config"]
@@ -144,7 +144,9 @@ class ESAgent(Agent):
}
env = self.env_creator(self.config["env_config"])
preprocessor = ModelCatalog.get_preprocessor(self.registry, env)
from ray.rllib import models
preprocessor = models.ModelCatalog.get_preprocessor(
self.registry, env)
self.sess = utils.make_session(single_threaded=False)
self.policy = policies.GenericPolicy(
@@ -292,7 +294,7 @@ class ESAgent(Agent):
"time_elapsed": step_tend - self.tstart
}
result = TrainingResult(
result = ray.tune.result.TrainingResult(
episode_reward_mean=eval_returns.mean(),
episode_len_mean=eval_lengths.mean(),
timesteps_this_iter=noisy_lengths.sum(),