From 587f207c2fadc02c25ddf1dedbca4cbaf3163d48 Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Thu, 21 Jan 2021 07:43:55 -0800 Subject: [PATCH] [RLlib] Support for D4RL + Semi-working CQL Benchmark (#13550) --- rllib/agents/cql/cql.py | 2 + rllib/evaluation/worker_set.py | 5 +- rllib/offline/__init__.py | 2 + rllib/offline/d4rl_reader.py | 52 +++++++++++++++++++ rllib/tuned_examples/cql/halfcheetah-cql.yaml | 1 + 5 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 rllib/offline/d4rl_reader.py diff --git a/rllib/agents/cql/cql.py b/rllib/agents/cql/cql.py index 04a63be72..30bbe89d4 100644 --- a/rllib/agents/cql/cql.py +++ b/rllib/agents/cql/cql.py @@ -15,6 +15,8 @@ CQL_DEFAULT_CONFIG = merge_dicts( SAC_CONFIG, { # You should override this to point to an offline dataset. "input": "sampler", + # Offline RL does not need IS estimators + "input_evaluation": [], # Number of iterations with Behavior Cloning Pretraining "bc_iters": 20000, # CQL Loss Temperature diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 80cf617bb..8361e0af8 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.evaluation.rollout_worker import RolloutWorker, \ _validate_multiagent_config from ray.rllib.offline import NoopOutput, JsonReader, MixedInput, JsonWriter, \ - ShuffledInput + ShuffledInput, D4RLReader from ray.rllib.env.env_context import EnvContext from ray.rllib.policy import Policy from ray.rllib.utils import merge_dicts @@ -266,6 +266,9 @@ class WorkerSet: input_creator = ( lambda ioctx: ShuffledInput(MixedInput(config["input"], ioctx), config["shuffle_buffer_size"])) + elif "d4rl" in config["input"]: + env_name = config["input"].split(".")[1] + input_creator = (lambda ioctx: D4RLReader(env_name, ioctx)) else: input_creator = ( lambda ioctx: ShuffledInput(JsonReader(config["input"], ioctx), diff --git a/rllib/offline/__init__.py b/rllib/offline/__init__.py index 69b07c657..540151cc2 100644 --- a/rllib/offline/__init__.py +++ b/rllib/offline/__init__.py @@ -5,6 +5,7 @@ from ray.rllib.offline.output_writer import OutputWriter, NoopOutput from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.mixed_input import MixedInput from ray.rllib.offline.shuffled_input import ShuffledInput +from ray.rllib.offline.d4rl_reader import D4RLReader __all__ = [ "IOContext", @@ -15,4 +16,5 @@ __all__ = [ "InputReader", "MixedInput", "ShuffledInput", + "D4RLReader", ] diff --git a/rllib/offline/d4rl_reader.py b/rllib/offline/d4rl_reader.py new file mode 100644 index 000000000..2c02af088 --- /dev/null +++ b/rllib/offline/d4rl_reader.py @@ -0,0 +1,52 @@ +import logging +import gym + +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.io_context import IOContext +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.typing import SampleBatchType +from typing import Dict + +logger = logging.getLogger(__name__) + + +@PublicAPI +class D4RLReader(InputReader): + """Reader object that loads the dataset from the D4RL dataset.""" + + @PublicAPI + def __init__(self, inputs: str, ioctx: IOContext = None): + """Initialize a D4RLReader. + + Args: + inputs (str): String corresponding to D4RL environment name + ioctx (IOContext): Current IO context object. + """ + import d4rl + self.env = gym.make(inputs) + self.dataset = convert_to_batch(d4rl.qlearning_dataset(self.env)) + assert self.dataset.count >= 1 + self.dataset.shuffle() + self.counter = 0 + + @override(InputReader) + def next(self) -> SampleBatchType: + if self.counter >= self.dataset.count: + self.counter = 0 + self.dataset.shuffle() + + self.counter += 1 + return self.dataset.slice(start=self.counter, end=self.counter + 1) + + +def convert_to_batch(dataset: Dict) -> SampleBatchType: + # Converts D4RL dataset to SampleBatch + d = {} + d[SampleBatch.OBS] = dataset["observations"] + d[SampleBatch.ACTIONS] = dataset["actions"] + d[SampleBatch.NEXT_OBS] = dataset["next_observations"] + d[SampleBatch.REWARDS] = dataset["rewards"] + d[SampleBatch.DONES] = dataset["terminals"] + + return SampleBatch(d) diff --git a/rllib/tuned_examples/cql/halfcheetah-cql.yaml b/rllib/tuned_examples/cql/halfcheetah-cql.yaml index 5bab20751..9a5fa9982 100644 --- a/rllib/tuned_examples/cql/halfcheetah-cql.yaml +++ b/rllib/tuned_examples/cql/halfcheetah-cql.yaml @@ -5,6 +5,7 @@ halfcheetah_cql: episode_reward_mean: 9000 config: # SAC Configs + input: d4rl.halfcheetah-medium-v0 framework: torch horizon: 1000 soft_horizon: false