mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:12:00 +08:00
Add test for mutually recursive remote functions. (#5349)
This commit is contained in:
@@ -22,7 +22,6 @@ from ray import ray_constants
|
||||
from ray import cloudpickle as pickle
|
||||
from ray.utils import (
|
||||
binary_to_hex,
|
||||
is_cython,
|
||||
is_function_or_method,
|
||||
is_class_method,
|
||||
check_oversized_pickle,
|
||||
@@ -355,23 +354,8 @@ class FunctionActorManager(object):
|
||||
"""
|
||||
if self._worker.load_code_from_local:
|
||||
return
|
||||
# Work around limitations of Python pickling.
|
||||
function = remote_function._function
|
||||
function_name_global_valid = function.__name__ in function.__globals__
|
||||
function_name_global_value = function.__globals__.get(
|
||||
function.__name__)
|
||||
# Allow the function to reference itself as a global variable
|
||||
if not is_cython(function):
|
||||
function.__globals__[function.__name__] = remote_function
|
||||
try:
|
||||
pickled_function = pickle.dumps(function)
|
||||
finally:
|
||||
# Undo our changes
|
||||
if function_name_global_valid:
|
||||
function.__globals__[function.__name__] = (
|
||||
function_name_global_value)
|
||||
else:
|
||||
del function.__globals__[function.__name__]
|
||||
pickled_function = pickle.dumps(function)
|
||||
|
||||
check_oversized_pickle(pickled_function,
|
||||
remote_function._function_name,
|
||||
|
||||
@@ -319,6 +319,38 @@ def test_nested_functions(ray_start_regular):
|
||||
|
||||
assert ray.get(f.remote()) == (1, 2)
|
||||
|
||||
# Test a remote function that recursively calls itself.
|
||||
|
||||
@ray.remote
|
||||
def factorial(n):
|
||||
if n == 0:
|
||||
return 1
|
||||
return n * ray.get(factorial.remote(n - 1))
|
||||
|
||||
assert ray.get(factorial.remote(0)) == 1
|
||||
assert ray.get(factorial.remote(1)) == 1
|
||||
assert ray.get(factorial.remote(2)) == 2
|
||||
assert ray.get(factorial.remote(3)) == 6
|
||||
assert ray.get(factorial.remote(4)) == 24
|
||||
assert ray.get(factorial.remote(5)) == 120
|
||||
|
||||
# Test remote functions that recursively call each other.
|
||||
|
||||
@ray.remote
|
||||
def factorial_even(n):
|
||||
assert n % 2 == 0
|
||||
if n == 0:
|
||||
return 1
|
||||
return n * ray.get(factorial_odd.remote(n - 1))
|
||||
|
||||
@ray.remote
|
||||
def factorial_odd(n):
|
||||
assert n % 2 == 1
|
||||
return n * ray.get(factorial_even.remote(n - 1))
|
||||
|
||||
assert ray.get(factorial_even.remote(4)) == 24
|
||||
assert ray.get(factorial_odd.remote(5)) == 120
|
||||
|
||||
|
||||
def test_ray_recursive_objects(ray_start_regular):
|
||||
class ClassA(object):
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
# This test is not inside of test_basic.py because when a recursive remote
|
||||
# function is defined inside of another function, we currently can't handle
|
||||
# that.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
@ray.remote
|
||||
def factorial(n):
|
||||
if n == 0:
|
||||
return 1
|
||||
return n * ray.get(factorial.remote(n - 1))
|
||||
|
||||
|
||||
def test_recursion(ray_start_regular):
|
||||
assert ray.get(factorial.remote(0)) == 1
|
||||
assert ray.get(factorial.remote(1)) == 1
|
||||
assert ray.get(factorial.remote(2)) == 2
|
||||
assert ray.get(factorial.remote(3)) == 6
|
||||
assert ray.get(factorial.remote(4)) == 24
|
||||
assert ray.get(factorial.remote(5)) == 120
|
||||
Reference in New Issue
Block a user