Add is_actor_checkpoint_method to TaskSpec. (#1117)

* Add is_actor_checkpoint_method to TaskSpec.

* Fix linting.

* Fix rebase error.

* Fix errors from rebase.
This commit is contained in:
Robert Nishihara
2017-10-15 16:52:10 -07:00
committed by Philipp Moritz
parent 802941994d
commit f3e3c7ec71
10 changed files with 72 additions and 45 deletions
+8 -14
View File
@@ -113,12 +113,6 @@ def put_dummy_object(worker, dummy_object_id):
worker.actor_pinned_objects[dummy_object_id] = dummy_object
def is_checkpoint_task(task_counter, checkpoint_interval):
if checkpoint_interval <= 0:
return False
return (task_counter % checkpoint_interval == 0)
def make_actor_method_executor(worker, method_name, method):
"""Make an executor that wraps a user-defined actor method.
@@ -333,6 +327,8 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
def make_actor(cls, num_cpus, num_gpus, checkpoint_interval):
if checkpoint_interval == 0:
raise Exception("checkpoint_interval must be greater than 0.")
# Add one to the checkpoint interval since we will insert a mock task for
# every checkpoint.
checkpoint_interval += 1
@@ -621,23 +617,21 @@ def make_actor(cls, num_cpus, num_gpus, checkpoint_interval):
# task.
args.append(dependency)
actor_counter = self._ray_actor_counter
# Mark checkpoint methods with a negative task counter.
if is_checkpoint_task(actor_counter, checkpoint_interval):
actor_counter = self._ray_actor_counter * -1
is_actor_checkpoint_method = (method_name == "__ray_checkpoint__")
function_id = get_actor_method_function_id(method_name)
object_ids = ray.worker.global_worker.submit_task(
function_id, args, actor_id=self._ray_actor_id,
actor_counter=actor_counter)
actor_counter=self._ray_actor_counter,
is_actor_checkpoint_method=is_actor_checkpoint_method)
# Update the actor counter and cursor to reflect the most recent
# invocation.
self._ray_actor_counter += 1
self._ray_actor_cursor = object_ids.pop()
# Submit a checkpoint task if necessary.
if is_checkpoint_task(self._ray_actor_counter,
checkpoint_interval):
# Submit a checkpoint task if it is time to do so.
if (checkpoint_interval > 1 and
self._ray_actor_counter % checkpoint_interval == 0):
self.__ray_checkpoint__.remote()
# The last object returned is the dummy object that should be
+1 -1
View File
@@ -170,7 +170,7 @@ class TestGlobalScheduler(unittest.TestCase):
task2 = local_scheduler.Task(random_driver_id(), random_function_id(),
[random_object_id()], 0, random_task_id(),
0, local_scheduler.ObjectID(NIL_ACTOR_ID),
0, [1.0, 2.0, 0.0])
0, 0, [1.0, 2.0, 0.0])
self.assertEqual(task2.required_resources(), [1.0, 2.0, 0.0])
def test_redis_only_single_task(self):
+12 -4
View File
@@ -454,7 +454,8 @@ class Worker(object):
assert len(final_results) == len(object_ids)
return final_results
def submit_task(self, function_id, args, actor_id=None, actor_counter=0):
def submit_task(self, function_id, args, actor_id=None, actor_counter=0,
is_actor_checkpoint_method=False):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with ID
@@ -462,9 +463,14 @@ class Worker(object):
the function from the scheduler and immediately return them.
Args:
args (List[Any]): The arguments to pass into the function.
Arguments can be object IDs or they can be values. If they are
values, they must be serializable objecs.
function_id: The ID of the function to execute.
args: The arguments to pass into the function. Arguments can be
object IDs or they can be values. If they are values, they must
be serializable objecs.
actor_id: The ID of the actor that this task is for.
actor_counter: The counter of the actor task.
is_actor_checkpoint_method: True if this is an actor checkpoint
task and false otherwise.
"""
with log_span("ray:submit_task", worker=self):
check_main_thread()
@@ -495,6 +501,7 @@ class Worker(object):
self.task_index,
actor_id,
actor_counter,
is_actor_checkpoint_method,
[function_properties.num_cpus, function_properties.num_gpus,
function_properties.num_custom_resource])
# Increment the worker's task index to track how many tasks have
@@ -1834,6 +1841,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
worker.task_index,
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
nil_actor_counter,
False,
[0, 0, 0])
global_state._execute_command(
driver_task.task_id(),