mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 18:45:03 +08:00
[sgd] Replaced class Resources in sgd with use_gpu (#5252)
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer
|
||||
|
||||
from ray.experimental.sgd.tests.pytorch_utils import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
@@ -15,10 +15,11 @@ def train_example(num_replicas=1, use_gpu=False):
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
num_replicas=num_replicas,
|
||||
resources_per_replica=Resources(
|
||||
num_cpus=1, num_gpus=int(use_gpu), resources={}))
|
||||
use_gpu=use_gpu,
|
||||
backend="gloo")
|
||||
trainer1.train()
|
||||
trainer1.shutdown()
|
||||
print("success!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -28,6 +29,12 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
type=str,
|
||||
help="the address to use for Redis")
|
||||
parser.add_argument(
|
||||
"--num-replicas",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Sets number of replicas for training.")
|
||||
parser.add_argument(
|
||||
"--use-gpu",
|
||||
action="store_true",
|
||||
@@ -36,5 +43,6 @@ if __name__ == "__main__":
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
import ray
|
||||
|
||||
ray.init(redis_address=args.redis_address)
|
||||
train_example(num_replicas=2, use_gpu=args.use_gpu)
|
||||
train_example(num_replicas=args.num_replicas, use_gpu=args.use_gpu)
|
||||
|
||||
@@ -3,6 +3,5 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.experimental.sgd.pytorch.pytorch_trainer import PyTorchTrainer
|
||||
from ray.experimental.sgd.pytorch.utils import Resources
|
||||
|
||||
__all__ = ["PyTorchTrainer", "Resources"]
|
||||
__all__ = ["PyTorchTrainer"]
|
||||
|
||||
@@ -30,7 +30,7 @@ class PyTorchTrainer(object):
|
||||
optimizer_creator=utils.sgd_mse_optimizer,
|
||||
config=None,
|
||||
num_replicas=1,
|
||||
resources_per_replica=None,
|
||||
use_gpu=False,
|
||||
batch_size=16,
|
||||
backend="auto"):
|
||||
"""Sets up the PyTorch trainer.
|
||||
@@ -46,8 +46,8 @@ class PyTorchTrainer(object):
|
||||
'data_creator', and 'optimizer_creator'.
|
||||
num_replicas (int): the number of workers used in distributed
|
||||
training.
|
||||
resources_per_replica (Resources): resources used by each worker.
|
||||
Defaults to Resources(num_cpus=1).
|
||||
use_gpu (bool): Sets resource allocation for workers to 1 GPU
|
||||
if true.
|
||||
batch_size (int): batch size for an update.
|
||||
backend (string): backend used by distributed PyTorch.
|
||||
"""
|
||||
@@ -64,19 +64,15 @@ class PyTorchTrainer(object):
|
||||
self.config = {} if config is None else config
|
||||
self.optimizer_timer = utils.TimerStat(window_size=1)
|
||||
|
||||
if resources_per_replica is None:
|
||||
resources_per_replica = utils.Resources(
|
||||
num_cpus=1, num_gpus=0, resources={})
|
||||
|
||||
if backend == "auto":
|
||||
backend = "nccl" if resources_per_replica.num_gpus > 0 else "gloo"
|
||||
backend = "nccl" if use_gpu else "gloo"
|
||||
|
||||
logger.info("Using {} as backend.".format(backend))
|
||||
|
||||
if num_replicas == 1:
|
||||
# Generate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=resources_per_replica.num_cpus,
|
||||
num_gpus=resources_per_replica.num_gpus,
|
||||
resources=resources_per_replica.resources)(PyTorchRunner)
|
||||
num_cpus=1, num_gpus=int(use_gpu))(PyTorchRunner)
|
||||
# Start workers
|
||||
self.workers = [
|
||||
Runner.remote(model_creator, data_creator, optimizer_creator,
|
||||
@@ -87,10 +83,7 @@ class PyTorchTrainer(object):
|
||||
else:
|
||||
# Geneate actor class
|
||||
Runner = ray.remote(
|
||||
num_cpus=resources_per_replica.num_cpus,
|
||||
num_gpus=resources_per_replica.num_gpus,
|
||||
resources=resources_per_replica.resources)(
|
||||
DistributedPyTorchRunner)
|
||||
num_cpus=1, num_gpus=int(use_gpu))(DistributedPyTorchRunner)
|
||||
# Compute batch size per replica
|
||||
batch_size_per_replica = batch_size // num_replicas
|
||||
if batch_size % num_replicas > 0:
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
from contextlib import closing
|
||||
import numpy as np
|
||||
import socket
|
||||
@@ -214,18 +213,6 @@ class AverageMeter(object):
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class Resources(
|
||||
namedtuple("Resources", ["num_cpus", "num_gpus", "resources"])):
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, num_cpus=1, num_gpus=0, resources=None):
|
||||
if resources is None:
|
||||
resources = {}
|
||||
|
||||
return super(Resources, cls).__new__(cls, num_cpus, num_gpus,
|
||||
resources)
|
||||
|
||||
|
||||
def sgd_mse_optimizer(model, config):
|
||||
"""Returns the mean squared error criterion and SGD optimizer.
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ray.tests.conftest import ray_start_2_cpus # noqa: F401
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer, Resources
|
||||
from ray.experimental.sgd.pytorch import PyTorchTrainer
|
||||
|
||||
from ray.experimental.sgd.tests.pytorch_utils import (
|
||||
model_creator, optimizer_creator, data_creator)
|
||||
@@ -22,8 +22,7 @@ def test_train(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
num_replicas=num_replicas,
|
||||
resources_per_replica=Resources(num_cpus=1))
|
||||
num_replicas=num_replicas)
|
||||
train_loss1 = trainer.train()["train_loss"]
|
||||
validation_loss1 = trainer.validate()["validation_loss"]
|
||||
|
||||
@@ -44,8 +43,7 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
num_replicas=num_replicas,
|
||||
resources_per_replica=Resources(num_cpus=1))
|
||||
num_replicas=num_replicas)
|
||||
trainer1.train()
|
||||
|
||||
filename = os.path.join(tempfile.mkdtemp(), "checkpoint")
|
||||
@@ -59,8 +57,7 @@ def test_save_and_restore(ray_start_2_cpus, num_replicas): # noqa: F811
|
||||
model_creator,
|
||||
data_creator,
|
||||
optimizer_creator,
|
||||
num_replicas=num_replicas,
|
||||
resources_per_replica=Resources(num_cpus=1))
|
||||
num_replicas=num_replicas)
|
||||
trainer2.restore(filename)
|
||||
|
||||
os.remove(filename)
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.rllib.optimizers import AsyncSamplesOptimizer
|
||||
from ray.rllib.optimizers.aso_tree_aggregator import TreeAggregator
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources
|
||||
from ray.tune.resources import Resources
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
|
||||
@@ -24,7 +24,8 @@ from ray.rllib.utils.memory import ray_get_and_free
|
||||
from ray.rllib.utils import try_import_tf
|
||||
from ray.tune.registry import ENV_CREATOR, register_env, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.trial import Resources, ExportFormat
|
||||
from ray.tune.trial import ExportFormat
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import ray
|
||||
from ray.tests.cluster_utils import Cluster
|
||||
from ray.tune.config_parser import make_parser
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.trial import resources_to_json
|
||||
from ray.tune.resources import resources_to_json
|
||||
from ray.tune.tune import _make_scheduler, run_experiments
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
|
||||
@@ -10,7 +10,8 @@ import os
|
||||
from six import string_types
|
||||
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trial import Trial, json_to_resources
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.resources import json_to_resources
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ import traceback
|
||||
import ray
|
||||
from ray.tune.error import AbortTrialExecution
|
||||
from ray.tune.logger import NoopLogger
|
||||
from ray.tune.trial import Trial, Resources, Checkpoint
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
from ray.tune.util import warn_if_slow
|
||||
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import json
|
||||
# For compatibility under py2 to consider unicode as str
|
||||
from six import string_types
|
||||
|
||||
from numbers import Number
|
||||
|
||||
from ray.tune import TuneError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Resources(
|
||||
namedtuple("Resources", [
|
||||
"cpu", "gpu", "extra_cpu", "extra_gpu", "custom_resources",
|
||||
"extra_custom_resources"
|
||||
])):
|
||||
"""Ray resources required to schedule a trial.
|
||||
|
||||
Attributes:
|
||||
cpu (float): Number of CPUs to allocate to the trial.
|
||||
gpu (float): Number of GPUs to allocate to the trial.
|
||||
extra_cpu (float): Extra CPUs to reserve in case the trial needs to
|
||||
launch additional Ray actors that use CPUs.
|
||||
extra_gpu (float): Extra GPUs to reserve in case the trial needs to
|
||||
launch additional Ray actors that use GPUs.
|
||||
custom_resources (dict): Mapping of resource to quantity to allocate
|
||||
to the trial.
|
||||
extra_custom_resources (dict): Extra custom resources to reserve in
|
||||
case the trial needs to launch additional Ray actors that use
|
||||
any of these custom resources.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls,
|
||||
cpu,
|
||||
gpu,
|
||||
extra_cpu=0,
|
||||
extra_gpu=0,
|
||||
custom_resources=None,
|
||||
extra_custom_resources=None):
|
||||
custom_resources = custom_resources or {}
|
||||
extra_custom_resources = extra_custom_resources or {}
|
||||
leftovers = set(custom_resources) ^ set(extra_custom_resources)
|
||||
|
||||
for value in leftovers:
|
||||
custom_resources.setdefault(value, 0)
|
||||
extra_custom_resources.setdefault(value, 0)
|
||||
|
||||
all_values = [cpu, gpu, extra_cpu, extra_gpu]
|
||||
all_values += list(custom_resources.values())
|
||||
all_values += list(extra_custom_resources.values())
|
||||
assert len(custom_resources) == len(extra_custom_resources)
|
||||
for entry in all_values:
|
||||
assert isinstance(entry, Number), "Improper resource value."
|
||||
return super(Resources,
|
||||
cls).__new__(cls, cpu, gpu, extra_cpu, extra_gpu,
|
||||
custom_resources, extra_custom_resources)
|
||||
|
||||
def summary_string(self):
|
||||
summary = "{} CPUs, {} GPUs".format(self.cpu + self.extra_cpu,
|
||||
self.gpu + self.extra_gpu)
|
||||
custom_summary = ", ".join([
|
||||
"{} {}".format(self.get_res_total(res), res)
|
||||
for res in self.custom_resources
|
||||
])
|
||||
if custom_summary:
|
||||
summary += " ({})".format(custom_summary)
|
||||
return summary
|
||||
|
||||
def cpu_total(self):
|
||||
return self.cpu + self.extra_cpu
|
||||
|
||||
def gpu_total(self):
|
||||
return self.gpu + self.extra_gpu
|
||||
|
||||
def get_res_total(self, key):
|
||||
return self.custom_resources.get(
|
||||
key, 0) + self.extra_custom_resources.get(key, 0)
|
||||
|
||||
def get(self, key):
|
||||
return self.custom_resources.get(key, 0)
|
||||
|
||||
def is_nonnegative(self):
|
||||
all_values = [self.cpu, self.gpu, self.extra_cpu, self.extra_gpu]
|
||||
all_values += list(self.custom_resources.values())
|
||||
all_values += list(self.extra_custom_resources.values())
|
||||
return all(v >= 0 for v in all_values)
|
||||
|
||||
@classmethod
|
||||
def subtract(cls, original, to_remove):
|
||||
cpu = original.cpu - to_remove.cpu
|
||||
gpu = original.gpu - to_remove.gpu
|
||||
extra_cpu = original.extra_cpu - to_remove.extra_cpu
|
||||
extra_gpu = original.extra_gpu - to_remove.extra_gpu
|
||||
all_resources = set(original.custom_resources).union(
|
||||
set(to_remove.custom_resources))
|
||||
new_custom_res = {
|
||||
k: original.custom_resources.get(k, 0) -
|
||||
to_remove.custom_resources.get(k, 0)
|
||||
for k in all_resources
|
||||
}
|
||||
extra_custom_res = {
|
||||
k: original.extra_custom_resources.get(k, 0) -
|
||||
to_remove.extra_custom_resources.get(k, 0)
|
||||
for k in all_resources
|
||||
}
|
||||
return Resources(cpu, gpu, extra_cpu, extra_gpu, new_custom_res,
|
||||
extra_custom_res)
|
||||
|
||||
def to_json(self):
|
||||
return resources_to_json(self)
|
||||
|
||||
|
||||
def json_to_resources(data):
|
||||
if data is None or data == "null":
|
||||
return None
|
||||
if isinstance(data, string_types):
|
||||
data = json.loads(data)
|
||||
for k in data:
|
||||
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
|
||||
raise TuneError(
|
||||
"The field `{}` is no longer supported. Use `extra_cpu` "
|
||||
"or `extra_gpu` instead.".format(k))
|
||||
if k not in Resources._fields:
|
||||
raise ValueError(
|
||||
"Unknown resource field {}, must be one of {}".format(
|
||||
k, Resources._fields))
|
||||
return Resources(
|
||||
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
|
||||
data.get("extra_gpu", 0), data.get("custom_resources"),
|
||||
data.get("extra_custom_resources"))
|
||||
|
||||
|
||||
def resources_to_json(resources):
|
||||
if resources is None:
|
||||
return None
|
||||
return {
|
||||
"cpu": resources.cpu,
|
||||
"gpu": resources.gpu,
|
||||
"extra_cpu": resources.extra_cpu,
|
||||
"extra_gpu": resources.extra_gpu,
|
||||
"custom_resources": resources.custom_resources.copy(),
|
||||
"extra_custom_resources": resources.extra_custom_resources.copy()
|
||||
}
|
||||
@@ -11,7 +11,8 @@ from ray.tune import Trainable
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.registry import _global_registry, TRAINABLE_CLASS
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.trial import Trial, Checkpoint, Resources
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.resources import Resources
|
||||
|
||||
|
||||
class RayTrialExecutorTest(unittest.TestCase):
|
||||
|
||||
@@ -27,9 +27,9 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, TIMESTEPS_TOTAL, DONE,
|
||||
from ray.tune.logger import Logger
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import (Trial, ExportFormat, Resources, resources_to_json,
|
||||
json_to_resources)
|
||||
from ray.tune.trial import Trial, ExportFormat
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
from ray.tune.suggest import grid_search, BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import (_MockSuggestionAlgorithm,
|
||||
SuggestionAlgorithm)
|
||||
|
||||
@@ -18,8 +18,9 @@ from ray.tune.schedulers import (HyperBandScheduler, AsyncHyperBandScheduler,
|
||||
TrialScheduler)
|
||||
|
||||
from ray.tune.schedulers.pbt import explore
|
||||
from ray.tune.trial import Trial, Resources, Checkpoint
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
from ray.tune.trial_executor import TrialExecutor
|
||||
from ray.tune.resources import Resources
|
||||
|
||||
from ray.rllib import _register_all
|
||||
_register_all()
|
||||
|
||||
+1
-144
@@ -2,21 +2,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import namedtuple
|
||||
import ray.cloudpickle as cloudpickle
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import json
|
||||
import uuid
|
||||
import time
|
||||
import tempfile
|
||||
import os
|
||||
from numbers import Number
|
||||
|
||||
# For compatibility under py2 to consider unicode as str
|
||||
from six import string_types
|
||||
|
||||
import ray
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.logger import pretty_print, UnifiedLogger
|
||||
@@ -28,6 +21,7 @@ from ray.tune.result import (DEFAULT_RESULTS_DIR, DONE, HOSTNAME, PID,
|
||||
TIME_TOTAL_S, TRAINING_ITERATION, TIMESTEPS_TOTAL,
|
||||
EPISODE_REWARD_MEAN, MEAN_LOSS, MEAN_ACCURACY)
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
MAX_LEN_IDENTIFIER = int(os.environ.get("MAX_LEN_IDENTIFIER", 130))
|
||||
@@ -38,143 +32,6 @@ def date_str():
|
||||
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
class Resources(
|
||||
namedtuple("Resources", [
|
||||
"cpu", "gpu", "extra_cpu", "extra_gpu", "custom_resources",
|
||||
"extra_custom_resources"
|
||||
])):
|
||||
"""Ray resources required to schedule a trial.
|
||||
|
||||
Attributes:
|
||||
cpu (float): Number of CPUs to allocate to the trial.
|
||||
gpu (float): Number of GPUs to allocate to the trial.
|
||||
extra_cpu (float): Extra CPUs to reserve in case the trial needs to
|
||||
launch additional Ray actors that use CPUs.
|
||||
extra_gpu (float): Extra GPUs to reserve in case the trial needs to
|
||||
launch additional Ray actors that use GPUs.
|
||||
custom_resources (dict): Mapping of resource to quantity to allocate
|
||||
to the trial.
|
||||
extra_custom_resources (dict): Extra custom resources to reserve in
|
||||
case the trial needs to launch additional Ray actors that use
|
||||
any of these custom resources.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls,
|
||||
cpu,
|
||||
gpu,
|
||||
extra_cpu=0,
|
||||
extra_gpu=0,
|
||||
custom_resources=None,
|
||||
extra_custom_resources=None):
|
||||
custom_resources = custom_resources or {}
|
||||
extra_custom_resources = extra_custom_resources or {}
|
||||
leftovers = set(custom_resources) ^ set(extra_custom_resources)
|
||||
|
||||
for value in leftovers:
|
||||
custom_resources.setdefault(value, 0)
|
||||
extra_custom_resources.setdefault(value, 0)
|
||||
|
||||
all_values = [cpu, gpu, extra_cpu, extra_gpu]
|
||||
all_values += list(custom_resources.values())
|
||||
all_values += list(extra_custom_resources.values())
|
||||
assert len(custom_resources) == len(extra_custom_resources)
|
||||
for entry in all_values:
|
||||
assert isinstance(entry, Number), "Improper resource value."
|
||||
return super(Resources,
|
||||
cls).__new__(cls, cpu, gpu, extra_cpu, extra_gpu,
|
||||
custom_resources, extra_custom_resources)
|
||||
|
||||
def summary_string(self):
|
||||
summary = "{} CPUs, {} GPUs".format(self.cpu + self.extra_cpu,
|
||||
self.gpu + self.extra_gpu)
|
||||
custom_summary = ", ".join([
|
||||
"{} {}".format(self.get_res_total(res), res)
|
||||
for res in self.custom_resources
|
||||
])
|
||||
if custom_summary:
|
||||
summary += " ({})".format(custom_summary)
|
||||
return summary
|
||||
|
||||
def cpu_total(self):
|
||||
return self.cpu + self.extra_cpu
|
||||
|
||||
def gpu_total(self):
|
||||
return self.gpu + self.extra_gpu
|
||||
|
||||
def get_res_total(self, key):
|
||||
return self.custom_resources.get(
|
||||
key, 0) + self.extra_custom_resources.get(key, 0)
|
||||
|
||||
def get(self, key):
|
||||
return self.custom_resources.get(key, 0)
|
||||
|
||||
def is_nonnegative(self):
|
||||
all_values = [self.cpu, self.gpu, self.extra_cpu, self.extra_gpu]
|
||||
all_values += list(self.custom_resources.values())
|
||||
all_values += list(self.extra_custom_resources.values())
|
||||
return all(v >= 0 for v in all_values)
|
||||
|
||||
@classmethod
|
||||
def subtract(cls, original, to_remove):
|
||||
cpu = original.cpu - to_remove.cpu
|
||||
gpu = original.gpu - to_remove.gpu
|
||||
extra_cpu = original.extra_cpu - to_remove.extra_cpu
|
||||
extra_gpu = original.extra_gpu - to_remove.extra_gpu
|
||||
all_resources = set(original.custom_resources).union(
|
||||
set(to_remove.custom_resources))
|
||||
new_custom_res = {
|
||||
k: original.custom_resources.get(k, 0) -
|
||||
to_remove.custom_resources.get(k, 0)
|
||||
for k in all_resources
|
||||
}
|
||||
extra_custom_res = {
|
||||
k: original.extra_custom_resources.get(k, 0) -
|
||||
to_remove.extra_custom_resources.get(k, 0)
|
||||
for k in all_resources
|
||||
}
|
||||
return Resources(cpu, gpu, extra_cpu, extra_gpu, new_custom_res,
|
||||
extra_custom_res)
|
||||
|
||||
def to_json(self):
|
||||
return resources_to_json(self)
|
||||
|
||||
|
||||
def json_to_resources(data):
|
||||
if data is None or data == "null":
|
||||
return None
|
||||
if isinstance(data, string_types):
|
||||
data = json.loads(data)
|
||||
for k in data:
|
||||
if k in ["driver_cpu_limit", "driver_gpu_limit"]:
|
||||
raise TuneError(
|
||||
"The field `{}` is no longer supported. Use `extra_cpu` "
|
||||
"or `extra_gpu` instead.".format(k))
|
||||
if k not in Resources._fields:
|
||||
raise ValueError(
|
||||
"Unknown resource field {}, must be one of {}".format(
|
||||
k, Resources._fields))
|
||||
return Resources(
|
||||
data.get("cpu", 1), data.get("gpu", 0), data.get("extra_cpu", 0),
|
||||
data.get("extra_gpu", 0), data.get("custom_resources"),
|
||||
data.get("extra_custom_resources"))
|
||||
|
||||
|
||||
def resources_to_json(resources):
|
||||
if resources is None:
|
||||
return None
|
||||
return {
|
||||
"cpu": resources.cpu,
|
||||
"gpu": resources.gpu,
|
||||
"extra_cpu": resources.extra_cpu,
|
||||
"extra_gpu": resources.extra_gpu,
|
||||
"custom_resources": resources.custom_resources.copy(),
|
||||
"extra_custom_resources": resources.extra_custom_resources.copy()
|
||||
}
|
||||
|
||||
|
||||
def has_trainable(trainable_name):
|
||||
return ray.tune.registry._global_registry.contains(
|
||||
ray.tune.registry.TRAINABLE_CLASS, trainable_name)
|
||||
|
||||
Reference in New Issue
Block a user