From 4dc16730a7cd97c7dec3484331f508dc5d7a79cd Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Tue, 1 Dec 2020 16:47:03 -0800 Subject: [PATCH] [tune] with-params fix (#12522) --- python/ray/tune/function_runner.py | 3 ++- python/ray/tune/registry.py | 1 + python/ray/tune/tests/test_function_api.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index f9938bc2d..79e0f5da9 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -627,6 +627,7 @@ def with_parameters(fn, **kwargs): parameter_registry.put(prefix + k, v) use_checkpoint = detect_checkpoint_function(fn) + keys = list(kwargs.keys()) def inner(config, checkpoint_dir=None): fn_kwargs = {} @@ -638,7 +639,7 @@ def with_parameters(fn, **kwargs): or default fn_kwargs["checkpoint_dir"] = default - for k in kwargs: + for k in keys: fn_kwargs[k] = parameter_registry.get(prefix + k) fn(config, **fn_kwargs) diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 27935e641..7409aaa0b 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -169,6 +169,7 @@ class _ParameterRegistry: def flush(self): for k, v in self.to_flush.items(): self.references[k] = ray.put(v) + self.to_flush.clear() parameter_registry = _ParameterRegistry() diff --git a/python/ray/tune/tests/test_function_api.py b/python/ray/tune/tests/test_function_api.py index 305e8abb0..9ee2cdc64 100644 --- a/python/ray/tune/tests/test_function_api.py +++ b/python/ray/tune/tests/test_function_api.py @@ -1,10 +1,12 @@ import json import os +import sys import shutil import tempfile import unittest import ray +import ray.cloudpickle as cloudpickle from ray.rllib import _register_all from ray import tune @@ -468,6 +470,19 @@ class FunctionApiTest(unittest.TestCase): self.assertEquals(trial_2.last_result["metric"], 500_000) self.assertEquals(trial_2.last_result["cp"], "DIR") + def testWithParameters2(self): + class Data: + def __init__(self): + import numpy as np + self.data = np.random.rand((2 * 1024 * 1024)) + + def train(config, data=None): + tune.report(metric=len(data.data)) + + trainable = tune.with_parameters(train, data=Data()) + dumped = cloudpickle.dumps(trainable) + assert sys.getsizeof(dumped) < 100 * 1024 + def testReturnAnonymous(self): def train(config): return config["a"]