Switch Python indentation from 2 spaces to 4 spaces. (#726)

* 4 space indentation for actor.py.

* 4 space indentation for worker.py.

* 4 space indentation for more files.

* 4 space indentation for some test files.

* Check indentation in Travis.

* 4 space indentation for some rl files.

* Fix failure test.

* Fix multi_node_test.

* 4 space indentation for more files.

* 4 space indentation for remaining files.

* Fixes.
This commit is contained in:
Robert Nishihara
2017-07-13 14:53:57 -07:00
committed by Philipp Moritz
parent 310ba82131
commit e0867c8845
100 changed files with 16686 additions and 16189 deletions
+6 -6
View File
@@ -21,9 +21,9 @@ __all__ = ["register_class", "error_info", "init", "connect", "disconnect",
import ctypes
# Windows only
if hasattr(ctypes, "windll"):
# Makes sure that all child processes die when we die. Also makes sure that
# fatal crashes result in process termination rather than an error dialog
# (the latter is annoying since we have a lot of processes). This is done by
# associating all child processes with a "job" object that imposes this
# behavior.
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501
# Makes sure that all child processes die when we die. Also makes sure that
# fatal crashes result in process termination rather than an error dialog
# (the latter is annoying since we have a lot of processes). This is done
# by associating all child processes with a "job" object that imposes this
# behavior.
(lambda kernel32: (lambda job: (lambda n: kernel32.SetInformationJobObject(job, 9, "\0" * 17 + chr(0x8 | 0x4 | 0x20) + "\0" * (n - 18), n))(0x90 if ctypes.sizeof(ctypes.c_void_p) > ctypes.sizeof(ctypes.c_int) else 0x70) and kernel32.AssignProcessToJobObject(job, ctypes.c_void_p(kernel32.GetCurrentProcess())))(ctypes.c_void_p(kernel32.CreateJobObjectW(None, None))) if kernel32 is not None else None)(ctypes.windll.kernel32) # noqa: E501
+338 -325
View File
@@ -18,406 +18,419 @@ from ray.utils import (FunctionProperties, binary_to_hex, hex_to_binary,
def random_actor_id():
return ray.local_scheduler.ObjectID(random_string())
return ray.local_scheduler.ObjectID(random_string())
def random_actor_class_id():
return random_string()
return random_string()
def get_actor_method_function_id(attr):
"""Get the function ID corresponding to an actor method.
"""Get the function ID corresponding to an actor method.
Args:
attr (str): The attribute name of the method.
Args:
attr (str): The attribute name of the method.
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == 20
return ray.local_scheduler.ObjectID(function_id)
Returns:
Function ID corresponding to the method.
"""
function_id_hash = hashlib.sha1()
function_id_hash.update(attr.encode("ascii"))
function_id = function_id_hash.digest()
assert len(function_id) == 20
return ray.local_scheduler.ObjectID(function_id)
def fetch_and_register_actor(actor_class_key, worker):
"""Import an actor.
"""Import an actor.
This will be called by the worker's import thread when the worker receives
the actor_class export, assuming that the worker is an actor for that class.
"""
actor_id_str = worker.actor_id
(driver_id, class_id, class_name,
module, pickled_class, actor_method_names) = worker.redis_client.hmget(
actor_class_key, ["driver_id", "class_id", "class_name", "module",
"class", "actor_method_names"])
This will be called by the worker's import thread when the worker receives
the actor_class export, assuming that the worker is an actor for that
class.
"""
actor_id_str = worker.actor_id
(driver_id, class_id, class_name,
module, pickled_class, actor_method_names) = worker.redis_client.hmget(
actor_class_key, ["driver_id", "class_id", "class_name", "module",
"class", "actor_method_names"])
actor_name = class_name.decode("ascii")
module = module.decode("ascii")
actor_method_names = json.loads(actor_method_names.decode("ascii"))
actor_name = class_name.decode("ascii")
module = module.decode("ascii")
actor_method_names = json.loads(actor_method_names.decode("ascii"))
# Create a temporary actor with some temporary methods so that if the actor
# fails to be unpickled, the temporary actor can be used (just to produce
# error messages and to prevent the driver from hanging).
class TemporaryActor(object):
pass
worker.actors[actor_id_str] = TemporaryActor()
# Create a temporary actor with some temporary methods so that if the actor
# fails to be unpickled, the temporary actor can be used (just to produce
# error messages and to prevent the driver from hanging).
class TemporaryActor(object):
pass
worker.actors[actor_id_str] = TemporaryActor()
def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so "
"cannot execute this method".format(actor_name))
for actor_method_name in actor_method_names:
function_id = get_actor_method_function_id(actor_method_name).id()
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_actor_method)
worker.function_properties[driver_id][function_id] = FunctionProperties(
num_return_vals=1,
num_cpus=1,
num_gpus=0,
max_calls=0)
worker.num_task_executions[driver_id][function_id] = 0
def temporary_actor_method(*xs):
raise Exception("The actor with name {} failed to be imported, and so "
"cannot execute this method".format(actor_name))
for actor_method_name in actor_method_names:
function_id = get_actor_method_function_id(actor_method_name).id()
worker.functions[driver_id][function_id] = (actor_method_name,
temporary_actor_method)
worker.function_properties[driver_id][function_id] = (
FunctionProperties(num_return_vals=1,
num_cpus=1,
num_gpus=0,
max_calls=0))
worker.num_task_executions[driver_id][function_id] = 0
try:
unpickled_class = pickle.loads(pickled_class)
except Exception:
# If an exception was thrown when the actor was imported, we record the
# traceback and notify the scheduler of the failure.
traceback_str = ray.worker.format_error_message(traceback.format_exc())
# Log the error message.
worker.push_error_to_driver(driver_id, "register_actor", traceback_str,
data={"actor_id": actor_id_str})
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
for (k, v) in inspect.getmembers(
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x)))):
function_id = get_actor_method_function_id(k).id()
worker.functions[driver_id][function_id] = (k, v)
# We do not set worker.function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new tasks for
# the actor.
try:
unpickled_class = pickle.loads(pickled_class)
except Exception:
# If an exception was thrown when the actor was imported, we record the
# traceback and notify the scheduler of the failure.
traceback_str = ray.worker.format_error_message(traceback.format_exc())
# Log the error message.
worker.push_error_to_driver(driver_id, "register_actor", traceback_str,
data={"actor_id": actor_id_str})
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
worker.actors[actor_id_str] = unpickled_class.__new__(unpickled_class)
for (k, v) in inspect.getmembers(
unpickled_class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x)))):
function_id = get_actor_method_function_id(k).id()
worker.functions[driver_id][function_id] = (k, v)
# We do not set worker.function_properties[driver_id][function_id]
# because we currently do need the actor worker to submit new tasks
# for the actor.
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
Args:
num_gpus: The number of GPUs to acquire.
driver_id: The ID of the driver responsible for creating the actor.
local_scheduler: Information about the local scheduler.
Args:
num_gpus: The number of GPUs to acquire.
driver_id: The ID of the driver responsible for creating the actor.
local_scheduler: Information about the local scheduler.
Returns:
True if the GPUs were successfully reserved and false otherwise.
"""
assert num_gpus != 0
local_scheduler_id = local_scheduler["DBClientID"]
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
Returns:
True if the GPUs were successfully reserved and false otherwise.
"""
assert num_gpus != 0
local_scheduler_id = local_scheduler["DBClientID"]
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
success = False
success = False
# Attempt to acquire GPU IDs atomically.
with worker.redis_client.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the multi/exec
# block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
# Attempt to acquire GPU IDs atomically.
with worker.redis_client.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the
# multi/exec block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
# Figure out which GPUs are currently in use.
result = worker.redis_client.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(
result.decode("ascii"))
num_gpus_in_use = 0
for key in gpus_in_use:
num_gpus_in_use += gpus_in_use[key]
assert num_gpus_in_use <= local_scheduler_total_gpus
# Figure out which GPUs are currently in use.
result = worker.redis_client.hget(local_scheduler_id,
"gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(
result.decode("ascii"))
num_gpus_in_use = 0
for key in gpus_in_use:
num_gpus_in_use += gpus_in_use[key]
assert num_gpus_in_use <= local_scheduler_total_gpus
pipe.multi()
pipe.multi()
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
# There are enough available GPUs, so try to reserve some. We use the
# hex driver ID in hex as a dictionary key so that the dictionary is
# JSON serializable.
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex not in gpus_in_use:
gpus_in_use[driver_id_hex] = 0
gpus_in_use[driver_id_hex] += num_gpus
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
# There are enough available GPUs, so try to reserve some.
# We use the hex driver ID in hex as a dictionary key so
# that the dictionary is JSON serializable.
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex not in gpus_in_use:
gpus_in_use[driver_id_hex] = 0
gpus_in_use[driver_id_hex] += num_gpus
# Stick the updated GPU IDs back in Redis
pipe.hset(local_scheduler_id, "gpus_in_use", json.dumps(gpus_in_use))
success = True
# Stick the updated GPU IDs back in Redis
pipe.hset(local_scheduler_id, "gpus_in_use",
json.dumps(gpus_in_use))
success = True
pipe.execute()
# If a WatchError is not raised, then the operations should have gone
# through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the time we
# started WATCHing it and the pipeline's execution. We should just
# retry.
success = False
continue
pipe.execute()
# If a WatchError is not raised, then the operations should
# have gone through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the
# time we started WATCHing it and the pipeline's execution. We
# should just retry.
success = False
continue
return success
return success
def select_local_scheduler(local_schedulers, num_gpus, worker):
"""Select a local scheduler to assign this actor to.
"""Select a local scheduler to assign this actor to.
Args:
local_schedulers: A list of dictionaries of information about the local
schedulers.
num_gpus (int): The number of GPUs that must be reserved for this actor.
Args:
local_schedulers: A list of dictionaries of information about the local
schedulers.
num_gpus (int): The number of GPUs that must be reserved for this
actor.
Returns:
The ID of the local scheduler that has been chosen.
Returns:
The ID of the local scheduler that has been chosen.
Raises:
Exception: An exception is raised if no local scheduler can be found with
sufficient resources.
"""
driver_id = worker.task_driver_id.id()
Raises:
Exception: An exception is raised if no local scheduler can be found
with sufficient resources.
"""
driver_id = worker.task_driver_id.id()
local_scheduler_id = None
# Loop through all of the local schedulers in a random order.
local_schedulers = np.random.permutation(local_schedulers)
for local_scheduler in local_schedulers:
if local_scheduler["NumCPUs"] < 1:
continue
if local_scheduler["NumGPUs"] < num_gpus:
continue
if num_gpus == 0:
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
break
else:
# Try to reserve enough GPUs on this local scheduler.
success = attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler,
worker)
if success:
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
break
local_scheduler_id = None
# Loop through all of the local schedulers in a random order.
local_schedulers = np.random.permutation(local_schedulers)
for local_scheduler in local_schedulers:
if local_scheduler["NumCPUs"] < 1:
continue
if local_scheduler["NumGPUs"] < num_gpus:
continue
if num_gpus == 0:
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
break
else:
# Try to reserve enough GPUs on this local scheduler.
success = attempt_to_reserve_gpus(num_gpus, driver_id,
local_scheduler, worker)
if success:
local_scheduler_id = hex_to_binary(
local_scheduler["DBClientID"])
break
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs or other "
"resources to create this actor. The local scheduler "
"information is {}.".format(local_schedulers))
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs or other "
"resources to create this actor. The local scheduler "
"information is {}.".format(local_schedulers))
return local_scheduler_id
return local_scheduler_id
def export_actor_class(class_id, Class, actor_method_names, worker):
if worker.mode is None:
raise NotImplemented("TODO(pcm): Cache actors")
key = b"ActorClass:" + class_id
d = {"driver_id": worker.task_driver_id.id(),
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"actor_method_names": json.dumps(list(actor_method_names))}
worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key)
if worker.mode is None:
raise NotImplemented("TODO(pcm): Cache actors")
key = b"ActorClass:" + class_id
d = {"driver_id": worker.task_driver_id.id(),
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"actor_method_names": json.dumps(list(actor_method_names))}
worker.redis_client.hmset(key, d)
worker.redis_client.rpush("Exports", key)
def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
worker):
"""Export an actor to redis.
"""Export an actor to redis.
Args:
actor_id: The ID of the actor.
actor_method_names (list): A list of the names of this actor's methods.
num_cpus (int): The number of CPUs that this actor requires.
num_gpus (int): The number of GPUs that this actor requires.
"""
ray.worker.check_main_thread()
if worker.mode is None:
raise Exception("Actors cannot be created before Ray has been started. "
"You can start Ray with 'ray.init()'.")
key = b"Actor:" + actor_id.id()
Args:
actor_id: The ID of the actor.
actor_method_names (list): A list of the names of this actor's methods.
num_cpus (int): The number of CPUs that this actor requires.
num_gpus (int): The number of GPUs that this actor requires.
"""
ray.worker.check_main_thread()
if worker.mode is None:
raise Exception("Actors cannot be created before Ray has been "
"started. You can start Ray with 'ray.init()'.")
key = b"Actor:" + actor_id.id()
# For now, all actor methods have 1 return value.
driver_id = worker.task_driver_id.id()
for actor_method_name in actor_method_names:
# TODO(rkn): When we create a second actor, we are probably overwriting
# the values from the first actor here. This may or may not be a problem.
function_id = get_actor_method_function_id(actor_method_name).id()
worker.function_properties[driver_id][function_id] = FunctionProperties(
num_return_vals=1,
num_cpus=1,
num_gpus=0,
max_calls=0)
# For now, all actor methods have 1 return value.
driver_id = worker.task_driver_id.id()
for actor_method_name in actor_method_names:
# TODO(rkn): When we create a second actor, we are probably overwriting
# the values from the first actor here. This may or may not be a
# problem.
function_id = get_actor_method_function_id(actor_method_name).id()
worker.function_properties[driver_id][function_id] = (
FunctionProperties(num_return_vals=1,
num_cpus=1,
num_gpus=0,
max_calls=0))
# Get a list of the local schedulers from the client table.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if client["ClientType"] == "local_scheduler" and not client["Deleted"]:
local_schedulers.append(client)
# Select a local scheduler for the actor.
local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus,
worker)
assert local_scheduler_id is not None
# Get a list of the local schedulers from the client table.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if (client["ClientType"] == "local_scheduler" and
not client["Deleted"]):
local_schedulers.append(client)
# Select a local scheduler for the actor.
local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus,
worker)
assert local_scheduler_id is not None
# We must put the actor information in Redis before publishing the actor
# notification so that when the newly created actor attempts to fetch the
# information from Redis, it is already there.
worker.redis_client.hmset(key, {"class_id": class_id,
"num_gpus": num_gpus})
# We must put the actor information in Redis before publishing the actor
# notification so that when the newly created actor attempts to fetch the
# information from Redis, it is already there.
worker.redis_client.hmset(key, {"class_id": class_id,
"num_gpus": num_gpus})
# Really we should encode this message as a flatbuffer object. However, we're
# having trouble getting that to work. It almost works, but in Python 2.7,
# builder.CreateString fails on byte strings that contain characters outside
# range(128).
# Really we should encode this message as a flatbuffer object. However,
# we're having trouble getting that to work. It almost works, but in Python
# 2.7, builder.CreateString fails on byte strings that contain characters
# outside range(128).
# TODO(rkn): There is actually no guarantee that the local scheduler that we
# are publishing to has already subscribed to the actor_notifications
# channel. Therefore, this message may be missed and the workload will hang.
# This is a bug.
worker.redis_client.publish("actor_notifications",
actor_id.id() + driver_id + local_scheduler_id)
# TODO(rkn): There is actually no guarantee that the local scheduler that
# we are publishing to has already subscribed to the actor_notifications
# channel. Therefore, this message may be missed and the workload will
# hang. This is a bug.
worker.redis_client.publish("actor_notifications",
actor_id.id() + driver_id + local_scheduler_id)
def actor(*args, **kwargs):
raise Exception("The @ray.actor decorator is deprecated. Instead, please "
"use @ray.remote.")
raise Exception("The @ray.actor decorator is deprecated. Instead, please "
"use @ray.remote.")
def make_actor(cls, num_cpus, num_gpus):
# Modify the class to have an additional method that will be used for
# terminating the worker.
class Class(cls):
def __ray_terminate__(self):
ray.worker.global_worker.local_scheduler_client.disconnect()
import os
os._exit(0)
# Modify the class to have an additional method that will be used for
# terminating the worker.
class Class(cls):
def __ray_terminate__(self):
ray.worker.global_worker.local_scheduler_client.disconnect()
import os
os._exit(0)
Class.__module__ = cls.__module__
Class.__name__ = cls.__name__
Class.__module__ = cls.__module__
Class.__name__ = cls.__name__
class_id = random_actor_class_id()
# The list exported will have length 0 if the class has not been exported
# yet, and length one if it has. This is just implementing a bool, but we
# don't use a bool because we need to modify it inside of the NewClass
# constructor.
exported = []
class_id = random_actor_class_id()
# The list exported will have length 0 if the class has not been exported
# yet, and length one if it has. This is just implementing a bool, but we
# don't use a bool because we need to modify it inside of the NewClass
# constructor.
exported = []
# The function actor_method_call gets called if somebody tries to call a
# method on their local actor stub object.
def actor_method_call(actor_id, attr, function_signature, *args, **kwargs):
ray.worker.check_connected()
ray.worker.check_main_thread()
args = signature.extend_args(function_signature, args, kwargs)
# The function actor_method_call gets called if somebody tries to call a
# method on their local actor stub object.
def actor_method_call(actor_id, attr, function_signature, *args, **kwargs):
ray.worker.check_connected()
ray.worker.check_main_thread()
args = signature.extend_args(function_signature, args, kwargs)
function_id = get_actor_method_function_id(attr)
object_ids = ray.worker.global_worker.submit_task(function_id, "", args,
actor_id=actor_id)
if len(object_ids) == 1:
return object_ids[0]
elif len(object_ids) > 1:
return object_ids
function_id = get_actor_method_function_id(attr)
object_ids = ray.worker.global_worker.submit_task(function_id, "",
args,
actor_id=actor_id)
if len(object_ids) == 1:
return object_ids[0]
elif len(object_ids) > 1:
return object_ids
class ActorMethod(object):
def __init__(self, method_name, actor_id, method_signature):
self.method_name = method_name
self.actor_id = actor_id
self.method_signature = method_signature
class ActorMethod(object):
def __init__(self, method_name, actor_id, method_signature):
self.method_name = method_name
self.actor_id = actor_id
self.method_signature = method_signature
def __call__(self, *args, **kwargs):
raise Exception("Actor methods cannot be called directly. Instead "
"of running 'object.{}()', try 'object.{}.remote()'."
.format(self.method_name, self.method_name))
def __call__(self, *args, **kwargs):
raise Exception("Actor methods cannot be called directly. Instead "
"of running 'object.{}()', try "
"'object.{}.remote()'."
.format(self.method_name, self.method_name))
def remote(self, *args, **kwargs):
return actor_method_call(self.actor_id, self.method_name,
self.method_signature, *args, **kwargs)
def remote(self, *args, **kwargs):
return actor_method_call(self.actor_id, self.method_name,
self.method_signature, *args, **kwargs)
class NewClass(object):
def __init__(self, *args, **kwargs):
raise Exception("Actor classes cannot be instantiated directly. "
"Instead of running '{}()', try '{}.remote()'."
.format(Class.__name__, Class.__name__))
class NewClass(object):
def __init__(self, *args, **kwargs):
raise Exception("Actor classes cannot be instantiated directly. "
"Instead of running '{}()', try '{}.remote()'."
.format(Class.__name__, Class.__name__))
@classmethod
def remote(cls, *args, **kwargs):
actor_object = cls.__new__(cls)
actor_object._manual_init(*args, **kwargs)
return actor_object
@classmethod
def remote(cls, *args, **kwargs):
actor_object = cls.__new__(cls)
actor_object._manual_init(*args, **kwargs)
return actor_object
def _manual_init(self, *args, **kwargs):
self._ray_actor_id = random_actor_id()
self._ray_actor_methods = {
k: v for (k, v) in inspect.getmembers(
Class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x))))}
# Extract the signatures of each of the methods. This will be used to
# catch some errors if the methods are called with inappropriate
# arguments.
self._ray_method_signatures = dict()
for k, v in self._ray_actor_methods.items():
# Print a warning message if the method signature is not supported.
# We don't raise an exception because if the actor inherits from a
# class that has a method whose signature we don't support, we
# there may not be much the user can do about it.
signature.check_signature_supported(v, warn=True)
self._ray_method_signatures[k] = signature.extract_signature(
v, ignore_first=True)
def _manual_init(self, *args, **kwargs):
self._ray_actor_id = random_actor_id()
self._ray_actor_methods = {
k: v for (k, v) in inspect.getmembers(
Class, predicate=(lambda x: (inspect.isfunction(x) or
inspect.ismethod(x))))}
# Extract the signatures of each of the methods. This will be used
# to catch some errors if the methods are called with inappropriate
# arguments.
self._ray_method_signatures = dict()
for k, v in self._ray_actor_methods.items():
# Print a warning message if the method signature is not
# supported. We don't raise an exception because if the actor
# inherits from a class that has a method whose signature we
# don't support, we there may not be much the user can do about
# it.
signature.check_signature_supported(v, warn=True)
self._ray_method_signatures[k] = signature.extract_signature(
v, ignore_first=True)
# Create objects to wrap method invocations. This is done so that we
# can invoke methods with actor.method.remote() instead of
# actor.method().
self._actor_method_invokers = dict()
for k, v in self._ray_actor_methods.items():
self._actor_method_invokers[k] = ActorMethod(
k, self._ray_actor_id, self._ray_method_signatures[k])
# Create objects to wrap method invocations. This is done so that
# we can invoke methods with actor.method.remote() instead of
# actor.method().
self._actor_method_invokers = dict()
for k, v in self._ray_actor_methods.items():
self._actor_method_invokers[k] = ActorMethod(
k, self._ray_actor_id, self._ray_method_signatures[k])
# Export the actor class if it has not been exported yet.
if len(exported) == 0:
export_actor_class(class_id, Class, self._ray_actor_methods.keys(),
ray.worker.global_worker)
exported.append(0)
# Export the actor.
export_actor(self._ray_actor_id, class_id,
self._ray_actor_methods.keys(), num_cpus, num_gpus,
ray.worker.global_worker)
# Call __init__ as a remote function.
if "__init__" in self._ray_actor_methods.keys():
actor_method_call(self._ray_actor_id, "__init__",
self._ray_method_signatures["__init__"],
*args, **kwargs)
else:
print("WARNING: this object has no __init__ method.")
# Export the actor class if it has not been exported yet.
if len(exported) == 0:
export_actor_class(class_id, Class,
self._ray_actor_methods.keys(),
ray.worker.global_worker)
exported.append(0)
# Export the actor.
export_actor(self._ray_actor_id, class_id,
self._ray_actor_methods.keys(), num_cpus, num_gpus,
ray.worker.global_worker)
# Call __init__ as a remote function.
if "__init__" in self._ray_actor_methods.keys():
actor_method_call(self._ray_actor_id, "__init__",
self._ray_method_signatures["__init__"],
*args, **kwargs)
else:
print("WARNING: this object has no __init__ method.")
# Make tab completion work.
def __dir__(self):
return self._ray_actor_methods
# Make tab completion work.
def __dir__(self):
return self._ray_actor_methods
def __getattribute__(self, attr):
# The following is needed so we can still access self.actor_methods.
if attr in ["_manual_init", "_ray_actor_id", "_ray_actor_methods",
"_actor_method_invokers", "_ray_method_signatures"]:
return object.__getattribute__(self, attr)
if attr in self._ray_actor_methods.keys():
return self._actor_method_invokers[attr]
# There is no method with this name, so raise an exception.
raise AttributeError("'{}' Actor object has no attribute '{}'"
.format(Class, attr))
def __getattribute__(self, attr):
# The following is needed so we can still access
# self.actor_methods.
if attr in ["_manual_init", "_ray_actor_id", "_ray_actor_methods",
"_actor_method_invokers", "_ray_method_signatures"]:
return object.__getattribute__(self, attr)
if attr in self._ray_actor_methods.keys():
return self._actor_method_invokers[attr]
# There is no method with this name, so raise an exception.
raise AttributeError("'{}' Actor object has no attribute '{}'"
.format(Class, attr))
def __repr__(self):
return "Actor(" + self._ray_actor_id.hex() + ")"
def __repr__(self):
return "Actor(" + self._ray_actor_id.hex() + ")"
def __reduce__(self):
raise Exception("Actor objects cannot be pickled.")
def __reduce__(self):
raise Exception("Actor objects cannot be pickled.")
def __del__(self):
"""Kill the worker that is running this actor."""
if ray.worker.global_worker.connected:
actor_method_call(self._ray_actor_id, "__ray_terminate__",
self._ray_method_signatures["__ray_terminate__"])
def __del__(self):
"""Kill the worker that is running this actor."""
if ray.worker.global_worker.connected:
actor_method_call(
self._ray_actor_id, "__ray_terminate__",
self._ray_method_signatures["__ray_terminate__"])
return NewClass
return NewClass
ray.worker.global_worker.fetch_and_register_actor = fetch_and_register_actor
+398 -385
View File
@@ -23,423 +23,436 @@ OBJECT_CHANNEL_PREFIX = "OC:"
def integerToAsciiHex(num, numbytes):
retstr = b""
# Support 32 and 64 bit architecture.
assert(numbytes == 4 or numbytes == 8)
for i in range(numbytes):
curbyte = num & 0xff
if sys.version_info >= (3, 0):
retstr += bytes([curbyte])
else:
retstr += chr(curbyte)
num = num >> 8
retstr = b""
# Support 32 and 64 bit architecture.
assert(numbytes == 4 or numbytes == 8)
for i in range(numbytes):
curbyte = num & 0xff
if sys.version_info >= (3, 0):
retstr += bytes([curbyte])
else:
retstr += chr(curbyte)
num = num >> 8
return retstr
return retstr
def get_next_message(pubsub_client, timeout_seconds=10):
"""Block until the next message is available on the pubsub channel."""
start_time = time.time()
while True:
message = pubsub_client.get_message()
if message is not None:
return message
time.sleep(0.1)
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for next message.")
"""Block until the next message is available on the pubsub channel."""
start_time = time.time()
while True:
message = pubsub_client.get_message()
if message is not None:
return message
time.sleep(0.1)
if time.time() - start_time > timeout_seconds:
raise Exception("Timed out while waiting for next message.")
class TestGlobalStateStore(unittest.TestCase):
def setUp(self):
redis_port, _ = ray.services.start_redis_instance()
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
def setUp(self):
redis_port, _ = ray.services.start_redis_instance()
self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0)
def tearDown(self):
ray.services.cleanup()
def tearDown(self):
ray.services.cleanup()
def testInvalidObjectTableAdd(self):
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
# with the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", "one",
"hash2", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
"hash2", "manager_id1", "extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
# object ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1"})
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
# Check that the second manager was added, even though the hash was
# mismatched.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that it is fine if we add the same object ID multiple times with
# the most recent hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
"hash2", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
def testInvalidObjectTableAdd(self):
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called
# with the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
"one", "hash2", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", 1,
"hash2", "manager_id1",
"extra argument")
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an
# object ID that is already present with a different hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1"})
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
# Check that the second manager was added, even though the hash was
# mismatched.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that it is fine if we add the same object ID multiple times
# with the most recent hash.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash2", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2,
"hash2", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
def testObjectTableAddAndLookup(self):
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Add a manager that already exists again and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that we properly handle NULL characters. In the past, NULL
# characters were handled improperly causing a "hash mismatch" error if two
# object IDs that agreed up to the NULL character were inserted with
# different hashes.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
"hash2", "manager_id1")
# Check that NULL characters in the hash are handled properly.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash1", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash2", "manager_id1")
def testObjectTableAddAndLookup(self):
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not
# been added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Add a manager that already exists again and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Check that we properly handle NULL characters. In the past, NULL
# characters were handled improperly causing a "hash mismatch" error if
# two object IDs that agreed up to the NULL character were inserted
# with different hashes.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id3", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "\x00object_id4", 1,
"hash2", "manager_id1")
# Check that NULL characters in the hash are handled properly.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash1", "manager_id1")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", 1,
"\x00hash2", "manager_id1")
def testObjectTableAddAndRemove(self):
# Try removing a manager from an object ID that has not been added yet.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
# added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that doesn't exist, and make sure we still have the same
# set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that does exist. Make sure it gets removed the first
# time and does nothing the second time.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
# Remove the last manager, and make sure we have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
# Remove a manager from an empty set, and make sure we now have an empty
# set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
def testObjectTableAddAndRemove(self):
# Try removing a manager from an object ID that has not been added yet.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not
# been added yet.
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(response, None)
# Add some managers and try again.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that doesn't exist, and make sure we still have the
# same set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
# Remove a manager that does exist. Make sure it gets removed the first
# time and does nothing the second time.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id1")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), {b"manager_id2"})
# Remove the last manager, and make sure we have an empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id2")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
# Remove a manager from an empty set, and make sure we now have an
# empty set.
self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1",
"manager_id3")
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP",
"object_id1")
self.assertEqual(set(response), set())
def testObjectTableSubscribeToNotifications(self):
# Define a helper method for checking the contents of object notifications.
def check_object_notification(notification_message, object_id, object_size,
manager_ids):
notification_object = (SubscribeToNotificationsReply
.GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(),
len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i), manager_ids[i])
def testObjectTableSubscribeToNotifications(self):
# Define a helper method for checking the contents of object
# notifications.
def check_object_notification(notification_message, object_id,
object_size, manager_ids):
notification_object = (SubscribeToNotificationsReply
.GetRootAsSubscribeToNotificationsReply(
notification_message, 0))
self.assertEqual(notification_object.ObjectId(), object_id)
self.assertEqual(notification_object.ObjectSize(), object_size)
self.assertEqual(notification_object.ManagerIdsLength(),
len(manager_ids))
for i in range(len(manager_ids)):
self.assertEqual(notification_object.ManagerIds(i),
manager_ids[i])
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", data_size,
"hash1", "manager_id2")
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id1")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
data_size = 0xf1f0
p = self.redis.pubsub()
# Subscribe to an object ID.
p.psubscribe("{}manager_id1".format(OBJECT_CHANNEL_PREFIX))
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1",
data_size, "hash1", "manager_id2")
# Receive the acknowledgement message.
self.assertEqual(get_next_message(p)["data"], 1)
# Request a notification and receive the data.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id1")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id1",
data_size,
[b"manager_id2"])
# Request a notification for an object that isn't there. Then add the
# object and receive the data. Only the first call to RAY.OBJECT_TABLE_ADD
# should trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3", data_size,
"hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2", data_size,
"hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1", b"manager_id2", b"manager_id3"])
# Request a notification for an object that isn't there. Then add the
# object and receive the data. Only the first call to
# RAY.OBJECT_TABLE_ADD should trigger notifications.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id2", "object_id3")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id1")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id2")
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id3",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1"])
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id2",
data_size, "hash1", "manager_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id2",
data_size,
[b"manager_id3"])
# Request notifications for object_id3 again.
self.redis.execute_command("RAY.OBJECT_TABLE_REQUEST_NOTIFICATIONS",
"manager_id1", "object_id3")
# Verify that the notification is correct.
check_object_notification(get_next_message(p)["data"],
b"object_id3",
data_size,
[b"manager_id1", b"manager_id2",
b"manager_id3"])
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = ResultTableReply.GetRootAsResultTableReply(message,
0)
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
def testResultTableAddAndLookup(self):
def check_result_table_entry(message, task_id, is_put):
result_table_reply = ResultTableReply.GetRootAsResultTableReply(
message, 0)
self.assertEqual(result_table_reply.TaskId(), task_id)
self.assertEqual(result_table_reply.IsPut(), is_put)
# Try looking up something in the result table before anything is added.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Adding the object to the object table should have no effect.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Add the result to the result table. The lookup now returns the task ID.
task_id = b"task_id1"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id,
0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Doing it again should still work.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Try another result table lookup. This should succeed.
task_id = b"task_id2"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id,
1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id2")
check_result_table_entry(response, task_id, True)
# Try looking up something in the result table before anything is
# added.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Adding the object to the object table should have no effect.
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1,
"hash1", "manager_id1")
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
self.assertIsNone(response)
# Add the result to the result table. The lookup now returns the task
# ID.
task_id = b"task_id1"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1",
task_id, 0)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Doing it again should still work.
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id1")
check_result_table_entry(response, task_id, False)
# Try another result table lookup. This should succeed.
task_id = b"task_id2"
self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2",
task_id, 1)
response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP",
"object_id2")
check_result_table_entry(response, task_id, True)
def testInvalidTaskTableAdd(self):
# Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with
# the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3, "node_id")
with self.assertRaises(redis.ResponseError):
# Non-integer scheduling states should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
"invalid_state", "node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Should not be able to update a non-existent task.
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
"node_id")
def testInvalidTaskTableAdd(self):
# Check that Redis returns an error when RAY.TASK_TABLE_ADD is called
# with the wrong arguments.
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "hello")
with self.assertRaises(redis.ResponseError):
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id", 3,
"node_id")
with self.assertRaises(redis.ResponseError):
# Non-integer scheduling states should not be added.
self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
"invalid_state", "node_id", "task_spec")
with self.assertRaises(redis.ResponseError):
# Should not be able to update a non-existent task.
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id", 10,
"node_id")
def testTaskTableAddAndLookup(self):
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
def testTaskTableAddAndLookup(self):
TASK_STATUS_WAITING = 1
TASK_STATUS_SCHEDULED = 2
TASK_STATUS_QUEUED = 4
# make sure somebody will get a notification (checked in the redis module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# make sure somebody will get a notification (checked in the redis
# module)
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
def check_task_reply(message, task_args, updated=False):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
self.assertEqual(task_reply_object.Updated(), updated)
def check_task_reply(message, task_args, updated=False):
task_status, local_scheduler_id, task_spec = task_args
task_reply_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(task_reply_object.State(), task_status)
self.assertEqual(task_reply_object.LocalSchedulerId(),
local_scheduler_id)
self.assertEqual(task_reply_object.TaskSpec(), task_spec)
self.assertEqual(task_reply_object.Updated(), updated)
# Check that task table adds, updates, and lookups work correctly.
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
*task_args)
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
# Check that task table adds, updates, and lookups work correctly.
task_args = [TASK_STATUS_WAITING, b"node_id", b"task_spec"]
response = self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id",
*task_args)
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
*task_args[:2])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
task_args[0] = TASK_STATUS_SCHEDULED
self.redis.execute_command("RAY.TASK_TABLE_UPDATE", "task_id",
*task_args[:2])
response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(response, task_args)
# If the current value, test value, and set value are all the same, the
# update happens, and the response is still the same task.
task_args = [task_args[0]] + task_args
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(get_response, task_args[1:])
# If the current value, test value, and set value are all the same, the
# update happens, and the response is still the same task.
task_args = [task_args[0]] + task_args
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the task entry is still the same.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
check_task_reply(get_response, task_args[1:])
# If the current value is the same as the test value, and the set value is
# different, the update happens, and the response is the entire task.
task_args[1] = TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
check_task_reply(get_response, task_args[1:])
# If the current value is the same as the test value, and the set value
# is different, the update happens, and the response is the entire
# task.
task_args[1] = TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# Check that the update happened.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
check_task_reply(get_response, task_args[1:])
# If the current value is no longer the same as the test value, the
# response is the same task as before the test-and-set.
new_task_args = task_args[:]
new_task_args[1] = TASK_STATUS_WAITING
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertEqual(get_response2, get_response)
# If the current value is no longer the same as the test value, the
# response is the same task as before the test-and-set.
new_task_args = task_args[:]
new_task_args[1] = TASK_STATUS_WAITING
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
self.assertEqual(get_response2, get_response)
# If the test value is a bitmask that matches the current value, the update
# happens.
task_args = new_task_args
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# If the test value is a bitmask that matches the current value, the
# update happens.
task_args = new_task_args
task_args[0] = TASK_STATUS_SCHEDULED | TASK_STATUS_QUEUED
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*task_args[:3])
check_task_reply(response, task_args[1:], updated=True)
# If the test value is a bitmask that does not match the current value, the
# update does not happen, and the response is the same task as before the
# test-and-set.
new_task_args = task_args[:]
new_task_args[0] = TASK_STATUS_SCHEDULED
old_response = response
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id")
self.assertNotEqual(get_response, old_response)
check_task_reply(get_response, task_args[1:])
# If the test value is a bitmask that does not match the current value,
# the update does not happen, and the response is the same task as
# before the test-and-set.
new_task_args = task_args[:]
new_task_args[0] = TASK_STATUS_SCHEDULED
old_response = response
response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE",
"task_id",
*new_task_args[:3])
check_task_reply(response, task_args[1:], updated=False)
# Check that the update did not happen.
get_response = self.redis.execute_command("RAY.TASK_TABLE_GET",
"task_id")
self.assertNotEqual(get_response, old_response)
check_task_reply(get_response, task_args[1:])
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
def check_task_subscription(self, p, scheduling_state, local_scheduler_id):
task_args = [b"task_id", scheduling_state,
local_scheduler_id.encode("ascii"), b"task_spec"]
self.redis.execute_command("RAY.TASK_TABLE_ADD", *task_args)
# Receive the data.
message = get_next_message(p)["data"]
# Check that the notification object is correct.
notification_object = TaskReply.GetRootAsTaskReply(message, 0)
self.assertEqual(notification_object.TaskId(), b"task_id")
self.assertEqual(notification_object.State(), scheduling_state)
self.assertEqual(notification_object.LocalSchedulerId(),
local_scheduler_id.encode("ascii"))
self.assertEqual(notification_object.TaskSpec(), b"task_spec")
def testTaskTableSubscribe(self):
scheduling_state = 1
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
def testTaskTableSubscribe(self):
scheduling_state = 1
local_scheduler_id = "local_scheduler_id"
# Subscribe to the task table.
p = self.redis.pubsub()
p.psubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
# unsubscribe to make sure there is only one subscriber at a given time
p.punsubscribe("{prefix}*:*".format(prefix=TASK_PREFIX))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}*:{state}".format(
prefix=TASK_PREFIX, state=scheduling_state))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
p.psubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 1)
self.check_task_subscription(p, scheduling_state, local_scheduler_id)
p.punsubscribe("{prefix}{local_scheduler_id}:*".format(
prefix=TASK_PREFIX, local_scheduler_id=local_scheduler_id))
# Receive acknowledgment.
self.assertEqual(get_next_message(p)["data"], 0)
if __name__ == "__main__":
unittest.main(verbosity=2)
unittest.main(verbosity=2)
+102 -102
View File
@@ -13,19 +13,19 @@ ID_SIZE = 20
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
BASE_SIMPLE_OBJECTS = [
@@ -33,7 +33,7 @@ BASE_SIMPLE_OBJECTS = [
990 * u"h"]
if sys.version_info < (3, 0):
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
BASE_SIMPLE_OBJECTS += [long(0), long(1), long(100000), long(1 << 100)] # noqa: E501,F821
LIST_SIMPLE_OBJECTS = [[obj] for obj in BASE_SIMPLE_OBJECTS]
TUPLE_SIMPLE_OBJECTS = [(obj,) for obj in BASE_SIMPLE_OBJECTS]
@@ -51,8 +51,8 @@ l.append(l)
class Foo(object):
def __init__(self):
pass
def __init__(self):
pass
BASE_COMPLEX_OBJECTS = [999 * "h", 999 * u"h", l, Foo(),
@@ -70,118 +70,118 @@ COMPLEX_OBJECTS = (BASE_COMPLEX_OBJECTS +
class TestSerialization(unittest.TestCase):
def test_serialize_by_value(self):
def test_serialize_by_value(self):
for val in SIMPLE_OBJECTS:
self.assertTrue(local_scheduler.check_simple_value(val))
for val in COMPLEX_OBJECTS:
self.assertFalse(local_scheduler.check_simple_value(val))
for val in SIMPLE_OBJECTS:
self.assertTrue(local_scheduler.check_simple_value(val))
for val in COMPLEX_OBJECTS:
self.assertFalse(local_scheduler.check_simple_value(val))
class TestObjectID(unittest.TestCase):
def test_create_object_id(self):
random_object_id()
def test_create_object_id(self):
random_object_id()
def test_cannot_pickle_object_ids(self):
object_ids = [random_object_id() for _ in range(256)]
def test_cannot_pickle_object_ids(self):
object_ids = [random_object_id() for _ in range(256)]
def f():
return object_ids
def f():
return object_ids
def g(val=object_ids):
return 1
def g(val=object_ids):
return 1
def h():
object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
self.assertRaises(Exception, lambda: pickle.dumps(f))
self.assertRaises(Exception, lambda: pickle.dumps(g))
self.assertRaises(Exception, lambda: pickle.dumps(h))
def h():
object_ids[0]
return 1
# Make sure that object IDs cannot be pickled (including functions that
# close over object IDs).
self.assertRaises(Exception, lambda: pickle.dumps(object_ids[0]))
self.assertRaises(Exception, lambda: pickle.dumps(object_ids))
self.assertRaises(Exception, lambda: pickle.dumps(f))
self.assertRaises(Exception, lambda: pickle.dumps(g))
self.assertRaises(Exception, lambda: pickle.dumps(h))
def test_equality_comparisons(self):
x1 = local_scheduler.ObjectID(ID_SIZE * b"a")
x2 = local_scheduler.ObjectID(ID_SIZE * b"a")
y1 = local_scheduler.ObjectID(ID_SIZE * b"b")
y2 = local_scheduler.ObjectID(ID_SIZE * b"b")
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)
self.assertNotEqual(x1, y1)
def test_equality_comparisons(self):
x1 = local_scheduler.ObjectID(ID_SIZE * b"a")
x2 = local_scheduler.ObjectID(ID_SIZE * b"a")
y1 = local_scheduler.ObjectID(ID_SIZE * b"b")
y2 = local_scheduler.ObjectID(ID_SIZE * b"b")
self.assertEqual(x1, x2)
self.assertEqual(y1, y2)
self.assertNotEqual(x1, y1)
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
random_strings = [np.random.bytes(ID_SIZE) for _ in range(256)]
object_ids1 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
object_ids2 = [local_scheduler.ObjectID(random_strings[i])
for i in range(256)]
self.assertEqual(len(set(object_ids1)), 256)
self.assertEqual(len(set(object_ids1 + object_ids2)), 256)
self.assertEqual(set(object_ids1), set(object_ids2))
def test_hashability(self):
x = random_object_id()
y = random_object_id()
{x: y}
set([x, y])
def test_hashability(self):
x = random_object_id()
y = random_object_id()
{x: y}
set([x, y])
class TestTask(unittest.TestCase):
def check_task(self, task, function_id, num_return_vals, args):
self.assertEqual(function_id.id(), task.function_id().id())
retrieved_args = task.arguments()
self.assertEqual(num_return_vals, len(task.returns()))
self.assertEqual(len(args), len(retrieved_args))
for i in range(len(retrieved_args)):
if isinstance(retrieved_args[i], local_scheduler.ObjectID):
self.assertEqual(retrieved_args[i].id(), args[i].id())
else:
self.assertEqual(retrieved_args[i], args[i])
def check_task(self, task, function_id, num_return_vals, args):
self.assertEqual(function_id.id(), task.function_id().id())
retrieved_args = task.arguments()
self.assertEqual(num_return_vals, len(task.returns()))
self.assertEqual(len(args), len(retrieved_args))
for i in range(len(retrieved_args)):
if isinstance(retrieved_args[i], local_scheduler.ObjectID):
self.assertEqual(retrieved_args[i].id(), args[i].id())
else:
self.assertEqual(retrieved_args[i], args[i])
def test_create_and_serialize_task(self):
# TODO(rkn): The function ID should be a FunctionID object, not an
# ObjectID.
driver_id = random_driver_id()
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(driver_id, function_id, args,
num_return_vals, parent_id, 0)
self.check_task(task, function_id, num_return_vals, args)
data = local_scheduler.task_to_string(task)
task2 = local_scheduler.task_from_string(data)
self.check_task(task2, function_id, num_return_vals, args)
def test_create_and_serialize_task(self):
# TODO(rkn): The function ID should be a FunctionID object, not an
# ObjectID.
driver_id = random_driver_id()
parent_id = random_task_id()
function_id = random_function_id()
object_ids = [random_object_id() for _ in range(256)]
args_list = [
[],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 2, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(driver_id, function_id, args,
num_return_vals, parent_id, 0)
self.check_task(task, function_id, num_return_vals, args)
data = local_scheduler.task_to_string(task)
task2 = local_scheduler.task_from_string(data)
self.check_task(task2, function_id, num_return_vals, args)
if __name__ == "__main__":
unittest.main(verbosity=2)
unittest.main(verbosity=2)
+209 -203
View File
@@ -10,271 +10,277 @@ BLOCK_SIZE = 10
class DistArray(object):
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in self.shape]
if objectids is not None:
self.objectids = objectids
else:
self.objectids = np.empty(self.num_blocks, dtype=object)
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are "
"inconsistent, `num_blocks` is {} and `objectids` has "
"shape {}".format(self.num_blocks,
list(self.objectids.shape)))
def __init__(self, shape, objectids=None):
self.shape = shape
self.ndim = len(shape)
self.num_blocks = [int(np.ceil(1.0 * a / BLOCK_SIZE))
for a in self.shape]
if objectids is not None:
self.objectids = objectids
else:
self.objectids = np.empty(self.num_blocks, dtype=object)
if self.num_blocks != list(self.objectids.shape):
raise Exception("The fields `num_blocks` and `objectids` are "
"inconsistent, `num_blocks` is {} and `objectids` "
"has shape {}".format(self.num_blocks,
list(self.objectids.shape)))
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same "
"length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
return [elem * BLOCK_SIZE for elem in index]
@staticmethod
def compute_block_lower(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the "
"same length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
return [elem * BLOCK_SIZE for elem in index]
@staticmethod
def compute_block_upper(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the same "
"length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
upper = []
for i in range(len(shape)):
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
return upper
@staticmethod
def compute_block_upper(index, shape):
if len(index) != len(shape):
raise Exception("The fields `index` and `shape` must have the "
"same length, but `index` is {} and `shape` is "
"{}.".format(index, shape))
upper = []
for i in range(len(shape)):
upper.append(min((index[i] + 1) * BLOCK_SIZE, shape[i]))
return upper
@staticmethod
def compute_block_shape(index, shape):
lower = DistArray.compute_block_lower(index, shape)
upper = DistArray.compute_block_upper(index, shape)
return [u - l for (l, u) in zip(lower, upper)]
@staticmethod
def compute_block_shape(index, shape):
lower = DistArray.compute_block_lower(index, shape)
upper = DistArray.compute_block_upper(index, shape)
return [u - l for (l, u) in zip(lower, upper)]
@staticmethod
def compute_num_blocks(shape):
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
@staticmethod
def compute_num_blocks(shape):
return [int(np.ceil(1.0 * a / BLOCK_SIZE)) for a in shape]
def assemble(self):
"""Assemble an array from a distributed array of object IDs."""
first_block = ray.get(self.objectids[(0,) * self.ndim])
dtype = first_block.dtype
result = np.zeros(self.shape, dtype=dtype)
for index in np.ndindex(*self.num_blocks):
lower = DistArray.compute_block_lower(index, self.shape)
upper = DistArray.compute_block_upper(index, self.shape)
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(
self.objectids[index])
return result
def assemble(self):
"""Assemble an array from a distributed array of object IDs."""
first_block = ray.get(self.objectids[(0,) * self.ndim])
dtype = first_block.dtype
result = np.zeros(self.shape, dtype=dtype)
for index in np.ndindex(*self.num_blocks):
lower = DistArray.compute_block_lower(index, self.shape)
upper = DistArray.compute_block_upper(index, self.shape)
result[[slice(l, u) for (l, u) in zip(lower, upper)]] = ray.get(
self.objectids[index])
return result
def __getitem__(self, sliced):
# TODO(rkn): Fix this, this is just a placeholder that should work but is
# inefficient.
a = self.assemble()
return a[sliced]
def __getitem__(self, sliced):
# TODO(rkn): Fix this, this is just a placeholder that should work but
# is inefficient.
a = self.assemble()
return a[sliced]
@ray.remote
def assemble(a):
return a.assemble()
return a.assemble()
# TODO(rkn): What should we call this method?
@ray.remote
def numpy_to_dist(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
lower = DistArray.compute_block_lower(index, a.shape)
upper = DistArray.compute_block_upper(index, a.shape)
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
in zip(lower, upper)]])
return result
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
lower = DistArray.compute_block_lower(index, a.shape)
upper = DistArray.compute_block_upper(index, a.shape)
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u)
in zip(lower, upper)]])
return result
@ray.remote
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.zeros.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.zeros.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.ones.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.ones.remote(
DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote
def copy(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
# We don't need to actually copy the objects because remote objects are
# immutable.
result.objectids[index] = a.objectids[index]
return result
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
# We don't need to actually copy the objects because remote objects are
# immutable.
result.objectids[index] = a.objectids[index]
return result
@ray.remote
def eye(dim1, dim2=-1, dtype_name="float"):
dim2 = dim1 if dim2 == -1 else dim2
shape = [dim1, dim2]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objectids[i, j] = ra.eye.remote(block_shape[0], block_shape[1],
dtype_name=dtype_name)
else:
result.objectids[i, j] = ra.zeros.remote(block_shape,
dtype_name=dtype_name)
return result
dim2 = dim1 if dim2 == -1 else dim2
shape = [dim1, dim2]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
block_shape = DistArray.compute_block_shape([i, j], shape)
if i == j:
result.objectids[i, j] = ra.eye.remote(block_shape[0],
block_shape[1],
dtype_name=dtype_name)
else:
result.objectids[i, j] = ra.zeros.remote(block_shape,
dtype_name=dtype_name)
return result
@ray.remote
def triu(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i < j:
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
elif i == j:
result.objectids[i, j] = ra.triu.remote(a.objectids[i, j])
else:
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i < j:
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
elif i == j:
result.objectids[i, j] = ra.triu.remote(a.objectids[i, j])
else:
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def tril(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i > j:
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
elif i == j:
result.objectids[i, j] = ra.tril.remote(a.objectids[i, j])
else:
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is "
"{}.".format(a.ndim))
result = DistArray(a.shape)
for (i, j) in np.ndindex(*result.num_blocks):
if i > j:
result.objectids[i, j] = ra.copy.remote(a.objectids[i, j])
elif i == j:
result.objectids[i, j] = ra.tril.remote(a.objectids[i, j])
else:
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote
def blockwise_dot(*matrices):
n = len(matrices)
if n % 2 != 0:
raise Exception("blockwise_dot expects an even number of arguments, but "
"len(matrices) is {}.".format(n))
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
result = np.zeros(shape)
for i in range(n // 2):
result += np.dot(matrices[i], matrices[n // 2 + i])
return result
n = len(matrices)
if n % 2 != 0:
raise Exception("blockwise_dot expects an even number of arguments, "
"but len(matrices) is {}.".format(n))
shape = (matrices[0].shape[0], matrices[n // 2].shape[1])
result = np.zeros(shape)
for i in range(n // 2):
result += np.dot(matrices[i], matrices[n // 2 + i])
return result
@ray.remote
def dot(a, b):
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but "
"a.ndim = {}.".format(a.ndim))
if b.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but "
"b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but a.shape "
"= {} and b.shape = {}.".format(a.shape, b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
args = list(a.objectids[i, :]) + list(b.objectids[:, j])
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but "
"a.ndim = {}.".format(a.ndim))
if b.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but "
"b.ndim = {}.".format(b.ndim))
if a.shape[1] != b.shape[0]:
raise Exception("dot expects a.shape[1] to equal b.shape[0], but "
"a.shape = {} and b.shape = {}.".format(a.shape,
b.shape))
shape = [a.shape[0], b.shape[1]]
result = DistArray(shape)
for (i, j) in np.ndindex(*result.num_blocks):
args = list(a.objectids[i, :]) + list(b.objectids[:, j])
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
@ray.remote
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the
`a`. The result and `a` will have the same number of dimensions.For example,
subblocks(a, [0, 1], [2, 4])
will produce a DistArray whose objectids are
[[a.objectids[0, 2], a.objectids[0, 4]],
[a.objectids[1, 2], a.objectids[1, 4]]]
We allow the user to pass in an empty list [] to indicate the full range.
"""
ranges = list(ranges)
if len(ranges) != a.ndim:
raise Exception("sub_blocks expects to receive a number of ranges equal "
"to a.ndim, but it received {} ranges and a.ndim = "
"{}.".format(len(ranges), a.ndim))
for i in range(len(ranges)):
# We allow the user to pass in an empty list to indicate the full range.
if ranges[i] == []:
ranges[i] = range(a.num_blocks[i])
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
raise Exception("Ranges passed to sub_blocks must be sorted, but the "
"{}th range is {}.".format(i, ranges[i]))
if ranges[i][0] < 0:
raise Exception("Values in the ranges passed to sub_blocks must be at "
"least 0, but the {}th range is {}.".format(i,
ranges[i]))
if ranges[i][-1] >= a.num_blocks[i]:
raise Exception("Values in the ranges passed to sub_blocks must be less "
"than the relevant number of blocks, but the {}th range "
"is {}, and a.num_blocks = {}.".format(i, ranges[i],
a.num_blocks))
last_index = [r[-1] for r in ranges]
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i]
for i in range(a.ndim)]
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
for i in range(a.ndim)])]
return result
"""
This function produces a distributed array from a subset of the blocks in
the `a`. The result and `a` will have the same number of dimensions. For
example,
subblocks(a, [0, 1], [2, 4])
will produce a DistArray whose objectids are
[[a.objectids[0, 2], a.objectids[0, 4]],
[a.objectids[1, 2], a.objectids[1, 4]]]
We allow the user to pass in an empty list [] to indicate the full range.
"""
ranges = list(ranges)
if len(ranges) != a.ndim:
raise Exception("sub_blocks expects to receive a number of ranges "
"equal to a.ndim, but it received {} ranges and "
"a.ndim = {}.".format(len(ranges), a.ndim))
for i in range(len(ranges)):
# We allow the user to pass in an empty list to indicate the full
# range.
if ranges[i] == []:
ranges[i] = range(a.num_blocks[i])
if not np.alltrue(ranges[i] == np.sort(ranges[i])):
raise Exception("Ranges passed to sub_blocks must be sorted, but "
"the {}th range is {}.".format(i, ranges[i]))
if ranges[i][0] < 0:
raise Exception("Values in the ranges passed to sub_blocks must "
"be at least 0, but the {}th range is {}."
.format(i, ranges[i]))
if ranges[i][-1] >= a.num_blocks[i]:
raise Exception("Values in the ranges passed to sub_blocks must "
"be less than the relevant number of blocks, but "
"the {}th range is {}, and a.num_blocks = {}."
.format(i, ranges[i], a.num_blocks))
last_index = [r[-1] for r in ranges]
last_block_shape = DistArray.compute_block_shape(last_index, a.shape)
shape = [(len(ranges[i]) - 1) * BLOCK_SIZE + last_block_shape[i]
for i in range(a.ndim)]
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]]
for i in range(a.ndim)])]
return result
@ray.remote
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, but "
"a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
return result
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, "
"but a.ndim = {}, a.shape = {}.".format(a.ndim,
a.shape))
result = DistArray([a.shape[1], a.shape[0]])
for i in range(result.num_blocks[0]):
for j in range(result.num_blocks[1]):
result.objectids[i, j] = ra.transpose.remote(a.objectids[j, i])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same "
"shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.add.remote(x1.objectids[index],
x2.objectids[index])
return result
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same "
"shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.add.remote(x1.objectids[index],
x2.objectids[index])
return result
# TODO(rkn): support broadcasting?
@ray.remote
def subtract(x1, x2):
if x1.shape != x2.shape:
raise Exception("subtract expects arguments `x1` and `x2` to have the "
"same shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.subtract.remote(x1.objectids[index],
x2.objectids[index])
return result
if x1.shape != x2.shape:
raise Exception("subtract expects arguments `x1` and `x2` to have the "
"same shape, but x1.shape = {}, and x2.shape = {}."
.format(x1.shape, x2.shape))
result = DistArray(x1.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.subtract.remote(x1.objectids[index],
x2.objectids[index])
return result
@@ -13,74 +13,75 @@ __all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
@ray.remote(num_return_vals=2)
def tsqr(a):
"""Perform a QR decomposition of a tall-skinny matrix.
"""Perform a QR decomposition of a tall-skinny matrix.
Args:
a: A distributed matrix with shape MxN (suppose K = min(M, N)).
Args:
a: A distributed matrix with shape MxN (suppose K = min(M, N)).
Returns:
A tuple of q (a DistArray) and r (a numpy array) satisfying the following.
- If q_full = ray.get(DistArray, q).assemble(), then
q_full.shape == (M, K).
- np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True.
- If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N).
- np.allclose(r, np.triu(r)) == True.
"""
if len(a.shape) != 2:
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is "
"{}".format(a.shape))
if a.num_blocks[1] != 1:
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks is "
"{}".format(a.num_blocks))
Returns:
A tuple of q (a DistArray) and r (a numpy array) satisfying the
following.
- If q_full = ray.get(DistArray, q).assemble(), then
q_full.shape == (M, K).
- np.allclose(np.dot(q_full.T, q_full), np.eye(K)) == True.
- If r_val = ray.get(np.ndarray, r), then r_val.shape == (K, N).
- np.allclose(r, np.triu(r)) == True.
"""
if len(a.shape) != 2:
raise Exception("tsqr requires len(a.shape) == 2, but a.shape is "
"{}".format(a.shape))
if a.num_blocks[1] != 1:
raise Exception("tsqr requires a.num_blocks[1] == 1, but a.num_blocks "
"is {}".format(a.num_blocks))
num_blocks = a.num_blocks[0]
K = int(np.ceil(np.log2(num_blocks))) + 1
q_tree = np.empty((num_blocks, K), dtype=object)
current_rs = []
for i in range(num_blocks):
block = a.objectids[i, 0]
q, r = ra.linalg.qr.remote(block)
q_tree[i, 0] = q
current_rs.append(r)
for j in range(1, K):
new_rs = []
for i in range(int(np.ceil(1.0 * len(current_rs) / 2))):
stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)])
q, r = ra.linalg.qr.remote(stacked_rs)
q_tree[i, j] = q
new_rs.append(r)
current_rs = new_rs
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
# handle the special case in which the whole DistArray "a" fits in one block
# and has fewer rows than columns, this is a bit ugly so think about how to
# remove it
if a.shape[0] >= a.shape[1]:
q_shape = a.shape
else:
q_shape = [a.shape[0], a.shape[0]]
q_num_blocks = core.DistArray.compute_num_blocks(q_shape)
q_objectids = np.empty(q_num_blocks, dtype=object)
q_result = core.DistArray(q_shape, q_objectids)
# reconstruct output
for i in range(num_blocks):
q_block_current = q_tree[i, 0]
ith_index = i
num_blocks = a.num_blocks[0]
K = int(np.ceil(np.log2(num_blocks))) + 1
q_tree = np.empty((num_blocks, K), dtype=object)
current_rs = []
for i in range(num_blocks):
block = a.objectids[i, 0]
q, r = ra.linalg.qr.remote(block)
q_tree[i, 0] = q
current_rs.append(r)
for j in range(1, K):
if np.mod(ith_index, 2) == 0:
lower = [0, 0]
upper = [a.shape[1], core.BLOCK_SIZE]
else:
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], core.BLOCK_SIZE]
ith_index //= 2
q_block_current = ra.dot.remote(q_block_current,
ra.subarray.remote(q_tree[ith_index, j],
lower, upper))
q_result.objectids[i] = q_block_current
r = current_rs[0]
return q_result, ray.get(r)
new_rs = []
for i in range(int(np.ceil(1.0 * len(current_rs) / 2))):
stacked_rs = ra.vstack.remote(*current_rs[(2 * i):(2 * i + 2)])
q, r = ra.linalg.qr.remote(stacked_rs)
q_tree[i, j] = q
new_rs.append(r)
current_rs = new_rs
assert len(current_rs) == 1, "len(current_rs) = " + str(len(current_rs))
# handle the special case in which the whole DistArray "a" fits in one
# block and has fewer rows than columns, this is a bit ugly so think about
# how to remove it
if a.shape[0] >= a.shape[1]:
q_shape = a.shape
else:
q_shape = [a.shape[0], a.shape[0]]
q_num_blocks = core.DistArray.compute_num_blocks(q_shape)
q_objectids = np.empty(q_num_blocks, dtype=object)
q_result = core.DistArray(q_shape, q_objectids)
# reconstruct output
for i in range(num_blocks):
q_block_current = q_tree[i, 0]
ith_index = i
for j in range(1, K):
if np.mod(ith_index, 2) == 0:
lower = [0, 0]
upper = [a.shape[1], core.BLOCK_SIZE]
else:
lower = [a.shape[1], 0]
upper = [2 * a.shape[1], core.BLOCK_SIZE]
ith_index //= 2
q_block_current = ra.dot.remote(
q_block_current, ra.subarray.remote(q_tree[ith_index, j],
lower, upper))
q_result.objectids[i] = q_block_current
r = current_rs[0]
return q_result, ray.get(r)
# TODO(rkn): This is unoptimized, we really want a block version of this.
@@ -88,76 +89,77 @@ def tsqr(a):
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=3)
def modified_lu(q):
"""Perform a modified LU decomposition of a matrix.
"""Perform a modified LU decomposition of a matrix.
This takes a matrix q with orthonormal columns, returns l, u, s such that
q - s = l * u.
This takes a matrix q with orthonormal columns, returns l, u, s such that
q - s = l * u.
Args:
q: A two dimensional orthonormal matrix q.
Args:
q: A two dimensional orthonormal matrix q.
Returns:
A tuple of a lower triangular matrix l, an upper triangular matrix u, and a
a vector representing a diagonal matrix s such that q - s = l * u.
"""
q = q.assemble()
m, b = q.shape[0], q.shape[1]
S = np.zeros(b)
Returns:
A tuple of a lower triangular matrix l, an upper triangular matrix u,
and a a vector representing a diagonal matrix s such that
q - s = l * u.
"""
q = q.assemble()
m, b = q.shape[0], q.shape[1]
S = np.zeros(b)
q_work = np.copy(q)
q_work = np.copy(q)
for i in range(b):
S[i] = -1 * np.sign(q_work[i, i])
q_work[i, i] -= S[i]
# Scale ith column of L by diagonal element.
q_work[(i + 1):m, i] /= q_work[i, i]
# Perform Schur complement update.
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i],
q_work[i, (i + 1):b])
for i in range(b):
S[i] = -1 * np.sign(q_work[i, i])
q_work[i, i] -= S[i]
# Scale ith column of L by diagonal element.
q_work[(i + 1):m, i] /= q_work[i, i]
# Perform Schur complement update.
q_work[(i + 1):m, (i + 1):b] -= np.outer(q_work[(i + 1):m, i],
q_work[i, (i + 1):b])
L = np.tril(q_work)
for i in range(b):
L[i, i] = 1
U = np.triu(q_work)[:b, :]
# TODO(rkn): Get rid of the put below.
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
L = np.tril(q_work)
for i in range(b):
L[i, i] = 1
U = np.triu(q_work)[:b, :]
# TODO(rkn): Get rid of the put below.
return ray.get(core.numpy_to_dist.remote(ray.put(L))), U, S
@ray.remote(num_return_vals=2)
def tsqr_hr_helper1(u, s, y_top_block, b):
y_top = y_top_block[:b, :b]
s_full = np.diag(s)
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
return t, y_top
y_top = y_top_block[:b, :b]
s_full = np.diag(s)
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
return t, y_top
@ray.remote
def tsqr_hr_helper2(s, r_temp):
s_full = np.diag(s)
return np.dot(s_full, r_temp)
s_full = np.diag(s)
return np.dot(s_full, r_temp)
# This is Algorithm 6 from
# http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf.
@ray.remote(num_return_vals=4)
def tsqr_hr(a):
q, r_temp = tsqr.remote(a)
y, u, s = modified_lu.remote(q)
y_blocked = ray.get(y)
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0],
a.shape[1])
r = tsqr_hr_helper2.remote(s, r_temp)
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
q, r_temp = tsqr.remote(a)
y, u, s = modified_lu.remote(q)
y_blocked = ray.get(y)
t, y_top = tsqr_hr_helper1.remote(u, s, y_blocked.objectids[0, 0],
a.shape[1])
r = tsqr_hr_helper2.remote(s, r_temp)
return ray.get(y), ray.get(t), ray.get(y_top), ray.get(r)
@ray.remote
def qr_helper1(a_rc, y_ri, t, W_c):
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
@ray.remote
def qr_helper2(y_ri, a_rc):
return np.dot(y_ri.T, a_rc)
return np.dot(y_ri.T, a_rc)
# This is Algorithm 7 from
@@ -165,60 +167,63 @@ def qr_helper2(y_ri, a_rc):
@ray.remote(num_return_vals=2)
def qr(a):
m, n = a.shape[0], a.shape[1]
k = min(m, n)
m, n = a.shape[0], a.shape[1]
k = min(m, n)
# we will store our scratch work in a_work
a_work = core.DistArray(a.shape, np.copy(a.objectids))
# we will store our scratch work in a_work
a_work = core.DistArray(a.shape, np.copy(a.objectids))
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
# TODO(rkn): It would be preferable not to get this right after creating it.
r_res = ray.get(core.zeros.remote([k, n], result_dtype))
# TODO(rkn): It would be preferable not to get this right after creating it.
y_res = ray.get(core.zeros.remote([m, k], result_dtype))
Ts = []
result_dtype = np.linalg.qr(ray.get(a.objectids[0, 0]))[0].dtype.name
# TODO(rkn): It would be preferable not to get this right after creating
# it.
r_res = ray.get(core.zeros.remote([k, n], result_dtype))
# TODO(rkn): It would be preferable not to get this right after creating
# it.
y_res = ray.get(core.zeros.remote([m, k], result_dtype))
Ts = []
# The for loop differs from the paper, which says
# "for i in range(a.num_blocks[1])", but that doesn't seem to make any sense
# when a.num_blocks[1] > a.num_blocks[0].
for i in range(min(a.num_blocks[0], a.num_blocks[1])):
sub_dist_array = core.subblocks.remote(
a_work, list(range(i, a_work.num_blocks[0])), [i])
y, t, _, R = tsqr_hr.remote(sub_dist_array)
y_val = ray.get(y)
# The for loop differs from the paper, which says
# "for i in range(a.num_blocks[1])", but that doesn't seem to make any
# sense when a.num_blocks[1] > a.num_blocks[0].
for i in range(min(a.num_blocks[0], a.num_blocks[1])):
sub_dist_array = core.subblocks.remote(
a_work, list(range(i, a_work.num_blocks[0])), [i])
y, t, _, R = tsqr_hr.remote(sub_dist_array)
y_val = ray.get(y)
for j in range(i, a.num_blocks[0]):
y_res.objectids[j, i] = y_val.objectids[j - i, 0]
if a.shape[0] > a.shape[1]:
# in this case, R needs to be square
R_shape = ray.get(ra.shape.remote(R))
eye_temp = ra.eye.remote(R_shape[1], R_shape[0], dtype_name=result_dtype)
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objectids[i, i] = R
Ts.append(core.numpy_to_dist.remote(t))
for j in range(i, a.num_blocks[0]):
y_res.objectids[j, i] = y_val.objectids[j - i, 0]
if a.shape[0] > a.shape[1]:
# in this case, R needs to be square
R_shape = ray.get(ra.shape.remote(R))
eye_temp = ra.eye.remote(R_shape[1], R_shape[0],
dtype_name=result_dtype)
r_res.objectids[i, i] = ra.dot.remote(eye_temp, R)
else:
r_res.objectids[i, i] = R
Ts.append(core.numpy_to_dist.remote(t))
for c in range(i + 1, a.num_blocks[1]):
W_rcs = []
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objectids[r - i, 0]
W_rcs.append(qr_helper2.remote(y_ri, a_work.objectids[r, c]))
W_c = ra.sum_list.remote(*W_rcs)
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objectids[r - i, 0]
A_rc = qr_helper1.remote(a_work.objectids[r, c], y_ri, t, W_c)
a_work.objectids[r, c] = A_rc
r_res.objectids[i, c] = a_work.objectids[i, c]
for c in range(i + 1, a.num_blocks[1]):
W_rcs = []
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objectids[r - i, 0]
W_rcs.append(qr_helper2.remote(y_ri, a_work.objectids[r, c]))
W_c = ra.sum_list.remote(*W_rcs)
for r in range(i, a.num_blocks[0]):
y_ri = y_val.objectids[r - i, 0]
A_rc = qr_helper1.remote(a_work.objectids[r, c], y_ri, t, W_c)
a_work.objectids[r, c] = A_rc
r_res.objectids[i, c] = a_work.objectids[i, c]
# construct q_res from Ys and Ts
q = core.eye.remote(m, k, dtype_name=result_dtype)
for i in range(len(Ts))[::-1]:
y_col_block = core.subblocks.remote(y_res, [], [i])
q = core.subtract.remote(
q, core.dot.remote(
y_col_block,
core.dot.remote(Ts[i],
core.dot.remote(core.transpose.remote(y_col_block),
q))))
# construct q_res from Ys and Ts
q = core.eye.remote(m, k, dtype_name=result_dtype)
for i in range(len(Ts))[::-1]:
y_col_block = core.subblocks.remote(y_res, [], [i])
q = core.subtract.remote(
q, core.dot.remote(
y_col_block,
core.dot.remote(
Ts[i],
core.dot.remote(core.transpose.remote(y_col_block), q))))
return ray.get(q), r_res
return ray.get(q), r_res
@@ -11,10 +11,10 @@ from .core import DistArray
@ray.remote
def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objectids = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objectids[index] = ra.random.normal.remote(
DistArray.compute_block_shape(index, shape))
result = DistArray(shape, objectids)
return result
num_blocks = DistArray.compute_num_blocks(shape)
objectids = np.empty(num_blocks, dtype=object)
for index in np.ndindex(*num_blocks):
objectids[index] = ra.random.normal.remote(
DistArray.compute_block_shape(index, shape))
result = DistArray(shape, objectids)
return result
+21 -21
View File
@@ -8,94 +8,94 @@ import ray
@ray.remote
def zeros(shape, dtype_name="float", order="C"):
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def zeros_like(a, dtype_name="None", order="K", subok=True):
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
@ray.remote
def ones(shape, dtype_name="float", order="C"):
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote
def eye(N, M=-1, k=0, dtype_name="float"):
M = N if M == -1 else M
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
M = N if M == -1 else M
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
@ray.remote
def dot(a, b):
return np.dot(a, b)
return np.dot(a, b)
@ray.remote
def vstack(*xs):
return np.vstack(xs)
return np.vstack(xs)
@ray.remote
def hstack(*xs):
return np.hstack(xs)
return np.hstack(xs)
# TODO(rkn): Instead of this, consider implementing slicing.
# TODO(rkn): Be consistent about using "index" versus "indices".
@ray.remote
def subarray(a, lower_indices, upper_indices):
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@ray.remote
def copy(a, order="K"):
return np.copy(a, order=order)
return np.copy(a, order=order)
@ray.remote
def tril(m, k=0):
return np.tril(m, k=k)
return np.tril(m, k=k)
@ray.remote
def triu(m, k=0):
return np.triu(m, k=k)
return np.triu(m, k=k)
@ray.remote
def diag(v, k=0):
return np.diag(v, k=k)
return np.diag(v, k=k)
@ray.remote
def transpose(a, axes=[]):
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
@ray.remote
def add(x1, x2):
return np.add(x1, x2)
return np.add(x1, x2)
@ray.remote
def subtract(x1, x2):
return np.subtract(x1, x2)
return np.subtract(x1, x2)
@ray.remote
def sum(x, axis=-1):
return np.sum(x, axis=axis if axis != -1 else None)
return np.sum(x, axis=axis if axis != -1 else None)
@ray.remote
def shape(a):
return np.shape(a)
return np.shape(a)
@ray.remote
def sum_list(*xs):
return np.sum(xs, axis=0)
return np.sum(xs, axis=0)
+20 -20
View File
@@ -13,99 +13,99 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
@ray.remote
def matrix_power(M, n):
return np.linalg.matrix_power(M, n)
return np.linalg.matrix_power(M, n)
@ray.remote
def solve(a, b):
return np.linalg.solve(a, b)
return np.linalg.solve(a, b)
@ray.remote(num_return_vals=2)
def tensorsolve(a):
raise NotImplementedError
raise NotImplementedError
@ray.remote(num_return_vals=2)
def tensorinv(a):
raise NotImplementedError
raise NotImplementedError
@ray.remote
def inv(a):
return np.linalg.inv(a)
return np.linalg.inv(a)
@ray.remote
def cholesky(a):
return np.linalg.cholesky(a)
return np.linalg.cholesky(a)
@ray.remote
def eigvals(a):
return np.linalg.eigvals(a)
return np.linalg.eigvals(a)
@ray.remote
def eigvalsh(a):
raise NotImplementedError
raise NotImplementedError
@ray.remote
def pinv(a):
return np.linalg.pinv(a)
return np.linalg.pinv(a)
@ray.remote
def slogdet(a):
raise NotImplementedError
raise NotImplementedError
@ray.remote
def det(a):
return np.linalg.det(a)
return np.linalg.det(a)
@ray.remote(num_return_vals=3)
def svd(a):
return np.linalg.svd(a)
return np.linalg.svd(a)
@ray.remote(num_return_vals=2)
def eig(a):
return np.linalg.eig(a)
return np.linalg.eig(a)
@ray.remote(num_return_vals=2)
def eigh(a):
return np.linalg.eigh(a)
return np.linalg.eigh(a)
@ray.remote(num_return_vals=4)
def lstsq(a, b):
return np.linalg.lstsq(a)
return np.linalg.lstsq(a)
@ray.remote
def norm(x):
return np.linalg.norm(x)
return np.linalg.norm(x)
@ray.remote(num_return_vals=2)
def qr(a):
return np.linalg.qr(a)
return np.linalg.qr(a)
@ray.remote
def cond(x):
return np.linalg.cond(x)
return np.linalg.cond(x)
@ray.remote
def matrix_rank(M):
return np.linalg.matrix_rank(M)
return np.linalg.matrix_rank(M)
@ray.remote
def multi_dot(*a):
raise NotImplementedError
raise NotImplementedError
@@ -8,4 +8,4 @@ import ray
@ray.remote
def normal(shape):
return np.random.normal(size=shape)
return np.random.normal(size=shape)
+455 -439
View File
@@ -49,507 +49,523 @@ TASK_STATUS_MAPPING = {
class GlobalState(object):
"""A class used to interface with the Ray control state.
"""A class used to interface with the Ray control state.
Attributes:
redis_client: The redis client used to query the redis server.
"""
def __init__(self):
"""Create a GlobalState object."""
self.redis_client = None
def _check_connected(self):
"""Check that the object has been initialized before it is used.
Raises:
Exception: An exception is raised if ray.init() has not been called yet.
Attributes:
redis_client: The redis client used to query the redis server.
"""
if self.redis_client is None:
raise Exception("The ray.global_state API cannot be used before "
"ray.init has been called.")
def __init__(self):
"""Create a GlobalState object."""
self.redis_client = None
def _initialize_global_state(self, redis_ip_address, redis_port):
"""Initialize the GlobalState object by connecting to Redis.
def _check_connected(self):
"""Check that the object has been initialized before it is used.
Args:
redis_ip_address: The IP address of the node that the Redis server lives
on.
redis_port: The port that the Redis server is listening on.
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_clients = []
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
raise Exception("No entry found for NumRedisShards")
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))
Raises:
Exception: An exception is raised if ray.init() has not been called
yet.
"""
if self.redis_client is None:
raise Exception("The ray.global_state API cannot be used before "
"ray.init has been called.")
ip_address_ports = self.redis_client.lrange("RedisShards", start=0, end=-1)
if len(ip_address_ports) != num_redis_shards:
raise Exception("Expected {} Redis shard addresses, found "
"{}".format(num_redis_shards, len(ip_address_ports)))
def _initialize_global_state(self, redis_ip_address, redis_port):
"""Initialize the GlobalState object by connecting to Redis.
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))
Args:
redis_ip_address: The IP address of the node that the Redis server
lives on.
redis_port: The port that the Redis server is listening on.
"""
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.redis_clients = []
num_redis_shards = self.redis_client.get("NumRedisShards")
if num_redis_shards is None:
raise Exception("No entry found for NumRedisShards")
num_redis_shards = int(num_redis_shards)
if (num_redis_shards < 1):
raise Exception("Expected at least one Redis shard, found "
"{}.".format(num_redis_shards))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
ip_address_ports = self.redis_client.lrange("RedisShards", start=0,
end=-1)
if len(ip_address_ports) != num_redis_shards:
raise Exception("Expected {} Redis shard addresses, found "
"{}".format(num_redis_shards,
len(ip_address_ports)))
Args:
key: The object ID or the task ID that the query is about.
args: The command to run.
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(redis.StrictRedis(host=shard_address,
port=shard_port))
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
return client.execute_command(*args)
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
def _keys(self, pattern):
"""Execute the KEYS command on all Redis shards.
Args:
key: The object ID or the task ID that the query is about.
args: The command to run.
Args:
pattern: The KEYS pattern to query.
Returns:
The value returned by the Redis command.
"""
client = self.redis_clients[key.redis_shard_hash() %
len(self.redis_clients)]
return client.execute_command(*args)
Returns:
The concatenated list of results from all shards.
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
return result
def _keys(self, pattern):
"""Execute the KEYS command on all Redis shards.
def _object_table(self, object_id):
"""Fetch and parse the object table information for a single object ID.
Args:
pattern: The KEYS pattern to query.
Args:
object_id_binary: A string of bytes with the object ID to get information
about.
Returns:
The concatenated list of results from all shards.
"""
result = []
for client in self.redis_clients:
result.extend(client.keys(pattern))
return result
Returns:
A dictionary with information about the object ID in question.
"""
# Allow the argument to be either an ObjectID or a hex string.
if not isinstance(object_id, ray.local_scheduler.ObjectID):
object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id))
def _object_table(self, object_id):
"""Fetch and parse the object table information for a single object ID.
# Return information about a single object ID.
object_locations = self._execute_command(object_id,
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
else:
manager_ids = None
Args:
object_id_binary: A string of bytes with the object ID to get
information about.
result_table_response = self._execute_command(object_id,
"RAY.RESULT_TABLE_LOOKUP",
object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
Returns:
A dictionary with information about the object ID in question.
"""
# Allow the argument to be either an ObjectID or a hex string.
if not isinstance(object_id, ray.local_scheduler.ObjectID):
object_id = ray.local_scheduler.ObjectID(hex_to_binary(object_id))
result = {"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut()),
"DataSize": result_table_message.DataSize(),
"Hash": binary_to_hex(result_table_message.Hash())}
# Return information about a single object ID.
object_locations = self._execute_command(object_id,
"RAY.OBJECT_TABLE_LOOKUP",
object_id.id())
if object_locations is not None:
manager_ids = [binary_to_hex(manager_id)
for manager_id in object_locations]
else:
manager_ids = None
return result
result_table_response = self._execute_command(
object_id, "RAY.RESULT_TABLE_LOOKUP", object_id.id())
result_table_message = ResultTableReply.GetRootAsResultTableReply(
result_table_response, 0)
def object_table(self, object_id=None):
"""Fetch and parse the object table information for one or more object IDs.
result = {"ManagerIDs": manager_ids,
"TaskID": binary_to_hex(result_table_message.TaskId()),
"IsPut": bool(result_table_message.IsPut()),
"DataSize": result_table_message.DataSize(),
"Hash": binary_to_hex(result_table_message.Hash())}
Args:
object_id: An object ID to fetch information about. If this is None, then
the entire object table is fetched.
return result
def object_table(self, object_id=None):
"""Fetch and parse the object table info for one or more object IDs.
Args:
object_id: An object ID to fetch information about. If this is
None, then the entire object table is fetched.
Returns:
Information from the object table.
"""
self._check_connected()
if object_id is not None:
# Return information about a single object ID.
return self._object_table(object_id)
else:
# Return the entire object table.
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):] for key in object_location_keys])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = self._object_table(
binary_to_object_id(object_id_binary))
return results
Returns:
Information from the object table.
"""
self._check_connected()
if object_id is not None:
# Return information about a single object ID.
return self._object_table(object_id)
else:
# Return the entire object table.
object_info_keys = self._keys(OBJECT_INFO_PREFIX + "*")
object_location_keys = self._keys(OBJECT_LOCATION_PREFIX + "*")
object_ids_binary = set(
[key[len(OBJECT_INFO_PREFIX):] for key in object_info_keys] +
[key[len(OBJECT_LOCATION_PREFIX):]
for key in object_location_keys])
results = {}
for object_id_binary in object_ids_binary:
results[binary_to_object_id(object_id_binary)] = (
self._object_table(binary_to_object_id(object_id_binary)))
return results
def _task_table(self, task_id):
"""Fetch and parse the task table information for a single object task ID.
def _task_table(self, task_id):
"""Fetch and parse the task table information for a single task ID.
Args:
task_id_binary: A string of bytes with the task ID to get information
about.
Args:
task_id_binary: A string of bytes with the task ID to get
information about.
Returns:
A dictionary with information about the task ID in question.
TASK_STATUS_MAPPING should be used to parse the "State" field into a
human-readable string.
"""
task_table_response = self._execute_command(task_id,
"RAY.TASK_TABLE_GET",
task_id.id())
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task table."
.format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response, 0)
task_spec = task_table_message.TaskSpec()
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
args = []
for i in range(task_spec_message.ArgsLength()):
arg = task_spec_message.Args(i)
if len(arg.ObjectId()) != 0:
args.append(binary_to_object_id(arg.ObjectId()))
else:
args.append(pickle.loads(arg.Data()))
assert task_spec_message.RequiredResourcesLength() == 2
required_resources = {"CPUs": task_spec_message.RequiredResources(0),
"GPUs": task_spec_message.RequiredResources(1)}
task_spec_info = {
"DriverID": binary_to_hex(task_spec_message.DriverId()),
"TaskID": binary_to_hex(task_spec_message.TaskId()),
"ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()),
"ParentCounter": task_spec_message.ParentCounter(),
"ActorID": binary_to_hex(task_spec_message.ActorId()),
"ActorCounter": task_spec_message.ActorCounter(),
"FunctionID": binary_to_hex(task_spec_message.FunctionId()),
"Args": args,
"ReturnObjectIDs": [binary_to_object_id(task_spec_message.Returns(i))
for i in range(task_spec_message.ReturnsLength())],
"RequiredResources": required_resources}
Returns:
A dictionary with information about the task ID in question.
TASK_STATUS_MAPPING should be used to parse the "State" field
into a human-readable string.
"""
task_table_response = self._execute_command(task_id,
"RAY.TASK_TABLE_GET",
task_id.id())
if task_table_response is None:
raise Exception("There is no entry for task ID {} in the task "
"table.".format(binary_to_hex(task_id.id())))
task_table_message = TaskReply.GetRootAsTaskReply(task_table_response,
0)
task_spec = task_table_message.TaskSpec()
task_spec_message = TaskInfo.GetRootAsTaskInfo(task_spec, 0)
args = []
for i in range(task_spec_message.ArgsLength()):
arg = task_spec_message.Args(i)
if len(arg.ObjectId()) != 0:
args.append(binary_to_object_id(arg.ObjectId()))
else:
args.append(pickle.loads(arg.Data()))
assert task_spec_message.RequiredResourcesLength() == 2
required_resources = {"CPUs": task_spec_message.RequiredResources(0),
"GPUs": task_spec_message.RequiredResources(1)}
task_spec_info = {
"DriverID": binary_to_hex(task_spec_message.DriverId()),
"TaskID": binary_to_hex(task_spec_message.TaskId()),
"ParentTaskID": binary_to_hex(task_spec_message.ParentTaskId()),
"ParentCounter": task_spec_message.ParentCounter(),
"ActorID": binary_to_hex(task_spec_message.ActorId()),
"ActorCounter": task_spec_message.ActorCounter(),
"FunctionID": binary_to_hex(task_spec_message.FunctionId()),
"Args": args,
"ReturnObjectIDs": [binary_to_object_id(
task_spec_message.Returns(i))
for i in range(
task_spec_message.ReturnsLength())],
"RequiredResources": required_resources}
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"TaskSpec": task_spec_info}
return {"State": task_table_message.State(),
"LocalSchedulerID": binary_to_hex(
task_table_message.LocalSchedulerId()),
"TaskSpec": task_spec_info}
def task_table(self, task_id=None):
"""Fetch and parse the task table information for one or more task IDs.
def task_table(self, task_id=None):
"""Fetch and parse the task table information for one or more task IDs.
Args:
task_id: A hex string of the task ID to fetch information about. If this
is None, then the task object table is fetched.
Args:
task_id: A hex string of the task ID to fetch information about. If
this is None, then the task object table is fetched.
Returns:
Information from the task table.
"""
self._check_connected()
if task_id is not None:
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
return self._task_table(task_id)
else:
task_table_keys = self._keys(TASK_PREFIX + "*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = self._task_table(
ray.local_scheduler.ObjectID(task_id_binary))
return results
Returns:
Information from the task table.
"""
self._check_connected()
if task_id is not None:
task_id = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
return self._task_table(task_id)
else:
task_table_keys = self._keys(TASK_PREFIX + "*")
results = {}
for key in task_table_keys:
task_id_binary = key[len(TASK_PREFIX):]
results[binary_to_hex(task_id_binary)] = self._task_table(
ray.local_scheduler.ObjectID(task_id_binary))
return results
def function_table(self, function_id=None):
"""Fetch and parse the function table.
def function_table(self, function_id=None):
"""Fetch and parse the function table.
Returns:
A dictionary that maps function IDs to information about the function.
"""
self._check_connected()
function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*")
results = {}
for key in function_table_keys:
info = self.redis_client.hgetall(key)
function_info_parsed = {
"DriverID": binary_to_hex(info[b"driver_id"]),
"Module": decode(info[b"module"]),
"Name": decode(info[b"name"])
}
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
return results
Returns:
A dictionary that maps function IDs to information about the
function.
"""
self._check_connected()
function_table_keys = self.redis_client.keys(FUNCTION_PREFIX + "*")
results = {}
for key in function_table_keys:
info = self.redis_client.hgetall(key)
function_info_parsed = {
"DriverID": binary_to_hex(info[b"driver_id"]),
"Module": decode(info[b"module"]),
"Name": decode(info[b"name"])}
results[binary_to_hex(info[b"function_id"])] = function_info_parsed
return results
def client_table(self):
"""Fetch and parse the Redis DB client table.
def client_table(self):
"""Fetch and parse the Redis DB client table.
Returns:
Information about the Ray clients in the cluster.
"""
self._check_connected()
db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
node_info = dict()
for key in db_client_keys:
client_info = self.redis_client.hgetall(key)
node_ip_address = decode(client_info[b"node_ip_address"])
if node_ip_address not in node_info:
node_info[node_ip_address] = []
client_info_parsed = {
"ClientType": decode(client_info[b"client_type"]),
"Deleted": bool(int(decode(client_info[b"deleted"]))),
"DBClientID": binary_to_hex(client_info[b"ray_client_id"])
}
if b"aux_address" in client_info:
client_info_parsed["AuxAddress"] = decode(client_info[b"aux_address"])
if b"num_cpus" in client_info:
client_info_parsed["NumCPUs"] = float(decode(client_info[b"num_cpus"]))
if b"num_gpus" in client_info:
client_info_parsed["NumGPUs"] = float(decode(client_info[b"num_gpus"]))
if b"local_scheduler_socket_name" in client_info:
client_info_parsed["LocalSchedulerSocketName"] = decode(
client_info[b"local_scheduler_socket_name"])
node_info[node_ip_address].append(client_info_parsed)
Returns:
Information about the Ray clients in the cluster.
"""
self._check_connected()
db_client_keys = self.redis_client.keys(DB_CLIENT_PREFIX + "*")
node_info = dict()
for key in db_client_keys:
client_info = self.redis_client.hgetall(key)
node_ip_address = decode(client_info[b"node_ip_address"])
if node_ip_address not in node_info:
node_info[node_ip_address] = []
client_info_parsed = {
"ClientType": decode(client_info[b"client_type"]),
"Deleted": bool(int(decode(client_info[b"deleted"]))),
"DBClientID": binary_to_hex(client_info[b"ray_client_id"])
}
if b"aux_address" in client_info:
client_info_parsed["AuxAddress"] = decode(
client_info[b"aux_address"])
if b"num_cpus" in client_info:
client_info_parsed["NumCPUs"] = float(
decode(client_info[b"num_cpus"]))
if b"num_gpus" in client_info:
client_info_parsed["NumGPUs"] = float(
decode(client_info[b"num_gpus"]))
if b"local_scheduler_socket_name" in client_info:
client_info_parsed["LocalSchedulerSocketName"] = decode(
client_info[b"local_scheduler_socket_name"])
node_info[node_ip_address].append(client_info_parsed)
return node_info
return node_info
def log_files(self):
"""Fetch and return a dictionary of log file names to outputs.
def log_files(self):
"""Fetch and return a dictionary of log file names to outputs.
Returns:
IP address to log file name to log file contents mappings.
"""
relevant_files = self.redis_client.keys("LOGFILE*")
Returns:
IP address to log file name to log file contents mappings.
"""
relevant_files = self.redis_client.keys("LOGFILE*")
ip_filename_file = dict()
ip_filename_file = dict()
for filename in relevant_files:
filename = filename.decode("ascii")
filename_components = filename.split(":")
ip_addr = filename_components[1]
for filename in relevant_files:
filename = filename.decode("ascii")
filename_components = filename.split(":")
ip_addr = filename_components[1]
file = self.redis_client.lrange(filename, 0, -1)
file_str = []
for x in file:
y = x.decode("ascii")
file_str.append(y)
file = self.redis_client.lrange(filename, 0, -1)
file_str = []
for x in file:
y = x.decode("ascii")
file_str.append(y)
if ip_addr not in ip_filename_file:
ip_filename_file[ip_addr] = dict()
if ip_addr not in ip_filename_file:
ip_filename_file[ip_addr] = dict()
ip_filename_file[ip_addr][filename] = file_str
ip_filename_file[ip_addr][filename] = file_str
return ip_filename_file
return ip_filename_file
def task_profiles(self, start=None, end=None, num=None):
"""Fetch and return a list of task profiles.
def task_profiles(self, start=None, end=None, num=None):
"""Fetch and return a list of task profiles.
Args:
start: The start point of the time window that is queried for tasks.
end: The end point in time of the time window that is queried for tasks.
num: A limit on the number of tasks that task_profiles will return.
Args:
start: The start point of the time window that is queried for
tasks.
end: The end point in time of the time window that is queried for
tasks.
num: A limit on the number of tasks that task_profiles will return.
Returns:
A tuple of two elements. The first element is a dictionary mapping the
task ID of a task to a list of the profiling information for all of the
executions of that task. The second element is a list of profiling
information for tasks where the events have no task ID.
"""
if start is None:
start = 0
if num is None:
num = sys.maxsize
Returns:
A tuple of two elements. The first element is a dictionary mapping
the task ID of a task to a list of the profiling information
for all of the executions of that task. The second element is a
list of profiling information for tasks where the events have
no task ID.
"""
if start is None:
start = 0
if num is None:
num = sys.maxsize
task_info = dict()
event_log_sets = self.redis_client.keys("event_log*")
task_info = dict()
event_log_sets = self.redis_client.keys("event_log*")
# The heap is used to maintain the set of x tasks that occurred the most
# recently across all of the workers, where x is defined as the function
# parameter num. The key is the start time of the "get_task" component of
# each task. Calling heappop will result in the taks with the earliest
# "get_task_start" to be removed from the heap.
# The heap is used to maintain the set of x tasks that occurred the
# most recently across all of the workers, where x is defined as the
# function parameter num. The key is the start time of the "get_task"
# component of each task. Calling heappop will result in the taks with
# the earliest "get_task_start" to be removed from the heap.
heap = []
heapq.heapify(heap)
heap_size = 0
# Parse through event logs to determine task start and end points.
for i in range(len(event_log_sets)):
event_list = self.redis_client.zrangebyscore(event_log_sets[i],
min=start,
max=end,
start=start,
num=num)
for event in event_list:
event_dict = json.loads(event)
task_id = ""
for event in event_dict:
if "task_id" in event[3]:
task_id = event[3]["task_id"]
task_info[task_id] = dict()
for event in event_dict:
if event[1] == "ray:get_task" and event[2] == 1:
task_info[task_id]["get_task_start"] = event[0]
# Add task to min heap by its start point.
heapq.heappush(heap,
(task_info[task_id]["get_task_start"], task_id))
heap_size += 1
if event[1] == "ray:get_task" and event[2] == 2:
task_info[task_id]["get_task_end"] = event[0]
if event[1] == "ray:import_remote_function" and event[2] == 1:
task_info[task_id]["import_remote_start"] = event[0]
if event[1] == "ray:import_remote_function" and event[2] == 2:
task_info[task_id]["import_remote_end"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 1:
task_info[task_id]["acquire_lock_start"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 2:
task_info[task_id]["acquire_lock_end"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 1:
task_info[task_id]["get_arguments_start"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 2:
task_info[task_id]["get_arguments_end"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 1:
task_info[task_id]["execute_start"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 2:
task_info[task_id]["execute_end"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 1:
task_info[task_id]["store_outputs_start"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 2:
task_info[task_id]["store_outputs_end"] = event[0]
if "worker_id" in event[3]:
task_info[task_id]["worker_id"] = event[3]["worker_id"]
if "function_name" in event[3]:
task_info[task_id]["function_name"] = event[3]["function_name"]
if heap_size > num:
min_task, task_id_hex = heapq.heappop(heap)
del task_info[task_id_hex]
heap_size -= 1
return task_info
heap = []
heapq.heapify(heap)
heap_size = 0
# Parse through event logs to determine task start and end points.
for i in range(len(event_log_sets)):
event_list = self.redis_client.zrangebyscore(event_log_sets[i],
min=start,
max=end,
start=start,
num=num)
for event in event_list:
event_dict = json.loads(event)
task_id = ""
for event in event_dict:
if "task_id" in event[3]:
task_id = event[3]["task_id"]
task_info[task_id] = dict()
for event in event_dict:
if event[1] == "ray:get_task" and event[2] == 1:
task_info[task_id]["get_task_start"] = event[0]
# Add task to min heap by its start point.
heapq.heappush(heap,
(task_info[task_id]["get_task_start"],
task_id))
heap_size += 1
if event[1] == "ray:get_task" and event[2] == 2:
task_info[task_id]["get_task_end"] = event[0]
if (event[1] == "ray:import_remote_function" and
event[2] == 1):
task_info[task_id]["import_remote_start"] = event[0]
if (event[1] == "ray:import_remote_function" and
event[2] == 2):
task_info[task_id]["import_remote_end"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 1:
task_info[task_id]["acquire_lock_start"] = event[0]
if event[1] == "ray:acquire_lock" and event[2] == 2:
task_info[task_id]["acquire_lock_end"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 1:
task_info[task_id]["get_arguments_start"] = event[0]
if event[1] == "ray:task:get_arguments" and event[2] == 2:
task_info[task_id]["get_arguments_end"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 1:
task_info[task_id]["execute_start"] = event[0]
if event[1] == "ray:task:execute" and event[2] == 2:
task_info[task_id]["execute_end"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 1:
task_info[task_id]["store_outputs_start"] = event[0]
if event[1] == "ray:task:store_outputs" and event[2] == 2:
task_info[task_id]["store_outputs_end"] = event[0]
if "worker_id" in event[3]:
task_info[task_id]["worker_id"] = event[3]["worker_id"]
if "function_name" in event[3]:
task_info[task_id]["function_name"] = (
event[3]["function_name"])
if heap_size > num:
min_task, task_id_hex = heapq.heappop(heap)
del task_info[task_id_hex]
heap_size -= 1
return task_info
def dump_catapult_trace(self, path, start=None, end=None, num=None):
"""Dump task profiling information to a file.
def dump_catapult_trace(self, path, start=None, end=None, num=None):
"""Dump task profiling information to a file.
This information can be viewed as a timeline of profiling information by
going to chrome://tracing in the chrome web browser and loading the
appropriate file.
This information can be viewed as a timeline of profiling information
by going to chrome://tracing in the chrome web browser and loading the
appropriate file.
Args:
path: The filepath to dump the profiling information to.
"""
if end is None:
end = time.time()
task_info = self.task_profiles(start=start, end=end, num=num)
workers = self.workers()
start_time = None
for info in task_info.values():
task_start = min(self._get_times(info))
if not start_time or task_start < start_time:
start_time = task_start
Args:
path: The filepath to dump the profiling information to.
"""
if end is None:
end = time.time()
task_info = self.task_profiles(start=start, end=end, num=num)
workers = self.workers()
start_time = None
for info in task_info.values():
task_start = min(self._get_times(info))
if not start_time or task_start < start_time:
start_time = task_start
def micros(ts):
return int(1e6 * (ts - start_time))
def micros(ts):
return int(1e6 * (ts - start_time))
full_trace = []
for task_id, info in task_info.items():
task_id_hex = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
task_data = self._task_table(task_id_hex)
parent_info = task_info.get(task_data["TaskSpec"]["ParentTaskID"])
times = self._get_times(info)
worker = workers[info["worker_id"]]
if parent_info:
parent_worker = workers[parent_info["worker_id"]]
parent_times = self._get_times(parent_info)
parent_trace = {
"cat": "submit_task",
"pid": "Node " + str(parent_worker["node_ip_address"]),
"tid": parent_info["worker_id"],
"ts": micros(min(parent_times)),
"ph": "s",
"name": "SubmitTask",
"args": {},
"id": str(worker)
}
full_trace.append(parent_trace)
full_trace = []
for task_id, info in task_info.items():
task_id_hex = ray.local_scheduler.ObjectID(hex_to_binary(task_id))
task_data = self._task_table(task_id_hex)
parent_info = task_info.get(task_data["TaskSpec"]["ParentTaskID"])
times = self._get_times(info)
worker = workers[info["worker_id"]]
if parent_info:
parent_worker = workers[parent_info["worker_id"]]
parent_times = self._get_times(parent_info)
parent_trace = {
"cat": "submit_task",
"pid": "Node " + str(parent_worker["node_ip_address"]),
"tid": parent_info["worker_id"],
"ts": micros(min(parent_times)),
"ph": "s",
"name": "SubmitTask",
"args": {},
"id": str(worker)
}
full_trace.append(parent_trace)
parent = {
"cat": "submit_task",
"pid": "Node " + str(parent_worker["node_ip_address"]),
"tid": parent_info["worker_id"],
"ts": micros(min(parent_times)),
"ph": "s",
"name": "SubmitTask",
"args": {},
"id": str(worker)
}
full_trace.append(parent)
parent = {
"cat": "submit_task",
"pid": "Node " + str(parent_worker["node_ip_address"]),
"tid": parent_info["worker_id"],
"ts": micros(min(parent_times)),
"ph": "s",
"name": "SubmitTask",
"args": {},
"id": str(worker)
}
full_trace.append(parent)
task_trace = {
"cat": "submit_task",
"pid": "Node " + str(worker["node_ip_address"]),
"tid": info["worker_id"],
"ts": micros(min(times)),
"ph": "f",
"name": "SubmitTask",
"args": {},
"id": str(worker)
}
full_trace.append(task_trace)
task_trace = {
"cat": "submit_task",
"pid": "Node " + str(worker["node_ip_address"]),
"tid": info["worker_id"],
"ts": micros(min(times)),
"ph": "f",
"name": "SubmitTask",
"args": {},
"id": str(worker)
}
full_trace.append(task_trace)
task = {
"name": info["function_name"],
"cat": "ray_task",
"ph": "X",
"ts": micros(min(times)),
"dur": micros(max(times)) - micros(min(times)),
"pid": "Node " + str(worker["node_ip_address"]),
"tid": info["worker_id"],
"args": info
}
full_trace.append(task)
task = {
"name": info["function_name"],
"cat": "ray_task",
"ph": "X",
"ts": micros(min(times)),
"dur": micros(max(times)) - micros(min(times)),
"pid": "Node " + str(worker["node_ip_address"]),
"tid": info["worker_id"],
"args": info
}
full_trace.append(task)
with open(path, "w") as outfile:
json.dump(full_trace, outfile)
with open(path, "w") as outfile:
json.dump(full_trace, outfile)
def _get_times(self, data):
"""Extract the numerical times from a task profile.
def _get_times(self, data):
"""Extract the numerical times from a task profile.
This is a helper method for dump_catapult_trace.
This is a helper method for dump_catapult_trace.
Args:
data: This must be a value in the dictionary returned by the
task_profiles function.
"""
all_times = []
all_times.append(data["acquire_lock_start"])
all_times.append(data["acquire_lock_end"])
all_times.append(data["get_arguments_start"])
all_times.append(data["get_arguments_end"])
all_times.append(data["execute_start"])
all_times.append(data["execute_end"])
all_times.append(data["store_outputs_start"])
all_times.append(data["store_outputs_end"])
return all_times
Args:
data: This must be a value in the dictionary returned by the
task_profiles function.
"""
all_times = []
all_times.append(data["acquire_lock_start"])
all_times.append(data["acquire_lock_end"])
all_times.append(data["get_arguments_start"])
all_times.append(data["get_arguments_end"])
all_times.append(data["execute_start"])
all_times.append(data["execute_end"])
all_times.append(data["store_outputs_start"])
all_times.append(data["store_outputs_end"])
return all_times
def workers(self):
"""Get a dictionary mapping worker ID to worker information."""
worker_keys = self.redis_client.keys("Worker*")
workers_data = dict()
def workers(self):
"""Get a dictionary mapping worker ID to worker information."""
worker_keys = self.redis_client.keys("Worker*")
workers_data = dict()
for worker_key in worker_keys:
worker_info = self.redis_client.hgetall(worker_key)
worker_id = binary_to_hex(worker_key[len("Workers:"):])
for worker_key in worker_keys:
worker_info = self.redis_client.hgetall(worker_key)
worker_id = binary_to_hex(worker_key[len("Workers:"):])
workers_data[worker_id] = {
"local_scheduler_socket": (worker_info[b"local_scheduler_socket"]
.decode("ascii")),
"node_ip_address": (worker_info[b"node_ip_address"]
.decode("ascii")),
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
.decode("ascii")),
"plasma_store_socket": (worker_info[b"plasma_store_socket"]
.decode("ascii")),
"stderr_file": worker_info[b"stderr_file"].decode("ascii"),
"stdout_file": worker_info[b"stdout_file"].decode("ascii")
}
return workers_data
workers_data[worker_id] = {
"local_scheduler_socket": (
worker_info[b"local_scheduler_socket"].decode("ascii")),
"node_ip_address": (
worker_info[b"node_ip_address"].decode("ascii")),
"plasma_manager_socket": (worker_info[b"plasma_manager_socket"]
.decode("ascii")),
"plasma_store_socket": (worker_info[b"plasma_store_socket"]
.decode("ascii")),
"stderr_file": worker_info[b"stderr_file"].decode("ascii"),
"stdout_file": worker_info[b"stdout_file"].decode("ascii")
}
return workers_data
+98 -96
View File
@@ -6,114 +6,116 @@ from collections import deque, OrderedDict
def unflatten(vector, shapes):
i = 0
arrays = []
for shape in shapes:
size = np.prod(shape)
array = vector[i:(i + size)].reshape(shape)
arrays.append(array)
i += size
assert len(vector) == i, "Passed weight does not have the correct shape."
return arrays
i = 0
arrays = []
for shape in shapes:
size = np.prod(shape)
array = vector[i:(i + size)].reshape(shape)
arrays.append(array)
i += size
assert len(vector) == i, "Passed weight does not have the correct shape."
return arrays
class TensorFlowVariables(object):
"""An object used to extract variables from a loss function.
"""An object used to extract variables from a loss function.
This object also provides methods for getting and setting the weights of the
relevant variables.
This object also provides methods for getting and setting the weights of
the relevant variables.
Attributes:
sess (tf.Session): The tensorflow session used to run assignment.
loss: The loss function passed in by the user.
variables (List[tf.Variable]): Extracted variables from the loss.
assignment_placeholders (List[tf.placeholders]): The nodes that weights get
passed to.
assignment_nodes (List[tf.Tensor]): The nodes that assign the weights.
"""
def __init__(self, loss, sess=None):
"""Creates a TensorFlowVariables instance."""
import tensorflow as tf
self.sess = sess
self.loss = loss
queue = deque([loss])
variable_names = []
explored_inputs = set([loss])
Attributes:
sess (tf.Session): The tensorflow session used to run assignment.
loss: The loss function passed in by the user.
variables (List[tf.Variable]): Extracted variables from the loss.
assignment_placeholders (List[tf.placeholders]): The nodes that weights
get passed to.
assignment _nodes (List[tf.Tensor]): The nodes that assign the weights.
"""
def __init__(self, loss, sess=None):
"""Creates a TensorFlowVariables instance."""
import tensorflow as tf
self.sess = sess
self.loss = loss
queue = deque([loss])
variable_names = []
explored_inputs = set([loss])
# We do a BFS on the dependency graph of the input function to find
# the variables.
while len(queue) != 0:
tf_obj = queue.popleft()
# We do a BFS on the dependency graph of the input function to find
# the variables.
while len(queue) != 0:
tf_obj = queue.popleft()
# The object put into the queue is not necessarily an operation, so we
# want the op attribute to get the operation underlying the object.
# Only operations contain the inputs that we can explore.
if hasattr(tf_obj, "op"):
tf_obj = tf_obj.op
for input_op in tf_obj.inputs:
if input_op not in explored_inputs:
queue.append(input_op)
explored_inputs.add(input_op)
# Tensorflow control inputs can be circular, so we keep track of
# explored operations.
for control in tf_obj.control_inputs:
if control not in explored_inputs:
queue.append(control)
explored_inputs.add(control)
if "Variable" in tf_obj.node_def.op:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
for v in [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]:
self.variables[v.op.node_def.name] = v
self.placeholders = dict()
self.assignment_nodes = []
# The object put into the queue is not necessarily an operation, so
# we want the op attribute to get the operation underlying the
# object. Only operations contain the inputs that we can explore.
if hasattr(tf_obj, "op"):
tf_obj = tf_obj.op
for input_op in tf_obj.inputs:
if input_op not in explored_inputs:
queue.append(input_op)
explored_inputs.add(input_op)
# Tensorflow control inputs can be circular, so we keep track of
# explored operations.
for control in tf_obj.control_inputs:
if control not in explored_inputs:
queue.append(control)
explored_inputs.add(control)
if "Variable" in tf_obj.node_def.op:
variable_names.append(tf_obj.node_def.name)
self.variables = OrderedDict()
for v in [v for v in tf.global_variables()
if v.op.node_def.name in variable_names]:
self.variables[v.op.node_def.name] = v
self.placeholders = dict()
self.assignment_nodes = []
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(var.value().dtype,
var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.placeholders[k]))
# Create new placeholders to put in custom weights.
for k, var in self.variables.items():
self.placeholders[k] = tf.placeholder(var.value().dtype,
var.get_shape().as_list())
self.assignment_nodes.append(var.assign(self.placeholders[k]))
def set_session(self, sess):
"""Modifies the current session used by the class."""
self.sess = sess
def set_session(self, sess):
"""Modifies the current session used by the class."""
self.sess = sess
def get_flat_size(self):
return sum([np.prod(v.get_shape().as_list())
for v in self.variables.values()])
def get_flat_size(self):
return sum([np.prod(v.get_shape().as_list())
for v in self.variables.values()])
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
assert self.sess is not None, ("The session is not set. Set the session "
"either by passing it into the "
"TensorFlowVariables constructor or by "
"calling set_session(sess).")
def _check_sess(self):
"""Checks if the session is set, and if not throw an error message."""
assert self.sess is not None, ("The session is not set. Set the "
"session either by passing it into the "
"TensorFlowVariables constructor or by "
"calling set_session(sess).")
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten()
for v in self.variables.values()])
def get_flat(self):
"""Gets the weights and returns them as a flat array."""
self._check_sess()
return np.concatenate([v.eval(session=self.sess).flatten()
for v in self.variables.values()])
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.placeholders[k] for k, v in self.variables.items()]
self.sess.run(self.assignment_nodes,
feed_dict=dict(zip(placeholders, arrays)))
def set_flat(self, new_weights):
"""Sets the weights to new_weights, converting from a flat array."""
self._check_sess()
shapes = [v.get_shape().as_list() for v in self.variables.values()]
arrays = unflatten(new_weights, shapes)
placeholders = [self.placeholders[k]
for k, v in self.variables.items()]
self.sess.run(self.assignment_nodes,
feed_dict=dict(zip(placeholders, arrays)))
def get_weights(self):
"""Returns the weights of the variables of the loss function in a list."""
self._check_sess()
return {k: v.eval(session=self.sess) for k, v in self.variables.items()}
def get_weights(self):
"""Returns a list of the weights of the loss function variables."""
self._check_sess()
return {k: v.eval(session=self.sess)
for k, v in self.variables.items()}
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""
self._check_sess()
self.sess.run(self.assignment_nodes,
feed_dict={self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders})
def set_weights(self, new_weights):
"""Sets the weights to new_weights."""
self._check_sess()
self.sess.run(self.assignment_nodes,
feed_dict={self.placeholders[name]: value
for (name, value) in new_weights.items()
if name in self.placeholders})
+56 -54
View File
@@ -11,70 +11,72 @@ import ray
def tarred_directory_as_bytes(source_dir):
"""Tar a directory and return it as a byte string.
"""Tar a directory and return it as a byte string.
Args:
source_dir (str): The name of the directory to tar.
Args:
source_dir (str): The name of the directory to tar.
Returns:
A byte string representing the tarred file.
"""
# Get a BytesIO object.
string_file = io.BytesIO()
# Create an in-memory tarfile of the source directory.
with tarfile.open(mode="w:gz", fileobj=string_file) as tar:
tar.add(source_dir, arcname=os.path.basename(source_dir))
string_file.seek(0)
return string_file.read()
Returns:
A byte string representing the tarred file.
"""
# Get a BytesIO object.
string_file = io.BytesIO()
# Create an in-memory tarfile of the source directory.
with tarfile.open(mode="w:gz", fileobj=string_file) as tar:
tar.add(source_dir, arcname=os.path.basename(source_dir))
string_file.seek(0)
return string_file.read()
def tarred_bytes_to_directory(tarred_bytes, target_dir):
"""Take a byte string and untar it.
"""Take a byte string and untar it.
Args:
tarred_bytes (str): A byte string representing the tarred file. This should
be the output of tarred_directory_as_bytes.
target_dir (str): The directory to create the untarred files in.
"""
string_file = io.BytesIO(tarred_bytes)
with tarfile.open(fileobj=string_file) as tar:
tar.extractall(path=target_dir)
Args:
tarred_bytes (str): A byte string representing the tarred file. This
should be the output of tarred_directory_as_bytes.
target_dir (str): The directory to create the untarred files in.
"""
string_file = io.BytesIO(tarred_bytes)
with tarfile.open(fileobj=string_file) as tar:
tar.extractall(path=target_dir)
def copy_directory(source_dir, target_dir=None):
"""Copy a local directory to each machine in the Ray cluster.
"""Copy a local directory to each machine in the Ray cluster.
Note that both source_dir and target_dir must have the same basename). For
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this case,
the directory /d/e will be added to the Python path of each worker.
Note that both source_dir and target_dir must have the same basename). For
example, source_dir can be /a/b/c and target_dir can be /d/e/c. In this
case, the directory /d/e will be added to the Python path of each worker.
Note that this method is not completely safe to use. For example, workers
that do not do the copying and only set their paths (only one worker per node
does the copying) may try to execute functions that use the files in the
directory being copied before the directory being copied has finished
untarring.
Note that this method is not completely safe to use. For example, workers
that do not do the copying and only set their paths (only one worker per
node does the copying) may try to execute functions that use the files in
the directory being copied before the directory being copied has finished
untarring.
Args:
source_dir (str): The directory to copy.
target_dir (str): The location to copy it to on the other machines. If this
is not provided, the source_dir will be used. If it is provided and is
different from source_dir, the source_dir also be copied to the
target_dir location on this machine.
"""
target_dir = source_dir if target_dir is None else target_dir
source_dir = os.path.abspath(source_dir)
target_dir = os.path.abspath(target_dir)
source_basename = os.path.basename(source_dir)
target_basename = os.path.basename(target_dir)
if source_basename != target_basename:
raise Exception("The source_dir and target_dir must have the same base "
"name, {} != {}".format(source_basename, target_basename))
tarred_bytes = tarred_directory_as_bytes(source_dir)
Args:
source_dir (str): The directory to copy.
target_dir (str): The location to copy it to on the other machines. If
this is not provided, the source_dir will be used. If it is
provided and is different from source_dir, the source_dir also be
copied to the target_dir location on this machine.
"""
target_dir = source_dir if target_dir is None else target_dir
source_dir = os.path.abspath(source_dir)
target_dir = os.path.abspath(target_dir)
source_basename = os.path.basename(source_dir)
target_basename = os.path.basename(target_dir)
if source_basename != target_basename:
raise Exception("The source_dir and target_dir must have the same "
"base name, {} != {}".format(source_basename,
target_basename))
tarred_bytes = tarred_directory_as_bytes(source_dir)
def f(worker_info):
if worker_info["counter"] == 0:
tarred_bytes_to_directory(tarred_bytes, os.path.dirname(target_dir))
sys.path.append(os.path.dirname(target_dir))
# Run this function on all workers to copy the directory to all nodes and to
# add the directory to the Python path of each worker.
ray.worker.global_worker.run_function_on_all_workers(f)
def f(worker_info):
if worker_info["counter"] == 0:
tarred_bytes_to_directory(tarred_bytes,
os.path.dirname(target_dir))
sys.path.append(os.path.dirname(target_dir))
# Run this function on all workers to copy the directory to all nodes and
# to add the directory to the Python path of each worker.
ray.worker.global_worker.run_function_on_all_workers(f)
@@ -10,45 +10,45 @@ import time
def start_global_scheduler(redis_address, node_ip_address,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
"""Start a global scheduler process.
"""Start a global scheduler process.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will run
on.
use_valgrind (bool): True if the global scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started inside
a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
redirection should happen, then this should be None.
Args:
redis_address (str): The address of the Redis instance.
node_ip_address: The IP address of the node that this scheduler will
run on.
use_valgrind (bool): True if the global scheduler should be started
inside of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the global scheduler should be started
inside a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
Return:
The process ID of the global scheduler process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable,
"-r", redis_address,
"-h", node_ip_address]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return pid
Return:
The process ID of the global scheduler process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
global_scheduler_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/global_scheduler/global_scheduler")
command = [global_scheduler_executable,
"-r", redis_address,
"-h", node_ip_address]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return pid
+251 -241
View File
@@ -33,275 +33,285 @@ TASK_PREFIX = "TT:"
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def new_port():
return random.randint(10000, 65535)
return random.randint(10000, 65535)
class TestGlobalScheduler(unittest.TestCase):
def setUp(self):
# Start one Redis server and N pairs of (plasma, local_scheduler)
self.node_ip_address = "127.0.0.1"
redis_address, redis_shards = services.start_redis(self.node_ip_address)
redis_port = services.get_port(redis_address)
time.sleep(0.1)
# Create a client for the global state store.
self.state = state.GlobalState()
self.state._initialize_global_state(self.node_ip_address, redis_port)
def setUp(self):
# Start one Redis server and N pairs of (plasma, local_scheduler)
self.node_ip_address = "127.0.0.1"
redis_address, redis_shards = services.start_redis(
self.node_ip_address)
redis_port = services.get_port(redis_address)
time.sleep(0.1)
# Create a client for the global state store.
self.state = state.GlobalState()
self.state._initialize_global_state(self.node_ip_address, redis_port)
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
self.plasma_clients = []
self.local_scheduler_clients = []
# Start one global scheduler.
self.p1 = global_scheduler.start_global_scheduler(
redis_address, self.node_ip_address, use_valgrind=USE_VALGRIND)
self.plasma_store_pids = []
self.plasma_manager_pids = []
self.local_scheduler_pids = []
self.plasma_clients = []
self.local_scheduler_clients = []
for i in range(NUM_CLUSTER_NODES):
# Start the Plasma store. Plasma store name is randomly generated.
plasma_store_name, p2 = plasma.start_plasma_store()
self.plasma_store_pids.append(p2)
# Start the Plasma manager.
# Assumption: Plasma manager name and port are randomly generated by the
# plasma module.
manager_info = plasma.start_plasma_manager(plasma_store_name,
redis_address)
plasma_manager_name, p3, plasma_manager_port = manager_info
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(self.node_ip_address,
plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name,
plasma_manager_name)
self.plasma_clients.append(plasma_client)
# Start the local scheduler.
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
plasma_store_name,
plasma_manager_name=plasma_manager_name,
plasma_address=plasma_address,
redis_address=redis_address,
static_resource_list=[10, 0])
# Connect to the scheduler.
local_scheduler_client = local_scheduler.LocalSchedulerClient(
local_scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
self.local_scheduler_clients.append(local_scheduler_client)
self.local_scheduler_pids.append(p4)
for i in range(NUM_CLUSTER_NODES):
# Start the Plasma store. Plasma store name is randomly generated.
plasma_store_name, p2 = plasma.start_plasma_store()
self.plasma_store_pids.append(p2)
# Start the Plasma manager.
# Assumption: Plasma manager name and port are randomly generated
# by the plasma module.
manager_info = plasma.start_plasma_manager(plasma_store_name,
redis_address)
plasma_manager_name, p3, plasma_manager_port = manager_info
self.plasma_manager_pids.append(p3)
plasma_address = "{}:{}".format(self.node_ip_address,
plasma_manager_port)
plasma_client = plasma.PlasmaClient(plasma_store_name,
plasma_manager_name)
self.plasma_clients.append(plasma_client)
# Start the local scheduler.
local_scheduler_name, p4 = local_scheduler.start_local_scheduler(
plasma_store_name,
plasma_manager_name=plasma_manager_name,
plasma_address=plasma_address,
redis_address=redis_address,
static_resource_list=[10, 0])
# Connect to the scheduler.
local_scheduler_client = local_scheduler.LocalSchedulerClient(
local_scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
self.local_scheduler_clients.append(local_scheduler_client)
self.local_scheduler_pids.append(p4)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
for p2 in self.plasma_store_pids:
self.assertEqual(p2.poll(), None)
for p3 in self.plasma_manager_pids:
self.assertEqual(p3.poll(), None)
for p4 in self.local_scheduler_pids:
self.assertEqual(p4.poll(), None)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
for p2 in self.plasma_store_pids:
self.assertEqual(p2.poll(), None)
for p3 in self.plasma_manager_pids:
self.assertEqual(p3.poll(), None)
for p4 in self.local_scheduler_pids:
self.assertEqual(p4.poll(), None)
redis_processes = services.all_processes[
services.PROCESS_TYPE_REDIS_SERVER]
for redis_process in redis_processes:
self.assertEqual(redis_process.poll(), None)
redis_processes = services.all_processes[
services.PROCESS_TYPE_REDIS_SERVER]
for redis_process in redis_processes:
self.assertEqual(redis_process.poll(), None)
# Kill the global scheduler.
if USE_VALGRIND:
self.p1.send_signal(signal.SIGTERM)
self.p1.wait()
if self.p1.returncode != 0:
os._exit(-1)
else:
self.p1.kill()
# Kill local schedulers, plasma managers, and plasma stores.
for p2 in self.local_scheduler_pids:
p2.kill()
for p3 in self.plasma_manager_pids:
p3.kill()
for p4 in self.plasma_store_pids:
p4.kill()
# Kill Redis. In the event that we are using valgrind, this needs to happen
# after we kill the global scheduler.
while redis_processes:
redis_process = redis_processes.pop()
redis_process.kill()
def get_plasma_manager_id(self):
"""Get the db_client_id with client_type equal to plasma_manager.
Iterates over all the client table keys, gets the db_client_id for the
client with client_type matching plasma_manager. Strips the client table
prefix. TODO(atumanov): write a separate function to get all plasma manager
client IDs.
Returns:
The db_client_id if one is found and otherwise None.
"""
db_client_id = None
client_list = self.state.client_table()[self.node_ip_address]
for client in client_list:
if client["ClientType"] == "plasma_manager":
db_client_id = client["DBClientID"]
break
return db_client_id
def test_task_default_resources(self):
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(), 0)
self.assertEqual(task1.required_resources(), [1.0, 0.0])
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(), 0,
local_scheduler.ObjectID(NIL_ACTOR_ID), 0,
[1.0, 2.0])
self.assertEqual(task2.required_resources(), [1.0, 2.0])
def test_redis_only_single_task(self):
"""
Tests global scheduler functionality by interacting with Redis and checking
task state transitions in Redis only. TODO(atumanov): implement.
"""
# Check precondition for this test:
# There should be 2n+1 db clients: the global scheduler + one local
# scheduler and one plasma per node.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id is not None)
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to local scheduler.
time.sleep(0.1)
# Submit a task to Redis.
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to the
# local scheduler
num_retries = 10
while num_retries > 0:
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), 1)
if len(task_entries) == 1:
task_id, task = task_entries.popitem()
task_status = task["State"]
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED])
if task_status == state.TASK_STATUS_QUEUED:
break
# Kill the global scheduler.
if USE_VALGRIND:
self.p1.send_signal(signal.SIGTERM)
self.p1.wait()
if self.p1.returncode != 0:
os._exit(-1)
else:
print(task_status)
print("The task has not been scheduled yet, trying again.")
num_retries -= 1
time.sleep(1)
self.p1.kill()
# Kill local schedulers, plasma managers, and plasma stores.
for p2 in self.local_scheduler_pids:
p2.kill()
for p3 in self.plasma_manager_pids:
p3.kill()
for p4 in self.plasma_store_pids:
p4.kill()
# Kill Redis. In the event that we are using valgrind, this needs to
# happen after we kill the global scheduler.
while redis_processes:
redis_process = redis_processes.pop()
redis_process.kill()
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
# Failed to submit and schedule a single task -- bail.
self.tearDown()
sys.exit(1)
def get_plasma_manager_id(self):
"""Get the db_client_id with client_type equal to plasma_manager.
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
Iterates over all the client table keys, gets the db_client_id for the
client with client_type matching plasma_manager. Strips the client
table prefix. TODO(atumanov): write a separate function to get all
plasma manager client IDs.
# Submit a bunch of tasks to Redis.
num_tasks = 1000
for _ in range(num_tasks):
# Create a new object for each task.
data_size = np.random.randint(1 << 12)
metadata_size = np.random.randint(1 << 9)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client,
data_size,
metadata_size,
seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to yield CPU).
time.sleep(0.010)
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
# Check that there are the correct number of tasks in Redis and that they
# all get assigned to the local scheduler.
num_retries = 10
num_tasks_done = 0
while num_retries > 0:
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_statuses = [task_entry["State"] for task_entry in
task_entries.values()]
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED]
for status in task_statuses]))
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(state.TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(state.TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done, num_retries))
if all([status == state.TASK_STATUS_QUEUED for status in
task_statuses]):
# We're done, so pass.
break
num_retries -= 1
time.sleep(0.1)
Returns:
The db_client_id if one is found and otherwise None.
"""
db_client_id = None
self.assertEqual(num_tasks_done, num_tasks)
client_list = self.state.client_table()[self.node_ip_address]
for client in client_list:
if client["ClientType"] == "plasma_manager":
db_client_id = client["DBClientID"]
break
def test_integration_many_tasks_handler_sync(self):
self.integration_many_tasks_helper(timesync=True)
return db_client_id
def test_integration_many_tasks(self):
# More realistic case: should handle out of order object and task
# notifications.
self.integration_many_tasks_helper(timesync=False)
def test_task_default_resources(self):
task1 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(),
0)
self.assertEqual(task1.required_resources(), [1.0, 0.0])
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(),
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
0, [1.0, 2.0])
self.assertEqual(task2.required_resources(), [1.0, 2.0])
def test_redis_only_single_task(self):
# Tests global scheduler functionality by interacting with Redis and
# checking task state transitions in Redis only. TODO(atumanov):
# implement.
# Check precondition for this test:
# There should be 2n+1 db clients: the global scheduler + one local
# scheduler and one plasma per node.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
db_client_id = self.get_plasma_manager_id()
assert(db_client_id is not None)
def test_integration_single_task(self):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Insert the object into Redis.
data_size = 0xf1f0
metadata_size = 0x40
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(
plasma_client, data_size, metadata_size, seal=True)
# Sleep before submitting task to local scheduler.
time.sleep(0.1)
# Submit a task to Redis.
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(), 0)
self.local_scheduler_clients[0].submit(task)
time.sleep(0.1)
# There should now be a task in Redis, and it should get assigned to
# the local scheduler
num_retries = 10
while num_retries > 0:
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), 1)
if len(task_entries) == 1:
task_id, task = task_entries.popitem()
task_status = task["State"]
self.assertTrue(task_status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED])
if task_status == state.TASK_STATUS_QUEUED:
break
else:
print(task_status)
print("The task has not been scheduled yet, trying again.")
num_retries -= 1
time.sleep(1)
if num_retries <= 0 and task_status != state.TASK_STATUS_QUEUED:
# Failed to submit and schedule a single task -- bail.
self.tearDown()
sys.exit(1)
def integration_many_tasks_helper(self, timesync=True):
# There should be three db clients, the global scheduler, the local
# scheduler, and the plasma manager.
self.assertEqual(
len(self.state.client_table()[self.node_ip_address]),
2 * NUM_CLUSTER_NODES + 1)
num_return_vals = [0, 1, 2, 3, 5, 10]
# Submit a bunch of tasks to Redis.
num_tasks = 1000
for _ in range(num_tasks):
# Create a new object for each task.
data_size = np.random.randint(1 << 12)
metadata_size = np.random.randint(1 << 9)
plasma_client = self.plasma_clients[0]
object_dep, memory_buffer, metadata = create_object(plasma_client,
data_size,
metadata_size,
seal=True)
if timesync:
# Give 10ms for object info handler to fire (long enough to
# yield CPU).
time.sleep(0.010)
task = local_scheduler.Task(random_driver_id(),
random_function_id(),
[local_scheduler.ObjectID(object_dep)],
num_return_vals[0], random_task_id(),
0)
self.local_scheduler_clients[0].submit(task)
# Check that there are the correct number of tasks in Redis and that
# they all get assigned to the local scheduler.
num_retries = 10
num_tasks_done = 0
while num_retries > 0:
task_entries = self.state.task_table()
self.assertLessEqual(len(task_entries), num_tasks)
# First, check if all tasks made it to Redis.
if len(task_entries) == num_tasks:
task_statuses = [task_entry["State"] for task_entry in
task_entries.values()]
self.assertTrue(all([status in [state.TASK_STATUS_WAITING,
state.TASK_STATUS_SCHEDULED,
state.TASK_STATUS_QUEUED]
for status in task_statuses]))
num_tasks_done = task_statuses.count(state.TASK_STATUS_QUEUED)
num_tasks_scheduled = task_statuses.count(
state.TASK_STATUS_SCHEDULED)
num_tasks_waiting = task_statuses.count(
state.TASK_STATUS_WAITING)
print("tasks in Redis = {}, tasks waiting = {}, "
"tasks scheduled = {}, "
"tasks queued = {}, retries left = {}"
.format(len(task_entries), num_tasks_waiting,
num_tasks_scheduled, num_tasks_done,
num_retries))
if all([status == state.TASK_STATUS_QUEUED for status in
task_statuses]):
# We're done, so pass.
break
num_retries -= 1
time.sleep(0.1)
self.assertEqual(num_tasks_done, num_tasks)
def test_integration_many_tasks_handler_sync(self):
self.integration_many_tasks_helper(timesync=True)
def test_integration_many_tasks(self):
# More realistic case: should handle out of order object and task
# notifications.
self.integration_many_tasks_helper(timesync=False)
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument
# parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
@@ -9,7 +9,7 @@ import time
def random_name():
return str(random.randint(0, 99999999))
return str(random.randint(0, 99999999))
def start_local_scheduler(plasma_store_name,
@@ -24,95 +24,99 @@ def start_local_scheduler(plasma_store_name,
stderr_file=None,
static_resource_list=None,
num_workers=0):
"""Start a local scheduler process.
"""Start a local scheduler process.
Args:
plasma_store_name (str): The name of the plasma store socket to connect to.
plasma_manager_name (str): The name of the plasma manager to connect to.
This does not need to be provided, but if it is, then the Redis address
must be provided as well.
worker_path (str): The path of the worker script to use when the local
scheduler starts up new workers.
plasma_address (str): The address of the plasma manager to connect to. This
is only used by the global scheduler to figure out which plasma managers
are connected to which local schedulers.
node_ip_address (str): The address of the node that this local scheduler is
running on.
redis_address (str): The address of the Redis instance to connect to. If
this is not provided, then the local scheduler will not connect to Redis.
use_valgrind (bool): True if the local scheduler should be started inside
of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the local scheduler should be started inside a
profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
redirection should happen, then this should be None.
static_resource_list (list): A list of integers specifying the local
scheduler's resource capacities. The resources should appear in an order
matching the order defined in task.h.
num_workers (int): The number of workers that the local scheduler should
start.
Args:
plasma_store_name (str): The name of the plasma store socket to connect
to.
plasma_manager_name (str): The name of the plasma manager to connect
to. This does not need to be provided, but if it is, then the Redis
address must be provided as well.
worker_path (str): The path of the worker script to use when the local
scheduler starts up new workers.
plasma_address (str): The address of the plasma manager to connect to.
This is only used by the global scheduler to figure out which
plasma managers are connected to which local schedulers.
node_ip_address (str): The address of the node that this local
scheduler is running on.
redis_address (str): The address of the Redis instance to connect to.
If this is not provided, then the local scheduler will not connect
to Redis.
use_valgrind (bool): True if the local scheduler should be started
inside of valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the local scheduler should be started
inside a profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
static_resource_list (list): A list of integers specifying the local
scheduler's resource capacities. The resources should appear in an
order matching the order defined in task.h.
num_workers (int): The number of workers that the local scheduler
should start.
Return:
A tuple of the name of the local scheduler socket and the process ID of the
local scheduler process.
"""
if (plasma_manager_name is None) != (redis_address is None):
raise Exception("If one of the plasma_manager_name and the redis_address "
"is provided, then both must be provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(os.path.dirname(
os.path.abspath(__file__)),
"../core/src/local_scheduler/local_scheduler")
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
command = [local_scheduler_executable,
"-s", local_scheduler_name,
"-p", plasma_store_name,
"-h", node_ip_address,
"-n", str(num_workers)]
if plasma_manager_name is not None:
command += ["-m", plasma_manager_name]
if worker_path is not None:
assert plasma_store_name is not None
assert plasma_manager_name is not None
assert redis_address is not None
start_worker_command = ("python {} "
"--node-ip-address={} "
"--object-store-name={} "
"--object-store-manager-name={} "
"--local-scheduler-name={} "
"--redis-address={}").format(worker_path,
node_ip_address,
plasma_store_name,
plasma_manager_name,
local_scheduler_name,
redis_address)
command += ["-w", start_worker_command]
if redis_address is not None:
command += ["-r", redis_address]
if plasma_address is not None:
command += ["-a", plasma_address]
if static_resource_list is not None:
assert all([isinstance(resource, int) or isinstance(resource, float)
for resource in static_resource_list])
command += ["-c", ",".join([str(resource) for resource
in static_resource_list])]
Return:
A tuple of the name of the local scheduler socket and the process ID of
the local scheduler process.
"""
if (plasma_manager_name is None) != (redis_address is None):
raise Exception("If one of the plasma_manager_name and the "
"redis_address is provided, then both must be "
"provided.")
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
local_scheduler_executable = os.path.join(os.path.dirname(
os.path.abspath(__file__)),
"../core/src/local_scheduler/local_scheduler")
local_scheduler_name = "/tmp/scheduler{}".format(random_name())
command = [local_scheduler_executable,
"-s", local_scheduler_name,
"-p", plasma_store_name,
"-h", node_ip_address,
"-n", str(num_workers)]
if plasma_manager_name is not None:
command += ["-m", plasma_manager_name]
if worker_path is not None:
assert plasma_store_name is not None
assert plasma_manager_name is not None
assert redis_address is not None
start_worker_command = ("python {} "
"--node-ip-address={} "
"--object-store-name={} "
"--object-store-manager-name={} "
"--local-scheduler-name={} "
"--redis-address={}"
.format(worker_path,
node_ip_address,
plasma_store_name,
plasma_manager_name,
local_scheduler_name,
redis_address))
command += ["-w", start_worker_command]
if redis_address is not None:
command += ["-r", redis_address]
if plasma_address is not None:
command += ["-a", plasma_address]
if static_resource_list is not None:
assert all([isinstance(resource, int) or isinstance(resource, float)
for resource in static_resource_list])
command += ["-c", ",".join([str(resource) for resource
in static_resource_list])]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return local_scheduler_name, pid
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return local_scheduler_name, pid
+171 -164
View File
@@ -21,195 +21,202 @@ NIL_ACTOR_ID = 20 * b"\xff"
def random_object_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_driver_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_task_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
def random_function_id():
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
return local_scheduler.ObjectID(np.random.bytes(ID_SIZE))
class TestLocalSchedulerClient(unittest.TestCase):
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p1 = plasma.start_plasma_store()
self.plasma_client = plasma.PlasmaClient(plasma_store_name,
release_delay=0)
# Start a local scheduler.
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
plasma_store_name, use_valgrind=USE_VALGRIND)
# Connect to the scheduler.
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
def setUp(self):
# Start Plasma store.
plasma_store_name, self.p1 = plasma.start_plasma_store()
self.plasma_client = plasma.PlasmaClient(plasma_store_name,
release_delay=0)
# Start a local scheduler.
scheduler_name, self.p2 = local_scheduler.start_local_scheduler(
plasma_store_name, use_valgrind=USE_VALGRIND)
# Connect to the scheduler.
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False, 0)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
self.assertEqual(self.p2.poll(), None)
def tearDown(self):
# Check that the processes are still alive.
self.assertEqual(self.p1.poll(), None)
self.assertEqual(self.p2.poll(), None)
# Kill Plasma.
self.p1.kill()
# Kill the local scheduler.
if USE_VALGRIND:
self.p2.send_signal(signal.SIGTERM)
self.p2.wait()
if self.p2.returncode != 0:
os._exit(-1)
else:
self.p2.kill()
# Kill Plasma.
self.p1.kill()
# Kill the local scheduler.
if USE_VALGRIND:
self.p2.send_signal(signal.SIGTERM)
self.p2.wait()
if self.p2.returncode != 0:
os._exit(-1)
else:
self.p2.kill()
def test_submit_and_get_task(self):
function_id = random_function_id()
object_ids = [random_object_id() for i in range(256)]
# Create and seal the objects in the object store so that we can schedule
# all of the subsequent tasks.
for object_id in object_ids:
self.plasma_client.create(object_id.id(), 0)
self.plasma_client.seal(object_id.id())
# Define some arguments to use for the tasks.
args_list = [
[],
[{}],
[()],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
def test_submit_and_get_task(self):
function_id = random_function_id()
object_ids = [random_object_id() for i in range(256)]
# Create and seal the objects in the object store so that we can
# schedule all of the subsequent tasks.
for object_id in object_ids:
self.plasma_client.create(object_id.id(), 0)
self.plasma_client.seal(object_id.id())
# Define some arguments to use for the tasks.
args_list = [
[],
[{}],
[()],
1 * [1],
10 * [1],
100 * [1],
1000 * [1],
1 * ["a"],
10 * ["a"],
100 * ["a"],
1000 * ["a"],
[1, 1.3, 1 << 100, "hi", u"hi", [1, 2]],
object_ids[:1],
object_ids[:2],
object_ids[:3],
object_ids[:4],
object_ids[:5],
object_ids[:10],
object_ids[:100],
object_ids[:256],
[1, object_ids[0]],
[object_ids[0], "a"],
[1, object_ids[0], "a"],
[object_ids[0], 1, object_ids[1], "a"],
object_ids[:3] + [1, "hi", 2.3] + object_ids[:5],
object_ids + 100 * ["a"] + object_ids
]
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id, args,
num_return_vals, random_task_id(), 0)
# Submit a task.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id,
args, num_return_vals,
random_task_id(), 0)
# Submit a task.
self.local_scheduler_client.submit(task)
# Get the task.
new_task = self.local_scheduler_client.get_task()
self.assertEqual(task.function_id().id(),
new_task.function_id().id())
retrieved_args = new_task.arguments()
returns = new_task.returns()
self.assertEqual(len(args), len(retrieved_args))
self.assertEqual(num_return_vals, len(returns))
for i in range(len(retrieved_args)):
if isinstance(args[i], local_scheduler.ObjectID):
self.assertEqual(args[i].id(), retrieved_args[i].id())
else:
self.assertEqual(args[i], retrieved_args[i])
# Submit all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id,
args, num_return_vals,
random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Get all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
new_task = self.local_scheduler_client.get_task()
def test_scheduling_when_objects_ready(self):
# Create a task and submit it.
object_id = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id], 0, random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Get the task.
new_task = self.local_scheduler_client.get_task()
self.assertEqual(task.function_id().id(), new_task.function_id().id())
retrieved_args = new_task.arguments()
returns = new_task.returns()
self.assertEqual(len(args), len(retrieved_args))
self.assertEqual(num_return_vals, len(returns))
for i in range(len(retrieved_args)):
if isinstance(args[i], local_scheduler.ObjectID):
self.assertEqual(args[i].id(), retrieved_args[i].id())
else:
self.assertEqual(args[i], retrieved_args[i])
# Submit all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
task = local_scheduler.Task(random_driver_id(), function_id, args,
num_return_vals, random_task_id(), 0)
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Sleep to give the thread time to call get_task.
time.sleep(0.1)
# Create and seal the object ID in the object store. This should
# trigger a scheduling event.
self.plasma_client.create(object_id.id(), 0)
self.plasma_client.seal(object_id.id())
# Wait until the thread finishes so that we know the task was
# scheduled.
t.join()
def test_scheduling_when_objects_evicted(self):
# Create a task with two dependencies and submit it.
object_id1 = random_object_id()
object_id2 = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id1, object_id2], 0,
random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Get all of the tasks.
for args in args_list:
for num_return_vals in [0, 1, 2, 3, 5, 10, 100]:
new_task = self.local_scheduler_client.get_task()
def test_scheduling_when_objects_ready(self):
# Create a task and submit it.
object_id = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id], 0, random_task_id(), 0)
self.local_scheduler_client.submit(task)
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Sleep to give the thread time to call get_task.
time.sleep(0.1)
# Create and seal the object ID in the object store. This should trigger a
# scheduling event.
self.plasma_client.create(object_id.id(), 0)
self.plasma_client.seal(object_id.id())
# Wait until the thread finishes so that we know the task was scheduled.
t.join()
# Make one of the dependencies available.
buf = self.plasma_client.create(object_id1.id(), 1)
self.plasma_client.seal(object_id1.id())
# Release the object.
del buf
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Force eviction of the first dependency.
self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY)
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Check that the first object dependency was evicted.
object1 = self.plasma_client.get([object_id1.id()], timeout_ms=0)
self.assertEqual(object1, [None])
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
def test_scheduling_when_objects_evicted(self):
# Create a task with two dependencies and submit it.
object_id1 = random_object_id()
object_id2 = random_object_id()
task = local_scheduler.Task(random_driver_id(), random_function_id(),
[object_id1, object_id2], 0, random_task_id(),
0)
self.local_scheduler_client.submit(task)
# Create the second dependency.
self.plasma_client.create(object_id2.id(), 1)
self.plasma_client.seal(object_id2.id())
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Launch a thread to get the task.
def get_task():
self.local_scheduler_client.get_task()
t = threading.Thread(target=get_task)
t.start()
# Create the first dependency again. Both dependencies are now
# available.
self.plasma_client.create(object_id1.id(), 1)
self.plasma_client.seal(object_id1.id())
# Make one of the dependencies available.
buf = self.plasma_client.create(object_id1.id(), 1)
self.plasma_client.seal(object_id1.id())
# Release the object.
del buf
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Force eviction of the first dependency.
self.plasma_client.evict(plasma.DEFAULT_PLASMA_STORE_MEMORY)
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Check that the first object dependency was evicted.
object1 = self.plasma_client.get([object_id1.id()], timeout_ms=0)
self.assertEqual(object1, [None])
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Create the second dependency.
self.plasma_client.create(object_id2.id(), 1)
self.plasma_client.seal(object_id2.id())
# Check that the thread is still waiting for a task.
time.sleep(0.1)
self.assertTrue(t.is_alive())
# Create the first dependency again. Both dependencies are now available.
self.plasma_client.create(object_id1.id(), 1)
self.plasma_client.seal(object_id1.id())
# Wait until the thread finishes so that we know the task was scheduled.
t.join()
# Wait until the thread finishes so that we know the task was
# scheduled.
t.join()
if __name__ == "__main__":
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
if len(sys.argv) > 1:
# Pop the argument so we don't mess with unittest's own argument
# parser.
if sys.argv[-1] == "valgrind":
arg = sys.argv.pop()
USE_VALGRIND = True
print("Using valgrind for tests")
unittest.main(verbosity=2)
+92 -84
View File
@@ -12,94 +12,102 @@ from ray.services import get_port
class LogMonitor(object):
"""A monitor process for monitoring Ray log files.
"""A monitor process for monitoring Ray log files.
Attributes:
node_ip_address: The IP address of the node that the log monitor process is
running on. This will be used to determine which log files to track.
redis_client: A client used to communicate with the Redis server.
log_filenames: A list of the names of the log files that this monitor
process is monitoring.
log_files: A dictionary mapping the name of a log file to a list of strings
representing its contents.
log_file_handles: A dictionary mapping the name of a log file to a file
handle for that file.
"""
def __init__(self, redis_ip_address, redis_port, node_ip_address):
"""Initialize the log monitor object."""
self.node_ip_address = node_ip_address
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.log_files = {}
self.log_file_handles = {}
def update_log_filenames(self):
"""Get the most up-to-date list of log files to monitor from Redis."""
num_current_log_files = len(self.log_files)
new_log_filenames = self.redis_client.lrange(
"LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
for log_filename in new_log_filenames:
print("Beginning to track file {}".format(log_filename))
assert log_filename not in self.log_files
self.log_files[log_filename] = []
def check_log_files_and_push_updates(self):
"""Get any changes to the log files and push updates to Redis."""
for log_filename in self.log_files:
if log_filename in self.log_file_handles:
# Get any updates to the file.
new_lines = []
while True:
current_position = self.log_file_handles[log_filename].tell()
next_line = self.log_file_handles[log_filename].readline()
if next_line != "":
new_lines.append(next_line)
else:
self.log_file_handles[log_filename].seek(current_position)
break
# If there are any new lines, cache them and also push them to Redis.
if len(new_lines) > 0:
self.log_files[log_filename] += new_lines
redis_key = "LOGFILE:{}:{}".format(self.node_ip_address,
log_filename.decode("ascii"))
self.redis_client.rpush(redis_key, *new_lines)
else:
try:
self.log_file_handles[log_filename] = open(log_filename, "r")
except IOError as e:
if e.errno == os.errno.EMFILE:
print("Warning: Some files are not being logged because there are "
"too many open files.")
elif e.errno == os.errno.ENOENT:
print("Warning: The file {} was not found.".format(log_filename))
else:
raise e
def run(self):
"""Run the log monitor.
This will query Redis once every second to check if there are new log files
to monitor. It will also store those log files in Redis.
Attributes:
node_ip_address: The IP address of the node that the log monitor
process is running on. This will be used to determine which log
files to track.
redis_client: A client used to communicate with the Redis server.
log_filenames: A list of the names of the log files that this monitor
process is monitoring.
log_files: A dictionary mapping the name of a log file to a list of
strings representing its contents.
log_file_handles: A dictionary mapping the name of a log file to a file
handle for that file.
"""
while True:
self.update_log_filenames()
self.check_log_files_and_push_updates()
time.sleep(1)
def __init__(self, redis_ip_address, redis_port, node_ip_address):
"""Initialize the log monitor object."""
self.node_ip_address = node_ip_address
self.redis_client = redis.StrictRedis(host=redis_ip_address,
port=redis_port)
self.log_files = {}
self.log_file_handles = {}
def update_log_filenames(self):
"""Get the most up-to-date list of log files to monitor from Redis."""
num_current_log_files = len(self.log_files)
new_log_filenames = self.redis_client.lrange(
"LOG_FILENAMES:{}".format(self.node_ip_address),
num_current_log_files, -1)
for log_filename in new_log_filenames:
print("Beginning to track file {}".format(log_filename))
assert log_filename not in self.log_files
self.log_files[log_filename] = []
def check_log_files_and_push_updates(self):
"""Get any changes to the log files and push updates to Redis."""
for log_filename in self.log_files:
if log_filename in self.log_file_handles:
# Get any updates to the file.
new_lines = []
while True:
current_position = (
self.log_file_handles[log_filename].tell())
next_line = self.log_file_handles[log_filename].readline()
if next_line != "":
new_lines.append(next_line)
else:
self.log_file_handles[log_filename].seek(
current_position)
break
# If there are any new lines, cache them and also push them to
# Redis.
if len(new_lines) > 0:
self.log_files[log_filename] += new_lines
redis_key = "LOGFILE:{}:{}".format(
self.node_ip_address, log_filename.decode("ascii"))
self.redis_client.rpush(redis_key, *new_lines)
else:
try:
self.log_file_handles[log_filename] = open(log_filename,
"r")
except IOError as e:
if e.errno == os.errno.EMFILE:
print("Warning: Some files are not being logged "
"because there are too many open files.")
elif e.errno == os.errno.ENOENT:
print("Warning: The file {} was not "
"found.".format(log_filename))
else:
raise e
def run(self):
"""Run the log monitor.
This will query Redis once every second to check if there are new log
files to monitor. It will also store those log files in Redis.
"""
while True:
self.update_log_filenames()
self.check_log_files_and_push_updates()
time.sleep(1)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"log monitor to connect to."))
parser.add_argument("--redis-address", required=True, type=str,
help="The address to use for Redis.")
parser.add_argument("--node-ip-address", required=True, type=str,
help="The IP address of the node this process is on.")
args = parser.parse_args()
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"log monitor to connect "
"to."))
parser.add_argument("--redis-address", required=True, type=str,
help="The address to use for Redis.")
parser.add_argument("--node-ip-address", required=True, type=str,
help="The IP address of the node this process is on.")
args = parser.parse_args()
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
log_monitor = LogMonitor(redis_ip_address, redis_port, args.node_ip_address)
log_monitor.run()
log_monitor = LogMonitor(redis_ip_address, redis_port,
args.node_ip_address)
log_monitor.run()
+318 -305
View File
@@ -48,354 +48,367 @@ log.setLevel(logging.INFO)
class Monitor(object):
"""A monitor for Ray processes.
"""A monitor for Ray processes.
The monitor is in charge of cleaning up the tables in the global state after
processes have died. The monitor is currently not responsible for detecting
component failures.
The monitor is in charge of cleaning up the tables in the global state
after processes have died. The monitor is currently not responsible for
detecting component failures.
Attributes:
redis: A connection to the Redis server.
subscribe_client: A pubsub client for the Redis server. This is used to
receive notifications about failed components.
subscribed: A dictionary mapping channel names (str) to whether or not the
subscription to that channel has succeeded yet (bool).
dead_local_schedulers: A set of the local scheduler IDs of all of the local
schedulers that were up at one point and have died since then.
live_plasma_managers: A counter mapping live plasma manager IDs to the
number of heartbeats that have passed since we last heard from that
plasma manager. A plasma manager is live if we received a heartbeat from
it at any point, and if it has not timed out.
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
managers that were up at one point and have died since then.
"""
def __init__(self, redis_address, redis_port):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(redis_address, redis_port)
self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0)
# TODO(swang): Update pubsub client to use ray.experimental.state once
# subscriptions are implemented there.
self.subscribe_client = self.redis.pubsub()
self.subscribed = {}
# Initialize data structures to keep track of the active database clients.
self.dead_local_schedulers = set()
self.live_plasma_managers = Counter()
self.dead_plasma_managers = set()
def subscribe(self, channel):
"""Subscribe to the given channel.
Args:
channel (str): The channel to subscribe to.
Raises:
Exception: An exception is raised if the subscription fails.
Attributes:
redis: A connection to the Redis server.
subscribe_client: A pubsub client for the Redis server. This is used to
receive notifications about failed components.
subscribed: A dictionary mapping channel names (str) to whether or not
the subscription to that channel has succeeded yet (bool).
dead_local_schedulers: A set of the local scheduler IDs of all of the
local schedulers that were up at one point and have died since
then.
live_plasma_managers: A counter mapping live plasma manager IDs to the
number of heartbeats that have passed since we last heard from that
plasma manager. A plasma manager is live if we received a heartbeat
from it at any point, and if it has not timed out.
dead_plasma_managers: A set of the plasma manager IDs of all the plasma
managers that were up at one point and have died since then.
"""
self.subscribe_client.subscribe(channel)
self.subscribed[channel] = False
def __init__(self, redis_address, redis_port):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(redis_address, redis_port)
self.redis = redis.StrictRedis(host=redis_address, port=redis_port,
db=0)
# TODO(swang): Update pubsub client to use ray.experimental.state once
# subscriptions are implemented there.
self.subscribe_client = self.redis.pubsub()
self.subscribed = {}
# Initialize data structures to keep track of the active database
# clients.
self.dead_local_schedulers = set()
self.live_plasma_managers = Counter()
self.dead_plasma_managers = set()
def cleanup_task_table(self):
"""Clean up global state for failed local schedulers.
def subscribe(self, channel):
"""Subscribe to the given channel.
This marks any tasks that were scheduled on dead local schedulers as
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
self.dead_local_schedulers.
"""
tasks = self.state.task_table()
num_tasks_updated = 0
for task_id, task in tasks.items():
# See if the corresponding local scheduler is alive.
if task["LocalSchedulerID"] in self.dead_local_schedulers:
# If the task is scheduled on a dead local scheduler, mark the task as
# lost.
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1
if num_tasks_updated > 0:
log.warn("Marked {} tasks as lost.".format(num_tasks_updated))
Args:
channel (str): The channel to subscribe to.
def cleanup_object_table(self):
"""Clean up global state for failed plasma managers.
Raises:
Exception: An exception is raised if the subscription fails.
"""
self.subscribe_client.subscribe(channel)
self.subscribed[channel] = False
This removes dead plasma managers from any location entries in the object
table. A plasma manager is deemed dead if it is in
self.dead_plasma_managers.
"""
# TODO(swang): Also kill the associated plasma store, since it's no longer
# reachable without a plasma manager.
objects = self.state.object_table()
num_objects_removed = 0
for object_id, obj in objects.items():
manager_ids = obj["ManagerIDs"]
if manager_ids is None:
continue
for manager in manager_ids:
if manager in self.dead_plasma_managers:
# If the object was on a dead plasma manager, remove that location
# entry.
ok = self.state._execute_command(object_id,
"RAY.OBJECT_TABLE_REMOVE",
object_id.id(),
hex_to_binary(manager))
if ok != b"OK":
log.warn("Failed to remove object location for dead plasma "
"manager.")
num_objects_removed += 1
if num_objects_removed > 0:
log.warn("Marked {} objects as lost.".format(num_objects_removed))
def cleanup_task_table(self):
"""Clean up global state for failed local schedulers.
def scan_db_client_table(self):
"""Scan the database client table for dead clients.
This marks any tasks that were scheduled on dead local schedulers as
TASK_STATUS_LOST. A local scheduler is deemed dead if it is in
self.dead_local_schedulers.
"""
tasks = self.state.task_table()
num_tasks_updated = 0
for task_id, task in tasks.items():
# See if the corresponding local scheduler is alive.
if task["LocalSchedulerID"] in self.dead_local_schedulers:
# If the task is scheduled on a dead local scheduler, mark the
# task as lost.
key = binary_to_object_id(hex_to_binary(task_id))
ok = self.state._execute_command(
key, "RAY.TASK_TABLE_UPDATE", hex_to_binary(task_id),
ray.experimental.state.TASK_STATUS_LOST, NIL_ID)
if ok != b"OK":
log.warn("Failed to update lost task for dead scheduler.")
num_tasks_updated += 1
if num_tasks_updated > 0:
log.warn("Marked {} tasks as lost.".format(num_tasks_updated))
After subscribing to the client table, it's necessary to call this before
reading any messages from the subscription channel. This ensures that we do
not miss any notifications for deleted clients that occurred before we
subscribed.
"""
clients = self.state.client_table()
for node_ip_address, node_clients in clients.items():
for client in node_clients:
db_client_id = client["DBClientID"]
client_type = client["ClientType"]
if client["Deleted"]:
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
self.dead_plasma_managers.add(db_client_id)
def cleanup_object_table(self):
"""Clean up global state for failed plasma managers.
def subscribe_handler(self, channel, data):
"""Handle a subscription success message from Redis.
"""
log.debug("Subscribed to {}, data was {}".format(channel, data))
self.subscribed[channel] = True
This removes dead plasma managers from any location entries in the
object table. A plasma manager is deemed dead if it is in
self.dead_plasma_managers.
"""
# TODO(swang): Also kill the associated plasma store, since it's no
# longer reachable without a plasma manager.
objects = self.state.object_table()
num_objects_removed = 0
for object_id, obj in objects.items():
manager_ids = obj["ManagerIDs"]
if manager_ids is None:
continue
for manager in manager_ids:
if manager in self.dead_plasma_managers:
# If the object was on a dead plasma manager, remove that
# location entry.
ok = self.state._execute_command(object_id,
"RAY.OBJECT_TABLE_REMOVE",
object_id.id(),
hex_to_binary(manager))
if ok != b"OK":
log.warn("Failed to remove object location for dead "
"plasma manager.")
num_objects_removed += 1
if num_objects_removed > 0:
log.warn("Marked {} objects as lost.".format(num_objects_removed))
def db_client_notification_handler(self, channel, data):
"""Handle a notification from the db_client table from Redis.
def scan_db_client_table(self):
"""Scan the database client table for dead clients.
This handler processes notifications from the db_client table.
Notifications should be parsed using the SubscribeToDBClientTableReply
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the
associated state in the state tables should be handled by the caller.
"""
notification_object = (SubscribeToDBClientTableReply
.GetRootAsSubscribeToDBClientTableReply(data, 0))
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
is_insertion = notification_object.IsInsertion()
After subscribing to the client table, it's necessary to call this
before reading any messages from the subscription channel. This ensures
that we do not miss any notifications for deleted clients that occurred
before we subscribed.
"""
clients = self.state.client_table()
for node_ip_address, node_clients in clients.items():
for client in node_clients:
db_client_id = client["DBClientID"]
client_type = client["ClientType"]
if client["Deleted"]:
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
self.dead_plasma_managers.add(db_client_id)
# If the update was an insertion, we ignore it.
if is_insertion:
return
def subscribe_handler(self, channel, data):
"""Handle a subscription success message from Redis."""
log.debug("Subscribed to {}, data was {}".format(channel, data))
self.subscribed[channel] = True
# If the update was a deletion, add them to our accounting for dead
# local schedulers and plasma managers.
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
if db_client_id not in self.dead_local_schedulers:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
if db_client_id not in self.dead_plasma_managers:
self.dead_plasma_managers.add(db_client_id)
# Stop tracking this plasma manager's heartbeats, since it's
# already dead.
del self.live_plasma_managers[db_client_id]
def db_client_notification_handler(self, channel, data):
"""Handle a notification from the db_client table from Redis.
def plasma_manager_heartbeat_handler(self, channel, data):
"""Handle a plasma manager heartbeat from Redis.
This handler processes notifications from the db_client table.
Notifications should be parsed using the SubscribeToDBClientTableReply
flatbuffer. Deletions are processed, insertions are ignored. Cleanup of
the associated state in the state tables should be handled by the
caller.
"""
notification_object = (SubscribeToDBClientTableReply
.GetRootAsSubscribeToDBClientTableReply(data,
0))
db_client_id = binary_to_hex(notification_object.DbClientId())
client_type = notification_object.ClientType()
is_insertion = notification_object.IsInsertion()
This resets the number of heartbeats that we've missed from this plasma
manager.
"""
# The first DB_CLIENT_ID_SIZE characters are the client ID.
db_client_id = data[:DB_CLIENT_ID_SIZE]
# Reset the number of heartbeats that we've missed from this plasma
# manager.
self.live_plasma_managers[db_client_id] = 0
# If the update was an insertion, we ignore it.
if is_insertion:
return
def driver_removed_handler(self, channel, data):
"""Handle a notification that a driver has been removed.
# If the update was a deletion, add them to our accounting for dead
# local schedulers and plasma managers.
log.warn("Removed {}, client ID {}".format(client_type, db_client_id))
if client_type == LOCAL_SCHEDULER_CLIENT_TYPE:
if db_client_id not in self.dead_local_schedulers:
self.dead_local_schedulers.add(db_client_id)
elif client_type == PLASMA_MANAGER_CLIENT_TYPE:
if db_client_id not in self.dead_plasma_managers:
self.dead_plasma_managers.add(db_client_id)
# Stop tracking this plasma manager's heartbeats, since it's
# already dead.
del self.live_plasma_managers[db_client_id]
This releases any GPU resources that were reserved for that driver in
Redis.
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
driver_id = message.DriverId()
log.info("Driver {} has been removed.".format(binary_to_hex(driver_id)))
def plasma_manager_heartbeat_handler(self, channel, data):
"""Handle a plasma manager heartbeat from Redis.
# Get a list of the local schedulers.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if client["ClientType"] == "local_scheduler":
local_schedulers.append(client)
This resets the number of heartbeats that we've missed from this plasma
manager.
"""
# The first DB_CLIENT_ID_SIZE characters are the client ID.
db_client_id = data[:DB_CLIENT_ID_SIZE]
# Reset the number of heartbeats that we've missed from this plasma
# manager.
self.live_plasma_managers[db_client_id] = 0
# Release any GPU resources that have been reserved for this driver in
# Redis.
for local_scheduler in local_schedulers:
if int(local_scheduler["NumGPUs"]) > 0:
local_scheduler_id = local_scheduler["DBClientID"]
def driver_removed_handler(self, channel, data):
"""Handle a notification that a driver has been removed.
num_gpus_returned = 0
This releases any GPU resources that were reserved for that driver in
Redis.
"""
message = DriverTableMessage.GetRootAsDriverTableMessage(data, 0)
driver_id = message.DriverId()
log.info("Driver {} has been removed."
.format(binary_to_hex(driver_id)))
# Perform a transaction to return the GPUs.
with self.redis.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the
# multi/exec block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
# Get a list of the local schedulers.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if client["ClientType"] == "local_scheduler":
local_schedulers.append(client)
result = pipe.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(result)
# Release any GPU resources that have been reserved for this driver in
# Redis.
for local_scheduler in local_schedulers:
if int(local_scheduler["NumGPUs"]) > 0:
local_scheduler_id = local_scheduler["DBClientID"]
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex in gpus_in_use:
num_gpus_returned = gpus_in_use.pop(driver_id_hex)
num_gpus_returned = 0
pipe.multi()
# Perform a transaction to return the GPUs.
with self.redis.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction
# below (the multi/exec block), then the
# transaction will not take place.
pipe.watch(local_scheduler_id)
pipe.hset(local_scheduler_id, "gpus_in_use",
json.dumps(gpus_in_use))
result = pipe.hget(local_scheduler_id,
"gpus_in_use")
gpus_in_use = (dict() if result is None
else json.loads(result))
pipe.execute()
# If a WatchError is not raise, then the operations should have
# gone through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the
# time we started WATCHing it and the pipeline's execution. We
# should just retry.
continue
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex in gpus_in_use:
num_gpus_returned = gpus_in_use.pop(
driver_id_hex)
log.info("Driver {} is returning GPU IDs {} to local scheduler {}."
.format(driver_id, num_gpus_returned, local_scheduler_id))
pipe.multi()
def process_messages(self):
"""Process all messages ready in the subscription channels.
pipe.hset(local_scheduler_id, "gpus_in_use",
json.dumps(gpus_in_use))
This reads messages from the subscription channels and calls the
appropriate handlers until there are no messages left.
"""
while True:
message = self.subscribe_client.get_message()
if message is None:
return
pipe.execute()
# If a WatchError is not raise, then the operations
# should have gone through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key
# between the time we started WATCHing it and the
# pipeline's execution. We should just retry.
continue
# Parse the message.
channel = message["channel"]
data = message["data"]
log.info("Driver {} is returning GPU IDs {} to local "
"scheduler {}.".format(driver_id, num_gpus_returned,
local_scheduler_id))
# Determine the appropriate message handler.
message_handler = None
if not self.subscribed[channel]:
# If the data was an integer, then the message was a response to an
# initial subscription request.
message_handler = self.subscribe_handler
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
assert(self.subscribed[channel])
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == DB_CLIENT_TABLE_NAME:
assert(self.subscribed[channel])
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == DRIVER_DEATH_CHANNEL:
assert(self.subscribed[channel])
# The message was a notification that a driver was removed.
message_handler = self.driver_removed_handler
else:
raise Exception("This code should be unreachable.")
def process_messages(self):
"""Process all messages ready in the subscription channels.
# Call the handler.
assert(message_handler is not None)
message_handler(channel, data)
This reads messages from the subscription channels and calls the
appropriate handlers until there are no messages left.
"""
while True:
message = self.subscribe_client.get_message()
if message is None:
return
def run(self):
"""Run the monitor.
# Parse the message.
channel = message["channel"]
data = message["data"]
This function loops forever, checking for messages about dead database
clients and cleaning up state accordingly.
"""
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
# Determine the appropriate message handler.
message_handler = None
if not self.subscribed[channel]:
# If the data was an integer, then the message was a response
# to an initial subscription request.
message_handler = self.subscribe_handler
elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL:
assert(self.subscribed[channel])
# The message was a heartbeat from a plasma manager.
message_handler = self.plasma_manager_heartbeat_handler
elif channel == DB_CLIENT_TABLE_NAME:
assert(self.subscribed[channel])
# The message was a notification from the db_client table.
message_handler = self.db_client_notification_handler
elif channel == DRIVER_DEATH_CHANNEL:
assert(self.subscribed[channel])
# The message was a notification that a driver was removed.
message_handler = self.driver_removed_handler
else:
raise Exception("This code should be unreachable.")
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel. This
# ensures that we start in a consistent state, since we may have missed
# notifications that were sent before we connected to the subscription
# channel.
self.scan_db_client_table()
# If there were any dead clients at startup, clean up the associated state
# in the state tables.
if len(self.dead_local_schedulers) > 0:
self.cleanup_task_table()
if len(self.dead_plasma_managers) > 0:
self.cleanup_object_table()
log.debug("{} dead local schedulers, {} plasma managers total, {} dead "
"plasma managers".format(len(self.dead_local_schedulers),
(len(self.live_plasma_managers) +
len(self.dead_plasma_managers)),
len(self.dead_plasma_managers)))
# Call the handler.
assert(message_handler is not None)
message_handler(channel, data)
# Handle messages from the subscription channels.
while True:
# Record how many dead local schedulers and plasma managers we had at the
# beginning of this round.
num_dead_local_schedulers = len(self.dead_local_schedulers)
num_dead_plasma_managers = len(self.dead_plasma_managers)
# Process a round of messages.
self.process_messages()
# If any new local schedulers or plasma managers were marked as dead in
# this round, clean up the associated state.
if len(self.dead_local_schedulers) > num_dead_local_schedulers:
self.cleanup_task_table()
if len(self.dead_plasma_managers) > num_dead_plasma_managers:
self.cleanup_object_table()
def run(self):
"""Run the monitor.
# Handle plasma managers that timed out during this round.
plasma_manager_ids = list(self.live_plasma_managers.keys())
for plasma_manager_id in plasma_manager_ids:
if ((self.live_plasma_managers
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
# Remove the plasma manager from the managers whose heartbeats we're
# tracking.
del self.live_plasma_managers[plasma_manager_id]
# Remove the plasma manager from the db_client table. The
# corresponding state in the object table will be cleaned up once we
# receive the notification for this db_client deletion.
self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id)
This function loops forever, checking for messages about dead database
clients and cleaning up state accordingly.
"""
# Initialize the subscription channel.
self.subscribe(DB_CLIENT_TABLE_NAME)
self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL)
self.subscribe(DRIVER_DEATH_CHANNEL)
# Increment the number of heartbeats that we've missed from each plasma
# manager.
for plasma_manager_id in self.live_plasma_managers:
self.live_plasma_managers[plasma_manager_id] += 1
# Scan the database table for dead database clients. NOTE: This must be
# called before reading any messages from the subscription channel.
# This ensures that we start in a consistent state, since we may have
# missed notifications that were sent before we connected to the
# subscription channel.
self.scan_db_client_table()
# If there were any dead clients at startup, clean up the associated
# state in the state tables.
if len(self.dead_local_schedulers) > 0:
self.cleanup_task_table()
if len(self.dead_plasma_managers) > 0:
self.cleanup_object_table()
log.debug("{} dead local schedulers, {} plasma managers total, {} "
"dead plasma managers".format(
len(self.dead_local_schedulers),
(len(self.live_plasma_managers) +
len(self.dead_plasma_managers)),
len(self.dead_plasma_managers)))
# Wait for a heartbeat interval before processing the next round of
# messages.
time.sleep(HEARTBEAT_TIMEOUT_MILLISECONDS * 1e-3)
# Handle messages from the subscription channels.
while True:
# Record how many dead local schedulers and plasma managers we had
# at the beginning of this round.
num_dead_local_schedulers = len(self.dead_local_schedulers)
num_dead_plasma_managers = len(self.dead_plasma_managers)
# Process a round of messages.
self.process_messages()
# If any new local schedulers or plasma managers were marked as
# dead in this round, clean up the associated state.
if len(self.dead_local_schedulers) > num_dead_local_schedulers:
self.cleanup_task_table()
if len(self.dead_plasma_managers) > num_dead_plasma_managers:
self.cleanup_object_table()
# Handle plasma managers that timed out during this round.
plasma_manager_ids = list(self.live_plasma_managers.keys())
for plasma_manager_id in plasma_manager_ids:
if ((self.live_plasma_managers
[plasma_manager_id]) >= NUM_HEARTBEATS_TIMEOUT):
log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE))
# Remove the plasma manager from the managers whose
# heartbeats we're tracking.
del self.live_plasma_managers[plasma_manager_id]
# Remove the plasma manager from the db_client table. The
# corresponding state in the object table will be cleaned
# up once we receive the notification for this db_client
# deletion.
self.redis.execute_command("RAY.DISCONNECT",
plasma_manager_id)
# Increment the number of heartbeats that we've missed from each
# plasma manager.
for plasma_manager_id in self.live_plasma_managers:
self.live_plasma_managers[plasma_manager_id] += 1
# Wait for a heartbeat interval before processing the next round of
# messages.
time.sleep(HEARTBEAT_TIMEOUT_MILLISECONDS * 1e-3)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"monitor to connect to."))
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for Redis")
args = parser.parse_args()
parser = argparse.ArgumentParser(description=("Parse Redis server for the "
"monitor to connect to."))
parser.add_argument("--redis-address", required=True, type=str,
help="the address to use for Redis")
args = parser.parse_args()
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
# Initialize the global state.
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
# Initialize the global state.
ray.global_state._initialize_global_state(redis_ip_address, redis_port)
monitor = Monitor(redis_ip_address, redis_port)
monitor.run()
monitor = Monitor(redis_ip_address, redis_port)
monitor.run()
+22 -24
View File
@@ -16,28 +16,26 @@ __all__ = ["deserialize_list", "numbuf_error",
"store_list", "write_to_buffer"]
try:
from ray.core.src.numbuf.libnumbuf import (deserialize_list, numbuf_error,
numbuf_plasma_object_exists_error,
read_from_buffer,
register_callbacks, retrieve_list,
serialize_list, store_list,
write_to_buffer)
from ray.core.src.numbuf.libnumbuf import (
deserialize_list, numbuf_error, numbuf_plasma_object_exists_error,
read_from_buffer, register_callbacks, retrieve_list, serialize_list,
store_list, write_to_buffer)
except ImportError as e:
if (hasattr(e, "msg") and isinstance(e.msg, str) and ("libstdc++" in e.msg or
"CXX" in e.msg)):
# This code path should be taken with Python 3.
e.msg += helpful_message
elif (hasattr(e, "message") and isinstance(e.message, str) and
("libstdc++" in e.message or "CXX" in e.message)):
# This code path should be taken with Python 2.
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
len(e.args) == 1 and isinstance(e.args[0], str))
if condition:
e.args = (e.args[0] + helpful_message,)
else:
if not hasattr(e, "args"):
e.args = ()
elif not isinstance(e.args, tuple):
e.args = (e.args,)
e.args += (helpful_message,)
raise
if ((hasattr(e, "msg") and isinstance(e.msg, str) and
("libstdc++" in e.msg or "CXX" in e.msg))):
# This code path should be taken with Python 3.
e.msg += helpful_message
elif (hasattr(e, "message") and isinstance(e.message, str) and
("libstdc++" in e.message or "CXX" in e.message)):
# This code path should be taken with Python 2.
condition = (hasattr(e, "args") and isinstance(e.args, tuple) and
len(e.args) == 1 and isinstance(e.args[0], str))
if condition:
e.args = (e.args[0] + helpful_message,)
else:
if not hasattr(e, "args"):
e.args = ()
elif not isinstance(e.args, tuple):
e.args = (e.args,)
e.args += (helpful_message,)
raise
+344 -329
View File
@@ -21,342 +21,355 @@ PLASMA_WAIT_TIMEOUT = 2 ** 30
class PlasmaBuffer(object):
"""This is the type of objects returned by calls to get with a PlasmaClient.
"""This is the type returned by calls to get with a PlasmaClient.
We define our own class instead of directly returning a buffer object so that
we can add a custom destructor which notifies Plasma that the object is no
longer being used, so the memory in the Plasma store backing the object can
potentially be freed.
We define our own class instead of directly returning a buffer object so
that we can add a custom destructor which notifies Plasma that the object
is no longer being used, so the memory in the Plasma store backing the
object can potentially be freed.
Attributes:
buffer (buffer): A buffer containing an object in the Plasma store.
plasma_id (PlasmaID): The ID of the object in the buffer.
plasma_client (PlasmaClient): The PlasmaClient that we use to communicate
with the store and manager.
"""
def __init__(self, buff, plasma_id, plasma_client):
"""Initialize a PlasmaBuffer."""
self.buffer = buff
self.plasma_id = plasma_id
self.plasma_client = plasma_client
def __del__(self):
"""Notify Plasma that the object is no longer needed.
If the plasma client has been shut down, then don't do anything.
Attributes:
buffer (buffer): A buffer containing an object in the Plasma store.
plasma_id (PlasmaID): The ID of the object in the buffer.
plasma_client (PlasmaClient): The PlasmaClient that we use to communicate
with the store and manager.
"""
if self.plasma_client.alive:
libplasma.release(self.plasma_client.conn, self.plasma_id)
def __init__(self, buff, plasma_id, plasma_client):
"""Initialize a PlasmaBuffer."""
self.buffer = buff
self.plasma_id = plasma_id
self.plasma_client = plasma_client
def __getitem__(self, index):
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go
# out of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
value = self.buffer[index]
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = chr(value)
return value
def __del__(self):
"""Notify Plasma that the object is no longer needed.
def __setitem__(self, index, value):
"""Write to the PlasmaBuffer as if it were just a regular buffer.
If the plasma client has been shut down, then don't do anything.
"""
if self.plasma_client.alive:
libplasma.release(self.plasma_client.conn, self.plasma_id)
This should fail because the buffer should be read only.
"""
# We currently don't allow slicing plasma buffers. We should handle this
# better, but it requires some care because the slice may be backed by the
# same memory in the object store, but the original plasma buffer may go
# out of scope causing the memory to no longer be accessible.
assert not isinstance(index, slice)
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = ord(value)
self.buffer[index] = value
def __getitem__(self, index):
"""Read from the PlasmaBuffer as if it were just a regular buffer."""
# We currently don't allow slicing plasma buffers. We should handle
# this better, but it requires some care because the slice may be
# backed by the same memory in the object store, but the original
# plasma buffer may go out of scope causing the memory to no longer be
# accessible.
assert not isinstance(index, slice)
value = self.buffer[index]
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = chr(value)
return value
def __len__(self):
"""Return the length of the buffer."""
return len(self.buffer)
def __setitem__(self, index, value):
"""Write to the PlasmaBuffer as if it were just a regular buffer.
This should fail because the buffer should be read only.
"""
# We currently don't allow slicing plasma buffers. We should handle
# this better, but it requires some care because the slice may be
# backed by the same memory in the object store, but the original
# plasma buffer may go out of scope causing the memory to no longer be
# accessible.
assert not isinstance(index, slice)
if sys.version_info >= (3, 0) and not isinstance(index, slice):
value = ord(value)
self.buffer[index] = value
def __len__(self):
"""Return the length of the buffer."""
return len(self.buffer)
def buffers_equal(buff1, buff2):
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
"""Compare two buffers. These buffers may be PlasmaBuffer objects.
This method should only be used in the tests. We implement a special helper
method for doing this because doing comparisons by slicing is much faster,
but we don't want to expose slicing of PlasmaBuffer objects because it
currently is not safe.
"""
buff1_to_compare = buff1.buffer if isinstance(buff1, PlasmaBuffer) else buff1
buff2_to_compare = buff2.buffer if isinstance(buff2, PlasmaBuffer) else buff2
return buff1_to_compare[:] == buff2_to_compare[:]
This method should only be used in the tests. We implement a special helper
method for doing this because doing comparisons by slicing is much faster,
but we don't want to expose slicing of PlasmaBuffer objects because it
currently is not safe.
"""
buff1_to_compare = (buff1.buffer if isinstance(buff1, PlasmaBuffer)
else buff1)
buff2_to_compare = (buff2.buffer if isinstance(buff2, PlasmaBuffer)
else buff2)
return buff1_to_compare[:] == buff2_to_compare[:]
class PlasmaClient(object):
"""The PlasmaClient is used to interface with a plasma store and manager.
"""The PlasmaClient is used to interface with a plasma store and manager.
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
buffer, and get a buffer. Buffers are referred to by object IDs, which are
strings.
"""
def __init__(self, store_socket_name, manager_socket_name=None,
release_delay=64):
"""Initialize the PlasmaClient.
Args:
store_socket_name (str): Name of the socket the plasma store is listening
at.
manager_socket_name (str): Name of the socket the plasma manager is
listening at.
release_delay (int): The maximum number of objects that the client will
keep and delay releasing (for caching reasons).
The PlasmaClient can ask the PlasmaStore to allocate a new buffer, seal a
buffer, and get a buffer. Buffers are referred to by object IDs, which are
strings.
"""
self.store_socket_name = store_socket_name
self.manager_socket_name = manager_socket_name
self.alive = True
if manager_socket_name is not None:
self.conn = libplasma.connect(store_socket_name, manager_socket_name,
release_delay)
else:
self.conn = libplasma.connect(store_socket_name, "", release_delay)
def __init__(self, store_socket_name, manager_socket_name=None,
release_delay=64):
"""Initialize the PlasmaClient.
def shutdown(self):
"""Shutdown the client so that it does not send messages.
Args:
store_socket_name (str): Name of the socket the plasma store is
listening at.
manager_socket_name (str): Name of the socket the plasma manager is
listening at.
release_delay (int): The maximum number of objects that the client
will keep and delay releasing (for caching reasons).
"""
self.store_socket_name = store_socket_name
self.manager_socket_name = manager_socket_name
self.alive = True
If we kill the Plasma store and Plasma manager that this client is
connected to, then we can use this method to prevent the client from trying
to send messages to the killed processes.
"""
if self.alive:
libplasma.disconnect(self.conn)
self.alive = False
if manager_socket_name is not None:
self.conn = libplasma.connect(store_socket_name,
manager_socket_name,
release_delay)
else:
self.conn = libplasma.connect(store_socket_name, "", release_delay)
def create(self, object_id, size, metadata=None):
"""Create a new buffer in the PlasmaStore for a particular object ID.
def shutdown(self):
"""Shutdown the client so that it does not send messages.
The returned buffer is mutable until seal is called.
If we kill the Plasma store and Plasma manager that this client is
connected to, then we can use this method to prevent the client from
trying to send messages to the killed processes.
"""
if self.alive:
libplasma.disconnect(self.conn)
self.alive = False
Args:
object_id (str): A string used to identify an object.
size (int): The size in bytes of the created buffer.
metadata (buffer): An optional buffer encoding whatever metadata the user
wishes to encode.
def create(self, object_id, size, metadata=None):
"""Create a new buffer in the PlasmaStore for a particular object ID.
Raises:
plasma_object_exists_error: This exception is raised if the object could
not be created because there already is an object with the same ID in
the plasma store.
plasma_out_of_memory_error: This exception is raised if the object could
not be created because the plasma store is unable to evict enough
objects to create room for it.
"""
# Turn the metadata into the right type.
metadata = bytearray(b"") if metadata is None else metadata
buff = libplasma.create(self.conn, object_id, size, metadata)
return PlasmaBuffer(buff, object_id, self)
The returned buffer is mutable until seal is called.
def get(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
Args:
object_id (str): A string used to identify an object.
size (int): The size in bytes of the created buffer.
metadata (buffer): An optional buffer encoding whatever metadata the
user wishes to encode.
If the object has not been sealed yet, this call will block. The retrieved
buffer is immutable.
Raises:
plasma_object_exists_error: This exception is raised if the object
could not be created because there already is an object with the
same ID in the plasma store.
plasma_out_of_memory_error: This exception is raised if the object
could not be created because the plasma store is unable to evict
enough objects to create room for it.
"""
# Turn the metadata into the right type.
metadata = bytearray(b"") if metadata is None else metadata
buff = libplasma.create(self.conn, object_id, size, metadata)
return PlasmaBuffer(buff, object_id, self)
Args:
object_ids (List[str]): A list of strings used to identify some objects.
timeout_ms (int): The number of milliseconds that the get call should
block before timing out and returning. Pass -1 if the call should block
and 0 if the call should return immediately.
"""
results = libplasma.get(self.conn, object_ids, timeout_ms)
assert len(object_ids) == len(results)
returns = []
for i in range(len(object_ids)):
if results[i] is None:
returns.append(None)
else:
returns.append(PlasmaBuffer(results[i][0], object_ids[i], self))
return returns
def get(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
def get_metadata(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
If the object has not been sealed yet, this call will block. The
retrieved buffer is immutable.
If the object has not been sealed yet, this call will block until the
object has been sealed. The retrieved buffer is immutable.
Args:
object_ids (List[str]): A list of strings used to identify some
objects.
timeout_ms (int): The number of milliseconds that the get call should
block before timing out and returning. Pass -1 if the call should
block and 0 if the call should return immediately.
"""
results = libplasma.get(self.conn, object_ids, timeout_ms)
assert len(object_ids) == len(results)
returns = []
for i in range(len(object_ids)):
if results[i] is None:
returns.append(None)
else:
returns.append(PlasmaBuffer(results[i][0], object_ids[i],
self))
return returns
Args:
object_ids (List[str]): A list of strings used to identify some objects.
timeout_ms (int): The number of milliseconds that the get call should
block before timing out and returning. Pass -1 if the call should block
and 0 if the call should return immediately.
"""
results = libplasma.get(self.conn, object_ids, timeout_ms)
assert len(object_ids) == len(results)
returns = []
for i in range(len(object_ids)):
if results[i] is None:
returns.append(None)
else:
returns.append(PlasmaBuffer(results[i][1], object_ids[i], self))
return returns
def get_metadata(self, object_ids, timeout_ms=-1):
"""Create a buffer from the PlasmaStore based on object ID.
def contains(self, object_id):
"""Check if the object is present and has been sealed in the PlasmaStore.
If the object has not been sealed yet, this call will block until the
object has been sealed. The retrieved buffer is immutable.
Args:
object_id (str): A string used to identify an object.
"""
return libplasma.contains(self.conn, object_id)
Args:
object_ids (List[str]): A list of strings used to identify some
objects.
timeout_ms (int): The number of milliseconds that the get call should
block before timing out and returning. Pass -1 if the call should
block and 0 if the call should return immediately.
"""
results = libplasma.get(self.conn, object_ids, timeout_ms)
assert len(object_ids) == len(results)
returns = []
for i in range(len(object_ids)):
if results[i] is None:
returns.append(None)
else:
returns.append(PlasmaBuffer(results[i][1], object_ids[i],
self))
return returns
def hash(self, object_id):
"""Compute the hash of an object in the object store.
def contains(self, object_id):
"""Check if the object is present and has been sealed.
Args:
object_id (str): A string used to identify an object.
Args:
object_id (str): A string used to identify an object.
"""
return libplasma.contains(self.conn, object_id)
Returns:
A digest string object's SHA256 hash. If the object isn't in the object
store, the string will have length zero.
"""
return libplasma.hash(self.conn, object_id)
def hash(self, object_id):
"""Compute the hash of an object in the object store.
def seal(self, object_id):
"""Seal the buffer in the PlasmaStore for a particular object ID.
Args:
object_id (str): A string used to identify an object.
Once a buffer has been sealed, the buffer is immutable and can only be
accessed through get.
Returns:
A digest string object's SHA256 hash. If the object isn't in the
object store, the string will have length zero.
"""
return libplasma.hash(self.conn, object_id)
Args:
object_id (str): A string used to identify an object.
"""
libplasma.seal(self.conn, object_id)
def seal(self, object_id):
"""Seal the buffer in the PlasmaStore for a particular object ID.
def delete(self, object_id):
"""Delete the buffer in the PlasmaStore for a particular object ID.
Once a buffer has been sealed, the buffer is immutable and can only be
accessed through get.
Once a buffer has been deleted, the buffer is no longer accessible.
Args:
object_id (str): A string used to identify an object.
"""
libplasma.seal(self.conn, object_id)
Args:
object_id (str): A string used to identify an object.
"""
libplasma.delete(self.conn, object_id)
def delete(self, object_id):
"""Delete the buffer in the PlasmaStore for a particular object ID.
def evict(self, num_bytes):
"""Evict some objects until to recover some bytes.
Once a buffer has been deleted, the buffer is no longer accessible.
Recover at least num_bytes bytes if possible.
Args:
object_id (str): A string used to identify an object.
"""
libplasma.delete(self.conn, object_id)
Args:
num_bytes (int): The number of bytes to attempt to recover.
"""
return libplasma.evict(self.conn, num_bytes)
def evict(self, num_bytes):
"""Evict some objects until to recover some bytes.
def transfer(self, addr, port, object_id):
"""Transfer local object with id object_id to another plasma instance
Recover at least num_bytes bytes if possible.
Args:
addr (str): IPv4 address of the plasma instance the object is sent to.
port (int): Port number of the plasma instance the object is sent to.
object_id (str): A string used to identify an object.
"""
return libplasma.transfer(self.conn, object_id, addr, port)
Args:
num_bytes (int): The number of bytes to attempt to recover.
"""
return libplasma.evict(self.conn, num_bytes)
def fetch(self, object_ids):
"""Fetch the objects with the given IDs from other plasma manager instances.
def transfer(self, addr, port, object_id):
"""Transfer local object with id object_id to another plasma instance
Args:
object_ids (List[str]): A list of strings used to identify the objects.
"""
return libplasma.fetch(self.conn, object_ids)
Args:
addr (str): IPv4 address of the plasma instance the object is sent
to.
port (int): Port number of the plasma instance the object is sent to.
object_id (str): A string used to identify an object.
"""
return libplasma.transfer(self.conn, object_id, addr, port)
def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1):
"""Wait until num_returns objects in object_ids are ready.
def fetch(self, object_ids):
"""Fetch the objects with the given IDs from other plasma managers.
Currently, the object ID arguments to wait must be unique.
Args:
object_ids (List[str]): A list of strings used to identify the
objects.
"""
return libplasma.fetch(self.conn, object_ids)
Args:
object_ids (List[str]): List of object IDs to wait for.
timeout (int): Return to the caller after timeout milliseconds.
num_returns (int): We are waiting for this number of objects to be ready.
def wait(self, object_ids, timeout=PLASMA_WAIT_TIMEOUT, num_returns=1):
"""Wait until num_returns objects in object_ids are ready.
Returns:
ready_ids, waiting_ids (List[str], List[str]): List of object IDs that
are ready and list of object IDs we might still wait on respectively.
"""
# Check that the object ID arguments are unique. The plasma manager
# currently crashes if given duplicate object IDs.
if len(object_ids) != len(set(object_ids)):
raise Exception("Wait requires a list of unique object IDs.")
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout,
num_returns)
return ready_ids, list(waiting_ids)
Currently, the object ID arguments to wait must be unique.
def subscribe(self):
"""Subscribe to notifications about sealed objects."""
self.notification_fd = libplasma.subscribe(self.conn)
Args:
object_ids (List[str]): List of object IDs to wait for.
timeout (int): Return to the caller after timeout milliseconds.
num_returns (int): We are waiting for this number of objects to be
ready.
def get_next_notification(self):
"""Get the next notification from the notification socket."""
return libplasma.receive_notification(self.notification_fd)
Returns:
ready_ids, waiting_ids (List[str], List[str]): List of object IDs
that are ready and list of object IDs we might still wait on
respectively.
"""
# Check that the object ID arguments are unique. The plasma manager
# currently crashes if given duplicate object IDs.
if len(object_ids) != len(set(object_ids)):
raise Exception("Wait requires a list of unique object IDs.")
ready_ids, waiting_ids = libplasma.wait(self.conn, object_ids, timeout,
num_returns)
return ready_ids, list(waiting_ids)
def subscribe(self):
"""Subscribe to notifications about sealed objects."""
self.notification_fd = libplasma.subscribe(self.conn)
def get_next_notification(self):
"""Get the next notification from the notification socket."""
return libplasma.receive_notification(self.notification_fd)
DEFAULT_PLASMA_STORE_MEMORY = 10 ** 9
def random_name():
return str(random.randint(0, 99999999))
return str(random.randint(0, 99999999))
def start_plasma_store(plasma_store_memory=DEFAULT_PLASMA_STORE_MEMORY,
use_valgrind=False, use_profiler=False,
stdout_file=None, stderr_file=None):
"""Start a plasma store process.
"""Start a plasma store process.
Args:
use_valgrind (bool): True if the plasma store should be started inside of
valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the plasma store should be started inside a
profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
redirection should happen, then this should be None.
Args:
use_valgrind (bool): True if the plasma store should be started inside of
valgrind. If this is True, use_profiler must be False.
use_profiler (bool): True if the plasma store should be started inside a
profiler. If this is True, use_valgrind must be False.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
Return:
A tuple of the name of the plasma store socket and the process ID of the
plasma store process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(os.path.abspath(
os.path.dirname(__file__)),
"../core/src/plasma/plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return plasma_store_name, pid
Return:
A tuple of the name of the plasma store socket and the process ID of the
plasma store process.
"""
if use_valgrind and use_profiler:
raise Exception("Cannot use valgrind and profiler at the same time.")
plasma_store_executable = os.path.join(os.path.abspath(
os.path.dirname(__file__)),
"../core/src/plasma/plasma_store")
plasma_store_name = "/tmp/plasma_store{}".format(random_name())
command = [plasma_store_executable,
"-s", plasma_store_name,
"-m", str(plasma_store_memory)]
if use_valgrind:
pid = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--leak-check-heuristics=stdstring",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
elif use_profiler:
pid = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
time.sleep(1.0)
else:
pid = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
time.sleep(0.1)
return plasma_store_name, pid
def new_port():
return random.randint(10000, 65535)
return random.randint(10000, 65535)
def start_plasma_manager(store_name, redis_address,
@@ -364,69 +377,71 @@ def start_plasma_manager(store_name, redis_address,
num_retries=20, use_valgrind=False,
run_profiler=False, stdout_file=None,
stderr_file=None):
"""Start a plasma manager and return the ports it listens on.
"""Start a plasma manager and return the ports it listens on.
Args:
store_name (str): The name of the plasma store socket.
redis_address (str): The address of the Redis server.
node_ip_address (str): The IP address of the node.
plasma_manager_port (int): The port to use for the plasma manager. If this
is not provided, a port will be generated at random.
use_valgrind (bool): True if the Plasma manager should be started inside of
valgrind and False otherwise.
stdout_file: A file handle opened for writing to redirect stdout to. If no
redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If no
redirection should happen, then this should be None.
Args:
store_name (str): The name of the plasma store socket.
redis_address (str): The address of the Redis server.
node_ip_address (str): The IP address of the node.
plasma_manager_port (int): The port to use for the plasma manager. If
this is not provided, a port will be generated at random.
use_valgrind (bool): True if the Plasma manager should be started inside
of valgrind and False otherwise.
stdout_file: A file handle opened for writing to redirect stdout to. If
no redirection should happen, then this should be None.
stderr_file: A file handle opened for writing to redirect stderr to. If
no redirection should happen, then this should be None.
Returns:
A tuple of the Plasma manager socket name, the process ID of the Plasma
manager process, and the port that the manager is listening on.
Returns:
A tuple of the Plasma manager socket name, the process ID of the Plasma
manager process, and the port that the manager is listening on.
Raises:
Exception: An exception is raised if the manager could not be started.
"""
plasma_manager_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/plasma/plasma_manager")
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
if plasma_manager_port is not None:
if num_retries != 1:
raise Exception("num_retries must be 1 if port is specified.")
else:
plasma_manager_port = new_port()
process = None
counter = 0
while counter < num_retries:
if counter > 0:
print("Plasma manager failed to start, retrying now.")
command = [plasma_manager_executable,
"-s", store_name,
"-m", plasma_manager_name,
"-h", node_ip_address,
"-p", str(plasma_manager_port),
"-r", redis_address,
]
if use_valgrind:
process = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
elif run_profiler:
process = subprocess.Popen(["valgrind", "--tool=callgrind"] + command,
stdout=stdout_file, stderr=stderr_file)
Raises:
Exception: An exception is raised if the manager could not be started.
"""
plasma_manager_executable = os.path.join(
os.path.abspath(os.path.dirname(__file__)),
"../core/src/plasma/plasma_manager")
plasma_manager_name = "/tmp/plasma_manager{}".format(random_name())
if plasma_manager_port is not None:
if num_retries != 1:
raise Exception("num_retries must be 1 if port is specified.")
else:
process = subprocess.Popen(command, stdout=stdout_file,
stderr=stderr_file)
# This sleep is critical. If the plasma_manager fails to start because the
# port is already in use, then we need it to fail within 0.1 seconds.
time.sleep(0.1)
# See if the process has terminated
if process.poll() is None:
return plasma_manager_name, process, plasma_manager_port
# Generate a new port and try again.
plasma_manager_port = new_port()
counter += 1
raise Exception("Couldn't start plasma manager.")
plasma_manager_port = new_port()
process = None
counter = 0
while counter < num_retries:
if counter > 0:
print("Plasma manager failed to start, retrying now.")
command = [plasma_manager_executable,
"-s", store_name,
"-m", plasma_manager_name,
"-h", node_ip_address,
"-p", str(plasma_manager_port),
"-r", redis_address,
]
if use_valgrind:
process = subprocess.Popen(["valgrind",
"--track-origins=yes",
"--leak-check=full",
"--show-leak-kinds=all",
"--error-exitcode=1"] + command,
stdout=stdout_file, stderr=stderr_file)
elif run_profiler:
process = subprocess.Popen((["valgrind", "--tool=callgrind"] +
command),
stdout=stdout_file, stderr=stderr_file)
else:
process = subprocess.Popen(command, stdout=stdout_file,
stderr=stderr_file)
# This sleep is critical. If the plasma_manager fails to start because
# the port is already in use, then we need it to fail within 0.1
# seconds.
time.sleep(0.1)
# See if the process has terminated
if process.poll() is None:
return plasma_manager_name, process, plasma_manager_port
# Generate a new port and try again.
plasma_manager_port = new_port()
counter += 1
raise Exception("Couldn't start plasma manager.")
File diff suppressed because it is too large Load Diff
+25 -23
View File
@@ -7,39 +7,41 @@ import random
def random_object_id():
return np.random.bytes(20)
return np.random.bytes(20)
def generate_metadata(length):
metadata_buffer = bytearray(length)
if length > 0:
metadata_buffer[0] = random.randint(0, 255)
metadata_buffer[-1] = random.randint(0, 255)
for _ in range(100):
metadata_buffer[random.randint(0, length - 1)] = random.randint(0, 255)
return metadata_buffer
metadata_buffer = bytearray(length)
if length > 0:
metadata_buffer[0] = random.randint(0, 255)
metadata_buffer[-1] = random.randint(0, 255)
for _ in range(100):
metadata_buffer[random.randint(0, length - 1)] = (
random.randint(0, 255))
return metadata_buffer
def write_to_data_buffer(buff, length):
if length > 0:
buff[0] = chr(random.randint(0, 255))
buff[-1] = chr(random.randint(0, 255))
for _ in range(100):
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
if length > 0:
buff[0] = chr(random.randint(0, 255))
buff[-1] = chr(random.randint(0, 255))
for _ in range(100):
buff[random.randint(0, length - 1)] = chr(random.randint(0, 255))
def create_object_with_id(client, object_id, data_size, metadata_size,
seal=True):
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
write_to_data_buffer(memory_buffer, data_size)
if seal:
client.seal(object_id)
return memory_buffer, metadata
metadata = generate_metadata(metadata_size)
memory_buffer = client.create(object_id, data_size, metadata)
write_to_data_buffer(memory_buffer, data_size)
if seal:
client.seal(object_id)
return memory_buffer, metadata
def create_object(client, data_size, metadata_size, seal=True):
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(client, object_id, data_size,
metadata_size, seal=seal)
return object_id, memory_buffer, metadata
object_id = random_object_id()
memory_buffer, metadata = create_object_with_id(client, object_id,
data_size, metadata_size,
seal=seal)
return object_id, memory_buffer, metadata
+80 -79
View File
@@ -16,97 +16,98 @@ use_tf100_api = (distutils.version.LooseVersion(tf.VERSION) >=
class LSTMPolicy(Policy):
def setup_graph(self, ob_space, ac_space):
"""Setup model used for Policy.
def setup_graph(self, ob_space, ac_space):
"""Setup model used for Policy.
In this A3C implementation, both the Critic and the Actor share the model.
"""
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
In this A3C implementation, both the Critic and the Actor share the
model.
"""
self.x = x = tf.placeholder(tf.float32, [None] + list(ob_space))
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# Introduce a "fake" batch dimension of 1 after flatten so that we can do
# LSTM over the time dim.
x = tf.expand_dims(flatten(x), [0])
for i in range(4):
x = tf.nn.elu(conv2d(x, 32, "l{}".format(i + 1), [3, 3], [2, 2]))
# Introduce a "fake" batch dimension of 1 after flatten so that we can
# do LSTM over the time dim.
x = tf.expand_dims(flatten(x), [0])
size = 256
if use_tf100_api:
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
else:
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
size = 256
if use_tf100_api:
lstm = rnn.BasicLSTMCell(size, state_is_tuple=True)
else:
lstm = rnn.rnn_cell.BasicLSTMCell(size, state_is_tuple=True)
self.state_size = lstm.state_size
step_size = tf.shape(self.x)[:1]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
c_init = np.zeros((1, lstm.state_size.c), np.float32)
h_init = np.zeros((1, lstm.state_size.h), np.float32)
self.state_init = [c_init, h_init]
c_in = tf.placeholder(tf.float32, [1, lstm.state_size.c])
h_in = tf.placeholder(tf.float32, [1, lstm.state_size.h])
self.state_in = [c_in, h_in]
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action",
normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value",
normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
self.global_step = tf.get_variable(
"global_step", [], tf.int32,
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)
if use_tf100_api:
state_in = rnn.LSTMStateTuple(c_in, h_in)
else:
state_in = rnn.rnn_cell.LSTMStateTuple(c_in, h_in)
lstm_outputs, lstm_state = tf.nn.dynamic_rnn(
lstm, x, initial_state=state_in, sequence_length=step_size,
time_major=False)
lstm_c, lstm_h = lstm_state
x = tf.reshape(lstm_outputs, [-1, size])
self.logits = linear(x, ac_space, "action",
normalized_columns_initializer(0.01))
self.vf = tf.reshape(linear(x, 1, "value",
normalized_columns_initializer(1.0)), [-1])
self.state_out = [lstm_c[:1, :], lstm_h[:1, :]]
self.sample = categorical_sample(self.logits, ac_space)[0, :]
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
self.global_step = tf.get_variable(
"global_step", [], tf.int32,
initializer=tf.constant_initializer(0, dtype=tf.int32),
trainable=False)
def get_gradients(self, batch):
"""Computing the gradient is actually model-dependent.
def get_gradients(self, batch):
"""Computing the gradient is actually model-dependent.
The LSTM needs its hidden states in order to compute the gradient
accurately.
"""
feed_dict = {
self.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
self.state_in[0]: batch.features[0],
self.state_in[1]: batch.features[1]
}
self.local_steps += 1
return self.sess.run(self.grads, feed_dict=feed_dict)
The LSTM needs its hidden states in order to compute the gradient
accurately.
"""
feed_dict = {
self.x: batch.si,
self.ac: batch.a,
self.adv: batch.adv,
self.r: batch.r,
self.state_in[0]: batch.features[0],
self.state_in[1]: batch.features[1]
}
self.local_steps += 1
return self.sess.run(self.grads, feed_dict=feed_dict)
def act(self, ob, c, h):
return self.sess.run([self.sample, self.vf] + self.state_out,
{self.x: [ob],
self.state_in[0]: c,
self.state_in[1]: h})
def act(self, ob, c, h):
return self.sess.run([self.sample, self.vf] + self.state_out,
{self.x: [ob],
self.state_in[0]: c,
self.state_in[1]: h})
def value(self, ob, c, h):
return self.sess.run(self.vf, {self.x: [ob],
self.state_in[0]: c,
self.state_in[1]: h})[0]
def value(self, ob, c, h):
return self.sess.run(self.vf, {self.x: [ob],
self.state_in[0]: c,
self.state_in[1]: h})[0]
def get_initial_features(self):
return self.state_init
def get_initial_features(self):
return self.state_init
class RawLSTMPolicy(LSTMPolicy):
def get_weights(self):
if not hasattr(self, "_weights"):
self._weights = self.variables.get_weights()
return self._weights
def get_weights(self):
if not hasattr(self, "_weights"):
self._weights = self.variables.get_weights()
return self._weights
def set_weights(self, weights):
self._weights = weights
def set_weights(self, weights):
self._weights = weights
def model_update(self, grads):
for var, grad in zip(self.var_list, grads):
self._weights[var.name[:-2]] -= 1e-4 * grad
def model_update(self, grads):
for var, grad in zip(self.var_list, grads):
self._weights[var.name[:-2]] -= 1e-4 * grad
+97 -96
View File
@@ -22,108 +22,109 @@ DEFAULT_CONFIG = {
@ray.remote
class Runner(object):
"""Actor object to start running simulation on workers.
"""Actor object to start running simulation on workers.
The gradient computation is also executed from this object.
"""
def __init__(self, env_name, actor_id, logdir, start=True):
env = create_env(env_name)
self.id = actor_id
num_actions = env.action_space.n
self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
actor_id)
self.runner = RunnerThread(env, self.policy, 20)
self.env = env
self.logdir = logdir
if start:
self.start()
def pull_batch_from_queue(self):
"""Take a rollout from the queue of the thread runner."""
rollout = self.runner.queue.get(timeout=600.0)
if isinstance(rollout, BaseException):
raise rollout
while not rollout.terminal:
try:
part = self.runner.queue.get_nowait()
if isinstance(part, BaseException):
raise rollout
rollout.extend(part)
except queue.Empty:
break
return rollout
def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
Calling this clears the queue of completed rollout metrics.
The gradient computation is also executed from this object.
"""
completed = []
while True:
try:
completed.append(self.runner.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
def __init__(self, env_name, actor_id, logdir, start=True):
env = create_env(env_name)
self.id = actor_id
num_actions = env.action_space.n
self.policy = LSTMPolicy(env.observation_space.shape, num_actions,
actor_id)
self.runner = RunnerThread(env, self.policy, 20)
self.env = env
self.logdir = logdir
if start:
self.start()
def start(self):
summary_writer = tf.summary.FileWriter(
os.path.join(self.logdir, "agent_%d" % self.id))
self.summary_writer = summary_writer
self.runner.start_runner(self.policy.sess, summary_writer)
def pull_batch_from_queue(self):
"""Take a rollout from the queue of the thread runner."""
rollout = self.runner.queue.get(timeout=600.0)
if isinstance(rollout, BaseException):
raise rollout
while not rollout.terminal:
try:
part = self.runner.queue.get_nowait()
if isinstance(part, BaseException):
raise rollout
rollout.extend(part)
except queue.Empty:
break
return rollout
def compute_gradient(self, params):
self.policy.set_weights(params)
rollout = self.pull_batch_from_queue()
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
gradient = self.policy.get_gradients(batch)
info = {"id": self.id,
"size": len(batch.a)}
return gradient, info
def get_completed_rollout_metrics(self):
"""Returns metrics on previously completed rollouts.
Calling this clears the queue of completed rollout metrics.
"""
completed = []
while True:
try:
completed.append(self.runner.metrics_queue.get_nowait())
except queue.Empty:
break
return completed
def start(self):
summary_writer = tf.summary.FileWriter(
os.path.join(self.logdir, "agent_%d" % self.id))
self.summary_writer = summary_writer
self.runner.start_runner(self.policy.sess, summary_writer)
def compute_gradient(self, params):
self.policy.set_weights(params)
rollout = self.pull_batch_from_queue()
batch = process_rollout(rollout, gamma=0.99, lambda_=1.0)
gradient = self.policy.get_gradients(batch)
info = {"id": self.id,
"size": len(batch.a)}
return gradient, info
class A3C(Algorithm):
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "A3C"})
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
self.env = create_env(env_name)
self.policy = LSTMPolicy(
self.env.observation_space.shape, self.env.action_space.n, 0)
self.agents = [
Runner.remote(env_name, i, self.logdir)
for i in range(config["num_workers"])]
self.parameters = self.policy.get_weights()
self.iteration = 0
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "A3C"})
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
self.env = create_env(env_name)
self.policy = LSTMPolicy(
self.env.observation_space.shape, self.env.action_space.n, 0)
self.agents = [
Runner.remote(env_name, i, self.logdir)
for i in range(config["num_workers"])]
self.parameters = self.policy.get_weights()
self.iteration = 0
def train(self):
gradient_list = [
agent.compute_gradient.remote(self.parameters)
for agent in self.agents]
max_batches = self.config["num_batches_per_iteration"]
batches_so_far = len(gradient_list)
while gradient_list:
done_id, gradient_list = ray.wait(gradient_list)
gradient, info = ray.get(done_id)[0]
self.policy.model_update(gradient)
self.parameters = self.policy.get_weights()
if batches_so_far < max_batches:
batches_so_far += 1
gradient_list.extend(
[self.agents[info["id"]].compute_gradient.remote(self.parameters)])
res = self.fetch_metrics_from_workers()
self.iteration += 1
return res
def train(self):
gradient_list = [
agent.compute_gradient.remote(self.parameters)
for agent in self.agents]
max_batches = self.config["num_batches_per_iteration"]
batches_so_far = len(gradient_list)
while gradient_list:
done_id, gradient_list = ray.wait(gradient_list)
gradient, info = ray.get(done_id)[0]
self.policy.model_update(gradient)
self.parameters = self.policy.get_weights()
if batches_so_far < max_batches:
batches_so_far += 1
gradient_list.extend(
[self.agents[info["id"]].compute_gradient.remote(
self.parameters)])
res = self.fetch_metrics_from_workers()
self.iteration += 1
return res
def fetch_metrics_from_workers(self):
episode_rewards = []
episode_lengths = []
metric_lists = [
a.get_completed_rollout_metrics.remote() for a in self.agents]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
res = TrainingResult(
self.experiment_id.hex, self.iteration,
np.mean(episode_rewards), np.mean(episode_lengths), dict())
return res
def fetch_metrics_from_workers(self):
episode_rewards = []
episode_lengths = []
metric_lists = [
a.get_completed_rollout_metrics.remote() for a in self.agents]
for metrics in metric_lists:
for episode in ray.get(metrics):
episode_lengths.append(episode.episode_length)
episode_rewards.append(episode.episode_reward)
res = TrainingResult(
self.experiment_id.hex, self.iteration,
np.mean(episode_rewards), np.mean(episode_lengths), dict())
return res
+71 -69
View File
@@ -14,94 +14,96 @@ logger.setLevel(logging.INFO)
def create_env(env_id):
env = gym.make(env_id)
env = AtariProcessing(env)
env = Diagnostic(env)
return env
env = gym.make(env_id)
env = AtariProcessing(env)
env = Diagnostic(env)
return env
def _process_frame42(frame):
frame = frame[34:(34 + 160), :160]
# Resize by half, then down to 42x42 (essentially mipmapping). If we resize
# directly we lose pixels that, when mapped to 42x42, aren't close enough to
# the pixel boundary.
frame = cv2.resize(frame, (80, 80))
frame = cv2.resize(frame, (42, 42))
frame = frame.mean(2)
frame = frame.astype(np.float32)
frame *= (1.0 / 255.0)
frame = np.reshape(frame, [42, 42, 1])
return frame
frame = frame[34:(34 + 160), :160]
# Resize by half, then down to 42x42 (essentially mipmapping). If we resize
# directly we lose pixels that, when mapped to 42x42, aren't close enough
# to the pixel boundary.
frame = cv2.resize(frame, (80, 80))
frame = cv2.resize(frame, (42, 42))
frame = frame.mean(2)
frame = frame.astype(np.float32)
frame *= (1.0 / 255.0)
frame = np.reshape(frame, [42, 42, 1])
return frame
class AtariProcessing(gym.ObservationWrapper):
def __init__(self, env=None):
super(AtariProcessing, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [42, 42, 1])
def __init__(self, env=None):
super(AtariProcessing, self).__init__(env)
self.observation_space = Box(0.0, 1.0, [42, 42, 1])
def _observation(self, observation):
return _process_frame42(observation)
def _observation(self, observation):
return _process_frame42(observation)
class Diagnostic(gym.Wrapper):
def __init__(self, env=None):
super(Diagnostic, self).__init__(env)
self.diagnostics = DiagnosticsLogger()
def __init__(self, env=None):
super(Diagnostic, self).__init__(env)
self.diagnostics = DiagnosticsLogger()
def _reset(self):
observation = self.env.reset()
return self.diagnostics._after_reset(observation)
def _reset(self):
observation = self.env.reset()
return self.diagnostics._after_reset(observation)
def _step(self, action):
results = self.env.step(action)
return self.diagnostics._after_step(*results)
def _step(self, action):
results = self.env.step(action)
return self.diagnostics._after_step(*results)
class DiagnosticsLogger(object):
def __init__(self, log_interval=503):
self._episode_time = time.time()
self._last_time = time.time()
self._local_t = 0
self._log_interval = log_interval
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
self._last_episode_id = -1
def __init__(self, log_interval=503):
self._episode_time = time.time()
self._last_time = time.time()
self._local_t = 0
self._log_interval = log_interval
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
self._last_episode_id = -1
def _after_reset(self, observation):
logger.info("Resetting environment")
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
return observation
def _after_reset(self, observation):
logger.info("Resetting environment")
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
return observation
def _after_step(self, observation, reward, done, info):
to_log = {}
if self._episode_length == 0:
self._episode_time = time.time()
def _after_step(self, observation, reward, done, info):
to_log = {}
if self._episode_length == 0:
self._episode_time = time.time()
self._local_t += 1
self._local_t += 1
if self._local_t % self._log_interval == 0:
cur_time = time.time()
self._last_time = cur_time
if self._local_t % self._log_interval == 0:
cur_time = time.time()
self._last_time = cur_time
if reward is not None:
self._episode_reward += reward
if observation is not None:
self._episode_length += 1
self._all_rewards.append(reward)
if reward is not None:
self._episode_reward += reward
if observation is not None:
self._episode_length += 1
self._all_rewards.append(reward)
if done:
logger.info("Episode terminating: episode_reward=%s episode_length=%s",
self._episode_reward, self._episode_length)
total_time = time.time() - self._episode_time
to_log["global/episode_reward"] = self._episode_reward
to_log["global/episode_length"] = self._episode_length
to_log["global/episode_time"] = total_time
to_log["global/reward_per_time"] = self._episode_reward / total_time
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
if done:
logger.info("Episode terminating: episode_reward=%s "
"episode_length=%s",
self._episode_reward, self._episode_length)
total_time = time.time() - self._episode_time
to_log["global/episode_reward"] = self._episode_reward
to_log["global/episode_length"] = self._episode_length
to_log["global/episode_time"] = total_time
to_log["global/reward_per_time"] = (self._episode_reward /
total_time)
self._episode_reward = 0
self._episode_length = 0
self._all_rewards = []
return observation, reward, done, to_log
return observation, reward, done, to_log
+15 -15
View File
@@ -11,22 +11,22 @@ from ray.rllib.a3c import A3C, DEFAULT_CONFIG
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the A3C algorithm.")
parser.add_argument("--environment", default="PongDeterministic-v3",
type=str, help="The gym environment to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-workers", default=4, type=int,
help="The number of A3C workers to use>")
parser = argparse.ArgumentParser(description="Run the A3C algorithm.")
parser.add_argument("--environment", default="PongDeterministic-v3",
type=str, help="The gym environment to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--num-workers", default=4, type=int,
help="The number of A3C workers to use>")
args = parser.parse_args()
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers)
args = parser.parse_args()
ray.init(redis_address=args.redis_address, num_cpus=args.num_workers)
config = DEFAULT_CONFIG.copy()
config["num_workers"] = args.num_workers
config = DEFAULT_CONFIG.copy()
config["num_workers"] = args.num_workers
a3c = A3C(args.environment, config)
a3c = A3C(args.environment, config)
while True:
res = a3c.train()
print("current status: {}".format(res))
while True:
res = a3c.train()
print("current status: {}".format(res))
+101 -100
View File
@@ -8,138 +8,139 @@ import ray
class Policy(object):
"""The policy base class."""
def __init__(self, ob_space, ac_space, task, name="local"):
self.local_steps = 0
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
self.g = tf.Graph()
with self.g.as_default(), tf.device(worker_device):
with tf.variable_scope(name):
self.setup_graph(ob_space, ac_space)
assert all([hasattr(self, attr)
for attr in ["vf", "logits", "x", "var_list"]])
print("Setting up loss")
self.setup_loss(ac_space)
self.initialize()
"""The policy base class."""
def __init__(self, ob_space, ac_space, task, name="local"):
self.local_steps = 0
worker_device = "/job:localhost/replica:0/task:0/cpu:0"
self.g = tf.Graph()
with self.g.as_default(), tf.device(worker_device):
with tf.variable_scope(name):
self.setup_graph(ob_space, ac_space)
assert all([hasattr(self, attr)
for attr in ["vf", "logits", "x", "var_list"]])
print("Setting up loss")
self.setup_loss(ac_space)
self.initialize()
def setup_graph(self):
raise NotImplementedError
def setup_graph(self):
raise NotImplementedError
def setup_loss(self, num_actions, summarize=True):
self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac")
self.adv = tf.placeholder(tf.float32, [None], name="adv")
self.r = tf.placeholder(tf.float32, [None], name="r")
def setup_loss(self, num_actions, summarize=True):
self.ac = tf.placeholder(tf.float32, [None, num_actions], name="ac")
self.adv = tf.placeholder(tf.float32, [None], name="adv")
self.r = tf.placeholder(tf.float32, [None], name="r")
log_prob_tf = tf.nn.log_softmax(self.logits)
prob_tf = tf.nn.softmax(self.logits)
log_prob_tf = tf.nn.log_softmax(self.logits)
prob_tf = tf.nn.softmax(self.logits)
# The "policy gradients" loss: its derivative is precisely the policy
# gradient. Notice that self.ac is a placeholder that is provided
# externally. adv will contain the advantages, as calculated in
# process_rollout.
pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac,
[1]) * self.adv)
# The "policy gradients" loss: its derivative is precisely the policy
# gradient. Notice that self.ac is a placeholder that is provided
# externally. adv will contain the advantages, as calculated in
# process_rollout.
pi_loss = - tf.reduce_sum(tf.reduce_sum(log_prob_tf * self.ac,
[1]) * self.adv)
# loss of value function
vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r))
vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss")
entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
# loss of value function
vf_loss = 0.5 * tf.reduce_sum(tf.square(self.vf - self.r))
vf_loss = tf.Print(vf_loss, [vf_loss], "Value Fn Loss")
entropy = - tf.reduce_sum(prob_tf * log_prob_tf)
bs = tf.to_float(tf.shape(self.x)[0])
self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01
bs = tf.to_float(tf.shape(self.x)[0])
self.loss = pi_loss + 0.5 * vf_loss - entropy * 0.01
grads = tf.gradients(self.loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
grads = tf.gradients(self.loss, self.var_list)
self.grads, _ = tf.clip_by_global_norm(grads, 40.0)
grads_and_vars = list(zip(self.grads, self.var_list))
opt = tf.train.AdamOptimizer(1e-4)
self._apply_gradients = opt.apply_gradients(grads_and_vars)
grads_and_vars = list(zip(self.grads, self.var_list))
opt = tf.train.AdamOptimizer(1e-4)
self._apply_gradients = opt.apply_gradients(grads_and_vars)
if summarize:
tf.summary.scalar("model/policy_loss", pi_loss / bs)
tf.summary.scalar("model/value_loss", vf_loss / bs)
tf.summary.scalar("model/entropy", entropy / bs)
tf.summary.image("model/state", self.x)
self.summary_op = tf.summary.merge_all()
if summarize:
tf.summary.scalar("model/policy_loss", pi_loss / bs)
tf.summary.scalar("model/value_loss", vf_loss / bs)
tf.summary.scalar("model/entropy", entropy / bs)
tf.summary.image("model/state", self.x)
self.summary_op = tf.summary.merge_all()
def initialize(self):
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
self.variables = ray.experimental.TensorFlowVariables(self.loss, self.sess)
self.sess.run(tf.global_variables_initializer())
def initialize(self):
self.sess = tf.Session(graph=self.g, config=tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=2))
self.variables = ray.experimental.TensorFlowVariables(self.loss,
self.sess)
self.sess.run(tf.global_variables_initializer())
def model_update(self, grads):
feed_dict = {self.grads[i]: grads[i]
for i in range(len(grads))}
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
def model_update(self, grads):
feed_dict = {self.grads[i]: grads[i]
for i in range(len(grads))}
self.sess.run(self._apply_gradients, feed_dict=feed_dict)
def get_weights(self):
weights = self.variables.get_weights()
return weights
def get_weights(self):
weights = self.variables.get_weights()
return weights
def set_weights(self, weights):
self.variables.set_weights(weights)
def set_weights(self, weights):
self.variables.set_weights(weights)
def get_gradients(self, batch):
raise NotImplementedError
def get_gradients(self, batch):
raise NotImplementedError
def get_vf_loss(self):
raise NotImplementedError
def get_vf_loss(self):
raise NotImplementedError
def act(self, ob):
raise NotImplementedError
def act(self, ob):
raise NotImplementedError
def value(self, ob):
raise NotImplementedError
def value(self, ob):
raise NotImplementedError
def normalized_columns_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def flatten(x):
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
def conv2d(x, num_filters, name, filter_size=(3, 3), stride=(1, 1), pad="SAME",
dtype=tf.float32, collections=None):
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]),
num_filters]
with tf.variable_scope(name):
stride_shape = [1, stride[0], stride[1], 1]
filter_shape = [filter_size[0], filter_size[1], int(x.get_shape()[3]),
num_filters]
# There are "num input feature maps * filter height * filter width"
# inputs to each hidden unit.
fan_in = np.prod(filter_shape[:3])
# Each unit in the lower layer receives a gradient from:
# "num output feature maps * filter height * filter width" / pooling size.
fan_out = np.prod(filter_shape[:2]) * num_filters
# Initialize weights with random weights.
w_bound = np.sqrt(6 / (fan_in + fan_out))
# There are "num input feature maps * filter height * filter width"
# inputs to each hidden unit.
fan_in = np.prod(filter_shape[:3])
# Each unit in the lower layer receives a gradient from: "num output
# feature maps * filter height * filter width" / pooling size.
fan_out = np.prod(filter_shape[:2]) * num_filters
# Initialize weights with random weights.
w_bound = np.sqrt(6 / (fan_in + fan_out))
w = tf.get_variable("W", filter_shape, dtype,
tf.random_uniform_initializer(-w_bound, w_bound),
collections=collections)
b = tf.get_variable("b", [1, 1, 1, num_filters],
initializer=tf.constant_initializer(0.0),
collections=collections)
return tf.nn.conv2d(x, w, stride_shape, pad) + b
w = tf.get_variable("W", filter_shape, dtype,
tf.random_uniform_initializer(-w_bound, w_bound),
collections=collections)
b = tf.get_variable("b", [1, 1, 1, num_filters],
initializer=tf.constant_initializer(0.0),
collections=collections)
return tf.nn.conv2d(x, w, stride_shape, pad) + b
def linear(x, size, name, initializer=None, bias_init=0):
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
initializer=initializer)
b = tf.get_variable(name + "/b", [size],
initializer=tf.constant_initializer(bias_init))
return tf.matmul(x, w) + b
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
initializer=initializer)
b = tf.get_variable(name + "/b", [size],
initializer=tf.constant_initializer(bias_init))
return tf.matmul(x, w) + b
def categorical_sample(logits, d):
value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1],
keep_dims=True),
1), [1])
return tf.one_hot(value, d)
value = tf.squeeze(tf.multinomial(logits - tf.reduce_max(logits, [1],
keep_dims=True),
1), [1])
return tf.one_hot(value, d)
+132 -131
View File
@@ -11,26 +11,26 @@ import threading
def discount(x, gamma):
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
return scipy.signal.lfilter([1], [1, -gamma], x[::-1], axis=0)[::-1]
def process_rollout(rollout, gamma, lambda_=1.0):
"""Given a rollout, compute its returns and the advantage."""
batch_si = np.asarray(rollout.states)
batch_a = np.asarray(rollout.actions)
rewards = np.asarray(rollout.rewards)
vpred_t = np.asarray(rollout.values + [rollout.r])
"""Given a rollout, compute its returns and the advantage."""
batch_si = np.asarray(rollout.states)
batch_a = np.asarray(rollout.actions)
rewards = np.asarray(rollout.rewards)
vpred_t = np.asarray(rollout.values + [rollout.r])
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
batch_r = discount(rewards_plus_v, gamma)[:-1]
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes "Generalized Advantage Estimation":
# https://arxiv.org/abs/1506.02438
batch_adv = discount(delta_t, gamma * lambda_)
rewards_plus_v = np.asarray(rollout.rewards + [rollout.r])
batch_r = discount(rewards_plus_v, gamma)[:-1]
delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
# This formula for the advantage comes "Generalized Advantage Estimation":
# https://arxiv.org/abs/1506.02438
batch_adv = discount(delta_t, gamma * lambda_)
features = rollout.features[0]
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
features)
features = rollout.features[0]
return Batch(batch_si, batch_a, batch_adv, batch_r, rollout.terminal,
features)
Batch = namedtuple(
@@ -41,142 +41,143 @@ CompletedRollout = namedtuple(
class PartialRollout(object):
"""A piece of a complete rollout.
"""A piece of a complete rollout.
We run our agent, and process its experience once it has processed enough
steps.
"""
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.r = 0.0
self.terminal = False
self.features = []
We run our agent, and process its experience once it has processed enough
steps.
"""
def __init__(self):
self.states = []
self.actions = []
self.rewards = []
self.values = []
self.r = 0.0
self.terminal = False
self.features = []
def add(self, state, action, reward, value, terminal, features):
self.states += [state]
self.actions += [action]
self.rewards += [reward]
self.values += [value]
self.terminal = terminal
self.features += [features]
def add(self, state, action, reward, value, terminal, features):
self.states += [state]
self.actions += [action]
self.rewards += [reward]
self.values += [value]
self.terminal = terminal
self.features += [features]
def extend(self, other):
assert not self.terminal
self.states.extend(other.states)
self.actions.extend(other.actions)
self.rewards.extend(other.rewards)
self.values.extend(other.values)
self.r = other.r
self.terminal = other.terminal
self.features.extend(other.features)
def extend(self, other):
assert not self.terminal
self.states.extend(other.states)
self.actions.extend(other.actions)
self.rewards.extend(other.rewards)
self.values.extend(other.values)
self.r = other.r
self.terminal = other.terminal
self.features.extend(other.features)
class RunnerThread(threading.Thread):
"""This thread interacts with the environment and tells it what to do."""
def __init__(self, env, policy, num_local_steps, visualise=False):
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.env = env
self.last_features = None
self.policy = policy
self.daemon = True
self.sess = None
self.summary_writer = None
self.visualise = visualise
"""This thread interacts with the environment and tells it what to do."""
def __init__(self, env, policy, num_local_steps, visualise=False):
threading.Thread.__init__(self)
self.queue = queue.Queue(5)
self.metrics_queue = queue.Queue()
self.num_local_steps = num_local_steps
self.env = env
self.last_features = None
self.policy = policy
self.daemon = True
self.sess = None
self.summary_writer = None
self.visualise = visualise
def start_runner(self, sess, summary_writer):
self.sess = sess
self.summary_writer = summary_writer
self.start()
def start_runner(self, sess, summary_writer):
self.sess = sess
self.summary_writer = summary_writer
self.start()
def run(self):
try:
with self.sess.as_default():
self._run()
except BaseException as e:
self.queue.put(e)
raise e
def run(self):
try:
with self.sess.as_default():
self._run()
except BaseException as e:
self.queue.put(e)
raise e
def _run(self):
rollout_provider = env_runner(
self.env, self.policy, self.num_local_steps,
self.summary_writer, self.visualise)
while True:
# The timeout variable exists because apparently, if one worker dies, the
# other workers won't die with it, unless the timeout is set to some
# large number. This is an empirical observation.
item = next(rollout_provider)
if isinstance(item, CompletedRollout):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
def _run(self):
rollout_provider = env_runner(
self.env, self.policy, self.num_local_steps,
self.summary_writer, self.visualise)
while True:
# The timeout variable exists because apparently, if one worker
# dies, the other workers won't die with it, unless the timeout is
# set to some large number. This is an empirical observation.
item = next(rollout_provider)
if isinstance(item, CompletedRollout):
self.metrics_queue.put(item)
else:
self.queue.put(item, timeout=600.0)
def env_runner(env, policy, num_local_steps, summary_writer, render):
"""This implements the logic of the thread runner.
"""This implements the logic of the thread runner.
It continually runs the policy, and as long as the rollout exceeds a certain
length, the thread runner appends the policy to the queue.
"""
last_state = env.reset()
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
last_features = policy.get_initial_features()
length = 0
rewards = 0
rollout_number = 0
It continually runs the policy, and as long as the rollout exceeds a
certain length, the thread runner appends the policy to the queue.
"""
last_state = env.reset()
timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
last_features = policy.get_initial_features()
length = 0
rewards = 0
rollout_number = 0
while True:
terminal_end = False
rollout = PartialRollout()
while True:
terminal_end = False
rollout = PartialRollout()
for _ in range(num_local_steps):
fetched = policy.act(last_state, *last_features)
action, value_, features = fetched[0], fetched[1], fetched[2:]
# Argmax to convert from one-hot.
state, reward, terminal, info = env.step(action.argmax())
if render:
env.render()
for _ in range(num_local_steps):
fetched = policy.act(last_state, *last_features)
action, value_, features = fetched[0], fetched[1], fetched[2:]
# Argmax to convert from one-hot.
state, reward, terminal, info = env.step(action.argmax())
if render:
env.render()
length += 1
rewards += reward
if length >= timestep_limit:
terminal = True
length += 1
rewards += reward
if length >= timestep_limit:
terminal = True
# Collect the experience.
rollout.add(last_state, action, reward, value_, terminal, last_features)
# Collect the experience.
rollout.add(last_state, action, reward, value_, terminal,
last_features)
last_state = state
last_features = features
last_state = state
last_features = features
if info:
summary = tf.Summary()
for k, v in info.items():
summary.value.add(tag=k, simple_value=float(v))
summary_writer.add_summary(summary, rollout_number)
summary_writer.flush()
if info:
summary = tf.Summary()
for k, v in info.items():
summary.value.add(tag=k, simple_value=float(v))
summary_writer.add_summary(summary, rollout_number)
summary_writer.flush()
if terminal:
terminal_end = True
yield CompletedRollout(length, rewards)
if terminal:
terminal_end = True
yield CompletedRollout(length, rewards)
if length >= timestep_limit or not env.metadata.get("semantics"
".autoreset"):
last_state = env.reset()
last_features = policy.get_initial_features()
rollout_number += 1
length = 0
rewards = 0
break
if (length >= timestep_limit or
not env.metadata.get("semantics.autoreset")):
last_state = env.reset()
last_features = policy.get_initial_features()
rollout_number += 1
length = 0
rewards = 0
break
if not terminal_end:
rollout.r = policy.value(last_state, *last_features)
if not terminal_end:
rollout.r = policy.value(last_state, *last_features)
# Once we have enough experience, yield it, and have the ThreadRunner
# place it on a queue.
yield rollout
# Once we have enough experience, yield it, and have the ThreadRunner
# place it on a queue.
yield rollout
+67 -66
View File
@@ -9,39 +9,39 @@ import tempfile
import uuid
import smart_open
if sys.version_info[0] == 2:
import cStringIO as StringIO
import cStringIO as StringIO
elif sys.version_info[0] == 3:
import io as StringIO
import io as StringIO
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class RLLibEncoder(json.JSONEncoder):
def default(self, value):
if isinstance(value, np.float32) or isinstance(value, np.float64):
if np.isnan(value):
return None
else:
return float(value)
def default(self, value):
if isinstance(value, np.float32) or isinstance(value, np.float64):
if np.isnan(value):
return None
else:
return float(value)
class RLLibLogger(object):
"""Writing small amounts of data to S3 with real-time updates.
"""
"""Writing small amounts of data to S3 with real-time updates.
"""
def __init__(self, uri):
self.result_buffer = StringIO.StringIO()
self.uri = uri
def __init__(self, uri):
self.result_buffer = StringIO.StringIO()
self.uri = uri
def write(self, b):
# TODO(pcm): At the moment we are writing the whole results output from
# the beginning in each iteration. This will write O(n^2) bytes where n
# is the number of bytes printed so far. Fix this! This should at least
# only write the last 5MBs (S3 chunksize).
with smart_open.smart_open(self.uri, "w") as f:
self.result_buffer.write(b)
f.write(self.result_buffer.getvalue())
def write(self, b):
# TODO(pcm): At the moment we are writing the whole results output from
# the beginning in each iteration. This will write O(n^2) bytes where n
# is the number of bytes printed so far. Fix this! This should at least
# only write the last 5MBs (S3 chunksize).
with smart_open.smart_open(self.uri, "w") as f:
self.result_buffer.write(b)
f.write(self.result_buffer.getvalue())
TrainingResult = namedtuple("TrainingResult", [
@@ -54,54 +54,55 @@ TrainingResult = namedtuple("TrainingResult", [
class Algorithm(object):
"""All RLlib algorithms extend this base class.
"""All RLlib algorithms extend this base class.
Algorithm objects retain internal model state between calls to train(), so
you should create a new algorithm instance for each training session.
Algorithm objects retain internal model state between calls to train(), so
you should create a new algorithm instance for each training session.
Attributes:
env_name (str): Name of the OpenAI gym environment to train against.
config (obj): Algorithm-specific configuration data.
logdir (str): Directory in which training outputs should be placed.
TODO(ekl): support checkpoint / restore of training state.
"""
def __init__(self, env_name, config, upload_dir="file:///tmp/ray"):
"""Initialize an RLLib algorithm.
Args:
env_name (str): The name of the OpenAI gym environment to use.
Attributes:
env_name (str): Name of the OpenAI gym environment to train against.
config (obj): Algorithm-specific configuration data.
upload_dir (str): Root directory into which the output directory
should be placed. Can be local like file:///tmp/ray/ or on S3
like s3://bucketname/.
"""
self.experiment_id = uuid.uuid4()
self.env_name = env_name
self.config = config
self.config.update({"experiment_id": self.experiment_id.hex})
self.config.update({"env_name": env_name})
prefix = "{}_{}_{}".format(
env_name,
self.__class__.__name__,
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
if upload_dir.startswith("file"):
self.logdir = "file://" + tempfile.mkdtemp(prefix=prefix, dir="/tmp/ray")
else:
self.logdir = os.path.join(upload_dir, prefix)
log_path = os.path.join(self.logdir, "config.json")
with smart_open.smart_open(log_path, "w") as f:
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
logger.info(
"%s algorithm created with logdir '%s'",
self.__class__.__name__, self.logdir)
logdir (str): Directory in which training outputs should be placed.
def train(self):
"""Runs one logical iteration of training.
Returns:
A TrainingResult that describes training progress.
TODO(ekl): support checkpoint / restore of training state.
"""
raise NotImplementedError
def __init__(self, env_name, config, upload_dir="file:///tmp/ray"):
"""Initialize an RLLib algorithm.
Args:
env_name (str): The name of the OpenAI gym environment to use.
config (obj): Algorithm-specific configuration data.
upload_dir (str): Root directory into which the output directory
should be placed. Can be local like file:///tmp/ray/ or on S3
like s3://bucketname/.
"""
self.experiment_id = uuid.uuid4()
self.env_name = env_name
self.config = config
self.config.update({"experiment_id": self.experiment_id.hex})
self.config.update({"env_name": env_name})
prefix = "{}_{}_{}".format(
env_name,
self.__class__.__name__,
datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
if upload_dir.startswith("file"):
self.logdir = "file://" + tempfile.mkdtemp(prefix=prefix,
dir="/tmp/ray")
else:
self.logdir = os.path.join(upload_dir, prefix)
log_path = os.path.join(self.logdir, "config.json")
with smart_open.smart_open(log_path, "w") as f:
json.dump(self.config, f, sort_keys=True, cls=RLLibEncoder)
logger.info(
"%s algorithm created with logdir '%s'",
self.__class__.__name__, self.logdir)
def train(self):
"""Runs one logical iteration of training.
Returns:
A TrainingResult that describes training progress.
"""
raise NotImplementedError
+177 -170
View File
@@ -80,198 +80,205 @@ from ray.rllib.dqn.common import tf_util as U
def build_act(make_obs_ph, q_func, num_actions, scope="deepq", reuse=None):
"""Creates the act function:
"""Creates the act function:
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that take a name and creates a placeholder of input with
that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of
every action.
num_actions: int
number of actions.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the
scope must be given.
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that take a name and creates a placeholder of input with
that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of
every action.
num_actions: int
number of actions.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the
scope must be given.
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given observation.
` See the top of the file for details.
"""
with tf.variable_scope(scope, reuse=reuse):
observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given observation.
` See the top of the file for details.
"""
with tf.variable_scope(scope, reuse=reuse):
observations_ph = U.ensure_tf_input(make_obs_ph("observation"))
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
eps = tf.get_variable(
"eps", (), initializer=tf.constant_initializer(0))
eps = tf.get_variable(
"eps", (), initializer=tf.constant_initializer(0))
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
deterministic_actions = tf.argmax(q_values, axis=1)
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
deterministic_actions = tf.argmax(q_values, axis=1)
batch_size = tf.shape(observations_ph.get())[0]
random_actions = tf.random_uniform(
tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
chose_random = tf.random_uniform(
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
stochastic_actions = tf.where(
chose_random, random_actions, deterministic_actions)
batch_size = tf.shape(observations_ph.get())[0]
random_actions = tf.random_uniform(
tf.stack([batch_size]), minval=0, maxval=num_actions,
dtype=tf.int64)
chose_random = tf.random_uniform(
tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
stochastic_actions = tf.where(
chose_random, random_actions, deterministic_actions)
output_actions = tf.cond(
stochastic_ph, lambda: stochastic_actions,
lambda: deterministic_actions)
update_eps_expr = eps.assign(
tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
output_actions = tf.cond(
stochastic_ph, lambda: stochastic_actions,
lambda: deterministic_actions)
update_eps_expr = eps.assign(
tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
act = U.function(
inputs=[observations_ph, stochastic_ph, update_eps_ph],
outputs=output_actions,
givens={update_eps_ph: -1.0, stochastic_ph: True},
updates=[update_eps_expr])
return act
act = U.function(
inputs=[observations_ph, stochastic_ph, update_eps_ph],
outputs=output_actions,
givens={update_eps_ph: -1.0, stochastic_ph: True},
updates=[update_eps_expr])
return act
def build_train(
make_obs_ph, q_func, num_actions, optimizer, grad_norm_clipping=None,
gamma=1.0, double_q=True, scope="deepq", reuse=None):
"""Creates the train function:
"""Creates the train function:
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that takes a name and creates a placeholder of input with
that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of
every action.
num_actions: int
number of actions
reuse: bool
whether or not to reuse the graph variables
optimizer: tf.train.Optimizer
optimizer to use for the Q-learning objective.
grad_norm_clipping: float or None
clip gradient norms to this value. If None no clipping is performed.
gamma: float
discount rate.
double_q: bool
if true will use Double Q Learning (https://arxiv.org/abs/1509.06461).
In general it is a good idea to keep it enabled.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the
scope must be given.
Parameters
----------
make_obs_ph: str -> tf.placeholder or TfInput
a function that takes a name and creates a placeholder of input with
that name
q_func: (tf.Variable, int, str, bool) -> tf.Variable
the model that takes the following inputs:
observation_in: object
the output of observation placeholder
num_actions: int
number of actions
scope: str
reuse: bool
should be passed to outer variable scope
and returns a tensor of shape (batch_size, num_actions) with values of
every action.
num_actions: int
number of actions
reuse: bool
whether or not to reuse the graph variables
optimizer: tf.train.Optimizer
optimizer to use for the Q-learning objective.
grad_norm_clipping: float or None
clip gradient norms to this value. If None no clipping is performed.
gamma: float
discount rate.
double_q: bool
if true will use Double Q Learning (https://arxiv.org/abs/1509.06461).
In general it is a good idea to keep it enabled.
scope: str or VariableScope
optional scope for variable_scope.
reuse: bool or None
whether or not the variables should be reused. To be able to reuse the
scope must be given.
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given observation.
` See the top of the file for details.
train: (object, np.array, np.array, object, np.array, np.array) -> np.array
optimize the error in Bellman's equation.
` See the top of the file for details.
update_target: () -> ()
copy the parameters from optimized Q function to the target Q function.
` See the top of the file for details.
debug: {str: function}
a bunch of functions to print debug data like q_values.
"""
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope, reuse=reuse)
Returns
-------
act: (tf.Variable, bool, float) -> tf.Variable
function to select and action given observation.
` See the top of the file for details.
train: (object, np.array, np.array, object, np.array, np.array) -> np.array
optimize the error in Bellman's equation.
` See the top of the file for details.
update_target: () -> ()
copy the parameters from optimized Q function to the target Q function.
` See the top of the file for details.
debug: {str: function}
a bunch of functions to print debug data like q_values.
"""
act_f = build_act(make_obs_ph, q_func, num_actions, scope=scope,
reuse=reuse)
with tf.variable_scope(scope, reuse=reuse):
# set up placeholders
obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t"))
act_t_ph = tf.placeholder(tf.int32, [None], name="action")
rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1"))
done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
importance_weights_ph = tf.placeholder(tf.float32, [None], name="weight")
with tf.variable_scope(scope, reuse=reuse):
# set up placeholders
obs_t_input = U.ensure_tf_input(make_obs_ph("obs_t"))
act_t_ph = tf.placeholder(tf.int32, [None], name="action")
rew_t_ph = tf.placeholder(tf.float32, [None], name="reward")
obs_tp1_input = U.ensure_tf_input(make_obs_ph("obs_tp1"))
done_mask_ph = tf.placeholder(tf.float32, [None], name="done")
importance_weights_ph = tf.placeholder(tf.float32, [None],
name="weight")
# q network evaluation
q_t = q_func(
obs_t_input.get(), num_actions, scope="q_func",
reuse=True) # reuse parameters from act
q_func_vars = U.scope_vars(U.absolute_scope_name("q_func"))
# q network evaluation
q_t = q_func(
obs_t_input.get(), num_actions, scope="q_func",
reuse=True) # reuse parameters from act
q_func_vars = U.scope_vars(U.absolute_scope_name("q_func"))
# target q network evalution
q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func")
target_q_func_vars = U.scope_vars(U.absolute_scope_name("target_q_func"))
# target q network evalution
q_tp1 = q_func(obs_tp1_input.get(), num_actions, scope="target_q_func")
target_q_func_vars = U.scope_vars(
U.absolute_scope_name("target_q_func"))
# q scores for actions which we know were selected in the given state.
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions), 1)
# q scores for actions which we know were selected in the given state.
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(act_t_ph, num_actions),
1)
# compute estimate of best possible value starting from state at t + 1
if double_q:
q_tp1_using_online_net = q_func(
obs_tp1_input.get(), num_actions, scope="q_func", reuse=True)
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
q_tp1_best = tf.reduce_sum(
q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions), 1)
else:
q_tp1_best = tf.reduce_max(q_tp1, 1)
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
# compute estimate of best possible value starting from state at t + 1
if double_q:
q_tp1_using_online_net = q_func(
obs_tp1_input.get(), num_actions, scope="q_func", reuse=True)
q_tp1_best_using_online_net = tf.arg_max(q_tp1_using_online_net, 1)
q_tp1_best = tf.reduce_sum(
q_tp1 * tf.one_hot(q_tp1_best_using_online_net, num_actions),
1)
else:
q_tp1_best = tf.reduce_max(q_tp1, 1)
q_tp1_best_masked = (1.0 - done_mask_ph) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked
# compute RHS of bellman equation
q_t_selected_target = rew_t_ph + gamma * q_tp1_best_masked
# compute the error (potentially clipped)
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
errors = U.huber_loss(td_error)
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
# compute optimization op (potentially with gradient clipping)
if grad_norm_clipping is not None:
optimize_expr = U.minimize_and_clip(
optimizer, weighted_error, var_list=q_func_vars,
clip_val=grad_norm_clipping)
else:
optimize_expr = optimizer.minimize(weighted_error, var_list=q_func_vars)
# compute the error (potentially clipped)
td_error = q_t_selected - tf.stop_gradient(q_t_selected_target)
errors = U.huber_loss(td_error)
weighted_error = tf.reduce_mean(importance_weights_ph * errors)
# compute optimization op (potentially with gradient clipping)
if grad_norm_clipping is not None:
optimize_expr = U.minimize_and_clip(
optimizer, weighted_error, var_list=q_func_vars,
clip_val=grad_norm_clipping)
else:
optimize_expr = optimizer.minimize(weighted_error,
var_list=q_func_vars)
# update_target_fn will be called periodically to copy Q network to
# target Q network
update_target_expr = []
for var, var_target in zip(
sorted(q_func_vars, key=lambda v: v.name),
sorted(target_q_func_vars, key=lambda v: v.name)):
update_target_expr.append(var_target.assign(var))
update_target_expr = tf.group(*update_target_expr)
# update_target_fn will be called periodically to copy Q network to
# target Q network
update_target_expr = []
for var, var_target in zip(
sorted(q_func_vars, key=lambda v: v.name),
sorted(target_q_func_vars, key=lambda v: v.name)):
update_target_expr.append(var_target.assign(var))
update_target_expr = tf.group(*update_target_expr)
# Create callable functions
train = U.function(
inputs=[
obs_t_input,
act_t_ph,
rew_t_ph,
obs_tp1_input,
done_mask_ph,
importance_weights_ph
],
outputs=td_error,
updates=[optimize_expr])
update_target = U.function([], [], updates=[update_target_expr])
# Create callable functions
train = U.function(
inputs=[
obs_t_input,
act_t_ph,
rew_t_ph,
obs_tp1_input,
done_mask_ph,
importance_weights_ph
],
outputs=td_error,
updates=[optimize_expr])
update_target = U.function([], [], updates=[update_target_expr])
q_values = U.function([obs_t_input], q_t)
q_values = U.function([obs_t_input], q_t)
return act_f, train, update_target, {'q_values': q_values}
return act_f, train, update_target, {'q_values': q_values}
@@ -11,236 +11,240 @@ from gym import spaces
class NoopResetEnv(gym.Wrapper):
def __init__(self, env=None, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
super(NoopResetEnv, self).__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def __init__(self, env=None, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
super(NoopResetEnv, self).__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def _reset(self):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = np.random.randint(1, self.noop_max + 1)
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
if done:
obs = self.env.reset()
return obs
def _reset(self):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = np.random.randint(1, self.noop_max + 1)
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
if done:
obs = self.env.reset()
return obs
class FireResetEnv(gym.Wrapper):
def __init__(self, env=None):
"""For environments where the user need to press FIRE for the game to
start."""
super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def __init__(self, env=None):
"""For environments where the user need to press FIRE for the game to
start."""
super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
def _reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env=None):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
super(EpisodicLifeEnv, self).__init__(env)
self.lives = 0
self.was_real_done = True
self.was_real_reset = False
def __init__(self, env=None):
"""Make end-of-life == end-of-episode, but only reset on true game
over. Done by DeepMind for the DQN and co. since it helps value
estimation.
"""
super(EpisodicLifeEnv, self).__init__(env)
self.lives = 0
self.was_real_done = True
self.was_real_reset = False
def _step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def _step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few
# frames so its important to keep lives > 0, so that we only reset
# once the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def _reset(self):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
self.was_real_reset = True
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.was_real_reset = False
self.lives = self.env.unwrapped.ale.lives()
return obs
def _reset(self):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
self.was_real_reset = True
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.was_real_reset = False
self.lives = self.env.unwrapped.ale.lives()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
def _step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
def _step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
return max_frame, total_reward, done, info
def _reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
def _reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class ProcessFrame84(gym.ObservationWrapper):
def __init__(self, env=None):
super(ProcessFrame84, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
def __init__(self, env=None):
super(ProcessFrame84, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
def _observation(self, obs):
return ProcessFrame84.process(obs)
def _observation(self, obs):
return ProcessFrame84.process(obs)
@staticmethod
def process(frame):
if frame.size == 210 * 160 * 3:
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
elif frame.size == 250 * 160 * 3:
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
@staticmethod
def process(frame):
if frame.size == 210 * 160 * 3:
img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
elif frame.size == 250 * 160 * 3:
img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
else:
assert False, "Unknown resolution."
img = (img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 +
img[:, :, 2] * 0.114)
resized_screen = cv2.resize(img, (84, 110),
interpolation=cv2.INTER_AREA)
x_t = resized_screen[18:102, :]
x_t = np.reshape(x_t, [84, 84, 1])
return x_t.astype(np.uint8)
class ClippedRewardsWrapper(gym.RewardWrapper):
def _reward(self, reward):
"""Change all the positive rewards to 1, negative to -1 and keep zero."""
return np.sign(reward)
def _reward(self, reward):
"""Change all the positive rewards to 1, negative to -1 and keep
zero."""
return np.sign(reward)
class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only
stored once. It exists purely to optimize memory usage which can be huge
for DQN's 1M frames replay buffers.
def __init__(self, frames):
"""This object ensures that common frames between the observations are
only stored once. It exists purely to optimize memory usage which can
be huge for DQN's 1M frames replay buffers.
This object should only be converted to numpy array before being passed to
the model.
This object should only be converted to numpy array before being passed
to the model.
You'd not belive how complex the previous solution was."""
self._frames = frames
You'd not belive how complex the previous solution was."""
self._frames = frames
def __array__(self, dtype=None):
out = np.concatenate(self._frames, axis=2)
if dtype is not None:
out = out.astype(dtype)
return out
def __array__(self, dtype=None):
out = np.concatenate(self._frames, axis=2)
if dtype is not None:
out = out.astype(dtype)
return out
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames.
def __init__(self, env, k):
"""Stack k last frames.
Returns lazy array, which is much more memory efficient.
Returns lazy array, which is much more memory efficient.
See Also
--------
ray.rllib.dqn.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
See Also
--------
ray.rllib.dqn.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k))
def _reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()
def _reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
def _step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))
def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))
class ScaledFloatFrame(gym.ObservationWrapper):
def _observation(self, obs):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(obs).astype(np.float32) / 255.0
def _observation(self, obs):
# careful! This undoes the memory optimization, use
# with smaller replay buffers only.
return np.array(obs).astype(np.float32) / 255.0
def wrap_dqn(env):
"""Apply a common set of wrappers for Atari games."""
assert 'NoFrameskip' in env.spec.id
env = EpisodicLifeEnv(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ProcessFrame84(env)
env = FrameStack(env, 4)
env = ClippedRewardsWrapper(env)
return env
"""Apply a common set of wrappers for Atari games."""
assert 'NoFrameskip' in env.spec.id
env = EpisodicLifeEnv(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ProcessFrame84(env)
env = FrameStack(env, 4)
env = ClippedRewardsWrapper(env)
return env
class A2cProcessFrame(gym.Wrapper):
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
def __init__(self, env):
gym.Wrapper.__init__(self, env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1))
def _step(self, action):
ob, reward, done, info = self.env.step(action)
return A2cProcessFrame.process(ob), reward, done, info
def _step(self, action):
ob, reward, done, info = self.env.step(action)
return A2cProcessFrame.process(ob), reward, done, info
def _reset(self):
return A2cProcessFrame.process(self.env.reset())
def _reset(self):
return A2cProcessFrame.process(self.env.reset())
@staticmethod
def process(frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
return frame.reshape(84, 84, 1)
@staticmethod
def process(frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)
return frame.reshape(84, 84, 1)
+73 -72
View File
@@ -13,95 +13,96 @@ from __future__ import print_function
class Schedule(object):
def value(self, t):
"""Value of the schedule at time t"""
raise NotImplementedError()
def value(self, t):
"""Value of the schedule at time t"""
raise NotImplementedError()
class ConstantSchedule(object):
def __init__(self, value):
"""Value remains constant over time.
def __init__(self, value):
"""Value remains constant over time.
Parameters
----------
value: float
Constant value of the schedule
"""
self._v = value
Parameters
----------
value: float
Constant value of the schedule
"""
self._v = value
def value(self, t):
"""See Schedule.value"""
return self._v
def value(self, t):
"""See Schedule.value"""
return self._v
def linear_interpolation(l, r, alpha):
return l + alpha * (r - l)
return l + alpha * (r - l)
class PiecewiseSchedule(object):
def __init__(
self, endpoints, interpolation=linear_interpolation,
outside_value=None):
def __init__(
self, endpoints, interpolation=linear_interpolation,
outside_value=None):
"""Piecewise schedule.
"""Piecewise schedule.
endpoints: [(int, int)]
list of pairs `(time, value)` meanining that schedule should output
`value` when `t==time`. All the values for time must be sorted in
an increasing order. When t is between two times, e.g.
`(time_a, value_a)`
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value
outputs `interpolation(value_a, value_b, alpha)` where alpha is a
fraction of time passed between `time_a` and `time_b` for time `t`.
interpolation: lambda float, float, float: float
a function that takes value to the left and to the right of t according
to the `endpoints`. Alpha is the fraction of distance from left endpoint
to right endpoint that t has covered. See linear_interpolation for
example.
outside_value: float
if the value is requested outside of all the intervals sepecified in
`endpoints` this value is returned. If None then AssertionError is
raised when outside value is requested.
"""
idxes = [e[0] for e in endpoints]
assert idxes == sorted(idxes)
self._interpolation = interpolation
self._outside_value = outside_value
self._endpoints = endpoints
endpoints: [(int, int)]
list of pairs `(time, value)` meanining that schedule should output
`value` when `t==time`. All the values for time must be sorted in
an increasing order. When t is between two times, e.g.
`(time_a, value_a)`
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value
outputs `interpolation(value_a, value_b, alpha)` where alpha is a
fraction of time passed between `time_a` and `time_b` for time `t`.
interpolation: lambda float, float, float: float
a function that takes value to the left and to the right of t
according to the `endpoints`. Alpha is the fraction of distance from
left endpoint to right endpoint that t has covered. See
linear_interpolation for example.
outside_value: float
if the value is requested outside of all the intervals sepecified in
`endpoints` this value is returned. If None then AssertionError is
raised when outside value is requested.
"""
idxes = [e[0] for e in endpoints]
assert idxes == sorted(idxes)
self._interpolation = interpolation
self._outside_value = outside_value
self._endpoints = endpoints
def value(self, t):
"""See Schedule.value"""
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]):
if l_t <= t and t < r_t:
alpha = float(t - l_t) / (r_t - l_t)
return self._interpolation(l, r, alpha)
def value(self, t):
"""See Schedule.value"""
for (l_t, l), (r_t, r) in zip(self._endpoints[:-1],
self._endpoints[1:]):
if l_t <= t and t < r_t:
alpha = float(t - l_t) / (r_t - l_t)
return self._interpolation(l, r, alpha)
# t does not belong to any of the pieces, so doom.
assert self._outside_value is not None
return self._outside_value
# t does not belong to any of the pieces, so doom.
assert self._outside_value is not None
return self._outside_value
class LinearSchedule(object):
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
"""Linear interpolation between initial_p and final_p over
schedule_timesteps. After this many timesteps pass final_p is
returned.
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
"""Linear interpolation between initial_p and final_p over
schedule_timesteps. After this many timesteps pass final_p is
returned.
Parameters
----------
schedule_timesteps: int
Number of timesteps for which to linearly anneal initial_p
to final_p
initial_p: float
initial output value
final_p: float
final output value
"""
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
Parameters
----------
schedule_timesteps: int
Number of timesteps for which to linearly anneal initial_p
to final_p
initial_p: float
initial output value
final_p: float
final output value
"""
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
def value(self, t):
"""See Schedule.value"""
fraction = min(float(t) / self.schedule_timesteps, 1.0)
return self.initial_p + fraction * (self.final_p - self.initial_p)
def value(self, t):
"""See Schedule.value"""
fraction = min(float(t) / self.schedule_timesteps, 1.0)
return self.initial_p + fraction * (self.final_p - self.initial_p)
+120 -118
View File
@@ -6,146 +6,148 @@ import operator
class SegmentTree(object):
def __init__(self, capacity, operation, neutral_element):
"""Build a Segment Tree data structure.
def __init__(self, capacity, operation, neutral_element):
"""Build a Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two
important differences:
Can be used as regular array, but with two
important differences:
a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1).
b) user has access to an efficient `reduce`
operation which reduces `operation` over
a contiguous subsequence of items in the
array.
a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1).
b) user has access to an efficient `reduce`
operation which reduces `operation` over
a contiguous subsequence of items in the
array.
Paramters
---------
capacity: int
Total size of the array - must be a power of two.
operation: lambda obj, obj -> obj
and operation for combining elements (eg. sum, max)
must for a mathematical group together with the set of
possible values for array elements.
neutral_element: obj
neutral element for the operation above. eg. float('-inf')
for max and 0 for sum.
"""
Paramters
---------
capacity: int
Total size of the array - must be a power of two.
operation: lambda obj, obj -> obj
and operation for combining elements (eg. sum, max)
must for a mathematical group together with the set of
possible values for array elements.
neutral_element: obj
neutral element for the operation above. eg. float('-inf')
for max and 0 for sum.
"""
assert capacity > 0 and capacity & (capacity - 1) == 0, \
"capacity must be positive and a power of 2."
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
assert capacity > 0 and capacity & (capacity - 1) == 0, \
"capacity must be positive and a power of 2."
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self._value[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._reduce_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
else:
return self._operation(
self._reduce_helper(start, mid, 2 * node, node_start, mid),
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
)
def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self._value[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._reduce_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._reduce_helper(start, end, 2 * node + 1, mid + 1,
node_end)
else:
return self._operation(
self._reduce_helper(start, mid, 2 * node, node_start, mid),
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1,
node_end)
)
def reduce(self, start=0, end=None):
"""Returns result of applying `self.operation`
to a contiguous subsequence of the array.
def reduce(self, start=0, end=None):
"""Returns result of applying `self.operation`
to a contiguous subsequence of the array.
self.operation(
arr[start], operation(arr[start+1], operation(... arr[end])))
self.operation(
arr[start], operation(arr[start+1], operation(... arr[end])))
Parameters
----------
start: int
beginning of the subsequence
end: int
end of the subsequences
Parameters
----------
start: int
beginning of the subsequence
end: int
end of the subsequences
Returns
-------
reduced: obj
result of reducing self.operation over the specified range of array
elements.
"""
if end is None:
end = self._capacity
if end < 0:
end += self._capacity
end -= 1
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
Returns
-------
reduced: obj
result of reducing self.operation over the specified range of array
elements.
"""
if end is None:
end = self._capacity
if end < 0:
end += self._capacity
end -= 1
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
def __setitem__(self, idx, val):
# index of the leaf
idx += self._capacity
self._value[idx] = val
idx //= 2
while idx >= 1:
self._value[idx] = self._operation(
self._value[2 * idx],
self._value[2 * idx + 1])
idx //= 2
def __setitem__(self, idx, val):
# index of the leaf
idx += self._capacity
self._value[idx] = val
idx //= 2
while idx >= 1:
self._value[idx] = self._operation(
self._value[2 * idx],
self._value[2 * idx + 1])
idx //= 2
def __getitem__(self, idx):
assert 0 <= idx < self._capacity
return self._value[self._capacity + idx]
def __getitem__(self, idx):
assert 0 <= idx < self._capacity
return self._value[self._capacity + idx]
class SumSegmentTree(SegmentTree):
def __init__(self, capacity):
super(SumSegmentTree, self).__init__(
capacity=capacity,
operation=operator.add,
neutral_element=0.0)
def __init__(self, capacity):
super(SumSegmentTree, self).__init__(
capacity=capacity,
operation=operator.add,
neutral_element=0.0)
def sum(self, start=0, end=None):
"""Returns arr[start] + ... + arr[end]"""
return super(SumSegmentTree, self).reduce(start, end)
def sum(self, start=0, end=None):
"""Returns arr[start] + ... + arr[end]"""
return super(SumSegmentTree, self).reduce(start, end)
def find_prefixsum_idx(self, prefixsum):
"""Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
def find_prefixsum_idx(self, prefixsum):
"""Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
Parameters
----------
perfixsum: float
upperbound on the sum of array prefix
Parameters
----------
perfixsum: float
upperbound on the sum of array prefix
Returns
-------
idx: int
highest index satisfying the prefixsum constraint
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
idx = 1
while idx < self._capacity: # while non-leaf
if self._value[2 * idx] > prefixsum:
idx = 2 * idx
else:
prefixsum -= self._value[2 * idx]
idx = 2 * idx + 1
return idx - self._capacity
Returns
-------
idx: int
highest index satisfying the prefixsum constraint
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
idx = 1
while idx < self._capacity: # while non-leaf
if self._value[2 * idx] > prefixsum:
idx = 2 * idx
else:
prefixsum -= self._value[2 * idx]
idx = 2 * idx + 1
return idx - self._capacity
class MinSegmentTree(SegmentTree):
def __init__(self, capacity):
super(MinSegmentTree, self).__init__(
capacity=capacity,
operation=min,
neutral_element=float('inf'))
def __init__(self, capacity):
super(MinSegmentTree, self).__init__(
capacity=capacity,
operation=min,
neutral_element=float('inf'))
def min(self, start=0, end=None):
"""Returns min(arr[start], ..., arr[end])"""
def min(self, start=0, end=None):
"""Returns min(arr[start], ..., arr[end])"""
return super(MinSegmentTree, self).reduce(start, end)
return super(MinSegmentTree, self).reduce(start, end)
File diff suppressed because it is too large Load Diff
+128 -119
View File
@@ -88,132 +88,141 @@ DEFAULT_CONFIG = dict(
class DQN(Algorithm):
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "DQN"})
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
env = gym.make(env_name)
env = ScaledFloatFrame(wrap_dqn(env))
self.env = env
model = models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256], dueling=True)
sess = U.make_session(num_cpu=config["num_cpu"])
sess.__enter__()
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "DQN"})
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
env = gym.make(env_name)
env = ScaledFloatFrame(wrap_dqn(env))
self.env = env
model = models.cnn_to_mlp(
convs=[(32, 8, 4), (64, 4, 2), (64, 3, 1)],
hiddens=[256], dueling=True)
sess = U.make_session(num_cpu=config["num_cpu"])
sess.__enter__()
def make_obs_ph(name):
return U.BatchInput(env.observation_space.shape, name=name)
def make_obs_ph(name):
return U.BatchInput(env.observation_space.shape, name=name)
self.act, self.optimize, self.update_target, self.debug = build_train(
make_obs_ph=make_obs_ph,
q_func=model,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=config["lr"]),
gamma=config["gamma"],
grad_norm_clipping=10)
# Create the replay buffer
if config["prioritized_replay"]:
self.replay_buffer = PrioritizedReplayBuffer(
config["buffer_size"], alpha=config["prioritized_replay_alpha"])
prioritized_replay_beta_iters = config["prioritized_replay_beta_iters"]
if prioritized_replay_beta_iters is None:
prioritized_replay_beta_iters = config["schedule_max_timesteps"]
self.beta_schedule = LinearSchedule(
prioritized_replay_beta_iters,
initial_p=config["prioritized_replay_beta0"],
final_p=1.0)
else:
self.replay_buffer = ReplayBuffer(config["buffer_size"])
self.beta_schedule = None
# Create the schedule for exploration starting from 1.
self.exploration = LinearSchedule(
schedule_timesteps=int(
config["exploration_fraction"] * config["schedule_max_timesteps"]),
initial_p=1.0,
final_p=config["exploration_final_eps"])
# Initialize the parameters and copy them to the target network.
U.initialize()
self.update_target()
self.episode_rewards = [0.0]
self.episode_lengths = [0.0]
self.saved_mean_reward = None
self.obs = self.env.reset()
self.num_timesteps = 0
self.num_iterations = 0
def train(self):
config = self.config
sample_time, learn_time = 0, 0
for t in range(config["timesteps_per_iteration"]):
self.num_timesteps += 1
dt = time.time()
# Take action and update exploration to the newest value
action = self.act(
np.array(self.obs)[None], update_eps=self.exploration.value(t))[0]
new_obs, rew, done, _ = self.env.step(action)
# Store transition in the replay buffer.
self.replay_buffer.add(self.obs, action, rew, new_obs, float(done))
self.obs = new_obs
self.episode_rewards[-1] += rew
self.episode_lengths[-1] += 1
if done:
self.obs = self.env.reset()
self.episode_rewards.append(0.0)
self.episode_lengths.append(0.0)
sample_time += time.time() - dt
if self.num_timesteps > config["learning_starts"] and \
self.num_timesteps % config["train_freq"] == 0:
dt = time.time()
# Minimize the error in Bellman's equation on a batch sampled from
# replay buffer.
self.act, self.optimize, self.update_target, self.debug = build_train(
make_obs_ph=make_obs_ph,
q_func=model,
num_actions=env.action_space.n,
optimizer=tf.train.AdamOptimizer(learning_rate=config["lr"]),
gamma=config["gamma"],
grad_norm_clipping=10)
# Create the replay buffer
if config["prioritized_replay"]:
experience = self.replay_buffer.sample(
config["batch_size"], beta=self.beta_schedule.value(t))
(obses_t, actions, rewards, obses_tp1,
dones, _, batch_idxes) = experience
self.replay_buffer = PrioritizedReplayBuffer(
config["buffer_size"],
alpha=config["prioritized_replay_alpha"])
prioritized_replay_beta_iters = (
config["prioritized_replay_beta_iters"])
if prioritized_replay_beta_iters is None:
prioritized_replay_beta_iters = (
config["schedule_max_timesteps"])
self.beta_schedule = LinearSchedule(
prioritized_replay_beta_iters,
initial_p=config["prioritized_replay_beta0"],
final_p=1.0)
else:
obses_t, actions, rewards, obses_tp1, dones = \
self.replay_buffer.sample(config["batch_size"])
batch_idxes = None
td_errors = self.optimize(
obses_t, actions, rewards, obses_tp1, dones, np.ones_like(rewards))
if config["prioritized_replay"]:
new_priorities = np.abs(td_errors) + config["prioritized_replay_eps"]
self.replay_buffer.update_priorities(batch_idxes, new_priorities)
learn_time += (time.time() - dt)
self.replay_buffer = ReplayBuffer(config["buffer_size"])
self.beta_schedule = None
# Create the schedule for exploration starting from 1.
self.exploration = LinearSchedule(
schedule_timesteps=int(
config["exploration_fraction"] *
config["schedule_max_timesteps"]),
initial_p=1.0,
final_p=config["exploration_final_eps"])
if self.num_timesteps > config["learning_starts"] and \
self.num_timesteps % config["target_network_update_freq"] == 0:
# Update target network periodically.
# Initialize the parameters and copy them to the target network.
U.initialize()
self.update_target()
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
num_episodes = len(self.episode_rewards)
self.episode_rewards = [0.0]
self.episode_lengths = [0.0]
self.saved_mean_reward = None
self.obs = self.env.reset()
self.num_timesteps = 0
self.num_iterations = 0
info = {
"sample_time": sample_time,
"learn_time": learn_time,
"steps": self.num_timesteps,
"episodes": num_episodes,
"exploration": int(100 * self.exploration.value(t))
}
def train(self):
config = self.config
sample_time, learn_time = 0, 0
logger.record_tabular("sample_time", sample_time)
logger.record_tabular("learn_time", learn_time)
logger.record_tabular("steps", self.num_timesteps)
logger.record_tabular("episodes", num_episodes)
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
logger.record_tabular(
"% time spent exploring", int(100 * self.exploration.value(t)))
logger.dump_tabular()
for t in range(config["timesteps_per_iteration"]):
self.num_timesteps += 1
dt = time.time()
# Take action and update exploration to the newest value
action = self.act(
np.array(self.obs)[None],
update_eps=self.exploration.value(t))[0]
new_obs, rew, done, _ = self.env.step(action)
# Store transition in the replay buffer.
self.replay_buffer.add(self.obs, action, rew, new_obs, float(done))
self.obs = new_obs
res = TrainingResult(
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
mean_100ep_length, info)
self.num_iterations += 1
return res
self.episode_rewards[-1] += rew
self.episode_lengths[-1] += 1
if done:
self.obs = self.env.reset()
self.episode_rewards.append(0.0)
self.episode_lengths.append(0.0)
sample_time += time.time() - dt
if self.num_timesteps > config["learning_starts"] and \
self.num_timesteps % config["train_freq"] == 0:
dt = time.time()
# Minimize the error in Bellman's equation on a batch sampled
# from replay buffer.
if config["prioritized_replay"]:
experience = self.replay_buffer.sample(
config["batch_size"], beta=self.beta_schedule.value(t))
(obses_t, actions, rewards, obses_tp1,
dones, _, batch_idxes) = experience
else:
obses_t, actions, rewards, obses_tp1, dones = \
self.replay_buffer.sample(config["batch_size"])
batch_idxes = None
td_errors = self.optimize(
obses_t, actions, rewards, obses_tp1, dones,
np.ones_like(rewards))
if config["prioritized_replay"]:
new_priorities = (np.abs(td_errors) +
config["prioritized_replay_eps"])
self.replay_buffer.update_priorities(batch_idxes,
new_priorities)
learn_time += (time.time() - dt)
if (self.num_timesteps > config["learning_starts"] and
self.num_timesteps %
config["target_network_update_freq"] == 0):
# Update target network periodically.
self.update_target()
mean_100ep_reward = round(np.mean(self.episode_rewards[-101:-1]), 1)
mean_100ep_length = round(np.mean(self.episode_lengths[-101:-1]), 1)
num_episodes = len(self.episode_rewards)
info = {
"sample_time": sample_time,
"learn_time": learn_time,
"steps": self.num_timesteps,
"episodes": num_episodes,
"exploration": int(100 * self.exploration.value(t))
}
logger.record_tabular("sample_time", sample_time)
logger.record_tabular("learn_time", learn_time)
logger.record_tabular("steps", self.num_timesteps)
logger.record_tabular("episodes", num_episodes)
logger.record_tabular("mean 100 episode reward", mean_100ep_reward)
logger.record_tabular(
"% time spent exploring", int(100 * self.exploration.value(t)))
logger.dump_tabular()
res = TrainingResult(
self.experiment_id.hex, self.num_iterations, mean_100ep_reward,
mean_100ep_length, info)
self.num_iterations += 1
return res
+2 -2
View File
@@ -24,8 +24,8 @@ def main():
dqn = DQN("PongNoFrameskip-v4", config)
while True:
res = dqn.train()
print("current status: {}".format(res))
res = dqn.train()
print("current status: {}".format(res))
if __name__ == '__main__':
+173 -172
View File
@@ -29,88 +29,88 @@ DISABLED = 50
class OutputFormat(object):
def writekvs(self, kvs):
"""
Write key-value pairs
"""
raise NotImplementedError
def writekvs(self, kvs):
"""
Write key-value pairs
"""
raise NotImplementedError
def writeseq(self, args):
"""
Write a sequence of other data (e.g. a logging message)
"""
pass
def writeseq(self, args):
"""
Write a sequence of other data (e.g. a logging message)
"""
pass
def close(self):
return
def close(self):
return
class HumanOutputFormat(OutputFormat):
def __init__(self, file):
self.file = file
def __init__(self, file):
self.file = file
def writekvs(self, kvs):
# Create strings for printing
key2str = OrderedDict()
for (key, val) in kvs.items():
valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
key2str[self._truncate(key)] = self._truncate(valstr)
def writekvs(self, kvs):
# Create strings for printing
key2str = OrderedDict()
for (key, val) in kvs.items():
valstr = '%-8.3g' % (val,) if hasattr(val, '__float__') else val
key2str[self._truncate(key)] = self._truncate(valstr)
# Find max widths
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Find max widths
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write out the data
dashes = '-' * (keywidth + valwidth + 7)
lines = [dashes]
for (key, val) in key2str.items():
lines.append('| %s%s | %s%s |' % (
key,
' ' * (keywidth - len(key)),
val,
' ' * (valwidth - len(val)),
))
lines.append(dashes)
self.file.write('\n'.join(lines) + '\n')
# Write out the data
dashes = '-' * (keywidth + valwidth + 7)
lines = [dashes]
for (key, val) in key2str.items():
lines.append('| %s%s | %s%s |' % (
key,
' ' * (keywidth - len(key)),
val,
' ' * (valwidth - len(val)),
))
lines.append(dashes)
self.file.write('\n'.join(lines) + '\n')
# Flush the output to the file
self.file.flush()
# Flush the output to the file
self.file.flush()
def _truncate(self, s):
return s[:20] + '...' if len(s) > 23 else s
def _truncate(self, s):
return s[:20] + '...' if len(s) > 23 else s
def writeseq(self, args):
for arg in args:
self.file.write(arg)
self.file.write('\n')
self.file.flush()
def writeseq(self, args):
for arg in args:
self.file.write(arg)
self.file.write('\n')
self.file.flush()
class JSONOutputFormat(OutputFormat):
def __init__(self, file):
self.file = file
def __init__(self, file):
self.file = file
def writekvs(self, kvs):
for k, v in kvs.items():
if hasattr(v, 'dtype'):
v = v.tolist()
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + '\n')
self.file.flush()
def writekvs(self, kvs):
for k, v in kvs.items():
if hasattr(v, 'dtype'):
v = v.tolist()
kvs[k] = float(v)
self.file.write(json.dumps(kvs) + '\n')
self.file.flush()
def make_output_format(format, ev_dir):
os.makedirs(ev_dir, exist_ok=True)
if format == 'stdout':
return HumanOutputFormat(sys.stdout)
elif format == 'log':
log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
return HumanOutputFormat(log_file)
elif format == 'json':
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
return JSONOutputFormat(json_file)
else:
raise ValueError('Unknown format specified: %s' % (format,))
os.makedirs(ev_dir, exist_ok=True)
if format == 'stdout':
return HumanOutputFormat(sys.stdout)
elif format == 'log':
log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
return HumanOutputFormat(log_file)
elif format == 'json':
json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
return JSONOutputFormat(json_file)
else:
raise ValueError('Unknown format specified: %s' % (format,))
# ================================================================
# API
@@ -118,21 +118,21 @@ def make_output_format(format, ev_dir):
def logkv(key, val):
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
"""
Logger.CURRENT.logkv(key, val)
"""
Log a value of some diagnostic
Call this once for each diagnostic quantity, each iteration
"""
Logger.CURRENT.logkv(key, val)
def dumpkvs():
"""
Write all of the diagnostics from the current iteration
"""
Write all of the diagnostics from the current iteration
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
"""
Logger.CURRENT.dumpkvs()
level: int. (see logger.py docs) If the global logger level is higher than
the level argument here, don't print to stdout.
"""
Logger.CURRENT.dumpkvs()
# for backwards compatibility
@@ -141,49 +141,50 @@ dump_tabular = dumpkvs
def log(*args, level=INFO):
"""
Write the sequence of args, with no separators, to the console and output
files (if you've configured an output file).
"""
Logger.CURRENT.log(*args, level=level)
"""
Write the sequence of args, with no separators, to the console and output
files (if you've configured an output file).
"""
Logger.CURRENT.log(*args, level=level)
def debug(*args):
log(*args, level=DEBUG)
log(*args, level=DEBUG)
def info(*args):
log(*args, level=INFO)
log(*args, level=INFO)
def warn(*args):
log(*args, level=WARN)
log(*args, level=WARN)
def error(*args):
log(*args, level=ERROR)
log(*args, level=ERROR)
def set_level(level):
"""
Set logging threshold on current logger.
"""
Logger.CURRENT.set_level(level)
"""
Set logging threshold on current logger.
"""
Logger.CURRENT.set_level(level)
def get_dir():
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
"""
return Logger.CURRENT.get_dir()
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call
start)
"""
return Logger.CURRENT.get_dir()
def get_expt_dir():
sys.stderr.write(
"get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" %
(get_dir(),))
return get_dir()
sys.stderr.write(
"get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" %
(get_dir(),))
return get_dir()
# ================================================================
@@ -192,50 +193,50 @@ def get_expt_dir():
class Logger(object):
# A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output
DEFAULT = None
# A logger with no output files. (See right below class definition)
# So that you can still log to the terminal without setting up any output
DEFAULT = None
# Current logger being used by the free functions above
CURRENT = None
# Current logger being used by the free functions above
CURRENT = None
def __init__(self, dir, output_formats):
self.name2val = OrderedDict() # values this iteration
self.level = INFO
self.dir = dir
self.output_formats = output_formats
def __init__(self, dir, output_formats):
self.name2val = OrderedDict() # values this iteration
self.level = INFO
self.dir = dir
self.output_formats = output_formats
# Logging API, forwarded
# ----------------------------------------
def logkv(self, key, val):
self.name2val[key] = val
# Logging API, forwarded
# ----------------------------------------
def logkv(self, key, val):
self.name2val[key] = val
def dumpkvs(self):
for fmt in self.output_formats:
fmt.writekvs(self.name2val)
self.name2val.clear()
def dumpkvs(self):
for fmt in self.output_formats:
fmt.writekvs(self.name2val)
self.name2val.clear()
def log(self, *args, level=INFO):
if self.level <= level:
self._do_log(args)
def log(self, *args, level=INFO):
if self.level <= level:
self._do_log(args)
# Configuration
# ----------------------------------------
def set_level(self, level):
self.level = level
# Configuration
# ----------------------------------------
def set_level(self, level):
self.level = level
def get_dir(self):
return self.dir
def get_dir(self):
return self.dir
def close(self):
for fmt in self.output_formats:
fmt.close()
def close(self):
for fmt in self.output_formats:
fmt.close()
# Misc
# ----------------------------------------
def _do_log(self, args):
for fmt in self.output_formats:
fmt.writeseq(args)
# Misc
# ----------------------------------------
def _do_log(self, args):
for fmt in self.output_formats:
fmt.writeseq(args)
# ================================================================
@@ -246,60 +247,60 @@ Logger.CURRENT = Logger.DEFAULT
class session(object):
"""
Context manager that sets up the loggers for an experiment.
"""
"""
Context manager that sets up the loggers for an experiment.
"""
CURRENT = None # Set to a LoggerContext object using enter/exit or cm
CURRENT = None # Set to a LoggerContext object using enter/exit or cm
def __init__(self, dir, format_strs=None):
self.dir = dir
if format_strs is None:
format_strs = LOG_OUTPUT_FORMATS
output_formats = [make_output_format(f, dir) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
def __init__(self, dir, format_strs=None):
self.dir = dir
if format_strs is None:
format_strs = LOG_OUTPUT_FORMATS
output_formats = [make_output_format(f, dir) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
def __enter__(self):
os.makedirs(self.evaluation_dir(), exist_ok=True)
output_formats = [
make_output_format(
f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
def __enter__(self):
os.makedirs(self.evaluation_dir(), exist_ok=True)
output_formats = [
make_output_format(
f, self.evaluation_dir()) for f in LOG_OUTPUT_FORMATS]
Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
def __exit__(self, *args):
Logger.CURRENT.close()
Logger.CURRENT = Logger.DEFAULT
def __exit__(self, *args):
Logger.CURRENT.close()
Logger.CURRENT = Logger.DEFAULT
def evaluation_dir(self):
return self.dir
def evaluation_dir(self):
return self.dir
# ================================================================
def _demo():
info("hi")
debug("shouldn't appear")
set_level(DEBUG)
debug("should appear")
dir = "/tmp/testlogging"
if os.path.exists(dir):
shutil.rmtree(dir)
with session(dir=dir):
record_tabular("a", 3)
record_tabular("b", 2.5)
dump_tabular()
info("hi")
debug("shouldn't appear")
set_level(DEBUG)
debug("should appear")
dir = "/tmp/testlogging"
if os.path.exists(dir):
shutil.rmtree(dir)
with session(dir=dir):
record_tabular("a", 3)
record_tabular("b", 2.5)
dump_tabular()
record_tabular("b", -2.5)
record_tabular("a", 5.5)
dump_tabular()
info("^^^ should see a = 5.5")
record_tabular("b", -2.5)
record_tabular("a", 5.5)
dump_tabular()
info("^^^ should see a = 5.5")
record_tabular("b", -2.5)
dump_tabular()
record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
dump_tabular()
record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
dump_tabular()
if __name__ == "__main__":
_demo()
_demo()
+72 -71
View File
@@ -7,91 +7,92 @@ import tensorflow.contrib.layers as layers
def _mlp(hiddens, inpt, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
out = inpt
for hidden in hiddens:
out = layers.fully_connected(
out, num_outputs=hidden, activation_fn=tf.nn.relu)
out = layers.fully_connected(
out, num_outputs=num_actions, activation_fn=None)
return out
with tf.variable_scope(scope, reuse=reuse):
out = inpt
for hidden in hiddens:
out = layers.fully_connected(
out, num_outputs=hidden, activation_fn=tf.nn.relu)
out = layers.fully_connected(
out, num_outputs=num_actions, activation_fn=None)
return out
def mlp(hiddens=[]):
"""This model takes as input an observation and returns values of all
actions.
"""This model takes as input an observation and returns values of all
actions.
Parameters
----------
hiddens: [int]
list of sizes of hidden layers
Parameters
----------
hiddens: [int]
list of sizes of hidden layers
Returns
-------
q_func: function
q_function for DQN algorithm.
"""
return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs)
Returns
-------
q_func: function
q_function for DQN algorithm.
"""
return lambda *args, **kwargs: _mlp(hiddens, *args, **kwargs)
def _cnn_to_mlp(
convs, hiddens, dueling, inpt, num_actions, scope, reuse=False):
with tf.variable_scope(scope, reuse=reuse):
out = inpt
with tf.variable_scope("convnet"):
for num_outputs, kernel_size, stride in convs:
out = layers.convolution2d(
out,
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
activation_fn=tf.nn.relu)
out = layers.flatten(out)
with tf.variable_scope("action_value"):
action_out = out
for hidden in hiddens:
action_out = layers.fully_connected(
action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
action_scores = layers.fully_connected(
action_out, num_outputs=num_actions, activation_fn=None)
with tf.variable_scope(scope, reuse=reuse):
out = inpt
with tf.variable_scope("convnet"):
for num_outputs, kernel_size, stride in convs:
out = layers.convolution2d(
out,
num_outputs=num_outputs,
kernel_size=kernel_size,
stride=stride,
activation_fn=tf.nn.relu)
out = layers.flatten(out)
with tf.variable_scope("action_value"):
action_out = out
for hidden in hiddens:
action_out = layers.fully_connected(
action_out, num_outputs=hidden, activation_fn=tf.nn.relu)
action_scores = layers.fully_connected(
action_out, num_outputs=num_actions, activation_fn=None)
if dueling:
with tf.variable_scope("state_value"):
state_out = out
for hidden in hiddens:
state_out = layers.fully_connected(
state_out, num_outputs=hidden, activation_fn=tf.nn.relu)
state_score = layers.fully_connected(
state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(
action_scores_mean, 1)
return state_score + action_scores_centered
else:
return action_scores
return out
if dueling:
with tf.variable_scope("state_value"):
state_out = out
for hidden in hiddens:
state_out = layers.fully_connected(
state_out, num_outputs=hidden,
activation_fn=tf.nn.relu)
state_score = layers.fully_connected(
state_out, num_outputs=1, activation_fn=None)
action_scores_mean = tf.reduce_mean(action_scores, 1)
action_scores_centered = action_scores - tf.expand_dims(
action_scores_mean, 1)
return state_score + action_scores_centered
else:
return action_scores
return out
def cnn_to_mlp(convs, hiddens, dueling=False):
"""This model takes as input an observation and returns values of all actions.
"""This model takes an observation and returns values for all actions.
Parameters
----------
convs: [(int, int int)]
list of convolutional layers in form of
(num_outputs, kernel_size, stride)
hiddens: [int]
list of sizes of hidden layers
dueling: bool
if true double the output MLP to compute a baseline
for action scores
Parameters
----------
convs: [(int, int int)]
list of convolutional layers in form of
(num_outputs, kernel_size, stride)
hiddens: [int]
list of sizes of hidden layers
dueling: bool
if true double the output MLP to compute a baseline
for action scores
Returns
-------
q_func: function
q_function for DQN algorithm.
"""
Returns
-------
q_func: function
q_function for DQN algorithm.
"""
return lambda *args, **kwargs: _cnn_to_mlp(
convs, hiddens, dueling, *args, **kwargs)
return lambda *args, **kwargs: _cnn_to_mlp(
convs, hiddens, dueling, *args, **kwargs)
+157 -156
View File
@@ -9,188 +9,189 @@ from ray.rllib.dqn.common.segment_tree import SumSegmentTree, MinSegmentTree
class ReplayBuffer(object):
def __init__(self, size):
"""Create Prioritized Replay buffer.
def __init__(self, size):
"""Create Prioritized Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
"""
self._storage = []
self._maxsize = size
self._next_idx = 0
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
"""
self._storage = []
self._maxsize = size
self._next_idx = 0
def __len__(self):
return len(self._storage)
def __len__(self):
return len(self._storage)
def add(self, obs_t, action, reward, obs_tp1, done):
data = (obs_t, action, reward, obs_tp1, done)
def add(self, obs_t, action, reward, obs_tp1, done):
data = (obs_t, action, reward, obs_tp1, done)
if self._next_idx >= len(self._storage):
self._storage.append(data)
else:
self._storage[self._next_idx] = data
self._next_idx = (self._next_idx + 1) % self._maxsize
if self._next_idx >= len(self._storage):
self._storage.append(data)
else:
self._storage[self._next_idx] = data
self._next_idx = (self._next_idx + 1) % self._maxsize
def _encode_sample(self, idxes):
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
for i in idxes:
data = self._storage[i]
obs_t, action, reward, obs_tp1, done = data
obses_t.append(np.array(obs_t, copy=False))
actions.append(np.array(action, copy=False))
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
return np.array(obses_t), np.array(actions), np.array(rewards), \
np.array(obses_tp1), np.array(dones)
def _encode_sample(self, idxes):
obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], []
for i in idxes:
data = self._storage[i]
obs_t, action, reward, obs_tp1, done = data
obses_t.append(np.array(obs_t, copy=False))
actions.append(np.array(action, copy=False))
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
return (np.array(obses_t), np.array(actions), np.array(rewards),
np.array(obses_tp1), np.array(dones))
def sample(self, batch_size):
"""Sample a batch of experiences.
def sample(self, batch_size):
"""Sample a batch of experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
Parameters
----------
batch_size: int
How many transitions to sample.
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
"""
idxes = [random.randint(0, len(self._storage) - 1)
for _ in range(batch_size)]
return self._encode_sample(idxes)
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
"""
idxes = [random.randint(0, len(self._storage) - 1)
for _ in range(batch_size)]
return self._encode_sample(idxes)
class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, size, alpha):
"""Create Prioritized Replay buffer.
def __init__(self, size, alpha):
"""Create Prioritized Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
alpha: float
how much prioritization is used
(0 - no prioritization, 1 - full prioritization)
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
alpha: float
how much prioritization is used
(0 - no prioritization, 1 - full prioritization)
See Also
--------
ReplayBuffer.__init__
"""
super(PrioritizedReplayBuffer, self).__init__(size)
assert alpha > 0
self._alpha = alpha
See Also
--------
ReplayBuffer.__init__
"""
super(PrioritizedReplayBuffer, self).__init__(size)
assert alpha > 0
self._alpha = alpha
it_capacity = 1
while it_capacity < size:
it_capacity *= 2
it_capacity = 1
while it_capacity < size:
it_capacity *= 2
self._it_sum = SumSegmentTree(it_capacity)
self._it_min = MinSegmentTree(it_capacity)
self._max_priority = 1.0
self._it_sum = SumSegmentTree(it_capacity)
self._it_min = MinSegmentTree(it_capacity)
self._max_priority = 1.0
def add(self, *args, **kwargs):
"""See ReplayBuffer.store_effect"""
idx = self._next_idx
super().add(*args, **kwargs)
self._it_sum[idx] = self._max_priority ** self._alpha
self._it_min[idx] = self._max_priority ** self._alpha
def add(self, *args, **kwargs):
"""See ReplayBuffer.store_effect"""
idx = self._next_idx
super().add(*args, **kwargs)
self._it_sum[idx] = self._max_priority ** self._alpha
self._it_min[idx] = self._max_priority ** self._alpha
def _sample_proportional(self, batch_size):
res = []
for _ in range(batch_size):
# TODO(szymon): should we ensure no repeats?
mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
idx = self._it_sum.find_prefixsum_idx(mass)
res.append(idx)
return res
def _sample_proportional(self, batch_size):
res = []
for _ in range(batch_size):
# TODO(szymon): should we ensure no repeats?
mass = random.random() * self._it_sum.sum(0,
len(self._storage) - 1)
idx = self._it_sum.find_prefixsum_idx(mass)
res.append(idx)
return res
def sample(self, batch_size, beta):
"""Sample a batch of experiences.
def sample(self, batch_size, beta):
"""Sample a batch of experiences.
compared to ReplayBuffer.sample
it also returns importance weights and idxes
of sampled experiences.
compared to ReplayBuffer.sample
it also returns importance weights and idxes
of sampled experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
beta: float
To what degree to use importance weights
(0 - no corrections, 1 - full correction)
Parameters
----------
batch_size: int
How many transitions to sample.
beta: float
To what degree to use importance weights
(0 - no corrections, 1 - full correction)
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
weights: np.array
Array of shape (batch_size,) and dtype np.float32
denoting importance weight of each sampled transition
idxes: np.array
Array of shape (batch_size,) and dtype np.int32
idexes in buffer of sampled experiences
"""
assert beta > 0
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
weights: np.array
Array of shape (batch_size,) and dtype np.float32
denoting importance weight of each sampled transition
idxes: np.array
Array of shape (batch_size,) and dtype np.int32
idexes in buffer of sampled experiences
"""
assert beta > 0
idxes = self._sample_proportional(batch_size)
idxes = self._sample_proportional(batch_size)
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage)) ** (-beta)
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage)) ** (-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage)) ** (-beta)
weights.append(weight / max_weight)
weights = np.array(weights)
encoded_sample = self._encode_sample(idxes)
return tuple(list(encoded_sample) + [weights, idxes])
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage)) ** (-beta)
weights.append(weight / max_weight)
weights = np.array(weights)
encoded_sample = self._encode_sample(idxes)
return tuple(list(encoded_sample) + [weights, idxes])
def update_priorities(self, idxes, priorities):
"""Update priorities of sampled transitions.
def update_priorities(self, idxes, priorities):
"""Update priorities of sampled transitions.
sets priority of transition at index idxes[i] in buffer
to priorities[i].
sets priority of transition at index idxes[i] in buffer
to priorities[i].
Parameters
----------
idxes: [int]
List of idxes of sampled transitions
priorities: [float]
List of updated priorities corresponding to
transitions at the sampled idxes denoted by
variable `idxes`.
"""
assert len(idxes) == len(priorities)
for idx, priority in zip(idxes, priorities):
assert priority > 0
assert 0 <= idx < len(self._storage)
self._it_sum[idx] = priority ** self._alpha
self._it_min[idx] = priority ** self._alpha
Parameters
----------
idxes: [int]
List of idxes of sampled transitions
priorities: [float]
List of updated priorities corresponding to
transitions at the sampled idxes denoted by
variable `idxes`.
"""
assert len(idxes) == len(priorities)
for idx, priority in zip(idxes, priorities):
assert priority > 0
assert 0 <= idx < len(self._storage)
self._it_sum[idx] = priority ** self._alpha
self._it_min[idx] = priority ** self._alpha
self._max_priority = max(self._max_priority, priority)
self._max_priority = max(self._max_priority, priority)
@@ -43,256 +43,267 @@ DEFAULT_CONFIG = dict(
@ray.remote
def create_shared_noise():
"""Create a large array of noise to be shared by all workers."""
seed = 123
count = 250000000
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
return noise
"""Create a large array of noise to be shared by all workers."""
seed = 123
count = 250000000
noise = np.random.RandomState(seed).randn(count).astype(np.float32)
return noise
class SharedNoiseTable(object):
def __init__(self, noise):
self.noise = noise
assert self.noise.dtype == np.float32
def __init__(self, noise):
self.noise = noise
assert self.noise.dtype == np.float32
def get(self, i, dim):
return self.noise[i:i + dim]
def get(self, i, dim):
return self.noise[i:i + dim]
def sample_index(self, stream, dim):
return stream.randint(0, len(self.noise) - dim + 1)
def sample_index(self, stream, dim):
return stream.randint(0, len(self.noise) - dim + 1)
@ray.remote
class Worker(object):
def __init__(self, config, policy_params, env_name, noise,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.noise = SharedNoiseTable(noise)
def __init__(self, config, policy_params, env_name, noise,
min_task_runtime=0.2):
self.min_task_runtime = min_task_runtime
self.config = config
self.policy_params = policy_params
self.noise = SharedNoiseTable(noise)
self.env = gym.make(env_name)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.MujocoPolicy(self.env.observation_space,
self.env.action_space,
**policy_params)
tf_util.initialize()
self.env = gym.make(env_name)
self.sess = utils.make_session(single_threaded=True)
self.policy = policies.MujocoPolicy(self.env.observation_space,
self.env.action_space,
**policy_params)
tf_util.initialize()
self.rs = np.random.RandomState()
self.rs = np.random.RandomState()
assert self.policy.needs_ob_stat == (self.config["calc_obstat_prob"] != 0)
assert self.policy.needs_ob_stat == (self.config["calc_obstat_prob"] !=
0)
def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat):
if (self.policy.needs_ob_stat and self.config["calc_obstat_prob"] != 0 and
self.rs.rand() < self.config["calc_obstat_prob"]):
rollout_rews, rollout_len, obs = self.policy.rollout(
self.env, timestep_limit=timestep_limit, save_obs=True,
random_stream=self.rs)
task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0),
len(obs))
else:
rollout_rews, rollout_len = self.policy.rollout(
self.env, timestep_limit=timestep_limit, random_stream=self.rs)
return rollout_rews, rollout_len
def rollout_and_update_ob_stat(self, timestep_limit, task_ob_stat):
if (self.policy.needs_ob_stat and
self.config["calc_obstat_prob"] != 0 and
self.rs.rand() < self.config["calc_obstat_prob"]):
rollout_rews, rollout_len, obs = self.policy.rollout(
self.env, timestep_limit=timestep_limit, save_obs=True,
random_stream=self.rs)
task_ob_stat.increment(obs.sum(axis=0), np.square(obs).sum(axis=0),
len(obs))
else:
rollout_rews, rollout_len = self.policy.rollout(
self.env, timestep_limit=timestep_limit, random_stream=self.rs)
return rollout_rews, rollout_len
def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None):
# Set the network weights.
self.policy.set_trainable_flat(params)
def do_rollouts(self, params, ob_mean, ob_std, timestep_limit=None):
# Set the network weights.
self.policy.set_trainable_flat(params)
if self.policy.needs_ob_stat:
self.policy.set_ob_stat(ob_mean, ob_std)
if self.policy.needs_ob_stat:
self.policy.set_ob_stat(ob_mean, ob_std)
if self.config["eval_prob"] != 0:
raise NotImplementedError("Eval rollouts are not implemented.")
if self.config["eval_prob"] != 0:
raise NotImplementedError("Eval rollouts are not implemented.")
noise_inds, returns, sign_returns, lengths = [], [], [], []
# We set eps=0 because we're incrementing only.
task_ob_stat = utils.RunningStat(self.env.observation_space.shape, eps=0)
noise_inds, returns, sign_returns, lengths = [], [], [], []
# We set eps=0 because we're incrementing only.
task_ob_stat = utils.RunningStat(self.env.observation_space.shape,
eps=0)
# Perform some rollouts with noise.
task_tstart = time.time()
while (len(noise_inds) == 0 or
time.time() - task_tstart < self.min_task_runtime):
noise_idx = self.noise.sample_index(self.rs, self.policy.num_params)
perturbation = self.config["noise_stdev"] * self.noise.get(
noise_idx, self.policy.num_params)
# Perform some rollouts with noise.
task_tstart = time.time()
while (len(noise_inds) == 0 or
time.time() - task_tstart < self.min_task_runtime):
noise_idx = self.noise.sample_index(self.rs,
self.policy.num_params)
perturbation = self.config["noise_stdev"] * self.noise.get(
noise_idx, self.policy.num_params)
# These two sampling steps could be done in parallel on different actors
# letting us update twice as frequently.
self.policy.set_trainable_flat(params + perturbation)
rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit,
task_ob_stat)
# These two sampling steps could be done in parallel on different
# actors letting us update twice as frequently.
self.policy.set_trainable_flat(params + perturbation)
rews_pos, len_pos = self.rollout_and_update_ob_stat(timestep_limit,
task_ob_stat)
self.policy.set_trainable_flat(params - perturbation)
rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit,
task_ob_stat)
self.policy.set_trainable_flat(params - perturbation)
rews_neg, len_neg = self.rollout_and_update_ob_stat(timestep_limit,
task_ob_stat)
noise_inds.append(noise_idx)
returns.append([rews_pos.sum(), rews_neg.sum()])
sign_returns.append([np.sign(rews_pos).sum(), np.sign(rews_neg).sum()])
lengths.append([len_pos, len_neg])
noise_inds.append(noise_idx)
returns.append([rews_pos.sum(), rews_neg.sum()])
sign_returns.append([np.sign(rews_pos).sum(),
np.sign(rews_neg).sum()])
lengths.append([len_pos, len_neg])
return Result(
noise_inds_n=np.array(noise_inds),
returns_n2=np.array(returns, dtype=np.float32),
sign_returns_n2=np.array(sign_returns, dtype=np.float32),
lengths_n2=np.array(lengths, dtype=np.int32),
eval_return=None,
eval_length=None,
ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum),
ob_sumsq=(None if task_ob_stat.count == 0 else task_ob_stat.sumsq),
ob_count=task_ob_stat.count)
return Result(
noise_inds_n=np.array(noise_inds),
returns_n2=np.array(returns, dtype=np.float32),
sign_returns_n2=np.array(sign_returns, dtype=np.float32),
lengths_n2=np.array(lengths, dtype=np.int32),
eval_return=None,
eval_length=None,
ob_sum=(None if task_ob_stat.count == 0 else task_ob_stat.sum),
ob_sumsq=(None if task_ob_stat.count == 0
else task_ob_stat.sumsq),
ob_count=task_ob_stat.count)
class EvolutionStrategies(Algorithm):
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "EvolutionStrategies"})
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "EvolutionStrategies"})
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
policy_params = {
"ac_bins": "continuous:",
"ac_noise_std": 0.01,
"nonlin_type": "tanh",
"hidden_dims": [256, 256],
"connection_type": "ff"
}
policy_params = {
"ac_bins": "continuous:",
"ac_noise_std": 0.01,
"nonlin_type": "tanh",
"hidden_dims": [256, 256],
"connection_type": "ff"
}
# Create the shared noise table.
print("Creating shared noise table.")
noise_id = create_shared_noise.remote()
self.noise = SharedNoiseTable(ray.get(noise_id))
# Create the shared noise table.
print("Creating shared noise table.")
noise_id = create_shared_noise.remote()
self.noise = SharedNoiseTable(ray.get(noise_id))
# Create the actors.
print("Creating actors.")
self.workers = [Worker.remote(config, policy_params, env_name, noise_id)
for _ in range(config["num_workers"])]
# Create the actors.
print("Creating actors.")
self.workers = [Worker.remote(config, policy_params, env_name,
noise_id)
for _ in range(config["num_workers"])]
env = gym.make(env_name)
utils.make_session(single_threaded=False)
self.policy = policies.MujocoPolicy(
env.observation_space, env.action_space, **policy_params)
tf_util.initialize()
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
env = gym.make(env_name)
utils.make_session(single_threaded=False)
self.policy = policies.MujocoPolicy(
env.observation_space, env.action_space, **policy_params)
tf_util.initialize()
self.optimizer = optimizers.Adam(self.policy, config["stepsize"])
self.ob_stat = utils.RunningStat(env.observation_space.shape, eps=1e-2)
self.episodes_so_far = 0
self.timesteps_so_far = 0
self.tstart = time.time()
self.iteration = 0
self.episodes_so_far = 0
self.timesteps_so_far = 0
self.tstart = time.time()
self.iteration = 0
def train(self):
config = self.config
def train(self):
config = self.config
step_tstart = time.time()
theta = self.policy.get_trainable_flat()
assert theta.dtype == np.float32
step_tstart = time.time()
theta = self.policy.get_trainable_flat()
assert theta.dtype == np.float32
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
# Use the actors to do rollouts, note that we pass in the ID of the policy
# weights.
rollout_ids = [worker.do_rollouts.remote(
theta_id,
self.ob_stat.mean if self.policy.needs_ob_stat else None,
self.ob_stat.std if self.policy.needs_ob_stat else None)
for worker in self.workers]
# Put the current policy weights in the object store.
theta_id = ray.put(theta)
# Use the actors to do rollouts, note that we pass in the ID of the
# policy weights.
rollout_ids = [worker.do_rollouts.remote(
theta_id,
self.ob_stat.mean if self.policy.needs_ob_stat else None,
self.ob_stat.std if self.policy.needs_ob_stat else None)
for worker in self.workers]
# Get the results of the rollouts.
results = ray.get(rollout_ids)
# Get the results of the rollouts.
results = ray.get(rollout_ids)
curr_task_results = []
ob_count_this_batch = 0
# Loop over the results
for result in results:
assert result.eval_length is None, "We aren't doing eval rollouts."
assert result.noise_inds_n.ndim == 1
assert result.returns_n2.shape == (len(result.noise_inds_n), 2)
assert result.lengths_n2.shape == (len(result.noise_inds_n), 2)
assert result.returns_n2.dtype == np.float32
curr_task_results = []
ob_count_this_batch = 0
# Loop over the results
for result in results:
assert result.eval_length is None, "We aren't doing eval rollouts."
assert result.noise_inds_n.ndim == 1
assert result.returns_n2.shape == (len(result.noise_inds_n), 2)
assert result.lengths_n2.shape == (len(result.noise_inds_n), 2)
assert result.returns_n2.dtype == np.float32
result_num_eps = result.lengths_n2.size
result_num_timesteps = result.lengths_n2.sum()
self.episodes_so_far += result_num_eps
self.timesteps_so_far += result_num_timesteps
result_num_eps = result.lengths_n2.size
result_num_timesteps = result.lengths_n2.sum()
self.episodes_so_far += result_num_eps
self.timesteps_so_far += result_num_timesteps
curr_task_results.append(result)
# Update ob stats.
if self.policy.needs_ob_stat and result.ob_count > 0:
self.ob_stat.increment(result.ob_sum, result.ob_sumsq, result.ob_count)
ob_count_this_batch += result.ob_count
curr_task_results.append(result)
# Update ob stats.
if self.policy.needs_ob_stat and result.ob_count > 0:
self.ob_stat.increment(result.ob_sum, result.ob_sumsq,
result.ob_count)
ob_count_this_batch += result.ob_count
# Assemble the results.
noise_inds_n = np.concatenate([r.noise_inds_n for
r in curr_task_results])
returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results])
lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results])
assert noise_inds_n.shape[0] == returns_n2.shape[0] == lengths_n2.shape[0]
# Process the returns.
if config["return_proc_mode"] == "centered_rank":
proc_returns_n2 = utils.compute_centered_ranks(returns_n2)
else:
raise NotImplementedError(config["return_proc_mode"])
# Assemble the results.
noise_inds_n = np.concatenate([r.noise_inds_n for
r in curr_task_results])
returns_n2 = np.concatenate([r.returns_n2 for r in curr_task_results])
lengths_n2 = np.concatenate([r.lengths_n2 for r in curr_task_results])
assert (noise_inds_n.shape[0] == returns_n2.shape[0] ==
lengths_n2.shape[0])
# Process the returns.
if config["return_proc_mode"] == "centered_rank":
proc_returns_n2 = utils.compute_centered_ranks(returns_n2)
else:
raise NotImplementedError(config["return_proc_mode"])
# Compute and take a step.
g, count = utils.batched_weighted_sum(
proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
(self.noise.get(idx, self.policy.num_params) for idx in noise_inds_n),
batch_size=500)
g /= returns_n2.size
assert (g.shape == (self.policy.num_params,) and g.dtype == np.float32 and
count == len(noise_inds_n))
update_ratio = self.optimizer.update(-g + config["l2coeff"] * theta)
# Compute and take a step.
g, count = utils.batched_weighted_sum(
proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
(self.noise.get(idx, self.policy.num_params)
for idx in noise_inds_n),
batch_size=500)
g /= returns_n2.size
assert (g.shape == (self.policy.num_params,) and
g.dtype == np.float32 and
count == len(noise_inds_n))
update_ratio = self.optimizer.update(-g + config["l2coeff"] * theta)
# Update ob stat (we're never running the policy in the master, but we
# might be snapshotting the policy).
if self.policy.needs_ob_stat:
self.policy.set_ob_stat(self.ob_stat.mean, self.ob_stat.std)
# Update ob stat (we're never running the policy in the master, but we
# might be snapshotting the policy).
if self.policy.needs_ob_stat:
self.policy.set_ob_stat(self.ob_stat.mean, self.ob_stat.std)
step_tend = time.time()
tlogger.record_tabular("EpRewMean", returns_n2.mean())
tlogger.record_tabular("EpRewStd", returns_n2.std())
tlogger.record_tabular("EpLenMean", lengths_n2.mean())
step_tend = time.time()
tlogger.record_tabular("EpRewMean", returns_n2.mean())
tlogger.record_tabular("EpRewStd", returns_n2.std())
tlogger.record_tabular("EpLenMean", lengths_n2.mean())
tlogger.record_tabular(
"Norm", float(np.square(self.policy.get_trainable_flat()).sum()))
tlogger.record_tabular("GradNorm", float(np.square(g).sum()))
tlogger.record_tabular("UpdateRatio", float(update_ratio))
tlogger.record_tabular(
"Norm", float(np.square(self.policy.get_trainable_flat()).sum()))
tlogger.record_tabular("GradNorm", float(np.square(g).sum()))
tlogger.record_tabular("UpdateRatio", float(update_ratio))
tlogger.record_tabular("EpisodesThisIter", lengths_n2.size)
tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far)
tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum())
tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
tlogger.record_tabular("EpisodesThisIter", lengths_n2.size)
tlogger.record_tabular("EpisodesSoFar", self.episodes_so_far)
tlogger.record_tabular("TimestepsThisIter", lengths_n2.sum())
tlogger.record_tabular("TimestepsSoFar", self.timesteps_so_far)
tlogger.record_tabular("ObCount", ob_count_this_batch)
tlogger.record_tabular("ObCount", ob_count_this_batch)
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
tlogger.dump_tabular()
tlogger.record_tabular("TimeElapsedThisIter", step_tend - step_tstart)
tlogger.record_tabular("TimeElapsed", step_tend - self.tstart)
tlogger.dump_tabular()
if (config["snapshot_freq"] != 0 and
self.iteration % config["snapshot_freq"] == 0):
filename = os.path.join(
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
assert not os.path.exists(filename)
self.policy.save(filename)
tlogger.log("Saved snapshot {}".format(filename))
if (config["snapshot_freq"] != 0 and
self.iteration % config["snapshot_freq"] == 0):
filename = os.path.join(
self.logdir, "snapshot_iter{:05d}.h5".format(self.iteration))
assert not os.path.exists(filename)
self.policy.save(filename)
tlogger.log("Saved snapshot {}".format(filename))
info = {
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
"grad_norm": np.square(g).sum(),
"update_ratio": update_ratio,
"episodes_this_iter": lengths_n2.size,
"episodes_so_far": self.episodes_so_far,
"timesteps_this_iter": lengths_n2.sum(),
"timesteps_so_far": self.timesteps_so_far,
"ob_count": ob_count_this_batch,
"time_elapsed_this_iter": step_tend - step_tstart,
"time_elapsed": step_tend - self.tstart
}
res = TrainingResult(self.experiment_id.hex, self.iteration,
returns_n2.mean(), lengths_n2.mean(), info)
info = {
"weights_norm": np.square(self.policy.get_trainable_flat()).sum(),
"grad_norm": np.square(g).sum(),
"update_ratio": update_ratio,
"episodes_this_iter": lengths_n2.size,
"episodes_so_far": self.episodes_so_far,
"timesteps_this_iter": lengths_n2.sum(),
"timesteps_so_far": self.timesteps_so_far,
"ob_count": ob_count_this_batch,
"time_elapsed_this_iter": step_tend - step_tstart,
"time_elapsed": step_tend - self.tstart
}
res = TrainingResult(self.experiment_id.hex, self.iteration,
returns_n2.mean(), lengths_n2.mean(), info)
self.iteration += 1
self.iteration += 1
return res
return res
@@ -11,30 +11,30 @@ from ray.rllib.evolution_strategies import EvolutionStrategies, DEFAULT_CONFIG
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train an RL agent on Pong.")
parser.add_argument("--num-workers", default=10, type=int,
help=("The number of actors to create in aggregate "
"across the cluster."))
parser.add_argument("--env-name", default="Pendulum-v0", type=str,
help="The name of the gym environment to use.")
parser.add_argument("--stepsize", default=0.01, type=float,
help="The stepsize to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser = argparse.ArgumentParser(description="Train an RL agent on Pong.")
parser.add_argument("--num-workers", default=10, type=int,
help=("The number of actors to create in aggregate "
"across the cluster."))
parser.add_argument("--env-name", default="Pendulum-v0", type=str,
help="The name of the gym environment to use.")
parser.add_argument("--stepsize", default=0.01, type=float,
help="The stepsize to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
args = parser.parse_args()
num_workers = args.num_workers
env_name = args.env_name
stepsize = args.stepsize
args = parser.parse_args()
num_workers = args.num_workers
env_name = args.env_name
stepsize = args.stepsize
ray.init(redis_address=args.redis_address,
num_workers=(0 if args.redis_address is None else None))
ray.init(redis_address=args.redis_address,
num_workers=(0 if args.redis_address is None else None))
config = DEFAULT_CONFIG._replace(
num_workers=num_workers,
stepsize=stepsize)
config = DEFAULT_CONFIG._replace(
num_workers=num_workers,
stepsize=stepsize)
alg = EvolutionStrategies(env_name, config)
while True:
result = alg.train()
print("current status: {}".format(result))
alg = EvolutionStrategies(env_name, config)
while True:
result = alg.train()
print("current status: {}".format(result))
@@ -9,49 +9,49 @@ import numpy as np
class Optimizer(object):
def __init__(self, pi):
self.pi = pi
self.dim = pi.num_params
self.t = 0
def __init__(self, pi):
self.pi = pi
self.dim = pi.num_params
self.t = 0
def update(self, globalg):
self.t += 1
step = self._compute_step(globalg)
theta = self.pi.get_trainable_flat()
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
self.pi.set_trainable_flat(theta + step)
return ratio
def update(self, globalg):
self.t += 1
step = self._compute_step(globalg)
theta = self.pi.get_trainable_flat()
ratio = np.linalg.norm(step) / np.linalg.norm(theta)
self.pi.set_trainable_flat(theta + step)
return ratio
def _compute_step(self, globalg):
raise NotImplementedError
def _compute_step(self, globalg):
raise NotImplementedError
class SGD(Optimizer):
def __init__(self, pi, stepsize, momentum=0.9):
Optimizer.__init__(self, pi)
self.v = np.zeros(self.dim, dtype=np.float32)
self.stepsize, self.momentum = stepsize, momentum
def __init__(self, pi, stepsize, momentum=0.9):
Optimizer.__init__(self, pi)
self.v = np.zeros(self.dim, dtype=np.float32)
self.stepsize, self.momentum = stepsize, momentum
def _compute_step(self, globalg):
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
step = -self.stepsize * self.v
return step
def _compute_step(self, globalg):
self.v = self.momentum * self.v + (1. - self.momentum) * globalg
step = -self.stepsize * self.v
return step
class Adam(Optimizer):
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
Optimizer.__init__(self, pi)
self.stepsize = stepsize
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = np.zeros(self.dim, dtype=np.float32)
self.v = np.zeros(self.dim, dtype=np.float32)
def __init__(self, pi, stepsize, beta1=0.9, beta2=0.999, epsilon=1e-08):
Optimizer.__init__(self, pi)
self.stepsize = stepsize
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = np.zeros(self.dim, dtype=np.float32)
self.v = np.zeros(self.dim, dtype=np.float32)
def _compute_step(self, globalg):
a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) /
(1 - self.beta1 ** self.t))
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
return step
def _compute_step(self, globalg):
a = self.stepsize * (np.sqrt(1 - self.beta2 ** self.t) /
(1 - self.beta1 ** self.t))
self.m = self.beta1 * self.m + (1 - self.beta1) * globalg
self.v = self.beta2 * self.v + (1 - self.beta2) * (globalg * globalg)
step = -a * self.m / (np.sqrt(self.v) + self.epsilon)
return step
+198 -187
View File
@@ -18,224 +18,235 @@ logger = logging.getLogger(__name__)
class Policy:
def __init__(self, *args, **kwargs):
self.args, self.kwargs = args, kwargs
self.scope = self._initialize(*args, **kwargs)
self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
self.scope.name)
def __init__(self, *args, **kwargs):
self.args, self.kwargs = args, kwargs
self.scope = self._initialize(*args, **kwargs)
self.all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
self.scope.name)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name)
self.num_params = sum(int(np.prod(v.get_shape().as_list()))
for v in self.trainable_variables)
self._setfromflat = U.SetFromFlat(self.trainable_variables)
self._getflat = U.GetFlat(self.trainable_variables)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, self.scope.name)
self.num_params = sum(int(np.prod(v.get_shape().as_list()))
for v in self.trainable_variables)
self._setfromflat = U.SetFromFlat(self.trainable_variables)
self._getflat = U.GetFlat(self.trainable_variables)
logger.info('Trainable variables ({} parameters)'.format(self.num_params))
for v in self.trainable_variables:
shp = v.get_shape().as_list()
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
logger.info('All variables')
for v in self.all_variables:
shp = v.get_shape().as_list()
logger.info('- {} shape:{} size:{}'.format(v.name, shp, np.prod(shp)))
logger.info('Trainable variables ({} parameters)'
.format(self.num_params))
for v in self.trainable_variables:
shp = v.get_shape().as_list()
logger.info('- {} shape:{} size:{}'.format(v.name, shp,
np.prod(shp)))
logger.info('All variables')
for v in self.all_variables:
shp = v.get_shape().as_list()
logger.info('- {} shape:{} size:{}'.format(v.name, shp,
np.prod(shp)))
placeholders = [tf.placeholder(v.value().dtype, v.get_shape().as_list())
for v in self.all_variables]
self.set_all_vars = U.function(
inputs=placeholders,
outputs=[],
updates=[tf.group(*[v.assign(p) for v, p
in zip(self.all_variables, placeholders)])]
)
placeholders = [tf.placeholder(v.value().dtype,
v.get_shape().as_list())
for v in self.all_variables]
self.set_all_vars = U.function(
inputs=placeholders,
outputs=[],
updates=[tf.group(*[v.assign(p) for v, p
in zip(self.all_variables, placeholders)])]
)
def _initialize(self, *args, **kwargs):
raise NotImplementedError
def _initialize(self, *args, **kwargs):
raise NotImplementedError
def save(self, filename):
assert filename.endswith('.h5')
with h5py.File(filename, 'w') as f:
for v in self.all_variables:
f[v.name] = v.eval()
# TODO: It would be nice to avoid pickle, but it's convenient to pass
# Python objects to _initialize (like Gym spaces or numpy arrays).
f.attrs['name'] = type(self).__name__
f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args,
self.kwargs),
protocol=-1))
def save(self, filename):
assert filename.endswith('.h5')
with h5py.File(filename, 'w') as f:
for v in self.all_variables:
f[v.name] = v.eval()
# TODO: It would be nice to avoid pickle, but it's convenient to
# pass Python objects to _initialize (like Gym spaces or numpy
# arrays).
f.attrs['name'] = type(self).__name__
f.attrs['args_and_kwargs'] = np.void(pickle.dumps((self.args,
self.kwargs),
protocol=-1))
@classmethod
def Load(cls, filename, extra_kwargs=None):
with h5py.File(filename, 'r') as f:
args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring())
if extra_kwargs:
kwargs.update(extra_kwargs)
policy = cls(*args, **kwargs)
policy.set_all_vars(*[f[v.name][...] for v in policy.all_variables])
return policy
@classmethod
def Load(cls, filename, extra_kwargs=None):
with h5py.File(filename, 'r') as f:
args, kwargs = pickle.loads(f.attrs['args_and_kwargs'].tostring())
if extra_kwargs:
kwargs.update(extra_kwargs)
policy = cls(*args, **kwargs)
policy.set_all_vars(*[f[v.name][...]
for v in policy.all_variables])
return policy
# === Rollouts/training ===
# === Rollouts/training ===
def rollout(self, env, render=False, timestep_limit=None, save_obs=False,
random_stream=None):
"""Do a rollout.
def rollout(self, env, render=False, timestep_limit=None, save_obs=False,
random_stream=None):
"""Do a rollout.
If random_stream is provided, the rollout will take noisy actions with
noise drawn from that stream. Otherwise, no action noise will be added.
"""
env_timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
timestep_limit = (env_timestep_limit if timestep_limit is None
else min(timestep_limit, env_timestep_limit))
rews = []
t = 0
if save_obs:
obs = []
ob = env.reset()
for _ in range(timestep_limit):
ac = self.act(ob[None], random_stream=random_stream)[0]
if save_obs:
obs.append(ob)
ob, rew, done, _ = env.step(ac)
rews.append(rew)
t += 1
if render:
env.render()
if done:
break
rews = np.array(rews, dtype=np.float32)
if save_obs:
return rews, t, np.array(obs)
return rews, t
If random_stream is provided, the rollout will take noisy actions with
noise drawn from that stream. Otherwise, no action noise will be added.
"""
env_timestep_limit = env.spec.tags.get("wrapper_config.TimeLimit"
".max_episode_steps")
timestep_limit = (env_timestep_limit if timestep_limit is None
else min(timestep_limit, env_timestep_limit))
rews = []
t = 0
if save_obs:
obs = []
ob = env.reset()
for _ in range(timestep_limit):
ac = self.act(ob[None], random_stream=random_stream)[0]
if save_obs:
obs.append(ob)
ob, rew, done, _ = env.step(ac)
rews.append(rew)
t += 1
if render:
env.render()
if done:
break
rews = np.array(rews, dtype=np.float32)
if save_obs:
return rews, t, np.array(obs)
return rews, t
def act(self, ob, random_stream=None):
raise NotImplementedError
def act(self, ob, random_stream=None):
raise NotImplementedError
def set_trainable_flat(self, x):
self._setfromflat(x)
def set_trainable_flat(self, x):
self._setfromflat(x)
def get_trainable_flat(self):
return self._getflat()
def get_trainable_flat(self):
return self._getflat()
@property
def needs_ob_stat(self):
raise NotImplementedError
@property
def needs_ob_stat(self):
raise NotImplementedError
def set_ob_stat(self, ob_mean, ob_std):
raise NotImplementedError
def set_ob_stat(self, ob_mean, ob_std):
raise NotImplementedError
def bins(x, dim, num_bins, name):
scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01))
scores_nab = tf.reshape(scores, [-1, dim, num_bins])
return tf.argmax(scores_nab, 2)
scores = U.dense(x, dim * num_bins, name, U.normc_initializer(0.01))
scores_nab = tf.reshape(scores, [-1, dim, num_bins])
return tf.argmax(scores_nab, 2)
class MujocoPolicy(Policy):
def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std, nonlin_type,
hidden_dims, connection_type):
self.ac_space = ac_space
self.ac_bins = ac_bins
self.ac_noise_std = ac_noise_std
self.hidden_dims = hidden_dims
self.connection_type = connection_type
def _initialize(self, ob_space, ac_space, ac_bins, ac_noise_std,
nonlin_type, hidden_dims, connection_type):
self.ac_space = ac_space
self.ac_bins = ac_bins
self.ac_noise_std = ac_noise_std
self.hidden_dims = hidden_dims
self.connection_type = connection_type
assert len(ob_space.shape) == len(self.ac_space.shape) == 1
assert (np.all(np.isfinite(self.ac_space.low)) and
np.all(np.isfinite(self.ac_space.high))), "Action bounds required"
assert len(ob_space.shape) == len(self.ac_space.shape) == 1
assert (np.all(np.isfinite(self.ac_space.low)) and
np.all(np.isfinite(self.ac_space.high))), ("Action bounds "
"required")
self.nonlin = {'tanh': tf.tanh,
'relu': tf.nn.relu,
'lrelu': U.lrelu,
'elu': tf.nn.elu}[nonlin_type]
self.nonlin = {'tanh': tf.tanh,
'relu': tf.nn.relu,
'lrelu': U.lrelu,
'elu': tf.nn.elu}[nonlin_type]
with tf.variable_scope(type(self).__name__) as scope:
# Observation normalization.
ob_mean = tf.get_variable(
'ob_mean', ob_space.shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
ob_std = tf.get_variable(
'ob_std', ob_space.shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
in_mean = tf.placeholder(tf.float32, ob_space.shape)
in_std = tf.placeholder(tf.float32, ob_space.shape)
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
tf.assign(ob_mean, in_mean),
tf.assign(ob_std, in_std),
])
with tf.variable_scope(type(self).__name__) as scope:
# Observation normalization.
ob_mean = tf.get_variable(
'ob_mean', ob_space.shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
ob_std = tf.get_variable(
'ob_std', ob_space.shape, tf.float32,
tf.constant_initializer(np.nan), trainable=False)
in_mean = tf.placeholder(tf.float32, ob_space.shape)
in_std = tf.placeholder(tf.float32, ob_space.shape)
self._set_ob_mean_std = U.function([in_mean, in_std], [], updates=[
tf.assign(ob_mean, in_mean),
tf.assign(ob_std, in_std),
])
# Policy network.
o = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std, -5.0, 5.0))
self._act = U.function([o], a)
return scope
# Policy network.
o = tf.placeholder(tf.float32, [None] + list(ob_space.shape))
a = self._make_net(tf.clip_by_value((o - ob_mean) / ob_std,
-5.0, 5.0))
self._act = U.function([o], a)
return scope
def _make_net(self, o):
# Process observation.
if self.connection_type == 'ff':
x = o
for ilayer, hd in enumerate(self.hidden_dims):
x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer),
U.normc_initializer(1.0)))
else:
raise NotImplementedError(self.connection_type)
def _make_net(self, o):
# Process observation.
if self.connection_type == 'ff':
x = o
for ilayer, hd in enumerate(self.hidden_dims):
x = self.nonlin(U.dense(x, hd, 'l{}'.format(ilayer),
U.normc_initializer(1.0)))
else:
raise NotImplementedError(self.connection_type)
# Map to action.
adim = self.ac_space.shape[0]
ahigh = self.ac_space.high
alow = self.ac_space.low
assert isinstance(self.ac_bins, str)
ac_bin_mode, ac_bin_arg = self.ac_bins.split(':')
# Map to action.
adim = self.ac_space.shape[0]
ahigh = self.ac_space.high
alow = self.ac_space.low
assert isinstance(self.ac_bins, str)
ac_bin_mode, ac_bin_arg = self.ac_bins.split(':')
if ac_bin_mode == 'uniform':
# Uniformly spaced bins, from ac_space.low to ac_space.high.
num_ac_bins = int(ac_bin_arg)
aidx_na = bins(x, adim, num_ac_bins, 'out')
ac_range_1a = (ahigh - alow)[None, :]
a = (1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a +
alow[None, :])
if ac_bin_mode == 'uniform':
# Uniformly spaced bins, from ac_space.low to ac_space.high.
num_ac_bins = int(ac_bin_arg)
aidx_na = bins(x, adim, num_ac_bins, 'out')
ac_range_1a = (ahigh - alow)[None, :]
a = (1. / (num_ac_bins - 1.) * tf.to_float(aidx_na) * ac_range_1a +
alow[None, :])
elif ac_bin_mode == 'custom':
# Custom bins specified as a list of values from -1 to 1.
# The bins are rescaled to ac_space.low to ac_space.high.
acvals_k = np.array(list(map(float, ac_bin_arg.split(','))),
dtype=np.float32)
logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x)
for x in acvals_k))
assert acvals_k.ndim == 1 and acvals_k[0] == -1 and acvals_k[-1] == 1
acvals_ak = ((ahigh - alow)[:, None] / (acvals_k[-1] - acvals_k[0]) *
(acvals_k - acvals_k[0])[None, :] + alow[:, None])
elif ac_bin_mode == 'custom':
# Custom bins specified as a list of values from -1 to 1.
# The bins are rescaled to ac_space.low to ac_space.high.
acvals_k = np.array(list(map(float, ac_bin_arg.split(','))),
dtype=np.float32)
logger.info('Custom action values: ' + ' '.join('{:.3f}'.format(x)
for x in acvals_k))
assert (acvals_k.ndim == 1 and acvals_k[0] == -1 and
acvals_k[-1] == 1)
acvals_ak = ((ahigh - alow)[:, None] /
(acvals_k[-1] - acvals_k[0]) *
(acvals_k - acvals_k[0])[None, :] + alow[:, None])
aidx_na = bins(x, adim, len(acvals_k), 'out') # Values in [0, k-1].
a = tf.gather_nd(
acvals_ak,
tf.concat([
tf.tile(np.arange(adim)[None, :, None],
[tf.shape(aidx_na)[0], 1, 1]),
2,
tf.expand_dims(aidx_na, -1)
]) # (n, a, 2)
) # (n, a)
elif ac_bin_mode == 'continuous':
a = U.dense(x, adim, 'out', U.normc_initializer(0.01))
else:
raise NotImplementedError(ac_bin_mode)
aidx_na = bins(x, adim, len(acvals_k),
'out') # Values in [0, k-1].
a = tf.gather_nd(
acvals_ak,
tf.concat([
tf.tile(np.arange(adim)[None, :, None],
[tf.shape(aidx_na)[0], 1, 1]),
2,
tf.expand_dims(aidx_na, -1)
]) # (n, a, 2)
) # (n, a)
elif ac_bin_mode == 'continuous':
a = U.dense(x, adim, 'out', U.normc_initializer(0.01))
else:
raise NotImplementedError(ac_bin_mode)
return a
return a
def act(self, ob, random_stream=None):
a = self._act(ob)
if random_stream is not None and self.ac_noise_std != 0:
a += random_stream.randn(*a.shape) * self.ac_noise_std
return a
def act(self, ob, random_stream=None):
a = self._act(ob)
if random_stream is not None and self.ac_noise_std != 0:
a += random_stream.randn(*a.shape) * self.ac_noise_std
return a
@property
def needs_ob_stat(self):
return True
@property
def needs_ob_stat(self):
return True
@property
def needs_ref_batch(self):
return False
@property
def needs_ref_batch(self):
return False
def set_ob_stat(self, ob_mean, ob_std):
self._set_ob_mean_std(ob_mean, ob_std)
def set_ob_stat(self, ob_mean, ob_std):
self._set_ob_mean_std(ob_mean, ob_std)
@@ -24,199 +24,201 @@ DISABLED = 50
class TbWriter(object):
"""Based on SummaryWriter, but changed to allow for a different prefix."""
def __init__(self, dir, prefix):
self.dir = dir
# Start at 1, because EvWriter automatically generates an object with
# step = 0.
self.step = 1
self.evwriter = pywrap_tensorflow.EventsWriter(
compat.as_bytes(os.path.join(dir, prefix)))
"""Based on SummaryWriter, but changed to allow for a different prefix."""
def __init__(self, dir, prefix):
self.dir = dir
# Start at 1, because EvWriter automatically generates an object with
# step = 0.
self.step = 1
self.evwriter = pywrap_tensorflow.EventsWriter(
compat.as_bytes(os.path.join(dir, prefix)))
def write_values(self, key2val):
summary = tf.Summary(value=[tf.Summary.Value(tag=k, simple_value=float(v))
for (k, v) in key2val.items()])
event = event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = self.step
self.evwriter.WriteEvent(event)
self.evwriter.Flush()
self.step += 1
def write_values(self, key2val):
summary = tf.Summary(value=[tf.Summary.Value(tag=k,
simple_value=float(v))
for (k, v) in key2val.items()])
event = event_pb2.Event(wall_time=time.time(), summary=summary)
event.step = self.step
self.evwriter.WriteEvent(event)
self.evwriter.Flush()
self.step += 1
def close(self):
self.evwriter.Close()
def close(self):
self.evwriter.Close()
# API
def start(dir):
if _Logger.CURRENT is not _Logger.DEFAULT:
sys.stderr.write("WARNING: You asked to start logging (dir=%s), but you "
"never stopped the previous logger (dir=%s)."
"\n" % (dir, _Logger.CURRENT.dir))
_Logger.CURRENT = _Logger(dir=dir)
if _Logger.CURRENT is not _Logger.DEFAULT:
sys.stderr.write("WARNING: You asked to start logging (dir=%s), but "
"you never stopped the previous logger (dir=%s)."
"\n" % (dir, _Logger.CURRENT.dir))
_Logger.CURRENT = _Logger(dir=dir)
def stop():
if _Logger.CURRENT is _Logger.DEFAULT:
sys.stderr.write("WARNING: You asked to stop logging, but you never "
"started any previous logger."
"\n" % (dir, _Logger.CURRENT.dir))
return
_Logger.CURRENT.close()
_Logger.CURRENT = _Logger.DEFAULT
if _Logger.CURRENT is _Logger.DEFAULT:
sys.stderr.write("WARNING: You asked to stop logging, but you never "
"started any previous logger."
"\n" % (dir, _Logger.CURRENT.dir))
return
_Logger.CURRENT.close()
_Logger.CURRENT = _Logger.DEFAULT
def record_tabular(key, val):
"""Log a value of some diagnostic.
"""Log a value of some diagnostic.
Call this once for each diagnostic quantity, each iteration.
"""
_Logger.CURRENT.record_tabular(key, val)
Call this once for each diagnostic quantity, each iteration.
"""
_Logger.CURRENT.record_tabular(key, val)
def dump_tabular():
"""Write all of the diagnostics from the current iteration."""
_Logger.CURRENT.dump_tabular()
"""Write all of the diagnostics from the current iteration."""
_Logger.CURRENT.dump_tabular()
def log(*args, **kwargs):
"""Write the sequence of args, with no separators.
"""Write the sequence of args, with no separators.
This is written to the console and output files (if you've configured an
output file).
"""
level = kwargs['level'] if 'level' in kwargs else INFO
_Logger.CURRENT.log(*args, level=level)
This is written to the console and output files (if you've configured an
output file).
"""
level = kwargs['level'] if 'level' in kwargs else INFO
_Logger.CURRENT.log(*args, level=level)
def debug(*args):
log(*args, level=DEBUG)
log(*args, level=DEBUG)
def info(*args):
log(*args, level=INFO)
log(*args, level=INFO)
def warn(*args):
log(*args, level=WARN)
log(*args, level=WARN)
def error(*args):
log(*args, level=ERROR)
log(*args, level=ERROR)
def set_level(level):
"""
Set logging threshold on current logger.
"""
_Logger.CURRENT.set_level(level)
"""
Set logging threshold on current logger.
"""
_Logger.CURRENT.set_level(level)
def get_dir():
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call start)
"""
return _Logger.CURRENT.get_dir()
"""
Get directory that log files are being written to.
will be None if there is no output directory (i.e., if you didn't call
start)
"""
return _Logger.CURRENT.get_dir()
def get_expt_dir():
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n")
return get_dir()
sys.stderr.write("get_expt_dir() is Deprecated. Switch to get_dir()\n")
return get_dir()
# Backend
class _Logger(object):
# A logger with no output files. (See right below class definition) so that
# you can still log to the terminal without setting up any output files.
DEFAULT = None
# Current logger being used by the free functions above.
CURRENT = None
# A logger with no output files. (See right below class definition) so that
# you can still log to the terminal without setting up any output files.
DEFAULT = None
# Current logger being used by the free functions above.
CURRENT = None
def __init__(self, dir=None):
self.name2val = OrderedDict() # Values this iteration.
self.level = INFO
self.dir = dir
self.text_outputs = [sys.stdout]
if dir is not None:
os.makedirs(dir, exist_ok=True)
self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w"))
self.tbwriter = TbWriter(dir=dir, prefix="events")
else:
self.tbwriter = None
def __init__(self, dir=None):
self.name2val = OrderedDict() # Values this iteration.
self.level = INFO
self.dir = dir
self.text_outputs = [sys.stdout]
if dir is not None:
os.makedirs(dir, exist_ok=True)
self.text_outputs.append(open(os.path.join(dir, "log.txt"), "w"))
self.tbwriter = TbWriter(dir=dir, prefix="events")
else:
self.tbwriter = None
# Logging API, forwarded
# Logging API, forwarded
def record_tabular(self, key, val):
self.name2val[key] = val
def record_tabular(self, key, val):
self.name2val[key] = val
def dump_tabular(self):
# Create strings for printing.
key2str = OrderedDict()
for (key, val) in self.name2val.items():
if hasattr(val, "__float__"):
valstr = "%-8.3g" % val
else:
valstr = val
key2str[self._truncate(key)] = self._truncate(valstr)
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write to all text outputs
self._write_text("-" * (keywidth + valwidth + 7), "\n")
for (key, val) in key2str.items():
self._write_text("| ", key, " " * (keywidth - len(key)), " | ", val,
" " * (valwidth - len(val)), " |\n")
self._write_text("-" * (keywidth + valwidth + 7), "\n")
for f in self.text_outputs:
try:
f.flush()
except OSError:
sys.stderr.write('Warning! OSError when flushing.\n')
# Write to tensorboard
if self.tbwriter is not None:
self.tbwriter.write_values(self.name2val)
self.name2val.clear()
def dump_tabular(self):
# Create strings for printing.
key2str = OrderedDict()
for (key, val) in self.name2val.items():
if hasattr(val, "__float__"):
valstr = "%-8.3g" % val
else:
valstr = val
key2str[self._truncate(key)] = self._truncate(valstr)
keywidth = max(map(len, key2str.keys()))
valwidth = max(map(len, key2str.values()))
# Write to all text outputs
self._write_text("-" * (keywidth + valwidth + 7), "\n")
for (key, val) in key2str.items():
self._write_text("| ", key, " " * (keywidth - len(key)),
" | ", val, " " * (valwidth - len(val)), " |\n")
self._write_text("-" * (keywidth + valwidth + 7), "\n")
for f in self.text_outputs:
try:
f.flush()
except OSError:
sys.stderr.write('Warning! OSError when flushing.\n')
# Write to tensorboard
if self.tbwriter is not None:
self.tbwriter.write_values(self.name2val)
self.name2val.clear()
def log(self, *args, **kwargs):
level = kwargs['level'] if 'level' in kwargs else INFO
if self.level <= level:
self._do_log(*args)
def log(self, *args, **kwargs):
level = kwargs['level'] if 'level' in kwargs else INFO
if self.level <= level:
self._do_log(*args)
# Configuration
# Configuration
def set_level(self, level):
self.level = level
def set_level(self, level):
self.level = level
def get_dir(self):
return self.dir
def get_dir(self):
return self.dir
def close(self):
for f in self.text_outputs[1:]:
f.close()
if self.tbwriter:
self.tbwriter.close()
def close(self):
for f in self.text_outputs[1:]:
f.close()
if self.tbwriter:
self.tbwriter.close()
# Misc
# Misc
def _do_log(self, *args):
self._write_text(*args + ('\n',))
for f in self.text_outputs:
try:
f.flush()
except OSError:
print('Warning! OSError when flushing.')
def _do_log(self, *args):
self._write_text(*args + ('\n',))
for f in self.text_outputs:
try:
f.flush()
except OSError:
print('Warning! OSError when flushing.')
def _write_text(self, *strings):
for f in self.text_outputs:
for string in strings:
f.write(string)
def _write_text(self, *strings):
for f in self.text_outputs:
for string in strings:
f.write(string)
def _truncate(self, s):
if len(s) > 33:
return s[:30] + "..."
else:
return s
def _truncate(self, s):
if len(s) > 33:
return s[:30] + "..."
else:
return s
_Logger.DEFAULT = _Logger()
+140 -136
View File
@@ -12,8 +12,8 @@ import os
# Tensorflow must be at least version 1.0.0 for the example to work.
if int(tf.__version__.split(".")[0]) < 1:
raise Exception("Your Tensorflow version is less than 1.0.0. Please update "
"Tensorflow to the latest version.")
raise Exception("Your Tensorflow version is less than 1.0.0. Please "
"update Tensorflow to the latest version.")
# ================================================================
# Import all names into common namespace
@@ -25,160 +25,163 @@ clip = tf.clip_by_value
def sum(x, axis=None, keepdims=False):
return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis],
keep_dims=keepdims)
return tf.reduce_sum(x, reduction_indices=None if axis is None else [axis],
keep_dims=keepdims)
def mean(x, axis=None, keepdims=False):
return tf.reduce_mean(x, reduction_indices=None if axis is None else [axis],
keep_dims=keepdims)
return tf.reduce_mean(x, reduction_indices=(None if axis is None
else [axis]),
keep_dims=keepdims)
def var(x, axis=None, keepdims=False):
meanx = mean(x, axis=axis, keepdims=keepdims)
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
meanx = mean(x, axis=axis, keepdims=keepdims)
return mean(tf.square(x - meanx), axis=axis, keepdims=keepdims)
def std(x, axis=None, keepdims=False):
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))
def max(x, axis=None, keepdims=False):
return tf.reduce_max(x, reduction_indices=None if axis is None else [axis],
keep_dims=keepdims)
return tf.reduce_max(x, reduction_indices=None if axis is None else [axis],
keep_dims=keepdims)
def min(x, axis=None, keepdims=False):
return tf.reduce_min(x, reduction_indices=None if axis is None else [axis],
keep_dims=keepdims)
return tf.reduce_min(x, reduction_indices=None if axis is None else [axis],
keep_dims=keepdims)
def concatenate(arrs, axis=0):
return tf.concat(arrs, axis)
return tf.concat(arrs, axis)
def argmax(x, axis=None):
return tf.argmax(x, dimension=axis)
return tf.argmax(x, dimension=axis)
# Extras
def l2loss(params):
if len(params) == 0:
return tf.constant(0.0)
else:
return tf.add_n([sum(tf.square(p)) for p in params])
if len(params) == 0:
return tf.constant(0.0)
else:
return tf.add_n([sum(tf.square(p)) for p in params])
def lrelu(x, leak=0.2):
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
f1 = 0.5 * (1 + leak)
f2 = 0.5 * (1 - leak)
return f1 * x + f2 * abs(x)
def categorical_sample_logits(X):
# https://github.com/tensorflow/tensorflow/issues/456
U = tf.random_uniform(tf.shape(X))
return argmax(X - tf.log(-tf.log(U)), axis=1)
# https://github.com/tensorflow/tensorflow/issues/456
U = tf.random_uniform(tf.shape(X))
return argmax(X - tf.log(-tf.log(U)), axis=1)
# Global session
def get_session():
return tf.get_default_session()
return tf.get_default_session()
def single_threaded_session():
tf_config = tf.ConfigProto(inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1)
return tf.Session(config=tf_config)
tf_config = tf.ConfigProto(inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1)
return tf.Session(config=tf_config)
ALREADY_INITIALIZED = set()
def initialize():
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
get_session().run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
new_variables = set(tf.global_variables()) - ALREADY_INITIALIZED
get_session().run(tf.variables_initializer(new_variables))
ALREADY_INITIALIZED.update(new_variables)
def eval(expr, feed_dict=None):
if feed_dict is None:
feed_dict = {}
return get_session().run(expr, feed_dict=feed_dict)
if feed_dict is None:
feed_dict = {}
return get_session().run(expr, feed_dict=feed_dict)
def set_value(v, val):
get_session().run(v.assign(val))
get_session().run(v.assign(val))
def load_state(fname):
saver = tf.train.Saver()
saver.restore(get_session(), fname)
saver = tf.train.Saver()
saver.restore(get_session(), fname)
def save_state(fname):
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(get_session(), fname)
os.makedirs(os.path.dirname(fname), exist_ok=True)
saver = tf.train.Saver()
saver.save(get_session(), fname)
# Model components
def normc_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def dense(x, size, name, weight_init=None, bias=True):
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
initializer=weight_init)
ret = tf.matmul(x, w)
if bias:
b = tf.get_variable(name + "/b", [size],
initializer=tf.zeros_initializer())
return ret + b
else:
return ret
w = tf.get_variable(name + "/w", [x.get_shape()[1], size],
initializer=weight_init)
ret = tf.matmul(x, w)
if bias:
b = tf.get_variable(name + "/b", [size],
initializer=tf.zeros_initializer())
return ret + b
else:
return ret
# Basic Stuff
def function(inputs, outputs, updates=None, givens=None):
if isinstance(outputs, list):
return _Function(inputs, outputs, updates, givens=givens)
elif isinstance(outputs, dict):
f = _Function(inputs, outputs.values(), updates, givens=givens)
return lambda *inputs: dict(zip(outputs.keys(), f(*inputs)))
else:
f = _Function(inputs, [outputs], updates, givens=givens)
return lambda *inputs: f(*inputs)[0]
if isinstance(outputs, list):
return _Function(inputs, outputs, updates, givens=givens)
elif isinstance(outputs, dict):
f = _Function(inputs, outputs.values(), updates, givens=givens)
return lambda *inputs: dict(zip(outputs.keys(), f(*inputs)))
else:
f = _Function(inputs, [outputs], updates, givens=givens)
return lambda *inputs: f(*inputs)[0]
class _Function(object):
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
assert all(len(i.op.inputs) == 0 for i in inputs), ("inputs should all be "
"placeholders")
self.inputs = inputs
updates = updates or []
self.update_group = tf.group(*updates)
self.outputs_update = list(outputs) + [self.update_group]
self.givens = {} if givens is None else givens
self.check_nan = check_nan
def __init__(self, inputs, outputs, updates, givens, check_nan=False):
assert all(len(i.op.inputs) == 0 for i in inputs), ("inputs should "
"all be "
"placeholders")
self.inputs = inputs
updates = updates or []
self.update_group = tf.group(*updates)
self.outputs_update = list(outputs) + [self.update_group]
self.givens = {} if givens is None else givens
self.check_nan = check_nan
def __call__(self, *inputvals):
assert len(inputvals) == len(self.inputs)
feed_dict = dict(zip(self.inputs, inputvals))
feed_dict.update(self.givens)
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
if self.check_nan:
if any(np.isnan(r).any() for r in results):
raise RuntimeError("Nan detected")
return results
def __call__(self, *inputvals):
assert len(inputvals) == len(self.inputs)
feed_dict = dict(zip(self.inputs, inputvals))
feed_dict.update(self.givens)
results = get_session().run(self.outputs_update,
feed_dict=feed_dict)[:-1]
if self.check_nan:
if any(np.isnan(r).any() for r in results):
raise RuntimeError("Nan detected")
return results
# Graph traversal
@@ -189,71 +192,72 @@ VARIABLES = {}
def var_shape(x):
out = [k.value for k in x.get_shape()]
assert all(isinstance(a, int) for a in out), ("shape function assumes that "
"shape is fully known")
return out
out = [k.value for k in x.get_shape()]
assert all(isinstance(a, int) for a in out), ("shape function assumes "
"that shape is fully known")
return out
def numel(x):
return intprod(var_shape(x))
return intprod(var_shape(x))
def intprod(x):
return int(np.prod(x))
return int(np.prod(x))
def flatgrad(loss, var_list):
grads = tf.gradients(loss, var_list)
return tf.concat([tf.reshape(grad, [numel(v)], 0)
for (v, grad) in zip(var_list, grads)])
grads = tf.gradients(loss, var_list)
return tf.concat([tf.reshape(grad, [numel(v)], 0)
for (v, grad) in zip(var_list, grads)])
class SetFromFlat(object):
def __init__(self, var_list, dtype=tf.float32):
assigns = []
shapes = list(map(var_shape, var_list))
total_size = np.sum([intprod(shape) for shape in shapes])
def __init__(self, var_list, dtype=tf.float32):
assigns = []
shapes = list(map(var_shape, var_list))
total_size = np.sum([intprod(shape) for shape in shapes])
self.theta = theta = tf.placeholder(dtype, [total_size])
start = 0
assigns = []
for (shape, v) in zip(shapes, var_list):
size = intprod(shape)
assigns.append(tf.assign(v, tf.reshape(theta[start:start + size],
shape)))
start += size
assert start == total_size
self.op = tf.group(*assigns)
self.theta = theta = tf.placeholder(dtype, [total_size])
start = 0
assigns = []
for (shape, v) in zip(shapes, var_list):
size = intprod(shape)
assigns.append(tf.assign(v, tf.reshape(theta[start:start + size],
shape)))
start += size
assert start == total_size
self.op = tf.group(*assigns)
def __call__(self, theta):
get_session().run(self.op, feed_dict={self.theta: theta})
def __call__(self, theta):
get_session().run(self.op, feed_dict={self.theta: theta})
class GetFlat(object):
def __init__(self, var_list):
self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0)
def __init__(self, var_list):
self.op = tf.concat([tf.reshape(v, [numel(v)]) for v in var_list], 0)
def __call__(self):
return get_session().run(self.op)
def __call__(self):
return get_session().run(self.op)
# Misc
def scope_vars(scope, trainable_only):
"""Get variables inside a scope. The scope can be specified as a string."""
return tf.get_collection((tf.GraphKeys.TRAINABLE_VARIABLES if trainable_only
else tf.GraphKeys.GLOBAL_VARIABLES),
scope=(scope if isinstance(scope, str)
else scope.name))
"""Get variables inside a scope. The scope can be specified as a string."""
return tf.get_collection((tf.GraphKeys.TRAINABLE_VARIABLES
if trainable_only
else tf.GraphKeys.GLOBAL_VARIABLES),
scope=(scope if isinstance(scope, str)
else scope.name))
def in_session(f):
@functools.wraps(f)
def newfunc(*args, **kwargs):
with tf.Session():
f(*args, **kwargs)
return newfunc
@functools.wraps(f)
def newfunc(*args, **kwargs):
with tf.Session():
f(*args, **kwargs)
return newfunc
# A mapping from name -> (placeholder, dtype, shape).
@@ -261,28 +265,28 @@ _PLACEHOLDER_CACHE = {}
def get_placeholder(name, dtype, shape):
print("calling get_placeholder", name)
if name in _PLACEHOLDER_CACHE:
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
assert dtype1 == dtype and shape1 == shape
return out
else:
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
return out
print("calling get_placeholder", name)
if name in _PLACEHOLDER_CACHE:
out, dtype1, shape1 = _PLACEHOLDER_CACHE[name]
assert dtype1 == dtype and shape1 == shape
return out
else:
out = tf.placeholder(dtype=dtype, shape=shape, name=name)
_PLACEHOLDER_CACHE[name] = (out, dtype, shape)
return out
def get_placeholder_cached(name):
return _PLACEHOLDER_CACHE[name][0]
return _PLACEHOLDER_CACHE[name][0]
def flattenallbut0(x):
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
return tf.reshape(x, [-1, intprod(x.get_shape().as_list()[1:])])
def reset():
global _PLACEHOLDER_CACHE
global VARIABLES
_PLACEHOLDER_CACHE = {}
VARIABLES = {}
tf.reset_default_graph()
global _PLACEHOLDER_CACHE
global VARIABLES
_PLACEHOLDER_CACHE = {}
VARIABLES = {}
tf.reset_default_graph()
+55 -54
View File
@@ -10,77 +10,78 @@ import tensorflow as tf
def compute_ranks(x):
"""Returns ranks in [0, len(x))
"""Returns ranks in [0, len(x))
Note: This is different from scipy.stats.rankdata, which returns ranks in
[1, len(x)].
"""
assert x.ndim == 1
ranks = np.empty(len(x), dtype=int)
ranks[x.argsort()] = np.arange(len(x))
return ranks
Note: This is different from scipy.stats.rankdata, which returns ranks in
[1, len(x)].
"""
assert x.ndim == 1
ranks = np.empty(len(x), dtype=int)
ranks[x.argsort()] = np.arange(len(x))
return ranks
def compute_centered_ranks(x):
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
y /= (x.size - 1)
y -= 0.5
return y
y = compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
y /= (x.size - 1)
y -= 0.5
return y
def make_session(single_threaded):
if not single_threaded:
return tf.InteractiveSession()
return tf.InteractiveSession(
config=tf.ConfigProto(inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1))
if not single_threaded:
return tf.InteractiveSession()
return tf.InteractiveSession(
config=tf.ConfigProto(inter_op_parallelism_threads=1,
intra_op_parallelism_threads=1))
def itergroups(items, group_size):
assert group_size >= 1
group = []
for x in items:
group.append(x)
if len(group) == group_size:
yield tuple(group)
del group[:]
if group:
yield tuple(group)
assert group_size >= 1
group = []
for x in items:
group.append(x)
if len(group) == group_size:
yield tuple(group)
del group[:]
if group:
yield tuple(group)
def batched_weighted_sum(weights, vecs, batch_size):
total = 0
num_items_summed = 0
for batch_weights, batch_vecs in zip(itergroups(weights, batch_size),
itergroups(vecs, batch_size)):
assert len(batch_weights) == len(batch_vecs) <= batch_size
total += np.dot(np.asarray(batch_weights, dtype=np.float32),
np.asarray(batch_vecs, dtype=np.float32))
num_items_summed += len(batch_weights)
return total, num_items_summed
total = 0
num_items_summed = 0
for batch_weights, batch_vecs in zip(itergroups(weights, batch_size),
itergroups(vecs, batch_size)):
assert len(batch_weights) == len(batch_vecs) <= batch_size
total += np.dot(np.asarray(batch_weights, dtype=np.float32),
np.asarray(batch_vecs, dtype=np.float32))
num_items_summed += len(batch_weights)
return total, num_items_summed
class RunningStat(object):
def __init__(self, shape, eps):
self.sum = np.zeros(shape, dtype=np.float32)
self.sumsq = np.full(shape, eps, dtype=np.float32)
self.count = eps
def __init__(self, shape, eps):
self.sum = np.zeros(shape, dtype=np.float32)
self.sumsq = np.full(shape, eps, dtype=np.float32)
self.count = eps
def increment(self, s, ssq, c):
self.sum += s
self.sumsq += ssq
self.count += c
def increment(self, s, ssq, c):
self.sum += s
self.sumsq += ssq
self.count += c
@property
def mean(self):
return self.sum / self.count
@property
def mean(self):
return self.sum / self.count
@property
def std(self):
return np.sqrt(np.maximum(self.sumsq / self.count - np.square(self.mean),
1e-2))
@property
def std(self):
return np.sqrt(np.maximum(
self.sumsq / self.count - np.square(self.mean), 1e-2))
def set_from_init(self, init_mean, init_std, init_count):
self.sum[:] = init_mean * init_count
self.sumsq[:] = (np.square(init_mean) + np.square(init_std)) * init_count
self.count = init_count
def set_from_init(self, init_mean, init_std, init_count):
self.sum[:] = init_mean * init_count
self.sumsq[:] = (np.square(init_mean) +
np.square(init_std)) * init_count
self.count = init_count
+25 -24
View File
@@ -11,32 +11,33 @@ import click
@click.option("--stochastic", is_flag=True)
@click.option("--extra_kwargs")
def main(env_id, policy_file, record, stochastic, extra_kwargs):
import gym
from gym import wrappers
import tensorflow as tf
from policies import MujocoPolicy
import numpy as np
env = gym.make(env_id)
if record:
import uuid
env = wrappers.Monitor(env, "/tmp/" + str(uuid.uuid4()), force=True)
if extra_kwargs:
import json
extra_kwargs = json.loads(extra_kwargs)
with tf.Session():
pi = MujocoPolicy.Load(policy_file, extra_kwargs=extra_kwargs)
while True:
rews, t = pi.rollout(env, render=True,
random_stream=np.random if stochastic else None)
print("return={:.4f} len={}".format(rews.sum(), t))
import gym
from gym import wrappers
import tensorflow as tf
from policies import MujocoPolicy
import numpy as np
env = gym.make(env_id)
if record:
env.close()
return
import uuid
env = wrappers.Monitor(env, "/tmp/" + str(uuid.uuid4()), force=True)
if extra_kwargs:
import json
extra_kwargs = json.loads(extra_kwargs)
with tf.Session():
pi = MujocoPolicy.Load(policy_file, extra_kwargs=extra_kwargs)
while True:
rews, t = pi.rollout(env, render=True,
random_stream=(np.random if stochastic
else None))
print("return={:.4f} len={}".format(rews.sum(), t))
if record:
env.close()
return
if __name__ == "__main__":
main()
main()
+10 -10
View File
@@ -11,15 +11,15 @@ import ray.rllib.policy_gradient as pg
if __name__ == "__main__":
ray.init()
ray.init()
# TODO(ekl): get the algorithms working on a common set of envs
env_name = "CartPole-v0"
alg1 = es.EvolutionStrategies(env_name, es.DEFAULT_CONFIG)
alg2 = pg.PolicyGradient(env_name, pg.DEFAULT_CONFIG)
# TODO(ekl): get the algorithms working on a common set of envs
env_name = "CartPole-v0"
alg1 = es.EvolutionStrategies(env_name, es.DEFAULT_CONFIG)
alg2 = pg.PolicyGradient(env_name, pg.DEFAULT_CONFIG)
while True:
r1 = alg1.train()
r2 = alg2.train()
print("evolution strategies: {}".format(r1))
print("policy gradient: {}".format(r2))
while True:
r1 = alg1.train()
r2 = alg2.train()
print("evolution strategies: {}".format(r1))
print("policy gradient: {}".format(r2))
+198 -194
View File
@@ -10,189 +10,192 @@ import tensorflow as tf
class LocalSyncParallelOptimizer(object):
"""Optimizer that runs in parallel across multiple local devices.
"""Optimizer that runs in parallel across multiple local devices.
LocalSyncParallelOptimizer automatically splits up and loads training data
onto specified local devices (e.g. GPUs) with `load_data()`. During a call to
`optimize()`, the devices compute gradients over slices of the data in
parallel. The gradients are then averaged and applied to the shared weights.
LocalSyncParallelOptimizer automatically splits up and loads training data
onto specified local devices (e.g. GPUs) with `load_data()`. During a call
to `optimize()`, the devices compute gradients over slices of the data in
parallel. The gradients are then averaged and applied to the shared
weights.
The data loaded is pinned in device memory until the next call to
`load_data`, so you can make multiple passes (possibly in randomized order)
over the same data once loaded.
The data loaded is pinned in device memory until the next call to
`load_data`, so you can make multiple passes (possibly in randomized order)
over the same data once loaded.
This is similar to tf.train.SyncReplicasOptimizer, but works within a single
TensorFlow graph, i.e. implements in-graph replicated training:
This is similar to tf.train.SyncReplicasOptimizer, but works within a
single TensorFlow graph, i.e. implements in-graph replicated training:
https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
Args:
optimizer: delegate TensorFlow optimizer object.
devices: list of the names of TensorFlow devices to parallelize over.
input_placeholders: list of inputs for the loss function. Tensors of
these shapes will be passed to build_loss() in order
to define the per-device loss ops.
per_device_batch_size: number of tuples to optimize over at a time per
device. In each call to `optimize()`,
`len(devices) * per_device_batch_size` tuples of
data will be processed.
build_loss: function that takes the specified inputs and returns an
object with a 'loss' property that is a scalar Tensor. For
example, ray.rllib.policy_gradient.ProximalPolicyLoss.
logdir: directory to place debugging output in.
"""
def __init__(
self,
optimizer,
devices,
input_placeholders,
per_device_batch_size,
build_loss,
logdir):
self.optimizer = optimizer
self.devices = devices
self.batch_size = per_device_batch_size * len(devices)
self.per_device_batch_size = per_device_batch_size
self.input_placeholders = input_placeholders
self.build_loss = build_loss
self.logdir = logdir
# First initialize the shared loss network
with tf.variable_scope("tower"):
self._shared_loss = build_loss(*input_placeholders)
# Then setup the per-device loss graphs that use the shared weights
self._batch_index = tf.placeholder(tf.int32)
data_splits = zip(
*[tf.split(ph, len(devices)) for ph in input_placeholders])
self._towers = []
for device, device_placeholders in zip(self.devices, data_splits):
self._towers.append(self._setup_device(device, device_placeholders))
avg = average_gradients([t.grads for t in self._towers])
self._train_op = self.optimizer.apply_gradients(avg)
def load_data(self, sess, inputs, full_trace=False):
"""Bulk loads the specified inputs into device memory.
The shape of the inputs must conform to the shapes of the input
placeholders this optimizer was constructed with.
The data is split equally across all the devices. If the data is not
evenly divisible by the batch size, excess data will be discarded.
https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
Args:
sess: TensorFlow session.
inputs: list of Tensors matching the input placeholders specified at
construction time of this optimizer.
full_trace: whether to profile data loading.
Returns:
The number of tuples loaded per device.
optimizer: Delegate TensorFlow optimizer object.
devices: List of the names of TensorFlow devices to parallelize over.
input_placeholders: List of inputs for the loss function. Tensors of
these shapes will be passed to build_loss() in order to define the
per-device loss ops.
per_device_batch_size: Number of tuples to optimize over at a time per
device. In each call to `optimize()`,
`len(devices) * per_device_batch_size` tuples of data will be
processed.
build_loss: Function that takes the specified inputs and returns an
object with a 'loss' property that is a scalar Tensor. For example,
ray.rllib.policy_gradient.ProximalPolicyLoss.
logdir: Directory to place debugging output in.
"""
feed_dict = {}
assert len(self.input_placeholders) == len(inputs)
for ph, arr in zip(self.input_placeholders, inputs):
truncated_arr = make_divisible_by(arr, self.batch_size)
feed_dict[ph] = truncated_arr
truncated_len = len(truncated_arr)
def __init__(self, optimizer, devices, input_placeholders,
per_device_batch_size, build_loss, logdir):
self.optimizer = optimizer
self.devices = devices
self.batch_size = per_device_batch_size * len(devices)
self.per_device_batch_size = per_device_batch_size
self.input_placeholders = input_placeholders
self.build_loss = build_loss
self.logdir = logdir
if full_trace:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
# First initialize the shared loss network
with tf.variable_scope("tower"):
self._shared_loss = build_loss(*input_placeholders)
sess.run(
[t.init_op for t in self._towers],
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if full_trace:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-load.json"), "w")
trace_file.write(trace.generate_chrome_trace_format())
# Then setup the per-device loss graphs that use the shared weights
self._batch_index = tf.placeholder(tf.int32)
data_splits = zip(
*[tf.split(ph, len(devices)) for ph in input_placeholders])
self._towers = []
for device, device_placeholders in zip(self.devices, data_splits):
self._towers.append(self._setup_device(device,
device_placeholders))
tuples_per_device = truncated_len / len(self.devices)
assert tuples_per_device % self.per_device_batch_size == 0
return tuples_per_device
avg = average_gradients([t.grads for t in self._towers])
self._train_op = self.optimizer.apply_gradients(avg)
def optimize(
self, sess, batch_index,
extra_ops=[], extra_feed_dict={}, file_writer=None):
"""Run a single step of SGD.
def load_data(self, sess, inputs, full_trace=False):
"""Bulk loads the specified inputs into device memory.
Runs a SGD step over a slice of the preloaded batch with size given by
self.per_device_batch_size and offset given by the batch_index argument.
The shape of the inputs must conform to the shapes of the input
placeholders this optimizer was constructed with.
Updates shared model weights based on the averaged per-device gradients.
The data is split equally across all the devices. If the data is not
evenly divisible by the batch size, excess data will be discarded.
Args:
sess: TensorFlow session.
batch_index: offset into the preloaded data. This value must be
between `0` and `tuples_per_device`. The amount of data
to process is always fixed to `per_device_batch_size`.
extra_ops: extra ops to run with this step (e.g. for metrics).
extra_feed_dict: extra args to feed into this session run.
file_writer: if specified, tf metrics will be written out using this.
Args:
sess: TensorFlow session.
inputs: List of Tensors matching the input placeholders specified
at construction time of this optimizer.
full_trace: Whether to profile data loading.
Returns:
the outputs of extra_ops evaluated over the batch.
"""
Returns:
The number of tuples loaded per device.
"""
if file_writer:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
feed_dict = {}
assert len(self.input_placeholders) == len(inputs)
for ph, arr in zip(self.input_placeholders, inputs):
truncated_arr = make_divisible_by(arr, self.batch_size)
feed_dict[ph] = truncated_arr
truncated_len = len(truncated_arr)
feed_dict = {self._batch_index: batch_index}
feed_dict.update(extra_feed_dict)
outs = sess.run(
[self._train_op] + extra_ops,
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if full_trace:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
if file_writer:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"), "w")
trace_file.write(trace.generate_chrome_trace_format())
file_writer.add_run_metadata(
run_metadata, "sgd_train_{}".format(batch_index))
sess.run(
[t.init_op for t in self._towers],
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if full_trace:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-load.json"),
"w")
trace_file.write(trace.generate_chrome_trace_format())
return outs[1:]
tuples_per_device = truncated_len / len(self.devices)
assert tuples_per_device % self.per_device_batch_size == 0
return tuples_per_device
def get_common_loss(self):
return self._shared_loss
def optimize(self, sess, batch_index, extra_ops=[], extra_feed_dict={},
file_writer=None):
"""Run a single step of SGD.
def get_device_losses(self):
return [t.loss_object for t in self._towers]
Runs a SGD step over a slice of the preloaded batch with size given by
self.per_device_batch_size and offset given by the batch_index
argument.
def _setup_device(self, device, device_input_placeholders):
with tf.device(device):
with tf.variable_scope("tower", reuse=True):
device_input_batches = []
device_input_slices = []
for ph in device_input_placeholders:
current_batch = tf.Variable(
ph, trainable=False, validate_shape=False, collections=[])
device_input_batches.append(current_batch)
current_slice = tf.slice(
current_batch,
[self._batch_index] + [0] * len(ph.shape[1:]),
[self.per_device_batch_size] + [-1] * len(ph.shape[1:]))
current_slice.set_shape(ph.shape)
device_input_slices.append(current_slice)
device_loss_obj = self.build_loss(*device_input_slices)
device_grads = self.optimizer.compute_gradients(
device_loss_obj.loss, colocate_gradients_with_ops=True)
return Tower(
tf.group(*[batch.initializer for batch in device_input_batches]),
device_grads,
device_loss_obj)
Updates shared model weights based on the averaged per-device
gradients.
Args:
sess: TensorFlow session.
batch_index: Offset into the preloaded data. This value must be
between `0` and `tuples_per_device`. The amount of data to
process is always fixed to `per_device_batch_size`.
extra_ops: Extra ops to run with this step (e.g. for metrics).
extra_feed_dict: Extra args to feed into this session run.
file_writer: If specified, tf metrics will be written out using
this.
Returns:
The outputs of extra_ops evaluated over the batch.
"""
if file_writer:
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
else:
run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
run_metadata = tf.RunMetadata()
feed_dict = {self._batch_index: batch_index}
feed_dict.update(extra_feed_dict)
outs = sess.run(
[self._train_op] + extra_ops,
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata)
if file_writer:
trace = timeline.Timeline(step_stats=run_metadata.step_stats)
trace_file = open(os.path.join(self.logdir, "timeline-sgd.json"),
"w")
trace_file.write(trace.generate_chrome_trace_format())
file_writer.add_run_metadata(
run_metadata, "sgd_train_{}".format(batch_index))
return outs[1:]
def get_common_loss(self):
return self._shared_loss
def get_device_losses(self):
return [t.loss_object for t in self._towers]
def _setup_device(self, device, device_input_placeholders):
with tf.device(device):
with tf.variable_scope("tower", reuse=True):
device_input_batches = []
device_input_slices = []
for ph in device_input_placeholders:
current_batch = tf.Variable(
ph, trainable=False, validate_shape=False,
collections=[])
device_input_batches.append(current_batch)
current_slice = tf.slice(
current_batch,
[self._batch_index] + [0] * len(ph.shape[1:]),
([self.per_device_batch_size] + [-1] *
len(ph.shape[1:])))
current_slice.set_shape(ph.shape)
device_input_slices.append(current_slice)
device_loss_obj = self.build_loss(*device_input_slices)
device_grads = self.optimizer.compute_gradients(
device_loss_obj.loss, colocate_gradients_with_ops=True)
return Tower(
tf.group(*[batch.initializer
for batch in device_input_batches]),
device_grads,
device_loss_obj)
# Each tower is a copy of the loss graph pinned to a specific device.
@@ -200,50 +203,51 @@ Tower = namedtuple("Tower", ["init_op", "grads", "loss_object"])
def make_divisible_by(array, n):
return array[0:array.shape[0] - array.shape[0] % n]
return array[0:array.shape[0] - array.shape[0] % n]
def average_gradients(tower_grads):
"""Averages gradients across towers.
"""Averages gradients across towers.
Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer
list is over individual gradients. The inner list is over the
gradient calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
Returns:
List of pairs of (gradient, variable) where the gradient has been
averaged across all towers.
TODO(ekl): We could use NCCL if this becomes a bottleneck.
"""
TODO(ekl): We could use NCCL if this becomes a bottleneck.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
if g is not None:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
if g is not None:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Append on a 'tower' dimension which we will average over
# below.
grads.append(expanded_g)
# Average over the 'tower' dimension.
grad = tf.concat(axis=0, values=grads)
grad = tf.reduce_mean(grad, 0)
# Average over the 'tower' dimension.
grad = tf.concat(axis=0, values=grads)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
return average_grads
+121 -113
View File
@@ -25,133 +25,141 @@ from ray.rllib.policy_gradient.rollout import rollouts, add_advantage_values
class Agent(object):
"""
Agent class that holds the simulator environment and the policy.
"""
Agent class that holds the simulator environment and the policy.
Initializes the tensorflow graphs for both training and evaluation.
One common policy graph is initialized on '/cpu:0' and holds all the shared
network weights. When run as a remote agent, only this graph is used.
"""
Initializes the tensorflow graphs for both training and evaluation.
One common policy graph is initialized on '/cpu:0' and holds all the shared
network weights. When run as a remote agent, only this graph is used.
"""
def __init__(self, name, batchsize, preprocessor, config, logdir, is_remote):
if is_remote:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
devices = ["/cpu:0"]
else:
devices = config["devices"]
self.devices = devices
self.config = config
self.logdir = logdir
self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor)
if preprocessor.shape is None:
preprocessor.shape = self.env.observation_space.shape
if is_remote:
config_proto = tf.ConfigProto()
else:
config_proto = tf.ConfigProto(**config["tf_session_args"])
self.preprocessor = preprocessor
self.sess = tf.Session(config=config_proto)
if config["use_tf_debugger"] and not is_remote:
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
self.sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
def __init__(self, name, batchsize, preprocessor, config, logdir,
is_remote):
if is_remote:
os.environ["CUDA_VISIBLE_DEVICES"] = ""
devices = ["/cpu:0"]
else:
devices = config["devices"]
self.devices = devices
self.config = config
self.logdir = logdir
self.env = BatchedEnv(name, batchsize, preprocessor=preprocessor)
if preprocessor.shape is None:
preprocessor.shape = self.env.observation_space.shape
if is_remote:
config_proto = tf.ConfigProto()
else:
config_proto = tf.ConfigProto(**config["tf_session_args"])
self.preprocessor = preprocessor
self.sess = tf.Session(config=config_proto)
if config["use_tf_debugger"] and not is_remote:
self.sess = tf_debug.LocalCLIDebugWrapperSession(self.sess)
self.sess.add_tensor_filter("has_inf_or_nan",
tf_debug.has_inf_or_nan)
# Defines the training inputs.
self.kl_coeff = tf.placeholder(name="newkl", shape=(), dtype=tf.float32)
self.observations = tf.placeholder(tf.float32,
shape=(None,) + preprocessor.shape)
self.advantages = tf.placeholder(tf.float32, shape=(None,))
# Defines the training inputs.
self.kl_coeff = tf.placeholder(name="newkl", shape=(),
dtype=tf.float32)
self.observations = tf.placeholder(tf.float32,
shape=(None,) + preprocessor.shape)
self.advantages = tf.placeholder(tf.float32, shape=(None,))
action_space = self.env.action_space
if isinstance(action_space, gym.spaces.Box):
# The first half of the dimensions are the means, the second half are the
# standard deviations.
self.action_dim = action_space.shape[0]
self.action_shape = (self.action_dim,)
self.logit_dim = 2 * self.action_dim
self.actions = tf.placeholder(tf.float32, shape=(None, self.action_dim))
self.distribution_class = DiagGaussian
elif isinstance(action_space, gym.spaces.Discrete):
self.action_dim = action_space.n
self.action_shape = ()
self.logit_dim = self.action_dim
self.actions = tf.placeholder(tf.int64, shape=(None,))
self.distribution_class = Categorical
else:
raise NotImplemented("action space" + str(type(action_space)) +
"currently not supported")
self.prev_logits = tf.placeholder(tf.float32, shape=(None, self.logit_dim))
action_space = self.env.action_space
if isinstance(action_space, gym.spaces.Box):
# The first half of the dimensions are the means, the second half
# are the standard deviations.
self.action_dim = action_space.shape[0]
self.action_shape = (self.action_dim,)
self.logit_dim = 2 * self.action_dim
self.actions = tf.placeholder(tf.float32,
shape=(None, self.action_dim))
self.distribution_class = DiagGaussian
elif isinstance(action_space, gym.spaces.Discrete):
self.action_dim = action_space.n
self.action_shape = ()
self.logit_dim = self.action_dim
self.actions = tf.placeholder(tf.int64, shape=(None,))
self.distribution_class = Categorical
else:
raise NotImplemented("action space" + str(type(action_space)) +
"currently not supported")
self.prev_logits = tf.placeholder(tf.float32,
shape=(None, self.logit_dim))
assert config["sgd_batchsize"] % len(devices) == 0, \
"Batch size must be evenly divisible by devices"
if is_remote:
self.batch_size = 1
self.per_device_batch_size = 1
else:
self.batch_size = config["sgd_batchsize"]
self.per_device_batch_size = int(self.batch_size / len(devices))
assert config["sgd_batchsize"] % len(devices) == 0, \
"Batch size must be evenly divisible by devices"
if is_remote:
self.batch_size = 1
self.per_device_batch_size = 1
else:
self.batch_size = config["sgd_batchsize"]
self.per_device_batch_size = int(self.batch_size / len(devices))
def build_loss(obs, advs, acts, plog):
return ProximalPolicyLoss(
self.env.observation_space, self.env.action_space,
obs, advs, acts, plog, self.logit_dim,
self.kl_coeff, self.distribution_class, self.config, self.sess)
def build_loss(obs, advs, acts, plog):
return ProximalPolicyLoss(
self.env.observation_space, self.env.action_space,
obs, advs, acts, plog, self.logit_dim,
self.kl_coeff, self.distribution_class, self.config, self.sess)
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
self.devices,
[self.observations, self.advantages, self.actions, self.prev_logits],
self.per_device_batch_size,
build_loss,
self.logdir)
self.par_opt = LocalSyncParallelOptimizer(
tf.train.AdamOptimizer(self.config["sgd_stepsize"]),
self.devices,
[self.observations, self.advantages, self.actions,
self.prev_logits],
self.per_device_batch_size,
build_loss,
self.logdir)
# Metric ops
with tf.name_scope("test_outputs"):
policies = self.par_opt.get_device_losses()
self.mean_loss = tf.reduce_mean(
tf.stack(values=[policy.loss for policy in policies]), 0)
self.mean_kl = tf.reduce_mean(
tf.stack(values=[policy.mean_kl for policy in policies]), 0)
self.mean_entropy = tf.reduce_mean(
tf.stack(values=[policy.mean_entropy for policy in policies]), 0)
# Metric ops
with tf.name_scope("test_outputs"):
policies = self.par_opt.get_device_losses()
self.mean_loss = tf.reduce_mean(
tf.stack(values=[policy.loss for policy in policies]), 0)
self.mean_kl = tf.reduce_mean(
tf.stack(values=[policy.mean_kl for policy in policies]), 0)
self.mean_entropy = tf.reduce_mean(
tf.stack(values=[policy.mean_entropy for policy in policies]),
0)
# References to the model weights
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss,
self.sess)
self.observation_filter = MeanStdFilter(preprocessor.shape, clip=None)
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())
# References to the model weights
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss,
self.sess)
self.observation_filter = MeanStdFilter(preprocessor.shape, clip=None)
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())
def load_data(self, trajectories, full_trace):
return self.par_opt.load_data(
self.sess,
[trajectories["observations"],
trajectories["advantages"],
trajectories["actions"].squeeze(),
trajectories["logprobs"]],
full_trace=full_trace)
def load_data(self, trajectories, full_trace):
return self.par_opt.load_data(
self.sess,
[trajectories["observations"],
trajectories["advantages"],
trajectories["actions"].squeeze(),
trajectories["logprobs"]],
full_trace=full_trace)
def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace, file_writer):
return self.par_opt.optimize(
self.sess,
batch_index,
extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy],
extra_feed_dict={self.kl_coeff: kl_coeff},
file_writer=file_writer if full_trace else None)
def run_sgd_minibatch(self, batch_index, kl_coeff, full_trace,
file_writer):
return self.par_opt.optimize(
self.sess,
batch_index,
extra_ops=[self.mean_loss, self.mean_kl, self.mean_entropy],
extra_feed_dict={self.kl_coeff: kl_coeff},
file_writer=file_writer if full_trace else None)
def get_weights(self):
return self.variables.get_weights()
def get_weights(self):
return self.variables.get_weights()
def load_weights(self, weights):
self.variables.set_weights(weights)
def load_weights(self, weights):
self.variables.set_weights(weights)
def compute_trajectory(self, gamma, lam, horizon):
trajectory = rollouts(
self.common_policy,
self.env, horizon, self.observation_filter, self.reward_filter)
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
return trajectory
def compute_trajectory(self, gamma, lam, horizon):
trajectory = rollouts(
self.common_policy,
self.env, horizon, self.observation_filter, self.reward_filter)
add_advantage_values(trajectory, gamma, lam, self.reward_filter)
return trajectory
RemoteAgent = ray.remote(Agent)
@@ -7,63 +7,63 @@ import numpy as np
class Categorical(object):
def __init__(self, logits):
self.logits = logits
def __init__(self, logits):
self.logits = logits
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits,
labels=x)
def logp(self, x):
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=self.logits, labels=x)
def entropy(self):
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
keep_dims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1])
def entropy(self):
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
keep_dims=True)
ea0 = tf.exp(a0)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (tf.log(z0) - a0), reduction_indices=[1])
def kl(self, other):
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
keep_dims=True)
a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1],
keep_dims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)),
reduction_indices=[1])
def kl(self, other):
a0 = self.logits - tf.reduce_max(self.logits, reduction_indices=[1],
keep_dims=True)
a1 = other.logits - tf.reduce_max(other.logits, reduction_indices=[1],
keep_dims=True)
ea0 = tf.exp(a0)
ea1 = tf.exp(a1)
z0 = tf.reduce_sum(ea0, reduction_indices=[1], keep_dims=True)
z1 = tf.reduce_sum(ea1, reduction_indices=[1], keep_dims=True)
p0 = ea0 / z0
return tf.reduce_sum(p0 * (a0 - tf.log(z0) - a1 + tf.log(z1)),
reduction_indices=[1])
def sample(self):
return tf.multinomial(self.logits, 1)
def sample(self):
return tf.multinomial(self.logits, 1)
class DiagGaussian(object):
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(flat, 2, axis=1)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
def __init__(self, flat):
self.flat = flat
mean, logstd = tf.split(flat, 2, axis=1)
self.mean = mean
self.logstd = logstd
self.std = tf.exp(logstd)
def logp(self, x):
return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std),
reduction_indices=[1]) -
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
tf.reduce_sum(self.logstd, reduction_indices=[1]))
def logp(self, x):
return (-0.5 * tf.reduce_sum(tf.square((x - self.mean) / self.std),
reduction_indices=[1]) -
0.5 * np.log(2.0 * np.pi) * tf.to_float(tf.shape(x)[1]) -
tf.reduce_sum(self.logstd, reduction_indices=[1]))
def kl(self, other):
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(other.logstd - self.logstd +
(tf.square(self.std) +
tf.square(self.mean - other.mean)) /
(2.0 * tf.square(other.std)) - 0.5,
reduction_indices=[1])
def kl(self, other):
assert isinstance(other, DiagGaussian)
return tf.reduce_sum(other.logstd - self.logstd +
(tf.square(self.std) +
tf.square(self.mean - other.mean)) /
(2.0 * tf.square(other.std)) - 0.5,
reduction_indices=[1])
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e),
reduction_indices=[1])
def entropy(self):
return tf.reduce_sum(self.logstd + .5 * np.log(2.0 * np.pi * np.e),
reduction_indices=[1])
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
def sample(self):
return self.mean + self.std * tf.random_normal(tf.shape(self.mean))
+43 -42
View File
@@ -7,59 +7,60 @@ import numpy as np
class AtariPixelPreprocessor(object):
def __init__(self):
self.shape = (80, 80, 3)
def __init__(self):
self.shape = (80, 80, 3)
def __call__(self, observation):
"Convert images from (210, 160, 3) to (3, 80, 80) by downsampling."
return (observation[25:-25:2, ::2, :][None] - 128) / 128
def __call__(self, observation):
"Convert images from (210, 160, 3) to (3, 80, 80) by downsampling."
return (observation[25:-25:2, ::2, :][None] - 128) / 128
class AtariRamPreprocessor(object):
def __init__(self):
self.shape = (128,)
def __init__(self):
self.shape = (128,)
def __call__(self, observation):
return (observation - 128) / 128
def __call__(self, observation):
return (observation - 128) / 128
class NoPreprocessor(object):
def __init__(self):
self.shape = None
def __init__(self):
self.shape = None
def __call__(self, observation):
return observation
def __call__(self, observation):
return observation
class BatchedEnv(object):
"""This holds multiple gym enviroments and performs steps on all of them."""
def __init__(self, name, batchsize, preprocessor=None):
self.envs = [gym.make(name) for _ in range(batchsize)]
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
self.batchsize = batchsize
self.preprocessor = preprocessor if preprocessor else lambda obs: obs[None]
"""This holds multiple gym envs and performs steps on all of them."""
def __init__(self, name, batchsize, preprocessor=None):
self.envs = [gym.make(name) for _ in range(batchsize)]
self.observation_space = self.envs[0].observation_space
self.action_space = self.envs[0].action_space
self.batchsize = batchsize
self.preprocessor = (preprocessor if preprocessor
else lambda obs: obs[None])
def reset(self):
observations = [self.preprocessor(env.reset()) for env in self.envs]
self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations)
def reset(self):
observations = [self.preprocessor(env.reset()) for env in self.envs]
self.shape = observations[0].shape
self.dones = [False for _ in range(self.batchsize)]
return np.vstack(observations)
def step(self, actions, render=False):
observations = []
rewards = []
for i, action in enumerate(actions):
if self.dones[i]:
observations.append(np.zeros(self.shape))
rewards.append(0.0)
continue
observation, reward, done, info = self.envs[i].step(
action if len(action) > 1 else action[0])
if render:
self.envs[0].render()
observations.append(self.preprocessor(observation))
rewards.append(reward)
self.dones[i] = done
return (np.vstack(observations), np.array(rewards, dtype="float32"),
np.array(self.dones))
def step(self, actions, render=False):
observations = []
rewards = []
for i, action in enumerate(actions):
if self.dones[i]:
observations.append(np.zeros(self.shape))
rewards.append(0.0)
continue
observation, reward, done, info = self.envs[i].step(
action if len(action) > 1 else action[0])
if render:
self.envs[0].render()
observations.append(self.preprocessor(observation))
rewards.append(reward)
self.dones[i] = done
return (np.vstack(observations), np.array(rewards, dtype="float32"),
np.array(self.dones))
+21 -21
View File
@@ -11,28 +11,28 @@ from ray.rllib.policy_gradient import PolicyGradient, DEFAULT_CONFIG
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the policy gradient "
"algorithm.")
parser.add_argument("--environment", default="Pong-v0", type=str,
help="The gym environment to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--use-tf-debugger", default=False, type=bool,
help="Run the script inside of tf-dbg.")
parser.add_argument("--load-checkpoint", default=None, type=str,
help="Continue training from a checkpoint.")
parser = argparse.ArgumentParser(description="Run the policy gradient "
"algorithm.")
parser.add_argument("--environment", default="Pong-v0", type=str,
help="The gym environment to use.")
parser.add_argument("--redis-address", default=None, type=str,
help="The Redis address of the cluster.")
parser.add_argument("--use-tf-debugger", default=False, type=bool,
help="Run the script inside of tf-dbg.")
parser.add_argument("--load-checkpoint", default=None, type=str,
help="Continue training from a checkpoint.")
args = parser.parse_args()
config = DEFAULT_CONFIG.copy()
config["use_tf_debugger"] = args.use_tf_debugger
if args.load_checkpoint:
config["load_checkpoint"] = args.load_checkpoint
args = parser.parse_args()
config = DEFAULT_CONFIG.copy()
config["use_tf_debugger"] = args.use_tf_debugger
if args.load_checkpoint:
config["load_checkpoint"] = args.load_checkpoint
ray.init(redis_address=args.redis_address)
ray.init(redis_address=args.redis_address)
alg = PolicyGradient(args.environment, config)
result = alg.train()
while result.training_iteration < config["max_iterations"]:
print("\n== iteration", result.training_iteration)
alg = PolicyGradient(args.environment, config)
result = alg.train()
print("current status: {}".format(result))
while result.training_iteration < config["max_iterations"]:
print("\n== iteration", result.training_iteration)
result = alg.train()
print("current status: {}".format(result))
+97 -97
View File
@@ -6,127 +6,127 @@ import numpy as np
class NoFilter(object):
def __init__(self):
pass
def __init__(self):
pass
def __call__(self, x, update=True):
return np.asarray(x)
def __call__(self, x, update=True):
return np.asarray(x)
# http://www.johndcook.com/blog/standard_deviation/
class RunningStat(object):
def __init__(self, shape=None):
self._n = 0
self._M = np.zeros(shape)
self._S = np.zeros(shape)
def __init__(self, shape=None):
self._n = 0
self._M = np.zeros(shape)
self._S = np.zeros(shape)
def push(self, x):
x = np.asarray(x)
# Unvectorized update of the running statistics.
assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}"
.format(x.shape, self._M.shape))
n1 = self._n
self._n += 1
if self._n == 1:
self._M[...] = x
else:
delta = x - self._M
self._M[...] += delta / self._n
self._S[...] += delta * delta * n1 / self._n
def push(self, x):
x = np.asarray(x)
# Unvectorized update of the running statistics.
assert x.shape == self._M.shape, ("x.shape = {}, self.shape = {}"
.format(x.shape, self._M.shape))
n1 = self._n
self._n += 1
if self._n == 1:
self._M[...] = x
else:
delta = x - self._M
self._M[...] += delta / self._n
self._S[...] += delta * delta * n1 / self._n
def update(self, other):
n1 = self._n
n2 = other._n
n = n1 + n2
delta = self._M - other._M
delta2 = delta * delta
M = (n1 * self._M + n2 * other._M) / n
S = self._S + other._S + delta2 * n1 * n2 / n
self._n = n
self._M = M
self._S = S
def update(self, other):
n1 = self._n
n2 = other._n
n = n1 + n2
delta = self._M - other._M
delta2 = delta * delta
M = (n1 * self._M + n2 * other._M) / n
S = self._S + other._S + delta2 * n1 * n2 / n
self._n = n
self._M = M
self._S = S
@property
def n(self):
return self._n
@property
def n(self):
return self._n
@property
def mean(self):
return self._M
@property
def mean(self):
return self._M
@property
def var(self):
return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
@property
def var(self):
return self._S / (self._n - 1) if self._n > 1 else np.square(self._M)
@property
def std(self):
return np.sqrt(self.var)
@property
def std(self):
return np.sqrt(self.var)
@property
def shape(self):
return self._M.shape
@property
def shape(self):
return self._M.shape
class MeanStdFilter(object):
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.demean = demean
self.destd = destd
self.clip = clip
def __init__(self, shape, demean=True, destd=True, clip=10.0):
self.demean = demean
self.destd = destd
self.clip = clip
self.rs = RunningStat(shape)
self.rs = RunningStat(shape)
def __call__(self, x, update=True):
x = np.asarray(x)
if update:
if len(x.shape) == len(self.rs.shape) + 1:
# The vectorized case.
for i in range(x.shape[0]):
self.rs.push(x[i])
else:
# The unvectorized case.
self.rs.push(x)
if self.demean:
x = x - self.rs.mean
if self.destd:
x = x / (self.rs.std + 1e-8)
if self.clip:
x = np.clip(x, -self.clip, self.clip)
return x
def __call__(self, x, update=True):
x = np.asarray(x)
if update:
if len(x.shape) == len(self.rs.shape) + 1:
# The vectorized case.
for i in range(x.shape[0]):
self.rs.push(x[i])
else:
# The unvectorized case.
self.rs.push(x)
if self.demean:
x = x - self.rs.mean
if self.destd:
x = x / (self.rs.std + 1e-8)
if self.clip:
x = np.clip(x, -self.clip, self.clip)
return x
def test_running_stat():
for shp in ((), (3,), (3, 4)):
li = []
rs = RunningStat(shp)
for _ in range(5):
val = np.random.randn(*shp)
rs.push(val)
li.append(val)
m = np.mean(li, axis=0)
assert np.allclose(rs.mean, m)
v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0)
assert np.allclose(rs.var, v)
for shp in ((), (3,), (3, 4)):
li = []
rs = RunningStat(shp)
for _ in range(5):
val = np.random.randn(*shp)
rs.push(val)
li.append(val)
m = np.mean(li, axis=0)
assert np.allclose(rs.mean, m)
v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0)
assert np.allclose(rs.var, v)
def test_combining_stat():
for shape in [(), (3,), (3, 4)]:
li = []
rs1 = RunningStat(shape)
rs2 = RunningStat(shape)
rs = RunningStat(shape)
for _ in range(5):
val = np.random.randn(*shape)
rs1.push(val)
rs.push(val)
li.append(val)
for _ in range(9):
rs2.push(val)
rs.push(val)
li.append(val)
rs1.update(rs2)
assert np.allclose(rs.mean, rs1.mean)
assert np.allclose(rs.std, rs1.std)
for shape in [(), (3,), (3, 4)]:
li = []
rs1 = RunningStat(shape)
rs2 = RunningStat(shape)
rs = RunningStat(shape)
for _ in range(5):
val = np.random.randn(*shape)
rs1.push(val)
rs.push(val)
li.append(val)
for _ in range(9):
rs2.push(val)
rs.push(val)
li.append(val)
rs1.update(rs2)
assert np.allclose(rs.mean, rs1.mean)
assert np.allclose(rs.std, rs1.std)
test_running_stat()
+35 -35
View File
@@ -10,43 +10,43 @@ from ray.rllib.policy_gradient.models.fcnet import fc_net
class ProximalPolicyLoss(object):
def __init__(
self, observation_space, action_space,
observations, advantages, actions, prev_logits, logit_dim,
kl_coeff, distribution_class, config, sess):
assert (isinstance(action_space, gym.spaces.Discrete) or
isinstance(action_space, gym.spaces.Box))
self.prev_dist = distribution_class(prev_logits)
def __init__(
self, observation_space, action_space,
observations, advantages, actions, prev_logits, logit_dim,
kl_coeff, distribution_class, config, sess):
assert (isinstance(action_space, gym.spaces.Discrete) or
isinstance(action_space, gym.spaces.Box))
self.prev_dist = distribution_class(prev_logits)
# Saved so that we can compute actions given different observations
self.observations = observations
# Saved so that we can compute actions given different observations
self.observations = observations
if len(observation_space.shape) > 1:
self.curr_logits = vision_net(observations, num_classes=logit_dim)
else:
assert len(observation_space.shape) == 1
self.curr_logits = fc_net(observations, num_classes=logit_dim)
self.curr_dist = distribution_class(self.curr_logits)
self.sampler = self.curr_dist.sample()
if len(observation_space.shape) > 1:
self.curr_logits = vision_net(observations, num_classes=logit_dim)
else:
assert len(observation_space.shape) == 1
self.curr_logits = fc_net(observations, num_classes=logit_dim)
self.curr_dist = distribution_class(self.curr_logits)
self.sampler = self.curr_dist.sample()
# Make loss functions.
self.ratio = tf.exp(self.curr_dist.logp(actions) -
self.prev_dist.logp(actions))
self.kl = self.prev_dist.kl(self.curr_dist)
self.mean_kl = tf.reduce_mean(self.kl)
self.entropy = self.curr_dist.entropy()
self.mean_entropy = tf.reduce_mean(self.entropy)
self.surr1 = self.ratio * advantages
self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"],
1 + config["clip_param"]) * advantages
self.surr = tf.minimum(self.surr1, self.surr2)
self.loss = tf.reduce_mean(-self.surr + kl_coeff * self.kl -
config["entropy_coeff"] * self.entropy)
self.sess = sess
# Make loss functions.
self.ratio = tf.exp(self.curr_dist.logp(actions) -
self.prev_dist.logp(actions))
self.kl = self.prev_dist.kl(self.curr_dist)
self.mean_kl = tf.reduce_mean(self.kl)
self.entropy = self.curr_dist.entropy()
self.mean_entropy = tf.reduce_mean(self.entropy)
self.surr1 = self.ratio * advantages
self.surr2 = tf.clip_by_value(self.ratio, 1 - config["clip_param"],
1 + config["clip_param"]) * advantages
self.surr = tf.minimum(self.surr1, self.surr2)
self.loss = tf.reduce_mean(-self.surr + kl_coeff * self.kl -
config["entropy_coeff"] * self.entropy)
self.sess = sess
def compute_actions(self, observations):
return self.sess.run([self.sampler, self.curr_logits],
feed_dict={self.observations: observations})
def compute_actions(self, observations):
return self.sess.run([self.sampler, self.curr_logits],
feed_dict={self.observations: observations})
def loss(self):
return self.loss
def loss(self):
return self.loss
@@ -9,30 +9,30 @@ import numpy as np
def normc_initializer(std=1.0):
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def _initializer(shape, dtype=None, partition_info=None):
out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True))
return tf.constant(out)
return _initializer
def fc_net(inputs, num_classes=10, logstd=False):
with tf.name_scope("fc_net"):
fc1 = slim.fully_connected(inputs, 128,
weights_initializer=normc_initializer(1.0),
scope="fc1")
fc2 = slim.fully_connected(fc1, 128,
weights_initializer=normc_initializer(1.0),
scope="fc2")
fc3 = slim.fully_connected(fc2, 128,
weights_initializer=normc_initializer(1.0),
scope="fc3")
fc4 = slim.fully_connected(fc3, num_classes,
weights_initializer=normc_initializer(0.01),
activation_fn=None, scope="fc4")
if logstd:
logstd = tf.get_variable(name="logstd", shape=[num_classes],
initializer=tf.zeros_initializer)
return tf.concat(1, [fc4, logstd])
else:
return fc4
with tf.name_scope("fc_net"):
fc1 = slim.fully_connected(inputs, 128,
weights_initializer=normc_initializer(1.0),
scope="fc1")
fc2 = slim.fully_connected(fc1, 128,
weights_initializer=normc_initializer(1.0),
scope="fc2")
fc3 = slim.fully_connected(fc2, 128,
weights_initializer=normc_initializer(1.0),
scope="fc3")
fc4 = slim.fully_connected(fc3, num_classes,
weights_initializer=normc_initializer(0.01),
activation_fn=None, scope="fc4")
if logstd:
logstd = tf.get_variable(name="logstd", shape=[num_classes],
initializer=tf.zeros_initializer)
return tf.concat(1, [fc4, logstd])
else:
return fc4
@@ -7,10 +7,10 @@ import tensorflow.contrib.slim as slim
def vision_net(inputs, num_classes=10):
with tf.name_scope("vision_net"):
conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1")
conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2")
fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1")
fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope="fc2")
return tf.squeeze(fc2, [1, 2])
with tf.name_scope("vision_net"):
conv1 = slim.conv2d(inputs, 16, [8, 8], 4, scope="conv1")
conv2 = slim.conv2d(conv1, 32, [4, 4], 2, scope="conv2")
fc1 = slim.conv2d(conv2, 512, [10, 10], padding="VALID", scope="fc1")
fc2 = slim.conv2d(fc1, num_classes, [1, 1], activation_fn=None,
normalizer_fn=None, scope="fc2")
return tf.squeeze(fc2, [1, 2])
@@ -43,165 +43,170 @@ DEFAULT_CONFIG = {
class PolicyGradient(Algorithm):
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "PolicyGradient"})
def __init__(self, env_name, config, upload_dir=None):
config.update({"alg": "PolicyGradient"})
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
Algorithm.__init__(self, env_name, config, upload_dir=upload_dir)
# TODO(ekl) the preprocessor should be associated with the env elsewhere
if self.env_name == "Pong-v0":
preprocessor = AtariPixelPreprocessor()
elif self.env_name == "Pong-ram-v3":
preprocessor = AtariRamPreprocessor()
elif self.env_name == "CartPole-v0":
preprocessor = NoPreprocessor()
elif self.env_name == "Walker2d-v1":
preprocessor = NoPreprocessor()
else:
preprocessor = AtariPixelPreprocessor()
# TODO(ekl): The preprocessor should be associated with the env
# elsewhere.
if self.env_name == "Pong-v0":
preprocessor = AtariPixelPreprocessor()
elif self.env_name == "Pong-ram-v3":
preprocessor = AtariRamPreprocessor()
elif self.env_name == "CartPole-v0":
preprocessor = NoPreprocessor()
elif self.env_name == "Walker2d-v1":
preprocessor = NoPreprocessor()
else:
preprocessor = AtariPixelPreprocessor()
self.preprocessor = preprocessor
self.global_step = 0
self.j = 0
self.kl_coeff = config["kl_coeff"]
self.model = Agent(
self.env_name, 1, self.preprocessor, self.config, self.logdir, False)
self.agents = [
RemoteAgent.remote(
self.env_name, 1, self.preprocessor, self.config,
self.logdir, True)
for _ in range(config["num_agents"])]
self.preprocessor = preprocessor
self.global_step = 0
self.j = 0
self.kl_coeff = config["kl_coeff"]
self.model = Agent(
self.env_name, 1, self.preprocessor, self.config, self.logdir,
False)
self.agents = [
RemoteAgent.remote(
self.env_name, 1, self.preprocessor, self.config,
self.logdir, True)
for _ in range(config["num_agents"])]
def train(self):
agents = self.agents
config = self.config
model = self.model
j = self.j
self.j += 1
def train(self):
agents = self.agents
config = self.config
model = self.model
j = self.j
self.j += 1
saver = tf.train.Saver(max_to_keep=None)
if "load_checkpoint" in config:
saver.restore(model.sess, config["load_checkpoint"])
saver = tf.train.Saver(max_to_keep=None)
if "load_checkpoint" in config:
saver.restore(model.sess, config["load_checkpoint"])
# TF does not support to write logs to S3 at the moment
write_tf_logs = self.logdir.startswith("file")
iter_start = time.time()
if write_tf_logs:
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
if config["model_checkpoint_file"]:
checkpoint_path = saver.save(
model.sess,
os.path.join(self.logdir, config["model_checkpoint_file"] % j))
print("Checkpoint saved in file: %s" % checkpoint_path)
checkpointing_end = time.time()
weights = ray.put(model.get_weights())
[a.load_weights.remote(weights) for a in agents]
trajectory, total_reward, traj_len_mean = collect_samples(
agents, config["timesteps_per_batch"], 0.995, 1.0, 2000)
print("total reward is ", total_reward)
print("trajectory length mean is ", traj_len_mean)
print("timesteps:", trajectory["dones"].shape[0])
if write_tf_logs:
traj_stats = tf.Summary(value=[
tf.Summary.Value(
tag="policy_gradient/rollouts/mean_reward",
simple_value=total_reward),
tf.Summary.Value(
tag="policy_gradient/rollouts/traj_len_mean",
simple_value=traj_len_mean)])
file_writer.add_summary(traj_stats, self.global_step)
self.global_step += 1
trajectory["advantages"] = ((trajectory["advantages"] -
trajectory["advantages"].mean()) /
trajectory["advantages"].std())
rollouts_end = time.time()
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
", stepsize=" + str(config["sgd_stepsize"]) + "):")
names = ["iter", "loss", "kl", "entropy"]
print(("{:>15}" * len(names)).format(*names))
trajectory = shuffle(trajectory)
shuffle_end = time.time()
tuples_per_device = model.load_data(
trajectory, j == 0 and config["full_trace_data_load"])
load_end = time.time()
checkpointing_time = checkpointing_end - iter_start
rollouts_time = rollouts_end - checkpointing_end
shuffle_time = shuffle_end - rollouts_end
load_time = load_end - shuffle_end
sgd_time = 0
for i in range(config["num_sgd_iter"]):
sgd_start = time.time()
batch_index = 0
num_batches = int(tuples_per_device) // int(model.per_device_batch_size)
loss, kl, entropy = [], [], []
permutation = np.random.permutation(num_batches)
while batch_index < num_batches:
full_trace = (
i == 0 and j == 0 and
batch_index == config["full_trace_nth_sgd_batch"])
batch_loss, batch_kl, batch_entropy = model.run_sgd_minibatch(
permutation[batch_index] * model.per_device_batch_size,
self.kl_coeff, full_trace,
file_writer if write_tf_logs else None)
loss.append(batch_loss)
kl.append(batch_kl)
entropy.append(batch_entropy)
batch_index += 1
loss = np.mean(loss)
kl = np.mean(kl)
entropy = np.mean(entropy)
sgd_end = time.time()
print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl, entropy))
# TF does not support to write logs to S3 at the moment
write_tf_logs = self.logdir.startswith("file")
iter_start = time.time()
if write_tf_logs:
file_writer = tf.summary.FileWriter(self.logdir, model.sess.graph)
if config["model_checkpoint_file"]:
checkpoint_path = saver.save(
model.sess,
os.path.join(self.logdir,
config["model_checkpoint_file"] % j))
print("Checkpoint saved in file: %s" % checkpoint_path)
checkpointing_end = time.time()
weights = ray.put(model.get_weights())
[a.load_weights.remote(weights) for a in agents]
trajectory, total_reward, traj_len_mean = collect_samples(
agents, config["timesteps_per_batch"], 0.995, 1.0, 2000)
print("total reward is ", total_reward)
print("trajectory length mean is ", traj_len_mean)
print("timesteps:", trajectory["dones"].shape[0])
if write_tf_logs:
traj_stats = tf.Summary(value=[
tf.Summary.Value(
tag="policy_gradient/rollouts/mean_reward",
simple_value=total_reward),
tf.Summary.Value(
tag="policy_gradient/rollouts/traj_len_mean",
simple_value=traj_len_mean)])
file_writer.add_summary(traj_stats, self.global_step)
self.global_step += 1
trajectory["advantages"] = ((trajectory["advantages"] -
trajectory["advantages"].mean()) /
trajectory["advantages"].std())
rollouts_end = time.time()
print("Computing policy (iterations=" + str(config["num_sgd_iter"]) +
", stepsize=" + str(config["sgd_stepsize"]) + "):")
names = ["iter", "loss", "kl", "entropy"]
print(("{:>15}" * len(names)).format(*names))
trajectory = shuffle(trajectory)
shuffle_end = time.time()
tuples_per_device = model.load_data(
trajectory, j == 0 and config["full_trace_data_load"])
load_end = time.time()
checkpointing_time = checkpointing_end - iter_start
rollouts_time = rollouts_end - checkpointing_end
shuffle_time = shuffle_end - rollouts_end
load_time = load_end - shuffle_end
sgd_time = 0
for i in range(config["num_sgd_iter"]):
sgd_start = time.time()
batch_index = 0
num_batches = (int(tuples_per_device) //
int(model.per_device_batch_size))
loss, kl, entropy = [], [], []
permutation = np.random.permutation(num_batches)
while batch_index < num_batches:
full_trace = (
i == 0 and j == 0 and
batch_index == config["full_trace_nth_sgd_batch"])
batch_loss, batch_kl, batch_entropy = model.run_sgd_minibatch(
permutation[batch_index] * model.per_device_batch_size,
self.kl_coeff, full_trace,
file_writer if write_tf_logs else None)
loss.append(batch_loss)
kl.append(batch_kl)
entropy.append(batch_entropy)
batch_index += 1
loss = np.mean(loss)
kl = np.mean(kl)
entropy = np.mean(entropy)
sgd_end = time.time()
print("{:>15}{:15.5e}{:15.5e}{:15.5e}".format(i, loss, kl,
entropy))
values = []
if i == config["num_sgd_iter"] - 1:
metric_prefix = "policy_gradient/sgd/final_iter/"
values.append(tf.Summary.Value(
tag=metric_prefix + "kl_coeff",
simple_value=self.kl_coeff))
else:
metric_prefix = "policy_gradient/sgd/intermediate_iters/"
values.extend([
tf.Summary.Value(
tag=metric_prefix + "mean_entropy",
simple_value=entropy),
tf.Summary.Value(
tag=metric_prefix + "mean_loss",
simple_value=loss),
tf.Summary.Value(
tag=metric_prefix + "mean_kl",
simple_value=kl)])
if write_tf_logs:
sgd_stats = tf.Summary(value=values)
file_writer.add_summary(sgd_stats, self.global_step)
self.global_step += 1
sgd_time += sgd_end - sgd_start
if kl > 2.0 * config["kl_target"]:
self.kl_coeff *= 1.5
elif kl < 0.5 * config["kl_target"]:
self.kl_coeff *= 0.5
values = []
if i == config["num_sgd_iter"] - 1:
metric_prefix = "policy_gradient/sgd/final_iter/"
values.append(tf.Summary.Value(
tag=metric_prefix + "kl_coeff",
simple_value=self.kl_coeff))
else:
metric_prefix = "policy_gradient/sgd/intermediate_iters/"
values.extend([
tf.Summary.Value(
tag=metric_prefix + "mean_entropy",
simple_value=entropy),
tf.Summary.Value(
tag=metric_prefix + "mean_loss",
simple_value=loss),
tf.Summary.Value(
tag=metric_prefix + "mean_kl",
simple_value=kl)])
if write_tf_logs:
sgd_stats = tf.Summary(value=values)
file_writer.add_summary(sgd_stats, self.global_step)
self.global_step += 1
sgd_time += sgd_end - sgd_start
if kl > 2.0 * config["kl_target"]:
self.kl_coeff *= 1.5
elif kl < 0.5 * config["kl_target"]:
self.kl_coeff *= 0.5
info = {
"kl_divergence": kl,
"kl_coefficient": self.kl_coeff,
"checkpointing_time": checkpointing_time,
"rollouts_time": rollouts_time,
"shuffle_time": shuffle_time,
"load_time": load_time,
"sgd_time": sgd_time,
"sample_throughput": len(trajectory["observations"]) / sgd_time
}
info = {
"kl_divergence": kl,
"kl_coefficient": self.kl_coeff,
"checkpointing_time": checkpointing_time,
"rollouts_time": rollouts_time,
"shuffle_time": shuffle_time,
"load_time": load_time,
"sgd_time": sgd_time,
"sample_throughput": len(trajectory["observations"]) / sgd_time
}
print("kl div:", kl)
print("kl coeff:", self.kl_coeff)
print("checkpointing time:", checkpointing_time)
print("rollouts time:", rollouts_time)
print("shuffle time:", shuffle_time)
print("load time:", load_time)
print("sgd time:", sgd_time)
print("sgd examples/s:", len(trajectory["observations"]) / sgd_time)
print("kl div:", kl)
print("kl coeff:", self.kl_coeff)
print("checkpointing time:", checkpointing_time)
print("rollouts time:", rollouts_time)
print("shuffle time:", shuffle_time)
print("load time:", load_time)
print("sgd time:", sgd_time)
print("sgd examples/s:", len(trajectory["observations"]) / sgd_time)
result = TrainingResult(
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
result = TrainingResult(
self.experiment_id.hex, j, total_reward, traj_len_mean, info)
return result
return result
+73 -71
View File
@@ -11,93 +11,95 @@ from ray.rllib.policy_gradient.utils import flatten, concatenate
def rollouts(policy, env, horizon, observation_filter=NoFilter(),
reward_filter=NoFilter()):
"""Perform a batch of rollouts of a policy in an environment.
"""Perform a batch of rollouts of a policy in an environment.
Args:
policy: The policy that will be rollout out. Can be an arbitrary object
that supports a compute_actions(observation) function.
env: The environment the rollout is computed in. Needs to support the
OpenAI gym API and needs to support batches of data.
horizon: Upper bound for the number of timesteps for each rollout in the
batch.
observation_filter: Function that is applied to each of the observations.
reward_filter: Function that is applied to each of the rewards.
Args:
policy: The policy that will be rollout out. Can be an arbitrary object
that supports a compute_actions(observation) function.
env: The environment the rollout is computed in. Needs to support the
OpenAI gym API and needs to support batches of data.
horizon: Upper bound for the number of timesteps for each rollout in
the batch.
observation_filter: Function that is applied to each of the
observations.
reward_filter: Function that is applied to each of the rewards.
Returns:
A trajectory, which is a dictionary with keys "observations", "rewards",
"orig_rewards", "actions", "logprobs", "dones". Each value is an array of
shape (num_timesteps, env.batchsize, shape).
"""
Returns:
A trajectory, which is a dictionary with keys "observations",
"rewards", "orig_rewards", "actions", "logprobs", "dones". Each
value is an array of shape (num_timesteps, env.batchsize, shape).
"""
observation = observation_filter(env.reset())
done = np.array(env.batchsize * [False])
t = 0
observations = []
raw_rewards = [] # Empirical rewards
actions = []
logprobs = []
dones = []
observation = observation_filter(env.reset())
done = np.array(env.batchsize * [False])
t = 0
observations = []
raw_rewards = [] # Empirical rewards
actions = []
logprobs = []
dones = []
while not done.all() and t < horizon:
action, logprob = policy.compute_actions(observation)
observations.append(observation[None])
actions.append(action[None])
logprobs.append(logprob[None])
observation, raw_reward, done = env.step(action)
observation = observation_filter(observation)
raw_rewards.append(raw_reward[None])
dones.append(done[None])
t += 1
while not done.all() and t < horizon:
action, logprob = policy.compute_actions(observation)
observations.append(observation[None])
actions.append(action[None])
logprobs.append(logprob[None])
observation, raw_reward, done = env.step(action)
observation = observation_filter(observation)
raw_rewards.append(raw_reward[None])
dones.append(done[None])
t += 1
return {"observations": np.vstack(observations),
"raw_rewards": np.vstack(raw_rewards),
"actions": np.vstack(actions),
"logprobs": np.vstack(logprobs),
"dones": np.vstack(dones)}
return {"observations": np.vstack(observations),
"raw_rewards": np.vstack(raw_rewards),
"actions": np.vstack(actions),
"logprobs": np.vstack(logprobs),
"dones": np.vstack(dones)}
def add_advantage_values(trajectory, gamma, lam, reward_filter):
rewards = trajectory["raw_rewards"]
dones = trajectory["dones"]
advantages = np.zeros_like(rewards)
last_advantage = np.zeros(rewards.shape[1], dtype="float32")
rewards = trajectory["raw_rewards"]
dones = trajectory["dones"]
advantages = np.zeros_like(rewards)
last_advantage = np.zeros(rewards.shape[1], dtype="float32")
for t in reversed(range(len(rewards))):
delta = rewards[t, :] * (1 - dones[t, :])
last_advantage = delta + gamma * lam * last_advantage
advantages[t, :] = last_advantage
reward_filter(advantages[t, :])
for t in reversed(range(len(rewards))):
delta = rewards[t, :] * (1 - dones[t, :])
last_advantage = delta + gamma * lam * last_advantage
advantages[t, :] = last_advantage
reward_filter(advantages[t, :])
trajectory["advantages"] = advantages
trajectory["advantages"] = advantages
@ray.remote
def compute_trajectory(policy, env, gamma, lam, horizon, observation_filter,
reward_filter):
trajectory = rollouts(policy, env, horizon, observation_filter,
reward_filter)
add_advantage_values(trajectory, gamma, lam, reward_filter)
return trajectory
trajectory = rollouts(policy, env, horizon, observation_filter,
reward_filter)
add_advantage_values(trajectory, gamma, lam, reward_filter)
return trajectory
def collect_samples(agents, num_timesteps, gamma, lam, horizon,
observation_filter=NoFilter(), reward_filter=NoFilter()):
num_timesteps_so_far = 0
trajectories = []
total_rewards = []
traj_len_means = []
while num_timesteps_so_far < num_timesteps:
trajectory_batch = ray.get(
[agent.compute_trajectory.remote(gamma, lam, horizon)
for agent in agents])
trajectory = concatenate(trajectory_batch)
trajectory = flatten(trajectory)
not_done = np.logical_not(trajectory["dones"])
total_rewards.append(
trajectory["raw_rewards"][not_done].sum(axis=0).mean() / len(agents))
traj_len_means.append(not_done.sum(axis=0).mean() / len(agents))
trajectory = {key: val[not_done] for key, val in trajectory.items()}
num_timesteps_so_far += len(trajectory["dones"])
trajectories.append(trajectory)
return (concatenate(trajectories), np.mean(total_rewards),
np.mean(traj_len_means))
num_timesteps_so_far = 0
trajectories = []
total_rewards = []
traj_len_means = []
while num_timesteps_so_far < num_timesteps:
trajectory_batch = ray.get(
[agent.compute_trajectory.remote(gamma, lam, horizon)
for agent in agents])
trajectory = concatenate(trajectory_batch)
trajectory = flatten(trajectory)
not_done = np.logical_not(trajectory["dones"])
total_rewards.append(
trajectory["raw_rewards"][not_done].sum(axis=0).mean() /
len(agents))
traj_len_means.append(not_done.sum(axis=0).mean() / len(agents))
trajectory = {key: val[not_done] for key, val in trajectory.items()}
num_timesteps_so_far += len(trajectory["dones"])
trajectories.append(trajectory)
return (concatenate(trajectories), np.mean(total_rewards),
np.mean(traj_len_means))
+37 -37
View File
@@ -13,49 +13,49 @@ from ray.rllib.policy_gradient.utils import flatten, concatenate
class DistibutionsTest(unittest.TestCase):
def testCategorical(self):
num_samples = 100000
logits = tf.placeholder(tf.float32, shape=(None, 10))
z = 8 * (np.random.rand(10) - 0.5)
data = np.tile(z, (num_samples, 1))
c = Categorical(logits)
sample_op = c.sample()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
samples = sess.run(sample_op, feed_dict={logits: data})
counts = np.zeros(10)
for sample in samples:
counts[sample] += 1.0
probs = np.exp(z) / np.sum(np.exp(z))
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
def testCategorical(self):
num_samples = 100000
logits = tf.placeholder(tf.float32, shape=(None, 10))
z = 8 * (np.random.rand(10) - 0.5)
data = np.tile(z, (num_samples, 1))
c = Categorical(logits)
sample_op = c.sample()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
samples = sess.run(sample_op, feed_dict={logits: data})
counts = np.zeros(10)
for sample in samples:
counts[sample] += 1.0
probs = np.exp(z) / np.sum(np.exp(z))
self.assertTrue(np.sum(np.abs(probs - counts / num_samples)) <= 0.01)
class UtilsTest(unittest.TestCase):
def testFlatten(self):
d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]),
"a": np.array([[[5], [-5]], [[6], [-6]]])}
flat = flatten(d.copy(), start=0, stop=2)
assert_allclose(d["s"][0][0][:], flat["s"][0][:])
assert_allclose(d["s"][0][1][:], flat["s"][1][:])
assert_allclose(d["s"][1][0][:], flat["s"][2][:])
assert_allclose(d["s"][1][1][:], flat["s"][3][:])
assert_allclose(d["a"][0][0], flat["a"][0])
assert_allclose(d["a"][0][1], flat["a"][1])
assert_allclose(d["a"][1][0], flat["a"][2])
assert_allclose(d["a"][1][1], flat["a"][3])
def testFlatten(self):
d = {"s": np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]]),
"a": np.array([[[5], [-5]], [[6], [-6]]])}
flat = flatten(d.copy(), start=0, stop=2)
assert_allclose(d["s"][0][0][:], flat["s"][0][:])
assert_allclose(d["s"][0][1][:], flat["s"][1][:])
assert_allclose(d["s"][1][0][:], flat["s"][2][:])
assert_allclose(d["s"][1][1][:], flat["s"][3][:])
assert_allclose(d["a"][0][0], flat["a"][0])
assert_allclose(d["a"][0][1], flat["a"][1])
assert_allclose(d["a"][1][0], flat["a"][2])
assert_allclose(d["a"][1][1], flat["a"][3])
def testConcatenate(self):
d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])}
d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])}
d = concatenate([d1, d2])
assert_allclose(d["s"], np.array([0, 1, 4, 5]))
assert_allclose(d["a"], np.array([2, 3, 6, 7]))
def testConcatenate(self):
d1 = {"s": np.array([0, 1]), "a": np.array([2, 3])}
d2 = {"s": np.array([4, 5]), "a": np.array([6, 7])}
d = concatenate([d1, d2])
assert_allclose(d["s"], np.array([0, 1, 4, 5]))
assert_allclose(d["a"], np.array([2, 3, 6, 7]))
D = concatenate([d])
assert_allclose(D["s"], np.array([0, 1, 4, 5]))
assert_allclose(D["a"], np.array([2, 3, 6, 7]))
D = concatenate([d])
assert_allclose(D["s"], np.array([0, 1, 4, 5]))
assert_allclose(D["a"], np.array([2, 3, 6, 7]))
if __name__ == "__main__":
unittest.main(verbosity=2)
unittest.main(verbosity=2)
+20 -20
View File
@@ -6,31 +6,31 @@ import numpy as np
def flatten(weights, start=0, stop=2):
"""This methods reshapes all values in a dictionary.
"""This methods reshapes all values in a dictionary.
The indices from start to stop will be flattened into a single index.
The indices from start to stop will be flattened into a single index.
Args:
weights: A dictionary mapping keys to numpy arrays.
start: The starting index.
stop: The ending index.
"""
for key, val in weights.items():
new_shape = val.shape[0:start] + (-1,) + val.shape[stop:]
weights[key] = val.reshape(new_shape)
return weights
Args:
weights: A dictionary mapping keys to numpy arrays.
start: The starting index.
stop: The ending index.
"""
for key, val in weights.items():
new_shape = val.shape[0:start] + (-1,) + val.shape[stop:]
weights[key] = val.reshape(new_shape)
return weights
def concatenate(weights_list):
keys = weights_list[0].keys()
result = {}
for key in keys:
result[key] = np.concatenate([l[key] for l in weights_list])
return result
keys = weights_list[0].keys()
result = {}
for key in keys:
result[key] = np.concatenate([l[key] for l in weights_list])
return result
def shuffle(trajectory):
permutation = np.random.permutation(trajectory["dones"].shape[0])
for key, val in trajectory.items():
trajectory[key] = val[permutation]
return trajectory
permutation = np.random.permutation(trajectory["dones"].shape[0])
for key, val in trajectory.items():
trajectory[key] = val[permutation]
return trajectory
+28 -28
View File
@@ -22,36 +22,36 @@ parser.add_argument("--upload-dir", default="file:///tmp/ray", type=str)
if __name__ == "__main__":
args = parser.parse_args()
args = parser.parse_args()
ray.init()
ray.init()
env_name = args.env
if args.alg == "PolicyGradient":
alg = pg.PolicyGradient(
env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir)
elif args.alg == "EvolutionStrategies":
alg = es.EvolutionStrategies(
env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir)
elif args.alg == "DQN":
alg = dqn.DQN(
env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir)
elif args.alg == "A3C":
alg = a3c.A3C(
env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir)
else:
assert False, ("Unknown algorithm, check --alg argument. Valid choices "
"are PolicyGradientPolicyGradient, EvolutionStrategies, "
"DQN and A3C.")
env_name = args.env
if args.alg == "PolicyGradient":
alg = pg.PolicyGradient(
env_name, pg.DEFAULT_CONFIG, upload_dir=args.upload_dir)
elif args.alg == "EvolutionStrategies":
alg = es.EvolutionStrategies(
env_name, es.DEFAULT_CONFIG, upload_dir=args.upload_dir)
elif args.alg == "DQN":
alg = dqn.DQN(
env_name, dqn.DEFAULT_CONFIG, upload_dir=args.upload_dir)
elif args.alg == "A3C":
alg = a3c.A3C(
env_name, a3c.DEFAULT_CONFIG, upload_dir=args.upload_dir)
else:
assert False, ("Unknown algorithm, check --alg argument. Valid "
"choices are PolicyGradientPolicyGradient, "
"EvolutionStrategies, DQN and A3C.")
result_logger = ray.rllib.common.RLLibLogger(
os.path.join(alg.logdir, "result.json"))
result_logger = ray.rllib.common.RLLibLogger(
os.path.join(alg.logdir, "result.json"))
while True:
result = alg.train()
while True:
result = alg.train()
# We need to use a custom json serializer class so that NaNs get encoded
# as null as required by Athena.
json.dump(result._asdict(), result_logger,
cls=ray.rllib.common.RLLibEncoder)
result_logger.write("\n")
# We need to use a custom json serializer class so that NaNs get
# encoded as null as required by Athena.
json.dump(result._asdict(), result_logger,
cls=ray.rllib.common.RLLibEncoder)
result_logger.write("\n")
+128 -125
View File
@@ -10,34 +10,35 @@ import ray.services as services
def check_no_existing_redis_clients(node_ip_address, redis_address):
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
# Filter to clients on the same node and do some basic checking.
for key in client_keys:
info = redis_client.hgetall(key)
assert b"ray_client_id" in info
assert b"node_ip_address" in info
assert b"client_type" in info
assert b"deleted" in info
# Clients that ran on the same node but that are marked dead can be
# ignored.
deleted = info[b"deleted"]
deleted = bool(int(deleted))
if deleted:
continue
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
# Filter to clients on the same node and do some basic checking.
for key in client_keys:
info = redis_client.hgetall(key)
assert b"ray_client_id" in info
assert b"node_ip_address" in info
assert b"client_type" in info
assert b"deleted" in info
# Clients that ran on the same node but that are marked dead can be
# ignored.
deleted = info[b"deleted"]
deleted = bool(int(deleted))
if deleted:
continue
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
raise Exception("This Redis instance is already connected to clients "
"with this IP address.")
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
raise Exception("This Redis instance is already connected to "
"clients with this IP address.")
@click.group()
def cli():
pass
pass
@click.command()
@@ -64,120 +65,122 @@ def cli():
help="provide this argument to block forever in this command")
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
object_manager_port, num_workers, num_cpus, num_gpus, head, block):
# Note that we redirect stdout and stderr to /dev/null because otherwise
# attempts to print may cause exceptions if a process is started inside of an
# SSH connection and the SSH connection dies. TODO(rkn): This is a temporary
# fix. We should actually redirect stdout and stderr to Redis in some way.
# Note that we redirect stdout and stderr to /dev/null because otherwise
# attempts to print may cause exceptions if a process is started inside of
# an SSH connection and the SSH connection dies. TODO(rkn): This is a
# temporary fix. We should actually redirect stdout and stderr to Redis in
# some way.
if head:
# Start Ray on the head node.
if redis_address is not None:
raise Exception("If --head is passed in, a Redis server will be "
"started, so a Redis address should not be provided.")
if head:
# Start Ray on the head node.
if redis_address is not None:
raise Exception("If --head is passed in, a Redis server will be "
"started, so a Redis address should not be "
"provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address()
print("Using IP address {} for this node.".format(node_ip_address))
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address()
print("Using IP address {} for this node.".format(node_ip_address))
address_info = {}
# Use the provided object manager port if there is one.
if object_manager_port is not None:
address_info["object_manager_ports"] = [object_manager_port]
if address_info == {}:
address_info = None
address_info = {}
# Use the provided object manager port if there is one.
if object_manager_port is not None:
address_info["object_manager_ports"] = [object_manager_port]
if address_info == {}:
address_info = None
address_info = services.start_ray_head(
address_info=address_info,
node_ip_address=node_ip_address,
redis_port=redis_port,
num_workers=num_workers,
cleanup=False,
redirect_output=True,
num_cpus=num_cpus,
num_gpus=num_gpus,
num_redis_shards=num_redis_shards)
print(address_info)
print("\nStarted Ray on this node. You can add additional nodes to the "
"cluster by calling\n\n"
" ray start --redis-address {}\n\n"
"from the node you wish to add. You can connect a driver to the "
"cluster from Python by running\n\n"
" import ray\n"
" ray.init(redis_address=\"{}\")\n\n"
"If you have trouble connecting from a different machine, check "
"that your firewall is configured properly. If you wish to "
"terminate the processes that have been started, run\n\n"
" ray stop".format(address_info["redis_address"],
address_info["redis_address"]))
else:
# Start Ray on a non-head node.
if redis_port is not None:
raise Exception("If --head is not passed in, --redis-port is not "
"allowed")
if redis_address is None:
raise Exception("If --head is not passed in, --redis-address must be "
"provided.")
if num_redis_shards is not None:
raise Exception("If --head is not passed in, --num-redis-shards must "
"not be provided.")
redis_ip_address, redis_port = redis_address.split(":")
# Wait for the Redis server to be started. And throw an exception if we
# can't connect to it.
services.wait_for_redis_to_start(redis_ip_address, int(redis_port))
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
print("Using IP address {} for this node.".format(node_ip_address))
# Check that there aren't already Redis clients with the same IP address
# connected with this Redis instance. This raises an exception if the Redis
# server already has clients on this node.
check_no_existing_redis_clients(node_ip_address, redis_address)
address_info = services.start_ray_node(
node_ip_address=node_ip_address,
redis_address=redis_address,
object_manager_ports=[object_manager_port],
num_workers=num_workers,
cleanup=False,
redirect_output=True,
num_cpus=num_cpus,
num_gpus=num_gpus)
print(address_info)
print("\nStarted Ray on this node. If you wish to terminate the processes "
"that have been started, run\n\n"
" ray stop")
address_info = services.start_ray_head(
address_info=address_info,
node_ip_address=node_ip_address,
redis_port=redis_port,
num_workers=num_workers,
cleanup=False,
redirect_output=True,
num_cpus=num_cpus,
num_gpus=num_gpus,
num_redis_shards=num_redis_shards)
print(address_info)
print("\nStarted Ray on this node. You can add additional nodes to "
"the cluster by calling\n\n"
" ray start --redis-address {}\n\n"
"from the node you wish to add. You can connect a driver to the "
"cluster from Python by running\n\n"
" import ray\n"
" ray.init(redis_address=\"{}\")\n\n"
"If you have trouble connecting from a different machine, check "
"that your firewall is configured properly. If you wish to "
"terminate the processes that have been started, run\n\n"
" ray stop".format(address_info["redis_address"],
address_info["redis_address"]))
else:
# Start Ray on a non-head node.
if redis_port is not None:
raise Exception("If --head is not passed in, --redis-port is not "
"allowed")
if redis_address is None:
raise Exception("If --head is not passed in, --redis-address must "
"be provided.")
if num_redis_shards is not None:
raise Exception("If --head is not passed in, --num-redis-shards "
"must not be provided.")
redis_ip_address, redis_port = redis_address.split(":")
# Wait for the Redis server to be started. And throw an exception if we
# can't connect to it.
services.wait_for_redis_to_start(redis_ip_address, int(redis_port))
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
print("Using IP address {} for this node.".format(node_ip_address))
# Check that there aren't already Redis clients with the same IP
# address connected with this Redis instance. This raises an exception
# if the Redis server already has clients on this node.
check_no_existing_redis_clients(node_ip_address, redis_address)
address_info = services.start_ray_node(
node_ip_address=node_ip_address,
redis_address=redis_address,
object_manager_ports=[object_manager_port],
num_workers=num_workers,
cleanup=False,
redirect_output=True,
num_cpus=num_cpus,
num_gpus=num_gpus)
print(address_info)
print("\nStarted Ray on this node. If you wish to terminate the "
"processes that have been started, run\n\n"
" ray stop")
if block:
import time
while True:
time.sleep(30)
if block:
import time
while True:
time.sleep(30)
@click.command()
def stop():
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
"local_scheduler"], shell=True)
subprocess.call(["killall global_scheduler plasma_store plasma_manager "
"local_scheduler"], shell=True)
# Find the PID of the monitor process and kill it.
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
# Find the PID of the monitor process and kill it.
subprocess.call(["kill $(ps aux | grep monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
# Find the PID of the Redis process and kill it.
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
# Find the PID of the Redis process and kill it.
subprocess.call(["kill $(ps aux | grep redis-server | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
# Find the PIDs of the worker processes and kill them.
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
shell=True)
# Find the PIDs of the worker processes and kill them.
subprocess.call(["kill -9 $(ps aux | grep default_worker.py | "
"grep -v grep | awk '{ print $2 }') 2> /dev/null"],
shell=True)
# Find the PID of the Ray log monitor process and kill it.
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
# Find the PID of the Ray log monitor process and kill it.
subprocess.call(["kill $(ps aux | grep log_monitor.py | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
# Find the PID of the jupyter process and kill it.
subprocess.call(["kill $(ps aux | grep jupyter | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
# Find the PID of the jupyter process and kill it.
subprocess.call(["kill $(ps aux | grep jupyter | grep -v grep | "
"awk '{ print $2 }') 2> /dev/null"], shell=True)
cli.add_command(start)
@@ -185,8 +188,8 @@ cli.add_command(stop)
def main():
return cli()
return cli()
if __name__ == "__main__":
main()
main()
+147 -142
View File
@@ -8,58 +8,61 @@ import ray.numbuf
class RaySerializationException(Exception):
def __init__(self, message, example_object):
Exception.__init__(self, message)
self.example_object = example_object
def __init__(self, message, example_object):
Exception.__init__(self, message)
self.example_object = example_object
class RayDeserializationException(Exception):
def __init__(self, message, class_id):
Exception.__init__(self, message)
self.class_id = class_id
def __init__(self, message, class_id):
Exception.__init__(self, message)
self.class_id = class_id
class RayNotDictionarySerializable(Exception):
pass
pass
def check_serializable(cls):
"""Throws an exception if Ray cannot serialize this class efficiently.
"""Throws an exception if Ray cannot serialize this class efficiently.
Args:
cls (type): The class to be serialized.
Args:
cls (type): The class to be serialized.
Raises:
Exception: An exception is raised if Ray cannot serialize this class
efficiently.
"""
if is_named_tuple(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
print("The class {} does not have a '__new__' attribute and is probably "
"an old-stye class. Please make it a new-style class by inheriting "
"from 'object'.")
raise RayNotDictionarySerializable("The class {} does not have a "
"'__new__' attribute and is probably "
"an old-style class. We do not support "
"this. Please make it a new-style "
"class by inheriting from 'object'."
.format(cls))
try:
obj = cls.__new__(cls)
except:
raise RayNotDictionarySerializable("The class {} has overridden '__new__'"
", so Ray may not be able to serialize "
"it efficiently.".format(cls))
if not hasattr(obj, "__dict__"):
raise RayNotDictionarySerializable("Objects of the class {} do not have a "
"'__dict__' attribute, so Ray cannot "
"serialize it efficiently.".format(cls))
if hasattr(obj, "__slots__"):
raise RayNotDictionarySerializable("The class {} uses '__slots__', so Ray "
"may not be able to serialize it "
"efficiently.".format(cls))
Raises:
Exception: An exception is raised if Ray cannot serialize this class
efficiently.
"""
if is_named_tuple(cls):
# This case works.
return
if not hasattr(cls, "__new__"):
print("The class {} does not have a '__new__' attribute and is "
"probably an old-stye class. Please make it a new-style class "
"by inheriting from 'object'.")
raise RayNotDictionarySerializable("The class {} does not have a "
"'__new__' attribute and is "
"probably an old-style class. We "
"do not support this. Please make "
"it a new-style class by "
"inheriting from 'object'."
.format(cls))
try:
obj = cls.__new__(cls)
except:
raise RayNotDictionarySerializable("The class {} has overridden "
"'__new__', so Ray may not be able "
"to serialize it efficiently."
.format(cls))
if not hasattr(obj, "__dict__"):
raise RayNotDictionarySerializable("Objects of the class {} do not "
"have a '__dict__' attribute, so "
"Ray cannot serialize it "
"efficiently.".format(cls))
if hasattr(obj, "__slots__"):
raise RayNotDictionarySerializable("The class {} uses '__slots__', so "
"Ray may not be able to serialize "
"it efficiently.".format(cls))
# This field keeps track of a whitelisted set of classes that Ray will
@@ -72,134 +75,136 @@ custom_deserializers = dict()
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
def add_class_to_whitelist(cls, class_id, pickle=False, custom_serializer=None,
custom_deserializer=None):
"""Add cls to the list of classes that we can serialize.
"""Add cls to the list of classes that we can serialize.
Args:
cls (type): The class that we can serialize.
class_id: A string of bytes used to identify the class.
pickle (bool): True if the serialization should be done with pickle. False
if it should be done efficiently with Ray.
custom_serializer: This argument is optional, but can be provided to
serialize objects of the class in a particular way.
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
type_to_class_id[cls] = class_id
whitelisted_classes[class_id] = cls
if pickle:
classes_to_pickle.add(class_id)
if custom_serializer is not None:
custom_serializers[class_id] = custom_serializer
custom_deserializers[class_id] = custom_deserializer
Args:
cls (type): The class that we can serialize.
class_id: A string of bytes used to identify the class.
pickle (bool): True if the serialization should be done with pickle.
False if it should be done efficiently with Ray.
custom_serializer: This argument is optional, but can be provided to
serialize objects of the class in a particular way.
custom_deserializer: This argument is optional, but can be provided to
deserialize objects of the class in a particular way.
"""
type_to_class_id[cls] = class_id
whitelisted_classes[class_id] = cls
if pickle:
classes_to_pickle.add(class_id)
if custom_serializer is not None:
custom_serializers[class_id] = custom_serializer
custom_deserializers[class_id] = custom_deserializer
def serialize(obj):
"""This is the callback that will be used by numbuf.
"""This is the callback that will be used by numbuf.
If numbuf does not know how to serialize an object, it will call this method.
If numbuf does not know how to serialize an object, it will call this
method.
Args:
obj (object): A Python object.
Args:
obj (object): A Python object.
Returns:
A dictionary that has the key "_pyttype_" to identify the class, and
contains all information needed to reconstruct the object.
"""
if type(obj) not in type_to_class_id:
raise RaySerializationException("Ray does not know how to serialize "
"objects of type {}.".format(type(obj)),
obj)
class_id = type_to_class_id[type(obj)]
Returns:
A dictionary that has the key "_pyttype_" to identify the class, and
contains all information needed to reconstruct the object.
"""
if type(obj) not in type_to_class_id:
raise RaySerializationException("Ray does not know how to serialize "
"objects of type {}."
.format(type(obj)),
obj)
class_id = type_to_class_id[type(obj)]
if class_id in classes_to_pickle:
serialized_obj = {"data": pickle.dumps(obj),
"pickle": True}
elif class_id in custom_serializers:
serialized_obj = {"data": custom_serializers[class_id](obj)}
else:
# Handle the namedtuple case.
if is_named_tuple(type(obj)):
serialized_obj = {}
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
if class_id in classes_to_pickle:
serialized_obj = {"data": pickle.dumps(obj),
"pickle": True}
elif class_id in custom_serializers:
serialized_obj = {"data": custom_serializers[class_id](obj)}
else:
raise RaySerializationException("We do not know how to serialize the "
"object '{}'".format(obj), obj)
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
# Handle the namedtuple case.
if is_named_tuple(type(obj)):
serialized_obj = {}
serialized_obj["_ray_getnewargs_"] = obj.__getnewargs__()
elif hasattr(obj, "__dict__"):
serialized_obj = obj.__dict__
else:
raise RaySerializationException("We do not know how to serialize "
"the object '{}'".format(obj), obj)
result = dict(serialized_obj, **{"_pytype_": class_id})
return result
def deserialize(serialized_obj):
"""This is the callback that will be used by numbuf.
"""This is the callback that will be used by numbuf.
If numbuf encounters a dictionary that contains the key "_pytype_" during
If numbuf encounters a dictionary that contains the key "_pytype_" during
deserialization, it will ask this callback to deserialize the object.
Args:
serialized_obj (object): A dictionary that contains the key "_pytype_".
Args:
serialized_obj (object): A dictionary that contains the key "_pytype_".
Returns:
A Python object.
Returns:
A Python object.
Raises:
An exception is raised if we do not know how to deserialize the object.
"""
class_id = serialized_obj["_pytype_"]
Raises:
An exception is raised if we do not know how to deserialize the object.
"""
class_id = serialized_obj["_pytype_"]
if "pickle" in serialized_obj:
# The object was pickled, so unpickle it.
obj = pickle.loads(serialized_obj["data"])
else:
assert class_id not in classes_to_pickle
if class_id not in whitelisted_classes:
# If this happens, that means that the call to _register_class, which
# should have added the class to the list of whitelisted classes, has not
# yet propagated to this worker. It should happen if we wait a little
# longer.
raise RayDeserializationException("The class {} is not one of the "
"whitelisted classes."
.format(class_id), class_id)
cls = whitelisted_classes[class_id]
if class_id in custom_deserializers:
obj = custom_deserializers[class_id](serialized_obj["data"])
if "pickle" in serialized_obj:
# The object was pickled, so unpickle it.
obj = pickle.loads(serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj
assert class_id not in classes_to_pickle
if class_id not in whitelisted_classes:
# If this happens, that means that the call to _register_class,
# which should have added the class to the list of whitelisted
# classes, has not yet propagated to this worker. It should happen
# if we wait a little longer.
raise RayDeserializationException("The class {} is not one of the "
"whitelisted classes."
.format(class_id), class_id)
cls = whitelisted_classes[class_id]
if class_id in custom_deserializers:
obj = custom_deserializers[class_id](serialized_obj["data"])
else:
# In this case, serialized_obj should just be the __dict__ field.
if "_ray_getnewargs_" in serialized_obj:
obj = cls.__new__(cls, *serialized_obj["_ray_getnewargs_"])
else:
obj = cls.__new__(cls)
serialized_obj.pop("_pytype_")
obj.__dict__.update(serialized_obj)
return obj
def set_callbacks():
"""Register the custom callbacks with numbuf.
"""Register the custom callbacks with numbuf.
The serialize callback is used to serialize objects that numbuf does not know
how to serialize (for example custom Python classes). The deserialize
callback is used to serialize objects that were serialized by the serialize
callback.
"""
ray.numbuf.register_callbacks(serialize, deserialize)
The serialize callback is used to serialize objects that numbuf does not
know how to serialize (for example custom Python classes). The deserialize
callback is used to serialize objects that were serialized by the serialize
callback.
"""
ray.numbuf.register_callbacks(serialize, deserialize)
def clear_state():
type_to_class_id.clear()
whitelisted_classes.clear()
classes_to_pickle.clear()
custom_serializers.clear()
custom_deserializers.clear()
type_to_class_id.clear()
whitelisted_classes.clear()
classes_to_pickle.clear()
custom_serializers.clear()
custom_deserializers.clear()
+888 -860
View File
File diff suppressed because it is too large Load Diff
+133 -128
View File
@@ -13,160 +13,165 @@ FunctionSignature = namedtuple("FunctionSignature", ["arg_names",
"""This class is used to represent a function signature.
Attributes:
keyword_names: The names of the functions keyword arguments. This is used to
test if an incorrect keyword argument has been passed to the function.
arg_defaults: A dictionary mapping from argument name to argument default
value. If the argument is not a keyword argument, the default value will be
funcsigs._empty.
arg_is_positionals: A dictionary mapping from argument name to a bool. The
bool will be true if the argument is a *args argument. Otherwise it will be
false.
function_name: The name of the function whose signature is being inspected.
This is used for printing better error messages.
keyword_names: The names of the functions keyword arguments. This is used
to test if an incorrect keyword argument has been passed to the
function.
arg_defaults: A dictionary mapping from argument name to argument default
value. If the argument is not a keyword argument, the default value
will be funcsigs._empty.
arg_is_positionals: A dictionary mapping from argument name to a bool. The
bool will be true if the argument is a *args argument. Otherwise it
will be false.
function_name: The name of the function whose signature is being
inspected. This is used for printing better error messages.
"""
def check_signature_supported(func, warn=False):
"""Check if we support the signature of this function.
"""Check if we support the signature of this function.
We currently do not allow remote functions to have **kwargs. We also do not
support keyword arguments in conjunction with a *args argument.
We currently do not allow remote functions to have **kwargs. We also do not
support keyword arguments in conjunction with a *args argument.
Args:
func: The function whose signature should be checked.
warn: If this is true, a warning will be printed if the signature is not
supported. If it is false, an exception will be raised if the signature
is not supported.
Args:
func: The function whose signature should be checked.
warn: If this is true, a warning will be printed if the signature is
not supported. If it is false, an exception will be raised if the
signature is not supported.
Raises:
Exception: An exception is raised if the signature is not supported.
"""
function_name = func.__name__
sig_params = [(k, v) for k, v
in funcsigs.signature(func).parameters.items()]
Raises:
Exception: An exception is raised if the signature is not supported.
"""
function_name = func.__name__
sig_params = [(k, v) for k, v
in funcsigs.signature(func).parameters.items()]
has_vararg_param = False
has_kwargs_param = False
has_keyword_arg = False
for keyword_name, parameter in sig_params:
if parameter.kind == parameter.VAR_KEYWORD:
has_kwargs_param = True
if parameter.kind == parameter.VAR_POSITIONAL:
has_vararg_param = True
if parameter.default != funcsigs._empty:
has_keyword_arg = True
has_vararg_param = False
has_kwargs_param = False
has_keyword_arg = False
for keyword_name, parameter in sig_params:
if parameter.kind == parameter.VAR_KEYWORD:
has_kwargs_param = True
if parameter.kind == parameter.VAR_POSITIONAL:
has_vararg_param = True
if parameter.default != funcsigs._empty:
has_keyword_arg = True
if has_kwargs_param:
message = ("The function {} has a **kwargs argument, which is "
"currently not supported.".format(function_name))
if warn:
print(message)
else:
raise Exception(message)
# Check if the user specified a variable number of arguments and any keyword
# arguments.
if has_vararg_param and has_keyword_arg:
message = ("Function {} has a *args argument as well as a keyword "
"argument, which is currently not supported."
.format(function_name))
if warn:
print(message)
else:
raise Exception(message)
if has_kwargs_param:
message = ("The function {} has a **kwargs argument, which is "
"currently not supported.".format(function_name))
if warn:
print(message)
else:
raise Exception(message)
# Check if the user specified a variable number of arguments and any
# keyword arguments.
if has_vararg_param and has_keyword_arg:
message = ("Function {} has a *args argument as well as a keyword "
"argument, which is currently not supported."
.format(function_name))
if warn:
print(message)
else:
raise Exception(message)
def extract_signature(func, ignore_first=False):
"""Extract the function signature from the function.
"""Extract the function signature from the function.
Args:
func: The function whose signature should be extracted.
ignore_first: True if the first argument should be ignored. This should be
used when func is a method of a class.
Args:
func: The function whose signature should be extracted.
ignore_first: True if the first argument should be ignored. This should
be used when func is a method of a class.
Returns:
A function signature object, which includes the names of the keyword
arguments as well as their default values.
"""
sig_params = [(k, v) for k, v
in funcsigs.signature(func).parameters.items()]
Returns:
A function signature object, which includes the names of the keyword
arguments as well as their default values.
"""
sig_params = [(k, v) for k, v
in funcsigs.signature(func).parameters.items()]
if ignore_first:
if len(sig_params) == 0:
raise Exception("Methods must take a 'self' argument, but the method "
"'{}' does not have one.".format(func.__name__))
sig_params = sig_params[1:]
if ignore_first:
if len(sig_params) == 0:
raise Exception("Methods must take a 'self' argument, but the "
"method '{}' does not have one."
.format(func.__name__))
sig_params = sig_params[1:]
# Extract the names of the keyword arguments.
keyword_names = set()
for keyword_name, parameter in sig_params:
if parameter.default != funcsigs._empty:
keyword_names.add(keyword_name)
# Extract the names of the keyword arguments.
keyword_names = set()
for keyword_name, parameter in sig_params:
if parameter.default != funcsigs._empty:
keyword_names.add(keyword_name)
# Construct the argument default values and other argument information.
arg_names = []
arg_defaults = []
arg_is_positionals = []
for keyword_name, parameter in sig_params:
arg_names.append(keyword_name)
arg_defaults.append(parameter.default)
arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL)
# Construct the argument default values and other argument information.
arg_names = []
arg_defaults = []
arg_is_positionals = []
for keyword_name, parameter in sig_params:
arg_names.append(keyword_name)
arg_defaults.append(parameter.default)
arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL)
return FunctionSignature(arg_names, arg_defaults, arg_is_positionals,
keyword_names, func.__name__)
return FunctionSignature(arg_names, arg_defaults, arg_is_positionals,
keyword_names, func.__name__)
def extend_args(function_signature, args, kwargs):
"""Extend the arguments that were passed into a function.
"""Extend the arguments that were passed into a function.
This extends the arguments that were passed into a function with the default
arguments provided in the function definition.
This extends the arguments that were passed into a function with the
default arguments provided in the function definition.
Args:
function_signature: The function signature of the function being called.
args: The non-keyword arguments passed into the function.
kwargs: The keyword arguments passed into the function.
Args:
function_signature: The function signature of the function being
called.
args: The non-keyword arguments passed into the function.
kwargs: The keyword arguments passed into the function.
Returns:
An extended list of arguments to pass into the function.
Returns:
An extended list of arguments to pass into the function.
Raises:
Exception: An exception may be raised if the function cannot be called with
these arguments.
"""
arg_names = function_signature.arg_names
arg_defaults = function_signature.arg_defaults
arg_is_positionals = function_signature.arg_is_positionals
keyword_names = function_signature.keyword_names
function_name = function_signature.function_name
Raises:
Exception: An exception may be raised if the function cannot be called
with these arguments.
"""
arg_names = function_signature.arg_names
arg_defaults = function_signature.arg_defaults
arg_is_positionals = function_signature.arg_is_positionals
keyword_names = function_signature.keyword_names
function_name = function_signature.function_name
args = list(args)
args = list(args)
for keyword_name in kwargs:
if keyword_name not in keyword_names:
raise Exception("The name '{}' is not a valid keyword argument for the "
"function '{}'.".format(keyword_name, function_name))
for keyword_name in kwargs:
if keyword_name not in keyword_names:
raise Exception("The name '{}' is not a valid keyword argument "
"for the function '{}'."
.format(keyword_name, function_name))
# Fill in the remaining arguments.
zipped_info = list(zip(arg_names, arg_defaults,
arg_is_positionals))[len(args):]
for keyword_name, default_value, is_positional in zipped_info:
if keyword_name in kwargs:
args.append(kwargs[keyword_name])
else:
if default_value != funcsigs._empty:
args.append(default_value)
else:
# This means that there is a missing argument. Unless this is the last
# argument and it is a *args argument in which case it can be omitted.
if not is_positional:
raise Exception("No value was provided for the argument '{}' for "
"the function '{}'.".format(keyword_name,
function_name))
# Fill in the remaining arguments.
zipped_info = list(zip(arg_names, arg_defaults,
arg_is_positionals))[len(args):]
for keyword_name, default_value, is_positional in zipped_info:
if keyword_name in kwargs:
args.append(kwargs[keyword_name])
else:
if default_value != funcsigs._empty:
args.append(default_value)
else:
# This means that there is a missing argument. Unless this is
# the last argument and it is a *args argument in which case it
# can be omitted.
if not is_positional:
raise Exception("No value was provided for the argument "
"'{}' for the function '{}'."
.format(keyword_name, function_name))
too_many_arguments = (len(args) > len(arg_names) and
(len(arg_is_positionals) == 0 or
not arg_is_positionals[-1]))
if too_many_arguments:
raise Exception("Too many arguments were passed to the function '{}'"
.format(function_name))
return args
too_many_arguments = (len(args) > len(arg_names) and
(len(arg_is_positionals) == 0 or
not arg_is_positionals[-1]))
if too_many_arguments:
raise Exception("Too many arguments were passed to the function '{}'"
.format(function_name))
return args
+25 -25
View File
@@ -11,99 +11,99 @@ import numpy as np
@ray.remote(num_return_vals=2)
def handle_int(a, b):
return a + 1, b + 1
return a + 1, b + 1
# Test timing
@ray.remote
def empty_function():
pass
pass
@ray.remote
def trivial_function():
return 1
return 1
# Test keyword arguments
@ray.remote
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
return "{} {}".format(a, b)
@ray.remote
def keyword_fct2(a="hello", b="world"):
return "{} {}".format(a, b)
return "{} {}".format(a, b)
@ray.remote
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
return "{} {} {} {}".format(a, b, c, d)
# Test variable numbers of arguments
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
return " ".join(map(str, a))
@ray.remote
def varargs_fct2(a, *b):
return " ".join(map(str, b))
return " ".join(map(str, b))
try:
@ray.remote
def kwargs_throw_exception(**c):
return ()
kwargs_exception_thrown = False
@ray.remote
def kwargs_throw_exception(**c):
return ()
kwargs_exception_thrown = False
except:
kwargs_exception_thrown = True
kwargs_exception_thrown = True
try:
@ray.remote
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
return "{} {} {}".format(a, b, c)
varargs_and_kwargs_exception_thrown = False
@ray.remote
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
return "{} {} {}".format(a, b, c)
varargs_and_kwargs_exception_thrown = False
except:
varargs_and_kwargs_exception_thrown = True
varargs_and_kwargs_exception_thrown = True
# test throwing an exception
@ray.remote
def throw_exception_fct1():
raise Exception("Test function 1 intentionally failed.")
raise Exception("Test function 1 intentionally failed.")
@ray.remote
def throw_exception_fct2():
raise Exception("Test function 2 intentionally failed.")
raise Exception("Test function 2 intentionally failed.")
@ray.remote(num_return_vals=3)
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
raise Exception("Test function 3 intentionally failed.")
# test Python mode
@ray.remote
def python_mode_f():
return np.array([0, 0])
return np.array([0, 0])
@ray.remote
def python_mode_g(x):
x[0] = 1
return x
x[0] = 1
return x
# test no return values
@ray.remote
def no_op():
pass
pass
+97 -92
View File
@@ -15,119 +15,124 @@ EVENT_KEY = "RAY_MULTI_NODE_TEST_KEY"
def _wait_for_nodes_to_join(num_nodes, timeout=20):
"""Wait until the nodes have joined the cluster.
"""Wait until the nodes have joined the cluster.
This will wait until exactly num_nodes have joined the cluster and each node
has a local scheduler and a plasma manager.
This will wait until exactly num_nodes have joined the cluster and each
node has a local scheduler and a plasma manager.
Args:
num_nodes: The number of nodes to wait for.
timeout: The amount of time in seconds to wait before failing.
Args:
num_nodes: The number of nodes to wait for.
timeout: The amount of time in seconds to wait before failing.
Raises:
Exception: An exception is raised if too many nodes join the cluster or if
the timeout expires while we are waiting.
"""
start_time = time.time()
while time.time() - start_time < timeout:
client_table = ray.global_state.client_table()
num_ready_nodes = len(client_table)
if num_ready_nodes == num_nodes:
ready = True
# Check that for each node, a local scheduler and a plasma manager are
# present.
for ip_address, clients in client_table.items():
client_types = [client["ClientType"] for client in clients]
if "local_scheduler" not in client_types:
ready = False
if "plasma_manager" not in client_types:
ready = False
if ready:
return
if num_ready_nodes > num_nodes:
# Too many nodes have joined. Something must be wrong.
raise Exception("{} nodes have joined the cluster, but we were "
"expecting {} nodes.".format(num_ready_nodes, num_nodes))
time.sleep(0.1)
Raises:
Exception: An exception is raised if too many nodes join the cluster or
if the timeout expires while we are waiting.
"""
start_time = time.time()
while time.time() - start_time < timeout:
client_table = ray.global_state.client_table()
num_ready_nodes = len(client_table)
if num_ready_nodes == num_nodes:
ready = True
# Check that for each node, a local scheduler and a plasma manager
# are present.
for ip_address, clients in client_table.items():
client_types = [client["ClientType"] for client in clients]
if "local_scheduler" not in client_types:
ready = False
if "plasma_manager" not in client_types:
ready = False
if ready:
return
if num_ready_nodes > num_nodes:
# Too many nodes have joined. Something must be wrong.
raise Exception("{} nodes have joined the cluster, but we were "
"expecting {} nodes.".format(num_ready_nodes,
num_nodes))
time.sleep(0.1)
# If we get here then we timed out.
raise Exception("Timed out while waiting for {} nodes to join. Only {} "
"nodes have joined so far.".format(num_ready_nodes,
num_nodes))
# If we get here then we timed out.
raise Exception("Timed out while waiting for {} nodes to join. Only {} "
"nodes have joined so far.".format(num_ready_nodes,
num_nodes))
def _broadcast_event(event_name, redis_address, data=None):
"""Broadcast an event.
"""Broadcast an event.
This is used to synchronize drivers for the multi-node tests.
This is used to synchronize drivers for the multi-node tests.
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for synchronization.
data: Extra data to include in the broadcast (this will be returned by the
corresponding _wait_for_event call). This data must be json serializable.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
payload = json.dumps((event_name, data))
redis_client.rpush(EVENT_KEY, payload)
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for
synchronization.
data: Extra data to include in the broadcast (this will be returned by
the corresponding _wait_for_event call). This data must be json
serializable.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
payload = json.dumps((event_name, data))
redis_client.rpush(EVENT_KEY, payload)
def _wait_for_event(event_name, redis_address, extra_buffer=0):
"""Block until an event has been broadcast.
"""Block until an event has been broadcast.
This is used to synchronize drivers for the multi-node tests.
This is used to synchronize drivers for the multi-node tests.
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for synchronization.
extra_buffer: An amount of time in seconds to wait after the event.
Args:
event_name: The name of the event to wait for.
redis_address: The address of the Redis server to use for
synchronization.
extra_buffer: An amount of time in seconds to wait after the event.
Returns:
The data that was passed into the corresponding _broadcast_event call.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
while True:
event_infos = redis_client.lrange(EVENT_KEY, 0, -1)
events = dict()
for event_info in event_infos:
name, data = json.loads(event_info)
if name in events:
raise Exception("The same event {} was broadcast twice.".format(name))
events[name] = data
if event_name in events:
# Potentially sleep a little longer and then return the event data.
time.sleep(extra_buffer)
return events[event_name]
time.sleep(0.1)
Returns:
The data that was passed into the corresponding _broadcast_event call.
"""
redis_host, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
while True:
event_infos = redis_client.lrange(EVENT_KEY, 0, -1)
events = dict()
for event_info in event_infos:
name, data = json.loads(event_info)
if name in events:
raise Exception("The same event {} was broadcast twice."
.format(name))
events[name] = data
if event_name in events:
# Potentially sleep a little longer and then return the event data.
time.sleep(extra_buffer)
return events[event_name]
time.sleep(0.1)
def _pid_alive(pid):
"""Check if the process with this PID is alive or not.
"""Check if the process with this PID is alive or not.
Args:
pid: The pid to check.
Args:
pid: The pid to check.
Returns:
This returns false if the process is dead or defunct. Otherwise, it returns
true.
"""
try:
os.kill(pid, 0)
except OSError:
return False
else:
if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE:
return False
Returns:
This returns false if the process is dead or defunct. Otherwise, it
returns true.
"""
try:
os.kill(pid, 0)
except OSError:
return False
else:
return True
if psutil.Process(pid).status() == psutil.STATUS_ZOMBIE:
return False
else:
return True
def wait_for_pid_to_exit(pid, timeout=20):
start_time = time.time()
while time.time() - start_time < timeout:
if not _pid_alive(pid):
return
time.sleep(0.1)
raise Exception("Timed out while waiting for process to exit.")
start_time = time.time()
while time.time() - start_time < timeout:
if not _pid_alive(pid):
return
time.sleep(0.1)
raise Exception("Timed out while waiting for process to exit.")
+31 -30
View File
@@ -11,51 +11,52 @@ import ray.local_scheduler
def random_string():
"""Generate a random string to use as an ID.
"""Generate a random string to use as an ID.
Note that users may seed numpy, which could cause this function to generate
duplicate IDs. Therefore, we need to seed numpy ourselves, but we can't
interfere with the state of the user's random number generator, so we extract
the state of the random number generator and reset it after we are done.
Note that users may seed numpy, which could cause this function to generate
duplicate IDs. Therefore, we need to seed numpy ourselves, but we can't
interfere with the state of the user's random number generator, so we
extract the state of the random number generator and reset it after we are
done.
TODO(rkn): If we want to later guarantee that these are generated in a
deterministic manner, then we will need to make some changes here.
TODO(rkn): If we want to later guarantee that these are generated in a
deterministic manner, then we will need to make some changes here.
Returns:
A random byte string of length 20.
"""
# Get the state of the numpy random number generator.
numpy_state = np.random.get_state()
# Try to use true randomness.
np.random.seed(None)
# Generate the random ID.
random_id = np.random.bytes(20)
# Reset the state of the numpy random number generator.
np.random.set_state(numpy_state)
return random_id
Returns:
A random byte string of length 20.
"""
# Get the state of the numpy random number generator.
numpy_state = np.random.get_state()
# Try to use true randomness.
np.random.seed(None)
# Generate the random ID.
random_id = np.random.bytes(20)
# Reset the state of the numpy random number generator.
np.random.set_state(numpy_state)
return random_id
def decode(byte_str):
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
"""Make this unicode in Python 3, otherwise leave it as bytes."""
if sys.version_info >= (3, 0):
return byte_str.decode("ascii")
else:
return byte_str
def binary_to_object_id(binary_object_id):
return ray.local_scheduler.ObjectID(binary_object_id)
return ray.local_scheduler.ObjectID(binary_object_id)
def binary_to_hex(identifier):
hex_identifier = binascii.hexlify(identifier)
if sys.version_info >= (3, 0):
hex_identifier = hex_identifier.decode()
return hex_identifier
hex_identifier = binascii.hexlify(identifier)
if sys.version_info >= (3, 0):
hex_identifier = hex_identifier.decode()
return hex_identifier
def hex_to_binary(hex_identifier):
return binascii.unhexlify(hex_identifier)
return binascii.unhexlify(hex_identifier)
FunctionProperties = collections.namedtuple("FunctionProperties",
+1802 -1730
View File
File diff suppressed because it is too large Load Diff
+50 -48
View File
@@ -27,59 +27,61 @@ parser.add_argument("--actor-id", required=False, type=str,
def random_string():
return np.random.bytes(20)
return np.random.bytes(20)
if __name__ == "__main__":
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name}
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name}
if args.actor_id is not None:
actor_id = binascii.unhexlify(args.actor_id)
else:
actor_id = ray.worker.NIL_ACTOR_ID
if args.actor_id is not None:
actor_id = binascii.unhexlify(args.actor_id)
else:
actor_id = ray.worker.NIL_ACTOR_ID
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker crashed
in an unanticipated way causing the main_loop to throw an exception, which is
being caught in "python/ray/workers/default_worker.py".
"""
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker
crashed in an unanticipated way causing the main_loop to throw an exception,
which is being caught in "python/ray/workers/default_worker.py".
"""
while True:
try:
# This call to main_loop should never return if things are working. Most
# exceptions that are thrown (e.g., inside the execution of a task)
# should be caught and handled inside of the call to main_loop. If an
# exception is thrown here, then that means that there is some error that
# we didn't anticipate.
ray.worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all
# drivers.
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
redis_ip_address, redis_port = args.redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
redis_client.hmset(error_key, {"type": "worker_crash",
"message": traceback_str,
"note": ("This error is unexpected and "
"should not have happened.")})
redis_client.rpush("ErrorKeys", error_key)
# TODO(rkn): Note that if the worker was in the middle of executing a
# task, the any worker or driver that is blocking in a get call and
# waiting for the output of that task will hang. We need to address this.
while True:
try:
# This call to main_loop should never return if things are working.
# Most exceptions that are thrown (e.g., inside the execution of a
# task) should be caught and handled inside of the call to
# main_loop. If an exception is thrown here, then that means that
# there is some error that we didn't anticipate.
ray.worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all
# drivers.
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
redis_ip_address, redis_port = args.redis_address.split(":")
# For this command to work, some other client (on the same machine
# as Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
redis_client.hmset(error_key, {"type": "worker_crash",
"message": traceback_str,
"note": ("This error is unexpected "
"and should not have "
"happened.")})
redis_client.rpush("ErrorKeys", error_key)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, the any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to
# address this.
# After putting the error message in Redis, this worker will attempt to
# reenter the main loop. TODO(rkn): We should probably reset it's state and
# call connect again.
# After putting the error message in Redis, this worker will attempt to
# reenter the main loop. TODO(rkn): We should probably reset it's state
# and call connect again.
+29 -27
View File
@@ -11,32 +11,34 @@ import setuptools.command.build_ext as _build_ext
class build_ext(_build_ext.build_ext):
def run(self):
subprocess.check_call(["../build.sh"])
# Ideally, we could include these files by putting them in a MANIFEST.in or
# using the package_data argument to setup, but the MANIFEST.in gets
# applied at the very beginning when setup.py runs before these files have
# been created, so we have to move the files manually.
for filename in files_to_include:
self.move_file(filename)
# Copy over the autogenerated flatbuffer Python bindings.
generated_python_directory = "ray/core/generated"
for filename in os.listdir(generated_python_directory):
if filename[-3:] == ".py":
self.move_file(os.path.join(generated_python_directory, filename))
def run(self):
subprocess.check_call(["../build.sh"])
# Ideally, we could include these files by putting them in a
# MANIFEST.in or using the package_data argument to setup, but the
# MANIFEST.in gets applied at the very beginning when setup.py runs
# before these files have been created, so we have to move the files
# manually.
for filename in files_to_include:
self.move_file(filename)
# Copy over the autogenerated flatbuffer Python bindings.
generated_python_directory = "ray/core/generated"
for filename in os.listdir(generated_python_directory):
if filename[-3:] == ".py":
self.move_file(os.path.join(generated_python_directory,
filename))
def move_file(self, filename):
# TODO(rkn): This feels very brittle. It may not handle all cases. See
# https://github.com/apache/arrow/blob/master/python/setup.py for an
# example.
source = filename
destination = os.path.join(self.build_lib, filename)
# Create the target directory if it doesn't already exist.
parent_directory = os.path.dirname(destination)
if not os.path.exists(parent_directory):
os.makedirs(parent_directory)
print("Copying {} to {}.".format(source, destination))
shutil.copy(source, destination)
def move_file(self, filename):
# TODO(rkn): This feels very brittle. It may not handle all cases. See
# https://github.com/apache/arrow/blob/master/python/setup.py for an
# example.
source = filename
destination = os.path.join(self.build_lib, filename)
# Create the target directory if it doesn't already exist.
parent_directory = os.path.dirname(destination)
if not os.path.exists(parent_directory):
os.makedirs(parent_directory)
print("Copying {} to {}.".format(source, destination))
shutil.copy(source, destination)
files_to_include = [
@@ -54,8 +56,8 @@ files_to_include = [
class BinaryDistribution(Distribution):
def has_ext_modules(self):
return True
def has_ext_modules(self):
return True
setup(name="ray",