mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:32:11 +08:00
[RLlib] Offline Type Annotations (#9676)
* Offline Annotations * Modifications * Fixed circular dependencies * Linter fix
This commit is contained in:
@@ -5,6 +5,8 @@ import threading
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from typing import Dict, List
|
||||
from ray.rllib.utils.types import TensorType, SampleBatchType
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
|
||||
@@ -25,7 +27,7 @@ class InputReader:
|
||||
raise NotImplementedError
|
||||
|
||||
@PublicAPI
|
||||
def tf_input_ops(self, queue_size=1):
|
||||
def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]:
|
||||
"""Returns TensorFlow queue ops for reading inputs from this reader.
|
||||
|
||||
The main use of these ops is for integration into custom model losses.
|
||||
@@ -90,7 +92,8 @@ class InputReader:
|
||||
class _QueueRunner(threading.Thread):
|
||||
"""Thread that feeds a TF queue from a InputReader."""
|
||||
|
||||
def __init__(self, input_reader, queue, keys, dtypes):
|
||||
def __init__(self, input_reader: InputReader, queue: tf1.FIFOQueue,
|
||||
keys: List[str], dtypes: "tf.dtypes.DType"):
|
||||
threading.Thread.__init__(self)
|
||||
self.sess = tf1.get_default_session()
|
||||
self.daemon = True
|
||||
@@ -100,7 +103,7 @@ class _QueueRunner(threading.Thread):
|
||||
self.placeholders = [tf1.placeholder(dtype) for dtype in dtypes]
|
||||
self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders)))
|
||||
|
||||
def enqueue(self, batch):
|
||||
def enqueue(self, batch: SampleBatchType):
|
||||
data = {
|
||||
self.placeholders[i]: batch[key]
|
||||
for i, key in enumerate(self.keys)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from typing import Any
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@@ -18,12 +19,16 @@ class IOContext:
|
||||
"""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, log_dir=None, config=None, worker_index=0, worker=None):
|
||||
def __init__(self,
|
||||
log_dir: str = None,
|
||||
config: dict = None,
|
||||
worker_index: int = 0,
|
||||
worker: Any = None):
|
||||
self.log_dir = log_dir or os.getcwd()
|
||||
self.config = config or {}
|
||||
self.worker_index = worker_index
|
||||
self.worker = worker
|
||||
|
||||
@PublicAPI
|
||||
def default_sampler_input(self):
|
||||
def default_sampler_input(self) -> Any:
|
||||
return self.worker.sampler
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
|
||||
|
||||
class ImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
@@ -9,7 +10,7 @@ class ImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
Step-wise IS estimator described in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch):
|
||||
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
|
||||
@@ -16,6 +16,8 @@ from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
from ray.rllib.utils.types import FileType, SampleBatchType
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,7 +29,7 @@ class JsonReader(InputReader):
|
||||
The input files will be read from in an random order."""
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self, inputs, ioctx=None):
|
||||
def __init__(self, inputs: List[str], ioctx: IOContext = None):
|
||||
"""Initialize a JsonReader.
|
||||
|
||||
Arguments:
|
||||
@@ -63,7 +65,7 @@ class JsonReader(InputReader):
|
||||
self.cur_file = None
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
def next(self) -> SampleBatchType:
|
||||
batch = self._try_parse(self._next_line())
|
||||
tries = 0
|
||||
while not batch and tries < 100:
|
||||
@@ -76,7 +78,8 @@ class JsonReader(InputReader):
|
||||
self.cur_file))
|
||||
return self._postprocess_if_needed(batch)
|
||||
|
||||
def _postprocess_if_needed(self, batch):
|
||||
def _postprocess_if_needed(self,
|
||||
batch: SampleBatchType) -> SampleBatchType:
|
||||
if not self.ioctx.config.get("postprocess_inputs"):
|
||||
return batch
|
||||
|
||||
@@ -92,7 +95,7 @@ class JsonReader(InputReader):
|
||||
raise NotImplementedError(
|
||||
"Postprocessing of multi-agent data not implemented yet.")
|
||||
|
||||
def _try_parse(self, line):
|
||||
def _try_parse(self, line: str) -> SampleBatchType:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
return None
|
||||
@@ -103,7 +106,7 @@ class JsonReader(InputReader):
|
||||
self.cur_file, line))
|
||||
return None
|
||||
|
||||
def _next_line(self):
|
||||
def _next_line(self) -> str:
|
||||
if not self.cur_file:
|
||||
self.cur_file = self._next_file()
|
||||
line = self.cur_file.readline()
|
||||
@@ -121,7 +124,7 @@ class JsonReader(InputReader):
|
||||
self.files))
|
||||
return line
|
||||
|
||||
def _next_file(self):
|
||||
def _next_file(self) -> FileType:
|
||||
path = random.choice(self.files)
|
||||
if urlparse(path).scheme not in ["", "c"]:
|
||||
if smart_open is None:
|
||||
@@ -133,7 +136,7 @@ class JsonReader(InputReader):
|
||||
return open(path, "r")
|
||||
|
||||
|
||||
def _from_json(batch):
|
||||
def _from_json(batch: str) -> SampleBatchType:
|
||||
if isinstance(batch, bytes): # smart_open S3 doesn't respect "r"
|
||||
batch = batch.decode("utf-8")
|
||||
data = json.loads(batch)
|
||||
|
||||
@@ -16,6 +16,8 @@ from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.output_writer import OutputWriter
|
||||
from ray.rllib.utils.annotations import override, PublicAPI
|
||||
from ray.rllib.utils.compression import pack, compression_supported
|
||||
from ray.rllib.utils.types import FileType, SampleBatchType
|
||||
from typing import Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,10 +28,10 @@ class JsonWriter(OutputWriter):
|
||||
|
||||
@PublicAPI
|
||||
def __init__(self,
|
||||
path,
|
||||
ioctx=None,
|
||||
max_file_size=64 * 1024 * 1024,
|
||||
compress_columns=frozenset(["obs", "new_obs"])):
|
||||
path: str,
|
||||
ioctx: IOContext = None,
|
||||
max_file_size: int = 64 * 1024 * 1024,
|
||||
compress_columns: List[str] = frozenset(["obs", "new_obs"])):
|
||||
"""Initialize a JsonWriter.
|
||||
|
||||
Arguments:
|
||||
@@ -59,7 +61,7 @@ class JsonWriter(OutputWriter):
|
||||
self.cur_file = None
|
||||
|
||||
@override(OutputWriter)
|
||||
def write(self, sample_batch):
|
||||
def write(self, sample_batch: SampleBatchType):
|
||||
start = time.time()
|
||||
data = _to_json(sample_batch, self.compress_columns)
|
||||
f = self._get_file()
|
||||
@@ -72,7 +74,7 @@ class JsonWriter(OutputWriter):
|
||||
len(data), f,
|
||||
time.time() - start))
|
||||
|
||||
def _get_file(self):
|
||||
def _get_file(self) -> FileType:
|
||||
if not self.cur_file or self.bytes_written >= self.max_file_size:
|
||||
if self.cur_file:
|
||||
self.cur_file.close()
|
||||
@@ -94,7 +96,7 @@ class JsonWriter(OutputWriter):
|
||||
return self.cur_file
|
||||
|
||||
|
||||
def _to_jsonable(v, compress):
|
||||
def _to_jsonable(v, compress: bool) -> Any:
|
||||
if compress and compression_supported():
|
||||
return str(pack(v))
|
||||
elif isinstance(v, np.ndarray):
|
||||
@@ -102,7 +104,7 @@ def _to_jsonable(v, compress):
|
||||
return v
|
||||
|
||||
|
||||
def _to_json(batch, compress_columns):
|
||||
def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str:
|
||||
out = {}
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
out["type"] = "MultiAgentBatch"
|
||||
|
||||
@@ -2,7 +2,10 @@ import numpy as np
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.offline.json_reader import JsonReader
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -18,7 +21,7 @@ class MixedInput(InputReader):
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, dist, ioctx):
|
||||
def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
Arguments:
|
||||
@@ -38,6 +41,6 @@ class MixedInput(InputReader):
|
||||
self.p.append(v)
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
def next(self) -> SampleBatchType:
|
||||
source = np.random.choice(self.choices, p=self.p)
|
||||
return source.next()
|
||||
|
||||
@@ -2,7 +2,11 @@ from collections import namedtuple
|
||||
import logging
|
||||
|
||||
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import DeveloperAPI
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.utils.types import TensorType, SampleBatchType
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -15,7 +19,7 @@ class OffPolicyEstimator:
|
||||
"""Interface for an off policy reward estimator."""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, policy, gamma):
|
||||
def __init__(self, policy: Policy, gamma: float):
|
||||
"""Creates an off-policy estimator.
|
||||
|
||||
Arguments:
|
||||
@@ -27,7 +31,7 @@ class OffPolicyEstimator:
|
||||
self.new_estimates = []
|
||||
|
||||
@classmethod
|
||||
def create(cls, ioctx):
|
||||
def create(cls, ioctx: IOContext) -> "OffPolicyEstimator":
|
||||
"""Create an off-policy estimator from a IOContext."""
|
||||
gamma = ioctx.worker.policy_config["gamma"]
|
||||
# Grab a reference to the current model
|
||||
@@ -40,7 +44,7 @@ class OffPolicyEstimator:
|
||||
return cls(policy, gamma)
|
||||
|
||||
@DeveloperAPI
|
||||
def estimate(self, batch):
|
||||
def estimate(self, batch: SampleBatchType):
|
||||
"""Returns an estimate for the given batch of experiences.
|
||||
|
||||
The batch will only contain data from one episode, but it may only be
|
||||
@@ -49,7 +53,7 @@ class OffPolicyEstimator:
|
||||
raise NotImplementedError
|
||||
|
||||
@DeveloperAPI
|
||||
def action_prob(self, batch):
|
||||
def action_prob(self, batch: SampleBatchType) -> TensorType:
|
||||
"""Returns the probs for the batch actions for the current policy."""
|
||||
|
||||
num_state_inputs = 0
|
||||
@@ -66,11 +70,11 @@ class OffPolicyEstimator:
|
||||
return log_likelihoods
|
||||
|
||||
@DeveloperAPI
|
||||
def process(self, batch):
|
||||
def process(self, batch: SampleBatchType):
|
||||
self.new_estimates.append(self.estimate(batch))
|
||||
|
||||
@DeveloperAPI
|
||||
def check_can_estimate_for(self, batch):
|
||||
def check_can_estimate_for(self, batch: SampleBatchType):
|
||||
"""Returns whether we can support OPE for this batch."""
|
||||
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
@@ -87,7 +91,7 @@ class OffPolicyEstimator:
|
||||
"`input_evaluation: []` to disable estimation.")
|
||||
|
||||
@DeveloperAPI
|
||||
def get_metrics(self):
|
||||
def get_metrics(self) -> List[OffPolicyEstimate]:
|
||||
"""Return a list of new episode metric estimates since the last call.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
|
||||
|
||||
@PublicAPI
|
||||
@@ -7,7 +8,7 @@ class OutputWriter:
|
||||
"""Writer object for saving experiences from policy evaluation."""
|
||||
|
||||
@PublicAPI
|
||||
def write(self, sample_batch):
|
||||
def write(self, sample_batch: SampleBatchType):
|
||||
"""Save a batch of experiences.
|
||||
|
||||
Arguments:
|
||||
@@ -20,5 +21,5 @@ class NoopOutput(OutputWriter):
|
||||
"""Output writer that discards its outputs."""
|
||||
|
||||
@override(OutputWriter)
|
||||
def write(self, sample_batch):
|
||||
def write(self, sample_batch: SampleBatchType):
|
||||
pass
|
||||
|
||||
@@ -3,6 +3,7 @@ import random
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.utils.annotations import override, DeveloperAPI
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,7 +17,7 @@ class ShuffledInput(InputReader):
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
def __init__(self, child, n=0):
|
||||
def __init__(self, child: InputReader, n: int = 0):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
Arguments:
|
||||
@@ -28,7 +29,7 @@ class ShuffledInput(InputReader):
|
||||
self.buffer = []
|
||||
|
||||
@override(InputReader)
|
||||
def next(self):
|
||||
def next(self) -> SampleBatchType:
|
||||
if self.n <= 1:
|
||||
return self.child.next()
|
||||
if len(self.buffer) < self.n:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from ray.rllib.offline.off_policy_estimator import OffPolicyEstimator, \
|
||||
OffPolicyEstimate
|
||||
from ray.rllib.policy import Policy
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.types import SampleBatchType
|
||||
|
||||
|
||||
class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
@@ -8,13 +10,13 @@ class WeightedImportanceSamplingEstimator(OffPolicyEstimator):
|
||||
|
||||
Step-wise WIS estimator in https://arxiv.org/pdf/1511.03722.pdf"""
|
||||
|
||||
def __init__(self, policy, gamma):
|
||||
def __init__(self, policy: Policy, gamma: float):
|
||||
super().__init__(policy, gamma)
|
||||
self.filter_values = []
|
||||
self.filter_counts = []
|
||||
|
||||
@override(OffPolicyEstimator)
|
||||
def estimate(self, batch):
|
||||
def estimate(self, batch: SampleBatchType) -> OffPolicyEstimate:
|
||||
self.check_can_estimate_for(batch)
|
||||
|
||||
rewards, old_prob = batch["rewards"], batch["action_prob"]
|
||||
|
||||
@@ -52,6 +52,9 @@ EnvActionType = Any
|
||||
# Info dictionary returned by calling step() on gym envs. Commonly empty dict.
|
||||
EnvInfoDict = dict
|
||||
|
||||
# Represents a File object
|
||||
FileType = Any
|
||||
|
||||
# Represents the result dict returned by Trainer.train().
|
||||
ResultDict = dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user