mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:23:17 +08:00
[tune] with-params fix (#12522)
This commit is contained in:
@@ -627,6 +627,7 @@ def with_parameters(fn, **kwargs):
|
||||
parameter_registry.put(prefix + k, v)
|
||||
|
||||
use_checkpoint = detect_checkpoint_function(fn)
|
||||
keys = list(kwargs.keys())
|
||||
|
||||
def inner(config, checkpoint_dir=None):
|
||||
fn_kwargs = {}
|
||||
@@ -638,7 +639,7 @@ def with_parameters(fn, **kwargs):
|
||||
or default
|
||||
fn_kwargs["checkpoint_dir"] = default
|
||||
|
||||
for k in kwargs:
|
||||
for k in keys:
|
||||
fn_kwargs[k] = parameter_registry.get(prefix + k)
|
||||
fn(config, **fn_kwargs)
|
||||
|
||||
|
||||
@@ -169,6 +169,7 @@ class _ParameterRegistry:
|
||||
def flush(self):
|
||||
for k, v in self.to_flush.items():
|
||||
self.references[k] = ray.put(v)
|
||||
self.to_flush.clear()
|
||||
|
||||
|
||||
parameter_registry = _ParameterRegistry()
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.rllib import _register_all
|
||||
|
||||
from ray import tune
|
||||
@@ -468,6 +470,19 @@ class FunctionApiTest(unittest.TestCase):
|
||||
self.assertEquals(trial_2.last_result["metric"], 500_000)
|
||||
self.assertEquals(trial_2.last_result["cp"], "DIR")
|
||||
|
||||
def testWithParameters2(self):
|
||||
class Data:
|
||||
def __init__(self):
|
||||
import numpy as np
|
||||
self.data = np.random.rand((2 * 1024 * 1024))
|
||||
|
||||
def train(config, data=None):
|
||||
tune.report(metric=len(data.data))
|
||||
|
||||
trainable = tune.with_parameters(train, data=Data())
|
||||
dumped = cloudpickle.dumps(trainable)
|
||||
assert sys.getsizeof(dumped) < 100 * 1024
|
||||
|
||||
def testReturnAnonymous(self):
|
||||
def train(config):
|
||||
return config["a"]
|
||||
|
||||
Reference in New Issue
Block a user