GCP credentials (#9052)

This commit is contained in:
henktillman
2020-06-23 17:11:18 -07:00
committed by GitHub
parent 9b4428c668
commit acb3280235
3 changed files with 92 additions and 37 deletions
+82 -33
View File
@@ -1,4 +1,5 @@
from functools import partial
import json
import os
import logging
import time
@@ -7,13 +8,10 @@ from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from googleapiclient import discovery, errors
from google.oauth2 import service_account
logger = logging.getLogger(__name__)
crm = discovery.build("cloudresourcemanager", "v1")
iam = discovery.build("iam", "v1")
compute = discovery.build("compute", "v1")
VERSION = "v1"
RAY = "ray-autoscaler"
@@ -30,7 +28,7 @@ MAX_POLLS = 12
POLL_INTERVAL = 5
def wait_for_crm_operation(operation):
def wait_for_crm_operation(operation, crm):
"""Poll for cloud resource manager operation until finished."""
logger.info("wait_for_crm_operation: "
"Waiting for operation {} to finish...".format(operation))
@@ -49,7 +47,7 @@ def wait_for_crm_operation(operation):
return result
def wait_for_compute_global_operation(project_name, operation):
def wait_for_compute_global_operation(project_name, operation, compute):
"""Poll for global compute operation until finished."""
logger.info("wait_for_compute_global_operation: "
"Waiting for operation {} to finish...".format(
@@ -75,7 +73,8 @@ def wait_for_compute_global_operation(project_name, operation):
def key_pair_name(i, region, project_id, ssh_user):
"""Returns the ith default gcp_key_pair_name."""
key_name = "{}_gcp_{}_{}_{}".format(RAY, region, project_id, ssh_user, i)
key_name = "{}_gcp_{}_{}_{}_{}".format(RAY, region, project_id, ssh_user,
i)
return key_name
@@ -104,16 +103,62 @@ def generate_rsa_key_pair():
return public_key, pem
def _create_crm(gcp_credentials):
return discovery.build(
"cloudresourcemanager", "v1", credentials=gcp_credentials)
def _create_iam(gcp_credentials):
return discovery.build("iam", "v1", credentials=gcp_credentials)
def _create_compute(gcp_credentials):
return discovery.build("compute", "v1", credentials=gcp_credentials)
def fetch_gcp_credentials_from_provider_config(provider_config):
"""
Attempt to fetch and parse the JSON GCP credentials from the provider
config yaml file.
"""
service_account_info_string = provider_config.get("gcp_credentials")
if service_account_info_string is None:
logger.info("gcp_credentials not found in cluster yaml file. "
"Falling back to GOOGLE_APPLICATION_CREDENTIALS "
"environment variable.")
# If gcp_credentials is None, then discovery.build will search for
# credentials in the local environment.
return None
# If parsing the gcp_credentials failed, then the user likely made a
# mistake in copying the credentials into the config yaml.
try:
service_account_info = json.loads(service_account_info_string)
except json.decoder.JSONDecodeError:
raise RuntimeError("gcp_credentials found in cluster yaml file but "
"formatted improperly.")
gcp_credentials = service_account.Credentials.from_service_account_info(
service_account_info)
return gcp_credentials
def bootstrap_gcp(config):
config = _configure_project(config)
config = _configure_iam_role(config)
config = _configure_key_pair(config)
config = _configure_subnet(config)
gcp_credentials = fetch_gcp_credentials_from_provider_config(
config["provider"])
crm = _create_crm(gcp_credentials)
iam = _create_iam(gcp_credentials)
compute = _create_compute(gcp_credentials)
config = _configure_project(config, crm)
config = _configure_iam_role(config, crm, iam)
config = _configure_key_pair(config, compute)
config = _configure_subnet(config, compute)
return config
def _configure_project(config):
def _configure_project(config, crm):
"""Setup a Google Cloud Platform Project.
Google Compute Platform organizes all the resources, such as storage
@@ -124,12 +169,12 @@ def _configure_project(config):
assert config["provider"]["project_id"] is not None, (
"'project_id' must be set in the 'provider' section of the autoscaler"
" config. Notice that the project id must be globally unique.")
project = _get_project(project_id)
project = _get_project(project_id, crm)
if project is None:
# Project not found, try creating it
_create_project(project_id)
project = _get_project(project_id)
_create_project(project_id, crm)
project = _get_project(project_id, crm)
assert project is not None, "Failed to create project"
assert project["lifecycleState"] == "ACTIVE", (
@@ -141,7 +186,7 @@ def _configure_project(config):
return config
def _configure_iam_role(config):
def _configure_iam_role(config, crm, iam):
"""Setup a gcp service account with IAM roles.
Creates a gcp service acconut and binds IAM roles which allow it to control
@@ -154,7 +199,7 @@ def _configure_iam_role(config):
email = SERVICE_ACCOUNT_EMAIL_TEMPLATE.format(
account_id=DEFAULT_SERVICE_ACCOUNT_ID,
project_id=config["provider"]["project_id"])
service_account = _get_service_account(email, config)
service_account = _get_service_account(email, config, iam)
if service_account is None:
logger.info("_configure_iam_role: "
@@ -162,11 +207,13 @@ def _configure_iam_role(config):
DEFAULT_SERVICE_ACCOUNT_ID))
service_account = _create_service_account(
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config)
DEFAULT_SERVICE_ACCOUNT_ID, DEFAULT_SERVICE_ACCOUNT_CONFIG, config,
iam)
assert service_account is not None, "Failed to create service account"
_add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES)
_add_iam_policy_binding(service_account, DEFAULT_SERVICE_ACCOUNT_ROLES,
crm)
config["head_node"]["serviceAccounts"] = [{
"email": service_account["email"],
@@ -180,7 +227,7 @@ def _configure_iam_role(config):
return config
def _configure_key_pair(config):
def _configure_key_pair(config, compute):
"""Configure SSH access, using an existing key pair if possible.
Creates a project-wide ssh key that can be used to access all the instances
@@ -235,7 +282,8 @@ def _configure_key_pair(config):
"Creating new key pair {}".format(key_name))
public_key, private_key = generate_rsa_key_pair()
_create_project_ssh_key_pair(project, public_key, ssh_user)
_create_project_ssh_key_pair(project, public_key, ssh_user,
compute)
# We need to make sure to _create_ the file with the right
# permissions. In order to do that we need to change the default
@@ -272,7 +320,7 @@ def _configure_key_pair(config):
return config
def _configure_subnet(config):
def _configure_subnet(config, compute):
"""Pick a reasonable subnet if not specified by the config."""
# Rationale: avoid subnet lookup if the network is already
@@ -281,7 +329,7 @@ def _configure_subnet(config):
and "networkInterfaces" in config["worker_nodes"]):
return config
subnets = _list_subnets(config)
subnets = _list_subnets(config, compute)
if not subnets:
raise NotImplementedError("Should be able to create subnet.")
@@ -312,7 +360,7 @@ def _configure_subnet(config):
return config
def _list_subnets(config):
def _list_subnets(config, compute):
response = compute.subnetworks().list(
project=config["provider"]["project_id"],
region=config["provider"]["region"]).execute()
@@ -320,7 +368,7 @@ def _list_subnets(config):
return response["items"]
def _get_subnet(config, subnet_id):
def _get_subnet(config, subnet_id, compute):
subnet = compute.subnetworks().get(
project=config["provider"]["project_id"],
region=config["provider"]["region"],
@@ -330,7 +378,7 @@ def _get_subnet(config, subnet_id):
return subnet
def _get_project(project_id):
def _get_project(project_id, crm):
try:
project = crm.projects().get(projectId=project_id).execute()
except errors.HttpError as e:
@@ -341,18 +389,18 @@ def _get_project(project_id):
return project
def _create_project(project_id):
def _create_project(project_id, crm):
operation = crm.projects().create(body={
"projectId": project_id,
"name": project_id
}).execute()
result = wait_for_crm_operation(operation)
result = wait_for_crm_operation(operation, crm)
return result
def _get_service_account(account, config):
def _get_service_account(account, config, iam):
project_id = config["provider"]["project_id"]
full_name = ("projects/{project_id}/serviceAccounts/{account}"
"".format(project_id=project_id, account=account))
@@ -367,7 +415,7 @@ def _get_service_account(account, config):
return service_account
def _create_service_account(account_id, account_config, config):
def _create_service_account(account_id, account_config, config, iam):
project_id = config["provider"]["project_id"]
service_account = iam.projects().serviceAccounts().create(
@@ -380,7 +428,7 @@ def _create_service_account(account_id, account_config, config):
return service_account
def _add_iam_policy_binding(service_account, roles):
def _add_iam_policy_binding(service_account, roles, crm):
"""Add new IAM roles for the service account."""
project_id = service_account["projectId"]
email = service_account["email"]
@@ -420,7 +468,7 @@ def _add_iam_policy_binding(service_account, roles):
return result
def _create_project_ssh_key_pair(project, public_key, ssh_user):
def _create_project_ssh_key_pair(project, public_key, ssh_user, compute):
"""Inserts an ssh-key into project commonInstanceMetadata"""
key_parts = public_key.split(" ")
@@ -450,6 +498,7 @@ def _create_project_ssh_key_pair(project, public_key, ssh_user):
operation = compute.projects().setCommonInstanceMetadata(
project=project["name"], body=common_instance_metadata).execute()
response = wait_for_compute_global_operation(project["name"], operation)
response = wait_for_compute_global_operation(project["name"], operation,
compute)
return response
+6 -4
View File
@@ -3,11 +3,10 @@ from threading import RLock
import time
import logging
from googleapiclient import discovery
from ray.autoscaler.node_provider import NodeProvider
from ray.autoscaler.tags import TAG_RAY_CLUSTER_NAME, TAG_RAY_NODE_NAME
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL
from ray.autoscaler.gcp.config import MAX_POLLS, POLL_INTERVAL, \
fetch_gcp_credentials_from_provider_config, _create_compute
logger = logging.getLogger(__name__)
@@ -43,7 +42,10 @@ class GCPNodeProvider(NodeProvider):
NodeProvider.__init__(self, provider_config, cluster_name)
self.lock = RLock()
self.compute = discovery.build("compute", "v1")
gcp_credentials = fetch_gcp_credentials_from_provider_config(
provider_config)
self.compute = _create_compute(gcp_credentials)
# Cache of node objects from the last nodes() call. This avoids
# excessive DescribeInstances requests.
+4
View File
@@ -139,6 +139,10 @@
"project_id": {
"type": ["string", "null"],
"description": "GCP globally unique project id"
},
"gcp_credentials": {
"type": "string",
"description": "JSON string constituting GCP credentials"
}
}
},