[rllib] Rename PolicyGraph => Policy, move from evaluation/ to policy/ (#4819)

This implements some of the renames proposed in #4813
We leave behind backwards-compatibility aliases for *PolicyGraph and SampleBatch.
This commit is contained in:
Eric Liang
2019-05-20 16:46:05 -07:00
committed by Richard Liaw
parent 6cb5b90bd6
commit 02583a8598
91 changed files with 1955 additions and 1739 deletions
+6 -2
View File
@@ -10,12 +10,14 @@ from ray.tune.registry import register_trainable
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.sample_batch import SampleBatch
def _setup_logger():
@@ -43,7 +45,9 @@ _setup_logger()
_register_all()
__all__ = [
"Policy",
"PolicyGraph",
"TFPolicy",
"TFPolicyGraph",
"PolicyEvaluator",
"SampleBatch",
+3 -3
View File
@@ -1,9 +1,9 @@
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG
from ray.rllib.agents.a3c.a2c import A2CTrainer
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
A2CAgent = renamed_class(A2CTrainer)
A3CAgent = renamed_class(A3CTrainer)
A2CAgent = renamed_agent(A2CTrainer)
A3CAgent = renamed_agent(A3CTrainer)
__all__ = [
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG"
+4 -4
View File
@@ -4,7 +4,7 @@ from __future__ import print_function
import time
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.optimizers import AsyncGradientsOptimizer
from ray.rllib.utils.annotations import override
@@ -43,16 +43,16 @@ class A3CTrainer(Trainer):
_name = "A3C"
_default_config = DEFAULT_CONFIG
_policy_graph = A3CPolicyGraph
_policy = A3CTFPolicy
@override(Trainer)
def _init(self, config, env_creator):
if config["use_pytorch"]:
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
from ray.rllib.agents.a3c.a3c_torch_policy import \
A3CTorchPolicy
policy_cls = A3CTorchPolicy
else:
policy_cls = self._policy_graph
policy_cls = self._policy
if config["entropy_coeff"] < 0:
raise DeprecationWarning("entropy_coeff must be >= 0")
@@ -1,4 +1,4 @@
"""Note: Keep in sync with changes to VTracePolicyGraph."""
"""Note: Keep in sync with changes to VTraceTFPolicy."""
from __future__ import absolute_import
from __future__ import division
@@ -8,13 +8,13 @@ import gym
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
@@ -47,13 +47,13 @@ class A3CLoss(object):
class A3CPostprocessing(object):
"""Adds the VF preds and advantages fields to the trajectory."""
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
TFPolicy.extra_compute_action_fetches(self),
**{SampleBatch.VF_PREDS: self.vf})
@override(PolicyGraph)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@@ -73,7 +73,7 @@ class A3CPostprocessing(object):
self.config["lambda"])
class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
self.config = config
@@ -114,7 +114,7 @@ class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
self.vf, self.config["vf_loss_coeff"],
self.config["entropy_coeff"])
# Initialize TFPolicyGraph
# Initialize TFPolicy
loss_in = [
(SampleBatch.CUR_OBS, self.observations),
(SampleBatch.ACTIONS, actions),
@@ -125,7 +125,7 @@ class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
TFPolicy.__init__(
self,
observation_space,
action_space,
@@ -157,18 +157,18 @@ class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
self.sess.run(tf.global_variables_initializer())
@override(PolicyGraph)
@override(Policy)
def get_initial_state(self):
return self.model.state_init
@override(TFPolicyGraph)
@override(TFPolicy)
def gradients(self, optimizer, loss):
grads = tf.gradients(loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
clipped_grads = list(zip(self.grads, self.var_list))
return clipped_grads
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return self.stats_fetches
@@ -9,8 +9,8 @@ from torch import nn
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
def actor_critic_loss(policy, batch_tensors):
+2 -2
View File
@@ -3,6 +3,6 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import Trainer
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
Agent = renamed_class(Trainer)
Agent = renamed_agent(Trainer)
+2 -2
View File
@@ -1,6 +1,6 @@
from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG)
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
ARSAgent = renamed_class(ARSTrainer)
ARSAgent = renamed_agent(ARSTrainer)
__all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"]
+1 -1
View File
@@ -17,7 +17,7 @@ from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.ars import optimizers
from ray.rllib.agents.ars import policies
from ray.rllib.agents.ars import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager
+3 -3
View File
@@ -5,10 +5,10 @@ from __future__ import print_function
from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG
from ray.rllib.agents.ddpg.td3 import TD3Trainer
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
ApexDDPGAgent = renamed_class(ApexDDPGTrainer)
DDPGAgent = renamed_class(DDPGTrainer)
ApexDDPGAgent = renamed_agent(ApexDDPGTrainer)
DDPGAgent = renamed_agent(DDPGTrainer)
__all__ = [
"DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer",
+2 -2
View File
@@ -4,7 +4,7 @@ from __future__ import print_function
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
@@ -163,7 +163,7 @@ class DDPGTrainer(DQNTrainer):
"""DDPG implementation in TensorFlow."""
_name = "DDPG"
_default_config = DEFAULT_CONFIG
_policy_graph = DDPGPolicyGraph
_policy = DDPGTFPolicy
@override(DQNTrainer)
def _train(self):
@@ -7,15 +7,15 @@ import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.agents.dqn.dqn_policy_graph import (
_huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn)
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.agents.dqn.dqn_policy import (_huber_loss, _minimize_and_clip,
_scope_vars, _postprocess_dqn)
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@@ -35,7 +35,7 @@ PRIO_WEIGHTS = "weights"
class DDPGPostprocessing(object):
"""Implements n-step learning and param noise adjustments."""
@override(PolicyGraph)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@@ -68,7 +68,7 @@ class DDPGPostprocessing(object):
return _postprocess_dqn(self, sample_batch)
class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, **config)
if not isinstance(action_space, Box):
@@ -281,7 +281,7 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
self.critic_loss = self.twin_q_model.custom_loss(
self.critic_loss, input_dict)
TFPolicyGraph.__init__(
TFPolicy.__init__(
self,
observation_space,
action_space,
@@ -301,12 +301,12 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
# Hard initial update
self.update_target(tau=1.0)
@override(TFPolicyGraph)
@override(TFPolicy)
def optimizer(self):
# we don't use this because we have two separate optimisers
return None
@override(TFPolicyGraph)
@override(TFPolicy)
def build_apply_op(self, optimizer, grads_and_vars):
# for policy gradient, update policy net one time v.s.
# update critic net `policy_delay` time(s)
@@ -327,7 +327,7 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
with tf.control_dependencies([tf.assign_add(self.global_step, 1)]):
return tf.group(actor_op, critic_op)
@override(TFPolicyGraph)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if self.config["grad_norm_clipping"] is not None:
actor_grads_and_vars = _minimize_and_clip(
@@ -360,7 +360,7 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
+ self._critic_grads_and_vars
return grads_and_vars
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_action_feed_dict(self):
return {
# FIXME: what about turning off exploration? Isn't that a good
@@ -370,31 +370,31 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
self.pure_exploration_phase: self.cur_pure_exploration_phase,
}
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return {
"td_error": self.td_error,
LEARNER_STATS_KEY: self.stats,
}
@override(TFPolicyGraph)
@override(TFPolicy)
def get_weights(self):
return self.variables.get_weights()
@override(TFPolicyGraph)
@override(TFPolicy)
def set_weights(self, weights):
self.variables.set_weights(weights)
@override(PolicyGraph)
@override(Policy)
def get_state(self):
return [
TFPolicyGraph.get_state(self), self.cur_noise_scale,
TFPolicy.get_state(self), self.cur_noise_scale,
self.cur_pure_exploration_phase
]
@override(PolicyGraph)
@override(Policy)
def set_state(self, state):
TFPolicyGraph.set_state(self, state[0])
TFPolicy.set_state(self, state[0])
self.set_epsilon(state[1])
self.set_pure_exploration_phase(state[2])
+3 -3
View File
@@ -4,10 +4,10 @@ from __future__ import print_function
from ray.rllib.agents.dqn.apex import ApexTrainer
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
DQNAgent = renamed_class(DQNTrainer)
ApexAgent = renamed_class(ApexTrainer)
DQNAgent = renamed_agent(DQNTrainer)
ApexAgent = renamed_agent(ApexTrainer)
__all__ = [
"DQNAgent", "ApexAgent", "ApexTrainer", "DQNTrainer", "DEFAULT_CONFIG"
+5 -5
View File
@@ -8,9 +8,9 @@ import time
from ray import tune
from ray.rllib import optimizers
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
@@ -133,7 +133,7 @@ class DQNTrainer(Trainer):
_name = "DQN"
_default_config = DEFAULT_CONFIG
_policy_graph = DQNPolicyGraph
_policy = DQNTFPolicy
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
@override(Trainer)
@@ -197,10 +197,10 @@ class DQNTrainer(Trainer):
on_episode_end)
self.local_evaluator = self.make_local_evaluator(
env_creator, self._policy_graph)
env_creator, self._policy)
def create_remote_evaluators():
return self.make_remote_evaluators(env_creator, self._policy_graph,
return self.make_remote_evaluators(env_creator, self._policy,
config["num_workers"])
if config["optimizer_class"] != "AsyncReplayOptimizer":
@@ -7,13 +7,13 @@ import numpy as np
from scipy.stats import entropy
import ray
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.models import ModelCatalog, Categorical
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.utils import try_import_tf
@@ -105,14 +105,14 @@ class QLoss(object):
class DQNPostprocessing(object):
"""Implements n-step learning and param noise adjustments."""
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self), **{
TFPolicy.extra_compute_action_fetches(self), **{
"q_values": self.q_values,
})
@override(PolicyGraph)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@@ -345,7 +345,7 @@ class QValuePolicy(object):
self.action_prob = None
class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
if not isinstance(action_space, Discrete):
@@ -446,7 +446,7 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
update_target_expr.append(var_target.assign(var))
self.update_target_expr = tf.group(*update_target_expr)
# initialize TFPolicyGraph
# initialize TFPolicy
self.sess = tf.get_default_session()
self.loss_inputs = [
(SampleBatch.CUR_OBS, self.obs_t),
@@ -459,7 +459,7 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
TFPolicy.__init__(
self,
observation_space,
action_space,
@@ -477,12 +477,12 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
"cur_lr": tf.cast(self.cur_lr, tf.float64),
}, **self.loss.stats)
@override(TFPolicyGraph)
@override(TFPolicy)
def optimizer(self):
return tf.train.AdamOptimizer(
learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"])
@override(TFPolicyGraph)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if self.config["grad_norm_clipping"] is not None:
grads_and_vars = _minimize_and_clip(
@@ -496,27 +496,27 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
return grads_and_vars
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_action_feed_dict(self):
return {
self.stochastic: True,
self.eps: self.cur_epsilon,
}
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return {
"td_error": self.loss.td_error,
LEARNER_STATS_KEY: self.stats_fetches,
}
@override(PolicyGraph)
@override(Policy)
def get_state(self):
return [TFPolicyGraph.get_state(self), self.cur_epsilon]
return [TFPolicy.get_state(self), self.cur_epsilon]
@override(PolicyGraph)
@override(Policy)
def set_state(self, state):
TFPolicyGraph.set_state(self, state[0])
TFPolicy.set_state(self, state[0])
self.set_epsilon(state[1])
def _build_parameter_noise(self, pnet_params):
@@ -633,25 +633,25 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
rewards[i] += gamma**j * rewards[i + j]
def _postprocess_dqn(policy_graph, batch):
def _postprocess_dqn(policy, batch):
# N-step Q adjustments
if policy_graph.config["n_step"] > 1:
_adjust_nstep(policy_graph.config["n_step"],
policy_graph.config["gamma"], batch[SampleBatch.CUR_OBS],
batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS],
batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES])
if policy.config["n_step"] > 1:
_adjust_nstep(policy.config["n_step"], policy.config["gamma"],
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES])
if PRIO_WEIGHTS not in batch:
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
# Prioritize on the worker side
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
td_errors = policy_graph.compute_td_error(
if batch.count > 0 and policy.config["worker_side_prioritization"]:
td_errors = policy.compute_td_error(
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
new_priorities = (
np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"])
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
batch.data[PRIO_WEIGHTS] = new_priorities
return batch
+2 -2
View File
@@ -1,6 +1,6 @@
from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG)
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
ESAgent = renamed_class(ESTrainer)
ESAgent = renamed_agent(ESTrainer)
__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"]
+1 -1
View File
@@ -16,7 +16,7 @@ from ray.rllib.agents import Trainer, with_common_config
from ray.rllib.agents.es import optimizers
from ray.rllib.agents.es import policies
from ray.rllib.agents.es import utils
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.memory import ray_get_and_free
from ray.rllib.utils import FilterManager
+2 -2
View File
@@ -1,6 +1,6 @@
from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
ImpalaAgent = renamed_class(ImpalaTrainer)
ImpalaAgent = renamed_agent(ImpalaTrainer)
__all__ = ["ImpalaAgent", "ImpalaTrainer", "DEFAULT_CONFIG"]
+7 -7
View File
@@ -4,8 +4,8 @@ from __future__ import print_function
import time
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.optimizers import AsyncSamplesOptimizer
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
@@ -105,14 +105,14 @@ class ImpalaTrainer(Trainer):
_name = "IMPALA"
_default_config = DEFAULT_CONFIG
_policy_graph = VTracePolicyGraph
_policy = VTraceTFPolicy
@override(Trainer)
def _init(self, config, env_creator):
for k in OPTIMIZER_SHARED_CONFIGS:
if k not in config["optimizer"]:
config["optimizer"][k] = config[k]
policy_cls = self._get_policy_graph()
policy_cls = self._get_policy()
self.local_evaluator = self.make_local_evaluator(
self.env_creator, policy_cls)
@@ -158,9 +158,9 @@ class ImpalaTrainer(Trainer):
prev_steps)
return result
def _get_policy_graph(self):
def _get_policy(self):
if self.config["vtrace"]:
policy_cls = self._policy_graph
policy_cls = self._policy
else:
policy_cls = A3CPolicyGraph
policy_cls = A3CTFPolicy
return policy_cls
@@ -1,6 +1,6 @@
"""Adapted from A3CPolicyGraph to add V-trace.
"""Adapted from A3CTFPolicy to add V-trace.
Keep in sync with changes to A3CPolicyGraph and VtraceSurrogatePolicyGraph."""
Keep in sync with changes to A3CTFPolicy and VtraceSurrogatePolicy."""
from __future__ import absolute_import
from __future__ import division
@@ -11,9 +11,9 @@ import ray
import numpy as np
from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy, \
LearningRateSchedule
from ray.rllib.models.action_dist import MultiCategorical
from ray.rllib.models.catalog import ModelCatalog
@@ -110,13 +110,13 @@ class VTraceLoss(object):
class VTracePostprocessing(object):
"""Adds the policy logits to the trajectory."""
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
TFPolicy.extra_compute_action_fetches(self),
**{BEHAVIOUR_LOGITS: self.model.outputs})
@override(PolicyGraph)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@@ -126,8 +126,7 @@ class VTracePostprocessing(object):
return sample_batch
class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
TFPolicyGraph):
class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
def __init__(self,
observation_space,
action_space,
@@ -285,7 +284,7 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
"max_KL": tf.reduce_max(kls[0]),
}
# Initialize TFPolicyGraph
# Initialize TFPolicy
loss_in = [
(SampleBatch.ACTIONS, actions),
(SampleBatch.DONES, dones),
@@ -297,7 +296,7 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
]
LearningRateSchedule.__init__(self, self.config["lr"],
self.config["lr_schedule"])
TFPolicyGraph.__init__(
TFPolicy.__init__(
self,
observation_space,
action_space,
@@ -332,15 +331,15 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
}, **self.KL_stats),
}
@override(TFPolicyGraph)
@override(TFPolicy)
def copy(self, existing_inputs):
return VTracePolicyGraph(
return VTraceTFPolicy(
self.observation_space,
self.action_space,
self.config,
existing_inputs=existing_inputs)
@override(TFPolicyGraph)
@override(TFPolicy)
def optimizer(self):
if self.config["opt_type"] == "adam":
return tf.train.AdamOptimizer(self.cur_lr)
@@ -349,17 +348,17 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
self.config["momentum"],
self.config["epsilon"])
@override(TFPolicyGraph)
@override(TFPolicy)
def gradients(self, optimizer, loss):
grads = tf.gradients(loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
clipped_grads = list(zip(self.grads, self.var_list))
return clipped_grads
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return self.stats_fetches
@override(PolicyGraph)
@override(Policy)
def get_initial_state(self):
return self.model.state_init
+4 -4
View File
@@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.agents.marwil.marwil_policy_graph import MARWILPolicyGraph
from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy
from ray.rllib.optimizers import SyncBatchReplayOptimizer
from ray.rllib.utils.annotations import override
@@ -44,14 +44,14 @@ class MARWILTrainer(Trainer):
_name = "MARWIL"
_default_config = DEFAULT_CONFIG
_policy_graph = MARWILPolicyGraph
_policy = MARWILPolicy
@override(Trainer)
def _init(self, config, env_creator):
self.local_evaluator = self.make_local_evaluator(
env_creator, self._policy_graph)
env_creator, self._policy)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, self._policy_graph, config["num_workers"])
env_creator, self._policy, config["num_workers"])
self.optimizer = SyncBatchReplayOptimizer(
self.local_evaluator,
self.remote_evaluators,
@@ -6,12 +6,12 @@ import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.utils.annotations import override
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.agents.dqn.dqn_policy import _scope_vars
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils import try_import_tf
@@ -59,7 +59,7 @@ class ReweightedImitationLoss(object):
class MARWILPostprocessing(object):
"""Adds the advantages field to the trajectory."""
@override(PolicyGraph)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@@ -79,7 +79,7 @@ class MARWILPostprocessing(object):
return batch
class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph):
class MARWILPolicy(MARWILPostprocessing, TFPolicy):
def __init__(self, observation_space, action_space, config):
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
self.config = config
@@ -127,14 +127,14 @@ class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph):
self.explained_variance = tf.reduce_mean(
explained_variance(self.cum_rew_t, state_values))
# initialize TFPolicyGraph
# initialize TFPolicy
self.sess = tf.get_default_session()
self.loss_inputs = [
(SampleBatch.CUR_OBS, self.obs_t),
(SampleBatch.ACTIONS, self.act_t),
(Postprocessing.ADVANTAGES, self.cum_rew_t),
]
TFPolicyGraph.__init__(
TFPolicy.__init__(
self,
observation_space,
action_space,
@@ -166,10 +166,10 @@ class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph):
return ReweightedImitationLoss(state_values, cum_rwds, logits, actions,
action_space, self.config["beta"])
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return {LEARNER_STATS_KEY: self.stats_fetches}
@override(PolicyGraph)
@override(Policy)
def get_initial_state(self):
return self.model.state_init
+2 -2
View File
@@ -1,6 +1,6 @@
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
PGAgent = renamed_class(PGTrainer)
PGAgent = renamed_agent(PGTrainer)
__all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"]
+2 -2
View File
@@ -4,7 +4,7 @@ from __future__ import print_function
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
# yapf: disable
# __sphinx_doc_begin__
@@ -22,7 +22,7 @@ DEFAULT_CONFIG = with_common_config({
def get_policy_class(config):
if config["use_pytorch"]:
from ray.rllib.agents.pg.torch_pg_policy_graph import PGTorchPolicy
from ray.rllib.agents.pg.torch_pg_policy import PGTorchPolicy
return PGTorchPolicy
else:
return PGTFPolicy
@@ -5,8 +5,8 @@ from __future__ import print_function
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
@@ -5,8 +5,8 @@ from __future__ import print_function
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
def pg_torch_loss(policy, batch_tensors):
+2 -2
View File
@@ -1,7 +1,7 @@
from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG
from ray.rllib.agents.ppo.appo import APPOTrainer
from ray.rllib.utils import renamed_class
from ray.rllib.utils import renamed_agent
PPOAgent = renamed_class(PPOTrainer)
PPOAgent = renamed_agent(PPOTrainer)
__all__ = ["PPOAgent", "APPOTrainer", "PPOTrainer", "DEFAULT_CONFIG"]
+3 -3
View File
@@ -2,7 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOTFPolicy
from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy
from ray.rllib.agents.trainer import with_base_config
from ray.rllib.agents import impala
from ray.rllib.utils.annotations import override
@@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer):
_name = "APPO"
_default_config = DEFAULT_CONFIG
_policy_graph = AsyncPPOTFPolicy
_policy = AsyncPPOTFPolicy
@override(impala.ImpalaTrainer)
def _get_policy_graph(self):
def _get_policy(self):
return AsyncPPOTFPolicy
@@ -1,6 +1,6 @@
"""Adapted from VTracePolicyGraph to use the PPO surrogate loss.
"""Adapted from VTraceTFPolicy to use the PPO surrogate loss.
Keep in sync with changes to VTracePolicyGraph."""
Keep in sync with changes to VTraceTFPolicy."""
from __future__ import absolute_import
from __future__ import division
@@ -13,9 +13,9 @@ import gym
import ray
from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.utils import try_import_tf
+2 -3
View File
@@ -5,7 +5,7 @@ from __future__ import print_function
import logging
from ray.rllib.agents import with_common_config
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
@@ -143,8 +143,7 @@ def validate_config(config):
raise ValueError(
"Episode truncation is not supported without a value "
"function. Consider setting batch_mode=complete_episodes.")
if (config["multiagent"]["policy_graphs"]
and not config["simple_optimizer"]):
if (config["multiagent"]["policies"] and not config["simple_optimizer"]):
logger.info(
"In multi-agent mode, policies will be optimized sequentially "
"by the multi-GPU optimizer. Consider setting "
@@ -7,9 +7,9 @@ import logging
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import LearningRateSchedule
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils import try_import_tf
+2 -2
View File
@@ -4,7 +4,7 @@ from __future__ import print_function
from ray.rllib.agents.trainer import with_common_config
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
# yapf: disable
# __sphinx_doc_begin__
@@ -95,7 +95,7 @@ class QMixTrainer(DQNTrainer):
_name = "QMIX"
_default_config = DEFAULT_CONFIG
_policy_graph = QMixPolicyGraph
_policy = QMixTorchPolicy
_optimizer_shared_configs = [
"learning_starts", "buffer_size", "train_batch_size"
]
@@ -14,8 +14,8 @@ import ray
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
from ray.rllib.agents.qmix.model import RNNModel, _get_size
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.action_dist import TupleActions
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.lstm import chop_into_sequences
@@ -130,7 +130,7 @@ class QMixLoss(nn.Module):
return loss, mask, masked_td_error, chosen_action_qvals, targets
class QMixPolicyGraph(PolicyGraph):
class QMixTorchPolicy(Policy):
"""QMix impl. Assumes homogeneous agents for now.
You must use MultiAgentEnv.with_agent_groups() to group agents
@@ -213,7 +213,7 @@ class QMixPolicyGraph(PolicyGraph):
alpha=config["optim_alpha"],
eps=config["optim_eps"])
@override(PolicyGraph)
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
@@ -243,7 +243,7 @@ class QMixPolicyGraph(PolicyGraph):
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
@override(PolicyGraph)
@override(Policy)
def learn_on_batch(self, samples):
obs_batch, action_mask = self._unpack_observation(
samples[SampleBatch.CUR_OBS])
@@ -314,22 +314,22 @@ class QMixPolicyGraph(PolicyGraph):
}
return {LEARNER_STATS_KEY: stats}
@override(PolicyGraph)
@override(Policy)
def get_initial_state(self):
return [
s.expand([self.n_agents, -1]).numpy()
for s in self.model.state_init()
]
@override(PolicyGraph)
@override(Policy)
def get_weights(self):
return {"model": self.model.state_dict()}
@override(PolicyGraph)
@override(Policy)
def set_weights(self, weights):
self.model.load_state_dict(weights["model"])
@override(PolicyGraph)
@override(Policy)
def get_state(self):
return {
"model": self.model.state_dict(),
@@ -340,7 +340,7 @@ class QMixPolicyGraph(PolicyGraph):
"cur_epsilon": self.cur_epsilon,
}
@override(PolicyGraph)
@override(Policy)
def set_state(self, state):
self.model.load_state_dict(state["model"])
self.target_model.load_state_dict(state["target_model"])
+25 -24
View File
@@ -19,7 +19,7 @@ from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
_validate_multiagent_config
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
@@ -220,9 +220,9 @@ COMMON_CONFIG = {
# === Multiagent ===
"multiagent": {
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
# Map from policy ids to tuples of (policy_cls, obs_space,
# act_space, config). See policy_evaluator.py for more info.
"policy_graphs": {},
"policies": {},
# Function mapping agent ids to policy ids.
"policy_mapping_fn": None,
# Optional whitelist of policies to train, or None for all policies.
@@ -435,9 +435,7 @@ class Trainer(Trainable):
"using evaluation_config: {}".format(extra_config))
# Make local evaluation evaluators
self.evaluation_ev = self.make_local_evaluator(
self.env_creator,
self._policy_graph,
extra_config=extra_config)
self.env_creator, self._policy, extra_config=extra_config)
self.evaluation_metrics = self._evaluate()
@override(Trainable)
@@ -578,10 +576,10 @@ class Trainer(Trainable):
@PublicAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
"""Return policy for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
policy_id (str): id of policy to return.
"""
return self.local_evaluator.get_policy(policy_id)
@@ -606,16 +604,13 @@ class Trainer(Trainable):
self.local_evaluator.set_weights(weights)
@DeveloperAPI
def make_local_evaluator(self,
env_creator,
policy_graph,
extra_config=None):
def make_local_evaluator(self, env_creator, policy, extra_config=None):
"""Convenience method to return configured local evaluator."""
return self._make_evaluator(
PolicyEvaluator,
env_creator,
policy_graph,
policy,
0,
merge_dicts(
# important: allow local tf to use more CPUs for optimization
@@ -627,7 +622,7 @@ class Trainer(Trainable):
extra_config or {}))
@DeveloperAPI
def make_remote_evaluators(self, env_creator, policy_graph, count):
def make_remote_evaluators(self, env_creator, policy, count):
"""Convenience method to return a number of remote evaluators."""
remote_args = {
@@ -639,8 +634,8 @@ class Trainer(Trainable):
cls = PolicyEvaluator.as_remote(**remote_args).remote
return [
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
self.config) for i in range(count)
self._make_evaluator(cls, env_creator, policy, i + 1, self.config)
for i in range(count)
]
@DeveloperAPI
@@ -700,6 +695,13 @@ class Trainer(Trainable):
@staticmethod
def _validate_config(config):
if "policy_graphs" in config["multiagent"]:
logger.warning(
"The `policy_graphs` config has been renamed to `policies`.")
# Backwards compatibility
config["multiagent"]["policies"] = config["multiagent"][
"policy_graphs"]
del config["multiagent"]["policy_graphs"]
if "gpu" in config:
raise ValueError(
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
@@ -760,8 +762,7 @@ class Trainer(Trainable):
return hasattr(self, "optimizer") and isinstance(
self.optimizer, PolicyOptimizer)
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
config):
def _make_evaluator(self, cls, env_creator, policy, worker_index, config):
def session_creator():
logger.debug("Creating TF session {}".format(
config["tf_session_args"]))
@@ -803,18 +804,18 @@ class Trainer(Trainable):
else:
input_evaluation = config["input_evaluation"]
# Fill in the default policy graph if 'None' is specified in multiagent
if self.config["multiagent"]["policy_graphs"]:
tmp = self.config["multiagent"]["policy_graphs"]
# Fill in the default policy if 'None' is specified in multiagent
if self.config["multiagent"]["policies"]:
tmp = self.config["multiagent"]["policies"]
_validate_multiagent_config(tmp, allow_none_graph=True)
for k, v in tmp.items():
if v[0] is None:
tmp[k] = (policy_graph, v[1], v[2], v[3])
policy_graph = tmp
tmp[k] = (policy, v[1], v[2], v[3])
policy = tmp
return cls(
env_creator,
policy_graph,
policy,
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
policies_to_train=self.config["multiagent"]["policies_to_train"],
tf_session_creator=(session_creator
+7 -7
View File
@@ -21,7 +21,7 @@ def build_trainer(name,
Arguments:
name (str): name of the trainer (e.g., "PPO")
default_policy (cls): the default PolicyGraph class to use
default_policy (cls): the default Policy class to use
default_config (dict): the default config dict of the algorithm,
otherwises uses the Trainer default config
make_policy_optimizer (func): optional function that returns a
@@ -30,7 +30,7 @@ def build_trainer(name,
validate_config (func): optional callback that checks a given config
for correctness. It may mutate the config as needed.
get_policy_class (func): optional callback that takes a config and
returns the policy graph class to override the default with
returns the policy class to override the default with
before_train_step (func): optional callback to run before each train()
call. It takes the trainer instance as an argument.
after_optimizer_step (func): optional callback to run after each
@@ -51,19 +51,19 @@ def build_trainer(name,
class trainer_cls(Trainer):
_name = name
_default_config = default_config or Trainer.COMMON_CONFIG
_policy_graph = default_policy
_policy = default_policy
def _init(self, config, env_creator):
if validate_config:
validate_config(config)
if get_policy_class is None:
policy_graph = default_policy
policy = default_policy
else:
policy_graph = get_policy_class(config)
policy = get_policy_class(config)
self.local_evaluator = self.make_local_evaluator(
env_creator, policy_graph)
env_creator, policy)
self.remote_evaluators = self.make_remote_evaluators(
env_creator, policy_graph, config["num_workers"])
env_creator, policy, config["num_workers"])
if make_policy_optimizer:
self.optimizer = make_policy_optimizer(
self.local_evaluator, self.remote_evaluators, config)
+2 -2
View File
@@ -27,7 +27,7 @@ class MultiAgentEpisode(object):
user_data (dict): Dict that you can use for temporary storage.
Use case 1: Model-based rollouts in multi-agent:
A custom compute_actions() function in a policy graph can inspect the
A custom compute_actions() function in a policy can inspect the
current episode state and perform a number of rollouts based on the
policies and state of other agents in the environment.
@@ -80,7 +80,7 @@ class MultiAgentEpisode(object):
@DeveloperAPI
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the policy graph for the specified agent.
"""Returns the policy for the specified agent.
If the agent is new, the policy mapping fn will be called to bind the
agent to a policy for the duration of the episode.
+1 -1
View File
@@ -62,7 +62,7 @@ class EvaluatorInterface(object):
Returns:
(grads, info): A list of gradients that can be applied on a
compatible evaluator. In the multi-agent case, returns a dict
of gradients keyed by policy graph ids. An info dictionary of
of gradients keyed by policy ids. An info dictionary of
extra metadata is also returned.
Examples:
+3 -6
View File
@@ -7,21 +7,18 @@ import numpy as np
import collections
import ray
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
from ray.rllib.policy.policy import LEARNER_STATS_KEY
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
# By convention, metrics from optimizing the loss can be reported in the
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
LEARNER_STATS_KEY = "learner_stats"
@DeveloperAPI
def get_learner_stats(grad_info):
"""Return optimization stats reported from the policy graph.
"""Return optimization stats reported from the policy.
Example:
>>> grad_info = evaluator.learn_on_batch(samples)
+41 -45
View File
@@ -15,11 +15,10 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.env.vector_env import VectorEnv
from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
@@ -52,9 +51,9 @@ def get_global_evaluator():
@DeveloperAPI
class PolicyEvaluator(EvaluatorInterface):
"""Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``.
"""Common ``PolicyEvaluator`` implementation that wraps a ``Policy``.
This class wraps a policy graph instance and an environment class to
This class wraps a policy instance and an environment class to
collect experiences from the environment. You can create many replicas of
this class as Ray actors to scale RL training.
@@ -65,7 +64,7 @@ class PolicyEvaluator(EvaluatorInterface):
>>> # Create a policy evaluator and using it to collect experiences.
>>> evaluator = PolicyEvaluator(
... env_creator=lambda _: gym.make("CartPole-v0"),
... policy_graph=PGTFPolicy)
... policy=PGTFPolicy)
>>> print(evaluator.sample())
SampleBatch({
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
@@ -76,7 +75,7 @@ class PolicyEvaluator(EvaluatorInterface):
... evaluator_cls=PolicyEvaluator,
... evaluator_args={
... "env_creator": lambda _: gym.make("CartPole-v0"),
... "policy_graph": PGTFPolicy,
... "policy": PGTFPolicy,
... },
... num_workers=10)
>>> for _ in range(10): optimizer.step()
@@ -84,7 +83,7 @@ class PolicyEvaluator(EvaluatorInterface):
>>> # Creating a multi-agent policy evaluator
>>> evaluator = PolicyEvaluator(
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
... policy_graphs={
... policies={
... # Use an ensemble of two policies for car agents
... "car_policy1":
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
@@ -113,7 +112,7 @@ class PolicyEvaluator(EvaluatorInterface):
@DeveloperAPI
def __init__(self,
env_creator,
policy_graph,
policy,
policy_mapping_fn=None,
policies_to_train=None,
tf_session_creator=None,
@@ -147,9 +146,9 @@ class PolicyEvaluator(EvaluatorInterface):
Arguments:
env_creator (func): Function that returns a gym.Env given an
EnvContext wrapped configuration.
policy_graph (class|dict): Either a class implementing
PolicyGraph, or a dictionary of policy id strings to
(PolicyGraph, obs_space, action_space, config) tuples. If a
policy (class|dict): Either a class implementing
Policy, or a dictionary of policy id strings to
(Policy, obs_space, action_space, config) tuples. If a
dict is specified, then we are in multi-agent mode and a
policy_mapping_fn should also be set.
policy_mapping_fn (func): A function that maps agent ids to
@@ -159,7 +158,7 @@ class PolicyEvaluator(EvaluatorInterface):
policies_to_train (list): Optional whitelist of policies to train,
or None for all policies.
tf_session_creator (func): A function that returns a TF session.
This is optional and only useful with TFPolicyGraph.
This is optional and only useful with TFPolicy.
batch_steps (int): The target number of env transitions to include
in each sample batch returned from this evaluator.
batch_mode (str): One of the following batch modes:
@@ -196,7 +195,7 @@ class PolicyEvaluator(EvaluatorInterface):
model_config (dict): Config to use when creating the policy model.
policy_config (dict): Config to pass to the policy. In the
multi-agent case, this config will be merged with the
per-policy configs specified by `policy_graph`.
per-policy configs specified by `policy`.
worker_index (int): For remote evaluators, this should be set to a
non-zero and unique value. This index is passed to created envs
through EnvContext so that envs can be configured per worker.
@@ -301,7 +300,7 @@ class PolicyEvaluator(EvaluatorInterface):
vector_index=vector_index, remote=remote_worker_envs)))
self.tf_sess = None
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
policy_dict = _validate_and_canonicalize(policy, self.env)
self.policies_to_train = policies_to_train or list(policy_dict.keys())
if _has_tensorflow_graph(policy_dict):
if (ray.is_initialized()
@@ -330,7 +329,7 @@ class PolicyEvaluator(EvaluatorInterface):
or isinstance(self.env, ExternalMultiAgentEnv))
or isinstance(self.env, BaseEnv)):
raise ValueError(
"Have multiple policy graphs {}, but the env ".format(
"Have multiple policies {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
"ExternalMultiAgentEnv?".format(self.env))
@@ -608,17 +607,17 @@ class PolicyEvaluator(EvaluatorInterface):
@DeveloperAPI
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
"""Return policy graph for the specified id, or None.
"""Return policy for the specified id, or None.
Arguments:
policy_id (str): id of policy graph to return.
policy_id (str): id of policy to return.
"""
return self.policy_map.get(policy_id)
@DeveloperAPI
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
"""Apply the given function to the specified policy graph."""
"""Apply the given function to the specified policy."""
return func(self.policy_map[policy_id])
@@ -708,7 +707,7 @@ class PolicyEvaluator(EvaluatorInterface):
preprocessors = {}
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
logger.debug("Creating policy graph for {}".format(name))
logger.debug("Creating policy for {}".format(name))
merged_conf = merge_dicts(policy_config, conf)
if self.preprocessing_enabled:
preprocessor = ModelCatalog.get_preprocessor_for_space(
@@ -720,7 +719,7 @@ class PolicyEvaluator(EvaluatorInterface):
if isinstance(obs_space, gym.spaces.Dict) or \
isinstance(obs_space, gym.spaces.Tuple):
raise ValueError(
"Found raw Tuple|Dict space as input to policy graph. "
"Found raw Tuple|Dict space as input to policy. "
"Please preprocess these observations with a "
"Tuple|DictFlatteningPreprocessor.")
if tf:
@@ -738,12 +737,12 @@ class PolicyEvaluator(EvaluatorInterface):
self.sampler.shutdown = True
def _validate_and_canonicalize(policy_graph, env):
if isinstance(policy_graph, dict):
_validate_multiagent_config(policy_graph)
return policy_graph
elif not issubclass(policy_graph, PolicyGraph):
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
def _validate_and_canonicalize(policy, env):
if isinstance(policy, dict):
_validate_multiagent_config(policy)
return policy
elif not issubclass(policy, Policy):
raise ValueError("policy must be a rllib.Policy class")
else:
if (isinstance(env, MultiAgentEnv)
and not hasattr(env, "observation_space")):
@@ -751,38 +750,35 @@ def _validate_and_canonicalize(policy_graph, env):
"MultiAgentEnv must have observation_space defined if run "
"in a single-agent configuration.")
return {
DEFAULT_POLICY_ID: (policy_graph, env.observation_space,
DEFAULT_POLICY_ID: (policy, env.observation_space,
env.action_space, {})
}
def _validate_multiagent_config(policy_graph, allow_none_graph=False):
for k, v in policy_graph.items():
def _validate_multiagent_config(policy, allow_none_graph=False):
for k, v in policy.items():
if not isinstance(k, str):
raise ValueError("policy_graph keys must be strs, got {}".format(
raise ValueError("policy keys must be strs, got {}".format(
type(k)))
if not isinstance(v, tuple) or len(v) != 4:
raise ValueError(
"policy_graph values must be tuples of "
"policy values must be tuples of "
"(cls, obs_space, action_space, config), got {}".format(v))
if allow_none_graph and v[0] is None:
pass
elif not issubclass(v[0], PolicyGraph):
raise ValueError(
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
"class or None, got {}".format(v[0]))
elif not issubclass(v[0], Policy):
raise ValueError("policy tuple value 0 must be a rllib.Policy "
"class or None, got {}".format(v[0]))
if not isinstance(v[1], gym.Space):
raise ValueError(
"policy_graph tuple value 1 (observation_space) must be a "
"policy tuple value 1 (observation_space) must be a "
"gym.Space, got {}".format(type(v[1])))
if not isinstance(v[2], gym.Space):
raise ValueError(
"policy_graph tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
raise ValueError("policy tuple value 2 (action_space) must be a "
"gym.Space, got {}".format(type(v[2])))
if not isinstance(v[3], dict):
raise ValueError(
"policy_graph tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
raise ValueError("policy tuple value 3 (config) must be a dict, "
"got {}".format(type(v[3])))
def _validate_env(env):
@@ -805,6 +801,6 @@ def _monitor(env, path):
def _has_tensorflow_graph(policy_dict):
for policy, _, _, _ in policy_dict.values():
if issubclass(policy, TFPolicyGraph):
if issubclass(policy, TFPolicy):
return True
return False
+3 -282
View File
@@ -2,286 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import gym
from ray.rllib.policy.policy import Policy
from ray.rllib.utils import renamed_class
from ray.rllib.utils.annotations import DeveloperAPI
@DeveloperAPI
class PolicyGraph(object):
"""An agent policy and loss, i.e., a TFPolicyGraph or other subclass.
This object defines how to act in the environment, and also losses used to
improve the policy based on its experiences. Note that both policy and
loss are defined together for convenience, though the policy itself is
logically separate.
All policies can directly extend PolicyGraph, however TensorFlow users may
find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib
to apply TensorFlow-specific optimizations such as fusing multiple policy
graphs and multi-GPU support.
Attributes:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
"""
@DeveloperAPI
def __init__(self, observation_space, action_space, config):
"""Initialize the graph.
This is the standard constructor for policy graphs. The policy graph
class you pass into PolicyEvaluator will be constructed with
these arguments.
Args:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
config (dict): Policy-specific configuration data.
"""
self.observation_space = observation_space
self.action_space = action_space
@DeveloperAPI
def compute_actions(self,
obs_batch,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
"""Compute actions for the current policy.
Arguments:
obs_batch (np.ndarray): batch of observations
state_batches (list): list of RNN state input batches, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
info_batch (info): batch of info objects
episodes (list): MultiAgentEpisode for each obs in obs_batch.
This provides access to all of the internal episode state,
which may be useful for model-based or multiagent algorithms.
kwargs: forward compatibility placeholder
Returns:
actions (np.ndarray): batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (list): list of RNN state output batches, if any, with
shape like [STATE_SIZE, BATCH_SIZE].
info (dict): dictionary of extra feature batches, if any, with
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
@DeveloperAPI
def compute_single_action(self,
obs,
state,
prev_action=None,
prev_reward=None,
info=None,
episode=None,
clip_actions=False,
**kwargs):
"""Unbatched version of compute_actions.
Arguments:
obs (obj): single observation
state_batches (list): list of RNN state inputs, if any
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
clip_actions (bool): should the action be clipped
kwargs: forward compatibility placeholder
Returns:
actions (obj): single action
state_outs (list): list of RNN state outputs, if any
info (dict): dictionary of extra features, if any
"""
prev_action_batch = None
prev_reward_batch = None
info_batch = None
episodes = None
if prev_action is not None:
prev_action_batch = [prev_action]
if prev_reward is not None:
prev_reward_batch = [prev_reward]
if info is not None:
info_batch = [info]
if episode is not None:
episodes = [episode]
[action], state_out, info = self.compute_actions(
[obs], [[s] for s in state],
prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch,
info_batch=info_batch,
episodes=episodes)
if clip_actions:
action = clip_action(action, self.action_space)
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
@DeveloperAPI
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
evaluation. Each fragment is guaranteed to be only from one episode.
Arguments:
sample_batch (SampleBatch): batch of experiences for the policy,
which will contain at most one episode trajectory.
other_agent_batches (dict): In a multi-agent env, this contains a
mapping of agent ids to (policy_graph, agent_batch) tuples
containing the policy graph and experiences of the other agent.
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
Returns:
SampleBatch: postprocessed sample batch.
"""
return sample_batch
@DeveloperAPI
def learn_on_batch(self, samples):
"""Fused compute gradients and apply gradients call.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
"""
grads, grad_info = self.compute_gradients(samples)
self.apply_gradients(grads)
return grad_info
@DeveloperAPI
def compute_gradients(self, postprocessed_batch):
"""Computes gradients against a batch of experiences.
Either this or learn_on_batch() must be implemented by subclasses.
Returns:
grads (list): List of gradient output values
info (dict): Extra policy-specific values
"""
raise NotImplementedError
@DeveloperAPI
def apply_gradients(self, gradients):
"""Applies previously computed gradients.
Either this or learn_on_batch() must be implemented by subclasses.
"""
raise NotImplementedError
@DeveloperAPI
def get_weights(self):
"""Returns model weights.
Returns:
weights (obj): Serializable copy or view of model weights
"""
raise NotImplementedError
@DeveloperAPI
def set_weights(self, weights):
"""Sets model weights.
Arguments:
weights (obj): Serializable copy or view of model weights
"""
raise NotImplementedError
@DeveloperAPI
def get_initial_state(self):
"""Returns initial RNN state for the current policy."""
return []
@DeveloperAPI
def get_state(self):
"""Saves all local state.
Returns:
state (obj): Serialized local state.
"""
return self.get_weights()
@DeveloperAPI
def set_state(self, state):
"""Restores all local state.
Arguments:
state (obj): Serialized local state.
"""
self.set_weights(state)
@DeveloperAPI
def on_global_var_update(self, global_vars):
"""Called on an update to global vars.
Arguments:
global_vars (dict): Global variables broadcast from the driver.
"""
pass
@DeveloperAPI
def export_model(self, export_dir):
"""Export PolicyGraph to local directory for serving.
Arguments:
export_dir (str): Local writable directory.
"""
raise NotImplementedError
@DeveloperAPI
def export_checkpoint(self, export_dir):
"""Export PolicyGraph checkpoint to local directory.
Argument:
export_dir (str): Local writable directory.
"""
raise NotImplementedError
def clip_action(action, space):
"""Called to clip actions to the specified range of this policy.
Arguments:
action: Single action.
space: Action space the actions should be present in.
Returns:
Clipped batch of actions.
"""
if isinstance(space, gym.spaces.Box):
return np.clip(action, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(action) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
action, space))
out = []
for a, s in zip(action, space.spaces):
out.append(clip_action(a, s))
return out
else:
return action
PolicyGraph = renamed_class(Policy, old_name="PolicyGraph")
@@ -4,7 +4,7 @@ from __future__ import print_function
import numpy as np
import scipy.signal
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import DeveloperAPI
+6 -291
View File
@@ -2,295 +2,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import collections
import numpy as np
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils import renamed_class
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import pack, unpack, is_compressed
from ray.rllib.utils.memory import concat_aligned
# Defaults policy id for single agent environments
DEFAULT_POLICY_ID = "default_policy"
@PublicAPI
class MultiAgentBatch(object):
"""A batch of experiences from multiple policies in the environment.
Attributes:
policy_batches (dict): Mapping from policy id to a normal SampleBatch
of experiences. Note that these batches may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
"""
@PublicAPI
def __init__(self, policy_batches, count):
self.policy_batches = policy_batches
self.count = count
@staticmethod
@PublicAPI
def wrap_as_needed(batches, count):
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
@staticmethod
@PublicAPI
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
total_count = 0
for s in samples:
assert isinstance(s, MultiAgentBatch)
for policy_id, batch in s.policy_batches.items():
policy_batches[policy_id].append(batch)
total_count += s.count
out = {}
for policy_id, batches in policy_batches.items():
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, total_count)
@PublicAPI
def copy(self):
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
@PublicAPI
def total(self):
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
def __str__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
def __repr__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
@PublicAPI
class SampleBatch(object):
"""Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
samples, each with an "obs" and "reward" attribute.
"""
# Outputs from interacting with the environment
CUR_OBS = "obs"
NEXT_OBS = "new_obs"
ACTIONS = "actions"
REWARDS = "rewards"
PREV_ACTIONS = "prev_actions"
PREV_REWARDS = "prev_rewards"
DONES = "dones"
INFOS = "infos"
# Uniquely identifies an episode
EPS_ID = "eps_id"
# Uniquely identifies a sample batch. This is important to distinguish RNN
# sequences from the same episode when multiple sample batches are
# concatenated (fusing sequences across batches can be unsafe).
UNROLL_ID = "unroll_id"
# Uniquely identifies an agent within an episode
AGENT_INDEX = "agent_index"
# Value function predictions emitted by the behaviour policy
VF_PREDS = "vf_preds"
@PublicAPI
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor)."""
self.data = dict(*args, **kwargs)
lengths = []
for k, v in self.data.copy().items():
assert isinstance(k, six.string_types), self
lengths.append(len(v))
self.data[k] = np.array(v, copy=False)
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, "data columns must be same length"
self.count = lengths[0]
@staticmethod
@PublicAPI
def concat_samples(samples):
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
out[k] = concat_aligned([s[k] for s in samples])
return SampleBatch(out)
@PublicAPI
def concat(self, other):
"""Returns a new SampleBatch with each data column concatenated.
Examples:
>>> b1 = SampleBatch({"a": [1, 2]})
>>> b2 = SampleBatch({"a": [3, 4, 5]})
>>> print(b1.concat(b2))
{"a": [1, 2, 3, 4, 5]}
"""
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.keys():
out[k] = concat_aligned([self[k], other[k]])
return SampleBatch(out)
@PublicAPI
def copy(self):
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
@PublicAPI
def rows(self):
"""Returns an iterator over data rows, i.e. dicts with column values.
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> for row in batch.rows():
print(row)
{"a": 1, "b": 4}
{"a": 2, "b": 5}
{"a": 3, "b": 6}
"""
for i in range(self.count):
row = {}
for k in self.keys():
row[k] = self[k][i]
yield row
@PublicAPI
def columns(self, keys):
"""Returns a list of just the specified columns.
Examples:
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
"""
out = []
for k in keys:
out.append(self[k])
return out
@PublicAPI
def shuffle(self):
"""Shuffles the rows of this batch in-place."""
permutation = np.random.permutation(self.count)
for key, val in self.items():
self[key] = val[permutation]
@PublicAPI
def split_by_episode(self):
"""Splits this batch's data by `eps_id`.
Returns:
list of SampleBatch, one per distinct episode.
"""
slices = []
cur_eps_id = self.data["eps_id"][0]
offset = 0
for i in range(self.count):
next_eps_id = self.data["eps_id"][i]
if next_eps_id != cur_eps_id:
slices.append(self.slice(offset, i))
offset = i
cur_eps_id = next_eps_id
slices.append(self.slice(offset, self.count))
for s in slices:
slen = len(set(s["eps_id"]))
assert slen == 1, (s, slen)
assert sum(s.count for s in slices) == self.count, (slices, self.count)
return slices
@PublicAPI
def slice(self, start, end):
"""Returns a slice of the row data of this batch.
Arguments:
start (int): Starting index.
end (int): Ending index.
Returns:
SampleBatch which has a slice of this batch's data.
"""
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
@PublicAPI
def keys(self):
return self.data.keys()
@PublicAPI
def items(self):
return self.data.items()
@PublicAPI
def __getitem__(self, key):
return self.data[key]
@PublicAPI
def __setitem__(self, key, item):
self.data[key] = item
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
for key in columns:
if key in self.data:
if bulk:
self.data[key] = pack(self.data[key])
else:
self.data[key] = np.array(
[pack(o) for o in self.data[key]])
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
for key in columns:
if key in self.data:
arr = self.data[key]
if is_compressed(arr):
self.data[key] = unpack(arr)
elif len(arr) > 0 and is_compressed(arr[0]):
self.data[key] = np.array(
[unpack(o) for o in self.data[key]])
def __str__(self):
return "SampleBatch({})".format(str(self.data))
def __repr__(self):
return "SampleBatch({})".format(str(self.data))
def __iter__(self):
return self.data.__iter__()
def __contains__(self, x):
return x in self.data
SampleBatch = renamed_class(
SampleBatch, old_name="rllib.evaluation.SampleBatch")
MultiAgentBatch = renamed_class(
MultiAgentBatch, old_name="rllib.evaluation.MultiAgentBatch")
@@ -6,7 +6,7 @@ import collections
import logging
import numpy as np
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.debug import log_once, summarize
@@ -79,7 +79,7 @@ class MultiAgentSampleBatchBuilder(object):
"""Initialize a MultiAgentSampleBatchBuilder.
Arguments:
policy_map (dict): Maps policy ids to policy graph instances.
policy_map (dict): Maps policy ids to policy instances.
clip_rewards (bool): Whether to clip rewards before postprocessing.
postp_callback: function to call on each postprocessed batch.
"""
+4 -4
View File
@@ -12,7 +12,7 @@ import time
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
from ray.rllib.evaluation.sample_batch_builder import \
MultiAgentSampleBatchBuilder
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
from ray.rllib.models.action_dist import TupleActions
@@ -20,7 +20,7 @@ from ray.rllib.offline import InputReader
from ray.rllib.utils.annotations import override
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.evaluation.policy_graph import clip_action
from ray.rllib.policy.policy import clip_action
logger = logging.getLogger(__name__)
@@ -236,7 +236,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
Args:
base_env (BaseEnv): env implementing BaseEnv.
extra_batch_callback (fn): function to send extra batch data to.
policies (dict): Map of policy ids to PolicyGraph instances.
policies (dict): Map of policy ids to Policy instances.
policy_mapping_fn (func): Function that maps agent ids to policy ids.
This is called when an agent first enters the environment. The
agent is then "bound" to the returned policy for the episode.
@@ -528,7 +528,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
policy = _get_or_raise(policies, policy_id)
if builder and (policy.compute_actions.__code__ is
TFPolicyGraph.compute_actions.__code__):
TFPolicy.compute_actions.__code__):
# TODO(ekl): how can we make info batch available to TF code?
pending_fetches[policy_id] = policy._build_compute_actions(
builder, [t.obs for t in eval_data],
+3 -509
View File
@@ -2,513 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import errno
import logging
import numpy as np
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils import renamed_class
import ray
import ray.experimental.tf_utils
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)
@DeveloperAPI
class TFPolicyGraph(PolicyGraph):
"""An agent policy and loss implemented in TensorFlow.
Extending this class enables RLlib to perform TensorFlow specific
optimizations on the policy graph, e.g., parallelization across gpus or
fusing multiple graphs together in the multi-agent setting.
Input tensors are typically shaped like [BATCH_SIZE, ...].
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
model (rllib.models.Model): RLlib model used for the policy.
Examples:
>>> policy = TFPolicyGraphSubclass(
sess, obs_input, action_sampler, loss, loss_inputs)
>>> print(policy.compute_actions([1, 0, 2]))
(array([0, 1, 1]), [], {})
>>> print(policy.postprocess_trajectory(SampleBatch({...})))
SampleBatch({"action": ..., "advantages": ..., ...})
"""
@DeveloperAPI
def __init__(self,
observation_space,
action_space,
sess,
obs_input,
action_sampler,
loss,
loss_inputs,
model=None,
action_prob=None,
state_inputs=None,
state_outputs=None,
prev_action_input=None,
prev_reward_input=None,
seq_lens=None,
max_seq_len=20,
batch_divisibility_req=1,
update_ops=None):
"""Initialize the policy graph.
Arguments:
observation_space (gym.Space): Observation space of the env.
action_space (gym.Space): Action space of the env.
sess (Session): TensorFlow session to use.
obs_input (Tensor): input placeholder for observations, of shape
[BATCH_SIZE, obs...].
action_sampler (Tensor): Tensor for sampling an action, of shape
[BATCH_SIZE, action...]
loss (Tensor): scalar policy loss output tensor.
loss_inputs (list): a (name, placeholder) tuple for each loss
input argument. Each placeholder name must correspond to a
SampleBatch column key returned by postprocess_trajectory(),
and has shape [BATCH_SIZE, data...]. These keys will be read
from postprocessed sample batches and fed into the specified
placeholders during loss computation.
model (rllib.models.Model): used to integrate custom losses and
stats from user-defined RLlib models.
action_prob (Tensor): probability of the sampled action.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
prev_action_input (Tensor): placeholder for previous actions
prev_reward_input (Tensor): placeholder for previous rewards
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
models/lstm.py for more information.
max_seq_len (int): max sequence length for LSTM training.
batch_divisibility_req (int): pad all agent experiences batches to
multiples of this value. This only has an effect if not using
a LSTM model.
update_ops (list): override the batchnorm update ops to run when
applying gradients. Otherwise we run all update ops found in
the current variable scope.
"""
self.observation_space = observation_space
self.action_space = action_space
self.model = model
self._sess = sess
self._obs_input = obs_input
self._prev_action_input = prev_action_input
self._prev_reward_input = prev_reward_input
self._sampler = action_sampler
self._is_training = self._get_is_training_placeholder()
self._action_prob = action_prob
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
self._seq_lens = seq_lens
self._max_seq_len = max_seq_len
self._batch_divisibility_req = batch_divisibility_req
self._update_ops = update_ops
self._stats_fetches = {}
if loss is not None:
self._initialize_loss(loss, loss_inputs)
else:
self._loss = None
if len(self._state_inputs) != len(self._state_outputs):
raise ValueError(
"Number of state input and output tensors must match, got: "
"{} vs {}".format(self._state_inputs, self._state_outputs))
if len(self.get_initial_state()) != len(self._state_inputs):
raise ValueError(
"Length of initial state must match number of state inputs, "
"got: {} vs {}".format(self.get_initial_state(),
self._state_inputs))
if self._state_inputs and self._seq_lens is None:
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")
def _initialize_loss(self, loss, loss_inputs):
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
for i, ph in enumerate(self._state_inputs):
self._loss_input_dict["state_in_{}".format(i)] = ph
if self.model:
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
self._stats_fetches.update({"model": self.model.custom_stats()})
else:
self._loss = loss
self._optimizer = self.optimizer()
self._grads_and_vars = [
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
if g is not None
]
self._grads = [g for (g, v) in self._grads_and_vars]
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
self._loss, self._sess)
# gather update ops for any batch norm layers
if not self._update_ops:
self._update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
if self._update_ops:
logger.debug("Update ops to run on apply gradient: {}".format(
self._update_ops))
with tf.control_dependencies(self._update_ops):
self._apply_op = self.build_apply_op(self._optimizer,
self._grads_and_vars)
if log_once("loss_used"):
logger.debug(
"These tensors were used in the loss_fn:\n\n{}\n".format(
summarize(self._loss_input_dict)))
self._sess.run(tf.global_variables_initializer())
@override(PolicyGraph)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self._build_compute_actions(builder, obs_batch,
state_batches, prev_action_batch,
prev_reward_batch)
return builder.get(fetches)
@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@override(PolicyGraph)
def apply_gradients(self, gradients):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)
@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
@override(PolicyGraph)
def get_weights(self):
return self._variables.get_flat()
@override(PolicyGraph)
def set_weights(self, weights):
return self._variables.set_flat(weights)
@override(PolicyGraph)
def export_model(self, export_dir):
"""Export tensorflow graph to export_dir for serving."""
with self._sess.graph.as_default():
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
signature_def_map = self._build_signature_def()
builder.add_meta_graph_and_variables(
self._sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map=signature_def_map)
builder.save()
@override(PolicyGraph)
def export_checkpoint(self, export_dir, filename_prefix="model"):
"""Export tensorflow checkpoint to export_dir."""
try:
os.makedirs(export_dir)
except OSError as e:
# ignore error if export dir already exists
if e.errno != errno.EEXIST:
raise
save_path = os.path.join(export_dir, filename_prefix)
with self._sess.graph.as_default():
saver = tf.train.Saver()
saver.save(self._sess, save_path)
@DeveloperAPI
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders.
Optional, only required to work with the multi-GPU optimizer."""
raise NotImplementedError
@DeveloperAPI
def extra_compute_action_feed_dict(self):
"""Extra dict to pass to the compute actions session run."""
return {}
@DeveloperAPI
def extra_compute_action_fetches(self):
"""Extra values to fetch and return from compute_actions().
By default we only return action probability info (if present).
"""
if self._action_prob is not None:
return {"action_prob": self._action_prob}
else:
return {}
@DeveloperAPI
def extra_compute_grad_feed_dict(self):
"""Extra dict to pass to the compute gradients session run."""
return {} # e.g, kl_coeff
@DeveloperAPI
def extra_compute_grad_fetches(self):
"""Extra values to fetch and return from compute_gradients()."""
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
@DeveloperAPI
def optimizer(self):
"""TF optimizer to use for policy optimization."""
if hasattr(self, "config"):
return tf.train.AdamOptimizer(self.config["lr"])
else:
return tf.train.AdamOptimizer()
@DeveloperAPI
def gradients(self, optimizer, loss):
"""Override for custom gradient computation."""
return optimizer.compute_gradients(loss)
@DeveloperAPI
def build_apply_op(self, optimizer, grads_and_vars):
"""Override for custom gradient apply computation."""
# specify global_step for TD3 which needs to count the num updates
return optimizer.apply_gradients(
self._grads_and_vars,
global_step=tf.train.get_or_create_global_step())
@DeveloperAPI
def _get_is_training_placeholder(self):
"""Get the placeholder for _is_training, i.e., for batch norm layers.
This can be called safely before __init__ has run.
"""
if not hasattr(self, "_is_training"):
self._is_training = tf.placeholder_with_default(False, ())
return self._is_training
def _extra_input_signature_def(self):
"""Extra input signatures to add when exporting tf model.
Inferred from extra_compute_action_feed_dict()
"""
feed_dict = self.extra_compute_action_feed_dict()
return {
k.name: tf.saved_model.utils.build_tensor_info(k)
for k in feed_dict.keys()
}
def _extra_output_signature_def(self):
"""Extra output signatures to add when exporting tf model.
Inferred from extra_compute_action_fetches()
"""
fetches = self.extra_compute_action_fetches()
return {
k: tf.saved_model.utils.build_tensor_info(fetches[k])
for k in fetches.keys()
}
def _build_signature_def(self):
"""Build signature def map for tensorflow SavedModelBuilder.
"""
# build input signatures
input_signature = self._extra_input_signature_def()
input_signature["observations"] = \
tf.saved_model.utils.build_tensor_info(self._obs_input)
if self._seq_lens is not None:
input_signature["seq_lens"] = \
tf.saved_model.utils.build_tensor_info(self._seq_lens)
if self._prev_action_input is not None:
input_signature["prev_action"] = \
tf.saved_model.utils.build_tensor_info(self._prev_action_input)
if self._prev_reward_input is not None:
input_signature["prev_reward"] = \
tf.saved_model.utils.build_tensor_info(self._prev_reward_input)
input_signature["is_training"] = \
tf.saved_model.utils.build_tensor_info(self._is_training)
for state_input in self._state_inputs:
input_signature[state_input.name] = \
tf.saved_model.utils.build_tensor_info(state_input)
# build output signatures
output_signature = self._extra_output_signature_def()
output_signature["actions"] = \
tf.saved_model.utils.build_tensor_info(self._sampler)
for state_output in self._state_outputs:
output_signature[state_output.name] = \
tf.saved_model.utils.build_tensor_info(state_output)
signature_def = (
tf.saved_model.signature_def_utils.build_signature_def(
input_signature, output_signature,
tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
signature_def_key = (tf.saved_model.signature_constants.
DEFAULT_SERVING_SIGNATURE_DEF_KEY)
signature_def_map = {signature_def_key: signature_def}
return signature_def_map
def _build_compute_actions(self,
builder,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
state_batches = state_batches or []
if len(self._state_inputs) != len(state_batches):
raise ValueError(
"Must pass in RNN state batches for placeholders {}, got {}".
format(self._state_inputs, state_batches))
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
if self._prev_action_input is not None and prev_action_batch:
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
if self._prev_reward_input is not None and prev_reward_batch:
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
builder.add_feed_dict({self._is_training: False})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
fetches = builder.add_fetches([self._sampler] + self._state_outputs +
[self.extra_compute_action_fetches()])
return fetches[0], fetches[1:-1], fetches[-1]
def _build_compute_gradients(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
fetches = builder.add_fetches(
[self._grads, self._get_grad_and_stats_fetches()])
return fetches[0], fetches[1]
def _build_apply_gradients(self, builder, gradients):
if len(gradients) != len(self._grads):
raise ValueError(
"Unexpected number of gradients to apply, got {} for {}".
format(gradients, self._grads))
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(dict(zip(self._grads, gradients)))
fetches = builder.add_fetches([self._apply_op])
return fetches[0]
def _build_learn_on_batch(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
builder.add_feed_dict({self._is_training: True})
fetches = builder.add_fetches([
self._apply_op,
self._get_grad_and_stats_fetches(),
])
return fetches[1]
def _get_grad_and_stats_fetches(self):
fetches = self.extra_compute_grad_fetches()
if LEARNER_STATS_KEY not in fetches:
raise ValueError(
"Grad fetches should contain 'stats': {...} entry")
if self._stats_fetches:
fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches,
**fetches[LEARNER_STATS_KEY])
return fetches
def _get_loss_inputs_dict(self, batch):
feed_dict = {}
if self._batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) %
self._batch_divisibility_req == 0
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
else:
meets_divisibility_reqs = True
# Simple case: not RNN nor do we need to pad
if not self._state_inputs and meets_divisibility_reqs:
for k, ph in self._loss_inputs:
feed_dict[ph] = batch[k]
return feed_dict
if self._state_inputs:
max_seq_len = self._max_seq_len
dynamic_max = True
else:
max_seq_len = self._batch_divisibility_req
dynamic_max = False
# RNN or multi-agent case
feature_keys = [k for k, v in self._loss_inputs]
state_keys = [
"state_in_{}".format(i) for i in range(len(self._state_inputs))
]
feature_sequences, initial_states, seq_lens = chop_into_sequences(
batch[SampleBatch.EPS_ID],
batch[SampleBatch.UNROLL_ID],
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
[batch[k] for k in state_keys],
max_seq_len,
dynamic_max=dynamic_max)
for k, v in zip(feature_keys, feature_sequences):
feed_dict[self._loss_input_dict[k]] = v
for k, v in zip(state_keys, initial_states):
feed_dict[self._loss_input_dict[k]] = v
feed_dict[self._seq_lens] = seq_lens
if log_once("rnn_feed_dict"):
logger.info("Padded input for RNN:\n\n{}\n".format(
summarize({
"features": feature_sequences,
"initial_states": initial_states,
"seq_lens": seq_lens,
"max_seq_len": max_seq_len,
})))
return feed_dict
@DeveloperAPI
class LearningRateSchedule(object):
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self.cur_lr = tf.get_variable("lr", initializer=lr)
if lr_schedule is None:
self.lr_schedule = ConstantSchedule(lr)
else:
self.lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1])
@override(PolicyGraph)
def on_global_var_update(self, global_vars):
super(LearningRateSchedule, self).on_global_var_update(global_vars)
self.cur_lr.load(
self.lr_schedule.value(global_vars["timestep"]),
session=self._sess)
@override(TFPolicyGraph)
def optimizer(self):
return tf.train.AdamOptimizer(self.cur_lr)
TFPolicyGraph = renamed_class(TFPolicy, old_name="TFPolicyGraph")
@@ -2,9 +2,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -27,7 +27,7 @@ def build_tf_policy(name,
"""Helper function for creating a dynamic tf policy at runtime.
Arguments:
name (str): name of the graph (e.g., "PPOPolicy")
name (str): name of the policy (e.g., "PPOTFPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
get_default_config (func): optional function that returns the default
@@ -39,7 +39,7 @@ def build_tf_policy(name,
extra_action_fetches_fn (func): optional function that returns
a dict of TF fetches given the policy object
postprocess_fn (func): optional experience postprocessing function
that takes the same args as PolicyGraph.postprocess_trajectory()
that takes the same args as Policy.postprocess_trajectory()
optimizer_fn (func): optional function that returns a tf.Optimizer
given the policy and config
gradients_fn (func): optional function that returns a list of gradients
@@ -57,18 +57,18 @@ def build_tf_policy(name,
arguments
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the DynamicTFPolicyGraph class
precedence than the DynamicTFPolicy class
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
Returns:
a DynamicTFPolicyGraph instance that uses the specified args
a DynamicTFPolicy instance that uses the specified args
"""
if not name.endswith("TFPolicy"):
raise ValueError("Name should match *TFPolicy", name)
base = DynamicTFPolicyGraph
base = DynamicTFPolicy
while mixins:
class new_base(mixins.pop(), base):
@@ -76,7 +76,7 @@ def build_tf_policy(name,
base = new_base
class graph_cls(base):
class policy_cls(base):
def __init__(self,
obs_space,
action_space,
@@ -97,7 +97,7 @@ def build_tf_policy(name,
else:
self._extra_action_fetches = extra_action_fetches_fn(self)
DynamicTFPolicyGraph.__init__(
DynamicTFPolicy.__init__(
self,
obs_space,
action_space,
@@ -111,7 +111,7 @@ def build_tf_policy(name,
if after_init:
after_init(self, obs_space, action_space, config)
@override(PolicyGraph)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@@ -121,26 +121,26 @@ def build_tf_policy(name,
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TFPolicyGraph)
@override(TFPolicy)
def optimizer(self):
if optimizer_fn:
return optimizer_fn(self, self.config)
else:
return TFPolicyGraph.optimizer(self)
return TFPolicy.optimizer(self)
@override(TFPolicyGraph)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if gradients_fn:
return gradients_fn(self, optimizer, loss)
else:
return TFPolicyGraph.gradients(self, optimizer, loss)
return TFPolicy.gradients(self, optimizer, loss)
@override(TFPolicyGraph)
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicyGraph.extra_compute_action_fetches(self),
TFPolicy.extra_compute_action_fetches(self),
**self._extra_action_fetches)
graph_cls.__name__ = name
graph_cls.__qualname__ = name
return graph_cls
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls
@@ -2,173 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.utils import renamed_class
import numpy as np
from threading import Lock
try:
import torch
except ImportError:
pass # soft dep
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.utils.annotations import override
from ray.rllib.utils.tracking_dict import UsageTrackingDict
class TorchPolicyGraph(PolicyGraph):
"""Template for a PyTorch policy and loss to use with RLlib.
This is similar to TFPolicyGraph, but for PyTorch.
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
lock (Lock): Lock that must be held around PyTorch ops on this graph.
This is necessary when using the async sampler.
"""
def __init__(self, observation_space, action_space, model, loss,
action_distribution_cls):
"""Build a policy graph from policy and loss torch modules.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
is set. Only single GPU is supported for now.
Arguments:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
model (nn.Module): PyTorch policy module. Given observations as
input, this module must return a list of outputs where the
first item is action logits, and the rest can be any value.
loss (func): Function that takes (policy_graph, batch_tensors)
and returns a single scalar loss.
action_distribution_cls (ActionDistribution): Class for action
distribution.
"""
self.observation_space = observation_space
self.action_space = action_space
self.lock = Lock()
self.device = (torch.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else torch.device("cpu"))
self._model = model.to(self.device)
self._loss = loss
self._optimizer = self.optimizer()
self._action_dist_cls = action_distribution_cls
@override(PolicyGraph)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
with self.lock:
with torch.no_grad():
ob = torch.from_numpy(np.array(obs_batch)) \
.float().to(self.device)
model_out = self._model({"obs": ob}, state_batches)
logits, _, vf, state = model_out
action_dist = self._action_dist_cls(logits)
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(model_out))
@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_out = self._loss(self, batch_tensors)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
self._optimizer.step()
grad_info = self.extra_grad_info(batch_tensors)
grad_info.update(grad_process_info)
return {LEARNER_STATS_KEY: grad_info}
@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_out = self._loss(self, batch_tensors)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
# Note that return values are just references;
# calling zero_grad will modify the values
grads = []
for p in self._model.parameters():
if p.grad is not None:
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)
grad_info = self.extra_grad_info(batch_tensors)
grad_info.update(grad_process_info)
return grads, {LEARNER_STATS_KEY: grad_info}
@override(PolicyGraph)
def apply_gradients(self, gradients):
with self.lock:
for g, p in zip(gradients, self._model.parameters()):
if g is not None:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()
@override(PolicyGraph)
def get_weights(self):
with self.lock:
return {k: v.cpu() for k, v in self._model.state_dict().items()}
@override(PolicyGraph)
def set_weights(self, weights):
with self.lock:
self._model.load_state_dict(weights)
@override(PolicyGraph)
def get_initial_state(self):
return [s.numpy() for s in self._model.state_init()]
def extra_grad_process(self):
"""Allow subclass to do extra processing on gradients and
return processing info."""
return {}
def extra_action_out(self, model_out):
"""Returns dict of extra info to include in experience batch.
Arguments:
model_out (list): Outputs of the policy model module."""
return {}
def extra_grad_info(self, batch_tensors):
"""Return dict of extra grad info."""
return {}
def optimizer(self):
"""Custom PyTorch optimizer to use."""
if hasattr(self, "config"):
return torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])
else:
return torch.optim.Adam(self._model.parameters())
def _lazy_tensor_dict(self, postprocessed_batch):
batch_tensors = UsageTrackingDict(postprocessed_batch)
batch_tensors.set_get_interceptor(
lambda arr: torch.from_numpy(arr).to(self.device))
return batch_tensors
TorchPolicyGraph = renamed_class(TorchPolicy, old_name="TorchPolicyGraph")
@@ -209,7 +209,7 @@ if __name__ == "__main__":
"log_level": "INFO",
"entropy_coeff": 0.01,
"multiagent": {
"policy_graphs": {
"policies": {
"high_level_policy": (None, maze.observation_space,
Discrete(4), {
"gamma": 0.9
@@ -6,7 +6,7 @@ from __future__ import print_function
Control the number of agents and policies via --num-agents and --num-policies.
This works with hundreds of agents and policies, but note that initializing
many TF policy graphs will take some time.
many TF policies will take some time.
Also, TF evals might slow down with large numbers of policies. To debug TF
execution, set the TF_TIMELINE_DIR environment variable.
@@ -90,12 +90,12 @@ if __name__ == "__main__":
}
return (None, obs_space, act_space, config)
# Setup PPO with an ensemble of `num_policies` different policy graphs
policy_graphs = {
# Setup PPO with an ensemble of `num_policies` different policies
policies = {
"policy_{}".format(i): gen_policy(i)
for i in range(args.num_policies)
}
policy_ids = list(policy_graphs.keys())
policy_ids = list(policies.keys())
tune.run(
"PPO",
@@ -105,7 +105,7 @@ if __name__ == "__main__":
"log_level": "DEBUG",
"num_sgd_iter": 10,
"multiagent": {
"policy_graphs": policy_graphs,
"policies": policies,
"policy_mapping_fn": tune.function(
lambda agent_id: random.choice(policy_ids)),
},
@@ -22,7 +22,7 @@ import gym
import ray
from ray import tune
from ray.rllib.evaluation import PolicyGraph
from ray.rllib.policy import Policy
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune.registry import register_env
@@ -30,7 +30,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--num-iters", type=int, default=20)
class RandomPolicy(PolicyGraph):
class RandomPolicy(Policy):
"""Hand-coded policy that returns random actions."""
def compute_actions(self,
@@ -65,7 +65,7 @@ if __name__ == "__main__":
config={
"env": "multi_cartpole",
"multiagent": {
"policy_graphs": {
"policies": {
"pg_policy": (None, obs_space, act_space, {}),
"random": (RandomPolicy, obs_space, act_space, {}),
},
@@ -16,9 +16,9 @@ import gym
import ray
from ray.rllib.agents.dqn.dqn import DQNTrainer
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
from ray.rllib.agents.ppo.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
from ray.tune.registry import register_env
@@ -36,11 +36,11 @@ if __name__ == "__main__":
obs_space = single_env.observation_space
act_space = single_env.action_space
# You can also have multiple policy graphs per trainer, but here we just
# You can also have multiple policies per trainer, but here we just
# show one each for PPO and DQN.
policy_graphs = {
policies = {
"ppo_policy": (PPOTFPolicy, obs_space, act_space, {}),
"dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}),
"dqn_policy": (DQNTFPolicy, obs_space, act_space, {}),
}
def policy_mapping_fn(agent_id):
@@ -53,7 +53,7 @@ if __name__ == "__main__":
env="multi_cartpole",
config={
"multiagent": {
"policy_graphs": policy_graphs,
"policies": policies,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": ["ppo_policy"],
},
@@ -66,7 +66,7 @@ if __name__ == "__main__":
env="multi_cartpole",
config={
"multiagent": {
"policy_graphs": policy_graphs,
"policies": policies,
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": ["dqn_policy"],
},
@@ -1,7 +1,7 @@
"""Example of using policy evaluator classes directly to implement training.
Instead of using the built-in Trainer classes provided by RLlib, here we define
a custom PolicyGraph class and manually coordinate distributed sample
a custom Policy class and manually coordinate distributed sample
collection and policy optimization.
"""
@@ -14,7 +14,8 @@ import gym
import ray
from ray import tune
from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.evaluation import PolicyEvaluator, SampleBatch
from ray.rllib.evaluation.metrics import collect_metrics
parser = argparse.ArgumentParser()
@@ -23,15 +24,15 @@ parser.add_argument("--num-iters", type=int, default=20)
parser.add_argument("--num-workers", type=int, default=2)
class CustomPolicy(PolicyGraph):
"""Example of a custom policy graph written from scratch.
class CustomPolicy(Policy):
"""Example of a custom policy written from scratch.
You might find it more convenient to extend TF/TorchPolicyGraph instead
You might find it more convenient to extend TF/TorchPolicy instead
for a real policy.
"""
def __init__(self, observation_space, action_space, config):
PolicyGraph.__init__(self, observation_space, action_space, config)
Policy.__init__(self, observation_space, action_space, config)
# example parameter
self.w = 1.0
@@ -4,19 +4,19 @@ from __future__ import print_function
import numpy as np
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.policy.policy import Policy
def _sample(probs):
return [np.random.choice(len(pr), p=pr) for pr in probs]
class KerasPolicyGraph(PolicyGraph):
"""Initialize the Keras Policy Graph.
class KerasPolicy(Policy):
"""Initialize the Keras Policy.
This is a Policy Graph used for models with actor and critics.
This is a Policy used for models with actor and critics.
Note: This class is built for specific usage of Actor-Critic models,
and is less general compared to TFPolicyGraph and TorchPolicyGraphs.
and is less general compared to TFPolicy and TorchPolicies.
Args:
observation_space (gym.Space): Observation space of the policy.
@@ -32,7 +32,7 @@ class KerasPolicyGraph(PolicyGraph):
config,
actor=None,
critic=None):
PolicyGraph.__init__(self, observation_space, action_space, config)
Policy.__init__(self, observation_space, action_space, config)
self.actor = actor
self.critic = critic
self.models = [self.actor, self.critic]
+1 -1
View File
@@ -161,7 +161,7 @@ class Model(object):
You can find an runnable example in examples/custom_loss.py.
Arguments:
policy_loss (Tensor): scalar policy loss from the policy graph.
policy_loss (Tensor): scalar policy loss from the policy.
loss_inputs (dict): map of input placeholders for rollout data.
Returns:
+1 -1
View File
@@ -6,7 +6,7 @@ import logging
import numpy as np
import threading
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils import try_import_tf
+1 -1
View File
@@ -17,7 +17,7 @@ except ImportError:
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, SampleBatch, \
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.compression import unpack_if_needed
+1 -1
View File
@@ -15,7 +15,7 @@ try:
except ImportError:
smart_open = None
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override, PublicAPI
@@ -5,7 +5,7 @@ from __future__ import print_function
from collections import namedtuple
import logging
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import DeveloperAPI
logger = logging.getLogger(__name__)
@@ -23,7 +23,7 @@ class OffPolicyEstimator(object):
"""Creates an off-policy estimator.
Arguments:
policy (PolicyGraph): Policy graph to evaluate.
policy (Policy): Policy to evaluate.
gamma (float): Discount of the MDP.
"""
self.policy = policy
@@ -71,7 +71,7 @@ class OffPolicyEstimator(object):
raise ValueError(
"Off-policy estimation is not possible unless the policy "
"returns action probabilities when computing actions (i.e., "
"the 'action_prob' key is output by the policy graph). You "
"the 'action_prob' key is output by the policy). You "
"can set `input_evaluation: []` to resolve this.")
return info["action_prob"]
@@ -11,7 +11,7 @@ import math
from six.moves import queue
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.optimizers.aso_learner import LearnerThread
from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
@@ -17,7 +17,7 @@ from six.moves import queue
import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
@@ -48,7 +48,7 @@ class LocalSyncParallelOptimizer(object):
processed. If this is larger than the total data size, it will be
clipped.
build_graph: Function that takes the specified inputs and returns a
TF Policy Graph instance.
TF Policy instance.
"""
def __init__(self,
@@ -9,14 +9,14 @@ from collections import defaultdict
import ray
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.optimizers.rollout import collect_samples, \
collect_samples_straggler_mitigation
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils import try_import_tf
@@ -34,9 +34,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
details, see `multi_gpu_impl.LocalSyncParallelOptimizer`.
This optimizer is Tensorflow-specific and require the underlying
PolicyGraph to be a TFPolicyGraph instance that support `.copy()`.
Policy to be a TFPolicy instance that support `.copy()`.
Note that all replicas of the TFPolicyGraph will merge their
Note that all replicas of the TFPolicy will merge their
extra_compute_grad and apply_grad feed_dicts and fetches. This
may result in unexpected behavior.
"""
@@ -83,7 +83,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p)))
logger.debug("Policies to train: {}".format(self.policies))
for policy_id, policy in self.policies.items():
if not isinstance(policy, TFPolicyGraph):
if not isinstance(policy, TFPolicy):
raise ValueError(
"Only TF policies are supported with multi-GPU. Try using "
"the simple optimizer instead.")
+1 -1
View File
@@ -5,7 +5,7 @@ from __future__ import print_function
import logging
import ray
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.memory import ray_get_and_free
logger = logging.getLogger(__name__)
@@ -7,7 +7,7 @@ import random
import ray
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
@@ -11,7 +11,7 @@ from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
PrioritizedReplayBuffer
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.compression import pack_if_needed
@@ -6,7 +6,7 @@ import ray
import logging
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.filter import RunningStat
from ray.rllib.utils.timer import TimerStat
+17
View File
@@ -0,0 +1,17 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.tf_policy_template import build_tf_policy
__all__ = [
"Policy",
"TFPolicy",
"TorchPolicy",
"build_tf_policy",
"build_torch_policy",
]
@@ -6,9 +6,9 @@ from collections import OrderedDict
import logging
import numpy as np
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.sample_batch import SampleBatch
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils import try_import_tf
@@ -20,8 +20,8 @@ tf = try_import_tf()
logger = logging.getLogger(__name__)
class DynamicTFPolicyGraph(TFPolicyGraph):
"""A TFPolicyGraph that auto-defines placeholders dynamically at runtime.
class DynamicTFPolicy(TFPolicy):
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
Initialization of this class occurs in two phases.
* Phase 1: the model is created and model variables are initialized.
@@ -42,7 +42,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
make_action_sampler=None,
existing_inputs=None,
get_batch_divisibility_req=None):
"""Initialize a dynamic TF policy graph.
"""Initialize a dynamic TF policy.
Arguments:
observation_space (gym.Space): Observation space of the policy.
@@ -51,16 +51,16 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
loss_fn (func): function that returns a loss tensor the policy
graph, and dict of experience tensor placeholders
stats_fn (func): optional function that returns a dict of
TF fetches given the policy graph and batch input tensors
TF fetches given the policy and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy graph and loss gradient tensors
TF fetches given the policy and loss gradient tensors
before_loss_init (func): optional function to run prior to loss
init that takes the same arguments as __init__
make_action_sampler (func): optional function that returns a
tuple of action and action prob tensors. The function takes
(policy, input_dict, obs_space, action_space, config) as its
arguments
existing_inputs (OrderedDict): when copying a policy graph, this
existing_inputs (OrderedDict): when copying a policy, this
specifies an existing dict of placeholders to use instead of
defining new ones
get_batch_divisibility_req (func): optional function that returns
@@ -134,7 +134,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
batch_divisibility_req = get_batch_divisibility_req(self)
else:
batch_divisibility_req = 1
TFPolicyGraph.__init__(
TFPolicy.__init__(
self,
obs_space,
action_space,
@@ -158,7 +158,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
if not existing_inputs:
self._initialize_loss()
@override(TFPolicyGraph)
@override(TFPolicy)
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders."""
@@ -194,7 +194,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
if instance._stats_fn:
instance._stats_fetches.update(
instance._stats_fn(instance, input_dict))
TFPolicyGraph._initialize_loss(
TFPolicy._initialize_loss(
instance, loss, [(k, existing_inputs[i])
for i, (k, _) in enumerate(self._loss_inputs)])
if instance._grad_stats_fn:
@@ -202,7 +202,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
instance._grad_stats_fn(instance, instance._grads))
return instance
@override(PolicyGraph)
@override(Policy)
def get_initial_state(self):
if self.model:
return self.model.state_init
@@ -269,7 +269,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
for k in sorted(batch_tensors.accessed_keys):
loss_inputs.append((k, batch_tensors[k]))
TFPolicyGraph._initialize_loss(self, loss, loss_inputs)
TFPolicy._initialize_loss(self, loss, loss_inputs)
if self._grad_stats_fn:
self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
self._sess.run(tf.global_variables_initializer())
+291
View File
@@ -0,0 +1,291 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import gym
from ray.rllib.utils.annotations import DeveloperAPI
# By convention, metrics from optimizing the loss can be reported in the
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
LEARNER_STATS_KEY = "learner_stats"
@DeveloperAPI
class Policy(object):
"""An agent policy and loss, i.e., a TFPolicy or other subclass.
This object defines how to act in the environment, and also losses used to
improve the policy based on its experiences. Note that both policy and
loss are defined together for convenience, though the policy itself is
logically separate.
All policies can directly extend Policy, however TensorFlow users may
find TFPolicy simpler to implement. TFPolicy also enables RLlib
to apply TensorFlow-specific optimizations such as fusing multiple policy
graphs and multi-GPU support.
Attributes:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
"""
@DeveloperAPI
def __init__(self, observation_space, action_space, config):
"""Initialize the graph.
This is the standard constructor for policies. The policy
class you pass into PolicyEvaluator will be constructed with
these arguments.
Args:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
config (dict): Policy-specific configuration data.
"""
self.observation_space = observation_space
self.action_space = action_space
@DeveloperAPI
def compute_actions(self,
obs_batch,
state_batches,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
"""Compute actions for the current policy.
Arguments:
obs_batch (np.ndarray): batch of observations
state_batches (list): list of RNN state input batches, if any
prev_action_batch (np.ndarray): batch of previous action values
prev_reward_batch (np.ndarray): batch of previous rewards
info_batch (info): batch of info objects
episodes (list): MultiAgentEpisode for each obs in obs_batch.
This provides access to all of the internal episode state,
which may be useful for model-based or multiagent algorithms.
kwargs: forward compatibility placeholder
Returns:
actions (np.ndarray): batch of output actions, with shape like
[BATCH_SIZE, ACTION_SHAPE].
state_outs (list): list of RNN state output batches, if any, with
shape like [STATE_SIZE, BATCH_SIZE].
info (dict): dictionary of extra feature batches, if any, with
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
"""
raise NotImplementedError
@DeveloperAPI
def compute_single_action(self,
obs,
state,
prev_action=None,
prev_reward=None,
info=None,
episode=None,
clip_actions=False,
**kwargs):
"""Unbatched version of compute_actions.
Arguments:
obs (obj): single observation
state_batches (list): list of RNN state inputs, if any
prev_action (obj): previous action value, if any
prev_reward (int): previous reward, if any
info (dict): info object, if any
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
clip_actions (bool): should the action be clipped
kwargs: forward compatibility placeholder
Returns:
actions (obj): single action
state_outs (list): list of RNN state outputs, if any
info (dict): dictionary of extra features, if any
"""
prev_action_batch = None
prev_reward_batch = None
info_batch = None
episodes = None
if prev_action is not None:
prev_action_batch = [prev_action]
if prev_reward is not None:
prev_reward_batch = [prev_reward]
if info is not None:
info_batch = [info]
if episode is not None:
episodes = [episode]
[action], state_out, info = self.compute_actions(
[obs], [[s] for s in state],
prev_action_batch=prev_action_batch,
prev_reward_batch=prev_reward_batch,
info_batch=info_batch,
episodes=episodes)
if clip_actions:
action = clip_action(action, self.action_space)
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
@DeveloperAPI
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
evaluation. Each fragment is guaranteed to be only from one episode.
Arguments:
sample_batch (SampleBatch): batch of experiences for the policy,
which will contain at most one episode trajectory.
other_agent_batches (dict): In a multi-agent env, this contains a
mapping of agent ids to (policy, agent_batch) tuples
containing the policy and experiences of the other agent.
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
Returns:
SampleBatch: postprocessed sample batch.
"""
return sample_batch
@DeveloperAPI
def learn_on_batch(self, samples):
"""Fused compute gradients and apply gradients call.
Either this or the combination of compute/apply grads must be
implemented by subclasses.
Returns:
grad_info: dictionary of extra metadata from compute_gradients().
Examples:
>>> batch = ev.sample()
>>> ev.learn_on_batch(samples)
"""
grads, grad_info = self.compute_gradients(samples)
self.apply_gradients(grads)
return grad_info
@DeveloperAPI
def compute_gradients(self, postprocessed_batch):
"""Computes gradients against a batch of experiences.
Either this or learn_on_batch() must be implemented by subclasses.
Returns:
grads (list): List of gradient output values
info (dict): Extra policy-specific values
"""
raise NotImplementedError
@DeveloperAPI
def apply_gradients(self, gradients):
"""Applies previously computed gradients.
Either this or learn_on_batch() must be implemented by subclasses.
"""
raise NotImplementedError
@DeveloperAPI
def get_weights(self):
"""Returns model weights.
Returns:
weights (obj): Serializable copy or view of model weights
"""
raise NotImplementedError
@DeveloperAPI
def set_weights(self, weights):
"""Sets model weights.
Arguments:
weights (obj): Serializable copy or view of model weights
"""
raise NotImplementedError
@DeveloperAPI
def get_initial_state(self):
"""Returns initial RNN state for the current policy."""
return []
@DeveloperAPI
def get_state(self):
"""Saves all local state.
Returns:
state (obj): Serialized local state.
"""
return self.get_weights()
@DeveloperAPI
def set_state(self, state):
"""Restores all local state.
Arguments:
state (obj): Serialized local state.
"""
self.set_weights(state)
@DeveloperAPI
def on_global_var_update(self, global_vars):
"""Called on an update to global vars.
Arguments:
global_vars (dict): Global variables broadcast from the driver.
"""
pass
@DeveloperAPI
def export_model(self, export_dir):
"""Export Policy to local directory for serving.
Arguments:
export_dir (str): Local writable directory.
"""
raise NotImplementedError
@DeveloperAPI
def export_checkpoint(self, export_dir):
"""Export Policy checkpoint to local directory.
Argument:
export_dir (str): Local writable directory.
"""
raise NotImplementedError
def clip_action(action, space):
"""Called to clip actions to the specified range of this policy.
Arguments:
action: Single action.
space: Action space the actions should be present in.
Returns:
Clipped batch of actions.
"""
if isinstance(space, gym.spaces.Box):
return np.clip(action, space.low, space.high)
elif isinstance(space, gym.spaces.Tuple):
if type(action) not in (tuple, list):
raise ValueError("Expected tuple space for actions {}: {}".format(
action, space))
out = []
for a, s in zip(action, space.spaces):
out.append(clip_action(a, s))
return out
else:
return action
+296
View File
@@ -0,0 +1,296 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import collections
import numpy as np
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
from ray.rllib.utils.compression import pack, unpack, is_compressed
from ray.rllib.utils.memory import concat_aligned
# Default policy id for single agent environments
DEFAULT_POLICY_ID = "default_policy"
@PublicAPI
class MultiAgentBatch(object):
"""A batch of experiences from multiple policies in the environment.
Attributes:
policy_batches (dict): Mapping from policy id to a normal SampleBatch
of experiences. Note that these batches may be of different length.
count (int): The number of timesteps in the environment this batch
contains. This will be less than the number of transitions this
batch contains across all policies in total.
"""
@PublicAPI
def __init__(self, policy_batches, count):
self.policy_batches = policy_batches
self.count = count
@staticmethod
@PublicAPI
def wrap_as_needed(batches, count):
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
return batches[DEFAULT_POLICY_ID]
return MultiAgentBatch(batches, count)
@staticmethod
@PublicAPI
def concat_samples(samples):
policy_batches = collections.defaultdict(list)
total_count = 0
for s in samples:
assert isinstance(s, MultiAgentBatch)
for policy_id, batch in s.policy_batches.items():
policy_batches[policy_id].append(batch)
total_count += s.count
out = {}
for policy_id, batches in policy_batches.items():
out[policy_id] = SampleBatch.concat_samples(batches)
return MultiAgentBatch(out, total_count)
@PublicAPI
def copy(self):
return MultiAgentBatch(
{k: v.copy()
for (k, v) in self.policy_batches.items()}, self.count)
@PublicAPI
def total(self):
ct = 0
for batch in self.policy_batches.values():
ct += batch.count
return ct
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.compress(bulk=bulk, columns=columns)
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
for batch in self.policy_batches.values():
batch.decompress_if_needed(columns)
def __str__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
def __repr__(self):
return "MultiAgentBatch({}, count={})".format(
str(self.policy_batches), self.count)
@PublicAPI
class SampleBatch(object):
"""Wrapper around a dictionary with string keys and array-like values.
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
samples, each with an "obs" and "reward" attribute.
"""
# Outputs from interacting with the environment
CUR_OBS = "obs"
NEXT_OBS = "new_obs"
ACTIONS = "actions"
REWARDS = "rewards"
PREV_ACTIONS = "prev_actions"
PREV_REWARDS = "prev_rewards"
DONES = "dones"
INFOS = "infos"
# Uniquely identifies an episode
EPS_ID = "eps_id"
# Uniquely identifies a sample batch. This is important to distinguish RNN
# sequences from the same episode when multiple sample batches are
# concatenated (fusing sequences across batches can be unsafe).
UNROLL_ID = "unroll_id"
# Uniquely identifies an agent within an episode
AGENT_INDEX = "agent_index"
# Value function predictions emitted by the behaviour policy
VF_PREDS = "vf_preds"
@PublicAPI
def __init__(self, *args, **kwargs):
"""Constructs a sample batch (same params as dict constructor)."""
self.data = dict(*args, **kwargs)
lengths = []
for k, v in self.data.copy().items():
assert isinstance(k, six.string_types), self
lengths.append(len(v))
self.data[k] = np.array(v, copy=False)
if not lengths:
raise ValueError("Empty sample batch")
assert len(set(lengths)) == 1, "data columns must be same length"
self.count = lengths[0]
@staticmethod
@PublicAPI
def concat_samples(samples):
if isinstance(samples[0], MultiAgentBatch):
return MultiAgentBatch.concat_samples(samples)
out = {}
samples = [s for s in samples if s.count > 0]
for k in samples[0].keys():
out[k] = concat_aligned([s[k] for s in samples])
return SampleBatch(out)
@PublicAPI
def concat(self, other):
"""Returns a new SampleBatch with each data column concatenated.
Examples:
>>> b1 = SampleBatch({"a": [1, 2]})
>>> b2 = SampleBatch({"a": [3, 4, 5]})
>>> print(b1.concat(b2))
{"a": [1, 2, 3, 4, 5]}
"""
assert self.keys() == other.keys(), "must have same columns"
out = {}
for k in self.keys():
out[k] = concat_aligned([self[k], other[k]])
return SampleBatch(out)
@PublicAPI
def copy(self):
return SampleBatch(
{k: np.array(v, copy=True)
for (k, v) in self.data.items()})
@PublicAPI
def rows(self):
"""Returns an iterator over data rows, i.e. dicts with column values.
Examples:
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> for row in batch.rows():
print(row)
{"a": 1, "b": 4}
{"a": 2, "b": 5}
{"a": 3, "b": 6}
"""
for i in range(self.count):
row = {}
for k in self.keys():
row[k] = self[k][i]
yield row
@PublicAPI
def columns(self, keys):
"""Returns a list of just the specified columns.
Examples:
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
"""
out = []
for k in keys:
out.append(self[k])
return out
@PublicAPI
def shuffle(self):
"""Shuffles the rows of this batch in-place."""
permutation = np.random.permutation(self.count)
for key, val in self.items():
self[key] = val[permutation]
@PublicAPI
def split_by_episode(self):
"""Splits this batch's data by `eps_id`.
Returns:
list of SampleBatch, one per distinct episode.
"""
slices = []
cur_eps_id = self.data["eps_id"][0]
offset = 0
for i in range(self.count):
next_eps_id = self.data["eps_id"][i]
if next_eps_id != cur_eps_id:
slices.append(self.slice(offset, i))
offset = i
cur_eps_id = next_eps_id
slices.append(self.slice(offset, self.count))
for s in slices:
slen = len(set(s["eps_id"]))
assert slen == 1, (s, slen)
assert sum(s.count for s in slices) == self.count, (slices, self.count)
return slices
@PublicAPI
def slice(self, start, end):
"""Returns a slice of the row data of this batch.
Arguments:
start (int): Starting index.
end (int): Ending index.
Returns:
SampleBatch which has a slice of this batch's data.
"""
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
@PublicAPI
def keys(self):
return self.data.keys()
@PublicAPI
def items(self):
return self.data.items()
@PublicAPI
def __getitem__(self, key):
return self.data[key]
@PublicAPI
def __setitem__(self, key, item):
self.data[key] = item
@DeveloperAPI
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
for key in columns:
if key in self.data:
if bulk:
self.data[key] = pack(self.data[key])
else:
self.data[key] = np.array(
[pack(o) for o in self.data[key]])
@DeveloperAPI
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
for key in columns:
if key in self.data:
arr = self.data[key]
if is_compressed(arr):
self.data[key] = unpack(arr)
elif len(arr) > 0 and is_compressed(arr[0]):
self.data[key] = np.array(
[unpack(o) for o in self.data[key]])
def __str__(self):
return "SampleBatch({})".format(str(self.data))
def __repr__(self):
return "SampleBatch({})".format(str(self.data))
def __iter__(self):
return self.data.__iter__()
def __contains__(self, x):
return x in self.data
+513
View File
@@ -0,0 +1,513 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import errno
import logging
import numpy as np
import ray
import ray.experimental.tf_utils
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.lstm import chop_into_sequences
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.debug import log_once, summarize
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
from ray.rllib.utils.tf_run_builder import TFRunBuilder
from ray.rllib.utils import try_import_tf
tf = try_import_tf()
logger = logging.getLogger(__name__)
@DeveloperAPI
class TFPolicy(Policy):
"""An agent policy and loss implemented in TensorFlow.
Extending this class enables RLlib to perform TensorFlow specific
optimizations on the policy, e.g., parallelization across gpus or
fusing multiple graphs together in the multi-agent setting.
Input tensors are typically shaped like [BATCH_SIZE, ...].
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
model (rllib.models.Model): RLlib model used for the policy.
Examples:
>>> policy = TFPolicySubclass(
sess, obs_input, action_sampler, loss, loss_inputs)
>>> print(policy.compute_actions([1, 0, 2]))
(array([0, 1, 1]), [], {})
>>> print(policy.postprocess_trajectory(SampleBatch({...})))
SampleBatch({"action": ..., "advantages": ..., ...})
"""
@DeveloperAPI
def __init__(self,
observation_space,
action_space,
sess,
obs_input,
action_sampler,
loss,
loss_inputs,
model=None,
action_prob=None,
state_inputs=None,
state_outputs=None,
prev_action_input=None,
prev_reward_input=None,
seq_lens=None,
max_seq_len=20,
batch_divisibility_req=1,
update_ops=None):
"""Initialize the policy.
Arguments:
observation_space (gym.Space): Observation space of the env.
action_space (gym.Space): Action space of the env.
sess (Session): TensorFlow session to use.
obs_input (Tensor): input placeholder for observations, of shape
[BATCH_SIZE, obs...].
action_sampler (Tensor): Tensor for sampling an action, of shape
[BATCH_SIZE, action...]
loss (Tensor): scalar policy loss output tensor.
loss_inputs (list): a (name, placeholder) tuple for each loss
input argument. Each placeholder name must correspond to a
SampleBatch column key returned by postprocess_trajectory(),
and has shape [BATCH_SIZE, data...]. These keys will be read
from postprocessed sample batches and fed into the specified
placeholders during loss computation.
model (rllib.models.Model): used to integrate custom losses and
stats from user-defined RLlib models.
action_prob (Tensor): probability of the sampled action.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
prev_action_input (Tensor): placeholder for previous actions
prev_reward_input (Tensor): placeholder for previous rewards
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
models/lstm.py for more information.
max_seq_len (int): max sequence length for LSTM training.
batch_divisibility_req (int): pad all agent experiences batches to
multiples of this value. This only has an effect if not using
a LSTM model.
update_ops (list): override the batchnorm update ops to run when
applying gradients. Otherwise we run all update ops found in
the current variable scope.
"""
self.observation_space = observation_space
self.action_space = action_space
self.model = model
self._sess = sess
self._obs_input = obs_input
self._prev_action_input = prev_action_input
self._prev_reward_input = prev_reward_input
self._sampler = action_sampler
self._is_training = self._get_is_training_placeholder()
self._action_prob = action_prob
self._state_inputs = state_inputs or []
self._state_outputs = state_outputs or []
self._seq_lens = seq_lens
self._max_seq_len = max_seq_len
self._batch_divisibility_req = batch_divisibility_req
self._update_ops = update_ops
self._stats_fetches = {}
if loss is not None:
self._initialize_loss(loss, loss_inputs)
else:
self._loss = None
if len(self._state_inputs) != len(self._state_outputs):
raise ValueError(
"Number of state input and output tensors must match, got: "
"{} vs {}".format(self._state_inputs, self._state_outputs))
if len(self.get_initial_state()) != len(self._state_inputs):
raise ValueError(
"Length of initial state must match number of state inputs, "
"got: {} vs {}".format(self.get_initial_state(),
self._state_inputs))
if self._state_inputs and self._seq_lens is None:
raise ValueError(
"seq_lens tensor must be given if state inputs are defined")
def _initialize_loss(self, loss, loss_inputs):
self._loss_inputs = loss_inputs
self._loss_input_dict = dict(self._loss_inputs)
for i, ph in enumerate(self._state_inputs):
self._loss_input_dict["state_in_{}".format(i)] = ph
if self.model:
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
self._stats_fetches.update({"model": self.model.custom_stats()})
else:
self._loss = loss
self._optimizer = self.optimizer()
self._grads_and_vars = [
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
if g is not None
]
self._grads = [g for (g, v) in self._grads_and_vars]
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
self._loss, self._sess)
# gather update ops for any batch norm layers
if not self._update_ops:
self._update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
if self._update_ops:
logger.debug("Update ops to run on apply gradient: {}".format(
self._update_ops))
with tf.control_dependencies(self._update_ops):
self._apply_op = self.build_apply_op(self._optimizer,
self._grads_and_vars)
if log_once("loss_used"):
logger.debug(
"These tensors were used in the loss_fn:\n\n{}\n".format(
summarize(self._loss_input_dict)))
self._sess.run(tf.global_variables_initializer())
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
builder = TFRunBuilder(self._sess, "compute_actions")
fetches = self._build_compute_actions(builder, obs_batch,
state_batches, prev_action_batch,
prev_reward_batch)
return builder.get(fetches)
@override(Policy)
def compute_gradients(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "compute_gradients")
fetches = self._build_compute_gradients(builder, postprocessed_batch)
return builder.get(fetches)
@override(Policy)
def apply_gradients(self, gradients):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "apply_gradients")
fetches = self._build_apply_gradients(builder, gradients)
builder.get(fetches)
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
assert self._loss is not None, "Loss not initialized"
builder = TFRunBuilder(self._sess, "learn_on_batch")
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
return builder.get(fetches)
@override(Policy)
def get_weights(self):
return self._variables.get_flat()
@override(Policy)
def set_weights(self, weights):
return self._variables.set_flat(weights)
@override(Policy)
def export_model(self, export_dir):
"""Export tensorflow graph to export_dir for serving."""
with self._sess.graph.as_default():
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
signature_def_map = self._build_signature_def()
builder.add_meta_graph_and_variables(
self._sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map=signature_def_map)
builder.save()
@override(Policy)
def export_checkpoint(self, export_dir, filename_prefix="model"):
"""Export tensorflow checkpoint to export_dir."""
try:
os.makedirs(export_dir)
except OSError as e:
# ignore error if export dir already exists
if e.errno != errno.EEXIST:
raise
save_path = os.path.join(export_dir, filename_prefix)
with self._sess.graph.as_default():
saver = tf.train.Saver()
saver.save(self._sess, save_path)
@DeveloperAPI
def copy(self, existing_inputs):
"""Creates a copy of self using existing input placeholders.
Optional, only required to work with the multi-GPU optimizer."""
raise NotImplementedError
@DeveloperAPI
def extra_compute_action_feed_dict(self):
"""Extra dict to pass to the compute actions session run."""
return {}
@DeveloperAPI
def extra_compute_action_fetches(self):
"""Extra values to fetch and return from compute_actions().
By default we only return action probability info (if present).
"""
if self._action_prob is not None:
return {"action_prob": self._action_prob}
else:
return {}
@DeveloperAPI
def extra_compute_grad_feed_dict(self):
"""Extra dict to pass to the compute gradients session run."""
return {} # e.g, kl_coeff
@DeveloperAPI
def extra_compute_grad_fetches(self):
"""Extra values to fetch and return from compute_gradients()."""
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
@DeveloperAPI
def optimizer(self):
"""TF optimizer to use for policy optimization."""
if hasattr(self, "config"):
return tf.train.AdamOptimizer(self.config["lr"])
else:
return tf.train.AdamOptimizer()
@DeveloperAPI
def gradients(self, optimizer, loss):
"""Override for custom gradient computation."""
return optimizer.compute_gradients(loss)
@DeveloperAPI
def build_apply_op(self, optimizer, grads_and_vars):
"""Override for custom gradient apply computation."""
# specify global_step for TD3 which needs to count the num updates
return optimizer.apply_gradients(
self._grads_and_vars,
global_step=tf.train.get_or_create_global_step())
@DeveloperAPI
def _get_is_training_placeholder(self):
"""Get the placeholder for _is_training, i.e., for batch norm layers.
This can be called safely before __init__ has run.
"""
if not hasattr(self, "_is_training"):
self._is_training = tf.placeholder_with_default(False, ())
return self._is_training
def _extra_input_signature_def(self):
"""Extra input signatures to add when exporting tf model.
Inferred from extra_compute_action_feed_dict()
"""
feed_dict = self.extra_compute_action_feed_dict()
return {
k.name: tf.saved_model.utils.build_tensor_info(k)
for k in feed_dict.keys()
}
def _extra_output_signature_def(self):
"""Extra output signatures to add when exporting tf model.
Inferred from extra_compute_action_fetches()
"""
fetches = self.extra_compute_action_fetches()
return {
k: tf.saved_model.utils.build_tensor_info(fetches[k])
for k in fetches.keys()
}
def _build_signature_def(self):
"""Build signature def map for tensorflow SavedModelBuilder.
"""
# build input signatures
input_signature = self._extra_input_signature_def()
input_signature["observations"] = \
tf.saved_model.utils.build_tensor_info(self._obs_input)
if self._seq_lens is not None:
input_signature["seq_lens"] = \
tf.saved_model.utils.build_tensor_info(self._seq_lens)
if self._prev_action_input is not None:
input_signature["prev_action"] = \
tf.saved_model.utils.build_tensor_info(self._prev_action_input)
if self._prev_reward_input is not None:
input_signature["prev_reward"] = \
tf.saved_model.utils.build_tensor_info(self._prev_reward_input)
input_signature["is_training"] = \
tf.saved_model.utils.build_tensor_info(self._is_training)
for state_input in self._state_inputs:
input_signature[state_input.name] = \
tf.saved_model.utils.build_tensor_info(state_input)
# build output signatures
output_signature = self._extra_output_signature_def()
output_signature["actions"] = \
tf.saved_model.utils.build_tensor_info(self._sampler)
for state_output in self._state_outputs:
output_signature[state_output.name] = \
tf.saved_model.utils.build_tensor_info(state_output)
signature_def = (
tf.saved_model.signature_def_utils.build_signature_def(
input_signature, output_signature,
tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
signature_def_key = (tf.saved_model.signature_constants.
DEFAULT_SERVING_SIGNATURE_DEF_KEY)
signature_def_map = {signature_def_key: signature_def}
return signature_def_map
def _build_compute_actions(self,
builder,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
episodes=None):
state_batches = state_batches or []
if len(self._state_inputs) != len(state_batches):
raise ValueError(
"Must pass in RNN state batches for placeholders {}, got {}".
format(self._state_inputs, state_batches))
builder.add_feed_dict(self.extra_compute_action_feed_dict())
builder.add_feed_dict({self._obs_input: obs_batch})
if state_batches:
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
if self._prev_action_input is not None and prev_action_batch:
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
if self._prev_reward_input is not None and prev_reward_batch:
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
builder.add_feed_dict({self._is_training: False})
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
fetches = builder.add_fetches([self._sampler] + self._state_outputs +
[self.extra_compute_action_fetches()])
return fetches[0], fetches[1:-1], fetches[-1]
def _build_compute_gradients(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
fetches = builder.add_fetches(
[self._grads, self._get_grad_and_stats_fetches()])
return fetches[0], fetches[1]
def _build_apply_gradients(self, builder, gradients):
if len(gradients) != len(self._grads):
raise ValueError(
"Unexpected number of gradients to apply, got {} for {}".
format(gradients, self._grads))
builder.add_feed_dict({self._is_training: True})
builder.add_feed_dict(dict(zip(self._grads, gradients)))
fetches = builder.add_fetches([self._apply_op])
return fetches[0]
def _build_learn_on_batch(self, builder, postprocessed_batch):
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
builder.add_feed_dict({self._is_training: True})
fetches = builder.add_fetches([
self._apply_op,
self._get_grad_and_stats_fetches(),
])
return fetches[1]
def _get_grad_and_stats_fetches(self):
fetches = self.extra_compute_grad_fetches()
if LEARNER_STATS_KEY not in fetches:
raise ValueError(
"Grad fetches should contain 'stats': {...} entry")
if self._stats_fetches:
fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches,
**fetches[LEARNER_STATS_KEY])
return fetches
def _get_loss_inputs_dict(self, batch):
feed_dict = {}
if self._batch_divisibility_req > 1:
meets_divisibility_reqs = (
len(batch[SampleBatch.CUR_OBS]) %
self._batch_divisibility_req == 0
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
else:
meets_divisibility_reqs = True
# Simple case: not RNN nor do we need to pad
if not self._state_inputs and meets_divisibility_reqs:
for k, ph in self._loss_inputs:
feed_dict[ph] = batch[k]
return feed_dict
if self._state_inputs:
max_seq_len = self._max_seq_len
dynamic_max = True
else:
max_seq_len = self._batch_divisibility_req
dynamic_max = False
# RNN or multi-agent case
feature_keys = [k for k, v in self._loss_inputs]
state_keys = [
"state_in_{}".format(i) for i in range(len(self._state_inputs))
]
feature_sequences, initial_states, seq_lens = chop_into_sequences(
batch[SampleBatch.EPS_ID],
batch[SampleBatch.UNROLL_ID],
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
[batch[k] for k in state_keys],
max_seq_len,
dynamic_max=dynamic_max)
for k, v in zip(feature_keys, feature_sequences):
feed_dict[self._loss_input_dict[k]] = v
for k, v in zip(state_keys, initial_states):
feed_dict[self._loss_input_dict[k]] = v
feed_dict[self._seq_lens] = seq_lens
if log_once("rnn_feed_dict"):
logger.info("Padded input for RNN:\n\n{}\n".format(
summarize({
"features": feature_sequences,
"initial_states": initial_states,
"seq_lens": seq_lens,
"max_seq_len": max_seq_len,
})))
return feed_dict
@DeveloperAPI
class LearningRateSchedule(object):
"""Mixin for TFPolicy that adds a learning rate schedule."""
@DeveloperAPI
def __init__(self, lr, lr_schedule):
self.cur_lr = tf.get_variable("lr", initializer=lr)
if lr_schedule is None:
self.lr_schedule = ConstantSchedule(lr)
else:
self.lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1])
@override(Policy)
def on_global_var_update(self, global_vars):
super(LearningRateSchedule, self).on_global_var_update(global_vars)
self.cur_lr.load(
self.lr_schedule.value(global_vars["timestep"]),
session=self._sess)
@override(TFPolicy)
def optimizer(self):
return tf.train.AdamOptimizer(self.cur_lr)
@@ -0,0 +1,146 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
from ray.rllib.utils.annotations import override, DeveloperAPI
@DeveloperAPI
def build_tf_policy(name,
loss_fn,
get_default_config=None,
stats_fn=None,
grad_stats_fn=None,
extra_action_fetches_fn=None,
postprocess_fn=None,
optimizer_fn=None,
gradients_fn=None,
before_init=None,
before_loss_init=None,
after_init=None,
make_action_sampler=None,
mixins=None,
get_batch_divisibility_req=None):
"""Helper function for creating a dynamic tf policy at runtime.
Arguments:
name (str): name of the policy (e.g., "PPOTFPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
get_default_config (func): optional function that returns the default
config to merge with any overrides
stats_fn (func): optional function that returns a dict of
TF fetches given the policy and batch input tensors
grad_stats_fn (func): optional function that returns a dict of
TF fetches given the policy and loss gradient tensors
extra_action_fetches_fn (func): optional function that returns
a dict of TF fetches given the policy object
postprocess_fn (func): optional experience postprocessing function
that takes the same args as Policy.postprocess_trajectory()
optimizer_fn (func): optional function that returns a tf.Optimizer
given the policy and config
gradients_fn (func): optional function that returns a list of gradients
given a tf optimizer and loss tensor. If not specified, this
defaults to optimizer.compute_gradients(loss)
before_init (func): optional function to run at the beginning of
policy init that takes the same arguments as the policy constructor
before_loss_init (func): optional function to run prior to loss
init that takes the same arguments as the policy constructor
after_init (func): optional function to run at the end of policy init
that takes the same arguments as the policy constructor
make_action_sampler (func): optional function that returns a
tuple of action and action prob tensors. The function takes
(policy, input_dict, obs_space, action_space, config) as its
arguments
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the DynamicTFPolicy class
get_batch_divisibility_req (func): optional function that returns
the divisibility requirement for sample batches
Returns:
a DynamicTFPolicy instance that uses the specified args
"""
if not name.endswith("TFPolicy"):
raise ValueError("Name should match *TFPolicy", name)
base = DynamicTFPolicy
while mixins:
class new_base(mixins.pop(), base):
pass
base = new_base
class policy_cls(base):
def __init__(self,
obs_space,
action_space,
config,
existing_inputs=None):
if get_default_config:
config = dict(get_default_config(), **config)
if before_init:
before_init(self, obs_space, action_space, config)
def before_loss_init_wrapper(policy, obs_space, action_space,
config):
if before_loss_init:
before_loss_init(policy, obs_space, action_space, config)
if extra_action_fetches_fn is None:
self._extra_action_fetches = {}
else:
self._extra_action_fetches = extra_action_fetches_fn(self)
DynamicTFPolicy.__init__(
self,
obs_space,
action_space,
config,
loss_fn,
stats_fn=stats_fn,
grad_stats_fn=grad_stats_fn,
before_loss_init=before_loss_init_wrapper,
existing_inputs=existing_inputs)
if after_init:
after_init(self, obs_space, action_space, config)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
if not postprocess_fn:
return sample_batch
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TFPolicy)
def optimizer(self):
if optimizer_fn:
return optimizer_fn(self, self.config)
else:
return TFPolicy.optimizer(self)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if gradients_fn:
return gradients_fn(self, optimizer, loss)
else:
return TFPolicy.gradients(self, optimizer, loss)
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicy.extra_compute_action_fetches(self),
**self._extra_action_fetches)
policy_cls.__name__ = name
policy_cls.__qualname__ = name
return policy_cls
+173
View File
@@ -0,0 +1,173 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
from threading import Lock
try:
import torch
except ImportError:
pass # soft dep
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
from ray.rllib.utils.annotations import override
from ray.rllib.utils.tracking_dict import UsageTrackingDict
class TorchPolicy(Policy):
"""Template for a PyTorch policy and loss to use with RLlib.
This is similar to TFPolicy, but for PyTorch.
Attributes:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
lock (Lock): Lock that must be held around PyTorch ops on this graph.
This is necessary when using the async sampler.
"""
def __init__(self, observation_space, action_space, model, loss,
action_distribution_cls):
"""Build a policy from policy and loss torch modules.
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
is set. Only single GPU is supported for now.
Arguments:
observation_space (gym.Space): observation space of the policy.
action_space (gym.Space): action space of the policy.
model (nn.Module): PyTorch policy module. Given observations as
input, this module must return a list of outputs where the
first item is action logits, and the rest can be any value.
loss (func): Function that takes (policy, batch_tensors)
and returns a single scalar loss.
action_distribution_cls (ActionDistribution): Class for action
distribution.
"""
self.observation_space = observation_space
self.action_space = action_space
self.lock = Lock()
self.device = (torch.device("cuda")
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
else torch.device("cpu"))
self._model = model.to(self.device)
self._loss = loss
self._optimizer = self.optimizer()
self._action_dist_cls = action_distribution_cls
@override(Policy)
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
with self.lock:
with torch.no_grad():
ob = torch.from_numpy(np.array(obs_batch)) \
.float().to(self.device)
model_out = self._model({"obs": ob}, state_batches)
logits, _, vf, state = model_out
action_dist = self._action_dist_cls(logits)
actions = action_dist.sample()
return (actions.cpu().numpy(),
[h.cpu().numpy() for h in state],
self.extra_action_out(model_out))
@override(Policy)
def learn_on_batch(self, postprocessed_batch):
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_out = self._loss(self, batch_tensors)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
self._optimizer.step()
grad_info = self.extra_grad_info(batch_tensors)
grad_info.update(grad_process_info)
return {LEARNER_STATS_KEY: grad_info}
@override(Policy)
def compute_gradients(self, postprocessed_batch):
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
with self.lock:
loss_out = self._loss(self, batch_tensors)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
# Note that return values are just references;
# calling zero_grad will modify the values
grads = []
for p in self._model.parameters():
if p.grad is not None:
grads.append(p.grad.data.cpu().numpy())
else:
grads.append(None)
grad_info = self.extra_grad_info(batch_tensors)
grad_info.update(grad_process_info)
return grads, {LEARNER_STATS_KEY: grad_info}
@override(Policy)
def apply_gradients(self, gradients):
with self.lock:
for g, p in zip(gradients, self._model.parameters()):
if g is not None:
p.grad = torch.from_numpy(g).to(self.device)
self._optimizer.step()
@override(Policy)
def get_weights(self):
with self.lock:
return {k: v.cpu() for k, v in self._model.state_dict().items()}
@override(Policy)
def set_weights(self, weights):
with self.lock:
self._model.load_state_dict(weights)
@override(Policy)
def get_initial_state(self):
return [s.numpy() for s in self._model.state_init()]
def extra_grad_process(self):
"""Allow subclass to do extra processing on gradients and
return processing info."""
return {}
def extra_action_out(self, model_out):
"""Returns dict of extra info to include in experience batch.
Arguments:
model_out (list): Outputs of the policy model module."""
return {}
def extra_grad_info(self, batch_tensors):
"""Return dict of extra grad info."""
return {}
def optimizer(self):
"""Custom PyTorch optimizer to use."""
if hasattr(self, "config"):
return torch.optim.Adam(
self._model.parameters(), lr=self.config["lr"])
else:
return torch.optim.Adam(self._model.parameters())
def _lazy_tensor_dict(self, postprocessed_batch):
batch_tensors = UsageTrackingDict(postprocessed_batch)
batch_tensors.set_get_interceptor(
lambda arr: torch.from_numpy(arr).to(self.device))
return batch_tensors
@@ -2,8 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.utils.annotations import override, DeveloperAPI
@@ -24,7 +24,7 @@ def build_torch_policy(name,
"""Helper function for creating a torch policy at runtime.
Arguments:
name (str): name of the graph (e.g., "PPOPolicy")
name (str): name of the policy (e.g., "PPOTFPolicy")
loss_fn (func): function that returns a loss tensor the policy,
and dict of experience tensor placeholders
get_default_config (func): optional function that returns the default
@@ -32,7 +32,7 @@ def build_torch_policy(name,
stats_fn (func): optional function that returns a dict of
values given the policy and batch input tensors
postprocess_fn (func): optional experience postprocessing function
that takes the same args as PolicyGraph.postprocess_trajectory()
that takes the same args as Policy.postprocess_trajectory()
extra_action_out_fn (func): optional function that returns
a dict of extra values to include in experiences
extra_grad_process_fn (func): optional function that is called after
@@ -49,16 +49,16 @@ def build_torch_policy(name,
model and action dist from the catalog will be used
mixins (list): list of any class mixins for the returned policy class.
These mixins will be applied in order and will have higher
precedence than the TorchPolicyGraph class
precedence than the TorchPolicy class
Returns:
a TorchPolicyGraph instance that uses the specified args
a TorchPolicy instance that uses the specified args
"""
if not name.endswith("TorchPolicy"):
raise ValueError("Name should match *TorchPolicy", name)
base = TorchPolicyGraph
base = TorchPolicy
while mixins:
class new_base(mixins.pop(), base):
@@ -84,13 +84,13 @@ def build_torch_policy(name,
self.model = ModelCatalog.get_torch_model(
obs_space, logit_dim, self.config["model"])
TorchPolicyGraph.__init__(self, obs_space, action_space,
self.model, loss_fn, self.dist_class)
TorchPolicy.__init__(self, obs_space, action_space, self.model,
loss_fn, self.dist_class)
if after_init:
after_init(self, obs_space, action_space, config)
@override(PolicyGraph)
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
@@ -100,33 +100,33 @@ def build_torch_policy(name,
return postprocess_fn(self, sample_batch, other_agent_batches,
episode)
@override(TorchPolicyGraph)
@override(TorchPolicy)
def extra_grad_process(self):
if extra_grad_process_fn:
return extra_grad_process_fn(self)
else:
return TorchPolicyGraph.extra_grad_process(self)
return TorchPolicy.extra_grad_process(self)
@override(TorchPolicyGraph)
@override(TorchPolicy)
def extra_action_out(self, model_out):
if extra_action_out_fn:
return extra_action_out_fn(self, model_out)
else:
return TorchPolicyGraph.extra_action_out(self, model_out)
return TorchPolicy.extra_action_out(self, model_out)
@override(TorchPolicyGraph)
@override(TorchPolicy)
def optimizer(self):
if optimizer_fn:
return optimizer_fn(self, self.config)
else:
return TorchPolicyGraph.optimizer(self)
return TorchPolicy.optimizer(self)
@override(TorchPolicyGraph)
@override(TorchPolicy)
def extra_grad_info(self, batch_tensors):
if stats_fn:
return stats_fn(self, batch_tensors)
else:
return TorchPolicyGraph.extra_grad_info(self, batch_tensors)
return TorchPolicy.extra_grad_info(self, batch_tensors)
graph_cls.__name__ = name
graph_cls.__qualname__ = name
+1 -1
View File
@@ -15,7 +15,7 @@ import ray
from ray.rllib.agents.registry import get_agent_class
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.tune.util import merge_dicts
EXAMPLE_USAGE = """
+1 -1
View File
@@ -7,7 +7,7 @@ import unittest
import ray
from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.a3c import A3CTrainer
from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep
from ray.rllib.agents.dqn.dqn_policy import _adjust_nstep
from ray.tune.registry import register_env
import gym
+7 -7
View File
@@ -13,8 +13,8 @@ from ray.rllib.agents.dqn import DQNTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.external_env import ExternalEnv
from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph,
MockPolicyGraph, MockEnv)
from ray.rllib.tests.test_policy_evaluator import (BadPolicy, MockPolicy,
MockEnv)
from ray.tune.registry import register_env
@@ -121,7 +121,7 @@ class TestExternalEnv(unittest.TestCase):
def testExternalEnvCompleteEpisodes(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=40,
batch_mode="complete_episodes")
for _ in range(3):
@@ -131,7 +131,7 @@ class TestExternalEnv(unittest.TestCase):
def testExternalEnvTruncateEpisodes(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=40,
batch_mode="truncate_episodes")
for _ in range(3):
@@ -141,7 +141,7 @@ class TestExternalEnv(unittest.TestCase):
def testExternalEnvOffPolicy(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=40,
batch_mode="complete_episodes")
for _ in range(3):
@@ -153,7 +153,7 @@ class TestExternalEnv(unittest.TestCase):
def testExternalEnvBadActions(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=BadPolicyGraph,
policy=BadPolicy,
sample_async=True,
batch_steps=40,
batch_mode="truncate_episodes")
@@ -198,7 +198,7 @@ class TestExternalEnv(unittest.TestCase):
def testExternalEnvHorizonNotSupported(self):
ev = PolicyEvaluator(
env_creator=lambda _: SimpleServing(MockEnv(25)),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
episode_horizon=20,
batch_steps=10,
batch_mode="complete_episodes")
@@ -8,11 +8,11 @@ import random
import unittest
import ray
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
from ray.rllib.optimizers import SyncSamplesOptimizer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph
from ray.rllib.tests.test_policy_evaluator import MockPolicy
from ray.rllib.tests.test_external_env import make_simple_serving
from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole
from ray.rllib.evaluation.metrics import collect_metrics
@@ -25,7 +25,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
agents = 4
ev = PolicyEvaluator(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=40,
batch_mode="complete_episodes")
for _ in range(3):
@@ -37,7 +37,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
agents = 4
ev = PolicyEvaluator(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=40,
batch_mode="truncate_episodes")
for _ in range(3):
@@ -51,9 +51,9 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50)
@@ -72,7 +72,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
policy_ids = list(policies.keys())
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer(ev, [])
+3 -3
View File
@@ -15,7 +15,7 @@ import unittest
import ray
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
from ray.rllib.evaluation import SampleBatch
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
from ray.rllib.offline.json_writer import _to_json
@@ -167,7 +167,7 @@ class AgentIOTest(unittest.TestCase):
"num_workers": 0,
"output": self.test_dir,
"multiagent": {
"policy_graphs": {
"policies": {
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
@@ -188,7 +188,7 @@ class AgentIOTest(unittest.TestCase):
"input_evaluation": ["simulation"],
"train_batch_size": 2000,
"multiagent": {
"policy_graphs": {
"policies": {
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
+41 -37
View File
@@ -8,14 +8,14 @@ import unittest
import ray
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
AsyncGradientsOptimizer)
from ray.rllib.tests.test_policy_evaluator import (MockEnv, MockEnv2,
MockPolicyGraph)
MockPolicy)
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -329,9 +329,9 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50)
@@ -347,9 +347,9 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50,
@@ -364,9 +364,9 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_steps=50,
@@ -380,9 +380,9 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: BasicMultiAgent(5),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
episode_horizon=10, # test with episode horizon set
@@ -395,9 +395,9 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(2)
ev = PolicyEvaluator(
env_creator=lambda _: EarlyDoneMultiAgent(),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
"p1": (MockPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
"p1": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
batch_mode="complete_episodes",
@@ -411,8 +411,8 @@ class TestMultiAgentEnv(unittest.TestCase):
obs_space = gym.spaces.Discrete(10)
ev = PolicyEvaluator(
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
policy_graph={
"p0": (MockPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (MockPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p0",
batch_steps=50)
@@ -445,7 +445,7 @@ class TestMultiAgentEnv(unittest.TestCase):
def testCustomRNNStateValues(self):
h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}
class StatefulPolicyGraph(PolicyGraph):
class StatefulPolicy(Policy):
def compute_actions(self,
obs_batch,
state_batches,
@@ -460,7 +460,7 @@ class TestMultiAgentEnv(unittest.TestCase):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=StatefulPolicyGraph,
policy=StatefulPolicy,
batch_steps=5)
batch = ev.sample()
self.assertEqual(batch.count, 5)
@@ -470,7 +470,7 @@ class TestMultiAgentEnv(unittest.TestCase):
self.assertEqual(batch["state_out_0"][1], h)
def testReturningModelBasedRolloutsData(self):
class ModelBasedPolicyGraph(PGTFPolicy):
class ModelBasedPolicy(PGTFPolicy):
def compute_actions(self,
obs_batch,
state_batches,
@@ -505,9 +505,9 @@ class TestMultiAgentEnv(unittest.TestCase):
act_space = single_env.action_space
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(2),
policy_graph={
"p0": (ModelBasedPolicyGraph, obs_space, act_space, {}),
"p1": (ModelBasedPolicyGraph, obs_space, act_space, {}),
policy={
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
},
policy_mapping_fn=lambda agent_id: "p0",
batch_steps=5)
@@ -547,7 +547,7 @@ class TestMultiAgentEnv(unittest.TestCase):
config={
"num_workers": 0,
"multiagent": {
"policy_graphs": {
"policies": {
"policy_1": gen_policy(),
"policy_2": gen_policy(),
},
@@ -579,17 +579,17 @@ class TestMultiAgentEnv(unittest.TestCase):
# happen since the replay buffer doesn't encode extra fields like
# "advantages" that PG uses.
policies = {
"p1": (DQNPolicyGraph, obs_space, act_space, dqn_config),
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
"p1": (DQNTFPolicy, obs_space, act_space, dqn_config),
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
}
else:
policies = {
"p1": (PGTFPolicy, obs_space, act_space, {}),
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
}
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy=policies,
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
batch_steps=50)
if optimizer_cls == AsyncGradientsOptimizer:
@@ -600,7 +600,7 @@ class TestMultiAgentEnv(unittest.TestCase):
remote_evs = [
PolicyEvaluator.as_remote().remote(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy=policies,
policy_mapping_fn=policy_mapper,
batch_steps=50)
]
@@ -610,12 +610,16 @@ class TestMultiAgentEnv(unittest.TestCase):
for i in range(200):
ev.foreach_policy(lambda p, _: p.set_epsilon(
max(0.02, 1 - i * .02))
if isinstance(p, DQNPolicyGraph) else None)
if isinstance(p, DQNTFPolicy) else None)
optimizer.step()
result = collect_metrics(ev, remote_evs)
if i % 20 == 0:
ev.foreach_policy(lambda p, _: p.update_target() if isinstance(
p, DQNPolicyGraph) else None)
def do_update(p):
if isinstance(p, DQNTFPolicy):
p.update_target()
ev.foreach_policy(lambda p, _: do_update(p))
print("Iter {}, rew {}".format(i,
result["policy_reward_mean"]))
print("Total reward", result["episode_reward_mean"])
@@ -645,7 +649,7 @@ class TestMultiAgentEnv(unittest.TestCase):
policy_ids = list(policies.keys())
ev = PolicyEvaluator(
env_creator=lambda _: MultiCartpole(n),
policy_graph=policies,
policy=policies,
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
batch_steps=100)
optimizer = SyncSamplesOptimizer(ev, [])
+2 -2
View File
@@ -12,7 +12,7 @@ import unittest
import ray
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
from ray.rllib.env import MultiAgentEnv
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.env.vector_env import VectorEnv
@@ -331,7 +331,7 @@ class NestedSpacesTest(unittest.TestCase):
"sample_batch_size": 5,
"train_batch_size": 5,
"multiagent": {
"policy_graphs": {
"policies": {
"tuple_policy": (
PGTFPolicy, TUPLE_SPACE, act_space,
{"model": {"custom_model": "tuple_spy"}}),
+3 -3
View File
@@ -9,7 +9,7 @@ import unittest
import ray
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
from ray.rllib.evaluation import SampleBatch
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
@@ -240,12 +240,12 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
local = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=PPOTFPolicy,
policy=PPOTFPolicy,
tf_session_creator=make_sess)
remotes = [
PolicyEvaluator.as_remote().remote(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=PPOTFPolicy,
policy=PPOTFPolicy,
tf_session_creator=make_sess)
]
return local, remotes
+2 -2
View File
@@ -8,7 +8,7 @@ import unittest
import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph
from ray.rllib.tests.test_policy_evaluator import MockPolicy
class TestPerf(unittest.TestCase):
@@ -19,7 +19,7 @@ class TestPerf(unittest.TestCase):
for _ in range(20):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=100)
start = time.time()
count = 0
+22 -24
View File
@@ -14,14 +14,14 @@ from ray.rllib.agents.pg import PGTrainer
from ray.rllib.agents.a3c import A2CTrainer
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_metrics
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.policy.policy import Policy
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
from ray.rllib.env.vector_env import VectorEnv
from ray.tune.registry import register_env
class MockPolicyGraph(PolicyGraph):
class MockPolicy(Policy):
def compute_actions(self,
obs_batch,
state_batches,
@@ -39,7 +39,7 @@ class MockPolicyGraph(PolicyGraph):
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
class BadPolicyGraph(PolicyGraph):
class BadPolicy(Policy):
def compute_actions(self,
obs_batch,
state_batches,
@@ -132,8 +132,7 @@ class MockVectorEnv(VectorEnv):
class TestPolicyEvaluator(unittest.TestCase):
def testBasic(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph)
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
batch = ev.sample()
for key in [
"obs", "actions", "rewards", "dones", "advantages",
@@ -157,8 +156,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testBatchIds(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph)
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
batch1 = ev.sample()
batch2 = ev.sample()
self.assertEqual(len(set(batch1["unroll_id"])), 1)
@@ -229,7 +227,7 @@ class TestPolicyEvaluator(unittest.TestCase):
# clipping on
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv2(episode_length=10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
clip_rewards=True,
batch_mode="complete_episodes")
self.assertEqual(max(ev.sample()["rewards"]), 1)
@@ -239,7 +237,7 @@ class TestPolicyEvaluator(unittest.TestCase):
# clipping off
ev2 = PolicyEvaluator(
env_creator=lambda _: MockEnv2(episode_length=10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
clip_rewards=False,
batch_mode="complete_episodes")
self.assertEqual(max(ev2.sample()["rewards"]), 100)
@@ -249,7 +247,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testHardHorizon(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=4,
@@ -263,7 +261,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testSoftHorizon(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_mode="complete_episodes",
batch_steps=10,
episode_horizon=4,
@@ -277,11 +275,11 @@ class TestPolicyEvaluator(unittest.TestCase):
def testMetrics(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_mode="complete_episodes")
remote_ev = PolicyEvaluator.as_remote().remote(
env_creator=lambda _: MockEnv(episode_length=10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_mode="complete_episodes")
ev.sample()
ray.get(remote_ev.sample.remote())
@@ -293,7 +291,7 @@ class TestPolicyEvaluator(unittest.TestCase):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
sample_async=True,
policy_graph=MockPolicyGraph)
policy=MockPolicy)
batch = ev.sample()
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
self.assertIn(key, batch)
@@ -302,7 +300,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testAutoVectorization(self):
ev = PolicyEvaluator(
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_mode="truncate_episodes",
batch_steps=2,
num_envs=8)
@@ -325,7 +323,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testBatchesLargerWhenVectorized(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(episode_length=8),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_mode="truncate_episodes",
batch_steps=4,
num_envs=4)
@@ -340,7 +338,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testVectorEnvSupport(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_mode="truncate_episodes",
batch_steps=10)
for _ in range(8):
@@ -357,7 +355,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testTruncateEpisodes(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=15,
batch_mode="truncate_episodes")
batch = ev.sample()
@@ -366,7 +364,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testCompleteEpisodes(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=5,
batch_mode="complete_episodes")
batch = ev.sample()
@@ -375,7 +373,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testCompleteEpisodesPacking(self):
ev = PolicyEvaluator(
env_creator=lambda _: MockEnv(10),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
batch_steps=15,
batch_mode="complete_episodes")
batch = ev.sample()
@@ -387,7 +385,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testFilterSync(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
time.sleep(2)
@@ -400,7 +398,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testGetFilters(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
self.sample_and_flush(ev)
@@ -415,7 +413,7 @@ class TestPolicyEvaluator(unittest.TestCase):
def testSyncFilter(self):
ev = PolicyEvaluator(
env_creator=lambda _: gym.make("CartPole-v0"),
policy_graph=MockPolicyGraph,
policy=MockPolicy,
sample_async=True,
observation_filter="ConcurrentMeanStdFilter")
obs_f = self.sample_and_flush(ev)
+19 -2
View File
@@ -10,13 +10,30 @@ from ray.tune.util import merge_dicts, deep_update
logger = logging.getLogger(__name__)
def renamed_class(cls):
def renamed_class(cls, old_name):
"""Helper class for renaming classes with a warning."""
class DeprecationWrapper(cls):
# note: **kw not supported for ray.remote classes
def __init__(self, *args, **kw):
new_name = cls.__module__ + "." + cls.__name__
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
format(old_name, new_name) +
"This will raise an error in the future.")
cls.__init__(self, *args, **kw)
DeprecationWrapper.__name__ = cls.__name__
return DeprecationWrapper
def renamed_agent(cls):
"""Helper class for renaming Agent => Trainer with a warning."""
class DeprecationWrapper(cls):
def __init__(self, config=None, env=None, logger_creator=None):
old_name = cls.__name__.replace("Trainer", "Agent")
new_name = cls.__name__
new_name = cls.__module__ + "." + cls.__name__
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
format(old_name, new_name) +
"This will raise an error in the future.")
+1 -1
View File
@@ -6,7 +6,7 @@ import numpy as np
import pprint
import time
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
_logged = set()
_disabled = False