mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 22:12:41 +08:00
Switch Python indentation from 2 spaces to 4 spaces. (#726)
* 4 space indentation for actor.py. * 4 space indentation for worker.py. * 4 space indentation for more files. * 4 space indentation for some test files. * Check indentation in Travis. * 4 space indentation for some rl files. * Fix failure test. * Fix multi_node_test. * 4 space indentation for more files. * 4 space indentation for remaining files. * Fixes.
This commit is contained in:
committed by
Philipp Moritz
parent
310ba82131
commit
e0867c8845
@@ -21,9 +21,9 @@ __all__ = ["register_class", "error_info", "init", "connect", "disconnect",
|
||||
import ctypes
|
||||
# Windows only
|
||||
if hasattr(ctypes, "windll"):
|
||||
# Makes sure that all child processes die when we die. Also makes sure that
|
||||
# fatal crashes result in process termination rather than an error dialog
|
||||
# (the latter is annoying since we have a lot of processes). This is done by
|
||||
# associating all child processes with a "job" object that imposes this
|
||||
# behavior.
|
||||
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501
|
||||
# Makes sure that all child processes die when we die. Also makes sure that
|
||||
# fatal crashes result in process termination rather than an error dialog
|
||||
# (the latter is annoying since we have a lot of processes). This is done
|
||||
# by associating all child processes with a "job" object that imposes this
|
||||
# behavior.
|
||||
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501
|
||||
|
||||
+338
-325
@@ -18,406 +18,419 @@ from ray.utils import (FunctionProperties, binary_to_hex, hex_to_binary,
|
||||
|
||||
|
||||
def random_actor_id():
|
||||
return ray.local_scheduler.ObjectID(random_string())
|
||||
return ray.local_scheduler.ObjectID(random_string())
|
||||
|
||||
|
||||
def random_actor_class_id():
|
||||
return random_string()
|
||||
return random_string()
|
||||
|
||||
|
||||
def get_actor_method_function_id(attr):
|
||||
"""Get the function ID corresponding to an actor method.
|
||||
"""Get the function ID corresponding to an actor method.
|
||||
|
||||
Args:
|
||||
attr (str): The attribute name of the method.
|
||||
Args:
|
||||
attr (str): The attribute name of the method.
|
||||
|
||||
Returns:
|
||||
Function ID corresponding to the method.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash.update(attr.encode("ascii"))
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == 20
|
||||
return ray.local_scheduler.ObjectID(function_id)
|
||||
Returns:
|
||||
Function ID corresponding to the method.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash.update(attr.encode("ascii"))
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == 20
|
||||
return ray.local_scheduler.ObjectID(function_id)
|
||||
|
||||
|
||||
def fetch_and_register_actor(actor_class_key, worker):
|
||||
"""Import an actor.
|
||||
"""Import an actor.
|
||||
|
||||
This will be called by the worker's import thread when the worker receives
|
||||
the actor_class export, assuming that the worker is an actor for that class.
|
||||
"""
|
||||
actor_id_str = worker.actor_id
|
||||
(driver_id, class_id, class_name,
|
||||
module, pickled_class, actor_method_names) = worker.redis_client.hmget(
|
||||
actor_class_key, ["driver_id", "class_id", "class_name", "module",
|
||||
"class", "actor_method_names"])
|
||||
This will be called by the worker's import thread when the worker receives
|
||||
the actor_class export, assuming that the worker is an actor for that
|
||||
class.
|
||||
"""
|
||||
actor_id_str = worker.actor_id
|
||||
(driver_id, class_id, class_name,
|
||||
module, pickled_class, actor_method_names) = worker.redis_client.hmget(
|
||||
actor_class_key, ["driver_id", "class_id", "class_name", "module",
|
||||
"class", "actor_method_names"])
|
||||
|
||||
actor_name = class_name.decode("ascii")
|
||||
module = module.decode("ascii")
|
||||
actor_method_names = json.loads(actor_method_names.decode("ascii"))
|
||||
actor_name = class_name.decode("ascii")
|
||||
module = module.decode("ascii")
|
||||
actor_method_names = json.loads(actor_method_names.decode("ascii"))
|
||||
|
||||
# Create a temporary actor with some temporary methods so that if the actor
|
||||
# fails to be unpickled, the temporary actor can be used (just to produce
|
||||
# error messages and to prevent the driver from hanging).
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
worker.actors[actor_id_str] = TemporaryActor()
|
||||
# Create a temporary actor with some temporary methods so that if the actor
|
||||
# fails to be unpickled, the temporary actor can be used (just to produce
|
||||
# error messages and to prevent the driver from hanging).
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
worker.actors[actor_id_str] = TemporaryActor()
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception("The actor with name {} failed to be imported, and so "
|
||||
"cannot execute this method".format(actor_name))
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = get_actor_method_function_id(actor_method_name).id()
|
||||
worker.functions[driver_id][function_id] = (actor_method_name,
|
||||
temporary_actor_method)
|
||||
worker.function_properties[driver_id][function_id] = FunctionProperties(
|
||||
num_return_vals=1,
|
||||
num_cpus=1,
|
||||
num_gpus=0,
|
||||
max_calls=0)
|
||||
worker.num_task_executions[driver_id][function_id] = 0
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception("The actor with name {} failed to be imported, and so "
|
||||
"cannot execute this method".format(actor_name))
|
||||
for actor_method_name in actor_method_names:
|
||||
function_id = get_actor_method_function_id(actor_method_name).id()
|
||||
worker.functions[driver_id][function_id] = (actor_method_name,
|
||||
temporary_actor_method)
|
||||
worker.function_properties[driver_id][function_id] = (
|
||||
FunctionProperties(num_return_vals=1,
|
||||
num_cpus=1,
|
||||
num_gpus=0,
|
||||
max_calls=0))
|
||||
worker.num_task_executions[driver_id][function_id] = 0
|
||||
|
||||
try:
|
||||
unpickled_class = pickle.loads(pickled_class)
|
||||
except Exception:
|
||||
# If an exception was thrown when the actor was imported, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.worker.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
worker.push_error_to_driver(driver_id, "register_actor", traceback_str,
|
||||
data={"actor_id": actor_id_str})
|
||||
else:
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
|
||||
for (k, v) in inspect.getmembers(
|
||||
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x)))):
|
||||
function_id = get_actor_method_function_id(k).id()
|
||||
worker.functions[driver_id][function_id] = (k, v)
|
||||
# We do not set worker.function_properties[driver_id][function_id]
|
||||
# because we currently do need the actor worker to submit new tasks for
|
||||
# the actor.
|
||||
try:
|
||||
unpickled_class = pickle.loads(pickled_class)
|
||||
except Exception:
|
||||
# If an exception was thrown when the actor was imported, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.worker.format_error_message(traceback.format_exc())
|
||||
# Log the error message.
|
||||
worker.push_error_to_driver(driver_id, "register_actor", traceback_str,
|
||||
data={"actor_id": actor_id_str})
|
||||
else:
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
|
||||
for (k, v) in inspect.getmembers(
|
||||
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x)))):
|
||||
function_id = get_actor_method_function_id(k).id()
|
||||
worker.functions[driver_id][function_id] = (k, v)
|
||||
# We do not set worker.function_properties[driver_id][function_id]
|
||||
# because we currently do need the actor worker to submit new tasks
|
||||
# for the actor.
|
||||
|
||||
|
||||
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
|
||||
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
|
||||
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
|
||||
|
||||
Args:
|
||||
num_gpus: The number of GPUs to acquire.
|
||||
driver_id: The ID of the driver responsible for creating the actor.
|
||||
local_scheduler: Information about the local scheduler.
|
||||
Args:
|
||||
num_gpus: The number of GPUs to acquire.
|
||||
driver_id: The ID of the driver responsible for creating the actor.
|
||||
local_scheduler: Information about the local scheduler.
|
||||
|
||||
Returns:
|
||||
True if the GPUs were successfully reserved and false otherwise.
|
||||
"""
|
||||
assert num_gpus != 0
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
|
||||
Returns:
|
||||
True if the GPUs were successfully reserved and false otherwise.
|
||||
"""
|
||||
assert num_gpus != 0
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
|
||||
|
||||
success = False
|
||||
success = False
|
||||
|
||||
# Attempt to acquire GPU IDs atomically.
|
||||
with worker.redis_client.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the multi/exec
|
||||
# block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
# Attempt to acquire GPU IDs atomically.
|
||||
with worker.redis_client.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the
|
||||
# multi/exec block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
|
||||
# Figure out which GPUs are currently in use.
|
||||
result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(
|
||||
result.decode("ascii"))
|
||||
num_gpus_in_use = 0
|
||||
for key in gpus_in_use:
|
||||
num_gpus_in_use += gpus_in_use[key]
|
||||
assert num_gpus_in_use <= local_scheduler_total_gpus
|
||||
# Figure out which GPUs are currently in use.
|
||||
result = worker.redis_client.hget(local_scheduler_id,
|
||||
"gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(
|
||||
result.decode("ascii"))
|
||||
num_gpus_in_use = 0
|
||||
for key in gpus_in_use:
|
||||
num_gpus_in_use += gpus_in_use[key]
|
||||
assert num_gpus_in_use <= local_scheduler_total_gpus
|
||||
|
||||
pipe.multi()
|
||||
pipe.multi()
|
||||
|
||||
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
|
||||
# There are enough available GPUs, so try to reserve some. We use the
|
||||
# hex driver ID in hex as a dictionary key so that the dictionary is
|
||||
# JSON serializable.
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex not in gpus_in_use:
|
||||
gpus_in_use[driver_id_hex] = 0
|
||||
gpus_in_use[driver_id_hex] += num_gpus
|
||||
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
|
||||
# There are enough available GPUs, so try to reserve some.
|
||||
# We use the hex driver ID in hex as a dictionary key so
|
||||
# that the dictionary is JSON serializable.
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex not in gpus_in_use:
|
||||
gpus_in_use[driver_id_hex] = 0
|
||||
gpus_in_use[driver_id_hex] += num_gpus
|
||||
|
||||
# Stick the updated GPU IDs back in Redis
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use))
|
||||
success = True
|
||||
# Stick the updated GPU IDs back in Redis
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use",
|
||||
json.dumps(gpus_in_use))
|
||||
success = True
|
||||
|
||||
pipe.execute()
|
||||
# If a WatchError is not raised, then the operations should have gone
|
||||
# through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the time we
|
||||
# started WATCHing it and the pipeline's execution. We should just
|
||||
# retry.
|
||||
success = False
|
||||
continue
|
||||
pipe.execute()
|
||||
# If a WatchError is not raised, then the operations should
|
||||
# have gone through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the
|
||||
# time we started WATCHing it and the pipeline's execution. We
|
||||
# should just retry.
|
||||
success = False
|
||||
continue
|
||||
|
||||
return success
|
||||
return success
|
||||
|
||||
|
||||
def select_local_scheduler(local_schedulers, num_gpus, worker):
|
||||
"""Select a local scheduler to assign this actor to.
|
||||
"""Select a local scheduler to assign this actor to.
|
||||
|
||||
Args:
|
||||
local_schedulers: A list of dictionaries of information about the local
|
||||
schedulers.
|
||||
num_gpus (int): The number of GPUs that must be reserved for this actor.
|
||||
Args:
|
||||
local_schedulers: A list of dictionaries of information about the local
|
||||
schedulers.
|
||||
num_gpus (int): The number of GPUs that must be reserved for this
|
||||
actor.
|
||||
|
||||
Returns:
|
||||
The ID of the local scheduler that has been chosen.
|
||||
Returns:
|
||||
The ID of the local scheduler that has been chosen.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if no local scheduler can be found with
|
||||
sufficient resources.
|
||||
"""
|
||||
driver_id = worker.task_driver_id.id()
|
||||
Raises:
|
||||
Exception: An exception is raised if no local scheduler can be found
|
||||
with sufficient resources.
|
||||
"""
|
||||
driver_id = worker.task_driver_id.id()
|
||||
|
||||
local_scheduler_id = None
|
||||
# Loop through all of the local schedulers in a random order.
|
||||
local_schedulers = np.random.permutation(local_schedulers)
|
||||
for local_scheduler in local_schedulers:
|
||||
if local_scheduler["NumCPUs"] < 1:
|
||||
continue
|
||||
if local_scheduler["NumGPUs"] < num_gpus:
|
||||
continue
|
||||
if num_gpus == 0:
|
||||
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
|
||||
break
|
||||
else:
|
||||
# Try to reserve enough GPUs on this local scheduler.
|
||||
success = attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler,
|
||||
worker)
|
||||
if success:
|
||||
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
|
||||
break
|
||||
local_scheduler_id = None
|
||||
# Loop through all of the local schedulers in a random order.
|
||||
local_schedulers = np.random.permutation(local_schedulers)
|
||||
for local_scheduler in local_schedulers:
|
||||
if local_scheduler["NumCPUs"] < 1:
|
||||
continue
|
||||
if local_scheduler["NumGPUs"] < num_gpus:
|
||||
continue
|
||||
if num_gpus == 0:
|
||||
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
|
||||
break
|
||||
else:
|
||||
# Try to reserve enough GPUs on this local scheduler.
|
||||
success = attempt_to_reserve_gpus(num_gpus, driver_id,
|
||||
local_scheduler, worker)
|
||||
if success:
|
||||
local_scheduler_id = hex_to_binary(
|
||||
local_scheduler["DBClientID"])
|
||||
break
|
||||
|
||||
if local_scheduler_id is None:
|
||||
raise Exception("Could not find a node with enough GPUs or other "
|
||||
"resources to create this actor. The local scheduler "
|
||||
"information is {}.".format(local_schedulers))
|
||||
if local_scheduler_id is None:
|
||||
raise Exception("Could not find a node with enough GPUs or other "
|
||||
"resources to create this actor. The local scheduler "
|
||||
"information is {}.".format(local_schedulers))
|
||||
|
||||
return local_scheduler_id
|
||||
return local_scheduler_id
|
||||
|
||||
|
||||
def export_actor_class(class_id, Class, actor_method_names, worker):
|
||||
if worker.mode is None:
|
||||
raise NotImplemented("TODO(pcm): Cache actors")
|
||||
key = b"ActorClass:" + class_id
|
||||
d = {"driver_id": worker.task_driver_id.id(),
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"actor_method_names": json.dumps(list(actor_method_names))}
|
||||
worker.redis_client.hmset(key, d)
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
if worker.mode is None:
|
||||
raise NotImplemented("TODO(pcm): Cache actors")
|
||||
key = b"ActorClass:" + class_id
|
||||
d = {"driver_id": worker.task_driver_id.id(),
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"actor_method_names": json.dumps(list(actor_method_names))}
|
||||
worker.redis_client.hmset(key, d)
|
||||
worker.redis_client.rpush("Exports", key)
|
||||
|
||||
|
||||
def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
|
||||
worker):
|
||||
"""Export an actor to redis.
|
||||
"""Export an actor to redis.
|
||||
|
||||
Args:
|
||||
actor_id: The ID of the actor.
|
||||
actor_method_names (list): A list of the names of this actor's methods.
|
||||
num_cpus (int): The number of CPUs that this actor requires.
|
||||
num_gpus (int): The number of GPUs that this actor requires.
|
||||
"""
|
||||
ray.worker.check_main_thread()
|
||||
if worker.mode is None:
|
||||
raise Exception("Actors cannot be created before Ray has been started. "
|
||||
"You can start Ray with 'ray.init()'.")
|
||||
key = b"Actor:" + actor_id.id()
|
||||
Args:
|
||||
actor_id: The ID of the actor.
|
||||
actor_method_names (list): A list of the names of this actor's methods.
|
||||
num_cpus (int): The number of CPUs that this actor requires.
|
||||
num_gpus (int): The number of GPUs that this actor requires.
|
||||
"""
|
||||
ray.worker.check_main_thread()
|
||||
if worker.mode is None:
|
||||
raise Exception("Actors cannot be created before Ray has been "
|
||||
"started. You can start Ray with 'ray.init()'.")
|
||||
key = b"Actor:" + actor_id.id()
|
||||
|
||||
# For now, all actor methods have 1 return value.
|
||||
driver_id = worker.task_driver_id.id()
|
||||
for actor_method_name in actor_method_names:
|
||||
# TODO(rkn): When we create a second actor, we are probably overwriting
|
||||
# the values from the first actor here. This may or may not be a problem.
|
||||
function_id = get_actor_method_function_id(actor_method_name).id()
|
||||
worker.function_properties[driver_id][function_id] = FunctionProperties(
|
||||
num_return_vals=1,
|
||||
num_cpus=1,
|
||||
num_gpus=0,
|
||||
max_calls=0)
|
||||
# For now, all actor methods have 1 return value.
|
||||
driver_id = worker.task_driver_id.id()
|
||||
for actor_method_name in actor_method_names:
|
||||
# TODO(rkn): When we create a second actor, we are probably overwriting
|
||||
# the values from the first actor here. This may or may not be a
|
||||
# problem.
|
||||
function_id = get_actor_method_function_id(actor_method_name).id()
|
||||
worker.function_properties[driver_id][function_id] = (
|
||||
FunctionProperties(num_return_vals=1,
|
||||
num_cpus=1,
|
||||
num_gpus=0,
|
||||
max_calls=0))
|
||||
|
||||
# Get a list of the local schedulers from the client table.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if client["ClientType"] == "local_scheduler" and not client["Deleted"]:
|
||||
local_schedulers.append(client)
|
||||
# Select a local scheduler for the actor.
|
||||
local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus,
|
||||
worker)
|
||||
assert local_scheduler_id is not None
|
||||
# Get a list of the local schedulers from the client table.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if (client["ClientType"] == "local_scheduler" and
|
||||
not client["Deleted"]):
|
||||
local_schedulers.append(client)
|
||||
# Select a local scheduler for the actor.
|
||||
local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus,
|
||||
worker)
|
||||
assert local_scheduler_id is not None
|
||||
|
||||
# We must put the actor information in Redis before publishing the actor
|
||||
# notification so that when the newly created actor attempts to fetch the
|
||||
# information from Redis, it is already there.
|
||||
worker.redis_client.hmset(key, {"class_id": class_id,
|
||||
"num_gpus": num_gpus})
|
||||
# We must put the actor information in Redis before publishing the actor
|
||||
# notification so that when the newly created actor attempts to fetch the
|
||||
# information from Redis, it is already there.
|
||||
worker.redis_client.hmset(key, {"class_id": class_id,
|
||||
"num_gpus": num_gpus})
|
||||
|
||||
# Really we should encode this message as a flatbuffer object. However, we're
|
||||
# having trouble getting that to work. It almost works, but in Python 2.7,
|
||||
# builder.CreateString fails on byte strings that contain characters outside
|
||||
# range(128).
|
||||
# Really we should encode this message as a flatbuffer object. However,
|
||||
# we're having trouble getting that to work. It almost works, but in Python
|
||||
# 2.7, builder.CreateString fails on byte strings that contain characters
|
||||
# outside range(128).
|
||||
|
||||
# TODO(rkn): There is actually no guarantee that the local scheduler that we
|
||||
# are publishing to has already subscribed to the actor_notifications
|
||||
# channel. Therefore, this message may be missed and the workload will hang.
|
||||
# This is a bug.
|
||||
worker.redis_client.publish("actor_notifications",
|
||||
actor_id.id() + driver_id + local_scheduler_id)
|
||||
# TODO(rkn): There is actually no guarantee that the local scheduler that
|
||||
# we are publishing to has already subscribed to the actor_notifications
|
||||
# channel. Therefore, this message may be missed and the workload will
|
||||
# hang. This is a bug.
|
||||
worker.redis_client.publish("actor_notifications",
|
||||
actor_id.id() + driver_id + local_scheduler_id)
|
||||
|
||||
|
||||
def actor(*args, **kwargs):
|
||||
raise Exception("The @ray.actor decorator is deprecated. Instead, please "
|
||||
"use @ray.remote.")
|
||||
raise Exception("The @ray.actor decorator is deprecated. Instead, please "
|
||||
"use @ray.remote.")
|
||||
|
||||
|
||||
def make_actor(cls, num_cpus, num_gpus):
|
||||
# Modify the class to have an additional method that will be used for
|
||||
# terminating the worker.
|
||||
class Class(cls):
|
||||
def __ray_terminate__(self):
|
||||
ray.worker.global_worker.local_scheduler_client.disconnect()
|
||||
import os
|
||||
os._exit(0)
|
||||
# Modify the class to have an additional method that will be used for
|
||||
# terminating the worker.
|
||||
class Class(cls):
|
||||
def __ray_terminate__(self):
|
||||
ray.worker.global_worker.local_scheduler_client.disconnect()
|
||||
import os
|
||||
os._exit(0)
|
||||
|
||||
Class.__module__ = cls.__module__
|
||||
Class.__name__ = cls.__name__
|
||||
Class.__module__ = cls.__module__
|
||||
Class.__name__ = cls.__name__
|
||||
|
||||
class_id = random_actor_class_id()
|
||||
# The list exported will have length 0 if the class has not been exported
|
||||
# yet, and length one if it has. This is just implementing a bool, but we
|
||||
# don't use a bool because we need to modify it inside of the NewClass
|
||||
# constructor.
|
||||
exported = []
|
||||
class_id = random_actor_class_id()
|
||||
# The list exported will have length 0 if the class has not been exported
|
||||
# yet, and length one if it has. This is just implementing a bool, but we
|
||||
# don't use a bool because we need to modify it inside of the NewClass
|
||||
# constructor.
|
||||
exported = []
|
||||
|
||||
# The function actor_method_call gets called if somebody tries to call a
|
||||
# method on their local actor stub object.
|
||||
def actor_method_call(actor_id, attr, function_signature, *args, **kwargs):
|
||||
ray.worker.check_connected()
|
||||
ray.worker.check_main_thread()
|
||||
args = signature.extend_args(function_signature, args, kwargs)
|
||||
# The function actor_method_call gets called if somebody tries to call a
|
||||
# method on their local actor stub object.
|
||||
def actor_method_call(actor_id, attr, function_signature, *args, **kwargs):
|
||||
ray.worker.check_connected()
|
||||
ray.worker.check_main_thread()
|
||||
args = signature.extend_args(function_signature, args, kwargs)
|
||||
|
||||
function_id = get_actor_method_function_id(attr)
|
||||
object_ids = ray.worker.global_worker.submit_task(function_id, "", args,
|
||||
actor_id=actor_id)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
elif len(object_ids) > 1:
|
||||
return object_ids
|
||||
function_id = get_actor_method_function_id(attr)
|
||||
object_ids = ray.worker.global_worker.submit_task(function_id, "",
|
||||
args,
|
||||
actor_id=actor_id)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
elif len(object_ids) > 1:
|
||||
return object_ids
|
||||
|
||||
class ActorMethod(object):
|
||||
def __init__(self, method_name, actor_id, method_signature):
|
||||
self.method_name = method_name
|
||||
self.actor_id = actor_id
|
||||
self.method_signature = method_signature
|
||||
class ActorMethod(object):
|
||||
def __init__(self, method_name, actor_id, method_signature):
|
||||
self.method_name = method_name
|
||||
self.actor_id = actor_id
|
||||
self.method_signature = method_signature
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Actor methods cannot be called directly. Instead "
|
||||
"of running 'object.{}()', try 'object.{}.remote()'."
|
||||
.format(self.method_name, self.method_name))
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise Exception("Actor methods cannot be called directly. Instead "
|
||||
"of running 'object.{}()', try "
|
||||
"'object.{}.remote()'."
|
||||
.format(self.method_name, self.method_name))
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return actor_method_call(self.actor_id, self.method_name,
|
||||
self.method_signature, *args, **kwargs)
|
||||
def remote(self, *args, **kwargs):
|
||||
return actor_method_call(self.actor_id, self.method_name,
|
||||
self.method_signature, *args, **kwargs)
|
||||
|
||||
class NewClass(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("Actor classes cannot be instantiated directly. "
|
||||
"Instead of running '{}()', try '{}.remote()'."
|
||||
.format(Class.__name__, Class.__name__))
|
||||
class NewClass(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise Exception("Actor classes cannot be instantiated directly. "
|
||||
"Instead of running '{}()', try '{}.remote()'."
|
||||
.format(Class.__name__, Class.__name__))
|
||||
|
||||
@classmethod
|
||||
def remote(cls, *args, **kwargs):
|
||||
actor_object = cls.__new__(cls)
|
||||
actor_object._manual_init(*args, **kwargs)
|
||||
return actor_object
|
||||
@classmethod
|
||||
def remote(cls, *args, **kwargs):
|
||||
actor_object = cls.__new__(cls)
|
||||
actor_object._manual_init(*args, **kwargs)
|
||||
return actor_object
|
||||
|
||||
def _manual_init(self, *args, **kwargs):
|
||||
self._ray_actor_id = random_actor_id()
|
||||
self._ray_actor_methods = {
|
||||
k: v for (k, v) in inspect.getmembers(
|
||||
Class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x))))}
|
||||
# Extract the signatures of each of the methods. This will be used to
|
||||
# catch some errors if the methods are called with inappropriate
|
||||
# arguments.
|
||||
self._ray_method_signatures = dict()
|
||||
for k, v in self._ray_actor_methods.items():
|
||||
# Print a warning message if the method signature is not supported.
|
||||
# We don't raise an exception because if the actor inherits from a
|
||||
# class that has a method whose signature we don't support, we
|
||||
# there may not be much the user can do about it.
|
||||
signature.check_signature_supported(v, warn=True)
|
||||
self._ray_method_signatures[k] = signature.extract_signature(
|
||||
v, ignore_first=True)
|
||||
def _manual_init(self, *args, **kwargs):
|
||||
self._ray_actor_id = random_actor_id()
|
||||
self._ray_actor_methods = {
|
||||
k: v for (k, v) in inspect.getmembers(
|
||||
Class, predicate=(lambda x: (inspect.isfunction(x) or
|
||||
inspect.ismethod(x))))}
|
||||
# Extract the signatures of each of the methods. This will be used
|
||||
# to catch some errors if the methods are called with inappropriate
|
||||
# arguments.
|
||||
self._ray_method_signatures = dict()
|
||||
for k, v in self._ray_actor_methods.items():
|
||||
# Print a warning message if the method signature is not
|
||||
# supported. We don't raise an exception because if the actor
|
||||
# inherits from a class that has a method whose signature we
|
||||
# don't support, we there may not be much the user can do about
|
||||
# it.
|
||||
signature.check_signature_supported(v, warn=True)
|
||||
self._ray_method_signatures[k] = signature.extract_signature(
|
||||
v, ignore_first=True)
|
||||
|
||||
# Create objects to wrap method invocations. This is done so that we
|
||||
# can invoke methods with actor.method.remote() instead of
|
||||
# actor.method().
|
||||
self._actor_method_invokers = dict()
|
||||
for k, v in self._ray_actor_methods.items():
|
||||
self._actor_method_invokers[k] = ActorMethod(
|
||||
k, self._ray_actor_id, self._ray_method_signatures[k])
|
||||
# Create objects to wrap method invocations. This is done so that
|
||||
# we can invoke methods with actor.method.remote() instead of
|
||||
# actor.method().
|
||||
self._actor_method_invokers = dict()
|
||||
for k, v in self._ray_actor_methods.items():
|
||||
self._actor_method_invokers[k] = ActorMethod(
|
||||
k, self._ray_actor_id, self._ray_method_signatures[k])
|
||||
|
||||
# Export the actor class if it has not been exported yet.
|
||||
if len(exported) == 0:
|
||||
export_actor_class(class_id, Class, self._ray_actor_methods.keys(),
|
||||
ray.worker.global_worker)
|
||||
exported.append(0)
|
||||
# Export the actor.
|
||||
export_actor(self._ray_actor_id, class_id,
|
||||
self._ray_actor_methods.keys(), num_cpus, num_gpus,
|
||||
ray.worker.global_worker)
|
||||
# Call __init__ as a remote function.
|
||||
if "__init__" in self._ray_actor_methods.keys():
|
||||
actor_method_call(self._ray_actor_id, "__init__",
|
||||
self._ray_method_signatures["__init__"],
|
||||
*args, **kwargs)
|
||||
else:
|
||||
print("WARNING: this object has no __init__ method.")
|
||||
# Export the actor class if it has not been exported yet.
|
||||
if len(exported) == 0:
|
||||
export_actor_class(class_id, Class,
|
||||
self._ray_actor_methods.keys(),
|
||||
ray.worker.global_worker)
|
||||
exported.append(0)
|
||||
# Export the actor.
|
||||
export_actor(self._ray_actor_id, class_id,
|
||||
self._ray_actor_methods.keys(), num_cpus, num_gpus,
|
||||
ray.worker.global_worker)
|
||||
# Call __init__ as a remote function.
|
||||
if "__init__" in self._ray_actor_methods.keys():
|
||||
actor_method_call(self._ray_actor_id, "__init__",
|
||||
self._ray_method_signatures["__init__"],
|
||||
*args, **kwargs)
|
||||
else:
|
||||
print("WARNING: this object has no __init__ method.")
|
||||
|
||||
# Make tab completion work.
|
||||
def __dir__(self):
|
||||
return self._ray_actor_methods
|
||||
# Make tab completion work.
|
||||
def __dir__(self):
|
||||
return self._ray_actor_methods
|
||||
|
||||
def __getattribute__(self, attr):
|
||||
# The following is needed so we can still access self.actor_methods.
|
||||
if attr in ["_manual_init", "_ray_actor_id", "_ray_actor_methods",
|
||||
"_actor_method_invokers", "_ray_method_signatures"]:
|
||||
return object.__getattribute__(self, attr)
|
||||
if attr in self._ray_actor_methods.keys():
|
||||
return self._actor_method_invokers[attr]
|
||||
# There is no method with this name, so raise an exception.
|
||||
raise AttributeError("'{}' Actor object has no attribute '{}'"
|
||||
.format(Class, attr))
|
||||
def __getattribute__(self, attr):
|
||||
# The following is needed so we can still access
|
||||
# self.actor_methods.
|
||||
if attr in ["_manual_init", "_ray_actor_id", "_ray_actor_methods",
|
||||
"_actor_method_invokers", "_ray_method_signatures"]:
|
||||
return object.__getattribute__(self, attr)
|
||||
if attr in self._ray_actor_methods.keys():
|
||||
return self._actor_method_invokers[attr]
|
||||
# There is no method with this name, so raise an exception.
|
||||
raise AttributeError("'{}' Actor object has no attribute '{}'"
|
||||
.format(Class, attr))
|
||||
|
||||
def __repr__(self):
|
||||
return "Actor(" + self._ray_actor_id.hex() + ")"
|
||||
def __repr__(self):
|
||||
return "Actor(" + self._ray_actor_id.hex() + ")"
|
||||
|
||||
def __reduce__(self):
|
||||
raise Exception("Actor objects cannot be pickled.")
|
||||
def __reduce__(self):
|
||||
raise Exception("Actor objects cannot be pickled.")
|
||||
|
||||
def __del__(self):
|
||||
"""Kill the worker that is running this actor."""
|
||||
if ray.worker.global_worker.connected:
|
||||
actor_method_call(self._ray_actor_id, "__ray_terminate__",
|
||||
self._ray_method_signatures["__ray_terminate__"])
|
||||
def __del__(self):
|
||||
"""Kill the worker that is running this actor."""
|
||||
if ray.worker.global_worker.connected:
|
||||
actor_method_call(
|
||||
self._ray_actor_id, "__ray_terminate__",
|
||||
self._ray_method_signatures["__ray_terminate__"])
|
||||
|
||||
return NewClass
|
||||
return NewClass
|
||||
|
||||
|
||||
ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor
|
||||
|
||||
@@ -23,423 +23,436 @@ OBJECT_CHANNEL_PREFIX = "OC:"
|
||||
|
||||
|
||||
def integerToAsciiHex(num, numbytes):
|
||||
retstr = b""
|
||||
# Support 32 and 64 bit architecture.
|
||||
assert(numbytes == 4 or numbytes == 8)
|
||||
for i in range(numbytes):
|
||||
curbyte = num & 0xff
|
||||
if sys.version_info >= (3, 0):
|
||||
retstr += bytes([curbyte])
|
||||
else:
|
||||
retstr += chr(curbyte)
|
||||
num = num >> 8
|
||||
retstr = b""
|
||||
# Support 32 and 64 bit architecture.
|
||||
assert(numbytes == 4 or numbytes == 8)
|
||||
for i in range(numbytes):
|
||||
curbyte = num & 0xff
|
||||
if sys.version_info >= (3, 0):
|
||||
retstr += bytes([curbyte])
|
||||
else:
|
||||
retstr += chr(curbyte)
|
||||
num = num >> 8
|
||||
|
||||
return retstr
|
||||
return retstr
|
||||
|
||||
|
||||
def get_next_message(pubsub_client, timeout_seconds=10):
|
||||
"""Block until the next message is available on the pubsub channel."""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
message = pubsub_client.get_message()
|
||||
if message is not None:
|
||||
return message
|
||||
time.sleep(0.1)
|
||||
if time.time() - start_time > timeout_seconds:
|
||||
raise Exception("Timed out while waiting for next message.")
|
||||
"""Block until the next message is available on the pubsub channel."""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
message = pubsub_client.get_message()
|
||||
if message is not None:
|
||||
return message
|
||||
time.sleep(0.1)
|
||||
if time.time() - start_time > timeout_seconds:
|
||||
raise Exception("Timed out while waiting for next message.")
|
||||
|
||||
|
||||
class TestGlobalStateStore(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
redis_port, _ = ray.services.start_redis_instance()
|
||||
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
|
||||
def setUp(self):
|
||||
redis_port, _ = ray.services.start_redis_instance()
|
||||
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
|
||||
|
||||
def tearDown(self):
|
||||
ray.services.cleanup()
|
||||
def tearDown(self):
|
||||
ray.services.cleanup()
|
||||
|
||||
def testInvalidObjectTableAdd(self):
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
|
||||
# with the wrong arguments.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one",
|
||||
"hash2", "manager_id1")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
|
||||
"hash2", "manager_id1", "extra argument")
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
|
||||
# object ID that is already present with a different hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1"})
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id2")
|
||||
# Check that the second manager was added, even though the hash was
|
||||
# mismatched.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Check that it is fine if we add the same object ID multiple times with
|
||||
# the most recent hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
|
||||
"hash2", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
def testInvalidObjectTableAdd(self):
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
|
||||
# with the wrong arguments.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
|
||||
"one", "hash2", "manager_id1")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
|
||||
"hash2", "manager_id1",
|
||||
"extra argument")
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
|
||||
# object ID that is already present with a different hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1"})
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id2")
|
||||
# Check that the second manager was added, even though the hash was
|
||||
# mismatched.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Check that it is fine if we add the same object ID multiple times
|
||||
# with the most recent hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash2", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
|
||||
"hash2", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
|
||||
def testObjectTableAddAndLookup(self):
|
||||
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
|
||||
# added yet.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(response, None)
|
||||
# Add some managers and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Add a manager that already exists again and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Check that we properly handle NULL characters. In the past, NULL
|
||||
# characters were handled improperly causing a "hash mismatch" error if two
|
||||
# object IDs that agreed up to the NULL character were inserted with
|
||||
# different hashes.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
|
||||
"hash2", "manager_id1")
|
||||
# Check that NULL characters in the hash are handled properly.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
|
||||
"\x00hash1", "manager_id1")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
|
||||
"\x00hash2", "manager_id1")
|
||||
def testObjectTableAddAndLookup(self):
|
||||
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not
|
||||
# been added yet.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(response, None)
|
||||
# Add some managers and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Add a manager that already exists again and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Check that we properly handle NULL characters. In the past, NULL
|
||||
# characters were handled improperly causing a "hash mismatch" error if
|
||||
# two object IDs that agreed up to the NULL character were inserted
|
||||
# with different hashes.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
|
||||
"hash2", "manager_id1")
|
||||
# Check that NULL characters in the hash are handled properly.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
|
||||
"\x00hash1", "manager_id1")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
|
||||
"\x00hash2", "manager_id1")
|
||||
|
||||
def testObjectTableAddAndRemove(self):
|
||||
# Try removing a manager from an object ID that has not been added yet.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
|
||||
# added yet.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(response, None)
|
||||
# Add some managers and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Remove a manager that doesn't exist, and make sure we still have the same
|
||||
# set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Remove a manager that does exist. Make sure it gets removed the first
|
||||
# time and does nothing the second time.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id2"})
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id2"})
|
||||
# Remove the last manager, and make sure we have an empty set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), set())
|
||||
# Remove a manager from an empty set, and make sure we now have an empty
|
||||
# set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), set())
|
||||
def testObjectTableAddAndRemove(self):
|
||||
# Try removing a manager from an object ID that has not been added yet.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not
|
||||
# been added yet.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(response, None)
|
||||
# Add some managers and try again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Remove a manager that doesn't exist, and make sure we still have the
|
||||
# same set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Remove a manager that does exist. Make sure it gets removed the first
|
||||
# time and does nothing the second time.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id2"})
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id2"})
|
||||
# Remove the last manager, and make sure we have an empty set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), set())
|
||||
# Remove a manager from an empty set, and make sure we now have an
|
||||
# empty set.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
|
||||
"manager_id3")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertEqual(set(response), set())
|
||||
|
||||
def testObjectTableSubscribeToNotifications(self):
|
||||
# Define a helper method for checking the contents of object notifications.
|
||||
def check_object_notification(notification_message, object_id, object_size,
|
||||
manager_ids):
|
||||
notification_object = (SubscribeToNotificationsReply
|
||||
.GetRootAsSubscribeToNotificationsReply(
|
||||
notification_message, 0))
|
||||
self.assertEqual(notification_object.ObjectId(), object_id)
|
||||
self.assertEqual(notification_object.ObjectSize(), object_size)
|
||||
self.assertEqual(notification_object.ManagerIdsLength(),
|
||||
len(manager_ids))
|
||||
for i in range(len(manager_ids)):
|
||||
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
|
||||
def testObjectTableSubscribeToNotifications(self):
|
||||
# Define a helper method for checking the contents of object
|
||||
# notifications.
|
||||
def check_object_notification(notification_message, object_id,
|
||||
object_size, manager_ids):
|
||||
notification_object = (SubscribeToNotificationsReply
|
||||
.GetRootAsSubscribeToNotificationsReply(
|
||||
notification_message, 0))
|
||||
self.assertEqual(notification_object.ObjectId(), object_id)
|
||||
self.assertEqual(notification_object.ObjectSize(), object_size)
|
||||
self.assertEqual(notification_object.ManagerIdsLength(),
|
||||
len(manager_ids))
|
||||
for i in range(len(manager_ids)):
|
||||
self.assertEqual(notification_object.ManagerIds(i),
|
||||
manager_ids[i])
|
||||
|
||||
data_size = 0xf1f0
|
||||
p = self.redis.pubsub()
|
||||
# Subscribe to an object ID.
|
||||
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size,
|
||||
"hash1", "manager_id2")
|
||||
# Receive the acknowledgement message.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
# Request a notification and receive the data.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id1")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id1",
|
||||
data_size,
|
||||
[b"manager_id2"])
|
||||
data_size = 0xf1f0
|
||||
p = self.redis.pubsub()
|
||||
# Subscribe to an object ID.
|
||||
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1",
|
||||
data_size, "hash1", "manager_id2")
|
||||
# Receive the acknowledgement message.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
# Request a notification and receive the data.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id1")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id1",
|
||||
data_size,
|
||||
[b"manager_id2"])
|
||||
|
||||
# Request a notification for an object that isn't there. Then add the
|
||||
# object and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD
|
||||
# should trigger notifications.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id2", "object_id3")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
|
||||
"hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
|
||||
"hash1", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
|
||||
"hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1"])
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size,
|
||||
"hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id2",
|
||||
data_size,
|
||||
[b"manager_id3"])
|
||||
# Request notifications for object_id3 again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1", b"manager_id2", b"manager_id3"])
|
||||
# Request a notification for an object that isn't there. Then add the
|
||||
# object and receive the data. Only the first call to
|
||||
# RAY.OBJECT_TABLE_ADD should trigger notifications.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id2", "object_id3")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
|
||||
data_size, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
|
||||
data_size, "hash1", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
|
||||
data_size, "hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1"])
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
|
||||
data_size, "hash1", "manager_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id2",
|
||||
data_size,
|
||||
[b"manager_id3"])
|
||||
# Request notifications for object_id3 again.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
|
||||
"manager_id1", "object_id3")
|
||||
# Verify that the notification is correct.
|
||||
check_object_notification(get_next_message(p)["data"],
|
||||
b"object_id3",
|
||||
data_size,
|
||||
[b"manager_id1", b"manager_id2",
|
||||
b"manager_id3"])
|
||||
|
||||
def testResultTableAddAndLookup(self):
|
||||
def check_result_table_entry(message, task_id, is_put):
|
||||
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message,
|
||||
0)
|
||||
self.assertEqual(result_table_reply.TaskId(), task_id)
|
||||
self.assertEqual(result_table_reply.IsPut(), is_put)
|
||||
def testResultTableAddAndLookup(self):
|
||||
def check_result_table_entry(message, task_id, is_put):
|
||||
result_table_reply = ResultTableReply.GetRootAsResultTableReply(
|
||||
message, 0)
|
||||
self.assertEqual(result_table_reply.TaskId(), task_id)
|
||||
self.assertEqual(result_table_reply.IsPut(), is_put)
|
||||
|
||||
# Try looking up something in the result table before anything is added.
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertIsNone(response)
|
||||
# Adding the object to the object table should have no effect.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertIsNone(response)
|
||||
# Add the result to the result table. The lookup now returns the task ID.
|
||||
task_id = b"task_id1"
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id,
|
||||
0)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
check_result_table_entry(response, task_id, False)
|
||||
# Doing it again should still work.
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
check_result_table_entry(response, task_id, False)
|
||||
# Try another result table lookup. This should succeed.
|
||||
task_id = b"task_id2"
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id,
|
||||
1)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id2")
|
||||
check_result_table_entry(response, task_id, True)
|
||||
# Try looking up something in the result table before anything is
|
||||
# added.
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertIsNone(response)
|
||||
# Adding the object to the object table should have no effect.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
|
||||
"hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
self.assertIsNone(response)
|
||||
# Add the result to the result table. The lookup now returns the task
|
||||
# ID.
|
||||
task_id = b"task_id1"
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1",
|
||||
task_id, 0)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
check_result_table_entry(response, task_id, False)
|
||||
# Doing it again should still work.
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id1")
|
||||
check_result_table_entry(response, task_id, False)
|
||||
# Try another result table lookup. This should succeed.
|
||||
task_id = b"task_id2"
|
||||
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2",
|
||||
task_id, 1)
|
||||
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
|
||||
"object_id2")
|
||||
check_result_table_entry(response, task_id, True)
|
||||
|
||||
def testInvalidTaskTableAdd(self):
|
||||
# Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with
|
||||
# the wrong arguments.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, "node_id")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
# Non-integer scheduling states should not be added.
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
|
||||
"invalid_state", "node_id", "task_spec")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
# Should not be able to update a non-existent task.
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
|
||||
"node_id")
|
||||
def testInvalidTaskTableAdd(self):
|
||||
# Check that Redis returns an error when RAY.TASK_TABLE_ADD is called
|
||||
# with the wrong arguments.
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3,
|
||||
"node_id")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
# Non-integer scheduling states should not be added.
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
|
||||
"invalid_state", "node_id", "task_spec")
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
# Should not be able to update a non-existent task.
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
|
||||
"node_id")
|
||||
|
||||
def testTaskTableAddAndLookup(self):
|
||||
TASK_STATUS_WAITING = 1
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
def testTaskTableAddAndLookup(self):
|
||||
TASK_STATUS_WAITING = 1
|
||||
TASK_STATUS_SCHEDULED = 2
|
||||
TASK_STATUS_QUEUED = 4
|
||||
|
||||
# make sure somebody will get a notification (checked in the redis module)
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# make sure somebody will get a notification (checked in the redis
|
||||
# module)
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
|
||||
def check_task_reply(message, task_args, updated=False):
|
||||
task_status, local_scheduler_id, task_spec = task_args
|
||||
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(task_reply_object.State(), task_status)
|
||||
self.assertEqual(task_reply_object.LocalSchedulerId(),
|
||||
local_scheduler_id)
|
||||
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
|
||||
self.assertEqual(task_reply_object.Updated(), updated)
|
||||
def check_task_reply(message, task_args, updated=False):
|
||||
task_status, local_scheduler_id, task_spec = task_args
|
||||
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(task_reply_object.State(), task_status)
|
||||
self.assertEqual(task_reply_object.LocalSchedulerId(),
|
||||
local_scheduler_id)
|
||||
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
|
||||
self.assertEqual(task_reply_object.Updated(), updated)
|
||||
|
||||
# Check that task table adds, updates, and lookups work correctly.
|
||||
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
|
||||
*task_args)
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
check_task_reply(response, task_args)
|
||||
# Check that task table adds, updates, and lookups work correctly.
|
||||
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
|
||||
*task_args)
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
check_task_reply(response, task_args)
|
||||
|
||||
task_args[0] = TASK_STATUS_SCHEDULED
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
|
||||
*task_args[:2])
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
check_task_reply(response, task_args)
|
||||
task_args[0] = TASK_STATUS_SCHEDULED
|
||||
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
|
||||
*task_args[:2])
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
check_task_reply(response, task_args)
|
||||
|
||||
# If the current value, test value, and set value are all the same, the
|
||||
# update happens, and the response is still the same task.
|
||||
task_args = [task_args[0]] + task_args
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the task entry is still the same.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
# If the current value, test value, and set value are all the same, the
|
||||
# update happens, and the response is still the same task.
|
||||
task_args = [task_args[0]] + task_args
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the task entry is still the same.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
"task_id")
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
# If the current value is the same as the test value, and the set value is
|
||||
# different, the update happens, and the response is the entire task.
|
||||
task_args[1] = TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the update happened.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
# If the current value is the same as the test value, and the set value
|
||||
# is different, the update happens, and the response is the entire
|
||||
# task.
|
||||
task_args[1] = TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# Check that the update happened.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
"task_id")
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
# If the current value is no longer the same as the test value, the
|
||||
# response is the same task as before the test-and-set.
|
||||
new_task_args = task_args[:]
|
||||
new_task_args[1] = TASK_STATUS_WAITING
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
self.assertEqual(get_response2, get_response)
|
||||
# If the current value is no longer the same as the test value, the
|
||||
# response is the same task as before the test-and-set.
|
||||
new_task_args = task_args[:]
|
||||
new_task_args[1] = TASK_STATUS_WAITING
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
"task_id")
|
||||
self.assertEqual(get_response2, get_response)
|
||||
|
||||
# If the test value is a bitmask that matches the current value, the update
|
||||
# happens.
|
||||
task_args = new_task_args
|
||||
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
# If the test value is a bitmask that matches the current value, the
|
||||
# update happens.
|
||||
task_args = new_task_args
|
||||
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=True)
|
||||
|
||||
# If the test value is a bitmask that does not match the current value, the
|
||||
# update does not happen, and the response is the same task as before the
|
||||
# test-and-set.
|
||||
new_task_args = task_args[:]
|
||||
new_task_args[0] = TASK_STATUS_SCHEDULED
|
||||
old_response = response
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
|
||||
self.assertNotEqual(get_response, old_response)
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
# If the test value is a bitmask that does not match the current value,
|
||||
# the update does not happen, and the response is the same task as
|
||||
# before the test-and-set.
|
||||
new_task_args = task_args[:]
|
||||
new_task_args[0] = TASK_STATUS_SCHEDULED
|
||||
old_response = response
|
||||
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
|
||||
"task_id",
|
||||
*new_task_args[:3])
|
||||
check_task_reply(response, task_args[1:], updated=False)
|
||||
# Check that the update did not happen.
|
||||
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
|
||||
"task_id")
|
||||
self.assertNotEqual(get_response, old_response)
|
||||
check_task_reply(get_response, task_args[1:])
|
||||
|
||||
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the data.
|
||||
message = get_next_message(p)["data"]
|
||||
# Check that the notification object is correct.
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
|
||||
task_args = [b"task_id", scheduling_state,
|
||||
local_scheduler_id.encode("ascii"), b"task_spec"]
|
||||
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
|
||||
# Receive the data.
|
||||
message = get_next_message(p)["data"]
|
||||
# Check that the notification object is correct.
|
||||
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
|
||||
self.assertEqual(notification_object.TaskId(), b"task_id")
|
||||
self.assertEqual(notification_object.State(), scheduling_state)
|
||||
self.assertEqual(notification_object.LocalSchedulerId(),
|
||||
local_scheduler_id.encode("ascii"))
|
||||
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
|
||||
|
||||
def testTaskTableSubscribe(self):
|
||||
scheduling_state = 1
|
||||
local_scheduler_id = "local_scheduler_id"
|
||||
# Subscribe to the task table.
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
# unsubscribe to make sure there is only one subscriber at a given time
|
||||
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
def testTaskTableSubscribe(self):
|
||||
scheduling_state = 1
|
||||
local_scheduler_id = "local_scheduler_id"
|
||||
# Subscribe to the task table.
|
||||
p = self.redis.pubsub()
|
||||
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
# unsubscribe to make sure there is only one subscriber at a given time
|
||||
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
p.psubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
p.psubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}*:{state}".format(
|
||||
prefix=TASK_PREFIX, state=scheduling_state))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 1)
|
||||
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
|
||||
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
|
||||
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
|
||||
# Receive acknowledgment.
|
||||
self.assertEqual(get_next_message(p)["data"], 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
+102
-102
@@ -13,19 +13,19 @@ ID_SIZE = 20
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_function_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_driver_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_task_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
BASE_SIMPLE_OBJECTS = [
|
||||
@@ -33,7 +33,7 @@ BASE_SIMPLE_OBJECTS = [
|
||||
990 * u"h"]
|
||||
|
||||
if sys.version_info < (3, 0):
|
||||
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
|
||||
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
|
||||
|
||||
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
|
||||
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
|
||||
@@ -51,8 +51,8 @@ l.append(l)
|
||||
|
||||
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(),
|
||||
@@ -70,118 +70,118 @@ COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
|
||||
|
||||
class TestSerialization(unittest.TestCase):
|
||||
|
||||
def test_serialize_by_value(self):
|
||||
def test_serialize_by_value(self):
|
||||
|
||||
for val in SIMPLE_OBJECTS:
|
||||
self.assertTrue(local_scheduler.check_simple_value(val))
|
||||
for val in COMPLEX_OBJECTS:
|
||||
self.assertFalse(local_scheduler.check_simple_value(val))
|
||||
for val in SIMPLE_OBJECTS:
|
||||
self.assertTrue(local_scheduler.check_simple_value(val))
|
||||
for val in COMPLEX_OBJECTS:
|
||||
self.assertFalse(local_scheduler.check_simple_value(val))
|
||||
|
||||
|
||||
class TestObjectID(unittest.TestCase):
|
||||
|
||||
def test_create_object_id(self):
|
||||
random_object_id()
|
||||
def test_create_object_id(self):
|
||||
random_object_id()
|
||||
|
||||
def test_cannot_pickle_object_ids(self):
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
def test_cannot_pickle_object_ids(self):
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
|
||||
def f():
|
||||
return object_ids
|
||||
def f():
|
||||
return object_ids
|
||||
|
||||
def g(val=object_ids):
|
||||
return 1
|
||||
def g(val=object_ids):
|
||||
return 1
|
||||
|
||||
def h():
|
||||
object_ids[0]
|
||||
return 1
|
||||
# Make sure that object IDs cannot be pickled (including functions that
|
||||
# close over object IDs).
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(f))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(g))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(h))
|
||||
def h():
|
||||
object_ids[0]
|
||||
return 1
|
||||
# Make sure that object IDs cannot be pickled (including functions that
|
||||
# close over object IDs).
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(f))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(g))
|
||||
self.assertRaises(Exception, lambda: pickle.dumps(h))
|
||||
|
||||
def test_equality_comparisons(self):
|
||||
x1 = local_scheduler.ObjectID(ID_SIZE * b"a")
|
||||
x2 = local_scheduler.ObjectID(ID_SIZE * b"a")
|
||||
y1 = local_scheduler.ObjectID(ID_SIZE * b"b")
|
||||
y2 = local_scheduler.ObjectID(ID_SIZE * b"b")
|
||||
self.assertEqual(x1, x2)
|
||||
self.assertEqual(y1, y2)
|
||||
self.assertNotEqual(x1, y1)
|
||||
def test_equality_comparisons(self):
|
||||
x1 = local_scheduler.ObjectID(ID_SIZE * b"a")
|
||||
x2 = local_scheduler.ObjectID(ID_SIZE * b"a")
|
||||
y1 = local_scheduler.ObjectID(ID_SIZE * b"b")
|
||||
y2 = local_scheduler.ObjectID(ID_SIZE * b"b")
|
||||
self.assertEqual(x1, x2)
|
||||
self.assertEqual(y1, y2)
|
||||
self.assertNotEqual(x1, y1)
|
||||
|
||||
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
|
||||
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
self.assertEqual(len(set(object_ids1)), 256)
|
||||
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
|
||||
self.assertEqual(set(object_ids1), set(object_ids2))
|
||||
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
|
||||
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
|
||||
for i in range(256)]
|
||||
self.assertEqual(len(set(object_ids1)), 256)
|
||||
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
|
||||
self.assertEqual(set(object_ids1), set(object_ids2))
|
||||
|
||||
def test_hashability(self):
|
||||
x = random_object_id()
|
||||
y = random_object_id()
|
||||
{x: y}
|
||||
set([x, y])
|
||||
def test_hashability(self):
|
||||
x = random_object_id()
|
||||
y = random_object_id()
|
||||
{x: y}
|
||||
set([x, y])
|
||||
|
||||
|
||||
class TestTask(unittest.TestCase):
|
||||
|
||||
def check_task(self, task, function_id, num_return_vals, args):
|
||||
self.assertEqual(function_id.id(), task.function_id().id())
|
||||
retrieved_args = task.arguments()
|
||||
self.assertEqual(num_return_vals, len(task.returns()))
|
||||
self.assertEqual(len(args), len(retrieved_args))
|
||||
for i in range(len(retrieved_args)):
|
||||
if isinstance(retrieved_args[i], local_scheduler.ObjectID):
|
||||
self.assertEqual(retrieved_args[i].id(), args[i].id())
|
||||
else:
|
||||
self.assertEqual(retrieved_args[i], args[i])
|
||||
def check_task(self, task, function_id, num_return_vals, args):
|
||||
self.assertEqual(function_id.id(), task.function_id().id())
|
||||
retrieved_args = task.arguments()
|
||||
self.assertEqual(num_return_vals, len(task.returns()))
|
||||
self.assertEqual(len(args), len(retrieved_args))
|
||||
for i in range(len(retrieved_args)):
|
||||
if isinstance(retrieved_args[i], local_scheduler.ObjectID):
|
||||
self.assertEqual(retrieved_args[i].id(), args[i].id())
|
||||
else:
|
||||
self.assertEqual(retrieved_args[i], args[i])
|
||||
|
||||
def test_create_and_serialize_task(self):
|
||||
# TODO(rkn): The function ID should be a FunctionID object, not an
|
||||
# ObjectID.
|
||||
driver_id = random_driver_id()
|
||||
parent_id = random_task_id()
|
||||
function_id = random_function_id()
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
args_list = [
|
||||
[],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(driver_id, function_id, args,
|
||||
num_return_vals, parent_id, 0)
|
||||
self.check_task(task, function_id, num_return_vals, args)
|
||||
data = local_scheduler.task_to_string(task)
|
||||
task2 = local_scheduler.task_from_string(data)
|
||||
self.check_task(task2, function_id, num_return_vals, args)
|
||||
def test_create_and_serialize_task(self):
|
||||
# TODO(rkn): The function ID should be a FunctionID object, not an
|
||||
# ObjectID.
|
||||
driver_id = random_driver_id()
|
||||
parent_id = random_task_id()
|
||||
function_id = random_function_id()
|
||||
object_ids = [random_object_id() for _ in range(256)]
|
||||
args_list = [
|
||||
[],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids]
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(driver_id, function_id, args,
|
||||
num_return_vals, parent_id, 0)
|
||||
self.check_task(task, function_id, num_return_vals, args)
|
||||
data = local_scheduler.task_to_string(task)
|
||||
task2 = local_scheduler.task_from_string(data)
|
||||
self.check_task(task2, function_id, num_return_vals, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -10,271 +10,277 @@ BLOCK_SIZE = 10
|
||||
|
||||
|
||||
class DistArray(object):
|
||||
def __init__(self, shape, objectids=None):
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
|
||||
if objectids is not None:
|
||||
self.objectids = objectids
|
||||
else:
|
||||
self.objectids = np.empty(self.num_blocks, dtype=object)
|
||||
if self.num_blocks != list(self.objectids.shape):
|
||||
raise Exception("The fields `num_blocks` and `objectids` are "
|
||||
"inconsistent, `num_blocks` is {} and `objectids` has "
|
||||
"shape {}".format(self.num_blocks,
|
||||
list(self.objectids.shape)))
|
||||
def __init__(self, shape, objectids=None):
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE))
|
||||
for a in self.shape]
|
||||
if objectids is not None:
|
||||
self.objectids = objectids
|
||||
else:
|
||||
self.objectids = np.empty(self.num_blocks, dtype=object)
|
||||
if self.num_blocks != list(self.objectids.shape):
|
||||
raise Exception("The fields `num_blocks` and `objectids` are "
|
||||
"inconsistent, `num_blocks` is {} and `objectids` "
|
||||
"has shape {}".format(self.num_blocks,
|
||||
list(self.objectids.shape)))
|
||||
|
||||
@staticmethod
|
||||
def compute_block_lower(index, shape):
|
||||
if len(index) != len(shape):
|
||||
raise Exception("The fields `index` and `shape` must have the same "
|
||||
"length, but `index` is {} and `shape` is "
|
||||
"{}.".format(index, shape))
|
||||
return [elem * BLOCK_SIZE for elem in index]
|
||||
@staticmethod
|
||||
def compute_block_lower(index, shape):
|
||||
if len(index) != len(shape):
|
||||
raise Exception("The fields `index` and `shape` must have the "
|
||||
"same length, but `index` is {} and `shape` is "
|
||||
"{}.".format(index, shape))
|
||||
return [elem * BLOCK_SIZE for elem in index]
|
||||
|
||||
@staticmethod
|
||||
def compute_block_upper(index, shape):
|
||||
if len(index) != len(shape):
|
||||
raise Exception("The fields `index` and `shape` must have the same "
|
||||
"length, but `index` is {} and `shape` is "
|
||||
"{}.".format(index, shape))
|
||||
upper = []
|
||||
for i in range(len(shape)):
|
||||
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
|
||||
return upper
|
||||
@staticmethod
|
||||
def compute_block_upper(index, shape):
|
||||
if len(index) != len(shape):
|
||||
raise Exception("The fields `index` and `shape` must have the "
|
||||
"same length, but `index` is {} and `shape` is "
|
||||
"{}.".format(index, shape))
|
||||
upper = []
|
||||
for i in range(len(shape)):
|
||||
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
|
||||
return upper
|
||||
|
||||
@staticmethod
|
||||
def compute_block_shape(index, shape):
|
||||
lower = DistArray.compute_block_lower(index, shape)
|
||||
upper = DistArray.compute_block_upper(index, shape)
|
||||
return [u - l for (l, u) in zip(lower, upper)]
|
||||
@staticmethod
|
||||
def compute_block_shape(index, shape):
|
||||
lower = DistArray.compute_block_lower(index, shape)
|
||||
upper = DistArray.compute_block_upper(index, shape)
|
||||
return [u - l for (l, u) in zip(lower, upper)]
|
||||
|
||||
@staticmethod
|
||||
def compute_num_blocks(shape):
|
||||
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
|
||||
@staticmethod
|
||||
def compute_num_blocks(shape):
|
||||
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
|
||||
|
||||
def assemble(self):
|
||||
"""Assemble an array from a distributed array of object IDs."""
|
||||
first_block = ray.get(self.objectids[(0,) * self.ndim])
|
||||
dtype = first_block.dtype
|
||||
result = np.zeros(self.shape, dtype=dtype)
|
||||
for index in np.ndindex(*self.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, self.shape)
|
||||
upper = DistArray.compute_block_upper(index, self.shape)
|
||||
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(
|
||||
self.objectids[index])
|
||||
return result
|
||||
def assemble(self):
|
||||
"""Assemble an array from a distributed array of object IDs."""
|
||||
first_block = ray.get(self.objectids[(0,) * self.ndim])
|
||||
dtype = first_block.dtype
|
||||
result = np.zeros(self.shape, dtype=dtype)
|
||||
for index in np.ndindex(*self.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, self.shape)
|
||||
upper = DistArray.compute_block_upper(index, self.shape)
|
||||
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(
|
||||
self.objectids[index])
|
||||
return result
|
||||
|
||||
def __getitem__(self, sliced):
|
||||
# TODO(rkn): Fix this, this is just a placeholder that should work but is
|
||||
# inefficient.
|
||||
a = self.assemble()
|
||||
return a[sliced]
|
||||
def __getitem__(self, sliced):
|
||||
# TODO(rkn): Fix this, this is just a placeholder that should work but
|
||||
# is inefficient.
|
||||
a = self.assemble()
|
||||
return a[sliced]
|
||||
|
||||
|
||||
@ray.remote
|
||||
def assemble(a):
|
||||
return a.assemble()
|
||||
return a.assemble()
|
||||
|
||||
|
||||
# TODO(rkn): What should we call this method?
|
||||
@ray.remote
|
||||
def numpy_to_dist(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, a.shape)
|
||||
upper = DistArray.compute_block_upper(index, a.shape)
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
|
||||
in zip(lower, upper)]])
|
||||
return result
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
lower = DistArray.compute_block_lower(index, a.shape)
|
||||
upper = DistArray.compute_block_upper(index, a.shape)
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
|
||||
in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.zeros.remote(
|
||||
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.zeros.remote(
|
||||
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.ones.remote(
|
||||
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.ones.remote(
|
||||
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def copy(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
# We don't need to actually copy the objects because remote objects are
|
||||
# immutable.
|
||||
result.objectids[index] = a.objectids[index]
|
||||
return result
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
# We don't need to actually copy the objects because remote objects are
|
||||
# immutable.
|
||||
result.objectids[index] = a.objectids[index]
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
dim2 = dim1 if dim2 == -1 else dim2
|
||||
shape = [dim1, dim2]
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
block_shape = DistArray.compute_block_shape([i, j], shape)
|
||||
if i == j:
|
||||
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1],
|
||||
dtype_name=dtype_name)
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape,
|
||||
dtype_name=dtype_name)
|
||||
return result
|
||||
dim2 = dim1 if dim2 == -1 else dim2
|
||||
shape = [dim1, dim2]
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
block_shape = DistArray.compute_block_shape([i, j], shape)
|
||||
if i == j:
|
||||
result.objectids[i, j] = ra.eye.remote(block_shape[0],
|
||||
block_shape[1],
|
||||
dtype_name=dtype_name)
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape,
|
||||
dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def triu(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is "
|
||||
"{}.".format(a.ndim))
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i < j:
|
||||
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
|
||||
elif i == j:
|
||||
result.objectids[i, j] = ra.triu.remote(a.objectids[i, j])
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is "
|
||||
"{}.".format(a.ndim))
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i < j:
|
||||
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
|
||||
elif i == j:
|
||||
result.objectids[i, j] = ra.triu.remote(a.objectids[i, j])
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def tril(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is "
|
||||
"{}.".format(a.ndim))
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i > j:
|
||||
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
|
||||
elif i == j:
|
||||
result.objectids[i, j] = ra.tril.remote(a.objectids[i, j])
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is "
|
||||
"{}.".format(a.ndim))
|
||||
result = DistArray(a.shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
if i > j:
|
||||
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
|
||||
elif i == j:
|
||||
result.objectids[i, j] = ra.tril.remote(a.objectids[i, j])
|
||||
else:
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def blockwise_dot(*matrices):
|
||||
n = len(matrices)
|
||||
if n % 2 != 0:
|
||||
raise Exception("blockwise_dot expects an even number of arguments, but "
|
||||
"len(matrices) is {}.".format(n))
|
||||
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
|
||||
result = np.zeros(shape)
|
||||
for i in range(n // 2):
|
||||
result += np.dot(matrices[i], matrices[n // 2 + i])
|
||||
return result
|
||||
n = len(matrices)
|
||||
if n % 2 != 0:
|
||||
raise Exception("blockwise_dot expects an even number of arguments, "
|
||||
"but len(matrices) is {}.".format(n))
|
||||
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
|
||||
result = np.zeros(shape)
|
||||
for i in range(n // 2):
|
||||
result += np.dot(matrices[i], matrices[n // 2 + i])
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
if a.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but "
|
||||
"a.ndim = {}.".format(a.ndim))
|
||||
if b.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but "
|
||||
"b.ndim = {}.".format(b.ndim))
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape "
|
||||
"= {} and b.shape = {}.".format(a.shape, b.shape))
|
||||
shape = [a.shape[0], b.shape[1]]
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
args = list(a.objectids[i, :]) + list(b.objectids[:, j])
|
||||
result.objectids[i, j] = blockwise_dot.remote(*args)
|
||||
return result
|
||||
if a.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but "
|
||||
"a.ndim = {}.".format(a.ndim))
|
||||
if b.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but "
|
||||
"b.ndim = {}.".format(b.ndim))
|
||||
if a.shape[1] != b.shape[0]:
|
||||
raise Exception("dot expects a.shape[1] to equal b.shape[0], but "
|
||||
"a.shape = {} and b.shape = {}.".format(a.shape,
|
||||
b.shape))
|
||||
shape = [a.shape[0], b.shape[1]]
|
||||
result = DistArray(shape)
|
||||
for (i, j) in np.ndindex(*result.num_blocks):
|
||||
args = list(a.objectids[i, :]) + list(b.objectids[:, j])
|
||||
result.objectids[i, j] = blockwise_dot.remote(*args)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def subblocks(a, *ranges):
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in the
|
||||
`a`. The result and `a` will have the same number of dimensions.For example,
|
||||
subblocks(a, [0, 1], [2, 4])
|
||||
will produce a DistArray whose objectids are
|
||||
[[a.objectids[0, 2], a.objectids[0, 4]],
|
||||
[a.objectids[1, 2], a.objectids[1, 4]]]
|
||||
We allow the user to pass in an empty list [] to indicate the full range.
|
||||
"""
|
||||
ranges = list(ranges)
|
||||
if len(ranges) != a.ndim:
|
||||
raise Exception("sub_blocks expects to receive a number of ranges equal "
|
||||
"to a.ndim, but it received {} ranges and a.ndim = "
|
||||
"{}.".format(len(ranges), a.ndim))
|
||||
for i in range(len(ranges)):
|
||||
# We allow the user to pass in an empty list to indicate the full range.
|
||||
if ranges[i] == []:
|
||||
ranges[i] = range(a.num_blocks[i])
|
||||
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
|
||||
raise Exception("Ranges passed to sub_blocks must be sorted, but the "
|
||||
"{}th range is {}.".format(i, ranges[i]))
|
||||
if ranges[i][0] < 0:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must be at "
|
||||
"least 0, but the {}th range is {}.".format(i,
|
||||
ranges[i]))
|
||||
if ranges[i][-1] >= a.num_blocks[i]:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must be less "
|
||||
"than the relevant number of blocks, but the {}th range "
|
||||
"is {}, and a.num_blocks = {}.".format(i, ranges[i],
|
||||
a.num_blocks))
|
||||
last_index = [r[-1] for r in ranges]
|
||||
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
|
||||
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i]
|
||||
for i in range(a.ndim)]
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
|
||||
for i in range(a.ndim)])]
|
||||
return result
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in
|
||||
the `a`. The result and `a` will have the same number of dimensions. For
|
||||
example,
|
||||
subblocks(a, [0, 1], [2, 4])
|
||||
will produce a DistArray whose objectids are
|
||||
[[a.objectids[0, 2], a.objectids[0, 4]],
|
||||
[a.objectids[1, 2], a.objectids[1, 4]]]
|
||||
We allow the user to pass in an empty list [] to indicate the full range.
|
||||
"""
|
||||
ranges = list(ranges)
|
||||
if len(ranges) != a.ndim:
|
||||
raise Exception("sub_blocks expects to receive a number of ranges "
|
||||
"equal to a.ndim, but it received {} ranges and "
|
||||
"a.ndim = {}.".format(len(ranges), a.ndim))
|
||||
for i in range(len(ranges)):
|
||||
# We allow the user to pass in an empty list to indicate the full
|
||||
# range.
|
||||
if ranges[i] == []:
|
||||
ranges[i] = range(a.num_blocks[i])
|
||||
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
|
||||
raise Exception("Ranges passed to sub_blocks must be sorted, but "
|
||||
"the {}th range is {}.".format(i, ranges[i]))
|
||||
if ranges[i][0] < 0:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must "
|
||||
"be at least 0, but the {}th range is {}."
|
||||
.format(i, ranges[i]))
|
||||
if ranges[i][-1] >= a.num_blocks[i]:
|
||||
raise Exception("Values in the ranges passed to sub_blocks must "
|
||||
"be less than the relevant number of blocks, but "
|
||||
"the {}th range is {}, and a.num_blocks = {}."
|
||||
.format(i, ranges[i], a.num_blocks))
|
||||
last_index = [r[-1] for r in ranges]
|
||||
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
|
||||
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i]
|
||||
for i in range(a.ndim)]
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
|
||||
for i in range(a.ndim)])]
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote
|
||||
def transpose(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, but "
|
||||
"a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
|
||||
result = DistArray([a.shape[1], a.shape[0]])
|
||||
for i in range(result.num_blocks[0]):
|
||||
for j in range(result.num_blocks[1]):
|
||||
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
|
||||
return result
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, "
|
||||
"but a.ndim = {}, a.shape = {}.".format(a.ndim,
|
||||
a.shape))
|
||||
result = DistArray([a.shape[1], a.shape[0]])
|
||||
for i in range(result.num_blocks[0]):
|
||||
for j in range(result.num_blocks[1]):
|
||||
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
|
||||
return result
|
||||
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same "
|
||||
"shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.add.remote(x1.objectids[index],
|
||||
x2.objectids[index])
|
||||
return result
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same "
|
||||
"shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.add.remote(x1.objectids[index],
|
||||
x2.objectids[index])
|
||||
return result
|
||||
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the "
|
||||
"same shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.subtract.remote(x1.objectids[index],
|
||||
x2.objectids[index])
|
||||
return result
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the "
|
||||
"same shape, but x1.shape = {}, and x2.shape = {}."
|
||||
.format(x1.shape, x2.shape))
|
||||
result = DistArray(x1.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.subtract.remote(x1.objectids[index],
|
||||
x2.objectids[index])
|
||||
return result
|
||||
|
||||
@@ -13,74 +13,75 @@ __all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr(a):
|
||||
"""Perform a QR decomposition of a tall-skinny matrix.
|
||||
"""Perform a QR decomposition of a tall-skinny matrix.
|
||||
|
||||
Args:
|
||||
a: A distributed matrix with shape MxN (suppose K = min(M, N)).
|
||||
Args:
|
||||
a: A distributed matrix with shape MxN (suppose K = min(M, N)).
|
||||
|
||||
Returns:
|
||||
A tuple of q (a DistArray) and r (a numpy array) satisfying the following.
|
||||
- If q_full = ray.get(DistArray, q).assemble(), then
|
||||
q_full.shape == (M, K).
|
||||
- np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True.
|
||||
- If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N).
|
||||
- np.allclose(r, np.triu(r)) == True.
|
||||
"""
|
||||
if len(a.shape) != 2:
|
||||
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is "
|
||||
"{}".format(a.shape))
|
||||
if a.num_blocks[1] != 1:
|
||||
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is "
|
||||
"{}".format(a.num_blocks))
|
||||
Returns:
|
||||
A tuple of q (a DistArray) and r (a numpy array) satisfying the
|
||||
following.
|
||||
- If q_full = ray.get(DistArray, q).assemble(), then
|
||||
q_full.shape == (M, K).
|
||||
- np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True.
|
||||
- If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N).
|
||||
- np.allclose(r, np.triu(r)) == True.
|
||||
"""
|
||||
if len(a.shape) != 2:
|
||||
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is "
|
||||
"{}".format(a.shape))
|
||||
if a.num_blocks[1] != 1:
|
||||
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks "
|
||||
"is {}".format(a.num_blocks))
|
||||
|
||||
num_blocks = a.num_blocks[0]
|
||||
K = int(np.ceil(np.log2(num_blocks))) + 1
|
||||
q_tree = np.empty((num_blocks, K), dtype=object)
|
||||
current_rs = []
|
||||
for i in range(num_blocks):
|
||||
block = a.objectids[i, 0]
|
||||
q, r = ra.linalg.qr.remote(block)
|
||||
q_tree[i, 0] = q
|
||||
current_rs.append(r)
|
||||
for j in range(1, K):
|
||||
new_rs = []
|
||||
for i in range(int(np.ceil(1.0 * len(current_rs) / 2))):
|
||||
stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)])
|
||||
q, r = ra.linalg.qr.remote(stacked_rs)
|
||||
q_tree[i, j] = q
|
||||
new_rs.append(r)
|
||||
current_rs = new_rs
|
||||
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
|
||||
|
||||
# handle the special case in which the whole DistArray "a" fits in one block
|
||||
# and has fewer rows than columns, this is a bit ugly so think about how to
|
||||
# remove it
|
||||
if a.shape[0] >= a.shape[1]:
|
||||
q_shape = a.shape
|
||||
else:
|
||||
q_shape = [a.shape[0], a.shape[0]]
|
||||
q_num_blocks = core.DistArray.compute_num_blocks(q_shape)
|
||||
q_objectids = np.empty(q_num_blocks, dtype=object)
|
||||
q_result = core.DistArray(q_shape, q_objectids)
|
||||
|
||||
# reconstruct output
|
||||
for i in range(num_blocks):
|
||||
q_block_current = q_tree[i, 0]
|
||||
ith_index = i
|
||||
num_blocks = a.num_blocks[0]
|
||||
K = int(np.ceil(np.log2(num_blocks))) + 1
|
||||
q_tree = np.empty((num_blocks, K), dtype=object)
|
||||
current_rs = []
|
||||
for i in range(num_blocks):
|
||||
block = a.objectids[i, 0]
|
||||
q, r = ra.linalg.qr.remote(block)
|
||||
q_tree[i, 0] = q
|
||||
current_rs.append(r)
|
||||
for j in range(1, K):
|
||||
if np.mod(ith_index, 2) == 0:
|
||||
lower = [0, 0]
|
||||
upper = [a.shape[1], core.BLOCK_SIZE]
|
||||
else:
|
||||
lower = [a.shape[1], 0]
|
||||
upper = [2 * a.shape[1], core.BLOCK_SIZE]
|
||||
ith_index //= 2
|
||||
q_block_current = ra.dot.remote(q_block_current,
|
||||
ra.subarray.remote(q_tree[ith_index, j],
|
||||
lower, upper))
|
||||
q_result.objectids[i] = q_block_current
|
||||
r = current_rs[0]
|
||||
return q_result, ray.get(r)
|
||||
new_rs = []
|
||||
for i in range(int(np.ceil(1.0 * len(current_rs) / 2))):
|
||||
stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)])
|
||||
q, r = ra.linalg.qr.remote(stacked_rs)
|
||||
q_tree[i, j] = q
|
||||
new_rs.append(r)
|
||||
current_rs = new_rs
|
||||
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
|
||||
|
||||
# handle the special case in which the whole DistArray "a" fits in one
|
||||
# block and has fewer rows than columns, this is a bit ugly so think about
|
||||
# how to remove it
|
||||
if a.shape[0] >= a.shape[1]:
|
||||
q_shape = a.shape
|
||||
else:
|
||||
q_shape = [a.shape[0], a.shape[0]]
|
||||
q_num_blocks = core.DistArray.compute_num_blocks(q_shape)
|
||||
q_objectids = np.empty(q_num_blocks, dtype=object)
|
||||
q_result = core.DistArray(q_shape, q_objectids)
|
||||
|
||||
# reconstruct output
|
||||
for i in range(num_blocks):
|
||||
q_block_current = q_tree[i, 0]
|
||||
ith_index = i
|
||||
for j in range(1, K):
|
||||
if np.mod(ith_index, 2) == 0:
|
||||
lower = [0, 0]
|
||||
upper = [a.shape[1], core.BLOCK_SIZE]
|
||||
else:
|
||||
lower = [a.shape[1], 0]
|
||||
upper = [2 * a.shape[1], core.BLOCK_SIZE]
|
||||
ith_index //= 2
|
||||
q_block_current = ra.dot.remote(
|
||||
q_block_current, ra.subarray.remote(q_tree[ith_index, j],
|
||||
lower, upper))
|
||||
q_result.objectids[i] = q_block_current
|
||||
r = current_rs[0]
|
||||
return q_result, ray.get(r)
|
||||
|
||||
|
||||
# TODO(rkn): This is unoptimized, we really want a block version of this.
|
||||
@@ -88,76 +89,77 @@ def tsqr(a):
|
||||
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
|
||||
@ray.remote(num_return_vals=3)
|
||||
def modified_lu(q):
|
||||
"""Perform a modified LU decomposition of a matrix.
|
||||
"""Perform a modified LU decomposition of a matrix.
|
||||
|
||||
This takes a matrix q with orthonormal columns, returns l, u, s such that
|
||||
q - s = l * u.
|
||||
This takes a matrix q with orthonormal columns, returns l, u, s such that
|
||||
q - s = l * u.
|
||||
|
||||
Args:
|
||||
q: A two dimensional orthonormal matrix q.
|
||||
Args:
|
||||
q: A two dimensional orthonormal matrix q.
|
||||
|
||||
Returns:
|
||||
A tuple of a lower triangular matrix l, an upper triangular matrix u, and a
|
||||
a vector representing a diagonal matrix s such that q - s = l * u.
|
||||
"""
|
||||
q = q.assemble()
|
||||
m, b = q.shape[0], q.shape[1]
|
||||
S = np.zeros(b)
|
||||
Returns:
|
||||
A tuple of a lower triangular matrix l, an upper triangular matrix u,
|
||||
and a a vector representing a diagonal matrix s such that
|
||||
q - s = l * u.
|
||||
"""
|
||||
q = q.assemble()
|
||||
m, b = q.shape[0], q.shape[1]
|
||||
S = np.zeros(b)
|
||||
|
||||
q_work = np.copy(q)
|
||||
q_work = np.copy(q)
|
||||
|
||||
for i in range(b):
|
||||
S[i] = -1 * np.sign(q_work[i, i])
|
||||
q_work[i, i] -= S[i]
|
||||
# Scale ith column of L by diagonal element.
|
||||
q_work[(i + 1):m, i] /= q_work[i, i]
|
||||
# Perform Schur complement update.
|
||||
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i],
|
||||
q_work[i, (i + 1):b])
|
||||
for i in range(b):
|
||||
S[i] = -1 * np.sign(q_work[i, i])
|
||||
q_work[i, i] -= S[i]
|
||||
# Scale ith column of L by diagonal element.
|
||||
q_work[(i + 1):m, i] /= q_work[i, i]
|
||||
# Perform Schur complement update.
|
||||
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i],
|
||||
q_work[i, (i + 1):b])
|
||||
|
||||
L = np.tril(q_work)
|
||||
for i in range(b):
|
||||
L[i, i] = 1
|
||||
U = np.triu(q_work)[:b, :]
|
||||
# TODO(rkn): Get rid of the put below.
|
||||
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
|
||||
L = np.tril(q_work)
|
||||
for i in range(b):
|
||||
L[i, i] = 1
|
||||
U = np.triu(q_work)[:b, :]
|
||||
# TODO(rkn): Get rid of the put below.
|
||||
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
y_top = y_top_block[:b, :b]
|
||||
s_full = np.diag(s)
|
||||
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
|
||||
return t, y_top
|
||||
y_top = y_top_block[:b, :b]
|
||||
s_full = np.diag(s)
|
||||
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
|
||||
return t, y_top
|
||||
|
||||
|
||||
@ray.remote
|
||||
def tsqr_hr_helper2(s, r_temp):
|
||||
s_full = np.diag(s)
|
||||
return np.dot(s_full, r_temp)
|
||||
s_full = np.diag(s)
|
||||
return np.dot(s_full, r_temp)
|
||||
|
||||
|
||||
# This is Algorithm 6 from
|
||||
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
|
||||
@ray.remote(num_return_vals=4)
|
||||
def tsqr_hr(a):
|
||||
q, r_temp = tsqr.remote(a)
|
||||
y, u, s = modified_lu.remote(q)
|
||||
y_blocked = ray.get(y)
|
||||
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0],
|
||||
a.shape[1])
|
||||
r = tsqr_hr_helper2.remote(s, r_temp)
|
||||
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
|
||||
q, r_temp = tsqr.remote(a)
|
||||
y, u, s = modified_lu.remote(q)
|
||||
y_blocked = ray.get(y)
|
||||
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0],
|
||||
a.shape[1])
|
||||
r = tsqr_hr_helper2.remote(s, r_temp)
|
||||
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def qr_helper1(a_rc, y_ri, t, W_c):
|
||||
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
|
||||
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def qr_helper2(y_ri, a_rc):
|
||||
return np.dot(y_ri.T, a_rc)
|
||||
return np.dot(y_ri.T, a_rc)
|
||||
|
||||
|
||||
# This is Algorithm 7 from
|
||||
@@ -165,60 +167,63 @@ def qr_helper2(y_ri, a_rc):
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
|
||||
m, n = a.shape[0], a.shape[1]
|
||||
k = min(m, n)
|
||||
m, n = a.shape[0], a.shape[1]
|
||||
k = min(m, n)
|
||||
|
||||
# we will store our scratch work in a_work
|
||||
a_work = core.DistArray(a.shape, np.copy(a.objectids))
|
||||
# we will store our scratch work in a_work
|
||||
a_work = core.DistArray(a.shape, np.copy(a.objectids))
|
||||
|
||||
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
|
||||
# TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
r_res = ray.get(core.zeros.remote([k, n], result_dtype))
|
||||
# TODO(rkn): It would be preferable not to get this right after creating it.
|
||||
y_res = ray.get(core.zeros.remote([m, k], result_dtype))
|
||||
Ts = []
|
||||
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
|
||||
# TODO(rkn): It would be preferable not to get this right after creating
|
||||
# it.
|
||||
r_res = ray.get(core.zeros.remote([k, n], result_dtype))
|
||||
# TODO(rkn): It would be preferable not to get this right after creating
|
||||
# it.
|
||||
y_res = ray.get(core.zeros.remote([m, k], result_dtype))
|
||||
Ts = []
|
||||
|
||||
# The for loop differs from the paper, which says
|
||||
# "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense
|
||||
# when a.num_blocks[1] > a.num_blocks[0].
|
||||
for i in range(min(a.num_blocks[0], a.num_blocks[1])):
|
||||
sub_dist_array = core.subblocks.remote(
|
||||
a_work, list(range(i, a_work.num_blocks[0])), [i])
|
||||
y, t, _, R = tsqr_hr.remote(sub_dist_array)
|
||||
y_val = ray.get(y)
|
||||
# The for loop differs from the paper, which says
|
||||
# "for i in range(a.num_blocks[1])", but that doesn't seem to make any
|
||||
# sense when a.num_blocks[1] > a.num_blocks[0].
|
||||
for i in range(min(a.num_blocks[0], a.num_blocks[1])):
|
||||
sub_dist_array = core.subblocks.remote(
|
||||
a_work, list(range(i, a_work.num_blocks[0])), [i])
|
||||
y, t, _, R = tsqr_hr.remote(sub_dist_array)
|
||||
y_val = ray.get(y)
|
||||
|
||||
for j in range(i, a.num_blocks[0]):
|
||||
y_res.objectids[j, i] = y_val.objectids[j - i, 0]
|
||||
if a.shape[0] > a.shape[1]:
|
||||
# in this case, R needs to be square
|
||||
R_shape = ray.get(ra.shape.remote(R))
|
||||
eye_temp = ra.eye.remote(R_shape[1], R_shape[0], dtype_name=result_dtype)
|
||||
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
|
||||
else:
|
||||
r_res.objectids[i, i] = R
|
||||
Ts.append(core.numpy_to_dist.remote(t))
|
||||
for j in range(i, a.num_blocks[0]):
|
||||
y_res.objectids[j, i] = y_val.objectids[j - i, 0]
|
||||
if a.shape[0] > a.shape[1]:
|
||||
# in this case, R needs to be square
|
||||
R_shape = ray.get(ra.shape.remote(R))
|
||||
eye_temp = ra.eye.remote(R_shape[1], R_shape[0],
|
||||
dtype_name=result_dtype)
|
||||
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
|
||||
else:
|
||||
r_res.objectids[i, i] = R
|
||||
Ts.append(core.numpy_to_dist.remote(t))
|
||||
|
||||
for c in range(i + 1, a.num_blocks[1]):
|
||||
W_rcs = []
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objectids[r - i, 0]
|
||||
W_rcs.append(qr_helper2.remote(y_ri, a_work.objectids[r, c]))
|
||||
W_c = ra.sum_list.remote(*W_rcs)
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objectids[r - i, 0]
|
||||
A_rc = qr_helper1.remote(a_work.objectids[r, c], y_ri, t, W_c)
|
||||
a_work.objectids[r, c] = A_rc
|
||||
r_res.objectids[i, c] = a_work.objectids[i, c]
|
||||
for c in range(i + 1, a.num_blocks[1]):
|
||||
W_rcs = []
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objectids[r - i, 0]
|
||||
W_rcs.append(qr_helper2.remote(y_ri, a_work.objectids[r, c]))
|
||||
W_c = ra.sum_list.remote(*W_rcs)
|
||||
for r in range(i, a.num_blocks[0]):
|
||||
y_ri = y_val.objectids[r - i, 0]
|
||||
A_rc = qr_helper1.remote(a_work.objectids[r, c], y_ri, t, W_c)
|
||||
a_work.objectids[r, c] = A_rc
|
||||
r_res.objectids[i, c] = a_work.objectids[i, c]
|
||||
|
||||
# construct q_res from Ys and Ts
|
||||
q = core.eye.remote(m, k, dtype_name=result_dtype)
|
||||
for i in range(len(Ts))[::-1]:
|
||||
y_col_block = core.subblocks.remote(y_res, [], [i])
|
||||
q = core.subtract.remote(
|
||||
q, core.dot.remote(
|
||||
y_col_block,
|
||||
core.dot.remote(Ts[i],
|
||||
core.dot.remote(core.transpose.remote(y_col_block),
|
||||
q))))
|
||||
# construct q_res from Ys and Ts
|
||||
q = core.eye.remote(m, k, dtype_name=result_dtype)
|
||||
for i in range(len(Ts))[::-1]:
|
||||
y_col_block = core.subblocks.remote(y_res, [], [i])
|
||||
q = core.subtract.remote(
|
||||
q, core.dot.remote(
|
||||
y_col_block,
|
||||
core.dot.remote(
|
||||
Ts[i],
|
||||
core.dot.remote(core.transpose.remote(y_col_block), q))))
|
||||
|
||||
return ray.get(q), r_res
|
||||
return ray.get(q), r_res
|
||||
|
||||
@@ -11,10 +11,10 @@ from .core import DistArray
|
||||
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objectids = np.empty(num_blocks, dtype=object)
|
||||
for index in np.ndindex(*num_blocks):
|
||||
objectids[index] = ra.random.normal.remote(
|
||||
DistArray.compute_block_shape(index, shape))
|
||||
result = DistArray(shape, objectids)
|
||||
return result
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objectids = np.empty(num_blocks, dtype=object)
|
||||
for index in np.ndindex(*num_blocks):
|
||||
objectids[index] = ra.random.normal.remote(
|
||||
DistArray.compute_block_shape(index, shape))
|
||||
result = DistArray(shape, objectids)
|
||||
return result
|
||||
|
||||
@@ -8,94 +8,94 @@ import ray
|
||||
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float", order="C"):
|
||||
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def zeros_like(a, dtype_name="None", order="K", subok=True):
|
||||
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
|
||||
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
|
||||
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
|
||||
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float", order="C"):
|
||||
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eye(N, M=-1, k=0, dtype_name="float"):
|
||||
M = N if M == -1 else M
|
||||
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
|
||||
M = N if M == -1 else M
|
||||
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
return np.dot(a, b)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def vstack(*xs):
|
||||
return np.vstack(xs)
|
||||
return np.vstack(xs)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
return np.hstack(xs)
|
||||
|
||||
|
||||
# TODO(rkn): Instead of this, consider implementing slicing.
|
||||
# TODO(rkn): Be consistent about using "index" versus "indices".
|
||||
@ray.remote
|
||||
def subarray(a, lower_indices, upper_indices):
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
|
||||
|
||||
@ray.remote
|
||||
def copy(a, order="K"):
|
||||
return np.copy(a, order=order)
|
||||
return np.copy(a, order=order)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def tril(m, k=0):
|
||||
return np.tril(m, k=k)
|
||||
return np.tril(m, k=k)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def triu(m, k=0):
|
||||
return np.triu(m, k=k)
|
||||
return np.triu(m, k=k)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def diag(v, k=0):
|
||||
return np.diag(v, k=k)
|
||||
return np.diag(v, k=k)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def transpose(a, axes=[]):
|
||||
axes = None if axes == [] else axes
|
||||
return np.transpose(a, axes=axes)
|
||||
axes = None if axes == [] else axes
|
||||
return np.transpose(a, axes=axes)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
return np.add(x1, x2)
|
||||
return np.add(x1, x2)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
return np.subtract(x1, x2)
|
||||
return np.subtract(x1, x2)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def sum(x, axis=-1):
|
||||
return np.sum(x, axis=axis if axis != -1 else None)
|
||||
return np.sum(x, axis=axis if axis != -1 else None)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def shape(a):
|
||||
return np.shape(a)
|
||||
return np.shape(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -13,99 +13,99 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
|
||||
|
||||
@ray.remote
|
||||
def matrix_power(M, n):
|
||||
return np.linalg.matrix_power(M, n)
|
||||
return np.linalg.matrix_power(M, n)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def solve(a, b):
|
||||
return np.linalg.solve(a, b)
|
||||
return np.linalg.solve(a, b)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorsolve(a):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorinv(a):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote
|
||||
def inv(a):
|
||||
return np.linalg.inv(a)
|
||||
return np.linalg.inv(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def cholesky(a):
|
||||
return np.linalg.cholesky(a)
|
||||
return np.linalg.cholesky(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eigvals(a):
|
||||
return np.linalg.eigvals(a)
|
||||
return np.linalg.eigvals(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def eigvalsh(a):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote
|
||||
def pinv(a):
|
||||
return np.linalg.pinv(a)
|
||||
return np.linalg.pinv(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def slogdet(a):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ray.remote
|
||||
def det(a):
|
||||
return np.linalg.det(a)
|
||||
return np.linalg.det(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=3)
|
||||
def svd(a):
|
||||
return np.linalg.svd(a)
|
||||
return np.linalg.svd(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eig(a):
|
||||
return np.linalg.eig(a)
|
||||
return np.linalg.eig(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eigh(a):
|
||||
return np.linalg.eigh(a)
|
||||
return np.linalg.eigh(a)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=4)
|
||||
def lstsq(a, b):
|
||||
return np.linalg.lstsq(a)
|
||||
return np.linalg.lstsq(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def norm(x):
|
||||
return np.linalg.norm(x)
|
||||
return np.linalg.norm(x)
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
return np.linalg.qr(a)
|
||||
return np.linalg.qr(a)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def cond(x):
|
||||
return np.linalg.cond(x)
|
||||
return np.linalg.cond(x)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def matrix_rank(M):
|
||||
return np.linalg.matrix_rank(M)
|
||||
return np.linalg.matrix_rank(M)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -8,4 +8,4 @@ import ray
|
||||
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
return np.random.normal(size=shape)
|
||||
return np.random.normal(size=shape)
|
||||
|
||||
+455
-439
@@ -49,507 +49,523 @@ TASK_STATUS_MAPPING = {
|
||||
|
||||
|
||||
class GlobalState(object):
|
||||
"""A class used to interface with the Ray control state.
|
||||
"""A class used to interface with the Ray control state.
|
||||
|
||||
Attributes:
|
||||
redis_client: The redis client used to query the redis server.
|
||||
"""
|
||||
def __init__(self):
|
||||
"""Create a GlobalState object."""
|
||||
self.redis_client = None
|
||||
|
||||
def _check_connected(self):
|
||||
"""Check that the object has been initialized before it is used.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if ray.init() has not been called yet.
|
||||
Attributes:
|
||||
redis_client: The redis client used to query the redis server.
|
||||
"""
|
||||
if self.redis_client is None:
|
||||
raise Exception("The ray.global_state API cannot be used before "
|
||||
"ray.init has been called.")
|
||||
def __init__(self):
|
||||
"""Create a GlobalState object."""
|
||||
self.redis_client = None
|
||||
|
||||
def _initialize_global_state(self, redis_ip_address, redis_port):
|
||||
"""Initialize the GlobalState object by connecting to Redis.
|
||||
def _check_connected(self):
|
||||
"""Check that the object has been initialized before it is used.
|
||||
|
||||
Args:
|
||||
redis_ip_address: The IP address of the node that the Redis server lives
|
||||
on.
|
||||
redis_port: The port that the Redis server is listening on.
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_clients = []
|
||||
num_redis_shards = self.redis_client.get("NumRedisShards")
|
||||
if num_redis_shards is None:
|
||||
raise Exception("No entry found for NumRedisShards")
|
||||
num_redis_shards = int(num_redis_shards)
|
||||
if (num_redis_shards < 1):
|
||||
raise Exception("Expected at least one Redis shard, found "
|
||||
"{}.".format(num_redis_shards))
|
||||
Raises:
|
||||
Exception: An exception is raised if ray.init() has not been called
|
||||
yet.
|
||||
"""
|
||||
if self.redis_client is None:
|
||||
raise Exception("The ray.global_state API cannot be used before "
|
||||
"ray.init has been called.")
|
||||
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
raise Exception("Expected {} Redis shard addresses, found "
|
||||
"{}".format(num_redis_shards, len(ip_address_ports)))
|
||||
def _initialize_global_state(self, redis_ip_address, redis_port):
|
||||
"""Initialize the GlobalState object by connecting to Redis.
|
||||
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||
port=shard_port))
|
||||
Args:
|
||||
redis_ip_address: The IP address of the node that the Redis server
|
||||
lives on.
|
||||
redis_port: The port that the Redis server is listening on.
|
||||
"""
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.redis_clients = []
|
||||
num_redis_shards = self.redis_client.get("NumRedisShards")
|
||||
if num_redis_shards is None:
|
||||
raise Exception("No entry found for NumRedisShards")
|
||||
num_redis_shards = int(num_redis_shards)
|
||||
if (num_redis_shards < 1):
|
||||
raise Exception("Expected at least one Redis shard, found "
|
||||
"{}.".format(num_redis_shards))
|
||||
|
||||
def _execute_command(self, key, *args):
|
||||
"""Execute a Redis command on the appropriate Redis shard based on key.
|
||||
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
|
||||
end=-1)
|
||||
if len(ip_address_ports) != num_redis_shards:
|
||||
raise Exception("Expected {} Redis shard addresses, found "
|
||||
"{}".format(num_redis_shards,
|
||||
len(ip_address_ports)))
|
||||
|
||||
Args:
|
||||
key: The object ID or the task ID that the query is about.
|
||||
args: The command to run.
|
||||
for ip_address_port in ip_address_ports:
|
||||
shard_address, shard_port = ip_address_port.split(b":")
|
||||
self.redis_clients.append(redis.StrictRedis(host=shard_address,
|
||||
port=shard_port))
|
||||
|
||||
Returns:
|
||||
The value returned by the Redis command.
|
||||
"""
|
||||
client = self.redis_clients[key.redis_shard_hash() %
|
||||
len(self.redis_clients)]
|
||||
return client.execute_command(*args)
|
||||
def _execute_command(self, key, *args):
|
||||
"""Execute a Redis command on the appropriate Redis shard based on key.
|
||||
|
||||
def _keys(self, pattern):
|
||||
"""Execute the KEYS command on all Redis shards.
|
||||
Args:
|
||||
key: The object ID or the task ID that the query is about.
|
||||
args: The command to run.
|
||||
|
||||
Args:
|
||||
pattern: The KEYS pattern to query.
|
||||
Returns:
|
||||
The value returned by the Redis command.
|
||||
"""
|
||||
client = self.redis_clients[key.redis_shard_hash() %
|
||||
len(self.redis_clients)]
|
||||
return client.execute_command(*args)
|
||||
|
||||
Returns:
|
||||
The concatenated list of results from all shards.
|
||||
"""
|
||||
result = []
|
||||
for client in self.redis_clients:
|
||||
result.extend(client.keys(pattern))
|
||||
return result
|
||||
def _keys(self, pattern):
|
||||
"""Execute the KEYS command on all Redis shards.
|
||||
|
||||
def _object_table(self, object_id):
|
||||
"""Fetch and parse the object table information for a single object ID.
|
||||
Args:
|
||||
pattern: The KEYS pattern to query.
|
||||
|
||||
Args:
|
||||
object_id_binary: A string of bytes with the object ID to get information
|
||||
about.
|
||||
Returns:
|
||||
The concatenated list of results from all shards.
|
||||
"""
|
||||
result = []
|
||||
for client in self.redis_clients:
|
||||
result.extend(client.keys(pattern))
|
||||
return result
|
||||
|
||||
Returns:
|
||||
A dictionary with information about the object ID in question.
|
||||
"""
|
||||
# Allow the argument to be either an ObjectID or a hex string.
|
||||
if not isinstance(object_id, ray.local_scheduler.ObjectID):
|
||||
object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id))
|
||||
def _object_table(self, object_id):
|
||||
"""Fetch and parse the object table information for a single object ID.
|
||||
|
||||
# Return information about a single object ID.
|
||||
object_locations = self._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
if object_locations is not None:
|
||||
manager_ids = [binary_to_hex(manager_id)
|
||||
for manager_id in object_locations]
|
||||
else:
|
||||
manager_ids = None
|
||||
Args:
|
||||
object_id_binary: A string of bytes with the object ID to get
|
||||
information about.
|
||||
|
||||
result_table_response = self._execute_command(object_id,
|
||||
"RAY.RESULT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
result_table_message = ResultTableReply.GetRootAsResultTableReply(
|
||||
result_table_response, 0)
|
||||
Returns:
|
||||
A dictionary with information about the object ID in question.
|
||||
"""
|
||||
# Allow the argument to be either an ObjectID or a hex string.
|
||||
if not isinstance(object_id, ray.local_scheduler.ObjectID):
|
||||
object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id))
|
||||
|
||||
result = {"ManagerIDs": manager_ids,
|
||||
"TaskID": binary_to_hex(result_table_message.TaskId()),
|
||||
"IsPut": bool(result_table_message.IsPut()),
|
||||
"DataSize": result_table_message.DataSize(),
|
||||
"Hash": binary_to_hex(result_table_message.Hash())}
|
||||
# Return information about a single object ID.
|
||||
object_locations = self._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_LOOKUP",
|
||||
object_id.id())
|
||||
if object_locations is not None:
|
||||
manager_ids = [binary_to_hex(manager_id)
|
||||
for manager_id in object_locations]
|
||||
else:
|
||||
manager_ids = None
|
||||
|
||||
return result
|
||||
result_table_response = self._execute_command(
|
||||
object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id())
|
||||
result_table_message = ResultTableReply.GetRootAsResultTableReply(
|
||||
result_table_response, 0)
|
||||
|
||||
def object_table(self, object_id=None):
|
||||
"""Fetch and parse the object table information for one or more object IDs.
|
||||
result = {"ManagerIDs": manager_ids,
|
||||
"TaskID": binary_to_hex(result_table_message.TaskId()),
|
||||
"IsPut": bool(result_table_message.IsPut()),
|
||||
"DataSize": result_table_message.DataSize(),
|
||||
"Hash": binary_to_hex(result_table_message.Hash())}
|
||||
|
||||
Args:
|
||||
object_id: An object ID to fetch information about. If this is None, then
|
||||
the entire object table is fetched.
|
||||
return result
|
||||
|
||||
def object_table(self, object_id=None):
|
||||
"""Fetch and parse the object table info for one or more object IDs.
|
||||
|
||||
Args:
|
||||
object_id: An object ID to fetch information about. If this is
|
||||
None, then the entire object table is fetched.
|
||||
|
||||
|
||||
Returns:
|
||||
Information from the object table.
|
||||
"""
|
||||
self._check_connected()
|
||||
if object_id is not None:
|
||||
# Return information about a single object ID.
|
||||
return self._object_table(object_id)
|
||||
else:
|
||||
# Return the entire object table.
|
||||
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
|
||||
object_ids_binary = set(
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
|
||||
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
|
||||
results = {}
|
||||
for object_id_binary in object_ids_binary:
|
||||
results[binary_to_object_id(object_id_binary)] = self._object_table(
|
||||
binary_to_object_id(object_id_binary))
|
||||
return results
|
||||
Returns:
|
||||
Information from the object table.
|
||||
"""
|
||||
self._check_connected()
|
||||
if object_id is not None:
|
||||
# Return information about a single object ID.
|
||||
return self._object_table(object_id)
|
||||
else:
|
||||
# Return the entire object table.
|
||||
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
|
||||
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
|
||||
object_ids_binary = set(
|
||||
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
|
||||
[key[len(OBJECT_LOCATION_PREFIX):]
|
||||
for key in object_location_keys])
|
||||
results = {}
|
||||
for object_id_binary in object_ids_binary:
|
||||
results[binary_to_object_id(object_id_binary)] = (
|
||||
self._object_table(binary_to_object_id(object_id_binary)))
|
||||
return results
|
||||
|
||||
def _task_table(self, task_id):
|
||||
"""Fetch and parse the task table information for a single object task ID.
|
||||
def _task_table(self, task_id):
|
||||
"""Fetch and parse the task table information for a single task ID.
|
||||
|
||||
Args:
|
||||
task_id_binary: A string of bytes with the task ID to get information
|
||||
about.
|
||||
Args:
|
||||
task_id_binary: A string of bytes with the task ID to get
|
||||
information about.
|
||||
|
||||
Returns:
|
||||
A dictionary with information about the task ID in question.
|
||||
TASK_STATUS_MAPPING should be used to parse the "State" field into a
|
||||
human-readable string.
|
||||
"""
|
||||
task_table_response = self._execute_command(task_id,
|
||||
"RAY.TASK_TABLE_GET",
|
||||
task_id.id())
|
||||
if task_table_response is None:
|
||||
raise Exception("There is no entry for task ID {} in the task table."
|
||||
.format(binary_to_hex(task_id.id())))
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
|
||||
task_spec = task_table_message.TaskSpec()
|
||||
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
|
||||
args = []
|
||||
for i in range(task_spec_message.ArgsLength()):
|
||||
arg = task_spec_message.Args(i)
|
||||
if len(arg.ObjectId()) != 0:
|
||||
args.append(binary_to_object_id(arg.ObjectId()))
|
||||
else:
|
||||
args.append(pickle.loads(arg.Data()))
|
||||
assert task_spec_message.RequiredResourcesLength() == 2
|
||||
required_resources = {"CPUs": task_spec_message.RequiredResources(0),
|
||||
"GPUs": task_spec_message.RequiredResources(1)}
|
||||
task_spec_info = {
|
||||
"DriverID": binary_to_hex(task_spec_message.DriverId()),
|
||||
"TaskID": binary_to_hex(task_spec_message.TaskId()),
|
||||
"ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()),
|
||||
"ParentCounter": task_spec_message.ParentCounter(),
|
||||
"ActorID": binary_to_hex(task_spec_message.ActorId()),
|
||||
"ActorCounter": task_spec_message.ActorCounter(),
|
||||
"FunctionID": binary_to_hex(task_spec_message.FunctionId()),
|
||||
"Args": args,
|
||||
"ReturnObjectIDs": [binary_to_object_id(task_spec_message.Returns(i))
|
||||
for i in range(task_spec_message.ReturnsLength())],
|
||||
"RequiredResources": required_resources}
|
||||
Returns:
|
||||
A dictionary with information about the task ID in question.
|
||||
TASK_STATUS_MAPPING should be used to parse the "State" field
|
||||
into a human-readable string.
|
||||
"""
|
||||
task_table_response = self._execute_command(task_id,
|
||||
"RAY.TASK_TABLE_GET",
|
||||
task_id.id())
|
||||
if task_table_response is None:
|
||||
raise Exception("There is no entry for task ID {} in the task "
|
||||
"table.".format(binary_to_hex(task_id.id())))
|
||||
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
|
||||
0)
|
||||
task_spec = task_table_message.TaskSpec()
|
||||
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
|
||||
args = []
|
||||
for i in range(task_spec_message.ArgsLength()):
|
||||
arg = task_spec_message.Args(i)
|
||||
if len(arg.ObjectId()) != 0:
|
||||
args.append(binary_to_object_id(arg.ObjectId()))
|
||||
else:
|
||||
args.append(pickle.loads(arg.Data()))
|
||||
assert task_spec_message.RequiredResourcesLength() == 2
|
||||
required_resources = {"CPUs": task_spec_message.RequiredResources(0),
|
||||
"GPUs": task_spec_message.RequiredResources(1)}
|
||||
task_spec_info = {
|
||||
"DriverID": binary_to_hex(task_spec_message.DriverId()),
|
||||
"TaskID": binary_to_hex(task_spec_message.TaskId()),
|
||||
"ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()),
|
||||
"ParentCounter": task_spec_message.ParentCounter(),
|
||||
"ActorID": binary_to_hex(task_spec_message.ActorId()),
|
||||
"ActorCounter": task_spec_message.ActorCounter(),
|
||||
"FunctionID": binary_to_hex(task_spec_message.FunctionId()),
|
||||
"Args": args,
|
||||
"ReturnObjectIDs": [binary_to_object_id(
|
||||
task_spec_message.Returns(i))
|
||||
for i in range(
|
||||
task_spec_message.ReturnsLength())],
|
||||
"RequiredResources": required_resources}
|
||||
|
||||
return {"State": task_table_message.State(),
|
||||
"LocalSchedulerID": binary_to_hex(
|
||||
task_table_message.LocalSchedulerId()),
|
||||
"TaskSpec": task_spec_info}
|
||||
return {"State": task_table_message.State(),
|
||||
"LocalSchedulerID": binary_to_hex(
|
||||
task_table_message.LocalSchedulerId()),
|
||||
"TaskSpec": task_spec_info}
|
||||
|
||||
def task_table(self, task_id=None):
|
||||
"""Fetch and parse the task table information for one or more task IDs.
|
||||
def task_table(self, task_id=None):
|
||||
"""Fetch and parse the task table information for one or more task IDs.
|
||||
|
||||
Args:
|
||||
task_id: A hex string of the task ID to fetch information about. If this
|
||||
is None, then the task object table is fetched.
|
||||
Args:
|
||||
task_id: A hex string of the task ID to fetch information about. If
|
||||
this is None, then the task object table is fetched.
|
||||
|
||||
|
||||
Returns:
|
||||
Information from the task table.
|
||||
"""
|
||||
self._check_connected()
|
||||
if task_id is not None:
|
||||
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
|
||||
return self._task_table(task_id)
|
||||
else:
|
||||
task_table_keys = self._keys(TASK_PREFIX + "*")
|
||||
results = {}
|
||||
for key in task_table_keys:
|
||||
task_id_binary = key[len(TASK_PREFIX):]
|
||||
results[binary_to_hex(task_id_binary)] = self._task_table(
|
||||
ray.local_scheduler.ObjectID(task_id_binary))
|
||||
return results
|
||||
Returns:
|
||||
Information from the task table.
|
||||
"""
|
||||
self._check_connected()
|
||||
if task_id is not None:
|
||||
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
|
||||
return self._task_table(task_id)
|
||||
else:
|
||||
task_table_keys = self._keys(TASK_PREFIX + "*")
|
||||
results = {}
|
||||
for key in task_table_keys:
|
||||
task_id_binary = key[len(TASK_PREFIX):]
|
||||
results[binary_to_hex(task_id_binary)] = self._task_table(
|
||||
ray.local_scheduler.ObjectID(task_id_binary))
|
||||
return results
|
||||
|
||||
def function_table(self, function_id=None):
|
||||
"""Fetch and parse the function table.
|
||||
def function_table(self, function_id=None):
|
||||
"""Fetch and parse the function table.
|
||||
|
||||
Returns:
|
||||
A dictionary that maps function IDs to information about the function.
|
||||
"""
|
||||
self._check_connected()
|
||||
function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*")
|
||||
results = {}
|
||||
for key in function_table_keys:
|
||||
info = self.redis_client.hgetall(key)
|
||||
function_info_parsed = {
|
||||
"DriverID": binary_to_hex(info[b"driver_id"]),
|
||||
"Module": decode(info[b"module"]),
|
||||
"Name": decode(info[b"name"])
|
||||
}
|
||||
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
|
||||
return results
|
||||
Returns:
|
||||
A dictionary that maps function IDs to information about the
|
||||
function.
|
||||
"""
|
||||
self._check_connected()
|
||||
function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*")
|
||||
results = {}
|
||||
for key in function_table_keys:
|
||||
info = self.redis_client.hgetall(key)
|
||||
function_info_parsed = {
|
||||
"DriverID": binary_to_hex(info[b"driver_id"]),
|
||||
"Module": decode(info[b"module"]),
|
||||
"Name": decode(info[b"name"])}
|
||||
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
|
||||
return results
|
||||
|
||||
def client_table(self):
|
||||
"""Fetch and parse the Redis DB client table.
|
||||
def client_table(self):
|
||||
"""Fetch and parse the Redis DB client table.
|
||||
|
||||
Returns:
|
||||
Information about the Ray clients in the cluster.
|
||||
"""
|
||||
self._check_connected()
|
||||
db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
|
||||
node_info = dict()
|
||||
for key in db_client_keys:
|
||||
client_info = self.redis_client.hgetall(key)
|
||||
node_ip_address = decode(client_info[b"node_ip_address"])
|
||||
if node_ip_address not in node_info:
|
||||
node_info[node_ip_address] = []
|
||||
client_info_parsed = {
|
||||
"ClientType": decode(client_info[b"client_type"]),
|
||||
"Deleted": bool(int(decode(client_info[b"deleted"]))),
|
||||
"DBClientID": binary_to_hex(client_info[b"ray_client_id"])
|
||||
}
|
||||
if b"aux_address" in client_info:
|
||||
client_info_parsed["AuxAddress"] = decode(client_info[b"aux_address"])
|
||||
if b"num_cpus" in client_info:
|
||||
client_info_parsed["NumCPUs"] = float(decode(client_info[b"num_cpus"]))
|
||||
if b"num_gpus" in client_info:
|
||||
client_info_parsed["NumGPUs"] = float(decode(client_info[b"num_gpus"]))
|
||||
if b"local_scheduler_socket_name" in client_info:
|
||||
client_info_parsed["LocalSchedulerSocketName"] = decode(
|
||||
client_info[b"local_scheduler_socket_name"])
|
||||
node_info[node_ip_address].append(client_info_parsed)
|
||||
Returns:
|
||||
Information about the Ray clients in the cluster.
|
||||
"""
|
||||
self._check_connected()
|
||||
db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
|
||||
node_info = dict()
|
||||
for key in db_client_keys:
|
||||
client_info = self.redis_client.hgetall(key)
|
||||
node_ip_address = decode(client_info[b"node_ip_address"])
|
||||
if node_ip_address not in node_info:
|
||||
node_info[node_ip_address] = []
|
||||
client_info_parsed = {
|
||||
"ClientType": decode(client_info[b"client_type"]),
|
||||
"Deleted": bool(int(decode(client_info[b"deleted"]))),
|
||||
"DBClientID": binary_to_hex(client_info[b"ray_client_id"])
|
||||
}
|
||||
if b"aux_address" in client_info:
|
||||
client_info_parsed["AuxAddress"] = decode(
|
||||
client_info[b"aux_address"])
|
||||
if b"num_cpus" in client_info:
|
||||
client_info_parsed["NumCPUs"] = float(
|
||||
decode(client_info[b"num_cpus"]))
|
||||
if b"num_gpus" in client_info:
|
||||
client_info_parsed["NumGPUs"] = float(
|
||||
decode(client_info[b"num_gpus"]))
|
||||
if b"local_scheduler_socket_name" in client_info:
|
||||
client_info_parsed["LocalSchedulerSocketName"] = decode(
|
||||
client_info[b"local_scheduler_socket_name"])
|
||||
node_info[node_ip_address].append(client_info_parsed)
|
||||
|
||||
return node_info
|
||||
return node_info
|
||||
|
||||
def log_files(self):
|
||||
"""Fetch and return a dictionary of log file names to outputs.
|
||||
def log_files(self):
|
||||
"""Fetch and return a dictionary of log file names to outputs.
|
||||
|
||||
Returns:
|
||||
IP address to log file name to log file contents mappings.
|
||||
"""
|
||||
relevant_files = self.redis_client.keys("LOGFILE*")
|
||||
Returns:
|
||||
IP address to log file name to log file contents mappings.
|
||||
"""
|
||||
relevant_files = self.redis_client.keys("LOGFILE*")
|
||||
|
||||
ip_filename_file = dict()
|
||||
ip_filename_file = dict()
|
||||
|
||||
for filename in relevant_files:
|
||||
filename = filename.decode("ascii")
|
||||
filename_components = filename.split(":")
|
||||
ip_addr = filename_components[1]
|
||||
for filename in relevant_files:
|
||||
filename = filename.decode("ascii")
|
||||
filename_components = filename.split(":")
|
||||
ip_addr = filename_components[1]
|
||||
|
||||
file = self.redis_client.lrange(filename, 0, -1)
|
||||
file_str = []
|
||||
for x in file:
|
||||
y = x.decode("ascii")
|
||||
file_str.append(y)
|
||||
file = self.redis_client.lrange(filename, 0, -1)
|
||||
file_str = []
|
||||
for x in file:
|
||||
y = x.decode("ascii")
|
||||
file_str.append(y)
|
||||
|
||||
if ip_addr not in ip_filename_file:
|
||||
ip_filename_file[ip_addr] = dict()
|
||||
if ip_addr not in ip_filename_file:
|
||||
ip_filename_file[ip_addr] = dict()
|
||||
|
||||
ip_filename_file[ip_addr][filename] = file_str
|
||||
ip_filename_file[ip_addr][filename] = file_str
|
||||
|
||||
return ip_filename_file
|
||||
return ip_filename_file
|
||||
|
||||
def task_profiles(self, start=None, end=None, num=None):
|
||||
"""Fetch and return a list of task profiles.
|
||||
def task_profiles(self, start=None, end=None, num=None):
|
||||
"""Fetch and return a list of task profiles.
|
||||
|
||||
Args:
|
||||
start: The start point of the time window that is queried for tasks.
|
||||
end: The end point in time of the time window that is queried for tasks.
|
||||
num: A limit on the number of tasks that task_profiles will return.
|
||||
Args:
|
||||
start: The start point of the time window that is queried for
|
||||
tasks.
|
||||
end: The end point in time of the time window that is queried for
|
||||
tasks.
|
||||
num: A limit on the number of tasks that task_profiles will return.
|
||||
|
||||
Returns:
|
||||
A tuple of two elements. The first element is a dictionary mapping the
|
||||
task ID of a task to a list of the profiling information for all of the
|
||||
executions of that task. The second element is a list of profiling
|
||||
information for tasks where the events have no task ID.
|
||||
"""
|
||||
if start is None:
|
||||
start = 0
|
||||
if num is None:
|
||||
num = sys.maxsize
|
||||
Returns:
|
||||
A tuple of two elements. The first element is a dictionary mapping
|
||||
the task ID of a task to a list of the profiling information
|
||||
for all of the executions of that task. The second element is a
|
||||
list of profiling information for tasks where the events have
|
||||
no task ID.
|
||||
"""
|
||||
if start is None:
|
||||
start = 0
|
||||
if num is None:
|
||||
num = sys.maxsize
|
||||
|
||||
task_info = dict()
|
||||
event_log_sets = self.redis_client.keys("event_log*")
|
||||
task_info = dict()
|
||||
event_log_sets = self.redis_client.keys("event_log*")
|
||||
|
||||
# The heap is used to maintain the set of x tasks that occurred the most
|
||||
# recently across all of the workers, where x is defined as the function
|
||||
# parameter num. The key is the start time of the "get_task" component of
|
||||
# each task. Calling heappop will result in the taks with the earliest
|
||||
# "get_task_start" to be removed from the heap.
|
||||
# The heap is used to maintain the set of x tasks that occurred the
|
||||
# most recently across all of the workers, where x is defined as the
|
||||
# function parameter num. The key is the start time of the "get_task"
|
||||
# component of each task. Calling heappop will result in the taks with
|
||||
# the earliest "get_task_start" to be removed from the heap.
|
||||
|
||||
heap = []
|
||||
heapq.heapify(heap)
|
||||
heap_size = 0
|
||||
# Parse through event logs to determine task start and end points.
|
||||
for i in range(len(event_log_sets)):
|
||||
event_list = self.redis_client.zrangebyscore(event_log_sets[i],
|
||||
min=start,
|
||||
max=end,
|
||||
start=start,
|
||||
num=num)
|
||||
for event in event_list:
|
||||
event_dict = json.loads(event)
|
||||
task_id = ""
|
||||
for event in event_dict:
|
||||
if "task_id" in event[3]:
|
||||
task_id = event[3]["task_id"]
|
||||
task_info[task_id] = dict()
|
||||
for event in event_dict:
|
||||
if event[1] == "ray:get_task" and event[2] == 1:
|
||||
task_info[task_id]["get_task_start"] = event[0]
|
||||
# Add task to min heap by its start point.
|
||||
heapq.heappush(heap,
|
||||
(task_info[task_id]["get_task_start"], task_id))
|
||||
heap_size += 1
|
||||
if event[1] == "ray:get_task" and event[2] == 2:
|
||||
task_info[task_id]["get_task_end"] = event[0]
|
||||
if event[1] == "ray:import_remote_function" and event[2] == 1:
|
||||
task_info[task_id]["import_remote_start"] = event[0]
|
||||
if event[1] == "ray:import_remote_function" and event[2] == 2:
|
||||
task_info[task_id]["import_remote_end"] = event[0]
|
||||
if event[1] == "ray:acquire_lock" and event[2] == 1:
|
||||
task_info[task_id]["acquire_lock_start"] = event[0]
|
||||
if event[1] == "ray:acquire_lock" and event[2] == 2:
|
||||
task_info[task_id]["acquire_lock_end"] = event[0]
|
||||
if event[1] == "ray:task:get_arguments" and event[2] == 1:
|
||||
task_info[task_id]["get_arguments_start"] = event[0]
|
||||
if event[1] == "ray:task:get_arguments" and event[2] == 2:
|
||||
task_info[task_id]["get_arguments_end"] = event[0]
|
||||
if event[1] == "ray:task:execute" and event[2] == 1:
|
||||
task_info[task_id]["execute_start"] = event[0]
|
||||
if event[1] == "ray:task:execute" and event[2] == 2:
|
||||
task_info[task_id]["execute_end"] = event[0]
|
||||
if event[1] == "ray:task:store_outputs" and event[2] == 1:
|
||||
task_info[task_id]["store_outputs_start"] = event[0]
|
||||
if event[1] == "ray:task:store_outputs" and event[2] == 2:
|
||||
task_info[task_id]["store_outputs_end"] = event[0]
|
||||
if "worker_id" in event[3]:
|
||||
task_info[task_id]["worker_id"] = event[3]["worker_id"]
|
||||
if "function_name" in event[3]:
|
||||
task_info[task_id]["function_name"] = event[3]["function_name"]
|
||||
if heap_size > num:
|
||||
min_task, task_id_hex = heapq.heappop(heap)
|
||||
del task_info[task_id_hex]
|
||||
heap_size -= 1
|
||||
return task_info
|
||||
heap = []
|
||||
heapq.heapify(heap)
|
||||
heap_size = 0
|
||||
# Parse through event logs to determine task start and end points.
|
||||
for i in range(len(event_log_sets)):
|
||||
event_list = self.redis_client.zrangebyscore(event_log_sets[i],
|
||||
min=start,
|
||||
max=end,
|
||||
start=start,
|
||||
num=num)
|
||||
for event in event_list:
|
||||
event_dict = json.loads(event)
|
||||
task_id = ""
|
||||
for event in event_dict:
|
||||
if "task_id" in event[3]:
|
||||
task_id = event[3]["task_id"]
|
||||
task_info[task_id] = dict()
|
||||
for event in event_dict:
|
||||
if event[1] == "ray:get_task" and event[2] == 1:
|
||||
task_info[task_id]["get_task_start"] = event[0]
|
||||
# Add task to min heap by its start point.
|
||||
heapq.heappush(heap,
|
||||
(task_info[task_id]["get_task_start"],
|
||||
task_id))
|
||||
heap_size += 1
|
||||
if event[1] == "ray:get_task" and event[2] == 2:
|
||||
task_info[task_id]["get_task_end"] = event[0]
|
||||
if (event[1] == "ray:import_remote_function" and
|
||||
event[2] == 1):
|
||||
task_info[task_id]["import_remote_start"] = event[0]
|
||||
if (event[1] == "ray:import_remote_function" and
|
||||
event[2] == 2):
|
||||
task_info[task_id]["import_remote_end"] = event[0]
|
||||
if event[1] == "ray:acquire_lock" and event[2] == 1:
|
||||
task_info[task_id]["acquire_lock_start"] = event[0]
|
||||
if event[1] == "ray:acquire_lock" and event[2] == 2:
|
||||
task_info[task_id]["acquire_lock_end"] = event[0]
|
||||
if event[1] == "ray:task:get_arguments" and event[2] == 1:
|
||||
task_info[task_id]["get_arguments_start"] = event[0]
|
||||
if event[1] == "ray:task:get_arguments" and event[2] == 2:
|
||||
task_info[task_id]["get_arguments_end"] = event[0]
|
||||
if event[1] == "ray:task:execute" and event[2] == 1:
|
||||
task_info[task_id]["execute_start"] = event[0]
|
||||
if event[1] == "ray:task:execute" and event[2] == 2:
|
||||
task_info[task_id]["execute_end"] = event[0]
|
||||
if event[1] == "ray:task:store_outputs" and event[2] == 1:
|
||||
task_info[task_id]["store_outputs_start"] = event[0]
|
||||
if event[1] == "ray:task:store_outputs" and event[2] == 2:
|
||||
task_info[task_id]["store_outputs_end"] = event[0]
|
||||
if "worker_id" in event[3]:
|
||||
task_info[task_id]["worker_id"] = event[3]["worker_id"]
|
||||
if "function_name" in event[3]:
|
||||
task_info[task_id]["function_name"] = (
|
||||
event[3]["function_name"])
|
||||
if heap_size > num:
|
||||
min_task, task_id_hex = heapq.heappop(heap)
|
||||
del task_info[task_id_hex]
|
||||
heap_size -= 1
|
||||
return task_info
|
||||
|
||||
def dump_catapult_trace(self, path, start=None, end=None, num=None):
|
||||
"""Dump task profiling information to a file.
|
||||
def dump_catapult_trace(self, path, start=None, end=None, num=None):
|
||||
"""Dump task profiling information to a file.
|
||||
|
||||
This information can be viewed as a timeline of profiling information by
|
||||
going to chrome://tracing in the chrome web browser and loading the
|
||||
appropriate file.
|
||||
This information can be viewed as a timeline of profiling information
|
||||
by going to chrome://tracing in the chrome web browser and loading the
|
||||
appropriate file.
|
||||
|
||||
Args:
|
||||
path: The filepath to dump the profiling information to.
|
||||
"""
|
||||
if end is None:
|
||||
end = time.time()
|
||||
task_info = self.task_profiles(start=start, end=end, num=num)
|
||||
workers = self.workers()
|
||||
start_time = None
|
||||
for info in task_info.values():
|
||||
task_start = min(self._get_times(info))
|
||||
if not start_time or task_start < start_time:
|
||||
start_time = task_start
|
||||
Args:
|
||||
path: The filepath to dump the profiling information to.
|
||||
"""
|
||||
if end is None:
|
||||
end = time.time()
|
||||
task_info = self.task_profiles(start=start, end=end, num=num)
|
||||
workers = self.workers()
|
||||
start_time = None
|
||||
for info in task_info.values():
|
||||
task_start = min(self._get_times(info))
|
||||
if not start_time or task_start < start_time:
|
||||
start_time = task_start
|
||||
|
||||
def micros(ts):
|
||||
return int(1e6 * (ts - start_time))
|
||||
def micros(ts):
|
||||
return int(1e6 * (ts - start_time))
|
||||
|
||||
full_trace = []
|
||||
for task_id, info in task_info.items():
|
||||
task_id_hex = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
|
||||
task_data = self._task_table(task_id_hex)
|
||||
parent_info = task_info.get(task_data["TaskSpec"]["ParentTaskID"])
|
||||
times = self._get_times(info)
|
||||
worker = workers[info["worker_id"]]
|
||||
if parent_info:
|
||||
parent_worker = workers[parent_info["worker_id"]]
|
||||
parent_times = self._get_times(parent_info)
|
||||
parent_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + str(parent_worker["node_ip_address"]),
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros(min(parent_times)),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"args": {},
|
||||
"id": str(worker)
|
||||
}
|
||||
full_trace.append(parent_trace)
|
||||
full_trace = []
|
||||
for task_id, info in task_info.items():
|
||||
task_id_hex = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
|
||||
task_data = self._task_table(task_id_hex)
|
||||
parent_info = task_info.get(task_data["TaskSpec"]["ParentTaskID"])
|
||||
times = self._get_times(info)
|
||||
worker = workers[info["worker_id"]]
|
||||
if parent_info:
|
||||
parent_worker = workers[parent_info["worker_id"]]
|
||||
parent_times = self._get_times(parent_info)
|
||||
parent_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + str(parent_worker["node_ip_address"]),
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros(min(parent_times)),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"args": {},
|
||||
"id": str(worker)
|
||||
}
|
||||
full_trace.append(parent_trace)
|
||||
|
||||
parent = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + str(parent_worker["node_ip_address"]),
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros(min(parent_times)),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"args": {},
|
||||
"id": str(worker)
|
||||
}
|
||||
full_trace.append(parent)
|
||||
parent = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + str(parent_worker["node_ip_address"]),
|
||||
"tid": parent_info["worker_id"],
|
||||
"ts": micros(min(parent_times)),
|
||||
"ph": "s",
|
||||
"name": "SubmitTask",
|
||||
"args": {},
|
||||
"id": str(worker)
|
||||
}
|
||||
full_trace.append(parent)
|
||||
|
||||
task_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + str(worker["node_ip_address"]),
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros(min(times)),
|
||||
"ph": "f",
|
||||
"name": "SubmitTask",
|
||||
"args": {},
|
||||
"id": str(worker)
|
||||
}
|
||||
full_trace.append(task_trace)
|
||||
task_trace = {
|
||||
"cat": "submit_task",
|
||||
"pid": "Node " + str(worker["node_ip_address"]),
|
||||
"tid": info["worker_id"],
|
||||
"ts": micros(min(times)),
|
||||
"ph": "f",
|
||||
"name": "SubmitTask",
|
||||
"args": {},
|
||||
"id": str(worker)
|
||||
}
|
||||
full_trace.append(task_trace)
|
||||
|
||||
task = {
|
||||
"name": info["function_name"],
|
||||
"cat": "ray_task",
|
||||
"ph": "X",
|
||||
"ts": micros(min(times)),
|
||||
"dur": micros(max(times)) - micros(min(times)),
|
||||
"pid": "Node " + str(worker["node_ip_address"]),
|
||||
"tid": info["worker_id"],
|
||||
"args": info
|
||||
}
|
||||
full_trace.append(task)
|
||||
task = {
|
||||
"name": info["function_name"],
|
||||
"cat": "ray_task",
|
||||
"ph": "X",
|
||||
"ts": micros(min(times)),
|
||||
"dur": micros(max(times)) - micros(min(times)),
|
||||
"pid": "Node " + str(worker["node_ip_address"]),
|
||||
"tid": info["worker_id"],
|
||||
"args": info
|
||||
}
|
||||
full_trace.append(task)
|
||||
|
||||
with open(path, "w") as outfile:
|
||||
json.dump(full_trace, outfile)
|
||||
with open(path, "w") as outfile:
|
||||
json.dump(full_trace, outfile)
|
||||
|
||||
def _get_times(self, data):
|
||||
"""Extract the numerical times from a task profile.
|
||||
def _get_times(self, data):
|
||||
"""Extract the numerical times from a task profile.
|
||||
|
||||
This is a helper method for dump_catapult_trace.
|
||||
This is a helper method for dump_catapult_trace.
|
||||
|
||||
Args:
|
||||
data: This must be a value in the dictionary returned by the
|
||||
task_profiles function.
|
||||
"""
|
||||
all_times = []
|
||||
all_times.append(data["acquire_lock_start"])
|
||||
all_times.append(data["acquire_lock_end"])
|
||||
all_times.append(data["get_arguments_start"])
|
||||
all_times.append(data["get_arguments_end"])
|
||||
all_times.append(data["execute_start"])
|
||||
all_times.append(data["execute_end"])
|
||||
all_times.append(data["store_outputs_start"])
|
||||
all_times.append(data["store_outputs_end"])
|
||||
return all_times
|
||||
Args:
|
||||
data: This must be a value in the dictionary returned by the
|
||||
task_profiles function.
|
||||
"""
|
||||
all_times = []
|
||||
all_times.append(data["acquire_lock_start"])
|
||||
all_times.append(data["acquire_lock_end"])
|
||||
all_times.append(data["get_arguments_start"])
|
||||
all_times.append(data["get_arguments_end"])
|
||||
all_times.append(data["execute_start"])
|
||||
all_times.append(data["execute_end"])
|
||||
all_times.append(data["store_outputs_start"])
|
||||
all_times.append(data["store_outputs_end"])
|
||||
return all_times
|
||||
|
||||
def workers(self):
|
||||
"""Get a dictionary mapping worker ID to worker information."""
|
||||
worker_keys = self.redis_client.keys("Worker*")
|
||||
workers_data = dict()
|
||||
def workers(self):
|
||||
"""Get a dictionary mapping worker ID to worker information."""
|
||||
worker_keys = self.redis_client.keys("Worker*")
|
||||
workers_data = dict()
|
||||
|
||||
for worker_key in worker_keys:
|
||||
worker_info = self.redis_client.hgetall(worker_key)
|
||||
worker_id = binary_to_hex(worker_key[len("Workers:"):])
|
||||
for worker_key in worker_keys:
|
||||
worker_info = self.redis_client.hgetall(worker_key)
|
||||
worker_id = binary_to_hex(worker_key[len("Workers:"):])
|
||||
|
||||
workers_data[worker_id] = {
|
||||
"local_scheduler_socket": (worker_info[b"local_scheduler_socket"]
|
||||
.decode("ascii")),
|
||||
"node_ip_address": (worker_info[b"node_ip_address"]
|
||||
.decode("ascii")),
|
||||
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
|
||||
.decode("ascii")),
|
||||
"plasma_store_socket": (worker_info[b"plasma_store_socket"]
|
||||
.decode("ascii")),
|
||||
"stderr_file": worker_info[b"stderr_file"].decode("ascii"),
|
||||
"stdout_file": worker_info[b"stdout_file"].decode("ascii")
|
||||
}
|
||||
return workers_data
|
||||
workers_data[worker_id] = {
|
||||
"local_scheduler_socket": (
|
||||
worker_info[b"local_scheduler_socket"].decode("ascii")),
|
||||
"node_ip_address": (
|
||||
worker_info[b"node_ip_address"].decode("ascii")),
|
||||
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
|
||||
.decode("ascii")),
|
||||
"plasma_store_socket": (worker_info[b"plasma_store_socket"]
|
||||
.decode("ascii")),
|
||||
"stderr_file": worker_info[b"stderr_file"].decode("ascii"),
|
||||
"stdout_file": worker_info[b"stdout_file"].decode("ascii")
|
||||
}
|
||||
return workers_data
|
||||
|
||||
@@ -6,114 +6,116 @@ from collections import deque, OrderedDict
|
||||
|
||||
|
||||
def unflatten(vector, shapes):
|
||||
i = 0
|
||||
arrays = []
|
||||
for shape in shapes:
|
||||
size = np.prod(shape)
|
||||
array = vector[i:(i + size)].reshape(shape)
|
||||
arrays.append(array)
|
||||
i += size
|
||||
assert len(vector) == i, "Passed weight does not have the correct shape."
|
||||
return arrays
|
||||
i = 0
|
||||
arrays = []
|
||||
for shape in shapes:
|
||||
size = np.prod(shape)
|
||||
array = vector[i:(i + size)].reshape(shape)
|
||||
arrays.append(array)
|
||||
i += size
|
||||
assert len(vector) == i, "Passed weight does not have the correct shape."
|
||||
return arrays
|
||||
|
||||
|
||||
class TensorFlowVariables(object):
|
||||
"""An object used to extract variables from a loss function.
|
||||
"""An object used to extract variables from a loss function.
|
||||
|
||||
This object also provides methods for getting and setting the weights of the
|
||||
relevant variables.
|
||||
This object also provides methods for getting and setting the weights of
|
||||
the relevant variables.
|
||||
|
||||
Attributes:
|
||||
sess (tf.Session): The tensorflow session used to run assignment.
|
||||
loss: The loss function passed in by the user.
|
||||
variables (List[tf.Variable]): Extracted variables from the loss.
|
||||
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
|
||||
passed to.
|
||||
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
|
||||
"""
|
||||
def __init__(self, loss, sess=None):
|
||||
"""Creates a TensorFlowVariables instance."""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
self.loss = loss
|
||||
queue = deque([loss])
|
||||
variable_names = []
|
||||
explored_inputs = set([loss])
|
||||
Attributes:
|
||||
sess (tf.Session): The tensorflow session used to run assignment.
|
||||
loss: The loss function passed in by the user.
|
||||
variables (List[tf.Variable]): Extracted variables from the loss.
|
||||
assignment_placeholders (List[tf.placeholders]): The nodes that weights
|
||||
get passed to.
|
||||
assignment _nodes (List[tf.Tensor]): The nodes that assign the weights.
|
||||
"""
|
||||
def __init__(self, loss, sess=None):
|
||||
"""Creates a TensorFlowVariables instance."""
|
||||
import tensorflow as tf
|
||||
self.sess = sess
|
||||
self.loss = loss
|
||||
queue = deque([loss])
|
||||
variable_names = []
|
||||
explored_inputs = set([loss])
|
||||
|
||||
# We do a BFS on the dependency graph of the input function to find
|
||||
# the variables.
|
||||
while len(queue) != 0:
|
||||
tf_obj = queue.popleft()
|
||||
# We do a BFS on the dependency graph of the input function to find
|
||||
# the variables.
|
||||
while len(queue) != 0:
|
||||
tf_obj = queue.popleft()
|
||||
|
||||
# The object put into the queue is not necessarily an operation, so we
|
||||
# want the op attribute to get the operation underlying the object.
|
||||
# Only operations contain the inputs that we can explore.
|
||||
if hasattr(tf_obj, "op"):
|
||||
tf_obj = tf_obj.op
|
||||
for input_op in tf_obj.inputs:
|
||||
if input_op not in explored_inputs:
|
||||
queue.append(input_op)
|
||||
explored_inputs.add(input_op)
|
||||
# Tensorflow control inputs can be circular, so we keep track of
|
||||
# explored operations.
|
||||
for control in tf_obj.control_inputs:
|
||||
if control not in explored_inputs:
|
||||
queue.append(control)
|
||||
explored_inputs.add(control)
|
||||
if "Variable" in tf_obj.node_def.op:
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
for v in [v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names]:
|
||||
self.variables[v.op.node_def.name] = v
|
||||
self.placeholders = dict()
|
||||
self.assignment_nodes = []
|
||||
# The object put into the queue is not necessarily an operation, so
|
||||
# we want the op attribute to get the operation underlying the
|
||||
# object. Only operations contain the inputs that we can explore.
|
||||
if hasattr(tf_obj, "op"):
|
||||
tf_obj = tf_obj.op
|
||||
for input_op in tf_obj.inputs:
|
||||
if input_op not in explored_inputs:
|
||||
queue.append(input_op)
|
||||
explored_inputs.add(input_op)
|
||||
# Tensorflow control inputs can be circular, so we keep track of
|
||||
# explored operations.
|
||||
for control in tf_obj.control_inputs:
|
||||
if control not in explored_inputs:
|
||||
queue.append(control)
|
||||
explored_inputs.add(control)
|
||||
if "Variable" in tf_obj.node_def.op:
|
||||
variable_names.append(tf_obj.node_def.name)
|
||||
self.variables = OrderedDict()
|
||||
for v in [v for v in tf.global_variables()
|
||||
if v.op.node_def.name in variable_names]:
|
||||
self.variables[v.op.node_def.name] = v
|
||||
self.placeholders = dict()
|
||||
self.assignment_nodes = []
|
||||
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf.placeholder(var.value().dtype,
|
||||
var.get_shape().as_list())
|
||||
self.assignment_nodes.append(var.assign(self.placeholders[k]))
|
||||
# Create new placeholders to put in custom weights.
|
||||
for k, var in self.variables.items():
|
||||
self.placeholders[k] = tf.placeholder(var.value().dtype,
|
||||
var.get_shape().as_list())
|
||||
self.assignment_nodes.append(var.assign(self.placeholders[k]))
|
||||
|
||||
def set_session(self, sess):
|
||||
"""Modifies the current session used by the class."""
|
||||
self.sess = sess
|
||||
def set_session(self, sess):
|
||||
"""Modifies the current session used by the class."""
|
||||
self.sess = sess
|
||||
|
||||
def get_flat_size(self):
|
||||
return sum([np.prod(v.get_shape().as_list())
|
||||
for v in self.variables.values()])
|
||||
def get_flat_size(self):
|
||||
return sum([np.prod(v.get_shape().as_list())
|
||||
for v in self.variables.values()])
|
||||
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
assert self.sess is not None, ("The session is not set. Set the session "
|
||||
"either by passing it into the "
|
||||
"TensorFlowVariables constructor or by "
|
||||
"calling set_session(sess).")
|
||||
def _check_sess(self):
|
||||
"""Checks if the session is set, and if not throw an error message."""
|
||||
assert self.sess is not None, ("The session is not set. Set the "
|
||||
"session either by passing it into the "
|
||||
"TensorFlowVariables constructor or by "
|
||||
"calling set_session(sess).")
|
||||
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array."""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()])
|
||||
def get_flat(self):
|
||||
"""Gets the weights and returns them as a flat array."""
|
||||
self._check_sess()
|
||||
return np.concatenate([v.eval(session=self.sess).flatten()
|
||||
for v in self.variables.values()])
|
||||
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array."""
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.placeholders[k] for k, v in self.variables.items()]
|
||||
self.sess.run(self.assignment_nodes,
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
def set_flat(self, new_weights):
|
||||
"""Sets the weights to new_weights, converting from a flat array."""
|
||||
self._check_sess()
|
||||
shapes = [v.get_shape().as_list() for v in self.variables.values()]
|
||||
arrays = unflatten(new_weights, shapes)
|
||||
placeholders = [self.placeholders[k]
|
||||
for k, v in self.variables.items()]
|
||||
self.sess.run(self.assignment_nodes,
|
||||
feed_dict=dict(zip(placeholders, arrays)))
|
||||
|
||||
def get_weights(self):
|
||||
"""Returns the weights of the variables of the loss function in a list."""
|
||||
self._check_sess()
|
||||
return {k: v.eval(session=self.sess) for k, v in self.variables.items()}
|
||||
def get_weights(self):
|
||||
"""Returns a list of the weights of the loss function variables."""
|
||||
self._check_sess()
|
||||
return {k: v.eval(session=self.sess)
|
||||
for k, v in self.variables.items()}
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights."""
|
||||
self._check_sess()
|
||||
self.sess.run(self.assignment_nodes,
|
||||
feed_dict={self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders})
|
||||
def set_weights(self, new_weights):
|
||||
"""Sets the weights to new_weights."""
|
||||
self._check_sess()
|
||||
self.sess.run(self.assignment_nodes,
|
||||
feed_dict={self.placeholders[name]: value
|
||||
for (name, value) in new_weights.items()
|
||||
if name in self.placeholders})
|
||||
|
||||
@@ -11,70 +11,72 @@ import ray
|
||||
|
||||
|
||||
def tarred_directory_as_bytes(source_dir):
|
||||
"""Tar a directory and return it as a byte string.
|
||||
"""Tar a directory and return it as a byte string.
|
||||
|
||||
Args:
|
||||
source_dir (str): The name of the directory to tar.
|
||||
Args:
|
||||
source_dir (str): The name of the directory to tar.
|
||||
|
||||
Returns:
|
||||
A byte string representing the tarred file.
|
||||
"""
|
||||
# Get a BytesIO object.
|
||||
string_file = io.BytesIO()
|
||||
# Create an in-memory tarfile of the source directory.
|
||||
with tarfile.open(mode="w:gz", fileobj=string_file) as tar:
|
||||
tar.add(source_dir, arcname=os.path.basename(source_dir))
|
||||
string_file.seek(0)
|
||||
return string_file.read()
|
||||
Returns:
|
||||
A byte string representing the tarred file.
|
||||
"""
|
||||
# Get a BytesIO object.
|
||||
string_file = io.BytesIO()
|
||||
# Create an in-memory tarfile of the source directory.
|
||||
with tarfile.open(mode="w:gz", fileobj=string_file) as tar:
|
||||
tar.add(source_dir, arcname=os.path.basename(source_dir))
|
||||
string_file.seek(0)
|
||||
return string_file.read()
|
||||
|
||||
|
||||
def tarred_bytes_to_directory(tarred_bytes, target_dir):
|
||||
"""Take a byte string and untar it.
|
||||
"""Take a byte string and untar it.
|
||||
|
||||
Args:
|
||||
tarred_bytes (str): A byte string representing the tarred file. This should
|
||||
be the output of tarred_directory_as_bytes.
|
||||
target_dir (str): The directory to create the untarred files in.
|
||||
"""
|
||||
string_file = io.BytesIO(tarred_bytes)
|
||||
with tarfile.open(fileobj=string_file) as tar:
|
||||
tar.extractall(path=target_dir)
|
||||
Args:
|
||||
tarred_bytes (str): A byte string representing the tarred file. This
|
||||
should be the output of tarred_directory_as_bytes.
|
||||
target_dir (str): The directory to create the untarred files in.
|
||||
"""
|
||||
string_file = io.BytesIO(tarred_bytes)
|
||||
with tarfile.open(fileobj=string_file) as tar:
|
||||
tar.extractall(path=target_dir)
|
||||
|
||||
|
||||
def copy_directory(source_dir, target_dir=None):
|
||||
"""Copy a local directory to each machine in the Ray cluster.
|
||||
"""Copy a local directory to each machine in the Ray cluster.
|
||||
|
||||
Note that both source_dir and target_dir must have the same basename). For
|
||||
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case,
|
||||
the directory /d/e will be added to the Python path of each worker.
|
||||
Note that both source_dir and target_dir must have the same basename). For
|
||||
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this
|
||||
case, the directory /d/e will be added to the Python path of each worker.
|
||||
|
||||
Note that this method is not completely safe to use. For example, workers
|
||||
that do not do the copying and only set their paths (only one worker per node
|
||||
does the copying) may try to execute functions that use the files in the
|
||||
directory being copied before the directory being copied has finished
|
||||
untarring.
|
||||
Note that this method is not completely safe to use. For example, workers
|
||||
that do not do the copying and only set their paths (only one worker per
|
||||
node does the copying) may try to execute functions that use the files in
|
||||
the directory being copied before the directory being copied has finished
|
||||
untarring.
|
||||
|
||||
Args:
|
||||
source_dir (str): The directory to copy.
|
||||
target_dir (str): The location to copy it to on the other machines. If this
|
||||
is not provided, the source_dir will be used. If it is provided and is
|
||||
different from source_dir, the source_dir also be copied to the
|
||||
target_dir location on this machine.
|
||||
"""
|
||||
target_dir = source_dir if target_dir is None else target_dir
|
||||
source_dir = os.path.abspath(source_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
source_basename = os.path.basename(source_dir)
|
||||
target_basename = os.path.basename(target_dir)
|
||||
if source_basename != target_basename:
|
||||
raise Exception("The source_dir and target_dir must have the same base "
|
||||
"name, {} != {}".format(source_basename, target_basename))
|
||||
tarred_bytes = tarred_directory_as_bytes(source_dir)
|
||||
Args:
|
||||
source_dir (str): The directory to copy.
|
||||
target_dir (str): The location to copy it to on the other machines. If
|
||||
this is not provided, the source_dir will be used. If it is
|
||||
provided and is different from source_dir, the source_dir also be
|
||||
copied to the target_dir location on this machine.
|
||||
"""
|
||||
target_dir = source_dir if target_dir is None else target_dir
|
||||
source_dir = os.path.abspath(source_dir)
|
||||
target_dir = os.path.abspath(target_dir)
|
||||
source_basename = os.path.basename(source_dir)
|
||||
target_basename = os.path.basename(target_dir)
|
||||
if source_basename != target_basename:
|
||||
raise Exception("The source_dir and target_dir must have the same "
|
||||
"base name, {} != {}".format(source_basename,
|
||||
target_basename))
|
||||
tarred_bytes = tarred_directory_as_bytes(source_dir)
|
||||
|
||||
def f(worker_info):
|
||||
if worker_info["counter"] == 0:
|
||||
tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir))
|
||||
sys.path.append(os.path.dirname(target_dir))
|
||||
# Run this function on all workers to copy the directory to all nodes and to
|
||||
# add the directory to the Python path of each worker.
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
def f(worker_info):
|
||||
if worker_info["counter"] == 0:
|
||||
tarred_bytes_to_directory(tarred_bytes,
|
||||
os.path.dirname(target_dir))
|
||||
sys.path.append(os.path.dirname(target_dir))
|
||||
# Run this function on all workers to copy the directory to all nodes and
|
||||
# to add the directory to the Python path of each worker.
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
@@ -10,45 +10,45 @@ import time
|
||||
def start_global_scheduler(redis_address, node_ip_address,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
"""Start a global scheduler process.
|
||||
"""Start a global scheduler process.
|
||||
|
||||
Args:
|
||||
redis_address (str): The address of the Redis instance.
|
||||
node_ip_address: The IP address of the node that this scheduler will run
|
||||
on.
|
||||
use_valgrind (bool): True if the global scheduler should be started inside
|
||||
of valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the global scheduler should be started inside
|
||||
a profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
redirection should happen, then this should be None.
|
||||
Args:
|
||||
redis_address (str): The address of the Redis instance.
|
||||
node_ip_address: The IP address of the node that this scheduler will
|
||||
run on.
|
||||
use_valgrind (bool): True if the global scheduler should be started
|
||||
inside of valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the global scheduler should be started
|
||||
inside a profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If
|
||||
no redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If
|
||||
no redirection should happen, then this should be None.
|
||||
|
||||
Return:
|
||||
The process ID of the global scheduler process.
|
||||
"""
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
global_scheduler_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/global_scheduler/global_scheduler")
|
||||
command = [global_scheduler_executable,
|
||||
"-r", redis_address,
|
||||
"-h", node_ip_address]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
return pid
|
||||
Return:
|
||||
The process ID of the global scheduler process.
|
||||
"""
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
global_scheduler_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/global_scheduler/global_scheduler")
|
||||
command = [global_scheduler_executable,
|
||||
"-r", redis_address,
|
||||
"-h", node_ip_address]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
return pid
|
||||
|
||||
@@ -33,275 +33,285 @@ TASK_PREFIX = "TT:"
|
||||
|
||||
|
||||
def random_driver_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_task_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_function_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
|
||||
class TestGlobalScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start one Redis server and N pairs of (plasma, local_scheduler)
|
||||
self.node_ip_address = "127.0.0.1"
|
||||
redis_address, redis_shards = services.start_redis(self.node_ip_address)
|
||||
redis_port = services.get_port(redis_address)
|
||||
time.sleep(0.1)
|
||||
# Create a client for the global state store.
|
||||
self.state = state.GlobalState()
|
||||
self.state._initialize_global_state(self.node_ip_address, redis_port)
|
||||
def setUp(self):
|
||||
# Start one Redis server and N pairs of (plasma, local_scheduler)
|
||||
self.node_ip_address = "127.0.0.1"
|
||||
redis_address, redis_shards = services.start_redis(
|
||||
self.node_ip_address)
|
||||
redis_port = services.get_port(redis_address)
|
||||
time.sleep(0.1)
|
||||
# Create a client for the global state store.
|
||||
self.state = state.GlobalState()
|
||||
self.state._initialize_global_state(self.node_ip_address, redis_port)
|
||||
|
||||
# Start one global scheduler.
|
||||
self.p1 = global_scheduler.start_global_scheduler(
|
||||
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
|
||||
self.plasma_store_pids = []
|
||||
self.plasma_manager_pids = []
|
||||
self.local_scheduler_pids = []
|
||||
self.plasma_clients = []
|
||||
self.local_scheduler_clients = []
|
||||
# Start one global scheduler.
|
||||
self.p1 = global_scheduler.start_global_scheduler(
|
||||
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
|
||||
self.plasma_store_pids = []
|
||||
self.plasma_manager_pids = []
|
||||
self.local_scheduler_pids = []
|
||||
self.plasma_clients = []
|
||||
self.local_scheduler_clients = []
|
||||
|
||||
for i in range(NUM_CLUSTER_NODES):
|
||||
# Start the Plasma store. Plasma store name is randomly generated.
|
||||
plasma_store_name, p2 = plasma.start_plasma_store()
|
||||
self.plasma_store_pids.append(p2)
|
||||
# Start the Plasma manager.
|
||||
# Assumption: Plasma manager name and port are randomly generated by the
|
||||
# plasma module.
|
||||
manager_info = plasma.start_plasma_manager(plasma_store_name,
|
||||
redis_address)
|
||||
plasma_manager_name, p3, plasma_manager_port = manager_info
|
||||
self.plasma_manager_pids.append(p3)
|
||||
plasma_address = "{}:{}".format(self.node_ip_address,
|
||||
plasma_manager_port)
|
||||
plasma_client = plasma.PlasmaClient(plasma_store_name,
|
||||
plasma_manager_name)
|
||||
self.plasma_clients.append(plasma_client)
|
||||
# Start the local scheduler.
|
||||
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
|
||||
plasma_store_name,
|
||||
plasma_manager_name=plasma_manager_name,
|
||||
plasma_address=plasma_address,
|
||||
redis_address=redis_address,
|
||||
static_resource_list=[10, 0])
|
||||
# Connect to the scheduler.
|
||||
local_scheduler_client = local_scheduler.LocalSchedulerClient(
|
||||
local_scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
|
||||
self.local_scheduler_clients.append(local_scheduler_client)
|
||||
self.local_scheduler_pids.append(p4)
|
||||
for i in range(NUM_CLUSTER_NODES):
|
||||
# Start the Plasma store. Plasma store name is randomly generated.
|
||||
plasma_store_name, p2 = plasma.start_plasma_store()
|
||||
self.plasma_store_pids.append(p2)
|
||||
# Start the Plasma manager.
|
||||
# Assumption: Plasma manager name and port are randomly generated
|
||||
# by the plasma module.
|
||||
manager_info = plasma.start_plasma_manager(plasma_store_name,
|
||||
redis_address)
|
||||
plasma_manager_name, p3, plasma_manager_port = manager_info
|
||||
self.plasma_manager_pids.append(p3)
|
||||
plasma_address = "{}:{}".format(self.node_ip_address,
|
||||
plasma_manager_port)
|
||||
plasma_client = plasma.PlasmaClient(plasma_store_name,
|
||||
plasma_manager_name)
|
||||
self.plasma_clients.append(plasma_client)
|
||||
# Start the local scheduler.
|
||||
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
|
||||
plasma_store_name,
|
||||
plasma_manager_name=plasma_manager_name,
|
||||
plasma_address=plasma_address,
|
||||
redis_address=redis_address,
|
||||
static_resource_list=[10, 0])
|
||||
# Connect to the scheduler.
|
||||
local_scheduler_client = local_scheduler.LocalSchedulerClient(
|
||||
local_scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
|
||||
self.local_scheduler_clients.append(local_scheduler_client)
|
||||
self.local_scheduler_pids.append(p4)
|
||||
|
||||
def tearDown(self):
|
||||
# Check that the processes are still alive.
|
||||
self.assertEqual(self.p1.poll(), None)
|
||||
for p2 in self.plasma_store_pids:
|
||||
self.assertEqual(p2.poll(), None)
|
||||
for p3 in self.plasma_manager_pids:
|
||||
self.assertEqual(p3.poll(), None)
|
||||
for p4 in self.local_scheduler_pids:
|
||||
self.assertEqual(p4.poll(), None)
|
||||
def tearDown(self):
|
||||
# Check that the processes are still alive.
|
||||
self.assertEqual(self.p1.poll(), None)
|
||||
for p2 in self.plasma_store_pids:
|
||||
self.assertEqual(p2.poll(), None)
|
||||
for p3 in self.plasma_manager_pids:
|
||||
self.assertEqual(p3.poll(), None)
|
||||
for p4 in self.local_scheduler_pids:
|
||||
self.assertEqual(p4.poll(), None)
|
||||
|
||||
redis_processes = services.all_processes[
|
||||
services.PROCESS_TYPE_REDIS_SERVER]
|
||||
for redis_process in redis_processes:
|
||||
self.assertEqual(redis_process.poll(), None)
|
||||
redis_processes = services.all_processes[
|
||||
services.PROCESS_TYPE_REDIS_SERVER]
|
||||
for redis_process in redis_processes:
|
||||
self.assertEqual(redis_process.poll(), None)
|
||||
|
||||
# Kill the global scheduler.
|
||||
if USE_VALGRIND:
|
||||
self.p1.send_signal(signal.SIGTERM)
|
||||
self.p1.wait()
|
||||
if self.p1.returncode != 0:
|
||||
os._exit(-1)
|
||||
else:
|
||||
self.p1.kill()
|
||||
# Kill local schedulers, plasma managers, and plasma stores.
|
||||
for p2 in self.local_scheduler_pids:
|
||||
p2.kill()
|
||||
for p3 in self.plasma_manager_pids:
|
||||
p3.kill()
|
||||
for p4 in self.plasma_store_pids:
|
||||
p4.kill()
|
||||
# Kill Redis. In the event that we are using valgrind, this needs to happen
|
||||
# after we kill the global scheduler.
|
||||
while redis_processes:
|
||||
redis_process = redis_processes.pop()
|
||||
redis_process.kill()
|
||||
|
||||
def get_plasma_manager_id(self):
|
||||
"""Get the db_client_id with client_type equal to plasma_manager.
|
||||
|
||||
Iterates over all the client table keys, gets the db_client_id for the
|
||||
client with client_type matching plasma_manager. Strips the client table
|
||||
prefix. TODO(atumanov): write a separate function to get all plasma manager
|
||||
client IDs.
|
||||
|
||||
Returns:
|
||||
The db_client_id if one is found and otherwise None.
|
||||
"""
|
||||
db_client_id = None
|
||||
|
||||
client_list = self.state.client_table()[self.node_ip_address]
|
||||
for client in client_list:
|
||||
if client["ClientType"] == "plasma_manager":
|
||||
db_client_id = client["DBClientID"]
|
||||
break
|
||||
|
||||
return db_client_id
|
||||
|
||||
def test_task_default_resources(self):
|
||||
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(), 0)
|
||||
self.assertEqual(task1.required_resources(), [1.0, 0.0])
|
||||
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(), 0,
|
||||
local_scheduler.ObjectID(NIL_ACTOR_ID), 0,
|
||||
[1.0, 2.0])
|
||||
self.assertEqual(task2.required_resources(), [1.0, 2.0])
|
||||
|
||||
def test_redis_only_single_task(self):
|
||||
"""
|
||||
Tests global scheduler functionality by interacting with Redis and checking
|
||||
task state transitions in Redis only. TODO(atumanov): implement.
|
||||
"""
|
||||
# Check precondition for this test:
|
||||
# There should be 2n+1 db clients: the global scheduler + one local
|
||||
# scheduler and one plasma per node.
|
||||
self.assertEqual(
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id is not None)
|
||||
|
||||
def test_integration_single_task(self):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
# Insert the object into Redis.
|
||||
data_size = 0xf1f0
|
||||
metadata_size = 0x40
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(
|
||||
plasma_client, data_size, metadata_size, seal=True)
|
||||
|
||||
# Sleep before submitting task to local scheduler.
|
||||
time.sleep(0.1)
|
||||
# Submit a task to Redis.
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep)],
|
||||
num_return_vals[0], random_task_id(), 0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
time.sleep(0.1)
|
||||
# There should now be a task in Redis, and it should get assigned to the
|
||||
# local scheduler
|
||||
num_retries = 10
|
||||
while num_retries > 0:
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), 1)
|
||||
if len(task_entries) == 1:
|
||||
task_id, task = task_entries.popitem()
|
||||
task_status = task["State"]
|
||||
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED])
|
||||
if task_status == state.TASK_STATUS_QUEUED:
|
||||
break
|
||||
# Kill the global scheduler.
|
||||
if USE_VALGRIND:
|
||||
self.p1.send_signal(signal.SIGTERM)
|
||||
self.p1.wait()
|
||||
if self.p1.returncode != 0:
|
||||
os._exit(-1)
|
||||
else:
|
||||
print(task_status)
|
||||
print("The task has not been scheduled yet, trying again.")
|
||||
num_retries -= 1
|
||||
time.sleep(1)
|
||||
self.p1.kill()
|
||||
# Kill local schedulers, plasma managers, and plasma stores.
|
||||
for p2 in self.local_scheduler_pids:
|
||||
p2.kill()
|
||||
for p3 in self.plasma_manager_pids:
|
||||
p3.kill()
|
||||
for p4 in self.plasma_store_pids:
|
||||
p4.kill()
|
||||
# Kill Redis. In the event that we are using valgrind, this needs to
|
||||
# happen after we kill the global scheduler.
|
||||
while redis_processes:
|
||||
redis_process = redis_processes.pop()
|
||||
redis_process.kill()
|
||||
|
||||
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
|
||||
# Failed to submit and schedule a single task -- bail.
|
||||
self.tearDown()
|
||||
sys.exit(1)
|
||||
def get_plasma_manager_id(self):
|
||||
"""Get the db_client_id with client_type equal to plasma_manager.
|
||||
|
||||
def integration_many_tasks_helper(self, timesync=True):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
Iterates over all the client table keys, gets the db_client_id for the
|
||||
client with client_type matching plasma_manager. Strips the client
|
||||
table prefix. TODO(atumanov): write a separate function to get all
|
||||
plasma manager client IDs.
|
||||
|
||||
# Submit a bunch of tasks to Redis.
|
||||
num_tasks = 1000
|
||||
for _ in range(num_tasks):
|
||||
# Create a new object for each task.
|
||||
data_size = np.random.randint(1 << 12)
|
||||
metadata_size = np.random.randint(1 << 9)
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client,
|
||||
data_size,
|
||||
metadata_size,
|
||||
seal=True)
|
||||
if timesync:
|
||||
# Give 10ms for object info handler to fire (long enough to yield CPU).
|
||||
time.sleep(0.010)
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep)],
|
||||
num_return_vals[0], random_task_id(), 0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
# Check that there are the correct number of tasks in Redis and that they
|
||||
# all get assigned to the local scheduler.
|
||||
num_retries = 10
|
||||
num_tasks_done = 0
|
||||
while num_retries > 0:
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), num_tasks)
|
||||
# First, check if all tasks made it to Redis.
|
||||
if len(task_entries) == num_tasks:
|
||||
task_statuses = [task_entry["State"] for task_entry in
|
||||
task_entries.values()]
|
||||
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED]
|
||||
for status in task_statuses]))
|
||||
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED)
|
||||
num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING)
|
||||
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
|
||||
"tasks queued = {}, retries left = {}"
|
||||
.format(len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done, num_retries))
|
||||
if all([status == state.TASK_STATUS_QUEUED for status in
|
||||
task_statuses]):
|
||||
# We're done, so pass.
|
||||
break
|
||||
num_retries -= 1
|
||||
time.sleep(0.1)
|
||||
Returns:
|
||||
The db_client_id if one is found and otherwise None.
|
||||
"""
|
||||
db_client_id = None
|
||||
|
||||
self.assertEqual(num_tasks_done, num_tasks)
|
||||
client_list = self.state.client_table()[self.node_ip_address]
|
||||
for client in client_list:
|
||||
if client["ClientType"] == "plasma_manager":
|
||||
db_client_id = client["DBClientID"]
|
||||
break
|
||||
|
||||
def test_integration_many_tasks_handler_sync(self):
|
||||
self.integration_many_tasks_helper(timesync=True)
|
||||
return db_client_id
|
||||
|
||||
def test_integration_many_tasks(self):
|
||||
# More realistic case: should handle out of order object and task
|
||||
# notifications.
|
||||
self.integration_many_tasks_helper(timesync=False)
|
||||
def test_task_default_resources(self):
|
||||
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(),
|
||||
0)
|
||||
self.assertEqual(task1.required_resources(), [1.0, 0.0])
|
||||
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(),
|
||||
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
0, [1.0, 2.0])
|
||||
self.assertEqual(task2.required_resources(), [1.0, 2.0])
|
||||
|
||||
def test_redis_only_single_task(self):
|
||||
# Tests global scheduler functionality by interacting with Redis and
|
||||
# checking task state transitions in Redis only. TODO(atumanov):
|
||||
# implement.
|
||||
|
||||
# Check precondition for this test:
|
||||
# There should be 2n+1 db clients: the global scheduler + one local
|
||||
# scheduler and one plasma per node.
|
||||
self.assertEqual(
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
db_client_id = self.get_plasma_manager_id()
|
||||
assert(db_client_id is not None)
|
||||
|
||||
def test_integration_single_task(self):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
# Insert the object into Redis.
|
||||
data_size = 0xf1f0
|
||||
metadata_size = 0x40
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(
|
||||
plasma_client, data_size, metadata_size, seal=True)
|
||||
|
||||
# Sleep before submitting task to local scheduler.
|
||||
time.sleep(0.1)
|
||||
# Submit a task to Redis.
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep)],
|
||||
num_return_vals[0], random_task_id(), 0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
time.sleep(0.1)
|
||||
# There should now be a task in Redis, and it should get assigned to
|
||||
# the local scheduler
|
||||
num_retries = 10
|
||||
while num_retries > 0:
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), 1)
|
||||
if len(task_entries) == 1:
|
||||
task_id, task = task_entries.popitem()
|
||||
task_status = task["State"]
|
||||
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED])
|
||||
if task_status == state.TASK_STATUS_QUEUED:
|
||||
break
|
||||
else:
|
||||
print(task_status)
|
||||
print("The task has not been scheduled yet, trying again.")
|
||||
num_retries -= 1
|
||||
time.sleep(1)
|
||||
|
||||
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
|
||||
# Failed to submit and schedule a single task -- bail.
|
||||
self.tearDown()
|
||||
sys.exit(1)
|
||||
|
||||
def integration_many_tasks_helper(self, timesync=True):
|
||||
# There should be three db clients, the global scheduler, the local
|
||||
# scheduler, and the plasma manager.
|
||||
self.assertEqual(
|
||||
len(self.state.client_table()[self.node_ip_address]),
|
||||
2 * NUM_CLUSTER_NODES + 1)
|
||||
num_return_vals = [0, 1, 2, 3, 5, 10]
|
||||
|
||||
# Submit a bunch of tasks to Redis.
|
||||
num_tasks = 1000
|
||||
for _ in range(num_tasks):
|
||||
# Create a new object for each task.
|
||||
data_size = np.random.randint(1 << 12)
|
||||
metadata_size = np.random.randint(1 << 9)
|
||||
plasma_client = self.plasma_clients[0]
|
||||
object_dep, memory_buffer, metadata = create_object(plasma_client,
|
||||
data_size,
|
||||
metadata_size,
|
||||
seal=True)
|
||||
if timesync:
|
||||
# Give 10ms for object info handler to fire (long enough to
|
||||
# yield CPU).
|
||||
time.sleep(0.010)
|
||||
task = local_scheduler.Task(random_driver_id(),
|
||||
random_function_id(),
|
||||
[local_scheduler.ObjectID(object_dep)],
|
||||
num_return_vals[0], random_task_id(),
|
||||
0)
|
||||
self.local_scheduler_clients[0].submit(task)
|
||||
# Check that there are the correct number of tasks in Redis and that
|
||||
# they all get assigned to the local scheduler.
|
||||
num_retries = 10
|
||||
num_tasks_done = 0
|
||||
while num_retries > 0:
|
||||
task_entries = self.state.task_table()
|
||||
self.assertLessEqual(len(task_entries), num_tasks)
|
||||
# First, check if all tasks made it to Redis.
|
||||
if len(task_entries) == num_tasks:
|
||||
task_statuses = [task_entry["State"] for task_entry in
|
||||
task_entries.values()]
|
||||
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
|
||||
state.TASK_STATUS_SCHEDULED,
|
||||
state.TASK_STATUS_QUEUED]
|
||||
for status in task_statuses]))
|
||||
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
|
||||
num_tasks_scheduled = task_statuses.count(
|
||||
state.TASK_STATUS_SCHEDULED)
|
||||
num_tasks_waiting = task_statuses.count(
|
||||
state.TASK_STATUS_WAITING)
|
||||
print("tasks in Redis = {}, tasks waiting = {}, "
|
||||
"tasks scheduled = {}, "
|
||||
"tasks queued = {}, retries left = {}"
|
||||
.format(len(task_entries), num_tasks_waiting,
|
||||
num_tasks_scheduled, num_tasks_done,
|
||||
num_retries))
|
||||
if all([status == state.TASK_STATUS_QUEUED for status in
|
||||
task_statuses]):
|
||||
# We're done, so pass.
|
||||
break
|
||||
num_retries -= 1
|
||||
time.sleep(0.1)
|
||||
|
||||
self.assertEqual(num_tasks_done, num_tasks)
|
||||
|
||||
def test_integration_many_tasks_handler_sync(self):
|
||||
self.integration_many_tasks_helper(timesync=True)
|
||||
|
||||
def test_integration_many_tasks(self):
|
||||
# More realistic case: should handle out of order object and task
|
||||
# notifications.
|
||||
self.integration_many_tasks_helper(timesync=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
# Pop the argument so we don't mess with unittest's own argument parser.
|
||||
if sys.argv[-1] == "valgrind":
|
||||
arg = sys.argv.pop()
|
||||
USE_VALGRIND = True
|
||||
print("Using valgrind for tests")
|
||||
unittest.main(verbosity=2)
|
||||
if len(sys.argv) > 1:
|
||||
# Pop the argument so we don't mess with unittest's own argument
|
||||
# parser.
|
||||
if sys.argv[-1] == "valgrind":
|
||||
arg = sys.argv.pop()
|
||||
USE_VALGRIND = True
|
||||
print("Using valgrind for tests")
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -9,7 +9,7 @@ import time
|
||||
|
||||
|
||||
def random_name():
|
||||
return str(random.randint(0, 99999999))
|
||||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def start_local_scheduler(plasma_store_name,
|
||||
@@ -24,95 +24,99 @@ def start_local_scheduler(plasma_store_name,
|
||||
stderr_file=None,
|
||||
static_resource_list=None,
|
||||
num_workers=0):
|
||||
"""Start a local scheduler process.
|
||||
"""Start a local scheduler process.
|
||||
|
||||
Args:
|
||||
plasma_store_name (str): The name of the plasma store socket to connect to.
|
||||
plasma_manager_name (str): The name of the plasma manager to connect to.
|
||||
This does not need to be provided, but if it is, then the Redis address
|
||||
must be provided as well.
|
||||
worker_path (str): The path of the worker script to use when the local
|
||||
scheduler starts up new workers.
|
||||
plasma_address (str): The address of the plasma manager to connect to. This
|
||||
is only used by the global scheduler to figure out which plasma managers
|
||||
are connected to which local schedulers.
|
||||
node_ip_address (str): The address of the node that this local scheduler is
|
||||
running on.
|
||||
redis_address (str): The address of the Redis instance to connect to. If
|
||||
this is not provided, then the local scheduler will not connect to Redis.
|
||||
use_valgrind (bool): True if the local scheduler should be started inside
|
||||
of valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the local scheduler should be started inside a
|
||||
profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
redirection should happen, then this should be None.
|
||||
static_resource_list (list): A list of integers specifying the local
|
||||
scheduler's resource capacities. The resources should appear in an order
|
||||
matching the order defined in task.h.
|
||||
num_workers (int): The number of workers that the local scheduler should
|
||||
start.
|
||||
Args:
|
||||
plasma_store_name (str): The name of the plasma store socket to connect
|
||||
to.
|
||||
plasma_manager_name (str): The name of the plasma manager to connect
|
||||
to. This does not need to be provided, but if it is, then the Redis
|
||||
address must be provided as well.
|
||||
worker_path (str): The path of the worker script to use when the local
|
||||
scheduler starts up new workers.
|
||||
plasma_address (str): The address of the plasma manager to connect to.
|
||||
This is only used by the global scheduler to figure out which
|
||||
plasma managers are connected to which local schedulers.
|
||||
node_ip_address (str): The address of the node that this local
|
||||
scheduler is running on.
|
||||
redis_address (str): The address of the Redis instance to connect to.
|
||||
If this is not provided, then the local scheduler will not connect
|
||||
to Redis.
|
||||
use_valgrind (bool): True if the local scheduler should be started
|
||||
inside of valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the local scheduler should be started
|
||||
inside a profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If
|
||||
no redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If
|
||||
no redirection should happen, then this should be None.
|
||||
static_resource_list (list): A list of integers specifying the local
|
||||
scheduler's resource capacities. The resources should appear in an
|
||||
order matching the order defined in task.h.
|
||||
num_workers (int): The number of workers that the local scheduler
|
||||
should start.
|
||||
|
||||
Return:
|
||||
A tuple of the name of the local scheduler socket and the process ID of the
|
||||
local scheduler process.
|
||||
"""
|
||||
if (plasma_manager_name is None) != (redis_address is None):
|
||||
raise Exception("If one of the plasma_manager_name and the redis_address "
|
||||
"is provided, then both must be provided.")
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
local_scheduler_executable = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
"../core/src/local_scheduler/local_scheduler")
|
||||
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
|
||||
command = [local_scheduler_executable,
|
||||
"-s", local_scheduler_name,
|
||||
"-p", plasma_store_name,
|
||||
"-h", node_ip_address,
|
||||
"-n", str(num_workers)]
|
||||
if plasma_manager_name is not None:
|
||||
command += ["-m", plasma_manager_name]
|
||||
if worker_path is not None:
|
||||
assert plasma_store_name is not None
|
||||
assert plasma_manager_name is not None
|
||||
assert redis_address is not None
|
||||
start_worker_command = ("python {} "
|
||||
"--node-ip-address={} "
|
||||
"--object-store-name={} "
|
||||
"--object-store-manager-name={} "
|
||||
"--local-scheduler-name={} "
|
||||
"--redis-address={}").format(worker_path,
|
||||
node_ip_address,
|
||||
plasma_store_name,
|
||||
plasma_manager_name,
|
||||
local_scheduler_name,
|
||||
redis_address)
|
||||
command += ["-w", start_worker_command]
|
||||
if redis_address is not None:
|
||||
command += ["-r", redis_address]
|
||||
if plasma_address is not None:
|
||||
command += ["-a", plasma_address]
|
||||
if static_resource_list is not None:
|
||||
assert all([isinstance(resource, int) or isinstance(resource, float)
|
||||
for resource in static_resource_list])
|
||||
command += ["-c", ",".join([str(resource) for resource
|
||||
in static_resource_list])]
|
||||
Return:
|
||||
A tuple of the name of the local scheduler socket and the process ID of
|
||||
the local scheduler process.
|
||||
"""
|
||||
if (plasma_manager_name is None) != (redis_address is None):
|
||||
raise Exception("If one of the plasma_manager_name and the "
|
||||
"redis_address is provided, then both must be "
|
||||
"provided.")
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
local_scheduler_executable = os.path.join(os.path.dirname(
|
||||
os.path.abspath(__file__)),
|
||||
"../core/src/local_scheduler/local_scheduler")
|
||||
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
|
||||
command = [local_scheduler_executable,
|
||||
"-s", local_scheduler_name,
|
||||
"-p", plasma_store_name,
|
||||
"-h", node_ip_address,
|
||||
"-n", str(num_workers)]
|
||||
if plasma_manager_name is not None:
|
||||
command += ["-m", plasma_manager_name]
|
||||
if worker_path is not None:
|
||||
assert plasma_store_name is not None
|
||||
assert plasma_manager_name is not None
|
||||
assert redis_address is not None
|
||||
start_worker_command = ("python {} "
|
||||
"--node-ip-address={} "
|
||||
"--object-store-name={} "
|
||||
"--object-store-manager-name={} "
|
||||
"--local-scheduler-name={} "
|
||||
"--redis-address={}"
|
||||
.format(worker_path,
|
||||
node_ip_address,
|
||||
plasma_store_name,
|
||||
plasma_manager_name,
|
||||
local_scheduler_name,
|
||||
redis_address))
|
||||
command += ["-w", start_worker_command]
|
||||
if redis_address is not None:
|
||||
command += ["-r", redis_address]
|
||||
if plasma_address is not None:
|
||||
command += ["-a", plasma_address]
|
||||
if static_resource_list is not None:
|
||||
assert all([isinstance(resource, int) or isinstance(resource, float)
|
||||
for resource in static_resource_list])
|
||||
command += ["-c", ",".join([str(resource) for resource
|
||||
in static_resource_list])]
|
||||
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
return local_scheduler_name, pid
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
return local_scheduler_name, pid
|
||||
|
||||
@@ -21,195 +21,202 @@ NIL_ACTOR_ID = 20 * b"\xff"
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_driver_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_task_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
def random_function_id():
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
|
||||
|
||||
|
||||
class TestLocalSchedulerClient(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Start Plasma store.
|
||||
plasma_store_name, self.p1 = plasma.start_plasma_store()
|
||||
self.plasma_client = plasma.PlasmaClient(plasma_store_name,
|
||||
release_delay=0)
|
||||
# Start a local scheduler.
|
||||
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
|
||||
plasma_store_name, use_valgrind=USE_VALGRIND)
|
||||
# Connect to the scheduler.
|
||||
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
|
||||
scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
|
||||
def setUp(self):
|
||||
# Start Plasma store.
|
||||
plasma_store_name, self.p1 = plasma.start_plasma_store()
|
||||
self.plasma_client = plasma.PlasmaClient(plasma_store_name,
|
||||
release_delay=0)
|
||||
# Start a local scheduler.
|
||||
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
|
||||
plasma_store_name, use_valgrind=USE_VALGRIND)
|
||||
# Connect to the scheduler.
|
||||
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
|
||||
scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
|
||||
|
||||
def tearDown(self):
|
||||
# Check that the processes are still alive.
|
||||
self.assertEqual(self.p1.poll(), None)
|
||||
self.assertEqual(self.p2.poll(), None)
|
||||
def tearDown(self):
|
||||
# Check that the processes are still alive.
|
||||
self.assertEqual(self.p1.poll(), None)
|
||||
self.assertEqual(self.p2.poll(), None)
|
||||
|
||||
# Kill Plasma.
|
||||
self.p1.kill()
|
||||
# Kill the local scheduler.
|
||||
if USE_VALGRIND:
|
||||
self.p2.send_signal(signal.SIGTERM)
|
||||
self.p2.wait()
|
||||
if self.p2.returncode != 0:
|
||||
os._exit(-1)
|
||||
else:
|
||||
self.p2.kill()
|
||||
# Kill Plasma.
|
||||
self.p1.kill()
|
||||
# Kill the local scheduler.
|
||||
if USE_VALGRIND:
|
||||
self.p2.send_signal(signal.SIGTERM)
|
||||
self.p2.wait()
|
||||
if self.p2.returncode != 0:
|
||||
os._exit(-1)
|
||||
else:
|
||||
self.p2.kill()
|
||||
|
||||
def test_submit_and_get_task(self):
|
||||
function_id = random_function_id()
|
||||
object_ids = [random_object_id() for i in range(256)]
|
||||
# Create and seal the objects in the object store so that we can schedule
|
||||
# all of the subsequent tasks.
|
||||
for object_id in object_ids:
|
||||
self.plasma_client.create(object_id.id(), 0)
|
||||
self.plasma_client.seal(object_id.id())
|
||||
# Define some arguments to use for the tasks.
|
||||
args_list = [
|
||||
[],
|
||||
[{}],
|
||||
[()],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids
|
||||
]
|
||||
def test_submit_and_get_task(self):
|
||||
function_id = random_function_id()
|
||||
object_ids = [random_object_id() for i in range(256)]
|
||||
# Create and seal the objects in the object store so that we can
|
||||
# schedule all of the subsequent tasks.
|
||||
for object_id in object_ids:
|
||||
self.plasma_client.create(object_id.id(), 0)
|
||||
self.plasma_client.seal(object_id.id())
|
||||
# Define some arguments to use for the tasks.
|
||||
args_list = [
|
||||
[],
|
||||
[{}],
|
||||
[()],
|
||||
1 * [1],
|
||||
10 * [1],
|
||||
100 * [1],
|
||||
1000 * [1],
|
||||
1 * ["a"],
|
||||
10 * ["a"],
|
||||
100 * ["a"],
|
||||
1000 * ["a"],
|
||||
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
|
||||
object_ids[:1],
|
||||
object_ids[:2],
|
||||
object_ids[:3],
|
||||
object_ids[:4],
|
||||
object_ids[:5],
|
||||
object_ids[:10],
|
||||
object_ids[:100],
|
||||
object_ids[:256],
|
||||
[1, object_ids[0]],
|
||||
[object_ids[0], "a"],
|
||||
[1, object_ids[0], "a"],
|
||||
[object_ids[0], 1, object_ids[1], "a"],
|
||||
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
|
||||
object_ids + 100 * ["a"] + object_ids
|
||||
]
|
||||
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(random_driver_id(), function_id, args,
|
||||
num_return_vals, random_task_id(), 0)
|
||||
# Submit a task.
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(random_driver_id(), function_id,
|
||||
args, num_return_vals,
|
||||
random_task_id(), 0)
|
||||
# Submit a task.
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Get the task.
|
||||
new_task = self.local_scheduler_client.get_task()
|
||||
self.assertEqual(task.function_id().id(),
|
||||
new_task.function_id().id())
|
||||
retrieved_args = new_task.arguments()
|
||||
returns = new_task.returns()
|
||||
self.assertEqual(len(args), len(retrieved_args))
|
||||
self.assertEqual(num_return_vals, len(returns))
|
||||
for i in range(len(retrieved_args)):
|
||||
if isinstance(args[i], local_scheduler.ObjectID):
|
||||
self.assertEqual(args[i].id(), retrieved_args[i].id())
|
||||
else:
|
||||
self.assertEqual(args[i], retrieved_args[i])
|
||||
|
||||
# Submit all of the tasks.
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(random_driver_id(), function_id,
|
||||
args, num_return_vals,
|
||||
random_task_id(), 0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Get all of the tasks.
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
new_task = self.local_scheduler_client.get_task()
|
||||
|
||||
def test_scheduling_when_objects_ready(self):
|
||||
# Create a task and submit it.
|
||||
object_id = random_object_id()
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[object_id], 0, random_task_id(), 0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Get the task.
|
||||
new_task = self.local_scheduler_client.get_task()
|
||||
self.assertEqual(task.function_id().id(), new_task.function_id().id())
|
||||
retrieved_args = new_task.arguments()
|
||||
returns = new_task.returns()
|
||||
self.assertEqual(len(args), len(retrieved_args))
|
||||
self.assertEqual(num_return_vals, len(returns))
|
||||
for i in range(len(retrieved_args)):
|
||||
if isinstance(args[i], local_scheduler.ObjectID):
|
||||
self.assertEqual(args[i].id(), retrieved_args[i].id())
|
||||
else:
|
||||
self.assertEqual(args[i], retrieved_args[i])
|
||||
|
||||
# Submit all of the tasks.
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
task = local_scheduler.Task(random_driver_id(), function_id, args,
|
||||
num_return_vals, random_task_id(), 0)
|
||||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
# Sleep to give the thread time to call get_task.
|
||||
time.sleep(0.1)
|
||||
# Create and seal the object ID in the object store. This should
|
||||
# trigger a scheduling event.
|
||||
self.plasma_client.create(object_id.id(), 0)
|
||||
self.plasma_client.seal(object_id.id())
|
||||
# Wait until the thread finishes so that we know the task was
|
||||
# scheduled.
|
||||
t.join()
|
||||
|
||||
def test_scheduling_when_objects_evicted(self):
|
||||
# Create a task with two dependencies and submit it.
|
||||
object_id1 = random_object_id()
|
||||
object_id2 = random_object_id()
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[object_id1, object_id2], 0,
|
||||
random_task_id(), 0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Get all of the tasks.
|
||||
for args in args_list:
|
||||
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
|
||||
new_task = self.local_scheduler_client.get_task()
|
||||
|
||||
def test_scheduling_when_objects_ready(self):
|
||||
# Create a task and submit it.
|
||||
object_id = random_object_id()
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[object_id], 0, random_task_id(), 0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
|
||||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
# Sleep to give the thread time to call get_task.
|
||||
time.sleep(0.1)
|
||||
# Create and seal the object ID in the object store. This should trigger a
|
||||
# scheduling event.
|
||||
self.plasma_client.create(object_id.id(), 0)
|
||||
self.plasma_client.seal(object_id.id())
|
||||
# Wait until the thread finishes so that we know the task was scheduled.
|
||||
t.join()
|
||||
# Make one of the dependencies available.
|
||||
buf = self.plasma_client.create(object_id1.id(), 1)
|
||||
self.plasma_client.seal(object_id1.id())
|
||||
# Release the object.
|
||||
del buf
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
# Force eviction of the first dependency.
|
||||
self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY)
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
# Check that the first object dependency was evicted.
|
||||
object1 = self.plasma_client.get([object_id1.id()], timeout_ms=0)
|
||||
self.assertEqual(object1, [None])
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
|
||||
def test_scheduling_when_objects_evicted(self):
|
||||
# Create a task with two dependencies and submit it.
|
||||
object_id1 = random_object_id()
|
||||
object_id2 = random_object_id()
|
||||
task = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[object_id1, object_id2], 0, random_task_id(),
|
||||
0)
|
||||
self.local_scheduler_client.submit(task)
|
||||
# Create the second dependency.
|
||||
self.plasma_client.create(object_id2.id(), 1)
|
||||
self.plasma_client.seal(object_id2.id())
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
|
||||
# Launch a thread to get the task.
|
||||
def get_task():
|
||||
self.local_scheduler_client.get_task()
|
||||
t = threading.Thread(target=get_task)
|
||||
t.start()
|
||||
# Create the first dependency again. Both dependencies are now
|
||||
# available.
|
||||
self.plasma_client.create(object_id1.id(), 1)
|
||||
self.plasma_client.seal(object_id1.id())
|
||||
|
||||
# Make one of the dependencies available.
|
||||
buf = self.plasma_client.create(object_id1.id(), 1)
|
||||
self.plasma_client.seal(object_id1.id())
|
||||
# Release the object.
|
||||
del buf
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
# Force eviction of the first dependency.
|
||||
self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY)
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
# Check that the first object dependency was evicted.
|
||||
object1 = self.plasma_client.get([object_id1.id()], timeout_ms=0)
|
||||
self.assertEqual(object1, [None])
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
|
||||
# Create the second dependency.
|
||||
self.plasma_client.create(object_id2.id(), 1)
|
||||
self.plasma_client.seal(object_id2.id())
|
||||
# Check that the thread is still waiting for a task.
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(t.is_alive())
|
||||
|
||||
# Create the first dependency again. Both dependencies are now available.
|
||||
self.plasma_client.create(object_id1.id(), 1)
|
||||
self.plasma_client.seal(object_id1.id())
|
||||
|
||||
# Wait until the thread finishes so that we know the task was scheduled.
|
||||
t.join()
|
||||
# Wait until the thread finishes so that we know the task was
|
||||
# scheduled.
|
||||
t.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
# Pop the argument so we don't mess with unittest's own argument parser.
|
||||
if sys.argv[-1] == "valgrind":
|
||||
arg = sys.argv.pop()
|
||||
USE_VALGRIND = True
|
||||
print("Using valgrind for tests")
|
||||
unittest.main(verbosity=2)
|
||||
if len(sys.argv) > 1:
|
||||
# Pop the argument so we don't mess with unittest's own argument
|
||||
# parser.
|
||||
if sys.argv[-1] == "valgrind":
|
||||
arg = sys.argv.pop()
|
||||
USE_VALGRIND = True
|
||||
print("Using valgrind for tests")
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
+92
-84
@@ -12,94 +12,102 @@ from ray.services import get_port
|
||||
|
||||
|
||||
class LogMonitor(object):
|
||||
"""A monitor process for monitoring Ray log files.
|
||||
"""A monitor process for monitoring Ray log files.
|
||||
|
||||
Attributes:
|
||||
node_ip_address: The IP address of the node that the log monitor process is
|
||||
running on. This will be used to determine which log files to track.
|
||||
redis_client: A client used to communicate with the Redis server.
|
||||
log_filenames: A list of the names of the log files that this monitor
|
||||
process is monitoring.
|
||||
log_files: A dictionary mapping the name of a log file to a list of strings
|
||||
representing its contents.
|
||||
log_file_handles: A dictionary mapping the name of a log file to a file
|
||||
handle for that file.
|
||||
"""
|
||||
def __init__(self, redis_ip_address, redis_port, node_ip_address):
|
||||
"""Initialize the log monitor object."""
|
||||
self.node_ip_address = node_ip_address
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.log_files = {}
|
||||
self.log_file_handles = {}
|
||||
|
||||
def update_log_filenames(self):
|
||||
"""Get the most up-to-date list of log files to monitor from Redis."""
|
||||
num_current_log_files = len(self.log_files)
|
||||
new_log_filenames = self.redis_client.lrange(
|
||||
"LOG_FILENAMES:{}".format(self.node_ip_address),
|
||||
num_current_log_files, -1)
|
||||
for log_filename in new_log_filenames:
|
||||
print("Beginning to track file {}".format(log_filename))
|
||||
assert log_filename not in self.log_files
|
||||
self.log_files[log_filename] = []
|
||||
|
||||
def check_log_files_and_push_updates(self):
|
||||
"""Get any changes to the log files and push updates to Redis."""
|
||||
for log_filename in self.log_files:
|
||||
if log_filename in self.log_file_handles:
|
||||
# Get any updates to the file.
|
||||
new_lines = []
|
||||
while True:
|
||||
current_position = self.log_file_handles[log_filename].tell()
|
||||
next_line = self.log_file_handles[log_filename].readline()
|
||||
if next_line != "":
|
||||
new_lines.append(next_line)
|
||||
else:
|
||||
self.log_file_handles[log_filename].seek(current_position)
|
||||
break
|
||||
|
||||
# If there are any new lines, cache them and also push them to Redis.
|
||||
if len(new_lines) > 0:
|
||||
self.log_files[log_filename] += new_lines
|
||||
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address,
|
||||
log_filename.decode("ascii"))
|
||||
self.redis_client.rpush(redis_key, *new_lines)
|
||||
else:
|
||||
try:
|
||||
self.log_file_handles[log_filename] = open(log_filename, "r")
|
||||
except IOError as e:
|
||||
if e.errno == os.errno.EMFILE:
|
||||
print("Warning: Some files are not being logged because there are "
|
||||
"too many open files.")
|
||||
elif e.errno == os.errno.ENOENT:
|
||||
print("Warning: The file {} was not found.".format(log_filename))
|
||||
else:
|
||||
raise e
|
||||
|
||||
def run(self):
|
||||
"""Run the log monitor.
|
||||
|
||||
This will query Redis once every second to check if there are new log files
|
||||
to monitor. It will also store those log files in Redis.
|
||||
Attributes:
|
||||
node_ip_address: The IP address of the node that the log monitor
|
||||
process is running on. This will be used to determine which log
|
||||
files to track.
|
||||
redis_client: A client used to communicate with the Redis server.
|
||||
log_filenames: A list of the names of the log files that this monitor
|
||||
process is monitoring.
|
||||
log_files: A dictionary mapping the name of a log file to a list of
|
||||
strings representing its contents.
|
||||
log_file_handles: A dictionary mapping the name of a log file to a file
|
||||
handle for that file.
|
||||
"""
|
||||
while True:
|
||||
self.update_log_filenames()
|
||||
self.check_log_files_and_push_updates()
|
||||
time.sleep(1)
|
||||
def __init__(self, redis_ip_address, redis_port, node_ip_address):
|
||||
"""Initialize the log monitor object."""
|
||||
self.node_ip_address = node_ip_address
|
||||
self.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=redis_port)
|
||||
self.log_files = {}
|
||||
self.log_file_handles = {}
|
||||
|
||||
def update_log_filenames(self):
|
||||
"""Get the most up-to-date list of log files to monitor from Redis."""
|
||||
num_current_log_files = len(self.log_files)
|
||||
new_log_filenames = self.redis_client.lrange(
|
||||
"LOG_FILENAMES:{}".format(self.node_ip_address),
|
||||
num_current_log_files, -1)
|
||||
for log_filename in new_log_filenames:
|
||||
print("Beginning to track file {}".format(log_filename))
|
||||
assert log_filename not in self.log_files
|
||||
self.log_files[log_filename] = []
|
||||
|
||||
def check_log_files_and_push_updates(self):
|
||||
"""Get any changes to the log files and push updates to Redis."""
|
||||
for log_filename in self.log_files:
|
||||
if log_filename in self.log_file_handles:
|
||||
# Get any updates to the file.
|
||||
new_lines = []
|
||||
while True:
|
||||
current_position = (
|
||||
self.log_file_handles[log_filename].tell())
|
||||
next_line = self.log_file_handles[log_filename].readline()
|
||||
if next_line != "":
|
||||
new_lines.append(next_line)
|
||||
else:
|
||||
self.log_file_handles[log_filename].seek(
|
||||
current_position)
|
||||
break
|
||||
|
||||
# If there are any new lines, cache them and also push them to
|
||||
# Redis.
|
||||
if len(new_lines) > 0:
|
||||
self.log_files[log_filename] += new_lines
|
||||
redis_key = "LOGFILE:{}:{}".format(
|
||||
self.node_ip_address, log_filename.decode("ascii"))
|
||||
self.redis_client.rpush(redis_key, *new_lines)
|
||||
else:
|
||||
try:
|
||||
self.log_file_handles[log_filename] = open(log_filename,
|
||||
"r")
|
||||
except IOError as e:
|
||||
if e.errno == os.errno.EMFILE:
|
||||
print("Warning: Some files are not being logged "
|
||||
"because there are too many open files.")
|
||||
elif e.errno == os.errno.ENOENT:
|
||||
print("Warning: The file {} was not "
|
||||
"found.".format(log_filename))
|
||||
else:
|
||||
raise e
|
||||
|
||||
def run(self):
|
||||
"""Run the log monitor.
|
||||
|
||||
This will query Redis once every second to check if there are new log
|
||||
files to monitor. It will also store those log files in Redis.
|
||||
"""
|
||||
while True:
|
||||
self.update_log_filenames()
|
||||
self.check_log_files_and_push_updates()
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"log monitor to connect to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="The address to use for Redis.")
|
||||
parser.add_argument("--node-ip-address", required=True, type=str,
|
||||
help="The IP address of the node this process is on.")
|
||||
args = parser.parse_args()
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"log monitor to connect "
|
||||
"to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="The address to use for Redis.")
|
||||
parser.add_argument("--node-ip-address", required=True, type=str,
|
||||
help="The IP address of the node this process is on.")
|
||||
args = parser.parse_args()
|
||||
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
redis_port = get_port(args.redis_address)
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
redis_port = get_port(args.redis_address)
|
||||
|
||||
log_monitor = LogMonitor(redis_ip_address, redis_port, args.node_ip_address)
|
||||
log_monitor.run()
|
||||
log_monitor = LogMonitor(redis_ip_address, redis_port,
|
||||
args.node_ip_address)
|
||||
log_monitor.run()
|
||||
|
||||
+318
-305
@@ -48,354 +48,367 @@ log.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Monitor(object):
|
||||
"""A monitor for Ray processes.
|
||||
"""A monitor for Ray processes.
|
||||
|
||||
The monitor is in charge of cleaning up the tables in the global state after
|
||||
processes have died. The monitor is currently not responsible for detecting
|
||||
component failures.
|
||||
The monitor is in charge of cleaning up the tables in the global state
|
||||
after processes have died. The monitor is currently not responsible for
|
||||
detecting component failures.
|
||||
|
||||
Attributes:
|
||||
redis: A connection to the Redis server.
|
||||
subscribe_client: A pubsub client for the Redis server. This is used to
|
||||
receive notifications about failed components.
|
||||
subscribed: A dictionary mapping channel names (str) to whether or not the
|
||||
subscription to that channel has succeeded yet (bool).
|
||||
dead_local_schedulers: A set of the local scheduler IDs of all of the local
|
||||
schedulers that were up at one point and have died since then.
|
||||
live_plasma_managers: A counter mapping live plasma manager IDs to the
|
||||
number of heartbeats that have passed since we last heard from that
|
||||
plasma manager. A plasma manager is live if we received a heartbeat from
|
||||
it at any point, and if it has not timed out.
|
||||
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
|
||||
managers that were up at one point and have died since then.
|
||||
"""
|
||||
def __init__(self, redis_address, redis_port):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
self.state._initialize_global_state(redis_address, redis_port)
|
||||
self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0)
|
||||
# TODO(swang): Update pubsub client to use ray.experimental.state once
|
||||
# subscriptions are implemented there.
|
||||
self.subscribe_client = self.redis.pubsub()
|
||||
self.subscribed = {}
|
||||
# Initialize data structures to keep track of the active database clients.
|
||||
self.dead_local_schedulers = set()
|
||||
self.live_plasma_managers = Counter()
|
||||
self.dead_plasma_managers = set()
|
||||
|
||||
def subscribe(self, channel):
|
||||
"""Subscribe to the given channel.
|
||||
|
||||
Args:
|
||||
channel (str): The channel to subscribe to.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the subscription fails.
|
||||
Attributes:
|
||||
redis: A connection to the Redis server.
|
||||
subscribe_client: A pubsub client for the Redis server. This is used to
|
||||
receive notifications about failed components.
|
||||
subscribed: A dictionary mapping channel names (str) to whether or not
|
||||
the subscription to that channel has succeeded yet (bool).
|
||||
dead_local_schedulers: A set of the local scheduler IDs of all of the
|
||||
local schedulers that were up at one point and have died since
|
||||
then.
|
||||
live_plasma_managers: A counter mapping live plasma manager IDs to the
|
||||
number of heartbeats that have passed since we last heard from that
|
||||
plasma manager. A plasma manager is live if we received a heartbeat
|
||||
from it at any point, and if it has not timed out.
|
||||
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
|
||||
managers that were up at one point and have died since then.
|
||||
"""
|
||||
self.subscribe_client.subscribe(channel)
|
||||
self.subscribed[channel] = False
|
||||
def __init__(self, redis_address, redis_port):
|
||||
# Initialize the Redis clients.
|
||||
self.state = ray.experimental.state.GlobalState()
|
||||
self.state._initialize_global_state(redis_address, redis_port)
|
||||
self.redis = redis.StrictRedis(host=redis_address, port=redis_port,
|
||||
db=0)
|
||||
# TODO(swang): Update pubsub client to use ray.experimental.state once
|
||||
# subscriptions are implemented there.
|
||||
self.subscribe_client = self.redis.pubsub()
|
||||
self.subscribed = {}
|
||||
# Initialize data structures to keep track of the active database
|
||||
# clients.
|
||||
self.dead_local_schedulers = set()
|
||||
self.live_plasma_managers = Counter()
|
||||
self.dead_plasma_managers = set()
|
||||
|
||||
def cleanup_task_table(self):
|
||||
"""Clean up global state for failed local schedulers.
|
||||
def subscribe(self, channel):
|
||||
"""Subscribe to the given channel.
|
||||
|
||||
This marks any tasks that were scheduled on dead local schedulers as
|
||||
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
|
||||
self.dead_local_schedulers.
|
||||
"""
|
||||
tasks = self.state.task_table()
|
||||
num_tasks_updated = 0
|
||||
for task_id, task in tasks.items():
|
||||
# See if the corresponding local scheduler is alive.
|
||||
if task["LocalSchedulerID"] in self.dead_local_schedulers:
|
||||
# If the task is scheduled on a dead local scheduler, mark the task as
|
||||
# lost.
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
num_tasks_updated += 1
|
||||
if num_tasks_updated > 0:
|
||||
log.warn("Marked {} tasks as lost.".format(num_tasks_updated))
|
||||
Args:
|
||||
channel (str): The channel to subscribe to.
|
||||
|
||||
def cleanup_object_table(self):
|
||||
"""Clean up global state for failed plasma managers.
|
||||
Raises:
|
||||
Exception: An exception is raised if the subscription fails.
|
||||
"""
|
||||
self.subscribe_client.subscribe(channel)
|
||||
self.subscribed[channel] = False
|
||||
|
||||
This removes dead plasma managers from any location entries in the object
|
||||
table. A plasma manager is deemed dead if it is in
|
||||
self.dead_plasma_managers.
|
||||
"""
|
||||
# TODO(swang): Also kill the associated plasma store, since it's no longer
|
||||
# reachable without a plasma manager.
|
||||
objects = self.state.object_table()
|
||||
num_objects_removed = 0
|
||||
for object_id, obj in objects.items():
|
||||
manager_ids = obj["ManagerIDs"]
|
||||
if manager_ids is None:
|
||||
continue
|
||||
for manager in manager_ids:
|
||||
if manager in self.dead_plasma_managers:
|
||||
# If the object was on a dead plasma manager, remove that location
|
||||
# entry.
|
||||
ok = self.state._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_REMOVE",
|
||||
object_id.id(),
|
||||
hex_to_binary(manager))
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to remove object location for dead plasma "
|
||||
"manager.")
|
||||
num_objects_removed += 1
|
||||
if num_objects_removed > 0:
|
||||
log.warn("Marked {} objects as lost.".format(num_objects_removed))
|
||||
def cleanup_task_table(self):
|
||||
"""Clean up global state for failed local schedulers.
|
||||
|
||||
def scan_db_client_table(self):
|
||||
"""Scan the database client table for dead clients.
|
||||
This marks any tasks that were scheduled on dead local schedulers as
|
||||
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
|
||||
self.dead_local_schedulers.
|
||||
"""
|
||||
tasks = self.state.task_table()
|
||||
num_tasks_updated = 0
|
||||
for task_id, task in tasks.items():
|
||||
# See if the corresponding local scheduler is alive.
|
||||
if task["LocalSchedulerID"] in self.dead_local_schedulers:
|
||||
# If the task is scheduled on a dead local scheduler, mark the
|
||||
# task as lost.
|
||||
key = binary_to_object_id(hex_to_binary(task_id))
|
||||
ok = self.state._execute_command(
|
||||
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
|
||||
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to update lost task for dead scheduler.")
|
||||
num_tasks_updated += 1
|
||||
if num_tasks_updated > 0:
|
||||
log.warn("Marked {} tasks as lost.".format(num_tasks_updated))
|
||||
|
||||
After subscribing to the client table, it's necessary to call this before
|
||||
reading any messages from the subscription channel. This ensures that we do
|
||||
not miss any notifications for deleted clients that occurred before we
|
||||
subscribed.
|
||||
"""
|
||||
clients = self.state.client_table()
|
||||
for node_ip_address, node_clients in clients.items():
|
||||
for client in node_clients:
|
||||
db_client_id = client["DBClientID"]
|
||||
client_type = client["ClientType"]
|
||||
if client["Deleted"]:
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
def cleanup_object_table(self):
|
||||
"""Clean up global state for failed plasma managers.
|
||||
|
||||
def subscribe_handler(self, channel, data):
|
||||
"""Handle a subscription success message from Redis.
|
||||
"""
|
||||
log.debug("Subscribed to {}, data was {}".format(channel, data))
|
||||
self.subscribed[channel] = True
|
||||
This removes dead plasma managers from any location entries in the
|
||||
object table. A plasma manager is deemed dead if it is in
|
||||
self.dead_plasma_managers.
|
||||
"""
|
||||
# TODO(swang): Also kill the associated plasma store, since it's no
|
||||
# longer reachable without a plasma manager.
|
||||
objects = self.state.object_table()
|
||||
num_objects_removed = 0
|
||||
for object_id, obj in objects.items():
|
||||
manager_ids = obj["ManagerIDs"]
|
||||
if manager_ids is None:
|
||||
continue
|
||||
for manager in manager_ids:
|
||||
if manager in self.dead_plasma_managers:
|
||||
# If the object was on a dead plasma manager, remove that
|
||||
# location entry.
|
||||
ok = self.state._execute_command(object_id,
|
||||
"RAY.OBJECT_TABLE_REMOVE",
|
||||
object_id.id(),
|
||||
hex_to_binary(manager))
|
||||
if ok != b"OK":
|
||||
log.warn("Failed to remove object location for dead "
|
||||
"plasma manager.")
|
||||
num_objects_removed += 1
|
||||
if num_objects_removed > 0:
|
||||
log.warn("Marked {} objects as lost.".format(num_objects_removed))
|
||||
|
||||
def db_client_notification_handler(self, channel, data):
|
||||
"""Handle a notification from the db_client table from Redis.
|
||||
def scan_db_client_table(self):
|
||||
"""Scan the database client table for dead clients.
|
||||
|
||||
This handler processes notifications from the db_client table.
|
||||
Notifications should be parsed using the SubscribeToDBClientTableReply
|
||||
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the
|
||||
associated state in the state tables should be handled by the caller.
|
||||
"""
|
||||
notification_object = (SubscribeToDBClientTableReply
|
||||
.GetRootAsSubscribeToDBClientTableReply(data, 0))
|
||||
db_client_id = binary_to_hex(notification_object.DbClientId())
|
||||
client_type = notification_object.ClientType()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
After subscribing to the client table, it's necessary to call this
|
||||
before reading any messages from the subscription channel. This ensures
|
||||
that we do not miss any notifications for deleted clients that occurred
|
||||
before we subscribed.
|
||||
"""
|
||||
clients = self.state.client_table()
|
||||
for node_ip_address, node_clients in clients.items():
|
||||
for client in node_clients:
|
||||
db_client_id = client["DBClientID"]
|
||||
client_type = client["ClientType"]
|
||||
if client["Deleted"]:
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
|
||||
# If the update was an insertion, we ignore it.
|
||||
if is_insertion:
|
||||
return
|
||||
def subscribe_handler(self, channel, data):
|
||||
"""Handle a subscription success message from Redis."""
|
||||
log.debug("Subscribed to {}, data was {}".format(channel, data))
|
||||
self.subscribed[channel] = True
|
||||
|
||||
# If the update was a deletion, add them to our accounting for dead
|
||||
# local schedulers and plasma managers.
|
||||
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
if db_client_id not in self.dead_local_schedulers:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
if db_client_id not in self.dead_plasma_managers:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
# Stop tracking this plasma manager's heartbeats, since it's
|
||||
# already dead.
|
||||
del self.live_plasma_managers[db_client_id]
|
||||
def db_client_notification_handler(self, channel, data):
|
||||
"""Handle a notification from the db_client table from Redis.
|
||||
|
||||
def plasma_manager_heartbeat_handler(self, channel, data):
|
||||
"""Handle a plasma manager heartbeat from Redis.
|
||||
This handler processes notifications from the db_client table.
|
||||
Notifications should be parsed using the SubscribeToDBClientTableReply
|
||||
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of
|
||||
the associated state in the state tables should be handled by the
|
||||
caller.
|
||||
"""
|
||||
notification_object = (SubscribeToDBClientTableReply
|
||||
.GetRootAsSubscribeToDBClientTableReply(data,
|
||||
0))
|
||||
db_client_id = binary_to_hex(notification_object.DbClientId())
|
||||
client_type = notification_object.ClientType()
|
||||
is_insertion = notification_object.IsInsertion()
|
||||
|
||||
This resets the number of heartbeats that we've missed from this plasma
|
||||
manager.
|
||||
"""
|
||||
# The first DB_CLIENT_ID_SIZE characters are the client ID.
|
||||
db_client_id = data[:DB_CLIENT_ID_SIZE]
|
||||
# Reset the number of heartbeats that we've missed from this plasma
|
||||
# manager.
|
||||
self.live_plasma_managers[db_client_id] = 0
|
||||
# If the update was an insertion, we ignore it.
|
||||
if is_insertion:
|
||||
return
|
||||
|
||||
def driver_removed_handler(self, channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
# If the update was a deletion, add them to our accounting for dead
|
||||
# local schedulers and plasma managers.
|
||||
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
|
||||
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
|
||||
if db_client_id not in self.dead_local_schedulers:
|
||||
self.dead_local_schedulers.add(db_client_id)
|
||||
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
|
||||
if db_client_id not in self.dead_plasma_managers:
|
||||
self.dead_plasma_managers.add(db_client_id)
|
||||
# Stop tracking this plasma manager's heartbeats, since it's
|
||||
# already dead.
|
||||
del self.live_plasma_managers[db_client_id]
|
||||
|
||||
This releases any GPU resources that were reserved for that driver in
|
||||
Redis.
|
||||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info("Driver {} has been removed.".format(binary_to_hex(driver_id)))
|
||||
def plasma_manager_heartbeat_handler(self, channel, data):
|
||||
"""Handle a plasma manager heartbeat from Redis.
|
||||
|
||||
# Get a list of the local schedulers.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if client["ClientType"] == "local_scheduler":
|
||||
local_schedulers.append(client)
|
||||
This resets the number of heartbeats that we've missed from this plasma
|
||||
manager.
|
||||
"""
|
||||
# The first DB_CLIENT_ID_SIZE characters are the client ID.
|
||||
db_client_id = data[:DB_CLIENT_ID_SIZE]
|
||||
# Reset the number of heartbeats that we've missed from this plasma
|
||||
# manager.
|
||||
self.live_plasma_managers[db_client_id] = 0
|
||||
|
||||
# Release any GPU resources that have been reserved for this driver in
|
||||
# Redis.
|
||||
for local_scheduler in local_schedulers:
|
||||
if int(local_scheduler["NumGPUs"]) > 0:
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
def driver_removed_handler(self, channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
|
||||
num_gpus_returned = 0
|
||||
This releases any GPU resources that were reserved for that driver in
|
||||
Redis.
|
||||
"""
|
||||
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
|
||||
driver_id = message.DriverId()
|
||||
log.info("Driver {} has been removed."
|
||||
.format(binary_to_hex(driver_id)))
|
||||
|
||||
# Perform a transaction to return the GPUs.
|
||||
with self.redis.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the
|
||||
# multi/exec block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
# Get a list of the local schedulers.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if client["ClientType"] == "local_scheduler":
|
||||
local_schedulers.append(client)
|
||||
|
||||
result = pipe.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(result)
|
||||
# Release any GPU resources that have been reserved for this driver in
|
||||
# Redis.
|
||||
for local_scheduler in local_schedulers:
|
||||
if int(local_scheduler["NumGPUs"]) > 0:
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
num_gpus_returned = gpus_in_use.pop(driver_id_hex)
|
||||
num_gpus_returned = 0
|
||||
|
||||
pipe.multi()
|
||||
# Perform a transaction to return the GPUs.
|
||||
with self.redis.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction
|
||||
# below (the multi/exec block), then the
|
||||
# transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use",
|
||||
json.dumps(gpus_in_use))
|
||||
result = pipe.hget(local_scheduler_id,
|
||||
"gpus_in_use")
|
||||
gpus_in_use = (dict() if result is None
|
||||
else json.loads(result))
|
||||
|
||||
pipe.execute()
|
||||
# If a WatchError is not raise, then the operations should have
|
||||
# gone through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the
|
||||
# time we started WATCHing it and the pipeline's execution. We
|
||||
# should just retry.
|
||||
continue
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex in gpus_in_use:
|
||||
num_gpus_returned = gpus_in_use.pop(
|
||||
driver_id_hex)
|
||||
|
||||
log.info("Driver {} is returning GPU IDs {} to local scheduler {}."
|
||||
.format(driver_id, num_gpus_returned, local_scheduler_id))
|
||||
pipe.multi()
|
||||
|
||||
def process_messages(self):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use",
|
||||
json.dumps(gpus_in_use))
|
||||
|
||||
This reads messages from the subscription channels and calls the
|
||||
appropriate handlers until there are no messages left.
|
||||
"""
|
||||
while True:
|
||||
message = self.subscribe_client.get_message()
|
||||
if message is None:
|
||||
return
|
||||
pipe.execute()
|
||||
# If a WatchError is not raise, then the operations
|
||||
# should have gone through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key
|
||||
# between the time we started WATCHing it and the
|
||||
# pipeline's execution. We should just retry.
|
||||
continue
|
||||
|
||||
# Parse the message.
|
||||
channel = message["channel"]
|
||||
data = message["data"]
|
||||
log.info("Driver {} is returning GPU IDs {} to local "
|
||||
"scheduler {}.".format(driver_id, num_gpus_returned,
|
||||
local_scheduler_id))
|
||||
|
||||
# Determine the appropriate message handler.
|
||||
message_handler = None
|
||||
if not self.subscribed[channel]:
|
||||
# If the data was an integer, then the message was a response to an
|
||||
# initial subscription request.
|
||||
message_handler = self.subscribe_handler
|
||||
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a heartbeat from a plasma manager.
|
||||
message_handler = self.plasma_manager_heartbeat_handler
|
||||
elif channel == DB_CLIENT_TABLE_NAME:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a notification from the db_client table.
|
||||
message_handler = self.db_client_notification_handler
|
||||
elif channel == DRIVER_DEATH_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a notification that a driver was removed.
|
||||
message_handler = self.driver_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
def process_messages(self):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
|
||||
# Call the handler.
|
||||
assert(message_handler is not None)
|
||||
message_handler(channel, data)
|
||||
This reads messages from the subscription channels and calls the
|
||||
appropriate handlers until there are no messages left.
|
||||
"""
|
||||
while True:
|
||||
message = self.subscribe_client.get_message()
|
||||
if message is None:
|
||||
return
|
||||
|
||||
def run(self):
|
||||
"""Run the monitor.
|
||||
# Parse the message.
|
||||
channel = message["channel"]
|
||||
data = message["data"]
|
||||
|
||||
This function loops forever, checking for messages about dead database
|
||||
clients and cleaning up state accordingly.
|
||||
"""
|
||||
# Initialize the subscription channel.
|
||||
self.subscribe(DB_CLIENT_TABLE_NAME)
|
||||
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
|
||||
self.subscribe(DRIVER_DEATH_CHANNEL)
|
||||
# Determine the appropriate message handler.
|
||||
message_handler = None
|
||||
if not self.subscribed[channel]:
|
||||
# If the data was an integer, then the message was a response
|
||||
# to an initial subscription request.
|
||||
message_handler = self.subscribe_handler
|
||||
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a heartbeat from a plasma manager.
|
||||
message_handler = self.plasma_manager_heartbeat_handler
|
||||
elif channel == DB_CLIENT_TABLE_NAME:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a notification from the db_client table.
|
||||
message_handler = self.db_client_notification_handler
|
||||
elif channel == DRIVER_DEATH_CHANNEL:
|
||||
assert(self.subscribed[channel])
|
||||
# The message was a notification that a driver was removed.
|
||||
message_handler = self.driver_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
# Scan the database table for dead database clients. NOTE: This must be
|
||||
# called before reading any messages from the subscription channel. This
|
||||
# ensures that we start in a consistent state, since we may have missed
|
||||
# notifications that were sent before we connected to the subscription
|
||||
# channel.
|
||||
self.scan_db_client_table()
|
||||
# If there were any dead clients at startup, clean up the associated state
|
||||
# in the state tables.
|
||||
if len(self.dead_local_schedulers) > 0:
|
||||
self.cleanup_task_table()
|
||||
if len(self.dead_plasma_managers) > 0:
|
||||
self.cleanup_object_table()
|
||||
log.debug("{} dead local schedulers, {} plasma managers total, {} dead "
|
||||
"plasma managers".format(len(self.dead_local_schedulers),
|
||||
(len(self.live_plasma_managers) +
|
||||
len(self.dead_plasma_managers)),
|
||||
len(self.dead_plasma_managers)))
|
||||
# Call the handler.
|
||||
assert(message_handler is not None)
|
||||
message_handler(channel, data)
|
||||
|
||||
# Handle messages from the subscription channels.
|
||||
while True:
|
||||
# Record how many dead local schedulers and plasma managers we had at the
|
||||
# beginning of this round.
|
||||
num_dead_local_schedulers = len(self.dead_local_schedulers)
|
||||
num_dead_plasma_managers = len(self.dead_plasma_managers)
|
||||
# Process a round of messages.
|
||||
self.process_messages()
|
||||
# If any new local schedulers or plasma managers were marked as dead in
|
||||
# this round, clean up the associated state.
|
||||
if len(self.dead_local_schedulers) > num_dead_local_schedulers:
|
||||
self.cleanup_task_table()
|
||||
if len(self.dead_plasma_managers) > num_dead_plasma_managers:
|
||||
self.cleanup_object_table()
|
||||
def run(self):
|
||||
"""Run the monitor.
|
||||
|
||||
# Handle plasma managers that timed out during this round.
|
||||
plasma_manager_ids = list(self.live_plasma_managers.keys())
|
||||
for plasma_manager_id in plasma_manager_ids:
|
||||
if ((self.live_plasma_managers
|
||||
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
|
||||
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
|
||||
# Remove the plasma manager from the managers whose heartbeats we're
|
||||
# tracking.
|
||||
del self.live_plasma_managers[plasma_manager_id]
|
||||
# Remove the plasma manager from the db_client table. The
|
||||
# corresponding state in the object table will be cleaned up once we
|
||||
# receive the notification for this db_client deletion.
|
||||
self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id)
|
||||
This function loops forever, checking for messages about dead database
|
||||
clients and cleaning up state accordingly.
|
||||
"""
|
||||
# Initialize the subscription channel.
|
||||
self.subscribe(DB_CLIENT_TABLE_NAME)
|
||||
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
|
||||
self.subscribe(DRIVER_DEATH_CHANNEL)
|
||||
|
||||
# Increment the number of heartbeats that we've missed from each plasma
|
||||
# manager.
|
||||
for plasma_manager_id in self.live_plasma_managers:
|
||||
self.live_plasma_managers[plasma_manager_id] += 1
|
||||
# Scan the database table for dead database clients. NOTE: This must be
|
||||
# called before reading any messages from the subscription channel.
|
||||
# This ensures that we start in a consistent state, since we may have
|
||||
# missed notifications that were sent before we connected to the
|
||||
# subscription channel.
|
||||
self.scan_db_client_table()
|
||||
# If there were any dead clients at startup, clean up the associated
|
||||
# state in the state tables.
|
||||
if len(self.dead_local_schedulers) > 0:
|
||||
self.cleanup_task_table()
|
||||
if len(self.dead_plasma_managers) > 0:
|
||||
self.cleanup_object_table()
|
||||
log.debug("{} dead local schedulers, {} plasma managers total, {} "
|
||||
"dead plasma managers".format(
|
||||
len(self.dead_local_schedulers),
|
||||
(len(self.live_plasma_managers) +
|
||||
len(self.dead_plasma_managers)),
|
||||
len(self.dead_plasma_managers)))
|
||||
|
||||
# Wait for a heartbeat interval before processing the next round of
|
||||
# messages.
|
||||
time.sleep(HEARTBEAT_TIMEOUT_MILLISECONDS * 1e-3)
|
||||
# Handle messages from the subscription channels.
|
||||
while True:
|
||||
# Record how many dead local schedulers and plasma managers we had
|
||||
# at the beginning of this round.
|
||||
num_dead_local_schedulers = len(self.dead_local_schedulers)
|
||||
num_dead_plasma_managers = len(self.dead_plasma_managers)
|
||||
# Process a round of messages.
|
||||
self.process_messages()
|
||||
# If any new local schedulers or plasma managers were marked as
|
||||
# dead in this round, clean up the associated state.
|
||||
if len(self.dead_local_schedulers) > num_dead_local_schedulers:
|
||||
self.cleanup_task_table()
|
||||
if len(self.dead_plasma_managers) > num_dead_plasma_managers:
|
||||
self.cleanup_object_table()
|
||||
|
||||
# Handle plasma managers that timed out during this round.
|
||||
plasma_manager_ids = list(self.live_plasma_managers.keys())
|
||||
for plasma_manager_id in plasma_manager_ids:
|
||||
if ((self.live_plasma_managers
|
||||
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
|
||||
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
|
||||
# Remove the plasma manager from the managers whose
|
||||
# heartbeats we're tracking.
|
||||
del self.live_plasma_managers[plasma_manager_id]
|
||||
# Remove the plasma manager from the db_client table. The
|
||||
# corresponding state in the object table will be cleaned
|
||||
# up once we receive the notification for this db_client
|
||||
# deletion.
|
||||
self.redis.execute_command("RAY.DISCONNECT",
|
||||
plasma_manager_id)
|
||||
|
||||
# Increment the number of heartbeats that we've missed from each
|
||||
# plasma manager.
|
||||
for plasma_manager_id in self.live_plasma_managers:
|
||||
self.live_plasma_managers[plasma_manager_id] += 1
|
||||
|
||||
# Wait for a heartbeat interval before processing the next round of
|
||||
# messages.
|
||||
time.sleep(HEARTBEAT_TIMEOUT_MILLISECONDS * 1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="the address to use for Redis")
|
||||
args = parser.parse_args()
|
||||
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
|
||||
"monitor to connect to."))
|
||||
parser.add_argument("--redis-address", required=True, type=str,
|
||||
help="the address to use for Redis")
|
||||
args = parser.parse_args()
|
||||
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
redis_port = get_port(args.redis_address)
|
||||
redis_ip_address = get_ip_address(args.redis_address)
|
||||
redis_port = get_port(args.redis_address)
|
||||
|
||||
# Initialize the global state.
|
||||
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
|
||||
# Initialize the global state.
|
||||
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
|
||||
|
||||
monitor = Monitor(redis_ip_address, redis_port)
|
||||
monitor.run()
|
||||
monitor = Monitor(redis_ip_address, redis_port)
|
||||
monitor.run()
|
||||
|
||||
@@ -16,28 +16,26 @@ __all__ = ["deserialize_list", "numbuf_error",
|
||||
"store_list", "write_to_buffer"]
|
||||
|
||||
try:
|
||||
from ray.core.src.numbuf.libnumbuf import (deserialize_list, numbuf_error,
|
||||
numbuf_plasma_object_exists_error,
|
||||
read_from_buffer,
|
||||
register_callbacks, retrieve_list,
|
||||
serialize_list, store_list,
|
||||
write_to_buffer)
|
||||
from ray.core.src.numbuf.libnumbuf import (
|
||||
deserialize_list, numbuf_error, numbuf_plasma_object_exists_error,
|
||||
read_from_buffer, register_callbacks, retrieve_list, serialize_list,
|
||||
store_list, write_to_buffer)
|
||||
except ImportError as e:
|
||||
if (hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or
|
||||
"CXX" in e.msg)):
|
||||
# This code path should be taken with Python 3.
|
||||
e.msg += helpful_message
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str) and
|
||||
("libstdc++" in e.message or "CXX" in e.message)):
|
||||
# This code path should be taken with Python 2.
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
|
||||
len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
if condition:
|
||||
e.args = (e.args[0] + helpful_message,)
|
||||
else:
|
||||
if not hasattr(e, "args"):
|
||||
e.args = ()
|
||||
elif not isinstance(e.args, tuple):
|
||||
e.args = (e.args,)
|
||||
e.args += (helpful_message,)
|
||||
raise
|
||||
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
|
||||
("libstdc++" in e.msg or "CXX" in e.msg))):
|
||||
# This code path should be taken with Python 3.
|
||||
e.msg += helpful_message
|
||||
elif (hasattr(e, "message") and isinstance(e.message, str) and
|
||||
("libstdc++" in e.message or "CXX" in e.message)):
|
||||
# This code path should be taken with Python 2.
|
||||
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
|
||||
len(e.args) == 1 and isinstance(e.args[0], str))
|
||||
if condition:
|
||||
e.args = (e.args[0] + helpful_message,)
|
||||
else:
|
||||
if not hasattr(e, "args"):
|
||||
e.args = ()
|
||||
elif not isinstance(e.args, tuple):
|
||||
e.args = (e.args,)
|
||||
e.args += (helpful_message,)
|
||||
raise
|
||||
|
||||
+344
-329
@@ -21,342 +21,355 @@ PLASMA_WAIT_TIMEOUT = 2 ** 30
|
||||
|
||||
|
||||
class PlasmaBuffer(object):
|
||||
"""This is the type of objects returned by calls to get with a PlasmaClient.
|
||||
"""This is the type returned by calls to get with a PlasmaClient.
|
||||
|
||||
We define our own class instead of directly returning a buffer object so that
|
||||
we can add a custom destructor which notifies Plasma that the object is no
|
||||
longer being used, so the memory in the Plasma store backing the object can
|
||||
potentially be freed.
|
||||
We define our own class instead of directly returning a buffer object so
|
||||
that we can add a custom destructor which notifies Plasma that the object
|
||||
is no longer being used, so the memory in the Plasma store backing the
|
||||
object can potentially be freed.
|
||||
|
||||
Attributes:
|
||||
buffer (buffer): A buffer containing an object in the Plasma store.
|
||||
plasma_id (PlasmaID): The ID of the object in the buffer.
|
||||
plasma_client (PlasmaClient): The PlasmaClient that we use to communicate
|
||||
with the store and manager.
|
||||
"""
|
||||
def __init__(self, buff, plasma_id, plasma_client):
|
||||
"""Initialize a PlasmaBuffer."""
|
||||
self.buffer = buff
|
||||
self.plasma_id = plasma_id
|
||||
self.plasma_client = plasma_client
|
||||
|
||||
def __del__(self):
|
||||
"""Notify Plasma that the object is no longer needed.
|
||||
|
||||
If the plasma client has been shut down, then don't do anything.
|
||||
Attributes:
|
||||
buffer (buffer): A buffer containing an object in the Plasma store.
|
||||
plasma_id (PlasmaID): The ID of the object in the buffer.
|
||||
plasma_client (PlasmaClient): The PlasmaClient that we use to communicate
|
||||
with the store and manager.
|
||||
"""
|
||||
if self.plasma_client.alive:
|
||||
libplasma.release(self.plasma_client.conn, self.plasma_id)
|
||||
def __init__(self, buff, plasma_id, plasma_client):
|
||||
"""Initialize a PlasmaBuffer."""
|
||||
self.buffer = buff
|
||||
self.plasma_id = plasma_id
|
||||
self.plasma_client = plasma_client
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
|
||||
# We currently don't allow slicing plasma buffers. We should handle this
|
||||
# better, but it requires some care because the slice may be backed by the
|
||||
# same memory in the object store, but the original plasma buffer may go
|
||||
# out of scope causing the memory to no longer be accessible.
|
||||
assert not isinstance(index, slice)
|
||||
value = self.buffer[index]
|
||||
if sys.version_info >= (3, 0) and not isinstance(index, slice):
|
||||
value = chr(value)
|
||||
return value
|
||||
def __del__(self):
|
||||
"""Notify Plasma that the object is no longer needed.
|
||||
|
||||
def __setitem__(self, index, value):
|
||||
"""Write to the PlasmaBuffer as if it were just a regular buffer.
|
||||
If the plasma client has been shut down, then don't do anything.
|
||||
"""
|
||||
if self.plasma_client.alive:
|
||||
libplasma.release(self.plasma_client.conn, self.plasma_id)
|
||||
|
||||
This should fail because the buffer should be read only.
|
||||
"""
|
||||
# We currently don't allow slicing plasma buffers. We should handle this
|
||||
# better, but it requires some care because the slice may be backed by the
|
||||
# same memory in the object store, but the original plasma buffer may go
|
||||
# out of scope causing the memory to no longer be accessible.
|
||||
assert not isinstance(index, slice)
|
||||
if sys.version_info >= (3, 0) and not isinstance(index, slice):
|
||||
value = ord(value)
|
||||
self.buffer[index] = value
|
||||
def __getitem__(self, index):
|
||||
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
|
||||
# We currently don't allow slicing plasma buffers. We should handle
|
||||
# this better, but it requires some care because the slice may be
|
||||
# backed by the same memory in the object store, but the original
|
||||
# plasma buffer may go out of scope causing the memory to no longer be
|
||||
# accessible.
|
||||
assert not isinstance(index, slice)
|
||||
value = self.buffer[index]
|
||||
if sys.version_info >= (3, 0) and not isinstance(index, slice):
|
||||
value = chr(value)
|
||||
return value
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the buffer."""
|
||||
return len(self.buffer)
|
||||
def __setitem__(self, index, value):
|
||||
"""Write to the PlasmaBuffer as if it were just a regular buffer.
|
||||
|
||||
This should fail because the buffer should be read only.
|
||||
"""
|
||||
# We currently don't allow slicing plasma buffers. We should handle
|
||||
# this better, but it requires some care because the slice may be
|
||||
# backed by the same memory in the object store, but the original
|
||||
# plasma buffer may go out of scope causing the memory to no longer be
|
||||
# accessible.
|
||||
assert not isinstance(index, slice)
|
||||
if sys.version_info >= (3, 0) and not isinstance(index, slice):
|
||||
value = ord(value)
|
||||
self.buffer[index] = value
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the buffer."""
|
||||
return len(self.buffer)
|
||||
|
||||
|
||||
def buffers_equal(buff1, buff2):
|
||||
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
|
||||
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
|
||||
|
||||
This method should only be used in the tests. We implement a special helper
|
||||
method for doing this because doing comparisons by slicing is much faster,
|
||||
but we don't want to expose slicing of PlasmaBuffer objects because it
|
||||
currently is not safe.
|
||||
"""
|
||||
buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1
|
||||
buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2
|
||||
return buff1_to_compare[:] == buff2_to_compare[:]
|
||||
This method should only be used in the tests. We implement a special helper
|
||||
method for doing this because doing comparisons by slicing is much faster,
|
||||
but we don't want to expose slicing of PlasmaBuffer objects because it
|
||||
currently is not safe.
|
||||
"""
|
||||
buff1_to_compare = (buff1.buffer if isinstance(buff1, PlasmaBuffer)
|
||||
else buff1)
|
||||
buff2_to_compare = (buff2.buffer if isinstance(buff2, PlasmaBuffer)
|
||||
else buff2)
|
||||
return buff1_to_compare[:] == buff2_to_compare[:]
|
||||
|
||||
|
||||
class PlasmaClient(object):
|
||||
"""The PlasmaClient is used to interface with a plasma store and manager.
|
||||
"""The PlasmaClient is used to interface with a plasma store and manager.
|
||||
|
||||
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
|
||||
buffer, and get a buffer. Buffers are referred to by object IDs, which are
|
||||
strings.
|
||||
"""
|
||||
|
||||
def __init__(self, store_socket_name, manager_socket_name=None,
|
||||
release_delay=64):
|
||||
"""Initialize the PlasmaClient.
|
||||
|
||||
Args:
|
||||
store_socket_name (str): Name of the socket the plasma store is listening
|
||||
at.
|
||||
manager_socket_name (str): Name of the socket the plasma manager is
|
||||
listening at.
|
||||
release_delay (int): The maximum number of objects that the client will
|
||||
keep and delay releasing (for caching reasons).
|
||||
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
|
||||
buffer, and get a buffer. Buffers are referred to by object IDs, which are
|
||||
strings.
|
||||
"""
|
||||
self.store_socket_name = store_socket_name
|
||||
self.manager_socket_name = manager_socket_name
|
||||
self.alive = True
|
||||
|
||||
if manager_socket_name is not None:
|
||||
self.conn = libplasma.connect(store_socket_name, manager_socket_name,
|
||||
release_delay)
|
||||
else:
|
||||
self.conn = libplasma.connect(store_socket_name, "", release_delay)
|
||||
def __init__(self, store_socket_name, manager_socket_name=None,
|
||||
release_delay=64):
|
||||
"""Initialize the PlasmaClient.
|
||||
|
||||
def shutdown(self):
|
||||
"""Shutdown the client so that it does not send messages.
|
||||
Args:
|
||||
store_socket_name (str): Name of the socket the plasma store is
|
||||
listening at.
|
||||
manager_socket_name (str): Name of the socket the plasma manager is
|
||||
listening at.
|
||||
release_delay (int): The maximum number of objects that the client
|
||||
will keep and delay releasing (for caching reasons).
|
||||
"""
|
||||
self.store_socket_name = store_socket_name
|
||||
self.manager_socket_name = manager_socket_name
|
||||
self.alive = True
|
||||
|
||||
If we kill the Plasma store and Plasma manager that this client is
|
||||
connected to, then we can use this method to prevent the client from trying
|
||||
to send messages to the killed processes.
|
||||
"""
|
||||
if self.alive:
|
||||
libplasma.disconnect(self.conn)
|
||||
self.alive = False
|
||||
if manager_socket_name is not None:
|
||||
self.conn = libplasma.connect(store_socket_name,
|
||||
manager_socket_name,
|
||||
release_delay)
|
||||
else:
|
||||
self.conn = libplasma.connect(store_socket_name, "", release_delay)
|
||||
|
||||
def create(self, object_id, size, metadata=None):
|
||||
"""Create a new buffer in the PlasmaStore for a particular object ID.
|
||||
def shutdown(self):
|
||||
"""Shutdown the client so that it does not send messages.
|
||||
|
||||
The returned buffer is mutable until seal is called.
|
||||
If we kill the Plasma store and Plasma manager that this client is
|
||||
connected to, then we can use this method to prevent the client from
|
||||
trying to send messages to the killed processes.
|
||||
"""
|
||||
if self.alive:
|
||||
libplasma.disconnect(self.conn)
|
||||
self.alive = False
|
||||
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
size (int): The size in bytes of the created buffer.
|
||||
metadata (buffer): An optional buffer encoding whatever metadata the user
|
||||
wishes to encode.
|
||||
def create(self, object_id, size, metadata=None):
|
||||
"""Create a new buffer in the PlasmaStore for a particular object ID.
|
||||
|
||||
Raises:
|
||||
plasma_object_exists_error: This exception is raised if the object could
|
||||
not be created because there already is an object with the same ID in
|
||||
the plasma store.
|
||||
plasma_out_of_memory_error: This exception is raised if the object could
|
||||
not be created because the plasma store is unable to evict enough
|
||||
objects to create room for it.
|
||||
"""
|
||||
# Turn the metadata into the right type.
|
||||
metadata = bytearray(b"") if metadata is None else metadata
|
||||
buff = libplasma.create(self.conn, object_id, size, metadata)
|
||||
return PlasmaBuffer(buff, object_id, self)
|
||||
The returned buffer is mutable until seal is called.
|
||||
|
||||
def get(self, object_ids, timeout_ms=-1):
|
||||
"""Create a buffer from the PlasmaStore based on object ID.
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
size (int): The size in bytes of the created buffer.
|
||||
metadata (buffer): An optional buffer encoding whatever metadata the
|
||||
user wishes to encode.
|
||||
|
||||
If the object has not been sealed yet, this call will block. The retrieved
|
||||
buffer is immutable.
|
||||
Raises:
|
||||
plasma_object_exists_error: This exception is raised if the object
|
||||
could not be created because there already is an object with the
|
||||
same ID in the plasma store.
|
||||
plasma_out_of_memory_error: This exception is raised if the object
|
||||
could not be created because the plasma store is unable to evict
|
||||
enough objects to create room for it.
|
||||
"""
|
||||
# Turn the metadata into the right type.
|
||||
metadata = bytearray(b"") if metadata is None else metadata
|
||||
buff = libplasma.create(self.conn, object_id, size, metadata)
|
||||
return PlasmaBuffer(buff, object_id, self)
|
||||
|
||||
Args:
|
||||
object_ids (List[str]): A list of strings used to identify some objects.
|
||||
timeout_ms (int): The number of milliseconds that the get call should
|
||||
block before timing out and returning. Pass -1 if the call should block
|
||||
and 0 if the call should return immediately.
|
||||
"""
|
||||
results = libplasma.get(self.conn, object_ids, timeout_ms)
|
||||
assert len(object_ids) == len(results)
|
||||
returns = []
|
||||
for i in range(len(object_ids)):
|
||||
if results[i] is None:
|
||||
returns.append(None)
|
||||
else:
|
||||
returns.append(PlasmaBuffer(results[i][0], object_ids[i], self))
|
||||
return returns
|
||||
def get(self, object_ids, timeout_ms=-1):
|
||||
"""Create a buffer from the PlasmaStore based on object ID.
|
||||
|
||||
def get_metadata(self, object_ids, timeout_ms=-1):
|
||||
"""Create a buffer from the PlasmaStore based on object ID.
|
||||
If the object has not been sealed yet, this call will block. The
|
||||
retrieved buffer is immutable.
|
||||
|
||||
If the object has not been sealed yet, this call will block until the
|
||||
object has been sealed. The retrieved buffer is immutable.
|
||||
Args:
|
||||
object_ids (List[str]): A list of strings used to identify some
|
||||
objects.
|
||||
timeout_ms (int): The number of milliseconds that the get call should
|
||||
block before timing out and returning. Pass -1 if the call should
|
||||
block and 0 if the call should return immediately.
|
||||
"""
|
||||
results = libplasma.get(self.conn, object_ids, timeout_ms)
|
||||
assert len(object_ids) == len(results)
|
||||
returns = []
|
||||
for i in range(len(object_ids)):
|
||||
if results[i] is None:
|
||||
returns.append(None)
|
||||
else:
|
||||
returns.append(PlasmaBuffer(results[i][0], object_ids[i],
|
||||
self))
|
||||
return returns
|
||||
|
||||
Args:
|
||||
object_ids (List[str]): A list of strings used to identify some objects.
|
||||
timeout_ms (int): The number of milliseconds that the get call should
|
||||
block before timing out and returning. Pass -1 if the call should block
|
||||
and 0 if the call should return immediately.
|
||||
"""
|
||||
results = libplasma.get(self.conn, object_ids, timeout_ms)
|
||||
assert len(object_ids) == len(results)
|
||||
returns = []
|
||||
for i in range(len(object_ids)):
|
||||
if results[i] is None:
|
||||
returns.append(None)
|
||||
else:
|
||||
returns.append(PlasmaBuffer(results[i][1], object_ids[i], self))
|
||||
return returns
|
||||
def get_metadata(self, object_ids, timeout_ms=-1):
|
||||
"""Create a buffer from the PlasmaStore based on object ID.
|
||||
|
||||
def contains(self, object_id):
|
||||
"""Check if the object is present and has been sealed in the PlasmaStore.
|
||||
If the object has not been sealed yet, this call will block until the
|
||||
object has been sealed. The retrieved buffer is immutable.
|
||||
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
return libplasma.contains(self.conn, object_id)
|
||||
Args:
|
||||
object_ids (List[str]): A list of strings used to identify some
|
||||
objects.
|
||||
timeout_ms (int): The number of milliseconds that the get call should
|
||||
block before timing out and returning. Pass -1 if the call should
|
||||
block and 0 if the call should return immediately.
|
||||
"""
|
||||
results = libplasma.get(self.conn, object_ids, timeout_ms)
|
||||
assert len(object_ids) == len(results)
|
||||
returns = []
|
||||
for i in range(len(object_ids)):
|
||||
if results[i] is None:
|
||||
returns.append(None)
|
||||
else:
|
||||
returns.append(PlasmaBuffer(results[i][1], object_ids[i],
|
||||
self))
|
||||
return returns
|
||||
|
||||
def hash(self, object_id):
|
||||
"""Compute the hash of an object in the object store.
|
||||
def contains(self, object_id):
|
||||
"""Check if the object is present and has been sealed.
|
||||
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
return libplasma.contains(self.conn, object_id)
|
||||
|
||||
Returns:
|
||||
A digest string object's SHA256 hash. If the object isn't in the object
|
||||
store, the string will have length zero.
|
||||
"""
|
||||
return libplasma.hash(self.conn, object_id)
|
||||
def hash(self, object_id):
|
||||
"""Compute the hash of an object in the object store.
|
||||
|
||||
def seal(self, object_id):
|
||||
"""Seal the buffer in the PlasmaStore for a particular object ID.
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
|
||||
Once a buffer has been sealed, the buffer is immutable and can only be
|
||||
accessed through get.
|
||||
Returns:
|
||||
A digest string object's SHA256 hash. If the object isn't in the
|
||||
object store, the string will have length zero.
|
||||
"""
|
||||
return libplasma.hash(self.conn, object_id)
|
||||
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
libplasma.seal(self.conn, object_id)
|
||||
def seal(self, object_id):
|
||||
"""Seal the buffer in the PlasmaStore for a particular object ID.
|
||||
|
||||
def delete(self, object_id):
|
||||
"""Delete the buffer in the PlasmaStore for a particular object ID.
|
||||
Once a buffer has been sealed, the buffer is immutable and can only be
|
||||
accessed through get.
|
||||
|
||||
Once a buffer has been deleted, the buffer is no longer accessible.
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
libplasma.seal(self.conn, object_id)
|
||||
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
libplasma.delete(self.conn, object_id)
|
||||
def delete(self, object_id):
|
||||
"""Delete the buffer in the PlasmaStore for a particular object ID.
|
||||
|
||||
def evict(self, num_bytes):
|
||||
"""Evict some objects until to recover some bytes.
|
||||
Once a buffer has been deleted, the buffer is no longer accessible.
|
||||
|
||||
Recover at least num_bytes bytes if possible.
|
||||
Args:
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
libplasma.delete(self.conn, object_id)
|
||||
|
||||
Args:
|
||||
num_bytes (int): The number of bytes to attempt to recover.
|
||||
"""
|
||||
return libplasma.evict(self.conn, num_bytes)
|
||||
def evict(self, num_bytes):
|
||||
"""Evict some objects until to recover some bytes.
|
||||
|
||||
def transfer(self, addr, port, object_id):
|
||||
"""Transfer local object with id object_id to another plasma instance
|
||||
Recover at least num_bytes bytes if possible.
|
||||
|
||||
Args:
|
||||
addr (str): IPv4 address of the plasma instance the object is sent to.
|
||||
port (int): Port number of the plasma instance the object is sent to.
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
return libplasma.transfer(self.conn, object_id, addr, port)
|
||||
Args:
|
||||
num_bytes (int): The number of bytes to attempt to recover.
|
||||
"""
|
||||
return libplasma.evict(self.conn, num_bytes)
|
||||
|
||||
def fetch(self, object_ids):
|
||||
"""Fetch the objects with the given IDs from other plasma manager instances.
|
||||
def transfer(self, addr, port, object_id):
|
||||
"""Transfer local object with id object_id to another plasma instance
|
||||
|
||||
Args:
|
||||
object_ids (List[str]): A list of strings used to identify the objects.
|
||||
"""
|
||||
return libplasma.fetch(self.conn, object_ids)
|
||||
Args:
|
||||
addr (str): IPv4 address of the plasma instance the object is sent
|
||||
to.
|
||||
port (int): Port number of the plasma instance the object is sent to.
|
||||
object_id (str): A string used to identify an object.
|
||||
"""
|
||||
return libplasma.transfer(self.conn, object_id, addr, port)
|
||||
|
||||
def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1):
|
||||
"""Wait until num_returns objects in object_ids are ready.
|
||||
def fetch(self, object_ids):
|
||||
"""Fetch the objects with the given IDs from other plasma managers.
|
||||
|
||||
Currently, the object ID arguments to wait must be unique.
|
||||
Args:
|
||||
object_ids (List[str]): A list of strings used to identify the
|
||||
objects.
|
||||
"""
|
||||
return libplasma.fetch(self.conn, object_ids)
|
||||
|
||||
Args:
|
||||
object_ids (List[str]): List of object IDs to wait for.
|
||||
timeout (int): Return to the caller after timeout milliseconds.
|
||||
num_returns (int): We are waiting for this number of objects to be ready.
|
||||
def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1):
|
||||
"""Wait until num_returns objects in object_ids are ready.
|
||||
|
||||
Returns:
|
||||
ready_ids, waiting_ids (List[str], List[str]): List of object IDs that
|
||||
are ready and list of object IDs we might still wait on respectively.
|
||||
"""
|
||||
# Check that the object ID arguments are unique. The plasma manager
|
||||
# currently crashes if given duplicate object IDs.
|
||||
if len(object_ids) != len(set(object_ids)):
|
||||
raise Exception("Wait requires a list of unique object IDs.")
|
||||
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout,
|
||||
num_returns)
|
||||
return ready_ids, list(waiting_ids)
|
||||
Currently, the object ID arguments to wait must be unique.
|
||||
|
||||
def subscribe(self):
|
||||
"""Subscribe to notifications about sealed objects."""
|
||||
self.notification_fd = libplasma.subscribe(self.conn)
|
||||
Args:
|
||||
object_ids (List[str]): List of object IDs to wait for.
|
||||
timeout (int): Return to the caller after timeout milliseconds.
|
||||
num_returns (int): We are waiting for this number of objects to be
|
||||
ready.
|
||||
|
||||
def get_next_notification(self):
|
||||
"""Get the next notification from the notification socket."""
|
||||
return libplasma.receive_notification(self.notification_fd)
|
||||
Returns:
|
||||
ready_ids, waiting_ids (List[str], List[str]): List of object IDs
|
||||
that are ready and list of object IDs we might still wait on
|
||||
respectively.
|
||||
"""
|
||||
# Check that the object ID arguments are unique. The plasma manager
|
||||
# currently crashes if given duplicate object IDs.
|
||||
if len(object_ids) != len(set(object_ids)):
|
||||
raise Exception("Wait requires a list of unique object IDs.")
|
||||
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout,
|
||||
num_returns)
|
||||
return ready_ids, list(waiting_ids)
|
||||
|
||||
def subscribe(self):
|
||||
"""Subscribe to notifications about sealed objects."""
|
||||
self.notification_fd = libplasma.subscribe(self.conn)
|
||||
|
||||
def get_next_notification(self):
|
||||
"""Get the next notification from the notification socket."""
|
||||
return libplasma.receive_notification(self.notification_fd)
|
||||
|
||||
|
||||
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
|
||||
|
||||
|
||||
def random_name():
|
||||
return str(random.randint(0, 99999999))
|
||||
return str(random.randint(0, 99999999))
|
||||
|
||||
|
||||
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
|
||||
use_valgrind=False, use_profiler=False,
|
||||
stdout_file=None, stderr_file=None):
|
||||
"""Start a plasma store process.
|
||||
"""Start a plasma store process.
|
||||
|
||||
Args:
|
||||
use_valgrind (bool): True if the plasma store should be started inside of
|
||||
valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the plasma store should be started inside a
|
||||
profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
redirection should happen, then this should be None.
|
||||
Args:
|
||||
use_valgrind (bool): True if the plasma store should be started inside of
|
||||
valgrind. If this is True, use_profiler must be False.
|
||||
use_profiler (bool): True if the plasma store should be started inside a
|
||||
profiler. If this is True, use_valgrind must be False.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If
|
||||
no redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If
|
||||
no redirection should happen, then this should be None.
|
||||
|
||||
Return:
|
||||
A tuple of the name of the plasma store socket and the process ID of the
|
||||
plasma store process.
|
||||
"""
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
plasma_store_executable = os.path.join(os.path.abspath(
|
||||
os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_store")
|
||||
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
|
||||
command = [plasma_store_executable,
|
||||
"-s", plasma_store_name,
|
||||
"-m", str(plasma_store_memory)]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
return plasma_store_name, pid
|
||||
Return:
|
||||
A tuple of the name of the plasma store socket and the process ID of the
|
||||
plasma store process.
|
||||
"""
|
||||
if use_valgrind and use_profiler:
|
||||
raise Exception("Cannot use valgrind and profiler at the same time.")
|
||||
plasma_store_executable = os.path.join(os.path.abspath(
|
||||
os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_store")
|
||||
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
|
||||
command = [plasma_store_executable,
|
||||
"-s", plasma_store_name,
|
||||
"-m", str(plasma_store_memory)]
|
||||
if use_valgrind:
|
||||
pid = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--leak-check-heuristics=stdstring",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
elif use_profiler:
|
||||
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(1.0)
|
||||
else:
|
||||
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
|
||||
time.sleep(0.1)
|
||||
return plasma_store_name, pid
|
||||
|
||||
|
||||
def new_port():
|
||||
return random.randint(10000, 65535)
|
||||
return random.randint(10000, 65535)
|
||||
|
||||
|
||||
def start_plasma_manager(store_name, redis_address,
|
||||
@@ -364,69 +377,71 @@ def start_plasma_manager(store_name, redis_address,
|
||||
num_retries=20, use_valgrind=False,
|
||||
run_profiler=False, stdout_file=None,
|
||||
stderr_file=None):
|
||||
"""Start a plasma manager and return the ports it listens on.
|
||||
"""Start a plasma manager and return the ports it listens on.
|
||||
|
||||
Args:
|
||||
store_name (str): The name of the plasma store socket.
|
||||
redis_address (str): The address of the Redis server.
|
||||
node_ip_address (str): The IP address of the node.
|
||||
plasma_manager_port (int): The port to use for the plasma manager. If this
|
||||
is not provided, a port will be generated at random.
|
||||
use_valgrind (bool): True if the Plasma manager should be started inside of
|
||||
valgrind and False otherwise.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If no
|
||||
redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If no
|
||||
redirection should happen, then this should be None.
|
||||
Args:
|
||||
store_name (str): The name of the plasma store socket.
|
||||
redis_address (str): The address of the Redis server.
|
||||
node_ip_address (str): The IP address of the node.
|
||||
plasma_manager_port (int): The port to use for the plasma manager. If
|
||||
this is not provided, a port will be generated at random.
|
||||
use_valgrind (bool): True if the Plasma manager should be started inside
|
||||
of valgrind and False otherwise.
|
||||
stdout_file: A file handle opened for writing to redirect stdout to. If
|
||||
no redirection should happen, then this should be None.
|
||||
stderr_file: A file handle opened for writing to redirect stderr to. If
|
||||
no redirection should happen, then this should be None.
|
||||
|
||||
Returns:
|
||||
A tuple of the Plasma manager socket name, the process ID of the Plasma
|
||||
manager process, and the port that the manager is listening on.
|
||||
Returns:
|
||||
A tuple of the Plasma manager socket name, the process ID of the Plasma
|
||||
manager process, and the port that the manager is listening on.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the manager could not be started.
|
||||
"""
|
||||
plasma_manager_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_manager")
|
||||
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
|
||||
if plasma_manager_port is not None:
|
||||
if num_retries != 1:
|
||||
raise Exception("num_retries must be 1 if port is specified.")
|
||||
else:
|
||||
plasma_manager_port = new_port()
|
||||
process = None
|
||||
counter = 0
|
||||
while counter < num_retries:
|
||||
if counter > 0:
|
||||
print("Plasma manager failed to start, retrying now.")
|
||||
command = [plasma_manager_executable,
|
||||
"-s", store_name,
|
||||
"-m", plasma_manager_name,
|
||||
"-h", node_ip_address,
|
||||
"-p", str(plasma_manager_port),
|
||||
"-r", redis_address,
|
||||
]
|
||||
if use_valgrind:
|
||||
process = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
elif run_profiler:
|
||||
process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
Raises:
|
||||
Exception: An exception is raised if the manager could not be started.
|
||||
"""
|
||||
plasma_manager_executable = os.path.join(
|
||||
os.path.abspath(os.path.dirname(__file__)),
|
||||
"../core/src/plasma/plasma_manager")
|
||||
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
|
||||
if plasma_manager_port is not None:
|
||||
if num_retries != 1:
|
||||
raise Exception("num_retries must be 1 if port is specified.")
|
||||
else:
|
||||
process = subprocess.Popen(command, stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
# This sleep is critical. If the plasma_manager fails to start because the
|
||||
# port is already in use, then we need it to fail within 0.1 seconds.
|
||||
time.sleep(0.1)
|
||||
# See if the process has terminated
|
||||
if process.poll() is None:
|
||||
return plasma_manager_name, process, plasma_manager_port
|
||||
# Generate a new port and try again.
|
||||
plasma_manager_port = new_port()
|
||||
counter += 1
|
||||
raise Exception("Couldn't start plasma manager.")
|
||||
plasma_manager_port = new_port()
|
||||
process = None
|
||||
counter = 0
|
||||
while counter < num_retries:
|
||||
if counter > 0:
|
||||
print("Plasma manager failed to start, retrying now.")
|
||||
command = [plasma_manager_executable,
|
||||
"-s", store_name,
|
||||
"-m", plasma_manager_name,
|
||||
"-h", node_ip_address,
|
||||
"-p", str(plasma_manager_port),
|
||||
"-r", redis_address,
|
||||
]
|
||||
if use_valgrind:
|
||||
process = subprocess.Popen(["valgrind",
|
||||
"--track-origins=yes",
|
||||
"--leak-check=full",
|
||||
"--show-leak-kinds=all",
|
||||
"--error-exitcode=1"] + command,
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
elif run_profiler:
|
||||
process = subprocess.Popen((["valgrind", "--tool=callgrind"] +
|
||||
command),
|
||||
stdout=stdout_file, stderr=stderr_file)
|
||||
else:
|
||||
process = subprocess.Popen(command, stdout=stdout_file,
|
||||
stderr=stderr_file)
|
||||
# This sleep is critical. If the plasma_manager fails to start because
|
||||
# the port is already in use, then we need it to fail within 0.1
|
||||
# seconds.
|
||||
time.sleep(0.1)
|
||||
# See if the process has terminated
|
||||
if process.poll() is None:
|
||||
return plasma_manager_name, process, plasma_manager_port
|
||||
# Generate a new port and try again.
|
||||
plasma_manager_port = new_port()
|
||||
counter += 1
|
||||
raise Exception("Couldn't start plasma manager.")
|
||||
|
||||
+791
-753
File diff suppressed because it is too large
Load Diff
+25
-23
@@ -7,39 +7,41 @@ import random
|
||||
|
||||
|
||||
def random_object_id():
|
||||
return np.random.bytes(20)
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
def generate_metadata(length):
|
||||
metadata_buffer = bytearray(length)
|
||||
if length > 0:
|
||||
metadata_buffer[0] = random.randint(0, 255)
|
||||
metadata_buffer[-1] = random.randint(0, 255)
|
||||
for _ in range(100):
|
||||
metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255)
|
||||
return metadata_buffer
|
||||
metadata_buffer = bytearray(length)
|
||||
if length > 0:
|
||||
metadata_buffer[0] = random.randint(0, 255)
|
||||
metadata_buffer[-1] = random.randint(0, 255)
|
||||
for _ in range(100):
|
||||
metadata_buffer[random.randint(0, length - 1)] = (
|
||||
random.randint(0, 255))
|
||||
return metadata_buffer
|
||||
|
||||
|
||||
def write_to_data_buffer(buff, length):
|
||||
if length > 0:
|
||||
buff[0] = chr(random.randint(0, 255))
|
||||
buff[-1] = chr(random.randint(0, 255))
|
||||
for _ in range(100):
|
||||
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
|
||||
if length > 0:
|
||||
buff[0] = chr(random.randint(0, 255))
|
||||
buff[-1] = chr(random.randint(0, 255))
|
||||
for _ in range(100):
|
||||
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
|
||||
|
||||
|
||||
def create_object_with_id(client, object_id, data_size, metadata_size,
|
||||
seal=True):
|
||||
metadata = generate_metadata(metadata_size)
|
||||
memory_buffer = client.create(object_id, data_size, metadata)
|
||||
write_to_data_buffer(memory_buffer, data_size)
|
||||
if seal:
|
||||
client.seal(object_id)
|
||||
return memory_buffer, metadata
|
||||
metadata = generate_metadata(metadata_size)
|
||||
memory_buffer = client.create(object_id, data_size, metadata)
|
||||
write_to_data_buffer(memory_buffer, data_size)
|
||||
if seal:
|
||||
client.seal(object_id)
|
||||
return memory_buffer, metadata
|
||||
|
||||
|
||||
def create_object(client, data_size, metadata_size, seal=True):
|
||||
object_id = random_object_id()
|
||||
memory_buffer, metadata = create_object_with_id(client, object_id, data_size,
|
||||
metadata_size, seal=seal)
|
||||
return object_id, memory_buffer, metadata
|
||||
object_id = random_object_id()
|
||||
memory_buffer, metadata = create_object_with_id(client, object_id,
|
||||
data_size, metadata_size,
|
||||
seal=seal)
|
||||
return object_id, memory_buffer, metadata
|
||||
|
||||
@@ -16,97 +16,98 @@ use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
|
||||
|
||||
|
||||
class LSTMPolicy(Policy):
|
||||
def setup_graph(self, ob_space, ac_space):
|
||||
"""Setup model used for Policy.
|
||||
def setup_graph(self, ob_space, ac_space):
|
||||
"""Setup model used for Policy.
|
||||
|
||||
In this A3C implementation, both the Critic and the Actor share the model.
|
||||
"""
|
||||
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
In this A3C implementation, both the Critic and the Actor share the
|
||||
model.
|
||||
"""
|
||||
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
|
||||
|
||||
for i in range(4):
|
||||
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
|
||||
# Introduce a "fake" batch dimension of 1 after flatten so that we can do
|
||||
# LSTM over the time dim.
|
||||
x = tf.expand_dims(flatten(x), [0])
|
||||
for i in range(4):
|
||||
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
|
||||
# Introduce a "fake" batch dimension of 1 after flatten so that we can
|
||||
# do LSTM over the time dim.
|
||||
x = tf.expand_dims(flatten(x), [0])
|
||||
|
||||
size = 256
|
||||
if use_tf100_api:
|
||||
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
|
||||
else:
|
||||
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
|
||||
self.state_size = lstm.state_size
|
||||
step_size = tf.shape(self.x)[:1]
|
||||
size = 256
|
||||
if use_tf100_api:
|
||||
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
|
||||
else:
|
||||
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
|
||||
self.state_size = lstm.state_size
|
||||
step_size = tf.shape(self.x)[:1]
|
||||
|
||||
c_init = np.zeros((1, lstm.state_size.c), np.float32)
|
||||
h_init = np.zeros((1, lstm.state_size.h), np.float32)
|
||||
self.state_init = [c_init, h_init]
|
||||
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
|
||||
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
|
||||
self.state_in = [c_in, h_in]
|
||||
c_init = np.zeros((1, lstm.state_size.c), np.float32)
|
||||
h_init = np.zeros((1, lstm.state_size.h), np.float32)
|
||||
self.state_init = [c_init, h_init]
|
||||
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
|
||||
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
|
||||
self.state_in = [c_in, h_in]
|
||||
|
||||
if use_tf100_api:
|
||||
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
||||
else:
|
||||
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
|
||||
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
|
||||
lstm, x, initial_state=state_in, sequence_length=step_size,
|
||||
time_major=False)
|
||||
lstm_c, lstm_h = lstm_state
|
||||
x = tf.reshape(lstm_outputs, [-1, size])
|
||||
self.logits = linear(x, ac_space, "action",
|
||||
normalized_columns_initializer(0.01))
|
||||
self.vf = tf.reshape(linear(x, 1, "value",
|
||||
normalized_columns_initializer(1.0)), [-1])
|
||||
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
|
||||
self.sample = categorical_sample(self.logits, ac_space)[0, :]
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
self.global_step = tf.get_variable(
|
||||
"global_step", [], tf.int32,
|
||||
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||||
trainable=False)
|
||||
if use_tf100_api:
|
||||
state_in = rnn.LSTMStateTuple(c_in, h_in)
|
||||
else:
|
||||
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
|
||||
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
|
||||
lstm, x, initial_state=state_in, sequence_length=step_size,
|
||||
time_major=False)
|
||||
lstm_c, lstm_h = lstm_state
|
||||
x = tf.reshape(lstm_outputs, [-1, size])
|
||||
self.logits = linear(x, ac_space, "action",
|
||||
normalized_columns_initializer(0.01))
|
||||
self.vf = tf.reshape(linear(x, 1, "value",
|
||||
normalized_columns_initializer(1.0)), [-1])
|
||||
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
|
||||
self.sample = categorical_sample(self.logits, ac_space)[0, :]
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
tf.get_variable_scope().name)
|
||||
self.global_step = tf.get_variable(
|
||||
"global_step", [], tf.int32,
|
||||
initializer=tf.constant_initializer(0, dtype=tf.int32),
|
||||
trainable=False)
|
||||
|
||||
def get_gradients(self, batch):
|
||||
"""Computing the gradient is actually model-dependent.
|
||||
def get_gradients(self, batch):
|
||||
"""Computing the gradient is actually model-dependent.
|
||||
|
||||
The LSTM needs its hidden states in order to compute the gradient
|
||||
accurately.
|
||||
"""
|
||||
feed_dict = {
|
||||
self.x: batch.si,
|
||||
self.ac: batch.a,
|
||||
self.adv: batch.adv,
|
||||
self.r: batch.r,
|
||||
self.state_in[0]: batch.features[0],
|
||||
self.state_in[1]: batch.features[1]
|
||||
}
|
||||
self.local_steps += 1
|
||||
return self.sess.run(self.grads, feed_dict=feed_dict)
|
||||
The LSTM needs its hidden states in order to compute the gradient
|
||||
accurately.
|
||||
"""
|
||||
feed_dict = {
|
||||
self.x: batch.si,
|
||||
self.ac: batch.a,
|
||||
self.adv: batch.adv,
|
||||
self.r: batch.r,
|
||||
self.state_in[0]: batch.features[0],
|
||||
self.state_in[1]: batch.features[1]
|
||||
}
|
||||
self.local_steps += 1
|
||||
return self.sess.run(self.grads, feed_dict=feed_dict)
|
||||
|
||||
def act(self, ob, c, h):
|
||||
return self.sess.run([self.sample, self.vf] + self.state_out,
|
||||
{self.x: [ob],
|
||||
self.state_in[0]: c,
|
||||
self.state_in[1]: h})
|
||||
def act(self, ob, c, h):
|
||||
return self.sess.run([self.sample, self.vf] + self.state_out,
|
||||
{self.x: [ob],
|
||||
self.state_in[0]: c,
|
||||
self.state_in[1]: h})
|
||||
|
||||
def value(self, ob, c, h):
|
||||
return self.sess.run(self.vf, {self.x: [ob],
|
||||
self.state_in[0]: c,
|
||||
self.state_in[1]: h})[0]
|
||||
def value(self, ob, c, h):
|
||||
return self.sess.run(self.vf, {self.x: [ob],
|
||||
self.state_in[0]: c,
|
||||
self.state_in[1]: h})[0]
|
||||
|
||||
def get_initial_features(self):
|
||||
return self.state_init
|
||||
def get_initial_features(self):
|
||||
return self.state_init
|
||||
|
||||
|
||||
class RawLSTMPolicy(LSTMPolicy):
|
||||
def get_weights(self):
|
||||
if not hasattr(self, "_weights"):
|
||||
self._weights = self.variables.get_weights()
|
||||
return self._weights
|
||||
def get_weights(self):
|
||||
if not hasattr(self, "_weights"):
|
||||
self._weights = self.variables.get_weights()
|
||||
return self._weights
|
||||
|
||||
def set_weights(self, weights):
|
||||
self._weights = weights
|
||||
def set_weights(self, weights):
|
||||
self._weights = weights
|
||||
|
||||
def model_update(self, grads):
|
||||
for var, grad in zip(self.var_list, grads):
|
||||
self._weights[var.name[:-2]] -= 1e-4 * grad
|
||||
def model_update(self, grads):
|
||||
for var, grad in zip(self.var_list, grads):
|
||||
self._weights[var.name[:-2]] -= 1e-4 * grad
|
||||
|
||||
+97
-96
@@ -22,108 +22,109 @@ DEFAULT_CONFIG = {
|
||||
|
||||
@ray.remote
|
||||
class Runner(object):
|
||||
"""Actor object to start running simulation on workers.
|
||||
"""Actor object to start running simulation on workers.
|
||||
|
||||
The gradient computation is also executed from this object.
|
||||
"""
|
||||
def __init__(self, env_name, actor_id, logdir, start=True):
|
||||
env = create_env(env_name)
|
||||
self.id = actor_id
|
||||
num_actions = env.action_space.n
|
||||
self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
|
||||
actor_id)
|
||||
self.runner = RunnerThread(env, self.policy, 20)
|
||||
self.env = env
|
||||
self.logdir = logdir
|
||||
if start:
|
||||
self.start()
|
||||
|
||||
def pull_batch_from_queue(self):
|
||||
"""Take a rollout from the queue of the thread runner."""
|
||||
rollout = self.runner.queue.get(timeout=600.0)
|
||||
if isinstance(rollout, BaseException):
|
||||
raise rollout
|
||||
while not rollout.terminal:
|
||||
try:
|
||||
part = self.runner.queue.get_nowait()
|
||||
if isinstance(part, BaseException):
|
||||
raise rollout
|
||||
rollout.extend(part)
|
||||
except queue.Empty:
|
||||
break
|
||||
return rollout
|
||||
|
||||
def get_completed_rollout_metrics(self):
|
||||
"""Returns metrics on previously completed rollouts.
|
||||
|
||||
Calling this clears the queue of completed rollout metrics.
|
||||
The gradient computation is also executed from this object.
|
||||
"""
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.runner.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
def __init__(self, env_name, actor_id, logdir, start=True):
|
||||
env = create_env(env_name)
|
||||
self.id = actor_id
|
||||
num_actions = env.action_space.n
|
||||
self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
|
||||
actor_id)
|
||||
self.runner = RunnerThread(env, self.policy, 20)
|
||||
self.env = env
|
||||
self.logdir = logdir
|
||||
if start:
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
summary_writer = tf.summary.FileWriter(
|
||||
os.path.join(self.logdir, "agent_%d" % self.id))
|
||||
self.summary_writer = summary_writer
|
||||
self.runner.start_runner(self.policy.sess, summary_writer)
|
||||
def pull_batch_from_queue(self):
|
||||
"""Take a rollout from the queue of the thread runner."""
|
||||
rollout = self.runner.queue.get(timeout=600.0)
|
||||
if isinstance(rollout, BaseException):
|
||||
raise rollout
|
||||
while not rollout.terminal:
|
||||
try:
|
||||
part = self.runner.queue.get_nowait()
|
||||
if isinstance(part, BaseException):
|
||||
raise rollout
|
||||
rollout.extend(part)
|
||||
except queue.Empty:
|
||||
break
|
||||
return rollout
|
||||
|
||||
def compute_gradient(self, params):
|
||||
self.policy.set_weights(params)
|
||||
rollout = self.pull_batch_from_queue()
|
||||
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
|
||||
gradient = self.policy.get_gradients(batch)
|
||||
info = {"id": self.id,
|
||||
"size": len(batch.a)}
|
||||
return gradient, info
|
||||
def get_completed_rollout_metrics(self):
|
||||
"""Returns metrics on previously completed rollouts.
|
||||
|
||||
Calling this clears the queue of completed rollout metrics.
|
||||
"""
|
||||
completed = []
|
||||
while True:
|
||||
try:
|
||||
completed.append(self.runner.metrics_queue.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
return completed
|
||||
|
||||
def start(self):
|
||||
summary_writer = tf.summary.FileWriter(
|
||||
os.path.join(self.logdir, "agent_%d" % self.id))
|
||||
self.summary_writer = summary_writer
|
||||
self.runner.start_runner(self.policy.sess, summary_writer)
|
||||
|
||||
def compute_gradient(self, params):
|
||||
self.policy.set_weights(params)
|
||||
rollout = self.pull_batch_from_queue()
|
||||
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
|
||||
gradient = self.policy.get_gradients(batch)
|
||||
info = {"id": self.id,
|
||||
"size": len(batch.a)}
|
||||
return gradient, info
|
||||
|
||||
|
||||
class A3C(Algorithm):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "A3C"})
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
self.env = create_env(env_name)
|
||||
self.policy = LSTMPolicy(
|
||||
self.env.observation_space.shape, self.env.action_space.n, 0)
|
||||
self.agents = [
|
||||
Runner.remote(env_name, i, self.logdir)
|
||||
for i in range(config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
self.iteration = 0
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "A3C"})
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
self.env = create_env(env_name)
|
||||
self.policy = LSTMPolicy(
|
||||
self.env.observation_space.shape, self.env.action_space.n, 0)
|
||||
self.agents = [
|
||||
Runner.remote(env_name, i, self.logdir)
|
||||
for i in range(config["num_workers"])]
|
||||
self.parameters = self.policy.get_weights()
|
||||
self.iteration = 0
|
||||
|
||||
def train(self):
|
||||
gradient_list = [
|
||||
agent.compute_gradient.remote(self.parameters)
|
||||
for agent in self.agents]
|
||||
max_batches = self.config["num_batches_per_iteration"]
|
||||
batches_so_far = len(gradient_list)
|
||||
while gradient_list:
|
||||
done_id, gradient_list = ray.wait(gradient_list)
|
||||
gradient, info = ray.get(done_id)[0]
|
||||
self.policy.model_update(gradient)
|
||||
self.parameters = self.policy.get_weights()
|
||||
if batches_so_far < max_batches:
|
||||
batches_so_far += 1
|
||||
gradient_list.extend(
|
||||
[self.agents[info["id"]].compute_gradient.remote(self.parameters)])
|
||||
res = self.fetch_metrics_from_workers()
|
||||
self.iteration += 1
|
||||
return res
|
||||
def train(self):
|
||||
gradient_list = [
|
||||
agent.compute_gradient.remote(self.parameters)
|
||||
for agent in self.agents]
|
||||
max_batches = self.config["num_batches_per_iteration"]
|
||||
batches_so_far = len(gradient_list)
|
||||
while gradient_list:
|
||||
done_id, gradient_list = ray.wait(gradient_list)
|
||||
gradient, info = ray.get(done_id)[0]
|
||||
self.policy.model_update(gradient)
|
||||
self.parameters = self.policy.get_weights()
|
||||
if batches_so_far < max_batches:
|
||||
batches_so_far += 1
|
||||
gradient_list.extend(
|
||||
[self.agents[info["id"]].compute_gradient.remote(
|
||||
self.parameters)])
|
||||
res = self.fetch_metrics_from_workers()
|
||||
self.iteration += 1
|
||||
return res
|
||||
|
||||
def fetch_metrics_from_workers(self):
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
metric_lists = [
|
||||
a.get_completed_rollout_metrics.remote() for a in self.agents]
|
||||
for metrics in metric_lists:
|
||||
for episode in ray.get(metrics):
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.iteration,
|
||||
np.mean(episode_rewards), np.mean(episode_lengths), dict())
|
||||
return res
|
||||
def fetch_metrics_from_workers(self):
|
||||
episode_rewards = []
|
||||
episode_lengths = []
|
||||
metric_lists = [
|
||||
a.get_completed_rollout_metrics.remote() for a in self.agents]
|
||||
for metrics in metric_lists:
|
||||
for episode in ray.get(metrics):
|
||||
episode_lengths.append(episode.episode_length)
|
||||
episode_rewards.append(episode.episode_reward)
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.iteration,
|
||||
np.mean(episode_rewards), np.mean(episode_lengths), dict())
|
||||
return res
|
||||
|
||||
@@ -14,94 +14,96 @@ logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def create_env(env_id):
|
||||
env = gym.make(env_id)
|
||||
env = AtariProcessing(env)
|
||||
env = Diagnostic(env)
|
||||
return env
|
||||
env = gym.make(env_id)
|
||||
env = AtariProcessing(env)
|
||||
env = Diagnostic(env)
|
||||
return env
|
||||
|
||||
|
||||
def _process_frame42(frame):
|
||||
frame = frame[34:(34 + 160), :160]
|
||||
# Resize by half, then down to 42x42 (essentially mipmapping). If we resize
|
||||
# directly we lose pixels that, when mapped to 42x42, aren't close enough to
|
||||
# the pixel boundary.
|
||||
frame = cv2.resize(frame, (80, 80))
|
||||
frame = cv2.resize(frame, (42, 42))
|
||||
frame = frame.mean(2)
|
||||
frame = frame.astype(np.float32)
|
||||
frame *= (1.0 / 255.0)
|
||||
frame = np.reshape(frame, [42, 42, 1])
|
||||
return frame
|
||||
frame = frame[34:(34 + 160), :160]
|
||||
# Resize by half, then down to 42x42 (essentially mipmapping). If we resize
|
||||
# directly we lose pixels that, when mapped to 42x42, aren't close enough
|
||||
# to the pixel boundary.
|
||||
frame = cv2.resize(frame, (80, 80))
|
||||
frame = cv2.resize(frame, (42, 42))
|
||||
frame = frame.mean(2)
|
||||
frame = frame.astype(np.float32)
|
||||
frame *= (1.0 / 255.0)
|
||||
frame = np.reshape(frame, [42, 42, 1])
|
||||
return frame
|
||||
|
||||
|
||||
class AtariProcessing(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(AtariProcessing, self).__init__(env)
|
||||
self.observation_space = Box(0.0, 1.0, [42, 42, 1])
|
||||
def __init__(self, env=None):
|
||||
super(AtariProcessing, self).__init__(env)
|
||||
self.observation_space = Box(0.0, 1.0, [42, 42, 1])
|
||||
|
||||
def _observation(self, observation):
|
||||
return _process_frame42(observation)
|
||||
def _observation(self, observation):
|
||||
return _process_frame42(observation)
|
||||
|
||||
|
||||
class Diagnostic(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
super(Diagnostic, self).__init__(env)
|
||||
self.diagnostics = DiagnosticsLogger()
|
||||
def __init__(self, env=None):
|
||||
super(Diagnostic, self).__init__(env)
|
||||
self.diagnostics = DiagnosticsLogger()
|
||||
|
||||
def _reset(self):
|
||||
observation = self.env.reset()
|
||||
return self.diagnostics._after_reset(observation)
|
||||
def _reset(self):
|
||||
observation = self.env.reset()
|
||||
return self.diagnostics._after_reset(observation)
|
||||
|
||||
def _step(self, action):
|
||||
results = self.env.step(action)
|
||||
return self.diagnostics._after_step(*results)
|
||||
def _step(self, action):
|
||||
results = self.env.step(action)
|
||||
return self.diagnostics._after_step(*results)
|
||||
|
||||
|
||||
class DiagnosticsLogger(object):
|
||||
def __init__(self, log_interval=503):
|
||||
self._episode_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._local_t = 0
|
||||
self._log_interval = log_interval
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
self._last_episode_id = -1
|
||||
def __init__(self, log_interval=503):
|
||||
self._episode_time = time.time()
|
||||
self._last_time = time.time()
|
||||
self._local_t = 0
|
||||
self._log_interval = log_interval
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
self._last_episode_id = -1
|
||||
|
||||
def _after_reset(self, observation):
|
||||
logger.info("Resetting environment")
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
return observation
|
||||
def _after_reset(self, observation):
|
||||
logger.info("Resetting environment")
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
return observation
|
||||
|
||||
def _after_step(self, observation, reward, done, info):
|
||||
to_log = {}
|
||||
if self._episode_length == 0:
|
||||
self._episode_time = time.time()
|
||||
def _after_step(self, observation, reward, done, info):
|
||||
to_log = {}
|
||||
if self._episode_length == 0:
|
||||
self._episode_time = time.time()
|
||||
|
||||
self._local_t += 1
|
||||
self._local_t += 1
|
||||
|
||||
if self._local_t % self._log_interval == 0:
|
||||
cur_time = time.time()
|
||||
self._last_time = cur_time
|
||||
if self._local_t % self._log_interval == 0:
|
||||
cur_time = time.time()
|
||||
self._last_time = cur_time
|
||||
|
||||
if reward is not None:
|
||||
self._episode_reward += reward
|
||||
if observation is not None:
|
||||
self._episode_length += 1
|
||||
self._all_rewards.append(reward)
|
||||
if reward is not None:
|
||||
self._episode_reward += reward
|
||||
if observation is not None:
|
||||
self._episode_length += 1
|
||||
self._all_rewards.append(reward)
|
||||
|
||||
if done:
|
||||
logger.info("Episode terminating: episode_reward=%s episode_length=%s",
|
||||
self._episode_reward, self._episode_length)
|
||||
total_time = time.time() - self._episode_time
|
||||
to_log["global/episode_reward"] = self._episode_reward
|
||||
to_log["global/episode_length"] = self._episode_length
|
||||
to_log["global/episode_time"] = total_time
|
||||
to_log["global/reward_per_time"] = self._episode_reward / total_time
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
if done:
|
||||
logger.info("Episode terminating: episode_reward=%s "
|
||||
"episode_length=%s",
|
||||
self._episode_reward, self._episode_length)
|
||||
total_time = time.time() - self._episode_time
|
||||
to_log["global/episode_reward"] = self._episode_reward
|
||||
to_log["global/episode_length"] = self._episode_length
|
||||
to_log["global/episode_time"] = total_time
|
||||
to_log["global/reward_per_time"] = (self._episode_reward /
|
||||
total_time)
|
||||
self._episode_reward = 0
|
||||
self._episode_length = 0
|
||||
self._all_rewards = []
|
||||
|
||||
return observation, reward, done, to_log
|
||||
return observation, reward, done, to_log
|
||||
|
||||
@@ -11,22 +11,22 @@ from ray.rllib.a3c import A3C, DEFAULT_CONFIG
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the A3C algorithm.")
|
||||
parser.add_argument("--environment", default="PongDeterministic-v3",
|
||||
type=str, help="The gym environment to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--num-workers", default=4, type=int,
|
||||
help="The number of A3C workers to use>")
|
||||
parser = argparse.ArgumentParser(description="Run the A3C algorithm.")
|
||||
parser.add_argument("--environment", default="PongDeterministic-v3",
|
||||
type=str, help="The gym environment to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--num-workers", default=4, type=int,
|
||||
help="The number of A3C workers to use>")
|
||||
|
||||
args = parser.parse_args()
|
||||
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers)
|
||||
args = parser.parse_args()
|
||||
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers)
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = args.num_workers
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["num_workers"] = args.num_workers
|
||||
|
||||
a3c = A3C(args.environment, config)
|
||||
a3c = A3C(args.environment, config)
|
||||
|
||||
while True:
|
||||
res = a3c.train()
|
||||
print("current status: {}".format(res))
|
||||
while True:
|
||||
res = a3c.train()
|
||||
print("current status: {}".format(res))
|
||||
|
||||
+101
-100
@@ -8,138 +8,139 @@ import ray
|
||||
|
||||
|
||||
class Policy(object):
|
||||
"""The policy base class."""
|
||||
def __init__(self, ob_space, ac_space, task, name="local"):
|
||||
self.local_steps = 0
|
||||
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
self.g = tf.Graph()
|
||||
with self.g.as_default(), tf.device(worker_device):
|
||||
with tf.variable_scope(name):
|
||||
self.setup_graph(ob_space, ac_space)
|
||||
assert all([hasattr(self, attr)
|
||||
for attr in ["vf", "logits", "x", "var_list"]])
|
||||
print("Setting up loss")
|
||||
self.setup_loss(ac_space)
|
||||
self.initialize()
|
||||
"""The policy base class."""
|
||||
def __init__(self, ob_space, ac_space, task, name="local"):
|
||||
self.local_steps = 0
|
||||
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
|
||||
self.g = tf.Graph()
|
||||
with self.g.as_default(), tf.device(worker_device):
|
||||
with tf.variable_scope(name):
|
||||
self.setup_graph(ob_space, ac_space)
|
||||
assert all([hasattr(self, attr)
|
||||
for attr in ["vf", "logits", "x", "var_list"]])
|
||||
print("Setting up loss")
|
||||
self.setup_loss(ac_space)
|
||||
self.initialize()
|
||||
|
||||
def setup_graph(self):
|
||||
raise NotImplementedError
|
||||
def setup_graph(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_loss(self, num_actions, summarize=True):
|
||||
self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac")
|
||||
self.adv = tf.placeholder(tf.float32, [None], name="adv")
|
||||
self.r = tf.placeholder(tf.float32, [None], name="r")
|
||||
def setup_loss(self, num_actions, summarize=True):
|
||||
self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac")
|
||||
self.adv = tf.placeholder(tf.float32, [None], name="adv")
|
||||
self.r = tf.placeholder(tf.float32, [None], name="r")
|
||||
|
||||
log_prob_tf = tf.nn.log_softmax(self.logits)
|
||||
prob_tf = tf.nn.softmax(self.logits)
|
||||
log_prob_tf = tf.nn.log_softmax(self.logits)
|
||||
prob_tf = tf.nn.softmax(self.logits)
|
||||
|
||||
# The "policy gradients" loss: its derivative is precisely the policy
|
||||
# gradient. Notice that self.ac is a placeholder that is provided
|
||||
# externally. adv will contain the advantages, as calculated in
|
||||
# process_rollout.
|
||||
pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac,
|
||||
[1]) * self.adv)
|
||||
# The "policy gradients" loss: its derivative is precisely the policy
|
||||
# gradient. Notice that self.ac is a placeholder that is provided
|
||||
# externally. adv will contain the advantages, as calculated in
|
||||
# process_rollout.
|
||||
pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac,
|
||||
[1]) * self.adv)
|
||||
|
||||
# loss of value function
|
||||
vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r))
|
||||
vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss")
|
||||
entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
|
||||
# loss of value function
|
||||
vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r))
|
||||
vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss")
|
||||
entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
|
||||
|
||||
bs = tf.to_float(tf.shape(self.x)[0])
|
||||
self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01
|
||||
bs = tf.to_float(tf.shape(self.x)[0])
|
||||
self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01
|
||||
|
||||
grads = tf.gradients(self.loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
|
||||
grads = tf.gradients(self.loss, self.var_list)
|
||||
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
|
||||
|
||||
grads_and_vars = list(zip(self.grads, self.var_list))
|
||||
opt = tf.train.AdamOptimizer(1e-4)
|
||||
self._apply_gradients = opt.apply_gradients(grads_and_vars)
|
||||
grads_and_vars = list(zip(self.grads, self.var_list))
|
||||
opt = tf.train.AdamOptimizer(1e-4)
|
||||
self._apply_gradients = opt.apply_gradients(grads_and_vars)
|
||||
|
||||
if summarize:
|
||||
tf.summary.scalar("model/policy_loss", pi_loss / bs)
|
||||
tf.summary.scalar("model/value_loss", vf_loss / bs)
|
||||
tf.summary.scalar("model/entropy", entropy / bs)
|
||||
tf.summary.image("model/state", self.x)
|
||||
self.summary_op = tf.summary.merge_all()
|
||||
if summarize:
|
||||
tf.summary.scalar("model/policy_loss", pi_loss / bs)
|
||||
tf.summary.scalar("model/value_loss", vf_loss / bs)
|
||||
tf.summary.scalar("model/entropy", entropy / bs)
|
||||
tf.summary.image("model/state", self.x)
|
||||
self.summary_op = tf.summary.merge_all()
|
||||
|
||||
def initialize(self):
|
||||
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
|
||||
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
|
||||
self.variables = ray.experimental.TensorFlowVariables(self.loss, self.sess)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
def initialize(self):
|
||||
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
|
||||
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
|
||||
self.variables = ray.experimental.TensorFlowVariables(self.loss,
|
||||
self.sess)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def model_update(self, grads):
|
||||
feed_dict = {self.grads[i]: grads[i]
|
||||
for i in range(len(grads))}
|
||||
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
|
||||
def model_update(self, grads):
|
||||
feed_dict = {self.grads[i]: grads[i]
|
||||
for i in range(len(grads))}
|
||||
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
|
||||
|
||||
def get_weights(self):
|
||||
weights = self.variables.get_weights()
|
||||
return weights
|
||||
def get_weights(self):
|
||||
weights = self.variables.get_weights()
|
||||
return weights
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
def set_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def get_gradients(self, batch):
|
||||
raise NotImplementedError
|
||||
def get_gradients(self, batch):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vf_loss(self):
|
||||
raise NotImplementedError
|
||||
def get_vf_loss(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def act(self, ob):
|
||||
raise NotImplementedError
|
||||
def act(self, ob):
|
||||
raise NotImplementedError
|
||||
|
||||
def value(self, ob):
|
||||
raise NotImplementedError
|
||||
def value(self, ob):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def normalized_columns_initializer(std=1.0):
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
|
||||
|
||||
def flatten(x):
|
||||
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
|
||||
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
|
||||
|
||||
|
||||
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME",
|
||||
dtype=tf.float32, collections=None):
|
||||
with tf.variable_scope(name):
|
||||
stride_shape = [1, stride[0], stride[1], 1]
|
||||
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]),
|
||||
num_filters]
|
||||
with tf.variable_scope(name):
|
||||
stride_shape = [1, stride[0], stride[1], 1]
|
||||
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]),
|
||||
num_filters]
|
||||
|
||||
# There are "num input feature maps * filter height * filter width"
|
||||
# inputs to each hidden unit.
|
||||
fan_in = np.prod(filter_shape[:3])
|
||||
# Each unit in the lower layer receives a gradient from:
|
||||
# "num output feature maps * filter height * filter width" / pooling size.
|
||||
fan_out = np.prod(filter_shape[:2]) * num_filters
|
||||
# Initialize weights with random weights.
|
||||
w_bound = np.sqrt(6 / (fan_in + fan_out))
|
||||
# There are "num input feature maps * filter height * filter width"
|
||||
# inputs to each hidden unit.
|
||||
fan_in = np.prod(filter_shape[:3])
|
||||
# Each unit in the lower layer receives a gradient from: "num output
|
||||
# feature maps * filter height * filter width" / pooling size.
|
||||
fan_out = np.prod(filter_shape[:2]) * num_filters
|
||||
# Initialize weights with random weights.
|
||||
w_bound = np.sqrt(6 / (fan_in + fan_out))
|
||||
|
||||
w = tf.get_variable("W", filter_shape, dtype,
|
||||
tf.random_uniform_initializer(-w_bound, w_bound),
|
||||
collections=collections)
|
||||
b = tf.get_variable("b", [1, 1, 1, num_filters],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
collections=collections)
|
||||
return tf.nn.conv2d(x, w, stride_shape, pad) + b
|
||||
w = tf.get_variable("W", filter_shape, dtype,
|
||||
tf.random_uniform_initializer(-w_bound, w_bound),
|
||||
collections=collections)
|
||||
b = tf.get_variable("b", [1, 1, 1, num_filters],
|
||||
initializer=tf.constant_initializer(0.0),
|
||||
collections=collections)
|
||||
return tf.nn.conv2d(x, w, stride_shape, pad) + b
|
||||
|
||||
|
||||
def linear(x, size, name, initializer=None, bias_init=0):
|
||||
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
|
||||
initializer=initializer)
|
||||
b = tf.get_variable(name + "/b", [size],
|
||||
initializer=tf.constant_initializer(bias_init))
|
||||
return tf.matmul(x, w) + b
|
||||
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
|
||||
initializer=initializer)
|
||||
b = tf.get_variable(name + "/b", [size],
|
||||
initializer=tf.constant_initializer(bias_init))
|
||||
return tf.matmul(x, w) + b
|
||||
|
||||
|
||||
def categorical_sample(logits, d):
|
||||
value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1],
|
||||
keep_dims=True),
|
||||
1), [1])
|
||||
return tf.one_hot(value, d)
|
||||
value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1],
|
||||
keep_dims=True),
|
||||
1), [1])
|
||||
return tf.one_hot(value, d)
|
||||
|
||||
+132
-131
@@ -11,26 +11,26 @@ import threading
|
||||
|
||||
|
||||
def discount(x, gamma):
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
|
||||
|
||||
|
||||
def process_rollout(rollout, gamma, lambda_=1.0):
|
||||
"""Given a rollout, compute its returns and the advantage."""
|
||||
batch_si = np.asarray(rollout.states)
|
||||
batch_a = np.asarray(rollout.actions)
|
||||
rewards = np.asarray(rollout.rewards)
|
||||
vpred_t = np.asarray(rollout.values + [rollout.r])
|
||||
"""Given a rollout, compute its returns and the advantage."""
|
||||
batch_si = np.asarray(rollout.states)
|
||||
batch_a = np.asarray(rollout.actions)
|
||||
rewards = np.asarray(rollout.rewards)
|
||||
vpred_t = np.asarray(rollout.values + [rollout.r])
|
||||
|
||||
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
|
||||
batch_r = discount(rewards_plus_v, gamma)[:-1]
|
||||
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
# This formula for the advantage comes "Generalized Advantage Estimation":
|
||||
# https://arxiv.org/abs/1506.02438
|
||||
batch_adv = discount(delta_t, gamma * lambda_)
|
||||
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
|
||||
batch_r = discount(rewards_plus_v, gamma)[:-1]
|
||||
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
|
||||
# This formula for the advantage comes "Generalized Advantage Estimation":
|
||||
# https://arxiv.org/abs/1506.02438
|
||||
batch_adv = discount(delta_t, gamma * lambda_)
|
||||
|
||||
features = rollout.features[0]
|
||||
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
|
||||
features)
|
||||
features = rollout.features[0]
|
||||
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
|
||||
features)
|
||||
|
||||
|
||||
Batch = namedtuple(
|
||||
@@ -41,142 +41,143 @@ CompletedRollout = namedtuple(
|
||||
|
||||
|
||||
class PartialRollout(object):
|
||||
"""A piece of a complete rollout.
|
||||
"""A piece of a complete rollout.
|
||||
|
||||
We run our agent, and process its experience once it has processed enough
|
||||
steps.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
self.rewards = []
|
||||
self.values = []
|
||||
self.r = 0.0
|
||||
self.terminal = False
|
||||
self.features = []
|
||||
We run our agent, and process its experience once it has processed enough
|
||||
steps.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
self.rewards = []
|
||||
self.values = []
|
||||
self.r = 0.0
|
||||
self.terminal = False
|
||||
self.features = []
|
||||
|
||||
def add(self, state, action, reward, value, terminal, features):
|
||||
self.states += [state]
|
||||
self.actions += [action]
|
||||
self.rewards += [reward]
|
||||
self.values += [value]
|
||||
self.terminal = terminal
|
||||
self.features += [features]
|
||||
def add(self, state, action, reward, value, terminal, features):
|
||||
self.states += [state]
|
||||
self.actions += [action]
|
||||
self.rewards += [reward]
|
||||
self.values += [value]
|
||||
self.terminal = terminal
|
||||
self.features += [features]
|
||||
|
||||
def extend(self, other):
|
||||
assert not self.terminal
|
||||
self.states.extend(other.states)
|
||||
self.actions.extend(other.actions)
|
||||
self.rewards.extend(other.rewards)
|
||||
self.values.extend(other.values)
|
||||
self.r = other.r
|
||||
self.terminal = other.terminal
|
||||
self.features.extend(other.features)
|
||||
def extend(self, other):
|
||||
assert not self.terminal
|
||||
self.states.extend(other.states)
|
||||
self.actions.extend(other.actions)
|
||||
self.rewards.extend(other.rewards)
|
||||
self.values.extend(other.values)
|
||||
self.r = other.r
|
||||
self.terminal = other.terminal
|
||||
self.features.extend(other.features)
|
||||
|
||||
|
||||
class RunnerThread(threading.Thread):
|
||||
"""This thread interacts with the environment and tells it what to do."""
|
||||
def __init__(self, env, policy, num_local_steps, visualise=False):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.env = env
|
||||
self.last_features = None
|
||||
self.policy = policy
|
||||
self.daemon = True
|
||||
self.sess = None
|
||||
self.summary_writer = None
|
||||
self.visualise = visualise
|
||||
"""This thread interacts with the environment and tells it what to do."""
|
||||
def __init__(self, env, policy, num_local_steps, visualise=False):
|
||||
threading.Thread.__init__(self)
|
||||
self.queue = queue.Queue(5)
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.num_local_steps = num_local_steps
|
||||
self.env = env
|
||||
self.last_features = None
|
||||
self.policy = policy
|
||||
self.daemon = True
|
||||
self.sess = None
|
||||
self.summary_writer = None
|
||||
self.visualise = visualise
|
||||
|
||||
def start_runner(self, sess, summary_writer):
|
||||
self.sess = sess
|
||||
self.summary_writer = summary_writer
|
||||
self.start()
|
||||
def start_runner(self, sess, summary_writer):
|
||||
self.sess = sess
|
||||
self.summary_writer = summary_writer
|
||||
self.start()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
with self.sess.as_default():
|
||||
self._run()
|
||||
except BaseException as e:
|
||||
self.queue.put(e)
|
||||
raise e
|
||||
def run(self):
|
||||
try:
|
||||
with self.sess.as_default():
|
||||
self._run()
|
||||
except BaseException as e:
|
||||
self.queue.put(e)
|
||||
raise e
|
||||
|
||||
def _run(self):
|
||||
rollout_provider = env_runner(
|
||||
self.env, self.policy, self.num_local_steps,
|
||||
self.summary_writer, self.visualise)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker dies, the
|
||||
# other workers won't die with it, unless the timeout is set to some
|
||||
# large number. This is an empirical observation.
|
||||
item = next(rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
def _run(self):
|
||||
rollout_provider = env_runner(
|
||||
self.env, self.policy, self.num_local_steps,
|
||||
self.summary_writer, self.visualise)
|
||||
while True:
|
||||
# The timeout variable exists because apparently, if one worker
|
||||
# dies, the other workers won't die with it, unless the timeout is
|
||||
# set to some large number. This is an empirical observation.
|
||||
item = next(rollout_provider)
|
||||
if isinstance(item, CompletedRollout):
|
||||
self.metrics_queue.put(item)
|
||||
else:
|
||||
self.queue.put(item, timeout=600.0)
|
||||
|
||||
|
||||
def env_runner(env, policy, num_local_steps, summary_writer, render):
|
||||
"""This implements the logic of the thread runner.
|
||||
"""This implements the logic of the thread runner.
|
||||
|
||||
It continually runs the policy, and as long as the rollout exceeds a certain
|
||||
length, the thread runner appends the policy to the queue.
|
||||
"""
|
||||
last_state = env.reset()
|
||||
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
last_features = policy.get_initial_features()
|
||||
length = 0
|
||||
rewards = 0
|
||||
rollout_number = 0
|
||||
It continually runs the policy, and as long as the rollout exceeds a
|
||||
certain length, the thread runner appends the policy to the queue.
|
||||
"""
|
||||
last_state = env.reset()
|
||||
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
last_features = policy.get_initial_features()
|
||||
length = 0
|
||||
rewards = 0
|
||||
rollout_number = 0
|
||||
|
||||
while True:
|
||||
terminal_end = False
|
||||
rollout = PartialRollout()
|
||||
while True:
|
||||
terminal_end = False
|
||||
rollout = PartialRollout()
|
||||
|
||||
for _ in range(num_local_steps):
|
||||
fetched = policy.act(last_state, *last_features)
|
||||
action, value_, features = fetched[0], fetched[1], fetched[2:]
|
||||
# Argmax to convert from one-hot.
|
||||
state, reward, terminal, info = env.step(action.argmax())
|
||||
if render:
|
||||
env.render()
|
||||
for _ in range(num_local_steps):
|
||||
fetched = policy.act(last_state, *last_features)
|
||||
action, value_, features = fetched[0], fetched[1], fetched[2:]
|
||||
# Argmax to convert from one-hot.
|
||||
state, reward, terminal, info = env.step(action.argmax())
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
length += 1
|
||||
rewards += reward
|
||||
if length >= timestep_limit:
|
||||
terminal = True
|
||||
length += 1
|
||||
rewards += reward
|
||||
if length >= timestep_limit:
|
||||
terminal = True
|
||||
|
||||
# Collect the experience.
|
||||
rollout.add(last_state, action, reward, value_, terminal, last_features)
|
||||
# Collect the experience.
|
||||
rollout.add(last_state, action, reward, value_, terminal,
|
||||
last_features)
|
||||
|
||||
last_state = state
|
||||
last_features = features
|
||||
last_state = state
|
||||
last_features = features
|
||||
|
||||
if info:
|
||||
summary = tf.Summary()
|
||||
for k, v in info.items():
|
||||
summary.value.add(tag=k, simple_value=float(v))
|
||||
summary_writer.add_summary(summary, rollout_number)
|
||||
summary_writer.flush()
|
||||
if info:
|
||||
summary = tf.Summary()
|
||||
for k, v in info.items():
|
||||
summary.value.add(tag=k, simple_value=float(v))
|
||||
summary_writer.add_summary(summary, rollout_number)
|
||||
summary_writer.flush()
|
||||
|
||||
if terminal:
|
||||
terminal_end = True
|
||||
yield CompletedRollout(length, rewards)
|
||||
if terminal:
|
||||
terminal_end = True
|
||||
yield CompletedRollout(length, rewards)
|
||||
|
||||
if length >= timestep_limit or not env.metadata.get("semantics"
|
||||
".autoreset"):
|
||||
last_state = env.reset()
|
||||
last_features = policy.get_initial_features()
|
||||
rollout_number += 1
|
||||
length = 0
|
||||
rewards = 0
|
||||
break
|
||||
if (length >= timestep_limit or
|
||||
not env.metadata.get("semantics.autoreset")):
|
||||
last_state = env.reset()
|
||||
last_features = policy.get_initial_features()
|
||||
rollout_number += 1
|
||||
length = 0
|
||||
rewards = 0
|
||||
break
|
||||
|
||||
if not terminal_end:
|
||||
rollout.r = policy.value(last_state, *last_features)
|
||||
if not terminal_end:
|
||||
rollout.r = policy.value(last_state, *last_features)
|
||||
|
||||
# Once we have enough experience, yield it, and have the ThreadRunner
|
||||
# place it on a queue.
|
||||
yield rollout
|
||||
# Once we have enough experience, yield it, and have the ThreadRunner
|
||||
# place it on a queue.
|
||||
yield rollout
|
||||
|
||||
+67
-66
@@ -9,39 +9,39 @@ import tempfile
|
||||
import uuid
|
||||
import smart_open
|
||||
if sys.version_info[0] == 2:
|
||||
import cStringIO as StringIO
|
||||
import cStringIO as StringIO
|
||||
elif sys.version_info[0] == 3:
|
||||
import io as StringIO
|
||||
import io as StringIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class RLLibEncoder(json.JSONEncoder):
|
||||
def default(self, value):
|
||||
if isinstance(value, np.float32) or isinstance(value, np.float64):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
else:
|
||||
return float(value)
|
||||
def default(self, value):
|
||||
if isinstance(value, np.float32) or isinstance(value, np.float64):
|
||||
if np.isnan(value):
|
||||
return None
|
||||
else:
|
||||
return float(value)
|
||||
|
||||
|
||||
class RLLibLogger(object):
|
||||
"""Writing small amounts of data to S3 with real-time updates.
|
||||
"""
|
||||
"""Writing small amounts of data to S3 with real-time updates.
|
||||
"""
|
||||
|
||||
def __init__(self, uri):
|
||||
self.result_buffer = StringIO.StringIO()
|
||||
self.uri = uri
|
||||
def __init__(self, uri):
|
||||
self.result_buffer = StringIO.StringIO()
|
||||
self.uri = uri
|
||||
|
||||
def write(self, b):
|
||||
# TODO(pcm): At the moment we are writing the whole results output from
|
||||
# the beginning in each iteration. This will write O(n^2) bytes where n
|
||||
# is the number of bytes printed so far. Fix this! This should at least
|
||||
# only write the last 5MBs (S3 chunksize).
|
||||
with smart_open.smart_open(self.uri, "w") as f:
|
||||
self.result_buffer.write(b)
|
||||
f.write(self.result_buffer.getvalue())
|
||||
def write(self, b):
|
||||
# TODO(pcm): At the moment we are writing the whole results output from
|
||||
# the beginning in each iteration. This will write O(n^2) bytes where n
|
||||
# is the number of bytes printed so far. Fix this! This should at least
|
||||
# only write the last 5MBs (S3 chunksize).
|
||||
with smart_open.smart_open(self.uri, "w") as f:
|
||||
self.result_buffer.write(b)
|
||||
f.write(self.result_buffer.getvalue())
|
||||
|
||||
|
||||
TrainingResult = namedtuple("TrainingResult", [
|
||||
@@ -54,54 +54,55 @@ TrainingResult = namedtuple("TrainingResult", [
|
||||
|
||||
|
||||
class Algorithm(object):
|
||||
"""All RLlib algorithms extend this base class.
|
||||
"""All RLlib algorithms extend this base class.
|
||||
|
||||
Algorithm objects retain internal model state between calls to train(), so
|
||||
you should create a new algorithm instance for each training session.
|
||||
Algorithm objects retain internal model state between calls to train(), so
|
||||
you should create a new algorithm instance for each training session.
|
||||
|
||||
Attributes:
|
||||
env_name (str): Name of the OpenAI gym environment to train against.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
|
||||
TODO(ekl): support checkpoint / restore of training state.
|
||||
"""
|
||||
|
||||
def __init__(self, env_name, config, upload_dir="file:///tmp/ray"):
|
||||
"""Initialize an RLLib algorithm.
|
||||
|
||||
Args:
|
||||
env_name (str): The name of the OpenAI gym environment to use.
|
||||
Attributes:
|
||||
env_name (str): Name of the OpenAI gym environment to train against.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
upload_dir (str): Root directory into which the output directory
|
||||
should be placed. Can be local like file:///tmp/ray/ or on S3
|
||||
like s3://bucketname/.
|
||||
"""
|
||||
self.experiment_id = uuid.uuid4()
|
||||
self.env_name = env_name
|
||||
self.config = config
|
||||
self.config.update({"experiment_id": self.experiment_id.hex})
|
||||
self.config.update({"env_name": env_name})
|
||||
prefix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
self.__class__.__name__,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
if upload_dir.startswith("file"):
|
||||
self.logdir = "file://" + tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray")
|
||||
else:
|
||||
self.logdir = os.path.join(upload_dir, prefix)
|
||||
log_path = os.path.join(self.logdir, "config.json")
|
||||
with smart_open.smart_open(log_path, "w") as f:
|
||||
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
|
||||
logger.info(
|
||||
"%s algorithm created with logdir '%s'",
|
||||
self.__class__.__name__, self.logdir)
|
||||
logdir (str): Directory in which training outputs should be placed.
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Returns:
|
||||
A TrainingResult that describes training progress.
|
||||
TODO(ekl): support checkpoint / restore of training state.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
def __init__(self, env_name, config, upload_dir="file:///tmp/ray"):
|
||||
"""Initialize an RLLib algorithm.
|
||||
|
||||
Args:
|
||||
env_name (str): The name of the OpenAI gym environment to use.
|
||||
config (obj): Algorithm-specific configuration data.
|
||||
upload_dir (str): Root directory into which the output directory
|
||||
should be placed. Can be local like file:///tmp/ray/ or on S3
|
||||
like s3://bucketname/.
|
||||
"""
|
||||
self.experiment_id = uuid.uuid4()
|
||||
self.env_name = env_name
|
||||
self.config = config
|
||||
self.config.update({"experiment_id": self.experiment_id.hex})
|
||||
self.config.update({"env_name": env_name})
|
||||
prefix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
self.__class__.__name__,
|
||||
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
if upload_dir.startswith("file"):
|
||||
self.logdir = "file://" + tempfile.mkdtemp(prefix=prefix,
|
||||
dir="/tmp/ray")
|
||||
else:
|
||||
self.logdir = os.path.join(upload_dir, prefix)
|
||||
log_path = os.path.join(self.logdir, "config.json")
|
||||
with smart_open.smart_open(log_path, "w") as f:
|
||||
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
|
||||
logger.info(
|
||||
"%s algorithm created with logdir '%s'",
|
||||
self.__class__.__name__, self.logdir)
|
||||
|
||||
def train(self):
|
||||
"""Runs one logical iteration of training.
|
||||
|
||||
Returns:
|
||||
A TrainingResult that describes training progress.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
+177
-170
@@ -80,198 +80,205 @@ from ray.rllib.dqn.common import tf_util as U
|
||||
|
||||
|
||||
def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
|
||||
"""Creates the act function:
|
||||
"""Creates the act function:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
make_obs_ph: str -> tf.placeholder or TfInput
|
||||
a function that take a name and creates a placeholder of input with
|
||||
that name
|
||||
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||
the model that takes the following inputs:
|
||||
observation_in: object
|
||||
the output of observation placeholder
|
||||
num_actions: int
|
||||
number of actions
|
||||
scope: str
|
||||
reuse: bool
|
||||
should be passed to outer variable scope
|
||||
and returns a tensor of shape (batch_size, num_actions) with values of
|
||||
every action.
|
||||
num_actions: int
|
||||
number of actions.
|
||||
scope: str or VariableScope
|
||||
optional scope for variable_scope.
|
||||
reuse: bool or None
|
||||
whether or not the variables should be reused. To be able to reuse the
|
||||
scope must be given.
|
||||
Parameters
|
||||
----------
|
||||
make_obs_ph: str -> tf.placeholder or TfInput
|
||||
a function that take a name and creates a placeholder of input with
|
||||
that name
|
||||
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||
the model that takes the following inputs:
|
||||
observation_in: object
|
||||
the output of observation placeholder
|
||||
num_actions: int
|
||||
number of actions
|
||||
scope: str
|
||||
reuse: bool
|
||||
should be passed to outer variable scope
|
||||
and returns a tensor of shape (batch_size, num_actions) with values of
|
||||
every action.
|
||||
num_actions: int
|
||||
number of actions.
|
||||
scope: str or VariableScope
|
||||
optional scope for variable_scope.
|
||||
reuse: bool or None
|
||||
whether or not the variables should be reused. To be able to reuse the
|
||||
scope must be given.
|
||||
|
||||
Returns
|
||||
-------
|
||||
act: (tf.Variable, bool, float) -> tf.Variable
|
||||
function to select and action given observation.
|
||||
` See the top of the file for details.
|
||||
"""
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
|
||||
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
|
||||
Returns
|
||||
-------
|
||||
act: (tf.Variable, bool, float) -> tf.Variable
|
||||
function to select and action given observation.
|
||||
` See the top of the file for details.
|
||||
"""
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
|
||||
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
|
||||
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
|
||||
|
||||
eps = tf.get_variable(
|
||||
"eps", (), initializer=tf.constant_initializer(0))
|
||||
eps = tf.get_variable(
|
||||
"eps", (), initializer=tf.constant_initializer(0))
|
||||
|
||||
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
|
||||
deterministic_actions = tf.argmax(q_values, axis=1)
|
||||
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
|
||||
deterministic_actions = tf.argmax(q_values, axis=1)
|
||||
|
||||
batch_size = tf.shape(observations_ph.get())[0]
|
||||
random_actions = tf.random_uniform(
|
||||
tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
|
||||
chose_random = tf.random_uniform(
|
||||
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
|
||||
stochastic_actions = tf.where(
|
||||
chose_random, random_actions, deterministic_actions)
|
||||
batch_size = tf.shape(observations_ph.get())[0]
|
||||
random_actions = tf.random_uniform(
|
||||
tf.stack([batch_size]), minval=0, maxval=num_actions,
|
||||
dtype=tf.int64)
|
||||
chose_random = tf.random_uniform(
|
||||
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
|
||||
stochastic_actions = tf.where(
|
||||
chose_random, random_actions, deterministic_actions)
|
||||
|
||||
output_actions = tf.cond(
|
||||
stochastic_ph, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
update_eps_expr = eps.assign(
|
||||
tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
|
||||
output_actions = tf.cond(
|
||||
stochastic_ph, lambda: stochastic_actions,
|
||||
lambda: deterministic_actions)
|
||||
update_eps_expr = eps.assign(
|
||||
tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
|
||||
|
||||
act = U.function(
|
||||
inputs=[observations_ph, stochastic_ph, update_eps_ph],
|
||||
outputs=output_actions,
|
||||
givens={update_eps_ph: -1.0, stochastic_ph: True},
|
||||
updates=[update_eps_expr])
|
||||
return act
|
||||
act = U.function(
|
||||
inputs=[observations_ph, stochastic_ph, update_eps_ph],
|
||||
outputs=output_actions,
|
||||
givens={update_eps_ph: -1.0, stochastic_ph: True},
|
||||
updates=[update_eps_expr])
|
||||
return act
|
||||
|
||||
|
||||
def build_train(
|
||||
make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None,
|
||||
gamma=1.0, double_q=True, scope="deepq", reuse=None):
|
||||
"""Creates the train function:
|
||||
"""Creates the train function:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
make_obs_ph: str -> tf.placeholder or TfInput
|
||||
a function that takes a name and creates a placeholder of input with
|
||||
that name
|
||||
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||
the model that takes the following inputs:
|
||||
observation_in: object
|
||||
the output of observation placeholder
|
||||
num_actions: int
|
||||
number of actions
|
||||
scope: str
|
||||
reuse: bool
|
||||
should be passed to outer variable scope
|
||||
and returns a tensor of shape (batch_size, num_actions) with values of
|
||||
every action.
|
||||
num_actions: int
|
||||
number of actions
|
||||
reuse: bool
|
||||
whether or not to reuse the graph variables
|
||||
optimizer: tf.train.Optimizer
|
||||
optimizer to use for the Q-learning objective.
|
||||
grad_norm_clipping: float or None
|
||||
clip gradient norms to this value. If None no clipping is performed.
|
||||
gamma: float
|
||||
discount rate.
|
||||
double_q: bool
|
||||
if true will use Double Q Learning (https://arxiv.org/abs/1509.06461).
|
||||
In general it is a good idea to keep it enabled.
|
||||
scope: str or VariableScope
|
||||
optional scope for variable_scope.
|
||||
reuse: bool or None
|
||||
whether or not the variables should be reused. To be able to reuse the
|
||||
scope must be given.
|
||||
Parameters
|
||||
----------
|
||||
make_obs_ph: str -> tf.placeholder or TfInput
|
||||
a function that takes a name and creates a placeholder of input with
|
||||
that name
|
||||
q_func: (tf.Variable, int, str, bool) -> tf.Variable
|
||||
the model that takes the following inputs:
|
||||
observation_in: object
|
||||
the output of observation placeholder
|
||||
num_actions: int
|
||||
number of actions
|
||||
scope: str
|
||||
reuse: bool
|
||||
should be passed to outer variable scope
|
||||
and returns a tensor of shape (batch_size, num_actions) with values of
|
||||
every action.
|
||||
num_actions: int
|
||||
number of actions
|
||||
reuse: bool
|
||||
whether or not to reuse the graph variables
|
||||
optimizer: tf.train.Optimizer
|
||||
optimizer to use for the Q-learning objective.
|
||||
grad_norm_clipping: float or None
|
||||
clip gradient norms to this value. If None no clipping is performed.
|
||||
gamma: float
|
||||
discount rate.
|
||||
double_q: bool
|
||||
if true will use Double Q Learning (https://arxiv.org/abs/1509.06461).
|
||||
In general it is a good idea to keep it enabled.
|
||||
scope: str or VariableScope
|
||||
optional scope for variable_scope.
|
||||
reuse: bool or None
|
||||
whether or not the variables should be reused. To be able to reuse the
|
||||
scope must be given.
|
||||
|
||||
Returns
|
||||
-------
|
||||
act: (tf.Variable, bool, float) -> tf.Variable
|
||||
function to select and action given observation.
|
||||
` See the top of the file for details.
|
||||
train: (object, np.array, np.array, object, np.array, np.array) -> np.array
|
||||
optimize the error in Bellman's equation.
|
||||
` See the top of the file for details.
|
||||
update_target: () -> ()
|
||||
copy the parameters from optimized Q function to the target Q function.
|
||||
` See the top of the file for details.
|
||||
debug: {str: function}
|
||||
a bunch of functions to print debug data like q_values.
|
||||
"""
|
||||
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse)
|
||||
Returns
|
||||
-------
|
||||
act: (tf.Variable, bool, float) -> tf.Variable
|
||||
function to select and action given observation.
|
||||
` See the top of the file for details.
|
||||
train: (object, np.array, np.array, object, np.array, np.array) -> np.array
|
||||
optimize the error in Bellman's equation.
|
||||
` See the top of the file for details.
|
||||
update_target: () -> ()
|
||||
copy the parameters from optimized Q function to the target Q function.
|
||||
` See the top of the file for details.
|
||||
debug: {str: function}
|
||||
a bunch of functions to print debug data like q_values.
|
||||
"""
|
||||
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope,
|
||||
reuse=reuse)
|
||||
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
# set up placeholders
|
||||
obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t"))
|
||||
act_t_ph = tf.placeholder(tf.int32, [None], name="action")
|
||||
rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
|
||||
obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1"))
|
||||
done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
|
||||
importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
# set up placeholders
|
||||
obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t"))
|
||||
act_t_ph = tf.placeholder(tf.int32, [None], name="action")
|
||||
rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
|
||||
obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1"))
|
||||
done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
|
||||
importance_weights_ph = tf.placeholder(tf.float32, [None],
|
||||
name="weight")
|
||||
|
||||
# q network evaluation
|
||||
q_t = q_func(
|
||||
obs_t_input.get(), num_actions, scope="q_func",
|
||||
reuse=True) # reuse parameters from act
|
||||
q_func_vars = U.scope_vars(U.absolute_scope_name("q_func"))
|
||||
# q network evaluation
|
||||
q_t = q_func(
|
||||
obs_t_input.get(), num_actions, scope="q_func",
|
||||
reuse=True) # reuse parameters from act
|
||||
q_func_vars = U.scope_vars(U.absolute_scope_name("q_func"))
|
||||
|
||||
# target q network evalution
|
||||
q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func")
|
||||
target_q_func_vars = U.scope_vars(U.absolute_scope_name("target_q_func"))
|
||||
# target q network evalution
|
||||
q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func")
|
||||
target_q_func_vars = U.scope_vars(
|
||||
U.absolute_scope_name("target_q_func"))
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), 1)
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions),
|
||||
1)
|
||||
|
||||
# compute estimate of best possible value starting from state at t + 1
|
||||
if double_q:
|
||||
q_tp1_using_online_net = q_func(
|
||||
obs_tp1_input.get(), num_actions, scope="q_func", reuse=True)
|
||||
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
|
||||
q_tp1_best = tf.reduce_sum(
|
||||
q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), 1)
|
||||
else:
|
||||
q_tp1_best = tf.reduce_max(q_tp1, 1)
|
||||
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
|
||||
# compute estimate of best possible value starting from state at t + 1
|
||||
if double_q:
|
||||
q_tp1_using_online_net = q_func(
|
||||
obs_tp1_input.get(), num_actions, scope="q_func", reuse=True)
|
||||
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
|
||||
q_tp1_best = tf.reduce_sum(
|
||||
q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions),
|
||||
1)
|
||||
else:
|
||||
q_tp1_best = tf.reduce_max(q_tp1, 1)
|
||||
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
|
||||
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked
|
||||
# compute RHS of bellman equation
|
||||
q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked
|
||||
|
||||
# compute the error (potentially clipped)
|
||||
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
errors = U.huber_loss(td_error)
|
||||
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
|
||||
# compute optimization op (potentially with gradient clipping)
|
||||
if grad_norm_clipping is not None:
|
||||
optimize_expr = U.minimize_and_clip(
|
||||
optimizer, weighted_error, var_list=q_func_vars,
|
||||
clip_val=grad_norm_clipping)
|
||||
else:
|
||||
optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars)
|
||||
# compute the error (potentially clipped)
|
||||
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
|
||||
errors = U.huber_loss(td_error)
|
||||
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
|
||||
# compute optimization op (potentially with gradient clipping)
|
||||
if grad_norm_clipping is not None:
|
||||
optimize_expr = U.minimize_and_clip(
|
||||
optimizer, weighted_error, var_list=q_func_vars,
|
||||
clip_val=grad_norm_clipping)
|
||||
else:
|
||||
optimize_expr = optimizer.minimize(weighted_error,
|
||||
var_list=q_func_vars)
|
||||
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
update_target_expr = []
|
||||
for var, var_target in zip(
|
||||
sorted(q_func_vars, key=lambda v: v.name),
|
||||
sorted(target_q_func_vars, key=lambda v: v.name)):
|
||||
update_target_expr.append(var_target.assign(var))
|
||||
update_target_expr = tf.group(*update_target_expr)
|
||||
# update_target_fn will be called periodically to copy Q network to
|
||||
# target Q network
|
||||
update_target_expr = []
|
||||
for var, var_target in zip(
|
||||
sorted(q_func_vars, key=lambda v: v.name),
|
||||
sorted(target_q_func_vars, key=lambda v: v.name)):
|
||||
update_target_expr.append(var_target.assign(var))
|
||||
update_target_expr = tf.group(*update_target_expr)
|
||||
|
||||
# Create callable functions
|
||||
train = U.function(
|
||||
inputs=[
|
||||
obs_t_input,
|
||||
act_t_ph,
|
||||
rew_t_ph,
|
||||
obs_tp1_input,
|
||||
done_mask_ph,
|
||||
importance_weights_ph
|
||||
],
|
||||
outputs=td_error,
|
||||
updates=[optimize_expr])
|
||||
update_target = U.function([], [], updates=[update_target_expr])
|
||||
# Create callable functions
|
||||
train = U.function(
|
||||
inputs=[
|
||||
obs_t_input,
|
||||
act_t_ph,
|
||||
rew_t_ph,
|
||||
obs_tp1_input,
|
||||
done_mask_ph,
|
||||
importance_weights_ph
|
||||
],
|
||||
outputs=td_error,
|
||||
updates=[optimize_expr])
|
||||
update_target = U.function([], [], updates=[update_target_expr])
|
||||
|
||||
q_values = U.function([obs_t_input], q_t)
|
||||
q_values = U.function([obs_t_input], q_t)
|
||||
|
||||
return act_f, train, update_target, {'q_values': q_values}
|
||||
return act_f, train, update_target, {'q_values': q_values}
|
||||
|
||||
@@ -11,236 +11,240 @@ from gym import spaces
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
def __init__(self, env=None, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
super(NoopResetEnv, self).__init__(env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
def __init__(self, env=None, noop_max=30):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
"""
|
||||
super(NoopResetEnv, self).__init__(env)
|
||||
self.noop_max = noop_max
|
||||
self.override_num_noops = None
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def _reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = np.random.randint(1, self.noop_max + 1)
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
def _reset(self):
|
||||
""" Do no-op action for a number of steps in [1, noop_max]."""
|
||||
self.env.reset()
|
||||
if self.override_num_noops is not None:
|
||||
noops = self.override_num_noops
|
||||
else:
|
||||
noops = np.random.randint(1, self.noop_max + 1)
|
||||
assert noops > 0
|
||||
obs = None
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(0)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
"""For environments where the user need to press FIRE for the game to
|
||||
start."""
|
||||
super(FireResetEnv, self).__init__(env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
def __init__(self, env=None):
|
||||
"""For environments where the user need to press FIRE for the game to
|
||||
start."""
|
||||
super(FireResetEnv, self).__init__(env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def _reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
def _reset(self):
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(1)
|
||||
if done:
|
||||
self.env.reset()
|
||||
obs, _, done, _ = self.env.step(2)
|
||||
if done:
|
||||
self.env.reset()
|
||||
return obs
|
||||
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
def __init__(self, env=None):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over.
|
||||
Done by DeepMind for the DQN and co. since it helps value estimation.
|
||||
"""
|
||||
super(EpisodicLifeEnv, self).__init__(env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
self.was_real_reset = False
|
||||
def __init__(self, env=None):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game
|
||||
over. Done by DeepMind for the DQN and co. since it helps value
|
||||
estimation.
|
||||
"""
|
||||
super(EpisodicLifeEnv, self).__init__(env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
self.was_real_reset = False
|
||||
|
||||
def _step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
|
||||
# so its important to keep lives > 0, so that we only reset once
|
||||
# the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
def _step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal,
|
||||
# then update lives to handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if lives < self.lives and lives > 0:
|
||||
# for Qbert somtimes we stay in lives == 0 condtion for a few
|
||||
# frames so its important to keep lives > 0, so that we only reset
|
||||
# once the environment advertises done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
self.was_real_reset = True
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.was_real_reset = False
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
def _reset(self):
|
||||
"""Reset only when lives are exhausted.
|
||||
This way all states are still reachable even though lives are episodic,
|
||||
and the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
self.was_real_reset = True
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs, _, _, _ = self.env.step(0)
|
||||
self.was_real_reset = False
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
def __init__(self, env=None, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
super(MaxAndSkipEnv, self).__init__(env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
def __init__(self, env=None, skip=4):
|
||||
"""Return only every `skip`-th frame"""
|
||||
super(MaxAndSkipEnv, self).__init__(env)
|
||||
# most recent raw observations (for max pooling across time steps)
|
||||
self._obs_buffer = deque(maxlen=2)
|
||||
self._skip = skip
|
||||
|
||||
def _step(self, action):
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
def _step(self, action):
|
||||
total_reward = 0.0
|
||||
done = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._obs_buffer.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
|
||||
|
||||
return max_frame, total_reward, done, info
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
def _reset(self):
|
||||
"""Clear past frame buffer and init. to first obs. from inner env."""
|
||||
self._obs_buffer.clear()
|
||||
obs = self.env.reset()
|
||||
self._obs_buffer.append(obs)
|
||||
return obs
|
||||
|
||||
|
||||
class ProcessFrame84(gym.ObservationWrapper):
|
||||
def __init__(self, env=None):
|
||||
super(ProcessFrame84, self).__init__(env)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
|
||||
def __init__(self, env=None):
|
||||
super(ProcessFrame84, self).__init__(env)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
|
||||
|
||||
def _observation(self, obs):
|
||||
return ProcessFrame84.process(obs)
|
||||
def _observation(self, obs):
|
||||
return ProcessFrame84.process(obs)
|
||||
|
||||
@staticmethod
|
||||
def process(frame):
|
||||
if frame.size == 210 * 160 * 3:
|
||||
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
|
||||
elif frame.size == 250 * 160 * 3:
|
||||
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
|
||||
else:
|
||||
assert False, "Unknown resolution."
|
||||
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
|
||||
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
|
||||
x_t = resized_screen[18:102, :]
|
||||
x_t = np.reshape(x_t, [84, 84, 1])
|
||||
return x_t.astype(np.uint8)
|
||||
@staticmethod
|
||||
def process(frame):
|
||||
if frame.size == 210 * 160 * 3:
|
||||
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
|
||||
elif frame.size == 250 * 160 * 3:
|
||||
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
|
||||
else:
|
||||
assert False, "Unknown resolution."
|
||||
img = (img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 +
|
||||
img[:, :, 2] * 0.114)
|
||||
resized_screen = cv2.resize(img, (84, 110),
|
||||
interpolation=cv2.INTER_AREA)
|
||||
x_t = resized_screen[18:102, :]
|
||||
x_t = np.reshape(x_t, [84, 84, 1])
|
||||
return x_t.astype(np.uint8)
|
||||
|
||||
|
||||
class ClippedRewardsWrapper(gym.RewardWrapper):
|
||||
def _reward(self, reward):
|
||||
"""Change all the positive rewards to 1, negative to -1 and keep zero."""
|
||||
return np.sign(reward)
|
||||
def _reward(self, reward):
|
||||
"""Change all the positive rewards to 1, negative to -1 and keep
|
||||
zero."""
|
||||
return np.sign(reward)
|
||||
|
||||
|
||||
class LazyFrames(object):
|
||||
def __init__(self, frames):
|
||||
"""This object ensures that common frames between the observations are only
|
||||
stored once. It exists purely to optimize memory usage which can be huge
|
||||
for DQN's 1M frames replay buffers.
|
||||
def __init__(self, frames):
|
||||
"""This object ensures that common frames between the observations are
|
||||
only stored once. It exists purely to optimize memory usage which can
|
||||
be huge for DQN's 1M frames replay buffers.
|
||||
|
||||
This object should only be converted to numpy array before being passed to
|
||||
the model.
|
||||
This object should only be converted to numpy array before being passed
|
||||
to the model.
|
||||
|
||||
You'd not belive how complex the previous solution was."""
|
||||
self._frames = frames
|
||||
You'd not belive how complex the previous solution was."""
|
||||
self._frames = frames
|
||||
|
||||
def __array__(self, dtype=None):
|
||||
out = np.concatenate(self._frames, axis=2)
|
||||
if dtype is not None:
|
||||
out = out.astype(dtype)
|
||||
return out
|
||||
def __array__(self, dtype=None):
|
||||
out = np.concatenate(self._frames, axis=2)
|
||||
if dtype is not None:
|
||||
out = out.astype(dtype)
|
||||
return out
|
||||
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
def __init__(self, env, k):
|
||||
"""Stack k last frames.
|
||||
def __init__(self, env, k):
|
||||
"""Stack k last frames.
|
||||
|
||||
Returns lazy array, which is much more memory efficient.
|
||||
Returns lazy array, which is much more memory efficient.
|
||||
|
||||
See Also
|
||||
--------
|
||||
ray.rllib.dqn.common.atari_wrappers.LazyFrames
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
|
||||
See Also
|
||||
--------
|
||||
ray.rllib.dqn.common.atari_wrappers.LazyFrames
|
||||
"""
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.k = k
|
||||
self.frames = deque([], maxlen=k)
|
||||
shp = env.observation_space.shape
|
||||
self.observation_space = spaces.Box(
|
||||
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
|
||||
|
||||
def _reset(self):
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k):
|
||||
self.frames.append(ob)
|
||||
return self._get_ob()
|
||||
def _reset(self):
|
||||
ob = self.env.reset()
|
||||
for _ in range(self.k):
|
||||
self.frames.append(ob)
|
||||
return self._get_ob()
|
||||
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self._get_ob(), reward, done, info
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
self.frames.append(ob)
|
||||
return self._get_ob(), reward, done, info
|
||||
|
||||
def _get_ob(self):
|
||||
assert len(self.frames) == self.k
|
||||
return LazyFrames(list(self.frames))
|
||||
def _get_ob(self):
|
||||
assert len(self.frames) == self.k
|
||||
return LazyFrames(list(self.frames))
|
||||
|
||||
|
||||
class ScaledFloatFrame(gym.ObservationWrapper):
|
||||
def _observation(self, obs):
|
||||
# careful! This undoes the memory optimization, use
|
||||
# with smaller replay buffers only.
|
||||
return np.array(obs).astype(np.float32) / 255.0
|
||||
def _observation(self, obs):
|
||||
# careful! This undoes the memory optimization, use
|
||||
# with smaller replay buffers only.
|
||||
return np.array(obs).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def wrap_dqn(env):
|
||||
"""Apply a common set of wrappers for Atari games."""
|
||||
assert 'NoFrameskip' in env.spec.id
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = ProcessFrame84(env)
|
||||
env = FrameStack(env, 4)
|
||||
env = ClippedRewardsWrapper(env)
|
||||
return env
|
||||
"""Apply a common set of wrappers for Atari games."""
|
||||
assert 'NoFrameskip' in env.spec.id
|
||||
env = EpisodicLifeEnv(env)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
env = ProcessFrame84(env)
|
||||
env = FrameStack(env, 4)
|
||||
env = ClippedRewardsWrapper(env)
|
||||
return env
|
||||
|
||||
|
||||
class A2cProcessFrame(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
|
||||
def __init__(self, env):
|
||||
gym.Wrapper.__init__(self, env)
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
|
||||
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
return A2cProcessFrame.process(ob), reward, done, info
|
||||
def _step(self, action):
|
||||
ob, reward, done, info = self.env.step(action)
|
||||
return A2cProcessFrame.process(ob), reward, done, info
|
||||
|
||||
def _reset(self):
|
||||
return A2cProcessFrame.process(self.env.reset())
|
||||
def _reset(self):
|
||||
return A2cProcessFrame.process(self.env.reset())
|
||||
|
||||
@staticmethod
|
||||
def process(frame):
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
|
||||
return frame.reshape(84, 84, 1)
|
||||
@staticmethod
|
||||
def process(frame):
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
|
||||
return frame.reshape(84, 84, 1)
|
||||
|
||||
@@ -13,95 +13,96 @@ from __future__ import print_function
|
||||
|
||||
|
||||
class Schedule(object):
|
||||
def value(self, t):
|
||||
"""Value of the schedule at time t"""
|
||||
raise NotImplementedError()
|
||||
def value(self, t):
|
||||
"""Value of the schedule at time t"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ConstantSchedule(object):
|
||||
def __init__(self, value):
|
||||
"""Value remains constant over time.
|
||||
def __init__(self, value):
|
||||
"""Value remains constant over time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value: float
|
||||
Constant value of the schedule
|
||||
"""
|
||||
self._v = value
|
||||
Parameters
|
||||
----------
|
||||
value: float
|
||||
Constant value of the schedule
|
||||
"""
|
||||
self._v = value
|
||||
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
return self._v
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
return self._v
|
||||
|
||||
|
||||
def linear_interpolation(l, r, alpha):
|
||||
return l + alpha * (r - l)
|
||||
return l + alpha * (r - l)
|
||||
|
||||
|
||||
class PiecewiseSchedule(object):
|
||||
def __init__(
|
||||
self, endpoints, interpolation=linear_interpolation,
|
||||
outside_value=None):
|
||||
def __init__(
|
||||
self, endpoints, interpolation=linear_interpolation,
|
||||
outside_value=None):
|
||||
|
||||
"""Piecewise schedule.
|
||||
"""Piecewise schedule.
|
||||
|
||||
endpoints: [(int, int)]
|
||||
list of pairs `(time, value)` meanining that schedule should output
|
||||
`value` when `t==time`. All the values for time must be sorted in
|
||||
an increasing order. When t is between two times, e.g.
|
||||
`(time_a, value_a)`
|
||||
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value
|
||||
outputs `interpolation(value_a, value_b, alpha)` where alpha is a
|
||||
fraction of time passed between `time_a` and `time_b` for time `t`.
|
||||
interpolation: lambda float, float, float: float
|
||||
a function that takes value to the left and to the right of t according
|
||||
to the `endpoints`. Alpha is the fraction of distance from left endpoint
|
||||
to right endpoint that t has covered. See linear_interpolation for
|
||||
example.
|
||||
outside_value: float
|
||||
if the value is requested outside of all the intervals sepecified in
|
||||
`endpoints` this value is returned. If None then AssertionError is
|
||||
raised when outside value is requested.
|
||||
"""
|
||||
idxes = [e[0] for e in endpoints]
|
||||
assert idxes == sorted(idxes)
|
||||
self._interpolation = interpolation
|
||||
self._outside_value = outside_value
|
||||
self._endpoints = endpoints
|
||||
endpoints: [(int, int)]
|
||||
list of pairs `(time, value)` meanining that schedule should output
|
||||
`value` when `t==time`. All the values for time must be sorted in
|
||||
an increasing order. When t is between two times, e.g.
|
||||
`(time_a, value_a)`
|
||||
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value
|
||||
outputs `interpolation(value_a, value_b, alpha)` where alpha is a
|
||||
fraction of time passed between `time_a` and `time_b` for time `t`.
|
||||
interpolation: lambda float, float, float: float
|
||||
a function that takes value to the left and to the right of t
|
||||
according to the `endpoints`. Alpha is the fraction of distance from
|
||||
left endpoint to right endpoint that t has covered. See
|
||||
linear_interpolation for example.
|
||||
outside_value: float
|
||||
if the value is requested outside of all the intervals sepecified in
|
||||
`endpoints` this value is returned. If None then AssertionError is
|
||||
raised when outside value is requested.
|
||||
"""
|
||||
idxes = [e[0] for e in endpoints]
|
||||
assert idxes == sorted(idxes)
|
||||
self._interpolation = interpolation
|
||||
self._outside_value = outside_value
|
||||
self._endpoints = endpoints
|
||||
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
|
||||
if l_t <= t and t < r_t:
|
||||
alpha = float(t - l_t) / (r_t - l_t)
|
||||
return self._interpolation(l, r, alpha)
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1],
|
||||
self._endpoints[1:]):
|
||||
if l_t <= t and t < r_t:
|
||||
alpha = float(t - l_t) / (r_t - l_t)
|
||||
return self._interpolation(l, r, alpha)
|
||||
|
||||
# t does not belong to any of the pieces, so doom.
|
||||
assert self._outside_value is not None
|
||||
return self._outside_value
|
||||
# t does not belong to any of the pieces, so doom.
|
||||
assert self._outside_value is not None
|
||||
return self._outside_value
|
||||
|
||||
|
||||
class LinearSchedule(object):
|
||||
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
|
||||
"""Linear interpolation between initial_p and final_p over
|
||||
schedule_timesteps. After this many timesteps pass final_p is
|
||||
returned.
|
||||
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
|
||||
"""Linear interpolation between initial_p and final_p over
|
||||
schedule_timesteps. After this many timesteps pass final_p is
|
||||
returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schedule_timesteps: int
|
||||
Number of timesteps for which to linearly anneal initial_p
|
||||
to final_p
|
||||
initial_p: float
|
||||
initial output value
|
||||
final_p: float
|
||||
final output value
|
||||
"""
|
||||
self.schedule_timesteps = schedule_timesteps
|
||||
self.final_p = final_p
|
||||
self.initial_p = initial_p
|
||||
Parameters
|
||||
----------
|
||||
schedule_timesteps: int
|
||||
Number of timesteps for which to linearly anneal initial_p
|
||||
to final_p
|
||||
initial_p: float
|
||||
initial output value
|
||||
final_p: float
|
||||
final output value
|
||||
"""
|
||||
self.schedule_timesteps = schedule_timesteps
|
||||
self.final_p = final_p
|
||||
self.initial_p = initial_p
|
||||
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
fraction = min(float(t) / self.schedule_timesteps, 1.0)
|
||||
return self.initial_p + fraction * (self.final_p - self.initial_p)
|
||||
def value(self, t):
|
||||
"""See Schedule.value"""
|
||||
fraction = min(float(t) / self.schedule_timesteps, 1.0)
|
||||
return self.initial_p + fraction * (self.final_p - self.initial_p)
|
||||
|
||||
@@ -6,146 +6,148 @@ import operator
|
||||
|
||||
|
||||
class SegmentTree(object):
|
||||
def __init__(self, capacity, operation, neutral_element):
|
||||
"""Build a Segment Tree data structure.
|
||||
def __init__(self, capacity, operation, neutral_element):
|
||||
"""Build a Segment Tree data structure.
|
||||
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
https://en.wikipedia.org/wiki/Segment_tree
|
||||
|
||||
Can be used as regular array, but with two
|
||||
important differences:
|
||||
Can be used as regular array, but with two
|
||||
important differences:
|
||||
|
||||
a) setting item's value is slightly slower.
|
||||
It is O(lg capacity) instead of O(1).
|
||||
b) user has access to an efficient `reduce`
|
||||
operation which reduces `operation` over
|
||||
a contiguous subsequence of items in the
|
||||
array.
|
||||
a) setting item's value is slightly slower.
|
||||
It is O(lg capacity) instead of O(1).
|
||||
b) user has access to an efficient `reduce`
|
||||
operation which reduces `operation` over
|
||||
a contiguous subsequence of items in the
|
||||
array.
|
||||
|
||||
Paramters
|
||||
---------
|
||||
capacity: int
|
||||
Total size of the array - must be a power of two.
|
||||
operation: lambda obj, obj -> obj
|
||||
and operation for combining elements (eg. sum, max)
|
||||
must for a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element: obj
|
||||
neutral element for the operation above. eg. float('-inf')
|
||||
for max and 0 for sum.
|
||||
"""
|
||||
Paramters
|
||||
---------
|
||||
capacity: int
|
||||
Total size of the array - must be a power of two.
|
||||
operation: lambda obj, obj -> obj
|
||||
and operation for combining elements (eg. sum, max)
|
||||
must for a mathematical group together with the set of
|
||||
possible values for array elements.
|
||||
neutral_element: obj
|
||||
neutral element for the operation above. eg. float('-inf')
|
||||
for max and 0 for sum.
|
||||
"""
|
||||
|
||||
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
||||
"capacity must be positive and a power of 2."
|
||||
self._capacity = capacity
|
||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||
self._operation = operation
|
||||
assert capacity > 0 and capacity & (capacity - 1) == 0, \
|
||||
"capacity must be positive and a power of 2."
|
||||
self._capacity = capacity
|
||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||
self._operation = operation
|
||||
|
||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||
if start == node_start and end == node_end:
|
||||
return self._value[node]
|
||||
mid = (node_start + node_end) // 2
|
||||
if end <= mid:
|
||||
return self._reduce_helper(start, end, 2 * node, node_start, mid)
|
||||
else:
|
||||
if mid + 1 <= start:
|
||||
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
|
||||
else:
|
||||
return self._operation(
|
||||
self._reduce_helper(start, mid, 2 * node, node_start, mid),
|
||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
||||
)
|
||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||
if start == node_start and end == node_end:
|
||||
return self._value[node]
|
||||
mid = (node_start + node_end) // 2
|
||||
if end <= mid:
|
||||
return self._reduce_helper(start, end, 2 * node, node_start, mid)
|
||||
else:
|
||||
if mid + 1 <= start:
|
||||
return self._reduce_helper(start, end, 2 * node + 1, mid + 1,
|
||||
node_end)
|
||||
else:
|
||||
return self._operation(
|
||||
self._reduce_helper(start, mid, 2 * node, node_start, mid),
|
||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1,
|
||||
node_end)
|
||||
)
|
||||
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Returns result of applying `self.operation`
|
||||
to a contiguous subsequence of the array.
|
||||
def reduce(self, start=0, end=None):
|
||||
"""Returns result of applying `self.operation`
|
||||
to a contiguous subsequence of the array.
|
||||
|
||||
self.operation(
|
||||
arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||
self.operation(
|
||||
arr[start], operation(arr[start+1], operation(... arr[end])))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start: int
|
||||
beginning of the subsequence
|
||||
end: int
|
||||
end of the subsequences
|
||||
Parameters
|
||||
----------
|
||||
start: int
|
||||
beginning of the subsequence
|
||||
end: int
|
||||
end of the subsequences
|
||||
|
||||
Returns
|
||||
-------
|
||||
reduced: obj
|
||||
result of reducing self.operation over the specified range of array
|
||||
elements.
|
||||
"""
|
||||
if end is None:
|
||||
end = self._capacity
|
||||
if end < 0:
|
||||
end += self._capacity
|
||||
end -= 1
|
||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||
Returns
|
||||
-------
|
||||
reduced: obj
|
||||
result of reducing self.operation over the specified range of array
|
||||
elements.
|
||||
"""
|
||||
if end is None:
|
||||
end = self._capacity
|
||||
if end < 0:
|
||||
end += self._capacity
|
||||
end -= 1
|
||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||
|
||||
def __setitem__(self, idx, val):
|
||||
# index of the leaf
|
||||
idx += self._capacity
|
||||
self._value[idx] = val
|
||||
idx //= 2
|
||||
while idx >= 1:
|
||||
self._value[idx] = self._operation(
|
||||
self._value[2 * idx],
|
||||
self._value[2 * idx + 1])
|
||||
idx //= 2
|
||||
def __setitem__(self, idx, val):
|
||||
# index of the leaf
|
||||
idx += self._capacity
|
||||
self._value[idx] = val
|
||||
idx //= 2
|
||||
while idx >= 1:
|
||||
self._value[idx] = self._operation(
|
||||
self._value[2 * idx],
|
||||
self._value[2 * idx + 1])
|
||||
idx //= 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self._capacity
|
||||
return self._value[self._capacity + idx]
|
||||
def __getitem__(self, idx):
|
||||
assert 0 <= idx < self._capacity
|
||||
return self._value[self._capacity + idx]
|
||||
|
||||
|
||||
class SumSegmentTree(SegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(SumSegmentTree, self).__init__(
|
||||
capacity=capacity,
|
||||
operation=operator.add,
|
||||
neutral_element=0.0)
|
||||
def __init__(self, capacity):
|
||||
super(SumSegmentTree, self).__init__(
|
||||
capacity=capacity,
|
||||
operation=operator.add,
|
||||
neutral_element=0.0)
|
||||
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(SumSegmentTree, self).reduce(start, end)
|
||||
def sum(self, start=0, end=None):
|
||||
"""Returns arr[start] + ... + arr[end]"""
|
||||
return super(SumSegmentTree, self).reduce(start, end)
|
||||
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
def find_prefixsum_idx(self, prefixsum):
|
||||
"""Find the highest index `i` in the array such that
|
||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||
|
||||
if array values are probabilities, this function
|
||||
allows to sample indexes according to the discrete
|
||||
probability efficiently.
|
||||
if array values are probabilities, this function
|
||||
allows to sample indexes according to the discrete
|
||||
probability efficiently.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
perfixsum: float
|
||||
upperbound on the sum of array prefix
|
||||
Parameters
|
||||
----------
|
||||
perfixsum: float
|
||||
upperbound on the sum of array prefix
|
||||
|
||||
Returns
|
||||
-------
|
||||
idx: int
|
||||
highest index satisfying the prefixsum constraint
|
||||
"""
|
||||
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||
idx = 1
|
||||
while idx < self._capacity: # while non-leaf
|
||||
if self._value[2 * idx] > prefixsum:
|
||||
idx = 2 * idx
|
||||
else:
|
||||
prefixsum -= self._value[2 * idx]
|
||||
idx = 2 * idx + 1
|
||||
return idx - self._capacity
|
||||
Returns
|
||||
-------
|
||||
idx: int
|
||||
highest index satisfying the prefixsum constraint
|
||||
"""
|
||||
assert 0 <= prefixsum <= self.sum() + 1e-5
|
||||
idx = 1
|
||||
while idx < self._capacity: # while non-leaf
|
||||
if self._value[2 * idx] > prefixsum:
|
||||
idx = 2 * idx
|
||||
else:
|
||||
prefixsum -= self._value[2 * idx]
|
||||
idx = 2 * idx + 1
|
||||
return idx - self._capacity
|
||||
|
||||
|
||||
class MinSegmentTree(SegmentTree):
|
||||
def __init__(self, capacity):
|
||||
super(MinSegmentTree, self).__init__(
|
||||
capacity=capacity,
|
||||
operation=min,
|
||||
neutral_element=float('inf'))
|
||||
def __init__(self, capacity):
|
||||
super(MinSegmentTree, self).__init__(
|
||||
capacity=capacity,
|
||||
operation=min,
|
||||
neutral_element=float('inf'))
|
||||
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
def min(self, start=0, end=None):
|
||||
"""Returns min(arr[start], ..., arr[end])"""
|
||||
|
||||
return super(MinSegmentTree, self).reduce(start, end)
|
||||
return super(MinSegmentTree, self).reduce(start, end)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+128
-119
@@ -88,132 +88,141 @@ DEFAULT_CONFIG = dict(
|
||||
|
||||
|
||||
class DQN(Algorithm):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "DQN"})
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
env = gym.make(env_name)
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
self.env = env
|
||||
model = models.cnn_to_mlp(
|
||||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
||||
hiddens=[256], dueling=True)
|
||||
sess = U.make_session(num_cpu=config["num_cpu"])
|
||||
sess.__enter__()
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "DQN"})
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
env = gym.make(env_name)
|
||||
env = ScaledFloatFrame(wrap_dqn(env))
|
||||
self.env = env
|
||||
model = models.cnn_to_mlp(
|
||||
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
|
||||
hiddens=[256], dueling=True)
|
||||
sess = U.make_session(num_cpu=config["num_cpu"])
|
||||
sess.__enter__()
|
||||
|
||||
def make_obs_ph(name):
|
||||
return U.BatchInput(env.observation_space.shape, name=name)
|
||||
def make_obs_ph(name):
|
||||
return U.BatchInput(env.observation_space.shape, name=name)
|
||||
|
||||
self.act, self.optimize, self.update_target, self.debug = build_train(
|
||||
make_obs_ph=make_obs_ph,
|
||||
q_func=model,
|
||||
num_actions=env.action_space.n,
|
||||
optimizer=tf.train.AdamOptimizer(learning_rate=config["lr"]),
|
||||
gamma=config["gamma"],
|
||||
grad_norm_clipping=10)
|
||||
# Create the replay buffer
|
||||
if config["prioritized_replay"]:
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
config["buffer_size"], alpha=config["prioritized_replay_alpha"])
|
||||
prioritized_replay_beta_iters = config["prioritized_replay_beta_iters"]
|
||||
if prioritized_replay_beta_iters is None:
|
||||
prioritized_replay_beta_iters = config["schedule_max_timesteps"]
|
||||
self.beta_schedule = LinearSchedule(
|
||||
prioritized_replay_beta_iters,
|
||||
initial_p=config["prioritized_replay_beta0"],
|
||||
final_p=1.0)
|
||||
else:
|
||||
self.replay_buffer = ReplayBuffer(config["buffer_size"])
|
||||
self.beta_schedule = None
|
||||
# Create the schedule for exploration starting from 1.
|
||||
self.exploration = LinearSchedule(
|
||||
schedule_timesteps=int(
|
||||
config["exploration_fraction"] * config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=config["exploration_final_eps"])
|
||||
|
||||
# Initialize the parameters and copy them to the target network.
|
||||
U.initialize()
|
||||
self.update_target()
|
||||
|
||||
self.episode_rewards = [0.0]
|
||||
self.episode_lengths = [0.0]
|
||||
self.saved_mean_reward = None
|
||||
self.obs = self.env.reset()
|
||||
self.num_timesteps = 0
|
||||
self.num_iterations = 0
|
||||
|
||||
def train(self):
|
||||
config = self.config
|
||||
sample_time, learn_time = 0, 0
|
||||
|
||||
for t in range(config["timesteps_per_iteration"]):
|
||||
self.num_timesteps += 1
|
||||
dt = time.time()
|
||||
# Take action and update exploration to the newest value
|
||||
action = self.act(
|
||||
np.array(self.obs)[None], update_eps=self.exploration.value(t))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
# Store transition in the replay buffer.
|
||||
self.replay_buffer.add(self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
sample_time += time.time() - dt
|
||||
|
||||
if self.num_timesteps > config["learning_starts"] and \
|
||||
self.num_timesteps % config["train_freq"] == 0:
|
||||
dt = time.time()
|
||||
# Minimize the error in Bellman's equation on a batch sampled from
|
||||
# replay buffer.
|
||||
self.act, self.optimize, self.update_target, self.debug = build_train(
|
||||
make_obs_ph=make_obs_ph,
|
||||
q_func=model,
|
||||
num_actions=env.action_space.n,
|
||||
optimizer=tf.train.AdamOptimizer(learning_rate=config["lr"]),
|
||||
gamma=config["gamma"],
|
||||
grad_norm_clipping=10)
|
||||
# Create the replay buffer
|
||||
if config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
config["batch_size"], beta=self.beta_schedule.value(t))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
self.replay_buffer = PrioritizedReplayBuffer(
|
||||
config["buffer_size"],
|
||||
alpha=config["prioritized_replay_alpha"])
|
||||
prioritized_replay_beta_iters = (
|
||||
config["prioritized_replay_beta_iters"])
|
||||
if prioritized_replay_beta_iters is None:
|
||||
prioritized_replay_beta_iters = (
|
||||
config["schedule_max_timesteps"])
|
||||
self.beta_schedule = LinearSchedule(
|
||||
prioritized_replay_beta_iters,
|
||||
initial_p=config["prioritized_replay_beta0"],
|
||||
final_p=1.0)
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = \
|
||||
self.replay_buffer.sample(config["batch_size"])
|
||||
batch_idxes = None
|
||||
td_errors = self.optimize(
|
||||
obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
|
||||
if config["prioritized_replay"]:
|
||||
new_priorities = np.abs(td_errors) + config["prioritized_replay_eps"]
|
||||
self.replay_buffer.update_priorities(batch_idxes, new_priorities)
|
||||
learn_time += (time.time() - dt)
|
||||
self.replay_buffer = ReplayBuffer(config["buffer_size"])
|
||||
self.beta_schedule = None
|
||||
# Create the schedule for exploration starting from 1.
|
||||
self.exploration = LinearSchedule(
|
||||
schedule_timesteps=int(
|
||||
config["exploration_fraction"] *
|
||||
config["schedule_max_timesteps"]),
|
||||
initial_p=1.0,
|
||||
final_p=config["exploration_final_eps"])
|
||||
|
||||
if self.num_timesteps > config["learning_starts"] and \
|
||||
self.num_timesteps % config["target_network_update_freq"] == 0:
|
||||
# Update target network periodically.
|
||||
# Initialize the parameters and copy them to the target network.
|
||||
U.initialize()
|
||||
self.update_target()
|
||||
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
|
||||
num_episodes = len(self.episode_rewards)
|
||||
self.episode_rewards = [0.0]
|
||||
self.episode_lengths = [0.0]
|
||||
self.saved_mean_reward = None
|
||||
self.obs = self.env.reset()
|
||||
self.num_timesteps = 0
|
||||
self.num_iterations = 0
|
||||
|
||||
info = {
|
||||
"sample_time": sample_time,
|
||||
"learn_time": learn_time,
|
||||
"steps": self.num_timesteps,
|
||||
"episodes": num_episodes,
|
||||
"exploration": int(100 * self.exploration.value(t))
|
||||
}
|
||||
def train(self):
|
||||
config = self.config
|
||||
sample_time, learn_time = 0, 0
|
||||
|
||||
logger.record_tabular("sample_time", sample_time)
|
||||
logger.record_tabular("learn_time", learn_time)
|
||||
logger.record_tabular("steps", self.num_timesteps)
|
||||
logger.record_tabular("episodes", num_episodes)
|
||||
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
|
||||
logger.record_tabular(
|
||||
"% time spent exploring", int(100 * self.exploration.value(t)))
|
||||
logger.dump_tabular()
|
||||
for t in range(config["timesteps_per_iteration"]):
|
||||
self.num_timesteps += 1
|
||||
dt = time.time()
|
||||
# Take action and update exploration to the newest value
|
||||
action = self.act(
|
||||
np.array(self.obs)[None],
|
||||
update_eps=self.exploration.value(t))[0]
|
||||
new_obs, rew, done, _ = self.env.step(action)
|
||||
# Store transition in the replay buffer.
|
||||
self.replay_buffer.add(self.obs, action, rew, new_obs, float(done))
|
||||
self.obs = new_obs
|
||||
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
|
||||
mean_100ep_length, info)
|
||||
self.num_iterations += 1
|
||||
return res
|
||||
self.episode_rewards[-1] += rew
|
||||
self.episode_lengths[-1] += 1
|
||||
if done:
|
||||
self.obs = self.env.reset()
|
||||
self.episode_rewards.append(0.0)
|
||||
self.episode_lengths.append(0.0)
|
||||
sample_time += time.time() - dt
|
||||
|
||||
if self.num_timesteps > config["learning_starts"] and \
|
||||
self.num_timesteps % config["train_freq"] == 0:
|
||||
dt = time.time()
|
||||
# Minimize the error in Bellman's equation on a batch sampled
|
||||
# from replay buffer.
|
||||
if config["prioritized_replay"]:
|
||||
experience = self.replay_buffer.sample(
|
||||
config["batch_size"], beta=self.beta_schedule.value(t))
|
||||
(obses_t, actions, rewards, obses_tp1,
|
||||
dones, _, batch_idxes) = experience
|
||||
else:
|
||||
obses_t, actions, rewards, obses_tp1, dones = \
|
||||
self.replay_buffer.sample(config["batch_size"])
|
||||
batch_idxes = None
|
||||
td_errors = self.optimize(
|
||||
obses_t, actions, rewards, obses_tp1, dones,
|
||||
np.ones_like(rewards))
|
||||
if config["prioritized_replay"]:
|
||||
new_priorities = (np.abs(td_errors) +
|
||||
config["prioritized_replay_eps"])
|
||||
self.replay_buffer.update_priorities(batch_idxes,
|
||||
new_priorities)
|
||||
learn_time += (time.time() - dt)
|
||||
|
||||
if (self.num_timesteps > config["learning_starts"] and
|
||||
self.num_timesteps %
|
||||
config["target_network_update_freq"] == 0):
|
||||
# Update target network periodically.
|
||||
self.update_target()
|
||||
|
||||
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
|
||||
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
|
||||
num_episodes = len(self.episode_rewards)
|
||||
|
||||
info = {
|
||||
"sample_time": sample_time,
|
||||
"learn_time": learn_time,
|
||||
"steps": self.num_timesteps,
|
||||
"episodes": num_episodes,
|
||||
"exploration": int(100 * self.exploration.value(t))
|
||||
}
|
||||
|
||||
logger.record_tabular("sample_time", sample_time)
|
||||
logger.record_tabular("learn_time", learn_time)
|
||||
logger.record_tabular("steps", self.num_timesteps)
|
||||
logger.record_tabular("episodes", num_episodes)
|
||||
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
|
||||
logger.record_tabular(
|
||||
"% time spent exploring", int(100 * self.exploration.value(t)))
|
||||
logger.dump_tabular()
|
||||
|
||||
res = TrainingResult(
|
||||
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
|
||||
mean_100ep_length, info)
|
||||
self.num_iterations += 1
|
||||
return res
|
||||
|
||||
@@ -24,8 +24,8 @@ def main():
|
||||
dqn = DQN("PongNoFrameskip-v4", config)
|
||||
|
||||
while True:
|
||||
res = dqn.train()
|
||||
print("current status: {}".format(res))
|
||||
res = dqn.train()
|
||||
print("current status: {}".format(res))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
+173
-172
@@ -29,88 +29,88 @@ DISABLED = 50
|
||||
|
||||
|
||||
class OutputFormat(object):
|
||||
def writekvs(self, kvs):
|
||||
"""
|
||||
Write key-value pairs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
def writekvs(self, kvs):
|
||||
"""
|
||||
Write key-value pairs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def writeseq(self, args):
|
||||
"""
|
||||
Write a sequence of other data (e.g. a logging message)
|
||||
"""
|
||||
pass
|
||||
def writeseq(self, args):
|
||||
"""
|
||||
Write a sequence of other data (e.g. a logging message)
|
||||
"""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
return
|
||||
def close(self):
|
||||
return
|
||||
|
||||
|
||||
class HumanOutputFormat(OutputFormat):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def writekvs(self, kvs):
|
||||
# Create strings for printing
|
||||
key2str = OrderedDict()
|
||||
for (key, val) in kvs.items():
|
||||
valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
|
||||
key2str[self._truncate(key)] = self._truncate(valstr)
|
||||
def writekvs(self, kvs):
|
||||
# Create strings for printing
|
||||
key2str = OrderedDict()
|
||||
for (key, val) in kvs.items():
|
||||
valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
|
||||
key2str[self._truncate(key)] = self._truncate(valstr)
|
||||
|
||||
# Find max widths
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
# Find max widths
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
|
||||
# Write out the data
|
||||
dashes = '-' * (keywidth + valwidth + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in key2str.items():
|
||||
lines.append('| %s%s | %s%s |' % (
|
||||
key,
|
||||
' ' * (keywidth - len(key)),
|
||||
val,
|
||||
' ' * (valwidth - len(val)),
|
||||
))
|
||||
lines.append(dashes)
|
||||
self.file.write('\n'.join(lines) + '\n')
|
||||
# Write out the data
|
||||
dashes = '-' * (keywidth + valwidth + 7)
|
||||
lines = [dashes]
|
||||
for (key, val) in key2str.items():
|
||||
lines.append('| %s%s | %s%s |' % (
|
||||
key,
|
||||
' ' * (keywidth - len(key)),
|
||||
val,
|
||||
' ' * (valwidth - len(val)),
|
||||
))
|
||||
lines.append(dashes)
|
||||
self.file.write('\n'.join(lines) + '\n')
|
||||
|
||||
# Flush the output to the file
|
||||
self.file.flush()
|
||||
# Flush the output to the file
|
||||
self.file.flush()
|
||||
|
||||
def _truncate(self, s):
|
||||
return s[:20] + '...' if len(s) > 23 else s
|
||||
def _truncate(self, s):
|
||||
return s[:20] + '...' if len(s) > 23 else s
|
||||
|
||||
def writeseq(self, args):
|
||||
for arg in args:
|
||||
self.file.write(arg)
|
||||
self.file.write('\n')
|
||||
self.file.flush()
|
||||
def writeseq(self, args):
|
||||
for arg in args:
|
||||
self.file.write(arg)
|
||||
self.file.write('\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
class JSONOutputFormat(OutputFormat):
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
def __init__(self, file):
|
||||
self.file = file
|
||||
|
||||
def writekvs(self, kvs):
|
||||
for k, v in kvs.items():
|
||||
if hasattr(v, 'dtype'):
|
||||
v = v.tolist()
|
||||
kvs[k] = float(v)
|
||||
self.file.write(json.dumps(kvs) + '\n')
|
||||
self.file.flush()
|
||||
def writekvs(self, kvs):
|
||||
for k, v in kvs.items():
|
||||
if hasattr(v, 'dtype'):
|
||||
v = v.tolist()
|
||||
kvs[k] = float(v)
|
||||
self.file.write(json.dumps(kvs) + '\n')
|
||||
self.file.flush()
|
||||
|
||||
|
||||
def make_output_format(format, ev_dir):
|
||||
os.makedirs(ev_dir, exist_ok=True)
|
||||
if format == 'stdout':
|
||||
return HumanOutputFormat(sys.stdout)
|
||||
elif format == 'log':
|
||||
log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
|
||||
return HumanOutputFormat(log_file)
|
||||
elif format == 'json':
|
||||
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
|
||||
return JSONOutputFormat(json_file)
|
||||
else:
|
||||
raise ValueError('Unknown format specified: %s' % (format,))
|
||||
os.makedirs(ev_dir, exist_ok=True)
|
||||
if format == 'stdout':
|
||||
return HumanOutputFormat(sys.stdout)
|
||||
elif format == 'log':
|
||||
log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
|
||||
return HumanOutputFormat(log_file)
|
||||
elif format == 'json':
|
||||
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
|
||||
return JSONOutputFormat(json_file)
|
||||
else:
|
||||
raise ValueError('Unknown format specified: %s' % (format,))
|
||||
|
||||
# ================================================================
|
||||
# API
|
||||
@@ -118,21 +118,21 @@ def make_output_format(format, ev_dir):
|
||||
|
||||
|
||||
def logkv(key, val):
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
"""
|
||||
Log a value of some diagnostic
|
||||
Call this once for each diagnostic quantity, each iteration
|
||||
"""
|
||||
Logger.CURRENT.logkv(key, val)
|
||||
|
||||
|
||||
def dumpkvs():
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
"""
|
||||
Write all of the diagnostics from the current iteration
|
||||
|
||||
level: int. (see logger.py docs) If the global logger level is higher than
|
||||
the level argument here, don't print to stdout.
|
||||
"""
|
||||
Logger.CURRENT.dumpkvs()
|
||||
level: int. (see logger.py docs) If the global logger level is higher than
|
||||
the level argument here, don't print to stdout.
|
||||
"""
|
||||
Logger.CURRENT.dumpkvs()
|
||||
|
||||
|
||||
# for backwards compatibility
|
||||
@@ -141,49 +141,50 @@ dump_tabular = dumpkvs
|
||||
|
||||
|
||||
def log(*args, level=INFO):
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output
|
||||
files (if you've configured an output file).
|
||||
"""
|
||||
Logger.CURRENT.log(*args, level=level)
|
||||
"""
|
||||
Write the sequence of args, with no separators, to the console and output
|
||||
files (if you've configured an output file).
|
||||
"""
|
||||
Logger.CURRENT.log(*args, level=level)
|
||||
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
log(*args, level=DEBUG)
|
||||
|
||||
|
||||
def info(*args):
|
||||
log(*args, level=INFO)
|
||||
log(*args, level=INFO)
|
||||
|
||||
|
||||
def warn(*args):
|
||||
log(*args, level=WARN)
|
||||
log(*args, level=WARN)
|
||||
|
||||
|
||||
def error(*args):
|
||||
log(*args, level=ERROR)
|
||||
log(*args, level=ERROR)
|
||||
|
||||
|
||||
def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
Logger.CURRENT.set_level(level)
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
Logger.CURRENT.set_level(level)
|
||||
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
"""
|
||||
return Logger.CURRENT.get_dir()
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call
|
||||
start)
|
||||
"""
|
||||
return Logger.CURRENT.get_dir()
|
||||
|
||||
|
||||
def get_expt_dir():
|
||||
sys.stderr.write(
|
||||
"get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" %
|
||||
(get_dir(),))
|
||||
return get_dir()
|
||||
sys.stderr.write(
|
||||
"get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" %
|
||||
(get_dir(),))
|
||||
return get_dir()
|
||||
|
||||
|
||||
# ================================================================
|
||||
@@ -192,50 +193,50 @@ def get_expt_dir():
|
||||
|
||||
|
||||
class Logger(object):
|
||||
# A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output
|
||||
DEFAULT = None
|
||||
# A logger with no output files. (See right below class definition)
|
||||
# So that you can still log to the terminal without setting up any output
|
||||
DEFAULT = None
|
||||
|
||||
# Current logger being used by the free functions above
|
||||
CURRENT = None
|
||||
# Current logger being used by the free functions above
|
||||
CURRENT = None
|
||||
|
||||
def __init__(self, dir, output_formats):
|
||||
self.name2val = OrderedDict() # values this iteration
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
def __init__(self, dir, output_formats):
|
||||
self.name2val = OrderedDict() # values this iteration
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.output_formats = output_formats
|
||||
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
def logkv(self, key, val):
|
||||
self.name2val[key] = val
|
||||
# Logging API, forwarded
|
||||
# ----------------------------------------
|
||||
def logkv(self, key, val):
|
||||
self.name2val[key] = val
|
||||
|
||||
def dumpkvs(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writekvs(self.name2val)
|
||||
self.name2val.clear()
|
||||
def dumpkvs(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writekvs(self.name2val)
|
||||
self.name2val.clear()
|
||||
|
||||
def log(self, *args, level=INFO):
|
||||
if self.level <= level:
|
||||
self._do_log(args)
|
||||
def log(self, *args, level=INFO):
|
||||
if self.level <= level:
|
||||
self._do_log(args)
|
||||
|
||||
# Configuration
|
||||
# ----------------------------------------
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
# Configuration
|
||||
# ----------------------------------------
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
def close(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.close()
|
||||
def close(self):
|
||||
for fmt in self.output_formats:
|
||||
fmt.close()
|
||||
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, args):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writeseq(args)
|
||||
# Misc
|
||||
# ----------------------------------------
|
||||
def _do_log(self, args):
|
||||
for fmt in self.output_formats:
|
||||
fmt.writeseq(args)
|
||||
|
||||
|
||||
# ================================================================
|
||||
@@ -246,60 +247,60 @@ Logger.CURRENT = Logger.DEFAULT
|
||||
|
||||
|
||||
class session(object):
|
||||
"""
|
||||
Context manager that sets up the loggers for an experiment.
|
||||
"""
|
||||
"""
|
||||
Context manager that sets up the loggers for an experiment.
|
||||
"""
|
||||
|
||||
CURRENT = None # Set to a LoggerContext object using enter/exit or cm
|
||||
CURRENT = None # Set to a LoggerContext object using enter/exit or cm
|
||||
|
||||
def __init__(self, dir, format_strs=None):
|
||||
self.dir = dir
|
||||
if format_strs is None:
|
||||
format_strs = LOG_OUTPUT_FORMATS
|
||||
output_formats = [make_output_format(f, dir) for f in format_strs]
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||
def __init__(self, dir, format_strs=None):
|
||||
self.dir = dir
|
||||
if format_strs is None:
|
||||
format_strs = LOG_OUTPUT_FORMATS
|
||||
output_formats = [make_output_format(f, dir) for f in format_strs]
|
||||
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
|
||||
|
||||
def __enter__(self):
|
||||
os.makedirs(self.evaluation_dir(), exist_ok=True)
|
||||
output_formats = [
|
||||
make_output_format(
|
||||
f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
|
||||
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
|
||||
def __enter__(self):
|
||||
os.makedirs(self.evaluation_dir(), exist_ok=True)
|
||||
output_formats = [
|
||||
make_output_format(
|
||||
f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
|
||||
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
|
||||
|
||||
def __exit__(self, *args):
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
def __exit__(self, *args):
|
||||
Logger.CURRENT.close()
|
||||
Logger.CURRENT = Logger.DEFAULT
|
||||
|
||||
def evaluation_dir(self):
|
||||
return self.dir
|
||||
def evaluation_dir(self):
|
||||
return self.dir
|
||||
|
||||
|
||||
# ================================================================
|
||||
|
||||
|
||||
def _demo():
|
||||
info("hi")
|
||||
debug("shouldn't appear")
|
||||
set_level(DEBUG)
|
||||
debug("should appear")
|
||||
dir = "/tmp/testlogging"
|
||||
if os.path.exists(dir):
|
||||
shutil.rmtree(dir)
|
||||
with session(dir=dir):
|
||||
record_tabular("a", 3)
|
||||
record_tabular("b", 2.5)
|
||||
dump_tabular()
|
||||
info("hi")
|
||||
debug("shouldn't appear")
|
||||
set_level(DEBUG)
|
||||
debug("should appear")
|
||||
dir = "/tmp/testlogging"
|
||||
if os.path.exists(dir):
|
||||
shutil.rmtree(dir)
|
||||
with session(dir=dir):
|
||||
record_tabular("a", 3)
|
||||
record_tabular("b", 2.5)
|
||||
dump_tabular()
|
||||
record_tabular("b", -2.5)
|
||||
record_tabular("a", 5.5)
|
||||
dump_tabular()
|
||||
info("^^^ should see a = 5.5")
|
||||
|
||||
record_tabular("b", -2.5)
|
||||
record_tabular("a", 5.5)
|
||||
dump_tabular()
|
||||
info("^^^ should see a = 5.5")
|
||||
|
||||
record_tabular("b", -2.5)
|
||||
dump_tabular()
|
||||
|
||||
record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
|
||||
dump_tabular()
|
||||
record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
|
||||
dump_tabular()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_demo()
|
||||
_demo()
|
||||
|
||||
@@ -7,91 +7,92 @@ import tensorflow.contrib.layers as layers
|
||||
|
||||
|
||||
def _mlp(hiddens, inpt, num_actions, scope, reuse=False):
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
for hidden in hiddens:
|
||||
out = layers.fully_connected(
|
||||
out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
out = layers.fully_connected(
|
||||
out, num_outputs=num_actions, activation_fn=None)
|
||||
return out
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
for hidden in hiddens:
|
||||
out = layers.fully_connected(
|
||||
out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
out = layers.fully_connected(
|
||||
out, num_outputs=num_actions, activation_fn=None)
|
||||
return out
|
||||
|
||||
|
||||
def mlp(hiddens=[]):
|
||||
"""This model takes as input an observation and returns values of all
|
||||
actions.
|
||||
"""This model takes as input an observation and returns values of all
|
||||
actions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
Parameters
|
||||
----------
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
|
||||
Returns
|
||||
-------
|
||||
q_func: function
|
||||
q_function for DQN algorithm.
|
||||
"""
|
||||
return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs)
|
||||
Returns
|
||||
-------
|
||||
q_func: function
|
||||
q_function for DQN algorithm.
|
||||
"""
|
||||
return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs)
|
||||
|
||||
|
||||
def _cnn_to_mlp(
|
||||
convs, hiddens, dueling, inpt, num_actions, scope, reuse=False):
|
||||
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
with tf.variable_scope("convnet"):
|
||||
for num_outputs, kernel_size, stride in convs:
|
||||
out = layers.convolution2d(
|
||||
out,
|
||||
num_outputs=num_outputs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
activation_fn=tf.nn.relu)
|
||||
out = layers.flatten(out)
|
||||
with tf.variable_scope("action_value"):
|
||||
action_out = out
|
||||
for hidden in hiddens:
|
||||
action_out = layers.fully_connected(
|
||||
action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
action_scores = layers.fully_connected(
|
||||
action_out, num_outputs=num_actions, activation_fn=None)
|
||||
with tf.variable_scope(scope, reuse=reuse):
|
||||
out = inpt
|
||||
with tf.variable_scope("convnet"):
|
||||
for num_outputs, kernel_size, stride in convs:
|
||||
out = layers.convolution2d(
|
||||
out,
|
||||
num_outputs=num_outputs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
activation_fn=tf.nn.relu)
|
||||
out = layers.flatten(out)
|
||||
with tf.variable_scope("action_value"):
|
||||
action_out = out
|
||||
for hidden in hiddens:
|
||||
action_out = layers.fully_connected(
|
||||
action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
action_scores = layers.fully_connected(
|
||||
action_out, num_outputs=num_actions, activation_fn=None)
|
||||
|
||||
if dueling:
|
||||
with tf.variable_scope("state_value"):
|
||||
state_out = out
|
||||
for hidden in hiddens:
|
||||
state_out = layers.fully_connected(
|
||||
state_out, num_outputs=hidden, activation_fn=tf.nn.relu)
|
||||
state_score = layers.fully_connected(
|
||||
state_out, num_outputs=1, activation_fn=None)
|
||||
action_scores_mean = tf.reduce_mean(action_scores, 1)
|
||||
action_scores_centered = action_scores - tf.expand_dims(
|
||||
action_scores_mean, 1)
|
||||
return state_score + action_scores_centered
|
||||
else:
|
||||
return action_scores
|
||||
return out
|
||||
if dueling:
|
||||
with tf.variable_scope("state_value"):
|
||||
state_out = out
|
||||
for hidden in hiddens:
|
||||
state_out = layers.fully_connected(
|
||||
state_out, num_outputs=hidden,
|
||||
activation_fn=tf.nn.relu)
|
||||
state_score = layers.fully_connected(
|
||||
state_out, num_outputs=1, activation_fn=None)
|
||||
action_scores_mean = tf.reduce_mean(action_scores, 1)
|
||||
action_scores_centered = action_scores - tf.expand_dims(
|
||||
action_scores_mean, 1)
|
||||
return state_score + action_scores_centered
|
||||
else:
|
||||
return action_scores
|
||||
return out
|
||||
|
||||
|
||||
def cnn_to_mlp(convs, hiddens, dueling=False):
|
||||
"""This model takes as input an observation and returns values of all actions.
|
||||
"""This model takes an observation and returns values for all actions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
convs: [(int, int int)]
|
||||
list of convolutional layers in form of
|
||||
(num_outputs, kernel_size, stride)
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
dueling: bool
|
||||
if true double the output MLP to compute a baseline
|
||||
for action scores
|
||||
Parameters
|
||||
----------
|
||||
convs: [(int, int int)]
|
||||
list of convolutional layers in form of
|
||||
(num_outputs, kernel_size, stride)
|
||||
hiddens: [int]
|
||||
list of sizes of hidden layers
|
||||
dueling: bool
|
||||
if true double the output MLP to compute a baseline
|
||||
for action scores
|
||||
|
||||
Returns
|
||||
-------
|
||||
q_func: function
|
||||
q_function for DQN algorithm.
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
q_func: function
|
||||
q_function for DQN algorithm.
|
||||
"""
|
||||
|
||||
return lambda *args, **kwargs: _cnn_to_mlp(
|
||||
convs, hiddens, dueling, *args, **kwargs)
|
||||
return lambda *args, **kwargs: _cnn_to_mlp(
|
||||
convs, hiddens, dueling, *args, **kwargs)
|
||||
|
||||
@@ -9,188 +9,189 @@ from ray.rllib.dqn.common.segment_tree import SumSegmentTree, MinSegmentTree
|
||||
|
||||
|
||||
class ReplayBuffer(object):
|
||||
def __init__(self, size):
|
||||
"""Create Prioritized Replay buffer.
|
||||
def __init__(self, size):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
"""
|
||||
self._storage = []
|
||||
self._maxsize = size
|
||||
self._next_idx = 0
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
"""
|
||||
self._storage = []
|
||||
self._maxsize = size
|
||||
self._next_idx = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._storage)
|
||||
def __len__(self):
|
||||
return len(self._storage)
|
||||
|
||||
def add(self, obs_t, action, reward, obs_tp1, done):
|
||||
data = (obs_t, action, reward, obs_tp1, done)
|
||||
def add(self, obs_t, action, reward, obs_tp1, done):
|
||||
data = (obs_t, action, reward, obs_tp1, done)
|
||||
|
||||
if self._next_idx >= len(self._storage):
|
||||
self._storage.append(data)
|
||||
else:
|
||||
self._storage[self._next_idx] = data
|
||||
self._next_idx = (self._next_idx + 1) % self._maxsize
|
||||
if self._next_idx >= len(self._storage):
|
||||
self._storage.append(data)
|
||||
else:
|
||||
self._storage[self._next_idx] = data
|
||||
self._next_idx = (self._next_idx + 1) % self._maxsize
|
||||
|
||||
def _encode_sample(self, idxes):
|
||||
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
|
||||
for i in idxes:
|
||||
data = self._storage[i]
|
||||
obs_t, action, reward, obs_tp1, done = data
|
||||
obses_t.append(np.array(obs_t, copy=False))
|
||||
actions.append(np.array(action, copy=False))
|
||||
rewards.append(reward)
|
||||
obses_tp1.append(np.array(obs_tp1, copy=False))
|
||||
dones.append(done)
|
||||
return np.array(obses_t), np.array(actions), np.array(rewards), \
|
||||
np.array(obses_tp1), np.array(dones)
|
||||
def _encode_sample(self, idxes):
|
||||
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
|
||||
for i in idxes:
|
||||
data = self._storage[i]
|
||||
obs_t, action, reward, obs_tp1, done = data
|
||||
obses_t.append(np.array(obs_t, copy=False))
|
||||
actions.append(np.array(action, copy=False))
|
||||
rewards.append(reward)
|
||||
obses_tp1.append(np.array(obs_tp1, copy=False))
|
||||
dones.append(done)
|
||||
return (np.array(obses_t), np.array(actions), np.array(rewards),
|
||||
np.array(obses_tp1), np.array(dones))
|
||||
|
||||
def sample(self, batch_size):
|
||||
"""Sample a batch of experiences.
|
||||
def sample(self, batch_size):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
"""
|
||||
idxes = [random.randint(0, len(self._storage) - 1)
|
||||
for _ in range(batch_size)]
|
||||
return self._encode_sample(idxes)
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
"""
|
||||
idxes = [random.randint(0, len(self._storage) - 1)
|
||||
for _ in range(batch_size)]
|
||||
return self._encode_sample(idxes)
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def __init__(self, size, alpha):
|
||||
"""Create Prioritized Replay buffer.
|
||||
def __init__(self, size, alpha):
|
||||
"""Create Prioritized Replay buffer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
alpha: float
|
||||
how much prioritization is used
|
||||
(0 - no prioritization, 1 - full prioritization)
|
||||
Parameters
|
||||
----------
|
||||
size: int
|
||||
Max number of transitions to store in the buffer. When the buffer
|
||||
overflows the old memories are dropped.
|
||||
alpha: float
|
||||
how much prioritization is used
|
||||
(0 - no prioritization, 1 - full prioritization)
|
||||
|
||||
See Also
|
||||
--------
|
||||
ReplayBuffer.__init__
|
||||
"""
|
||||
super(PrioritizedReplayBuffer, self).__init__(size)
|
||||
assert alpha > 0
|
||||
self._alpha = alpha
|
||||
See Also
|
||||
--------
|
||||
ReplayBuffer.__init__
|
||||
"""
|
||||
super(PrioritizedReplayBuffer, self).__init__(size)
|
||||
assert alpha > 0
|
||||
self._alpha = alpha
|
||||
|
||||
it_capacity = 1
|
||||
while it_capacity < size:
|
||||
it_capacity *= 2
|
||||
it_capacity = 1
|
||||
while it_capacity < size:
|
||||
it_capacity *= 2
|
||||
|
||||
self._it_sum = SumSegmentTree(it_capacity)
|
||||
self._it_min = MinSegmentTree(it_capacity)
|
||||
self._max_priority = 1.0
|
||||
self._it_sum = SumSegmentTree(it_capacity)
|
||||
self._it_min = MinSegmentTree(it_capacity)
|
||||
self._max_priority = 1.0
|
||||
|
||||
def add(self, *args, **kwargs):
|
||||
"""See ReplayBuffer.store_effect"""
|
||||
idx = self._next_idx
|
||||
super().add(*args, **kwargs)
|
||||
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||
self._it_min[idx] = self._max_priority ** self._alpha
|
||||
def add(self, *args, **kwargs):
|
||||
"""See ReplayBuffer.store_effect"""
|
||||
idx = self._next_idx
|
||||
super().add(*args, **kwargs)
|
||||
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||
self._it_min[idx] = self._max_priority ** self._alpha
|
||||
|
||||
def _sample_proportional(self, batch_size):
|
||||
res = []
|
||||
for _ in range(batch_size):
|
||||
# TODO(szymon): should we ensure no repeats?
|
||||
mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
|
||||
idx = self._it_sum.find_prefixsum_idx(mass)
|
||||
res.append(idx)
|
||||
return res
|
||||
def _sample_proportional(self, batch_size):
|
||||
res = []
|
||||
for _ in range(batch_size):
|
||||
# TODO(szymon): should we ensure no repeats?
|
||||
mass = random.random() * self._it_sum.sum(0,
|
||||
len(self._storage) - 1)
|
||||
idx = self._it_sum.find_prefixsum_idx(mass)
|
||||
res.append(idx)
|
||||
return res
|
||||
|
||||
def sample(self, batch_size, beta):
|
||||
"""Sample a batch of experiences.
|
||||
def sample(self, batch_size, beta):
|
||||
"""Sample a batch of experiences.
|
||||
|
||||
compared to ReplayBuffer.sample
|
||||
it also returns importance weights and idxes
|
||||
of sampled experiences.
|
||||
compared to ReplayBuffer.sample
|
||||
it also returns importance weights and idxes
|
||||
of sampled experiences.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
beta: float
|
||||
To what degree to use importance weights
|
||||
(0 - no corrections, 1 - full correction)
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
How many transitions to sample.
|
||||
beta: float
|
||||
To what degree to use importance weights
|
||||
(0 - no corrections, 1 - full correction)
|
||||
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
weights: np.array
|
||||
Array of shape (batch_size,) and dtype np.float32
|
||||
denoting importance weight of each sampled transition
|
||||
idxes: np.array
|
||||
Array of shape (batch_size,) and dtype np.int32
|
||||
idexes in buffer of sampled experiences
|
||||
"""
|
||||
assert beta > 0
|
||||
Returns
|
||||
-------
|
||||
obs_batch: np.array
|
||||
batch of observations
|
||||
act_batch: np.array
|
||||
batch of actions executed given obs_batch
|
||||
rew_batch: np.array
|
||||
rewards received as results of executing act_batch
|
||||
next_obs_batch: np.array
|
||||
next set of observations seen after executing act_batch
|
||||
done_mask: np.array
|
||||
done_mask[i] = 1 if executing act_batch[i] resulted in
|
||||
the end of an episode and 0 otherwise.
|
||||
weights: np.array
|
||||
Array of shape (batch_size,) and dtype np.float32
|
||||
denoting importance weight of each sampled transition
|
||||
idxes: np.array
|
||||
Array of shape (batch_size,) and dtype np.int32
|
||||
idexes in buffer of sampled experiences
|
||||
"""
|
||||
assert beta > 0
|
||||
|
||||
idxes = self._sample_proportional(batch_size)
|
||||
idxes = self._sample_proportional(batch_size)
|
||||
|
||||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self._storage)) ** (-beta)
|
||||
weights = []
|
||||
p_min = self._it_min.min() / self._it_sum.sum()
|
||||
max_weight = (p_min * len(self._storage)) ** (-beta)
|
||||
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self._storage)) ** (-beta)
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
encoded_sample = self._encode_sample(idxes)
|
||||
return tuple(list(encoded_sample) + [weights, idxes])
|
||||
for idx in idxes:
|
||||
p_sample = self._it_sum[idx] / self._it_sum.sum()
|
||||
weight = (p_sample * len(self._storage)) ** (-beta)
|
||||
weights.append(weight / max_weight)
|
||||
weights = np.array(weights)
|
||||
encoded_sample = self._encode_sample(idxes)
|
||||
return tuple(list(encoded_sample) + [weights, idxes])
|
||||
|
||||
def update_priorities(self, idxes, priorities):
|
||||
"""Update priorities of sampled transitions.
|
||||
def update_priorities(self, idxes, priorities):
|
||||
"""Update priorities of sampled transitions.
|
||||
|
||||
sets priority of transition at index idxes[i] in buffer
|
||||
to priorities[i].
|
||||
sets priority of transition at index idxes[i] in buffer
|
||||
to priorities[i].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idxes: [int]
|
||||
List of idxes of sampled transitions
|
||||
priorities: [float]
|
||||
List of updated priorities corresponding to
|
||||
transitions at the sampled idxes denoted by
|
||||
variable `idxes`.
|
||||
"""
|
||||
assert len(idxes) == len(priorities)
|
||||
for idx, priority in zip(idxes, priorities):
|
||||
assert priority > 0
|
||||
assert 0 <= idx < len(self._storage)
|
||||
self._it_sum[idx] = priority ** self._alpha
|
||||
self._it_min[idx] = priority ** self._alpha
|
||||
Parameters
|
||||
----------
|
||||
idxes: [int]
|
||||
List of idxes of sampled transitions
|
||||
priorities: [float]
|
||||
List of updated priorities corresponding to
|
||||
transitions at the sampled idxes denoted by
|
||||
variable `idxes`.
|
||||
"""
|
||||
assert len(idxes) == len(priorities)
|
||||
for idx, priority in zip(idxes, priorities):
|
||||
assert priority > 0
|
||||
assert 0 <= idx < len(self._storage)
|
||||
self._it_sum[idx] = priority ** self._alpha
|
||||
self._it_min[idx] = priority ** self._alpha
|
||||
|
||||
self._max_priority = max(self._max_priority, priority)
|
||||
self._max_priority = max(self._max_priority, priority)
|
||||
|
||||
@@ -43,256 +43,267 @@ DEFAULT_CONFIG = dict(
|
||||
|
||||
@ray.remote
|
||||
def create_shared_noise():
|
||||
"""Create a large array of noise to be shared by all workers."""
|
||||
seed = 123
|
||||
count = 250000000
|
||||
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
|
||||
return noise
|
||||
"""Create a large array of noise to be shared by all workers."""
|
||||
seed = 123
|
||||
count = 250000000
|
||||
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
|
||||
return noise
|
||||
|
||||
|
||||
class SharedNoiseTable(object):
|
||||
def __init__(self, noise):
|
||||
self.noise = noise
|
||||
assert self.noise.dtype == np.float32
|
||||
def __init__(self, noise):
|
||||
self.noise = noise
|
||||
assert self.noise.dtype == np.float32
|
||||
|
||||
def get(self, i, dim):
|
||||
return self.noise[i:i + dim]
|
||||
def get(self, i, dim):
|
||||
return self.noise[i:i + dim]
|
||||
|
||||
def sample_index(self, stream, dim):
|
||||
return stream.randint(0, len(self.noise) - dim + 1)
|
||||
def sample_index(self, stream, dim):
|
||||
return stream.randint(0, len(self.noise) - dim + 1)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Worker(object):
|
||||
def __init__(self, config, policy_params, env_name, noise,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
self.policy_params = policy_params
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
def __init__(self, config, policy_params, env_name, noise,
|
||||
min_task_runtime=0.2):
|
||||
self.min_task_runtime = min_task_runtime
|
||||
self.config = config
|
||||
self.policy_params = policy_params
|
||||
self.noise = SharedNoiseTable(noise)
|
||||
|
||||
self.env = gym.make(env_name)
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.MujocoPolicy(self.env.observation_space,
|
||||
self.env.action_space,
|
||||
**policy_params)
|
||||
tf_util.initialize()
|
||||
self.env = gym.make(env_name)
|
||||
self.sess = utils.make_session(single_threaded=True)
|
||||
self.policy = policies.MujocoPolicy(self.env.observation_space,
|
||||
self.env.action_space,
|
||||
**policy_params)
|
||||
tf_util.initialize()
|
||||
|
||||
self.rs = np.random.RandomState()
|
||||
self.rs = np.random.RandomState()
|
||||
|
||||
assert self.policy.needs_ob_stat == (self.config["calc_obstat_prob"] != 0)
|
||||
assert self.policy.needs_ob_stat == (self.config["calc_obstat_prob"] !=
|
||||
0)
|
||||
|
||||
def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat):
|
||||
if (self.policy.needs_ob_stat and self.config["calc_obstat_prob"] != 0 and
|
||||
self.rs.rand() < self.config["calc_obstat_prob"]):
|
||||
rollout_rews, rollout_len, obs = self.policy.rollout(
|
||||
self.env, timestep_limit=timestep_limit, save_obs=True,
|
||||
random_stream=self.rs)
|
||||
task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0),
|
||||
len(obs))
|
||||
else:
|
||||
rollout_rews, rollout_len = self.policy.rollout(
|
||||
self.env, timestep_limit=timestep_limit, random_stream=self.rs)
|
||||
return rollout_rews, rollout_len
|
||||
def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat):
|
||||
if (self.policy.needs_ob_stat and
|
||||
self.config["calc_obstat_prob"] != 0 and
|
||||
self.rs.rand() < self.config["calc_obstat_prob"]):
|
||||
rollout_rews, rollout_len, obs = self.policy.rollout(
|
||||
self.env, timestep_limit=timestep_limit, save_obs=True,
|
||||
random_stream=self.rs)
|
||||
task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0),
|
||||
len(obs))
|
||||
else:
|
||||
rollout_rews, rollout_len = self.policy.rollout(
|
||||
self.env, timestep_limit=timestep_limit, random_stream=self.rs)
|
||||
return rollout_rews, rollout_len
|
||||
|
||||
def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
self.policy.set_trainable_flat(params)
|
||||
def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None):
|
||||
# Set the network weights.
|
||||
self.policy.set_trainable_flat(params)
|
||||
|
||||
if self.policy.needs_ob_stat:
|
||||
self.policy.set_ob_stat(ob_mean, ob_std)
|
||||
if self.policy.needs_ob_stat:
|
||||
self.policy.set_ob_stat(ob_mean, ob_std)
|
||||
|
||||
if self.config["eval_prob"] != 0:
|
||||
raise NotImplementedError("Eval rollouts are not implemented.")
|
||||
if self.config["eval_prob"] != 0:
|
||||
raise NotImplementedError("Eval rollouts are not implemented.")
|
||||
|
||||
noise_inds, returns, sign_returns, lengths = [], [], [], []
|
||||
# We set eps=0 because we're incrementing only.
|
||||
task_ob_stat = utils.RunningStat(self.env.observation_space.shape, eps=0)
|
||||
noise_inds, returns, sign_returns, lengths = [], [], [], []
|
||||
# We set eps=0 because we're incrementing only.
|
||||
task_ob_stat = utils.RunningStat(self.env.observation_space.shape,
|
||||
eps=0)
|
||||
|
||||
# Perform some rollouts with noise.
|
||||
task_tstart = time.time()
|
||||
while (len(noise_inds) == 0 or
|
||||
time.time() - task_tstart < self.min_task_runtime):
|
||||
noise_idx = self.noise.sample_index(self.rs, self.policy.num_params)
|
||||
perturbation = self.config["noise_stdev"] * self.noise.get(
|
||||
noise_idx, self.policy.num_params)
|
||||
# Perform some rollouts with noise.
|
||||
task_tstart = time.time()
|
||||
while (len(noise_inds) == 0 or
|
||||
time.time() - task_tstart < self.min_task_runtime):
|
||||
noise_idx = self.noise.sample_index(self.rs,
|
||||
self.policy.num_params)
|
||||
perturbation = self.config["noise_stdev"] * self.noise.get(
|
||||
noise_idx, self.policy.num_params)
|
||||
|
||||
# These two sampling steps could be done in parallel on different actors
|
||||
# letting us update twice as frequently.
|
||||
self.policy.set_trainable_flat(params + perturbation)
|
||||
rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit,
|
||||
task_ob_stat)
|
||||
# These two sampling steps could be done in parallel on different
|
||||
# actors letting us update twice as frequently.
|
||||
self.policy.set_trainable_flat(params + perturbation)
|
||||
rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit,
|
||||
task_ob_stat)
|
||||
|
||||
self.policy.set_trainable_flat(params - perturbation)
|
||||
rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit,
|
||||
task_ob_stat)
|
||||
self.policy.set_trainable_flat(params - perturbation)
|
||||
rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit,
|
||||
task_ob_stat)
|
||||
|
||||
noise_inds.append(noise_idx)
|
||||
returns.append([rews_pos.sum(), rews_neg.sum()])
|
||||
sign_returns.append([np.sign(rews_pos).sum(), np.sign(rews_neg).sum()])
|
||||
lengths.append([len_pos, len_neg])
|
||||
noise_inds.append(noise_idx)
|
||||
returns.append([rews_pos.sum(), rews_neg.sum()])
|
||||
sign_returns.append([np.sign(rews_pos).sum(),
|
||||
np.sign(rews_neg).sum()])
|
||||
lengths.append([len_pos, len_neg])
|
||||
|
||||
return Result(
|
||||
noise_inds_n=np.array(noise_inds),
|
||||
returns_n2=np.array(returns, dtype=np.float32),
|
||||
sign_returns_n2=np.array(sign_returns, dtype=np.float32),
|
||||
lengths_n2=np.array(lengths, dtype=np.int32),
|
||||
eval_return=None,
|
||||
eval_length=None,
|
||||
ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum),
|
||||
ob_sumsq=(None if task_ob_stat.count == 0 else task_ob_stat.sumsq),
|
||||
ob_count=task_ob_stat.count)
|
||||
return Result(
|
||||
noise_inds_n=np.array(noise_inds),
|
||||
returns_n2=np.array(returns, dtype=np.float32),
|
||||
sign_returns_n2=np.array(sign_returns, dtype=np.float32),
|
||||
lengths_n2=np.array(lengths, dtype=np.int32),
|
||||
eval_return=None,
|
||||
eval_length=None,
|
||||
ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum),
|
||||
ob_sumsq=(None if task_ob_stat.count == 0
|
||||
else task_ob_stat.sumsq),
|
||||
ob_count=task_ob_stat.count)
|
||||
|
||||
|
||||
class EvolutionStrategies(Algorithm):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "EvolutionStrategies"})
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "EvolutionStrategies"})
|
||||
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
policy_params = {
|
||||
"ac_bins": "continuous:",
|
||||
"ac_noise_std": 0.01,
|
||||
"nonlin_type": "tanh",
|
||||
"hidden_dims": [256, 256],
|
||||
"connection_type": "ff"
|
||||
}
|
||||
policy_params = {
|
||||
"ac_bins": "continuous:",
|
||||
"ac_noise_std": 0.01,
|
||||
"nonlin_type": "tanh",
|
||||
"hidden_dims": [256, 256],
|
||||
"connection_type": "ff"
|
||||
}
|
||||
|
||||
# Create the shared noise table.
|
||||
print("Creating shared noise table.")
|
||||
noise_id = create_shared_noise.remote()
|
||||
self.noise = SharedNoiseTable(ray.get(noise_id))
|
||||
# Create the shared noise table.
|
||||
print("Creating shared noise table.")
|
||||
noise_id = create_shared_noise.remote()
|
||||
self.noise = SharedNoiseTable(ray.get(noise_id))
|
||||
|
||||
# Create the actors.
|
||||
print("Creating actors.")
|
||||
self.workers = [Worker.remote(config, policy_params, env_name, noise_id)
|
||||
for _ in range(config["num_workers"])]
|
||||
# Create the actors.
|
||||
print("Creating actors.")
|
||||
self.workers = [Worker.remote(config, policy_params, env_name,
|
||||
noise_id)
|
||||
for _ in range(config["num_workers"])]
|
||||
|
||||
env = gym.make(env_name)
|
||||
utils.make_session(single_threaded=False)
|
||||
self.policy = policies.MujocoPolicy(
|
||||
env.observation_space, env.action_space, **policy_params)
|
||||
tf_util.initialize()
|
||||
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
|
||||
self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
|
||||
env = gym.make(env_name)
|
||||
utils.make_session(single_threaded=False)
|
||||
self.policy = policies.MujocoPolicy(
|
||||
env.observation_space, env.action_space, **policy_params)
|
||||
tf_util.initialize()
|
||||
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
|
||||
self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
|
||||
|
||||
self.episodes_so_far = 0
|
||||
self.timesteps_so_far = 0
|
||||
self.tstart = time.time()
|
||||
self.iteration = 0
|
||||
self.episodes_so_far = 0
|
||||
self.timesteps_so_far = 0
|
||||
self.tstart = time.time()
|
||||
self.iteration = 0
|
||||
|
||||
def train(self):
|
||||
config = self.config
|
||||
def train(self):
|
||||
config = self.config
|
||||
|
||||
step_tstart = time.time()
|
||||
theta = self.policy.get_trainable_flat()
|
||||
assert theta.dtype == np.float32
|
||||
step_tstart = time.time()
|
||||
theta = self.policy.get_trainable_flat()
|
||||
assert theta.dtype == np.float32
|
||||
|
||||
# Put the current policy weights in the object store.
|
||||
theta_id = ray.put(theta)
|
||||
# Use the actors to do rollouts, note that we pass in the ID of the policy
|
||||
# weights.
|
||||
rollout_ids = [worker.do_rollouts.remote(
|
||||
theta_id,
|
||||
self.ob_stat.mean if self.policy.needs_ob_stat else None,
|
||||
self.ob_stat.std if self.policy.needs_ob_stat else None)
|
||||
for worker in self.workers]
|
||||
# Put the current policy weights in the object store.
|
||||
theta_id = ray.put(theta)
|
||||
# Use the actors to do rollouts, note that we pass in the ID of the
|
||||
# policy weights.
|
||||
rollout_ids = [worker.do_rollouts.remote(
|
||||
theta_id,
|
||||
self.ob_stat.mean if self.policy.needs_ob_stat else None,
|
||||
self.ob_stat.std if self.policy.needs_ob_stat else None)
|
||||
for worker in self.workers]
|
||||
|
||||
# Get the results of the rollouts.
|
||||
results = ray.get(rollout_ids)
|
||||
# Get the results of the rollouts.
|
||||
results = ray.get(rollout_ids)
|
||||
|
||||
curr_task_results = []
|
||||
ob_count_this_batch = 0
|
||||
# Loop over the results
|
||||
for result in results:
|
||||
assert result.eval_length is None, "We aren't doing eval rollouts."
|
||||
assert result.noise_inds_n.ndim == 1
|
||||
assert result.returns_n2.shape == (len(result.noise_inds_n), 2)
|
||||
assert result.lengths_n2.shape == (len(result.noise_inds_n), 2)
|
||||
assert result.returns_n2.dtype == np.float32
|
||||
curr_task_results = []
|
||||
ob_count_this_batch = 0
|
||||
# Loop over the results
|
||||
for result in results:
|
||||
assert result.eval_length is None, "We aren't doing eval rollouts."
|
||||
assert result.noise_inds_n.ndim == 1
|
||||
assert result.returns_n2.shape == (len(result.noise_inds_n), 2)
|
||||
assert result.lengths_n2.shape == (len(result.noise_inds_n), 2)
|
||||
assert result.returns_n2.dtype == np.float32
|
||||
|
||||
result_num_eps = result.lengths_n2.size
|
||||
result_num_timesteps = result.lengths_n2.sum()
|
||||
self.episodes_so_far += result_num_eps
|
||||
self.timesteps_so_far += result_num_timesteps
|
||||
result_num_eps = result.lengths_n2.size
|
||||
result_num_timesteps = result.lengths_n2.sum()
|
||||
self.episodes_so_far += result_num_eps
|
||||
self.timesteps_so_far += result_num_timesteps
|
||||
|
||||
curr_task_results.append(result)
|
||||
# Update ob stats.
|
||||
if self.policy.needs_ob_stat and result.ob_count > 0:
|
||||
self.ob_stat.increment(result.ob_sum, result.ob_sumsq, result.ob_count)
|
||||
ob_count_this_batch += result.ob_count
|
||||
curr_task_results.append(result)
|
||||
# Update ob stats.
|
||||
if self.policy.needs_ob_stat and result.ob_count > 0:
|
||||
self.ob_stat.increment(result.ob_sum, result.ob_sumsq,
|
||||
result.ob_count)
|
||||
ob_count_this_batch += result.ob_count
|
||||
|
||||
# Assemble the results.
|
||||
noise_inds_n = np.concatenate([r.noise_inds_n for
|
||||
r in curr_task_results])
|
||||
returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results])
|
||||
lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results])
|
||||
assert noise_inds_n.shape[0] == returns_n2.shape[0] == lengths_n2.shape[0]
|
||||
# Process the returns.
|
||||
if config["return_proc_mode"] == "centered_rank":
|
||||
proc_returns_n2 = utils.compute_centered_ranks(returns_n2)
|
||||
else:
|
||||
raise NotImplementedError(config["return_proc_mode"])
|
||||
# Assemble the results.
|
||||
noise_inds_n = np.concatenate([r.noise_inds_n for
|
||||
r in curr_task_results])
|
||||
returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results])
|
||||
lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results])
|
||||
assert (noise_inds_n.shape[0] == returns_n2.shape[0] ==
|
||||
lengths_n2.shape[0])
|
||||
# Process the returns.
|
||||
if config["return_proc_mode"] == "centered_rank":
|
||||
proc_returns_n2 = utils.compute_centered_ranks(returns_n2)
|
||||
else:
|
||||
raise NotImplementedError(config["return_proc_mode"])
|
||||
|
||||
# Compute and take a step.
|
||||
g, count = utils.batched_weighted_sum(
|
||||
proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
|
||||
(self.noise.get(idx, self.policy.num_params) for idx in noise_inds_n),
|
||||
batch_size=500)
|
||||
g /= returns_n2.size
|
||||
assert (g.shape == (self.policy.num_params,) and g.dtype == np.float32 and
|
||||
count == len(noise_inds_n))
|
||||
update_ratio = self.optimizer.update(-g + config["l2coeff"] * theta)
|
||||
# Compute and take a step.
|
||||
g, count = utils.batched_weighted_sum(
|
||||
proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
|
||||
(self.noise.get(idx, self.policy.num_params)
|
||||
for idx in noise_inds_n),
|
||||
batch_size=500)
|
||||
g /= returns_n2.size
|
||||
assert (g.shape == (self.policy.num_params,) and
|
||||
g.dtype == np.float32 and
|
||||
count == len(noise_inds_n))
|
||||
update_ratio = self.optimizer.update(-g + config["l2coeff"] * theta)
|
||||
|
||||
# Update ob stat (we're never running the policy in the master, but we
|
||||
# might be snapshotting the policy).
|
||||
if self.policy.needs_ob_stat:
|
||||
self.policy.set_ob_stat(self.ob_stat.mean, self.ob_stat.std)
|
||||
# Update ob stat (we're never running the policy in the master, but we
|
||||
# might be snapshotting the policy).
|
||||
if self.policy.needs_ob_stat:
|
||||
self.policy.set_ob_stat(self.ob_stat.mean, self.ob_stat.std)
|
||||
|
||||
step_tend = time.time()
|
||||
tlogger.record_tabular("EpRewMean", returns_n2.mean())
|
||||
tlogger.record_tabular("EpRewStd", returns_n2.std())
|
||||
tlogger.record_tabular("EpLenMean", lengths_n2.mean())
|
||||
step_tend = time.time()
|
||||
tlogger.record_tabular("EpRewMean", returns_n2.mean())
|
||||
tlogger.record_tabular("EpRewStd", returns_n2.std())
|
||||
tlogger.record_tabular("EpLenMean", lengths_n2.mean())
|
||||
|
||||
tlogger.record_tabular(
|
||||
"Norm", float(np.square(self.policy.get_trainable_flat()).sum()))
|
||||
tlogger.record_tabular("GradNorm", float(np.square(g).sum()))
|
||||
tlogger.record_tabular("UpdateRatio", float(update_ratio))
|
||||
tlogger.record_tabular(
|
||||
"Norm", float(np.square(self.policy.get_trainable_flat()).sum()))
|
||||
tlogger.record_tabular("GradNorm", float(np.square(g).sum()))
|
||||
tlogger.record_tabular("UpdateRatio", float(update_ratio))
|
||||
|
||||
tlogger.record_tabular("EpisodesThisIter", lengths_n2.size)
|
||||
tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far)
|
||||
tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum())
|
||||
tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
|
||||
tlogger.record_tabular("EpisodesThisIter", lengths_n2.size)
|
||||
tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far)
|
||||
tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum())
|
||||
tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
|
||||
|
||||
tlogger.record_tabular("ObCount", ob_count_this_batch)
|
||||
tlogger.record_tabular("ObCount", ob_count_this_batch)
|
||||
|
||||
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
|
||||
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
|
||||
tlogger.dump_tabular()
|
||||
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
|
||||
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
|
||||
tlogger.dump_tabular()
|
||||
|
||||
if (config["snapshot_freq"] != 0 and
|
||||
self.iteration % config["snapshot_freq"] == 0):
|
||||
filename = os.path.join(
|
||||
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
|
||||
assert not os.path.exists(filename)
|
||||
self.policy.save(filename)
|
||||
tlogger.log("Saved snapshot {}".format(filename))
|
||||
if (config["snapshot_freq"] != 0 and
|
||||
self.iteration % config["snapshot_freq"] == 0):
|
||||
filename = os.path.join(
|
||||
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
|
||||
assert not os.path.exists(filename)
|
||||
self.policy.save(filename)
|
||||
tlogger.log("Saved snapshot {}".format(filename))
|
||||
|
||||
info = {
|
||||
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
|
||||
"grad_norm": np.square(g).sum(),
|
||||
"update_ratio": update_ratio,
|
||||
"episodes_this_iter": lengths_n2.size,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
"timesteps_this_iter": lengths_n2.sum(),
|
||||
"timesteps_so_far": self.timesteps_so_far,
|
||||
"ob_count": ob_count_this_batch,
|
||||
"time_elapsed_this_iter": step_tend - step_tstart,
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
res = TrainingResult(self.experiment_id.hex, self.iteration,
|
||||
returns_n2.mean(), lengths_n2.mean(), info)
|
||||
info = {
|
||||
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
|
||||
"grad_norm": np.square(g).sum(),
|
||||
"update_ratio": update_ratio,
|
||||
"episodes_this_iter": lengths_n2.size,
|
||||
"episodes_so_far": self.episodes_so_far,
|
||||
"timesteps_this_iter": lengths_n2.sum(),
|
||||
"timesteps_so_far": self.timesteps_so_far,
|
||||
"ob_count": ob_count_this_batch,
|
||||
"time_elapsed_this_iter": step_tend - step_tstart,
|
||||
"time_elapsed": step_tend - self.tstart
|
||||
}
|
||||
res = TrainingResult(self.experiment_id.hex, self.iteration,
|
||||
returns_n2.mean(), lengths_n2.mean(), info)
|
||||
|
||||
self.iteration += 1
|
||||
self.iteration += 1
|
||||
|
||||
return res
|
||||
return res
|
||||
|
||||
@@ -11,30 +11,30 @@ from ray.rllib.evolution_strategies import EvolutionStrategies, DEFAULT_CONFIG
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Train an RL agent on Pong.")
|
||||
parser.add_argument("--num-workers", default=10, type=int,
|
||||
help=("The number of actors to create in aggregate "
|
||||
"across the cluster."))
|
||||
parser.add_argument("--env-name", default="Pendulum-v0", type=str,
|
||||
help="The name of the gym environment to use.")
|
||||
parser.add_argument("--stepsize", default=0.01, type=float,
|
||||
help="The stepsize to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser = argparse.ArgumentParser(description="Train an RL agent on Pong.")
|
||||
parser.add_argument("--num-workers", default=10, type=int,
|
||||
help=("The number of actors to create in aggregate "
|
||||
"across the cluster."))
|
||||
parser.add_argument("--env-name", default="Pendulum-v0", type=str,
|
||||
help="The name of the gym environment to use.")
|
||||
parser.add_argument("--stepsize", default=0.01, type=float,
|
||||
help="The stepsize to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
|
||||
args = parser.parse_args()
|
||||
num_workers = args.num_workers
|
||||
env_name = args.env_name
|
||||
stepsize = args.stepsize
|
||||
args = parser.parse_args()
|
||||
num_workers = args.num_workers
|
||||
env_name = args.env_name
|
||||
stepsize = args.stepsize
|
||||
|
||||
ray.init(redis_address=args.redis_address,
|
||||
num_workers=(0 if args.redis_address is None else None))
|
||||
ray.init(redis_address=args.redis_address,
|
||||
num_workers=(0 if args.redis_address is None else None))
|
||||
|
||||
config = DEFAULT_CONFIG._replace(
|
||||
num_workers=num_workers,
|
||||
stepsize=stepsize)
|
||||
config = DEFAULT_CONFIG._replace(
|
||||
num_workers=num_workers,
|
||||
stepsize=stepsize)
|
||||
|
||||
alg = EvolutionStrategies(env_name, config)
|
||||
while True:
|
||||
result = alg.train()
|
||||
print("current status: {}".format(result))
|
||||
alg = EvolutionStrategies(env_name, config)
|
||||
while True:
|
||||
result = alg.train()
|
||||
print("current status: {}".format(result))
|
||||
|
||||
@@ -9,49 +9,49 @@ import numpy as np
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
def __init__(self, pi):
|
||||
self.pi = pi
|
||||
self.dim = pi.num_params
|
||||
self.t = 0
|
||||
def __init__(self, pi):
|
||||
self.pi = pi
|
||||
self.dim = pi.num_params
|
||||
self.t = 0
|
||||
|
||||
def update(self, globalg):
|
||||
self.t += 1
|
||||
step = self._compute_step(globalg)
|
||||
theta = self.pi.get_trainable_flat()
|
||||
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
|
||||
self.pi.set_trainable_flat(theta + step)
|
||||
return ratio
|
||||
def update(self, globalg):
|
||||
self.t += 1
|
||||
step = self._compute_step(globalg)
|
||||
theta = self.pi.get_trainable_flat()
|
||||
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
|
||||
self.pi.set_trainable_flat(theta + step)
|
||||
return ratio
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
raise NotImplementedError
|
||||
def _compute_step(self, globalg):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
def __init__(self, pi, stepsize, momentum=0.9):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
self.stepsize, self.momentum = stepsize, momentum
|
||||
def __init__(self, pi, stepsize, momentum=0.9):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
self.stepsize, self.momentum = stepsize, momentum
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
|
||||
step = -self.stepsize * self.v
|
||||
return step
|
||||
def _compute_step(self, globalg):
|
||||
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
|
||||
step = -self.stepsize * self.v
|
||||
return step
|
||||
|
||||
|
||||
class Adam(Optimizer):
|
||||
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.stepsize = stepsize
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.m = np.zeros(self.dim, dtype=np.float32)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
|
||||
Optimizer.__init__(self, pi)
|
||||
self.stepsize = stepsize
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.epsilon = epsilon
|
||||
self.m = np.zeros(self.dim, dtype=np.float32)
|
||||
self.v = np.zeros(self.dim, dtype=np.float32)
|
||||
|
||||
def _compute_step(self, globalg):
|
||||
a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) /
|
||||
(1 - self.beta1 ** self.t))
|
||||
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
|
||||
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
|
||||
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
|
||||
return step
|
||||
def _compute_step(self, globalg):
|
||||
a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) /
|
||||
(1 - self.beta1 ** self.t))
|
||||
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
|
||||
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
|
||||
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
|
||||
return step
|
||||
|
||||
@@ -18,224 +18,235 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Policy:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args, self.kwargs = args, kwargs
|
||||
self.scope = self._initialize(*args, **kwargs)
|
||||
self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
|
||||
self.scope.name)
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args, self.kwargs = args, kwargs
|
||||
self.scope = self._initialize(*args, **kwargs)
|
||||
self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
|
||||
self.scope.name)
|
||||
|
||||
self.trainable_variables = tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name)
|
||||
self.num_params = sum(int(np.prod(v.get_shape().as_list()))
|
||||
for v in self.trainable_variables)
|
||||
self._setfromflat = U.SetFromFlat(self.trainable_variables)
|
||||
self._getflat = U.GetFlat(self.trainable_variables)
|
||||
self.trainable_variables = tf.get_collection(
|
||||
tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name)
|
||||
self.num_params = sum(int(np.prod(v.get_shape().as_list()))
|
||||
for v in self.trainable_variables)
|
||||
self._setfromflat = U.SetFromFlat(self.trainable_variables)
|
||||
self._getflat = U.GetFlat(self.trainable_variables)
|
||||
|
||||
logger.info('Trainable variables ({} parameters)'.format(self.num_params))
|
||||
for v in self.trainable_variables:
|
||||
shp = v.get_shape().as_list()
|
||||
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
|
||||
logger.info('All variables')
|
||||
for v in self.all_variables:
|
||||
shp = v.get_shape().as_list()
|
||||
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
|
||||
logger.info('Trainable variables ({} parameters)'
|
||||
.format(self.num_params))
|
||||
for v in self.trainable_variables:
|
||||
shp = v.get_shape().as_list()
|
||||
logger.info('- {} shape:{} size:{}'.format(v.name, shp,
|
||||
np.prod(shp)))
|
||||
logger.info('All variables')
|
||||
for v in self.all_variables:
|
||||
shp = v.get_shape().as_list()
|
||||
logger.info('- {} shape:{} size:{}'.format(v.name, shp,
|
||||
np.prod(shp)))
|
||||
|
||||
placeholders = [tf.placeholder(v.value().dtype, v.get_shape().as_list())
|
||||
for v in self.all_variables]
|
||||
self.set_all_vars = U.function(
|
||||
inputs=placeholders,
|
||||
outputs=[],
|
||||
updates=[tf.group(*[v.assign(p) for v, p
|
||||
in zip(self.all_variables, placeholders)])]
|
||||
)
|
||||
placeholders = [tf.placeholder(v.value().dtype,
|
||||
v.get_shape().as_list())
|
||||
for v in self.all_variables]
|
||||
self.set_all_vars = U.function(
|
||||
inputs=placeholders,
|
||||
outputs=[],
|
||||
updates=[tf.group(*[v.assign(p) for v, p
|
||||
in zip(self.all_variables, placeholders)])]
|
||||
)
|
||||
|
||||
def _initialize(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
def _initialize(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, filename):
|
||||
assert filename.endswith('.h5')
|
||||
with h5py.File(filename, 'w') as f:
|
||||
for v in self.all_variables:
|
||||
f[v.name] = v.eval()
|
||||
# TODO: It would be nice to avoid pickle, but it's convenient to pass
|
||||
# Python objects to _initialize (like Gym spaces or numpy arrays).
|
||||
f.attrs['name'] = type(self).__name__
|
||||
f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args,
|
||||
self.kwargs),
|
||||
protocol=-1))
|
||||
def save(self, filename):
|
||||
assert filename.endswith('.h5')
|
||||
with h5py.File(filename, 'w') as f:
|
||||
for v in self.all_variables:
|
||||
f[v.name] = v.eval()
|
||||
# TODO: It would be nice to avoid pickle, but it's convenient to
|
||||
# pass Python objects to _initialize (like Gym spaces or numpy
|
||||
# arrays).
|
||||
f.attrs['name'] = type(self).__name__
|
||||
f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args,
|
||||
self.kwargs),
|
||||
protocol=-1))
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filename, extra_kwargs=None):
|
||||
with h5py.File(filename, 'r') as f:
|
||||
args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring())
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
policy = cls(*args, **kwargs)
|
||||
policy.set_all_vars(*[f[v.name][...] for v in policy.all_variables])
|
||||
return policy
|
||||
@classmethod
|
||||
def Load(cls, filename, extra_kwargs=None):
|
||||
with h5py.File(filename, 'r') as f:
|
||||
args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring())
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
policy = cls(*args, **kwargs)
|
||||
policy.set_all_vars(*[f[v.name][...]
|
||||
for v in policy.all_variables])
|
||||
return policy
|
||||
|
||||
# === Rollouts/training ===
|
||||
# === Rollouts/training ===
|
||||
|
||||
def rollout(self, env, render=False, timestep_limit=None, save_obs=False,
|
||||
random_stream=None):
|
||||
"""Do a rollout.
|
||||
def rollout(self, env, render=False, timestep_limit=None, save_obs=False,
|
||||
random_stream=None):
|
||||
"""Do a rollout.
|
||||
|
||||
If random_stream is provided, the rollout will take noisy actions with
|
||||
noise drawn from that stream. Otherwise, no action noise will be added.
|
||||
"""
|
||||
env_timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
timestep_limit = (env_timestep_limit if timestep_limit is None
|
||||
else min(timestep_limit, env_timestep_limit))
|
||||
rews = []
|
||||
t = 0
|
||||
if save_obs:
|
||||
obs = []
|
||||
ob = env.reset()
|
||||
for _ in range(timestep_limit):
|
||||
ac = self.act(ob[None], random_stream=random_stream)[0]
|
||||
if save_obs:
|
||||
obs.append(ob)
|
||||
ob, rew, done, _ = env.step(ac)
|
||||
rews.append(rew)
|
||||
t += 1
|
||||
if render:
|
||||
env.render()
|
||||
if done:
|
||||
break
|
||||
rews = np.array(rews, dtype=np.float32)
|
||||
if save_obs:
|
||||
return rews, t, np.array(obs)
|
||||
return rews, t
|
||||
If random_stream is provided, the rollout will take noisy actions with
|
||||
noise drawn from that stream. Otherwise, no action noise will be added.
|
||||
"""
|
||||
env_timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
|
||||
".max_episode_steps")
|
||||
timestep_limit = (env_timestep_limit if timestep_limit is None
|
||||
else min(timestep_limit, env_timestep_limit))
|
||||
rews = []
|
||||
t = 0
|
||||
if save_obs:
|
||||
obs = []
|
||||
ob = env.reset()
|
||||
for _ in range(timestep_limit):
|
||||
ac = self.act(ob[None], random_stream=random_stream)[0]
|
||||
if save_obs:
|
||||
obs.append(ob)
|
||||
ob, rew, done, _ = env.step(ac)
|
||||
rews.append(rew)
|
||||
t += 1
|
||||
if render:
|
||||
env.render()
|
||||
if done:
|
||||
break
|
||||
rews = np.array(rews, dtype=np.float32)
|
||||
if save_obs:
|
||||
return rews, t, np.array(obs)
|
||||
return rews, t
|
||||
|
||||
def act(self, ob, random_stream=None):
|
||||
raise NotImplementedError
|
||||
def act(self, ob, random_stream=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_trainable_flat(self, x):
|
||||
self._setfromflat(x)
|
||||
def set_trainable_flat(self, x):
|
||||
self._setfromflat(x)
|
||||
|
||||
def get_trainable_flat(self):
|
||||
return self._getflat()
|
||||
def get_trainable_flat(self):
|
||||
return self._getflat()
|
||||
|
||||
@property
|
||||
def needs_ob_stat(self):
|
||||
raise NotImplementedError
|
||||
@property
|
||||
def needs_ob_stat(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_ob_stat(self, ob_mean, ob_std):
|
||||
raise NotImplementedError
|
||||
def set_ob_stat(self, ob_mean, ob_std):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def bins(x, dim, num_bins, name):
|
||||
scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01))
|
||||
scores_nab = tf.reshape(scores, [-1, dim, num_bins])
|
||||
return tf.argmax(scores_nab, 2)
|
||||
scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01))
|
||||
scores_nab = tf.reshape(scores, [-1, dim, num_bins])
|
||||
return tf.argmax(scores_nab, 2)
|
||||
|
||||
|
||||
class MujocoPolicy(Policy):
|
||||
def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std, nonlin_type,
|
||||
hidden_dims, connection_type):
|
||||
self.ac_space = ac_space
|
||||
self.ac_bins = ac_bins
|
||||
self.ac_noise_std = ac_noise_std
|
||||
self.hidden_dims = hidden_dims
|
||||
self.connection_type = connection_type
|
||||
def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std,
|
||||
nonlin_type, hidden_dims, connection_type):
|
||||
self.ac_space = ac_space
|
||||
self.ac_bins = ac_bins
|
||||
self.ac_noise_std = ac_noise_std
|
||||
self.hidden_dims = hidden_dims
|
||||
self.connection_type = connection_type
|
||||
|
||||
assert len(ob_space.shape) == len(self.ac_space.shape) == 1
|
||||
assert (np.all(np.isfinite(self.ac_space.low)) and
|
||||
np.all(np.isfinite(self.ac_space.high))), "Action bounds required"
|
||||
assert len(ob_space.shape) == len(self.ac_space.shape) == 1
|
||||
assert (np.all(np.isfinite(self.ac_space.low)) and
|
||||
np.all(np.isfinite(self.ac_space.high))), ("Action bounds "
|
||||
"required")
|
||||
|
||||
self.nonlin = {'tanh': tf.tanh,
|
||||
'relu': tf.nn.relu,
|
||||
'lrelu': U.lrelu,
|
||||
'elu': tf.nn.elu}[nonlin_type]
|
||||
self.nonlin = {'tanh': tf.tanh,
|
||||
'relu': tf.nn.relu,
|
||||
'lrelu': U.lrelu,
|
||||
'elu': tf.nn.elu}[nonlin_type]
|
||||
|
||||
with tf.variable_scope(type(self).__name__) as scope:
|
||||
# Observation normalization.
|
||||
ob_mean = tf.get_variable(
|
||||
'ob_mean', ob_space.shape, tf.float32,
|
||||
tf.constant_initializer(np.nan), trainable=False)
|
||||
ob_std = tf.get_variable(
|
||||
'ob_std', ob_space.shape, tf.float32,
|
||||
tf.constant_initializer(np.nan), trainable=False)
|
||||
in_mean = tf.placeholder(tf.float32, ob_space.shape)
|
||||
in_std = tf.placeholder(tf.float32, ob_space.shape)
|
||||
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
|
||||
tf.assign(ob_mean, in_mean),
|
||||
tf.assign(ob_std, in_std),
|
||||
])
|
||||
with tf.variable_scope(type(self).__name__) as scope:
|
||||
# Observation normalization.
|
||||
ob_mean = tf.get_variable(
|
||||
'ob_mean', ob_space.shape, tf.float32,
|
||||
tf.constant_initializer(np.nan), trainable=False)
|
||||
ob_std = tf.get_variable(
|
||||
'ob_std', ob_space.shape, tf.float32,
|
||||
tf.constant_initializer(np.nan), trainable=False)
|
||||
in_mean = tf.placeholder(tf.float32, ob_space.shape)
|
||||
in_std = tf.placeholder(tf.float32, ob_space.shape)
|
||||
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
|
||||
tf.assign(ob_mean, in_mean),
|
||||
tf.assign(ob_std, in_std),
|
||||
])
|
||||
|
||||
# Policy network.
|
||||
o = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
|
||||
a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std, -5.0, 5.0))
|
||||
self._act = U.function([o], a)
|
||||
return scope
|
||||
# Policy network.
|
||||
o = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
|
||||
a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std,
|
||||
-5.0, 5.0))
|
||||
self._act = U.function([o], a)
|
||||
return scope
|
||||
|
||||
def _make_net(self, o):
|
||||
# Process observation.
|
||||
if self.connection_type == 'ff':
|
||||
x = o
|
||||
for ilayer, hd in enumerate(self.hidden_dims):
|
||||
x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer),
|
||||
U.normc_initializer(1.0)))
|
||||
else:
|
||||
raise NotImplementedError(self.connection_type)
|
||||
def _make_net(self, o):
|
||||
# Process observation.
|
||||
if self.connection_type == 'ff':
|
||||
x = o
|
||||
for ilayer, hd in enumerate(self.hidden_dims):
|
||||
x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer),
|
||||
U.normc_initializer(1.0)))
|
||||
else:
|
||||
raise NotImplementedError(self.connection_type)
|
||||
|
||||
# Map to action.
|
||||
adim = self.ac_space.shape[0]
|
||||
ahigh = self.ac_space.high
|
||||
alow = self.ac_space.low
|
||||
assert isinstance(self.ac_bins, str)
|
||||
ac_bin_mode, ac_bin_arg = self.ac_bins.split(':')
|
||||
# Map to action.
|
||||
adim = self.ac_space.shape[0]
|
||||
ahigh = self.ac_space.high
|
||||
alow = self.ac_space.low
|
||||
assert isinstance(self.ac_bins, str)
|
||||
ac_bin_mode, ac_bin_arg = self.ac_bins.split(':')
|
||||
|
||||
if ac_bin_mode == 'uniform':
|
||||
# Uniformly spaced bins, from ac_space.low to ac_space.high.
|
||||
num_ac_bins = int(ac_bin_arg)
|
||||
aidx_na = bins(x, adim, num_ac_bins, 'out')
|
||||
ac_range_1a = (ahigh - alow)[None, :]
|
||||
a = (1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a +
|
||||
alow[None, :])
|
||||
if ac_bin_mode == 'uniform':
|
||||
# Uniformly spaced bins, from ac_space.low to ac_space.high.
|
||||
num_ac_bins = int(ac_bin_arg)
|
||||
aidx_na = bins(x, adim, num_ac_bins, 'out')
|
||||
ac_range_1a = (ahigh - alow)[None, :]
|
||||
a = (1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a +
|
||||
alow[None, :])
|
||||
|
||||
elif ac_bin_mode == 'custom':
|
||||
# Custom bins specified as a list of values from -1 to 1.
|
||||
# The bins are rescaled to ac_space.low to ac_space.high.
|
||||
acvals_k = np.array(list(map(float, ac_bin_arg.split(','))),
|
||||
dtype=np.float32)
|
||||
logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x)
|
||||
for x in acvals_k))
|
||||
assert acvals_k.ndim == 1 and acvals_k[0] == -1 and acvals_k[-1] == 1
|
||||
acvals_ak = ((ahigh - alow)[:, None] / (acvals_k[-1] - acvals_k[0]) *
|
||||
(acvals_k - acvals_k[0])[None, :] + alow[:, None])
|
||||
elif ac_bin_mode == 'custom':
|
||||
# Custom bins specified as a list of values from -1 to 1.
|
||||
# The bins are rescaled to ac_space.low to ac_space.high.
|
||||
acvals_k = np.array(list(map(float, ac_bin_arg.split(','))),
|
||||
dtype=np.float32)
|
||||
logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x)
|
||||
for x in acvals_k))
|
||||
assert (acvals_k.ndim == 1 and acvals_k[0] == -1 and
|
||||
acvals_k[-1] == 1)
|
||||
acvals_ak = ((ahigh - alow)[:, None] /
|
||||
(acvals_k[-1] - acvals_k[0]) *
|
||||
(acvals_k - acvals_k[0])[None, :] + alow[:, None])
|
||||
|
||||
aidx_na = bins(x, adim, len(acvals_k), 'out') # Values in [0, k-1].
|
||||
a = tf.gather_nd(
|
||||
acvals_ak,
|
||||
tf.concat([
|
||||
tf.tile(np.arange(adim)[None, :, None],
|
||||
[tf.shape(aidx_na)[0], 1, 1]),
|
||||
2,
|
||||
tf.expand_dims(aidx_na, -1)
|
||||
]) # (n, a, 2)
|
||||
) # (n, a)
|
||||
elif ac_bin_mode == 'continuous':
|
||||
a = U.dense(x, adim, 'out', U.normc_initializer(0.01))
|
||||
else:
|
||||
raise NotImplementedError(ac_bin_mode)
|
||||
aidx_na = bins(x, adim, len(acvals_k),
|
||||
'out') # Values in [0, k-1].
|
||||
a = tf.gather_nd(
|
||||
acvals_ak,
|
||||
tf.concat([
|
||||
tf.tile(np.arange(adim)[None, :, None],
|
||||
[tf.shape(aidx_na)[0], 1, 1]),
|
||||
2,
|
||||
tf.expand_dims(aidx_na, -1)
|
||||
]) # (n, a, 2)
|
||||
) # (n, a)
|
||||
elif ac_bin_mode == 'continuous':
|
||||
a = U.dense(x, adim, 'out', U.normc_initializer(0.01))
|
||||
else:
|
||||
raise NotImplementedError(ac_bin_mode)
|
||||
|
||||
return a
|
||||
return a
|
||||
|
||||
def act(self, ob, random_stream=None):
|
||||
a = self._act(ob)
|
||||
if random_stream is not None and self.ac_noise_std != 0:
|
||||
a += random_stream.randn(*a.shape) * self.ac_noise_std
|
||||
return a
|
||||
def act(self, ob, random_stream=None):
|
||||
a = self._act(ob)
|
||||
if random_stream is not None and self.ac_noise_std != 0:
|
||||
a += random_stream.randn(*a.shape) * self.ac_noise_std
|
||||
return a
|
||||
|
||||
@property
|
||||
def needs_ob_stat(self):
|
||||
return True
|
||||
@property
|
||||
def needs_ob_stat(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def needs_ref_batch(self):
|
||||
return False
|
||||
@property
|
||||
def needs_ref_batch(self):
|
||||
return False
|
||||
|
||||
def set_ob_stat(self, ob_mean, ob_std):
|
||||
self._set_ob_mean_std(ob_mean, ob_std)
|
||||
def set_ob_stat(self, ob_mean, ob_std):
|
||||
self._set_ob_mean_std(ob_mean, ob_std)
|
||||
|
||||
@@ -24,199 +24,201 @@ DISABLED = 50
|
||||
|
||||
|
||||
class TbWriter(object):
|
||||
"""Based on SummaryWriter, but changed to allow for a different prefix."""
|
||||
def __init__(self, dir, prefix):
|
||||
self.dir = dir
|
||||
# Start at 1, because EvWriter automatically generates an object with
|
||||
# step = 0.
|
||||
self.step = 1
|
||||
self.evwriter = pywrap_tensorflow.EventsWriter(
|
||||
compat.as_bytes(os.path.join(dir, prefix)))
|
||||
"""Based on SummaryWriter, but changed to allow for a different prefix."""
|
||||
def __init__(self, dir, prefix):
|
||||
self.dir = dir
|
||||
# Start at 1, because EvWriter automatically generates an object with
|
||||
# step = 0.
|
||||
self.step = 1
|
||||
self.evwriter = pywrap_tensorflow.EventsWriter(
|
||||
compat.as_bytes(os.path.join(dir, prefix)))
|
||||
|
||||
def write_values(self, key2val):
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=float(v))
|
||||
for (k, v) in key2val.items()])
|
||||
event = event_pb2.Event(wall_time=time.time(), summary=summary)
|
||||
event.step = self.step
|
||||
self.evwriter.WriteEvent(event)
|
||||
self.evwriter.Flush()
|
||||
self.step += 1
|
||||
def write_values(self, key2val):
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag=k,
|
||||
simple_value=float(v))
|
||||
for (k, v) in key2val.items()])
|
||||
event = event_pb2.Event(wall_time=time.time(), summary=summary)
|
||||
event.step = self.step
|
||||
self.evwriter.WriteEvent(event)
|
||||
self.evwriter.Flush()
|
||||
self.step += 1
|
||||
|
||||
def close(self):
|
||||
self.evwriter.Close()
|
||||
def close(self):
|
||||
self.evwriter.Close()
|
||||
|
||||
# API
|
||||
|
||||
|
||||
def start(dir):
|
||||
if _Logger.CURRENT is not _Logger.DEFAULT:
|
||||
sys.stderr.write("WARNING: You asked to start logging (dir=%s), but you "
|
||||
"never stopped the previous logger (dir=%s)."
|
||||
"\n" % (dir, _Logger.CURRENT.dir))
|
||||
_Logger.CURRENT = _Logger(dir=dir)
|
||||
if _Logger.CURRENT is not _Logger.DEFAULT:
|
||||
sys.stderr.write("WARNING: You asked to start logging (dir=%s), but "
|
||||
"you never stopped the previous logger (dir=%s)."
|
||||
"\n" % (dir, _Logger.CURRENT.dir))
|
||||
_Logger.CURRENT = _Logger(dir=dir)
|
||||
|
||||
|
||||
def stop():
|
||||
if _Logger.CURRENT is _Logger.DEFAULT:
|
||||
sys.stderr.write("WARNING: You asked to stop logging, but you never "
|
||||
"started any previous logger."
|
||||
"\n" % (dir, _Logger.CURRENT.dir))
|
||||
return
|
||||
_Logger.CURRENT.close()
|
||||
_Logger.CURRENT = _Logger.DEFAULT
|
||||
if _Logger.CURRENT is _Logger.DEFAULT:
|
||||
sys.stderr.write("WARNING: You asked to stop logging, but you never "
|
||||
"started any previous logger."
|
||||
"\n" % (dir, _Logger.CURRENT.dir))
|
||||
return
|
||||
_Logger.CURRENT.close()
|
||||
_Logger.CURRENT = _Logger.DEFAULT
|
||||
|
||||
|
||||
def record_tabular(key, val):
|
||||
"""Log a value of some diagnostic.
|
||||
"""Log a value of some diagnostic.
|
||||
|
||||
Call this once for each diagnostic quantity, each iteration.
|
||||
"""
|
||||
_Logger.CURRENT.record_tabular(key, val)
|
||||
Call this once for each diagnostic quantity, each iteration.
|
||||
"""
|
||||
_Logger.CURRENT.record_tabular(key, val)
|
||||
|
||||
|
||||
def dump_tabular():
|
||||
"""Write all of the diagnostics from the current iteration."""
|
||||
_Logger.CURRENT.dump_tabular()
|
||||
"""Write all of the diagnostics from the current iteration."""
|
||||
_Logger.CURRENT.dump_tabular()
|
||||
|
||||
|
||||
def log(*args, **kwargs):
|
||||
"""Write the sequence of args, with no separators.
|
||||
"""Write the sequence of args, with no separators.
|
||||
|
||||
This is written to the console and output files (if you've configured an
|
||||
output file).
|
||||
"""
|
||||
level = kwargs['level'] if 'level' in kwargs else INFO
|
||||
_Logger.CURRENT.log(*args, level=level)
|
||||
This is written to the console and output files (if you've configured an
|
||||
output file).
|
||||
"""
|
||||
level = kwargs['level'] if 'level' in kwargs else INFO
|
||||
_Logger.CURRENT.log(*args, level=level)
|
||||
|
||||
|
||||
def debug(*args):
|
||||
log(*args, level=DEBUG)
|
||||
log(*args, level=DEBUG)
|
||||
|
||||
|
||||
def info(*args):
|
||||
log(*args, level=INFO)
|
||||
log(*args, level=INFO)
|
||||
|
||||
|
||||
def warn(*args):
|
||||
log(*args, level=WARN)
|
||||
log(*args, level=WARN)
|
||||
|
||||
|
||||
def error(*args):
|
||||
log(*args, level=ERROR)
|
||||
log(*args, level=ERROR)
|
||||
|
||||
|
||||
def set_level(level):
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
_Logger.CURRENT.set_level(level)
|
||||
"""
|
||||
Set logging threshold on current logger.
|
||||
"""
|
||||
_Logger.CURRENT.set_level(level)
|
||||
|
||||
|
||||
def get_dir():
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call start)
|
||||
"""
|
||||
return _Logger.CURRENT.get_dir()
|
||||
"""
|
||||
Get directory that log files are being written to.
|
||||
will be None if there is no output directory (i.e., if you didn't call
|
||||
start)
|
||||
"""
|
||||
return _Logger.CURRENT.get_dir()
|
||||
|
||||
|
||||
def get_expt_dir():
|
||||
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n")
|
||||
return get_dir()
|
||||
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n")
|
||||
return get_dir()
|
||||
|
||||
# Backend
|
||||
|
||||
|
||||
class _Logger(object):
|
||||
# A logger with no output files. (See right below class definition) so that
|
||||
# you can still log to the terminal without setting up any output files.
|
||||
DEFAULT = None
|
||||
# Current logger being used by the free functions above.
|
||||
CURRENT = None
|
||||
# A logger with no output files. (See right below class definition) so that
|
||||
# you can still log to the terminal without setting up any output files.
|
||||
DEFAULT = None
|
||||
# Current logger being used by the free functions above.
|
||||
CURRENT = None
|
||||
|
||||
def __init__(self, dir=None):
|
||||
self.name2val = OrderedDict() # Values this iteration.
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.text_outputs = [sys.stdout]
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w"))
|
||||
self.tbwriter = TbWriter(dir=dir, prefix="events")
|
||||
else:
|
||||
self.tbwriter = None
|
||||
def __init__(self, dir=None):
|
||||
self.name2val = OrderedDict() # Values this iteration.
|
||||
self.level = INFO
|
||||
self.dir = dir
|
||||
self.text_outputs = [sys.stdout]
|
||||
if dir is not None:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w"))
|
||||
self.tbwriter = TbWriter(dir=dir, prefix="events")
|
||||
else:
|
||||
self.tbwriter = None
|
||||
|
||||
# Logging API, forwarded
|
||||
# Logging API, forwarded
|
||||
|
||||
def record_tabular(self, key, val):
|
||||
self.name2val[key] = val
|
||||
def record_tabular(self, key, val):
|
||||
self.name2val[key] = val
|
||||
|
||||
def dump_tabular(self):
|
||||
# Create strings for printing.
|
||||
key2str = OrderedDict()
|
||||
for (key, val) in self.name2val.items():
|
||||
if hasattr(val, "__float__"):
|
||||
valstr = "%-8.3g" % val
|
||||
else:
|
||||
valstr = val
|
||||
key2str[self._truncate(key)] = self._truncate(valstr)
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
# Write to all text outputs
|
||||
self._write_text("-" * (keywidth + valwidth + 7), "\n")
|
||||
for (key, val) in key2str.items():
|
||||
self._write_text("| ", key, " " * (keywidth - len(key)), " | ", val,
|
||||
" " * (valwidth - len(val)), " |\n")
|
||||
self._write_text("-" * (keywidth + valwidth + 7), "\n")
|
||||
for f in self.text_outputs:
|
||||
try:
|
||||
f.flush()
|
||||
except OSError:
|
||||
sys.stderr.write('Warning! OSError when flushing.\n')
|
||||
# Write to tensorboard
|
||||
if self.tbwriter is not None:
|
||||
self.tbwriter.write_values(self.name2val)
|
||||
self.name2val.clear()
|
||||
def dump_tabular(self):
|
||||
# Create strings for printing.
|
||||
key2str = OrderedDict()
|
||||
for (key, val) in self.name2val.items():
|
||||
if hasattr(val, "__float__"):
|
||||
valstr = "%-8.3g" % val
|
||||
else:
|
||||
valstr = val
|
||||
key2str[self._truncate(key)] = self._truncate(valstr)
|
||||
keywidth = max(map(len, key2str.keys()))
|
||||
valwidth = max(map(len, key2str.values()))
|
||||
# Write to all text outputs
|
||||
self._write_text("-" * (keywidth + valwidth + 7), "\n")
|
||||
for (key, val) in key2str.items():
|
||||
self._write_text("| ", key, " " * (keywidth - len(key)),
|
||||
" | ", val, " " * (valwidth - len(val)), " |\n")
|
||||
self._write_text("-" * (keywidth + valwidth + 7), "\n")
|
||||
for f in self.text_outputs:
|
||||
try:
|
||||
f.flush()
|
||||
except OSError:
|
||||
sys.stderr.write('Warning! OSError when flushing.\n')
|
||||
# Write to tensorboard
|
||||
if self.tbwriter is not None:
|
||||
self.tbwriter.write_values(self.name2val)
|
||||
self.name2val.clear()
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
level = kwargs['level'] if 'level' in kwargs else INFO
|
||||
if self.level <= level:
|
||||
self._do_log(*args)
|
||||
def log(self, *args, **kwargs):
|
||||
level = kwargs['level'] if 'level' in kwargs else INFO
|
||||
if self.level <= level:
|
||||
self._do_log(*args)
|
||||
|
||||
# Configuration
|
||||
# Configuration
|
||||
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
def set_level(self, level):
|
||||
self.level = level
|
||||
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
def get_dir(self):
|
||||
return self.dir
|
||||
|
||||
def close(self):
|
||||
for f in self.text_outputs[1:]:
|
||||
f.close()
|
||||
if self.tbwriter:
|
||||
self.tbwriter.close()
|
||||
def close(self):
|
||||
for f in self.text_outputs[1:]:
|
||||
f.close()
|
||||
if self.tbwriter:
|
||||
self.tbwriter.close()
|
||||
|
||||
# Misc
|
||||
# Misc
|
||||
|
||||
def _do_log(self, *args):
|
||||
self._write_text(*args + ('\n',))
|
||||
for f in self.text_outputs:
|
||||
try:
|
||||
f.flush()
|
||||
except OSError:
|
||||
print('Warning! OSError when flushing.')
|
||||
def _do_log(self, *args):
|
||||
self._write_text(*args + ('\n',))
|
||||
for f in self.text_outputs:
|
||||
try:
|
||||
f.flush()
|
||||
except OSError:
|
||||
print('Warning! OSError when flushing.')
|
||||
|
||||
def _write_text(self, *strings):
|
||||
for f in self.text_outputs:
|
||||
for string in strings:
|
||||
f.write(string)
|
||||
def _write_text(self, *strings):
|
||||
for f in self.text_outputs:
|
||||
for string in strings:
|
||||
f.write(string)
|
||||
|
||||
def _truncate(self, s):
|
||||
if len(s) > 33:
|
||||
return s[:30] + "..."
|
||||
else:
|
||||
return s
|
||||
def _truncate(self, s):
|
||||
if len(s) > 33:
|
||||
return s[:30] + "..."
|
||||
else:
|
||||
return s
|
||||
|
||||
|
||||
_Logger.DEFAULT = _Logger()
|
||||
|
||||
@@ -12,8 +12,8 @@ import os
|
||||
|
||||
# Tensorflow must be at least version 1.0.0 for the example to work.
|
||||
if int(tf.__version__.split(".")[0]) < 1:
|
||||
raise Exception("Your Tensorflow version is less than 1.0.0. Please update "
|
||||
"Tensorflow to the latest version.")
|
||||
raise Exception("Your Tensorflow version is less than 1.0.0. Please "
|
||||
"update Tensorflow to the latest version.")
|
||||
|
||||
# ================================================================
|
||||
# Import all names into common namespace
|
||||
@@ -25,160 +25,163 @@ clip = tf.clip_by_value
|
||||
|
||||
|
||||
def sum(x, axis=None, keepdims=False):
|
||||
return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis],
|
||||
keep_dims=keepdims)
|
||||
return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis],
|
||||
keep_dims=keepdims)
|
||||
|
||||
|
||||
def mean(x, axis=None, keepdims=False):
|
||||
return tf.reduce_mean(x, reduction_indices=None if axis is None else [axis],
|
||||
keep_dims=keepdims)
|
||||
return tf.reduce_mean(x, reduction_indices=(None if axis is None
|
||||
else [axis]),
|
||||
keep_dims=keepdims)
|
||||
|
||||
|
||||
def var(x, axis=None, keepdims=False):
|
||||
meanx = mean(x, axis=axis, keepdims=keepdims)
|
||||
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
|
||||
meanx = mean(x, axis=axis, keepdims=keepdims)
|
||||
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
|
||||
|
||||
|
||||
def std(x, axis=None, keepdims=False):
|
||||
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
|
||||
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
|
||||
|
||||
|
||||
def max(x, axis=None, keepdims=False):
|
||||
return tf.reduce_max(x, reduction_indices=None if axis is None else [axis],
|
||||
keep_dims=keepdims)
|
||||
return tf.reduce_max(x, reduction_indices=None if axis is None else [axis],
|
||||
keep_dims=keepdims)
|
||||
|
||||
|
||||
def min(x, axis=None, keepdims=False):
|
||||
return tf.reduce_min(x, reduction_indices=None if axis is None else [axis],
|
||||
keep_dims=keepdims)
|
||||
return tf.reduce_min(x, reduction_indices=None if axis is None else [axis],
|
||||
keep_dims=keepdims)
|
||||
|
||||
|
||||
def concatenate(arrs, axis=0):
|
||||
return tf.concat(arrs, axis)
|
||||
return tf.concat(arrs, axis)
|
||||
|
||||
|
||||
def argmax(x, axis=None):
|
||||
return tf.argmax(x, dimension=axis)
|
||||
return tf.argmax(x, dimension=axis)
|
||||
|
||||
# Extras
|
||||
|
||||
|
||||
def l2loss(params):
|
||||
if len(params) == 0:
|
||||
return tf.constant(0.0)
|
||||
else:
|
||||
return tf.add_n([sum(tf.square(p)) for p in params])
|
||||
if len(params) == 0:
|
||||
return tf.constant(0.0)
|
||||
else:
|
||||
return tf.add_n([sum(tf.square(p)) for p in params])
|
||||
|
||||
|
||||
def lrelu(x, leak=0.2):
|
||||
f1 = 0.5 * (1 + leak)
|
||||
f2 = 0.5 * (1 - leak)
|
||||
return f1 * x + f2 * abs(x)
|
||||
f1 = 0.5 * (1 + leak)
|
||||
f2 = 0.5 * (1 - leak)
|
||||
return f1 * x + f2 * abs(x)
|
||||
|
||||
|
||||
def categorical_sample_logits(X):
|
||||
# https://github.com/tensorflow/tensorflow/issues/456
|
||||
U = tf.random_uniform(tf.shape(X))
|
||||
return argmax(X - tf.log(-tf.log(U)), axis=1)
|
||||
# https://github.com/tensorflow/tensorflow/issues/456
|
||||
U = tf.random_uniform(tf.shape(X))
|
||||
return argmax(X - tf.log(-tf.log(U)), axis=1)
|
||||
|
||||
# Global session
|
||||
|
||||
|
||||
def get_session():
|
||||
return tf.get_default_session()
|
||||
return tf.get_default_session()
|
||||
|
||||
|
||||
def single_threaded_session():
|
||||
tf_config = tf.ConfigProto(inter_op_parallelism_threads=1,
|
||||
intra_op_parallelism_threads=1)
|
||||
return tf.Session(config=tf_config)
|
||||
tf_config = tf.ConfigProto(inter_op_parallelism_threads=1,
|
||||
intra_op_parallelism_threads=1)
|
||||
return tf.Session(config=tf_config)
|
||||
|
||||
|
||||
ALREADY_INITIALIZED = set()
|
||||
|
||||
|
||||
def initialize():
|
||||
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
|
||||
get_session().run(tf.variables_initializer(new_variables))
|
||||
ALREADY_INITIALIZED.update(new_variables)
|
||||
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
|
||||
get_session().run(tf.variables_initializer(new_variables))
|
||||
ALREADY_INITIALIZED.update(new_variables)
|
||||
|
||||
|
||||
def eval(expr, feed_dict=None):
|
||||
if feed_dict is None:
|
||||
feed_dict = {}
|
||||
return get_session().run(expr, feed_dict=feed_dict)
|
||||
if feed_dict is None:
|
||||
feed_dict = {}
|
||||
return get_session().run(expr, feed_dict=feed_dict)
|
||||
|
||||
|
||||
def set_value(v, val):
|
||||
get_session().run(v.assign(val))
|
||||
get_session().run(v.assign(val))
|
||||
|
||||
|
||||
def load_state(fname):
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(get_session(), fname)
|
||||
saver = tf.train.Saver()
|
||||
saver.restore(get_session(), fname)
|
||||
|
||||
|
||||
def save_state(fname):
|
||||
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
||||
saver = tf.train.Saver()
|
||||
saver.save(get_session(), fname)
|
||||
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
||||
saver = tf.train.Saver()
|
||||
saver.save(get_session(), fname)
|
||||
|
||||
# Model components
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
|
||||
|
||||
def dense(x, size, name, weight_init=None, bias=True):
|
||||
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
|
||||
initializer=weight_init)
|
||||
ret = tf.matmul(x, w)
|
||||
if bias:
|
||||
b = tf.get_variable(name + "/b", [size],
|
||||
initializer=tf.zeros_initializer())
|
||||
return ret + b
|
||||
else:
|
||||
return ret
|
||||
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
|
||||
initializer=weight_init)
|
||||
ret = tf.matmul(x, w)
|
||||
if bias:
|
||||
b = tf.get_variable(name + "/b", [size],
|
||||
initializer=tf.zeros_initializer())
|
||||
return ret + b
|
||||
else:
|
||||
return ret
|
||||
|
||||
# Basic Stuff
|
||||
|
||||
|
||||
def function(inputs, outputs, updates=None, givens=None):
|
||||
if isinstance(outputs, list):
|
||||
return _Function(inputs, outputs, updates, givens=givens)
|
||||
elif isinstance(outputs, dict):
|
||||
f = _Function(inputs, outputs.values(), updates, givens=givens)
|
||||
return lambda *inputs: dict(zip(outputs.keys(), f(*inputs)))
|
||||
else:
|
||||
f = _Function(inputs, [outputs], updates, givens=givens)
|
||||
return lambda *inputs: f(*inputs)[0]
|
||||
if isinstance(outputs, list):
|
||||
return _Function(inputs, outputs, updates, givens=givens)
|
||||
elif isinstance(outputs, dict):
|
||||
f = _Function(inputs, outputs.values(), updates, givens=givens)
|
||||
return lambda *inputs: dict(zip(outputs.keys(), f(*inputs)))
|
||||
else:
|
||||
f = _Function(inputs, [outputs], updates, givens=givens)
|
||||
return lambda *inputs: f(*inputs)[0]
|
||||
|
||||
|
||||
class _Function(object):
|
||||
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
|
||||
assert all(len(i.op.inputs) == 0 for i in inputs), ("inputs should all be "
|
||||
"placeholders")
|
||||
self.inputs = inputs
|
||||
updates = updates or []
|
||||
self.update_group = tf.group(*updates)
|
||||
self.outputs_update = list(outputs) + [self.update_group]
|
||||
self.givens = {} if givens is None else givens
|
||||
self.check_nan = check_nan
|
||||
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
|
||||
assert all(len(i.op.inputs) == 0 for i in inputs), ("inputs should "
|
||||
"all be "
|
||||
"placeholders")
|
||||
self.inputs = inputs
|
||||
updates = updates or []
|
||||
self.update_group = tf.group(*updates)
|
||||
self.outputs_update = list(outputs) + [self.update_group]
|
||||
self.givens = {} if givens is None else givens
|
||||
self.check_nan = check_nan
|
||||
|
||||
def __call__(self, *inputvals):
|
||||
assert len(inputvals) == len(self.inputs)
|
||||
feed_dict = dict(zip(self.inputs, inputvals))
|
||||
feed_dict.update(self.givens)
|
||||
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
||||
if self.check_nan:
|
||||
if any(np.isnan(r).any() for r in results):
|
||||
raise RuntimeError("Nan detected")
|
||||
return results
|
||||
def __call__(self, *inputvals):
|
||||
assert len(inputvals) == len(self.inputs)
|
||||
feed_dict = dict(zip(self.inputs, inputvals))
|
||||
feed_dict.update(self.givens)
|
||||
results = get_session().run(self.outputs_update,
|
||||
feed_dict=feed_dict)[:-1]
|
||||
if self.check_nan:
|
||||
if any(np.isnan(r).any() for r in results):
|
||||
raise RuntimeError("Nan detected")
|
||||
return results
|
||||
|
||||
|
||||
# Graph traversal
|
||||
@@ -189,71 +192,72 @@ VARIABLES = {}
|
||||
|
||||
|
||||
def var_shape(x):
|
||||
out = [k.value for k in x.get_shape()]
|
||||
assert all(isinstance(a, int) for a in out), ("shape function assumes that "
|
||||
"shape is fully known")
|
||||
return out
|
||||
out = [k.value for k in x.get_shape()]
|
||||
assert all(isinstance(a, int) for a in out), ("shape function assumes "
|
||||
"that shape is fully known")
|
||||
return out
|
||||
|
||||
|
||||
def numel(x):
|
||||
return intprod(var_shape(x))
|
||||
return intprod(var_shape(x))
|
||||
|
||||
|
||||
def intprod(x):
|
||||
return int(np.prod(x))
|
||||
return int(np.prod(x))
|
||||
|
||||
|
||||
def flatgrad(loss, var_list):
|
||||
grads = tf.gradients(loss, var_list)
|
||||
return tf.concat([tf.reshape(grad, [numel(v)], 0)
|
||||
for (v, grad) in zip(var_list, grads)])
|
||||
grads = tf.gradients(loss, var_list)
|
||||
return tf.concat([tf.reshape(grad, [numel(v)], 0)
|
||||
for (v, grad) in zip(var_list, grads)])
|
||||
|
||||
|
||||
class SetFromFlat(object):
|
||||
def __init__(self, var_list, dtype=tf.float32):
|
||||
assigns = []
|
||||
shapes = list(map(var_shape, var_list))
|
||||
total_size = np.sum([intprod(shape) for shape in shapes])
|
||||
def __init__(self, var_list, dtype=tf.float32):
|
||||
assigns = []
|
||||
shapes = list(map(var_shape, var_list))
|
||||
total_size = np.sum([intprod(shape) for shape in shapes])
|
||||
|
||||
self.theta = theta = tf.placeholder(dtype, [total_size])
|
||||
start = 0
|
||||
assigns = []
|
||||
for (shape, v) in zip(shapes, var_list):
|
||||
size = intprod(shape)
|
||||
assigns.append(tf.assign(v, tf.reshape(theta[start:start + size],
|
||||
shape)))
|
||||
start += size
|
||||
assert start == total_size
|
||||
self.op = tf.group(*assigns)
|
||||
self.theta = theta = tf.placeholder(dtype, [total_size])
|
||||
start = 0
|
||||
assigns = []
|
||||
for (shape, v) in zip(shapes, var_list):
|
||||
size = intprod(shape)
|
||||
assigns.append(tf.assign(v, tf.reshape(theta[start:start + size],
|
||||
shape)))
|
||||
start += size
|
||||
assert start == total_size
|
||||
self.op = tf.group(*assigns)
|
||||
|
||||
def __call__(self, theta):
|
||||
get_session().run(self.op, feed_dict={self.theta: theta})
|
||||
def __call__(self, theta):
|
||||
get_session().run(self.op, feed_dict={self.theta: theta})
|
||||
|
||||
|
||||
class GetFlat(object):
|
||||
def __init__(self, var_list):
|
||||
self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0)
|
||||
def __init__(self, var_list):
|
||||
self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0)
|
||||
|
||||
def __call__(self):
|
||||
return get_session().run(self.op)
|
||||
def __call__(self):
|
||||
return get_session().run(self.op)
|
||||
|
||||
# Misc
|
||||
|
||||
|
||||
def scope_vars(scope, trainable_only):
|
||||
"""Get variables inside a scope. The scope can be specified as a string."""
|
||||
return tf.get_collection((tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only
|
||||
else tf.GraphKeys.GLOBAL_VARIABLES),
|
||||
scope=(scope if isinstance(scope, str)
|
||||
else scope.name))
|
||||
"""Get variables inside a scope. The scope can be specified as a string."""
|
||||
return tf.get_collection((tf.GraphKeys.TRAINABLE_VARIABLES
|
||||
if trainable_only
|
||||
else tf.GraphKeys.GLOBAL_VARIABLES),
|
||||
scope=(scope if isinstance(scope, str)
|
||||
else scope.name))
|
||||
|
||||
|
||||
def in_session(f):
|
||||
@functools.wraps(f)
|
||||
def newfunc(*args, **kwargs):
|
||||
with tf.Session():
|
||||
f(*args, **kwargs)
|
||||
return newfunc
|
||||
@functools.wraps(f)
|
||||
def newfunc(*args, **kwargs):
|
||||
with tf.Session():
|
||||
f(*args, **kwargs)
|
||||
return newfunc
|
||||
|
||||
|
||||
# A mapping from name -> (placeholder, dtype, shape).
|
||||
@@ -261,28 +265,28 @@ _PLACEHOLDER_CACHE = {}
|
||||
|
||||
|
||||
def get_placeholder(name, dtype, shape):
|
||||
print("calling get_placeholder", name)
|
||||
if name in _PLACEHOLDER_CACHE:
|
||||
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
|
||||
assert dtype1 == dtype and shape1 == shape
|
||||
return out
|
||||
else:
|
||||
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
|
||||
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
|
||||
return out
|
||||
print("calling get_placeholder", name)
|
||||
if name in _PLACEHOLDER_CACHE:
|
||||
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
|
||||
assert dtype1 == dtype and shape1 == shape
|
||||
return out
|
||||
else:
|
||||
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
|
||||
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
|
||||
return out
|
||||
|
||||
|
||||
def get_placeholder_cached(name):
|
||||
return _PLACEHOLDER_CACHE[name][0]
|
||||
return _PLACEHOLDER_CACHE[name][0]
|
||||
|
||||
|
||||
def flattenallbut0(x):
|
||||
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
|
||||
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
|
||||
|
||||
|
||||
def reset():
|
||||
global _PLACEHOLDER_CACHE
|
||||
global VARIABLES
|
||||
_PLACEHOLDER_CACHE = {}
|
||||
VARIABLES = {}
|
||||
tf.reset_default_graph()
|
||||
global _PLACEHOLDER_CACHE
|
||||
global VARIABLES
|
||||
_PLACEHOLDER_CACHE = {}
|
||||
VARIABLES = {}
|
||||
tf.reset_default_graph()
|
||||
|
||||
@@ -10,77 +10,78 @@ import tensorflow as tf
|
||||
|
||||
|
||||
def compute_ranks(x):
|
||||
"""Returns ranks in [0, len(x))
|
||||
"""Returns ranks in [0, len(x))
|
||||
|
||||
Note: This is different from scipy.stats.rankdata, which returns ranks in
|
||||
[1, len(x)].
|
||||
"""
|
||||
assert x.ndim == 1
|
||||
ranks = np.empty(len(x), dtype=int)
|
||||
ranks[x.argsort()] = np.arange(len(x))
|
||||
return ranks
|
||||
Note: This is different from scipy.stats.rankdata, which returns ranks in
|
||||
[1, len(x)].
|
||||
"""
|
||||
assert x.ndim == 1
|
||||
ranks = np.empty(len(x), dtype=int)
|
||||
ranks[x.argsort()] = np.arange(len(x))
|
||||
return ranks
|
||||
|
||||
|
||||
def compute_centered_ranks(x):
|
||||
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
|
||||
y /= (x.size - 1)
|
||||
y -= 0.5
|
||||
return y
|
||||
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
|
||||
y /= (x.size - 1)
|
||||
y -= 0.5
|
||||
return y
|
||||
|
||||
|
||||
def make_session(single_threaded):
|
||||
if not single_threaded:
|
||||
return tf.InteractiveSession()
|
||||
return tf.InteractiveSession(
|
||||
config=tf.ConfigProto(inter_op_parallelism_threads=1,
|
||||
intra_op_parallelism_threads=1))
|
||||
if not single_threaded:
|
||||
return tf.InteractiveSession()
|
||||
return tf.InteractiveSession(
|
||||
config=tf.ConfigProto(inter_op_parallelism_threads=1,
|
||||
intra_op_parallelism_threads=1))
|
||||
|
||||
|
||||
def itergroups(items, group_size):
|
||||
assert group_size >= 1
|
||||
group = []
|
||||
for x in items:
|
||||
group.append(x)
|
||||
if len(group) == group_size:
|
||||
yield tuple(group)
|
||||
del group[:]
|
||||
if group:
|
||||
yield tuple(group)
|
||||
assert group_size >= 1
|
||||
group = []
|
||||
for x in items:
|
||||
group.append(x)
|
||||
if len(group) == group_size:
|
||||
yield tuple(group)
|
||||
del group[:]
|
||||
if group:
|
||||
yield tuple(group)
|
||||
|
||||
|
||||
def batched_weighted_sum(weights, vecs, batch_size):
|
||||
total = 0
|
||||
num_items_summed = 0
|
||||
for batch_weights, batch_vecs in zip(itergroups(weights, batch_size),
|
||||
itergroups(vecs, batch_size)):
|
||||
assert len(batch_weights) == len(batch_vecs) <= batch_size
|
||||
total += np.dot(np.asarray(batch_weights, dtype=np.float32),
|
||||
np.asarray(batch_vecs, dtype=np.float32))
|
||||
num_items_summed += len(batch_weights)
|
||||
return total, num_items_summed
|
||||
total = 0
|
||||
num_items_summed = 0
|
||||
for batch_weights, batch_vecs in zip(itergroups(weights, batch_size),
|
||||
itergroups(vecs, batch_size)):
|
||||
assert len(batch_weights) == len(batch_vecs) <= batch_size
|
||||
total += np.dot(np.asarray(batch_weights, dtype=np.float32),
|
||||
np.asarray(batch_vecs, dtype=np.float32))
|
||||
num_items_summed += len(batch_weights)
|
||||
return total, num_items_summed
|
||||
|
||||
|
||||
class RunningStat(object):
|
||||
def __init__(self, shape, eps):
|
||||
self.sum = np.zeros(shape, dtype=np.float32)
|
||||
self.sumsq = np.full(shape, eps, dtype=np.float32)
|
||||
self.count = eps
|
||||
def __init__(self, shape, eps):
|
||||
self.sum = np.zeros(shape, dtype=np.float32)
|
||||
self.sumsq = np.full(shape, eps, dtype=np.float32)
|
||||
self.count = eps
|
||||
|
||||
def increment(self, s, ssq, c):
|
||||
self.sum += s
|
||||
self.sumsq += ssq
|
||||
self.count += c
|
||||
def increment(self, s, ssq, c):
|
||||
self.sum += s
|
||||
self.sumsq += ssq
|
||||
self.count += c
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self.sum / self.count
|
||||
@property
|
||||
def mean(self):
|
||||
return self.sum / self.count
|
||||
|
||||
@property
|
||||
def std(self):
|
||||
return np.sqrt(np.maximum(self.sumsq / self.count - np.square(self.mean),
|
||||
1e-2))
|
||||
@property
|
||||
def std(self):
|
||||
return np.sqrt(np.maximum(
|
||||
self.sumsq / self.count - np.square(self.mean), 1e-2))
|
||||
|
||||
def set_from_init(self, init_mean, init_std, init_count):
|
||||
self.sum[:] = init_mean * init_count
|
||||
self.sumsq[:] = (np.square(init_mean) + np.square(init_std)) * init_count
|
||||
self.count = init_count
|
||||
def set_from_init(self, init_mean, init_std, init_count):
|
||||
self.sum[:] = init_mean * init_count
|
||||
self.sumsq[:] = (np.square(init_mean) +
|
||||
np.square(init_std)) * init_count
|
||||
self.count = init_count
|
||||
|
||||
@@ -11,32 +11,33 @@ import click
|
||||
@click.option("--stochastic", is_flag=True)
|
||||
@click.option("--extra_kwargs")
|
||||
def main(env_id, policy_file, record, stochastic, extra_kwargs):
|
||||
import gym
|
||||
from gym import wrappers
|
||||
import tensorflow as tf
|
||||
from policies import MujocoPolicy
|
||||
import numpy as np
|
||||
|
||||
env = gym.make(env_id)
|
||||
if record:
|
||||
import uuid
|
||||
env = wrappers.Monitor(env, "/tmp/" + str(uuid.uuid4()), force=True)
|
||||
|
||||
if extra_kwargs:
|
||||
import json
|
||||
extra_kwargs = json.loads(extra_kwargs)
|
||||
|
||||
with tf.Session():
|
||||
pi = MujocoPolicy.Load(policy_file, extra_kwargs=extra_kwargs)
|
||||
while True:
|
||||
rews, t = pi.rollout(env, render=True,
|
||||
random_stream=np.random if stochastic else None)
|
||||
print("return={:.4f} len={}".format(rews.sum(), t))
|
||||
import gym
|
||||
from gym import wrappers
|
||||
import tensorflow as tf
|
||||
from policies import MujocoPolicy
|
||||
import numpy as np
|
||||
|
||||
env = gym.make(env_id)
|
||||
if record:
|
||||
env.close()
|
||||
return
|
||||
import uuid
|
||||
env = wrappers.Monitor(env, "/tmp/" + str(uuid.uuid4()), force=True)
|
||||
|
||||
if extra_kwargs:
|
||||
import json
|
||||
extra_kwargs = json.loads(extra_kwargs)
|
||||
|
||||
with tf.Session():
|
||||
pi = MujocoPolicy.Load(policy_file, extra_kwargs=extra_kwargs)
|
||||
while True:
|
||||
rews, t = pi.rollout(env, render=True,
|
||||
random_stream=(np.random if stochastic
|
||||
else None))
|
||||
print("return={:.4f} len={}".format(rews.sum(), t))
|
||||
|
||||
if record:
|
||||
env.close()
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
+10
-10
@@ -11,15 +11,15 @@ import ray.rllib.policy_gradient as pg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ray.init()
|
||||
ray.init()
|
||||
|
||||
# TODO(ekl): get the algorithms working on a common set of envs
|
||||
env_name = "CartPole-v0"
|
||||
alg1 = es.EvolutionStrategies(env_name, es.DEFAULT_CONFIG)
|
||||
alg2 = pg.PolicyGradient(env_name, pg.DEFAULT_CONFIG)
|
||||
# TODO(ekl): get the algorithms working on a common set of envs
|
||||
env_name = "CartPole-v0"
|
||||
alg1 = es.EvolutionStrategies(env_name, es.DEFAULT_CONFIG)
|
||||
alg2 = pg.PolicyGradient(env_name, pg.DEFAULT_CONFIG)
|
||||
|
||||
while True:
|
||||
r1 = alg1.train()
|
||||
r2 = alg2.train()
|
||||
print("evolution strategies: {}".format(r1))
|
||||
print("policy gradient: {}".format(r2))
|
||||
while True:
|
||||
r1 = alg1.train()
|
||||
r2 = alg2.train()
|
||||
print("evolution strategies: {}".format(r1))
|
||||
print("policy gradient: {}".format(r2))
|
||||
|
||||
+198
-194
@@ -10,189 +10,192 @@ import tensorflow as tf
|
||||
|
||||
|
||||
class LocalSyncParallelOptimizer(object):
|
||||
"""Optimizer that runs in parallel across multiple local devices.
|
||||
"""Optimizer that runs in parallel across multiple local devices.
|
||||
|
||||
LocalSyncParallelOptimizer automatically splits up and loads training data
|
||||
onto specified local devices (e.g. GPUs) with `load_data()`. During a call to
|
||||
`optimize()`, the devices compute gradients over slices of the data in
|
||||
parallel. The gradients are then averaged and applied to the shared weights.
|
||||
LocalSyncParallelOptimizer automatically splits up and loads training data
|
||||
onto specified local devices (e.g. GPUs) with `load_data()`. During a call
|
||||
to `optimize()`, the devices compute gradients over slices of the data in
|
||||
parallel. The gradients are then averaged and applied to the shared
|
||||
weights.
|
||||
|
||||
The data loaded is pinned in device memory until the next call to
|
||||
`load_data`, so you can make multiple passes (possibly in randomized order)
|
||||
over the same data once loaded.
|
||||
The data loaded is pinned in device memory until the next call to
|
||||
`load_data`, so you can make multiple passes (possibly in randomized order)
|
||||
over the same data once loaded.
|
||||
|
||||
This is similar to tf.train.SyncReplicasOptimizer, but works within a single
|
||||
TensorFlow graph, i.e. implements in-graph replicated training:
|
||||
This is similar to tf.train.SyncReplicasOptimizer, but works within a
|
||||
single TensorFlow graph, i.e. implements in-graph replicated training:
|
||||
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
|
||||
|
||||
Args:
|
||||
optimizer: delegate TensorFlow optimizer object.
|
||||
devices: list of the names of TensorFlow devices to parallelize over.
|
||||
input_placeholders: list of inputs for the loss function. Tensors of
|
||||
these shapes will be passed to build_loss() in order
|
||||
to define the per-device loss ops.
|
||||
per_device_batch_size: number of tuples to optimize over at a time per
|
||||
device. In each call to `optimize()`,
|
||||
`len(devices) * per_device_batch_size` tuples of
|
||||
data will be processed.
|
||||
build_loss: function that takes the specified inputs and returns an
|
||||
object with a 'loss' property that is a scalar Tensor. For
|
||||
example, ray.rllib.policy_gradient.ProximalPolicyLoss.
|
||||
logdir: directory to place debugging output in.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
devices,
|
||||
input_placeholders,
|
||||
per_device_batch_size,
|
||||
build_loss,
|
||||
logdir):
|
||||
self.optimizer = optimizer
|
||||
self.devices = devices
|
||||
self.batch_size = per_device_batch_size * len(devices)
|
||||
self.per_device_batch_size = per_device_batch_size
|
||||
self.input_placeholders = input_placeholders
|
||||
self.build_loss = build_loss
|
||||
self.logdir = logdir
|
||||
|
||||
# First initialize the shared loss network
|
||||
with tf.variable_scope("tower"):
|
||||
self._shared_loss = build_loss(*input_placeholders)
|
||||
|
||||
# Then setup the per-device loss graphs that use the shared weights
|
||||
self._batch_index = tf.placeholder(tf.int32)
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(devices)) for ph in input_placeholders])
|
||||
self._towers = []
|
||||
for device, device_placeholders in zip(self.devices, data_splits):
|
||||
self._towers.append(self._setup_device(device, device_placeholders))
|
||||
|
||||
avg = average_gradients([t.grads for t in self._towers])
|
||||
self._train_op = self.optimizer.apply_gradients(avg)
|
||||
|
||||
def load_data(self, sess, inputs, full_trace=False):
|
||||
"""Bulk loads the specified inputs into device memory.
|
||||
|
||||
The shape of the inputs must conform to the shapes of the input
|
||||
placeholders this optimizer was constructed with.
|
||||
|
||||
The data is split equally across all the devices. If the data is not
|
||||
evenly divisible by the batch size, excess data will be discarded.
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
inputs: list of Tensors matching the input placeholders specified at
|
||||
construction time of this optimizer.
|
||||
full_trace: whether to profile data loading.
|
||||
|
||||
Returns:
|
||||
The number of tuples loaded per device.
|
||||
optimizer: Delegate TensorFlow optimizer object.
|
||||
devices: List of the names of TensorFlow devices to parallelize over.
|
||||
input_placeholders: List of inputs for the loss function. Tensors of
|
||||
these shapes will be passed to build_loss() in order to define the
|
||||
per-device loss ops.
|
||||
per_device_batch_size: Number of tuples to optimize over at a time per
|
||||
device. In each call to `optimize()`,
|
||||
`len(devices) * per_device_batch_size` tuples of data will be
|
||||
processed.
|
||||
build_loss: Function that takes the specified inputs and returns an
|
||||
object with a 'loss' property that is a scalar Tensor. For example,
|
||||
ray.rllib.policy_gradient.ProximalPolicyLoss.
|
||||
logdir: Directory to place debugging output in.
|
||||
"""
|
||||
|
||||
feed_dict = {}
|
||||
assert len(self.input_placeholders) == len(inputs)
|
||||
for ph, arr in zip(self.input_placeholders, inputs):
|
||||
truncated_arr = make_divisible_by(arr, self.batch_size)
|
||||
feed_dict[ph] = truncated_arr
|
||||
truncated_len = len(truncated_arr)
|
||||
def __init__(self, optimizer, devices, input_placeholders,
|
||||
per_device_batch_size, build_loss, logdir):
|
||||
self.optimizer = optimizer
|
||||
self.devices = devices
|
||||
self.batch_size = per_device_batch_size * len(devices)
|
||||
self.per_device_batch_size = per_device_batch_size
|
||||
self.input_placeholders = input_placeholders
|
||||
self.build_loss = build_loss
|
||||
self.logdir = logdir
|
||||
|
||||
if full_trace:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
# First initialize the shared loss network
|
||||
with tf.variable_scope("tower"):
|
||||
self._shared_loss = build_loss(*input_placeholders)
|
||||
|
||||
sess.run(
|
||||
[t.init_op for t in self._towers],
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
if full_trace:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-load.json"), "w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
# Then setup the per-device loss graphs that use the shared weights
|
||||
self._batch_index = tf.placeholder(tf.int32)
|
||||
data_splits = zip(
|
||||
*[tf.split(ph, len(devices)) for ph in input_placeholders])
|
||||
self._towers = []
|
||||
for device, device_placeholders in zip(self.devices, data_splits):
|
||||
self._towers.append(self._setup_device(device,
|
||||
device_placeholders))
|
||||
|
||||
tuples_per_device = truncated_len / len(self.devices)
|
||||
assert tuples_per_device % self.per_device_batch_size == 0
|
||||
return tuples_per_device
|
||||
avg = average_gradients([t.grads for t in self._towers])
|
||||
self._train_op = self.optimizer.apply_gradients(avg)
|
||||
|
||||
def optimize(
|
||||
self, sess, batch_index,
|
||||
extra_ops=[], extra_feed_dict={}, file_writer=None):
|
||||
"""Run a single step of SGD.
|
||||
def load_data(self, sess, inputs, full_trace=False):
|
||||
"""Bulk loads the specified inputs into device memory.
|
||||
|
||||
Runs a SGD step over a slice of the preloaded batch with size given by
|
||||
self.per_device_batch_size and offset given by the batch_index argument.
|
||||
The shape of the inputs must conform to the shapes of the input
|
||||
placeholders this optimizer was constructed with.
|
||||
|
||||
Updates shared model weights based on the averaged per-device gradients.
|
||||
The data is split equally across all the devices. If the data is not
|
||||
evenly divisible by the batch size, excess data will be discarded.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
batch_index: offset into the preloaded data. This value must be
|
||||
between `0` and `tuples_per_device`. The amount of data
|
||||
to process is always fixed to `per_device_batch_size`.
|
||||
extra_ops: extra ops to run with this step (e.g. for metrics).
|
||||
extra_feed_dict: extra args to feed into this session run.
|
||||
file_writer: if specified, tf metrics will be written out using this.
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
inputs: List of Tensors matching the input placeholders specified
|
||||
at construction time of this optimizer.
|
||||
full_trace: Whether to profile data loading.
|
||||
|
||||
Returns:
|
||||
the outputs of extra_ops evaluated over the batch.
|
||||
"""
|
||||
Returns:
|
||||
The number of tuples loaded per device.
|
||||
"""
|
||||
|
||||
if file_writer:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
feed_dict = {}
|
||||
assert len(self.input_placeholders) == len(inputs)
|
||||
for ph, arr in zip(self.input_placeholders, inputs):
|
||||
truncated_arr = make_divisible_by(arr, self.batch_size)
|
||||
feed_dict[ph] = truncated_arr
|
||||
truncated_len = len(truncated_arr)
|
||||
|
||||
feed_dict = {self._batch_index: batch_index}
|
||||
feed_dict.update(extra_feed_dict)
|
||||
outs = sess.run(
|
||||
[self._train_op] + extra_ops,
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
if full_trace:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
if file_writer:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), "w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
file_writer.add_run_metadata(
|
||||
run_metadata, "sgd_train_{}".format(batch_index))
|
||||
sess.run(
|
||||
[t.init_op for t in self._towers],
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
if full_trace:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-load.json"),
|
||||
"w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
|
||||
return outs[1:]
|
||||
tuples_per_device = truncated_len / len(self.devices)
|
||||
assert tuples_per_device % self.per_device_batch_size == 0
|
||||
return tuples_per_device
|
||||
|
||||
def get_common_loss(self):
|
||||
return self._shared_loss
|
||||
def optimize(self, sess, batch_index, extra_ops=[], extra_feed_dict={},
|
||||
file_writer=None):
|
||||
"""Run a single step of SGD.
|
||||
|
||||
def get_device_losses(self):
|
||||
return [t.loss_object for t in self._towers]
|
||||
Runs a SGD step over a slice of the preloaded batch with size given by
|
||||
self.per_device_batch_size and offset given by the batch_index
|
||||
argument.
|
||||
|
||||
def _setup_device(self, device, device_input_placeholders):
|
||||
with tf.device(device):
|
||||
with tf.variable_scope("tower", reuse=True):
|
||||
device_input_batches = []
|
||||
device_input_slices = []
|
||||
for ph in device_input_placeholders:
|
||||
current_batch = tf.Variable(
|
||||
ph, trainable=False, validate_shape=False, collections=[])
|
||||
device_input_batches.append(current_batch)
|
||||
current_slice = tf.slice(
|
||||
current_batch,
|
||||
[self._batch_index] + [0] * len(ph.shape[1:]),
|
||||
[self.per_device_batch_size] + [-1] * len(ph.shape[1:]))
|
||||
current_slice.set_shape(ph.shape)
|
||||
device_input_slices.append(current_slice)
|
||||
device_loss_obj = self.build_loss(*device_input_slices)
|
||||
device_grads = self.optimizer.compute_gradients(
|
||||
device_loss_obj.loss, colocate_gradients_with_ops=True)
|
||||
return Tower(
|
||||
tf.group(*[batch.initializer for batch in device_input_batches]),
|
||||
device_grads,
|
||||
device_loss_obj)
|
||||
Updates shared model weights based on the averaged per-device
|
||||
gradients.
|
||||
|
||||
Args:
|
||||
sess: TensorFlow session.
|
||||
batch_index: Offset into the preloaded data. This value must be
|
||||
between `0` and `tuples_per_device`. The amount of data to
|
||||
process is always fixed to `per_device_batch_size`.
|
||||
extra_ops: Extra ops to run with this step (e.g. for metrics).
|
||||
extra_feed_dict: Extra args to feed into this session run.
|
||||
file_writer: If specified, tf metrics will be written out using
|
||||
this.
|
||||
|
||||
Returns:
|
||||
The outputs of extra_ops evaluated over the batch.
|
||||
"""
|
||||
|
||||
if file_writer:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
else:
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
|
||||
feed_dict = {self._batch_index: batch_index}
|
||||
feed_dict.update(extra_feed_dict)
|
||||
outs = sess.run(
|
||||
[self._train_op] + extra_ops,
|
||||
feed_dict=feed_dict,
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
|
||||
if file_writer:
|
||||
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
|
||||
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"),
|
||||
"w")
|
||||
trace_file.write(trace.generate_chrome_trace_format())
|
||||
file_writer.add_run_metadata(
|
||||
run_metadata, "sgd_train_{}".format(batch_index))
|
||||
|
||||
return outs[1:]
|
||||
|
||||
def get_common_loss(self):
|
||||
return self._shared_loss
|
||||
|
||||
def get_device_losses(self):
|
||||
return [t.loss_object for t in self._towers]
|
||||
|
||||
def _setup_device(self, device, device_input_placeholders):
|
||||
with tf.device(device):
|
||||
with tf.variable_scope("tower", reuse=True):
|
||||
device_input_batches = []
|
||||
device_input_slices = []
|
||||
for ph in device_input_placeholders:
|
||||
current_batch = tf.Variable(
|
||||
ph, trainable=False, validate_shape=False,
|
||||
collections=[])
|
||||
device_input_batches.append(current_batch)
|
||||
current_slice = tf.slice(
|
||||
current_batch,
|
||||
[self._batch_index] + [0] * len(ph.shape[1:]),
|
||||
([self.per_device_batch_size] + [-1] *
|
||||
len(ph.shape[1:])))
|
||||
current_slice.set_shape(ph.shape)
|
||||
device_input_slices.append(current_slice)
|
||||
device_loss_obj = self.build_loss(*device_input_slices)
|
||||
device_grads = self.optimizer.compute_gradients(
|
||||
device_loss_obj.loss, colocate_gradients_with_ops=True)
|
||||
return Tower(
|
||||
tf.group(*[batch.initializer
|
||||
for batch in device_input_batches]),
|
||||
device_grads,
|
||||
device_loss_obj)
|
||||
|
||||
|
||||
# Each tower is a copy of the loss graph pinned to a specific device.
|
||||
@@ -200,50 +203,51 @@ Tower = namedtuple("Tower", ["init_op", "grads", "loss_object"])
|
||||
|
||||
|
||||
def make_divisible_by(array, n):
|
||||
return array[0:array.shape[0] - array.shape[0] % n]
|
||||
return array[0:array.shape[0] - array.shape[0] % n]
|
||||
|
||||
|
||||
def average_gradients(tower_grads):
|
||||
"""Averages gradients across towers.
|
||||
"""Averages gradients across towers.
|
||||
|
||||
Calculate the average gradient for each shared variable across all towers.
|
||||
Note that this function provides a synchronization point across all towers.
|
||||
Calculate the average gradient for each shared variable across all towers.
|
||||
Note that this function provides a synchronization point across all towers.
|
||||
|
||||
Args:
|
||||
tower_grads: List of lists of (gradient, variable) tuples. The outer list
|
||||
is over individual gradients. The inner list is over the gradient
|
||||
calculation for each tower.
|
||||
Args:
|
||||
tower_grads: List of lists of (gradient, variable) tuples. The outer
|
||||
list is over individual gradients. The inner list is over the
|
||||
gradient calculation for each tower.
|
||||
|
||||
Returns:
|
||||
List of pairs of (gradient, variable) where the gradient has been averaged
|
||||
across all towers.
|
||||
Returns:
|
||||
List of pairs of (gradient, variable) where the gradient has been
|
||||
averaged across all towers.
|
||||
|
||||
TODO(ekl): We could use NCCL if this becomes a bottleneck.
|
||||
"""
|
||||
TODO(ekl): We could use NCCL if this becomes a bottleneck.
|
||||
"""
|
||||
|
||||
average_grads = []
|
||||
for grad_and_vars in zip(*tower_grads):
|
||||
average_grads = []
|
||||
for grad_and_vars in zip(*tower_grads):
|
||||
|
||||
# Note that each grad_and_vars looks like the following:
|
||||
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
|
||||
grads = []
|
||||
for g, _ in grad_and_vars:
|
||||
if g is not None:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
# Note that each grad_and_vars looks like the following:
|
||||
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
|
||||
grads = []
|
||||
for g, _ in grad_and_vars:
|
||||
if g is not None:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
# Append on a 'tower' dimension which we will average over
|
||||
# below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension.
|
||||
grad = tf.concat(axis=0, values=grads)
|
||||
grad = tf.reduce_mean(grad, 0)
|
||||
# Average over the 'tower' dimension.
|
||||
grad = tf.concat(axis=0, values=grads)
|
||||
grad = tf.reduce_mean(grad, 0)
|
||||
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads.append(grad_and_var)
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads.append(grad_and_var)
|
||||
|
||||
return average_grads
|
||||
return average_grads
|
||||
|
||||
@@ -25,133 +25,141 @@ from ray.rllib.policy_gradient.rollout import rollouts, add_advantage_values
|
||||
|
||||
|
||||
class Agent(object):
|
||||
"""
|
||||
Agent class that holds the simulator environment and the policy.
|
||||
"""
|
||||
Agent class that holds the simulator environment and the policy.
|
||||
|
||||
Initializes the tensorflow graphs for both training and evaluation.
|
||||
One common policy graph is initialized on '/cpu:0' and holds all the shared
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
"""
|
||||
Initializes the tensorflow graphs for both training and evaluation.
|
||||
One common policy graph is initialized on '/cpu:0' and holds all the shared
|
||||
network weights. When run as a remote agent, only this graph is used.
|
||||
"""
|
||||
|
||||
def __init__(self, name, batchsize, preprocessor, config, logdir, is_remote):
|
||||
if is_remote:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
devices = ["/cpu:0"]
|
||||
else:
|
||||
devices = config["devices"]
|
||||
self.devices = devices
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor)
|
||||
if preprocessor.shape is None:
|
||||
preprocessor.shape = self.env.observation_space.shape
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
config_proto = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.preprocessor = preprocessor
|
||||
self.sess = tf.Session(config=config_proto)
|
||||
if config["use_tf_debugger"] and not is_remote:
|
||||
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
|
||||
self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
|
||||
def __init__(self, name, batchsize, preprocessor, config, logdir,
|
||||
is_remote):
|
||||
if is_remote:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
devices = ["/cpu:0"]
|
||||
else:
|
||||
devices = config["devices"]
|
||||
self.devices = devices
|
||||
self.config = config
|
||||
self.logdir = logdir
|
||||
self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor)
|
||||
if preprocessor.shape is None:
|
||||
preprocessor.shape = self.env.observation_space.shape
|
||||
if is_remote:
|
||||
config_proto = tf.ConfigProto()
|
||||
else:
|
||||
config_proto = tf.ConfigProto(**config["tf_session_args"])
|
||||
self.preprocessor = preprocessor
|
||||
self.sess = tf.Session(config=config_proto)
|
||||
if config["use_tf_debugger"] and not is_remote:
|
||||
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
|
||||
self.sess.add_tensor_filter("has_inf_or_nan",
|
||||
tf_debug.has_inf_or_nan)
|
||||
|
||||
# Defines the training inputs.
|
||||
self.kl_coeff = tf.placeholder(name="newkl", shape=(), dtype=tf.float32)
|
||||
self.observations = tf.placeholder(tf.float32,
|
||||
shape=(None,) + preprocessor.shape)
|
||||
self.advantages = tf.placeholder(tf.float32, shape=(None,))
|
||||
# Defines the training inputs.
|
||||
self.kl_coeff = tf.placeholder(name="newkl", shape=(),
|
||||
dtype=tf.float32)
|
||||
self.observations = tf.placeholder(tf.float32,
|
||||
shape=(None,) + preprocessor.shape)
|
||||
self.advantages = tf.placeholder(tf.float32, shape=(None,))
|
||||
|
||||
action_space = self.env.action_space
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
# The first half of the dimensions are the means, the second half are the
|
||||
# standard deviations.
|
||||
self.action_dim = action_space.shape[0]
|
||||
self.action_shape = (self.action_dim,)
|
||||
self.logit_dim = 2 * self.action_dim
|
||||
self.actions = tf.placeholder(tf.float32, shape=(None, self.action_dim))
|
||||
self.distribution_class = DiagGaussian
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
self.action_dim = action_space.n
|
||||
self.action_shape = ()
|
||||
self.logit_dim = self.action_dim
|
||||
self.actions = tf.placeholder(tf.int64, shape=(None,))
|
||||
self.distribution_class = Categorical
|
||||
else:
|
||||
raise NotImplemented("action space" + str(type(action_space)) +
|
||||
"currently not supported")
|
||||
self.prev_logits = tf.placeholder(tf.float32, shape=(None, self.logit_dim))
|
||||
action_space = self.env.action_space
|
||||
if isinstance(action_space, gym.spaces.Box):
|
||||
# The first half of the dimensions are the means, the second half
|
||||
# are the standard deviations.
|
||||
self.action_dim = action_space.shape[0]
|
||||
self.action_shape = (self.action_dim,)
|
||||
self.logit_dim = 2 * self.action_dim
|
||||
self.actions = tf.placeholder(tf.float32,
|
||||
shape=(None, self.action_dim))
|
||||
self.distribution_class = DiagGaussian
|
||||
elif isinstance(action_space, gym.spaces.Discrete):
|
||||
self.action_dim = action_space.n
|
||||
self.action_shape = ()
|
||||
self.logit_dim = self.action_dim
|
||||
self.actions = tf.placeholder(tf.int64, shape=(None,))
|
||||
self.distribution_class = Categorical
|
||||
else:
|
||||
raise NotImplemented("action space" + str(type(action_space)) +
|
||||
"currently not supported")
|
||||
self.prev_logits = tf.placeholder(tf.float32,
|
||||
shape=(None, self.logit_dim))
|
||||
|
||||
assert config["sgd_batchsize"] % len(devices) == 0, \
|
||||
"Batch size must be evenly divisible by devices"
|
||||
if is_remote:
|
||||
self.batch_size = 1
|
||||
self.per_device_batch_size = 1
|
||||
else:
|
||||
self.batch_size = config["sgd_batchsize"]
|
||||
self.per_device_batch_size = int(self.batch_size / len(devices))
|
||||
assert config["sgd_batchsize"] % len(devices) == 0, \
|
||||
"Batch size must be evenly divisible by devices"
|
||||
if is_remote:
|
||||
self.batch_size = 1
|
||||
self.per_device_batch_size = 1
|
||||
else:
|
||||
self.batch_size = config["sgd_batchsize"]
|
||||
self.per_device_batch_size = int(self.batch_size / len(devices))
|
||||
|
||||
def build_loss(obs, advs, acts, plog):
|
||||
return ProximalPolicyLoss(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
obs, advs, acts, plog, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config, self.sess)
|
||||
def build_loss(obs, advs, acts, plog):
|
||||
return ProximalPolicyLoss(
|
||||
self.env.observation_space, self.env.action_space,
|
||||
obs, advs, acts, plog, self.logit_dim,
|
||||
self.kl_coeff, self.distribution_class, self.config, self.sess)
|
||||
|
||||
self.par_opt = LocalSyncParallelOptimizer(
|
||||
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
|
||||
self.devices,
|
||||
[self.observations, self.advantages, self.actions, self.prev_logits],
|
||||
self.per_device_batch_size,
|
||||
build_loss,
|
||||
self.logdir)
|
||||
self.par_opt = LocalSyncParallelOptimizer(
|
||||
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
|
||||
self.devices,
|
||||
[self.observations, self.advantages, self.actions,
|
||||
self.prev_logits],
|
||||
self.per_device_batch_size,
|
||||
build_loss,
|
||||
self.logdir)
|
||||
|
||||
# Metric ops
|
||||
with tf.name_scope("test_outputs"):
|
||||
policies = self.par_opt.get_device_losses()
|
||||
self.mean_loss = tf.reduce_mean(
|
||||
tf.stack(values=[policy.loss for policy in policies]), 0)
|
||||
self.mean_kl = tf.reduce_mean(
|
||||
tf.stack(values=[policy.mean_kl for policy in policies]), 0)
|
||||
self.mean_entropy = tf.reduce_mean(
|
||||
tf.stack(values=[policy.mean_entropy for policy in policies]), 0)
|
||||
# Metric ops
|
||||
with tf.name_scope("test_outputs"):
|
||||
policies = self.par_opt.get_device_losses()
|
||||
self.mean_loss = tf.reduce_mean(
|
||||
tf.stack(values=[policy.loss for policy in policies]), 0)
|
||||
self.mean_kl = tf.reduce_mean(
|
||||
tf.stack(values=[policy.mean_kl for policy in policies]), 0)
|
||||
self.mean_entropy = tf.reduce_mean(
|
||||
tf.stack(values=[policy.mean_entropy for policy in policies]),
|
||||
0)
|
||||
|
||||
# References to the model weights
|
||||
self.common_policy = self.par_opt.get_common_loss()
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.common_policy.loss,
|
||||
self.sess)
|
||||
self.observation_filter = MeanStdFilter(preprocessor.shape, clip=None)
|
||||
self.reward_filter = MeanStdFilter((), clip=5.0)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
# References to the model weights
|
||||
self.common_policy = self.par_opt.get_common_loss()
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.common_policy.loss,
|
||||
self.sess)
|
||||
self.observation_filter = MeanStdFilter(preprocessor.shape, clip=None)
|
||||
self.reward_filter = MeanStdFilter((), clip=5.0)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
def load_data(self, trajectories, full_trace):
|
||||
return self.par_opt.load_data(
|
||||
self.sess,
|
||||
[trajectories["observations"],
|
||||
trajectories["advantages"],
|
||||
trajectories["actions"].squeeze(),
|
||||
trajectories["logprobs"]],
|
||||
full_trace=full_trace)
|
||||
def load_data(self, trajectories, full_trace):
|
||||
return self.par_opt.load_data(
|
||||
self.sess,
|
||||
[trajectories["observations"],
|
||||
trajectories["advantages"],
|
||||
trajectories["actions"].squeeze(),
|
||||
trajectories["logprobs"]],
|
||||
full_trace=full_trace)
|
||||
|
||||
def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace, file_writer):
|
||||
return self.par_opt.optimize(
|
||||
self.sess,
|
||||
batch_index,
|
||||
extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy],
|
||||
extra_feed_dict={self.kl_coeff: kl_coeff},
|
||||
file_writer=file_writer if full_trace else None)
|
||||
def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace,
|
||||
file_writer):
|
||||
return self.par_opt.optimize(
|
||||
self.sess,
|
||||
batch_index,
|
||||
extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy],
|
||||
extra_feed_dict={self.kl_coeff: kl_coeff},
|
||||
file_writer=file_writer if full_trace else None)
|
||||
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
def get_weights(self):
|
||||
return self.variables.get_weights()
|
||||
|
||||
def load_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
def load_weights(self, weights):
|
||||
self.variables.set_weights(weights)
|
||||
|
||||
def compute_trajectory(self, gamma, lam, horizon):
|
||||
trajectory = rollouts(
|
||||
self.common_policy,
|
||||
self.env, horizon, self.observation_filter, self.reward_filter)
|
||||
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
|
||||
return trajectory
|
||||
def compute_trajectory(self, gamma, lam, horizon):
|
||||
trajectory = rollouts(
|
||||
self.common_policy,
|
||||
self.env, horizon, self.observation_filter, self.reward_filter)
|
||||
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
|
||||
return trajectory
|
||||
|
||||
|
||||
RemoteAgent = ray.remote(Agent)
|
||||
|
||||
@@ -7,63 +7,63 @@ import numpy as np
|
||||
|
||||
|
||||
class Categorical(object):
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
|
||||
def logp(self, x):
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits,
|
||||
labels=x)
|
||||
def logp(self, x):
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=self.logits, labels=x)
|
||||
|
||||
def entropy(self):
|
||||
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
|
||||
keep_dims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
|
||||
p0 = ea0 / z0
|
||||
return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1])
|
||||
def entropy(self):
|
||||
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
|
||||
keep_dims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
|
||||
p0 = ea0 / z0
|
||||
return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1])
|
||||
|
||||
def kl(self, other):
|
||||
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
|
||||
keep_dims=True)
|
||||
a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1],
|
||||
keep_dims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
ea1 = tf.exp(a1)
|
||||
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
|
||||
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True)
|
||||
p0 = ea0 / z0
|
||||
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)),
|
||||
reduction_indices=[1])
|
||||
def kl(self, other):
|
||||
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
|
||||
keep_dims=True)
|
||||
a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1],
|
||||
keep_dims=True)
|
||||
ea0 = tf.exp(a0)
|
||||
ea1 = tf.exp(a1)
|
||||
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
|
||||
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True)
|
||||
p0 = ea0 / z0
|
||||
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)),
|
||||
reduction_indices=[1])
|
||||
|
||||
def sample(self):
|
||||
return tf.multinomial(self.logits, 1)
|
||||
def sample(self):
|
||||
return tf.multinomial(self.logits, 1)
|
||||
|
||||
|
||||
class DiagGaussian(object):
|
||||
def __init__(self, flat):
|
||||
self.flat = flat
|
||||
mean, logstd = tf.split(flat, 2, axis=1)
|
||||
self.mean = mean
|
||||
self.logstd = logstd
|
||||
self.std = tf.exp(logstd)
|
||||
def __init__(self, flat):
|
||||
self.flat = flat
|
||||
mean, logstd = tf.split(flat, 2, axis=1)
|
||||
self.mean = mean
|
||||
self.logstd = logstd
|
||||
self.std = tf.exp(logstd)
|
||||
|
||||
def logp(self, x):
|
||||
return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std),
|
||||
reduction_indices=[1]) -
|
||||
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
|
||||
tf.reduce_sum(self.logstd, reduction_indices=[1]))
|
||||
def logp(self, x):
|
||||
return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std),
|
||||
reduction_indices=[1]) -
|
||||
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
|
||||
tf.reduce_sum(self.logstd, reduction_indices=[1]))
|
||||
|
||||
def kl(self, other):
|
||||
assert isinstance(other, DiagGaussian)
|
||||
return tf.reduce_sum(other.logstd - self.logstd +
|
||||
(tf.square(self.std) +
|
||||
tf.square(self.mean - other.mean)) /
|
||||
(2.0 * tf.square(other.std)) - 0.5,
|
||||
reduction_indices=[1])
|
||||
def kl(self, other):
|
||||
assert isinstance(other, DiagGaussian)
|
||||
return tf.reduce_sum(other.logstd - self.logstd +
|
||||
(tf.square(self.std) +
|
||||
tf.square(self.mean - other.mean)) /
|
||||
(2.0 * tf.square(other.std)) - 0.5,
|
||||
reduction_indices=[1])
|
||||
|
||||
def entropy(self):
|
||||
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e),
|
||||
reduction_indices=[1])
|
||||
def entropy(self):
|
||||
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e),
|
||||
reduction_indices=[1])
|
||||
|
||||
def sample(self):
|
||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
def sample(self):
|
||||
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
|
||||
|
||||
@@ -7,59 +7,60 @@ import numpy as np
|
||||
|
||||
|
||||
class AtariPixelPreprocessor(object):
|
||||
def __init__(self):
|
||||
self.shape = (80, 80, 3)
|
||||
def __init__(self):
|
||||
self.shape = (80, 80, 3)
|
||||
|
||||
def __call__(self, observation):
|
||||
"Convert images from (210, 160, 3) to (3, 80, 80) by downsampling."
|
||||
return (observation[25:-25:2, ::2, :][None] - 128) / 128
|
||||
def __call__(self, observation):
|
||||
"Convert images from (210, 160, 3) to (3, 80, 80) by downsampling."
|
||||
return (observation[25:-25:2, ::2, :][None] - 128) / 128
|
||||
|
||||
|
||||
class AtariRamPreprocessor(object):
|
||||
def __init__(self):
|
||||
self.shape = (128,)
|
||||
def __init__(self):
|
||||
self.shape = (128,)
|
||||
|
||||
def __call__(self, observation):
|
||||
return (observation - 128) / 128
|
||||
def __call__(self, observation):
|
||||
return (observation - 128) / 128
|
||||
|
||||
|
||||
class NoPreprocessor(object):
|
||||
def __init__(self):
|
||||
self.shape = None
|
||||
def __init__(self):
|
||||
self.shape = None
|
||||
|
||||
def __call__(self, observation):
|
||||
return observation
|
||||
def __call__(self, observation):
|
||||
return observation
|
||||
|
||||
|
||||
class BatchedEnv(object):
|
||||
"""This holds multiple gym enviroments and performs steps on all of them."""
|
||||
def __init__(self, name, batchsize, preprocessor=None):
|
||||
self.envs = [gym.make(name) for _ in range(batchsize)]
|
||||
self.observation_space = self.envs[0].observation_space
|
||||
self.action_space = self.envs[0].action_space
|
||||
self.batchsize = batchsize
|
||||
self.preprocessor = preprocessor if preprocessor else lambda obs: obs[None]
|
||||
"""This holds multiple gym envs and performs steps on all of them."""
|
||||
def __init__(self, name, batchsize, preprocessor=None):
|
||||
self.envs = [gym.make(name) for _ in range(batchsize)]
|
||||
self.observation_space = self.envs[0].observation_space
|
||||
self.action_space = self.envs[0].action_space
|
||||
self.batchsize = batchsize
|
||||
self.preprocessor = (preprocessor if preprocessor
|
||||
else lambda obs: obs[None])
|
||||
|
||||
def reset(self):
|
||||
observations = [self.preprocessor(env.reset()) for env in self.envs]
|
||||
self.shape = observations[0].shape
|
||||
self.dones = [False for _ in range(self.batchsize)]
|
||||
return np.vstack(observations)
|
||||
def reset(self):
|
||||
observations = [self.preprocessor(env.reset()) for env in self.envs]
|
||||
self.shape = observations[0].shape
|
||||
self.dones = [False for _ in range(self.batchsize)]
|
||||
return np.vstack(observations)
|
||||
|
||||
def step(self, actions, render=False):
|
||||
observations = []
|
||||
rewards = []
|
||||
for i, action in enumerate(actions):
|
||||
if self.dones[i]:
|
||||
observations.append(np.zeros(self.shape))
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
observation, reward, done, info = self.envs[i].step(
|
||||
action if len(action) > 1 else action[0])
|
||||
if render:
|
||||
self.envs[0].render()
|
||||
observations.append(self.preprocessor(observation))
|
||||
rewards.append(reward)
|
||||
self.dones[i] = done
|
||||
return (np.vstack(observations), np.array(rewards, dtype="float32"),
|
||||
np.array(self.dones))
|
||||
def step(self, actions, render=False):
|
||||
observations = []
|
||||
rewards = []
|
||||
for i, action in enumerate(actions):
|
||||
if self.dones[i]:
|
||||
observations.append(np.zeros(self.shape))
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
observation, reward, done, info = self.envs[i].step(
|
||||
action if len(action) > 1 else action[0])
|
||||
if render:
|
||||
self.envs[0].render()
|
||||
observations.append(self.preprocessor(observation))
|
||||
rewards.append(reward)
|
||||
self.dones[i] = done
|
||||
return (np.vstack(observations), np.array(rewards, dtype="float32"),
|
||||
np.array(self.dones))
|
||||
|
||||
@@ -11,28 +11,28 @@ from ray.rllib.policy_gradient import PolicyGradient, DEFAULT_CONFIG
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the policy gradient "
|
||||
"algorithm.")
|
||||
parser.add_argument("--environment", default="Pong-v0", type=str,
|
||||
help="The gym environment to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--use-tf-debugger", default=False, type=bool,
|
||||
help="Run the script inside of tf-dbg.")
|
||||
parser.add_argument("--load-checkpoint", default=None, type=str,
|
||||
help="Continue training from a checkpoint.")
|
||||
parser = argparse.ArgumentParser(description="Run the policy gradient "
|
||||
"algorithm.")
|
||||
parser.add_argument("--environment", default="Pong-v0", type=str,
|
||||
help="The gym environment to use.")
|
||||
parser.add_argument("--redis-address", default=None, type=str,
|
||||
help="The Redis address of the cluster.")
|
||||
parser.add_argument("--use-tf-debugger", default=False, type=bool,
|
||||
help="Run the script inside of tf-dbg.")
|
||||
parser.add_argument("--load-checkpoint", default=None, type=str,
|
||||
help="Continue training from a checkpoint.")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["use_tf_debugger"] = args.use_tf_debugger
|
||||
if args.load_checkpoint:
|
||||
config["load_checkpoint"] = args.load_checkpoint
|
||||
args = parser.parse_args()
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
config["use_tf_debugger"] = args.use_tf_debugger
|
||||
if args.load_checkpoint:
|
||||
config["load_checkpoint"] = args.load_checkpoint
|
||||
|
||||
ray.init(redis_address=args.redis_address)
|
||||
ray.init(redis_address=args.redis_address)
|
||||
|
||||
alg = PolicyGradient(args.environment, config)
|
||||
result = alg.train()
|
||||
while result.training_iteration < config["max_iterations"]:
|
||||
print("\n== iteration", result.training_iteration)
|
||||
alg = PolicyGradient(args.environment, config)
|
||||
result = alg.train()
|
||||
print("current status: {}".format(result))
|
||||
while result.training_iteration < config["max_iterations"]:
|
||||
print("\n== iteration", result.training_iteration)
|
||||
result = alg.train()
|
||||
print("current status: {}".format(result))
|
||||
|
||||
@@ -6,127 +6,127 @@ import numpy as np
|
||||
|
||||
|
||||
class NoFilter(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
return np.asarray(x)
|
||||
def __call__(self, x, update=True):
|
||||
return np.asarray(x)
|
||||
|
||||
|
||||
# http://www.johndcook.com/blog/standard_deviation/
|
||||
class RunningStat(object):
|
||||
|
||||
def __init__(self, shape=None):
|
||||
self._n = 0
|
||||
self._M = np.zeros(shape)
|
||||
self._S = np.zeros(shape)
|
||||
def __init__(self, shape=None):
|
||||
self._n = 0
|
||||
self._M = np.zeros(shape)
|
||||
self._S = np.zeros(shape)
|
||||
|
||||
def push(self, x):
|
||||
x = np.asarray(x)
|
||||
# Unvectorized update of the running statistics.
|
||||
assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}"
|
||||
.format(x.shape, self._M.shape))
|
||||
n1 = self._n
|
||||
self._n += 1
|
||||
if self._n == 1:
|
||||
self._M[...] = x
|
||||
else:
|
||||
delta = x - self._M
|
||||
self._M[...] += delta / self._n
|
||||
self._S[...] += delta * delta * n1 / self._n
|
||||
def push(self, x):
|
||||
x = np.asarray(x)
|
||||
# Unvectorized update of the running statistics.
|
||||
assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}"
|
||||
.format(x.shape, self._M.shape))
|
||||
n1 = self._n
|
||||
self._n += 1
|
||||
if self._n == 1:
|
||||
self._M[...] = x
|
||||
else:
|
||||
delta = x - self._M
|
||||
self._M[...] += delta / self._n
|
||||
self._S[...] += delta * delta * n1 / self._n
|
||||
|
||||
def update(self, other):
|
||||
n1 = self._n
|
||||
n2 = other._n
|
||||
n = n1 + n2
|
||||
delta = self._M - other._M
|
||||
delta2 = delta * delta
|
||||
M = (n1 * self._M + n2 * other._M) / n
|
||||
S = self._S + other._S + delta2 * n1 * n2 / n
|
||||
self._n = n
|
||||
self._M = M
|
||||
self._S = S
|
||||
def update(self, other):
|
||||
n1 = self._n
|
||||
n2 = other._n
|
||||
n = n1 + n2
|
||||
delta = self._M - other._M
|
||||
delta2 = delta * delta
|
||||
M = (n1 * self._M + n2 * other._M) / n
|
||||
S = self._S + other._S + delta2 * n1 * n2 / n
|
||||
self._n = n
|
||||
self._M = M
|
||||
self._S = S
|
||||
|
||||
@property
|
||||
def n(self):
|
||||
return self._n
|
||||
@property
|
||||
def n(self):
|
||||
return self._n
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._M
|
||||
@property
|
||||
def mean(self):
|
||||
return self._M
|
||||
|
||||
@property
|
||||
def var(self):
|
||||
return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
|
||||
@property
|
||||
def var(self):
|
||||
return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
|
||||
|
||||
@property
|
||||
def std(self):
|
||||
return np.sqrt(self.var)
|
||||
@property
|
||||
def std(self):
|
||||
return np.sqrt(self.var)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._M.shape
|
||||
@property
|
||||
def shape(self):
|
||||
return self._M.shape
|
||||
|
||||
|
||||
class MeanStdFilter(object):
|
||||
def __init__(self, shape, demean=True, destd=True, clip=10.0):
|
||||
self.demean = demean
|
||||
self.destd = destd
|
||||
self.clip = clip
|
||||
def __init__(self, shape, demean=True, destd=True, clip=10.0):
|
||||
self.demean = demean
|
||||
self.destd = destd
|
||||
self.clip = clip
|
||||
|
||||
self.rs = RunningStat(shape)
|
||||
self.rs = RunningStat(shape)
|
||||
|
||||
def __call__(self, x, update=True):
|
||||
x = np.asarray(x)
|
||||
if update:
|
||||
if len(x.shape) == len(self.rs.shape) + 1:
|
||||
# The vectorized case.
|
||||
for i in range(x.shape[0]):
|
||||
self.rs.push(x[i])
|
||||
else:
|
||||
# The unvectorized case.
|
||||
self.rs.push(x)
|
||||
if self.demean:
|
||||
x = x - self.rs.mean
|
||||
if self.destd:
|
||||
x = x / (self.rs.std + 1e-8)
|
||||
if self.clip:
|
||||
x = np.clip(x, -self.clip, self.clip)
|
||||
return x
|
||||
def __call__(self, x, update=True):
|
||||
x = np.asarray(x)
|
||||
if update:
|
||||
if len(x.shape) == len(self.rs.shape) + 1:
|
||||
# The vectorized case.
|
||||
for i in range(x.shape[0]):
|
||||
self.rs.push(x[i])
|
||||
else:
|
||||
# The unvectorized case.
|
||||
self.rs.push(x)
|
||||
if self.demean:
|
||||
x = x - self.rs.mean
|
||||
if self.destd:
|
||||
x = x / (self.rs.std + 1e-8)
|
||||
if self.clip:
|
||||
x = np.clip(x, -self.clip, self.clip)
|
||||
return x
|
||||
|
||||
|
||||
def test_running_stat():
|
||||
for shp in ((), (3,), (3, 4)):
|
||||
li = []
|
||||
rs = RunningStat(shp)
|
||||
for _ in range(5):
|
||||
val = np.random.randn(*shp)
|
||||
rs.push(val)
|
||||
li.append(val)
|
||||
m = np.mean(li, axis=0)
|
||||
assert np.allclose(rs.mean, m)
|
||||
v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0)
|
||||
assert np.allclose(rs.var, v)
|
||||
for shp in ((), (3,), (3, 4)):
|
||||
li = []
|
||||
rs = RunningStat(shp)
|
||||
for _ in range(5):
|
||||
val = np.random.randn(*shp)
|
||||
rs.push(val)
|
||||
li.append(val)
|
||||
m = np.mean(li, axis=0)
|
||||
assert np.allclose(rs.mean, m)
|
||||
v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0)
|
||||
assert np.allclose(rs.var, v)
|
||||
|
||||
|
||||
def test_combining_stat():
|
||||
for shape in [(), (3,), (3, 4)]:
|
||||
li = []
|
||||
rs1 = RunningStat(shape)
|
||||
rs2 = RunningStat(shape)
|
||||
rs = RunningStat(shape)
|
||||
for _ in range(5):
|
||||
val = np.random.randn(*shape)
|
||||
rs1.push(val)
|
||||
rs.push(val)
|
||||
li.append(val)
|
||||
for _ in range(9):
|
||||
rs2.push(val)
|
||||
rs.push(val)
|
||||
li.append(val)
|
||||
rs1.update(rs2)
|
||||
assert np.allclose(rs.mean, rs1.mean)
|
||||
assert np.allclose(rs.std, rs1.std)
|
||||
for shape in [(), (3,), (3, 4)]:
|
||||
li = []
|
||||
rs1 = RunningStat(shape)
|
||||
rs2 = RunningStat(shape)
|
||||
rs = RunningStat(shape)
|
||||
for _ in range(5):
|
||||
val = np.random.randn(*shape)
|
||||
rs1.push(val)
|
||||
rs.push(val)
|
||||
li.append(val)
|
||||
for _ in range(9):
|
||||
rs2.push(val)
|
||||
rs.push(val)
|
||||
li.append(val)
|
||||
rs1.update(rs2)
|
||||
assert np.allclose(rs.mean, rs1.mean)
|
||||
assert np.allclose(rs.std, rs1.std)
|
||||
|
||||
|
||||
test_running_stat()
|
||||
|
||||
@@ -10,43 +10,43 @@ from ray.rllib.policy_gradient.models.fcnet import fc_net
|
||||
|
||||
class ProximalPolicyLoss(object):
|
||||
|
||||
def __init__(
|
||||
self, observation_space, action_space,
|
||||
observations, advantages, actions, prev_logits, logit_dim,
|
||||
kl_coeff, distribution_class, config, sess):
|
||||
assert (isinstance(action_space, gym.spaces.Discrete) or
|
||||
isinstance(action_space, gym.spaces.Box))
|
||||
self.prev_dist = distribution_class(prev_logits)
|
||||
def __init__(
|
||||
self, observation_space, action_space,
|
||||
observations, advantages, actions, prev_logits, logit_dim,
|
||||
kl_coeff, distribution_class, config, sess):
|
||||
assert (isinstance(action_space, gym.spaces.Discrete) or
|
||||
isinstance(action_space, gym.spaces.Box))
|
||||
self.prev_dist = distribution_class(prev_logits)
|
||||
|
||||
# Saved so that we can compute actions given different observations
|
||||
self.observations = observations
|
||||
# Saved so that we can compute actions given different observations
|
||||
self.observations = observations
|
||||
|
||||
if len(observation_space.shape) > 1:
|
||||
self.curr_logits = vision_net(observations, num_classes=logit_dim)
|
||||
else:
|
||||
assert len(observation_space.shape) == 1
|
||||
self.curr_logits = fc_net(observations, num_classes=logit_dim)
|
||||
self.curr_dist = distribution_class(self.curr_logits)
|
||||
self.sampler = self.curr_dist.sample()
|
||||
if len(observation_space.shape) > 1:
|
||||
self.curr_logits = vision_net(observations, num_classes=logit_dim)
|
||||
else:
|
||||
assert len(observation_space.shape) == 1
|
||||
self.curr_logits = fc_net(observations, num_classes=logit_dim)
|
||||
self.curr_dist = distribution_class(self.curr_logits)
|
||||
self.sampler = self.curr_dist.sample()
|
||||
|
||||
# Make loss functions.
|
||||
self.ratio = tf.exp(self.curr_dist.logp(actions) -
|
||||
self.prev_dist.logp(actions))
|
||||
self.kl = self.prev_dist.kl(self.curr_dist)
|
||||
self.mean_kl = tf.reduce_mean(self.kl)
|
||||
self.entropy = self.curr_dist.entropy()
|
||||
self.mean_entropy = tf.reduce_mean(self.entropy)
|
||||
self.surr1 = self.ratio * advantages
|
||||
self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"],
|
||||
1 + config["clip_param"]) * advantages
|
||||
self.surr = tf.minimum(self.surr1, self.surr2)
|
||||
self.loss = tf.reduce_mean(-self.surr + kl_coeff * self.kl -
|
||||
config["entropy_coeff"] * self.entropy)
|
||||
self.sess = sess
|
||||
# Make loss functions.
|
||||
self.ratio = tf.exp(self.curr_dist.logp(actions) -
|
||||
self.prev_dist.logp(actions))
|
||||
self.kl = self.prev_dist.kl(self.curr_dist)
|
||||
self.mean_kl = tf.reduce_mean(self.kl)
|
||||
self.entropy = self.curr_dist.entropy()
|
||||
self.mean_entropy = tf.reduce_mean(self.entropy)
|
||||
self.surr1 = self.ratio * advantages
|
||||
self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"],
|
||||
1 + config["clip_param"]) * advantages
|
||||
self.surr = tf.minimum(self.surr1, self.surr2)
|
||||
self.loss = tf.reduce_mean(-self.surr + kl_coeff * self.kl -
|
||||
config["entropy_coeff"] * self.entropy)
|
||||
self.sess = sess
|
||||
|
||||
def compute_actions(self, observations):
|
||||
return self.sess.run([self.sampler, self.curr_logits],
|
||||
feed_dict={self.observations: observations})
|
||||
def compute_actions(self, observations):
|
||||
return self.sess.run([self.sampler, self.curr_logits],
|
||||
feed_dict={self.observations: observations})
|
||||
|
||||
def loss(self):
|
||||
return self.loss
|
||||
def loss(self):
|
||||
return self.loss
|
||||
|
||||
@@ -9,30 +9,30 @@ import numpy as np
|
||||
|
||||
|
||||
def normc_initializer(std=1.0):
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
def _initializer(shape, dtype=None, partition_info=None):
|
||||
out = np.random.randn(*shape).astype(np.float32)
|
||||
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
|
||||
return tf.constant(out)
|
||||
return _initializer
|
||||
|
||||
|
||||
def fc_net(inputs, num_classes=10, logstd=False):
|
||||
with tf.name_scope("fc_net"):
|
||||
fc1 = slim.fully_connected(inputs, 128,
|
||||
weights_initializer=normc_initializer(1.0),
|
||||
scope="fc1")
|
||||
fc2 = slim.fully_connected(fc1, 128,
|
||||
weights_initializer=normc_initializer(1.0),
|
||||
scope="fc2")
|
||||
fc3 = slim.fully_connected(fc2, 128,
|
||||
weights_initializer=normc_initializer(1.0),
|
||||
scope="fc3")
|
||||
fc4 = slim.fully_connected(fc3, num_classes,
|
||||
weights_initializer=normc_initializer(0.01),
|
||||
activation_fn=None, scope="fc4")
|
||||
if logstd:
|
||||
logstd = tf.get_variable(name="logstd", shape=[num_classes],
|
||||
initializer=tf.zeros_initializer)
|
||||
return tf.concat(1, [fc4, logstd])
|
||||
else:
|
||||
return fc4
|
||||
with tf.name_scope("fc_net"):
|
||||
fc1 = slim.fully_connected(inputs, 128,
|
||||
weights_initializer=normc_initializer(1.0),
|
||||
scope="fc1")
|
||||
fc2 = slim.fully_connected(fc1, 128,
|
||||
weights_initializer=normc_initializer(1.0),
|
||||
scope="fc2")
|
||||
fc3 = slim.fully_connected(fc2, 128,
|
||||
weights_initializer=normc_initializer(1.0),
|
||||
scope="fc3")
|
||||
fc4 = slim.fully_connected(fc3, num_classes,
|
||||
weights_initializer=normc_initializer(0.01),
|
||||
activation_fn=None, scope="fc4")
|
||||
if logstd:
|
||||
logstd = tf.get_variable(name="logstd", shape=[num_classes],
|
||||
initializer=tf.zeros_initializer)
|
||||
return tf.concat(1, [fc4, logstd])
|
||||
else:
|
||||
return fc4
|
||||
|
||||
@@ -7,10 +7,10 @@ import tensorflow.contrib.slim as slim
|
||||
|
||||
|
||||
def vision_net(inputs, num_classes=10):
|
||||
with tf.name_scope("vision_net"):
|
||||
conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1")
|
||||
conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2")
|
||||
fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1")
|
||||
fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope="fc2")
|
||||
return tf.squeeze(fc2, [1, 2])
|
||||
with tf.name_scope("vision_net"):
|
||||
conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1")
|
||||
conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2")
|
||||
fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1")
|
||||
fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None,
|
||||
normalizer_fn=None, scope="fc2")
|
||||
return tf.squeeze(fc2, [1, 2])
|
||||
|
||||
@@ -43,165 +43,170 @@ DEFAULT_CONFIG = {
|
||||
|
||||
|
||||
class PolicyGradient(Algorithm):
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "PolicyGradient"})
|
||||
def __init__(self, env_name, config, upload_dir=None):
|
||||
config.update({"alg": "PolicyGradient"})
|
||||
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
|
||||
|
||||
# TODO(ekl) the preprocessor should be associated with the env elsewhere
|
||||
if self.env_name == "Pong-v0":
|
||||
preprocessor = AtariPixelPreprocessor()
|
||||
elif self.env_name == "Pong-ram-v3":
|
||||
preprocessor = AtariRamPreprocessor()
|
||||
elif self.env_name == "CartPole-v0":
|
||||
preprocessor = NoPreprocessor()
|
||||
elif self.env_name == "Walker2d-v1":
|
||||
preprocessor = NoPreprocessor()
|
||||
else:
|
||||
preprocessor = AtariPixelPreprocessor()
|
||||
# TODO(ekl): The preprocessor should be associated with the env
|
||||
# elsewhere.
|
||||
if self.env_name == "Pong-v0":
|
||||
preprocessor = AtariPixelPreprocessor()
|
||||
elif self.env_name == "Pong-ram-v3":
|
||||
preprocessor = AtariRamPreprocessor()
|
||||
elif self.env_name == "CartPole-v0":
|
||||
preprocessor = NoPreprocessor()
|
||||
elif self.env_name == "Walker2d-v1":
|
||||
preprocessor = NoPreprocessor()
|
||||
else:
|
||||
preprocessor = AtariPixelPreprocessor()
|
||||
|
||||
self.preprocessor = preprocessor
|
||||
self.global_step = 0
|
||||
self.j = 0
|
||||
self.kl_coeff = config["kl_coeff"]
|
||||
self.model = Agent(
|
||||
self.env_name, 1, self.preprocessor, self.config, self.logdir, False)
|
||||
self.agents = [
|
||||
RemoteAgent.remote(
|
||||
self.env_name, 1, self.preprocessor, self.config,
|
||||
self.logdir, True)
|
||||
for _ in range(config["num_agents"])]
|
||||
self.preprocessor = preprocessor
|
||||
self.global_step = 0
|
||||
self.j = 0
|
||||
self.kl_coeff = config["kl_coeff"]
|
||||
self.model = Agent(
|
||||
self.env_name, 1, self.preprocessor, self.config, self.logdir,
|
||||
False)
|
||||
self.agents = [
|
||||
RemoteAgent.remote(
|
||||
self.env_name, 1, self.preprocessor, self.config,
|
||||
self.logdir, True)
|
||||
for _ in range(config["num_agents"])]
|
||||
|
||||
def train(self):
|
||||
agents = self.agents
|
||||
config = self.config
|
||||
model = self.model
|
||||
j = self.j
|
||||
self.j += 1
|
||||
def train(self):
|
||||
agents = self.agents
|
||||
config = self.config
|
||||
model = self.model
|
||||
j = self.j
|
||||
self.j += 1
|
||||
|
||||
saver = tf.train.Saver(max_to_keep=None)
|
||||
if "load_checkpoint" in config:
|
||||
saver.restore(model.sess, config["load_checkpoint"])
|
||||
saver = tf.train.Saver(max_to_keep=None)
|
||||
if "load_checkpoint" in config:
|
||||
saver.restore(model.sess, config["load_checkpoint"])
|
||||
|
||||
# TF does not support to write logs to S3 at the moment
|
||||
write_tf_logs = self.logdir.startswith("file")
|
||||
iter_start = time.time()
|
||||
if write_tf_logs:
|
||||
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
|
||||
if config["model_checkpoint_file"]:
|
||||
checkpoint_path = saver.save(
|
||||
model.sess,
|
||||
os.path.join(self.logdir, config["model_checkpoint_file"] % j))
|
||||
print("Checkpoint saved in file: %s" % checkpoint_path)
|
||||
checkpointing_end = time.time()
|
||||
weights = ray.put(model.get_weights())
|
||||
[a.load_weights.remote(weights) for a in agents]
|
||||
trajectory, total_reward, traj_len_mean = collect_samples(
|
||||
agents, config["timesteps_per_batch"], 0.995, 1.0, 2000)
|
||||
print("total reward is ", total_reward)
|
||||
print("trajectory length mean is ", traj_len_mean)
|
||||
print("timesteps:", trajectory["dones"].shape[0])
|
||||
if write_tf_logs:
|
||||
traj_stats = tf.Summary(value=[
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/mean_reward",
|
||||
simple_value=total_reward),
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/traj_len_mean",
|
||||
simple_value=traj_len_mean)])
|
||||
file_writer.add_summary(traj_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
trajectory["advantages"] = ((trajectory["advantages"] -
|
||||
trajectory["advantages"].mean()) /
|
||||
trajectory["advantages"].std())
|
||||
rollouts_end = time.time()
|
||||
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
|
||||
", stepsize=" + str(config["sgd_stepsize"]) + "):")
|
||||
names = ["iter", "loss", "kl", "entropy"]
|
||||
print(("{:>15}" * len(names)).format(*names))
|
||||
trajectory = shuffle(trajectory)
|
||||
shuffle_end = time.time()
|
||||
tuples_per_device = model.load_data(
|
||||
trajectory, j == 0 and config["full_trace_data_load"])
|
||||
load_end = time.time()
|
||||
checkpointing_time = checkpointing_end - iter_start
|
||||
rollouts_time = rollouts_end - checkpointing_end
|
||||
shuffle_time = shuffle_end - rollouts_end
|
||||
load_time = load_end - shuffle_end
|
||||
sgd_time = 0
|
||||
for i in range(config["num_sgd_iter"]):
|
||||
sgd_start = time.time()
|
||||
batch_index = 0
|
||||
num_batches = int(tuples_per_device) // int(model.per_device_batch_size)
|
||||
loss, kl, entropy = [], [], []
|
||||
permutation = np.random.permutation(num_batches)
|
||||
while batch_index < num_batches:
|
||||
full_trace = (
|
||||
i == 0 and j == 0 and
|
||||
batch_index == config["full_trace_nth_sgd_batch"])
|
||||
batch_loss, batch_kl, batch_entropy = model.run_sgd_minibatch(
|
||||
permutation[batch_index] * model.per_device_batch_size,
|
||||
self.kl_coeff, full_trace,
|
||||
file_writer if write_tf_logs else None)
|
||||
loss.append(batch_loss)
|
||||
kl.append(batch_kl)
|
||||
entropy.append(batch_entropy)
|
||||
batch_index += 1
|
||||
loss = np.mean(loss)
|
||||
kl = np.mean(kl)
|
||||
entropy = np.mean(entropy)
|
||||
sgd_end = time.time()
|
||||
print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl, entropy))
|
||||
# TF does not support to write logs to S3 at the moment
|
||||
write_tf_logs = self.logdir.startswith("file")
|
||||
iter_start = time.time()
|
||||
if write_tf_logs:
|
||||
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
|
||||
if config["model_checkpoint_file"]:
|
||||
checkpoint_path = saver.save(
|
||||
model.sess,
|
||||
os.path.join(self.logdir,
|
||||
config["model_checkpoint_file"] % j))
|
||||
print("Checkpoint saved in file: %s" % checkpoint_path)
|
||||
checkpointing_end = time.time()
|
||||
weights = ray.put(model.get_weights())
|
||||
[a.load_weights.remote(weights) for a in agents]
|
||||
trajectory, total_reward, traj_len_mean = collect_samples(
|
||||
agents, config["timesteps_per_batch"], 0.995, 1.0, 2000)
|
||||
print("total reward is ", total_reward)
|
||||
print("trajectory length mean is ", traj_len_mean)
|
||||
print("timesteps:", trajectory["dones"].shape[0])
|
||||
if write_tf_logs:
|
||||
traj_stats = tf.Summary(value=[
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/mean_reward",
|
||||
simple_value=total_reward),
|
||||
tf.Summary.Value(
|
||||
tag="policy_gradient/rollouts/traj_len_mean",
|
||||
simple_value=traj_len_mean)])
|
||||
file_writer.add_summary(traj_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
trajectory["advantages"] = ((trajectory["advantages"] -
|
||||
trajectory["advantages"].mean()) /
|
||||
trajectory["advantages"].std())
|
||||
rollouts_end = time.time()
|
||||
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
|
||||
", stepsize=" + str(config["sgd_stepsize"]) + "):")
|
||||
names = ["iter", "loss", "kl", "entropy"]
|
||||
print(("{:>15}" * len(names)).format(*names))
|
||||
trajectory = shuffle(trajectory)
|
||||
shuffle_end = time.time()
|
||||
tuples_per_device = model.load_data(
|
||||
trajectory, j == 0 and config["full_trace_data_load"])
|
||||
load_end = time.time()
|
||||
checkpointing_time = checkpointing_end - iter_start
|
||||
rollouts_time = rollouts_end - checkpointing_end
|
||||
shuffle_time = shuffle_end - rollouts_end
|
||||
load_time = load_end - shuffle_end
|
||||
sgd_time = 0
|
||||
for i in range(config["num_sgd_iter"]):
|
||||
sgd_start = time.time()
|
||||
batch_index = 0
|
||||
num_batches = (int(tuples_per_device) //
|
||||
int(model.per_device_batch_size))
|
||||
loss, kl, entropy = [], [], []
|
||||
permutation = np.random.permutation(num_batches)
|
||||
while batch_index < num_batches:
|
||||
full_trace = (
|
||||
i == 0 and j == 0 and
|
||||
batch_index == config["full_trace_nth_sgd_batch"])
|
||||
batch_loss, batch_kl, batch_entropy = model.run_sgd_minibatch(
|
||||
permutation[batch_index] * model.per_device_batch_size,
|
||||
self.kl_coeff, full_trace,
|
||||
file_writer if write_tf_logs else None)
|
||||
loss.append(batch_loss)
|
||||
kl.append(batch_kl)
|
||||
entropy.append(batch_entropy)
|
||||
batch_index += 1
|
||||
loss = np.mean(loss)
|
||||
kl = np.mean(kl)
|
||||
entropy = np.mean(entropy)
|
||||
sgd_end = time.time()
|
||||
print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl,
|
||||
entropy))
|
||||
|
||||
values = []
|
||||
if i == config["num_sgd_iter"] - 1:
|
||||
metric_prefix = "policy_gradient/sgd/final_iter/"
|
||||
values.append(tf.Summary.Value(
|
||||
tag=metric_prefix + "kl_coeff",
|
||||
simple_value=self.kl_coeff))
|
||||
else:
|
||||
metric_prefix = "policy_gradient/sgd/intermediate_iters/"
|
||||
values.extend([
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_entropy",
|
||||
simple_value=entropy),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_loss",
|
||||
simple_value=loss),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_kl",
|
||||
simple_value=kl)])
|
||||
if write_tf_logs:
|
||||
sgd_stats = tf.Summary(value=values)
|
||||
file_writer.add_summary(sgd_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
sgd_time += sgd_end - sgd_start
|
||||
if kl > 2.0 * config["kl_target"]:
|
||||
self.kl_coeff *= 1.5
|
||||
elif kl < 0.5 * config["kl_target"]:
|
||||
self.kl_coeff *= 0.5
|
||||
values = []
|
||||
if i == config["num_sgd_iter"] - 1:
|
||||
metric_prefix = "policy_gradient/sgd/final_iter/"
|
||||
values.append(tf.Summary.Value(
|
||||
tag=metric_prefix + "kl_coeff",
|
||||
simple_value=self.kl_coeff))
|
||||
else:
|
||||
metric_prefix = "policy_gradient/sgd/intermediate_iters/"
|
||||
values.extend([
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_entropy",
|
||||
simple_value=entropy),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_loss",
|
||||
simple_value=loss),
|
||||
tf.Summary.Value(
|
||||
tag=metric_prefix + "mean_kl",
|
||||
simple_value=kl)])
|
||||
if write_tf_logs:
|
||||
sgd_stats = tf.Summary(value=values)
|
||||
file_writer.add_summary(sgd_stats, self.global_step)
|
||||
self.global_step += 1
|
||||
sgd_time += sgd_end - sgd_start
|
||||
if kl > 2.0 * config["kl_target"]:
|
||||
self.kl_coeff *= 1.5
|
||||
elif kl < 0.5 * config["kl_target"]:
|
||||
self.kl_coeff *= 0.5
|
||||
|
||||
info = {
|
||||
"kl_divergence": kl,
|
||||
"kl_coefficient": self.kl_coeff,
|
||||
"checkpointing_time": checkpointing_time,
|
||||
"rollouts_time": rollouts_time,
|
||||
"shuffle_time": shuffle_time,
|
||||
"load_time": load_time,
|
||||
"sgd_time": sgd_time,
|
||||
"sample_throughput": len(trajectory["observations"]) / sgd_time
|
||||
}
|
||||
info = {
|
||||
"kl_divergence": kl,
|
||||
"kl_coefficient": self.kl_coeff,
|
||||
"checkpointing_time": checkpointing_time,
|
||||
"rollouts_time": rollouts_time,
|
||||
"shuffle_time": shuffle_time,
|
||||
"load_time": load_time,
|
||||
"sgd_time": sgd_time,
|
||||
"sample_throughput": len(trajectory["observations"]) / sgd_time
|
||||
}
|
||||
|
||||
print("kl div:", kl)
|
||||
print("kl coeff:", self.kl_coeff)
|
||||
print("checkpointing time:", checkpointing_time)
|
||||
print("rollouts time:", rollouts_time)
|
||||
print("shuffle time:", shuffle_time)
|
||||
print("load time:", load_time)
|
||||
print("sgd time:", sgd_time)
|
||||
print("sgd examples/s:", len(trajectory["observations"]) / sgd_time)
|
||||
print("kl div:", kl)
|
||||
print("kl coeff:", self.kl_coeff)
|
||||
print("checkpointing time:", checkpointing_time)
|
||||
print("rollouts time:", rollouts_time)
|
||||
print("shuffle time:", shuffle_time)
|
||||
print("load time:", load_time)
|
||||
print("sgd time:", sgd_time)
|
||||
print("sgd examples/s:", len(trajectory["observations"]) / sgd_time)
|
||||
|
||||
result = TrainingResult(
|
||||
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
|
||||
result = TrainingResult(
|
||||
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -11,93 +11,95 @@ from ray.rllib.policy_gradient.utils import flatten, concatenate
|
||||
|
||||
def rollouts(policy, env, horizon, observation_filter=NoFilter(),
|
||||
reward_filter=NoFilter()):
|
||||
"""Perform a batch of rollouts of a policy in an environment.
|
||||
"""Perform a batch of rollouts of a policy in an environment.
|
||||
|
||||
Args:
|
||||
policy: The policy that will be rollout out. Can be an arbitrary object
|
||||
that supports a compute_actions(observation) function.
|
||||
env: The environment the rollout is computed in. Needs to support the
|
||||
OpenAI gym API and needs to support batches of data.
|
||||
horizon: Upper bound for the number of timesteps for each rollout in the
|
||||
batch.
|
||||
observation_filter: Function that is applied to each of the observations.
|
||||
reward_filter: Function that is applied to each of the rewards.
|
||||
Args:
|
||||
policy: The policy that will be rollout out. Can be an arbitrary object
|
||||
that supports a compute_actions(observation) function.
|
||||
env: The environment the rollout is computed in. Needs to support the
|
||||
OpenAI gym API and needs to support batches of data.
|
||||
horizon: Upper bound for the number of timesteps for each rollout in
|
||||
the batch.
|
||||
observation_filter: Function that is applied to each of the
|
||||
observations.
|
||||
reward_filter: Function that is applied to each of the rewards.
|
||||
|
||||
Returns:
|
||||
A trajectory, which is a dictionary with keys "observations", "rewards",
|
||||
"orig_rewards", "actions", "logprobs", "dones". Each value is an array of
|
||||
shape (num_timesteps, env.batchsize, shape).
|
||||
"""
|
||||
Returns:
|
||||
A trajectory, which is a dictionary with keys "observations",
|
||||
"rewards", "orig_rewards", "actions", "logprobs", "dones". Each
|
||||
value is an array of shape (num_timesteps, env.batchsize, shape).
|
||||
"""
|
||||
|
||||
observation = observation_filter(env.reset())
|
||||
done = np.array(env.batchsize * [False])
|
||||
t = 0
|
||||
observations = []
|
||||
raw_rewards = [] # Empirical rewards
|
||||
actions = []
|
||||
logprobs = []
|
||||
dones = []
|
||||
observation = observation_filter(env.reset())
|
||||
done = np.array(env.batchsize * [False])
|
||||
t = 0
|
||||
observations = []
|
||||
raw_rewards = [] # Empirical rewards
|
||||
actions = []
|
||||
logprobs = []
|
||||
dones = []
|
||||
|
||||
while not done.all() and t < horizon:
|
||||
action, logprob = policy.compute_actions(observation)
|
||||
observations.append(observation[None])
|
||||
actions.append(action[None])
|
||||
logprobs.append(logprob[None])
|
||||
observation, raw_reward, done = env.step(action)
|
||||
observation = observation_filter(observation)
|
||||
raw_rewards.append(raw_reward[None])
|
||||
dones.append(done[None])
|
||||
t += 1
|
||||
while not done.all() and t < horizon:
|
||||
action, logprob = policy.compute_actions(observation)
|
||||
observations.append(observation[None])
|
||||
actions.append(action[None])
|
||||
logprobs.append(logprob[None])
|
||||
observation, raw_reward, done = env.step(action)
|
||||
observation = observation_filter(observation)
|
||||
raw_rewards.append(raw_reward[None])
|
||||
dones.append(done[None])
|
||||
t += 1
|
||||
|
||||
return {"observations": np.vstack(observations),
|
||||
"raw_rewards": np.vstack(raw_rewards),
|
||||
"actions": np.vstack(actions),
|
||||
"logprobs": np.vstack(logprobs),
|
||||
"dones": np.vstack(dones)}
|
||||
return {"observations": np.vstack(observations),
|
||||
"raw_rewards": np.vstack(raw_rewards),
|
||||
"actions": np.vstack(actions),
|
||||
"logprobs": np.vstack(logprobs),
|
||||
"dones": np.vstack(dones)}
|
||||
|
||||
|
||||
def add_advantage_values(trajectory, gamma, lam, reward_filter):
|
||||
rewards = trajectory["raw_rewards"]
|
||||
dones = trajectory["dones"]
|
||||
advantages = np.zeros_like(rewards)
|
||||
last_advantage = np.zeros(rewards.shape[1], dtype="float32")
|
||||
rewards = trajectory["raw_rewards"]
|
||||
dones = trajectory["dones"]
|
||||
advantages = np.zeros_like(rewards)
|
||||
last_advantage = np.zeros(rewards.shape[1], dtype="float32")
|
||||
|
||||
for t in reversed(range(len(rewards))):
|
||||
delta = rewards[t, :] * (1 - dones[t, :])
|
||||
last_advantage = delta + gamma * lam * last_advantage
|
||||
advantages[t, :] = last_advantage
|
||||
reward_filter(advantages[t, :])
|
||||
for t in reversed(range(len(rewards))):
|
||||
delta = rewards[t, :] * (1 - dones[t, :])
|
||||
last_advantage = delta + gamma * lam * last_advantage
|
||||
advantages[t, :] = last_advantage
|
||||
reward_filter(advantages[t, :])
|
||||
|
||||
trajectory["advantages"] = advantages
|
||||
trajectory["advantages"] = advantages
|
||||
|
||||
|
||||
@ray.remote
|
||||
def compute_trajectory(policy, env, gamma, lam, horizon, observation_filter,
|
||||
reward_filter):
|
||||
trajectory = rollouts(policy, env, horizon, observation_filter,
|
||||
reward_filter)
|
||||
add_advantage_values(trajectory, gamma, lam, reward_filter)
|
||||
return trajectory
|
||||
trajectory = rollouts(policy, env, horizon, observation_filter,
|
||||
reward_filter)
|
||||
add_advantage_values(trajectory, gamma, lam, reward_filter)
|
||||
return trajectory
|
||||
|
||||
|
||||
def collect_samples(agents, num_timesteps, gamma, lam, horizon,
|
||||
observation_filter=NoFilter(), reward_filter=NoFilter()):
|
||||
num_timesteps_so_far = 0
|
||||
trajectories = []
|
||||
total_rewards = []
|
||||
traj_len_means = []
|
||||
while num_timesteps_so_far < num_timesteps:
|
||||
trajectory_batch = ray.get(
|
||||
[agent.compute_trajectory.remote(gamma, lam, horizon)
|
||||
for agent in agents])
|
||||
trajectory = concatenate(trajectory_batch)
|
||||
trajectory = flatten(trajectory)
|
||||
not_done = np.logical_not(trajectory["dones"])
|
||||
total_rewards.append(
|
||||
trajectory["raw_rewards"][not_done].sum(axis=0).mean() / len(agents))
|
||||
traj_len_means.append(not_done.sum(axis=0).mean() / len(agents))
|
||||
trajectory = {key: val[not_done] for key, val in trajectory.items()}
|
||||
num_timesteps_so_far += len(trajectory["dones"])
|
||||
trajectories.append(trajectory)
|
||||
return (concatenate(trajectories), np.mean(total_rewards),
|
||||
np.mean(traj_len_means))
|
||||
num_timesteps_so_far = 0
|
||||
trajectories = []
|
||||
total_rewards = []
|
||||
traj_len_means = []
|
||||
while num_timesteps_so_far < num_timesteps:
|
||||
trajectory_batch = ray.get(
|
||||
[agent.compute_trajectory.remote(gamma, lam, horizon)
|
||||
for agent in agents])
|
||||
trajectory = concatenate(trajectory_batch)
|
||||
trajectory = flatten(trajectory)
|
||||
not_done = np.logical_not(trajectory["dones"])
|
||||
total_rewards.append(
|
||||
trajectory["raw_rewards"][not_done].sum(axis=0).mean() /
|
||||
len(agents))
|
||||
traj_len_means.append(not_done.sum(axis=0).mean() / len(agents))
|
||||
trajectory = {key: val[not_done] for key, val in trajectory.items()}
|
||||
num_timesteps_so_far += len(trajectory["dones"])
|
||||
trajectories.append(trajectory)
|
||||
return (concatenate(trajectories), np.mean(total_rewards),
|
||||
np.mean(traj_len_means))
|
||||
|
||||
@@ -13,49 +13,49 @@ from ray.rllib.policy_gradient.utils import flatten, concatenate
|
||||
|
||||
class DistibutionsTest(unittest.TestCase):
|
||||
|
||||
def testCategorical(self):
|
||||
num_samples = 100000
|
||||
logits = tf.placeholder(tf.float32, shape=(None, 10))
|
||||
z = 8 * (np.random.rand(10) - 0.5)
|
||||
data = np.tile(z, (num_samples, 1))
|
||||
c = Categorical(logits)
|
||||
sample_op = c.sample()
|
||||
sess = tf.Session()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
samples = sess.run(sample_op, feed_dict={logits: data})
|
||||
counts = np.zeros(10)
|
||||
for sample in samples:
|
||||
counts[sample] += 1.0
|
||||
probs = np.exp(z) / np.sum(np.exp(z))
|
||||
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
|
||||
def testCategorical(self):
|
||||
num_samples = 100000
|
||||
logits = tf.placeholder(tf.float32, shape=(None, 10))
|
||||
z = 8 * (np.random.rand(10) - 0.5)
|
||||
data = np.tile(z, (num_samples, 1))
|
||||
c = Categorical(logits)
|
||||
sample_op = c.sample()
|
||||
sess = tf.Session()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
samples = sess.run(sample_op, feed_dict={logits: data})
|
||||
counts = np.zeros(10)
|
||||
for sample in samples:
|
||||
counts[sample] += 1.0
|
||||
probs = np.exp(z) / np.sum(np.exp(z))
|
||||
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
|
||||
|
||||
|
||||
class UtilsTest(unittest.TestCase):
|
||||
|
||||
def testFlatten(self):
|
||||
d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]),
|
||||
"a": np.array([[[5], [-5]], [[6], [-6]]])}
|
||||
flat = flatten(d.copy(), start=0, stop=2)
|
||||
assert_allclose(d["s"][0][0][:], flat["s"][0][:])
|
||||
assert_allclose(d["s"][0][1][:], flat["s"][1][:])
|
||||
assert_allclose(d["s"][1][0][:], flat["s"][2][:])
|
||||
assert_allclose(d["s"][1][1][:], flat["s"][3][:])
|
||||
assert_allclose(d["a"][0][0], flat["a"][0])
|
||||
assert_allclose(d["a"][0][1], flat["a"][1])
|
||||
assert_allclose(d["a"][1][0], flat["a"][2])
|
||||
assert_allclose(d["a"][1][1], flat["a"][3])
|
||||
def testFlatten(self):
|
||||
d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]),
|
||||
"a": np.array([[[5], [-5]], [[6], [-6]]])}
|
||||
flat = flatten(d.copy(), start=0, stop=2)
|
||||
assert_allclose(d["s"][0][0][:], flat["s"][0][:])
|
||||
assert_allclose(d["s"][0][1][:], flat["s"][1][:])
|
||||
assert_allclose(d["s"][1][0][:], flat["s"][2][:])
|
||||
assert_allclose(d["s"][1][1][:], flat["s"][3][:])
|
||||
assert_allclose(d["a"][0][0], flat["a"][0])
|
||||
assert_allclose(d["a"][0][1], flat["a"][1])
|
||||
assert_allclose(d["a"][1][0], flat["a"][2])
|
||||
assert_allclose(d["a"][1][1], flat["a"][3])
|
||||
|
||||
def testConcatenate(self):
|
||||
d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])}
|
||||
d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])}
|
||||
d = concatenate([d1, d2])
|
||||
assert_allclose(d["s"], np.array([0, 1, 4, 5]))
|
||||
assert_allclose(d["a"], np.array([2, 3, 6, 7]))
|
||||
def testConcatenate(self):
|
||||
d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])}
|
||||
d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])}
|
||||
d = concatenate([d1, d2])
|
||||
assert_allclose(d["s"], np.array([0, 1, 4, 5]))
|
||||
assert_allclose(d["a"], np.array([2, 3, 6, 7]))
|
||||
|
||||
D = concatenate([d])
|
||||
assert_allclose(D["s"], np.array([0, 1, 4, 5]))
|
||||
assert_allclose(D["a"], np.array([2, 3, 6, 7]))
|
||||
D = concatenate([d])
|
||||
assert_allclose(D["s"], np.array([0, 1, 4, 5]))
|
||||
assert_allclose(D["a"], np.array([2, 3, 6, 7]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -6,31 +6,31 @@ import numpy as np
|
||||
|
||||
|
||||
def flatten(weights, start=0, stop=2):
|
||||
"""This methods reshapes all values in a dictionary.
|
||||
"""This methods reshapes all values in a dictionary.
|
||||
|
||||
The indices from start to stop will be flattened into a single index.
|
||||
The indices from start to stop will be flattened into a single index.
|
||||
|
||||
Args:
|
||||
weights: A dictionary mapping keys to numpy arrays.
|
||||
start: The starting index.
|
||||
stop: The ending index.
|
||||
"""
|
||||
for key, val in weights.items():
|
||||
new_shape = val.shape[0:start] + (-1,) + val.shape[stop:]
|
||||
weights[key] = val.reshape(new_shape)
|
||||
return weights
|
||||
Args:
|
||||
weights: A dictionary mapping keys to numpy arrays.
|
||||
start: The starting index.
|
||||
stop: The ending index.
|
||||
"""
|
||||
for key, val in weights.items():
|
||||
new_shape = val.shape[0:start] + (-1,) + val.shape[stop:]
|
||||
weights[key] = val.reshape(new_shape)
|
||||
return weights
|
||||
|
||||
|
||||
def concatenate(weights_list):
|
||||
keys = weights_list[0].keys()
|
||||
result = {}
|
||||
for key in keys:
|
||||
result[key] = np.concatenate([l[key] for l in weights_list])
|
||||
return result
|
||||
keys = weights_list[0].keys()
|
||||
result = {}
|
||||
for key in keys:
|
||||
result[key] = np.concatenate([l[key] for l in weights_list])
|
||||
return result
|
||||
|
||||
|
||||
def shuffle(trajectory):
|
||||
permutation = np.random.permutation(trajectory["dones"].shape[0])
|
||||
for key, val in trajectory.items():
|
||||
trajectory[key] = val[permutation]
|
||||
return trajectory
|
||||
permutation = np.random.permutation(trajectory["dones"].shape[0])
|
||||
for key, val in trajectory.items():
|
||||
trajectory[key] = val[permutation]
|
||||
return trajectory
|
||||
|
||||
+28
-28
@@ -22,36 +22,36 @@ parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init()
|
||||
ray.init()
|
||||
|
||||
env_name = args.env
|
||||
if args.alg == "PolicyGradient":
|
||||
alg = pg.PolicyGradient(
|
||||
env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "EvolutionStrategies":
|
||||
alg = es.EvolutionStrategies(
|
||||
env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "DQN":
|
||||
alg = dqn.DQN(
|
||||
env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "A3C":
|
||||
alg = a3c.A3C(
|
||||
env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
else:
|
||||
assert False, ("Unknown algorithm, check --alg argument. Valid choices "
|
||||
"are PolicyGradientPolicyGradient, EvolutionStrategies, "
|
||||
"DQN and A3C.")
|
||||
env_name = args.env
|
||||
if args.alg == "PolicyGradient":
|
||||
alg = pg.PolicyGradient(
|
||||
env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "EvolutionStrategies":
|
||||
alg = es.EvolutionStrategies(
|
||||
env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "DQN":
|
||||
alg = dqn.DQN(
|
||||
env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
elif args.alg == "A3C":
|
||||
alg = a3c.A3C(
|
||||
env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir)
|
||||
else:
|
||||
assert False, ("Unknown algorithm, check --alg argument. Valid "
|
||||
"choices are PolicyGradientPolicyGradient, "
|
||||
"EvolutionStrategies, DQN and A3C.")
|
||||
|
||||
result_logger = ray.rllib.common.RLLibLogger(
|
||||
os.path.join(alg.logdir, "result.json"))
|
||||
result_logger = ray.rllib.common.RLLibLogger(
|
||||
os.path.join(alg.logdir, "result.json"))
|
||||
|
||||
while True:
|
||||
result = alg.train()
|
||||
while True:
|
||||
result = alg.train()
|
||||
|
||||
# We need to use a custom json serializer class so that NaNs get encoded
|
||||
# as null as required by Athena.
|
||||
json.dump(result._asdict(), result_logger,
|
||||
cls=ray.rllib.common.RLLibEncoder)
|
||||
result_logger.write("\n")
|
||||
# We need to use a custom json serializer class so that NaNs get
|
||||
# encoded as null as required by Athena.
|
||||
json.dump(result._asdict(), result_logger,
|
||||
cls=ray.rllib.common.RLLibEncoder)
|
||||
result_logger.write("\n")
|
||||
|
||||
+128
-125
@@ -10,34 +10,35 @@ import ray.services as services
|
||||
|
||||
|
||||
def check_no_existing_redis_clients(node_ip_address, redis_address):
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
|
||||
# The client table prefix must be kept in sync with the file
|
||||
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
|
||||
REDIS_CLIENT_TABLE_PREFIX = "CL:"
|
||||
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
|
||||
# Filter to clients on the same node and do some basic checking.
|
||||
for key in client_keys:
|
||||
info = redis_client.hgetall(key)
|
||||
assert b"ray_client_id" in info
|
||||
assert b"node_ip_address" in info
|
||||
assert b"client_type" in info
|
||||
assert b"deleted" in info
|
||||
# Clients that ran on the same node but that are marked dead can be
|
||||
# ignored.
|
||||
deleted = info[b"deleted"]
|
||||
deleted = bool(int(deleted))
|
||||
if deleted:
|
||||
continue
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
# The client table prefix must be kept in sync with the file
|
||||
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
|
||||
REDIS_CLIENT_TABLE_PREFIX = "CL:"
|
||||
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
|
||||
# Filter to clients on the same node and do some basic checking.
|
||||
for key in client_keys:
|
||||
info = redis_client.hgetall(key)
|
||||
assert b"ray_client_id" in info
|
||||
assert b"node_ip_address" in info
|
||||
assert b"client_type" in info
|
||||
assert b"deleted" in info
|
||||
# Clients that ran on the same node but that are marked dead can be
|
||||
# ignored.
|
||||
deleted = info[b"deleted"]
|
||||
deleted = bool(int(deleted))
|
||||
if deleted:
|
||||
continue
|
||||
|
||||
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
|
||||
raise Exception("This Redis instance is already connected to clients "
|
||||
"with this IP address.")
|
||||
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
|
||||
raise Exception("This Redis instance is already connected to "
|
||||
"clients with this IP address.")
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -64,120 +65,122 @@ def cli():
|
||||
help="provide this argument to block forever in this command")
|
||||
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
object_manager_port, num_workers, num_cpus, num_gpus, head, block):
|
||||
# Note that we redirect stdout and stderr to /dev/null because otherwise
|
||||
# attempts to print may cause exceptions if a process is started inside of an
|
||||
# SSH connection and the SSH connection dies. TODO(rkn): This is a temporary
|
||||
# fix. We should actually redirect stdout and stderr to Redis in some way.
|
||||
# Note that we redirect stdout and stderr to /dev/null because otherwise
|
||||
# attempts to print may cause exceptions if a process is started inside of
|
||||
# an SSH connection and the SSH connection dies. TODO(rkn): This is a
|
||||
# temporary fix. We should actually redirect stdout and stderr to Redis in
|
||||
# some way.
|
||||
|
||||
if head:
|
||||
# Start Ray on the head node.
|
||||
if redis_address is not None:
|
||||
raise Exception("If --head is passed in, a Redis server will be "
|
||||
"started, so a Redis address should not be provided.")
|
||||
if head:
|
||||
# Start Ray on the head node.
|
||||
if redis_address is not None:
|
||||
raise Exception("If --head is passed in, a Redis server will be "
|
||||
"started, so a Redis address should not be "
|
||||
"provided.")
|
||||
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address()
|
||||
print("Using IP address {} for this node.".format(node_ip_address))
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address()
|
||||
print("Using IP address {} for this node.".format(node_ip_address))
|
||||
|
||||
address_info = {}
|
||||
# Use the provided object manager port if there is one.
|
||||
if object_manager_port is not None:
|
||||
address_info["object_manager_ports"] = [object_manager_port]
|
||||
if address_info == {}:
|
||||
address_info = None
|
||||
address_info = {}
|
||||
# Use the provided object manager port if there is one.
|
||||
if object_manager_port is not None:
|
||||
address_info["object_manager_ports"] = [object_manager_port]
|
||||
if address_info == {}:
|
||||
address_info = None
|
||||
|
||||
address_info = services.start_ray_head(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
redis_port=redis_port,
|
||||
num_workers=num_workers,
|
||||
cleanup=False,
|
||||
redirect_output=True,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
num_redis_shards=num_redis_shards)
|
||||
print(address_info)
|
||||
print("\nStarted Ray on this node. You can add additional nodes to the "
|
||||
"cluster by calling\n\n"
|
||||
" ray start --redis-address {}\n\n"
|
||||
"from the node you wish to add. You can connect a driver to the "
|
||||
"cluster from Python by running\n\n"
|
||||
" import ray\n"
|
||||
" ray.init(redis_address=\"{}\")\n\n"
|
||||
"If you have trouble connecting from a different machine, check "
|
||||
"that your firewall is configured properly. If you wish to "
|
||||
"terminate the processes that have been started, run\n\n"
|
||||
" ray stop".format(address_info["redis_address"],
|
||||
address_info["redis_address"]))
|
||||
else:
|
||||
# Start Ray on a non-head node.
|
||||
if redis_port is not None:
|
||||
raise Exception("If --head is not passed in, --redis-port is not "
|
||||
"allowed")
|
||||
if redis_address is None:
|
||||
raise Exception("If --head is not passed in, --redis-address must be "
|
||||
"provided.")
|
||||
if num_redis_shards is not None:
|
||||
raise Exception("If --head is not passed in, --num-redis-shards must "
|
||||
"not be provided.")
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
# Wait for the Redis server to be started. And throw an exception if we
|
||||
# can't connect to it.
|
||||
services.wait_for_redis_to_start(redis_ip_address, int(redis_port))
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
print("Using IP address {} for this node.".format(node_ip_address))
|
||||
# Check that there aren't already Redis clients with the same IP address
|
||||
# connected with this Redis instance. This raises an exception if the Redis
|
||||
# server already has clients on this node.
|
||||
check_no_existing_redis_clients(node_ip_address, redis_address)
|
||||
address_info = services.start_ray_node(
|
||||
node_ip_address=node_ip_address,
|
||||
redis_address=redis_address,
|
||||
object_manager_ports=[object_manager_port],
|
||||
num_workers=num_workers,
|
||||
cleanup=False,
|
||||
redirect_output=True,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
print(address_info)
|
||||
print("\nStarted Ray on this node. If you wish to terminate the processes "
|
||||
"that have been started, run\n\n"
|
||||
" ray stop")
|
||||
address_info = services.start_ray_head(
|
||||
address_info=address_info,
|
||||
node_ip_address=node_ip_address,
|
||||
redis_port=redis_port,
|
||||
num_workers=num_workers,
|
||||
cleanup=False,
|
||||
redirect_output=True,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus,
|
||||
num_redis_shards=num_redis_shards)
|
||||
print(address_info)
|
||||
print("\nStarted Ray on this node. You can add additional nodes to "
|
||||
"the cluster by calling\n\n"
|
||||
" ray start --redis-address {}\n\n"
|
||||
"from the node you wish to add. You can connect a driver to the "
|
||||
"cluster from Python by running\n\n"
|
||||
" import ray\n"
|
||||
" ray.init(redis_address=\"{}\")\n\n"
|
||||
"If you have trouble connecting from a different machine, check "
|
||||
"that your firewall is configured properly. If you wish to "
|
||||
"terminate the processes that have been started, run\n\n"
|
||||
" ray stop".format(address_info["redis_address"],
|
||||
address_info["redis_address"]))
|
||||
else:
|
||||
# Start Ray on a non-head node.
|
||||
if redis_port is not None:
|
||||
raise Exception("If --head is not passed in, --redis-port is not "
|
||||
"allowed")
|
||||
if redis_address is None:
|
||||
raise Exception("If --head is not passed in, --redis-address must "
|
||||
"be provided.")
|
||||
if num_redis_shards is not None:
|
||||
raise Exception("If --head is not passed in, --num-redis-shards "
|
||||
"must not be provided.")
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
# Wait for the Redis server to be started. And throw an exception if we
|
||||
# can't connect to it.
|
||||
services.wait_for_redis_to_start(redis_ip_address, int(redis_port))
|
||||
# Get the node IP address if one is not provided.
|
||||
if node_ip_address is None:
|
||||
node_ip_address = services.get_node_ip_address(redis_address)
|
||||
print("Using IP address {} for this node.".format(node_ip_address))
|
||||
# Check that there aren't already Redis clients with the same IP
|
||||
# address connected with this Redis instance. This raises an exception
|
||||
# if the Redis server already has clients on this node.
|
||||
check_no_existing_redis_clients(node_ip_address, redis_address)
|
||||
address_info = services.start_ray_node(
|
||||
node_ip_address=node_ip_address,
|
||||
redis_address=redis_address,
|
||||
object_manager_ports=[object_manager_port],
|
||||
num_workers=num_workers,
|
||||
cleanup=False,
|
||||
redirect_output=True,
|
||||
num_cpus=num_cpus,
|
||||
num_gpus=num_gpus)
|
||||
print(address_info)
|
||||
print("\nStarted Ray on this node. If you wish to terminate the "
|
||||
"processes that have been started, run\n\n"
|
||||
" ray stop")
|
||||
|
||||
if block:
|
||||
import time
|
||||
while True:
|
||||
time.sleep(30)
|
||||
if block:
|
||||
import time
|
||||
while True:
|
||||
time.sleep(30)
|
||||
|
||||
|
||||
@click.command()
|
||||
def stop():
|
||||
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
|
||||
"local_scheduler"], shell=True)
|
||||
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
|
||||
"local_scheduler"], shell=True)
|
||||
|
||||
# Find the PID of the monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
# Find the PID of the monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
|
||||
# Find the PID of the Redis process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
# Find the PID of the Redis process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
|
||||
# Find the PIDs of the worker processes and kill them.
|
||||
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
|
||||
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
|
||||
shell=True)
|
||||
# Find the PIDs of the worker processes and kill them.
|
||||
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
|
||||
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
|
||||
shell=True)
|
||||
|
||||
# Find the PID of the Ray log monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
# Find the PID of the Ray log monitor process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
|
||||
# Find the PID of the jupyter process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep jupyter | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
# Find the PID of the jupyter process and kill it.
|
||||
subprocess.call(["kill $(ps aux | grep jupyter | grep -v grep | "
|
||||
"awk '{ print $2 }') 2> /dev/null"], shell=True)
|
||||
|
||||
|
||||
cli.add_command(start)
|
||||
@@ -185,8 +188,8 @@ cli.add_command(stop)
|
||||
|
||||
|
||||
def main():
|
||||
return cli()
|
||||
return cli()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
+147
-142
@@ -8,58 +8,61 @@ import ray.numbuf
|
||||
|
||||
|
||||
class RaySerializationException(Exception):
|
||||
def __init__(self, message, example_object):
|
||||
Exception.__init__(self, message)
|
||||
self.example_object = example_object
|
||||
def __init__(self, message, example_object):
|
||||
Exception.__init__(self, message)
|
||||
self.example_object = example_object
|
||||
|
||||
|
||||
class RayDeserializationException(Exception):
|
||||
def __init__(self, message, class_id):
|
||||
Exception.__init__(self, message)
|
||||
self.class_id = class_id
|
||||
def __init__(self, message, class_id):
|
||||
Exception.__init__(self, message)
|
||||
self.class_id = class_id
|
||||
|
||||
|
||||
class RayNotDictionarySerializable(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def check_serializable(cls):
|
||||
"""Throws an exception if Ray cannot serialize this class efficiently.
|
||||
"""Throws an exception if Ray cannot serialize this class efficiently.
|
||||
|
||||
Args:
|
||||
cls (type): The class to be serialized.
|
||||
Args:
|
||||
cls (type): The class to be serialized.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if Ray cannot serialize this class
|
||||
efficiently.
|
||||
"""
|
||||
if is_named_tuple(cls):
|
||||
# This case works.
|
||||
return
|
||||
if not hasattr(cls, "__new__"):
|
||||
print("The class {} does not have a '__new__' attribute and is probably "
|
||||
"an old-stye class. Please make it a new-style class by inheriting "
|
||||
"from 'object'.")
|
||||
raise RayNotDictionarySerializable("The class {} does not have a "
|
||||
"'__new__' attribute and is probably "
|
||||
"an old-style class. We do not support "
|
||||
"this. Please make it a new-style "
|
||||
"class by inheriting from 'object'."
|
||||
.format(cls))
|
||||
try:
|
||||
obj = cls.__new__(cls)
|
||||
except:
|
||||
raise RayNotDictionarySerializable("The class {} has overridden '__new__'"
|
||||
", so Ray may not be able to serialize "
|
||||
"it efficiently.".format(cls))
|
||||
if not hasattr(obj, "__dict__"):
|
||||
raise RayNotDictionarySerializable("Objects of the class {} do not have a "
|
||||
"'__dict__' attribute, so Ray cannot "
|
||||
"serialize it efficiently.".format(cls))
|
||||
if hasattr(obj, "__slots__"):
|
||||
raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray "
|
||||
"may not be able to serialize it "
|
||||
"efficiently.".format(cls))
|
||||
Raises:
|
||||
Exception: An exception is raised if Ray cannot serialize this class
|
||||
efficiently.
|
||||
"""
|
||||
if is_named_tuple(cls):
|
||||
# This case works.
|
||||
return
|
||||
if not hasattr(cls, "__new__"):
|
||||
print("The class {} does not have a '__new__' attribute and is "
|
||||
"probably an old-stye class. Please make it a new-style class "
|
||||
"by inheriting from 'object'.")
|
||||
raise RayNotDictionarySerializable("The class {} does not have a "
|
||||
"'__new__' attribute and is "
|
||||
"probably an old-style class. We "
|
||||
"do not support this. Please make "
|
||||
"it a new-style class by "
|
||||
"inheriting from 'object'."
|
||||
.format(cls))
|
||||
try:
|
||||
obj = cls.__new__(cls)
|
||||
except:
|
||||
raise RayNotDictionarySerializable("The class {} has overridden "
|
||||
"'__new__', so Ray may not be able "
|
||||
"to serialize it efficiently."
|
||||
.format(cls))
|
||||
if not hasattr(obj, "__dict__"):
|
||||
raise RayNotDictionarySerializable("Objects of the class {} do not "
|
||||
"have a '__dict__' attribute, so "
|
||||
"Ray cannot serialize it "
|
||||
"efficiently.".format(cls))
|
||||
if hasattr(obj, "__slots__"):
|
||||
raise RayNotDictionarySerializable("The class {} uses '__slots__', so "
|
||||
"Ray may not be able to serialize "
|
||||
"it efficiently.".format(cls))
|
||||
|
||||
|
||||
# This field keeps track of a whitelisted set of classes that Ray will
|
||||
@@ -72,134 +75,136 @@ custom_deserializers = dict()
|
||||
|
||||
|
||||
def is_named_tuple(cls):
|
||||
"""Return True if cls is a namedtuple and False otherwise."""
|
||||
b = cls.__bases__
|
||||
if len(b) != 1 or b[0] != tuple:
|
||||
return False
|
||||
f = getattr(cls, "_fields", None)
|
||||
if not isinstance(f, tuple):
|
||||
return False
|
||||
return all(type(n) == str for n in f)
|
||||
"""Return True if cls is a namedtuple and False otherwise."""
|
||||
b = cls.__bases__
|
||||
if len(b) != 1 or b[0] != tuple:
|
||||
return False
|
||||
f = getattr(cls, "_fields", None)
|
||||
if not isinstance(f, tuple):
|
||||
return False
|
||||
return all(type(n) == str for n in f)
|
||||
|
||||
|
||||
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
|
||||
custom_deserializer=None):
|
||||
"""Add cls to the list of classes that we can serialize.
|
||||
"""Add cls to the list of classes that we can serialize.
|
||||
|
||||
Args:
|
||||
cls (type): The class that we can serialize.
|
||||
class_id: A string of bytes used to identify the class.
|
||||
pickle (bool): True if the serialization should be done with pickle. False
|
||||
if it should be done efficiently with Ray.
|
||||
custom_serializer: This argument is optional, but can be provided to
|
||||
serialize objects of the class in a particular way.
|
||||
custom_deserializer: This argument is optional, but can be provided to
|
||||
deserialize objects of the class in a particular way.
|
||||
"""
|
||||
type_to_class_id[cls] = class_id
|
||||
whitelisted_classes[class_id] = cls
|
||||
if pickle:
|
||||
classes_to_pickle.add(class_id)
|
||||
if custom_serializer is not None:
|
||||
custom_serializers[class_id] = custom_serializer
|
||||
custom_deserializers[class_id] = custom_deserializer
|
||||
Args:
|
||||
cls (type): The class that we can serialize.
|
||||
class_id: A string of bytes used to identify the class.
|
||||
pickle (bool): True if the serialization should be done with pickle.
|
||||
False if it should be done efficiently with Ray.
|
||||
custom_serializer: This argument is optional, but can be provided to
|
||||
serialize objects of the class in a particular way.
|
||||
custom_deserializer: This argument is optional, but can be provided to
|
||||
deserialize objects of the class in a particular way.
|
||||
"""
|
||||
type_to_class_id[cls] = class_id
|
||||
whitelisted_classes[class_id] = cls
|
||||
if pickle:
|
||||
classes_to_pickle.add(class_id)
|
||||
if custom_serializer is not None:
|
||||
custom_serializers[class_id] = custom_serializer
|
||||
custom_deserializers[class_id] = custom_deserializer
|
||||
|
||||
|
||||
def serialize(obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
If numbuf does not know how to serialize an object, it will call this method.
|
||||
If numbuf does not know how to serialize an object, it will call this
|
||||
method.
|
||||
|
||||
Args:
|
||||
obj (object): A Python object.
|
||||
Args:
|
||||
obj (object): A Python object.
|
||||
|
||||
Returns:
|
||||
A dictionary that has the key "_pyttype_" to identify the class, and
|
||||
contains all information needed to reconstruct the object.
|
||||
"""
|
||||
if type(obj) not in type_to_class_id:
|
||||
raise RaySerializationException("Ray does not know how to serialize "
|
||||
"objects of type {}.".format(type(obj)),
|
||||
obj)
|
||||
class_id = type_to_class_id[type(obj)]
|
||||
Returns:
|
||||
A dictionary that has the key "_pyttype_" to identify the class, and
|
||||
contains all information needed to reconstruct the object.
|
||||
"""
|
||||
if type(obj) not in type_to_class_id:
|
||||
raise RaySerializationException("Ray does not know how to serialize "
|
||||
"objects of type {}."
|
||||
.format(type(obj)),
|
||||
obj)
|
||||
class_id = type_to_class_id[type(obj)]
|
||||
|
||||
if class_id in classes_to_pickle:
|
||||
serialized_obj = {"data": pickle.dumps(obj),
|
||||
"pickle": True}
|
||||
elif class_id in custom_serializers:
|
||||
serialized_obj = {"data": custom_serializers[class_id](obj)}
|
||||
else:
|
||||
# Handle the namedtuple case.
|
||||
if is_named_tuple(type(obj)):
|
||||
serialized_obj = {}
|
||||
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
|
||||
elif hasattr(obj, "__dict__"):
|
||||
serialized_obj = obj.__dict__
|
||||
if class_id in classes_to_pickle:
|
||||
serialized_obj = {"data": pickle.dumps(obj),
|
||||
"pickle": True}
|
||||
elif class_id in custom_serializers:
|
||||
serialized_obj = {"data": custom_serializers[class_id](obj)}
|
||||
else:
|
||||
raise RaySerializationException("We do not know how to serialize the "
|
||||
"object '{}'".format(obj), obj)
|
||||
result = dict(serialized_obj, **{"_pytype_": class_id})
|
||||
return result
|
||||
# Handle the namedtuple case.
|
||||
if is_named_tuple(type(obj)):
|
||||
serialized_obj = {}
|
||||
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
|
||||
elif hasattr(obj, "__dict__"):
|
||||
serialized_obj = obj.__dict__
|
||||
else:
|
||||
raise RaySerializationException("We do not know how to serialize "
|
||||
"the object '{}'".format(obj), obj)
|
||||
result = dict(serialized_obj, **{"_pytype_": class_id})
|
||||
return result
|
||||
|
||||
|
||||
def deserialize(serialized_obj):
|
||||
"""This is the callback that will be used by numbuf.
|
||||
"""This is the callback that will be used by numbuf.
|
||||
|
||||
If numbuf encounters a dictionary that contains the key "_pytype_" during
|
||||
If numbuf encounters a dictionary that contains the key "_pytype_" during
|
||||
deserialization, it will ask this callback to deserialize the object.
|
||||
|
||||
Args:
|
||||
serialized_obj (object): A dictionary that contains the key "_pytype_".
|
||||
Args:
|
||||
serialized_obj (object): A dictionary that contains the key "_pytype_".
|
||||
|
||||
Returns:
|
||||
A Python object.
|
||||
Returns:
|
||||
A Python object.
|
||||
|
||||
Raises:
|
||||
An exception is raised if we do not know how to deserialize the object.
|
||||
"""
|
||||
class_id = serialized_obj["_pytype_"]
|
||||
Raises:
|
||||
An exception is raised if we do not know how to deserialize the object.
|
||||
"""
|
||||
class_id = serialized_obj["_pytype_"]
|
||||
|
||||
if "pickle" in serialized_obj:
|
||||
# The object was pickled, so unpickle it.
|
||||
obj = pickle.loads(serialized_obj["data"])
|
||||
else:
|
||||
assert class_id not in classes_to_pickle
|
||||
if class_id not in whitelisted_classes:
|
||||
# If this happens, that means that the call to _register_class, which
|
||||
# should have added the class to the list of whitelisted classes, has not
|
||||
# yet propagated to this worker. It should happen if we wait a little
|
||||
# longer.
|
||||
raise RayDeserializationException("The class {} is not one of the "
|
||||
"whitelisted classes."
|
||||
.format(class_id), class_id)
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in custom_deserializers:
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
if "pickle" in serialized_obj:
|
||||
# The object was pickled, so unpickle it.
|
||||
obj = pickle.loads(serialized_obj["data"])
|
||||
else:
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
return obj
|
||||
assert class_id not in classes_to_pickle
|
||||
if class_id not in whitelisted_classes:
|
||||
# If this happens, that means that the call to _register_class,
|
||||
# which should have added the class to the list of whitelisted
|
||||
# classes, has not yet propagated to this worker. It should happen
|
||||
# if we wait a little longer.
|
||||
raise RayDeserializationException("The class {} is not one of the "
|
||||
"whitelisted classes."
|
||||
.format(class_id), class_id)
|
||||
cls = whitelisted_classes[class_id]
|
||||
if class_id in custom_deserializers:
|
||||
obj = custom_deserializers[class_id](serialized_obj["data"])
|
||||
else:
|
||||
# In this case, serialized_obj should just be the __dict__ field.
|
||||
if "_ray_getnewargs_" in serialized_obj:
|
||||
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
|
||||
else:
|
||||
obj = cls.__new__(cls)
|
||||
serialized_obj.pop("_pytype_")
|
||||
obj.__dict__.update(serialized_obj)
|
||||
return obj
|
||||
|
||||
|
||||
def set_callbacks():
|
||||
"""Register the custom callbacks with numbuf.
|
||||
"""Register the custom callbacks with numbuf.
|
||||
|
||||
The serialize callback is used to serialize objects that numbuf does not know
|
||||
how to serialize (for example custom Python classes). The deserialize
|
||||
callback is used to serialize objects that were serialized by the serialize
|
||||
callback.
|
||||
"""
|
||||
ray.numbuf.register_callbacks(serialize, deserialize)
|
||||
The serialize callback is used to serialize objects that numbuf does not
|
||||
know how to serialize (for example custom Python classes). The deserialize
|
||||
callback is used to serialize objects that were serialized by the serialize
|
||||
callback.
|
||||
"""
|
||||
ray.numbuf.register_callbacks(serialize, deserialize)
|
||||
|
||||
|
||||
def clear_state():
|
||||
type_to_class_id.clear()
|
||||
whitelisted_classes.clear()
|
||||
classes_to_pickle.clear()
|
||||
custom_serializers.clear()
|
||||
custom_deserializers.clear()
|
||||
type_to_class_id.clear()
|
||||
whitelisted_classes.clear()
|
||||
classes_to_pickle.clear()
|
||||
custom_serializers.clear()
|
||||
custom_deserializers.clear()
|
||||
|
||||
+888
-860
File diff suppressed because it is too large
Load Diff
+133
-128
@@ -13,160 +13,165 @@ FunctionSignature = namedtuple("FunctionSignature", ["arg_names",
|
||||
"""This class is used to represent a function signature.
|
||||
|
||||
Attributes:
|
||||
keyword_names: The names of the functions keyword arguments. This is used to
|
||||
test if an incorrect keyword argument has been passed to the function.
|
||||
arg_defaults: A dictionary mapping from argument name to argument default
|
||||
value. If the argument is not a keyword argument, the default value will be
|
||||
funcsigs._empty.
|
||||
arg_is_positionals: A dictionary mapping from argument name to a bool. The
|
||||
bool will be true if the argument is a *args argument. Otherwise it will be
|
||||
false.
|
||||
function_name: The name of the function whose signature is being inspected.
|
||||
This is used for printing better error messages.
|
||||
keyword_names: The names of the functions keyword arguments. This is used
|
||||
to test if an incorrect keyword argument has been passed to the
|
||||
function.
|
||||
arg_defaults: A dictionary mapping from argument name to argument default
|
||||
value. If the argument is not a keyword argument, the default value
|
||||
will be funcsigs._empty.
|
||||
arg_is_positionals: A dictionary mapping from argument name to a bool. The
|
||||
bool will be true if the argument is a *args argument. Otherwise it
|
||||
will be false.
|
||||
function_name: The name of the function whose signature is being
|
||||
inspected. This is used for printing better error messages.
|
||||
"""
|
||||
|
||||
|
||||
def check_signature_supported(func, warn=False):
|
||||
"""Check if we support the signature of this function.
|
||||
"""Check if we support the signature of this function.
|
||||
|
||||
We currently do not allow remote functions to have **kwargs. We also do not
|
||||
support keyword arguments in conjunction with a *args argument.
|
||||
We currently do not allow remote functions to have **kwargs. We also do not
|
||||
support keyword arguments in conjunction with a *args argument.
|
||||
|
||||
Args:
|
||||
func: The function whose signature should be checked.
|
||||
warn: If this is true, a warning will be printed if the signature is not
|
||||
supported. If it is false, an exception will be raised if the signature
|
||||
is not supported.
|
||||
Args:
|
||||
func: The function whose signature should be checked.
|
||||
warn: If this is true, a warning will be printed if the signature is
|
||||
not supported. If it is false, an exception will be raised if the
|
||||
signature is not supported.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the signature is not supported.
|
||||
"""
|
||||
function_name = func.__name__
|
||||
sig_params = [(k, v) for k, v
|
||||
in funcsigs.signature(func).parameters.items()]
|
||||
Raises:
|
||||
Exception: An exception is raised if the signature is not supported.
|
||||
"""
|
||||
function_name = func.__name__
|
||||
sig_params = [(k, v) for k, v
|
||||
in funcsigs.signature(func).parameters.items()]
|
||||
|
||||
has_vararg_param = False
|
||||
has_kwargs_param = False
|
||||
has_keyword_arg = False
|
||||
for keyword_name, parameter in sig_params:
|
||||
if parameter.kind == parameter.VAR_KEYWORD:
|
||||
has_kwargs_param = True
|
||||
if parameter.kind == parameter.VAR_POSITIONAL:
|
||||
has_vararg_param = True
|
||||
if parameter.default != funcsigs._empty:
|
||||
has_keyword_arg = True
|
||||
has_vararg_param = False
|
||||
has_kwargs_param = False
|
||||
has_keyword_arg = False
|
||||
for keyword_name, parameter in sig_params:
|
||||
if parameter.kind == parameter.VAR_KEYWORD:
|
||||
has_kwargs_param = True
|
||||
if parameter.kind == parameter.VAR_POSITIONAL:
|
||||
has_vararg_param = True
|
||||
if parameter.default != funcsigs._empty:
|
||||
has_keyword_arg = True
|
||||
|
||||
if has_kwargs_param:
|
||||
message = ("The function {} has a **kwargs argument, which is "
|
||||
"currently not supported.".format(function_name))
|
||||
if warn:
|
||||
print(message)
|
||||
else:
|
||||
raise Exception(message)
|
||||
# Check if the user specified a variable number of arguments and any keyword
|
||||
# arguments.
|
||||
if has_vararg_param and has_keyword_arg:
|
||||
message = ("Function {} has a *args argument as well as a keyword "
|
||||
"argument, which is currently not supported."
|
||||
.format(function_name))
|
||||
if warn:
|
||||
print(message)
|
||||
else:
|
||||
raise Exception(message)
|
||||
if has_kwargs_param:
|
||||
message = ("The function {} has a **kwargs argument, which is "
|
||||
"currently not supported.".format(function_name))
|
||||
if warn:
|
||||
print(message)
|
||||
else:
|
||||
raise Exception(message)
|
||||
# Check if the user specified a variable number of arguments and any
|
||||
# keyword arguments.
|
||||
if has_vararg_param and has_keyword_arg:
|
||||
message = ("Function {} has a *args argument as well as a keyword "
|
||||
"argument, which is currently not supported."
|
||||
.format(function_name))
|
||||
if warn:
|
||||
print(message)
|
||||
else:
|
||||
raise Exception(message)
|
||||
|
||||
|
||||
def extract_signature(func, ignore_first=False):
|
||||
"""Extract the function signature from the function.
|
||||
"""Extract the function signature from the function.
|
||||
|
||||
Args:
|
||||
func: The function whose signature should be extracted.
|
||||
ignore_first: True if the first argument should be ignored. This should be
|
||||
used when func is a method of a class.
|
||||
Args:
|
||||
func: The function whose signature should be extracted.
|
||||
ignore_first: True if the first argument should be ignored. This should
|
||||
be used when func is a method of a class.
|
||||
|
||||
Returns:
|
||||
A function signature object, which includes the names of the keyword
|
||||
arguments as well as their default values.
|
||||
"""
|
||||
sig_params = [(k, v) for k, v
|
||||
in funcsigs.signature(func).parameters.items()]
|
||||
Returns:
|
||||
A function signature object, which includes the names of the keyword
|
||||
arguments as well as their default values.
|
||||
"""
|
||||
sig_params = [(k, v) for k, v
|
||||
in funcsigs.signature(func).parameters.items()]
|
||||
|
||||
if ignore_first:
|
||||
if len(sig_params) == 0:
|
||||
raise Exception("Methods must take a 'self' argument, but the method "
|
||||
"'{}' does not have one.".format(func.__name__))
|
||||
sig_params = sig_params[1:]
|
||||
if ignore_first:
|
||||
if len(sig_params) == 0:
|
||||
raise Exception("Methods must take a 'self' argument, but the "
|
||||
"method '{}' does not have one."
|
||||
.format(func.__name__))
|
||||
sig_params = sig_params[1:]
|
||||
|
||||
# Extract the names of the keyword arguments.
|
||||
keyword_names = set()
|
||||
for keyword_name, parameter in sig_params:
|
||||
if parameter.default != funcsigs._empty:
|
||||
keyword_names.add(keyword_name)
|
||||
# Extract the names of the keyword arguments.
|
||||
keyword_names = set()
|
||||
for keyword_name, parameter in sig_params:
|
||||
if parameter.default != funcsigs._empty:
|
||||
keyword_names.add(keyword_name)
|
||||
|
||||
# Construct the argument default values and other argument information.
|
||||
arg_names = []
|
||||
arg_defaults = []
|
||||
arg_is_positionals = []
|
||||
for keyword_name, parameter in sig_params:
|
||||
arg_names.append(keyword_name)
|
||||
arg_defaults.append(parameter.default)
|
||||
arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL)
|
||||
# Construct the argument default values and other argument information.
|
||||
arg_names = []
|
||||
arg_defaults = []
|
||||
arg_is_positionals = []
|
||||
for keyword_name, parameter in sig_params:
|
||||
arg_names.append(keyword_name)
|
||||
arg_defaults.append(parameter.default)
|
||||
arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL)
|
||||
|
||||
return FunctionSignature(arg_names, arg_defaults, arg_is_positionals,
|
||||
keyword_names, func.__name__)
|
||||
return FunctionSignature(arg_names, arg_defaults, arg_is_positionals,
|
||||
keyword_names, func.__name__)
|
||||
|
||||
|
||||
def extend_args(function_signature, args, kwargs):
|
||||
"""Extend the arguments that were passed into a function.
|
||||
"""Extend the arguments that were passed into a function.
|
||||
|
||||
This extends the arguments that were passed into a function with the default
|
||||
arguments provided in the function definition.
|
||||
This extends the arguments that were passed into a function with the
|
||||
default arguments provided in the function definition.
|
||||
|
||||
Args:
|
||||
function_signature: The function signature of the function being called.
|
||||
args: The non-keyword arguments passed into the function.
|
||||
kwargs: The keyword arguments passed into the function.
|
||||
Args:
|
||||
function_signature: The function signature of the function being
|
||||
called.
|
||||
args: The non-keyword arguments passed into the function.
|
||||
kwargs: The keyword arguments passed into the function.
|
||||
|
||||
Returns:
|
||||
An extended list of arguments to pass into the function.
|
||||
Returns:
|
||||
An extended list of arguments to pass into the function.
|
||||
|
||||
Raises:
|
||||
Exception: An exception may be raised if the function cannot be called with
|
||||
these arguments.
|
||||
"""
|
||||
arg_names = function_signature.arg_names
|
||||
arg_defaults = function_signature.arg_defaults
|
||||
arg_is_positionals = function_signature.arg_is_positionals
|
||||
keyword_names = function_signature.keyword_names
|
||||
function_name = function_signature.function_name
|
||||
Raises:
|
||||
Exception: An exception may be raised if the function cannot be called
|
||||
with these arguments.
|
||||
"""
|
||||
arg_names = function_signature.arg_names
|
||||
arg_defaults = function_signature.arg_defaults
|
||||
arg_is_positionals = function_signature.arg_is_positionals
|
||||
keyword_names = function_signature.keyword_names
|
||||
function_name = function_signature.function_name
|
||||
|
||||
args = list(args)
|
||||
args = list(args)
|
||||
|
||||
for keyword_name in kwargs:
|
||||
if keyword_name not in keyword_names:
|
||||
raise Exception("The name '{}' is not a valid keyword argument for the "
|
||||
"function '{}'.".format(keyword_name, function_name))
|
||||
for keyword_name in kwargs:
|
||||
if keyword_name not in keyword_names:
|
||||
raise Exception("The name '{}' is not a valid keyword argument "
|
||||
"for the function '{}'."
|
||||
.format(keyword_name, function_name))
|
||||
|
||||
# Fill in the remaining arguments.
|
||||
zipped_info = list(zip(arg_names, arg_defaults,
|
||||
arg_is_positionals))[len(args):]
|
||||
for keyword_name, default_value, is_positional in zipped_info:
|
||||
if keyword_name in kwargs:
|
||||
args.append(kwargs[keyword_name])
|
||||
else:
|
||||
if default_value != funcsigs._empty:
|
||||
args.append(default_value)
|
||||
else:
|
||||
# This means that there is a missing argument. Unless this is the last
|
||||
# argument and it is a *args argument in which case it can be omitted.
|
||||
if not is_positional:
|
||||
raise Exception("No value was provided for the argument '{}' for "
|
||||
"the function '{}'.".format(keyword_name,
|
||||
function_name))
|
||||
# Fill in the remaining arguments.
|
||||
zipped_info = list(zip(arg_names, arg_defaults,
|
||||
arg_is_positionals))[len(args):]
|
||||
for keyword_name, default_value, is_positional in zipped_info:
|
||||
if keyword_name in kwargs:
|
||||
args.append(kwargs[keyword_name])
|
||||
else:
|
||||
if default_value != funcsigs._empty:
|
||||
args.append(default_value)
|
||||
else:
|
||||
# This means that there is a missing argument. Unless this is
|
||||
# the last argument and it is a *args argument in which case it
|
||||
# can be omitted.
|
||||
if not is_positional:
|
||||
raise Exception("No value was provided for the argument "
|
||||
"'{}' for the function '{}'."
|
||||
.format(keyword_name, function_name))
|
||||
|
||||
too_many_arguments = (len(args) > len(arg_names) and
|
||||
(len(arg_is_positionals) == 0 or
|
||||
not arg_is_positionals[-1]))
|
||||
if too_many_arguments:
|
||||
raise Exception("Too many arguments were passed to the function '{}'"
|
||||
.format(function_name))
|
||||
return args
|
||||
too_many_arguments = (len(args) > len(arg_names) and
|
||||
(len(arg_is_positionals) == 0 or
|
||||
not arg_is_positionals[-1]))
|
||||
if too_many_arguments:
|
||||
raise Exception("Too many arguments were passed to the function '{}'"
|
||||
.format(function_name))
|
||||
return args
|
||||
|
||||
@@ -11,99 +11,99 @@ import numpy as np
|
||||
|
||||
@ray.remote(num_return_vals=2)
|
||||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
return a + 1, b + 1
|
||||
|
||||
# Test timing
|
||||
|
||||
|
||||
@ray.remote
|
||||
def empty_function():
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
@ray.remote
|
||||
def trivial_function():
|
||||
return 1
|
||||
return 1
|
||||
|
||||
# Test keyword arguments
|
||||
|
||||
|
||||
@ray.remote
|
||||
def keyword_fct1(a, b="hello"):
|
||||
return "{} {}".format(a, b)
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def keyword_fct2(a="hello", b="world"):
|
||||
return "{} {}".format(a, b)
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def keyword_fct3(a, b, c="hello", d="world"):
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
|
||||
# Test variable numbers of arguments
|
||||
|
||||
|
||||
@ray.remote
|
||||
def varargs_fct1(*a):
|
||||
return " ".join(map(str, a))
|
||||
return " ".join(map(str, a))
|
||||
|
||||
|
||||
@ray.remote
|
||||
def varargs_fct2(a, *b):
|
||||
return " ".join(map(str, b))
|
||||
return " ".join(map(str, b))
|
||||
|
||||
|
||||
try:
|
||||
@ray.remote
|
||||
def kwargs_throw_exception(**c):
|
||||
return ()
|
||||
kwargs_exception_thrown = False
|
||||
@ray.remote
|
||||
def kwargs_throw_exception(**c):
|
||||
return ()
|
||||
kwargs_exception_thrown = False
|
||||
except:
|
||||
kwargs_exception_thrown = True
|
||||
kwargs_exception_thrown = True
|
||||
|
||||
try:
|
||||
@ray.remote
|
||||
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
|
||||
return "{} {} {}".format(a, b, c)
|
||||
varargs_and_kwargs_exception_thrown = False
|
||||
@ray.remote
|
||||
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
|
||||
return "{} {} {}".format(a, b, c)
|
||||
varargs_and_kwargs_exception_thrown = False
|
||||
except:
|
||||
varargs_and_kwargs_exception_thrown = True
|
||||
varargs_and_kwargs_exception_thrown = True
|
||||
|
||||
# test throwing an exception
|
||||
|
||||
|
||||
@ray.remote
|
||||
def throw_exception_fct1():
|
||||
raise Exception("Test function 1 intentionally failed.")
|
||||
raise Exception("Test function 1 intentionally failed.")
|
||||
|
||||
|
||||
@ray.remote
|
||||
def throw_exception_fct2():
|
||||
raise Exception("Test function 2 intentionally failed.")
|
||||
raise Exception("Test function 2 intentionally failed.")
|
||||
|
||||
|
||||
@ray.remote(num_return_vals=3)
|
||||
def throw_exception_fct3(x):
|
||||
raise Exception("Test function 3 intentionally failed.")
|
||||
raise Exception("Test function 3 intentionally failed.")
|
||||
|
||||
# test Python mode
|
||||
|
||||
|
||||
@ray.remote
|
||||
def python_mode_f():
|
||||
return np.array([0, 0])
|
||||
return np.array([0, 0])
|
||||
|
||||
|
||||
@ray.remote
|
||||
def python_mode_g(x):
|
||||
x[0] = 1
|
||||
return x
|
||||
x[0] = 1
|
||||
return x
|
||||
|
||||
# test no return values
|
||||
|
||||
|
||||
@ray.remote
|
||||
def no_op():
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -15,119 +15,124 @@ EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY"
|
||||
|
||||
|
||||
def _wait_for_nodes_to_join(num_nodes, timeout=20):
|
||||
"""Wait until the nodes have joined the cluster.
|
||||
"""Wait until the nodes have joined the cluster.
|
||||
|
||||
This will wait until exactly num_nodes have joined the cluster and each node
|
||||
has a local scheduler and a plasma manager.
|
||||
This will wait until exactly num_nodes have joined the cluster and each
|
||||
node has a local scheduler and a plasma manager.
|
||||
|
||||
Args:
|
||||
num_nodes: The number of nodes to wait for.
|
||||
timeout: The amount of time in seconds to wait before failing.
|
||||
Args:
|
||||
num_nodes: The number of nodes to wait for.
|
||||
timeout: The amount of time in seconds to wait before failing.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if too many nodes join the cluster or if
|
||||
the timeout expires while we are waiting.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
client_table = ray.global_state.client_table()
|
||||
num_ready_nodes = len(client_table)
|
||||
if num_ready_nodes == num_nodes:
|
||||
ready = True
|
||||
# Check that for each node, a local scheduler and a plasma manager are
|
||||
# present.
|
||||
for ip_address, clients in client_table.items():
|
||||
client_types = [client["ClientType"] for client in clients]
|
||||
if "local_scheduler" not in client_types:
|
||||
ready = False
|
||||
if "plasma_manager" not in client_types:
|
||||
ready = False
|
||||
if ready:
|
||||
return
|
||||
if num_ready_nodes > num_nodes:
|
||||
# Too many nodes have joined. Something must be wrong.
|
||||
raise Exception("{} nodes have joined the cluster, but we were "
|
||||
"expecting {} nodes.".format(num_ready_nodes, num_nodes))
|
||||
time.sleep(0.1)
|
||||
Raises:
|
||||
Exception: An exception is raised if too many nodes join the cluster or
|
||||
if the timeout expires while we are waiting.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
client_table = ray.global_state.client_table()
|
||||
num_ready_nodes = len(client_table)
|
||||
if num_ready_nodes == num_nodes:
|
||||
ready = True
|
||||
# Check that for each node, a local scheduler and a plasma manager
|
||||
# are present.
|
||||
for ip_address, clients in client_table.items():
|
||||
client_types = [client["ClientType"] for client in clients]
|
||||
if "local_scheduler" not in client_types:
|
||||
ready = False
|
||||
if "plasma_manager" not in client_types:
|
||||
ready = False
|
||||
if ready:
|
||||
return
|
||||
if num_ready_nodes > num_nodes:
|
||||
# Too many nodes have joined. Something must be wrong.
|
||||
raise Exception("{} nodes have joined the cluster, but we were "
|
||||
"expecting {} nodes.".format(num_ready_nodes,
|
||||
num_nodes))
|
||||
time.sleep(0.1)
|
||||
|
||||
# If we get here then we timed out.
|
||||
raise Exception("Timed out while waiting for {} nodes to join. Only {} "
|
||||
"nodes have joined so far.".format(num_ready_nodes,
|
||||
num_nodes))
|
||||
# If we get here then we timed out.
|
||||
raise Exception("Timed out while waiting for {} nodes to join. Only {} "
|
||||
"nodes have joined so far.".format(num_ready_nodes,
|
||||
num_nodes))
|
||||
|
||||
|
||||
def _broadcast_event(event_name, redis_address, data=None):
|
||||
"""Broadcast an event.
|
||||
"""Broadcast an event.
|
||||
|
||||
This is used to synchronize drivers for the multi-node tests.
|
||||
This is used to synchronize drivers for the multi-node tests.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to wait for.
|
||||
redis_address: The address of the Redis server to use for synchronization.
|
||||
data: Extra data to include in the broadcast (this will be returned by the
|
||||
corresponding _wait_for_event call). This data must be json serializable.
|
||||
"""
|
||||
redis_host, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
||||
payload = json.dumps((event_name, data))
|
||||
redis_client.rpush(EVENT_KEY, payload)
|
||||
Args:
|
||||
event_name: The name of the event to wait for.
|
||||
redis_address: The address of the Redis server to use for
|
||||
synchronization.
|
||||
data: Extra data to include in the broadcast (this will be returned by
|
||||
the corresponding _wait_for_event call). This data must be json
|
||||
serializable.
|
||||
"""
|
||||
redis_host, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
||||
payload = json.dumps((event_name, data))
|
||||
redis_client.rpush(EVENT_KEY, payload)
|
||||
|
||||
|
||||
def _wait_for_event(event_name, redis_address, extra_buffer=0):
|
||||
"""Block until an event has been broadcast.
|
||||
"""Block until an event has been broadcast.
|
||||
|
||||
This is used to synchronize drivers for the multi-node tests.
|
||||
This is used to synchronize drivers for the multi-node tests.
|
||||
|
||||
Args:
|
||||
event_name: The name of the event to wait for.
|
||||
redis_address: The address of the Redis server to use for synchronization.
|
||||
extra_buffer: An amount of time in seconds to wait after the event.
|
||||
Args:
|
||||
event_name: The name of the event to wait for.
|
||||
redis_address: The address of the Redis server to use for
|
||||
synchronization.
|
||||
extra_buffer: An amount of time in seconds to wait after the event.
|
||||
|
||||
Returns:
|
||||
The data that was passed into the corresponding _broadcast_event call.
|
||||
"""
|
||||
redis_host, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
||||
while True:
|
||||
event_infos = redis_client.lrange(EVENT_KEY, 0, -1)
|
||||
events = dict()
|
||||
for event_info in event_infos:
|
||||
name, data = json.loads(event_info)
|
||||
if name in events:
|
||||
raise Exception("The same event {} was broadcast twice.".format(name))
|
||||
events[name] = data
|
||||
if event_name in events:
|
||||
# Potentially sleep a little longer and then return the event data.
|
||||
time.sleep(extra_buffer)
|
||||
return events[event_name]
|
||||
time.sleep(0.1)
|
||||
Returns:
|
||||
The data that was passed into the corresponding _broadcast_event call.
|
||||
"""
|
||||
redis_host, redis_port = redis_address.split(":")
|
||||
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
||||
while True:
|
||||
event_infos = redis_client.lrange(EVENT_KEY, 0, -1)
|
||||
events = dict()
|
||||
for event_info in event_infos:
|
||||
name, data = json.loads(event_info)
|
||||
if name in events:
|
||||
raise Exception("The same event {} was broadcast twice."
|
||||
.format(name))
|
||||
events[name] = data
|
||||
if event_name in events:
|
||||
# Potentially sleep a little longer and then return the event data.
|
||||
time.sleep(extra_buffer)
|
||||
return events[event_name]
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _pid_alive(pid):
|
||||
"""Check if the process with this PID is alive or not.
|
||||
"""Check if the process with this PID is alive or not.
|
||||
|
||||
Args:
|
||||
pid: The pid to check.
|
||||
Args:
|
||||
pid: The pid to check.
|
||||
|
||||
Returns:
|
||||
This returns false if the process is dead or defunct. Otherwise, it returns
|
||||
true.
|
||||
"""
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE:
|
||||
return False
|
||||
Returns:
|
||||
This returns false if the process is dead or defunct. Otherwise, it
|
||||
returns true.
|
||||
"""
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def wait_for_pid_to_exit(pid, timeout=20):
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if not _pid_alive(pid):
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for process to exit.")
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if not _pid_alive(pid):
|
||||
return
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timed out while waiting for process to exit.")
|
||||
|
||||
+31
-30
@@ -11,51 +11,52 @@ import ray.local_scheduler
|
||||
|
||||
|
||||
def random_string():
|
||||
"""Generate a random string to use as an ID.
|
||||
"""Generate a random string to use as an ID.
|
||||
|
||||
Note that users may seed numpy, which could cause this function to generate
|
||||
duplicate IDs. Therefore, we need to seed numpy ourselves, but we can't
|
||||
interfere with the state of the user's random number generator, so we extract
|
||||
the state of the random number generator and reset it after we are done.
|
||||
Note that users may seed numpy, which could cause this function to generate
|
||||
duplicate IDs. Therefore, we need to seed numpy ourselves, but we can't
|
||||
interfere with the state of the user's random number generator, so we
|
||||
extract the state of the random number generator and reset it after we are
|
||||
done.
|
||||
|
||||
TODO(rkn): If we want to later guarantee that these are generated in a
|
||||
deterministic manner, then we will need to make some changes here.
|
||||
TODO(rkn): If we want to later guarantee that these are generated in a
|
||||
deterministic manner, then we will need to make some changes here.
|
||||
|
||||
Returns:
|
||||
A random byte string of length 20.
|
||||
"""
|
||||
# Get the state of the numpy random number generator.
|
||||
numpy_state = np.random.get_state()
|
||||
# Try to use true randomness.
|
||||
np.random.seed(None)
|
||||
# Generate the random ID.
|
||||
random_id = np.random.bytes(20)
|
||||
# Reset the state of the numpy random number generator.
|
||||
np.random.set_state(numpy_state)
|
||||
return random_id
|
||||
Returns:
|
||||
A random byte string of length 20.
|
||||
"""
|
||||
# Get the state of the numpy random number generator.
|
||||
numpy_state = np.random.get_state()
|
||||
# Try to use true randomness.
|
||||
np.random.seed(None)
|
||||
# Generate the random ID.
|
||||
random_id = np.random.bytes(20)
|
||||
# Reset the state of the numpy random number generator.
|
||||
np.random.set_state(numpy_state)
|
||||
return random_id
|
||||
|
||||
|
||||
def decode(byte_str):
|
||||
"""Make this unicode in Python 3, otherwise leave it as bytes."""
|
||||
if sys.version_info >= (3, 0):
|
||||
return byte_str.decode("ascii")
|
||||
else:
|
||||
return byte_str
|
||||
"""Make this unicode in Python 3, otherwise leave it as bytes."""
|
||||
if sys.version_info >= (3, 0):
|
||||
return byte_str.decode("ascii")
|
||||
else:
|
||||
return byte_str
|
||||
|
||||
|
||||
def binary_to_object_id(binary_object_id):
|
||||
return ray.local_scheduler.ObjectID(binary_object_id)
|
||||
return ray.local_scheduler.ObjectID(binary_object_id)
|
||||
|
||||
|
||||
def binary_to_hex(identifier):
|
||||
hex_identifier = binascii.hexlify(identifier)
|
||||
if sys.version_info >= (3, 0):
|
||||
hex_identifier = hex_identifier.decode()
|
||||
return hex_identifier
|
||||
hex_identifier = binascii.hexlify(identifier)
|
||||
if sys.version_info >= (3, 0):
|
||||
hex_identifier = hex_identifier.decode()
|
||||
return hex_identifier
|
||||
|
||||
|
||||
def hex_to_binary(hex_identifier):
|
||||
return binascii.unhexlify(hex_identifier)
|
||||
return binascii.unhexlify(hex_identifier)
|
||||
|
||||
|
||||
FunctionProperties = collections.namedtuple("FunctionProperties",
|
||||
|
||||
+1802
-1730
File diff suppressed because it is too large
Load Diff
@@ -27,59 +27,61 @@ parser.add_argument("--actor-id", required=False, type=str,
|
||||
|
||||
|
||||
def random_string():
|
||||
return np.random.bytes(20)
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
info = {"node_ip_address": args.node_ip_address,
|
||||
"redis_address": args.redis_address,
|
||||
"store_socket_name": args.object_store_name,
|
||||
"manager_socket_name": args.object_store_manager_name,
|
||||
"local_scheduler_socket_name": args.local_scheduler_name}
|
||||
args = parser.parse_args()
|
||||
info = {"node_ip_address": args.node_ip_address,
|
||||
"redis_address": args.redis_address,
|
||||
"store_socket_name": args.object_store_name,
|
||||
"manager_socket_name": args.object_store_manager_name,
|
||||
"local_scheduler_socket_name": args.local_scheduler_name}
|
||||
|
||||
if args.actor_id is not None:
|
||||
actor_id = binascii.unhexlify(args.actor_id)
|
||||
else:
|
||||
actor_id = ray.worker.NIL_ACTOR_ID
|
||||
if args.actor_id is not None:
|
||||
actor_id = binascii.unhexlify(args.actor_id)
|
||||
else:
|
||||
actor_id = ray.worker.NIL_ACTOR_ID
|
||||
|
||||
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
|
||||
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
|
||||
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker crashed
|
||||
in an unanticipated way causing the main_loop to throw an exception, which is
|
||||
being caught in "python/ray/workers/default_worker.py".
|
||||
"""
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker
|
||||
crashed in an unanticipated way causing the main_loop to throw an exception,
|
||||
which is being caught in "python/ray/workers/default_worker.py".
|
||||
"""
|
||||
|
||||
while True:
|
||||
try:
|
||||
# This call to main_loop should never return if things are working. Most
|
||||
# exceptions that are thrown (e.g., inside the execution of a task)
|
||||
# should be caught and handled inside of the call to main_loop. If an
|
||||
# exception is thrown here, then that means that there is some error that
|
||||
# we didn't anticipate.
|
||||
ray.worker.main_loop()
|
||||
except Exception as e:
|
||||
traceback_str = traceback.format_exc() + error_explanation
|
||||
DRIVER_ID_LENGTH = 20
|
||||
# We use a driver ID of all zeros to push an error message to all
|
||||
# drivers.
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = b"Error:" + driver_id + b":" + random_string()
|
||||
redis_ip_address, redis_port = args.redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine as
|
||||
# Redis) must have run "CONFIG SET protected-mode no".
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
redis_client.hmset(error_key, {"type": "worker_crash",
|
||||
"message": traceback_str,
|
||||
"note": ("This error is unexpected and "
|
||||
"should not have happened.")})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing a
|
||||
# task, the any worker or driver that is blocking in a get call and
|
||||
# waiting for the output of that task will hang. We need to address this.
|
||||
while True:
|
||||
try:
|
||||
# This call to main_loop should never return if things are working.
|
||||
# Most exceptions that are thrown (e.g., inside the execution of a
|
||||
# task) should be caught and handled inside of the call to
|
||||
# main_loop. If an exception is thrown here, then that means that
|
||||
# there is some error that we didn't anticipate.
|
||||
ray.worker.main_loop()
|
||||
except Exception as e:
|
||||
traceback_str = traceback.format_exc() + error_explanation
|
||||
DRIVER_ID_LENGTH = 20
|
||||
# We use a driver ID of all zeros to push an error message to all
|
||||
# drivers.
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = b"Error:" + driver_id + b":" + random_string()
|
||||
redis_ip_address, redis_port = args.redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine
|
||||
# as Redis) must have run "CONFIG SET protected-mode no".
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
redis_client.hmset(error_key, {"type": "worker_crash",
|
||||
"message": traceback_str,
|
||||
"note": ("This error is unexpected "
|
||||
"and should not have "
|
||||
"happened.")})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, the any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
# address this.
|
||||
|
||||
# After putting the error message in Redis, this worker will attempt to
|
||||
# reenter the main loop. TODO(rkn): We should probably reset it's state and
|
||||
# call connect again.
|
||||
# After putting the error message in Redis, this worker will attempt to
|
||||
# reenter the main loop. TODO(rkn): We should probably reset it's state
|
||||
# and call connect again.
|
||||
|
||||
+29
-27
@@ -11,32 +11,34 @@ import setuptools.command.build_ext as _build_ext
|
||||
|
||||
|
||||
class build_ext(_build_ext.build_ext):
|
||||
def run(self):
|
||||
subprocess.check_call(["../build.sh"])
|
||||
# Ideally, we could include these files by putting them in a MANIFEST.in or
|
||||
# using the package_data argument to setup, but the MANIFEST.in gets
|
||||
# applied at the very beginning when setup.py runs before these files have
|
||||
# been created, so we have to move the files manually.
|
||||
for filename in files_to_include:
|
||||
self.move_file(filename)
|
||||
# Copy over the autogenerated flatbuffer Python bindings.
|
||||
generated_python_directory = "ray/core/generated"
|
||||
for filename in os.listdir(generated_python_directory):
|
||||
if filename[-3:] == ".py":
|
||||
self.move_file(os.path.join(generated_python_directory, filename))
|
||||
def run(self):
|
||||
subprocess.check_call(["../build.sh"])
|
||||
# Ideally, we could include these files by putting them in a
|
||||
# MANIFEST.in or using the package_data argument to setup, but the
|
||||
# MANIFEST.in gets applied at the very beginning when setup.py runs
|
||||
# before these files have been created, so we have to move the files
|
||||
# manually.
|
||||
for filename in files_to_include:
|
||||
self.move_file(filename)
|
||||
# Copy over the autogenerated flatbuffer Python bindings.
|
||||
generated_python_directory = "ray/core/generated"
|
||||
for filename in os.listdir(generated_python_directory):
|
||||
if filename[-3:] == ".py":
|
||||
self.move_file(os.path.join(generated_python_directory,
|
||||
filename))
|
||||
|
||||
def move_file(self, filename):
|
||||
# TODO(rkn): This feels very brittle. It may not handle all cases. See
|
||||
# https://github.com/apache/arrow/blob/master/python/setup.py for an
|
||||
# example.
|
||||
source = filename
|
||||
destination = os.path.join(self.build_lib, filename)
|
||||
# Create the target directory if it doesn't already exist.
|
||||
parent_directory = os.path.dirname(destination)
|
||||
if not os.path.exists(parent_directory):
|
||||
os.makedirs(parent_directory)
|
||||
print("Copying {} to {}.".format(source, destination))
|
||||
shutil.copy(source, destination)
|
||||
def move_file(self, filename):
|
||||
# TODO(rkn): This feels very brittle. It may not handle all cases. See
|
||||
# https://github.com/apache/arrow/blob/master/python/setup.py for an
|
||||
# example.
|
||||
source = filename
|
||||
destination = os.path.join(self.build_lib, filename)
|
||||
# Create the target directory if it doesn't already exist.
|
||||
parent_directory = os.path.dirname(destination)
|
||||
if not os.path.exists(parent_directory):
|
||||
os.makedirs(parent_directory)
|
||||
print("Copying {} to {}.".format(source, destination))
|
||||
shutil.copy(source, destination)
|
||||
|
||||
|
||||
files_to_include = [
|
||||
@@ -54,8 +56,8 @@ files_to_include = [
|
||||
|
||||
|
||||
class BinaryDistribution(Distribution):
|
||||
def has_ext_modules(self):
|
||||
return True
|
||||
def has_ext_modules(self):
|
||||
return True
|
||||
|
||||
|
||||
setup(name="ray",
|
||||
|
||||
Reference in New Issue
Block a user