Files
2020-07-11 14:18:02 +05:30

136 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import argparse
import datetime
import gym
import numpy as np
import itertools
import torch
import pybullet_envs
from sac import SAC
from torch.utils.tensorboard import SummaryWriter
from replay_memory import ReplayMemory
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--env-name', default="HalfCheetahBulletEnv-v0",
help='name of the environment to run')
parser.add_argument('--eval', type=bool, default=True,
help='Evaluates a policy a policy every 10 episode (default:True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
help='learning rate (default: 0.0003)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
help='Temperature parameter α determines the relative importance of the entropy term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
help='Temperature parameter α automaically adjusted.')
parser.add_argument('--seed', type=int, default=456, metavar='N',
help='random seed (default: 456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
help='batch size (default: 256)')
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N',
help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_size', type=int, default=256, metavar='N',
help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_true",
help='run on CUDA (default: False)')
args = parser.parse_args()
# Environment
env = gym.make(args.env_name)
env.seed(args.seed)
env.action_space.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Agent
agent = SAC(env.observation_space.shape[0], env.action_space, args)
writer = SummaryWriter('runs/{}_SAC_V_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name))
# Memory
memory = ReplayMemory(args.replay_size, args.seed)
# Training Loop
total_numsteps = 0
updates = 0
for i_episode in itertools.count(1):
episode_reward = 0
episode_steps = 0
done = False
state = env.reset()
while not done:
if args.start_steps > total_numsteps:
action = env.action_space.sample()
else:
action = agent.select_action(state) # Sample action from policy
if len(memory) > args.batch_size:
for i in range(args.updates_per_step): # Number of updates per step in environment
# Update parameters of all the networks
value_loss, critic_1_loss, critic_2_loss, policy_loss = agent.update_parameters(memory, args.batch_size, updates)
writer.add_scalar('loss/value', value_loss, updates)
writer.add_scalar('loss/critic_1', critic_1_loss, updates)
writer.add_scalar('loss/critic_2', critic_2_loss, updates)
writer.add_scalar('loss/policy', policy_loss, updates)
updates += 1
next_state, reward, done, _ = env.step(action) # Step
episode_steps += 1
total_numsteps += 1
episode_reward += reward
# Ignore the "done" signal if it comes from hitting the time horizon.
# (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
mask = 1 if episode_steps == env._max_episode_steps else float(not done)
memory.push(state, action, reward, next_state, mask) # Append transition to memory
state = next_state
if total_numsteps > args.num_steps:
break
writer.add_scalar('reward/train', episode_reward, i_episode)
print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))
if i_episode % 10 == 0 and args.eval == True:
avg_reward = 0.
episodes = 10
for _ in range(episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
action = agent.select_action(state, eval=True)
next_state, reward, done, _ = env.step(action)
episode_reward += reward
state = next_state
avg_reward += episode_reward
avg_reward /= episodes
writer.add_scalar('reward/test', avg_reward, i_episode)
print("----------------------------------------")
print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
print("----------------------------------------")
env.close()