[rllib] Rollout script needs to pipe in config and update states (#1566)

* Mon Feb 19 15:20:09 PST 2018

* fix it actually
This commit is contained in:
Eric Liang
2018-02-20 12:04:41 -08:00
committed by GitHub
parent fd03fb967f
commit 1b596f7d3b
Regular → Executable
+12 -3
View File
@@ -12,6 +12,7 @@ import pickle
import gym
import ray
from ray.rllib.agent import get_agent_class
from ray.rllib.dqn.common.wrappers import wrap_dqn
from ray.rllib.models import ModelCatalog
from ray.tune.registry import get_registry
@@ -67,12 +68,16 @@ if __name__ == "__main__":
ray.init()
cls = get_agent_class(args.run)
agent = cls(env=args.env)
agent = cls(env=args.env, config=args.config)
agent.restore(args.checkpoint)
num_steps = int(args.steps)
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
gym.make(args.env))
if args.run == "DQN":
env = gym.make(args.env)
env = wrap_dqn(get_registry(), env, args.config.get("model", {}))
else:
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
gym.make(args.env))
if args.out is not None:
rollouts = []
steps = 0
@@ -81,15 +86,19 @@ if __name__ == "__main__":
rollout = []
state = env.reset()
done = False
reward_total = 0.0
while not done and steps < (num_steps or steps + 1):
action = agent.compute_action(state)
next_state, reward, done, _ = env.step(action)
reward_total += reward
if not args.no_render:
env.render()
if args.out is not None:
rollout.append([state, action, next_state, reward, done])
steps += 1
state = next_state
if args.out is not None:
rollouts.append(rollout)
print("Episode reward", reward_total)
if args.out is not None:
pickle.dump(rollouts, open(args.out, "wb"))