diff --git a/rllib/env/external_env.py b/rllib/env/external_env.py index 89b1c544f..6f2c36928 100644 --- a/rllib/env/external_env.py +++ b/rllib/env/external_env.py @@ -148,6 +148,7 @@ class ExternalEnv(threading.Thread): episode = self._get(episode_id) episode.cur_reward += reward + if info: episode.cur_info = info or {} @@ -238,6 +239,9 @@ class _ExternalEnvEpisode: def _send(self): if self.multiagent: + if not self.training_enabled: + for agent_id in self.cur_info_dict: + self.cur_info_dict[agent_id]["training_enabled"] = False item = { "obs": self.new_observation_dict, "reward": self.cur_reward_dict, @@ -261,8 +265,8 @@ class _ExternalEnvEpisode: self.new_observation = None self.new_action = None self.cur_reward = 0.0 - if not self.training_enabled: - item["info"]["training_enabled"] = False + if not self.training_enabled: + item["info"]["training_enabled"] = False with self.results_avail_condition: self.data_queue.put_nowait(item) self.results_avail_condition.notify() diff --git a/rllib/env/external_multi_agent_env.py b/rllib/env/external_multi_agent_env.py index 414679274..7604aca42 100644 --- a/rllib/env/external_multi_agent_env.py +++ b/rllib/env/external_multi_agent_env.py @@ -105,7 +105,11 @@ class ExternalMultiAgentEnv(ExternalEnv): @PublicAPI @override(ExternalEnv) - def log_returns(self, episode_id, reward_dict, info_dict=None): + def log_returns(self, + episode_id, + reward_dict, + info_dict=None, + multiagent_done_dict=None): """Record returns from the environment. The reward will be attributed to the previous action taken by the @@ -115,7 +119,8 @@ class ExternalMultiAgentEnv(ExternalEnv): Arguments: episode_id (str): Episode id returned from start_episode(). reward_dict (dict): Reward from the environment agents. - info (dict): Optional info dict. + info_dict (dict): Optional info dict. + multiagent_done_dict (dict): Optional done dict for agents. """ episode = self._get(episode_id) @@ -127,6 +132,14 @@ class ExternalMultiAgentEnv(ExternalEnv): episode.cur_reward_dict[agent] += rew else: episode.cur_reward_dict[agent] = rew + + if multiagent_done_dict: + for agent, done in multiagent_done_dict.items(): + if agent in episode.cur_done_dict: + episode.cur_done_dict[agent] = done + else: + episode.cur_done_dict[agent] = done + if info_dict: episode.cur_info_dict = info_dict or {} diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index 2fbace2d3..6249c3f93 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -49,8 +49,9 @@ class PolicyClient: address (str): Server to connect to (e.g., "localhost:9090"). inference_mode (str): Whether to use 'local' or 'remote' policy inference for computing actions. - update_interval (float): If using 'local' inference mode, the - policy is refreshed after this many seconds have passed. + update_interval (float or None): If using 'local' inference mode, + the policy is refreshed after this many seconds have passed, + or None for manual control via client. """ self.address = address if inference_mode == "local": @@ -130,7 +131,11 @@ class PolicyClient: }) @PublicAPI - def log_returns(self, episode_id, reward, info=None): + def log_returns(self, + episode_id, + reward, + info=None, + multiagent_done_dict=None): """Record returns from the environment. The reward will be attributed to the previous action taken by the @@ -140,17 +145,24 @@ class PolicyClient: Arguments: episode_id (str): Episode id returned from start_episode(). reward (float): Reward from the environment. + info (dict): Extra info dict. + multiagent_done_dict (dict): Multi-agent done information. """ if self.local: self._update_local_policy() - return self.env.log_returns(episode_id, reward, info) + if multiagent_done_dict: + return self.env.log_returns(episode_id, reward, info, + multiagent_done_dict) + else: + return self.env.log_returns(episode_id, reward, info) self._send({ "command": PolicyClient.LOG_RETURNS, "reward": reward, "info": info, "episode_id": episode_id, + "done": multiagent_done_dict, }) @PublicAPI @@ -172,6 +184,12 @@ class PolicyClient: "episode_id": episode_id, }) + @PublicAPI + def update_policy_weights(self): + """Query the server for new policy weights, if local inference is enabled. + """ + self._update_local_policy(force=True) + def _send(self, data): payload = pickle.dumps(data) response = requests.post(self.address, data=payload) @@ -195,9 +213,10 @@ class PolicyClient: kwargs, self._send) self.env = self.rollout_worker.env - def _update_local_policy(self): + def _update_local_policy(self, force=False): assert self.inference_thread.is_alive() - if time.time() - self.last_updated > self.update_interval: + if (self.update_interval and time.time() - self.last_updated > + self.update_interval) or force: logger.info("Querying server for new policy weights.") resp = self._send({ "command": PolicyClient.GET_WEIGHTS, @@ -253,7 +272,7 @@ def auto_wrap_external(real_env_creator): "Attempting to convert it automatically to ExternalEnv.") if isinstance(real_env, MultiAgentEnv): - external_cls = MultiAgentEnv + external_cls = ExternalMultiAgentEnv else: external_cls = ExternalEnv @@ -268,6 +287,7 @@ def auto_wrap_external(real_env_creator): time.sleep(999999) return ExternalEnvWrapper(real_env) + return real_env return wrapped_creator diff --git a/rllib/env/policy_server_input.py b/rllib/env/policy_server_input.py index 6bc2d7267..2d8a4770d 100644 --- a/rllib/env/policy_server_input.py +++ b/rllib/env/policy_server_input.py @@ -179,8 +179,13 @@ def _make_handler(rollout_worker, samples_queue, metrics_queue): args["episode_id"], args["observation"], args["action"]) elif command == PolicyClient.LOG_RETURNS: assert inference_thread.is_alive() - child_rollout_worker.env.log_returns( - args["episode_id"], args["reward"], args["info"]) + if args["done"]: + child_rollout_worker.env.log_returns( + args["episode_id"], args["reward"], args["info"], + args["done"]) + else: + child_rollout_worker.env.log_returns( + args["episode_id"], args["reward"], args["info"]) elif command == PolicyClient.END_EPISODE: assert inference_thread.is_alive() child_rollout_worker.env.end_episode(args["episode_id"],