mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
Add option of load_code_from_local which is required in cross-language ray call. (#3675)
This commit is contained in:
+171
-92
@@ -3,9 +3,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import six
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
@@ -87,9 +89,9 @@ class FunctionDescriptor(object):
|
||||
return FunctionDescriptor.for_driver_task()
|
||||
elif (len(function_descriptor_list) == 3
|
||||
or len(function_descriptor_list) == 4):
|
||||
module_name = function_descriptor_list[0].decode()
|
||||
class_name = function_descriptor_list[1].decode()
|
||||
function_name = function_descriptor_list[2].decode()
|
||||
module_name = six.ensure_str(function_descriptor_list[0])
|
||||
class_name = six.ensure_str(function_descriptor_list[1])
|
||||
function_name = six.ensure_str(function_descriptor_list[2])
|
||||
if len(function_descriptor_list) == 4:
|
||||
return cls(module_name, function_name, class_name,
|
||||
function_descriptor_list[3])
|
||||
@@ -256,6 +258,14 @@ class FunctionDescriptor(object):
|
||||
descriptor_list.append(self._function_source_hash)
|
||||
return descriptor_list
|
||||
|
||||
def is_actor_method(self):
|
||||
"""Wether this function descriptor is an actor method.
|
||||
|
||||
Returns:
|
||||
True if it's an actor method, False if it's a normal function.
|
||||
"""
|
||||
return len(self._class_name) > 0
|
||||
|
||||
|
||||
class FunctionActorManager(object):
|
||||
"""A class used to export/load remote functions and actors.
|
||||
@@ -289,13 +299,18 @@ class FunctionActorManager(object):
|
||||
# import thread. It is safe to convert this worker into an actor of
|
||||
# these types.
|
||||
self.imported_actor_classes = set()
|
||||
self._loaded_actor_classes = {}
|
||||
|
||||
def increase_task_counter(self, driver_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
self._num_task_executions[driver_id][function_id] += 1
|
||||
|
||||
def get_task_counter(self, driver_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
return self._num_task_executions[driver_id][function_id]
|
||||
|
||||
def export_cached(self):
|
||||
@@ -336,6 +351,8 @@ class FunctionActorManager(object):
|
||||
Args:
|
||||
remote_function: the RemoteFunction object.
|
||||
"""
|
||||
if self._worker.load_code_from_local:
|
||||
return
|
||||
# Work around limitations of Python pickling.
|
||||
function = remote_function._function
|
||||
function_name_global_valid = function.__name__ in function.__globals__
|
||||
@@ -436,16 +453,24 @@ class FunctionActorManager(object):
|
||||
Returns:
|
||||
A FunctionExecutionInfo object.
|
||||
"""
|
||||
function_id = function_descriptor.function_id
|
||||
|
||||
# Wait until the function to be executed has actually been
|
||||
# registered on this worker. We will push warnings to the user if
|
||||
# we spend too long in this loop.
|
||||
# The driver function may not be found in sys.path. Try to load
|
||||
# the function from GCS.
|
||||
with profiling.profile("wait_for_function"):
|
||||
self._wait_for_function(function_descriptor, driver_id)
|
||||
if self._worker.load_code_from_local:
|
||||
# Load function from local code.
|
||||
# Currently, we don't support isolating code by drivers,
|
||||
# thus always set driver ID to NIL here.
|
||||
driver_id = ray.DriverID.nil()
|
||||
if not function_descriptor.is_actor_method():
|
||||
self._load_function_from_local(driver_id, function_descriptor)
|
||||
else:
|
||||
# Load function from GCS.
|
||||
# Wait until the function to be executed has actually been
|
||||
# registered on this worker. We will push warnings to the user if
|
||||
# we spend too long in this loop.
|
||||
# The driver function may not be found in sys.path. Try to load
|
||||
# the function from GCS.
|
||||
with profiling.profile("wait_for_function"):
|
||||
self._wait_for_function(function_descriptor, driver_id)
|
||||
try:
|
||||
function_id = function_descriptor.function_id
|
||||
info = self._function_execution_info[driver_id][function_id]
|
||||
except KeyError as e:
|
||||
message = ("Error occurs in get_execution_info: "
|
||||
@@ -454,6 +479,33 @@ class FunctionActorManager(object):
|
||||
raise KeyError(message)
|
||||
return info
|
||||
|
||||
def _load_function_from_local(self, driver_id, function_descriptor):
|
||||
assert not function_descriptor.is_actor_method()
|
||||
function_id = function_descriptor.function_id
|
||||
if (driver_id in self._function_execution_info
|
||||
and function_id in self._function_execution_info[function_id]):
|
||||
return
|
||||
module_name, function_name = (
|
||||
function_descriptor.module_name,
|
||||
function_descriptor.function_name,
|
||||
)
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
function = getattr(module, function_name)._function
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=0,
|
||||
))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load function %s.".format(function_name))
|
||||
raise Exception(
|
||||
"Function {} failed to be loaded from local code.".format(
|
||||
function_descriptor))
|
||||
|
||||
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
@@ -513,6 +565,8 @@ class FunctionActorManager(object):
|
||||
self._worker.redis_client.rpush("Exports", key)
|
||||
|
||||
def export_actor_class(self, Class, actor_method_names):
|
||||
if self._worker.load_code_from_local:
|
||||
return
|
||||
function_descriptor = FunctionDescriptor.from_class(Class)
|
||||
# `task_driver_id` shouldn't be NIL, unless:
|
||||
# 1) This worker isn't an actor;
|
||||
@@ -553,7 +607,87 @@ class FunctionActorManager(object):
|
||||
# within tasks. I tried to disable this, but it may be necessary
|
||||
# because of https://github.com/ray-project/ray/issues/1146.
|
||||
|
||||
def load_actor(self, driver_id, function_descriptor):
|
||||
def load_actor_class(self, driver_id, function_descriptor):
|
||||
"""Load the actor class.
|
||||
|
||||
Args:
|
||||
driver_id: Driver ID of the actor.
|
||||
function_descriptor: Function descriptor of the actor constructor.
|
||||
|
||||
Returns:
|
||||
The actor class.
|
||||
"""
|
||||
function_id = function_descriptor.function_id
|
||||
# Check if the actor class already exists in the cache.
|
||||
actor_class = self._loaded_actor_classes.get(function_id, None)
|
||||
if actor_class is None:
|
||||
# Load actor class.
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
# Load actor class from local code.
|
||||
actor_class = self._load_actor_from_local(
|
||||
driver_id, function_descriptor)
|
||||
else:
|
||||
# Load actor class from GCS.
|
||||
actor_class = self._load_actor_class_from_gcs(
|
||||
driver_id, function_descriptor)
|
||||
# Save the loaded actor class in cache.
|
||||
self._loaded_actor_classes[function_id] = actor_class
|
||||
|
||||
# Generate execution info for the methods of this actor class.
|
||||
module_name = function_descriptor.module_name
|
||||
actor_class_name = function_descriptor.class_name
|
||||
actor_methods = inspect.getmembers(
|
||||
actor_class, predicate=is_function_or_method)
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
method_descriptor = FunctionDescriptor(
|
||||
module_name, actor_method_name, actor_class_name)
|
||||
method_id = method_descriptor.function_id
|
||||
executor = self._make_actor_method_executor(
|
||||
actor_method_name,
|
||||
actor_method,
|
||||
actor_imported=True,
|
||||
)
|
||||
self._function_execution_info[driver_id][method_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0,
|
||||
))
|
||||
self._num_task_executions[driver_id][method_id] = 0
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
return actor_class
|
||||
|
||||
def _load_actor_from_local(self, driver_id, function_descriptor):
|
||||
"""Load actor class from local code."""
|
||||
module_name, class_name = (function_descriptor.module_name,
|
||||
function_descriptor.class_name)
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)._modified_class
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load actor_class %s.".format(class_name))
|
||||
raise Exception(
|
||||
"Actor {} failed to be imported from local code.".format(
|
||||
class_name))
|
||||
|
||||
def _create_fake_actor_class(self, actor_class_name, actor_method_names):
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception(
|
||||
"The actor with name {} failed to be imported, "
|
||||
"and so cannot execute this method.".format(actor_class_name))
|
||||
|
||||
for method in actor_method_names:
|
||||
setattr(TemporaryActor, method, temporary_actor_method)
|
||||
|
||||
return TemporaryActor
|
||||
|
||||
def _load_actor_class_from_gcs(self, driver_id, function_descriptor):
|
||||
"""Load actor class from GCS."""
|
||||
key = (b"ActorClass:" + driver_id.binary() + b":" +
|
||||
function_descriptor.function_id.binary())
|
||||
# Wait for the actor class key to have been imported by the
|
||||
@@ -562,74 +696,32 @@ class FunctionActorManager(object):
|
||||
# the driver if too much time is spent here.
|
||||
while key not in self.imported_actor_classes:
|
||||
time.sleep(0.001)
|
||||
with self._worker.lock:
|
||||
self.fetch_and_register_actor(key)
|
||||
|
||||
def fetch_and_register_actor(self, actor_class_key):
|
||||
"""Import an actor.
|
||||
|
||||
This will be called by the worker's import thread when the worker
|
||||
receives the actor_class export, assuming that the worker is an actor
|
||||
for that class.
|
||||
|
||||
Args:
|
||||
actor_class_key: The key in Redis to use to fetch the actor.
|
||||
"""
|
||||
actor_id = self._worker.actor_id
|
||||
# Fetch raw data from GCS.
|
||||
(driver_id_str, class_name, module, pickled_class,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
actor_class_key, [
|
||||
key, [
|
||||
"driver_id", "class_name", "module", "class",
|
||||
"actor_method_names"
|
||||
])
|
||||
|
||||
class_name = decode(class_name)
|
||||
module = decode(module)
|
||||
class_name = six.ensure_str(class_name)
|
||||
module_name = six.ensure_str(module)
|
||||
driver_id = ray.DriverID(driver_id_str)
|
||||
actor_method_names = json.loads(decode(actor_method_names))
|
||||
|
||||
# In Python 2, json loads strings as unicode, so convert them back to
|
||||
# strings.
|
||||
if sys.version_info < (3, 0):
|
||||
actor_method_names = [
|
||||
method_name.encode("ascii")
|
||||
for method_name in actor_method_names
|
||||
]
|
||||
|
||||
# Create a temporary actor with some temporary methods so that if
|
||||
# the actor fails to be unpickled, the temporary actor can be used
|
||||
# (just to produce error messages and to prevent the driver from
|
||||
# hanging).
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
|
||||
self._worker.actors[actor_id] = TemporaryActor()
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
raise Exception(
|
||||
"The actor with name {} failed to be imported, "
|
||||
"and so cannot execute this method".format(class_name))
|
||||
|
||||
# Register the actor method executors.
|
||||
for actor_method_name in actor_method_names:
|
||||
function_descriptor = FunctionDescriptor(module, actor_method_name,
|
||||
class_name)
|
||||
function_id = function_descriptor.function_id
|
||||
temporary_executor = self._make_actor_method_executor(
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
actor_imported=False)
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=temporary_executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
actor_method_names = json.loads(six.ensure_str(actor_method_names))
|
||||
|
||||
actor_class = None
|
||||
try:
|
||||
unpickled_class = pickle.loads(pickled_class)
|
||||
self._worker.actor_class = unpickled_class
|
||||
with self._worker.lock:
|
||||
actor_class = pickle.loads(pickled_class)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load actor class %s.".format(class_name))
|
||||
# The actor class failed to be unpickled, create a fake actor
|
||||
# class instead (just to produce error messages and to prevent
|
||||
# the driver from hanging).
|
||||
actor_class = self._create_fake_actor_class(
|
||||
class_name, actor_method_names)
|
||||
# If an exception was thrown when the actor was imported, we record
|
||||
# the traceback and notify the scheduler of the failure.
|
||||
traceback_str = ray.utils.format_error_message(
|
||||
@@ -638,33 +730,20 @@ class FunctionActorManager(object):
|
||||
push_error_to_driver(
|
||||
self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
"Failed to unpickle actor class '{}' for actor ID {}. "
|
||||
"Traceback:\n{}".format(class_name, actor_id.hex(),
|
||||
"Traceback:\n{}".format(class_name,
|
||||
self._worker.actor_id.hex(),
|
||||
traceback_str), driver_id)
|
||||
# TODO(rkn): In the future, it might make sense to have the worker
|
||||
# exit here. However, currently that would lead to hanging if
|
||||
# someone calls ray.get on a method invoked on the actor.
|
||||
else:
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
self._worker.actors[actor_id] = unpickled_class.__new__(
|
||||
unpickled_class)
|
||||
|
||||
actor_methods = inspect.getmembers(
|
||||
unpickled_class, predicate=is_function_or_method)
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
function_descriptor = FunctionDescriptor(
|
||||
module, actor_method_name, class_name)
|
||||
function_id = function_descriptor.function_id
|
||||
executor = self._make_actor_method_executor(
|
||||
actor_method_name, actor_method, actor_imported=True)
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0))
|
||||
# We do not set function_properties[driver_id][function_id]
|
||||
# because we currently do need the actor worker to submit new
|
||||
# tasks for the actor.
|
||||
# The below line is necessary. Because in the driver process,
|
||||
# if the function is defined in the file where the python script
|
||||
# was started from, its module is `__main__`.
|
||||
# However in the worker process, the `__main__` module is a
|
||||
# different module, which is `default_worker.py`
|
||||
actor_class.__module__ = module_name
|
||||
return actor_class
|
||||
|
||||
def _make_actor_method_executor(self, method_name, method, actor_imported):
|
||||
"""Make an executor that wraps a user-defined actor method.
|
||||
|
||||
@@ -355,6 +355,7 @@ class Node(object):
|
||||
config=self._config,
|
||||
include_java=self._ray_params.include_java,
|
||||
java_worker_options=self._ray_params.java_worker_options,
|
||||
load_code_from_local=self._ray_params.load_code_from_local,
|
||||
)
|
||||
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
|
||||
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
|
||||
|
||||
@@ -73,6 +73,7 @@ class RayParams(object):
|
||||
include_java (bool): If True, the raylet backend can also support
|
||||
Java worker.
|
||||
java_worker_options (str): The command options for Java worker.
|
||||
load_code_from_local: Whether load code from local file or from GCS.
|
||||
_internal_config (str): JSON configuration for overriding
|
||||
RayConfig defaults. For testing purposes ONLY.
|
||||
"""
|
||||
@@ -110,6 +111,7 @@ class RayParams(object):
|
||||
autoscaling_config=None,
|
||||
include_java=False,
|
||||
java_worker_options=None,
|
||||
load_code_from_local=False,
|
||||
_internal_config=None):
|
||||
self.object_id_seed = object_id_seed
|
||||
self.redis_address = redis_address
|
||||
@@ -141,6 +143,7 @@ class RayParams(object):
|
||||
self.autoscaling_config = autoscaling_config
|
||||
self.include_java = include_java
|
||||
self.java_worker_options = java_worker_options
|
||||
self.load_code_from_local = load_code_from_local
|
||||
self._internal_config = _internal_config
|
||||
self._check_usage()
|
||||
|
||||
|
||||
@@ -208,6 +208,11 @@ def cli(logging_level, logging_format):
|
||||
default=None,
|
||||
type=str,
|
||||
help="Do NOT use this. This is for debugging/development purposes ONLY.")
|
||||
@click.option(
|
||||
"--load-code-from-local",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Specify whether load code from local file or GCS serialization.")
|
||||
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
redis_max_clients, redis_password, redis_shard_ports,
|
||||
object_manager_port, node_manager_port, object_store_memory,
|
||||
@@ -215,7 +220,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
plasma_directory, huge_pages, autoscaling_config,
|
||||
no_redirect_worker_output, no_redirect_output,
|
||||
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
|
||||
java_worker_options, internal_config):
|
||||
java_worker_options, load_code_from_local, internal_config):
|
||||
# Convert hostnames to numerical IP address.
|
||||
if node_ip_address is not None:
|
||||
node_ip_address = services.address_to_ip(node_ip_address)
|
||||
@@ -250,6 +255,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
|
||||
temp_dir=temp_dir,
|
||||
include_java=include_java,
|
||||
java_worker_options=java_worker_options,
|
||||
load_code_from_local=load_code_from_local,
|
||||
_internal_config=internal_config)
|
||||
|
||||
if head:
|
||||
|
||||
@@ -973,7 +973,8 @@ def start_raylet(redis_address,
|
||||
stderr_file=None,
|
||||
config=None,
|
||||
include_java=False,
|
||||
java_worker_options=None):
|
||||
java_worker_options=None,
|
||||
load_code_from_local=False):
|
||||
"""Start a raylet, which is a combined local scheduler and object manager.
|
||||
|
||||
Args:
|
||||
@@ -1068,6 +1069,9 @@ def start_raylet(redis_address,
|
||||
if node_manager_port is None:
|
||||
node_manager_port = 0
|
||||
|
||||
if load_code_from_local:
|
||||
start_worker_command += " --load-code-from-local "
|
||||
|
||||
command = [
|
||||
RAYLET_EXECUTABLE,
|
||||
raylet_name,
|
||||
|
||||
+13
-4
@@ -152,6 +152,7 @@ class Worker(object):
|
||||
# A dictionary that maps from driver id to SerializationContext
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
self.load_code_from_local = False
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Identity of the driver that this worker is processing.
|
||||
# It is a DriverID.
|
||||
@@ -915,8 +916,9 @@ class Worker(object):
|
||||
assert self.actor_id.is_nil()
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.actor_creation_task_id = task.task_id()
|
||||
self.function_actor_manager.load_actor(driver_id,
|
||||
function_descriptor)
|
||||
actor_class = self.function_actor_manager.load_actor_class(
|
||||
driver_id, function_descriptor)
|
||||
self.actors[self.actor_id] = actor_class.__new__(actor_class)
|
||||
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
|
||||
num_tasks_since_last_checkpoint=0,
|
||||
last_checkpoint_timestamp=int(1000 * time.time()),
|
||||
@@ -1271,6 +1273,7 @@ def init(redis_address=None,
|
||||
plasma_store_socket_name=None,
|
||||
raylet_socket_name=None,
|
||||
temp_dir=None,
|
||||
load_code_from_local=False,
|
||||
_internal_config=None):
|
||||
"""Connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
@@ -1346,6 +1349,8 @@ def init(redis_address=None,
|
||||
used by the raylet process.
|
||||
temp_dir (str): If provided, it will specify the root temporary
|
||||
directory for the Ray process.
|
||||
load_code_from_local: Whether code should be loaded from a local module
|
||||
or from the GCS.
|
||||
_internal_config (str): JSON configuration for overriding
|
||||
RayConfig defaults. For testing purposes ONLY.
|
||||
|
||||
@@ -1427,6 +1432,7 @@ def init(redis_address=None,
|
||||
plasma_store_socket_name=plasma_store_socket_name,
|
||||
raylet_socket_name=raylet_socket_name,
|
||||
temp_dir=temp_dir,
|
||||
load_code_from_local=load_code_from_local,
|
||||
_internal_config=_internal_config,
|
||||
)
|
||||
# Start the Ray processes. We set shutdown_at_exit=False because we
|
||||
@@ -1516,7 +1522,8 @@ def init(redis_address=None,
|
||||
mode=driver_mode,
|
||||
log_to_driver=log_to_driver,
|
||||
worker=global_worker,
|
||||
driver_id=driver_id)
|
||||
driver_id=driver_id,
|
||||
load_code_from_local=load_code_from_local)
|
||||
|
||||
for hook in _post_init_hooks:
|
||||
hook()
|
||||
@@ -1745,7 +1752,8 @@ def connect(info,
|
||||
mode=WORKER_MODE,
|
||||
log_to_driver=False,
|
||||
worker=global_worker,
|
||||
driver_id=None):
|
||||
driver_id=None,
|
||||
load_code_from_local=False):
|
||||
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
|
||||
|
||||
Args:
|
||||
@@ -1802,6 +1810,7 @@ def connect(info,
|
||||
worker.actor_id = ActorID.nil()
|
||||
worker.connected = True
|
||||
worker.set_mode(mode)
|
||||
worker.load_code_from_local = load_code_from_local
|
||||
|
||||
# If running Ray in LOCAL_MODE, there is no need to create call
|
||||
# create_worker or to start the worker service.
|
||||
|
||||
@@ -57,6 +57,11 @@ parser.add_argument(
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify the path of the temporary directory use by Ray process.")
|
||||
parser.add_argument(
|
||||
"--load-code-from-local",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="True if code is loaded from local files, as opposed to the GCS.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
@@ -77,7 +82,8 @@ if __name__ == "__main__":
|
||||
redis_password=args.redis_password,
|
||||
plasma_store_socket_name=args.object_store_name,
|
||||
raylet_socket_name=args.raylet_name,
|
||||
temp_dir=args.temp_dir)
|
||||
temp_dir=args.temp_dir,
|
||||
load_code_from_local=args.load_code_from_local)
|
||||
|
||||
node = ray.node.Node(
|
||||
ray_params, head=False, shutdown_at_exit=False, connect_only=True)
|
||||
@@ -85,7 +91,10 @@ if __name__ == "__main__":
|
||||
|
||||
# TODO(suquark): Use "node" as the input of "connect".
|
||||
ray.worker.connect(
|
||||
info, redis_password=args.redis_password, mode=ray.WORKER_MODE)
|
||||
info,
|
||||
redis_password=args.redis_password,
|
||||
mode=ray.WORKER_MODE,
|
||||
load_code_from_local=args.load_code_from_local)
|
||||
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker
|
||||
|
||||
+1
-2
@@ -148,8 +148,7 @@ requires = [
|
||||
"pytest",
|
||||
"pyyaml",
|
||||
"redis",
|
||||
# The six module is required by pyarrow.
|
||||
"six >= 1.0.0",
|
||||
"six >= 1.12.0",
|
||||
# The typing module is required by modin.
|
||||
"typing",
|
||||
"flatbuffers",
|
||||
|
||||
+58
-3
@@ -156,9 +156,9 @@ def test_complex_serialization(ray_start):
|
||||
assert_equal(obj1[i], obj2[i])
|
||||
elif (ray.serialization.is_named_tuple(type(obj1))
|
||||
or ray.serialization.is_named_tuple(type(obj2))):
|
||||
assert len(obj1) == len(obj2), ("Objects {} and {} are named "
|
||||
"tuples with different lengths."
|
||||
.format(obj1, obj2))
|
||||
assert len(obj1) == len(obj2), (
|
||||
"Objects {} and {} are named "
|
||||
"tuples with different lengths.".format(obj1, obj2))
|
||||
for i in range(len(obj1)):
|
||||
assert_equal(obj1[i], obj2[i])
|
||||
else:
|
||||
@@ -2843,3 +2843,58 @@ def test_non_ascii_comment(ray_start):
|
||||
return 1
|
||||
|
||||
assert ray.get(f.remote()) == 1
|
||||
|
||||
|
||||
@ray.remote
|
||||
def echo(x):
|
||||
return x
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WithConstructor(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WithoutConstructor(object):
|
||||
def set_data(self, data):
|
||||
self.data = data
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
|
||||
class BaseClass(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
|
||||
@ray.remote
|
||||
class DerivedClass(BaseClass):
|
||||
def __init__(self, data):
|
||||
# Due to different behaviors of super in Python 2 and Python 3,
|
||||
# we use BaseClass directly here.
|
||||
BaseClass.__init__(self, data)
|
||||
|
||||
def get_data(self):
|
||||
return self.data
|
||||
|
||||
|
||||
def test_load_code_from_local(shutdown_only):
|
||||
ray.init(load_code_from_local=True, num_cpus=4)
|
||||
# Test normal function.
|
||||
assert ray.get(echo.remote("foo")) == "foo"
|
||||
# Test actor class with constructor.
|
||||
actor = WithConstructor.remote(1)
|
||||
assert ray.get(actor.get_data.remote()) == 1
|
||||
# Test actor class without constructor.
|
||||
actor = WithoutConstructor.remote()
|
||||
actor.set_data.remote(1)
|
||||
assert ray.get(actor.get_data.remote()) == 1
|
||||
# Test derived actor class.
|
||||
actor = DerivedClass.remote(1)
|
||||
assert ray.get(actor.get_data.remote()) == 1
|
||||
|
||||
Reference in New Issue
Block a user