From e1d3f7eba6775093dff6765784e2390fd326b826 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Thu, 16 Apr 2020 16:13:45 +0800 Subject: [PATCH] [rllib]Add config for rllib to support set python environments (#8026) * support set extra python environments * wrap value with str * Apply suggestions from code review Co-Authored-By: Eric Liang * addresses comments * fix lint errors * remove unrelated changes due to format.sh * remove unrelated changes due to format.sh Co-authored-by: Eric Liang --- rllib/agents/trainer.py | 8 +++++++- rllib/evaluation/rollout_worker.py | 11 ++++++++++- rllib/evaluation/worker_set.py | 10 +++++++++- rllib/tests/test_rollout_worker.py | 16 ++++++++++++++++ 4 files changed, 42 insertions(+), 3 deletions(-) diff --git a/rllib/agents/trainer.py b/rllib/agents/trainer.py index 072b4a072..f3c8f68a8 100644 --- a/rllib/agents/trainer.py +++ b/rllib/agents/trainer.py @@ -276,6 +276,11 @@ COMMON_CONFIG = { # each worker, so that identically configured trials will have identical # results. This makes experiments reproducible. "seed": None, + # Any extra python env vars to set in the trainer process, e.g., + # {"OMP_NUM_THREADS": "16"} + "extra_python_environs_for_driver": {}, + # The extra python environments need to set for worker processes. + "extra_python_environs_for_worker": {}, # === Advanced Resource Settings === # Number of CPUs to allocate per worker. @@ -397,7 +402,8 @@ class Trainer(Trainable): _allow_unknown_subkeys = [ "tf_session_args", "local_tf_session_args", "env_config", "model", "optimizer", "multiagent", "custom_resources_per_worker", - "evaluation_config", "exploration_config" + "evaluation_config", "exploration_config", + "extra_python_environs_for_driver", "extra_python_environs_for_worker" ] # List of top level keys with value=dict, for which we always override the diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 6f9443d60..20386764f 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -3,6 +3,7 @@ import numpy as np import gym import logging import pickle +import os import ray from ray.util.debug import log_once, disable_log_once_globally, \ @@ -146,7 +147,8 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker): soft_horizon=False, no_done_at_end=False, seed=None, - _fake_sampler=False): + _fake_sampler=False, + extra_python_environs=None): """Initialize a rollout worker. Arguments: @@ -239,6 +241,8 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker): seed (int): Set the seed of both np and tf to this value to to ensure each remote worker has unique exploration behavior. _fake_sampler (bool): Use a fake (inf speed) sampler for testing. + extra_python_environs (dict): Extra python environments need to + be set. """ self._original_kwargs = locals().copy() del self._original_kwargs["self"] @@ -246,6 +250,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker): global _global_worker _global_worker = self + # set extra environs first + if extra_python_environs: + for key, value in extra_python_environs.items(): + os.environ[key] = str(value) + def gen_rollouts(): while True: yield self.sample() diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 44bd10932..3acabf9aa 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -234,6 +234,13 @@ class WorkerSet: tmp[k] = (policy, v[1], v[2], v[3]) policy = tmp + if worker_index == 0: + extra_python_environs = config.get( + "extra_python_environs_for_driver", None) + else: + extra_python_environs = config.get( + "extra_python_environs_for_worker", None) + worker = cls( env_creator, policy, @@ -269,7 +276,8 @@ class WorkerSet: no_done_at_end=config["no_done_at_end"], seed=(config["seed"] + worker_index) if config["seed"] is not None else None, - _fake_sampler=config.get("_fake_sampler", False)) + _fake_sampler=config.get("_fake_sampler", False), + extra_python_environs=extra_python_environs) # Check for correct policy class (only locally, remote Workers should # create the exact same Policy types). diff --git a/rllib/tests/test_rollout_worker.py b/rllib/tests/test_rollout_worker.py index 21e79a435..9b73c43ee 100644 --- a/rllib/tests/test_rollout_worker.py +++ b/rllib/tests/test_rollout_worker.py @@ -1,6 +1,7 @@ from collections import Counter import gym import numpy as np +import os import random import time import unittest @@ -488,6 +489,21 @@ class TestRolloutWorker(unittest.TestCase): self.assertNotEqual(obs_f.buffer.n, 0) return obs_f + def test_extra_python_envs(self): + extra_envs = {"env_key_1": "env_value_1", "env_key_2": "env_value_2"} + self.assertFalse("env_key_1" in os.environ) + self.assertFalse("env_key_2" in os.environ) + RolloutWorker( + env_creator=lambda _: MockEnv(10), + policy=MockPolicy, + extra_python_environs=extra_envs) + self.assertTrue("env_key_1" in os.environ) + self.assertTrue("env_key_2" in os.environ) + + # reset to original + del os.environ["env_key_1"] + del os.environ["env_key_2"] + if __name__ == "__main__": import pytest