mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:06:25 +08:00
[tune] Try to enable PTL, SKlearn tests (#11542)
This commit is contained in:
@@ -295,18 +295,10 @@ install_dependencies() {
|
||||
pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt
|
||||
fi
|
||||
|
||||
# Additional Tune dependency for Horovod.
|
||||
if [ "${INSTALL_HOROVOD-}" = 1 ]; then
|
||||
# TODO: eventually pin this to master.
|
||||
HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git
|
||||
fi
|
||||
|
||||
# Additional RaySGD test dependencies.
|
||||
if [ "${SGD_TESTING-}" = 1 ]; then
|
||||
pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt
|
||||
# TODO: eventually have a separate requirements file for Ray SGD.
|
||||
# Fix PTL version to 0.10 for now.
|
||||
pip install -U pytorch-lightning==0.10.0
|
||||
fi
|
||||
|
||||
# Additional Doc test dependencies.
|
||||
@@ -328,6 +320,13 @@ install_dependencies() {
|
||||
tensorflow=="${TF_VERSION-2.2.0}" gym
|
||||
fi
|
||||
|
||||
# Additional Tune dependency for Horovod.
|
||||
# This must be run last (i.e., torch cannot be re-installed after this)
|
||||
if [ "${INSTALL_HOROVOD-}" = 1 ]; then
|
||||
# TODO: eventually pin this to master.
|
||||
HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git
|
||||
fi
|
||||
|
||||
if [ -n "${PYTHON-}" ] || [ -n "${LINT-}" ] || [ "${MAC_WHEELS-}" = 1 ]; then
|
||||
install_node
|
||||
fi
|
||||
|
||||
@@ -54,14 +54,14 @@ py_test(
|
||||
# Please keep these sorted alphabetically.
|
||||
# --------------------------------------------------------------------
|
||||
|
||||
# py_test(
|
||||
# name = "tune_sklearn",
|
||||
# size = "medium",
|
||||
# main = "source/tune/_tutorials/tune-sklearn.py",
|
||||
# srcs = ["source/tune/_tutorials/tune-sklearn.py"],
|
||||
# tags = ["exclusive", "example"],
|
||||
# args = ["--smoke-test"]
|
||||
# )
|
||||
py_test(
|
||||
name = "tune_sklearn",
|
||||
size = "medium",
|
||||
main = "source/tune/_tutorials/tune-sklearn.py",
|
||||
srcs = ["source/tune/_tutorials/tune-sklearn.py"],
|
||||
tags = ["exclusive", "example"],
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tune_serve_integration_mnist",
|
||||
|
||||
@@ -36,7 +36,9 @@ Let's compare Tune's Scikit-Learn APIs to the standard scikit-learn GridSearchCV
|
||||
To start out, change the import statement to get tune-scikit-learn’s grid search cross validation interface:
|
||||
|
||||
"""
|
||||
# from sklearn.model_selection import GridSearchCV
|
||||
# Keep this here for https://github.com/ray-project/ray/issues/11547
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
# Replace above line with:
|
||||
from ray.tune.sklearn import TuneGridSearchCV
|
||||
|
||||
#######################################################################
|
||||
|
||||
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_rank() -> str:
|
||||
"""Returns rank of worker."""
|
||||
return os.environ["HOROVOD_RANK"]
|
||||
|
||||
|
||||
|
||||
@@ -155,4 +155,4 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
sys.exit(pytest.main(sys.argv[1:] + ["-v", __file__]))
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# --------------------------------------------------------------------
|
||||
py_test(
|
||||
name = "test_ptl",
|
||||
size = "small",
|
||||
size = "large",
|
||||
srcs = ["tests/test_ptl.py"],
|
||||
tags = ["exclusive", "pytorch-lightning", "pytorch"],
|
||||
deps = [":sgd_lib"],
|
||||
|
||||
@@ -209,3 +209,9 @@ def test_correctness(ray_start_2_cpus, num_workers, use_local):
|
||||
assert train1_stats["train_loss"] == train2_stats["train_loss"]
|
||||
assert val1_stats["val_loss"] == val2_stats["val_loss"]
|
||||
assert val1_stats["val_acc"] == val2_stats["val_accuracy"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(sys.argv[1:] + ["-v", __file__]))
|
||||
|
||||
@@ -3,7 +3,6 @@ import json
|
||||
import os
|
||||
|
||||
import ray
|
||||
import ray._private.services
|
||||
from ray.util.sgd import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -148,7 +147,7 @@ class TFRunner:
|
||||
|
||||
def get_node_ip(self):
|
||||
"""Returns the IP address of the current node."""
|
||||
return ray._private.services.get_node_ip_address()
|
||||
return ray.services.get_node_ip_address()
|
||||
|
||||
def find_free_port(self):
|
||||
"""Finds a free port on the current node."""
|
||||
|
||||
@@ -168,7 +168,7 @@ def clear_dummy_actor():
|
||||
|
||||
|
||||
def reserve_resources(num_cpus, num_gpus, retries=20):
|
||||
ip = ray._private.services.get_node_ip_address()
|
||||
ip = ray.services.get_node_ip_address()
|
||||
|
||||
reserved_cuda_device = None
|
||||
|
||||
|
||||
@@ -57,9 +57,12 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
"""Returns list of scheduler dictionaries.
|
||||
|
||||
List is empty if no schedulers are returned in the
|
||||
configure_optimizers method of your LightningModule. Default
|
||||
configuration is used if configure_optimizers returns scheduler
|
||||
objects instead of scheduler dicts. See
|
||||
configure_optimizers method of your LightningModule.
|
||||
|
||||
Default configuration is used if configure_optimizers
|
||||
returns scheduler objects.
|
||||
|
||||
See
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#configure-optimizers
|
||||
"""
|
||||
return self._scheduler_dicts
|
||||
@@ -266,7 +269,8 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
return_output = meter_collection.summary()
|
||||
|
||||
if self.is_function_implemented("on_train_epoch_end", model):
|
||||
model.on_train_epoch_end()
|
||||
model.on_train_epoch_end(
|
||||
[eo.get("raw_output") for eo in epoch_outputs])
|
||||
|
||||
for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers):
|
||||
if s_dict["interval"] == SCHEDULER_STEP_EPOCH:
|
||||
@@ -345,10 +349,9 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
with self.timers.record("grad"):
|
||||
if self.use_fp16:
|
||||
with self._amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
model.backward(
|
||||
self, scaled_loss, optimizer, optimizer_idx=0)
|
||||
model.backward(scaled_loss, optimizer, optimizer_idx=0)
|
||||
else:
|
||||
model.backward(self, loss, optimizer, optimizer_idx=0)
|
||||
model.backward(loss, optimizer, optimizer_idx=0)
|
||||
|
||||
if self.is_function_implemented("on_after_backward", model):
|
||||
model.on_after_backward()
|
||||
@@ -370,7 +373,10 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
|
||||
if self.is_function_implemented("on_train_batch_end", model):
|
||||
model.on_train_batch_end(
|
||||
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
|
||||
outputs=output,
|
||||
batch=batch,
|
||||
batch_idx=batch_idx,
|
||||
dataloader_idx=0)
|
||||
|
||||
return {
|
||||
"signal": 0,
|
||||
@@ -468,7 +474,10 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
|
||||
|
||||
if self.is_function_implemented("on_validation_batch_end", model):
|
||||
model.on_validation_batch_end(
|
||||
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
|
||||
outputs=output,
|
||||
batch=batch,
|
||||
batch_idx=batch_idx,
|
||||
dataloader_idx=0)
|
||||
return {
|
||||
"raw_output": output,
|
||||
# NUM_SAMPLES: len(batch)
|
||||
|
||||
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_address():
|
||||
ip = ray._private.services.get_node_ip_address()
|
||||
ip = ray.services.get_node_ip_address()
|
||||
port = find_free_port()
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ sigopt
|
||||
smart_open
|
||||
tensorflow-probability
|
||||
timm
|
||||
torch>=1.5.0
|
||||
torch>=1.6.0
|
||||
torchvision>=0.6.0
|
||||
# transformers
|
||||
git+git://github.com/huggingface/transformers.git@bdcc4b78a27775d1ec8f3fd297cb679c257289db#transformers
|
||||
|
||||
Reference in New Issue
Block a user