mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 19:14:35 +08:00
Cross-language invocation Part 1: Java calling Python functions and actors (#4166)
This commit is contained in:
@@ -0,0 +1,112 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import java.io.File;
|
||||
import java.lang.ProcessBuilder.Redirect;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.ray.api.Ray;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.SkipException;
|
||||
import org.testng.annotations.AfterClass;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
|
||||
public abstract class BaseMultiLanguageTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(BaseMultiLanguageTest.class);
|
||||
|
||||
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
|
||||
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
|
||||
|
||||
/**
|
||||
* Execute an external command.
|
||||
*
|
||||
* @return Whether the command succeeded.
|
||||
*/
|
||||
private boolean executeCommand(List<String> command, int waitTimeoutSeconds,
|
||||
Map<String, String> env) {
|
||||
try {
|
||||
LOGGER.info("Executing command: {}", String.join(" ", command));
|
||||
ProcessBuilder processBuilder = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
|
||||
.redirectError(Redirect.INHERIT);
|
||||
for (Entry<String, String> entry : env.entrySet()) {
|
||||
processBuilder.environment().put(entry.getKey(), entry.getValue());
|
||||
}
|
||||
Process process = processBuilder.start();
|
||||
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
|
||||
return process.exitValue() == 0;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
|
||||
LOGGER.info("Skip Multi-language tests because environment variable "
|
||||
+ "ENABLE_MULTI_LANGUAGE_TESTS isn't set");
|
||||
throw new SkipException("Skip test.");
|
||||
}
|
||||
|
||||
// Delete existing socket files.
|
||||
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
|
||||
File file = new File(socket);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
// Start ray cluster.
|
||||
String workerOptions =
|
||||
" -classpath " + System.getProperty("java.class.path");
|
||||
final List<String> startCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"start",
|
||||
"--head",
|
||||
"--redis-port=6379",
|
||||
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
|
||||
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
|
||||
"--load-code-from-local",
|
||||
"--include-java",
|
||||
"--java-worker-options=" + workerOptions
|
||||
);
|
||||
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
|
||||
throw new RuntimeException("Couldn't start ray cluster.");
|
||||
}
|
||||
|
||||
// Connect to the cluster.
|
||||
System.setProperty("ray.redis.address", "127.0.0.1:6379");
|
||||
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
|
||||
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The environment variables needed for the `ray start` command.
|
||||
*/
|
||||
protected Map<String, String> getRayStartEnv() {
|
||||
return ImmutableMap.of();
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public void tearDown() {
|
||||
// Disconnect to the cluster.
|
||||
Ray.shutdown();
|
||||
System.clearProperty("ray.redis.address");
|
||||
System.clearProperty("ray.object-store.socket-name");
|
||||
System.clearProperty("ray.raylet.socket-name");
|
||||
|
||||
// Stop ray cluster.
|
||||
final List<String> stopCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"stop"
|
||||
);
|
||||
if (!executeCommand(stopCommand, 10, ImmutableMap.of())) {
|
||||
throw new RuntimeException("Couldn't stop ray cluster");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
|
||||
|
||||
private static final String PYTHON_MODULE = "test_cross_language_invocation";
|
||||
|
||||
@Override
|
||||
protected Map<String, String> getRayStartEnv() {
|
||||
// Delete and re-create the temp dir.
|
||||
File tempDir = new File(
|
||||
System.getProperty("java.io.tmpdir") + File.separator + "ray_cross_language_test");
|
||||
FileUtils.deleteQuietly(tempDir);
|
||||
tempDir.mkdirs();
|
||||
tempDir.deleteOnExit();
|
||||
|
||||
// Write the test Python file to the temp dir.
|
||||
InputStream in = CrossLanguageInvocationTest.class
|
||||
.getResourceAsStream("/" + PYTHON_MODULE + ".py");
|
||||
File pythonFile = new File(
|
||||
tempDir.getAbsolutePath() + File.separator + PYTHON_MODULE + ".py");
|
||||
try {
|
||||
FileUtils.copyInputStreamToFile(in, pythonFile);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return ImmutableMap.of("PYTHONPATH", tempDir.getAbsolutePath());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCallingPythonFunction() {
|
||||
RayObject res = Ray.callPy(PYTHON_MODULE, "py_func", "hello".getBytes());
|
||||
Assert.assertEquals(res.get(), "Response from Python: hello".getBytes());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCallingPythonActor() {
|
||||
RayPyActor actor = Ray.createPyActor(PYTHON_MODULE, "Counter", "1".getBytes());
|
||||
RayObject res = Ray.callPy(actor, "increase", "1".getBytes());
|
||||
Assert.assertEquals(res.get(), "2".getBytes());
|
||||
}
|
||||
}
|
||||
@@ -1,113 +1,18 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.io.File;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.Assert;
|
||||
import org.testng.SkipException;
|
||||
import org.testng.annotations.AfterMethod;
|
||||
import org.testng.annotations.BeforeMethod;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
/**
|
||||
* Test starting a ray cluster with multi-language support.
|
||||
*/
|
||||
public class MultiLanguageClusterTest {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MultiLanguageClusterTest.class);
|
||||
|
||||
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
|
||||
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
|
||||
public class MultiLanguageClusterTest extends BaseMultiLanguageTest {
|
||||
|
||||
@RayRemote
|
||||
public static String echo(String word) {
|
||||
return word;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute an external command.
|
||||
*
|
||||
* @return Whether the command succeeded.
|
||||
*/
|
||||
private boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
|
||||
try {
|
||||
LOGGER.info("Executing command: {}", String.join(" ", command));
|
||||
Process process = new ProcessBuilder(command).inheritIO().start();
|
||||
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
|
||||
return process.exitValue() == 0;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
|
||||
}
|
||||
}
|
||||
|
||||
@BeforeMethod
|
||||
public void setUp(Method method) {
|
||||
String testName = method.getName();
|
||||
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
|
||||
LOGGER.info("Skip " + testName +
|
||||
" because env variable ENABLE_MULTI_LANGUAGE_TESTS isn't set");
|
||||
throw new SkipException("Skip test.");
|
||||
}
|
||||
|
||||
// Delete existing socket files.
|
||||
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
|
||||
File file = new File(socket);
|
||||
if (file.exists()) {
|
||||
file.delete();
|
||||
}
|
||||
}
|
||||
|
||||
// Start ray cluster.
|
||||
String testDir = System.getProperty("user.dir");
|
||||
String workerOptions =
|
||||
" -classpath " + String.format("%s/../../build/java/*:%s/target/*", testDir, testDir);
|
||||
final List<String> startCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"start",
|
||||
"--head",
|
||||
"--redis-port=6379",
|
||||
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
|
||||
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
|
||||
"--load-code-from-local",
|
||||
"--include-java",
|
||||
"--java-worker-options=" + workerOptions
|
||||
);
|
||||
if (!executeCommand(startCommand, 10)) {
|
||||
throw new RuntimeException("Couldn't start ray cluster.");
|
||||
}
|
||||
|
||||
// Connect to the cluster.
|
||||
System.setProperty("ray.redis.address", "127.0.0.1:6379");
|
||||
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
|
||||
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
|
||||
Ray.init();
|
||||
}
|
||||
|
||||
@AfterMethod
|
||||
public void tearDown() {
|
||||
// Disconnect to the cluster.
|
||||
Ray.shutdown();
|
||||
System.clearProperty("ray.redis.address");
|
||||
System.clearProperty("ray.object-store.socket-name");
|
||||
System.clearProperty("ray.raylet.socket-name");
|
||||
|
||||
// Stop ray cluster.
|
||||
final List<String> stopCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"stop"
|
||||
);
|
||||
if (!executeCommand(stopCommand, 10)) {
|
||||
throw new RuntimeException("Couldn't stop ray cluster");
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiLanguageCluster() {
|
||||
RayObject<String> obj = Ray.call(MultiLanguageClusterTest::echo, "hello");
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# This file is used by CrossLanguageInvocationTest.java to test cross-language
|
||||
# invocation.
|
||||
import ray
|
||||
import six
|
||||
|
||||
|
||||
@ray.remote
|
||||
def py_func(value):
|
||||
assert isinstance(value, bytes)
|
||||
return b"Response from Python: " + value
|
||||
|
||||
|
||||
@ray.remote
|
||||
class Counter(object):
|
||||
def __init__(self, value):
|
||||
self.value = int(value)
|
||||
|
||||
def increase(self, delta):
|
||||
self.value += int(delta)
|
||||
return str(self.value).encode("utf-8") if six.PY3 else str(self.value)
|
||||
Reference in New Issue
Block a user