mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 02:01:24 +08:00
[rllib] [minor] Rename agent_id to experiment_tag (#1143)
* tagstr * doc * rename * fix test
This commit is contained in:
@@ -42,7 +42,7 @@ class Agent(object):
|
||||
|
||||
def __init__(
|
||||
self, env_creator, config, local_dir='/tmp/ray',
|
||||
upload_dir=None, agent_id=None):
|
||||
upload_dir=None, experiment_tag=None):
|
||||
"""Initialize an RLLib agent.
|
||||
|
||||
Args:
|
||||
@@ -53,8 +53,10 @@ class Agent(object):
|
||||
be placed.
|
||||
upload_dir (str): Optional remote URI like s3://bucketname/ where
|
||||
results will be uploaded.
|
||||
agent_id (str): Optional unique identifier for this agent, used
|
||||
to determine where to store results in the local dir.
|
||||
experiment_tag (str): Optional string containing extra metadata
|
||||
about the experiment, e.g. a summary of parameters. This string
|
||||
will be included in the logdir path and when displaying agent
|
||||
progress.
|
||||
"""
|
||||
self._initialize_ok = False
|
||||
self._experiment_id = uuid.uuid4().hex
|
||||
@@ -63,7 +65,10 @@ class Agent(object):
|
||||
env_name = env_creator
|
||||
self.env_creator = lambda: gym.make(env_name)
|
||||
else:
|
||||
env_name = "custom"
|
||||
if hasattr(env_creator, "env_name"):
|
||||
env_name = env_creator.env_name
|
||||
else:
|
||||
env_name = "custom"
|
||||
self.env_creator = env_creator
|
||||
|
||||
self.config = self._default_config.copy()
|
||||
@@ -75,7 +80,7 @@ class Agent(object):
|
||||
"all agent configs: {}".format(k, self.config.keys()))
|
||||
self.config.update(config)
|
||||
self.config.update({
|
||||
"agent_id": agent_id,
|
||||
"experiment_tag": experiment_tag,
|
||||
"alg": self._agent_name,
|
||||
"env_name": env_name,
|
||||
"experiment_id": self._experiment_id,
|
||||
@@ -84,7 +89,7 @@ class Agent(object):
|
||||
logdir_suffix = "{}_{}_{}".format(
|
||||
env_name,
|
||||
self._agent_name,
|
||||
agent_id or datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
experiment_tag or datetime.today().strftime("%Y-%m-%d_%H-%M-%S"))
|
||||
|
||||
if not os.path.exists(local_dir):
|
||||
os.makedirs(local_dir)
|
||||
|
||||
@@ -93,13 +93,13 @@ def parse_to_trials(config):
|
||||
next_cfg, resolved_vars = grid_search.next()
|
||||
resolved, resolved_vars = resolve(next_cfg, resolved_vars, i)
|
||||
if resolved_vars:
|
||||
agent_id = "{}_{}".format(
|
||||
experiment_tag = "{}_{}".format(
|
||||
i, param_str(resolved, resolved_vars))
|
||||
else:
|
||||
agent_id = str(i)
|
||||
experiment_tag = str(i)
|
||||
trials.append(Trial(
|
||||
args.env, args.alg, resolved,
|
||||
os.path.join(args.local_dir, experiment_name), agent_id,
|
||||
os.path.join(args.local_dir, experiment_name), experiment_tag,
|
||||
args.resources, args.stop, args.checkpoint_freq, None,
|
||||
args.upload_dir))
|
||||
|
||||
|
||||
@@ -214,7 +214,7 @@ if __name__ == '__main__':
|
||||
'script_min_iter_time_s': 1,
|
||||
'activation': act,
|
||||
},
|
||||
agent_id='act={}'.format(act)))
|
||||
experiment_tag='act={}'.format(act)))
|
||||
|
||||
ray.init()
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class Trial(object):
|
||||
|
||||
def __init__(
|
||||
self, env_creator, alg, config={}, local_dir='/tmp/ray',
|
||||
agent_id=None, resources=Resources(cpu=1, gpu=0),
|
||||
experiment_tag=None, resources=Resources(cpu=1, gpu=0),
|
||||
stopping_criterion={}, checkpoint_freq=None,
|
||||
restore_path=None, upload_dir=None):
|
||||
"""Initialize a new trial.
|
||||
@@ -45,11 +45,14 @@ class Trial(object):
|
||||
if type(env_creator) is str:
|
||||
self.env_name = env_creator
|
||||
else:
|
||||
self.env_name = "custom"
|
||||
if hasattr(env_creator, "env_name"):
|
||||
self.env_name = env_creator.env_name
|
||||
else:
|
||||
self.env_name = "custom"
|
||||
self.alg = alg
|
||||
self.config = config
|
||||
self.local_dir = local_dir
|
||||
self.agent_id = agent_id
|
||||
self.experiment_tag = experiment_tag
|
||||
self.resources = resources
|
||||
self.stopping_criterion = stopping_criterion
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
@@ -77,7 +80,7 @@ class Trial(object):
|
||||
agent_cls)
|
||||
self.agent = cls.remote(
|
||||
self.env_creator, self.config, self.local_dir, self.upload_dir,
|
||||
agent_id=self.agent_id)
|
||||
experiment_tag=self.experiment_tag)
|
||||
if self.restore_path:
|
||||
ray.get(self.agent.restore.remote(self.restore_path))
|
||||
|
||||
@@ -178,8 +181,8 @@ class Trial(object):
|
||||
|
||||
def __str__(self):
|
||||
identifier = '{}_{}'.format(self.alg, self.env_name)
|
||||
if self.agent_id:
|
||||
identifier += '_' + self.agent_id
|
||||
if self.experiment_tag:
|
||||
identifier += '_' + self.experiment_tag
|
||||
return identifier
|
||||
|
||||
def __eq__(self, other):
|
||||
|
||||
Reference in New Issue
Block a user