mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 16:46:28 +08:00
play and gitignore
This commit is contained in:
+203
@@ -1,3 +1,206 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
runs/
|
runs/
|
||||||
data
|
data
|
||||||
|
models/
|
||||||
|
|
||||||
|
|
||||||
|
# Created by https://www.toptal.com/developers/gitignore/api/code,python,jupyternotebooks,windows,linux
|
||||||
|
# Edit at https://www.toptal.com/developers/gitignore?templates=code,python,jupyternotebooks,windows,linux
|
||||||
|
|
||||||
|
### Code ###
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/settings.json
|
||||||
|
!.vscode/tasks.json
|
||||||
|
!.vscode/launch.json
|
||||||
|
!.vscode/extensions.json
|
||||||
|
*.code-workspace
|
||||||
|
|
||||||
|
### JupyterNotebooks ###
|
||||||
|
# gitignore template for Jupyter Notebooks
|
||||||
|
# website: http://jupyter.org/
|
||||||
|
|
||||||
|
.ipynb_checkpoints
|
||||||
|
*/.ipynb_checkpoints/*
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# Remove previous ipynb_checkpoints
|
||||||
|
# git rm -r .ipynb_checkpoints/
|
||||||
|
|
||||||
|
### Linux ###
|
||||||
|
*~
|
||||||
|
|
||||||
|
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||||
|
.fuse_hidden*
|
||||||
|
|
||||||
|
# KDE directory preferences
|
||||||
|
.directory
|
||||||
|
|
||||||
|
# Linux trash folder which might appear on any partition or disk
|
||||||
|
.Trash-*
|
||||||
|
|
||||||
|
# .nfs files are created when an open file is removed but is still being accessed
|
||||||
|
.nfs*
|
||||||
|
|
||||||
|
### Python ###
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
pytestdebug.log
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
doc/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# # Environments
|
||||||
|
# .env
|
||||||
|
# .venv
|
||||||
|
# # env/
|
||||||
|
# venv/
|
||||||
|
# ENV/
|
||||||
|
# env.bak/
|
||||||
|
# venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
### Windows ###
|
||||||
|
# Windows thumbnail cache files
|
||||||
|
Thumbs.db
|
||||||
|
Thumbs.db:encryptable
|
||||||
|
ehthumbs.db
|
||||||
|
ehthumbs_vista.db
|
||||||
|
|
||||||
|
# Dump file
|
||||||
|
*.stackdump
|
||||||
|
|
||||||
|
# Folder config file
|
||||||
|
[Dd]esktop.ini
|
||||||
|
|
||||||
|
# Recycle Bin used on file shares
|
||||||
|
$RECYCLE.BIN/
|
||||||
|
|
||||||
|
# Windows Installer files
|
||||||
|
*.cab
|
||||||
|
*.msi
|
||||||
|
*.msix
|
||||||
|
*.msm
|
||||||
|
*.msp
|
||||||
|
|
||||||
|
# Windows shortcuts
|
||||||
|
*.lnk
|
||||||
|
|
||||||
|
# End of https://www.toptal.com/developers/gitignore/api/code,python,jupyternotebooks,windows,linux
|
||||||
|
|||||||
@@ -1,2 +1,5 @@
|
|||||||
run:
|
run:
|
||||||
python main.py --demonstrations data/demonstrations
|
python main.py --demonstrations data/demonstrations --tau 1 --target_update_interval 100
|
||||||
|
|
||||||
|
play:
|
||||||
|
python play.py --load-actor models/actor_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl --load-critic models/critic_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl --render
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ usage: main.py [-h] [--env-name ENV_NAME] [--policy POLICY] [--eval EVAL]
|
|||||||
|
|
||||||
(Note: There is no need for setting Temperature(`--alpha`) if `--automatic_entropy_tuning` is True.)
|
(Note: There is no need for setting Temperature(`--alpha`) if `--automatic_entropy_tuning` is True.)
|
||||||
|
|
||||||
|
`make run`
|
||||||
|
|
||||||
|
--------------
|
||||||
#### For SAC
|
#### For SAC
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
def load_demonstrations(mem: ReplayMemory, recordings: Path):
|
def load_demonstrations(mem: ReplayMemory, recordings: Path):
|
||||||
records = get_recordings(str(recordings))
|
records = get_recordings(str(recordings))
|
||||||
|
print('picks in recordings', sum(records['reward']>10))
|
||||||
ends=records["episodes_end_point"]
|
ends=records["episodes_end_point"]
|
||||||
for i in tqdm(range(len(ends)-1), desc='loading demonstrations'):
|
for i in tqdm(range(len(ends)-1), desc='loading demonstrations'):
|
||||||
a = ends[i]
|
a = ends[i]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import pickle
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
||||||
parser.add_argument('--env-name', default="ApplePick-v0",
|
parser.add_argument('-e', '--env-name', default="ApplePick-v0",
|
||||||
help='Mujoco Gym environment (default: ApplePick-v0)')
|
help='Mujoco Gym environment (default: ApplePick-v0)')
|
||||||
parser.add_argument('--policy', default="Gaussian",
|
parser.add_argument('--policy', default="Gaussian",
|
||||||
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
|
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
|
||||||
@@ -28,8 +28,8 @@ parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
|
|||||||
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
|
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
|
||||||
help='Temperature parameter α determines the relative importance of the entropy\
|
help='Temperature parameter α determines the relative importance of the entropy\
|
||||||
term against the reward (default: 0.2)')
|
term against the reward (default: 0.2)')
|
||||||
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
|
parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
|
||||||
help='Automaically adjust α (default: False)')
|
help='Automaically adjust α (default: True)')
|
||||||
parser.add_argument('--seed', type=int, default=123456, metavar='N',
|
parser.add_argument('--seed', type=int, default=123456, metavar='N',
|
||||||
help='random seed (default: 123456)')
|
help='random seed (default: 123456)')
|
||||||
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
||||||
@@ -50,11 +50,15 @@ parser.add_argument('--cuda', action="store_true",
|
|||||||
help='run on CUDA (default: False)')
|
help='run on CUDA (default: False)')
|
||||||
parser.add_argument('--demonstrations', default=False,
|
parser.add_argument('--demonstrations', default=False,
|
||||||
help='Load demonstrations from https://github.com/erfanMhi/gym-recording-modified')
|
help='Load demonstrations from https://github.com/erfanMhi/gym-recording-modified')
|
||||||
|
parser.add_argument('-l', '--load', default=False,
|
||||||
|
help='Load models')
|
||||||
|
parser.add_argument('-r', '--render', action="store_true",
|
||||||
|
help='show')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
# env = NormalizedActions(gym.make(args.env_name))
|
# env = NormalizedActions(gym.make(args.env_name))
|
||||||
env = gym.make(args.env_name, render=False)
|
env = gym.make(args.env_name, render=args.render)
|
||||||
env.seed(args.seed)
|
env.seed(args.seed)
|
||||||
env.action_space.seed(args.seed)
|
env.action_space.seed(args.seed)
|
||||||
|
|
||||||
@@ -64,7 +68,7 @@ np.random.seed(args.seed)
|
|||||||
# Agent
|
# Agent
|
||||||
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
||||||
|
|
||||||
#Tesnorboard
|
#Tensorboard
|
||||||
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
||||||
args.policy, "autotune" if args.automatic_entropy_tuning else "")
|
args.policy, "autotune" if args.automatic_entropy_tuning else "")
|
||||||
writer = SummaryWriter('runs/' + log_name)
|
writer = SummaryWriter('runs/' + log_name)
|
||||||
@@ -77,8 +81,13 @@ if args.demonstrations:
|
|||||||
def save():
|
def save():
|
||||||
agent.save_model(args.env_name, "", "models/actor_" + log_name+'.pkl', "models/critic_"+log_name+'.pkl')
|
agent.save_model(args.env_name, "", "models/actor_" + log_name+'.pkl', "models/critic_"+log_name+'.pkl')
|
||||||
memory.save(args.env_name, "", "models/memory_" + log_name +'.pkl')
|
memory.save(args.env_name, "", "models/memory_" + log_name +'.pkl')
|
||||||
# agent.load_model("models/actor_" + log_name + '.pkl', "models/critic_" + log_name + '.pkl')
|
|
||||||
# memory.load("models/memory_" + log_name +'.pkl')
|
def load(log_name):
|
||||||
|
agent.load_model("models/actor_" + log_name + '.pkl', "models/critic_" + log_name + '.pkl')
|
||||||
|
memory.load("models/memory_" + log_name +'.pkl')
|
||||||
|
|
||||||
|
if args.load:
|
||||||
|
load(args.load)
|
||||||
|
|
||||||
# Training Loop
|
# Training Loop
|
||||||
total_numsteps = 0
|
total_numsteps = 0
|
||||||
@@ -107,15 +116,19 @@ with tqdm(unit='frames') as prog:
|
|||||||
writer.add_scalar('loss/critic_2', critic_2_loss, updates)
|
writer.add_scalar('loss/critic_2', critic_2_loss, updates)
|
||||||
writer.add_scalar('loss/policy', policy_loss, updates)
|
writer.add_scalar('loss/policy', policy_loss, updates)
|
||||||
writer.add_scalar('loss/entropy_loss', ent_loss, updates)
|
writer.add_scalar('loss/entropy_loss', ent_loss, updates)
|
||||||
writer.add_scalar('entropy_temprature/alpha', alpha, updates)
|
writer.add_scalar('entropy_temperature/alpha', alpha, updates)
|
||||||
updates += 1
|
updates += 1
|
||||||
|
|
||||||
next_state, reward, done, _ = env.step(action) # Step
|
next_state, reward, done, info = env.step(action) # Step
|
||||||
episode_steps += 1
|
episode_steps += 1
|
||||||
total_numsteps += 1
|
total_numsteps += 1
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
|
|
||||||
prog.update(1)
|
prog.update(1)
|
||||||
prog.desc = f'er={episode_reward/episode_steps:2.2f}'
|
prog.desc = f'er={episode_reward/episode_steps:2.2f}'
|
||||||
|
# for k, v in info.items():
|
||||||
|
# if len(v) == 1:
|
||||||
|
# writer.add_scalar('env/'+k, v, episode_steps)
|
||||||
|
|
||||||
# Ignore the "done" signal if it comes from hitting the time horizon.
|
# Ignore the "done" signal if it comes from hitting the time horizon.
|
||||||
# (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
|
# (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
|
||||||
|
|||||||
@@ -5,14 +5,12 @@ import numpy as np
|
|||||||
import itertools
|
import itertools
|
||||||
import torch
|
import torch
|
||||||
from sac import SAC
|
from sac import SAC
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from tqdm.auto import tqdm
|
||||||
from replay_memory import ReplayMemory
|
|
||||||
from load_demonstrations import load_demonstrations
|
|
||||||
import apple_gym.env
|
import apple_gym.env
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
||||||
parser.add_argument('--env-name', default="ApplePick-v0",
|
parser.add_argument('-e', '--env-name', default="ApplePick-v0",
|
||||||
help='Mujoco Gym environment (default: ApplePick-v0)')
|
help='Mujoco Gym environment (default: ApplePick-v0)')
|
||||||
parser.add_argument('--policy', default="Gaussian",
|
parser.add_argument('--policy', default="Gaussian",
|
||||||
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
|
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
|
||||||
@@ -27,8 +25,8 @@ parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
|
|||||||
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
|
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
|
||||||
help='Temperature parameter α determines the relative importance of the entropy\
|
help='Temperature parameter α determines the relative importance of the entropy\
|
||||||
term against the reward (default: 0.2)')
|
term against the reward (default: 0.2)')
|
||||||
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
|
parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
|
||||||
help='Automaically adjust α (default: False)')
|
help='Automaically adjust α (default: True)')
|
||||||
parser.add_argument('--seed', type=int, default=123456, metavar='N',
|
parser.add_argument('--seed', type=int, default=123456, metavar='N',
|
||||||
help='random seed (default: 123456)')
|
help='random seed (default: 123456)')
|
||||||
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
||||||
@@ -49,29 +47,48 @@ parser.add_argument('--cuda', action="store_true",
|
|||||||
help='run on CUDA (default: False)')
|
help='run on CUDA (default: False)')
|
||||||
parser.add_argument('--demonstrations', default=False,
|
parser.add_argument('--demonstrations', default=False,
|
||||||
help='Load demonstrations from https://github.com/erfanMhi/gym-recording-modified')
|
help='Load demonstrations from https://github.com/erfanMhi/gym-recording-modified')
|
||||||
|
parser.add_argument('-l', '--load', default=False,
|
||||||
|
help='Load models')
|
||||||
|
parser.add_argument('-r', '--render', action="store_true",
|
||||||
|
help='show')
|
||||||
|
parser.add_argument('--load-actor', type=str, help='e.g. models/actor_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl')
|
||||||
|
parser.add_argument('--load-critic', type=str, help='e.g. models/critic_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Environment
|
|
||||||
# env = NormalizedActions(gym.make(args.env_name))
|
|
||||||
env = gym.make(args.env_name, render=True)
|
|
||||||
env.seed(args.seed)
|
|
||||||
env.action_space.seed(args.seed)
|
|
||||||
|
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
# env = NormalizedActions(gym.make(args.env_name))
|
||||||
|
env = gym.make(args.env_name, render=args.render)
|
||||||
|
env.seed(args.seed)
|
||||||
|
env.action_space.seed(args.seed)
|
||||||
|
|
||||||
|
|
||||||
# Agent
|
# Agent
|
||||||
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
||||||
|
agent.load_model(args.load_actor, args.load_critic)
|
||||||
|
|
||||||
#Tesnorboard
|
# Test
|
||||||
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
avg_reward = 0.
|
||||||
args.policy, "autotune" if args.automatic_entropy_tuning else "")
|
episodes = 10
|
||||||
writer = SummaryWriter('runs/' + log_name)
|
for _ in tqdm(range(episodes)):
|
||||||
|
state = env.reset()
|
||||||
|
episode_reward = 0
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action = agent.select_action(state, evaluate=True)
|
||||||
|
|
||||||
# Memory
|
next_state, reward, done, _ = env.step(action)
|
||||||
memory=ReplayMemory(args.replay_size, args.seed)
|
episode_reward += reward
|
||||||
if args.demonstrations:
|
|
||||||
load_demonstrations(memory, args.demonstrations)
|
|
||||||
|
|
||||||
agent.load_model("models/actor_" + log_name + '.pkl', "models/critic_" + log_name + '.pkl')
|
|
||||||
memory.load("models/memory_" + log_name +'.pkl')
|
state = next_state
|
||||||
|
avg_reward += episode_reward
|
||||||
|
avg_reward /= episodes
|
||||||
|
|
||||||
|
print("----------------------------------------")
|
||||||
|
print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
|
||||||
|
print("----------------------------------------")
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user