mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 03:42:52 +08:00
Add is_actor_checkpoint_method to TaskSpec. (#1117)
* Add is_actor_checkpoint_method to TaskSpec. * Fix linting. * Fix rebase error. * Fix errors from rebase.
This commit is contained in:
committed by
Philipp Moritz
parent
802941994d
commit
f3e3c7ec71
+8
-14
@@ -113,12 +113,6 @@ def put_dummy_object(worker, dummy_object_id):
|
||||
worker.actor_pinned_objects[dummy_object_id] = dummy_object
|
||||
|
||||
|
||||
def is_checkpoint_task(task_counter, checkpoint_interval):
|
||||
if checkpoint_interval <= 0:
|
||||
return False
|
||||
return (task_counter % checkpoint_interval == 0)
|
||||
|
||||
|
||||
def make_actor_method_executor(worker, method_name, method):
|
||||
"""Make an executor that wraps a user-defined actor method.
|
||||
|
||||
@@ -333,6 +327,8 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
|
||||
|
||||
|
||||
def make_actor(cls, num_cpus, num_gpus, checkpoint_interval):
|
||||
if checkpoint_interval == 0:
|
||||
raise Exception("checkpoint_interval must be greater than 0.")
|
||||
# Add one to the checkpoint interval since we will insert a mock task for
|
||||
# every checkpoint.
|
||||
checkpoint_interval += 1
|
||||
@@ -621,23 +617,21 @@ def make_actor(cls, num_cpus, num_gpus, checkpoint_interval):
|
||||
# task.
|
||||
args.append(dependency)
|
||||
|
||||
actor_counter = self._ray_actor_counter
|
||||
# Mark checkpoint methods with a negative task counter.
|
||||
if is_checkpoint_task(actor_counter, checkpoint_interval):
|
||||
actor_counter = self._ray_actor_counter * -1
|
||||
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
|
||||
|
||||
function_id = get_actor_method_function_id(method_name)
|
||||
object_ids = ray.worker.global_worker.submit_task(
|
||||
function_id, args, actor_id=self._ray_actor_id,
|
||||
actor_counter=actor_counter)
|
||||
actor_counter=self._ray_actor_counter,
|
||||
is_actor_checkpoint_method=is_actor_checkpoint_method)
|
||||
# Update the actor counter and cursor to reflect the most recent
|
||||
# invocation.
|
||||
self._ray_actor_counter += 1
|
||||
self._ray_actor_cursor = object_ids.pop()
|
||||
|
||||
# Submit a checkpoint task if necessary.
|
||||
if is_checkpoint_task(self._ray_actor_counter,
|
||||
checkpoint_interval):
|
||||
# Submit a checkpoint task if it is time to do so.
|
||||
if (checkpoint_interval > 1 and
|
||||
self._ray_actor_counter % checkpoint_interval == 0):
|
||||
self.__ray_checkpoint__.remote()
|
||||
|
||||
# The last object returned is the dummy object that should be
|
||||
|
||||
@@ -170,7 +170,7 @@ class TestGlobalScheduler(unittest.TestCase):
|
||||
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
|
||||
[random_object_id()], 0, random_task_id(),
|
||||
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
0, [1.0, 2.0, 0.0])
|
||||
0, 0, [1.0, 2.0, 0.0])
|
||||
self.assertEqual(task2.required_resources(), [1.0, 2.0, 0.0])
|
||||
|
||||
def test_redis_only_single_task(self):
|
||||
|
||||
+12
-4
@@ -454,7 +454,8 @@ class Worker(object):
|
||||
assert len(final_results) == len(object_ids)
|
||||
return final_results
|
||||
|
||||
def submit_task(self, function_id, args, actor_id=None, actor_counter=0):
|
||||
def submit_task(self, function_id, args, actor_id=None, actor_counter=0,
|
||||
is_actor_checkpoint_method=False):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
Tell the scheduler to schedule the execution of the function with ID
|
||||
@@ -462,9 +463,14 @@ class Worker(object):
|
||||
the function from the scheduler and immediately return them.
|
||||
|
||||
Args:
|
||||
args (List[Any]): The arguments to pass into the function.
|
||||
Arguments can be object IDs or they can be values. If they are
|
||||
values, they must be serializable objecs.
|
||||
function_id: The ID of the function to execute.
|
||||
args: The arguments to pass into the function. Arguments can be
|
||||
object IDs or they can be values. If they are values, they must
|
||||
be serializable objecs.
|
||||
actor_id: The ID of the actor that this task is for.
|
||||
actor_counter: The counter of the actor task.
|
||||
is_actor_checkpoint_method: True if this is an actor checkpoint
|
||||
task and false otherwise.
|
||||
"""
|
||||
with log_span("ray:submit_task", worker=self):
|
||||
check_main_thread()
|
||||
@@ -495,6 +501,7 @@ class Worker(object):
|
||||
self.task_index,
|
||||
actor_id,
|
||||
actor_counter,
|
||||
is_actor_checkpoint_method,
|
||||
[function_properties.num_cpus, function_properties.num_gpus,
|
||||
function_properties.num_custom_resource])
|
||||
# Increment the worker's task index to track how many tasks have
|
||||
@@ -1834,6 +1841,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
||||
worker.task_index,
|
||||
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
|
||||
nil_actor_counter,
|
||||
False,
|
||||
[0, 0, 0])
|
||||
global_state._execute_command(
|
||||
driver_task.task_id(),
|
||||
|
||||
Reference in New Issue
Block a user