From b51ab2af66d6e6af890542df386ebe78e1b3cb65 Mon Sep 17 00:00:00 2001 From: Michael Luo Date: Mon, 27 Jul 2020 14:01:17 -0700 Subject: [PATCH] [RLlib] Offline Type Annotations (#9676) * Offline Annotations * Modifications * Fixed circular dependencies * Linter fix --- rllib/offline/input_reader.py | 9 ++++++--- rllib/offline/io_context.py | 9 +++++++-- rllib/offline/is_estimator.py | 3 ++- rllib/offline/json_reader.py | 17 ++++++++++------- rllib/offline/json_writer.py | 18 ++++++++++-------- rllib/offline/mixed_input.py | 7 +++++-- rllib/offline/off_policy_estimator.py | 18 +++++++++++------- rllib/offline/output_writer.py | 5 +++-- rllib/offline/shuffled_input.py | 5 +++-- rllib/offline/wis_estimator.py | 6 ++++-- rllib/utils/types.py | 3 +++ 11 files changed, 64 insertions(+), 36 deletions(-) diff --git a/rllib/offline/input_reader.py b/rllib/offline/input_reader.py index c0eeb11da..6b3c9efa1 100644 --- a/rllib/offline/input_reader.py +++ b/rllib/offline/input_reader.py @@ -5,6 +5,8 @@ import threading from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI from ray.rllib.utils.framework import try_import_tf +from typing import Dict, List +from ray.rllib.utils.types import TensorType, SampleBatchType tf1, tf, tfv = try_import_tf() @@ -25,7 +27,7 @@ class InputReader: raise NotImplementedError @PublicAPI - def tf_input_ops(self, queue_size=1): + def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]: """Returns TensorFlow queue ops for reading inputs from this reader. The main use of these ops is for integration into custom model losses. @@ -90,7 +92,8 @@ class InputReader: class _QueueRunner(threading.Thread): """Thread that feeds a TF queue from a InputReader.""" - def __init__(self, input_reader, queue, keys, dtypes): + def __init__(self, input_reader: InputReader, queue: tf1.FIFOQueue, + keys: List[str], dtypes: "tf.dtypes.DType"): threading.Thread.__init__(self) self.sess = tf1.get_default_session() self.daemon = True @@ -100,7 +103,7 @@ class _QueueRunner(threading.Thread): self.placeholders = [tf1.placeholder(dtype) for dtype in dtypes] self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders))) - def enqueue(self, batch): + def enqueue(self, batch: SampleBatchType): data = { self.placeholders[i]: batch[key] for i, key in enumerate(self.keys) diff --git a/rllib/offline/io_context.py b/rllib/offline/io_context.py index 3b4083426..4f36bce97 100644 --- a/rllib/offline/io_context.py +++ b/rllib/offline/io_context.py @@ -1,6 +1,7 @@ import os from ray.rllib.utils.annotations import PublicAPI +from typing import Any @PublicAPI @@ -18,12 +19,16 @@ class IOContext: """ @PublicAPI - def __init__(self, log_dir=None, config=None, worker_index=0, worker=None): + def __init__(self, + log_dir: str = None, + config: dict = None, + worker_index: int = 0, + worker: Any = None): self.log_dir = log_dir or os.getcwd() self.config = config or {} self.worker_index = worker_index self.worker = worker @PublicAPI - def default_sampler_input(self): + def default_sampler_input(self) -> Any: return self.worker.sampler diff --git a/rllib/offline/is_estimator.py b/rllib/offline/is_estimator.py index 6c4ce9192..58b24c691 100644 --- a/rllib/offline/is_estimator.py +++ b/rllib/offline/is_estimator.py @@ -1,6 +1,7 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ OffPolicyEstimate from ray.rllib.utils.annotations import override +from ray.rllib.utils.types import SampleBatchType class ImportanceSamplingEstimator(OffPolicyEstimator): @@ -9,7 +10,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator): Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf""" @override(OffPolicyEstimator) - def estimate(self, batch): + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: self.check_can_estimate_for(batch) rewards, old_prob = batch["rewards"], batch["action_prob"] diff --git a/rllib/offline/json_reader.py b/rllib/offline/json_reader.py index 372349cfd..40ba37bd5 100644 --- a/rllib/offline/json_reader.py +++ b/rllib/offline/json_reader.py @@ -16,6 +16,8 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \ DEFAULT_POLICY_ID from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.compression import unpack_if_needed +from ray.rllib.utils.types import FileType, SampleBatchType +from typing import List logger = logging.getLogger(__name__) @@ -27,7 +29,7 @@ class JsonReader(InputReader): The input files will be read from in an random order.""" @PublicAPI - def __init__(self, inputs, ioctx=None): + def __init__(self, inputs: List[str], ioctx: IOContext = None): """Initialize a JsonReader. Arguments: @@ -63,7 +65,7 @@ class JsonReader(InputReader): self.cur_file = None @override(InputReader) - def next(self): + def next(self) -> SampleBatchType: batch = self._try_parse(self._next_line()) tries = 0 while not batch and tries < 100: @@ -76,7 +78,8 @@ class JsonReader(InputReader): self.cur_file)) return self._postprocess_if_needed(batch) - def _postprocess_if_needed(self, batch): + def _postprocess_if_needed(self, + batch: SampleBatchType) -> SampleBatchType: if not self.ioctx.config.get("postprocess_inputs"): return batch @@ -92,7 +95,7 @@ class JsonReader(InputReader): raise NotImplementedError( "Postprocessing of multi-agent data not implemented yet.") - def _try_parse(self, line): + def _try_parse(self, line: str) -> SampleBatchType: line = line.strip() if not line: return None @@ -103,7 +106,7 @@ class JsonReader(InputReader): self.cur_file, line)) return None - def _next_line(self): + def _next_line(self) -> str: if not self.cur_file: self.cur_file = self._next_file() line = self.cur_file.readline() @@ -121,7 +124,7 @@ class JsonReader(InputReader): self.files)) return line - def _next_file(self): + def _next_file(self) -> FileType: path = random.choice(self.files) if urlparse(path).scheme not in ["", "c"]: if smart_open is None: @@ -133,7 +136,7 @@ class JsonReader(InputReader): return open(path, "r") -def _from_json(batch): +def _from_json(batch: str) -> SampleBatchType: if isinstance(batch, bytes): # smart_open S3 doesn't respect "r" batch = batch.decode("utf-8") data = json.loads(batch) diff --git a/rllib/offline/json_writer.py b/rllib/offline/json_writer.py index f9700eb44..c2e3cd942 100644 --- a/rllib/offline/json_writer.py +++ b/rllib/offline/json_writer.py @@ -16,6 +16,8 @@ from ray.rllib.offline.io_context import IOContext from ray.rllib.offline.output_writer import OutputWriter from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.compression import pack, compression_supported +from ray.rllib.utils.types import FileType, SampleBatchType +from typing import Any, List logger = logging.getLogger(__name__) @@ -26,10 +28,10 @@ class JsonWriter(OutputWriter): @PublicAPI def __init__(self, - path, - ioctx=None, - max_file_size=64 * 1024 * 1024, - compress_columns=frozenset(["obs", "new_obs"])): + path: str, + ioctx: IOContext = None, + max_file_size: int = 64 * 1024 * 1024, + compress_columns: List[str] = frozenset(["obs", "new_obs"])): """Initialize a JsonWriter. Arguments: @@ -59,7 +61,7 @@ class JsonWriter(OutputWriter): self.cur_file = None @override(OutputWriter) - def write(self, sample_batch): + def write(self, sample_batch: SampleBatchType): start = time.time() data = _to_json(sample_batch, self.compress_columns) f = self._get_file() @@ -72,7 +74,7 @@ class JsonWriter(OutputWriter): len(data), f, time.time() - start)) - def _get_file(self): + def _get_file(self) -> FileType: if not self.cur_file or self.bytes_written >= self.max_file_size: if self.cur_file: self.cur_file.close() @@ -94,7 +96,7 @@ class JsonWriter(OutputWriter): return self.cur_file -def _to_jsonable(v, compress): +def _to_jsonable(v, compress: bool) -> Any: if compress and compression_supported(): return str(pack(v)) elif isinstance(v, np.ndarray): @@ -102,7 +104,7 @@ def _to_jsonable(v, compress): return v -def _to_json(batch, compress_columns): +def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str: out = {} if isinstance(batch, MultiAgentBatch): out["type"] = "MultiAgentBatch" diff --git a/rllib/offline/mixed_input.py b/rllib/offline/mixed_input.py index 6c00b043c..45e4aa41a 100644 --- a/rllib/offline/mixed_input.py +++ b/rllib/offline/mixed_input.py @@ -2,7 +2,10 @@ import numpy as np from ray.rllib.offline.input_reader import InputReader from ray.rllib.offline.json_reader import JsonReader +from ray.rllib.offline.io_context import IOContext from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.types import SampleBatchType +from typing import Dict @DeveloperAPI @@ -18,7 +21,7 @@ class MixedInput(InputReader): """ @DeveloperAPI - def __init__(self, dist, ioctx): + def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext): """Initialize a MixedInput. Arguments: @@ -38,6 +41,6 @@ class MixedInput(InputReader): self.p.append(v) @override(InputReader) - def next(self): + def next(self) -> SampleBatchType: source = np.random.choice(self.choices, p=self.p) return source.next() diff --git a/rllib/offline/off_policy_estimator.py b/rllib/offline/off_policy_estimator.py index 9030b2c8d..c0c1fa849 100644 --- a/rllib/offline/off_policy_estimator.py +++ b/rllib/offline/off_policy_estimator.py @@ -2,7 +2,11 @@ from collections import namedtuple import logging from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.policy import Policy from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.offline.io_context import IOContext +from ray.rllib.utils.types import TensorType, SampleBatchType +from typing import List logger = logging.getLogger(__name__) @@ -15,7 +19,7 @@ class OffPolicyEstimator: """Interface for an off policy reward estimator.""" @DeveloperAPI - def __init__(self, policy, gamma): + def __init__(self, policy: Policy, gamma: float): """Creates an off-policy estimator. Arguments: @@ -27,7 +31,7 @@ class OffPolicyEstimator: self.new_estimates = [] @classmethod - def create(cls, ioctx): + def create(cls, ioctx: IOContext) -> "OffPolicyEstimator": """Create an off-policy estimator from a IOContext.""" gamma = ioctx.worker.policy_config["gamma"] # Grab a reference to the current model @@ -40,7 +44,7 @@ class OffPolicyEstimator: return cls(policy, gamma) @DeveloperAPI - def estimate(self, batch): + def estimate(self, batch: SampleBatchType): """Returns an estimate for the given batch of experiences. The batch will only contain data from one episode, but it may only be @@ -49,7 +53,7 @@ class OffPolicyEstimator: raise NotImplementedError @DeveloperAPI - def action_prob(self, batch): + def action_prob(self, batch: SampleBatchType) -> TensorType: """Returns the probs for the batch actions for the current policy.""" num_state_inputs = 0 @@ -66,11 +70,11 @@ class OffPolicyEstimator: return log_likelihoods @DeveloperAPI - def process(self, batch): + def process(self, batch: SampleBatchType): self.new_estimates.append(self.estimate(batch)) @DeveloperAPI - def check_can_estimate_for(self, batch): + def check_can_estimate_for(self, batch: SampleBatchType): """Returns whether we can support OPE for this batch.""" if isinstance(batch, MultiAgentBatch): @@ -87,7 +91,7 @@ class OffPolicyEstimator: "`input_evaluation: []` to disable estimation.") @DeveloperAPI - def get_metrics(self): + def get_metrics(self) -> List[OffPolicyEstimate]: """Return a list of new episode metric estimates since the last call. Returns: diff --git a/rllib/offline/output_writer.py b/rllib/offline/output_writer.py index af7b92b66..cf9d0cc80 100644 --- a/rllib/offline/output_writer.py +++ b/rllib/offline/output_writer.py @@ -1,5 +1,6 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.types import SampleBatchType @PublicAPI @@ -7,7 +8,7 @@ class OutputWriter: """Writer object for saving experiences from policy evaluation.""" @PublicAPI - def write(self, sample_batch): + def write(self, sample_batch: SampleBatchType): """Save a batch of experiences. Arguments: @@ -20,5 +21,5 @@ class NoopOutput(OutputWriter): """Output writer that discards its outputs.""" @override(OutputWriter) - def write(self, sample_batch): + def write(self, sample_batch: SampleBatchType): pass diff --git a/rllib/offline/shuffled_input.py b/rllib/offline/shuffled_input.py index a29f03e1c..10dfc8cb7 100644 --- a/rllib/offline/shuffled_input.py +++ b/rllib/offline/shuffled_input.py @@ -3,6 +3,7 @@ import random from ray.rllib.offline.input_reader import InputReader from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.types import SampleBatchType logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ class ShuffledInput(InputReader): """ @DeveloperAPI - def __init__(self, child, n=0): + def __init__(self, child: InputReader, n: int = 0): """Initialize a MixedInput. Arguments: @@ -28,7 +29,7 @@ class ShuffledInput(InputReader): self.buffer = [] @override(InputReader) - def next(self): + def next(self) -> SampleBatchType: if self.n <= 1: return self.child.next() if len(self.buffer) < self.n: diff --git a/rllib/offline/wis_estimator.py b/rllib/offline/wis_estimator.py index d4d36490c..e1bb156bb 100644 --- a/rllib/offline/wis_estimator.py +++ b/rllib/offline/wis_estimator.py @@ -1,6 +1,8 @@ from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \ OffPolicyEstimate +from ray.rllib.policy import Policy from ray.rllib.utils.annotations import override +from ray.rllib.utils.types import SampleBatchType class WeightedImportanceSamplingEstimator(OffPolicyEstimator): @@ -8,13 +10,13 @@ class WeightedImportanceSamplingEstimator(OffPolicyEstimator): Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf""" - def __init__(self, policy, gamma): + def __init__(self, policy: Policy, gamma: float): super().__init__(policy, gamma) self.filter_values = [] self.filter_counts = [] @override(OffPolicyEstimator) - def estimate(self, batch): + def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate: self.check_can_estimate_for(batch) rewards, old_prob = batch["rewards"], batch["action_prob"] diff --git a/rllib/utils/types.py b/rllib/utils/types.py index 3f2e02f89..9bba67b1a 100644 --- a/rllib/utils/types.py +++ b/rllib/utils/types.py @@ -52,6 +52,9 @@ EnvActionType = Any # Info dictionary returned by calling step() on gym envs. Commonly empty dict. EnvInfoDict = dict +# Represents a File object +FileType = Any + # Represents the result dict returned by Trainer.train(). ResultDict = dict