diff --git a/python/ray/state.py b/python/ray/state.py index 66cc60a13..04be03a8a 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -31,6 +31,9 @@ class GlobalState: def __init__(self): """Create a GlobalState object.""" + # Args used for lazy init of this object. + self.redis_address = None + self.redis_password = None # The redis server storing metadata, such as function table, client # table, log files, event logs, workers/actions info. self.redis_client = None @@ -39,12 +42,17 @@ class GlobalState: self.global_state_accessor = None def _check_connected(self): - """Check that the object has been initialized before it is used. + """Ensure that the object has been initialized before it is used. + + This lazily initializes clients needed for state accessors. Raises: RuntimeError: An exception is raised if ray.init() has not been called yet. """ + if self.redis_client is None and self.redis_address is not None: + self._really_init_global_state() + if (self.redis_client is None or self.redis_clients is None or self.global_state_accessor is None): raise ray.exceptions.RaySystemError( @@ -55,15 +63,14 @@ class GlobalState: """Disconnect global state from GCS.""" self.redis_client = None self.redis_clients = None + self.redis_address = None + self.redis_password = None if self.global_state_accessor is not None: self.global_state_accessor.disconnect() self.global_state_accessor = None - def _initialize_global_state(self, - redis_address, - redis_password=None, - timeout=20): - """Initialize the GlobalState object by connecting to Redis. + def _initialize_global_state(self, redis_address, redis_password=None): + """Set args for lazily initialization of the GlobalState object. It's possible that certain keys in Redis may not have been fully populated yet. In this case, we will retry this method until they have @@ -73,10 +80,17 @@ class GlobalState: redis_address: The Redis address to connect. redis_password: The password of the redis server. """ + + # Save args for lazy init of global state. This avoids opening extra + # redis connections from each worker until needed. + self.redis_address = redis_address + self.redis_password = redis_password + + def _really_init_global_state(self, timeout=20): self.redis_client = services.create_redis_client( - redis_address, redis_password) + self.redis_address, self.redis_password) self.global_state_accessor = GlobalStateAccessor( - redis_address, redis_password, False) + self.redis_address, self.redis_password, False) self.global_state_accessor.connect() start_time = time.time() @@ -119,7 +133,7 @@ class GlobalState: for shard_address in redis_shard_addresses: self.redis_clients.append( services.create_redis_client(shard_address.decode(), - redis_password)) + self.redis_password)) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key.