mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 04:44:28 +08:00
[tune] Add encoder for PBT (#5599)
* Add encoder * Apply suggestions from code review
This commit is contained in:
@@ -13,6 +13,7 @@ import shutil
|
||||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import TRAINING_ITERATION
|
||||
from ray.tune.logger import _SafeFallbackEncoder
|
||||
from ray.tune.schedulers import FIFOScheduler, TrialScheduler
|
||||
from ray.tune.suggest.variant_generator import format_vars
|
||||
from ray.tune.trial import Trial, Checkpoint
|
||||
@@ -276,13 +277,13 @@ class PopulationBasedTraining(FIFOScheduler):
|
||||
]
|
||||
# Log to global file.
|
||||
with open(os.path.join(trial.local_dir, "pbt_global.txt"), "a+") as f:
|
||||
f.write(json.dumps(policy) + "\n")
|
||||
print(json.dumps(policy, cls=_SafeFallbackEncoder), file=f)
|
||||
# Overwrite state in target trial from trial_to_clone.
|
||||
if os.path.exists(trial_to_clone_path):
|
||||
shutil.copyfile(trial_to_clone_path, trial_path)
|
||||
# Log new exploit in target trial log.
|
||||
with open(trial_path, "a+") as f:
|
||||
f.write(json.dumps(policy) + "\n")
|
||||
f.write(json.dumps(policy, cls=_SafeFallbackEncoder) + "\n")
|
||||
|
||||
def _exploit(self, trial_executor, trial, trial_to_clone):
|
||||
"""Transfers perturbed state from trial_to_clone -> trial.
|
||||
|
||||
Reference in New Issue
Block a user