Added Done to MultiAgentExternalEnv. (#8478)

Co-authored-by: devanderhoff <devanderhoff@hotmail.com>
This commit is contained in:
Dennis van der Hoff
2020-05-18 01:29:47 +02:00
committed by GitHub
parent 87cbf2aedd
commit be1f158747
4 changed files with 55 additions and 13 deletions
+6 -2
View File
@@ -148,6 +148,7 @@ class ExternalEnv(threading.Thread):
episode = self._get(episode_id)
episode.cur_reward += reward
if info:
episode.cur_info = info or {}
@@ -238,6 +239,9 @@ class _ExternalEnvEpisode:
def _send(self):
if self.multiagent:
if not self.training_enabled:
for agent_id in self.cur_info_dict:
self.cur_info_dict[agent_id]["training_enabled"] = False
item = {
"obs": self.new_observation_dict,
"reward": self.cur_reward_dict,
@@ -261,8 +265,8 @@ class _ExternalEnvEpisode:
self.new_observation = None
self.new_action = None
self.cur_reward = 0.0
if not self.training_enabled:
item["info"]["training_enabled"] = False
if not self.training_enabled:
item["info"]["training_enabled"] = False
with self.results_avail_condition:
self.data_queue.put_nowait(item)
self.results_avail_condition.notify()
+15 -2
View File
@@ -105,7 +105,11 @@ class ExternalMultiAgentEnv(ExternalEnv):
@PublicAPI
@override(ExternalEnv)
def log_returns(self, episode_id, reward_dict, info_dict=None):
def log_returns(self,
episode_id,
reward_dict,
info_dict=None,
multiagent_done_dict=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
@@ -115,7 +119,8 @@ class ExternalMultiAgentEnv(ExternalEnv):
Arguments:
episode_id (str): Episode id returned from start_episode().
reward_dict (dict): Reward from the environment agents.
info (dict): Optional info dict.
info_dict (dict): Optional info dict.
multiagent_done_dict (dict): Optional done dict for agents.
"""
episode = self._get(episode_id)
@@ -127,6 +132,14 @@ class ExternalMultiAgentEnv(ExternalEnv):
episode.cur_reward_dict[agent] += rew
else:
episode.cur_reward_dict[agent] = rew
if multiagent_done_dict:
for agent, done in multiagent_done_dict.items():
if agent in episode.cur_done_dict:
episode.cur_done_dict[agent] = done
else:
episode.cur_done_dict[agent] = done
if info_dict:
episode.cur_info_dict = info_dict or {}
+27 -7
View File
@@ -49,8 +49,9 @@ class PolicyClient:
address (str): Server to connect to (e.g., "localhost:9090").
inference_mode (str): Whether to use 'local' or 'remote' policy
inference for computing actions.
update_interval (float): If using 'local' inference mode, the
policy is refreshed after this many seconds have passed.
update_interval (float or None): If using 'local' inference mode,
the policy is refreshed after this many seconds have passed,
or None for manual control via client.
"""
self.address = address
if inference_mode == "local":
@@ -130,7 +131,11 @@ class PolicyClient:
})
@PublicAPI
def log_returns(self, episode_id, reward, info=None):
def log_returns(self,
episode_id,
reward,
info=None,
multiagent_done_dict=None):
"""Record returns from the environment.
The reward will be attributed to the previous action taken by the
@@ -140,17 +145,24 @@ class PolicyClient:
Arguments:
episode_id (str): Episode id returned from start_episode().
reward (float): Reward from the environment.
info (dict): Extra info dict.
multiagent_done_dict (dict): Multi-agent done information.
"""
if self.local:
self._update_local_policy()
return self.env.log_returns(episode_id, reward, info)
if multiagent_done_dict:
return self.env.log_returns(episode_id, reward, info,
multiagent_done_dict)
else:
return self.env.log_returns(episode_id, reward, info)
self._send({
"command": PolicyClient.LOG_RETURNS,
"reward": reward,
"info": info,
"episode_id": episode_id,
"done": multiagent_done_dict,
})
@PublicAPI
@@ -172,6 +184,12 @@ class PolicyClient:
"episode_id": episode_id,
})
@PublicAPI
def update_policy_weights(self):
"""Query the server for new policy weights, if local inference is enabled.
"""
self._update_local_policy(force=True)
def _send(self, data):
payload = pickle.dumps(data)
response = requests.post(self.address, data=payload)
@@ -195,9 +213,10 @@ class PolicyClient:
kwargs, self._send)
self.env = self.rollout_worker.env
def _update_local_policy(self):
def _update_local_policy(self, force=False):
assert self.inference_thread.is_alive()
if time.time() - self.last_updated > self.update_interval:
if (self.update_interval and time.time() - self.last_updated >
self.update_interval) or force:
logger.info("Querying server for new policy weights.")
resp = self._send({
"command": PolicyClient.GET_WEIGHTS,
@@ -253,7 +272,7 @@ def auto_wrap_external(real_env_creator):
"Attempting to convert it automatically to ExternalEnv.")
if isinstance(real_env, MultiAgentEnv):
external_cls = MultiAgentEnv
external_cls = ExternalMultiAgentEnv
else:
external_cls = ExternalEnv
@@ -268,6 +287,7 @@ def auto_wrap_external(real_env_creator):
time.sleep(999999)
return ExternalEnvWrapper(real_env)
return real_env
return wrapped_creator
+7 -2
View File
@@ -179,8 +179,13 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue):
args["episode_id"], args["observation"], args["action"])
elif command == PolicyClient.LOG_RETURNS:
assert inference_thread.is_alive()
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"])
if args["done"]:
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"],
args["done"])
else:
child_rollout_worker.env.log_returns(
args["episode_id"], args["reward"], args["info"])
elif command == PolicyClient.END_EPISODE:
assert inference_thread.is_alive()
child_rollout_worker.env.end_episode(args["episode_id"],