mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 06:14:06 +08:00
[rllib] MAML Transform (#9463)
* MAML Transform * Moved Inner Adapt to Method in Execution Plan
This commit is contained in:
+61
-64
@@ -6,14 +6,13 @@ from ray.rllib.agents import with_common_config
|
||||
from ray.rllib.agents.maml.maml_tf_policy import MAMLTFPolicy
|
||||
from ray.rllib.agents.maml.maml_torch_policy import MAMLTorchPolicy
|
||||
from ray.rllib.agents.trainer_template import build_trainer
|
||||
from typing import List
|
||||
from ray.rllib.evaluation.metrics import get_learner_stats
|
||||
from ray.rllib.execution.common import STEPS_SAMPLED_COUNTER, \
|
||||
STEPS_TRAINED_COUNTER, LEARNER_INFO, _get_shared_metrics
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.execution.metric_ops import CollectMetrics
|
||||
from ray.util.iter import from_actors
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from ray.rllib.evaluation.metrics import collect_metrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,65 +69,6 @@ def set_worker_tasks(workers):
|
||||
worker.foreach_env.remote(lambda env: env.set_task(tasks[i]))
|
||||
|
||||
|
||||
class InnerAdaptationSteps:
|
||||
def __init__(self, workers, inner_adaptation_steps, metric_gen):
|
||||
self.workers = workers
|
||||
self.n = inner_adaptation_steps
|
||||
self.buffer = []
|
||||
self.split = []
|
||||
self.metrics = {}
|
||||
self.metric_gen = metric_gen
|
||||
|
||||
def __call__(self, samples: List[SampleBatchType]):
|
||||
samples, split_lst = self.post_process_samples(samples)
|
||||
self.buffer.extend(samples)
|
||||
self.split.append(split_lst)
|
||||
self.post_process_metrics()
|
||||
if len(self.split) > self.n:
|
||||
out = SampleBatch.concat_samples(self.buffer)
|
||||
out["split"] = np.array(self.split)
|
||||
self.buffer = []
|
||||
self.split = []
|
||||
|
||||
# Metrics Reporting
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += out.count
|
||||
|
||||
# Reporting Adaptation Rew Diff
|
||||
ep_rew_pre = self.metrics["episode_reward_mean"]
|
||||
ep_rew_post = self.metrics["episode_reward_mean_adapt_" +
|
||||
str(self.n)]
|
||||
self.metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre
|
||||
return [(out, self.metrics)]
|
||||
else:
|
||||
self.inner_adaptation_step(samples)
|
||||
return []
|
||||
|
||||
def post_process_samples(self, samples):
|
||||
split_lst = []
|
||||
for sample in samples:
|
||||
sample["advantages"] = standardized(sample["advantages"])
|
||||
split_lst.append(sample.count)
|
||||
return samples, split_lst
|
||||
|
||||
def inner_adaptation_step(self, samples):
|
||||
for i, e in enumerate(self.workers.remote_workers()):
|
||||
e.learn_on_batch.remote(samples[i])
|
||||
|
||||
def post_process_metrics(self):
|
||||
# Obtain Current Dataset Metrics and filter out
|
||||
name = "_adapt_" + str(len(self.split) - 1) if len(
|
||||
self.split) > 1 else ""
|
||||
res = self.metric_gen.__call__(None)
|
||||
|
||||
self.metrics["episode_reward_max" +
|
||||
str(name)] = res["episode_reward_max"]
|
||||
self.metrics["episode_reward_mean" +
|
||||
str(name)] = res["episode_reward_mean"]
|
||||
self.metrics["episode_reward_min" +
|
||||
str(name)] = res["episode_reward_min"]
|
||||
|
||||
|
||||
class MetaUpdate:
|
||||
def __init__(self, workers, maml_steps, metric_gen):
|
||||
self.workers = workers
|
||||
@@ -139,6 +79,10 @@ class MetaUpdate:
|
||||
# Metaupdate Step
|
||||
samples = data_tuple[0]
|
||||
adapt_metrics_dict = data_tuple[1]
|
||||
|
||||
# Metric Updating
|
||||
metrics = _get_shared_metrics()
|
||||
metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count
|
||||
for i in range(self.maml_optimizer_steps):
|
||||
fetches = self.workers.local_worker().learn_on_batch(samples)
|
||||
fetches = get_learner_stats(fetches)
|
||||
@@ -172,6 +116,26 @@ class MetaUpdate:
|
||||
return res
|
||||
|
||||
|
||||
def post_process_metrics(adapt_iter, workers, metrics):
|
||||
# Obtain Current Dataset Metrics and filter out
|
||||
name = "_adapt_" + str(adapt_iter) if adapt_iter > 0 else ""
|
||||
|
||||
# Only workers are collecting data
|
||||
res = collect_metrics(remote_workers=workers.remote_workers())
|
||||
|
||||
metrics["episode_reward_max" + str(name)] = res["episode_reward_max"]
|
||||
metrics["episode_reward_mean" + str(name)] = res["episode_reward_mean"]
|
||||
metrics["episode_reward_min" + str(name)] = res["episode_reward_min"]
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def inner_adaptation(workers, samples):
|
||||
# Each worker performs one gradient descent
|
||||
for i, e in enumerate(workers.remote_workers()):
|
||||
e.learn_on_batch.remote(samples[i])
|
||||
|
||||
|
||||
def execution_plan(workers, config):
|
||||
# Sync workers with meta policy
|
||||
workers.sync_weights()
|
||||
@@ -186,11 +150,44 @@ def execution_plan(workers, config):
|
||||
timeout_seconds=config["collect_metrics_timeout"])
|
||||
|
||||
# Iterator for Inner Adaptation Data gathering (from pre->post adaptation)
|
||||
inner_steps = config["inner_adaptation_steps"]
|
||||
|
||||
def inner_adaptation_steps(itr):
|
||||
buf = []
|
||||
split = []
|
||||
metrics = {}
|
||||
for samples in itr:
|
||||
|
||||
# Processing Samples (Standardize Advantages)
|
||||
split_lst = []
|
||||
for sample in samples:
|
||||
sample["advantages"] = standardized(sample["advantages"])
|
||||
split_lst.append(sample.count)
|
||||
|
||||
buf.extend(samples)
|
||||
split.append(split_lst)
|
||||
|
||||
adapt_iter = len(split) - 1
|
||||
metrics = post_process_metrics(adapt_iter, workers, metrics)
|
||||
if len(split) > inner_steps:
|
||||
out = SampleBatch.concat_samples(buf)
|
||||
out["split"] = np.array(split)
|
||||
buf = []
|
||||
split = []
|
||||
|
||||
# Reporting Adaptation Rew Diff
|
||||
ep_rew_pre = metrics["episode_reward_mean"]
|
||||
ep_rew_post = metrics["episode_reward_mean_adapt_" +
|
||||
str(inner_steps)]
|
||||
metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre
|
||||
yield out, metrics
|
||||
metrics = {}
|
||||
else:
|
||||
inner_adaptation(workers, samples)
|
||||
|
||||
rollouts = from_actors(workers.remote_workers())
|
||||
rollouts = rollouts.batch_across_shards()
|
||||
rollouts = rollouts.combine(
|
||||
InnerAdaptationSteps(workers, config["inner_adaptation_steps"],
|
||||
metric_collect))
|
||||
rollouts = rollouts.transform(inner_adaptation_steps)
|
||||
|
||||
# Metaupdate Step
|
||||
train_op = rollouts.for_each(
|
||||
|
||||
Reference in New Issue
Block a user