Fix python linting (#2076)

This commit is contained in:
Melih Elibol
2018-05-16 15:04:31 -07:00
committed by Robert Nishihara
parent 88fa98e851
commit bea97b425b
14 changed files with 91 additions and 88 deletions
+3 -3
View File
@@ -391,9 +391,9 @@ class Bracket():
def __repr__(self):
status = ", ".join([
"Max Size (n)={}".format(self._n), "Milestone (r)={}".format(
self._cumul_r), "completed={:.1%}".format(
self.completion_percentage())
"Max Size (n)={}".format(self._n),
"Milestone (r)={}".format(self._cumul_r),
"completed={:.1%}".format(self.completion_percentage())
])
counts = collections.Counter([t.status for t in self._all_trials])
trial_statuses = ", ".join(
+22 -18
View File
@@ -370,12 +370,14 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
sched.on_trial_error(mock_runner, t3)
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
def testTrialErrored2(self):
"""Check successive halving happened even when last trial failed"""
@@ -405,12 +407,14 @@ class HyperbandSuite(unittest.TestCase):
mock_runner._launch_trial(t)
sched.on_trial_complete(mock_runner, t3, result(1, 12))
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t1,
result(stats[str(1)]["r"], 10)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t2,
result(stats[str(1)]["r"], 10)))
def testTrialEndedEarly2(self):
"""Check successive halving happened even when last trial failed"""
@@ -449,13 +453,13 @@ class HyperbandSuite(unittest.TestCase):
self.assertEqual(len(sched._state["bracket"].current_trials()), 2)
# Make sure that newly added trial gets fair computation (not just 1)
self.assertEqual(TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t,
result(init_units, 12)))
self.assertEqual(
TrialScheduler.CONTINUE,
sched.on_trial_result(mock_runner, t, result(init_units, 12)))
new_units = init_units + int(init_units * sched._eta)
self.assertEqual(TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t,
result(new_units, 12)))
self.assertEqual(
TrialScheduler.PAUSE,
sched.on_trial_result(mock_runner, t, result(new_units, 12)))
def testAlternateMetrics(self):
"""Checking that alternate metrics will pass."""
+10 -7
View File
@@ -174,8 +174,8 @@ class Trial(object):
try:
if error_msg and self.logdir:
self.num_failures += 1
error_file = os.path.join(self.logdir, "error_{}.txt".format(
date_str()))
error_file = os.path.join(self.logdir,
"error_{}.txt".format(date_str()))
with open(error_file, "w") as f:
f.write(error_msg)
self.error_file = error_file
@@ -259,9 +259,10 @@ class Trial(object):
return '{} pid={}'.format(hostname, pid)
pieces = [
'{} [{}]'.format(self._status_string(),
location_string(self.last_result.hostname,
self.last_result.pid)),
'{} [{}]'.format(
self._status_string(),
location_string(self.last_result.hostname,
self.last_result.pid)),
'{} s'.format(int(self.last_result.time_total_s)), '{} ts'.format(
int(self.last_result.timesteps_total))
]
@@ -281,8 +282,10 @@ class Trial(object):
return ', '.join(pieces)
def _status_string(self):
return "{}{}".format(self.status, ", {} failures: {}".format(
self.num_failures, self.error_file) if self.error_file else "")
return "{}{}".format(
self.status, ", {} failures: {}".format(self.num_failures,
self.error_file)
if self.error_file else "")
def has_checkpoint(self):
return self._checkpoint_path is not None or \