mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 07:27:37 +08:00
Added Done to MultiAgentExternalEnv. (#8478)
Co-authored-by: devanderhoff <devanderhoff@hotmail.com>
This commit is contained in:
committed by
GitHub
parent
87cbf2aedd
commit
be1f158747
Vendored
+6
-2
@@ -148,6 +148,7 @@ class ExternalEnv(threading.Thread):
|
||||
|
||||
episode = self._get(episode_id)
|
||||
episode.cur_reward += reward
|
||||
|
||||
if info:
|
||||
episode.cur_info = info or {}
|
||||
|
||||
@@ -238,6 +239,9 @@ class _ExternalEnvEpisode:
|
||||
|
||||
def _send(self):
|
||||
if self.multiagent:
|
||||
if not self.training_enabled:
|
||||
for agent_id in self.cur_info_dict:
|
||||
self.cur_info_dict[agent_id]["training_enabled"] = False
|
||||
item = {
|
||||
"obs": self.new_observation_dict,
|
||||
"reward": self.cur_reward_dict,
|
||||
@@ -261,8 +265,8 @@ class _ExternalEnvEpisode:
|
||||
self.new_observation = None
|
||||
self.new_action = None
|
||||
self.cur_reward = 0.0
|
||||
if not self.training_enabled:
|
||||
item["info"]["training_enabled"] = False
|
||||
if not self.training_enabled:
|
||||
item["info"]["training_enabled"] = False
|
||||
with self.results_avail_condition:
|
||||
self.data_queue.put_nowait(item)
|
||||
self.results_avail_condition.notify()
|
||||
|
||||
+15
-2
@@ -105,7 +105,11 @@ class ExternalMultiAgentEnv(ExternalEnv):
|
||||
|
||||
@PublicAPI
|
||||
@override(ExternalEnv)
|
||||
def log_returns(self, episode_id, reward_dict, info_dict=None):
|
||||
def log_returns(self,
|
||||
episode_id,
|
||||
reward_dict,
|
||||
info_dict=None,
|
||||
multiagent_done_dict=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
The reward will be attributed to the previous action taken by the
|
||||
@@ -115,7 +119,8 @@ class ExternalMultiAgentEnv(ExternalEnv):
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
reward_dict (dict): Reward from the environment agents.
|
||||
info (dict): Optional info dict.
|
||||
info_dict (dict): Optional info dict.
|
||||
multiagent_done_dict (dict): Optional done dict for agents.
|
||||
"""
|
||||
|
||||
episode = self._get(episode_id)
|
||||
@@ -127,6 +132,14 @@ class ExternalMultiAgentEnv(ExternalEnv):
|
||||
episode.cur_reward_dict[agent] += rew
|
||||
else:
|
||||
episode.cur_reward_dict[agent] = rew
|
||||
|
||||
if multiagent_done_dict:
|
||||
for agent, done in multiagent_done_dict.items():
|
||||
if agent in episode.cur_done_dict:
|
||||
episode.cur_done_dict[agent] = done
|
||||
else:
|
||||
episode.cur_done_dict[agent] = done
|
||||
|
||||
if info_dict:
|
||||
episode.cur_info_dict = info_dict or {}
|
||||
|
||||
|
||||
Vendored
+27
-7
@@ -49,8 +49,9 @@ class PolicyClient:
|
||||
address (str): Server to connect to (e.g., "localhost:9090").
|
||||
inference_mode (str): Whether to use 'local' or 'remote' policy
|
||||
inference for computing actions.
|
||||
update_interval (float): If using 'local' inference mode, the
|
||||
policy is refreshed after this many seconds have passed.
|
||||
update_interval (float or None): If using 'local' inference mode,
|
||||
the policy is refreshed after this many seconds have passed,
|
||||
or None for manual control via client.
|
||||
"""
|
||||
self.address = address
|
||||
if inference_mode == "local":
|
||||
@@ -130,7 +131,11 @@ class PolicyClient:
|
||||
})
|
||||
|
||||
@PublicAPI
|
||||
def log_returns(self, episode_id, reward, info=None):
|
||||
def log_returns(self,
|
||||
episode_id,
|
||||
reward,
|
||||
info=None,
|
||||
multiagent_done_dict=None):
|
||||
"""Record returns from the environment.
|
||||
|
||||
The reward will be attributed to the previous action taken by the
|
||||
@@ -140,17 +145,24 @@ class PolicyClient:
|
||||
Arguments:
|
||||
episode_id (str): Episode id returned from start_episode().
|
||||
reward (float): Reward from the environment.
|
||||
info (dict): Extra info dict.
|
||||
multiagent_done_dict (dict): Multi-agent done information.
|
||||
"""
|
||||
|
||||
if self.local:
|
||||
self._update_local_policy()
|
||||
return self.env.log_returns(episode_id, reward, info)
|
||||
if multiagent_done_dict:
|
||||
return self.env.log_returns(episode_id, reward, info,
|
||||
multiagent_done_dict)
|
||||
else:
|
||||
return self.env.log_returns(episode_id, reward, info)
|
||||
|
||||
self._send({
|
||||
"command": PolicyClient.LOG_RETURNS,
|
||||
"reward": reward,
|
||||
"info": info,
|
||||
"episode_id": episode_id,
|
||||
"done": multiagent_done_dict,
|
||||
})
|
||||
|
||||
@PublicAPI
|
||||
@@ -172,6 +184,12 @@ class PolicyClient:
|
||||
"episode_id": episode_id,
|
||||
})
|
||||
|
||||
@PublicAPI
|
||||
def update_policy_weights(self):
|
||||
"""Query the server for new policy weights, if local inference is enabled.
|
||||
"""
|
||||
self._update_local_policy(force=True)
|
||||
|
||||
def _send(self, data):
|
||||
payload = pickle.dumps(data)
|
||||
response = requests.post(self.address, data=payload)
|
||||
@@ -195,9 +213,10 @@ class PolicyClient:
|
||||
kwargs, self._send)
|
||||
self.env = self.rollout_worker.env
|
||||
|
||||
def _update_local_policy(self):
|
||||
def _update_local_policy(self, force=False):
|
||||
assert self.inference_thread.is_alive()
|
||||
if time.time() - self.last_updated > self.update_interval:
|
||||
if (self.update_interval and time.time() - self.last_updated >
|
||||
self.update_interval) or force:
|
||||
logger.info("Querying server for new policy weights.")
|
||||
resp = self._send({
|
||||
"command": PolicyClient.GET_WEIGHTS,
|
||||
@@ -253,7 +272,7 @@ def auto_wrap_external(real_env_creator):
|
||||
"Attempting to convert it automatically to ExternalEnv.")
|
||||
|
||||
if isinstance(real_env, MultiAgentEnv):
|
||||
external_cls = MultiAgentEnv
|
||||
external_cls = ExternalMultiAgentEnv
|
||||
else:
|
||||
external_cls = ExternalEnv
|
||||
|
||||
@@ -268,6 +287,7 @@ def auto_wrap_external(real_env_creator):
|
||||
time.sleep(999999)
|
||||
|
||||
return ExternalEnvWrapper(real_env)
|
||||
return real_env
|
||||
|
||||
return wrapped_creator
|
||||
|
||||
|
||||
Vendored
+7
-2
@@ -179,8 +179,13 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
|
||||
args["episode_id"], args["observation"], args["action"])
|
||||
elif command == PolicyClient.LOG_RETURNS:
|
||||
assert inference_thread.is_alive()
|
||||
child_rollout_worker.env.log_returns(
|
||||
args["episode_id"], args["reward"], args["info"])
|
||||
if args["done"]:
|
||||
child_rollout_worker.env.log_returns(
|
||||
args["episode_id"], args["reward"], args["info"],
|
||||
args["done"])
|
||||
else:
|
||||
child_rollout_worker.env.log_returns(
|
||||
args["episode_id"], args["reward"], args["info"])
|
||||
elif command == PolicyClient.END_EPISODE:
|
||||
assert inference_thread.is_alive()
|
||||
child_rollout_worker.env.end_episode(args["episode_id"],
|
||||
|
||||
Reference in New Issue
Block a user