From 4175569d96857d146f982ae9232048383f1c728f Mon Sep 17 00:00:00 2001 From: architkulkarni Date: Thu, 29 Oct 2020 12:22:44 -0700 Subject: [PATCH] [Core] Add option to override environment variables for tasks and actors (#11619) --- python/ray/_raylet.pyx | 30 ++++--- python/ray/actor.py | 17 +++- python/ray/includes/common.pxd | 8 +- python/ray/remote_function.py | 22 +++-- python/ray/tests/test_advanced_3.py | 121 ++++++++++++++++++++++++++++ python/ray/worker.py | 6 ++ src/ray/common/task/task_spec.cc | 5 ++ src/ray/common/task/task_spec.h | 2 + src/ray/common/task/task_util.h | 7 +- src/ray/core_worker/common.h | 24 +++++- src/ray/core_worker/context.cc | 7 ++ src/ray/core_worker/context.h | 6 +- src/ray/core_worker/core_worker.cc | 44 +++++++--- src/ray/protobuf/common.proto | 2 + src/ray/raylet/worker_pool.cc | 28 +++++-- src/ray/raylet/worker_pool.h | 7 +- 16 files changed, 286 insertions(+), 50 deletions(-) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index efbbd5f2e..47cee86ba 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1001,28 +1001,32 @@ cdef class CoreWorker: int max_retries, PlacementGroupID placement_group_id, int64_t placement_group_bundle_index, - c_bool placement_group_capture_child_tasks): + c_bool placement_group_capture_child_tasks, + override_environment_variables): cdef: unordered_map[c_string, double] c_resources - CTaskOptions task_options CRayFunction ray_function c_vector[unique_ptr[CTaskArg]] args_vector c_vector[CObjectID] return_ids CPlacementGroupID c_placement_group_id = \ placement_group_id.native() + unordered_map[c_string, c_string] \ + c_override_environment_variables = \ + override_environment_variables with self.profile_event(b"submit_task"): prepare_resources(resources, &c_resources) - task_options = CTaskOptions( - name, num_returns, c_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) prepare_args(self, language, args, &args_vector) with nogil: CCoreWorkerProcess.GetCoreWorker().SubmitTask( - ray_function, args_vector, task_options, &return_ids, - max_retries, c_pair[CPlacementGroupID, int64_t]( + ray_function, args_vector, CTaskOptions( + name, num_returns, c_resources, + c_override_environment_variables), + &return_ids, max_retries, + c_pair[CPlacementGroupID, int64_t]( c_placement_group_id, placement_group_bundle_index), placement_group_capture_child_tasks) @@ -1043,7 +1047,8 @@ cdef class CoreWorker: PlacementGroupID placement_group_id, int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, - c_string extension_data + c_string extension_data, + override_environment_variables ): cdef: CRayFunction ray_function @@ -1054,6 +1059,9 @@ cdef class CoreWorker: CActorID c_actor_id CPlacementGroupID c_placement_group_id = \ placement_group_id.native() + unordered_map[c_string, c_string] \ + c_override_environment_variables = \ + override_environment_variables with self.profile_event(b"submit_task"): prepare_resources(resources, &c_resources) @@ -1072,7 +1080,8 @@ cdef class CoreWorker: c_pair[CPlacementGroupID, int64_t]( c_placement_group_id, placement_group_bundle_index), - placement_group_capture_child_tasks), + placement_group_capture_child_tasks, + c_override_environment_variables), extension_data, &c_actor_id)) @@ -1134,7 +1143,6 @@ cdef class CoreWorker: cdef: CActorID c_actor_id = actor_id.native() unordered_map[c_string, double] c_resources - CTaskOptions task_options CRayFunction ray_function c_vector[unique_ptr[CTaskArg]] args_vector c_vector[CObjectID] return_ids @@ -1142,7 +1150,6 @@ cdef class CoreWorker: with self.profile_event(b"submit_task"): if num_method_cpus > 0: c_resources[b"CPU"] = num_method_cpus - task_options = CTaskOptions(name, num_returns, c_resources) ray_function = CRayFunction( language.lang, function_descriptor.descriptor) prepare_args(self, language, args, &args_vector) @@ -1151,7 +1158,8 @@ cdef class CoreWorker: CCoreWorkerProcess.GetCoreWorker().SubmitActorTask( c_actor_id, ray_function, - args_vector, task_options, &return_ids) + args_vector, CTaskOptions(name, num_returns, c_resources), + &return_ids) return VectorToObjectRefs(return_ids) diff --git a/python/ray/actor.py b/python/ray/actor.py index e59c494f4..d3fa34ff8 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -418,7 +418,8 @@ class ActorClass: lifetime=None, placement_group=None, placement_group_bundle_index=-1, - placement_group_capture_child_tasks=None): + placement_group_capture_child_tasks=None, + override_environment_variables=None): """Configures and overrides the actor instantiation parameters. The arguments are the same as those that can be passed @@ -458,7 +459,9 @@ class ActorClass: placement_group=placement_group, placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( - placement_group_capture_child_tasks)) + placement_group_capture_child_tasks), + override_environment_variables=( + override_environment_variables)) return ActorOptionWrapper() @@ -478,7 +481,8 @@ class ActorClass: lifetime=None, placement_group=None, placement_group_bundle_index=-1, - placement_group_capture_child_tasks=None): + placement_group_capture_child_tasks=None, + override_environment_variables=None): """Create an actor. This method allows more flexibility than the remote method because @@ -515,6 +519,9 @@ class ActorClass: placement_group_capture_child_tasks: Whether or not children tasks of this actor should implicitly use the same placement group as its parent. It is True by default. + override_environment_variables: Environment variables to override + and/or introduce for this actor. This is a dictionary mapping + variable names to their values. Returns: A handle to the newly created actor. @@ -661,7 +668,9 @@ class ActorClass: placement_group_bundle_index, placement_group_capture_child_tasks, # Store actor_method_cpu in actor handle's extension data. - extension_data=str(actor_method_cpu)) + extension_data=str(actor_method_cpu), + override_environment_variables=override_environment_variables + or dict()) actor_handle = ActorHandle( meta.language, diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index b21635fd9..4ecc8d90a 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -243,6 +243,10 @@ cdef extern from "ray/core_worker/common.h" nogil: CTaskOptions() CTaskOptions(c_string name, int num_returns, unordered_map[c_string, double] &resources) + CTaskOptions(c_string name, int num_returns, + unordered_map[c_string, double] &resources, + const unordered_map[c_string, c_string] + &override_environment_variables) cdef cppclass CActorCreationOptions "ray::ActorCreationOptions": CActorCreationOptions() @@ -255,7 +259,9 @@ cdef extern from "ray/core_worker/common.h" nogil: const c_vector[c_string] &dynamic_worker_options, c_bool is_detached, c_string &name, c_bool is_asyncio, c_pair[CPlacementGroupID, int64_t] placement_options, - c_bool placement_group_capture_child_tasks) + c_bool placement_group_capture_child_tasks, + const unordered_map[c_string, c_string] + &override_environment_variables) cdef cppclass CPlacementGroupCreationOptions \ "ray::PlacementGroupCreationOptions": diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 8734fdb04..0375d84b9 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -138,10 +138,12 @@ class RemoteFunction: placement_group=None, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, + override_environment_variables=None, name=""): """Configures and overrides the task invocation parameters. - Options are overlapping values provided by :obj:`ray.remote`. + The arguments are the same as those that can be passed to + :obj:`ray.remote`. Examples: @@ -173,6 +175,8 @@ class RemoteFunction: placement_group_bundle_index=placement_group_bundle_index, placement_group_capture_child_tasks=( placement_group_capture_child_tasks), + override_environment_variables=( + override_environment_variables), name=name) return FuncWrapper() @@ -191,6 +195,7 @@ class RemoteFunction: placement_group=None, placement_group_bundle_index=-1, placement_group_capture_child_tasks=None, + override_environment_variables=None, name=""): """Submit the remote function for execution.""" worker = ray.worker.global_worker @@ -260,11 +265,18 @@ class RemoteFunction: "Cross language remote function " \ "cannot be executed locally." object_refs = worker.core_worker.submit_task( - self._language, self._function_descriptor, list_args, name, - num_returns, resources, max_retries, placement_group.id, + self._language, + self._function_descriptor, + list_args, + name, + num_returns, + resources, + max_retries, + placement_group.id, placement_group_bundle_index, - placement_group_capture_child_tasks) - + placement_group_capture_child_tasks, + override_environment_variables=override_environment_variables + or dict()) if len(object_refs) == 1: return object_refs[0] elif len(object_refs) > 1: diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index e76b77620..32690ae62 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -790,6 +790,127 @@ def test_detect_docker_cpus(): cpuset_file_name=cpuset_file.name) == 0.42 +def test_override_environment_variables_task(ray_start_regular): + @ray.remote + def get_env(key): + return os.environ.get(key) + + assert (ray.get( + get_env.options(override_environment_variables={ + "a": "b" + }).remote("a")) == "b") + + +def test_override_environment_variables_actor(ray_start_regular): + @ray.remote + class EnvGetter: + def get(self, key): + return os.environ.get(key) + + a = EnvGetter.options(override_environment_variables={ + "a": "b", + "c": "d" + }).remote() + assert (ray.get(a.get.remote("a")) == "b") + assert (ray.get(a.get.remote("c")) == "d") + + +def test_override_environment_variables_nested_task(ray_start_regular): + @ray.remote + def get_env(key): + return os.environ.get(key) + + @ray.remote + def get_env_wrapper(key): + return ray.get(get_env.remote(key)) + + assert (ray.get( + get_env_wrapper.options(override_environment_variables={ + "a": "b" + }).remote("a")) == "b") + + +def test_override_environment_variables_multitenancy(shutdown_only): + ray.init( + job_config=ray.job_config.JobConfig(worker_env={ + "foo1": "bar1", + "foo2": "bar2" + })) + + @ray.remote + def get_env(key): + return os.environ.get(key) + + assert ray.get(get_env.remote("foo1")) == "bar1" + assert ray.get(get_env.remote("foo2")) == "bar2" + assert ray.get( + get_env.options(override_environment_variables={ + "foo1": "baz1" + }).remote("foo1")) == "baz1" + assert ray.get( + get_env.options(override_environment_variables={ + "foo1": "baz1" + }).remote("foo2")) == "bar2" + + +def test_override_environment_variables_complex(shutdown_only): + ray.init( + job_config=ray.job_config.JobConfig(worker_env={ + "a": "job_a", + "b": "job_b", + "z": "job_z" + })) + + @ray.remote + def get_env(key): + return os.environ.get(key) + + @ray.remote + class NestedEnvGetter: + def get(self, key): + return os.environ.get(key) + + def get_task(self, key): + return ray.get(get_env.remote(key)) + + @ray.remote + class EnvGetter: + def get(self, key): + return os.environ.get(key) + + def get_task(self, key): + return ray.get(get_env.remote(key)) + + def nested_get(self, key): + aa = NestedEnvGetter.options(override_environment_variables={ + "c": "e", + "d": "dd" + }).remote() + return ray.get(aa.get.remote(key)) + + a = EnvGetter.options(override_environment_variables={ + "a": "b", + "c": "d" + }).remote() + assert (ray.get(a.get.remote("a")) == "b") + assert (ray.get(a.get_task.remote("a")) == "b") + assert (ray.get(a.nested_get.remote("a")) == "b") + assert (ray.get(a.nested_get.remote("c")) == "e") + assert (ray.get(a.nested_get.remote("d")) == "dd") + assert (ray.get( + get_env.options(override_environment_variables={ + "a": "b" + }).remote("a")) == "b") + + assert (ray.get(a.get.remote("z")) == "job_z") + assert (ray.get(a.get_task.remote("z")) == "job_z") + assert (ray.get(a.nested_get.remote("z")) == "job_z") + assert (ray.get( + get_env.options(override_environment_variables={ + "a": "b" + }).remote("z")) == "job_z") + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index 8cab2bd89..0f94aac97 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1857,6 +1857,12 @@ def remote(*args, **kwargs): crashes unexpectedly. The minimum valid value is 0, the default is 4 (default), and a value of -1 indicates infinite retries. + override_environment_variables (Dict[str, str]): This specifies + environment variables to override for the actor or task. The + overrides are propagated to all child actors and tasks. This + is a dictionary mapping variable names to their values. Existing + variables can be overridden, new ones can be created, and an + existing variable can be unset by setting it to an empty string. """ worker = global_worker diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index eaa13b7e9..c69ebde86 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -192,6 +192,11 @@ const ResourceSet &TaskSpecification::GetRequiredPlacementResources() const { return *required_placement_resources_; } +std::unordered_map +TaskSpecification::OverrideEnvironmentVariables() const { + return MapFromProtobuf(message_->override_environment_variables()); +} + bool TaskSpecification::IsDriverTask() const { return message_->type() == TaskType::DRIVER_TASK; } diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 311b34281..4399bf9c2 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -131,6 +131,8 @@ class TaskSpecification : public MessageWrapper { /// \return The recomputed dependencies for the task. std::vector GetDependencies() const; + std::unordered_map OverrideEnvironmentVariables() const; + bool IsDriverTask() const; Language GetLanguage() const; diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 8c72e431c..6e06c5f1c 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -87,7 +87,9 @@ class TaskSpecBuilder { const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const PlacementGroupID &placement_group_id, - bool placement_group_capture_child_tasks) { + bool placement_group_capture_child_tasks, + const std::unordered_map &override_environment_variables = + {}) { message_->set_type(TaskType::NORMAL_TASK); message_->set_name(name); message_->set_language(language); @@ -106,6 +108,9 @@ class TaskSpecBuilder { message_->set_placement_group_id(placement_group_id.Binary()); message_->set_placement_group_capture_child_tasks( placement_group_capture_child_tasks); + for (const auto &env : override_environment_variables) { + (*message_->mutable_override_environment_variables())[env.first] = env.second; + } return *this; } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index de790bcb6..a78f39722 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -55,8 +55,13 @@ class RayFunction { struct TaskOptions { TaskOptions() {} TaskOptions(std::string name, int num_returns, - std::unordered_map &resources) - : name(name), num_returns(num_returns), resources(resources) {} + std::unordered_map &resources, + const std::unordered_map + &override_environment_variables = {}) + : name(name), + num_returns(num_returns), + resources(resources), + override_environment_variables(override_environment_variables) {} /// The name of this task. std::string name; @@ -64,6 +69,10 @@ struct TaskOptions { int num_returns = 1; /// Resources required by this task. std::unordered_map resources; + /// Environment variables to update for this task. Maps a variable name to its + /// value. Can override existing environment variables and introduce new ones. + /// Propagated to child actors and/or tasks. + const std::unordered_map override_environment_variables; }; /// Options for actor creation tasks. @@ -76,7 +85,9 @@ struct ActorCreationOptions { const std::vector &dynamic_worker_options, bool is_detached, std::string &name, bool is_asyncio, PlacementOptions placement_options = std::make_pair(PlacementGroupID::Nil(), -1), - bool placement_group_capture_child_tasks = true) + bool placement_group_capture_child_tasks = true, + const std::unordered_map &override_environment_variables = + {}) : max_restarts(max_restarts), max_task_retries(max_task_retries), max_concurrency(max_concurrency), @@ -87,7 +98,8 @@ struct ActorCreationOptions { name(name), is_asyncio(is_asyncio), placement_options(placement_options), - placement_group_capture_child_tasks(placement_group_capture_child_tasks){}; + placement_group_capture_child_tasks(placement_group_capture_child_tasks), + override_environment_variables(override_environment_variables){}; /// Maximum number of times that the actor should be restarted if it dies /// unexpectedly. A value of -1 indicates infinite restarts. If it's 0, the @@ -122,6 +134,10 @@ struct ActorCreationOptions { /// When true, the child task will always scheduled on the same placement group /// specified in the PlacementOptions. bool placement_group_capture_child_tasks = true; + /// Environment variables to update for this actor. Maps a variable name to its + /// value. Can override existing environment variables and introduce new ones. + /// Propagated to child actors and/or tasks. + const std::unordered_map override_environment_variables; }; using PlacementStrategy = rpc::PlacementStrategy; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 9ff630dc0..7e47fc212 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -156,6 +156,11 @@ bool WorkerContext::ShouldCaptureChildTasksInPlacementGroup() const { } } +const std::unordered_map + &WorkerContext::GetCurrentOverrideEnvironmentVariables() const { + return override_environment_variables_; +} + void WorkerContext::SetCurrentJobId(const JobID &job_id) { current_job_id_ = job_id; } void WorkerContext::SetCurrentTaskId(const TaskID &task_id) { @@ -168,6 +173,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { RAY_CHECK(current_job_id_.IsNil()); SetCurrentJobId(task_spec.JobId()); current_task_is_direct_call_ = true; + override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); } else if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_job_id_.IsNil()); SetCurrentJobId(task_spec.JobId()); @@ -178,6 +184,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { current_actor_is_asyncio_ = task_spec.IsAsyncioActor(); current_actor_placement_group_id_ = task_spec.PlacementGroupId(); placement_group_capture_child_tasks_ = task_spec.PlacementGroupCaptureChildTasks(); + override_environment_variables_ = task_spec.OverrideEnvironmentVariables(); } else if (task_spec.IsActorTask()) { RAY_CHECK(current_job_id_ == task_spec.JobId()); RAY_CHECK(current_actor_id_ == task_spec.ActorId()); diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 81c714e7a..6bd3b1bfa 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -39,6 +39,9 @@ class WorkerContext { bool ShouldCaptureChildTasksInPlacementGroup() const; + const std::unordered_map + &GetCurrentOverrideEnvironmentVariables() const; + // TODO(edoakes): remove this once Python core worker uses the task interfaces. void SetCurrentJobId(const JobID &job_id); @@ -92,7 +95,8 @@ class WorkerContext { PlacementGroupID current_actor_placement_group_id_; // Whether or not we should implicitly capture parent's placement group. bool placement_group_capture_child_tasks_; - + // The environment variable overrides for the current actor or task. + std::unordered_map override_environment_variables_; /// The id of the (main) thread that constructed this worker context. boost::thread::id main_thread_id_; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 8430172da..49d90b91f 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -39,13 +39,14 @@ void BuildCommonTaskSpec( const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, std::vector *return_ids, const ray::PlacementGroupID &placement_group_id, - bool placement_group_capture_child_tasks) { + bool placement_group_capture_child_tasks, + const std::unordered_map &override_environment_variables) { // Build common task spec. - builder.SetCommonTaskSpec(task_id, name, function.GetLanguage(), - function.GetFunctionDescriptor(), job_id, current_task_id, - task_index, caller_id, address, num_returns, - required_resources, required_placement_resources, - placement_group_id, placement_group_capture_child_tasks); + builder.SetCommonTaskSpec( + task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, + current_task_id, task_index, caller_id, address, num_returns, required_resources, + required_placement_resources, placement_group_id, + placement_group_capture_child_tasks, override_environment_variables); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); @@ -1282,19 +1283,27 @@ void CoreWorker::SubmitTask(const RayFunction &function, const auto task_id = TaskID::ForNormalTask(worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index); - auto constrained_resources = AddPlacementGroupConstraint( task_options.resources, placement_options.first, placement_options.second); const std::unordered_map required_resources; auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; + // Propagate existing environment variable overrides, but override them with any new + // ones + std::unordered_map current_override_environment_variables = + worker_context_.GetCurrentOverrideEnvironmentVariables(); + std::unordered_map override_environment_variables = + task_options.override_environment_variables; + override_environment_variables.insert(current_override_environment_variables.begin(), + current_override_environment_variables.end()); // TODO(ekl) offload task building onto a thread pool for performance BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id, task_name, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, task_options.num_returns, constrained_resources, required_resources, return_ids, - placement_options.first, placement_group_capture_child_tasks); + placement_options.first, placement_group_capture_child_tasks, + override_environment_variables); TaskSpecification task_spec = builder.Build(); if (options_.is_local_mode) { ExecuteTaskLocalMode(task_spec); @@ -1322,6 +1331,14 @@ Status CoreWorker::CreateActor(const RayFunction &function, next_task_index); const TaskID actor_creation_task_id = TaskID::ForActorCreationTask(actor_id); const JobID job_id = worker_context_.GetCurrentJobID(); + // Propagate existing environment variable overrides, but override them with any new + // ones + std::unordered_map current_override_environment_variables = + worker_context_.GetCurrentOverrideEnvironmentVariables(); + std::unordered_map override_environment_variables = + actor_creation_options.override_environment_variables; + override_environment_variables.insert(current_override_environment_variables.begin(), + current_override_environment_variables.end()); std::vector return_ids; TaskSpecBuilder builder; auto new_placement_resources = @@ -1341,7 +1358,8 @@ Status CoreWorker::CreateActor(const RayFunction &function, rpc_address_, function, args, 1, new_resource, new_placement_resources, &return_ids, actor_creation_options.placement_options.first, - actor_creation_options.placement_group_capture_child_tasks); + actor_creation_options.placement_group_capture_child_tasks, + override_environment_variables); builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_restarts, actor_creation_options.dynamic_worker_options, actor_creation_options.max_concurrency, @@ -1442,13 +1460,15 @@ void CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &fun const auto task_name = task_options.name.empty() ? function.GetFunctionDescriptor()->DefaultTaskName() : task_options.name; + const std::unordered_map override_environment_variables = {}; BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, task_name, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, num_returns, task_options.resources, required_resources, return_ids, PlacementGroupID::Nil(), - true /* placement_group_capture_child_tasks */); - // NOTE: placement_group_capture_child_tasks will be ignored in the actor because - // we should always follow actor's option. + true, /* placement_group_capture_child_tasks */ + override_environment_variables); + // NOTE: placement_group_capture_child_tasks and override_environment_variables will be + // ignored in the actor because we should always follow the actor's option. const ObjectID new_cursor = return_ids->back(); actor_handle->SetActorTaskSpec(builder, new_cursor); diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 2c60cf3b8..99e654c45 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -190,6 +190,8 @@ message TaskSpec { bytes placement_group_id = 18; // Whether or not this task should capture parent's placement group automatically. bool placement_group_capture_child_tasks = 19; + // Environment variables to override for this task + map override_environment_variables = 20; } message Bundle { diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index ab092cd7d..30dd33ffe 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -173,10 +173,10 @@ WorkerPool::~WorkerPool() { } } -Process WorkerPool::StartWorkerProcess(const Language &language, - const rpc::WorkerType worker_type, - const JobID &job_id, - std::vector dynamic_options) { +Process WorkerPool::StartWorkerProcess( + const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, + std::vector dynamic_options, + std::unordered_map override_environment_variables) { rpc::JobConfig *job_config = nullptr; if (RayConfig::instance().enable_multi_tenancy() && worker_type != rpc::WorkerType::IO_WORKER) { @@ -324,6 +324,11 @@ Process WorkerPool::StartWorkerProcess(const Language &language, if (RayConfig::instance().enable_multi_tenancy() && job_config) { env.insert(job_config->worker_env().begin(), job_config->worker_env().end()); } + + for (const auto &pair : override_environment_variables) { + env[pair.first] = pair.second; + } + Process proc = StartProcess(worker_command_args, env); if (RayConfig::instance().enable_multi_tenancy() && job_config) { // If the pid is reused between processes, the old process must have exited. @@ -646,7 +651,7 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { const auto task_id = it->second; state.idle_dedicated_workers[task_id] = worker; } else { - // The worker is not used for the actor creation task without dynamic options. + // The worker is not used for the actor creation task with dynamic options. // Put the worker to the corresponding idle pool. if (worker->GetActorId().IsNil()) { state.idle.insert(worker); @@ -782,8 +787,10 @@ std::shared_ptr WorkerPool::PopWorker( std::shared_ptr worker = nullptr; Process proc; - if (task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) { - // Code path of actor creation task with dynamic worker options. + if ((task_spec.IsActorCreationTask() && !task_spec.DynamicWorkerOptions().empty()) || + task_spec.OverrideEnvironmentVariables().size() > 0) { + // Code path of task that needs a dedicated worker: an actor creation task with + // dynamic worker options, or any task with environment variable overrides. // Try to pop it from idle dedicated pool. auto it = state.idle_dedicated_workers.find(task_spec.TaskId()); if (it != state.idle_dedicated_workers.end()) { @@ -797,8 +804,13 @@ std::shared_ptr WorkerPool::PopWorker( } else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) { // We are not pending a registration from a worker for this task, // so start a new worker process for this task. + std::vector dynamic_options = {}; + if (task_spec.IsActorCreationTask()) { + dynamic_options = task_spec.DynamicWorkerOptions(); + } proc = StartWorkerProcess(task_spec.GetLanguage(), rpc::WorkerType::WORKER, - task_spec.JobId(), task_spec.DynamicWorkerOptions()); + task_spec.JobId(), dynamic_options, + task_spec.OverrideEnvironmentVariables()); if (proc.IsValid()) { state.dedicated_workers_to_tasks[proc] = task_spec.TaskId(); state.tasks_to_dedicated_workers[task_spec.TaskId()] = proc; diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index b325cf33a..4a876777a 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -283,9 +283,10 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// \param dynamic_options The dynamic options that we should add for worker command. /// \return The id of the process that we started if it's positive, /// otherwise it means we didn't start a process. - Process StartWorkerProcess(const Language &language, const rpc::WorkerType worker_type, - const JobID &job_id, - std::vector dynamic_options = {}); + Process StartWorkerProcess( + const Language &language, const rpc::WorkerType worker_type, const JobID &job_id, + std::vector dynamic_options = {}, + std::unordered_map override_environment_variables = {}); /// The implementation of how to start a new worker process with command arguments. /// The lifetime of the process is tied to that of the returned object,