diff --git a/python/ray/services.py b/python/ray/services.py index 424b1009c..588e867b7 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -4,6 +4,8 @@ from __future__ import print_function import binascii from collections import namedtuple, OrderedDict +import cloudpickle +import json import os import psutil import random @@ -261,6 +263,68 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): "configured properly.") +def _compute_version_info(): + """Compute the versions of Python, cloudpickle, and Ray. + + Returns: + A tuple containing the version information. + """ + ray_version = ray.__version__ + ray_location = ray.__file__ + python_version = ".".join(map(str, sys.version_info[:3])) + cloudpickle_version = cloudpickle.__version__ + return ray_version, ray_location, python_version, cloudpickle_version + + +def _put_version_info_in_redis(redis_client): + """Store version information in Redis. + + This will be used to detect if workers or drivers are started using + different versions of Python, cloudpickle, or Ray. + + Args: + redis_client: A client for the primary Redis shard. + """ + redis_client.set("VERSION_INFO", json.dumps(_compute_version_info())) + + +def check_version_info(redis_client): + """Check if various version info of this process is correct. + + This will be used to detect if workers or drivers are started using + different versions of Python, cloudpickle, or Ray. If the version + information is not present in Redis, then no check is done. + + Args: + redis_client: A client for the primary Redis shard. + + Raises: + Exception: An exception is raised if there is a version mismatch. + """ + redis_reply = redis_client.get("VERSION_INFO") + + # Don't do the check if there is no version information in Redis. This + # is to make it easier to do things like start the processes by hand. + if redis_reply is None: + return + + true_version_info = tuple(json.loads(redis_reply.decode("ascii"))) + version_info = _compute_version_info() + if version_info != true_version_info: + node_ip_address = ray.services.get_node_ip_address() + raise Exception("Version mismatch: The cluster was started with:\n" + " Ray: " + true_version_info[0] + "\n" + " Ray location: " + true_version_info[1] + "\n" + " Python: " + true_version_info[2] + "\n" + " Cloudpickle: " + true_version_info[3] + "\n" + "This process on node " + node_ip_address + + " was started with:" + "\n" + " Ray: " + version_info[0] + "\n" + " Ray location: " + version_info[1] + "\n" + " Python: " + version_info[2] + "\n" + " Cloudpickle: " + version_info[3]) + + def start_redis(node_ip_address, port=None, num_redis_shards=1, @@ -311,6 +375,9 @@ def start_redis(node_ip_address, # can access it and know whether or not to redirect their output. redis_client.set("RedirectOutput", 1 if redirect_worker_output else 0) + # Store version information in the primary Redis shard. + _put_version_info_in_redis(redis_client) + # Start other Redis shards listening on random ports. Each Redis shard logs # to a separate file, prefixed by "redis-". redis_shards = [] diff --git a/python/ray/worker.py b/python/ray/worker.py index 06dfa8b53..443459f7b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1706,6 +1706,11 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, redis_ip_address, redis_port = info["redis_address"].split(":") worker.redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + + # Check that the version information matches the version information that + # the Ray cluster was started with. + ray.services.check_version_info(worker.redis_client) + worker.lock = threading.Lock() # Check the RedirectOutput key in Redis and based on its value redirect diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 7a9257e67..6e480ade5 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -41,12 +41,13 @@ def create_redis_client(redis_address): return redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) -def push_error_to_all_drivers(redis_client, message): +def push_error_to_all_drivers(redis_client, message, error_type): """Push an error message to all drivers. Args: redis_client: The redis client to use. message: The error message to push. + error_type: The type of the error. """ DRIVER_ID_LENGTH = 20 # We use a driver ID of all zeros to push an error message to all @@ -54,7 +55,7 @@ def push_error_to_all_drivers(redis_client, message): driver_id = DRIVER_ID_LENGTH * b"\x00" error_key = b"Error:" + driver_id + b":" + random_string() # Create a Redis client. - redis_client.hmset(error_key, {"type": "worker_crash", + redis_client.hmset(error_key, {"type": error_type, "message": message}) redis_client.rpush("ErrorKeys", error_key) @@ -79,6 +80,13 @@ if __name__ == "__main__": ray.worker.connect(info, mode=ray.WORKER_MODE, actor_id=actor_id) + try: + ray.services.check_version_info(ray.worker.global_worker.redis_client) + except Exception as e: + traceback_str = traceback.format_exc() + push_error_to_all_drivers(ray.worker.global_worker.redis_client, + traceback_str, "version_mismatch") + error_explanation = """ This error is unexpected and should not have happened. Somehow a worker crashed in an unanticipated way causing the main_loop to throw an exception, @@ -96,7 +104,7 @@ if __name__ == "__main__": traceback_str = traceback.format_exc() + error_explanation # Create a Redis client. redis_client = create_redis_client(args.redis_address) - push_error_to_all_drivers(redis_client, traceback_str) + push_error_to_all_drivers(redis_client, traceback_str, "worker_crash") # TODO(rkn): Note that if the worker was in the middle of executing # a task, then any worker or driver that is blocking in a get call # and waiting for the output of that task will hang. We need to diff --git a/test/failure_test.py b/test/failure_test.py index 85bdeea2b..13f696781 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -259,6 +259,21 @@ class ActorTest(unittest.TestCase): class WorkerDeath(unittest.TestCase): + def testWorkerRaisingException(self): + ray.init(num_workers=1, driver_mode=ray.SILENT_MODE) + + @ray.remote + def f(): + ray.worker.global_worker._get_next_task_from_local_scheduler = None + + # Running this task should cause the worker to raise an exception after + # the task has successfully completed. + f.remote() + + wait_for_errors(b"worker_crash", 1) + wait_for_errors(b"worker_died", 1) + self.assertEqual(len(ray.error_info()), 2) + def testWorkerDying(self): ray.init(num_workers=0, driver_mode=ray.SILENT_MODE) @@ -434,5 +449,20 @@ class PutErrorTest(unittest.TestCase): ray.worker.cleanup() +class ConfigurationTest(unittest.TestCase): + + def testVersionMismatch(self): + import cloudpickle + cloudpickle_version = cloudpickle.__version__ + cloudpickle.__version__ = "fake cloudpickle version" + + ray.init(num_workers=1, driver_mode=ray.SILENT_MODE) + + wait_for_errors(b"version_mismatch", 1) + + cloudpickle.__version__ = cloudpickle_version + ray.worker.cleanup() + + if __name__ == "__main__": unittest.main(verbosity=2)