mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[joblib] Fix flaky joblib test. (#13046)
This commit is contained in:
Binary file not shown.
@@ -1,12 +1,13 @@
|
||||
import joblib
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from sklearn.datasets import load_digits, load_iris
|
||||
from sklearn.model_selection import RandomizedSearchCV
|
||||
from sklearn.datasets import fetch_openml
|
||||
from sklearn.ensemble import ExtraTreesClassifier
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.kernel_approximation import Nystroem
|
||||
@@ -14,7 +15,6 @@ from sklearn.kernel_approximation import RBFSampler
|
||||
from sklearn.pipeline import make_pipeline
|
||||
from sklearn.svm import LinearSVC, SVC
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.utils import check_array
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.model_selection import cross_val_score
|
||||
@@ -112,20 +112,14 @@ def test_sklearn_benchmarks(ray_start_cluster_2_nodes):
|
||||
}
|
||||
# Load dataset.
|
||||
print("Loading dataset...")
|
||||
data = fetch_openml("mnist_784")
|
||||
X = check_array(data["data"], dtype=np.float32, order="C")
|
||||
y = data["target"]
|
||||
|
||||
unnormalized_X_train, y_train = pickle.load(
|
||||
open(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "mnist_784_100_samples.pkl"), "rb"))
|
||||
# Normalize features.
|
||||
X = X / 255
|
||||
X_train = unnormalized_X_train / 255
|
||||
|
||||
# Create train-test split.
|
||||
print("Creating train-test split...")
|
||||
n_train = 100
|
||||
X_train = X[:n_train]
|
||||
y_train = y[:n_train]
|
||||
register_ray()
|
||||
|
||||
train_time = {}
|
||||
random_seed = 0
|
||||
# Use two workers per classifier.
|
||||
|
||||
Reference in New Issue
Block a user