[tune] Try to enable PTL, SKlearn tests (#11542)

This commit is contained in:
Richard Liaw
2020-10-24 01:08:46 -07:00
committed by GitHub
parent d3ee83205b
commit 1b357533b1
12 changed files with 49 additions and 33 deletions
+1
View File
@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
def get_rank() -> str:
"""Returns rank of worker."""
return os.environ["HOROVOD_RANK"]
@@ -155,4 +155,4 @@ class PyTorchLightningIntegrationTest(unittest.TestCase):
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
sys.exit(pytest.main(sys.argv[1:] + ["-v", __file__]))
+1 -1
View File
@@ -4,7 +4,7 @@
# --------------------------------------------------------------------
py_test(
name = "test_ptl",
size = "small",
size = "large",
srcs = ["tests/test_ptl.py"],
tags = ["exclusive", "pytorch-lightning", "pytorch"],
deps = [":sgd_lib"],
+6
View File
@@ -209,3 +209,9 @@ def test_correctness(ray_start_2_cpus, num_workers, use_local):
assert train1_stats["train_loss"] == train2_stats["train_loss"]
assert val1_stats["val_loss"] == val2_stats["val_loss"]
assert val1_stats["val_acc"] == val2_stats["val_accuracy"]
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(sys.argv[1:] + ["-v", __file__]))
+1 -2
View File
@@ -3,7 +3,6 @@ import json
import os
import ray
import ray._private.services
from ray.util.sgd import utils
logger = logging.getLogger(__name__)
@@ -148,7 +147,7 @@ class TFRunner:
def get_node_ip(self):
"""Returns the IP address of the current node."""
return ray._private.services.get_node_ip_address()
return ray.services.get_node_ip_address()
def find_free_port(self):
"""Finds a free port on the current node."""
@@ -168,7 +168,7 @@ def clear_dummy_actor():
def reserve_resources(num_cpus, num_gpus, retries=20):
ip = ray._private.services.get_node_ip_address()
ip = ray.services.get_node_ip_address()
reserved_cuda_device = None
+18 -9
View File
@@ -57,9 +57,12 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
"""Returns list of scheduler dictionaries.
List is empty if no schedulers are returned in the
configure_optimizers method of your LightningModule. Default
configuration is used if configure_optimizers returns scheduler
objects instead of scheduler dicts. See
configure_optimizers method of your LightningModule.
Default configuration is used if configure_optimizers
returns scheduler objects.
See
https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#configure-optimizers
"""
return self._scheduler_dicts
@@ -266,7 +269,8 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
return_output = meter_collection.summary()
if self.is_function_implemented("on_train_epoch_end", model):
model.on_train_epoch_end()
model.on_train_epoch_end(
[eo.get("raw_output") for eo in epoch_outputs])
for s_dict, scheduler in zip(self.scheduler_dicts, self.schedulers):
if s_dict["interval"] == SCHEDULER_STEP_EPOCH:
@@ -345,10 +349,9 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
with self.timers.record("grad"):
if self.use_fp16:
with self._amp.scale_loss(loss, optimizer) as scaled_loss:
model.backward(
self, scaled_loss, optimizer, optimizer_idx=0)
model.backward(scaled_loss, optimizer, optimizer_idx=0)
else:
model.backward(self, loss, optimizer, optimizer_idx=0)
model.backward(loss, optimizer, optimizer_idx=0)
if self.is_function_implemented("on_after_backward", model):
model.on_after_backward()
@@ -370,7 +373,10 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
if self.is_function_implemented("on_train_batch_end", model):
model.on_train_batch_end(
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
outputs=output,
batch=batch,
batch_idx=batch_idx,
dataloader_idx=0)
return {
"signal": 0,
@@ -468,7 +474,10 @@ class LightningOperator(TrainingOperator, TrainerModelHooksMixin,
if self.is_function_implemented("on_validation_batch_end", model):
model.on_validation_batch_end(
batch=batch, batch_idx=batch_idx, dataloader_idx=0)
outputs=output,
batch=batch,
batch_idx=batch_idx,
dataloader_idx=0)
return {
"raw_output": output,
# NUM_SAMPLES: len(batch)
+1 -1
View File
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
def setup_address():
ip = ray._private.services.get_node_ip_address()
ip = ray.services.get_node_ip_address()
port = find_free_port()
return f"tcp://{ip}:{port}"
+1 -1
View File
@@ -24,7 +24,7 @@ sigopt
smart_open
tensorflow-probability
timm
torch>=1.5.0
torch>=1.6.0
torchvision>=0.6.0
# transformers
git+git://github.com/huggingface/transformers.git@bdcc4b78a27775d1ec8f3fd297cb679c257289db#transformers