mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 12:58:37 +08:00
[tune] Fix numerical error (#5653)
This commit is contained in:
@@ -5,11 +5,10 @@ from __future__ import print_function
|
||||
from collections import namedtuple
|
||||
import logging
|
||||
import json
|
||||
from numbers import Number
|
||||
# For compatibility under py2 to consider unicode as str
|
||||
from six import string_types
|
||||
|
||||
from numbers import Number
|
||||
|
||||
from ray.tune import TuneError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,6 +65,23 @@ class Resources(
|
||||
custom_resources.setdefault(value, 0)
|
||||
extra_custom_resources.setdefault(value, 0)
|
||||
|
||||
cpu = round(cpu, 2)
|
||||
gpu = round(gpu, 2)
|
||||
memory = round(memory, 2)
|
||||
object_store_memory = round(object_store_memory, 2)
|
||||
extra_cpu = round(extra_cpu, 2)
|
||||
extra_gpu = round(extra_gpu, 2)
|
||||
extra_memory = round(extra_memory, 2)
|
||||
extra_object_store_memory = round(extra_object_store_memory, 2)
|
||||
custom_resources = {
|
||||
resource: round(value, 2)
|
||||
for resource, value in custom_resources.items()
|
||||
}
|
||||
extra_custom_resources = {
|
||||
resource: round(value, 2)
|
||||
for resource, value in extra_custom_resources.items()
|
||||
}
|
||||
|
||||
all_values = [
|
||||
cpu, gpu, memory, object_store_memory, extra_cpu, extra_gpu,
|
||||
extra_memory, extra_object_store_memory
|
||||
|
||||
@@ -1531,6 +1531,14 @@ class TrialRunnerTest(unittest.TestCase):
|
||||
self.assertEqual(trials[2].status, Trial.PENDING)
|
||||
self.assertEqual(trials[3].status, Trial.PENDING)
|
||||
|
||||
def testResourceNumericalError(self):
|
||||
resource = Resources(cpu=0.99, gpu=0.99, custom_resources={"a": 0.99})
|
||||
small_resource = Resources(
|
||||
cpu=0.33, gpu=0.33, custom_resources={"a": 0.33})
|
||||
for i in range(3):
|
||||
resource = Resources.subtract(resource, small_resource)
|
||||
self.assertTrue(resource.is_nonnegative())
|
||||
|
||||
def testResourceScheduler(self):
|
||||
ray.init(num_cpus=4, num_gpus=1)
|
||||
runner = TrialRunner()
|
||||
|
||||
Reference in New Issue
Block a user