Fix passing duplicate by-reference arguments (#7306)

This commit is contained in:
Edward Oakes
2020-02-24 19:18:16 -08:00
committed by GitHub
parent 8b6784de06
commit f2faf8d26e
2 changed files with 39 additions and 3 deletions
+29
View File
@@ -1731,6 +1731,35 @@ def test_wait(ray_start_regular):
ray.wait([1])
def test_duplicate_args(ray_start_regular):
@ray.remote
def f(arg1,
arg2,
arg1_duplicate,
kwarg1=None,
kwarg2=None,
kwarg1_duplicate=None):
assert arg1 == kwarg1
assert arg1 != arg2
assert arg1 == arg1_duplicate
assert kwarg1 != kwarg2
assert kwarg1 == kwarg1_duplicate
# Test by-value arguments.
arg1 = [1]
arg2 = [2]
ray.get(
f.remote(
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
# Test by-reference arguments.
arg1 = ray.put([1])
arg2 = ray.put([2])
ray.get(
f.remote(
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
+10 -3
View File
@@ -1102,7 +1102,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
arg_reference_ids->resize(num_args);
absl::flat_hash_set<ObjectID> by_ref_ids;
absl::flat_hash_map<ObjectID, int> by_ref_indices;
absl::flat_hash_map<ObjectID, std::vector<size_t>> by_ref_indices;
for (size_t i = 0; i < task.NumArgs(); ++i) {
if (task.ArgByRef(i)) {
@@ -1117,7 +1117,12 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
}
const auto &arg_id = task.ArgId(i, 0);
by_ref_ids.insert(arg_id);
by_ref_indices.emplace(arg_id, i);
auto it = by_ref_indices.find(arg_id);
if (it == by_ref_indices.end()) {
by_ref_indices.emplace(arg_id, std::vector<size_t>({i}));
} else {
it->second.push_back(i);
}
arg_reference_ids->at(i) = arg_id;
// The task borrows all args passed by reference. Because the task does
// not have a reference to the argument ID in the frontend, it is not
@@ -1155,7 +1160,9 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
RAY_RETURN_NOT_OK(plasma_store_provider_->Get(by_ref_ids, -1, worker_context_,
&result_map, &got_exception));
for (const auto &it : result_map) {
args->at(by_ref_indices[it.first]) = it.second;
for (size_t idx : by_ref_indices[it.first]) {
args->at(idx) = it.second;
}
}
return Status::OK();