mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 04:50:40 +08:00
c9fafe7733
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
from typing import Dict, List, Optional
|
|
|
|
from ray.tune.suggest.suggestion import Searcher, ConcurrencyLimiter
|
|
from ray.tune.suggest.search_generator import SearchGenerator
|
|
from ray.tune.trial import Trial
|
|
|
|
|
|
class _MockSearcher(Searcher):
|
|
def __init__(self, **kwargs):
|
|
self.live_trials = {}
|
|
self.counter = {"result": 0, "complete": 0}
|
|
self.final_results = []
|
|
self.stall = False
|
|
self.results = []
|
|
super(_MockSearcher, self).__init__(**kwargs)
|
|
|
|
def suggest(self, trial_id: str):
|
|
if not self.stall:
|
|
self.live_trials[trial_id] = 1
|
|
return {"test_variable": 2}
|
|
return None
|
|
|
|
def on_trial_result(self, trial_id: str, result: Dict):
|
|
self.counter["result"] += 1
|
|
self.results += [result]
|
|
|
|
def on_trial_complete(self,
|
|
trial_id: str,
|
|
result: Optional[Dict] = None,
|
|
error: bool = False):
|
|
self.counter["complete"] += 1
|
|
if result:
|
|
self._process_result(result)
|
|
if trial_id in self.live_trials:
|
|
del self.live_trials[trial_id]
|
|
|
|
def _process_result(self, result: Dict):
|
|
self.final_results += [result]
|
|
|
|
|
|
class _MockSuggestionAlgorithm(SearchGenerator):
|
|
def __init__(self, max_concurrent: Optional[int] = None, **kwargs):
|
|
self.searcher = _MockSearcher(**kwargs)
|
|
if max_concurrent:
|
|
self.searcher = ConcurrencyLimiter(
|
|
self.searcher, max_concurrent=max_concurrent)
|
|
super(_MockSuggestionAlgorithm, self).__init__(self.searcher)
|
|
|
|
@property
|
|
def live_trials(self) -> List[Trial]:
|
|
return self.searcher.live_trials
|
|
|
|
@property
|
|
def results(self) -> List[Dict]:
|
|
return self.searcher.results
|