mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 18:11:38 +08:00
Give error if a worker has a version mismatch for Python Ray, or clou… (#1245)
* Give error if a worker has a version mismatch for Python Ray, or cloudpickle. * Check version when attaching driver to cluster. * Only do check if the version info is present. * Bug fix. * Fix typo.
This commit is contained in:
committed by
Philipp Moritz
parent
ddfe00b7e8
commit
7af5292646
@@ -4,6 +4,8 @@ from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
from collections import namedtuple, OrderedDict
|
||||
import cloudpickle
|
||||
import json
|
||||
import os
|
||||
import psutil
|
||||
import random
|
||||
@@ -261,6 +263,68 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
|
||||
"configured properly.")
|
||||
|
||||
|
||||
def _compute_version_info():
|
||||
"""Compute the versions of Python, cloudpickle, and Ray.
|
||||
|
||||
Returns:
|
||||
A tuple containing the version information.
|
||||
"""
|
||||
ray_version = ray.__version__
|
||||
ray_location = ray.__file__
|
||||
python_version = ".".join(map(str, sys.version_info[:3]))
|
||||
cloudpickle_version = cloudpickle.__version__
|
||||
return ray_version, ray_location, python_version, cloudpickle_version
|
||||
|
||||
|
||||
def _put_version_info_in_redis(redis_client):
|
||||
"""Store version information in Redis.
|
||||
|
||||
This will be used to detect if workers or drivers are started using
|
||||
different versions of Python, cloudpickle, or Ray.
|
||||
|
||||
Args:
|
||||
redis_client: A client for the primary Redis shard.
|
||||
"""
|
||||
redis_client.set("VERSION_INFO", json.dumps(_compute_version_info()))
|
||||
|
||||
|
||||
def check_version_info(redis_client):
|
||||
"""Check if various version info of this process is correct.
|
||||
|
||||
This will be used to detect if workers or drivers are started using
|
||||
different versions of Python, cloudpickle, or Ray. If the version
|
||||
information is not present in Redis, then no check is done.
|
||||
|
||||
Args:
|
||||
redis_client: A client for the primary Redis shard.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if there is a version mismatch.
|
||||
"""
|
||||
redis_reply = redis_client.get("VERSION_INFO")
|
||||
|
||||
# Don't do the check if there is no version information in Redis. This
|
||||
# is to make it easier to do things like start the processes by hand.
|
||||
if redis_reply is None:
|
||||
return
|
||||
|
||||
true_version_info = tuple(json.loads(redis_reply.decode("ascii")))
|
||||
version_info = _compute_version_info()
|
||||
if version_info != true_version_info:
|
||||
node_ip_address = ray.services.get_node_ip_address()
|
||||
raise Exception("Version mismatch: The cluster was started with:\n"
|
||||
" Ray: " + true_version_info[0] + "\n"
|
||||
" Ray location: " + true_version_info[1] + "\n"
|
||||
" Python: " + true_version_info[2] + "\n"
|
||||
" Cloudpickle: " + true_version_info[3] + "\n"
|
||||
"This process on node " + node_ip_address +
|
||||
" was started with:" + "\n"
|
||||
" Ray: " + version_info[0] + "\n"
|
||||
" Ray location: " + version_info[1] + "\n"
|
||||
" Python: " + version_info[2] + "\n"
|
||||
" Cloudpickle: " + version_info[3])
|
||||
|
||||
|
||||
def start_redis(node_ip_address,
|
||||
port=None,
|
||||
num_redis_shards=1,
|
||||
@@ -311,6 +375,9 @@ def start_redis(node_ip_address,
|
||||
# can access it and know whether or not to redirect their output.
|
||||
redis_client.set("RedirectOutput", 1 if redirect_worker_output else 0)
|
||||
|
||||
# Store version information in the primary Redis shard.
|
||||
_put_version_info_in_redis(redis_client)
|
||||
|
||||
# Start other Redis shards listening on random ports. Each Redis shard logs
|
||||
# to a separate file, prefixed by "redis-<shard number>".
|
||||
redis_shards = []
|
||||
|
||||
@@ -1706,6 +1706,11 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
redis_ip_address, redis_port = info["redis_address"].split(":")
|
||||
worker.redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
|
||||
# Check that the version information matches the version information that
|
||||
# the Ray cluster was started with.
|
||||
ray.services.check_version_info(worker.redis_client)
|
||||
|
||||
worker.lock = threading.Lock()
|
||||
|
||||
# Check the RedirectOutput key in Redis and based on its value redirect
|
||||
|
||||
@@ -41,12 +41,13 @@ def create_redis_client(redis_address):
|
||||
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
|
||||
|
||||
|
||||
def push_error_to_all_drivers(redis_client, message):
|
||||
def push_error_to_all_drivers(redis_client, message, error_type):
|
||||
"""Push an error message to all drivers.
|
||||
|
||||
Args:
|
||||
redis_client: The redis client to use.
|
||||
message: The error message to push.
|
||||
error_type: The type of the error.
|
||||
"""
|
||||
DRIVER_ID_LENGTH = 20
|
||||
# We use a driver ID of all zeros to push an error message to all
|
||||
@@ -54,7 +55,7 @@ def push_error_to_all_drivers(redis_client, message):
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = b"Error:" + driver_id + b":" + random_string()
|
||||
# Create a Redis client.
|
||||
redis_client.hmset(error_key, {"type": "worker_crash",
|
||||
redis_client.hmset(error_key, {"type": error_type,
|
||||
"message": message})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
@@ -79,6 +80,13 @@ if __name__ == "__main__":
|
||||
|
||||
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
|
||||
|
||||
try:
|
||||
ray.services.check_version_info(ray.worker.global_worker.redis_client)
|
||||
except Exception as e:
|
||||
traceback_str = traceback.format_exc()
|
||||
push_error_to_all_drivers(ray.worker.global_worker.redis_client,
|
||||
traceback_str, "version_mismatch")
|
||||
|
||||
error_explanation = """
|
||||
This error is unexpected and should not have happened. Somehow a worker
|
||||
crashed in an unanticipated way causing the main_loop to throw an exception,
|
||||
@@ -96,7 +104,7 @@ if __name__ == "__main__":
|
||||
traceback_str = traceback.format_exc() + error_explanation
|
||||
# Create a Redis client.
|
||||
redis_client = create_redis_client(args.redis_address)
|
||||
push_error_to_all_drivers(redis_client, traceback_str)
|
||||
push_error_to_all_drivers(redis_client, traceback_str, "worker_crash")
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, then any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
|
||||
Reference in New Issue
Block a user