[rllib]Add config for rllib to support set python environments (#8026)

* support set extra python environments

* wrap value with str

* Apply suggestions from code review

Co-Authored-By: Eric Liang <ekhliang@gmail.com>

* addresses comments

* fix lint errors

* remove unrelated changes due to format.sh

* remove unrelated changes due to format.sh

Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
Xianyang Liu
2020-04-16 16:13:45 +08:00
committed by GitHub
parent 9345d03ffb
commit e1d3f7eba6
4 changed files with 42 additions and 3 deletions
+7 -1
View File
@@ -276,6 +276,11 @@ COMMON_CONFIG = {
# each worker, so that identically configured trials will have identical
# results. This makes experiments reproducible.
"seed": None,
# Any extra python env vars to set in the trainer process, e.g.,
# {"OMP_NUM_THREADS": "16"}
"extra_python_environs_for_driver": {},
# The extra python environments need to set for worker processes.
"extra_python_environs_for_worker": {},
# === Advanced Resource Settings ===
# Number of CPUs to allocate per worker.
@@ -397,7 +402,8 @@ class Trainer(Trainable):
_allow_unknown_subkeys = [
"tf_session_args", "local_tf_session_args", "env_config", "model",
"optimizer", "multiagent", "custom_resources_per_worker",
"evaluation_config", "exploration_config"
"evaluation_config", "exploration_config",
"extra_python_environs_for_driver", "extra_python_environs_for_worker"
]
# List of top level keys with value=dict, for which we always override the
+10 -1
View File
@@ -3,6 +3,7 @@ import numpy as np
import gym
import logging
import pickle
import os
import ray
from ray.util.debug import log_once, disable_log_once_globally, \
@@ -146,7 +147,8 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
soft_horizon=False,
no_done_at_end=False,
seed=None,
_fake_sampler=False):
_fake_sampler=False,
extra_python_environs=None):
"""Initialize a rollout worker.
Arguments:
@@ -239,6 +241,8 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
seed (int): Set the seed of both np and tf to this value to
to ensure each remote worker has unique exploration behavior.
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
extra_python_environs (dict): Extra python environments need to
be set.
"""
self._original_kwargs = locals().copy()
del self._original_kwargs["self"]
@@ -246,6 +250,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
global _global_worker
_global_worker = self
# set extra environs first
if extra_python_environs:
for key, value in extra_python_environs.items():
os.environ[key] = str(value)
def gen_rollouts():
while True:
yield self.sample()
+9 -1
View File
@@ -234,6 +234,13 @@ class WorkerSet:
tmp[k] = (policy, v[1], v[2], v[3])
policy = tmp
if worker_index == 0:
extra_python_environs = config.get(
"extra_python_environs_for_driver", None)
else:
extra_python_environs = config.get(
"extra_python_environs_for_worker", None)
worker = cls(
env_creator,
policy,
@@ -269,7 +276,8 @@ class WorkerSet:
no_done_at_end=config["no_done_at_end"],
seed=(config["seed"] + worker_index)
if config["seed"] is not None else None,
_fake_sampler=config.get("_fake_sampler", False))
_fake_sampler=config.get("_fake_sampler", False),
extra_python_environs=extra_python_environs)
# Check for correct policy class (only locally, remote Workers should
# create the exact same Policy types).
+16
View File
@@ -1,6 +1,7 @@
from collections import Counter
import gym
import numpy as np
import os
import random
import time
import unittest
@@ -488,6 +489,21 @@ class TestRolloutWorker(unittest.TestCase):
self.assertNotEqual(obs_f.buffer.n, 0)
return obs_f
def test_extra_python_envs(self):
extra_envs = {"env_key_1": "env_value_1", "env_key_2": "env_value_2"}
self.assertFalse("env_key_1" in os.environ)
self.assertFalse("env_key_2" in os.environ)
RolloutWorker(
env_creator=lambda _: MockEnv(10),
policy=MockPolicy,
extra_python_environs=extra_envs)
self.assertTrue("env_key_1" in os.environ)
self.assertTrue("env_key_2" in os.environ)
# reset to original
del os.environ["env_key_1"]
del os.environ["env_key_2"]
if __name__ == "__main__":
import pytest