[Java] Fix multiple FunctionManagers creating multiple ClassLoader s (#6434)

This commit is contained in:
Kai Yang
2019-12-16 14:04:44 +08:00
committed by Hao Chen
parent e38b25edfb
commit b7d5c8f220
10 changed files with 245 additions and 48 deletions
@@ -1,10 +1,12 @@
package org.ray.runtime;
import java.util.List;
import java.util.concurrent.Callable;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
@@ -52,9 +54,9 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected TaskSubmitter taskSubmitter;
protected WorkerContext workerContext;
public AbstractRayRuntime(RayConfig rayConfig) {
public AbstractRayRuntime(RayConfig rayConfig, FunctionManager functionManager) {
this.rayConfig = rayConfig;
functionManager = new FunctionManager(rayConfig.jobResourcePath);
this.functionManager = functionManager;
runtimeContext = new RuntimeContextImpl(this);
}
@@ -4,6 +4,7 @@ import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtime.RayRuntimeFactory;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -19,14 +20,15 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
public RayRuntime createRayRuntime() {
RayConfig rayConfig = RayConfig.create();
try {
FunctionManager functionManager = new FunctionManager(rayConfig.jobResourcePath);
RayRuntime runtime;
if (rayConfig.runMode == RunMode.SINGLE_PROCESS) {
runtime = new RayDevRuntime(rayConfig);
runtime = new RayDevRuntime(rayConfig, functionManager);
} else {
if (rayConfig.workerMode == WorkerType.DRIVER) {
runtime = new RayNativeRuntime(rayConfig);
runtime = new RayNativeRuntime(rayConfig, functionManager);
} else {
runtime = new RayMultiWorkerNativeRuntime(rayConfig);
runtime = new RayMultiWorkerNativeRuntime(rayConfig, functionManager);
}
}
return runtime;
@@ -1,10 +1,12 @@
package org.ray.runtime;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.object.LocalModeObjectStore;
import org.ray.runtime.task.LocalModeTaskExecutor;
import org.ray.runtime.task.LocalModeTaskSubmitter;
@@ -17,8 +19,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
private AtomicInteger jobCounter = new AtomicInteger(0);
public RayDevRuntime(RayConfig rayConfig) {
super(rayConfig);
public RayDevRuntime(RayConfig rayConfig, FunctionManager functionManager) {
super(rayConfig, functionManager);
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
@@ -1,8 +1,10 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.concurrent.Callable;
import com.google.common.base.Preconditions;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
@@ -16,6 +18,7 @@ import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -28,6 +31,8 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class);
private final FunctionManager functionManager;
/**
* The number of workers per worker process.
*/
@@ -45,7 +50,8 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
*/
private final ThreadLocal<RayNativeRuntime> currentThreadRuntime = new ThreadLocal<>();
public RayMultiWorkerNativeRuntime(RayConfig rayConfig) {
public RayMultiWorkerNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
this.functionManager = functionManager;
Preconditions.checkState(
rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER);
Preconditions.checkState(rayConfig.numWorkersPerProcess > 0,
@@ -59,7 +65,7 @@ public class RayMultiWorkerNativeRuntime implements RayRuntime {
for (int i = 0; i < numWorkers; i++) {
final int workerIndex = i;
threads[i] = new Thread(() -> {
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig);
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig, functionManager);
runtimes[workerIndex] = runtime;
currentThreadRuntime.set(runtime);
runtime.run();
@@ -12,6 +12,7 @@ import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.NativeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.gcs.GcsClientOptions;
import org.ray.runtime.gcs.RedisClient;
@@ -90,8 +91,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
}
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
public RayNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
super(rayConfig, functionManager);
// Reset library path at runtime.
resetLibraryPath(rayConfig);
@@ -136,6 +137,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
manager.cleanup();
manager = null;
}
LOGGER.info("RayNativeRuntime shutdown");
}
@Override
@@ -16,6 +16,8 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
@@ -42,24 +44,26 @@ public class FunctionManager {
* Cache from a RayFunc object to its corresponding JavaFunctionDescriptor. Because
* `LambdaUtils.getSerializedLambda` is expensive.
*/
// If the cache is not thread local, we'll need a lock to protect it,
// which means competition is highly possible.
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
/**
* Mapping from the job id to the functions that belong to this job.
*/
private Map<JobId, JobFunctionTable> jobFunctionTables = new HashMap<>();
private ConcurrentMap<JobId, JobFunctionTable> jobFunctionTables = new ConcurrentHashMap<>();
/**
* The resource path which we can load the job's jar resources.
*/
private String jobResourcePath;
private final String jobResourcePath;
/**
* Construct a FunctionManager with the specified job resource path.
*
* @param jobResourcePath The specified job resource that can store the job's
* resources.
* resources.
*/
public FunctionManager(String jobResourcePath) {
this.jobResourcePath = jobResourcePath;
@@ -75,6 +79,8 @@ public class FunctionManager {
public RayFunction getFunction(JobId jobId, RayFunc func) {
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
if (functionDescriptor == null) {
// It's OK to not lock here, because it's OK to have multiple JavaFunctionDescriptor instances
// for the same RayFunc instance.
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
final String className = serializedLambda.getImplClass().replace('/', '.');
final String methodName = serializedLambda.getImplMethodName();
@@ -92,35 +98,45 @@ public class FunctionManager {
* @param functionDescriptor The function descriptor.
* @return A RayFunction object.
*/
public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) {
public RayFunction getFunction(JobId jobId,
JavaFunctionDescriptor functionDescriptor) {
JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId);
if (jobFunctionTable == null) {
ClassLoader classLoader;
if (Strings.isNullOrEmpty(jobResourcePath)) {
classLoader = getClass().getClassLoader();
} else {
File resourceDir = new File(jobResourcePath + "/" + jobId.toString() + "/");
Collection<File> files = FileUtils.listFiles(resourceDir,
new RegexFileFilter(".*\\.jar"), DirectoryFileFilter.DIRECTORY);
files.add(resourceDir);
final List<URL> urlList = files.stream().map(file -> {
try {
return file.toURI().toURL();
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList());
classLoader = new URLClassLoader(urlList.toArray(new URL[urlList.size()]));
LOGGER.debug("Resource loaded for job {} from path {}.", jobId,
resourceDir.getAbsolutePath());
synchronized (this) {
jobFunctionTable = jobFunctionTables.get(jobId);
if (jobFunctionTable == null) {
jobFunctionTable = createJobFunctionTable(jobId);
jobFunctionTables.put(jobId, jobFunctionTable);
}
}
jobFunctionTable = new JobFunctionTable(classLoader);
jobFunctionTables.put(jobId, jobFunctionTable);
}
return jobFunctionTable.getFunction(functionDescriptor);
}
private JobFunctionTable createJobFunctionTable(JobId jobId) {
ClassLoader classLoader;
if (Strings.isNullOrEmpty(jobResourcePath)) {
classLoader = getClass().getClassLoader();
} else {
File resourceDir = new File(jobResourcePath + "/" + jobId.toString() + "/");
Collection<File> files = FileUtils.listFiles(resourceDir,
new RegexFileFilter(".*\\.jar"), DirectoryFileFilter.DIRECTORY);
files.add(resourceDir);
final List<URL> urlList = files.stream().map(file -> {
try {
return file.toURI().toURL();
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList());
classLoader = new URLClassLoader(urlList.toArray(new URL[urlList.size()]));
LOGGER.debug("Resource loaded for job {} from path {}.", jobId,
resourceDir.getAbsolutePath());
}
return new JobFunctionTable(classLoader);
}
/**
* Manages all functions that belong to one job.
*/
@@ -129,22 +145,27 @@ public class FunctionManager {
/**
* The job's corresponding class loader.
*/
ClassLoader classLoader;
final ClassLoader classLoader;
/**
* Functions per class, per function name + type descriptor.
*/
Map<String, Map<Pair<String, String>, RayFunction>> functions;
ConcurrentMap<String, Map<Pair<String, String>, RayFunction>> functions;
JobFunctionTable(ClassLoader classLoader) {
this.classLoader = classLoader;
this.functions = new HashMap<>();
this.functions = new ConcurrentHashMap<>();
}
RayFunction getFunction(JavaFunctionDescriptor descriptor) {
Map<Pair<String, String>, RayFunction> classFunctions = functions.get(descriptor.className);
if (classFunctions == null) {
classFunctions = loadFunctionsForClass(descriptor.className);
functions.put(descriptor.className, classFunctions);
synchronized (this) {
classFunctions = functions.get(descriptor.className);
if (classFunctions == null) {
classFunctions = loadFunctionsForClass(descriptor.className);
functions.put(descriptor.className, classFunctions);
}
}
}
return classFunctions.get(ImmutablePair.of(descriptor.name, descriptor.typeDescriptor));
}
@@ -308,6 +308,10 @@ public class RunManager {
cmd.add("-Dray.logging.file.path=" + logFile);
}
if (!Strings.isNullOrEmpty(rayConfig.jobResourcePath)) {
cmd.add("-Dray.job.resource-path=" + rayConfig.jobResourcePath);
}
// socket names
cmd.add("-Dray.raylet.socket-name=" + rayConfig.rayletSocketName);
cmd.add("-Dray.object-store.socket-name=" + rayConfig.objectStoreSocketName);