mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 23:46:50 +08:00
[rllib] Rename PolicyGraph => Policy, move from evaluation/ to policy/ (#4819)
This implements some of the renames proposed in #4813 We leave behind backwards-compatibility aliases for *PolicyGraph and SampleBatch.
This commit is contained in:
@@ -10,12 +10,14 @@ from ray.tune.registry import register_trainable
|
||||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
||||
|
||||
def _setup_logger():
|
||||
@@ -43,7 +45,9 @@ _setup_logger()
|
||||
_register_all()
|
||||
|
||||
__all__ = [
|
||||
"Policy",
|
||||
"PolicyGraph",
|
||||
"TFPolicy",
|
||||
"TFPolicyGraph",
|
||||
"PolicyEvaluator",
|
||||
"SampleBatch",
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.a3c.a2c import A2CTrainer
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
A2CAgent = renamed_class(A2CTrainer)
|
||||
A3CAgent = renamed_class(A3CTrainer)
|
||||
A2CAgent = renamed_agent(A2CTrainer)
|
||||
A3CAgent = renamed_agent(A3CTrainer)
|
||||
|
||||
__all__ = [
|
||||
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG"
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -43,16 +43,16 @@ class A3CTrainer(Trainer):
|
||||
|
||||
_name = "A3C"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = A3CPolicyGraph
|
||||
_policy = A3CTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy_graph import \
|
||||
from ray.rllib.agents.a3c.a3c_torch_policy import \
|
||||
A3CTorchPolicy
|
||||
policy_cls = A3CTorchPolicy
|
||||
else:
|
||||
policy_cls = self._policy_graph
|
||||
policy_cls = self._policy
|
||||
|
||||
if config["entropy_coeff"] < 0:
|
||||
raise DeprecationWarning("entropy_coeff must be >= 0")
|
||||
|
||||
+13
-13
@@ -1,4 +1,4 @@
|
||||
"""Note: Keep in sync with changes to VTracePolicyGraph."""
|
||||
"""Note: Keep in sync with changes to VTraceTFPolicy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@@ -8,13 +8,13 @@ import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
from ray.rllib.policy.tf_policy import TFPolicy, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -47,13 +47,13 @@ class A3CLoss(object):
|
||||
class A3CPostprocessing(object):
|
||||
"""Adds the VF preds and advantages fields to the trajectory."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**{SampleBatch.VF_PREDS: self.vf})
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -73,7 +73,7 @@ class A3CPostprocessing(object):
|
||||
self.config["lambda"])
|
||||
|
||||
|
||||
class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
|
||||
class A3CTFPolicy(LearningRateSchedule, A3CPostprocessing, TFPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
@@ -114,7 +114,7 @@ class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
|
||||
self.vf, self.config["vf_loss_coeff"],
|
||||
self.config["entropy_coeff"])
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
# Initialize TFPolicy
|
||||
loss_in = [
|
||||
(SampleBatch.CUR_OBS, self.observations),
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
@@ -125,7 +125,7 @@ class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
|
||||
]
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -157,18 +157,18 @@ class A3CPolicyGraph(LearningRateSchedule, A3CPostprocessing, TFPolicyGraph):
|
||||
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
+2
-2
@@ -9,8 +9,8 @@ from torch import nn
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
|
||||
|
||||
def actor_critic_loss(policy, batch_tensors):
|
||||
@@ -3,6 +3,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
Agent = renamed_class(Trainer)
|
||||
Agent = renamed_agent(Trainer)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ray.rllib.agents.ars.ars import (ARSTrainer, DEFAULT_CONFIG)
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
ARSAgent = renamed_class(ARSTrainer)
|
||||
ARSAgent = renamed_agent(ARSTrainer)
|
||||
|
||||
__all__ = ["ARSAgent", "ARSTrainer", "DEFAULT_CONFIG"]
|
||||
|
||||
@@ -17,7 +17,7 @@ from ray.rllib.agents import Trainer, with_common_config
|
||||
from ray.rllib.agents.ars import optimizers
|
||||
from ray.rllib.agents.ars import policies
|
||||
from ray.rllib.agents.ars import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
@@ -5,10 +5,10 @@ from __future__ import print_function
|
||||
from ray.rllib.agents.ddpg.apex import ApexDDPGTrainer
|
||||
from ray.rllib.agents.ddpg.ddpg import DDPGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ddpg.td3 import TD3Trainer
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
ApexDDPGAgent = renamed_class(ApexDDPGTrainer)
|
||||
DDPGAgent = renamed_class(DDPGTrainer)
|
||||
ApexDDPGAgent = renamed_agent(ApexDDPGTrainer)
|
||||
DDPGAgent = renamed_agent(DDPGTrainer)
|
||||
|
||||
__all__ = [
|
||||
"DDPGAgent", "ApexDDPGAgent", "DDPGTrainer", "ApexDDPGTrainer",
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.ddpg.ddpg_policy_graph import DDPGPolicyGraph
|
||||
from ray.rllib.agents.ddpg.ddpg_policy import DDPGTFPolicy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
||||
@@ -163,7 +163,7 @@ class DDPGTrainer(DQNTrainer):
|
||||
"""DDPG implementation in TensorFlow."""
|
||||
_name = "DDPG"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DDPGPolicyGraph
|
||||
_policy = DDPGTFPolicy
|
||||
|
||||
@override(DQNTrainer)
|
||||
def _train(self):
|
||||
|
||||
+19
-19
@@ -7,15 +7,15 @@ import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import (
|
||||
_huber_loss, _minimize_and_clip, _scope_vars, _postprocess_dqn)
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.agents.dqn.dqn_policy import (_huber_loss, _minimize_and_clip,
|
||||
_scope_vars, _postprocess_dqn)
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
@@ -35,7 +35,7 @@ PRIO_WEIGHTS = "weights"
|
||||
class DDPGPostprocessing(object):
|
||||
"""Implements n-step learning and param noise adjustments."""
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -68,7 +68,7 @@ class DDPGPostprocessing(object):
|
||||
return _postprocess_dqn(self, sample_batch)
|
||||
|
||||
|
||||
class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
||||
class DDPGTFPolicy(DDPGPostprocessing, TFPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Box):
|
||||
@@ -281,7 +281,7 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
||||
self.critic_loss = self.twin_q_model.custom_loss(
|
||||
self.critic_loss, input_dict)
|
||||
|
||||
TFPolicyGraph.__init__(
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -301,12 +301,12 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
||||
# Hard initial update
|
||||
self.update_target(tau=1.0)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
# we don't use this because we have two separate optimisers
|
||||
return None
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def build_apply_op(self, optimizer, grads_and_vars):
|
||||
# for policy gradient, update policy net one time v.s.
|
||||
# update critic net `policy_delay` time(s)
|
||||
@@ -327,7 +327,7 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
||||
with tf.control_dependencies([tf.assign_add(self.global_step, 1)]):
|
||||
return tf.group(actor_op, critic_op)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if self.config["grad_norm_clipping"] is not None:
|
||||
actor_grads_and_vars = _minimize_and_clip(
|
||||
@@ -360,7 +360,7 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
||||
+ self._critic_grads_and_vars
|
||||
return grads_and_vars
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
return {
|
||||
# FIXME: what about turning off exploration? Isn't that a good
|
||||
@@ -370,31 +370,31 @@ class DDPGPolicyGraph(DDPGPostprocessing, TFPolicyGraph):
|
||||
self.pure_exploration_phase: self.cur_pure_exploration_phase,
|
||||
}
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.td_error,
|
||||
LEARNER_STATS_KEY: self.stats,
|
||||
}
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
return [
|
||||
TFPolicyGraph.get_state(self), self.cur_noise_scale,
|
||||
TFPolicy.get_state(self), self.cur_noise_scale,
|
||||
self.cur_pure_exploration_phase
|
||||
]
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
TFPolicyGraph.set_state(self, state[0])
|
||||
TFPolicy.set_state(self, state[0])
|
||||
self.set_epsilon(state[1])
|
||||
self.set_pure_exploration_phase(state[2])
|
||||
|
||||
@@ -4,10 +4,10 @@ from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.dqn.apex import ApexTrainer
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
DQNAgent = renamed_class(DQNTrainer)
|
||||
ApexAgent = renamed_class(ApexTrainer)
|
||||
DQNAgent = renamed_agent(DQNTrainer)
|
||||
ApexAgent = renamed_agent(ApexTrainer)
|
||||
|
||||
__all__ = [
|
||||
"DQNAgent", "ApexAgent", "ApexTrainer", "DQNTrainer", "DEFAULT_CONFIG"
|
||||
|
||||
@@ -8,9 +8,9 @@ import time
|
||||
from ray import tune
|
||||
from ray.rllib import optimizers
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule
|
||||
|
||||
@@ -133,7 +133,7 @@ class DQNTrainer(Trainer):
|
||||
|
||||
_name = "DQN"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = DQNPolicyGraph
|
||||
_policy = DQNTFPolicy
|
||||
_optimizer_shared_configs = OPTIMIZER_SHARED_CONFIGS
|
||||
|
||||
@override(Trainer)
|
||||
@@ -197,10 +197,10 @@ class DQNTrainer(Trainer):
|
||||
on_episode_end)
|
||||
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, self._policy_graph)
|
||||
env_creator, self._policy)
|
||||
|
||||
def create_remote_evaluators():
|
||||
return self.make_remote_evaluators(env_creator, self._policy_graph,
|
||||
return self.make_remote_evaluators(env_creator, self._policy,
|
||||
config["num_workers"])
|
||||
|
||||
if config["optimizer_class"] != "AsyncReplayOptimizer":
|
||||
|
||||
+26
-26
@@ -7,13 +7,13 @@ import numpy as np
|
||||
from scipy.stats import entropy
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.models import ModelCatalog, Categorical
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -105,14 +105,14 @@ class QLoss(object):
|
||||
class DQNPostprocessing(object):
|
||||
"""Implements n-step learning and param noise adjustments."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self), **{
|
||||
TFPolicy.extra_compute_action_fetches(self), **{
|
||||
"q_values": self.q_values,
|
||||
})
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -345,7 +345,7 @@ class QValuePolicy(object):
|
||||
self.action_prob = None
|
||||
|
||||
|
||||
class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
|
||||
class DQNTFPolicy(LearningRateSchedule, DQNPostprocessing, TFPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
|
||||
if not isinstance(action_space, Discrete):
|
||||
@@ -446,7 +446,7 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
|
||||
update_target_expr.append(var_target.assign(var))
|
||||
self.update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
# initialize TFPolicyGraph
|
||||
# initialize TFPolicy
|
||||
self.sess = tf.get_default_session()
|
||||
self.loss_inputs = [
|
||||
(SampleBatch.CUR_OBS, self.obs_t),
|
||||
@@ -459,7 +459,7 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
|
||||
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -477,12 +477,12 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
}, **self.loss.stats)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(
|
||||
learning_rate=self.cur_lr, epsilon=self.config["adam_epsilon"])
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if self.config["grad_norm_clipping"] is not None:
|
||||
grads_and_vars = _minimize_and_clip(
|
||||
@@ -496,27 +496,27 @@ class DQNPolicyGraph(LearningRateSchedule, DQNPostprocessing, TFPolicyGraph):
|
||||
grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]
|
||||
return grads_and_vars
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_feed_dict(self):
|
||||
return {
|
||||
self.stochastic: True,
|
||||
self.eps: self.cur_epsilon,
|
||||
}
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {
|
||||
"td_error": self.loss.td_error,
|
||||
LEARNER_STATS_KEY: self.stats_fetches,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
return [TFPolicyGraph.get_state(self), self.cur_epsilon]
|
||||
return [TFPolicy.get_state(self), self.cur_epsilon]
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
TFPolicyGraph.set_state(self, state[0])
|
||||
TFPolicy.set_state(self, state[0])
|
||||
self.set_epsilon(state[1])
|
||||
|
||||
def _build_parameter_noise(self, pnet_params):
|
||||
@@ -633,25 +633,25 @@ def _adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones):
|
||||
rewards[i] += gamma**j * rewards[i + j]
|
||||
|
||||
|
||||
def _postprocess_dqn(policy_graph, batch):
|
||||
def _postprocess_dqn(policy, batch):
|
||||
# N-step Q adjustments
|
||||
if policy_graph.config["n_step"] > 1:
|
||||
_adjust_nstep(policy_graph.config["n_step"],
|
||||
policy_graph.config["gamma"], batch[SampleBatch.CUR_OBS],
|
||||
batch[SampleBatch.ACTIONS], batch[SampleBatch.REWARDS],
|
||||
batch[SampleBatch.NEXT_OBS], batch[SampleBatch.DONES])
|
||||
if policy.config["n_step"] > 1:
|
||||
_adjust_nstep(policy.config["n_step"], policy.config["gamma"],
|
||||
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
|
||||
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
|
||||
batch[SampleBatch.DONES])
|
||||
|
||||
if PRIO_WEIGHTS not in batch:
|
||||
batch[PRIO_WEIGHTS] = np.ones_like(batch[SampleBatch.REWARDS])
|
||||
|
||||
# Prioritize on the worker side
|
||||
if batch.count > 0 and policy_graph.config["worker_side_prioritization"]:
|
||||
td_errors = policy_graph.compute_td_error(
|
||||
if batch.count > 0 and policy.config["worker_side_prioritization"]:
|
||||
td_errors = policy.compute_td_error(
|
||||
batch[SampleBatch.CUR_OBS], batch[SampleBatch.ACTIONS],
|
||||
batch[SampleBatch.REWARDS], batch[SampleBatch.NEXT_OBS],
|
||||
batch[SampleBatch.DONES], batch[PRIO_WEIGHTS])
|
||||
new_priorities = (
|
||||
np.abs(td_errors) + policy_graph.config["prioritized_replay_eps"])
|
||||
np.abs(td_errors) + policy.config["prioritized_replay_eps"])
|
||||
batch.data[PRIO_WEIGHTS] = new_priorities
|
||||
|
||||
return batch
|
||||
@@ -1,6 +1,6 @@
|
||||
from ray.rllib.agents.es.es import (ESTrainer, DEFAULT_CONFIG)
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
ESAgent = renamed_class(ESTrainer)
|
||||
ESAgent = renamed_agent(ESTrainer)
|
||||
|
||||
__all__ = ["ESAgent", "ESTrainer", "DEFAULT_CONFIG"]
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.rllib.agents import Trainer, with_common_config
|
||||
from ray.rllib.agents.es import optimizers
|
||||
from ray.rllib.agents.es import policies
|
||||
from ray.rllib.agents.es import utils
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import FilterManager
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ray.rllib.agents.impala.impala import ImpalaTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
ImpalaAgent = renamed_class(ImpalaTrainer)
|
||||
ImpalaAgent = renamed_agent(ImpalaTrainer)
|
||||
|
||||
__all__ = ["ImpalaAgent", "ImpalaTrainer", "DEFAULT_CONFIG"]
|
||||
|
||||
@@ -4,8 +4,8 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy_graph import A3CPolicyGraph
|
||||
from ray.rllib.agents.impala.vtrace_policy_graph import VTracePolicyGraph
|
||||
from ray.rllib.agents.a3c.a3c_tf_policy import A3CTFPolicy
|
||||
from ray.rllib.agents.impala.vtrace_policy import VTraceTFPolicy
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
@@ -105,14 +105,14 @@ class ImpalaTrainer(Trainer):
|
||||
|
||||
_name = "IMPALA"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = VTracePolicyGraph
|
||||
_policy = VTraceTFPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
for k in OPTIMIZER_SHARED_CONFIGS:
|
||||
if k not in config["optimizer"]:
|
||||
config["optimizer"][k] = config[k]
|
||||
policy_cls = self._get_policy_graph()
|
||||
policy_cls = self._get_policy()
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
self.env_creator, policy_cls)
|
||||
|
||||
@@ -158,9 +158,9 @@ class ImpalaTrainer(Trainer):
|
||||
prev_steps)
|
||||
return result
|
||||
|
||||
def _get_policy_graph(self):
|
||||
def _get_policy(self):
|
||||
if self.config["vtrace"]:
|
||||
policy_cls = self._policy_graph
|
||||
policy_cls = self._policy
|
||||
else:
|
||||
policy_cls = A3CPolicyGraph
|
||||
policy_cls = A3CTFPolicy
|
||||
return policy_cls
|
||||
|
||||
+17
-18
@@ -1,6 +1,6 @@
|
||||
"""Adapted from A3CPolicyGraph to add V-trace.
|
||||
"""Adapted from A3CTFPolicy to add V-trace.
|
||||
|
||||
Keep in sync with changes to A3CPolicyGraph and VtraceSurrogatePolicyGraph."""
|
||||
Keep in sync with changes to A3CTFPolicy and VtraceSurrogatePolicy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@@ -11,9 +11,9 @@ import ray
|
||||
import numpy as np
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
@@ -110,13 +110,13 @@ class VTraceLoss(object):
|
||||
class VTracePostprocessing(object):
|
||||
"""Adds the policy logits to the trajectory."""
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**{BEHAVIOUR_LOGITS: self.model.outputs})
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -126,8 +126,7 @@ class VTracePostprocessing(object):
|
||||
return sample_batch
|
||||
|
||||
|
||||
class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
TFPolicyGraph):
|
||||
class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -285,7 +284,7 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
"max_KL": tf.reduce_max(kls[0]),
|
||||
}
|
||||
|
||||
# Initialize TFPolicyGraph
|
||||
# Initialize TFPolicy
|
||||
loss_in = [
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
(SampleBatch.DONES, dones),
|
||||
@@ -297,7 +296,7 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
]
|
||||
LearningRateSchedule.__init__(self, self.config["lr"],
|
||||
self.config["lr_schedule"])
|
||||
TFPolicyGraph.__init__(
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -332,15 +331,15 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
}, **self.KL_stats),
|
||||
}
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def copy(self, existing_inputs):
|
||||
return VTracePolicyGraph(
|
||||
return VTraceTFPolicy(
|
||||
self.observation_space,
|
||||
self.action_space,
|
||||
self.config,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
if self.config["opt_type"] == "adam":
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
@@ -349,17 +348,17 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
self.config["momentum"],
|
||||
self.config["epsilon"])
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
grads = tf.gradients(loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"])
|
||||
clipped_grads = list(zip(self.grads, self.var_list))
|
||||
return clipped_grads
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return self.stats_fetches
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
@@ -3,7 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import Trainer, with_common_config
|
||||
from ray.rllib.agents.marwil.marwil_policy_graph import MARWILPolicyGraph
|
||||
from ray.rllib.agents.marwil.marwil_policy import MARWILPolicy
|
||||
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
@@ -44,14 +44,14 @@ class MARWILTrainer(Trainer):
|
||||
|
||||
_name = "MARWIL"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = MARWILPolicyGraph
|
||||
_policy = MARWILPolicy
|
||||
|
||||
@override(Trainer)
|
||||
def _init(self, config, env_creator):
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, self._policy_graph)
|
||||
env_creator, self._policy)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, self._policy_graph, config["num_workers"])
|
||||
env_creator, self._policy, config["num_workers"])
|
||||
self.optimizer = SyncBatchReplayOptimizer(
|
||||
self.local_evaluator,
|
||||
self.remote_evaluators,
|
||||
|
||||
+10
-10
@@ -6,12 +6,12 @@ import ray
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import _scope_vars
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.agents.dqn.dqn_policy import _scope_vars
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -59,7 +59,7 @@ class ReweightedImitationLoss(object):
|
||||
class MARWILPostprocessing(object):
|
||||
"""Adds the advantages field to the trajectory."""
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -79,7 +79,7 @@ class MARWILPostprocessing(object):
|
||||
return batch
|
||||
|
||||
|
||||
class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph):
|
||||
class MARWILPolicy(MARWILPostprocessing, TFPolicy):
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
config = dict(ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG, **config)
|
||||
self.config = config
|
||||
@@ -127,14 +127,14 @@ class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph):
|
||||
self.explained_variance = tf.reduce_mean(
|
||||
explained_variance(self.cum_rew_t, state_values))
|
||||
|
||||
# initialize TFPolicyGraph
|
||||
# initialize TFPolicy
|
||||
self.sess = tf.get_default_session()
|
||||
self.loss_inputs = [
|
||||
(SampleBatch.CUR_OBS, self.obs_t),
|
||||
(SampleBatch.ACTIONS, self.act_t),
|
||||
(Postprocessing.ADVANTAGES, self.cum_rew_t),
|
||||
]
|
||||
TFPolicyGraph.__init__(
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
observation_space,
|
||||
action_space,
|
||||
@@ -166,10 +166,10 @@ class MARWILPolicyGraph(MARWILPostprocessing, TFPolicyGraph):
|
||||
return ReweightedImitationLoss(state_values, cum_rwds, logits, actions,
|
||||
action_space, self.config["beta"])
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_grad_fetches(self):
|
||||
return {LEARNER_STATS_KEY: self.stats_fetches}
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return self.model.state_init
|
||||
@@ -1,6 +1,6 @@
|
||||
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
PGAgent = renamed_class(PGTrainer)
|
||||
PGAgent = renamed_agent(PGTrainer)
|
||||
|
||||
__all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"]
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
@@ -22,7 +22,7 @@ DEFAULT_CONFIG = with_common_config({
|
||||
|
||||
def get_policy_class(config):
|
||||
if config["use_pytorch"]:
|
||||
from ray.rllib.agents.pg.torch_pg_policy_graph import PGTorchPolicy
|
||||
from ray.rllib.agents.pg.torch_pg_policy import PGTorchPolicy
|
||||
return PGTorchPolicy
|
||||
else:
|
||||
return PGTFPolicy
|
||||
|
||||
+2
-2
@@ -5,8 +5,8 @@ from __future__ import print_function
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
+2
-2
@@ -5,8 +5,8 @@ from __future__ import print_function
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
|
||||
|
||||
def pg_torch_loss(policy, batch_tensors):
|
||||
@@ -1,7 +1,7 @@
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.ppo.appo import APPOTrainer
|
||||
from ray.rllib.utils import renamed_class
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
PPOAgent = renamed_class(PPOTrainer)
|
||||
PPOAgent = renamed_agent(PPOTrainer)
|
||||
|
||||
__all__ = ["PPOAgent", "APPOTrainer", "PPOTrainer", "DEFAULT_CONFIG"]
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.ppo.appo_policy_graph import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.ppo.appo_policy import AsyncPPOTFPolicy
|
||||
from ray.rllib.agents.trainer import with_base_config
|
||||
from ray.rllib.agents import impala
|
||||
from ray.rllib.utils.annotations import override
|
||||
@@ -57,8 +57,8 @@ class APPOTrainer(impala.ImpalaTrainer):
|
||||
|
||||
_name = "APPO"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = AsyncPPOTFPolicy
|
||||
_policy = AsyncPPOTFPolicy
|
||||
|
||||
@override(impala.ImpalaTrainer)
|
||||
def _get_policy_graph(self):
|
||||
def _get_policy(self):
|
||||
return AsyncPPOTFPolicy
|
||||
|
||||
+5
-5
@@ -1,6 +1,6 @@
|
||||
"""Adapted from VTracePolicyGraph to use the PPO surrogate loss.
|
||||
"""Adapted from VTraceTFPolicy to use the PPO surrogate loss.
|
||||
|
||||
Keep in sync with changes to VTracePolicyGraph."""
|
||||
Keep in sync with changes to VTraceTFPolicy."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@@ -13,9 +13,9 @@ import gym
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.utils import try_import_tf
|
||||
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import logging
|
||||
|
||||
from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer, LocalMultiGPUOptimizer
|
||||
|
||||
@@ -143,8 +143,7 @@ def validate_config(config):
|
||||
raise ValueError(
|
||||
"Episode truncation is not supported without a value "
|
||||
"function. Consider setting batch_mode=complete_episodes.")
|
||||
if (config["multiagent"]["policy_graphs"]
|
||||
and not config["simple_optimizer"]):
|
||||
if (config["multiagent"]["policies"] and not config["simple_optimizer"]):
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
|
||||
+3
-3
@@ -7,9 +7,9 @@ import logging
|
||||
import ray
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages, \
|
||||
Postprocessing
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import LearningRateSchedule
|
||||
from ray.rllib.evaluation.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
@@ -4,7 +4,7 @@ from __future__ import print_function
|
||||
|
||||
from ray.rllib.agents.trainer import with_common_config
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.qmix.qmix_policy_graph import QMixPolicyGraph
|
||||
from ray.rllib.agents.qmix.qmix_policy import QMixTorchPolicy
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
@@ -95,7 +95,7 @@ class QMixTrainer(DQNTrainer):
|
||||
|
||||
_name = "QMIX"
|
||||
_default_config = DEFAULT_CONFIG
|
||||
_policy_graph = QMixPolicyGraph
|
||||
_policy = QMixTorchPolicy
|
||||
_optimizer_shared_configs = [
|
||||
"learning_starts", "buffer_size", "train_batch_size"
|
||||
]
|
||||
|
||||
+10
-10
@@ -14,8 +14,8 @@ import ray
|
||||
from ray.rllib.agents.qmix.mixers import VDNMixer, QMixer
|
||||
from ray.rllib.agents.qmix.model import RNNModel, _get_size
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
@@ -130,7 +130,7 @@ class QMixLoss(nn.Module):
|
||||
return loss, mask, masked_td_error, chosen_action_qvals, targets
|
||||
|
||||
|
||||
class QMixPolicyGraph(PolicyGraph):
|
||||
class QMixTorchPolicy(Policy):
|
||||
"""QMix impl. Assumes homogeneous agents for now.
|
||||
|
||||
You must use MultiAgentEnv.with_agent_groups() to group agents
|
||||
@@ -213,7 +213,7 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
alpha=config["optim_alpha"],
|
||||
eps=config["optim_eps"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
@@ -243,7 +243,7 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
|
||||
return TupleActions(list(actions.transpose([1, 0]))), hiddens, {}
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, samples):
|
||||
obs_batch, action_mask = self._unpack_observation(
|
||||
samples[SampleBatch.CUR_OBS])
|
||||
@@ -314,22 +314,22 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
}
|
||||
return {LEARNER_STATS_KEY: stats}
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return [
|
||||
s.expand([self.n_agents, -1]).numpy()
|
||||
for s in self.model.state_init()
|
||||
]
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
return {"model": self.model.state_dict()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
self.model.load_state_dict(weights["model"])
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_state(self):
|
||||
return {
|
||||
"model": self.model.state_dict(),
|
||||
@@ -340,7 +340,7 @@ class QMixPolicyGraph(PolicyGraph):
|
||||
"cur_epsilon": self.cur_epsilon,
|
||||
}
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def set_state(self, state):
|
||||
self.model.load_state_dict(state["model"])
|
||||
self.target_model.load_state_dict(state["target_model"])
|
||||
@@ -19,7 +19,7 @@ from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \
|
||||
from ray.rllib.models import MODEL_DEFAULTS
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator, \
|
||||
_validate_multiagent_config
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI
|
||||
@@ -220,9 +220,9 @@ COMMON_CONFIG = {
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
# Map from policy ids to tuples of (policy_graph_cls, obs_space,
|
||||
# Map from policy ids to tuples of (policy_cls, obs_space,
|
||||
# act_space, config). See policy_evaluator.py for more info.
|
||||
"policy_graphs": {},
|
||||
"policies": {},
|
||||
# Function mapping agent ids to policy ids.
|
||||
"policy_mapping_fn": None,
|
||||
# Optional whitelist of policies to train, or None for all policies.
|
||||
@@ -435,9 +435,7 @@ class Trainer(Trainable):
|
||||
"using evaluation_config: {}".format(extra_config))
|
||||
# Make local evaluation evaluators
|
||||
self.evaluation_ev = self.make_local_evaluator(
|
||||
self.env_creator,
|
||||
self._policy_graph,
|
||||
extra_config=extra_config)
|
||||
self.env_creator, self._policy, extra_config=extra_config)
|
||||
self.evaluation_metrics = self._evaluate()
|
||||
|
||||
@override(Trainable)
|
||||
@@ -578,10 +576,10 @@ class Trainer(Trainable):
|
||||
|
||||
@PublicAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
"""Return policy for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy graph to return.
|
||||
policy_id (str): id of policy to return.
|
||||
"""
|
||||
|
||||
return self.local_evaluator.get_policy(policy_id)
|
||||
@@ -606,16 +604,13 @@ class Trainer(Trainable):
|
||||
self.local_evaluator.set_weights(weights)
|
||||
|
||||
@DeveloperAPI
|
||||
def make_local_evaluator(self,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
extra_config=None):
|
||||
def make_local_evaluator(self, env_creator, policy, extra_config=None):
|
||||
"""Convenience method to return configured local evaluator."""
|
||||
|
||||
return self._make_evaluator(
|
||||
PolicyEvaluator,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
policy,
|
||||
0,
|
||||
merge_dicts(
|
||||
# important: allow local tf to use more CPUs for optimization
|
||||
@@ -627,7 +622,7 @@ class Trainer(Trainable):
|
||||
extra_config or {}))
|
||||
|
||||
@DeveloperAPI
|
||||
def make_remote_evaluators(self, env_creator, policy_graph, count):
|
||||
def make_remote_evaluators(self, env_creator, policy, count):
|
||||
"""Convenience method to return a number of remote evaluators."""
|
||||
|
||||
remote_args = {
|
||||
@@ -639,8 +634,8 @@ class Trainer(Trainable):
|
||||
cls = PolicyEvaluator.as_remote(**remote_args).remote
|
||||
|
||||
return [
|
||||
self._make_evaluator(cls, env_creator, policy_graph, i + 1,
|
||||
self.config) for i in range(count)
|
||||
self._make_evaluator(cls, env_creator, policy, i + 1, self.config)
|
||||
for i in range(count)
|
||||
]
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -700,6 +695,13 @@ class Trainer(Trainable):
|
||||
|
||||
@staticmethod
|
||||
def _validate_config(config):
|
||||
if "policy_graphs" in config["multiagent"]:
|
||||
logger.warning(
|
||||
"The `policy_graphs` config has been renamed to `policies`.")
|
||||
# Backwards compatibility
|
||||
config["multiagent"]["policies"] = config["multiagent"][
|
||||
"policy_graphs"]
|
||||
del config["multiagent"]["policy_graphs"]
|
||||
if "gpu" in config:
|
||||
raise ValueError(
|
||||
"The `gpu` config is deprecated, please use `num_gpus=0|1` "
|
||||
@@ -760,8 +762,7 @@ class Trainer(Trainable):
|
||||
return hasattr(self, "optimizer") and isinstance(
|
||||
self.optimizer, PolicyOptimizer)
|
||||
|
||||
def _make_evaluator(self, cls, env_creator, policy_graph, worker_index,
|
||||
config):
|
||||
def _make_evaluator(self, cls, env_creator, policy, worker_index, config):
|
||||
def session_creator():
|
||||
logger.debug("Creating TF session {}".format(
|
||||
config["tf_session_args"]))
|
||||
@@ -803,18 +804,18 @@ class Trainer(Trainable):
|
||||
else:
|
||||
input_evaluation = config["input_evaluation"]
|
||||
|
||||
# Fill in the default policy graph if 'None' is specified in multiagent
|
||||
if self.config["multiagent"]["policy_graphs"]:
|
||||
tmp = self.config["multiagent"]["policy_graphs"]
|
||||
# Fill in the default policy if 'None' is specified in multiagent
|
||||
if self.config["multiagent"]["policies"]:
|
||||
tmp = self.config["multiagent"]["policies"]
|
||||
_validate_multiagent_config(tmp, allow_none_graph=True)
|
||||
for k, v in tmp.items():
|
||||
if v[0] is None:
|
||||
tmp[k] = (policy_graph, v[1], v[2], v[3])
|
||||
policy_graph = tmp
|
||||
tmp[k] = (policy, v[1], v[2], v[3])
|
||||
policy = tmp
|
||||
|
||||
return cls(
|
||||
env_creator,
|
||||
policy_graph,
|
||||
policy,
|
||||
policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
|
||||
policies_to_train=self.config["multiagent"]["policies_to_train"],
|
||||
tf_session_creator=(session_creator
|
||||
|
||||
@@ -21,7 +21,7 @@ def build_trainer(name,
|
||||
|
||||
Arguments:
|
||||
name (str): name of the trainer (e.g., "PPO")
|
||||
default_policy (cls): the default PolicyGraph class to use
|
||||
default_policy (cls): the default Policy class to use
|
||||
default_config (dict): the default config dict of the algorithm,
|
||||
otherwises uses the Trainer default config
|
||||
make_policy_optimizer (func): optional function that returns a
|
||||
@@ -30,7 +30,7 @@ def build_trainer(name,
|
||||
validate_config (func): optional callback that checks a given config
|
||||
for correctness. It may mutate the config as needed.
|
||||
get_policy_class (func): optional callback that takes a config and
|
||||
returns the policy graph class to override the default with
|
||||
returns the policy class to override the default with
|
||||
before_train_step (func): optional callback to run before each train()
|
||||
call. It takes the trainer instance as an argument.
|
||||
after_optimizer_step (func): optional callback to run after each
|
||||
@@ -51,19 +51,19 @@ def build_trainer(name,
|
||||
class trainer_cls(Trainer):
|
||||
_name = name
|
||||
_default_config = default_config or Trainer.COMMON_CONFIG
|
||||
_policy_graph = default_policy
|
||||
_policy = default_policy
|
||||
|
||||
def _init(self, config, env_creator):
|
||||
if validate_config:
|
||||
validate_config(config)
|
||||
if get_policy_class is None:
|
||||
policy_graph = default_policy
|
||||
policy = default_policy
|
||||
else:
|
||||
policy_graph = get_policy_class(config)
|
||||
policy = get_policy_class(config)
|
||||
self.local_evaluator = self.make_local_evaluator(
|
||||
env_creator, policy_graph)
|
||||
env_creator, policy)
|
||||
self.remote_evaluators = self.make_remote_evaluators(
|
||||
env_creator, policy_graph, config["num_workers"])
|
||||
env_creator, policy, config["num_workers"])
|
||||
if make_policy_optimizer:
|
||||
self.optimizer = make_policy_optimizer(
|
||||
self.local_evaluator, self.remote_evaluators, config)
|
||||
|
||||
@@ -27,7 +27,7 @@ class MultiAgentEpisode(object):
|
||||
user_data (dict): Dict that you can use for temporary storage.
|
||||
|
||||
Use case 1: Model-based rollouts in multi-agent:
|
||||
A custom compute_actions() function in a policy graph can inspect the
|
||||
A custom compute_actions() function in a policy can inspect the
|
||||
current episode state and perform a number of rollouts based on the
|
||||
policies and state of other agents in the environment.
|
||||
|
||||
@@ -80,7 +80,7 @@ class MultiAgentEpisode(object):
|
||||
|
||||
@DeveloperAPI
|
||||
def policy_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the policy graph for the specified agent.
|
||||
"""Returns the policy for the specified agent.
|
||||
|
||||
If the agent is new, the policy mapping fn will be called to bind the
|
||||
agent to a policy for the duration of the episode.
|
||||
|
||||
@@ -62,7 +62,7 @@ class EvaluatorInterface(object):
|
||||
Returns:
|
||||
(grads, info): A list of gradients that can be applied on a
|
||||
compatible evaluator. In the multi-agent case, returns a dict
|
||||
of gradients keyed by policy graph ids. An info dictionary of
|
||||
of gradients keyed by policy ids. An info dictionary of
|
||||
extra metadata is also returned.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -7,21 +7,18 @@ import numpy as np
|
||||
import collections
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimate
|
||||
from ray.rllib.policy.policy import LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# By convention, metrics from optimizing the loss can be reported in the
|
||||
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
|
||||
LEARNER_STATS_KEY = "learner_stats"
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def get_learner_stats(grad_info):
|
||||
"""Return optimization stats reported from the policy graph.
|
||||
"""Return optimization stats reported from the policy.
|
||||
|
||||
Example:
|
||||
>>> grad_info = evaluator.learn_on_batch(samples)
|
||||
|
||||
@@ -15,11 +15,10 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.rllib.evaluation.interface import EvaluatorInterface
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader
|
||||
from ray.rllib.offline.is_estimator import ImportanceSamplingEstimator
|
||||
from ray.rllib.offline.wis_estimator import WeightedImportanceSamplingEstimator
|
||||
@@ -52,9 +51,9 @@ def get_global_evaluator():
|
||||
|
||||
@DeveloperAPI
|
||||
class PolicyEvaluator(EvaluatorInterface):
|
||||
"""Common ``PolicyEvaluator`` implementation that wraps a ``PolicyGraph``.
|
||||
"""Common ``PolicyEvaluator`` implementation that wraps a ``Policy``.
|
||||
|
||||
This class wraps a policy graph instance and an environment class to
|
||||
This class wraps a policy instance and an environment class to
|
||||
collect experiences from the environment. You can create many replicas of
|
||||
this class as Ray actors to scale RL training.
|
||||
|
||||
@@ -65,7 +64,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
>>> # Create a policy evaluator and using it to collect experiences.
|
||||
>>> evaluator = PolicyEvaluator(
|
||||
... env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
... policy_graph=PGTFPolicy)
|
||||
... policy=PGTFPolicy)
|
||||
>>> print(evaluator.sample())
|
||||
SampleBatch({
|
||||
"obs": [[...]], "actions": [[...]], "rewards": [[...]],
|
||||
@@ -76,7 +75,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
... evaluator_cls=PolicyEvaluator,
|
||||
... evaluator_args={
|
||||
... "env_creator": lambda _: gym.make("CartPole-v0"),
|
||||
... "policy_graph": PGTFPolicy,
|
||||
... "policy": PGTFPolicy,
|
||||
... },
|
||||
... num_workers=10)
|
||||
>>> for _ in range(10): optimizer.step()
|
||||
@@ -84,7 +83,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
>>> # Creating a multi-agent policy evaluator
|
||||
>>> evaluator = PolicyEvaluator(
|
||||
... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25),
|
||||
... policy_graphs={
|
||||
... policies={
|
||||
... # Use an ensemble of two policies for car agents
|
||||
... "car_policy1":
|
||||
... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}),
|
||||
@@ -113,7 +112,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
env_creator,
|
||||
policy_graph,
|
||||
policy,
|
||||
policy_mapping_fn=None,
|
||||
policies_to_train=None,
|
||||
tf_session_creator=None,
|
||||
@@ -147,9 +146,9 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
Arguments:
|
||||
env_creator (func): Function that returns a gym.Env given an
|
||||
EnvContext wrapped configuration.
|
||||
policy_graph (class|dict): Either a class implementing
|
||||
PolicyGraph, or a dictionary of policy id strings to
|
||||
(PolicyGraph, obs_space, action_space, config) tuples. If a
|
||||
policy (class|dict): Either a class implementing
|
||||
Policy, or a dictionary of policy id strings to
|
||||
(Policy, obs_space, action_space, config) tuples. If a
|
||||
dict is specified, then we are in multi-agent mode and a
|
||||
policy_mapping_fn should also be set.
|
||||
policy_mapping_fn (func): A function that maps agent ids to
|
||||
@@ -159,7 +158,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
policies_to_train (list): Optional whitelist of policies to train,
|
||||
or None for all policies.
|
||||
tf_session_creator (func): A function that returns a TF session.
|
||||
This is optional and only useful with TFPolicyGraph.
|
||||
This is optional and only useful with TFPolicy.
|
||||
batch_steps (int): The target number of env transitions to include
|
||||
in each sample batch returned from this evaluator.
|
||||
batch_mode (str): One of the following batch modes:
|
||||
@@ -196,7 +195,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
model_config (dict): Config to use when creating the policy model.
|
||||
policy_config (dict): Config to pass to the policy. In the
|
||||
multi-agent case, this config will be merged with the
|
||||
per-policy configs specified by `policy_graph`.
|
||||
per-policy configs specified by `policy`.
|
||||
worker_index (int): For remote evaluators, this should be set to a
|
||||
non-zero and unique value. This index is passed to created envs
|
||||
through EnvContext so that envs can be configured per worker.
|
||||
@@ -301,7 +300,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
vector_index=vector_index, remote=remote_worker_envs)))
|
||||
|
||||
self.tf_sess = None
|
||||
policy_dict = _validate_and_canonicalize(policy_graph, self.env)
|
||||
policy_dict = _validate_and_canonicalize(policy, self.env)
|
||||
self.policies_to_train = policies_to_train or list(policy_dict.keys())
|
||||
if _has_tensorflow_graph(policy_dict):
|
||||
if (ray.is_initialized()
|
||||
@@ -330,7 +329,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
or isinstance(self.env, ExternalMultiAgentEnv))
|
||||
or isinstance(self.env, BaseEnv)):
|
||||
raise ValueError(
|
||||
"Have multiple policy graphs {}, but the env ".format(
|
||||
"Have multiple policies {}, but the env ".format(
|
||||
self.policy_map) +
|
||||
"{} is not a subclass of BaseEnv, MultiAgentEnv or "
|
||||
"ExternalMultiAgentEnv?".format(self.env))
|
||||
@@ -608,17 +607,17 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
|
||||
@DeveloperAPI
|
||||
def get_policy(self, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Return policy graph for the specified id, or None.
|
||||
"""Return policy for the specified id, or None.
|
||||
|
||||
Arguments:
|
||||
policy_id (str): id of policy graph to return.
|
||||
policy_id (str): id of policy to return.
|
||||
"""
|
||||
|
||||
return self.policy_map.get(policy_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def for_policy(self, func, policy_id=DEFAULT_POLICY_ID):
|
||||
"""Apply the given function to the specified policy graph."""
|
||||
"""Apply the given function to the specified policy."""
|
||||
|
||||
return func(self.policy_map[policy_id])
|
||||
|
||||
@@ -708,7 +707,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
preprocessors = {}
|
||||
for name, (cls, obs_space, act_space,
|
||||
conf) in sorted(policy_dict.items()):
|
||||
logger.debug("Creating policy graph for {}".format(name))
|
||||
logger.debug("Creating policy for {}".format(name))
|
||||
merged_conf = merge_dicts(policy_config, conf)
|
||||
if self.preprocessing_enabled:
|
||||
preprocessor = ModelCatalog.get_preprocessor_for_space(
|
||||
@@ -720,7 +719,7 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
if isinstance(obs_space, gym.spaces.Dict) or \
|
||||
isinstance(obs_space, gym.spaces.Tuple):
|
||||
raise ValueError(
|
||||
"Found raw Tuple|Dict space as input to policy graph. "
|
||||
"Found raw Tuple|Dict space as input to policy. "
|
||||
"Please preprocess these observations with a "
|
||||
"Tuple|DictFlatteningPreprocessor.")
|
||||
if tf:
|
||||
@@ -738,12 +737,12 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.sampler.shutdown = True
|
||||
|
||||
|
||||
def _validate_and_canonicalize(policy_graph, env):
|
||||
if isinstance(policy_graph, dict):
|
||||
_validate_multiagent_config(policy_graph)
|
||||
return policy_graph
|
||||
elif not issubclass(policy_graph, PolicyGraph):
|
||||
raise ValueError("policy_graph must be a rllib.PolicyGraph class")
|
||||
def _validate_and_canonicalize(policy, env):
|
||||
if isinstance(policy, dict):
|
||||
_validate_multiagent_config(policy)
|
||||
return policy
|
||||
elif not issubclass(policy, Policy):
|
||||
raise ValueError("policy must be a rllib.Policy class")
|
||||
else:
|
||||
if (isinstance(env, MultiAgentEnv)
|
||||
and not hasattr(env, "observation_space")):
|
||||
@@ -751,38 +750,35 @@ def _validate_and_canonicalize(policy_graph, env):
|
||||
"MultiAgentEnv must have observation_space defined if run "
|
||||
"in a single-agent configuration.")
|
||||
return {
|
||||
DEFAULT_POLICY_ID: (policy_graph, env.observation_space,
|
||||
DEFAULT_POLICY_ID: (policy, env.observation_space,
|
||||
env.action_space, {})
|
||||
}
|
||||
|
||||
|
||||
def _validate_multiagent_config(policy_graph, allow_none_graph=False):
|
||||
for k, v in policy_graph.items():
|
||||
def _validate_multiagent_config(policy, allow_none_graph=False):
|
||||
for k, v in policy.items():
|
||||
if not isinstance(k, str):
|
||||
raise ValueError("policy_graph keys must be strs, got {}".format(
|
||||
raise ValueError("policy keys must be strs, got {}".format(
|
||||
type(k)))
|
||||
if not isinstance(v, tuple) or len(v) != 4:
|
||||
raise ValueError(
|
||||
"policy_graph values must be tuples of "
|
||||
"policy values must be tuples of "
|
||||
"(cls, obs_space, action_space, config), got {}".format(v))
|
||||
if allow_none_graph and v[0] is None:
|
||||
pass
|
||||
elif not issubclass(v[0], PolicyGraph):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 0 must be a rllib.PolicyGraph "
|
||||
"class or None, got {}".format(v[0]))
|
||||
elif not issubclass(v[0], Policy):
|
||||
raise ValueError("policy tuple value 0 must be a rllib.Policy "
|
||||
"class or None, got {}".format(v[0]))
|
||||
if not isinstance(v[1], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 1 (observation_space) must be a "
|
||||
"policy tuple value 1 (observation_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[1])))
|
||||
if not isinstance(v[2], gym.Space):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
raise ValueError("policy tuple value 2 (action_space) must be a "
|
||||
"gym.Space, got {}".format(type(v[2])))
|
||||
if not isinstance(v[3], dict):
|
||||
raise ValueError(
|
||||
"policy_graph tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
raise ValueError("policy tuple value 3 (config) must be a dict, "
|
||||
"got {}".format(type(v[3])))
|
||||
|
||||
|
||||
def _validate_env(env):
|
||||
@@ -805,6 +801,6 @@ def _monitor(env, path):
|
||||
|
||||
def _has_tensorflow_graph(policy_dict):
|
||||
for policy, _, _, _ in policy_dict.values():
|
||||
if issubclass(policy, TFPolicyGraph):
|
||||
if issubclass(policy, TFPolicy):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -2,286 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class PolicyGraph(object):
|
||||
"""An agent policy and loss, i.e., a TFPolicyGraph or other subclass.
|
||||
|
||||
This object defines how to act in the environment, and also losses used to
|
||||
improve the policy based on its experiences. Note that both policy and
|
||||
loss are defined together for convenience, though the policy itself is
|
||||
logically separate.
|
||||
|
||||
All policies can directly extend PolicyGraph, however TensorFlow users may
|
||||
find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib
|
||||
to apply TensorFlow-specific optimizations such as fusing multiple policy
|
||||
graphs and multi-GPU support.
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
"""Initialize the graph.
|
||||
|
||||
This is the standard constructor for policy graphs. The policy graph
|
||||
class you pass into PolicyEvaluator will be constructed with
|
||||
these arguments.
|
||||
|
||||
Args:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
config (dict): Policy-specific configuration data.
|
||||
"""
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
"""Compute actions for the current policy.
|
||||
|
||||
Arguments:
|
||||
obs_batch (np.ndarray): batch of observations
|
||||
state_batches (list): list of RNN state input batches, if any
|
||||
prev_action_batch (np.ndarray): batch of previous action values
|
||||
prev_reward_batch (np.ndarray): batch of previous rewards
|
||||
info_batch (info): batch of info objects
|
||||
episodes (list): MultiAgentEpisode for each obs in obs_batch.
|
||||
This provides access to all of the internal episode state,
|
||||
which may be useful for model-based or multiagent algorithms.
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (np.ndarray): batch of output actions, with shape like
|
||||
[BATCH_SIZE, ACTION_SHAPE].
|
||||
state_outs (list): list of RNN state output batches, if any, with
|
||||
shape like [STATE_SIZE, BATCH_SIZE].
|
||||
info (dict): dictionary of extra feature batches, if any, with
|
||||
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_single_action(self,
|
||||
obs,
|
||||
state,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
episode=None,
|
||||
clip_actions=False,
|
||||
**kwargs):
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
Arguments:
|
||||
obs (obj): single observation
|
||||
state_batches (list): list of RNN state inputs, if any
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
clip_actions (bool): should the action be clipped
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (obj): single action
|
||||
state_outs (list): list of RNN state outputs, if any
|
||||
info (dict): dictionary of extra features, if any
|
||||
"""
|
||||
|
||||
prev_action_batch = None
|
||||
prev_reward_batch = None
|
||||
info_batch = None
|
||||
episodes = None
|
||||
if prev_action is not None:
|
||||
prev_action_batch = [prev_action]
|
||||
if prev_reward is not None:
|
||||
prev_reward_batch = [prev_reward]
|
||||
if info is not None:
|
||||
info_batch = [info]
|
||||
if episode is not None:
|
||||
episodes = [episode]
|
||||
[action], state_out, info = self.compute_actions(
|
||||
[obs], [[s] for s in state],
|
||||
prev_action_batch=prev_action_batch,
|
||||
prev_reward_batch=prev_reward_batch,
|
||||
info_batch=info_batch,
|
||||
episodes=episodes)
|
||||
if clip_actions:
|
||||
action = clip_action(action, self.action_space)
|
||||
return action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
@DeveloperAPI
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
"""Implements algorithm-specific trajectory postprocessing.
|
||||
|
||||
This will be called on each trajectory fragment computed during policy
|
||||
evaluation. Each fragment is guaranteed to be only from one episode.
|
||||
|
||||
Arguments:
|
||||
sample_batch (SampleBatch): batch of experiences for the policy,
|
||||
which will contain at most one episode trajectory.
|
||||
other_agent_batches (dict): In a multi-agent env, this contains a
|
||||
mapping of agent ids to (policy_graph, agent_batch) tuples
|
||||
containing the policy graph and experiences of the other agent.
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
|
||||
Returns:
|
||||
SampleBatch: postprocessed sample batch.
|
||||
"""
|
||||
return sample_batch
|
||||
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
Either this or the combination of compute/apply grads must be
|
||||
implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
grad_info: dictionary of extra metadata from compute_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.learn_on_batch(samples)
|
||||
"""
|
||||
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
self.apply_gradients(grads)
|
||||
return grad_info
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
"""Computes gradients against a batch of experiences.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
grads (list): List of gradient output values
|
||||
info (dict): Extra policy-specific values
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, gradients):
|
||||
"""Applies previously computed gradients.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_weights(self):
|
||||
"""Returns model weights.
|
||||
|
||||
Returns:
|
||||
weights (obj): Serializable copy or view of model weights
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights):
|
||||
"""Sets model weights.
|
||||
|
||||
Arguments:
|
||||
weights (obj): Serializable copy or view of model weights
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_initial_state(self):
|
||||
"""Returns initial RNN state for the current policy."""
|
||||
return []
|
||||
|
||||
@DeveloperAPI
|
||||
def get_state(self):
|
||||
"""Saves all local state.
|
||||
|
||||
Returns:
|
||||
state (obj): Serialized local state.
|
||||
"""
|
||||
return self.get_weights()
|
||||
|
||||
@DeveloperAPI
|
||||
def set_state(self, state):
|
||||
"""Restores all local state.
|
||||
|
||||
Arguments:
|
||||
state (obj): Serialized local state.
|
||||
"""
|
||||
self.set_weights(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def on_global_var_update(self, global_vars):
|
||||
"""Called on an update to global vars.
|
||||
|
||||
Arguments:
|
||||
global_vars (dict): Global variables broadcast from the driver.
|
||||
"""
|
||||
pass
|
||||
|
||||
@DeveloperAPI
|
||||
def export_model(self, export_dir):
|
||||
"""Export PolicyGraph to local directory for serving.
|
||||
|
||||
Arguments:
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def export_checkpoint(self, export_dir):
|
||||
"""Export PolicyGraph checkpoint to local directory.
|
||||
|
||||
Argument:
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def clip_action(action, space):
|
||||
"""Called to clip actions to the specified range of this policy.
|
||||
|
||||
Arguments:
|
||||
action: Single action.
|
||||
space: Action space the actions should be present in.
|
||||
|
||||
Returns:
|
||||
Clipped batch of actions.
|
||||
"""
|
||||
|
||||
if isinstance(space, gym.spaces.Box):
|
||||
return np.clip(action, space.low, space.high)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
if type(action) not in (tuple, list):
|
||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||
action, space))
|
||||
out = []
|
||||
for a, s in zip(action, space.spaces):
|
||||
out.append(clip_action(a, s))
|
||||
return out
|
||||
else:
|
||||
return action
|
||||
PolicyGraph = renamed_class(Policy, old_name="PolicyGraph")
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import scipy.signal
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
|
||||
|
||||
@@ -2,295 +2,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import collections
|
||||
import numpy as np
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
||||
from ray.rllib.utils.memory import concat_aligned
|
||||
|
||||
# Defaults policy id for single agent environments
|
||||
DEFAULT_POLICY_ID = "default_policy"
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class MultiAgentBatch(object):
|
||||
"""A batch of experiences from multiple policies in the environment.
|
||||
|
||||
Attributes:
|
||||
policy_batches (dict): Mapping from policy id to a normal SampleBatch
|
||||
of experiences. Note that these batches may be of different length.
|
||||
count (int): The number of timesteps in the environment this batch
|
||||
contains. This will be less than the number of transitions this
|
||||
batch contains across all policies in total.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, policy_batches, count):
|
||||
self.policy_batches = policy_batches
|
||||
self.count = count
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def wrap_as_needed(batches, count):
|
||||
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
|
||||
return batches[DEFAULT_POLICY_ID]
|
||||
return MultiAgentBatch(batches, count)
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
policy_batches = collections.defaultdict(list)
|
||||
total_count = 0
|
||||
for s in samples:
|
||||
assert isinstance(s, MultiAgentBatch)
|
||||
for policy_id, batch in s.policy_batches.items():
|
||||
policy_batches[policy_id].append(batch)
|
||||
total_count += s.count
|
||||
out = {}
|
||||
for policy_id, batches in policy_batches.items():
|
||||
out[policy_id] = SampleBatch.concat_samples(batches)
|
||||
return MultiAgentBatch(out, total_count)
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
return MultiAgentBatch(
|
||||
{k: v.copy()
|
||||
for (k, v) in self.policy_batches.items()}, self.count)
|
||||
|
||||
@PublicAPI
|
||||
def total(self):
|
||||
ct = 0
|
||||
for batch in self.policy_batches.values():
|
||||
ct += batch.count
|
||||
return ct
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
for batch in self.policy_batches.values():
|
||||
batch.compress(bulk=bulk, columns=columns)
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
for batch in self.policy_batches.values():
|
||||
batch.decompress_if_needed(columns)
|
||||
|
||||
def __str__(self):
|
||||
return "MultiAgentBatch({}, count={})".format(
|
||||
str(self.policy_batches), self.count)
|
||||
|
||||
def __repr__(self):
|
||||
return "MultiAgentBatch({}, count={})".format(
|
||||
str(self.policy_batches), self.count)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class SampleBatch(object):
|
||||
"""Wrapper around a dictionary with string keys and array-like values.
|
||||
|
||||
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
|
||||
samples, each with an "obs" and "reward" attribute.
|
||||
"""
|
||||
|
||||
# Outputs from interacting with the environment
|
||||
CUR_OBS = "obs"
|
||||
NEXT_OBS = "new_obs"
|
||||
ACTIONS = "actions"
|
||||
REWARDS = "rewards"
|
||||
PREV_ACTIONS = "prev_actions"
|
||||
PREV_REWARDS = "prev_rewards"
|
||||
DONES = "dones"
|
||||
INFOS = "infos"
|
||||
|
||||
# Uniquely identifies an episode
|
||||
EPS_ID = "eps_id"
|
||||
|
||||
# Uniquely identifies a sample batch. This is important to distinguish RNN
|
||||
# sequences from the same episode when multiple sample batches are
|
||||
# concatenated (fusing sequences across batches can be unsafe).
|
||||
UNROLL_ID = "unroll_id"
|
||||
|
||||
# Uniquely identifies an agent within an episode
|
||||
AGENT_INDEX = "agent_index"
|
||||
|
||||
# Value function predictions emitted by the behaviour policy
|
||||
VF_PREDS = "vf_preds"
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Constructs a sample batch (same params as dict constructor)."""
|
||||
|
||||
self.data = dict(*args, **kwargs)
|
||||
lengths = []
|
||||
for k, v in self.data.copy().items():
|
||||
assert isinstance(k, six.string_types), self
|
||||
lengths.append(len(v))
|
||||
self.data[k] = np.array(v, copy=False)
|
||||
if not lengths:
|
||||
raise ValueError("Empty sample batch")
|
||||
assert len(set(lengths)) == 1, "data columns must be same length"
|
||||
self.count = lengths[0]
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
if isinstance(samples[0], MultiAgentBatch):
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
out = {}
|
||||
samples = [s for s in samples if s.count > 0]
|
||||
for k in samples[0].keys():
|
||||
out[k] = concat_aligned([s[k] for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def concat(self, other):
|
||||
"""Returns a new SampleBatch with each data column concatenated.
|
||||
|
||||
Examples:
|
||||
>>> b1 = SampleBatch({"a": [1, 2]})
|
||||
>>> b2 = SampleBatch({"a": [3, 4, 5]})
|
||||
>>> print(b1.concat(b2))
|
||||
{"a": [1, 2, 3, 4, 5]}
|
||||
"""
|
||||
|
||||
assert self.keys() == other.keys(), "must have same columns"
|
||||
out = {}
|
||||
for k in self.keys():
|
||||
out[k] = concat_aligned([self[k], other[k]])
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
return SampleBatch(
|
||||
{k: np.array(v, copy=True)
|
||||
for (k, v) in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def rows(self):
|
||||
"""Returns an iterator over data rows, i.e. dicts with column values.
|
||||
|
||||
Examples:
|
||||
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||
>>> for row in batch.rows():
|
||||
print(row)
|
||||
{"a": 1, "b": 4}
|
||||
{"a": 2, "b": 5}
|
||||
{"a": 3, "b": 6}
|
||||
"""
|
||||
|
||||
for i in range(self.count):
|
||||
row = {}
|
||||
for k in self.keys():
|
||||
row[k] = self[k][i]
|
||||
yield row
|
||||
|
||||
@PublicAPI
|
||||
def columns(self, keys):
|
||||
"""Returns a list of just the specified columns.
|
||||
|
||||
Examples:
|
||||
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
|
||||
>>> print(batch.columns(["a", "b"]))
|
||||
[[1], [2]]
|
||||
"""
|
||||
|
||||
out = []
|
||||
for k in keys:
|
||||
out.append(self[k])
|
||||
return out
|
||||
|
||||
@PublicAPI
|
||||
def shuffle(self):
|
||||
"""Shuffles the rows of this batch in-place."""
|
||||
|
||||
permutation = np.random.permutation(self.count)
|
||||
for key, val in self.items():
|
||||
self[key] = val[permutation]
|
||||
|
||||
@PublicAPI
|
||||
def split_by_episode(self):
|
||||
"""Splits this batch's data by `eps_id`.
|
||||
|
||||
Returns:
|
||||
list of SampleBatch, one per distinct episode.
|
||||
"""
|
||||
|
||||
slices = []
|
||||
cur_eps_id = self.data["eps_id"][0]
|
||||
offset = 0
|
||||
for i in range(self.count):
|
||||
next_eps_id = self.data["eps_id"][i]
|
||||
if next_eps_id != cur_eps_id:
|
||||
slices.append(self.slice(offset, i))
|
||||
offset = i
|
||||
cur_eps_id = next_eps_id
|
||||
slices.append(self.slice(offset, self.count))
|
||||
for s in slices:
|
||||
slen = len(set(s["eps_id"]))
|
||||
assert slen == 1, (s, slen)
|
||||
assert sum(s.count for s in slices) == self.count, (slices, self.count)
|
||||
return slices
|
||||
|
||||
@PublicAPI
|
||||
def slice(self, start, end):
|
||||
"""Returns a slice of the row data of this batch.
|
||||
|
||||
Arguments:
|
||||
start (int): Starting index.
|
||||
end (int): Ending index.
|
||||
|
||||
Returns:
|
||||
SampleBatch which has a slice of this batch's data.
|
||||
"""
|
||||
|
||||
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
@PublicAPI
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
@PublicAPI
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
@PublicAPI
|
||||
def __setitem__(self, key, item):
|
||||
self.data[key] = item
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
if bulk:
|
||||
self.data[key] = pack(self.data[key])
|
||||
else:
|
||||
self.data[key] = np.array(
|
||||
[pack(o) for o in self.data[key]])
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
if is_compressed(arr):
|
||||
self.data[key] = unpack(arr)
|
||||
elif len(arr) > 0 and is_compressed(arr[0]):
|
||||
self.data[key] = np.array(
|
||||
[unpack(o) for o in self.data[key]])
|
||||
|
||||
def __str__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def __repr__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def __iter__(self):
|
||||
return self.data.__iter__()
|
||||
|
||||
def __contains__(self, x):
|
||||
return x in self.data
|
||||
SampleBatch = renamed_class(
|
||||
SampleBatch, old_name="rllib.evaluation.SampleBatch")
|
||||
MultiAgentBatch = renamed_class(
|
||||
MultiAgentBatch, old_name="rllib.evaluation.MultiAgentBatch")
|
||||
|
||||
@@ -6,7 +6,7 @@ import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
|
||||
@@ -79,7 +79,7 @@ class MultiAgentSampleBatchBuilder(object):
|
||||
"""Initialize a MultiAgentSampleBatchBuilder.
|
||||
|
||||
Arguments:
|
||||
policy_map (dict): Maps policy ids to policy graph instances.
|
||||
policy_map (dict): Maps policy ids to policy instances.
|
||||
clip_rewards (bool): Whether to clip rewards before postprocessing.
|
||||
postp_callback: function to call on each postprocessed batch.
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,7 @@ import time
|
||||
from ray.rllib.evaluation.episode import MultiAgentEpisode, _flatten_action
|
||||
from ray.rllib.evaluation.sample_batch_builder import \
|
||||
MultiAgentSampleBatchBuilder
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.env.base_env import BaseEnv, ASYNC_RESET_RETURN
|
||||
from ray.rllib.env.atari_wrappers import get_wrapper_by_cls, MonitorEnv
|
||||
from ray.rllib.models.action_dist import TupleActions
|
||||
@@ -20,7 +20,7 @@ from ray.rllib.offline import InputReader
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.evaluation.policy_graph import clip_action
|
||||
from ray.rllib.policy.policy import clip_action
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -236,7 +236,7 @@ def _env_runner(base_env, extra_batch_callback, policies, policy_mapping_fn,
|
||||
Args:
|
||||
base_env (BaseEnv): env implementing BaseEnv.
|
||||
extra_batch_callback (fn): function to send extra batch data to.
|
||||
policies (dict): Map of policy ids to PolicyGraph instances.
|
||||
policies (dict): Map of policy ids to Policy instances.
|
||||
policy_mapping_fn (func): Function that maps agent ids to policy ids.
|
||||
This is called when an agent first enters the environment. The
|
||||
agent is then "bound" to the returned policy for the episode.
|
||||
@@ -528,7 +528,7 @@ def _do_policy_eval(tf_sess, to_eval, policies, active_episodes):
|
||||
rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data])
|
||||
policy = _get_or_raise(policies, policy_id)
|
||||
if builder and (policy.compute_actions.__code__ is
|
||||
TFPolicyGraph.compute_actions.__code__):
|
||||
TFPolicy.compute_actions.__code__):
|
||||
# TODO(ekl): how can we make info batch available to TF code?
|
||||
pending_fetches[policy_id] = policy._build_compute_actions(
|
||||
builder, [t.obs for t in eval_data],
|
||||
|
||||
@@ -2,513 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import errno
|
||||
import logging
|
||||
import numpy as np
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class TFPolicyGraph(PolicyGraph):
|
||||
"""An agent policy and loss implemented in TensorFlow.
|
||||
|
||||
Extending this class enables RLlib to perform TensorFlow specific
|
||||
optimizations on the policy graph, e.g., parallelization across gpus or
|
||||
fusing multiple graphs together in the multi-agent setting.
|
||||
|
||||
Input tensors are typically shaped like [BATCH_SIZE, ...].
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
model (rllib.models.Model): RLlib model used for the policy.
|
||||
|
||||
Examples:
|
||||
>>> policy = TFPolicyGraphSubclass(
|
||||
sess, obs_input, action_sampler, loss, loss_inputs)
|
||||
|
||||
>>> print(policy.compute_actions([1, 0, 2]))
|
||||
(array([0, 1, 1]), [], {})
|
||||
|
||||
>>> print(policy.postprocess_trajectory(SampleBatch({...})))
|
||||
SampleBatch({"action": ..., "advantages": ..., ...})
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
sess,
|
||||
obs_input,
|
||||
action_sampler,
|
||||
loss,
|
||||
loss_inputs,
|
||||
model=None,
|
||||
action_prob=None,
|
||||
state_inputs=None,
|
||||
state_outputs=None,
|
||||
prev_action_input=None,
|
||||
prev_reward_input=None,
|
||||
seq_lens=None,
|
||||
max_seq_len=20,
|
||||
batch_divisibility_req=1,
|
||||
update_ops=None):
|
||||
"""Initialize the policy graph.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
action_space (gym.Space): Action space of the env.
|
||||
sess (Session): TensorFlow session to use.
|
||||
obs_input (Tensor): input placeholder for observations, of shape
|
||||
[BATCH_SIZE, obs...].
|
||||
action_sampler (Tensor): Tensor for sampling an action, of shape
|
||||
[BATCH_SIZE, action...]
|
||||
loss (Tensor): scalar policy loss output tensor.
|
||||
loss_inputs (list): a (name, placeholder) tuple for each loss
|
||||
input argument. Each placeholder name must correspond to a
|
||||
SampleBatch column key returned by postprocess_trajectory(),
|
||||
and has shape [BATCH_SIZE, data...]. These keys will be read
|
||||
from postprocessed sample batches and fed into the specified
|
||||
placeholders during loss computation.
|
||||
model (rllib.models.Model): used to integrate custom losses and
|
||||
stats from user-defined RLlib models.
|
||||
action_prob (Tensor): probability of the sampled action.
|
||||
state_inputs (list): list of RNN state input Tensors.
|
||||
state_outputs (list): list of RNN state output Tensors.
|
||||
prev_action_input (Tensor): placeholder for previous actions
|
||||
prev_reward_input (Tensor): placeholder for previous rewards
|
||||
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
|
||||
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
|
||||
models/lstm.py for more information.
|
||||
max_seq_len (int): max sequence length for LSTM training.
|
||||
batch_divisibility_req (int): pad all agent experiences batches to
|
||||
multiples of this value. This only has an effect if not using
|
||||
a LSTM model.
|
||||
update_ops (list): override the batchnorm update ops to run when
|
||||
applying gradients. Otherwise we run all update ops found in
|
||||
the current variable scope.
|
||||
"""
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.model = model
|
||||
self._sess = sess
|
||||
self._obs_input = obs_input
|
||||
self._prev_action_input = prev_action_input
|
||||
self._prev_reward_input = prev_reward_input
|
||||
self._sampler = action_sampler
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
self._action_prob = action_prob
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
self._seq_lens = seq_lens
|
||||
self._max_seq_len = max_seq_len
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
self._update_ops = update_ops
|
||||
self._stats_fetches = {}
|
||||
|
||||
if loss is not None:
|
||||
self._initialize_loss(loss, loss_inputs)
|
||||
else:
|
||||
self._loss = None
|
||||
|
||||
if len(self._state_inputs) != len(self._state_outputs):
|
||||
raise ValueError(
|
||||
"Number of state input and output tensors must match, got: "
|
||||
"{} vs {}".format(self._state_inputs, self._state_outputs))
|
||||
if len(self.get_initial_state()) != len(self._state_inputs):
|
||||
raise ValueError(
|
||||
"Length of initial state must match number of state inputs, "
|
||||
"got: {} vs {}".format(self.get_initial_state(),
|
||||
self._state_inputs))
|
||||
if self._state_inputs and self._seq_lens is None:
|
||||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
||||
def _initialize_loss(self, loss, loss_inputs):
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
self._loss_input_dict["state_in_{}".format(i)] = ph
|
||||
|
||||
if self.model:
|
||||
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
|
||||
self._stats_fetches.update({"model": self.model.custom_stats()})
|
||||
else:
|
||||
self._loss = loss
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [
|
||||
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
|
||||
if g is not None
|
||||
]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
# gather update ops for any batch norm layers
|
||||
if not self._update_ops:
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
self._apply_op = self.build_apply_op(self._optimizer,
|
||||
self._grads_and_vars)
|
||||
|
||||
if log_once("loss_used"):
|
||||
logger.debug(
|
||||
"These tensors were used in the loss_fn:\n\n{}\n".format(
|
||||
summarize(self._loss_input_dict)))
|
||||
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
fetches = self._build_compute_actions(builder, obs_batch,
|
||||
state_batches, prev_action_batch,
|
||||
prev_reward_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def apply_gradients(self, gradients):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
||||
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_weights(self):
|
||||
return self._variables.get_flat()
|
||||
|
||||
@override(PolicyGraph)
|
||||
def set_weights(self, weights):
|
||||
return self._variables.set_flat(weights)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def export_model(self, export_dir):
|
||||
"""Export tensorflow graph to export_dir for serving."""
|
||||
with self._sess.graph.as_default():
|
||||
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
|
||||
signature_def_map = self._build_signature_def()
|
||||
builder.add_meta_graph_and_variables(
|
||||
self._sess, [tf.saved_model.tag_constants.SERVING],
|
||||
signature_def_map=signature_def_map)
|
||||
builder.save()
|
||||
|
||||
@override(PolicyGraph)
|
||||
def export_checkpoint(self, export_dir, filename_prefix="model"):
|
||||
"""Export tensorflow checkpoint to export_dir."""
|
||||
try:
|
||||
os.makedirs(export_dir)
|
||||
except OSError as e:
|
||||
# ignore error if export dir already exists
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
save_path = os.path.join(export_dir, filename_prefix)
|
||||
with self._sess.graph.as_default():
|
||||
saver = tf.train.Saver()
|
||||
saver.save(self._sess, save_path)
|
||||
|
||||
@DeveloperAPI
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders.
|
||||
|
||||
Optional, only required to work with the multi-GPU optimizer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_feed_dict(self):
|
||||
"""Extra dict to pass to the compute actions session run."""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_fetches(self):
|
||||
"""Extra values to fetch and return from compute_actions().
|
||||
|
||||
By default we only return action probability info (if present).
|
||||
"""
|
||||
if self._action_prob is not None:
|
||||
return {"action_prob": self._action_prob}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_feed_dict(self):
|
||||
"""Extra dict to pass to the compute gradients session run."""
|
||||
return {} # e.g, kl_coeff
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_fetches(self):
|
||||
"""Extra values to fetch and return from compute_gradients()."""
|
||||
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
||||
|
||||
@DeveloperAPI
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
if hasattr(self, "config"):
|
||||
return tf.train.AdamOptimizer(self.config["lr"])
|
||||
else:
|
||||
return tf.train.AdamOptimizer()
|
||||
|
||||
@DeveloperAPI
|
||||
def gradients(self, optimizer, loss):
|
||||
"""Override for custom gradient computation."""
|
||||
return optimizer.compute_gradients(loss)
|
||||
|
||||
@DeveloperAPI
|
||||
def build_apply_op(self, optimizer, grads_and_vars):
|
||||
"""Override for custom gradient apply computation."""
|
||||
|
||||
# specify global_step for TD3 which needs to count the num updates
|
||||
return optimizer.apply_gradients(
|
||||
self._grads_and_vars,
|
||||
global_step=tf.train.get_or_create_global_step())
|
||||
|
||||
@DeveloperAPI
|
||||
def _get_is_training_placeholder(self):
|
||||
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
||||
|
||||
This can be called safely before __init__ has run.
|
||||
"""
|
||||
if not hasattr(self, "_is_training"):
|
||||
self._is_training = tf.placeholder_with_default(False, ())
|
||||
return self._is_training
|
||||
|
||||
def _extra_input_signature_def(self):
|
||||
"""Extra input signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_feed_dict()
|
||||
"""
|
||||
feed_dict = self.extra_compute_action_feed_dict()
|
||||
return {
|
||||
k.name: tf.saved_model.utils.build_tensor_info(k)
|
||||
for k in feed_dict.keys()
|
||||
}
|
||||
|
||||
def _extra_output_signature_def(self):
|
||||
"""Extra output signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_fetches()
|
||||
"""
|
||||
fetches = self.extra_compute_action_fetches()
|
||||
return {
|
||||
k: tf.saved_model.utils.build_tensor_info(fetches[k])
|
||||
for k in fetches.keys()
|
||||
}
|
||||
|
||||
def _build_signature_def(self):
|
||||
"""Build signature def map for tensorflow SavedModelBuilder.
|
||||
"""
|
||||
# build input signatures
|
||||
input_signature = self._extra_input_signature_def()
|
||||
input_signature["observations"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._obs_input)
|
||||
|
||||
if self._seq_lens is not None:
|
||||
input_signature["seq_lens"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._seq_lens)
|
||||
if self._prev_action_input is not None:
|
||||
input_signature["prev_action"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._prev_action_input)
|
||||
if self._prev_reward_input is not None:
|
||||
input_signature["prev_reward"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._prev_reward_input)
|
||||
input_signature["is_training"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._is_training)
|
||||
|
||||
for state_input in self._state_inputs:
|
||||
input_signature[state_input.name] = \
|
||||
tf.saved_model.utils.build_tensor_info(state_input)
|
||||
|
||||
# build output signatures
|
||||
output_signature = self._extra_output_signature_def()
|
||||
output_signature["actions"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._sampler)
|
||||
for state_output in self._state_outputs:
|
||||
output_signature[state_output.name] = \
|
||||
tf.saved_model.utils.build_tensor_info(state_output)
|
||||
signature_def = (
|
||||
tf.saved_model.signature_def_utils.build_signature_def(
|
||||
input_signature, output_signature,
|
||||
tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
|
||||
signature_def_key = (tf.saved_model.signature_constants.
|
||||
DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||
signature_def_map = {signature_def_key: signature_def}
|
||||
return signature_def_map
|
||||
|
||||
def _build_compute_actions(self,
|
||||
builder,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
state_batches = state_batches or []
|
||||
if len(self._state_inputs) != len(state_batches):
|
||||
raise ValueError(
|
||||
"Must pass in RNN state batches for placeholders {}, got {}".
|
||||
format(self._state_inputs, state_batches))
|
||||
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
if state_batches:
|
||||
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
||||
if self._prev_action_input is not None and prev_action_batch:
|
||||
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
||||
if self._prev_reward_input is not None and prev_reward_batch:
|
||||
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
||||
builder.add_feed_dict({self._is_training: False})
|
||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||
fetches = builder.add_fetches([self._sampler] + self._state_outputs +
|
||||
[self.extra_compute_action_fetches()])
|
||||
return fetches[0], fetches[1:-1], fetches[-1]
|
||||
|
||||
def _build_compute_gradients(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
fetches = builder.add_fetches(
|
||||
[self._grads, self._get_grad_and_stats_fetches()])
|
||||
return fetches[0], fetches[1]
|
||||
|
||||
def _build_apply_gradients(self, builder, gradients):
|
||||
if len(gradients) != len(self._grads):
|
||||
raise ValueError(
|
||||
"Unexpected number of gradients to apply, got {} for {}".
|
||||
format(gradients, self._grads))
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(dict(zip(self._grads, gradients)))
|
||||
fetches = builder.add_fetches([self._apply_op])
|
||||
return fetches[0]
|
||||
|
||||
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
fetches = builder.add_fetches([
|
||||
self._apply_op,
|
||||
self._get_grad_and_stats_fetches(),
|
||||
])
|
||||
return fetches[1]
|
||||
|
||||
def _get_grad_and_stats_fetches(self):
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
if LEARNER_STATS_KEY not in fetches:
|
||||
raise ValueError(
|
||||
"Grad fetches should contain 'stats': {...} entry")
|
||||
if self._stats_fetches:
|
||||
fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches,
|
||||
**fetches[LEARNER_STATS_KEY])
|
||||
return fetches
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch[SampleBatch.CUR_OBS]) %
|
||||
self._batch_divisibility_req == 0
|
||||
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
|
||||
else:
|
||||
meets_divisibility_reqs = True
|
||||
|
||||
# Simple case: not RNN nor do we need to pad
|
||||
if not self._state_inputs and meets_divisibility_reqs:
|
||||
for k, ph in self._loss_inputs:
|
||||
feed_dict[ph] = batch[k]
|
||||
return feed_dict
|
||||
|
||||
if self._state_inputs:
|
||||
max_seq_len = self._max_seq_len
|
||||
dynamic_max = True
|
||||
else:
|
||||
max_seq_len = self._batch_divisibility_req
|
||||
dynamic_max = False
|
||||
|
||||
# RNN or multi-agent case
|
||||
feature_keys = [k for k, v in self._loss_inputs]
|
||||
state_keys = [
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch[SampleBatch.EPS_ID],
|
||||
batch[SampleBatch.UNROLL_ID],
|
||||
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
dynamic_max=dynamic_max)
|
||||
for k, v in zip(feature_keys, feature_sequences):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
for k, v in zip(state_keys, initial_states):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
feed_dict[self._seq_lens] = seq_lens
|
||||
|
||||
if log_once("rnn_feed_dict"):
|
||||
logger.info("Padded input for RNN:\n\n{}\n".format(
|
||||
summarize({
|
||||
"features": feature_sequences,
|
||||
"initial_states": initial_states,
|
||||
"seq_lens": seq_lens,
|
||||
"max_seq_len": max_seq_len,
|
||||
})))
|
||||
return feed_dict
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class LearningRateSchedule(object):
|
||||
"""Mixin for TFPolicyGraph that adds a learning rate schedule."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, lr, lr_schedule):
|
||||
self.cur_lr = tf.get_variable("lr", initializer=lr)
|
||||
if lr_schedule is None:
|
||||
self.lr_schedule = ConstantSchedule(lr)
|
||||
else:
|
||||
self.lr_schedule = PiecewiseSchedule(
|
||||
lr_schedule, outside_value=lr_schedule[-1][-1])
|
||||
|
||||
@override(PolicyGraph)
|
||||
def on_global_var_update(self, global_vars):
|
||||
super(LearningRateSchedule, self).on_global_var_update(global_vars)
|
||||
self.cur_lr.load(
|
||||
self.lr_schedule.value(global_vars["timestep"]),
|
||||
session=self._sess)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
TFPolicyGraph = renamed_class(TFPolicy, old_name="TFPolicyGraph")
|
||||
|
||||
@@ -2,9 +2,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.evaluation.dynamic_tf_policy_graph import DynamicTFPolicyGraph
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ def build_tf_policy(name,
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the graph (e.g., "PPOPolicy")
|
||||
name (str): name of the policy (e.g., "PPOTFPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
get_default_config (func): optional function that returns the default
|
||||
@@ -39,7 +39,7 @@ def build_tf_policy(name,
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
a dict of TF fetches given the policy object
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as PolicyGraph.postprocess_trajectory()
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
optimizer_fn (func): optional function that returns a tf.Optimizer
|
||||
given the policy and config
|
||||
gradients_fn (func): optional function that returns a list of gradients
|
||||
@@ -57,18 +57,18 @@ def build_tf_policy(name,
|
||||
arguments
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the DynamicTFPolicyGraph class
|
||||
precedence than the DynamicTFPolicy class
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
|
||||
Returns:
|
||||
a DynamicTFPolicyGraph instance that uses the specified args
|
||||
a DynamicTFPolicy instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TFPolicy"):
|
||||
raise ValueError("Name should match *TFPolicy", name)
|
||||
|
||||
base = DynamicTFPolicyGraph
|
||||
base = DynamicTFPolicy
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
@@ -76,7 +76,7 @@ def build_tf_policy(name,
|
||||
|
||||
base = new_base
|
||||
|
||||
class graph_cls(base):
|
||||
class policy_cls(base):
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
@@ -97,7 +97,7 @@ def build_tf_policy(name,
|
||||
else:
|
||||
self._extra_action_fetches = extra_action_fetches_fn(self)
|
||||
|
||||
DynamicTFPolicyGraph.__init__(
|
||||
DynamicTFPolicy.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
@@ -111,7 +111,7 @@ def build_tf_policy(name,
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -121,26 +121,26 @@ def build_tf_policy(name,
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
if optimizer_fn:
|
||||
return optimizer_fn(self, self.config)
|
||||
else:
|
||||
return TFPolicyGraph.optimizer(self)
|
||||
return TFPolicy.optimizer(self)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if gradients_fn:
|
||||
return gradients_fn(self, optimizer, loss)
|
||||
else:
|
||||
return TFPolicyGraph.gradients(self, optimizer, loss)
|
||||
return TFPolicy.gradients(self, optimizer, loss)
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicyGraph.extra_compute_action_fetches(self),
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**self._extra_action_fetches)
|
||||
|
||||
graph_cls.__name__ = name
|
||||
graph_cls.__qualname__ = name
|
||||
return graph_cls
|
||||
policy_cls.__name__ = name
|
||||
policy_cls.__qualname__ = name
|
||||
return policy_cls
|
||||
|
||||
@@ -2,173 +2,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.utils import renamed_class
|
||||
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass # soft dep
|
||||
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
|
||||
class TorchPolicyGraph(PolicyGraph):
|
||||
"""Template for a PyTorch policy and loss to use with RLlib.
|
||||
|
||||
This is similar to TFPolicyGraph, but for PyTorch.
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
lock (Lock): Lock that must be held around PyTorch ops on this graph.
|
||||
This is necessary when using the async sampler.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, model, loss,
|
||||
action_distribution_cls):
|
||||
"""Build a policy graph from policy and loss torch modules.
|
||||
|
||||
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
||||
is set. Only single GPU is supported for now.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
model (nn.Module): PyTorch policy module. Given observations as
|
||||
input, this module must return a list of outputs where the
|
||||
first item is action logits, and the rest can be any value.
|
||||
loss (func): Function that takes (policy_graph, batch_tensors)
|
||||
and returns a single scalar loss.
|
||||
action_distribution_cls (ActionDistribution): Class for action
|
||||
distribution.
|
||||
"""
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.lock = Lock()
|
||||
self.device = (torch.device("cuda")
|
||||
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
else torch.device("cpu"))
|
||||
self._model = model.to(self.device)
|
||||
self._loss = loss
|
||||
self._optimizer = self.optimizer()
|
||||
self._action_dist_cls = action_distribution_cls
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
with self.lock:
|
||||
with torch.no_grad():
|
||||
ob = torch.from_numpy(np.array(obs_batch)) \
|
||||
.float().to(self.device)
|
||||
model_out = self._model({"obs": ob}, state_batches)
|
||||
logits, _, vf, state = model_out
|
||||
action_dist = self._action_dist_cls(logits)
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(),
|
||||
[h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(model_out))
|
||||
|
||||
@override(PolicyGraph)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_out = self._loss(self, batch_tensors)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
self._optimizer.step()
|
||||
|
||||
grad_info = self.extra_grad_info(batch_tensors)
|
||||
grad_info.update(grad_process_info)
|
||||
return {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_out = self._loss(self, batch_tensors)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
grads = []
|
||||
for p in self._model.parameters():
|
||||
if p.grad is not None:
|
||||
grads.append(p.grad.data.cpu().numpy())
|
||||
else:
|
||||
grads.append(None)
|
||||
|
||||
grad_info = self.extra_grad_info(batch_tensors)
|
||||
grad_info.update(grad_process_info)
|
||||
return grads, {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def apply_gradients(self, gradients):
|
||||
with self.lock:
|
||||
for g, p in zip(gradients, self._model.parameters()):
|
||||
if g is not None:
|
||||
p.grad = torch.from_numpy(g).to(self.device)
|
||||
self._optimizer.step()
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_weights(self):
|
||||
with self.lock:
|
||||
return {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def set_weights(self, weights):
|
||||
with self.lock:
|
||||
self._model.load_state_dict(weights)
|
||||
|
||||
@override(PolicyGraph)
|
||||
def get_initial_state(self):
|
||||
return [s.numpy() for s in self._model.state_init()]
|
||||
|
||||
def extra_grad_process(self):
|
||||
"""Allow subclass to do extra processing on gradients and
|
||||
return processing info."""
|
||||
return {}
|
||||
|
||||
def extra_action_out(self, model_out):
|
||||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
Arguments:
|
||||
model_out (list): Outputs of the policy model module."""
|
||||
return {}
|
||||
|
||||
def extra_grad_info(self, batch_tensors):
|
||||
"""Return dict of extra grad info."""
|
||||
|
||||
return {}
|
||||
|
||||
def optimizer(self):
|
||||
"""Custom PyTorch optimizer to use."""
|
||||
if hasattr(self, "config"):
|
||||
return torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
else:
|
||||
return torch.optim.Adam(self._model.parameters())
|
||||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
batch_tensors = UsageTrackingDict(postprocessed_batch)
|
||||
batch_tensors.set_get_interceptor(
|
||||
lambda arr: torch.from_numpy(arr).to(self.device))
|
||||
return batch_tensors
|
||||
TorchPolicyGraph = renamed_class(TorchPolicy, old_name="TorchPolicyGraph")
|
||||
|
||||
@@ -209,7 +209,7 @@ if __name__ == "__main__":
|
||||
"log_level": "INFO",
|
||||
"entropy_coeff": 0.01,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policies": {
|
||||
"high_level_policy": (None, maze.observation_space,
|
||||
Discrete(4), {
|
||||
"gamma": 0.9
|
||||
|
||||
@@ -6,7 +6,7 @@ from __future__ import print_function
|
||||
Control the number of agents and policies via --num-agents and --num-policies.
|
||||
|
||||
This works with hundreds of agents and policies, but note that initializing
|
||||
many TF policy graphs will take some time.
|
||||
many TF policies will take some time.
|
||||
|
||||
Also, TF evals might slow down with large numbers of policies. To debug TF
|
||||
execution, set the TF_TIMELINE_DIR environment variable.
|
||||
@@ -90,12 +90,12 @@ if __name__ == "__main__":
|
||||
}
|
||||
return (None, obs_space, act_space, config)
|
||||
|
||||
# Setup PPO with an ensemble of `num_policies` different policy graphs
|
||||
policy_graphs = {
|
||||
# Setup PPO with an ensemble of `num_policies` different policies
|
||||
policies = {
|
||||
"policy_{}".format(i): gen_policy(i)
|
||||
for i in range(args.num_policies)
|
||||
}
|
||||
policy_ids = list(policy_graphs.keys())
|
||||
policy_ids = list(policies.keys())
|
||||
|
||||
tune.run(
|
||||
"PPO",
|
||||
@@ -105,7 +105,7 @@ if __name__ == "__main__":
|
||||
"log_level": "DEBUG",
|
||||
"num_sgd_iter": 10,
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": tune.function(
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
},
|
||||
|
||||
@@ -22,7 +22,7 @@ import gym
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.evaluation import PolicyGraph
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
@@ -30,7 +30,7 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--num-iters", type=int, default=20)
|
||||
|
||||
|
||||
class RandomPolicy(PolicyGraph):
|
||||
class RandomPolicy(Policy):
|
||||
"""Hand-coded policy that returns random actions."""
|
||||
|
||||
def compute_actions(self,
|
||||
@@ -65,7 +65,7 @@ if __name__ == "__main__":
|
||||
config={
|
||||
"env": "multi_cartpole",
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policies": {
|
||||
"pg_policy": (None, obs_space, act_space, {}),
|
||||
"random": (RandomPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
|
||||
@@ -16,9 +16,9 @@ import gym
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.dqn import DQNTrainer
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
|
||||
from ray.rllib.tests.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune.registry import register_env
|
||||
@@ -36,11 +36,11 @@ if __name__ == "__main__":
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
|
||||
# You can also have multiple policy graphs per trainer, but here we just
|
||||
# You can also have multiple policies per trainer, but here we just
|
||||
# show one each for PPO and DQN.
|
||||
policy_graphs = {
|
||||
policies = {
|
||||
"ppo_policy": (PPOTFPolicy, obs_space, act_space, {}),
|
||||
"dqn_policy": (DQNPolicyGraph, obs_space, act_space, {}),
|
||||
"dqn_policy": (DQNTFPolicy, obs_space, act_space, {}),
|
||||
}
|
||||
|
||||
def policy_mapping_fn(agent_id):
|
||||
@@ -53,7 +53,7 @@ if __name__ == "__main__":
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": policy_mapping_fn,
|
||||
"policies_to_train": ["ppo_policy"],
|
||||
},
|
||||
@@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policies": policies,
|
||||
"policy_mapping_fn": policy_mapping_fn,
|
||||
"policies_to_train": ["dqn_policy"],
|
||||
},
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Example of using policy evaluator classes directly to implement training.
|
||||
|
||||
Instead of using the built-in Trainer classes provided by RLlib, here we define
|
||||
a custom PolicyGraph class and manually coordinate distributed sample
|
||||
a custom Policy class and manually coordinate distributed sample
|
||||
collection and policy optimization.
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,8 @@ import gym
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.evaluation import PolicyGraph, PolicyEvaluator, SampleBatch
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.evaluation import PolicyEvaluator, SampleBatch
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -23,15 +24,15 @@ parser.add_argument("--num-iters", type=int, default=20)
|
||||
parser.add_argument("--num-workers", type=int, default=2)
|
||||
|
||||
|
||||
class CustomPolicy(PolicyGraph):
|
||||
"""Example of a custom policy graph written from scratch.
|
||||
class CustomPolicy(Policy):
|
||||
"""Example of a custom policy written from scratch.
|
||||
|
||||
You might find it more convenient to extend TF/TorchPolicyGraph instead
|
||||
You might find it more convenient to extend TF/TorchPolicy instead
|
||||
for a real policy.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
PolicyGraph.__init__(self, observation_space, action_space, config)
|
||||
Policy.__init__(self, observation_space, action_space, config)
|
||||
# example parameter
|
||||
self.w = 1.0
|
||||
|
||||
|
||||
+6
-6
@@ -4,19 +4,19 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
||||
|
||||
def _sample(probs):
|
||||
return [np.random.choice(len(pr), p=pr) for pr in probs]
|
||||
|
||||
|
||||
class KerasPolicyGraph(PolicyGraph):
|
||||
"""Initialize the Keras Policy Graph.
|
||||
class KerasPolicy(Policy):
|
||||
"""Initialize the Keras Policy.
|
||||
|
||||
This is a Policy Graph used for models with actor and critics.
|
||||
This is a Policy used for models with actor and critics.
|
||||
Note: This class is built for specific usage of Actor-Critic models,
|
||||
and is less general compared to TFPolicyGraph and TorchPolicyGraphs.
|
||||
and is less general compared to TFPolicy and TorchPolicies.
|
||||
|
||||
Args:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
@@ -32,7 +32,7 @@ class KerasPolicyGraph(PolicyGraph):
|
||||
config,
|
||||
actor=None,
|
||||
critic=None):
|
||||
PolicyGraph.__init__(self, observation_space, action_space, config)
|
||||
Policy.__init__(self, observation_space, action_space, config)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.models = [self.actor, self.critic]
|
||||
@@ -161,7 +161,7 @@ class Model(object):
|
||||
You can find an runnable example in examples/custom_loss.py.
|
||||
|
||||
Arguments:
|
||||
policy_loss (Tensor): scalar policy loss from the policy graph.
|
||||
policy_loss (Tensor): scalar policy loss from the policy.
|
||||
loss_inputs (dict): map of input placeholders for rollout data.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ except ImportError:
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, SampleBatch, \
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
|
||||
@@ -15,7 +15,7 @@ try:
|
||||
except ImportError:
|
||||
smart_open = None
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.output_writer import OutputWriter
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -23,7 +23,7 @@ class OffPolicyEstimator(object):
|
||||
"""Creates an off-policy estimator.
|
||||
|
||||
Arguments:
|
||||
policy (PolicyGraph): Policy graph to evaluate.
|
||||
policy (Policy): Policy to evaluate.
|
||||
gamma (float): Discount of the MDP.
|
||||
"""
|
||||
self.policy = policy
|
||||
@@ -71,7 +71,7 @@ class OffPolicyEstimator(object):
|
||||
raise ValueError(
|
||||
"Off-policy estimation is not possible unless the policy "
|
||||
"returns action probabilities when computing actions (i.e., "
|
||||
"the 'action_prob' key is output by the policy graph). You "
|
||||
"the 'action_prob' key is output by the policy). You "
|
||||
"can set `input_evaluation: []` to resolve this.")
|
||||
return info["action_prob"]
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import math
|
||||
from six.moves import queue
|
||||
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.optimizers.aso_learner import LearnerThread
|
||||
from ray.rllib.optimizers.aso_minibatch_buffer import MinibatchBuffer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
|
||||
@@ -17,7 +17,7 @@ from six.moves import queue
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.replay_buffer import PrioritizedReplayBuffer
|
||||
|
||||
@@ -48,7 +48,7 @@ class LocalSyncParallelOptimizer(object):
|
||||
processed. If this is larger than the total data size, it will be
|
||||
clipped.
|
||||
build_graph: Function that takes the specified inputs and returns a
|
||||
TF Policy Graph instance.
|
||||
TF Policy instance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -9,14 +9,14 @@ from collections import defaultdict
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.optimizers.rollout import collect_samples, \
|
||||
collect_samples_straggler_mitigation
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -34,9 +34,9 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
details, see `multi_gpu_impl.LocalSyncParallelOptimizer`.
|
||||
|
||||
This optimizer is Tensorflow-specific and require the underlying
|
||||
PolicyGraph to be a TFPolicyGraph instance that support `.copy()`.
|
||||
Policy to be a TFPolicy instance that support `.copy()`.
|
||||
|
||||
Note that all replicas of the TFPolicyGraph will merge their
|
||||
Note that all replicas of the TFPolicy will merge their
|
||||
extra_compute_grad and apply_grad feed_dicts and fetches. This
|
||||
may result in unexpected behavior.
|
||||
"""
|
||||
@@ -83,7 +83,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
self.local_evaluator.foreach_trainable_policy(lambda p, i: (i, p)))
|
||||
logger.debug("Policies to train: {}".format(self.policies))
|
||||
for policy_id, policy in self.policies.items():
|
||||
if not isinstance(policy, TFPolicyGraph):
|
||||
if not isinstance(policy, TFPolicy):
|
||||
raise ValueError(
|
||||
"Only TF policies are supported with multi-GPU. Try using "
|
||||
"the simple optimizer instead.")
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import print_function
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.memory import ray_get_and_free
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,7 @@ import random
|
||||
import ray
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
@@ -11,7 +11,7 @@ from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \
|
||||
PrioritizedReplayBuffer
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import pack_if_needed
|
||||
|
||||
@@ -6,7 +6,7 @@ import ray
|
||||
import logging
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.filter import RunningStat
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.policy.torch_policy_template import build_torch_policy
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
|
||||
__all__ = [
|
||||
"Policy",
|
||||
"TFPolicy",
|
||||
"TorchPolicy",
|
||||
"build_tf_policy",
|
||||
"build_torch_policy",
|
||||
]
|
||||
+14
-14
@@ -6,9 +6,9 @@ from collections import OrderedDict
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils import try_import_tf
|
||||
@@ -20,8 +20,8 @@ tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
"""A TFPolicyGraph that auto-defines placeholders dynamically at runtime.
|
||||
class DynamicTFPolicy(TFPolicy):
|
||||
"""A TFPolicy that auto-defines placeholders dynamically at runtime.
|
||||
|
||||
Initialization of this class occurs in two phases.
|
||||
* Phase 1: the model is created and model variables are initialized.
|
||||
@@ -42,7 +42,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
make_action_sampler=None,
|
||||
existing_inputs=None,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Initialize a dynamic TF policy graph.
|
||||
"""Initialize a dynamic TF policy.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
@@ -51,16 +51,16 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
loss_fn (func): function that returns a loss tensor the policy
|
||||
graph, and dict of experience tensor placeholders
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy graph and batch input tensors
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy graph and loss gradient tensors
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as __init__
|
||||
make_action_sampler (func): optional function that returns a
|
||||
tuple of action and action prob tensors. The function takes
|
||||
(policy, input_dict, obs_space, action_space, config) as its
|
||||
arguments
|
||||
existing_inputs (OrderedDict): when copying a policy graph, this
|
||||
existing_inputs (OrderedDict): when copying a policy, this
|
||||
specifies an existing dict of placeholders to use instead of
|
||||
defining new ones
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
@@ -134,7 +134,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
batch_divisibility_req = get_batch_divisibility_req(self)
|
||||
else:
|
||||
batch_divisibility_req = 1
|
||||
TFPolicyGraph.__init__(
|
||||
TFPolicy.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
@@ -158,7 +158,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
if not existing_inputs:
|
||||
self._initialize_loss()
|
||||
|
||||
@override(TFPolicyGraph)
|
||||
@override(TFPolicy)
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders."""
|
||||
|
||||
@@ -194,7 +194,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
if instance._stats_fn:
|
||||
instance._stats_fetches.update(
|
||||
instance._stats_fn(instance, input_dict))
|
||||
TFPolicyGraph._initialize_loss(
|
||||
TFPolicy._initialize_loss(
|
||||
instance, loss, [(k, existing_inputs[i])
|
||||
for i, (k, _) in enumerate(self._loss_inputs)])
|
||||
if instance._grad_stats_fn:
|
||||
@@ -202,7 +202,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
instance._grad_stats_fn(instance, instance._grads))
|
||||
return instance
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
if self.model:
|
||||
return self.model.state_init
|
||||
@@ -269,7 +269,7 @@ class DynamicTFPolicyGraph(TFPolicyGraph):
|
||||
self._stats_fetches.update(self._stats_fn(self, batch_tensors))
|
||||
for k in sorted(batch_tensors.accessed_keys):
|
||||
loss_inputs.append((k, batch_tensors[k]))
|
||||
TFPolicyGraph._initialize_loss(self, loss, loss_inputs)
|
||||
TFPolicy._initialize_loss(self, loss, loss_inputs)
|
||||
if self._grad_stats_fn:
|
||||
self._stats_fetches.update(self._grad_stats_fn(self, self._grads))
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
@@ -0,0 +1,291 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
|
||||
# By convention, metrics from optimizing the loss can be reported in the
|
||||
# `grad_info` dict returned by learn_on_batch() / compute_grads() via this key.
|
||||
LEARNER_STATS_KEY = "learner_stats"
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class Policy(object):
|
||||
"""An agent policy and loss, i.e., a TFPolicy or other subclass.
|
||||
|
||||
This object defines how to act in the environment, and also losses used to
|
||||
improve the policy based on its experiences. Note that both policy and
|
||||
loss are defined together for convenience, though the policy itself is
|
||||
logically separate.
|
||||
|
||||
All policies can directly extend Policy, however TensorFlow users may
|
||||
find TFPolicy simpler to implement. TFPolicy also enables RLlib
|
||||
to apply TensorFlow-specific optimizations such as fusing multiple policy
|
||||
graphs and multi-GPU support.
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, observation_space, action_space, config):
|
||||
"""Initialize the graph.
|
||||
|
||||
This is the standard constructor for policies. The policy
|
||||
class you pass into PolicyEvaluator will be constructed with
|
||||
these arguments.
|
||||
|
||||
Args:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
config (dict): Policy-specific configuration data.
|
||||
"""
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
"""Compute actions for the current policy.
|
||||
|
||||
Arguments:
|
||||
obs_batch (np.ndarray): batch of observations
|
||||
state_batches (list): list of RNN state input batches, if any
|
||||
prev_action_batch (np.ndarray): batch of previous action values
|
||||
prev_reward_batch (np.ndarray): batch of previous rewards
|
||||
info_batch (info): batch of info objects
|
||||
episodes (list): MultiAgentEpisode for each obs in obs_batch.
|
||||
This provides access to all of the internal episode state,
|
||||
which may be useful for model-based or multiagent algorithms.
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (np.ndarray): batch of output actions, with shape like
|
||||
[BATCH_SIZE, ACTION_SHAPE].
|
||||
state_outs (list): list of RNN state output batches, if any, with
|
||||
shape like [STATE_SIZE, BATCH_SIZE].
|
||||
info (dict): dictionary of extra feature batches, if any, with
|
||||
shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_single_action(self,
|
||||
obs,
|
||||
state,
|
||||
prev_action=None,
|
||||
prev_reward=None,
|
||||
info=None,
|
||||
episode=None,
|
||||
clip_actions=False,
|
||||
**kwargs):
|
||||
"""Unbatched version of compute_actions.
|
||||
|
||||
Arguments:
|
||||
obs (obj): single observation
|
||||
state_batches (list): list of RNN state inputs, if any
|
||||
prev_action (obj): previous action value, if any
|
||||
prev_reward (int): previous reward, if any
|
||||
info (dict): info object, if any
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
clip_actions (bool): should the action be clipped
|
||||
kwargs: forward compatibility placeholder
|
||||
|
||||
Returns:
|
||||
actions (obj): single action
|
||||
state_outs (list): list of RNN state outputs, if any
|
||||
info (dict): dictionary of extra features, if any
|
||||
"""
|
||||
|
||||
prev_action_batch = None
|
||||
prev_reward_batch = None
|
||||
info_batch = None
|
||||
episodes = None
|
||||
if prev_action is not None:
|
||||
prev_action_batch = [prev_action]
|
||||
if prev_reward is not None:
|
||||
prev_reward_batch = [prev_reward]
|
||||
if info is not None:
|
||||
info_batch = [info]
|
||||
if episode is not None:
|
||||
episodes = [episode]
|
||||
[action], state_out, info = self.compute_actions(
|
||||
[obs], [[s] for s in state],
|
||||
prev_action_batch=prev_action_batch,
|
||||
prev_reward_batch=prev_reward_batch,
|
||||
info_batch=info_batch,
|
||||
episodes=episodes)
|
||||
if clip_actions:
|
||||
action = clip_action(action, self.action_space)
|
||||
return action, [s[0] for s in state_out], \
|
||||
{k: v[0] for k, v in info.items()}
|
||||
|
||||
@DeveloperAPI
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
"""Implements algorithm-specific trajectory postprocessing.
|
||||
|
||||
This will be called on each trajectory fragment computed during policy
|
||||
evaluation. Each fragment is guaranteed to be only from one episode.
|
||||
|
||||
Arguments:
|
||||
sample_batch (SampleBatch): batch of experiences for the policy,
|
||||
which will contain at most one episode trajectory.
|
||||
other_agent_batches (dict): In a multi-agent env, this contains a
|
||||
mapping of agent ids to (policy, agent_batch) tuples
|
||||
containing the policy and experiences of the other agent.
|
||||
episode (MultiAgentEpisode): this provides access to all of the
|
||||
internal episode state, which may be useful for model-based or
|
||||
multi-agent algorithms.
|
||||
|
||||
Returns:
|
||||
SampleBatch: postprocessed sample batch.
|
||||
"""
|
||||
return sample_batch
|
||||
|
||||
@DeveloperAPI
|
||||
def learn_on_batch(self, samples):
|
||||
"""Fused compute gradients and apply gradients call.
|
||||
|
||||
Either this or the combination of compute/apply grads must be
|
||||
implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
grad_info: dictionary of extra metadata from compute_gradients().
|
||||
|
||||
Examples:
|
||||
>>> batch = ev.sample()
|
||||
>>> ev.learn_on_batch(samples)
|
||||
"""
|
||||
|
||||
grads, grad_info = self.compute_gradients(samples)
|
||||
self.apply_gradients(grads)
|
||||
return grad_info
|
||||
|
||||
@DeveloperAPI
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
"""Computes gradients against a batch of experiences.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
grads (list): List of gradient output values
|
||||
info (dict): Extra policy-specific values
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def apply_gradients(self, gradients):
|
||||
"""Applies previously computed gradients.
|
||||
|
||||
Either this or learn_on_batch() must be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_weights(self):
|
||||
"""Returns model weights.
|
||||
|
||||
Returns:
|
||||
weights (obj): Serializable copy or view of model weights
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def set_weights(self, weights):
|
||||
"""Sets model weights.
|
||||
|
||||
Arguments:
|
||||
weights (obj): Serializable copy or view of model weights
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def get_initial_state(self):
|
||||
"""Returns initial RNN state for the current policy."""
|
||||
return []
|
||||
|
||||
@DeveloperAPI
|
||||
def get_state(self):
|
||||
"""Saves all local state.
|
||||
|
||||
Returns:
|
||||
state (obj): Serialized local state.
|
||||
"""
|
||||
return self.get_weights()
|
||||
|
||||
@DeveloperAPI
|
||||
def set_state(self, state):
|
||||
"""Restores all local state.
|
||||
|
||||
Arguments:
|
||||
state (obj): Serialized local state.
|
||||
"""
|
||||
self.set_weights(state)
|
||||
|
||||
@DeveloperAPI
|
||||
def on_global_var_update(self, global_vars):
|
||||
"""Called on an update to global vars.
|
||||
|
||||
Arguments:
|
||||
global_vars (dict): Global variables broadcast from the driver.
|
||||
"""
|
||||
pass
|
||||
|
||||
@DeveloperAPI
|
||||
def export_model(self, export_dir):
|
||||
"""Export Policy to local directory for serving.
|
||||
|
||||
Arguments:
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def export_checkpoint(self, export_dir):
|
||||
"""Export Policy checkpoint to local directory.
|
||||
|
||||
Argument:
|
||||
export_dir (str): Local writable directory.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def clip_action(action, space):
|
||||
"""Called to clip actions to the specified range of this policy.
|
||||
|
||||
Arguments:
|
||||
action: Single action.
|
||||
space: Action space the actions should be present in.
|
||||
|
||||
Returns:
|
||||
Clipped batch of actions.
|
||||
"""
|
||||
|
||||
if isinstance(space, gym.spaces.Box):
|
||||
return np.clip(action, space.low, space.high)
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
if type(action) not in (tuple, list):
|
||||
raise ValueError("Expected tuple space for actions {}: {}".format(
|
||||
action, space))
|
||||
out = []
|
||||
for a, s in zip(action, space.spaces):
|
||||
out.append(clip_action(a, s))
|
||||
return out
|
||||
else:
|
||||
return action
|
||||
@@ -0,0 +1,296 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI, DeveloperAPI
|
||||
from ray.rllib.utils.compression import pack, unpack, is_compressed
|
||||
from ray.rllib.utils.memory import concat_aligned
|
||||
|
||||
# Default policy id for single agent environments
|
||||
DEFAULT_POLICY_ID = "default_policy"
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class MultiAgentBatch(object):
|
||||
"""A batch of experiences from multiple policies in the environment.
|
||||
|
||||
Attributes:
|
||||
policy_batches (dict): Mapping from policy id to a normal SampleBatch
|
||||
of experiences. Note that these batches may be of different length.
|
||||
count (int): The number of timesteps in the environment this batch
|
||||
contains. This will be less than the number of transitions this
|
||||
batch contains across all policies in total.
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, policy_batches, count):
|
||||
self.policy_batches = policy_batches
|
||||
self.count = count
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def wrap_as_needed(batches, count):
|
||||
if len(batches) == 1 and DEFAULT_POLICY_ID in batches:
|
||||
return batches[DEFAULT_POLICY_ID]
|
||||
return MultiAgentBatch(batches, count)
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
policy_batches = collections.defaultdict(list)
|
||||
total_count = 0
|
||||
for s in samples:
|
||||
assert isinstance(s, MultiAgentBatch)
|
||||
for policy_id, batch in s.policy_batches.items():
|
||||
policy_batches[policy_id].append(batch)
|
||||
total_count += s.count
|
||||
out = {}
|
||||
for policy_id, batches in policy_batches.items():
|
||||
out[policy_id] = SampleBatch.concat_samples(batches)
|
||||
return MultiAgentBatch(out, total_count)
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
return MultiAgentBatch(
|
||||
{k: v.copy()
|
||||
for (k, v) in self.policy_batches.items()}, self.count)
|
||||
|
||||
@PublicAPI
|
||||
def total(self):
|
||||
ct = 0
|
||||
for batch in self.policy_batches.values():
|
||||
ct += batch.count
|
||||
return ct
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
for batch in self.policy_batches.values():
|
||||
batch.compress(bulk=bulk, columns=columns)
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
for batch in self.policy_batches.values():
|
||||
batch.decompress_if_needed(columns)
|
||||
|
||||
def __str__(self):
|
||||
return "MultiAgentBatch({}, count={})".format(
|
||||
str(self.policy_batches), self.count)
|
||||
|
||||
def __repr__(self):
|
||||
return "MultiAgentBatch({}, count={})".format(
|
||||
str(self.policy_batches), self.count)
|
||||
|
||||
|
||||
@PublicAPI
|
||||
class SampleBatch(object):
|
||||
"""Wrapper around a dictionary with string keys and array-like values.
|
||||
|
||||
For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
|
||||
samples, each with an "obs" and "reward" attribute.
|
||||
"""
|
||||
|
||||
# Outputs from interacting with the environment
|
||||
CUR_OBS = "obs"
|
||||
NEXT_OBS = "new_obs"
|
||||
ACTIONS = "actions"
|
||||
REWARDS = "rewards"
|
||||
PREV_ACTIONS = "prev_actions"
|
||||
PREV_REWARDS = "prev_rewards"
|
||||
DONES = "dones"
|
||||
INFOS = "infos"
|
||||
|
||||
# Uniquely identifies an episode
|
||||
EPS_ID = "eps_id"
|
||||
|
||||
# Uniquely identifies a sample batch. This is important to distinguish RNN
|
||||
# sequences from the same episode when multiple sample batches are
|
||||
# concatenated (fusing sequences across batches can be unsafe).
|
||||
UNROLL_ID = "unroll_id"
|
||||
|
||||
# Uniquely identifies an agent within an episode
|
||||
AGENT_INDEX = "agent_index"
|
||||
|
||||
# Value function predictions emitted by the behaviour policy
|
||||
VF_PREDS = "vf_preds"
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Constructs a sample batch (same params as dict constructor)."""
|
||||
|
||||
self.data = dict(*args, **kwargs)
|
||||
lengths = []
|
||||
for k, v in self.data.copy().items():
|
||||
assert isinstance(k, six.string_types), self
|
||||
lengths.append(len(v))
|
||||
self.data[k] = np.array(v, copy=False)
|
||||
if not lengths:
|
||||
raise ValueError("Empty sample batch")
|
||||
assert len(set(lengths)) == 1, "data columns must be same length"
|
||||
self.count = lengths[0]
|
||||
|
||||
@staticmethod
|
||||
@PublicAPI
|
||||
def concat_samples(samples):
|
||||
if isinstance(samples[0], MultiAgentBatch):
|
||||
return MultiAgentBatch.concat_samples(samples)
|
||||
out = {}
|
||||
samples = [s for s in samples if s.count > 0]
|
||||
for k in samples[0].keys():
|
||||
out[k] = concat_aligned([s[k] for s in samples])
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def concat(self, other):
|
||||
"""Returns a new SampleBatch with each data column concatenated.
|
||||
|
||||
Examples:
|
||||
>>> b1 = SampleBatch({"a": [1, 2]})
|
||||
>>> b2 = SampleBatch({"a": [3, 4, 5]})
|
||||
>>> print(b1.concat(b2))
|
||||
{"a": [1, 2, 3, 4, 5]}
|
||||
"""
|
||||
|
||||
assert self.keys() == other.keys(), "must have same columns"
|
||||
out = {}
|
||||
for k in self.keys():
|
||||
out[k] = concat_aligned([self[k], other[k]])
|
||||
return SampleBatch(out)
|
||||
|
||||
@PublicAPI
|
||||
def copy(self):
|
||||
return SampleBatch(
|
||||
{k: np.array(v, copy=True)
|
||||
for (k, v) in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def rows(self):
|
||||
"""Returns an iterator over data rows, i.e. dicts with column values.
|
||||
|
||||
Examples:
|
||||
>>> batch = SampleBatch({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||
>>> for row in batch.rows():
|
||||
print(row)
|
||||
{"a": 1, "b": 4}
|
||||
{"a": 2, "b": 5}
|
||||
{"a": 3, "b": 6}
|
||||
"""
|
||||
|
||||
for i in range(self.count):
|
||||
row = {}
|
||||
for k in self.keys():
|
||||
row[k] = self[k][i]
|
||||
yield row
|
||||
|
||||
@PublicAPI
|
||||
def columns(self, keys):
|
||||
"""Returns a list of just the specified columns.
|
||||
|
||||
Examples:
|
||||
>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
|
||||
>>> print(batch.columns(["a", "b"]))
|
||||
[[1], [2]]
|
||||
"""
|
||||
|
||||
out = []
|
||||
for k in keys:
|
||||
out.append(self[k])
|
||||
return out
|
||||
|
||||
@PublicAPI
|
||||
def shuffle(self):
|
||||
"""Shuffles the rows of this batch in-place."""
|
||||
|
||||
permutation = np.random.permutation(self.count)
|
||||
for key, val in self.items():
|
||||
self[key] = val[permutation]
|
||||
|
||||
@PublicAPI
|
||||
def split_by_episode(self):
|
||||
"""Splits this batch's data by `eps_id`.
|
||||
|
||||
Returns:
|
||||
list of SampleBatch, one per distinct episode.
|
||||
"""
|
||||
|
||||
slices = []
|
||||
cur_eps_id = self.data["eps_id"][0]
|
||||
offset = 0
|
||||
for i in range(self.count):
|
||||
next_eps_id = self.data["eps_id"][i]
|
||||
if next_eps_id != cur_eps_id:
|
||||
slices.append(self.slice(offset, i))
|
||||
offset = i
|
||||
cur_eps_id = next_eps_id
|
||||
slices.append(self.slice(offset, self.count))
|
||||
for s in slices:
|
||||
slen = len(set(s["eps_id"]))
|
||||
assert slen == 1, (s, slen)
|
||||
assert sum(s.count for s in slices) == self.count, (slices, self.count)
|
||||
return slices
|
||||
|
||||
@PublicAPI
|
||||
def slice(self, start, end):
|
||||
"""Returns a slice of the row data of this batch.
|
||||
|
||||
Arguments:
|
||||
start (int): Starting index.
|
||||
end (int): Ending index.
|
||||
|
||||
Returns:
|
||||
SampleBatch which has a slice of this batch's data.
|
||||
"""
|
||||
|
||||
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
|
||||
|
||||
@PublicAPI
|
||||
def keys(self):
|
||||
return self.data.keys()
|
||||
|
||||
@PublicAPI
|
||||
def items(self):
|
||||
return self.data.items()
|
||||
|
||||
@PublicAPI
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
@PublicAPI
|
||||
def __setitem__(self, key, item):
|
||||
self.data[key] = item
|
||||
|
||||
@DeveloperAPI
|
||||
def compress(self, bulk=False, columns=frozenset(["obs", "new_obs"])):
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
if bulk:
|
||||
self.data[key] = pack(self.data[key])
|
||||
else:
|
||||
self.data[key] = np.array(
|
||||
[pack(o) for o in self.data[key]])
|
||||
|
||||
@DeveloperAPI
|
||||
def decompress_if_needed(self, columns=frozenset(["obs", "new_obs"])):
|
||||
for key in columns:
|
||||
if key in self.data:
|
||||
arr = self.data[key]
|
||||
if is_compressed(arr):
|
||||
self.data[key] = unpack(arr)
|
||||
elif len(arr) > 0 and is_compressed(arr[0]):
|
||||
self.data[key] = np.array(
|
||||
[unpack(o) for o in self.data[key]])
|
||||
|
||||
def __str__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def __repr__(self):
|
||||
return "SampleBatch({})".format(str(self.data))
|
||||
|
||||
def __iter__(self):
|
||||
return self.data.__iter__()
|
||||
|
||||
def __contains__(self, x):
|
||||
return x in self.data
|
||||
@@ -0,0 +1,513 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import errno
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.experimental.tf_utils
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.lstm import chop_into_sequences
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.debug import log_once, summarize
|
||||
from ray.rllib.utils.schedules import ConstantSchedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.tf_run_builder import TFRunBuilder
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class TFPolicy(Policy):
|
||||
"""An agent policy and loss implemented in TensorFlow.
|
||||
|
||||
Extending this class enables RLlib to perform TensorFlow specific
|
||||
optimizations on the policy, e.g., parallelization across gpus or
|
||||
fusing multiple graphs together in the multi-agent setting.
|
||||
|
||||
Input tensors are typically shaped like [BATCH_SIZE, ...].
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
model (rllib.models.Model): RLlib model used for the policy.
|
||||
|
||||
Examples:
|
||||
>>> policy = TFPolicySubclass(
|
||||
sess, obs_input, action_sampler, loss, loss_inputs)
|
||||
|
||||
>>> print(policy.compute_actions([1, 0, 2]))
|
||||
(array([0, 1, 1]), [], {})
|
||||
|
||||
>>> print(policy.postprocess_trajectory(SampleBatch({...})))
|
||||
SampleBatch({"action": ..., "advantages": ..., ...})
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self,
|
||||
observation_space,
|
||||
action_space,
|
||||
sess,
|
||||
obs_input,
|
||||
action_sampler,
|
||||
loss,
|
||||
loss_inputs,
|
||||
model=None,
|
||||
action_prob=None,
|
||||
state_inputs=None,
|
||||
state_outputs=None,
|
||||
prev_action_input=None,
|
||||
prev_reward_input=None,
|
||||
seq_lens=None,
|
||||
max_seq_len=20,
|
||||
batch_divisibility_req=1,
|
||||
update_ops=None):
|
||||
"""Initialize the policy.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): Observation space of the env.
|
||||
action_space (gym.Space): Action space of the env.
|
||||
sess (Session): TensorFlow session to use.
|
||||
obs_input (Tensor): input placeholder for observations, of shape
|
||||
[BATCH_SIZE, obs...].
|
||||
action_sampler (Tensor): Tensor for sampling an action, of shape
|
||||
[BATCH_SIZE, action...]
|
||||
loss (Tensor): scalar policy loss output tensor.
|
||||
loss_inputs (list): a (name, placeholder) tuple for each loss
|
||||
input argument. Each placeholder name must correspond to a
|
||||
SampleBatch column key returned by postprocess_trajectory(),
|
||||
and has shape [BATCH_SIZE, data...]. These keys will be read
|
||||
from postprocessed sample batches and fed into the specified
|
||||
placeholders during loss computation.
|
||||
model (rllib.models.Model): used to integrate custom losses and
|
||||
stats from user-defined RLlib models.
|
||||
action_prob (Tensor): probability of the sampled action.
|
||||
state_inputs (list): list of RNN state input Tensors.
|
||||
state_outputs (list): list of RNN state output Tensors.
|
||||
prev_action_input (Tensor): placeholder for previous actions
|
||||
prev_reward_input (Tensor): placeholder for previous rewards
|
||||
seq_lens (Tensor): placeholder for RNN sequence lengths, of shape
|
||||
[NUM_SEQUENCES]. Note that NUM_SEQUENCES << BATCH_SIZE. See
|
||||
models/lstm.py for more information.
|
||||
max_seq_len (int): max sequence length for LSTM training.
|
||||
batch_divisibility_req (int): pad all agent experiences batches to
|
||||
multiples of this value. This only has an effect if not using
|
||||
a LSTM model.
|
||||
update_ops (list): override the batchnorm update ops to run when
|
||||
applying gradients. Otherwise we run all update ops found in
|
||||
the current variable scope.
|
||||
"""
|
||||
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.model = model
|
||||
self._sess = sess
|
||||
self._obs_input = obs_input
|
||||
self._prev_action_input = prev_action_input
|
||||
self._prev_reward_input = prev_reward_input
|
||||
self._sampler = action_sampler
|
||||
self._is_training = self._get_is_training_placeholder()
|
||||
self._action_prob = action_prob
|
||||
self._state_inputs = state_inputs or []
|
||||
self._state_outputs = state_outputs or []
|
||||
self._seq_lens = seq_lens
|
||||
self._max_seq_len = max_seq_len
|
||||
self._batch_divisibility_req = batch_divisibility_req
|
||||
self._update_ops = update_ops
|
||||
self._stats_fetches = {}
|
||||
|
||||
if loss is not None:
|
||||
self._initialize_loss(loss, loss_inputs)
|
||||
else:
|
||||
self._loss = None
|
||||
|
||||
if len(self._state_inputs) != len(self._state_outputs):
|
||||
raise ValueError(
|
||||
"Number of state input and output tensors must match, got: "
|
||||
"{} vs {}".format(self._state_inputs, self._state_outputs))
|
||||
if len(self.get_initial_state()) != len(self._state_inputs):
|
||||
raise ValueError(
|
||||
"Length of initial state must match number of state inputs, "
|
||||
"got: {} vs {}".format(self.get_initial_state(),
|
||||
self._state_inputs))
|
||||
if self._state_inputs and self._seq_lens is None:
|
||||
raise ValueError(
|
||||
"seq_lens tensor must be given if state inputs are defined")
|
||||
|
||||
def _initialize_loss(self, loss, loss_inputs):
|
||||
self._loss_inputs = loss_inputs
|
||||
self._loss_input_dict = dict(self._loss_inputs)
|
||||
for i, ph in enumerate(self._state_inputs):
|
||||
self._loss_input_dict["state_in_{}".format(i)] = ph
|
||||
|
||||
if self.model:
|
||||
self._loss = self.model.custom_loss(loss, self._loss_input_dict)
|
||||
self._stats_fetches.update({"model": self.model.custom_stats()})
|
||||
else:
|
||||
self._loss = loss
|
||||
|
||||
self._optimizer = self.optimizer()
|
||||
self._grads_and_vars = [
|
||||
(g, v) for (g, v) in self.gradients(self._optimizer, self._loss)
|
||||
if g is not None
|
||||
]
|
||||
self._grads = [g for (g, v) in self._grads_and_vars]
|
||||
self._variables = ray.experimental.tf_utils.TensorFlowVariables(
|
||||
self._loss, self._sess)
|
||||
|
||||
# gather update ops for any batch norm layers
|
||||
if not self._update_ops:
|
||||
self._update_ops = tf.get_collection(
|
||||
tf.GraphKeys.UPDATE_OPS, scope=tf.get_variable_scope().name)
|
||||
if self._update_ops:
|
||||
logger.debug("Update ops to run on apply gradient: {}".format(
|
||||
self._update_ops))
|
||||
with tf.control_dependencies(self._update_ops):
|
||||
self._apply_op = self.build_apply_op(self._optimizer,
|
||||
self._grads_and_vars)
|
||||
|
||||
if log_once("loss_used"):
|
||||
logger.debug(
|
||||
"These tensors were used in the loss_fn:\n\n{}\n".format(
|
||||
summarize(self._loss_input_dict)))
|
||||
|
||||
self._sess.run(tf.global_variables_initializer())
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
builder = TFRunBuilder(self._sess, "compute_actions")
|
||||
fetches = self._build_compute_actions(builder, obs_batch,
|
||||
state_batches, prev_action_batch,
|
||||
prev_reward_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "compute_gradients")
|
||||
fetches = self._build_compute_gradients(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "apply_gradients")
|
||||
fetches = self._build_apply_gradients(builder, gradients)
|
||||
builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
assert self._loss is not None, "Loss not initialized"
|
||||
builder = TFRunBuilder(self._sess, "learn_on_batch")
|
||||
fetches = self._build_learn_on_batch(builder, postprocessed_batch)
|
||||
return builder.get(fetches)
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
return self._variables.get_flat()
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
return self._variables.set_flat(weights)
|
||||
|
||||
@override(Policy)
|
||||
def export_model(self, export_dir):
|
||||
"""Export tensorflow graph to export_dir for serving."""
|
||||
with self._sess.graph.as_default():
|
||||
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
|
||||
signature_def_map = self._build_signature_def()
|
||||
builder.add_meta_graph_and_variables(
|
||||
self._sess, [tf.saved_model.tag_constants.SERVING],
|
||||
signature_def_map=signature_def_map)
|
||||
builder.save()
|
||||
|
||||
@override(Policy)
|
||||
def export_checkpoint(self, export_dir, filename_prefix="model"):
|
||||
"""Export tensorflow checkpoint to export_dir."""
|
||||
try:
|
||||
os.makedirs(export_dir)
|
||||
except OSError as e:
|
||||
# ignore error if export dir already exists
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
save_path = os.path.join(export_dir, filename_prefix)
|
||||
with self._sess.graph.as_default():
|
||||
saver = tf.train.Saver()
|
||||
saver.save(self._sess, save_path)
|
||||
|
||||
@DeveloperAPI
|
||||
def copy(self, existing_inputs):
|
||||
"""Creates a copy of self using existing input placeholders.
|
||||
|
||||
Optional, only required to work with the multi-GPU optimizer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_feed_dict(self):
|
||||
"""Extra dict to pass to the compute actions session run."""
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_action_fetches(self):
|
||||
"""Extra values to fetch and return from compute_actions().
|
||||
|
||||
By default we only return action probability info (if present).
|
||||
"""
|
||||
if self._action_prob is not None:
|
||||
return {"action_prob": self._action_prob}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_feed_dict(self):
|
||||
"""Extra dict to pass to the compute gradients session run."""
|
||||
return {} # e.g, kl_coeff
|
||||
|
||||
@DeveloperAPI
|
||||
def extra_compute_grad_fetches(self):
|
||||
"""Extra values to fetch and return from compute_gradients()."""
|
||||
return {LEARNER_STATS_KEY: {}} # e.g, stats, td error, etc.
|
||||
|
||||
@DeveloperAPI
|
||||
def optimizer(self):
|
||||
"""TF optimizer to use for policy optimization."""
|
||||
if hasattr(self, "config"):
|
||||
return tf.train.AdamOptimizer(self.config["lr"])
|
||||
else:
|
||||
return tf.train.AdamOptimizer()
|
||||
|
||||
@DeveloperAPI
|
||||
def gradients(self, optimizer, loss):
|
||||
"""Override for custom gradient computation."""
|
||||
return optimizer.compute_gradients(loss)
|
||||
|
||||
@DeveloperAPI
|
||||
def build_apply_op(self, optimizer, grads_and_vars):
|
||||
"""Override for custom gradient apply computation."""
|
||||
|
||||
# specify global_step for TD3 which needs to count the num updates
|
||||
return optimizer.apply_gradients(
|
||||
self._grads_and_vars,
|
||||
global_step=tf.train.get_or_create_global_step())
|
||||
|
||||
@DeveloperAPI
|
||||
def _get_is_training_placeholder(self):
|
||||
"""Get the placeholder for _is_training, i.e., for batch norm layers.
|
||||
|
||||
This can be called safely before __init__ has run.
|
||||
"""
|
||||
if not hasattr(self, "_is_training"):
|
||||
self._is_training = tf.placeholder_with_default(False, ())
|
||||
return self._is_training
|
||||
|
||||
def _extra_input_signature_def(self):
|
||||
"""Extra input signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_feed_dict()
|
||||
"""
|
||||
feed_dict = self.extra_compute_action_feed_dict()
|
||||
return {
|
||||
k.name: tf.saved_model.utils.build_tensor_info(k)
|
||||
for k in feed_dict.keys()
|
||||
}
|
||||
|
||||
def _extra_output_signature_def(self):
|
||||
"""Extra output signatures to add when exporting tf model.
|
||||
Inferred from extra_compute_action_fetches()
|
||||
"""
|
||||
fetches = self.extra_compute_action_fetches()
|
||||
return {
|
||||
k: tf.saved_model.utils.build_tensor_info(fetches[k])
|
||||
for k in fetches.keys()
|
||||
}
|
||||
|
||||
def _build_signature_def(self):
|
||||
"""Build signature def map for tensorflow SavedModelBuilder.
|
||||
"""
|
||||
# build input signatures
|
||||
input_signature = self._extra_input_signature_def()
|
||||
input_signature["observations"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._obs_input)
|
||||
|
||||
if self._seq_lens is not None:
|
||||
input_signature["seq_lens"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._seq_lens)
|
||||
if self._prev_action_input is not None:
|
||||
input_signature["prev_action"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._prev_action_input)
|
||||
if self._prev_reward_input is not None:
|
||||
input_signature["prev_reward"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._prev_reward_input)
|
||||
input_signature["is_training"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._is_training)
|
||||
|
||||
for state_input in self._state_inputs:
|
||||
input_signature[state_input.name] = \
|
||||
tf.saved_model.utils.build_tensor_info(state_input)
|
||||
|
||||
# build output signatures
|
||||
output_signature = self._extra_output_signature_def()
|
||||
output_signature["actions"] = \
|
||||
tf.saved_model.utils.build_tensor_info(self._sampler)
|
||||
for state_output in self._state_outputs:
|
||||
output_signature[state_output.name] = \
|
||||
tf.saved_model.utils.build_tensor_info(state_output)
|
||||
signature_def = (
|
||||
tf.saved_model.signature_def_utils.build_signature_def(
|
||||
input_signature, output_signature,
|
||||
tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
|
||||
signature_def_key = (tf.saved_model.signature_constants.
|
||||
DEFAULT_SERVING_SIGNATURE_DEF_KEY)
|
||||
signature_def_map = {signature_def_key: signature_def}
|
||||
return signature_def_map
|
||||
|
||||
def _build_compute_actions(self,
|
||||
builder,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
episodes=None):
|
||||
state_batches = state_batches or []
|
||||
if len(self._state_inputs) != len(state_batches):
|
||||
raise ValueError(
|
||||
"Must pass in RNN state batches for placeholders {}, got {}".
|
||||
format(self._state_inputs, state_batches))
|
||||
builder.add_feed_dict(self.extra_compute_action_feed_dict())
|
||||
builder.add_feed_dict({self._obs_input: obs_batch})
|
||||
if state_batches:
|
||||
builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))})
|
||||
if self._prev_action_input is not None and prev_action_batch:
|
||||
builder.add_feed_dict({self._prev_action_input: prev_action_batch})
|
||||
if self._prev_reward_input is not None and prev_reward_batch:
|
||||
builder.add_feed_dict({self._prev_reward_input: prev_reward_batch})
|
||||
builder.add_feed_dict({self._is_training: False})
|
||||
builder.add_feed_dict(dict(zip(self._state_inputs, state_batches)))
|
||||
fetches = builder.add_fetches([self._sampler] + self._state_outputs +
|
||||
[self.extra_compute_action_fetches()])
|
||||
return fetches[0], fetches[1:-1], fetches[-1]
|
||||
|
||||
def _build_compute_gradients(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
fetches = builder.add_fetches(
|
||||
[self._grads, self._get_grad_and_stats_fetches()])
|
||||
return fetches[0], fetches[1]
|
||||
|
||||
def _build_apply_gradients(self, builder, gradients):
|
||||
if len(gradients) != len(self._grads):
|
||||
raise ValueError(
|
||||
"Unexpected number of gradients to apply, got {} for {}".
|
||||
format(gradients, self._grads))
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
builder.add_feed_dict(dict(zip(self._grads, gradients)))
|
||||
fetches = builder.add_fetches([self._apply_op])
|
||||
return fetches[0]
|
||||
|
||||
def _build_learn_on_batch(self, builder, postprocessed_batch):
|
||||
builder.add_feed_dict(self.extra_compute_grad_feed_dict())
|
||||
builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch))
|
||||
builder.add_feed_dict({self._is_training: True})
|
||||
fetches = builder.add_fetches([
|
||||
self._apply_op,
|
||||
self._get_grad_and_stats_fetches(),
|
||||
])
|
||||
return fetches[1]
|
||||
|
||||
def _get_grad_and_stats_fetches(self):
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
if LEARNER_STATS_KEY not in fetches:
|
||||
raise ValueError(
|
||||
"Grad fetches should contain 'stats': {...} entry")
|
||||
if self._stats_fetches:
|
||||
fetches[LEARNER_STATS_KEY] = dict(self._stats_fetches,
|
||||
**fetches[LEARNER_STATS_KEY])
|
||||
return fetches
|
||||
|
||||
def _get_loss_inputs_dict(self, batch):
|
||||
feed_dict = {}
|
||||
if self._batch_divisibility_req > 1:
|
||||
meets_divisibility_reqs = (
|
||||
len(batch[SampleBatch.CUR_OBS]) %
|
||||
self._batch_divisibility_req == 0
|
||||
and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent
|
||||
else:
|
||||
meets_divisibility_reqs = True
|
||||
|
||||
# Simple case: not RNN nor do we need to pad
|
||||
if not self._state_inputs and meets_divisibility_reqs:
|
||||
for k, ph in self._loss_inputs:
|
||||
feed_dict[ph] = batch[k]
|
||||
return feed_dict
|
||||
|
||||
if self._state_inputs:
|
||||
max_seq_len = self._max_seq_len
|
||||
dynamic_max = True
|
||||
else:
|
||||
max_seq_len = self._batch_divisibility_req
|
||||
dynamic_max = False
|
||||
|
||||
# RNN or multi-agent case
|
||||
feature_keys = [k for k, v in self._loss_inputs]
|
||||
state_keys = [
|
||||
"state_in_{}".format(i) for i in range(len(self._state_inputs))
|
||||
]
|
||||
feature_sequences, initial_states, seq_lens = chop_into_sequences(
|
||||
batch[SampleBatch.EPS_ID],
|
||||
batch[SampleBatch.UNROLL_ID],
|
||||
batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys],
|
||||
[batch[k] for k in state_keys],
|
||||
max_seq_len,
|
||||
dynamic_max=dynamic_max)
|
||||
for k, v in zip(feature_keys, feature_sequences):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
for k, v in zip(state_keys, initial_states):
|
||||
feed_dict[self._loss_input_dict[k]] = v
|
||||
feed_dict[self._seq_lens] = seq_lens
|
||||
|
||||
if log_once("rnn_feed_dict"):
|
||||
logger.info("Padded input for RNN:\n\n{}\n".format(
|
||||
summarize({
|
||||
"features": feature_sequences,
|
||||
"initial_states": initial_states,
|
||||
"seq_lens": seq_lens,
|
||||
"max_seq_len": max_seq_len,
|
||||
})))
|
||||
return feed_dict
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
class LearningRateSchedule(object):
|
||||
"""Mixin for TFPolicy that adds a learning rate schedule."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, lr, lr_schedule):
|
||||
self.cur_lr = tf.get_variable("lr", initializer=lr)
|
||||
if lr_schedule is None:
|
||||
self.lr_schedule = ConstantSchedule(lr)
|
||||
else:
|
||||
self.lr_schedule = PiecewiseSchedule(
|
||||
lr_schedule, outside_value=lr_schedule[-1][-1])
|
||||
|
||||
@override(Policy)
|
||||
def on_global_var_update(self, global_vars):
|
||||
super(LearningRateSchedule, self).on_global_var_update(global_vars)
|
||||
self.cur_lr.load(
|
||||
self.lr_schedule.value(global_vars["timestep"]),
|
||||
session=self._sess)
|
||||
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
return tf.train.AdamOptimizer(self.cur_lr)
|
||||
@@ -0,0 +1,146 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.tf_policy import TFPolicy
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def build_tf_policy(name,
|
||||
loss_fn,
|
||||
get_default_config=None,
|
||||
stats_fn=None,
|
||||
grad_stats_fn=None,
|
||||
extra_action_fetches_fn=None,
|
||||
postprocess_fn=None,
|
||||
optimizer_fn=None,
|
||||
gradients_fn=None,
|
||||
before_init=None,
|
||||
before_loss_init=None,
|
||||
after_init=None,
|
||||
make_action_sampler=None,
|
||||
mixins=None,
|
||||
get_batch_divisibility_req=None):
|
||||
"""Helper function for creating a dynamic tf policy at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the policy (e.g., "PPOTFPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
get_default_config (func): optional function that returns the default
|
||||
config to merge with any overrides
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and batch input tensors
|
||||
grad_stats_fn (func): optional function that returns a dict of
|
||||
TF fetches given the policy and loss gradient tensors
|
||||
extra_action_fetches_fn (func): optional function that returns
|
||||
a dict of TF fetches given the policy object
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
optimizer_fn (func): optional function that returns a tf.Optimizer
|
||||
given the policy and config
|
||||
gradients_fn (func): optional function that returns a list of gradients
|
||||
given a tf optimizer and loss tensor. If not specified, this
|
||||
defaults to optimizer.compute_gradients(loss)
|
||||
before_init (func): optional function to run at the beginning of
|
||||
policy init that takes the same arguments as the policy constructor
|
||||
before_loss_init (func): optional function to run prior to loss
|
||||
init that takes the same arguments as the policy constructor
|
||||
after_init (func): optional function to run at the end of policy init
|
||||
that takes the same arguments as the policy constructor
|
||||
make_action_sampler (func): optional function that returns a
|
||||
tuple of action and action prob tensors. The function takes
|
||||
(policy, input_dict, obs_space, action_space, config) as its
|
||||
arguments
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the DynamicTFPolicy class
|
||||
get_batch_divisibility_req (func): optional function that returns
|
||||
the divisibility requirement for sample batches
|
||||
|
||||
Returns:
|
||||
a DynamicTFPolicy instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TFPolicy"):
|
||||
raise ValueError("Name should match *TFPolicy", name)
|
||||
|
||||
base = DynamicTFPolicy
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
pass
|
||||
|
||||
base = new_base
|
||||
|
||||
class policy_cls(base):
|
||||
def __init__(self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
existing_inputs=None):
|
||||
if get_default_config:
|
||||
config = dict(get_default_config(), **config)
|
||||
|
||||
if before_init:
|
||||
before_init(self, obs_space, action_space, config)
|
||||
|
||||
def before_loss_init_wrapper(policy, obs_space, action_space,
|
||||
config):
|
||||
if before_loss_init:
|
||||
before_loss_init(policy, obs_space, action_space, config)
|
||||
if extra_action_fetches_fn is None:
|
||||
self._extra_action_fetches = {}
|
||||
else:
|
||||
self._extra_action_fetches = extra_action_fetches_fn(self)
|
||||
|
||||
DynamicTFPolicy.__init__(
|
||||
self,
|
||||
obs_space,
|
||||
action_space,
|
||||
config,
|
||||
loss_fn,
|
||||
stats_fn=stats_fn,
|
||||
grad_stats_fn=grad_stats_fn,
|
||||
before_loss_init=before_loss_init_wrapper,
|
||||
existing_inputs=existing_inputs)
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
if not postprocess_fn:
|
||||
return sample_batch
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
@override(TFPolicy)
|
||||
def optimizer(self):
|
||||
if optimizer_fn:
|
||||
return optimizer_fn(self, self.config)
|
||||
else:
|
||||
return TFPolicy.optimizer(self)
|
||||
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if gradients_fn:
|
||||
return gradients_fn(self, optimizer, loss)
|
||||
else:
|
||||
return TFPolicy.gradients(self, optimizer, loss)
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
**self._extra_action_fetches)
|
||||
|
||||
policy_cls.__name__ = name
|
||||
policy_cls.__qualname__ = name
|
||||
return policy_cls
|
||||
@@ -0,0 +1,173 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from threading import Lock
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
pass # soft dep
|
||||
|
||||
from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.tracking_dict import UsageTrackingDict
|
||||
|
||||
|
||||
class TorchPolicy(Policy):
|
||||
"""Template for a PyTorch policy and loss to use with RLlib.
|
||||
|
||||
This is similar to TFPolicy, but for PyTorch.
|
||||
|
||||
Attributes:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
lock (Lock): Lock that must be held around PyTorch ops on this graph.
|
||||
This is necessary when using the async sampler.
|
||||
"""
|
||||
|
||||
def __init__(self, observation_space, action_space, model, loss,
|
||||
action_distribution_cls):
|
||||
"""Build a policy from policy and loss torch modules.
|
||||
|
||||
Note that model will be placed on GPU device if CUDA_VISIBLE_DEVICES
|
||||
is set. Only single GPU is supported for now.
|
||||
|
||||
Arguments:
|
||||
observation_space (gym.Space): observation space of the policy.
|
||||
action_space (gym.Space): action space of the policy.
|
||||
model (nn.Module): PyTorch policy module. Given observations as
|
||||
input, this module must return a list of outputs where the
|
||||
first item is action logits, and the rest can be any value.
|
||||
loss (func): Function that takes (policy, batch_tensors)
|
||||
and returns a single scalar loss.
|
||||
action_distribution_cls (ActionDistribution): Class for action
|
||||
distribution.
|
||||
"""
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self.lock = Lock()
|
||||
self.device = (torch.device("cuda")
|
||||
if bool(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
else torch.device("cpu"))
|
||||
self._model = model.to(self.device)
|
||||
self._loss = loss
|
||||
self._optimizer = self.optimizer()
|
||||
self._action_dist_cls = action_distribution_cls
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches=None,
|
||||
prev_action_batch=None,
|
||||
prev_reward_batch=None,
|
||||
info_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
with self.lock:
|
||||
with torch.no_grad():
|
||||
ob = torch.from_numpy(np.array(obs_batch)) \
|
||||
.float().to(self.device)
|
||||
model_out = self._model({"obs": ob}, state_batches)
|
||||
logits, _, vf, state = model_out
|
||||
action_dist = self._action_dist_cls(logits)
|
||||
actions = action_dist.sample()
|
||||
return (actions.cpu().numpy(),
|
||||
[h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(model_out))
|
||||
|
||||
@override(Policy)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_out = self._loss(self, batch_tensors)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
self._optimizer.step()
|
||||
|
||||
grad_info = self.extra_grad_info(batch_tensors)
|
||||
grad_info.update(grad_process_info)
|
||||
return {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(Policy)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
batch_tensors = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
with self.lock:
|
||||
loss_out = self._loss(self, batch_tensors)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
|
||||
# Note that return values are just references;
|
||||
# calling zero_grad will modify the values
|
||||
grads = []
|
||||
for p in self._model.parameters():
|
||||
if p.grad is not None:
|
||||
grads.append(p.grad.data.cpu().numpy())
|
||||
else:
|
||||
grads.append(None)
|
||||
|
||||
grad_info = self.extra_grad_info(batch_tensors)
|
||||
grad_info.update(grad_process_info)
|
||||
return grads, {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(Policy)
|
||||
def apply_gradients(self, gradients):
|
||||
with self.lock:
|
||||
for g, p in zip(gradients, self._model.parameters()):
|
||||
if g is not None:
|
||||
p.grad = torch.from_numpy(g).to(self.device)
|
||||
self._optimizer.step()
|
||||
|
||||
@override(Policy)
|
||||
def get_weights(self):
|
||||
with self.lock:
|
||||
return {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||
|
||||
@override(Policy)
|
||||
def set_weights(self, weights):
|
||||
with self.lock:
|
||||
self._model.load_state_dict(weights)
|
||||
|
||||
@override(Policy)
|
||||
def get_initial_state(self):
|
||||
return [s.numpy() for s in self._model.state_init()]
|
||||
|
||||
def extra_grad_process(self):
|
||||
"""Allow subclass to do extra processing on gradients and
|
||||
return processing info."""
|
||||
return {}
|
||||
|
||||
def extra_action_out(self, model_out):
|
||||
"""Returns dict of extra info to include in experience batch.
|
||||
|
||||
Arguments:
|
||||
model_out (list): Outputs of the policy model module."""
|
||||
return {}
|
||||
|
||||
def extra_grad_info(self, batch_tensors):
|
||||
"""Return dict of extra grad info."""
|
||||
|
||||
return {}
|
||||
|
||||
def optimizer(self):
|
||||
"""Custom PyTorch optimizer to use."""
|
||||
if hasattr(self, "config"):
|
||||
return torch.optim.Adam(
|
||||
self._model.parameters(), lr=self.config["lr"])
|
||||
else:
|
||||
return torch.optim.Adam(self._model.parameters())
|
||||
|
||||
def _lazy_tensor_dict(self, postprocessed_batch):
|
||||
batch_tensors = UsageTrackingDict(postprocessed_batch)
|
||||
batch_tensors.set_get_interceptor(
|
||||
lambda arr: torch.from_numpy(arr).to(self.device))
|
||||
return batch_tensors
|
||||
+18
-18
@@ -2,8 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.evaluation.torch_policy_graph import TorchPolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.torch_policy import TorchPolicy
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
|
||||
@@ -24,7 +24,7 @@ def build_torch_policy(name,
|
||||
"""Helper function for creating a torch policy at runtime.
|
||||
|
||||
Arguments:
|
||||
name (str): name of the graph (e.g., "PPOPolicy")
|
||||
name (str): name of the policy (e.g., "PPOTFPolicy")
|
||||
loss_fn (func): function that returns a loss tensor the policy,
|
||||
and dict of experience tensor placeholders
|
||||
get_default_config (func): optional function that returns the default
|
||||
@@ -32,7 +32,7 @@ def build_torch_policy(name,
|
||||
stats_fn (func): optional function that returns a dict of
|
||||
values given the policy and batch input tensors
|
||||
postprocess_fn (func): optional experience postprocessing function
|
||||
that takes the same args as PolicyGraph.postprocess_trajectory()
|
||||
that takes the same args as Policy.postprocess_trajectory()
|
||||
extra_action_out_fn (func): optional function that returns
|
||||
a dict of extra values to include in experiences
|
||||
extra_grad_process_fn (func): optional function that is called after
|
||||
@@ -49,16 +49,16 @@ def build_torch_policy(name,
|
||||
model and action dist from the catalog will be used
|
||||
mixins (list): list of any class mixins for the returned policy class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the TorchPolicyGraph class
|
||||
precedence than the TorchPolicy class
|
||||
|
||||
Returns:
|
||||
a TorchPolicyGraph instance that uses the specified args
|
||||
a TorchPolicy instance that uses the specified args
|
||||
"""
|
||||
|
||||
if not name.endswith("TorchPolicy"):
|
||||
raise ValueError("Name should match *TorchPolicy", name)
|
||||
|
||||
base = TorchPolicyGraph
|
||||
base = TorchPolicy
|
||||
while mixins:
|
||||
|
||||
class new_base(mixins.pop(), base):
|
||||
@@ -84,13 +84,13 @@ def build_torch_policy(name,
|
||||
self.model = ModelCatalog.get_torch_model(
|
||||
obs_space, logit_dim, self.config["model"])
|
||||
|
||||
TorchPolicyGraph.__init__(self, obs_space, action_space,
|
||||
self.model, loss_fn, self.dist_class)
|
||||
TorchPolicy.__init__(self, obs_space, action_space, self.model,
|
||||
loss_fn, self.dist_class)
|
||||
|
||||
if after_init:
|
||||
after_init(self, obs_space, action_space, config)
|
||||
|
||||
@override(PolicyGraph)
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
@@ -100,33 +100,33 @@ def build_torch_policy(name,
|
||||
return postprocess_fn(self, sample_batch, other_agent_batches,
|
||||
episode)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
@override(TorchPolicy)
|
||||
def extra_grad_process(self):
|
||||
if extra_grad_process_fn:
|
||||
return extra_grad_process_fn(self)
|
||||
else:
|
||||
return TorchPolicyGraph.extra_grad_process(self)
|
||||
return TorchPolicy.extra_grad_process(self)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
@override(TorchPolicy)
|
||||
def extra_action_out(self, model_out):
|
||||
if extra_action_out_fn:
|
||||
return extra_action_out_fn(self, model_out)
|
||||
else:
|
||||
return TorchPolicyGraph.extra_action_out(self, model_out)
|
||||
return TorchPolicy.extra_action_out(self, model_out)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
@override(TorchPolicy)
|
||||
def optimizer(self):
|
||||
if optimizer_fn:
|
||||
return optimizer_fn(self, self.config)
|
||||
else:
|
||||
return TorchPolicyGraph.optimizer(self)
|
||||
return TorchPolicy.optimizer(self)
|
||||
|
||||
@override(TorchPolicyGraph)
|
||||
@override(TorchPolicy)
|
||||
def extra_grad_info(self, batch_tensors):
|
||||
if stats_fn:
|
||||
return stats_fn(self, batch_tensors)
|
||||
else:
|
||||
return TorchPolicyGraph.extra_grad_info(self, batch_tensors)
|
||||
return TorchPolicy.extra_grad_info(self, batch_tensors)
|
||||
|
||||
graph_cls.__name__ = name
|
||||
graph_cls.__qualname__ = name
|
||||
@@ -15,7 +15,7 @@ import ray
|
||||
from ray.rllib.agents.registry import get_agent_class
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import _DUMMY_AGENT_ID
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
|
||||
from ray.tune.util import merge_dicts
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
|
||||
@@ -7,7 +7,7 @@ import unittest
|
||||
import ray
|
||||
from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.agents.a3c import A3CTrainer
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import _adjust_nstep
|
||||
from ray.rllib.agents.dqn.dqn_policy import _adjust_nstep
|
||||
from ray.tune.registry import register_env
|
||||
import gym
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ from ray.rllib.agents.dqn import DQNTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.env.external_env import ExternalEnv
|
||||
from ray.rllib.tests.test_policy_evaluator import (BadPolicyGraph,
|
||||
MockPolicyGraph, MockEnv)
|
||||
from ray.rllib.tests.test_policy_evaluator import (BadPolicy, MockPolicy,
|
||||
MockEnv)
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
def testExternalEnvCompleteEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
@@ -131,7 +131,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
def testExternalEnvTruncateEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
batch_mode="truncate_episodes")
|
||||
for _ in range(3):
|
||||
@@ -141,7 +141,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
def testExternalEnvOffPolicy(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
@@ -153,7 +153,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
def testExternalEnvBadActions(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=BadPolicyGraph,
|
||||
policy=BadPolicy,
|
||||
sample_async=True,
|
||||
batch_steps=40,
|
||||
batch_mode="truncate_episodes")
|
||||
@@ -198,7 +198,7 @@ class TestExternalEnv(unittest.TestCase):
|
||||
def testExternalEnvHorizonNotSupported(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleServing(MockEnv(25)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
episode_horizon=20,
|
||||
batch_steps=10,
|
||||
batch_mode="complete_episodes")
|
||||
|
||||
@@ -8,11 +8,11 @@ import random
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
||||
from ray.rllib.optimizers import SyncSamplesOptimizer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv
|
||||
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph
|
||||
from ray.rllib.tests.test_policy_evaluator import MockPolicy
|
||||
from ray.rllib.tests.test_external_env import make_simple_serving
|
||||
from ray.rllib.tests.test_multi_agent_env import BasicMultiAgent, MultiCartpole
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
@@ -25,7 +25,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
agents = 4
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
batch_mode="complete_episodes")
|
||||
for _ in range(3):
|
||||
@@ -37,7 +37,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
agents = 4
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=40,
|
||||
batch_mode="truncate_episodes")
|
||||
for _ in range(3):
|
||||
@@ -51,9 +51,9 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50)
|
||||
@@ -72,7 +72,7 @@ class TestExternalMultiAgentEnv(unittest.TestCase):
|
||||
policy_ids = list(policies.keys())
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = SyncSamplesOptimizer(ev, [])
|
||||
|
||||
@@ -15,7 +15,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
|
||||
from ray.rllib.offline.json_writer import _to_json
|
||||
@@ -167,7 +167,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
"num_workers": 0,
|
||||
"output": self.test_dir,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
@@ -188,7 +188,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
"input_evaluation": ["simulation"],
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
|
||||
@@ -8,14 +8,14 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.agents.dqn.dqn_policy_graph import DQNPolicyGraph
|
||||
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
||||
from ray.rllib.agents.dqn.dqn_policy import DQNTFPolicy
|
||||
from ray.rllib.optimizers import (SyncSamplesOptimizer, SyncReplayOptimizer,
|
||||
AsyncGradientsOptimizer)
|
||||
from ray.rllib.tests.test_policy_evaluator import (MockEnv, MockEnv2,
|
||||
MockPolicyGraph)
|
||||
MockPolicy)
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.env.base_env import _MultiAgentEnvToBaseEnv
|
||||
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
@@ -329,9 +329,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50)
|
||||
@@ -347,9 +347,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
@@ -364,9 +364,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_steps=50,
|
||||
@@ -380,9 +380,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: BasicMultiAgent(5),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
episode_horizon=10, # test with episode horizon set
|
||||
@@ -395,9 +395,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = gym.spaces.Discrete(2)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: EarlyDoneMultiAgent(),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
"p1": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2),
|
||||
batch_mode="complete_episodes",
|
||||
@@ -411,8 +411,8 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs_space = gym.spaces.Discrete(10)
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True),
|
||||
policy_graph={
|
||||
"p0": (MockPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (MockPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
batch_steps=50)
|
||||
@@ -445,7 +445,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
def testCustomRNNStateValues(self):
|
||||
h = {"some": {"arbitrary": "structure", "here": [1, 2, 3]}}
|
||||
|
||||
class StatefulPolicyGraph(PolicyGraph):
|
||||
class StatefulPolicy(Policy):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
@@ -460,7 +460,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=StatefulPolicyGraph,
|
||||
policy=StatefulPolicy,
|
||||
batch_steps=5)
|
||||
batch = ev.sample()
|
||||
self.assertEqual(batch.count, 5)
|
||||
@@ -470,7 +470,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
self.assertEqual(batch["state_out_0"][1], h)
|
||||
|
||||
def testReturningModelBasedRolloutsData(self):
|
||||
class ModelBasedPolicyGraph(PGTFPolicy):
|
||||
class ModelBasedPolicy(PGTFPolicy):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
@@ -505,9 +505,9 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
act_space = single_env.action_space
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(2),
|
||||
policy_graph={
|
||||
"p0": (ModelBasedPolicyGraph, obs_space, act_space, {}),
|
||||
"p1": (ModelBasedPolicyGraph, obs_space, act_space, {}),
|
||||
policy={
|
||||
"p0": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
"p1": (ModelBasedPolicy, obs_space, act_space, {}),
|
||||
},
|
||||
policy_mapping_fn=lambda agent_id: "p0",
|
||||
batch_steps=5)
|
||||
@@ -547,7 +547,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policies": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
@@ -579,17 +579,17 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
# happen since the replay buffer doesn't encode extra fields like
|
||||
# "advantages" that PG uses.
|
||||
policies = {
|
||||
"p1": (DQNPolicyGraph, obs_space, act_space, dqn_config),
|
||||
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
|
||||
"p1": (DQNTFPolicy, obs_space, act_space, dqn_config),
|
||||
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
|
||||
}
|
||||
else:
|
||||
policies = {
|
||||
"p1": (PGTFPolicy, obs_space, act_space, {}),
|
||||
"p2": (DQNPolicyGraph, obs_space, act_space, dqn_config),
|
||||
"p2": (DQNTFPolicy, obs_space, act_space, dqn_config),
|
||||
}
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2],
|
||||
batch_steps=50)
|
||||
if optimizer_cls == AsyncGradientsOptimizer:
|
||||
@@ -600,7 +600,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
remote_evs = [
|
||||
PolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy=policies,
|
||||
policy_mapping_fn=policy_mapper,
|
||||
batch_steps=50)
|
||||
]
|
||||
@@ -610,12 +610,16 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
for i in range(200):
|
||||
ev.foreach_policy(lambda p, _: p.set_epsilon(
|
||||
max(0.02, 1 - i * .02))
|
||||
if isinstance(p, DQNPolicyGraph) else None)
|
||||
if isinstance(p, DQNTFPolicy) else None)
|
||||
optimizer.step()
|
||||
result = collect_metrics(ev, remote_evs)
|
||||
if i % 20 == 0:
|
||||
ev.foreach_policy(lambda p, _: p.update_target() if isinstance(
|
||||
p, DQNPolicyGraph) else None)
|
||||
|
||||
def do_update(p):
|
||||
if isinstance(p, DQNTFPolicy):
|
||||
p.update_target()
|
||||
|
||||
ev.foreach_policy(lambda p, _: do_update(p))
|
||||
print("Iter {}, rew {}".format(i,
|
||||
result["policy_reward_mean"]))
|
||||
print("Total reward", result["episode_reward_mean"])
|
||||
@@ -645,7 +649,7 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
policy_ids = list(policies.keys())
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MultiCartpole(n),
|
||||
policy_graph=policies,
|
||||
policy=policies,
|
||||
policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
|
||||
batch_steps=100)
|
||||
optimizer = SyncSamplesOptimizer(ev, [])
|
||||
|
||||
@@ -12,7 +12,7 @@ import unittest
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGTFPolicy
|
||||
from ray.rllib.agents.pg.pg_policy import PGTFPolicy
|
||||
from ray.rllib.env import MultiAgentEnv
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
@@ -331,7 +331,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
"sample_batch_size": 5,
|
||||
"train_batch_size": 5,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policies": {
|
||||
"tuple_policy": (
|
||||
PGTFPolicy, TUPLE_SPACE, act_space,
|
||||
{"model": {"custom_model": "tuple_spy"}}),
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.ppo import PPOTrainer
|
||||
from ray.rllib.agents.ppo.ppo_policy_graph import PPOTFPolicy
|
||||
from ray.rllib.agents.ppo.ppo_policy import PPOTFPolicy
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.optimizers import AsyncGradientsOptimizer, AsyncSamplesOptimizer
|
||||
@@ -240,12 +240,12 @@ class AsyncSamplesOptimizerTest(unittest.TestCase):
|
||||
|
||||
local = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=PPOTFPolicy,
|
||||
policy=PPOTFPolicy,
|
||||
tf_session_creator=make_sess)
|
||||
remotes = [
|
||||
PolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=PPOTFPolicy,
|
||||
policy=PPOTFPolicy,
|
||||
tf_session_creator=make_sess)
|
||||
]
|
||||
return local, remotes
|
||||
|
||||
@@ -8,7 +8,7 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.tests.test_policy_evaluator import MockPolicyGraph
|
||||
from ray.rllib.tests.test_policy_evaluator import MockPolicy
|
||||
|
||||
|
||||
class TestPerf(unittest.TestCase):
|
||||
@@ -19,7 +19,7 @@ class TestPerf(unittest.TestCase):
|
||||
for _ in range(20):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=100)
|
||||
start = time.time()
|
||||
count = 0
|
||||
|
||||
@@ -14,14 +14,14 @@ from ray.rllib.agents.pg import PGTrainer
|
||||
from ray.rllib.agents.a3c import A2CTrainer
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
from ray.rllib.evaluation.policy_graph import PolicyGraph
|
||||
from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
from ray.rllib.evaluation.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
|
||||
class MockPolicyGraph(PolicyGraph):
|
||||
class MockPolicy(Policy):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
@@ -39,7 +39,7 @@ class MockPolicyGraph(PolicyGraph):
|
||||
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
|
||||
|
||||
|
||||
class BadPolicyGraph(PolicyGraph):
|
||||
class BadPolicy(Policy):
|
||||
def compute_actions(self,
|
||||
obs_batch,
|
||||
state_batches,
|
||||
@@ -132,8 +132,7 @@ class MockVectorEnv(VectorEnv):
|
||||
class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testBasic(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph)
|
||||
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
|
||||
batch = ev.sample()
|
||||
for key in [
|
||||
"obs", "actions", "rewards", "dones", "advantages",
|
||||
@@ -157,8 +156,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
|
||||
def testBatchIds(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph)
|
||||
env_creator=lambda _: gym.make("CartPole-v0"), policy=MockPolicy)
|
||||
batch1 = ev.sample()
|
||||
batch2 = ev.sample()
|
||||
self.assertEqual(len(set(batch1["unroll_id"])), 1)
|
||||
@@ -229,7 +227,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
# clipping on
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
clip_rewards=True,
|
||||
batch_mode="complete_episodes")
|
||||
self.assertEqual(max(ev.sample()["rewards"]), 1)
|
||||
@@ -239,7 +237,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
# clipping off
|
||||
ev2 = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv2(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
clip_rewards=False,
|
||||
batch_mode="complete_episodes")
|
||||
self.assertEqual(max(ev2.sample()["rewards"]), 100)
|
||||
@@ -249,7 +247,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testHardHorizon(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
episode_horizon=4,
|
||||
@@ -263,7 +261,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testSoftHorizon(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes",
|
||||
batch_steps=10,
|
||||
episode_horizon=4,
|
||||
@@ -277,11 +275,11 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testMetrics(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes")
|
||||
remote_ev = PolicyEvaluator.as_remote().remote(
|
||||
env_creator=lambda _: MockEnv(episode_length=10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_mode="complete_episodes")
|
||||
ev.sample()
|
||||
ray.get(remote_ev.sample.remote())
|
||||
@@ -293,7 +291,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
sample_async=True,
|
||||
policy_graph=MockPolicyGraph)
|
||||
policy=MockPolicy)
|
||||
batch = ev.sample()
|
||||
for key in ["obs", "actions", "rewards", "dones", "advantages"]:
|
||||
self.assertIn(key, batch)
|
||||
@@ -302,7 +300,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testAutoVectorization(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=2,
|
||||
num_envs=8)
|
||||
@@ -325,7 +323,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testBatchesLargerWhenVectorized(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(episode_length=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=4,
|
||||
num_envs=4)
|
||||
@@ -340,7 +338,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testVectorEnvSupport(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_mode="truncate_episodes",
|
||||
batch_steps=10)
|
||||
for _ in range(8):
|
||||
@@ -357,7 +355,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testTruncateEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=15,
|
||||
batch_mode="truncate_episodes")
|
||||
batch = ev.sample()
|
||||
@@ -366,7 +364,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testCompleteEpisodes(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=5,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
@@ -375,7 +373,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testCompleteEpisodesPacking(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
batch_steps=15,
|
||||
batch_mode="complete_episodes")
|
||||
batch = ev.sample()
|
||||
@@ -387,7 +385,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testFilterSync(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
sample_async=True,
|
||||
observation_filter="ConcurrentMeanStdFilter")
|
||||
time.sleep(2)
|
||||
@@ -400,7 +398,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testGetFilters(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
sample_async=True,
|
||||
observation_filter="ConcurrentMeanStdFilter")
|
||||
self.sample_and_flush(ev)
|
||||
@@ -415,7 +413,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
def testSyncFilter(self):
|
||||
ev = PolicyEvaluator(
|
||||
env_creator=lambda _: gym.make("CartPole-v0"),
|
||||
policy_graph=MockPolicyGraph,
|
||||
policy=MockPolicy,
|
||||
sample_async=True,
|
||||
observation_filter="ConcurrentMeanStdFilter")
|
||||
obs_f = self.sample_and_flush(ev)
|
||||
|
||||
@@ -10,13 +10,30 @@ from ray.tune.util import merge_dicts, deep_update
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def renamed_class(cls):
|
||||
def renamed_class(cls, old_name):
|
||||
"""Helper class for renaming classes with a warning."""
|
||||
|
||||
class DeprecationWrapper(cls):
|
||||
# note: **kw not supported for ray.remote classes
|
||||
def __init__(self, *args, **kw):
|
||||
new_name = cls.__module__ + "." + cls.__name__
|
||||
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
|
||||
format(old_name, new_name) +
|
||||
"This will raise an error in the future.")
|
||||
cls.__init__(self, *args, **kw)
|
||||
|
||||
DeprecationWrapper.__name__ = cls.__name__
|
||||
|
||||
return DeprecationWrapper
|
||||
|
||||
|
||||
def renamed_agent(cls):
|
||||
"""Helper class for renaming Agent => Trainer with a warning."""
|
||||
|
||||
class DeprecationWrapper(cls):
|
||||
def __init__(self, config=None, env=None, logger_creator=None):
|
||||
old_name = cls.__name__.replace("Trainer", "Agent")
|
||||
new_name = cls.__name__
|
||||
new_name = cls.__module__ + "." + cls.__name__
|
||||
logger.warn("DeprecationWarning: {} has been renamed to {}. ".
|
||||
format(old_name, new_name) +
|
||||
"This will raise an error in the future.")
|
||||
|
||||
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import pprint
|
||||
import time
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, MultiAgentBatch
|
||||
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
|
||||
|
||||
_logged = set()
|
||||
_disabled = False
|
||||
|
||||
Reference in New Issue
Block a user