mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 21:56:20 +08:00
Support of scikit-learn with ray joblib backend (#6925)
This commit is contained in:
committed by
Edward Oakes
parent
396d7fafc8
commit
a7ecda6017
@@ -0,0 +1,17 @@
|
||||
from joblib.parallel import register_parallel_backend
|
||||
|
||||
|
||||
def register_ray():
|
||||
""" Register Ray Backend to be called with parallel_backend("ray"). """
|
||||
try:
|
||||
from ray.experimental.joblib.ray_backend import RayBackend
|
||||
register_parallel_backend("ray", RayBackend)
|
||||
except ImportError:
|
||||
msg = ("To use the ray backend you must install ray."
|
||||
"Try running 'pip install ray'."
|
||||
"See https://ray.readthedocs.io/en/latest/installation.html"
|
||||
"for more information.")
|
||||
raise ImportError(msg)
|
||||
|
||||
|
||||
__all__ = ["register_ray"]
|
||||
@@ -0,0 +1,58 @@
|
||||
from joblib._parallel_backends import MultiprocessingBackend
|
||||
from joblib.pool import PicklingPool
|
||||
import logging
|
||||
|
||||
from ray.experimental.multiprocessing.pool import Pool
|
||||
import ray
|
||||
|
||||
RAY_ADDRESS_ENV = "RAY_ADDRESS"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayBackend(MultiprocessingBackend):
|
||||
"""Ray backend uses ray, a system for scalable distributed computing.
|
||||
More info about Ray is available here: https://ray.readthedocs.io.
|
||||
"""
|
||||
|
||||
def configure(self,
|
||||
n_jobs=1,
|
||||
parallel=None,
|
||||
prefer=None,
|
||||
require=None,
|
||||
**memmappingpool_args):
|
||||
"""Make Ray Pool the father class of PicklingPool. PicklingPool is a
|
||||
father class that inherits Pool from multiprocessing.pool. The next
|
||||
line is a patch, which changes the inheritance of Pool to be from
|
||||
ray.experimental.multiprocessing.pool.
|
||||
"""
|
||||
PicklingPool.__bases__ = (Pool, )
|
||||
"""Use all available resources when n_jobs == -1. Must set RAY_ADDRESS
|
||||
variable in the environment or run ray.init(address=..) to run on
|
||||
multiple nodes.
|
||||
"""
|
||||
if n_jobs == -1:
|
||||
if not ray.is_initialized():
|
||||
import os
|
||||
if RAY_ADDRESS_ENV in os.environ:
|
||||
ray_address = os.environ[RAY_ADDRESS_ENV]
|
||||
logger.info(
|
||||
"Connecting to ray cluster at address='{}'".format(
|
||||
ray_address))
|
||||
ray.init(address=ray_address)
|
||||
else:
|
||||
logger.info("Starting local ray cluster")
|
||||
ray.init()
|
||||
ray_cpus = int(ray.state.cluster_resources()["CPU"])
|
||||
n_jobs = ray_cpus
|
||||
|
||||
eff_n_jobs = super(RayBackend, self).configure(
|
||||
n_jobs, parallel, prefer, require, **memmappingpool_args)
|
||||
return eff_n_jobs
|
||||
|
||||
def effective_n_jobs(self, n_jobs):
|
||||
eff_n_jobs = super(RayBackend, self).effective_n_jobs(n_jobs)
|
||||
if n_jobs == -1:
|
||||
ray_cpus = int(ray.state.cluster_resources()["CPU"])
|
||||
eff_n_jobs = ray_cpus
|
||||
return eff_n_jobs
|
||||
Reference in New Issue
Block a user