[rllib] Add callback accessor for raw observation, fix prev actions (#4212)

This commit is contained in:
Eric Liang
2019-03-06 10:21:05 -08:00
committed by GitHub
parent 0e77a8f8c0
commit 6d705036f3
4 changed files with 19 additions and 1 deletions
+13
View File
@@ -58,6 +58,7 @@ class MultiAgentEpisode(object):
self._agent_to_policy = {}
self._agent_to_rnn_state = {}
self._agent_to_last_obs = {}
self._agent_to_last_raw_obs = {}
self._agent_to_last_info = {}
self._agent_to_last_action = {}
self._agent_to_last_pi_info = {}
@@ -82,6 +83,12 @@ class MultiAgentEpisode(object):
return self._agent_to_last_obs.get(agent_id)
@DeveloperAPI
def last_raw_obs_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last un-preprocessed obs for the specified agent."""
return self._agent_to_last_raw_obs.get(agent_id)
@DeveloperAPI
def last_info_for(self, agent_id=_DUMMY_AGENT_ID):
"""Returns the last info for the specified agent."""
@@ -149,10 +156,16 @@ class MultiAgentEpisode(object):
def _set_last_observation(self, agent_id, obs):
self._agent_to_last_obs[agent_id] = obs
def _set_last_raw_obs(self, agent_id, obs):
self._agent_to_last_raw_obs[agent_id] = obs
def _set_last_info(self, agent_id, info):
self._agent_to_last_info[agent_id] = info
def _set_last_action(self, agent_id, action):
if agent_id in self._agent_to_last_action:
self._agent_to_prev_action[agent_id] = \
self._agent_to_last_action[agent_id]
self._agent_to_last_action[agent_id] = action
def _set_last_pi_info(self, agent_id, pi_info):
+1
View File
@@ -372,6 +372,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
last_observation = episode.last_observation_for(agent_id)
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_raw_obs(agent_id, raw_obs)
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
# Record transition info if applicable
@@ -20,6 +20,8 @@ def on_episode_start(info):
def on_episode_step(info):
episode = info["episode"]
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
@@ -4,6 +4,7 @@ from __future__ import print_function
import gym
import numpy as np
import random
import time
import unittest
from collections import Counter
@@ -27,7 +28,7 @@ class MockPolicyGraph(PolicyGraph):
prev_reward_batch=None,
episodes=None,
**kwargs):
return [0] * len(obs_batch), [], {}
return [random.choice([0, 1])] * len(obs_batch), [], {}
def postprocess_trajectory(self,
batch,
@@ -138,6 +139,7 @@ class TestPolicyEvaluator(unittest.TestCase):
"prev_rewards", "prev_actions"
]:
self.assertIn(key, batch)
self.assertGreater(np.abs(np.mean(batch[key])), 0)
def to_prev(vec):
out = np.zeros_like(vec)