diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 10436bb97..6a75d73aa 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -42,7 +42,7 @@ class Agent(object): def __init__( self, env_creator, config, local_dir='/tmp/ray', - upload_dir=None, agent_id=None): + upload_dir=None, experiment_tag=None): """Initialize an RLLib agent. Args: @@ -53,8 +53,10 @@ class Agent(object): be placed. upload_dir (str): Optional remote URI like s3://bucketname/ where results will be uploaded. - agent_id (str): Optional unique identifier for this agent, used - to determine where to store results in the local dir. + experiment_tag (str): Optional string containing extra metadata + about the experiment, e.g. a summary of parameters. This string + will be included in the logdir path and when displaying agent + progress. """ self._initialize_ok = False self._experiment_id = uuid.uuid4().hex @@ -63,7 +65,10 @@ class Agent(object): env_name = env_creator self.env_creator = lambda: gym.make(env_name) else: - env_name = "custom" + if hasattr(env_creator, "env_name"): + env_name = env_creator.env_name + else: + env_name = "custom" self.env_creator = env_creator self.config = self._default_config.copy() @@ -75,7 +80,7 @@ class Agent(object): "all agent configs: {}".format(k, self.config.keys())) self.config.update(config) self.config.update({ - "agent_id": agent_id, + "experiment_tag": experiment_tag, "alg": self._agent_name, "env_name": env_name, "experiment_id": self._experiment_id, @@ -84,7 +89,7 @@ class Agent(object): logdir_suffix = "{}_{}_{}".format( env_name, self._agent_name, - agent_id or datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) + experiment_tag or datetime.today().strftime("%Y-%m-%d_%H-%M-%S")) if not os.path.exists(local_dir): os.makedirs(local_dir) diff --git a/python/ray/tune/config_parser.py b/python/ray/tune/config_parser.py index 9f4135841..1cba83f73 100644 --- a/python/ray/tune/config_parser.py +++ b/python/ray/tune/config_parser.py @@ -93,13 +93,13 @@ def parse_to_trials(config): next_cfg, resolved_vars = grid_search.next() resolved, resolved_vars = resolve(next_cfg, resolved_vars, i) if resolved_vars: - agent_id = "{}_{}".format( + experiment_tag = "{}_{}".format( i, param_str(resolved, resolved_vars)) else: - agent_id = str(i) + experiment_tag = str(i) trials.append(Trial( args.env, args.alg, resolved, - os.path.join(args.local_dir, experiment_name), agent_id, + os.path.join(args.local_dir, experiment_name), experiment_tag, args.resources, args.stop, args.checkpoint_freq, None, args.upload_dir)) diff --git a/python/ray/tune/examples/tune_mnist_ray.py b/python/ray/tune/examples/tune_mnist_ray.py index 2a10070d7..a306e7af5 100755 --- a/python/ray/tune/examples/tune_mnist_ray.py +++ b/python/ray/tune/examples/tune_mnist_ray.py @@ -214,7 +214,7 @@ if __name__ == '__main__': 'script_min_iter_time_s': 1, 'activation': act, }, - agent_id='act={}'.format(act))) + experiment_tag='act={}'.format(act))) ray.init() diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index d92c8a9c5..9f5713253 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -31,7 +31,7 @@ class Trial(object): def __init__( self, env_creator, alg, config={}, local_dir='/tmp/ray', - agent_id=None, resources=Resources(cpu=1, gpu=0), + experiment_tag=None, resources=Resources(cpu=1, gpu=0), stopping_criterion={}, checkpoint_freq=None, restore_path=None, upload_dir=None): """Initialize a new trial. @@ -45,11 +45,14 @@ class Trial(object): if type(env_creator) is str: self.env_name = env_creator else: - self.env_name = "custom" + if hasattr(env_creator, "env_name"): + self.env_name = env_creator.env_name + else: + self.env_name = "custom" self.alg = alg self.config = config self.local_dir = local_dir - self.agent_id = agent_id + self.experiment_tag = experiment_tag self.resources = resources self.stopping_criterion = stopping_criterion self.checkpoint_freq = checkpoint_freq @@ -77,7 +80,7 @@ class Trial(object): agent_cls) self.agent = cls.remote( self.env_creator, self.config, self.local_dir, self.upload_dir, - agent_id=self.agent_id) + experiment_tag=self.experiment_tag) if self.restore_path: ray.get(self.agent.restore.remote(self.restore_path)) @@ -178,8 +181,8 @@ class Trial(object): def __str__(self): identifier = '{}_{}'.format(self.alg, self.env_name) - if self.agent_id: - identifier += '_' + self.agent_id + if self.experiment_tag: + identifier += '_' + self.experiment_tag return identifier def __eq__(self, other): diff --git a/test/trial_runner_test.py b/test/trial_runner_test.py index f4ae36fdd..3b842df44 100644 --- a/test/trial_runner_test.py +++ b/test/trial_runner_test.py @@ -26,9 +26,9 @@ class ConfigParserTest(unittest.TestCase): self.assertEqual(trials[0].env_name, "Pong-v0") self.assertEqual(trials[0].config, {"foo": "bar"}) self.assertEqual(trials[0].alg, "PPO") - self.assertEqual(trials[0].agent_id, "0") + self.assertEqual(trials[0].experiment_tag, "0") self.assertEqual(trials[0].local_dir, "/tmp/ray/tune-pong") - self.assertEqual(trials[1].agent_id, "1") + self.assertEqual(trials[1].experiment_tag, "1") def testEval(self): trials = parse_to_trials({ @@ -43,7 +43,7 @@ class ConfigParserTest(unittest.TestCase): }) self.assertEqual(len(trials), 1) self.assertEqual(trials[0].config, {"foo": 4}) - self.assertEqual(trials[0].agent_id, "0_foo=4") + self.assertEqual(trials[0].experiment_tag, "0_foo=4") def testGridSearch(self): trials = parse_to_trials({ @@ -62,9 +62,9 @@ class ConfigParserTest(unittest.TestCase): }) self.assertEqual(len(trials), 6) self.assertEqual(trials[0].config, {"bar": True, "foo": 1}) - self.assertEqual(trials[0].agent_id, "0_bar=True_foo=1") + self.assertEqual(trials[0].experiment_tag, "0_bar=True_foo=1") self.assertEqual(trials[1].config, {"bar": False, "foo": 1}) - self.assertEqual(trials[1].agent_id, "1_bar=False_foo=1") + self.assertEqual(trials[1].experiment_tag, "1_bar=False_foo=1") self.assertEqual(trials[2].config, {"bar": True, "foo": 2}) self.assertEqual(trials[3].config, {"bar": False, "foo": 2}) self.assertEqual(trials[4].config, {"bar": True, "foo": 3}) @@ -90,7 +90,7 @@ class ConfigParserTest(unittest.TestCase): }) self.assertEqual(len(trials), 1) self.assertEqual(trials[0].config, {"bar": True, "foo": 1, "qux": 4}) - self.assertEqual(trials[0].agent_id, "0_bar=True_foo=1_qux=4") + self.assertEqual(trials[0].experiment_tag, "0_bar=True_foo=1_qux=4") class TrialRunnerTest(unittest.TestCase):