diff --git a/python/ray/tune/resources.py b/python/ray/tune/resources.py index e2a8faffa..b340f133f 100644 --- a/python/ray/tune/resources.py +++ b/python/ray/tune/resources.py @@ -5,11 +5,10 @@ from __future__ import print_function from collections import namedtuple import logging import json +from numbers import Number # For compatibility under py2 to consider unicode as str from six import string_types -from numbers import Number - from ray.tune import TuneError logger = logging.getLogger(__name__) @@ -66,6 +65,23 @@ class Resources( custom_resources.setdefault(value, 0) extra_custom_resources.setdefault(value, 0) + cpu = round(cpu, 2) + gpu = round(gpu, 2) + memory = round(memory, 2) + object_store_memory = round(object_store_memory, 2) + extra_cpu = round(extra_cpu, 2) + extra_gpu = round(extra_gpu, 2) + extra_memory = round(extra_memory, 2) + extra_object_store_memory = round(extra_object_store_memory, 2) + custom_resources = { + resource: round(value, 2) + for resource, value in custom_resources.items() + } + extra_custom_resources = { + resource: round(value, 2) + for resource, value in extra_custom_resources.items() + } + all_values = [ cpu, gpu, memory, object_store_memory, extra_cpu, extra_gpu, extra_memory, extra_object_store_memory diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 0e75f6f73..d6164bd13 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -1531,6 +1531,14 @@ class TrialRunnerTest(unittest.TestCase): self.assertEqual(trials[2].status, Trial.PENDING) self.assertEqual(trials[3].status, Trial.PENDING) + def testResourceNumericalError(self): + resource = Resources(cpu=0.99, gpu=0.99, custom_resources={"a": 0.99}) + small_resource = Resources( + cpu=0.33, gpu=0.33, custom_resources={"a": 0.33}) + for i in range(3): + resource = Resources.subtract(resource, small_resource) + self.assertTrue(resource.is_nonnegative()) + def testResourceScheduler(self): ray.init(num_cpus=4, num_gpus=1) runner = TrialRunner()