diff --git a/python/ray/rllib/evaluation/episode.py b/python/ray/rllib/evaluation/episode.py index 307568915..acf7d85cb 100644 --- a/python/ray/rllib/evaluation/episode.py +++ b/python/ray/rllib/evaluation/episode.py @@ -58,6 +58,7 @@ class MultiAgentEpisode(object): self._agent_to_policy = {} self._agent_to_rnn_state = {} self._agent_to_last_obs = {} + self._agent_to_last_raw_obs = {} self._agent_to_last_info = {} self._agent_to_last_action = {} self._agent_to_last_pi_info = {} @@ -82,6 +83,12 @@ class MultiAgentEpisode(object): return self._agent_to_last_obs.get(agent_id) + @DeveloperAPI + def last_raw_obs_for(self, agent_id=_DUMMY_AGENT_ID): + """Returns the last un-preprocessed obs for the specified agent.""" + + return self._agent_to_last_raw_obs.get(agent_id) + @DeveloperAPI def last_info_for(self, agent_id=_DUMMY_AGENT_ID): """Returns the last info for the specified agent.""" @@ -149,10 +156,16 @@ class MultiAgentEpisode(object): def _set_last_observation(self, agent_id, obs): self._agent_to_last_obs[agent_id] = obs + def _set_last_raw_obs(self, agent_id, obs): + self._agent_to_last_raw_obs[agent_id] = obs + def _set_last_info(self, agent_id, info): self._agent_to_last_info[agent_id] = info def _set_last_action(self, agent_id, action): + if agent_id in self._agent_to_last_action: + self._agent_to_prev_action[agent_id] = \ + self._agent_to_last_action[agent_id] self._agent_to_last_action[agent_id] = action def _set_last_pi_info(self, agent_id, pi_info): diff --git a/python/ray/rllib/evaluation/sampler.py b/python/ray/rllib/evaluation/sampler.py index 76a73274d..3d91cc44f 100644 --- a/python/ray/rllib/evaluation/sampler.py +++ b/python/ray/rllib/evaluation/sampler.py @@ -372,6 +372,7 @@ def _process_observations(base_env, policies, batch_builder_pool, last_observation = episode.last_observation_for(agent_id) episode._set_last_observation(agent_id, filtered_obs) + episode._set_last_raw_obs(agent_id, raw_obs) episode._set_last_info(agent_id, infos[env_id].get(agent_id, {})) # Record transition info if applicable diff --git a/python/ray/rllib/examples/custom_metrics_and_callbacks.py b/python/ray/rllib/examples/custom_metrics_and_callbacks.py index af1d25f16..0f0dcb040 100644 --- a/python/ray/rllib/examples/custom_metrics_and_callbacks.py +++ b/python/ray/rllib/examples/custom_metrics_and_callbacks.py @@ -20,6 +20,8 @@ def on_episode_start(info): def on_episode_step(info): episode = info["episode"] pole_angle = abs(episode.last_observation_for()[2]) + raw_angle = abs(episode.last_raw_obs_for()[2]) + assert pole_angle == raw_angle episode.user_data["pole_angles"].append(pole_angle) diff --git a/python/ray/rllib/tests/test_policy_evaluator.py b/python/ray/rllib/tests/test_policy_evaluator.py index a810542ad..d5466865b 100644 --- a/python/ray/rllib/tests/test_policy_evaluator.py +++ b/python/ray/rllib/tests/test_policy_evaluator.py @@ -4,6 +4,7 @@ from __future__ import print_function import gym import numpy as np +import random import time import unittest from collections import Counter @@ -27,7 +28,7 @@ class MockPolicyGraph(PolicyGraph): prev_reward_batch=None, episodes=None, **kwargs): - return [0] * len(obs_batch), [], {} + return [random.choice([0, 1])] * len(obs_batch), [], {} def postprocess_trajectory(self, batch, @@ -138,6 +139,7 @@ class TestPolicyEvaluator(unittest.TestCase): "prev_rewards", "prev_actions" ]: self.assertIn(key, batch) + self.assertGreater(np.abs(np.mean(batch[key])), 0) def to_prev(vec): out = np.zeros_like(vec)