[tune] with-params fix (#12522)

This commit is contained in:
Richard Liaw
2020-12-01 16:47:03 -08:00
committed by GitHub
parent 7022278ce9
commit 4dc16730a7
3 changed files with 18 additions and 1 deletions
+2 -1
View File
@@ -627,6 +627,7 @@ def with_parameters(fn, **kwargs):
parameter_registry.put(prefix + k, v)
use_checkpoint = detect_checkpoint_function(fn)
keys = list(kwargs.keys())
def inner(config, checkpoint_dir=None):
fn_kwargs = {}
@@ -638,7 +639,7 @@ def with_parameters(fn, **kwargs):
or default
fn_kwargs["checkpoint_dir"] = default
for k in kwargs:
for k in keys:
fn_kwargs[k] = parameter_registry.get(prefix + k)
fn(config, **fn_kwargs)
+1
View File
@@ -169,6 +169,7 @@ class _ParameterRegistry:
def flush(self):
for k, v in self.to_flush.items():
self.references[k] = ray.put(v)
self.to_flush.clear()
parameter_registry = _ParameterRegistry()
@@ -1,10 +1,12 @@
import json
import os
import sys
import shutil
import tempfile
import unittest
import ray
import ray.cloudpickle as cloudpickle
from ray.rllib import _register_all
from ray import tune
@@ -468,6 +470,19 @@ class FunctionApiTest(unittest.TestCase):
self.assertEquals(trial_2.last_result["metric"], 500_000)
self.assertEquals(trial_2.last_result["cp"], "DIR")
def testWithParameters2(self):
class Data:
def __init__(self):
import numpy as np
self.data = np.random.rand((2 * 1024 * 1024))
def train(config, data=None):
tune.report(metric=len(data.data))
trainable = tune.with_parameters(train, data=Data())
dumped = cloudpickle.dumps(trainable)
assert sys.getsizeof(dumped) < 100 * 1024
def testReturnAnonymous(self):
def train(config):
return config["a"]