diff --git a/main.py b/main.py index 9bf85ce..97a6b1c 100644 --- a/main.py +++ b/main.py @@ -109,7 +109,7 @@ for i_episode in itertools.count(): np.round(np.mean(rewards[-100:]),2))) if i_episode % 10 == 0 and args.eval == True: - state = torch.Tensor([env.reset()]) + state = env.reset() episode_reward = 0 while True: action = agent.select_action(state, eval=True)