mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:23:17 +08:00
[rllib] Add callback accessor for raw observation, fix prev actions (#4212)
This commit is contained in:
@@ -58,6 +58,7 @@ class MultiAgentEpisode(object):
|
||||
self._agent_to_policy = {}
|
||||
self._agent_to_rnn_state = {}
|
||||
self._agent_to_last_obs = {}
|
||||
self._agent_to_last_raw_obs = {}
|
||||
self._agent_to_last_info = {}
|
||||
self._agent_to_last_action = {}
|
||||
self._agent_to_last_pi_info = {}
|
||||
@@ -82,6 +83,12 @@ class MultiAgentEpisode(object):
|
||||
|
||||
return self._agent_to_last_obs.get(agent_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def last_raw_obs_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last un-preprocessed obs for the specified agent."""
|
||||
|
||||
return self._agent_to_last_raw_obs.get(agent_id)
|
||||
|
||||
@DeveloperAPI
|
||||
def last_info_for(self, agent_id=_DUMMY_AGENT_ID):
|
||||
"""Returns the last info for the specified agent."""
|
||||
@@ -149,10 +156,16 @@ class MultiAgentEpisode(object):
|
||||
def _set_last_observation(self, agent_id, obs):
|
||||
self._agent_to_last_obs[agent_id] = obs
|
||||
|
||||
def _set_last_raw_obs(self, agent_id, obs):
|
||||
self._agent_to_last_raw_obs[agent_id] = obs
|
||||
|
||||
def _set_last_info(self, agent_id, info):
|
||||
self._agent_to_last_info[agent_id] = info
|
||||
|
||||
def _set_last_action(self, agent_id, action):
|
||||
if agent_id in self._agent_to_last_action:
|
||||
self._agent_to_prev_action[agent_id] = \
|
||||
self._agent_to_last_action[agent_id]
|
||||
self._agent_to_last_action[agent_id] = action
|
||||
|
||||
def _set_last_pi_info(self, agent_id, pi_info):
|
||||
|
||||
@@ -372,6 +372,7 @@ def _process_observations(base_env, policies, batch_builder_pool,
|
||||
|
||||
last_observation = episode.last_observation_for(agent_id)
|
||||
episode._set_last_observation(agent_id, filtered_obs)
|
||||
episode._set_last_raw_obs(agent_id, raw_obs)
|
||||
episode._set_last_info(agent_id, infos[env_id].get(agent_id, {}))
|
||||
|
||||
# Record transition info if applicable
|
||||
|
||||
@@ -20,6 +20,8 @@ def on_episode_start(info):
|
||||
def on_episode_step(info):
|
||||
episode = info["episode"]
|
||||
pole_angle = abs(episode.last_observation_for()[2])
|
||||
raw_angle = abs(episode.last_raw_obs_for()[2])
|
||||
assert pole_angle == raw_angle
|
||||
episode.user_data["pole_angles"].append(pole_angle)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import print_function
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
from collections import Counter
|
||||
@@ -27,7 +28,7 @@ class MockPolicyGraph(PolicyGraph):
|
||||
prev_reward_batch=None,
|
||||
episodes=None,
|
||||
**kwargs):
|
||||
return [0] * len(obs_batch), [], {}
|
||||
return [random.choice([0, 1])] * len(obs_batch), [], {}
|
||||
|
||||
def postprocess_trajectory(self,
|
||||
batch,
|
||||
@@ -138,6 +139,7 @@ class TestPolicyEvaluator(unittest.TestCase):
|
||||
"prev_rewards", "prev_actions"
|
||||
]:
|
||||
self.assertIn(key, batch)
|
||||
self.assertGreater(np.abs(np.mean(batch[key])), 0)
|
||||
|
||||
def to_prev(vec):
|
||||
out = np.zeros_like(vec)
|
||||
|
||||
Reference in New Issue
Block a user