[rllib] MAML Transform (#9463)

* MAML Transform

* Moved Inner Adapt to Method in Execution Plan
This commit is contained in:
Michael Luo
2020-07-16 11:11:33 -07:00
committed by GitHub
parent baf4be245d
commit 94fcd43593
+61 -64
View File
@@ -6,14 +6,13 @@ from ray.rllib.agents import with_common_config
from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy
from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy
from ray.rllib.agents.trainer_template import build_trainer
from typing import List
from ray.rllib.evaluation.metrics import get_learner_stats
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.execution.metric_ops import CollectMetrics
from ray.util.iter import from_actors
from ray.rllib.utils.types import SampleBatchType
from ray.rllib.evaluation.metrics import collect_metrics
logger = logging.getLogger(__name__)
@@ -70,65 +69,6 @@ def set_worker_tasks(workers):
worker.foreach_env.remote(lambda env: env.set_task(tasks[i]))
class InnerAdaptationSteps:
def __init__(self, workers, inner_adaptation_steps, metric_gen):
self.workers = workers
self.n = inner_adaptation_steps
self.buffer = []
self.split = []
self.metrics = {}
self.metric_gen = metric_gen
def __call__(self, samples: List[SampleBatchType]):
samples, split_lst = self.post_process_samples(samples)
self.buffer.extend(samples)
self.split.append(split_lst)
self.post_process_metrics()
if len(self.split) > self.n:
out = SampleBatch.concat_samples(self.buffer)
out["split"] = np.array(self.split)
self.buffer = []
self.split = []
# Metrics Reporting
metrics = _get_shared_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += out.count
# Reporting Adaptation Rew Diff
ep_rew_pre = self.metrics["episode_reward_mean"]
ep_rew_post = self.metrics["episode_reward_mean_adapt_" +
str(self.n)]
self.metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre
return [(out, self.metrics)]
else:
self.inner_adaptation_step(samples)
return []
def post_process_samples(self, samples):
split_lst = []
for sample in samples:
sample["advantages"] = standardized(sample["advantages"])
split_lst.append(sample.count)
return samples, split_lst
def inner_adaptation_step(self, samples):
for i, e in enumerate(self.workers.remote_workers()):
e.learn_on_batch.remote(samples[i])
def post_process_metrics(self):
# Obtain Current Dataset Metrics and filter out
name = "_adapt_" + str(len(self.split) - 1) if len(
self.split) > 1 else ""
res = self.metric_gen.__call__(None)
self.metrics["episode_reward_max" +
str(name)] = res["episode_reward_max"]
self.metrics["episode_reward_mean" +
str(name)] = res["episode_reward_mean"]
self.metrics["episode_reward_min" +
str(name)] = res["episode_reward_min"]
class MetaUpdate:
def __init__(self, workers, maml_steps, metric_gen):
self.workers = workers
@@ -139,6 +79,10 @@ class MetaUpdate:
# Metaupdate Step
samples = data_tuple[0]
adapt_metrics_dict = data_tuple[1]
# Metric Updating
metrics = _get_shared_metrics()
metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count
for i in range(self.maml_optimizer_steps):
fetches = self.workers.local_worker().learn_on_batch(samples)
fetches = get_learner_stats(fetches)
@@ -172,6 +116,26 @@ class MetaUpdate:
return res
def post_process_metrics(adapt_iter, workers, metrics):
# Obtain Current Dataset Metrics and filter out
name = "_adapt_" + str(adapt_iter) if adapt_iter > 0 else ""
# Only workers are collecting data
res = collect_metrics(remote_workers=workers.remote_workers())
metrics["episode_reward_max" + str(name)] = res["episode_reward_max"]
metrics["episode_reward_mean" + str(name)] = res["episode_reward_mean"]
metrics["episode_reward_min" + str(name)] = res["episode_reward_min"]
return metrics
def inner_adaptation(workers, samples):
# Each worker performs one gradient descent
for i, e in enumerate(workers.remote_workers()):
e.learn_on_batch.remote(samples[i])
def execution_plan(workers, config):
# Sync workers with meta policy
workers.sync_weights()
@@ -186,11 +150,44 @@ def execution_plan(workers, config):
timeout_seconds=config["collect_metrics_timeout"])
# Iterator for Inner Adaptation Data gathering (from pre->post adaptation)
inner_steps = config["inner_adaptation_steps"]
def inner_adaptation_steps(itr):
buf = []
split = []
metrics = {}
for samples in itr:
# Processing Samples (Standardize Advantages)
split_lst = []
for sample in samples:
sample["advantages"] = standardized(sample["advantages"])
split_lst.append(sample.count)
buf.extend(samples)
split.append(split_lst)
adapt_iter = len(split) - 1
metrics = post_process_metrics(adapt_iter, workers, metrics)
if len(split) > inner_steps:
out = SampleBatch.concat_samples(buf)
out["split"] = np.array(split)
buf = []
split = []
# Reporting Adaptation Rew Diff
ep_rew_pre = metrics["episode_reward_mean"]
ep_rew_post = metrics["episode_reward_mean_adapt_" +
str(inner_steps)]
metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre
yield out, metrics
metrics = {}
else:
inner_adaptation(workers, samples)
rollouts = from_actors(workers.remote_workers())
rollouts = rollouts.batch_across_shards()
rollouts = rollouts.combine(
InnerAdaptationSteps(workers, config["inner_adaptation_steps"],
metric_collect))
rollouts = rollouts.transform(inner_adaptation_steps)
# Metaupdate Step
train_op = rollouts.for_each(