mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 07:03:48 +08:00
GCP credentials (#9052)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
@@ -7,13 +8,10 @@ from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from googleapiclient import discovery, errors
|
||||
from google.oauth2 import service_account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
crm = discovery.build("cloudresourcemanager", "v1")
|
||||
iam = discovery.build("iam", "v1")
|
||||
compute = discovery.build("compute", "v1")
|
||||
|
||||
VERSION = "v1"
|
||||
|
||||
RAY = "ray-autoscaler"
|
||||
@@ -30,7 +28,7 @@ MAX_POLLS = 12
|
||||
POLL_INTERVAL = 5
|
||||
|
||||
|
||||
def wait_for_crm_operation(operation):
|
||||
def wait_for_crm_operation(operation, crm):
|
||||
"""Poll for cloud resource manager operation until finished."""
|
||||
logger.info("wait_for_crm_operation: "
|
||||
"Waiting for operation {} to finish...".format(operation))
|
||||
@@ -49,7 +47,7 @@ def wait_for_crm_operation(operation):
|
||||
return result
|
||||
|
||||
|
||||
def wait_for_compute_global_operation(project_name, operation):
|
||||
def wait_for_compute_global_operation(project_name, operation, compute):
|
||||
"""Poll for global compute operation until finished."""
|
||||
logger.info("wait_for_compute_global_operation: "
|
||||
"Waiting for operation {} to finish...".format(
|
||||
@@ -75,7 +73,8 @@ def wait_for_compute_global_operation(project_name, operation):
|
||||
|
||||
def key_pair_name(i, region, project_id, ssh_user):
|
||||
"""Returns the ith default gcp_key_pair_name."""
|
||||
key_name = "{}_gcp_{}_{}_{}".format(RAY, region, project_id, ssh_user, i)
|
||||
key_name = "{}_gcp_{}_{}_{}_{}".format(RAY, region, project_id, ssh_user,
|
||||
i)
|
||||
return key_name
|
||||
|
||||
|
||||
@@ -104,16 +103,62 @@ def generate_rsa_key_pair():
|
||||
return public_key, pem
|
||||
|
||||
|
||||
def _create_crm(gcp_credentials):
|
||||
return discovery.build(
|
||||
"cloudresourcemanager", "v1", credentials=gcp_credentials)
|
||||
|
||||
|
||||
def _create_iam(gcp_credentials):
|
||||
return discovery.build("iam", "v1", credentials=gcp_credentials)
|
||||
|
||||
|
||||
def _create_compute(gcp_credentials):
|
||||
return discovery.build("compute", "v1", credentials=gcp_credentials)
|
||||
|
||||
|
||||
def fetch_gcp_credentials_from_provider_config(provider_config):
|
||||
"""
|
||||
Attempt to fetch and parse the JSON GCP credentials from the provider
|
||||
config yaml file.
|
||||
"""
|
||||
service_account_info_string = provider_config.get("gcp_credentials")
|
||||
if service_account_info_string is None:
|
||||
logger.info("gcp_credentials not found in cluster yaml file. "
|
||||
"Falling back to GOOGLE_APPLICATION_CREDENTIALS "
|
||||
"environment variable.")
|
||||
# If gcp_credentials is None, then discovery.build will search for
|
||||
# credentials in the local environment.
|
||||
return None
|
||||
|
||||
# If parsing the gcp_credentials failed, then the user likely made a
|
||||
# mistake in copying the credentials into the config yaml.
|
||||
try:
|
||||
service_account_info = json.loads(service_account_info_string)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise RuntimeError("gcp_credentials found in cluster yaml file but "
|
||||
"formatted improperly.")
|
||||
gcp_credentials = service_account.Credentials.from_service_account_info(
|
||||
service_account_info)
|
||||
return gcp_credentials
|
||||
|
||||
|
||||
def bootstrap_gcp(config):
|
||||
config = _configure_project(config)
|
||||
config = _configure_iam_role(config)
|
||||
config = _configure_key_pair(config)
|
||||
config = _configure_subnet(config)
|
||||
gcp_credentials = fetch_gcp_credentials_from_provider_config(
|
||||
config["provider"])
|
||||
|
||||
crm = _create_crm(gcp_credentials)
|
||||
iam = _create_iam(gcp_credentials)
|
||||
compute = _create_compute(gcp_credentials)
|
||||
|
||||
config = _configure_project(config, crm)
|
||||
config = _configure_iam_role(config, crm, iam)
|
||||
config = _configure_key_pair(config, compute)
|
||||
config = _configure_subnet(config, compute)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _configure_project(config):
|
||||
def _configure_project(config, crm):
|
||||
"""Setup a Google Cloud Platform Project.
|
||||
|
||||
Google Compute Platform organizes all the resources, such as storage
|
||||
@@ -124,12 +169,12 @@ def _configure_project(config):
|
||||
assert config["provider"]["project_id"] is not None, (
|
||||
"'project_id' must be set in the 'provider' section of the autoscaler"
|
||||
" config. Notice that the project id must be globally unique.")
|
||||
project = _get_project(project_id)
|
||||
project = _get_project(project_id, crm)
|
||||
|
||||
if project is None:
|
||||
# Project not found, try creating it
|
||||
_create_project(project_id)
|
||||
project = _get_project(project_id)
|
||||
_create_project(project_id, crm)
|
||||
project = _get_project(project_id, crm)
|
||||
|
||||
assert project is not None, "Failed to create project"
|
||||
assert project["lifecycleState"] == "ACTIVE", (
|
||||
@@ -141,7 +186,7 @@ def _configure_project(config):
|
||||
return config
|
||||
|
||||
|
||||
def _configure_iam_role(config):
|
||||
def _configure_iam_role(config, crm, iam):
|
||||
"""Setup a gcp service account with IAM roles.
|
||||
|
||||
Creates a gcp service acconut and binds IAM roles which allow it to control
|
||||
@@ -154,7 +199,7 @@ def _configure_iam_role(config):
|
||||
email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format(
|
||||
account_id=DEFAULT_SERVICE_ACCOUNT_ID,
|
||||
project_id=config["provider"]["project_id"])
|
||||
service_account = _get_service_account(email, config)
|
||||
service_account = _get_service_account(email, config, iam)
|
||||
|
||||
if service_account is None:
|
||||
logger.info("_configure_iam_role: "
|
||||
@@ -162,11 +207,13 @@ def _configure_iam_role(config):
|
||||
DEFAULT_SERVICE_ACCOUNT_ID))
|
||||
|
||||
service_account = _create_service_account(
|
||||
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config)
|
||||
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config,
|
||||
iam)
|
||||
|
||||
assert service_account is not None, "Failed to create service account"
|
||||
|
||||
_add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES)
|
||||
_add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES,
|
||||
crm)
|
||||
|
||||
config["head_node"]["serviceAccounts"] = [{
|
||||
"email": service_account["email"],
|
||||
@@ -180,7 +227,7 @@ def _configure_iam_role(config):
|
||||
return config
|
||||
|
||||
|
||||
def _configure_key_pair(config):
|
||||
def _configure_key_pair(config, compute):
|
||||
"""Configure SSH access, using an existing key pair if possible.
|
||||
|
||||
Creates a project-wide ssh key that can be used to access all the instances
|
||||
@@ -235,7 +282,8 @@ def _configure_key_pair(config):
|
||||
"Creating new key pair {}".format(key_name))
|
||||
public_key, private_key = generate_rsa_key_pair()
|
||||
|
||||
_create_project_ssh_key_pair(project, public_key, ssh_user)
|
||||
_create_project_ssh_key_pair(project, public_key, ssh_user,
|
||||
compute)
|
||||
|
||||
# We need to make sure to _create_ the file with the right
|
||||
# permissions. In order to do that we need to change the default
|
||||
@@ -272,7 +320,7 @@ def _configure_key_pair(config):
|
||||
return config
|
||||
|
||||
|
||||
def _configure_subnet(config):
|
||||
def _configure_subnet(config, compute):
|
||||
"""Pick a reasonable subnet if not specified by the config."""
|
||||
|
||||
# Rationale: avoid subnet lookup if the network is already
|
||||
@@ -281,7 +329,7 @@ def _configure_subnet(config):
|
||||
and "networkInterfaces" in config["worker_nodes"]):
|
||||
return config
|
||||
|
||||
subnets = _list_subnets(config)
|
||||
subnets = _list_subnets(config, compute)
|
||||
|
||||
if not subnets:
|
||||
raise NotImplementedError("Should be able to create subnet.")
|
||||
@@ -312,7 +360,7 @@ def _configure_subnet(config):
|
||||
return config
|
||||
|
||||
|
||||
def _list_subnets(config):
|
||||
def _list_subnets(config, compute):
|
||||
response = compute.subnetworks().list(
|
||||
project=config["provider"]["project_id"],
|
||||
region=config["provider"]["region"]).execute()
|
||||
@@ -320,7 +368,7 @@ def _list_subnets(config):
|
||||
return response["items"]
|
||||
|
||||
|
||||
def _get_subnet(config, subnet_id):
|
||||
def _get_subnet(config, subnet_id, compute):
|
||||
subnet = compute.subnetworks().get(
|
||||
project=config["provider"]["project_id"],
|
||||
region=config["provider"]["region"],
|
||||
@@ -330,7 +378,7 @@ def _get_subnet(config, subnet_id):
|
||||
return subnet
|
||||
|
||||
|
||||
def _get_project(project_id):
|
||||
def _get_project(project_id, crm):
|
||||
try:
|
||||
project = crm.projects().get(projectId=project_id).execute()
|
||||
except errors.HttpError as e:
|
||||
@@ -341,18 +389,18 @@ def _get_project(project_id):
|
||||
return project
|
||||
|
||||
|
||||
def _create_project(project_id):
|
||||
def _create_project(project_id, crm):
|
||||
operation = crm.projects().create(body={
|
||||
"projectId": project_id,
|
||||
"name": project_id
|
||||
}).execute()
|
||||
|
||||
result = wait_for_crm_operation(operation)
|
||||
result = wait_for_crm_operation(operation, crm)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _get_service_account(account, config):
|
||||
def _get_service_account(account, config, iam):
|
||||
project_id = config["provider"]["project_id"]
|
||||
full_name = ("projects/{project_id}/serviceAccounts/{account}"
|
||||
"".format(project_id=project_id, account=account))
|
||||
@@ -367,7 +415,7 @@ def _get_service_account(account, config):
|
||||
return service_account
|
||||
|
||||
|
||||
def _create_service_account(account_id, account_config, config):
|
||||
def _create_service_account(account_id, account_config, config, iam):
|
||||
project_id = config["provider"]["project_id"]
|
||||
|
||||
service_account = iam.projects().serviceAccounts().create(
|
||||
@@ -380,7 +428,7 @@ def _create_service_account(account_id, account_config, config):
|
||||
return service_account
|
||||
|
||||
|
||||
def _add_iam_policy_binding(service_account, roles):
|
||||
def _add_iam_policy_binding(service_account, roles, crm):
|
||||
"""Add new IAM roles for the service account."""
|
||||
project_id = service_account["projectId"]
|
||||
email = service_account["email"]
|
||||
@@ -420,7 +468,7 @@ def _add_iam_policy_binding(service_account, roles):
|
||||
return result
|
||||
|
||||
|
||||
def _create_project_ssh_key_pair(project, public_key, ssh_user):
|
||||
def _create_project_ssh_key_pair(project, public_key, ssh_user, compute):
|
||||
"""Inserts an ssh-key into project commonInstanceMetadata"""
|
||||
|
||||
key_parts = public_key.split(" ")
|
||||
@@ -450,6 +498,7 @@ def _create_project_ssh_key_pair(project, public_key, ssh_user):
|
||||
operation = compute.projects().setCommonInstanceMetadata(
|
||||
project=project["name"], body=common_instance_metadata).execute()
|
||||
|
||||
response = wait_for_compute_global_operation(project["name"], operation)
|
||||
response = wait_for_compute_global_operation(project["name"], operation,
|
||||
compute)
|
||||
|
||||
return response
|
||||
|
||||
@@ -3,11 +3,10 @@ from threading import RLock
|
||||
import time
|
||||
import logging
|
||||
|
||||
from googleapiclient import discovery
|
||||
|
||||
from ray.autoscaler.node_provider import NodeProvider
|
||||
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
|
||||
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL
|
||||
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL, \
|
||||
fetch_gcp_credentials_from_provider_config, _create_compute
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,7 +42,10 @@ class GCPNodeProvider(NodeProvider):
|
||||
NodeProvider.__init__(self, provider_config, cluster_name)
|
||||
|
||||
self.lock = RLock()
|
||||
self.compute = discovery.build("compute", "v1")
|
||||
gcp_credentials = fetch_gcp_credentials_from_provider_config(
|
||||
provider_config)
|
||||
|
||||
self.compute = _create_compute(gcp_credentials)
|
||||
|
||||
# Cache of node objects from the last nodes() call. This avoids
|
||||
# excessive DescribeInstances requests.
|
||||
|
||||
@@ -139,6 +139,10 @@
|
||||
"project_id": {
|
||||
"type": ["string", "null"],
|
||||
"description": "GCP globally unique project id"
|
||||
},
|
||||
"gcp_credentials": {
|
||||
"type": "string",
|
||||
"description": "JSON string constituting GCP credentials"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user