mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 04:44:28 +08:00
[rllib] Rollout script needs to pipe in config and update states (#1566)
* Mon Feb 19 15:20:09 PST 2018 * fix it actually
This commit is contained in:
Regular → Executable
+12
-3
@@ -12,6 +12,7 @@ import pickle
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agent import get_agent_class
|
||||
from ray.rllib.dqn.common.wrappers import wrap_dqn
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.tune.registry import get_registry
|
||||
|
||||
@@ -67,12 +68,16 @@ if __name__ == "__main__":
|
||||
ray.init()
|
||||
|
||||
cls = get_agent_class(args.run)
|
||||
agent = cls(env=args.env)
|
||||
agent = cls(env=args.env, config=args.config)
|
||||
agent.restore(args.checkpoint)
|
||||
num_steps = int(args.steps)
|
||||
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
|
||||
gym.make(args.env))
|
||||
if args.run == "DQN":
|
||||
env = gym.make(args.env)
|
||||
env = wrap_dqn(get_registry(), env, args.config.get("model", {}))
|
||||
else:
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(get_registry(),
|
||||
gym.make(args.env))
|
||||
if args.out is not None:
|
||||
rollouts = []
|
||||
steps = 0
|
||||
@@ -81,15 +86,19 @@ if __name__ == "__main__":
|
||||
rollout = []
|
||||
state = env.reset()
|
||||
done = False
|
||||
reward_total = 0.0
|
||||
while not done and steps < (num_steps or steps + 1):
|
||||
action = agent.compute_action(state)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
reward_total += reward
|
||||
if not args.no_render:
|
||||
env.render()
|
||||
if args.out is not None:
|
||||
rollout.append([state, action, next_state, reward, done])
|
||||
steps += 1
|
||||
state = next_state
|
||||
if args.out is not None:
|
||||
rollouts.append(rollout)
|
||||
print("Episode reward", reward_total)
|
||||
if args.out is not None:
|
||||
pickle.dump(rollouts, open(args.out, "wb"))
|
||||
|
||||
Reference in New Issue
Block a user