diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 7140f393c..d1f84358e 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1731,6 +1731,35 @@ def test_wait(ray_start_regular): ray.wait([1]) +def test_duplicate_args(ray_start_regular): + @ray.remote + def f(arg1, + arg2, + arg1_duplicate, + kwarg1=None, + kwarg2=None, + kwarg1_duplicate=None): + assert arg1 == kwarg1 + assert arg1 != arg2 + assert arg1 == arg1_duplicate + assert kwarg1 != kwarg2 + assert kwarg1 == kwarg1_duplicate + + # Test by-value arguments. + arg1 = [1] + arg2 = [2] + ray.get( + f.remote( + arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) + + # Test by-reference arguments. + arg1 = ray.put([1]) + arg2 = ray.put([2]) + ray.get( + f.remote( + arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 743d08a09..ec41da335 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1102,7 +1102,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, arg_reference_ids->resize(num_args); absl::flat_hash_set by_ref_ids; - absl::flat_hash_map by_ref_indices; + absl::flat_hash_map> by_ref_indices; for (size_t i = 0; i < task.NumArgs(); ++i) { if (task.ArgByRef(i)) { @@ -1117,7 +1117,12 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, } const auto &arg_id = task.ArgId(i, 0); by_ref_ids.insert(arg_id); - by_ref_indices.emplace(arg_id, i); + auto it = by_ref_indices.find(arg_id); + if (it == by_ref_indices.end()) { + by_ref_indices.emplace(arg_id, std::vector({i})); + } else { + it->second.push_back(i); + } arg_reference_ids->at(i) = arg_id; // The task borrows all args passed by reference. Because the task does // not have a reference to the argument ID in the frontend, it is not @@ -1155,7 +1160,9 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, RAY_RETURN_NOT_OK(plasma_store_provider_->Get(by_ref_ids, -1, worker_context_, &result_map, &got_exception)); for (const auto &it : result_map) { - args->at(by_ref_indices[it.first]) = it.second; + for (size_t idx : by_ref_indices[it.first]) { + args->at(idx) = it.second; + } } return Status::OK();