[tune] Fix numerical error (#5653)

This commit is contained in:
Richard Liaw
2019-09-07 07:34:51 -07:00
committed by GitHub
parent 8a352a8e70
commit d89ceb3ee5
2 changed files with 26 additions and 2 deletions
+18 -2
View File
@@ -5,11 +5,10 @@ from __future__ import print_function
from collections import namedtuple
import logging
import json
from numbers import Number
# For compatibility under py2 to consider unicode as str
from six import string_types
from numbers import Number
from ray.tune import TuneError
logger = logging.getLogger(__name__)
@@ -66,6 +65,23 @@ class Resources(
custom_resources.setdefault(value, 0)
extra_custom_resources.setdefault(value, 0)
cpu = round(cpu, 2)
gpu = round(gpu, 2)
memory = round(memory, 2)
object_store_memory = round(object_store_memory, 2)
extra_cpu = round(extra_cpu, 2)
extra_gpu = round(extra_gpu, 2)
extra_memory = round(extra_memory, 2)
extra_object_store_memory = round(extra_object_store_memory, 2)
custom_resources = {
resource: round(value, 2)
for resource, value in custom_resources.items()
}
extra_custom_resources = {
resource: round(value, 2)
for resource, value in extra_custom_resources.items()
}
all_values = [
cpu, gpu, memory, object_store_memory, extra_cpu, extra_gpu,
extra_memory, extra_object_store_memory
@@ -1531,6 +1531,14 @@ class TrialRunnerTest(unittest.TestCase):
self.assertEqual(trials[2].status, Trial.PENDING)
self.assertEqual(trials[3].status, Trial.PENDING)
def testResourceNumericalError(self):
resource = Resources(cpu=0.99, gpu=0.99, custom_resources={"a": 0.99})
small_resource = Resources(
cpu=0.33, gpu=0.33, custom_resources={"a": 0.33})
for i in range(3):
resource = Resources.subtract(resource, small_resource)
self.assertTrue(resource.is_nonnegative())
def testResourceScheduler(self):
ray.init(num_cpus=4, num_gpus=1)
runner = TrialRunner()