From acb3280235634a455b68d903758fd943f790458f Mon Sep 17 00:00:00 2001 From: henktillman Date: Tue, 23 Jun 2020 17:11:18 -0700 Subject: [PATCH] GCP credentials (#9052) --- python/ray/autoscaler/gcp/config.py | 115 +++++++++++++++------ python/ray/autoscaler/gcp/node_provider.py | 10 +- python/ray/autoscaler/ray-schema.json | 4 + 3 files changed, 92 insertions(+), 37 deletions(-) diff --git a/python/ray/autoscaler/gcp/config.py b/python/ray/autoscaler/gcp/config.py index ff54d1286..0584f8382 100644 --- a/python/ray/autoscaler/gcp/config.py +++ b/python/ray/autoscaler/gcp/config.py @@ -1,4 +1,5 @@ from functools import partial +import json import os import logging import time @@ -7,13 +8,10 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.backends import default_backend from googleapiclient import discovery, errors +from google.oauth2 import service_account logger = logging.getLogger(__name__) -crm = discovery.build("cloudresourcemanager", "v1") -iam = discovery.build("iam", "v1") -compute = discovery.build("compute", "v1") - VERSION = "v1" RAY = "ray-autoscaler" @@ -30,7 +28,7 @@ MAX_POLLS = 12 POLL_INTERVAL = 5 -def wait_for_crm_operation(operation): +def wait_for_crm_operation(operation, crm): """Poll for cloud resource manager operation until finished.""" logger.info("wait_for_crm_operation: " "Waiting for operation {} to finish...".format(operation)) @@ -49,7 +47,7 @@ def wait_for_crm_operation(operation): return result -def wait_for_compute_global_operation(project_name, operation): +def wait_for_compute_global_operation(project_name, operation, compute): """Poll for global compute operation until finished.""" logger.info("wait_for_compute_global_operation: " "Waiting for operation {} to finish...".format( @@ -75,7 +73,8 @@ def wait_for_compute_global_operation(project_name, operation): def key_pair_name(i, region, project_id, ssh_user): """Returns the ith default gcp_key_pair_name.""" - key_name = "{}_gcp_{}_{}_{}".format(RAY, region, project_id, ssh_user, i) + key_name = "{}_gcp_{}_{}_{}_{}".format(RAY, region, project_id, ssh_user, + i) return key_name @@ -104,16 +103,62 @@ def generate_rsa_key_pair(): return public_key, pem +def _create_crm(gcp_credentials): + return discovery.build( + "cloudresourcemanager", "v1", credentials=gcp_credentials) + + +def _create_iam(gcp_credentials): + return discovery.build("iam", "v1", credentials=gcp_credentials) + + +def _create_compute(gcp_credentials): + return discovery.build("compute", "v1", credentials=gcp_credentials) + + +def fetch_gcp_credentials_from_provider_config(provider_config): + """ + Attempt to fetch and parse the JSON GCP credentials from the provider + config yaml file. + """ + service_account_info_string = provider_config.get("gcp_credentials") + if service_account_info_string is None: + logger.info("gcp_credentials not found in cluster yaml file. " + "Falling back to GOOGLE_APPLICATION_CREDENTIALS " + "environment variable.") + # If gcp_credentials is None, then discovery.build will search for + # credentials in the local environment. + return None + + # If parsing the gcp_credentials failed, then the user likely made a + # mistake in copying the credentials into the config yaml. + try: + service_account_info = json.loads(service_account_info_string) + except json.decoder.JSONDecodeError: + raise RuntimeError("gcp_credentials found in cluster yaml file but " + "formatted improperly.") + gcp_credentials = service_account.Credentials.from_service_account_info( + service_account_info) + return gcp_credentials + + def bootstrap_gcp(config): - config = _configure_project(config) - config = _configure_iam_role(config) - config = _configure_key_pair(config) - config = _configure_subnet(config) + gcp_credentials = fetch_gcp_credentials_from_provider_config( + config["provider"]) + + crm = _create_crm(gcp_credentials) + iam = _create_iam(gcp_credentials) + compute = _create_compute(gcp_credentials) + + config = _configure_project(config, crm) + config = _configure_iam_role(config, crm, iam) + config = _configure_key_pair(config, compute) + config = _configure_subnet(config, compute) return config -def _configure_project(config): +def _configure_project(config, crm): """Setup a Google Cloud Platform Project. Google Compute Platform organizes all the resources, such as storage @@ -124,12 +169,12 @@ def _configure_project(config): assert config["provider"]["project_id"] is not None, ( "'project_id' must be set in the 'provider' section of the autoscaler" " config. Notice that the project id must be globally unique.") - project = _get_project(project_id) + project = _get_project(project_id, crm) if project is None: # Project not found, try creating it - _create_project(project_id) - project = _get_project(project_id) + _create_project(project_id, crm) + project = _get_project(project_id, crm) assert project is not None, "Failed to create project" assert project["lifecycleState"] == "ACTIVE", ( @@ -141,7 +186,7 @@ def _configure_project(config): return config -def _configure_iam_role(config): +def _configure_iam_role(config, crm, iam): """Setup a gcp service account with IAM roles. Creates a gcp service acconut and binds IAM roles which allow it to control @@ -154,7 +199,7 @@ def _configure_iam_role(config): email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format( account_id=DEFAULT_SERVICE_ACCOUNT_ID, project_id=config["provider"]["project_id"]) - service_account = _get_service_account(email, config) + service_account = _get_service_account(email, config, iam) if service_account is None: logger.info("_configure_iam_role: " @@ -162,11 +207,13 @@ def _configure_iam_role(config): DEFAULT_SERVICE_ACCOUNT_ID)) service_account = _create_service_account( - DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config) + DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config, + iam) assert service_account is not None, "Failed to create service account" - _add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES) + _add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES, + crm) config["head_node"]["serviceAccounts"] = [{ "email": service_account["email"], @@ -180,7 +227,7 @@ def _configure_iam_role(config): return config -def _configure_key_pair(config): +def _configure_key_pair(config, compute): """Configure SSH access, using an existing key pair if possible. Creates a project-wide ssh key that can be used to access all the instances @@ -235,7 +282,8 @@ def _configure_key_pair(config): "Creating new key pair {}".format(key_name)) public_key, private_key = generate_rsa_key_pair() - _create_project_ssh_key_pair(project, public_key, ssh_user) + _create_project_ssh_key_pair(project, public_key, ssh_user, + compute) # We need to make sure to _create_ the file with the right # permissions. In order to do that we need to change the default @@ -272,7 +320,7 @@ def _configure_key_pair(config): return config -def _configure_subnet(config): +def _configure_subnet(config, compute): """Pick a reasonable subnet if not specified by the config.""" # Rationale: avoid subnet lookup if the network is already @@ -281,7 +329,7 @@ def _configure_subnet(config): and "networkInterfaces" in config["worker_nodes"]): return config - subnets = _list_subnets(config) + subnets = _list_subnets(config, compute) if not subnets: raise NotImplementedError("Should be able to create subnet.") @@ -312,7 +360,7 @@ def _configure_subnet(config): return config -def _list_subnets(config): +def _list_subnets(config, compute): response = compute.subnetworks().list( project=config["provider"]["project_id"], region=config["provider"]["region"]).execute() @@ -320,7 +368,7 @@ def _list_subnets(config): return response["items"] -def _get_subnet(config, subnet_id): +def _get_subnet(config, subnet_id, compute): subnet = compute.subnetworks().get( project=config["provider"]["project_id"], region=config["provider"]["region"], @@ -330,7 +378,7 @@ def _get_subnet(config, subnet_id): return subnet -def _get_project(project_id): +def _get_project(project_id, crm): try: project = crm.projects().get(projectId=project_id).execute() except errors.HttpError as e: @@ -341,18 +389,18 @@ def _get_project(project_id): return project -def _create_project(project_id): +def _create_project(project_id, crm): operation = crm.projects().create(body={ "projectId": project_id, "name": project_id }).execute() - result = wait_for_crm_operation(operation) + result = wait_for_crm_operation(operation, crm) return result -def _get_service_account(account, config): +def _get_service_account(account, config, iam): project_id = config["provider"]["project_id"] full_name = ("projects/{project_id}/serviceAccounts/{account}" "".format(project_id=project_id, account=account)) @@ -367,7 +415,7 @@ def _get_service_account(account, config): return service_account -def _create_service_account(account_id, account_config, config): +def _create_service_account(account_id, account_config, config, iam): project_id = config["provider"]["project_id"] service_account = iam.projects().serviceAccounts().create( @@ -380,7 +428,7 @@ def _create_service_account(account_id, account_config, config): return service_account -def _add_iam_policy_binding(service_account, roles): +def _add_iam_policy_binding(service_account, roles, crm): """Add new IAM roles for the service account.""" project_id = service_account["projectId"] email = service_account["email"] @@ -420,7 +468,7 @@ def _add_iam_policy_binding(service_account, roles): return result -def _create_project_ssh_key_pair(project, public_key, ssh_user): +def _create_project_ssh_key_pair(project, public_key, ssh_user, compute): """Inserts an ssh-key into project commonInstanceMetadata""" key_parts = public_key.split(" ") @@ -450,6 +498,7 @@ def _create_project_ssh_key_pair(project, public_key, ssh_user): operation = compute.projects().setCommonInstanceMetadata( project=project["name"], body=common_instance_metadata).execute() - response = wait_for_compute_global_operation(project["name"], operation) + response = wait_for_compute_global_operation(project["name"], operation, + compute) return response diff --git a/python/ray/autoscaler/gcp/node_provider.py b/python/ray/autoscaler/gcp/node_provider.py index af48f2334..4b46c186f 100644 --- a/python/ray/autoscaler/gcp/node_provider.py +++ b/python/ray/autoscaler/gcp/node_provider.py @@ -3,11 +3,10 @@ from threading import RLock import time import logging -from googleapiclient import discovery - from ray.autoscaler.node_provider import NodeProvider from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME -from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL +from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL, \ + fetch_gcp_credentials_from_provider_config, _create_compute logger = logging.getLogger(__name__) @@ -43,7 +42,10 @@ class GCPNodeProvider(NodeProvider): NodeProvider.__init__(self, provider_config, cluster_name) self.lock = RLock() - self.compute = discovery.build("compute", "v1") + gcp_credentials = fetch_gcp_credentials_from_provider_config( + provider_config) + + self.compute = _create_compute(gcp_credentials) # Cache of node objects from the last nodes() call. This avoids # excessive DescribeInstances requests. diff --git a/python/ray/autoscaler/ray-schema.json b/python/ray/autoscaler/ray-schema.json index d48f4f231..ca1d2ff07 100644 --- a/python/ray/autoscaler/ray-schema.json +++ b/python/ray/autoscaler/ray-schema.json @@ -139,6 +139,10 @@ "project_id": { "type": ["string", "null"], "description": "GCP globally unique project id" + }, + "gcp_credentials": { + "type": "string", + "description": "JSON string constituting GCP credentials" } } },