mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:53:20 +08:00
[rllib/tune] Add test for fractional gpu support in xray mode; add rllib support for fractional gpu (#2768)
* frac gpu * doc * Update rllib-training.rst * yapf * remove xray
This commit is contained in:
@@ -34,6 +34,6 @@ class A2CAgent(A3CAgent):
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=1 if cf["gpu"] else 0,
|
||||
gpu=cf["gpu_fraction"] if cf["gpu"] else 0,
|
||||
extra_cpu=cf["num_workers"],
|
||||
extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)
|
||||
|
||||
@@ -64,6 +64,8 @@ COMMON_CONFIG = {
|
||||
"compress_observations": False,
|
||||
# Whether to write episode stats and videos to the agent log dir
|
||||
"monitor": False,
|
||||
# Allocate a fraction of a GPU instead of one (e.g., 0.3 GPUs)
|
||||
"gpu_fraction": 1,
|
||||
|
||||
# === Multiagent ===
|
||||
"multiagent": {
|
||||
|
||||
@@ -53,12 +53,12 @@ class BCAgent(Agent):
|
||||
def default_resource_request(cls, config):
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
if cf["use_gpu_for_workers"]:
|
||||
num_gpus_per_worker = 1
|
||||
num_gpus_per_worker = cf["gpu_fraction"]
|
||||
else:
|
||||
num_gpus_per_worker = 0
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
|
||||
extra_cpu=cf["num_workers"],
|
||||
extra_gpu=num_gpus_per_worker * cf["num_workers"])
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ class ApexDDPGAgent(DDPGAgent):
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class ApexAgent(DQNAgent):
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1 + cf["optimizer"]["num_replay_buffer_shards"],
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
|
||||
@@ -132,7 +132,7 @@ class DQNAgent(Agent):
|
||||
cf = merge_dicts(cls._default_config, config)
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class ImpalaAgent(Agent):
|
||||
cf = dict(cls._default_config, **config)
|
||||
return Resources(
|
||||
cpu=1,
|
||||
gpu=cf["gpu"] and 1 or 0,
|
||||
gpu=cf["gpu"] and cf["gpu_fraction"] or 0,
|
||||
extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"],
|
||||
extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"])
|
||||
|
||||
|
||||
@@ -800,6 +800,29 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertEqual(trials[0].status, Trial.TERMINATED)
|
||||
self.assertEqual(trials[1].status, Trial.PENDING)
|
||||
|
||||
def testFractionalGpus(self):
|
||||
ray.init(num_cpus=4, num_gpus=1, use_raylet=True)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
kwargs = {
|
||||
"resources": Resources(cpu=1, gpu=0.5),
|
||||
}
|
||||
trials = [
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs),
|
||||
Trial("__fake", **kwargs)
|
||||
]
|
||||
for t in trials:
|
||||
runner.add_trial(t)
|
||||
|
||||
for _ in range(10):
|
||||
runner.step()
|
||||
|
||||
self.assertEqual(trials[0].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[1].status, Trial.RUNNING)
|
||||
self.assertEqual(trials[2].status, Trial.PENDING)
|
||||
self.assertEqual(trials[3].status, Trial.PENDING)
|
||||
|
||||
def testResourceScheduler(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
runner = TrialRunner(BasicVariantGenerator())
|
||||
|
||||
Reference in New Issue
Block a user