mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:53:18 +08:00
[rllib] Adds agent name & env id to default logdir prefix (#2859)
* Added agent name & env id to default logdir prefix * Revert "Added agent name & env id to default logdir prefix" This reverts commit 07cfdf80d2537da3c67dd4f553c5f3e43671cc7d. * Added default logger creator with informative prefix to Agent * Updated import order & improved str cat * Update agent.py
This commit is contained in:
committed by
Eric Liang
parent
3a3782c39f
commit
b23fd5de13
@@ -7,6 +7,8 @@ import json
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
|
||||
import tensorflow as tf
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
@@ -14,6 +16,8 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.utils import FilterManager, deep_update, merge_dicts
|
||||
from ray.tune.registry import ENV_CREATOR, _global_registry
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
|
||||
COMMON_CONFIG = {
|
||||
# Discount factor of the MDP
|
||||
@@ -190,6 +194,24 @@ class Agent(Trainable):
|
||||
|
||||
# Agents allow env ids to be passed directly to the constructor.
|
||||
self._env_id = env or config.get("env")
|
||||
|
||||
# Create a default logger creator if no logger_creator is specified
|
||||
if logger_creator is None:
|
||||
timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
logdir_prefix = '_'.join([self._agent_name, self._env_id, timestr])
|
||||
|
||||
def default_logger_creator(config):
|
||||
"""Creates a Unified logger with a default logdir prefix
|
||||
containing the agent name and the env id
|
||||
"""
|
||||
if not os.path.exists(DEFAULT_RESULTS_DIR):
|
||||
os.makedirs(DEFAULT_RESULTS_DIR)
|
||||
logdir = tempfile.mkdtemp(
|
||||
prefix=logdir_prefix, dir=DEFAULT_RESULTS_DIR)
|
||||
return UnifiedLogger(config, logdir, None)
|
||||
|
||||
logger_creator = default_logger_creator
|
||||
|
||||
Trainable.__init__(self, config, logger_creator)
|
||||
|
||||
def train(self):
|
||||
|
||||
Reference in New Issue
Block a user