mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 22:36:53 +08:00
Fix passing duplicate by-reference arguments (#7306)
This commit is contained in:
@@ -1731,6 +1731,35 @@ def test_wait(ray_start_regular):
|
||||
ray.wait([1])
|
||||
|
||||
|
||||
def test_duplicate_args(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(arg1,
|
||||
arg2,
|
||||
arg1_duplicate,
|
||||
kwarg1=None,
|
||||
kwarg2=None,
|
||||
kwarg1_duplicate=None):
|
||||
assert arg1 == kwarg1
|
||||
assert arg1 != arg2
|
||||
assert arg1 == arg1_duplicate
|
||||
assert kwarg1 != kwarg2
|
||||
assert kwarg1 == kwarg1_duplicate
|
||||
|
||||
# Test by-value arguments.
|
||||
arg1 = [1]
|
||||
arg2 = [2]
|
||||
ray.get(
|
||||
f.remote(
|
||||
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
|
||||
|
||||
# Test by-reference arguments.
|
||||
arg1 = ray.put([1])
|
||||
arg2 = ray.put([2])
|
||||
ray.get(
|
||||
f.remote(
|
||||
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
@@ -1102,7 +1102,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
||||
arg_reference_ids->resize(num_args);
|
||||
|
||||
absl::flat_hash_set<ObjectID> by_ref_ids;
|
||||
absl::flat_hash_map<ObjectID, int> by_ref_indices;
|
||||
absl::flat_hash_map<ObjectID, std::vector<size_t>> by_ref_indices;
|
||||
|
||||
for (size_t i = 0; i < task.NumArgs(); ++i) {
|
||||
if (task.ArgByRef(i)) {
|
||||
@@ -1117,7 +1117,12 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
||||
}
|
||||
const auto &arg_id = task.ArgId(i, 0);
|
||||
by_ref_ids.insert(arg_id);
|
||||
by_ref_indices.emplace(arg_id, i);
|
||||
auto it = by_ref_indices.find(arg_id);
|
||||
if (it == by_ref_indices.end()) {
|
||||
by_ref_indices.emplace(arg_id, std::vector<size_t>({i}));
|
||||
} else {
|
||||
it->second.push_back(i);
|
||||
}
|
||||
arg_reference_ids->at(i) = arg_id;
|
||||
// The task borrows all args passed by reference. Because the task does
|
||||
// not have a reference to the argument ID in the frontend, it is not
|
||||
@@ -1155,7 +1160,9 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Get(by_ref_ids, -1, worker_context_,
|
||||
&result_map, &got_exception));
|
||||
for (const auto &it : result_map) {
|
||||
args->at(by_ref_indices[it.first]) = it.second;
|
||||
for (size_t idx : by_ref_indices[it.first]) {
|
||||
args->at(idx) = it.second;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
||||
Reference in New Issue
Block a user