Give error if a worker has a version mismatch for Python Ray, or clou… (#1245)

* Give error if a worker has a version mismatch for Python Ray, or cloudpickle.

* Check version when attaching driver to cluster.

* Only do check if the version info is present.

* Bug fix.

* Fix typo.
This commit is contained in:
Robert Nishihara
2017-11-23 23:31:03 -08:00
committed by Philipp Moritz
parent ddfe00b7e8
commit 7af5292646
4 changed files with 113 additions and 3 deletions
+67
View File
@@ -4,6 +4,8 @@ from __future__ import print_function
import binascii
from collections import namedtuple, OrderedDict
import cloudpickle
import json
import os
import psutil
import random
@@ -261,6 +263,68 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
"configured properly.")
def _compute_version_info():
"""Compute the versions of Python, cloudpickle, and Ray.
Returns:
A tuple containing the version information.
"""
ray_version = ray.__version__
ray_location = ray.__file__
python_version = ".".join(map(str, sys.version_info[:3]))
cloudpickle_version = cloudpickle.__version__
return ray_version, ray_location, python_version, cloudpickle_version
def _put_version_info_in_redis(redis_client):
"""Store version information in Redis.
This will be used to detect if workers or drivers are started using
different versions of Python, cloudpickle, or Ray.
Args:
redis_client: A client for the primary Redis shard.
"""
redis_client.set("VERSION_INFO", json.dumps(_compute_version_info()))
def check_version_info(redis_client):
"""Check if various version info of this process is correct.
This will be used to detect if workers or drivers are started using
different versions of Python, cloudpickle, or Ray. If the version
information is not present in Redis, then no check is done.
Args:
redis_client: A client for the primary Redis shard.
Raises:
Exception: An exception is raised if there is a version mismatch.
"""
redis_reply = redis_client.get("VERSION_INFO")
# Don't do the check if there is no version information in Redis. This
# is to make it easier to do things like start the processes by hand.
if redis_reply is None:
return
true_version_info = tuple(json.loads(redis_reply.decode("ascii")))
version_info = _compute_version_info()
if version_info != true_version_info:
node_ip_address = ray.services.get_node_ip_address()
raise Exception("Version mismatch: The cluster was started with:\n"
" Ray: " + true_version_info[0] + "\n"
" Ray location: " + true_version_info[1] + "\n"
" Python: " + true_version_info[2] + "\n"
" Cloudpickle: " + true_version_info[3] + "\n"
"This process on node " + node_ip_address +
" was started with:" + "\n"
" Ray: " + version_info[0] + "\n"
" Ray location: " + version_info[1] + "\n"
" Python: " + version_info[2] + "\n"
" Cloudpickle: " + version_info[3])
def start_redis(node_ip_address,
port=None,
num_redis_shards=1,
@@ -311,6 +375,9 @@ def start_redis(node_ip_address,
# can access it and know whether or not to redirect their output.
redis_client.set("RedirectOutput", 1 if redirect_worker_output else 0)
# Store version information in the primary Redis shard.
_put_version_info_in_redis(redis_client)
# Start other Redis shards listening on random ports. Each Redis shard logs
# to a separate file, prefixed by "redis-<shard number>".
redis_shards = []
+5
View File
@@ -1706,6 +1706,11 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
redis_ip_address, redis_port = info["redis_address"].split(":")
worker.redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
# Check that the version information matches the version information that
# the Ray cluster was started with.
ray.services.check_version_info(worker.redis_client)
worker.lock = threading.Lock()
# Check the RedirectOutput key in Redis and based on its value redirect
+11 -3
View File
@@ -41,12 +41,13 @@ def create_redis_client(redis_address):
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
def push_error_to_all_drivers(redis_client, message):
def push_error_to_all_drivers(redis_client, message, error_type):
"""Push an error message to all drivers.
Args:
redis_client: The redis client to use.
message: The error message to push.
error_type: The type of the error.
"""
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all
@@ -54,7 +55,7 @@ def push_error_to_all_drivers(redis_client, message):
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
# Create a Redis client.
redis_client.hmset(error_key, {"type": "worker_crash",
redis_client.hmset(error_key, {"type": error_type,
"message": message})
redis_client.rpush("ErrorKeys", error_key)
@@ -79,6 +80,13 @@ if __name__ == "__main__":
ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id)
try:
ray.services.check_version_info(ray.worker.global_worker.redis_client)
except Exception as e:
traceback_str = traceback.format_exc()
push_error_to_all_drivers(ray.worker.global_worker.redis_client,
traceback_str, "version_mismatch")
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker
crashed in an unanticipated way causing the main_loop to throw an exception,
@@ -96,7 +104,7 @@ if __name__ == "__main__":
traceback_str = traceback.format_exc() + error_explanation
# Create a Redis client.
redis_client = create_redis_client(args.redis_address)
push_error_to_all_drivers(redis_client, traceback_str)
push_error_to_all_drivers(redis_client, traceback_str, "worker_crash")
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to