mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[rllib] Documentation for I/O API and multi-agent support / cleanup (#3650)
This commit is contained in:
@@ -130,7 +130,8 @@ COMMON_CONFIG = {
|
||||
# Drop metric batches from unresponsive workers after this many seconds
|
||||
"collect_metrics_timeout": 180,
|
||||
|
||||
# === Offline Data Input / Output (Experimental) ===
|
||||
# === Offline Data Input / Output ===
|
||||
# __sphinx_doc_input_begin__
|
||||
# Specify how to generate experiences:
|
||||
# - "sampler": generate experiences via online simulation (default)
|
||||
# - a local directory or file glob expression (e.g., "/tmp/*.json")
|
||||
@@ -146,9 +147,14 @@ COMMON_CONFIG = {
|
||||
# metrics will be NaN if using offline data.
|
||||
# - "simulation": run the environment in the background, but use
|
||||
# this data for evaluation only and not for learning.
|
||||
# - "counterfactual": use counterfactual policy evaluation to estimate
|
||||
# performance (this option is not implemented yet).
|
||||
"input_evaluation": None,
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Note that postprocessing will be done using the *current*
|
||||
# policy, not the *behaviour* policy, which is typically undesirable for
|
||||
# on-policy algorithms.
|
||||
"postprocess_inputs": False,
|
||||
# __sphinx_doc_input_end__
|
||||
# __sphinx_doc_output_begin__
|
||||
# Specify where experiences should be saved:
|
||||
# - None: don't save any experiences
|
||||
# - "logdir" to save to the agent log dir
|
||||
@@ -159,10 +165,7 @@ COMMON_CONFIG = {
|
||||
"output_compress_columns": ["obs", "new_obs"],
|
||||
# Max output file size before rolling over to a new file.
|
||||
"output_max_file_size": 64 * 1024 * 1024,
|
||||
# Whether to run postprocess_trajectory() on the trajectory fragments from
|
||||
# offline inputs. Whether this makes sense is algorithm-specific.
|
||||
# TODO(ekl) implement this and multi-agent batch handling
|
||||
# "postprocess_inputs": False,
|
||||
# __sphinx_doc_output_end__
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
@@ -503,9 +506,9 @@ class Agent(Trainable):
|
||||
elif config["input"] == "sampler":
|
||||
input_creator = (lambda ioctx: ioctx.default_sampler_input())
|
||||
elif isinstance(config["input"], dict):
|
||||
input_creator = (lambda ioctx: MixedInput(ioctx, config["input"]))
|
||||
input_creator = (lambda ioctx: MixedInput(config["input"], ioctx))
|
||||
else:
|
||||
input_creator = (lambda ioctx: JsonReader(ioctx, config["input"]))
|
||||
input_creator = (lambda ioctx: JsonReader(config["input"], ioctx))
|
||||
|
||||
if isinstance(config["output"], FunctionType):
|
||||
output_creator = config["output"]
|
||||
@@ -513,14 +516,14 @@ class Agent(Trainable):
|
||||
output_creator = (lambda ioctx: NoopOutput())
|
||||
elif config["output"] == "logdir":
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx,
|
||||
ioctx.log_dir,
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
else:
|
||||
output_creator = (lambda ioctx: JsonWriter(
|
||||
ioctx,
|
||||
config["output"],
|
||||
ioctx,
|
||||
max_file_size=config["output_max_file_size"],
|
||||
compress_columns=config["output_compress_columns"]))
|
||||
|
||||
|
||||
@@ -187,8 +187,6 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
other metrics will be NaN.
|
||||
- "simulation": run the environment in the background, but
|
||||
use this data for evaluation only and never for learning.
|
||||
- "counterfactual": use counterfactual policy evaluation to
|
||||
estimate performance.
|
||||
output_creator (func): Function that returns an OutputWriter object
|
||||
for saving generated experiences.
|
||||
"""
|
||||
@@ -309,8 +307,6 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
"Requested 'simulation' input evaluation method: "
|
||||
"will discard all sampler outputs and keep only metrics.")
|
||||
sample_async = True
|
||||
elif input_evaluation_method == "counterfactual":
|
||||
raise NotImplementedError
|
||||
elif input_evaluation_method is None:
|
||||
pass
|
||||
else:
|
||||
@@ -388,6 +384,10 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
"samples": batch
|
||||
})
|
||||
|
||||
# Always do writes prior to compression for consistency and to allow
|
||||
# for better compression inside the writer.
|
||||
self.output_writer.write(batch)
|
||||
|
||||
if self.compress_observations:
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
for data in batch.policy_batches.values():
|
||||
@@ -397,7 +397,6 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
batch["obs"] = [pack(o) for o in batch["obs"]]
|
||||
batch["new_obs"] = [pack(o) for o in batch["new_obs"]]
|
||||
|
||||
self.output_writer.write(batch)
|
||||
return batch
|
||||
|
||||
@ray.method(num_return_vals=2)
|
||||
|
||||
@@ -306,10 +306,48 @@ class SampleBatch(object):
|
||||
return out
|
||||
|
||||
def shuffle(self):
|
||||
"""Shuffles the rows of this batch in-place."""
|
||||
|
||||
permutation = np.random.permutation(self.count)
|
||||
for key, val in self.items():
|
||||
self[key] = val[permutation]
|
||||
|
||||
def split_by_episode(self):
|
||||
"""Splits this batch's data by `eps_id`.
|
||||
|
||||
Returns:
|
||||
list of SampleBatch, one per distinct episode.
|
||||
"""
|
||||
|
||||
slices = []
|
||||
cur_eps_id = self.data["eps_id"][0]
|
||||
offset = 0
|
||||
for i in range(self.count):
|
||||
next_eps_id = self.data["eps_id"][i]
|
||||
if next_eps_id != cur_eps_id:
|
||||
slices.append(self.slice(offset, i))
|
||||
offset = i
|
||||
cur_eps_id = next_eps_id
|
||||
slices.append(self.slice(offset, self.count))
|
||||
for s in slices:
|
||||
slen = len(set(s["eps_id"]))
|
||||
assert slen == 1, (s, slen)
|
||||
assert sum(s.count for s in slices) == self.count, (slices, self.count)
|
||||
return slices
|
||||
|
||||
def slice(self, start, end):
|
||||
"""Returns a slice of the row data of this batch.
|
||||
|
||||
Arguments:
|
||||
start (int): Starting index.
|
||||
end (int): Ending index.
|
||||
|
||||
Returns:
|
||||
SampleBatch which has a slice of this batch's data.
|
||||
"""
|
||||
|
||||
return SampleBatch({k: v[start:end] for k, v in self.data.items()})
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.data[key]
|
||||
|
||||
|
||||
@@ -175,10 +175,10 @@ if __name__ == "__main__":
|
||||
}
|
||||
elif args.run == "DQN":
|
||||
cfg = {
|
||||
"hiddens": [], # don't postprocess the action scores
|
||||
"hiddens": [], # important: don't postprocess the action scores
|
||||
}
|
||||
else:
|
||||
cfg = {}
|
||||
cfg = {} # PG, IMPALA, A2C, etc.
|
||||
run_experiments({
|
||||
"parametric_cartpole": {
|
||||
"run": args.run,
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
"""Simple example of writing experiences to a file using JsonWriter."""
|
||||
|
||||
# __sphinx_doc_begin__
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatchBuilder
|
||||
from ray.rllib.offline.json_writer import JsonWriter
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_builder = SampleBatchBuilder() # or MultiAgentSampleBatchBuilder
|
||||
writer = JsonWriter("/tmp/demo-out")
|
||||
|
||||
# You normally wouldn't want to manually create sample batches if a
|
||||
# simulator is available, but let's do it anyways for example purposes:
|
||||
env = gym.make("CartPole-v0")
|
||||
|
||||
for eps_id in range(100):
|
||||
obs = env.reset()
|
||||
prev_action = np.zeros_like(env.action_space.sample())
|
||||
prev_reward = 0
|
||||
done = False
|
||||
t = 0
|
||||
while not done:
|
||||
action = env.action_space.sample()
|
||||
new_obs, rew, done, info = env.step(action)
|
||||
batch_builder.add_values(
|
||||
t=t,
|
||||
eps_id=eps_id,
|
||||
agent_index=0,
|
||||
obs=obs,
|
||||
actions=action,
|
||||
rewards=rew,
|
||||
prev_actions=prev_action,
|
||||
prev_rewards=prev_reward,
|
||||
dones=done,
|
||||
infos=info,
|
||||
new_obs=new_obs)
|
||||
obs = new_obs
|
||||
prev_action = action
|
||||
prev_reward = rew
|
||||
t += 1
|
||||
writer.write(batch_builder.build_and_reset())
|
||||
# __sphinx_doc_end__
|
||||
@@ -9,8 +9,11 @@ class InputReader(object):
|
||||
"""Input object for loading experiences in policy evaluation."""
|
||||
|
||||
def next(self):
|
||||
"""Return the next batch of experiences read."""
|
||||
"""Return the next batch of experiences read.
|
||||
|
||||
Returns:
|
||||
SampleBatch or MultiAgentBatch read.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.rllib.offline.input_reader import SamplerInput
|
||||
|
||||
|
||||
@@ -18,9 +20,13 @@ class IOContext(object):
|
||||
evaluator (PolicyEvaluator): policy evaluator object reference.
|
||||
"""
|
||||
|
||||
def __init__(self, log_dir, config, worker_index, evaluator):
|
||||
self.log_dir = log_dir
|
||||
self.config = config
|
||||
def __init__(self,
|
||||
log_dir=None,
|
||||
config=None,
|
||||
worker_index=0,
|
||||
evaluator=None):
|
||||
self.log_dir = log_dir or os.getcwd()
|
||||
self.config = config or {}
|
||||
self.worker_index = worker_index
|
||||
self.evaluator = evaluator
|
||||
|
||||
|
||||
@@ -16,7 +16,9 @@ except ImportError:
|
||||
smart_open = None
|
||||
|
||||
from ray.rllib.offline.input_reader import InputReader
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, SampleBatch, \
|
||||
DEFAULT_POLICY_ID
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import unpack_if_needed
|
||||
|
||||
@@ -28,17 +30,17 @@ class JsonReader(InputReader):
|
||||
|
||||
The input files will be read from in an random order."""
|
||||
|
||||
def __init__(self, ioctx, inputs):
|
||||
def __init__(self, inputs, ioctx=None):
|
||||
"""Initialize a JsonReader.
|
||||
|
||||
Arguments:
|
||||
ioctx (IOContext): current IO context object.
|
||||
inputs (str|list): either a glob expression for files, e.g.,
|
||||
"/tmp/**/*.json", or a list of single file paths or URIs, e.g.,
|
||||
["s3://bucket/file.json", "s3://bucket/file2.json"].
|
||||
ioctx (IOContext): current IO context object.
|
||||
"""
|
||||
|
||||
self.ioctx = ioctx
|
||||
self.ioctx = ioctx or IOContext()
|
||||
if isinstance(inputs, six.string_types):
|
||||
if os.path.isdir(inputs):
|
||||
inputs = os.path.join(inputs, "*.json")
|
||||
@@ -74,7 +76,23 @@ class JsonReader(InputReader):
|
||||
raise ValueError(
|
||||
"Failed to read valid experience batch from file: {}".format(
|
||||
self.cur_file))
|
||||
return batch
|
||||
return self._postprocess_if_needed(batch)
|
||||
|
||||
def _postprocess_if_needed(self, batch):
|
||||
if not self.ioctx.config.get("postprocess_inputs"):
|
||||
return batch
|
||||
|
||||
if isinstance(batch, SampleBatch):
|
||||
out = []
|
||||
for sub_batch in batch.split_by_episode():
|
||||
out.append(self.ioctx.evaluator.policy_map[DEFAULT_POLICY_ID]
|
||||
.postprocess_trajectory(sub_batch))
|
||||
return SampleBatch.concat_samples(out)
|
||||
else:
|
||||
# TODO(ekl) this is trickier since the alignments between agent
|
||||
# trajectories in the episode are not available any more.
|
||||
raise NotImplementedError(
|
||||
"Postprocessing of multi-agent data not implemented yet.")
|
||||
|
||||
def _try_parse(self, line):
|
||||
line = line.strip()
|
||||
@@ -121,6 +139,25 @@ def _from_json(batch):
|
||||
if isinstance(batch, bytes): # smart_open S3 doesn't respect "r"
|
||||
batch = batch.decode("utf-8")
|
||||
data = json.loads(batch)
|
||||
for k, v in data.items():
|
||||
data[k] = [unpack_if_needed(x) for x in unpack_if_needed(v)]
|
||||
return SampleBatch(data)
|
||||
|
||||
if "type" in data:
|
||||
data_type = data.pop("type")
|
||||
else:
|
||||
raise ValueError("JSON record missing 'type' field")
|
||||
|
||||
if data_type == "SampleBatch":
|
||||
for k, v in data.items():
|
||||
data[k] = unpack_if_needed(v)
|
||||
return SampleBatch(data)
|
||||
elif data_type == "MultiAgentBatch":
|
||||
policy_batches = {}
|
||||
for policy_id, policy_batch in data["policy_batches"].items():
|
||||
inner = {}
|
||||
for k, v in policy_batch.items():
|
||||
inner[k] = unpack_if_needed(v)
|
||||
policy_batches[policy_id] = SampleBatch(inner)
|
||||
return MultiAgentBatch(policy_batches, data["count"])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Type field must be one of ['SampleBatch', 'MultiAgentBatch']",
|
||||
data_type)
|
||||
|
||||
@@ -15,6 +15,8 @@ try:
|
||||
except ImportError:
|
||||
smart_open = None
|
||||
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
from ray.rllib.offline.io_context import IOContext
|
||||
from ray.rllib.offline.output_writer import OutputWriter
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.compression import pack
|
||||
@@ -26,21 +28,21 @@ class JsonWriter(OutputWriter):
|
||||
"""Writer object that saves experiences in JSON file chunks."""
|
||||
|
||||
def __init__(self,
|
||||
ioctx,
|
||||
path,
|
||||
ioctx=None,
|
||||
max_file_size=64 * 1024 * 1024,
|
||||
compress_columns=frozenset(["obs", "new_obs"])):
|
||||
"""Initialize a JsonWriter.
|
||||
|
||||
Arguments:
|
||||
ioctx (IOContext): current IO context object.
|
||||
path (str): a path/URI of the output directory to save files in.
|
||||
ioctx (IOContext): current IO context object.
|
||||
max_file_size (int): max size of single files before rolling over.
|
||||
compress_columns (list): list of sample batch columns to compress.
|
||||
"""
|
||||
|
||||
self.ioctx = ioctx
|
||||
self.path = path
|
||||
self.ioctx = ioctx or IOContext()
|
||||
self.max_file_size = max_file_size
|
||||
self.compress_columns = compress_columns
|
||||
if urlparse(path).scheme:
|
||||
@@ -102,7 +104,19 @@ def _to_jsonable(v, compress):
|
||||
|
||||
|
||||
def _to_json(batch, compress_columns):
|
||||
return json.dumps({
|
||||
k: _to_jsonable(v, compress=k in compress_columns)
|
||||
for k, v in batch.data.items()
|
||||
})
|
||||
out = {}
|
||||
if isinstance(batch, MultiAgentBatch):
|
||||
out["type"] = "MultiAgentBatch"
|
||||
out["count"] = batch.count
|
||||
policy_batches = {}
|
||||
for policy_id, sub_batch in batch.policy_batches.items():
|
||||
policy_batches[policy_id] = {}
|
||||
for k, v in sub_batch.data.items():
|
||||
policy_batches[policy_id][k] = _to_jsonable(
|
||||
v, compress=k in compress_columns)
|
||||
out["policy_batches"] = policy_batches
|
||||
else:
|
||||
out["type"] = "SampleBatch"
|
||||
for k, v in batch.data.items():
|
||||
out[k] = _to_jsonable(v, compress=k in compress_columns)
|
||||
return json.dumps(out)
|
||||
|
||||
@@ -13,20 +13,20 @@ class MixedInput(InputReader):
|
||||
"""Mixes input from a number of other input sources.
|
||||
|
||||
Examples:
|
||||
>>> MixedInput(ioctx, {
|
||||
>>> MixedInput({
|
||||
"sampler": 0.4,
|
||||
"/tmp/experiences/*.json": 0.4,
|
||||
"s3://bucket/expert.json": 0.2,
|
||||
})
|
||||
}, ioctx)
|
||||
"""
|
||||
|
||||
def __init__(self, ioctx, dist):
|
||||
def __init__(self, dist, ioctx):
|
||||
"""Initialize a MixedInput.
|
||||
|
||||
Arguments:
|
||||
ioctx (IOContext): current IO context object.
|
||||
dist (dict): dict mapping JSONReader paths or "sampler" to
|
||||
probabilities. The probabilities must sum to 1.0.
|
||||
ioctx (IOContext): current IO context object.
|
||||
"""
|
||||
if sum(dist.values()) != 1.0:
|
||||
raise ValueError("Values must sum to 1.0: {}".format(dist))
|
||||
@@ -36,7 +36,7 @@ class MixedInput(InputReader):
|
||||
if k == "sampler":
|
||||
self.choices.append(ioctx.default_sampler_input())
|
||||
else:
|
||||
self.choices.append(JsonReader(ioctx, k))
|
||||
self.choices.append(JsonReader(k))
|
||||
self.p.append(v)
|
||||
|
||||
@override(InputReader)
|
||||
|
||||
@@ -21,7 +21,7 @@ if __name__ == '__main__':
|
||||
print(yaml.dump(experiments))
|
||||
|
||||
for i in range(3):
|
||||
trials = run_experiments(experiments)
|
||||
trials = run_experiments(experiments, resume=False)
|
||||
|
||||
num_failures = 0
|
||||
for t in trials:
|
||||
|
||||
@@ -3,8 +3,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import gym
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
@@ -12,13 +15,17 @@ import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg import PGAgent
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.evaluation import SampleBatch
|
||||
from ray.rllib.offline import IOContext, JsonWriter, JsonReader
|
||||
from ray.rllib.offline.json_writer import _to_json
|
||||
from ray.rllib.test.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
SAMPLES = SampleBatch({
|
||||
"actions": np.array([1, 2, 3]),
|
||||
"obs": np.array([4, 5, 6])
|
||||
"actions": np.array([1, 2, 3, 4]),
|
||||
"obs": np.array([4, 5, 6, 7]),
|
||||
"eps_id": [1, 1, 2, 3],
|
||||
})
|
||||
|
||||
|
||||
@@ -49,8 +56,7 @@ class AgentIOTest(unittest.TestCase):
|
||||
def testAgentOutputOk(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
reader = JsonReader(ioctx, self.test_dir + "/*.json")
|
||||
reader = JsonReader(self.test_dir + "/*.json")
|
||||
reader.next()
|
||||
|
||||
def testAgentOutputLogdir(self):
|
||||
@@ -69,6 +75,40 @@ class AgentIOTest(unittest.TestCase):
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testSplitByEpisode(self):
|
||||
splits = SAMPLES.split_by_episode()
|
||||
self.assertEqual(len(splits), 3)
|
||||
self.assertEqual(splits[0].count, 2)
|
||||
self.assertEqual(splits[1].count, 1)
|
||||
self.assertEqual(splits[2].count, 1)
|
||||
|
||||
def testAgentInputPostprocessingEnabled(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
|
||||
# Rewrite the files to drop advantages and value_targets for testing
|
||||
for path in glob.glob(self.test_dir + "/*.json"):
|
||||
out = []
|
||||
for line in open(path).readlines():
|
||||
data = json.loads(line)
|
||||
del data["advantages"]
|
||||
del data["value_targets"]
|
||||
out.append(data)
|
||||
with open(path, "w") as f:
|
||||
for data in out:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
agent = PGAgent(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": None,
|
||||
"postprocess_inputs": True, # adds back 'advantages'
|
||||
})
|
||||
|
||||
result = agent.train()
|
||||
self.assertEqual(result["timesteps_total"], 250) # read from input
|
||||
self.assertTrue(np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testAgentInputEvalSim(self):
|
||||
self.writeOutputs(self.test_dir)
|
||||
agent = PGAgent(
|
||||
@@ -112,6 +152,58 @@ class AgentIOTest(unittest.TestCase):
|
||||
result = agent.train()
|
||||
self.assertTrue(not np.isnan(result["episode_reward_mean"]))
|
||||
|
||||
def testMultiAgent(self):
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(10))
|
||||
single_env = gym.make("CartPole-v0")
|
||||
|
||||
def gen_policy():
|
||||
obs_space = single_env.observation_space
|
||||
act_space = single_env.action_space
|
||||
return (PGPolicyGraph, obs_space, act_space, {})
|
||||
|
||||
pg = PGAgent(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"output": self.test_dir,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
})
|
||||
pg.train()
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 1)
|
||||
|
||||
pg.stop()
|
||||
pg = PGAgent(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"num_workers": 0,
|
||||
"input": self.test_dir,
|
||||
"input_evaluation": "simulation",
|
||||
"train_batch_size": 2000,
|
||||
"multiagent": {
|
||||
"policy_graphs": {
|
||||
"policy_1": gen_policy(),
|
||||
"policy_2": gen_policy(),
|
||||
},
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(
|
||||
["policy_1", "policy_2"])),
|
||||
},
|
||||
})
|
||||
for _ in range(50):
|
||||
result = pg.train()
|
||||
if not np.isnan(result["episode_reward_mean"]):
|
||||
return # simulation ok
|
||||
time.sleep(0.1)
|
||||
assert False, "did not see any simulation results"
|
||||
|
||||
|
||||
class JsonIOTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -123,7 +215,7 @@ class JsonIOTest(unittest.TestCase):
|
||||
def testWriteSimple(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx, self.test_dir, max_file_size=1000, compress_columns=["obs"])
|
||||
self.test_dir, ioctx, max_file_size=1000, compress_columns=["obs"])
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 0)
|
||||
writer.write(SAMPLES)
|
||||
writer.write(SAMPLES)
|
||||
@@ -132,8 +224,8 @@ class JsonIOTest(unittest.TestCase):
|
||||
def testWriteFileURI(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx,
|
||||
"file:" + self.test_dir,
|
||||
ioctx,
|
||||
max_file_size=1000,
|
||||
compress_columns=["obs"])
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 0)
|
||||
@@ -144,7 +236,7 @@ class JsonIOTest(unittest.TestCase):
|
||||
def testWritePaginate(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"])
|
||||
self.test_dir, ioctx, max_file_size=5000, compress_columns=["obs"])
|
||||
self.assertEqual(len(os.listdir(self.test_dir)), 0)
|
||||
for _ in range(100):
|
||||
writer.write(SAMPLES)
|
||||
@@ -153,10 +245,10 @@ class JsonIOTest(unittest.TestCase):
|
||||
def testReadWrite(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
writer = JsonWriter(
|
||||
ioctx, self.test_dir, max_file_size=5000, compress_columns=["obs"])
|
||||
self.test_dir, ioctx, max_file_size=5000, compress_columns=["obs"])
|
||||
for i in range(100):
|
||||
writer.write(make_sample_batch(i))
|
||||
reader = JsonReader(ioctx, self.test_dir + "/*.json")
|
||||
reader = JsonReader(self.test_dir + "/*.json")
|
||||
seen_a = set()
|
||||
seen_o = set()
|
||||
for i in range(1000):
|
||||
@@ -169,7 +261,6 @@ class JsonIOTest(unittest.TestCase):
|
||||
self.assertLess(len(seen_o), 101)
|
||||
|
||||
def testSkipsOverEmptyLinesAndFiles(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
open(self.test_dir + "/empty", "w").close()
|
||||
with open(self.test_dir + "/f1", "w") as f:
|
||||
f.write("\n")
|
||||
@@ -178,7 +269,7 @@ class JsonIOTest(unittest.TestCase):
|
||||
with open(self.test_dir + "/f2", "w") as f:
|
||||
f.write(_to_json(make_sample_batch(1), []))
|
||||
f.write("\n")
|
||||
reader = JsonReader(ioctx, [
|
||||
reader = JsonReader([
|
||||
self.test_dir + "/empty",
|
||||
self.test_dir + "/f1",
|
||||
"file:" + self.test_dir + "/f2",
|
||||
@@ -190,7 +281,6 @@ class JsonIOTest(unittest.TestCase):
|
||||
self.assertEqual(len(seen_a), 2)
|
||||
|
||||
def testSkipsOverCorruptedLines(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
with open(self.test_dir + "/f1", "w") as f:
|
||||
f.write(_to_json(make_sample_batch(0), []))
|
||||
f.write("\n")
|
||||
@@ -201,7 +291,7 @@ class JsonIOTest(unittest.TestCase):
|
||||
f.write(_to_json(make_sample_batch(3), []))
|
||||
f.write("\n")
|
||||
f.write("{..corrupted_json_record")
|
||||
reader = JsonReader(ioctx, [
|
||||
reader = JsonReader([
|
||||
self.test_dir + "/f1",
|
||||
])
|
||||
seen_a = set()
|
||||
@@ -211,9 +301,8 @@ class JsonIOTest(unittest.TestCase):
|
||||
self.assertEqual(len(seen_a), 4)
|
||||
|
||||
def testAbortOnAllEmptyInputs(self):
|
||||
ioctx = IOContext(self.test_dir, {}, 0, None)
|
||||
open(self.test_dir + "/empty", "w").close()
|
||||
reader = JsonReader(ioctx, [
|
||||
reader = JsonReader([
|
||||
self.test_dir + "/empty",
|
||||
])
|
||||
self.assertRaises(ValueError, lambda: reader.next())
|
||||
@@ -223,7 +312,7 @@ class JsonIOTest(unittest.TestCase):
|
||||
with open(self.test_dir + "/empty2", "w") as f:
|
||||
for _ in range(100):
|
||||
f.write("\n")
|
||||
reader = JsonReader(ioctx, [
|
||||
reader = JsonReader([
|
||||
self.test_dir + "/empty1",
|
||||
self.test_dir + "/empty2",
|
||||
])
|
||||
|
||||
@@ -104,5 +104,5 @@ class LinearSchedule(object):
|
||||
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
fraction = min(float(t) / self.schedule_timesteps, 1.0)
|
||||
fraction = min(float(t) / max(1, self.schedule_timesteps), 1.0)
|
||||
return self.initial_p + fraction * (self.final_p - self.initial_p)
|
||||
|
||||
Reference in New Issue
Block a user