mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 17:01:47 +08:00
play and gitignore
This commit is contained in:
+203
@@ -1,3 +1,206 @@
|
||||
__pycache__/
|
||||
runs/
|
||||
data
|
||||
models/
|
||||
|
||||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/code,python,jupyternotebooks,windows,linux
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=code,python,jupyternotebooks,windows,linux
|
||||
|
||||
### Code ###
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
*.code-workspace
|
||||
|
||||
### JupyterNotebooks ###
|
||||
# gitignore template for Jupyter Notebooks
|
||||
# website: http://jupyter.org/
|
||||
|
||||
.ipynb_checkpoints
|
||||
*/.ipynb_checkpoints/*
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# Remove previous ipynb_checkpoints
|
||||
# git rm -r .ipynb_checkpoints/
|
||||
|
||||
### Linux ###
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
pytestdebug.log
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
doc/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
|
||||
# IPython
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# # Environments
|
||||
# .env
|
||||
# .venv
|
||||
# # env/
|
||||
# venv/
|
||||
# ENV/
|
||||
# env.bak/
|
||||
# venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/code,python,jupyternotebooks,windows,linux
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
run:
|
||||
python main.py --demonstrations data/demonstrations
|
||||
python main.py --demonstrations data/demonstrations --tau 1 --target_update_interval 100
|
||||
|
||||
play:
|
||||
python play.py --load-actor models/actor_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl --load-critic models/critic_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl --render
|
||||
|
||||
@@ -26,6 +26,9 @@ usage: main.py [-h] [--env-name ENV_NAME] [--policy POLICY] [--eval EVAL]
|
||||
|
||||
(Note: There is no need for setting Temperature(`--alpha`) if `--automatic_entropy_tuning` is True.)
|
||||
|
||||
`make run`
|
||||
|
||||
--------------
|
||||
#### For SAC
|
||||
|
||||
```
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
def load_demonstrations(mem: ReplayMemory, recordings: Path):
|
||||
records = get_recordings(str(recordings))
|
||||
print('picks in recordings', sum(records['reward']>10))
|
||||
ends=records["episodes_end_point"]
|
||||
for i in tqdm(range(len(ends)-1), desc='loading demonstrations'):
|
||||
a = ends[i]
|
||||
|
||||
@@ -13,7 +13,7 @@ import pickle
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
||||
parser.add_argument('--env-name', default="ApplePick-v0",
|
||||
parser.add_argument('-e', '--env-name', default="ApplePick-v0",
|
||||
help='Mujoco Gym environment (default: ApplePick-v0)')
|
||||
parser.add_argument('--policy', default="Gaussian",
|
||||
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
|
||||
@@ -28,8 +28,8 @@ parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
|
||||
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
|
||||
help='Temperature parameter α determines the relative importance of the entropy\
|
||||
term against the reward (default: 0.2)')
|
||||
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
|
||||
help='Automaically adjust α (default: False)')
|
||||
parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
|
||||
help='Automaically adjust α (default: True)')
|
||||
parser.add_argument('--seed', type=int, default=123456, metavar='N',
|
||||
help='random seed (default: 123456)')
|
||||
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
||||
@@ -50,11 +50,15 @@ parser.add_argument('--cuda', action="store_true",
|
||||
help='run on CUDA (default: False)')
|
||||
parser.add_argument('--demonstrations', default=False,
|
||||
help='Load demonstrations from https://github.com/erfanMhi/gym-recording-modified')
|
||||
parser.add_argument('-l', '--load', default=False,
|
||||
help='Load models')
|
||||
parser.add_argument('-r', '--render', action="store_true",
|
||||
help='show')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Environment
|
||||
# env = NormalizedActions(gym.make(args.env_name))
|
||||
env = gym.make(args.env_name, render=False)
|
||||
env = gym.make(args.env_name, render=args.render)
|
||||
env.seed(args.seed)
|
||||
env.action_space.seed(args.seed)
|
||||
|
||||
@@ -64,7 +68,7 @@ np.random.seed(args.seed)
|
||||
# Agent
|
||||
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
||||
|
||||
#Tesnorboard
|
||||
#Tensorboard
|
||||
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
||||
args.policy, "autotune" if args.automatic_entropy_tuning else "")
|
||||
writer = SummaryWriter('runs/' + log_name)
|
||||
@@ -77,8 +81,13 @@ if args.demonstrations:
|
||||
def save():
|
||||
agent.save_model(args.env_name, "", "models/actor_" + log_name+'.pkl', "models/critic_"+log_name+'.pkl')
|
||||
memory.save(args.env_name, "", "models/memory_" + log_name +'.pkl')
|
||||
# agent.load_model("models/actor_" + log_name + '.pkl', "models/critic_" + log_name + '.pkl')
|
||||
# memory.load("models/memory_" + log_name +'.pkl')
|
||||
|
||||
def load(log_name):
|
||||
agent.load_model("models/actor_" + log_name + '.pkl', "models/critic_" + log_name + '.pkl')
|
||||
memory.load("models/memory_" + log_name +'.pkl')
|
||||
|
||||
if args.load:
|
||||
load(args.load)
|
||||
|
||||
# Training Loop
|
||||
total_numsteps = 0
|
||||
@@ -107,15 +116,19 @@ with tqdm(unit='frames') as prog:
|
||||
writer.add_scalar('loss/critic_2', critic_2_loss, updates)
|
||||
writer.add_scalar('loss/policy', policy_loss, updates)
|
||||
writer.add_scalar('loss/entropy_loss', ent_loss, updates)
|
||||
writer.add_scalar('entropy_temprature/alpha', alpha, updates)
|
||||
writer.add_scalar('entropy_temperature/alpha', alpha, updates)
|
||||
updates += 1
|
||||
|
||||
next_state, reward, done, _ = env.step(action) # Step
|
||||
next_state, reward, done, info = env.step(action) # Step
|
||||
episode_steps += 1
|
||||
total_numsteps += 1
|
||||
episode_reward += reward
|
||||
|
||||
prog.update(1)
|
||||
prog.desc = f'er={episode_reward/episode_steps:2.2f}'
|
||||
# for k, v in info.items():
|
||||
# if len(v) == 1:
|
||||
# writer.add_scalar('env/'+k, v, episode_steps)
|
||||
|
||||
# Ignore the "done" signal if it comes from hitting the time horizon.
|
||||
# (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
|
||||
|
||||
@@ -5,14 +5,12 @@ import numpy as np
|
||||
import itertools
|
||||
import torch
|
||||
from sac import SAC
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from replay_memory import ReplayMemory
|
||||
from load_demonstrations import load_demonstrations
|
||||
from tqdm.auto import tqdm
|
||||
import apple_gym.env
|
||||
import pickle
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
|
||||
parser.add_argument('--env-name', default="ApplePick-v0",
|
||||
parser.add_argument('-e', '--env-name', default="ApplePick-v0",
|
||||
help='Mujoco Gym environment (default: ApplePick-v0)')
|
||||
parser.add_argument('--policy', default="Gaussian",
|
||||
help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
|
||||
@@ -27,8 +25,8 @@ parser.add_argument('--lr', type=float, default=0.0003, metavar='G',
|
||||
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
|
||||
help='Temperature parameter α determines the relative importance of the entropy\
|
||||
term against the reward (default: 0.2)')
|
||||
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
|
||||
help='Automaically adjust α (default: False)')
|
||||
parser.add_argument('--automatic_entropy_tuning', type=bool, default=True, metavar='G',
|
||||
help='Automaically adjust α (default: True)')
|
||||
parser.add_argument('--seed', type=int, default=123456, metavar='N',
|
||||
help='random seed (default: 123456)')
|
||||
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
||||
@@ -49,29 +47,48 @@ parser.add_argument('--cuda', action="store_true",
|
||||
help='run on CUDA (default: False)')
|
||||
parser.add_argument('--demonstrations', default=False,
|
||||
help='Load demonstrations from https://github.com/erfanMhi/gym-recording-modified')
|
||||
parser.add_argument('-l', '--load', default=False,
|
||||
help='Load models')
|
||||
parser.add_argument('-r', '--render', action="store_true",
|
||||
help='show')
|
||||
parser.add_argument('--load-actor', type=str, help='e.g. models/actor_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl')
|
||||
parser.add_argument('--load-critic', type=str, help='e.g. models/critic_2021-01-02_10-26-23_SAC_ApplePick-v0_Gaussian_autotune.pkl')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Environment
|
||||
# env = NormalizedActions(gym.make(args.env_name))
|
||||
env = gym.make(args.env_name, render=True)
|
||||
env.seed(args.seed)
|
||||
env.action_space.seed(args.seed)
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Environment
|
||||
# env = NormalizedActions(gym.make(args.env_name))
|
||||
env = gym.make(args.env_name, render=args.render)
|
||||
env.seed(args.seed)
|
||||
env.action_space.seed(args.seed)
|
||||
|
||||
|
||||
# Agent
|
||||
agent = SAC(env.observation_space.shape[0], env.action_space, args)
|
||||
agent.load_model(args.load_actor, args.load_critic)
|
||||
|
||||
#Tesnorboard
|
||||
log_name = '{}_SAC_{}_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
|
||||
args.policy, "autotune" if args.automatic_entropy_tuning else "")
|
||||
writer = SummaryWriter('runs/' + log_name)
|
||||
# Test
|
||||
avg_reward = 0.
|
||||
episodes = 10
|
||||
for _ in tqdm(range(episodes)):
|
||||
state = env.reset()
|
||||
episode_reward = 0
|
||||
done = False
|
||||
while not done:
|
||||
action = agent.select_action(state, evaluate=True)
|
||||
|
||||
# Memory
|
||||
memory=ReplayMemory(args.replay_size, args.seed)
|
||||
if args.demonstrations:
|
||||
load_demonstrations(memory, args.demonstrations)
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
episode_reward += reward
|
||||
|
||||
agent.load_model("models/actor_" + log_name + '.pkl', "models/critic_" + log_name + '.pkl')
|
||||
memory.load("models/memory_" + log_name +'.pkl')
|
||||
|
||||
state = next_state
|
||||
avg_reward += episode_reward
|
||||
avg_reward /= episodes
|
||||
|
||||
print("----------------------------------------")
|
||||
print("Test Episodes: {}, Avg. Reward: {}".format(episodes, round(avg_reward, 2)))
|
||||
print("----------------------------------------")
|
||||
|
||||
env.close()
|
||||
|
||||
Reference in New Issue
Block a user