From 94fcd43593a7b7675ff857cbfd6d189965899a65 Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Thu, 16 Jul 2020 11:11:33 -0700 Subject: [PATCH] [rllib] MAML Transform (#9463) * MAML Transform * Moved Inner Adapt to Method in Execution Plan --- rllib/agents/maml/maml.py | 125 +++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 64 deletions(-) diff --git a/rllib/agents/maml/maml.py b/rllib/agents/maml/maml.py index 92d2d4329..1ff36c0cb 100644 --- a/rllib/agents/maml/maml.py +++ b/rllib/agents/maml/maml.py @@ -6,14 +6,13 @@ from ray.rllib.agents import with_common_config from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy from ray.rllib.agents.trainer_template import build_trainer -from typing import List from ray.rllib.evaluation.metrics import get_learner_stats from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \ STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.execution.metric_ops import CollectMetrics from ray.util.iter import from_actors -from ray.rllib.utils.types import SampleBatchType +from ray.rllib.evaluation.metrics import collect_metrics logger = logging.getLogger(__name__) @@ -70,65 +69,6 @@ def set_worker_tasks(workers): worker.foreach_env.remote(lambda env: env.set_task(tasks[i])) -class InnerAdaptationSteps: - def __init__(self, workers, inner_adaptation_steps, metric_gen): - self.workers = workers - self.n = inner_adaptation_steps - self.buffer = [] - self.split = [] - self.metrics = {} - self.metric_gen = metric_gen - - def __call__(self, samples: List[SampleBatchType]): - samples, split_lst = self.post_process_samples(samples) - self.buffer.extend(samples) - self.split.append(split_lst) - self.post_process_metrics() - if len(self.split) > self.n: - out = SampleBatch.concat_samples(self.buffer) - out["split"] = np.array(self.split) - self.buffer = [] - self.split = [] - - # Metrics Reporting - metrics = _get_shared_metrics() - metrics.counters[STEPS_SAMPLED_COUNTER] += out.count - - # Reporting Adaptation Rew Diff - ep_rew_pre = self.metrics["episode_reward_mean"] - ep_rew_post = self.metrics["episode_reward_mean_adapt_" + - str(self.n)] - self.metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre - return [(out, self.metrics)] - else: - self.inner_adaptation_step(samples) - return [] - - def post_process_samples(self, samples): - split_lst = [] - for sample in samples: - sample["advantages"] = standardized(sample["advantages"]) - split_lst.append(sample.count) - return samples, split_lst - - def inner_adaptation_step(self, samples): - for i, e in enumerate(self.workers.remote_workers()): - e.learn_on_batch.remote(samples[i]) - - def post_process_metrics(self): - # Obtain Current Dataset Metrics and filter out - name = "_adapt_" + str(len(self.split) - 1) if len( - self.split) > 1 else "" - res = self.metric_gen.__call__(None) - - self.metrics["episode_reward_max" + - str(name)] = res["episode_reward_max"] - self.metrics["episode_reward_mean" + - str(name)] = res["episode_reward_mean"] - self.metrics["episode_reward_min" + - str(name)] = res["episode_reward_min"] - - class MetaUpdate: def __init__(self, workers, maml_steps, metric_gen): self.workers = workers @@ -139,6 +79,10 @@ class MetaUpdate: # Metaupdate Step samples = data_tuple[0] adapt_metrics_dict = data_tuple[1] + + # Metric Updating + metrics = _get_shared_metrics() + metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count for i in range(self.maml_optimizer_steps): fetches = self.workers.local_worker().learn_on_batch(samples) fetches = get_learner_stats(fetches) @@ -172,6 +116,26 @@ class MetaUpdate: return res +def post_process_metrics(adapt_iter, workers, metrics): + # Obtain Current Dataset Metrics and filter out + name = "_adapt_" + str(adapt_iter) if adapt_iter > 0 else "" + + # Only workers are collecting data + res = collect_metrics(remote_workers=workers.remote_workers()) + + metrics["episode_reward_max" + str(name)] = res["episode_reward_max"] + metrics["episode_reward_mean" + str(name)] = res["episode_reward_mean"] + metrics["episode_reward_min" + str(name)] = res["episode_reward_min"] + + return metrics + + +def inner_adaptation(workers, samples): + # Each worker performs one gradient descent + for i, e in enumerate(workers.remote_workers()): + e.learn_on_batch.remote(samples[i]) + + def execution_plan(workers, config): # Sync workers with meta policy workers.sync_weights() @@ -186,11 +150,44 @@ def execution_plan(workers, config): timeout_seconds=config["collect_metrics_timeout"]) # Iterator for Inner Adaptation Data gathering (from pre->post adaptation) + inner_steps = config["inner_adaptation_steps"] + + def inner_adaptation_steps(itr): + buf = [] + split = [] + metrics = {} + for samples in itr: + + # Processing Samples (Standardize Advantages) + split_lst = [] + for sample in samples: + sample["advantages"] = standardized(sample["advantages"]) + split_lst.append(sample.count) + + buf.extend(samples) + split.append(split_lst) + + adapt_iter = len(split) - 1 + metrics = post_process_metrics(adapt_iter, workers, metrics) + if len(split) > inner_steps: + out = SampleBatch.concat_samples(buf) + out["split"] = np.array(split) + buf = [] + split = [] + + # Reporting Adaptation Rew Diff + ep_rew_pre = metrics["episode_reward_mean"] + ep_rew_post = metrics["episode_reward_mean_adapt_" + + str(inner_steps)] + metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre + yield out, metrics + metrics = {} + else: + inner_adaptation(workers, samples) + rollouts = from_actors(workers.remote_workers()) rollouts = rollouts.batch_across_shards() - rollouts = rollouts.combine( - InnerAdaptationSteps(workers, config["inner_adaptation_steps"], - metric_collect)) + rollouts = rollouts.transform(inner_adaptation_steps) # Metaupdate Step train_op = rollouts.for_each(