mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 16:31:38 +08:00
[rllib]Add config for rllib to support set python environments (#8026)
* support set extra python environments * wrap value with str * Apply suggestions from code review Co-Authored-By: Eric Liang <ekhliang@gmail.com> * addresses comments * fix lint errors * remove unrelated changes due to format.sh * remove unrelated changes due to format.sh Co-authored-by: Eric Liang <ekhliang@gmail.com>
This commit is contained in:
@@ -276,6 +276,11 @@ COMMON_CONFIG = {
|
||||
# each worker, so that identically configured trials will have identical
|
||||
# results. This makes experiments reproducible.
|
||||
"seed": None,
|
||||
# Any extra python env vars to set in the trainer process, e.g.,
|
||||
# {"OMP_NUM_THREADS": "16"}
|
||||
"extra_python_environs_for_driver": {},
|
||||
# The extra python environments need to set for worker processes.
|
||||
"extra_python_environs_for_worker": {},
|
||||
|
||||
# === Advanced Resource Settings ===
|
||||
# Number of CPUs to allocate per worker.
|
||||
@@ -397,7 +402,8 @@ class Trainer(Trainable):
|
||||
_allow_unknown_subkeys = [
|
||||
"tf_session_args", "local_tf_session_args", "env_config", "model",
|
||||
"optimizer", "multiagent", "custom_resources_per_worker",
|
||||
"evaluation_config", "exploration_config"
|
||||
"evaluation_config", "exploration_config",
|
||||
"extra_python_environs_for_driver", "extra_python_environs_for_worker"
|
||||
]
|
||||
|
||||
# List of top level keys with value=dict, for which we always override the
|
||||
|
||||
@@ -3,6 +3,7 @@ import numpy as np
|
||||
import gym
|
||||
import logging
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import ray
|
||||
from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
@@ -146,7 +147,8 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
soft_horizon=False,
|
||||
no_done_at_end=False,
|
||||
seed=None,
|
||||
_fake_sampler=False):
|
||||
_fake_sampler=False,
|
||||
extra_python_environs=None):
|
||||
"""Initialize a rollout worker.
|
||||
|
||||
Arguments:
|
||||
@@ -239,6 +241,8 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
seed (int): Set the seed of both np and tf to this value to
|
||||
to ensure each remote worker has unique exploration behavior.
|
||||
_fake_sampler (bool): Use a fake (inf speed) sampler for testing.
|
||||
extra_python_environs (dict): Extra python environments need to
|
||||
be set.
|
||||
"""
|
||||
self._original_kwargs = locals().copy()
|
||||
del self._original_kwargs["self"]
|
||||
@@ -246,6 +250,11 @@ class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
global _global_worker
|
||||
_global_worker = self
|
||||
|
||||
# set extra environs first
|
||||
if extra_python_environs:
|
||||
for key, value in extra_python_environs.items():
|
||||
os.environ[key] = str(value)
|
||||
|
||||
def gen_rollouts():
|
||||
while True:
|
||||
yield self.sample()
|
||||
|
||||
@@ -234,6 +234,13 @@ class WorkerSet:
|
||||
tmp[k] = (policy, v[1], v[2], v[3])
|
||||
policy = tmp
|
||||
|
||||
if worker_index == 0:
|
||||
extra_python_environs = config.get(
|
||||
"extra_python_environs_for_driver", None)
|
||||
else:
|
||||
extra_python_environs = config.get(
|
||||
"extra_python_environs_for_worker", None)
|
||||
|
||||
worker = cls(
|
||||
env_creator,
|
||||
policy,
|
||||
@@ -269,7 +276,8 @@ class WorkerSet:
|
||||
no_done_at_end=config["no_done_at_end"],
|
||||
seed=(config["seed"] + worker_index)
|
||||
if config["seed"] is not None else None,
|
||||
_fake_sampler=config.get("_fake_sampler", False))
|
||||
_fake_sampler=config.get("_fake_sampler", False),
|
||||
extra_python_environs=extra_python_environs)
|
||||
|
||||
# Check for correct policy class (only locally, remote Workers should
|
||||
# create the exact same Policy types).
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections import Counter
|
||||
import gym
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
@@ -488,6 +489,21 @@ class TestRolloutWorker(unittest.TestCase):
|
||||
self.assertNotEqual(obs_f.buffer.n, 0)
|
||||
return obs_f
|
||||
|
||||
def test_extra_python_envs(self):
|
||||
extra_envs = {"env_key_1": "env_value_1", "env_key_2": "env_value_2"}
|
||||
self.assertFalse("env_key_1" in os.environ)
|
||||
self.assertFalse("env_key_2" in os.environ)
|
||||
RolloutWorker(
|
||||
env_creator=lambda _: MockEnv(10),
|
||||
policy=MockPolicy,
|
||||
extra_python_environs=extra_envs)
|
||||
self.assertTrue("env_key_1" in os.environ)
|
||||
self.assertTrue("env_key_2" in os.environ)
|
||||
|
||||
# reset to original
|
||||
del os.environ["env_key_1"]
|
||||
del os.environ["env_key_2"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
Reference in New Issue
Block a user