Improve code related to node (#4383)

* Make full use of node

implement local node

fix bugs mentioned in comments

* Add more tests

* Use more specific exception handling

* fix, lint

* fix for py2.x
This commit is contained in:
Si-Yuan
2019-04-09 17:27:54 +08:00
committed by GitHub
parent c5bcec54f3
commit dab99d26af
7 changed files with 243 additions and 245 deletions
+98 -33
View File
@@ -31,7 +31,8 @@ PY3 = sys.version_info.major >= 3
class Node(object):
"""An encapsulation of the Ray processes on a single node.
This class is responsible for starting Ray processes and killing them.
This class is responsible for starting Ray processes and killing them,
and it also controls the temp file policy.
Attributes:
all_processes (dict): A mapping from process type (str) to a list of
@@ -63,8 +64,17 @@ class Node(object):
"be both true.")
self.all_processes = {}
# Try to get node IP address with the parameters.
if ray_params.node_ip_address:
node_ip_address = ray_params.node_ip_address
elif ray_params.redis_address:
node_ip_address = ray.services.get_node_ip_address(
ray_params.redis_address)
else:
node_ip_address = ray.services.get_node_ip_address()
self._node_ip_address = node_ip_address
ray_params.update_if_absent(
node_ip_address=ray.services.get_node_ip_address(),
include_log_monitor=True,
resources={},
include_webui=False,
@@ -73,31 +83,51 @@ class Node(object):
"workers/default_worker.py"))
self._ray_params = ray_params
self._node_ip_address = ray_params.node_ip_address
self._redis_address = ray_params.redis_address
self._config = (json.loads(ray_params._internal_config)
if ray_params._internal_config else None)
if head:
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
self._plasma_store_socket_name = None
self._raylet_socket_name = None
self._webui_url = None
else:
self._init_temp()
if connect_only:
# Get socket names from the configuration.
self._plasma_store_socket_name = (
ray_params.plasma_store_socket_name)
self._raylet_socket_name = ray_params.raylet_socket_name
# If user does not provide the socket name, get it from Redis.
if (self._plasma_store_socket_name is None
or self._raylet_socket_name is None):
# Get the address info of the processes to connect to
# from Redis.
address_info = ray.services.get_address_info_from_redis(
self.redis_address,
self._node_ip_address,
redis_password=self.redis_password)
self._plasma_store_socket_name = address_info[
"object_store_address"]
self._raylet_socket_name = address_info["raylet_socket_name"]
else:
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
self._ray_params.plasma_store_socket_name,
default_prefix="plasma_store")
self._raylet_socket_name = self._prepare_socket_file(
self._ray_params.raylet_socket_name, default_prefix="raylet")
if head:
ray_params.update_if_absent(num_redis_shards=1, include_webui=True)
self._webui_url = None
else:
redis_client = self.create_redis_client()
# TODO(suquark): Replace _webui_url_helper in worker.py in
# another PR.
_webui_url = redis_client.hmget("webui", "url")[0]
self._webui_url = (ray.utils.decode(_webui_url)
if _webui_url is not None else None)
self._webui_url = (
ray.services.get_webui_url_from_redis(redis_client))
ray_params.include_java = (
ray.services.include_java_from_redis(redis_client))
self._init_temp()
# Start processes.
if head:
self.start_head_processes()
if not connect_only:
self.start_ray_processes()
@@ -136,6 +166,20 @@ class Node(object):
"""Get the cluster Redis address."""
return self._redis_address
@property
def redis_password(self):
"""Get the cluster Redis password"""
return self._ray_params.redis_password
@property
def load_code_from_local(self):
return self._ray_params.load_code_from_local
@property
def object_id_seed(self):
"""Get the seed for deterministic generation of object IDs"""
return self._ray_params.object_id_seed
@property
def plasma_store_socket_name(self):
"""Get the node's plasma store socket name."""
@@ -151,6 +195,17 @@ class Node(object):
"""Get the node's raylet socket name."""
return self._raylet_socket_name
@property
def address_info(self):
"""Get a dictionary of addresses."""
return {
"node_ip_address": self._node_ip_address,
"redis_address": self._redis_address,
"object_store_address": self._plasma_store_socket_name,
"raylet_socket_name": self._raylet_socket_name,
"webui_url": self._webui_url,
}
def create_redis_client(self):
"""Create a redis client."""
return ray.services.create_redis_client(
@@ -321,11 +376,6 @@ class Node(object):
def start_plasma_store(self):
"""Start the plasma store."""
assert self._plasma_store_socket_name is None
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
self._ray_params.plasma_store_socket_name,
default_prefix="plasma_store")
stdout_file, stderr_file = self.new_log_files("plasma_store")
process_info = ray.services.start_plasma_store(
stdout_file=stdout_file,
@@ -349,10 +399,6 @@ class Node(object):
use_profiler (bool): True if we should start the process in the
valgrind profiler.
"""
assert self._raylet_socket_name is None
# If the user specified a socket name, use it.
self._raylet_socket_name = self._prepare_socket_file(
self._ray_params.raylet_socket_name, default_prefix="raylet")
stdout_file, stderr_file = self.new_log_files("raylet")
process_info = ray.services.start_raylet(
self._redis_address,
@@ -416,20 +462,26 @@ class Node(object):
process_info
]
def start_head_processes(self):
"""Start head processes on the node."""
logger.info(
"Process STDOUT and STDERR is being redirected to {}.".format(
self._logs_dir))
assert self._redis_address is None
# If this is the head node, start the relevant head node processes.
self.start_redis()
self.start_monitor()
self.start_raylet_monitor()
# The dashboard is Python3.x only.
if PY3 and self._ray_params.include_webui:
self.start_dashboard()
def start_ray_processes(self):
"""Start all of the processes on the node."""
logger.info(
"Process STDOUT and STDERR is being redirected to {}.".format(
self._logs_dir))
# If this is the head node, start the relevant head node processes.
if self._redis_address is None:
self.start_redis()
self.start_monitor()
self.start_raylet_monitor()
if PY3 and self._ray_params.include_webui:
self.start_dashboard()
self.start_plasma_store()
self.start_raylet()
if PY3:
@@ -685,3 +737,16 @@ class Node(object):
True if any process that wasn't explicitly killed is still alive.
"""
return not any(self.dead_processes())
class LocalNode(object):
"""Imitate the node that manages the processes in local mode."""
def kill_all_processes(self, *args, **kwargs):
"""Kill all of the processes."""
pass # Keep this function empty because it will be used in worker.py
@property
def address_info(self):
"""Get a dictionary of addresses."""
return {} # Return a null dict.