[RLlib] Offline Type Annotations (#9676)

* Offline Annotations

* Modifications

* Fixed circular dependencies

* Linter fix
This commit is contained in:
Michael Luo
2020-07-27 14:01:17 -07:00
committed by GitHub
parent 2e9d748100
commit b51ab2af66
11 changed files with 64 additions and 36 deletions
+6 -3
View File
@@ -5,6 +5,8 @@ import threading
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.framework import try_import_tf
from typing import Dict, List
from ray.rllib.utils.types import TensorType, SampleBatchType
tf1, tf, tfv = try_import_tf()
@@ -25,7 +27,7 @@ class InputReader:
raise NotImplementedError
@PublicAPI
def tf_input_ops(self, queue_size=1):
def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]:
"""Returns TensorFlow queue ops for reading inputs from this reader.
The main use of these ops is for integration into custom model losses.
@@ -90,7 +92,8 @@ class InputReader:
class _QueueRunner(threading.Thread):
"""Thread that feeds a TF queue from a InputReader."""
def __init__(self, input_reader, queue, keys, dtypes):
def __init__(self, input_reader: InputReader, queue: tf1.FIFOQueue,
keys: List[str], dtypes: "tf.dtypes.DType"):
threading.Thread.__init__(self)
self.sess = tf1.get_default_session()
self.daemon = True
@@ -100,7 +103,7 @@ class _QueueRunner(threading.Thread):
self.placeholders = [tf1.placeholder(dtype) for dtype in dtypes]
self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders)))
def enqueue(self, batch):
def enqueue(self, batch: SampleBatchType):
data = {
self.placeholders[i]: batch[key]
for i, key in enumerate(self.keys)
+7 -2
View File
@@ -1,6 +1,7 @@
import os
from ray.rllib.utils.annotations import PublicAPI
from typing import Any
@PublicAPI
@@ -18,12 +19,16 @@ class IOContext:
"""
@PublicAPI
def __init__(self, log_dir=None, config=None, worker_index=0, worker=None):
def __init__(self,
log_dir: str = None,
config: dict = None,
worker_index: int = 0,
worker: Any = None):
self.log_dir = log_dir or os.getcwd()
self.config = config or {}
self.worker_index = worker_index
self.worker = worker
@PublicAPI
def default_sampler_input(self):
def default_sampler_input(self) -> Any:
return self.worker.sampler
+2 -1
View File
@@ -1,6 +1,7 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.utils.annotations import override
from ray.rllib.utils.types import SampleBatchType
class ImportanceSamplingEstimator(OffPolicyEstimator):
@@ -9,7 +10,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator):
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
@override(OffPolicyEstimator)
def estimate(self, batch):
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
+10 -7
View File
@@ -16,6 +16,8 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \
DEFAULT_POLICY_ID
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.types import FileType, SampleBatchType
from typing import List
logger = logging.getLogger(__name__)
@@ -27,7 +29,7 @@ class JsonReader(InputReader):
The input files will be read from in an random order."""
@PublicAPI
def __init__(self, inputs, ioctx=None):
def __init__(self, inputs: List[str], ioctx: IOContext = None):
"""Initialize a JsonReader.
Arguments:
@@ -63,7 +65,7 @@ class JsonReader(InputReader):
self.cur_file = None
@override(InputReader)
def next(self):
def next(self) -> SampleBatchType:
batch = self._try_parse(self._next_line())
tries = 0
while not batch and tries < 100:
@@ -76,7 +78,8 @@ class JsonReader(InputReader):
self.cur_file))
return self._postprocess_if_needed(batch)
def _postprocess_if_needed(self, batch):
def _postprocess_if_needed(self,
batch: SampleBatchType) -> SampleBatchType:
if not self.ioctx.config.get("postprocess_inputs"):
return batch
@@ -92,7 +95,7 @@ class JsonReader(InputReader):
raise NotImplementedError(
"Postprocessing of multi-agent data not implemented yet.")
def _try_parse(self, line):
def _try_parse(self, line: str) -> SampleBatchType:
line = line.strip()
if not line:
return None
@@ -103,7 +106,7 @@ class JsonReader(InputReader):
self.cur_file, line))
return None
def _next_line(self):
def _next_line(self) -> str:
if not self.cur_file:
self.cur_file = self._next_file()
line = self.cur_file.readline()
@@ -121,7 +124,7 @@ class JsonReader(InputReader):
self.files))
return line
def _next_file(self):
def _next_file(self) -> FileType:
path = random.choice(self.files)
if urlparse(path).scheme not in ["", "c"]:
if smart_open is None:
@@ -133,7 +136,7 @@ class JsonReader(InputReader):
return open(path, "r")
def _from_json(batch):
def _from_json(batch: str) -> SampleBatchType:
if isinstance(batch, bytes): # smart_open S3 doesn't respect "r"
batch = batch.decode("utf-8")
data = json.loads(batch)
+10 -8
View File
@@ -16,6 +16,8 @@ from ray.rllib.offline.io_context import IOContext
from ray.rllib.offline.output_writer import OutputWriter
from ray.rllib.utils.annotations import override, PublicAPI
from ray.rllib.utils.compression import pack, compression_supported
from ray.rllib.utils.types import FileType, SampleBatchType
from typing import Any, List
logger = logging.getLogger(__name__)
@@ -26,10 +28,10 @@ class JsonWriter(OutputWriter):
@PublicAPI
def __init__(self,
path,
ioctx=None,
max_file_size=64 * 1024 * 1024,
compress_columns=frozenset(["obs", "new_obs"])):
path: str,
ioctx: IOContext = None,
max_file_size: int = 64 * 1024 * 1024,
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
"""Initialize a JsonWriter.
Arguments:
@@ -59,7 +61,7 @@ class JsonWriter(OutputWriter):
self.cur_file = None
@override(OutputWriter)
def write(self, sample_batch):
def write(self, sample_batch: SampleBatchType):
start = time.time()
data = _to_json(sample_batch, self.compress_columns)
f = self._get_file()
@@ -72,7 +74,7 @@ class JsonWriter(OutputWriter):
len(data), f,
time.time() - start))
def _get_file(self):
def _get_file(self) -> FileType:
if not self.cur_file or self.bytes_written >= self.max_file_size:
if self.cur_file:
self.cur_file.close()
@@ -94,7 +96,7 @@ class JsonWriter(OutputWriter):
return self.cur_file
def _to_jsonable(v, compress):
def _to_jsonable(v, compress: bool) -> Any:
if compress and compression_supported():
return str(pack(v))
elif isinstance(v, np.ndarray):
@@ -102,7 +104,7 @@ def _to_jsonable(v, compress):
return v
def _to_json(batch, compress_columns):
def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
out = {}
if isinstance(batch, MultiAgentBatch):
out["type"] = "MultiAgentBatch"
+5 -2
View File
@@ -2,7 +2,10 @@ import numpy as np
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.offline.json_reader import JsonReader
from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.types import SampleBatchType
from typing import Dict
@DeveloperAPI
@@ -18,7 +21,7 @@ class MixedInput(InputReader):
"""
@DeveloperAPI
def __init__(self, dist, ioctx):
def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext):
"""Initialize a MixedInput.
Arguments:
@@ -38,6 +41,6 @@ class MixedInput(InputReader):
self.p.append(v)
@override(InputReader)
def next(self):
def next(self) -> SampleBatchType:
source = np.random.choice(self.choices, p=self.p)
return source.next()
+11 -7
View File
@@ -2,7 +2,11 @@ from collections import namedtuple
import logging
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import DeveloperAPI
from ray.rllib.offline.io_context import IOContext
from ray.rllib.utils.types import TensorType, SampleBatchType
from typing import List
logger = logging.getLogger(__name__)
@@ -15,7 +19,7 @@ class OffPolicyEstimator:
"""Interface for an off policy reward estimator."""
@DeveloperAPI
def __init__(self, policy, gamma):
def __init__(self, policy: Policy, gamma: float):
"""Creates an off-policy estimator.
Arguments:
@@ -27,7 +31,7 @@ class OffPolicyEstimator:
self.new_estimates = []
@classmethod
def create(cls, ioctx):
def create(cls, ioctx: IOContext) -> "OffPolicyEstimator":
"""Create an off-policy estimator from a IOContext."""
gamma = ioctx.worker.policy_config["gamma"]
# Grab a reference to the current model
@@ -40,7 +44,7 @@ class OffPolicyEstimator:
return cls(policy, gamma)
@DeveloperAPI
def estimate(self, batch):
def estimate(self, batch: SampleBatchType):
"""Returns an estimate for the given batch of experiences.
The batch will only contain data from one episode, but it may only be
@@ -49,7 +53,7 @@ class OffPolicyEstimator:
raise NotImplementedError
@DeveloperAPI
def action_prob(self, batch):
def action_prob(self, batch: SampleBatchType) -> TensorType:
"""Returns the probs for the batch actions for the current policy."""
num_state_inputs = 0
@@ -66,11 +70,11 @@ class OffPolicyEstimator:
return log_likelihoods
@DeveloperAPI
def process(self, batch):
def process(self, batch: SampleBatchType):
self.new_estimates.append(self.estimate(batch))
@DeveloperAPI
def check_can_estimate_for(self, batch):
def check_can_estimate_for(self, batch: SampleBatchType):
"""Returns whether we can support OPE for this batch."""
if isinstance(batch, MultiAgentBatch):
@@ -87,7 +91,7 @@ class OffPolicyEstimator:
"`input_evaluation: []` to disable estimation.")
@DeveloperAPI
def get_metrics(self):
def get_metrics(self) -> List[OffPolicyEstimate]:
"""Return a list of new episode metric estimates since the last call.
Returns:
+3 -2
View File
@@ -1,5 +1,6 @@
from ray.rllib.utils.annotations import override
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.types import SampleBatchType
@PublicAPI
@@ -7,7 +8,7 @@ class OutputWriter:
"""Writer object for saving experiences from policy evaluation."""
@PublicAPI
def write(self, sample_batch):
def write(self, sample_batch: SampleBatchType):
"""Save a batch of experiences.
Arguments:
@@ -20,5 +21,5 @@ class NoopOutput(OutputWriter):
"""Output writer that discards its outputs."""
@override(OutputWriter)
def write(self, sample_batch):
def write(self, sample_batch: SampleBatchType):
pass
+3 -2
View File
@@ -3,6 +3,7 @@ import random
from ray.rllib.offline.input_reader import InputReader
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.types import SampleBatchType
logger = logging.getLogger(__name__)
@@ -16,7 +17,7 @@ class ShuffledInput(InputReader):
"""
@DeveloperAPI
def __init__(self, child, n=0):
def __init__(self, child: InputReader, n: int = 0):
"""Initialize a MixedInput.
Arguments:
@@ -28,7 +29,7 @@ class ShuffledInput(InputReader):
self.buffer = []
@override(InputReader)
def next(self):
def next(self) -> SampleBatchType:
if self.n <= 1:
return self.child.next()
if len(self.buffer) < self.n:
+4 -2
View File
@@ -1,6 +1,8 @@
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
OffPolicyEstimate
from ray.rllib.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.utils.types import SampleBatchType
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
@@ -8,13 +10,13 @@ class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
def __init__(self, policy, gamma):
def __init__(self, policy: Policy, gamma: float):
super().__init__(policy, gamma)
self.filter_values = []
self.filter_counts = []
@override(OffPolicyEstimator)
def estimate(self, batch):
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
self.check_can_estimate_for(batch)
rewards, old_prob = batch["rewards"], batch["action_prob"]
+3
View File
@@ -52,6 +52,9 @@ EnvActionType = Any
# Info dictionary returned by calling step() on gym envs. Commonly empty dict.
EnvInfoDict = dict
# Represents a File object
FileType = Any
# Represents the result dict returned by Trainer.train().
ResultDict = dict