mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:54:27 +08:00
[Core] Read resources from an environment variable (#9831)
This commit is contained in:
+25
-3
@@ -2,8 +2,9 @@ import atexit
|
||||
import collections
|
||||
import datetime
|
||||
import errno
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import socket
|
||||
@@ -254,12 +255,33 @@ class Node:
|
||||
|
||||
def get_resource_spec(self):
|
||||
"""Resolve and return the current resource spec for the node."""
|
||||
|
||||
def merge_resources(env_dict, params_dict):
|
||||
"""Merge two dictionaries, picking from the second in the event of a conflict.
|
||||
Also emit a warning on every conflict.
|
||||
"""
|
||||
result = params_dict.copy()
|
||||
result.update(env_dict)
|
||||
|
||||
for key in set(env_dict.keys()).intersection(
|
||||
set(params_dict.keys())):
|
||||
logger.warning("Autoscaler is overriding your resource:"
|
||||
"{}: {} with {}.".format(
|
||||
key, params_dict[key], env_dict[key]))
|
||||
return result
|
||||
|
||||
env_resources = {}
|
||||
env_string = os.getenv("RAY_OVERRIDE_RESOURCES")
|
||||
if env_string:
|
||||
env_resources = json.loads(env_string)
|
||||
|
||||
if not self._resource_spec:
|
||||
resources = merge_resources(env_resources,
|
||||
self._ray_params.resources)
|
||||
self._resource_spec = ResourceSpec(
|
||||
self._ray_params.num_cpus, self._ray_params.num_gpus,
|
||||
self._ray_params.memory, self._ray_params.object_store_memory,
|
||||
self._ray_params.resources,
|
||||
self._ray_params.redis_max_memory).resolve(
|
||||
resources, self._ray_params.redis_max_memory).resolve(
|
||||
is_head=self.head, node_ip_address=self.node_ip_address)
|
||||
return self._resource_spec
|
||||
|
||||
|
||||
Reference in New Issue
Block a user