Files
ray/python/ray/rllib/agent.py
T
Eric Liang ecb811c26e [rllib] Ape-X implementation and DQN refactor to handle replay in policy optimizer (#1604)
* minimal apex checkin

* cleanup dqn options

* actor utils

* Sun Feb 25 17:39:54 PST 2018

* update

* compression refactor

* fix

* add test

* fix models

* Sun Feb 25 21:46:27 PST 2018

* Wed Feb 28 10:26:34 PST 2018

* Wed Feb 28 10:28:09 PST 2018

* Wed Feb 28 10:42:59 PST 2018

* refactor

* Wed Feb 28 11:17:19 PST 2018

* Wed Feb 28 11:42:08 PST 2018

* Wed Feb 28 11:42:13 PST 2018

* Wed Feb 28 11:59:02 PST 2018

* Wed Feb 28 11:59:58 PST 2018

* Wed Feb 28 12:00:08 PST 2018

* Wed Feb 28 12:02:19 PST 2018

* Wed Feb 28 13:44:31 PST 2018

* Wed Feb 28 17:01:20 PST 2018

* Sat Mar  3 14:55:59 PST 2018

* make optimizer construction explicit

* Sat Mar  3 18:23:08 PST 2018

* Sat Mar  3 18:24:28 PST 2018

* Sat Mar  3 18:49:28 PST 2018

* Sat Mar  3 18:50:42 PST 2018

* Sat Mar  3 18:56:10 PST 2018
2018-03-04 12:25:25 -08:00

258 lines
8.0 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import os
import pickle
import tensorflow as tf
from ray.tune.registry import ENV_CREATOR
from ray.tune.result import TrainingResult
from ray.tune.trainable import Trainable
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def _deep_update(original, new_dict, new_keys_allowed, whitelist):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the whitelist, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
whitelist (list): List of keys that correspond to dict values
where new subkeys can be introduced. This is only at
the top level.
"""
for k, value in new_dict.items():
if k not in original and k != "env":
if not new_keys_allowed:
raise Exception(
"Unknown config parameter `{}` ".format(k))
if type(original.get(k)) is dict:
if k in whitelist:
_deep_update(original[k], value, True, [])
else:
_deep_update(original[k], value, new_keys_allowed, [])
else:
original[k] = value
return original
class Agent(Trainable):
"""All RLlib agents extend this base class.
Agent objects retain internal model state between calls to train(), so
you should create a new agent instance for each training session.
Attributes:
env_creator (func): Function that creates a new training env.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
registry (obj): Tune object registry which holds user-registered
classes and objects by name.
"""
_allow_unknown_configs = False
_allow_unknown_subkeys = []
def __init__(
self, config=None, env=None, registry=None,
logger_creator=None):
"""Initialize an RLLib agent.
Args:
config (dict): Algorithm-specific configuration data.
env (str): Name of the environment to use. Note that this can also
be specified as the `env` key in config.
registry (obj): Object registry for user-defined envs, models, etc.
If unspecified, the default registry will be used.
logger_creator (func): Function that creates a ray.tune.Logger
object. If unspecified, a default logger is created.
"""
config = config or {}
# Agents allow env ids to be passed directly to the constructor.
self._env_id = env or config.get("env")
Trainable.__init__(self, config, registry, logger_creator)
def _setup(self):
env = self._env_id
if env:
self.config["env"] = env
if self.registry and self.registry.contains(ENV_CREATOR, env):
self.env_creator = self.registry.get(ENV_CREATOR, env)
else:
import gym # soft dependency
self.env_creator = lambda env_config: gym.make(env)
else:
self.env_creator = lambda env_config: None
# Merge the supplied config with the class default
merged_config = self._default_config.copy()
merged_config = _deep_update(merged_config, self.config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
self.config = merged_config
# TODO(ekl) setting the graph is unnecessary for PyTorch agents
with tf.Graph().as_default():
self._init()
def _init(self):
"""Subclasses should override this for custom initialization."""
raise NotImplementedError
def compute_action(self, observation):
"""Computes an action using the current trained policy."""
raise NotImplementedError
@property
def iteration(self):
"""Current training iter, auto-incremented with each train() call."""
return self._iteration
@property
def _agent_name(self):
"""Subclasses should override this to declare their name."""
raise NotImplementedError
@property
def _default_config(self):
"""Subclasses should override this to declare their default config."""
raise NotImplementedError
class _MockAgent(Agent):
"""Mock agent for use in tests"""
_agent_name = "MockAgent"
_default_config = {
"mock_error": False,
"persistent_error": False,
}
def _init(self):
self.info = None
self.restored = False
def _train(self):
if self.config["mock_error"] and self.iteration == 1 \
and (self.config["persistent_error"] or not self.restored):
raise Exception("mock error")
return TrainingResult(
episode_reward_mean=10, episode_len_mean=10,
timesteps_this_iter=10, info={})
def _save(self, checkpoint_dir):
path = os.path.join(checkpoint_dir, "mock_agent.pkl")
with open(path, 'wb') as f:
pickle.dump(self.info, f)
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
info = pickle.load(f)
self.info = info
self.restored = True
def set_info(self, info):
self.info = info
return info
def get_info(self):
return self.info
class _SigmoidFakeData(_MockAgent):
"""Agent that returns sigmoid learning curves.
This can be helpful for evaluating early stopping algorithms."""
_agent_name = "SigmoidFakeData"
_default_config = {
"width": 100,
"height": 100,
"offset": 0,
"iter_time": 10,
"iter_timesteps": 1,
}
def _train(self):
i = max(0, self.iteration - self.config["offset"])
v = np.tanh(float(i) / self.config["width"])
v *= self.config["height"]
return TrainingResult(
episode_reward_mean=v, episode_len_mean=v,
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"], info={})
class _ParameterTuningAgent(_MockAgent):
_agent_name = "ParameterTuningAgent"
_default_config = {
"reward_amt": 10,
"dummy_param": 10,
"dummy_param2": 15,
"iter_time": 10,
"iter_timesteps": 1
}
def _train(self):
return TrainingResult(
episode_reward_mean=self.config["reward_amt"] * self.iteration,
episode_len_mean=self.config["reward_amt"],
timesteps_this_iter=self.config["iter_timesteps"],
time_this_iter_s=self.config["iter_time"], info={})
def get_agent_class(alg):
"""Returns the class of an known agent given its name."""
if alg == "PPO":
from ray.rllib import ppo
return ppo.PPOAgent
elif alg == "ES":
from ray.rllib import es
return es.ESAgent
elif alg == "DQN":
from ray.rllib import dqn
return dqn.DQNAgent
elif alg == "APEX":
from ray.rllib import dqn
return dqn.ApexAgent
elif alg == "A3C":
from ray.rllib import a3c
return a3c.A3CAgent
elif alg == "BC":
from ray.rllib import bc
return bc.BCAgent
elif alg == "PG":
from ray.rllib import pg
return pg.PGAgent
elif alg == "script":
from ray.tune import script_runner
return script_runner.ScriptRunner
elif alg == "__fake":
return _MockAgent
elif alg == "__sigmoid_fake_data":
return _SigmoidFakeData
elif alg == "__parameter_tuning":
return _ParameterTuningAgent
else:
raise Exception(
("Unknown algorithm {}.").format(alg))