mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 15:40:00 +08:00
* WIP. * Fix. * Fix. * Fix.
This commit is contained in:
@@ -274,7 +274,7 @@ DDPGTorchPolicy = build_torch_policy(
|
||||
optimizer_fn=make_ddpg_optimizers,
|
||||
validate_spaces=validate_spaces,
|
||||
before_init=before_init_fn,
|
||||
after_init=setup_late_mixins,
|
||||
before_loss_init=setup_late_mixins,
|
||||
action_distribution_fn=get_distribution_inputs_and_class,
|
||||
make_model_and_action_dist=build_ddpg_models_and_action_dist,
|
||||
apply_gradients_fn=apply_gradients_fn,
|
||||
|
||||
@@ -317,9 +317,9 @@ def setup_early_mixins(policy: Policy, obs_space, action_space,
|
||||
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
|
||||
|
||||
|
||||
def after_init(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
def before_loss_init(policy: Policy, obs_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
config: TrainerConfigDict) -> None:
|
||||
ComputeTDErrorMixin.__init__(policy)
|
||||
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)
|
||||
# Move target net to device (this is done automatically for the
|
||||
@@ -397,7 +397,7 @@ DQNTorchPolicy = build_torch_policy(
|
||||
extra_learn_fetches_fn=lambda policy: {"td_error": policy.q_loss.td_error},
|
||||
extra_action_out_fn=extra_action_out_fn,
|
||||
before_init=setup_early_mixins,
|
||||
after_init=after_init,
|
||||
before_loss_init=before_loss_init,
|
||||
mixins=[
|
||||
TargetNetworkMixin,
|
||||
ComputeTDErrorMixin,
|
||||
|
||||
@@ -81,5 +81,5 @@ MARWILTorchPolicy = build_torch_policy(
|
||||
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
|
||||
stats_fn=stats,
|
||||
postprocess_fn=postprocess_advantages,
|
||||
after_init=setup_mixins,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[ValueNetworkMixin])
|
||||
|
||||
@@ -331,7 +331,7 @@ AsyncPPOTorchPolicy = build_torch_policy(
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=choose_optimizer,
|
||||
before_init=setup_early_mixins,
|
||||
after_init=setup_late_mixins,
|
||||
before_loss_init=setup_late_mixins,
|
||||
make_model=make_appo_model,
|
||||
mixins=[
|
||||
LearningRateSchedule, KLCoeffMixin, TargetNetworkMixin,
|
||||
|
||||
@@ -47,7 +47,12 @@ def ppo_surrogate_loss(
|
||||
|
||||
# RNN case: Mask away 0-padded chunks at end of time axis.
|
||||
if state:
|
||||
max_seq_len = tf.reduce_max(train_batch["seq_lens"])
|
||||
# Derive max_seq_len from the data itself, not from the seq_lens
|
||||
# tensor. This is in case e.g. seq_lens=[2, 3], but the data is still
|
||||
# 0-padded up to T=5 (as it's the case for attention nets).
|
||||
B = tf.shape(train_batch["seq_lens"])[0]
|
||||
max_seq_len = tf.shape(logits)[0] // B
|
||||
|
||||
mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len)
|
||||
mask = tf.reshape(mask, [-1])
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ PPOTorchPolicy = build_torch_policy(
|
||||
postprocess_fn=postprocess_ppo_gae,
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
before_init=setup_config,
|
||||
after_init=setup_mixins,
|
||||
before_loss_init=setup_mixins,
|
||||
mixins=[
|
||||
LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin,
|
||||
ValueNetworkMixin
|
||||
|
||||
@@ -24,6 +24,7 @@ class TestAPPO(unittest.TestCase):
|
||||
for _ in framework_iterator(config):
|
||||
print("w/o v-trace")
|
||||
_config = config.copy()
|
||||
_config["vtrace"] = False
|
||||
trainer = ppo.APPOTrainer(config=_config, env="CartPole-v0")
|
||||
for i in range(num_iterations):
|
||||
print(trainer.train())
|
||||
|
||||
@@ -489,7 +489,7 @@ SACTorchPolicy = build_torch_policy(
|
||||
extra_grad_process_fn=apply_grad_clipping,
|
||||
optimizer_fn=optimizer_fn,
|
||||
validate_spaces=validate_spaces,
|
||||
after_init=setup_late_mixins,
|
||||
before_loss_init=setup_late_mixins,
|
||||
make_model_and_action_dist=build_sac_model_and_action_dist,
|
||||
mixins=[TargetNetworkMixin, ComputeTDErrorMixin],
|
||||
action_distribution_fn=action_distribution_fn,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import ray
|
||||
from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip, _adjust_nstep
|
||||
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.policy.policy import Policy
|
||||
|
||||
@@ -157,6 +157,10 @@ class _AgentCollector:
|
||||
batch = SampleBatch(batch_data)
|
||||
|
||||
if SampleBatch.UNROLL_ID not in batch.data:
|
||||
# TODO: (sven) Once we have the additional
|
||||
# model.preprocess_train_batch in place (attention net PR), we
|
||||
# should not even need UNROLL_ID anymore:
|
||||
# Add "if SampleBatch.UNROLL_ID in view_requirements:" here.
|
||||
batch.data[SampleBatch.UNROLL_ID] = np.repeat(
|
||||
_AgentCollector._next_unroll_id, batch.count)
|
||||
_AgentCollector._next_unroll_id += 1
|
||||
@@ -238,7 +242,7 @@ class _PolicyCollector:
|
||||
"""
|
||||
for view_col, data in batch.items():
|
||||
# Skip columns that are not used for training.
|
||||
if view_col in view_requirements and \
|
||||
if view_col not in view_requirements or \
|
||||
not view_requirements[view_col].used_for_training:
|
||||
continue
|
||||
self.buffers[view_col].extend(data)
|
||||
@@ -465,8 +469,7 @@ class _SimpleListCollector(_SampleCollector):
|
||||
pre_batch = collector.build(policy.view_requirements)
|
||||
pre_batches[agent_id] = (policy, pre_batch)
|
||||
|
||||
# Apply postprocessor.
|
||||
post_batches = {}
|
||||
# Apply reward clipping before calling postprocessing functions.
|
||||
if self.clip_rewards is True:
|
||||
for _, (_, pre_batch) in pre_batches.items():
|
||||
pre_batch["rewards"] = np.sign(pre_batch["rewards"])
|
||||
@@ -477,6 +480,7 @@ class _SimpleListCollector(_SampleCollector):
|
||||
a_min=-self.clip_rewards,
|
||||
a_max=self.clip_rewards)
|
||||
|
||||
post_batches = {}
|
||||
for agent_id, (_, pre_batch) in pre_batches.items():
|
||||
# Entire episode is said to be done.
|
||||
# Error if no DONE at end of this agent's trajectory.
|
||||
|
||||
@@ -257,7 +257,8 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||
directory if specified.
|
||||
log_dir (str): Directory where logs can be placed.
|
||||
log_level (str): Set the root log level on creation.
|
||||
callbacks (DefaultCallbacks): Custom training callbacks.
|
||||
callbacks (Type[DefaultCallbacks]): Custom sub-class of
|
||||
DefaultCallbacks for training/policy/rollout-worker callbacks.
|
||||
input_creator (Callable[[IOContext], InputReader]): Function that
|
||||
returns an InputReader object for loading previous generated
|
||||
experiences.
|
||||
@@ -340,7 +341,7 @@ class RolloutWorker(ParallelIteratorWorker):
|
||||
self.callbacks: "DefaultCallbacks" = callbacks()
|
||||
else:
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
self.callbacks: "DefaultCallbacks" = DefaultCallbacks()
|
||||
self.callbacks: DefaultCallbacks = DefaultCallbacks()
|
||||
self.worker_index: int = worker_index
|
||||
self.num_workers: int = num_workers
|
||||
model_config: ModelConfigDict = model_config or {}
|
||||
|
||||
@@ -1033,7 +1033,9 @@ def _process_observations_w_trajectory_view_api(
|
||||
agent_id)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_raw_obs(agent_id, raw_obs)
|
||||
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
|
||||
# Infos from the environment.
|
||||
agent_infos = infos[env_id].get(agent_id, {})
|
||||
episode._set_last_info(agent_id, agent_infos)
|
||||
|
||||
# Record transition info if applicable.
|
||||
if last_observation is None:
|
||||
@@ -1058,15 +1060,20 @@ def _process_observations_w_trajectory_view_api(
|
||||
"new_obs": filtered_obs,
|
||||
}
|
||||
# Add extra-action-fetches to collectors.
|
||||
values_dict.update(**episode.last_pi_info_for(agent_id))
|
||||
pol = policies[policy_id]
|
||||
for key, value in episode.last_pi_info_for(agent_id).items():
|
||||
values_dict[key] = value
|
||||
# Env infos for this agent.
|
||||
if "infos" in pol.view_requirements:
|
||||
values_dict["infos"] = agent_infos
|
||||
_sample_collector.add_action_reward_next_obs(
|
||||
episode.episode_id, agent_id, env_id, policy_id,
|
||||
agent_done, values_dict)
|
||||
|
||||
if not agent_done:
|
||||
item = PolicyEvalData(
|
||||
env_id, agent_id, filtered_obs, infos[env_id].get(
|
||||
agent_id, {}), None if last_observation is None else
|
||||
env_id, agent_id, filtered_obs, agent_infos, None
|
||||
if last_observation is None else
|
||||
episode.rnn_state_for(agent_id), None
|
||||
if last_observation is None else
|
||||
episode.last_action_for(agent_id),
|
||||
|
||||
@@ -10,7 +10,7 @@ import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv
|
||||
from ray.rllib.evaluation.rollout_worker import RolloutWorker
|
||||
from ray.rllib.examples.policy.episode_env_aware_policy import \
|
||||
EpisodeEnvAwarePolicy
|
||||
EpisodeEnvAwareLSTMPolicy
|
||||
from ray.rllib.policy.rnn_sequencing import pad_batch_to_sequences_of_same_size
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
@@ -121,7 +121,6 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
obs_space = Box(-1.0, 1.0, shape=(700, ))
|
||||
|
||||
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
|
||||
|
||||
from ray.tune import register_env
|
||||
register_env("ma_env", lambda c: RandomMultiAgentEnv({
|
||||
"num_agents": 2,
|
||||
@@ -147,7 +146,6 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
"policy_mapping_fn": policy_fn,
|
||||
}
|
||||
num_iterations = 2
|
||||
# Only works in torch so far.
|
||||
for _ in framework_iterator(config, frameworks="torch"):
|
||||
print("w/ traj. view API")
|
||||
config["_use_trajectory_view_api"] = True
|
||||
@@ -253,7 +251,7 @@ class TestTrajectoryViewAPI(unittest.TestCase):
|
||||
rollout_fragment_length = 200
|
||||
assert rollout_fragment_length % max_seq_len == 0
|
||||
policies = {
|
||||
"pol0": (EpisodeEnvAwarePolicy, obs_space, action_space, {}),
|
||||
"pol0": (EpisodeEnvAwareLSTMPolicy, obs_space, action_space, {}),
|
||||
}
|
||||
|
||||
def policy_fn(agent_id):
|
||||
@@ -316,8 +314,8 @@ def analyze_rnn_batch(batch, max_seq_len):
|
||||
state_in_1 = batch["state_in_1"][idx]
|
||||
|
||||
# Check postprocessing outputs.
|
||||
if "postprocessed_column" in batch:
|
||||
postprocessed_col_t = batch["postprocessed_column"][idx]
|
||||
if "2xobs" in batch:
|
||||
postprocessed_col_t = batch["2xobs"][idx]
|
||||
assert (obs_t == postprocessed_col_t / 2.0).all()
|
||||
|
||||
# Check state-in/out and next-obs values.
|
||||
@@ -386,8 +384,8 @@ def analyze_rnn_batch(batch, max_seq_len):
|
||||
r_t = batch["rewards"][k]
|
||||
|
||||
# Check postprocessing outputs.
|
||||
if "postprocessed_column" in batch:
|
||||
postprocessed_col_t = batch["postprocessed_column"][k]
|
||||
if "2xobs" in batch:
|
||||
postprocessed_col_t = batch["2xobs"][k]
|
||||
assert (obs_t == postprocessed_col_t / 2.0).all()
|
||||
|
||||
# Check state-in/out and next-obs values.
|
||||
|
||||
@@ -8,7 +8,7 @@ from ray.rllib.policy.view_requirement import ViewRequirement
|
||||
from ray.rllib.utils.annotations import override
|
||||
|
||||
|
||||
class EpisodeEnvAwarePolicy(RandomPolicy):
|
||||
class EpisodeEnvAwareLSTMPolicy(RandomPolicy):
|
||||
"""A Policy that always knows the current EpisodeID and EnvID and
|
||||
returns these in its actions."""
|
||||
|
||||
@@ -78,5 +78,67 @@ class EpisodeEnvAwarePolicy(RandomPolicy):
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
sample_batch["postprocessed_column"] = sample_batch["obs"] * 2.0
|
||||
sample_batch["2xobs"] = sample_batch["obs"] * 2.0
|
||||
return sample_batch
|
||||
|
||||
|
||||
class EpisodeEnvAwareAttentionPolicy(RandomPolicy):
|
||||
"""A Policy that always knows the current EpisodeID and EnvID and
|
||||
returns these in its actions."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.state_space = Box(-1.0, 1.0, (1, ))
|
||||
self.config["model"] = {"max_seq_len": 50}
|
||||
|
||||
class _fake_model:
|
||||
pass
|
||||
|
||||
self.model = _fake_model()
|
||||
self.model.inference_view_requirements = {
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
"env_id": ViewRequirement(),
|
||||
"t": ViewRequirement(),
|
||||
SampleBatch.OBS: ViewRequirement(),
|
||||
"state_in_0": ViewRequirement(
|
||||
"state_out_0",
|
||||
# Provide state outs -50 to -1 as "state-in".
|
||||
data_rel_pos="-50:-1",
|
||||
# Repeat the incoming state every n time steps (usually max seq
|
||||
# len).
|
||||
batch_repeat_value=self.config["model"]["max_seq_len"],
|
||||
space=self.state_space)
|
||||
}
|
||||
|
||||
self.view_requirements = dict(super()._get_default_view_requirements(),
|
||||
**self.model.inference_view_requirements)
|
||||
|
||||
@override(Policy)
|
||||
def is_recurrent(self):
|
||||
return True
|
||||
|
||||
@override(Policy)
|
||||
def compute_actions_from_input_dict(self,
|
||||
input_dict,
|
||||
explore=None,
|
||||
timestep=None,
|
||||
**kwargs):
|
||||
ts = input_dict["t"]
|
||||
print(ts)
|
||||
# Always return [episodeID, envID] as actions.
|
||||
actions = np.array([[
|
||||
input_dict[SampleBatch.AGENT_INDEX][i],
|
||||
input_dict[SampleBatch.EPS_ID][i], input_dict["env_id"][i]
|
||||
] for i, _ in enumerate(input_dict["obs"])])
|
||||
states = [np.array([[ts[i]] for i in range(len(input_dict["obs"]))])]
|
||||
self.global_timestep += 1
|
||||
return actions, states, {}
|
||||
|
||||
@override(Policy)
|
||||
def postprocess_trajectory(self,
|
||||
sample_batch,
|
||||
other_agent_batches=None,
|
||||
episode=None):
|
||||
sample_batch["3xobs"] = sample_batch["obs"] * 3.0
|
||||
return sample_batch
|
||||
|
||||
@@ -572,6 +572,7 @@ class Policy(metaclass=ABCMeta):
|
||||
SampleBatch.INFOS: ViewRequirement(),
|
||||
SampleBatch.EPS_ID: ViewRequirement(),
|
||||
SampleBatch.AGENT_INDEX: ViewRequirement(),
|
||||
SampleBatch.UNROLL_ID: ViewRequirement(),
|
||||
"t": ViewRequirement(),
|
||||
}
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ class SampleBatch:
|
||||
if self.seq_lens is not None and len(self.seq_lens) > 0:
|
||||
self.count = sum(self.seq_lens)
|
||||
else:
|
||||
self.count = len(self.data[k])
|
||||
self.count = len(next(iter(self.data.values())))
|
||||
|
||||
# Keeps track of new columns added after initial ones.
|
||||
self.new_columns = []
|
||||
|
||||
@@ -354,6 +354,8 @@ class TorchPolicy(Policy):
|
||||
)
|
||||
|
||||
train_batch = self._lazy_tensor_dict(postprocessed_batch)
|
||||
|
||||
# Calculate the actual policy loss.
|
||||
loss_out = force_list(
|
||||
self._loss(self, self.model, self.dist_class, train_batch))
|
||||
|
||||
@@ -369,6 +371,7 @@ class TorchPolicy(Policy):
|
||||
|
||||
assert len(loss_out) == len(self._optimizers)
|
||||
|
||||
# assert not any(torch.isnan(l) for l in loss_out)
|
||||
fetches = self.extra_compute_grad_fetches()
|
||||
|
||||
# Loop through all optimizers.
|
||||
@@ -376,6 +379,7 @@ class TorchPolicy(Policy):
|
||||
|
||||
all_grads = []
|
||||
for i, opt in enumerate(self._optimizers):
|
||||
# Erase gradients in all vars of this optimizer.
|
||||
opt.zero_grad()
|
||||
# Recompute gradients of loss over all variables.
|
||||
loss_out[i].backward(retain_graph=(i < len(self._optimizers) - 1))
|
||||
|
||||
@@ -184,7 +184,7 @@ class Exploration:
|
||||
Policy's own loss function and maybe the Model's custom loss.
|
||||
train_batch (SampleBatch): The training data to calculate the
|
||||
loss(es) for. This train data has already gone through
|
||||
this Exploration's `preprocess_train_batch()` method.
|
||||
this Exploration's `postprocess_trajectory()` method.
|
||||
|
||||
Returns:
|
||||
List[TensorType]: The updated list of loss terms.
|
||||
|
||||
+1
-1
@@ -66,7 +66,7 @@ def minibatches(samples, sgd_minibatch_size):
|
||||
# Replace with `if samples.seq_lens` check.
|
||||
if "state_in_0" in samples.data or "state_out_0" in samples.data:
|
||||
if log_once("not_shuffling_rnn_data_in_simple_mode"):
|
||||
logger.warning("Not shuffling RNN data for SGD in simple mode")
|
||||
logger.warning("Not time-shuffling RNN data for SGD.")
|
||||
else:
|
||||
samples.shuffle()
|
||||
|
||||
|
||||
+28
-9
@@ -37,14 +37,14 @@ def explained_variance(y, pred):
|
||||
return tf.maximum(-1.0, 1 - (diff_var / y_var))
|
||||
|
||||
|
||||
def get_placeholder(*, space=None, value=None, name=None):
|
||||
def get_placeholder(*, space=None, value=None, name=None, time_axis=False):
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
|
||||
if space is not None:
|
||||
if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
|
||||
return ModelCatalog.get_action_placeholder(space, None)
|
||||
return tf1.placeholder(
|
||||
shape=(None, ) + space.shape,
|
||||
shape=(None, ) + ((None, ) if time_axis else ()) + space.shape,
|
||||
dtype=tf.float32 if space.dtype == np.float64 else space.dtype,
|
||||
name=name,
|
||||
)
|
||||
@@ -52,8 +52,9 @@ def get_placeholder(*, space=None, value=None, name=None):
|
||||
assert value is not None
|
||||
shape = value.shape[1:]
|
||||
return tf1.placeholder(
|
||||
shape=(None, ) + (shape if isinstance(shape, tuple) else tuple(
|
||||
shape.as_list())),
|
||||
shape=(None, ) + ((None, )
|
||||
if time_axis else ()) + (shape if isinstance(
|
||||
shape, tuple) else tuple(shape.as_list())),
|
||||
dtype=tf.float32 if value.dtype == np.float64 else value.dtype,
|
||||
name=name,
|
||||
)
|
||||
@@ -132,10 +133,11 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
|
||||
|
||||
def make_wrapper(fn):
|
||||
if session_or_none:
|
||||
placeholders = []
|
||||
args_placeholders = []
|
||||
kwargs_placeholders = {}
|
||||
symbolic_out = [None]
|
||||
|
||||
def call(*args):
|
||||
def call(*args, **kwargs):
|
||||
args_flat = []
|
||||
for a in args:
|
||||
if type(a) is list:
|
||||
@@ -153,13 +155,30 @@ def make_tf_callable(session_or_none, dynamic_shape=False):
|
||||
shape = ()
|
||||
else:
|
||||
shape = v.shape
|
||||
placeholders.append(
|
||||
args_placeholders.append(
|
||||
tf1.placeholder(
|
||||
dtype=v.dtype,
|
||||
shape=shape,
|
||||
name="arg_{}".format(i)))
|
||||
symbolic_out[0] = fn(*placeholders)
|
||||
feed_dict = dict(zip(placeholders, args))
|
||||
for k, v in kwargs.items():
|
||||
if dynamic_shape:
|
||||
if len(v.shape) > 0:
|
||||
shape = (None, ) + v.shape[1:]
|
||||
else:
|
||||
shape = ()
|
||||
else:
|
||||
shape = v.shape
|
||||
kwargs_placeholders[k] = \
|
||||
tf1.placeholder(
|
||||
dtype=v.dtype,
|
||||
shape=shape,
|
||||
name="kwarg_{}".format(k))
|
||||
symbolic_out[0] = fn(*args_placeholders,
|
||||
**kwargs_placeholders)
|
||||
feed_dict = dict(zip(args_placeholders, args))
|
||||
feed_dict.update(
|
||||
{kwargs_placeholders[k]: kwargs[k]
|
||||
for k in kwargs.keys()})
|
||||
ret = session_or_none.run(symbolic_out[0], feed_dict)
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user