diff --git a/python/ray/rllib/rollout.py b/python/ray/rllib/rollout.py old mode 100644 new mode 100755 index 324903d91..64174866a --- a/python/ray/rllib/rollout.py +++ b/python/ray/rllib/rollout.py @@ -12,6 +12,7 @@ import pickle import gym import ray from ray.rllib.agent import get_agent_class +from ray.rllib.dqn.common.wrappers import wrap_dqn from ray.rllib.models import ModelCatalog from ray.tune.registry import get_registry @@ -67,12 +68,16 @@ if __name__ == "__main__": ray.init() cls = get_agent_class(args.run) - agent = cls(env=args.env) + agent = cls(env=args.env, config=args.config) agent.restore(args.checkpoint) num_steps = int(args.steps) - env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(), - gym.make(args.env)) + if args.run == "DQN": + env = gym.make(args.env) + env = wrap_dqn(get_registry(), env, args.config.get("model", {})) + else: + env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(), + gym.make(args.env)) if args.out is not None: rollouts = [] steps = 0 @@ -81,15 +86,19 @@ if __name__ == "__main__": rollout = [] state = env.reset() done = False + reward_total = 0.0 while not done and steps < (num_steps or steps + 1): action = agent.compute_action(state) next_state, reward, done, _ = env.step(action) + reward_total += reward if not args.no_render: env.render() if args.out is not None: rollout.append([state, action, next_state, reward, done]) steps += 1 + state = next_state if args.out is not None: rollouts.append(rollout) + print("Episode reward", reward_total) if args.out is not None: pickle.dump(rollouts, open(args.out, "wb"))