mirror of
https://github.com/wassname/pytorch-soft-actor-critic.git
synced 2026-06-27 15:16:26 +08:00
tidy
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
SHELL=/bin/bash
|
||||
python=/home/wassname/anaconda/envs/diygym3/bin/python
|
||||
date=2021-01-03_13-30-07
|
||||
LOGURU_LEVEL=INFO
|
||||
# ulimit -S -m 35000000
|
||||
# ulimit -S -v 35000000
|
||||
|
||||
run:
|
||||
LOGURU_LEVEL=INFO ${python} -m pdb main.py --cuda --automatic_entropy_tuning true --replay_size 15000 --load auto
|
||||
ulimit -S -m 65000000
|
||||
ulimit -S -v 65000000
|
||||
LOGURU_LEVEL=INFO ${python} main.py --cuda --automatic_entropy_tuning true --replay_size 50000 --load auto
|
||||
# ${python} -m pdb main.py --cuda --automatic_entropy_tuning true --replay_size 10000 --load auto --start_steps 200
|
||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --automatic_entropy_tuning true --replay_size 20000 --load auto
|
||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --alpha 0.1 --tau 1 --target_update_interval 1000
|
||||
# LOGURU_LEVEL=INFO ${python} main.py --demonstrations data/demonstrations --cuda --updates_per_step 2 --load auto --tau 1 --target_update_interval 1000 --policy Deterministic
|
||||
|
||||
@@ -13,26 +13,12 @@ import apple_gym.env
|
||||
import pickle
|
||||
from process_obs import ProcessObservation
|
||||
# from torchinfo import summary
|
||||
from tqdm.auto import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
from progress import RichTQDM
|
||||
from loguru import logger
|
||||
from rich import print
|
||||
from rich.logging import RichHandler
|
||||
from rich.progress import (
|
||||
ProgressColumn,
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
TextColumn,
|
||||
TransferSpeedColumn,
|
||||
TimeRemainingColumn,
|
||||
Progress,
|
||||
TaskID,
|
||||
TimeElapsedColumn,
|
||||
SpinnerColumn,
|
||||
Text
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO, handlers=[RichHandler(rich_tracebacks=True, markup=True)])
|
||||
logger.configure(handlers=[{"sink": RichHandler(markup=True),
|
||||
"format": "{message}"}])
|
||||
@@ -156,32 +142,11 @@ if args.demonstrations:
|
||||
total_numsteps = 0
|
||||
updates = 0
|
||||
|
||||
class SpeedColumn(ProgressColumn):
|
||||
"""Renders human readable transfer speed."""
|
||||
|
||||
def render(self, task: "Task") -> Text:
|
||||
"""Show data transfer speed."""
|
||||
speed = task.speed
|
||||
if speed is None:
|
||||
return Text("?", style="progress.data.speed")
|
||||
return Text(f"{speed:2.2f} it/s", style="progress.data.speed")
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
"[progress.description]{task.description}",
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
"[",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
',',
|
||||
SpeedColumn(),
|
||||
']',
|
||||
refresh_per_second=1, speed_estimate_period=360
|
||||
) as prog:
|
||||
with RichTQDM() as prog:
|
||||
task1 = prog.add_task("[red]steps", total=args.num_steps)
|
||||
task2 = prog.add_task("[red]updates", total=args.num_steps)
|
||||
task3 = prog.add_task("[red]test", total=args.num_steps)
|
||||
for i_episode in itertools.count(0):
|
||||
print('1')
|
||||
episode_reward = 0
|
||||
@@ -234,7 +199,7 @@ with Progress(
|
||||
logger.info("\nEpisode: {}, total numsteps: {}, episode steps: {}, reward: {}, updates: {}".format(i_episode, total_numsteps, episode_steps, round(episode_reward, 2), updates))
|
||||
prog.desc = "e: {}, r: {}, u: {}, m: {}".format(i_episode, round(episode_reward, 2), updates, len(memory))
|
||||
|
||||
if i_episode % 10 == 0 and args.eval is True:
|
||||
if (i_episode % 100 == 0) and (args.eval is True) and i_episode>0:
|
||||
avg_reward = 0.
|
||||
episodes = 10
|
||||
for _ in range(episodes):
|
||||
@@ -243,6 +208,7 @@ with Progress(
|
||||
done = False
|
||||
while not done:
|
||||
action = agent.select_action(state, evaluate=True)
|
||||
prog.update(task3, advance=1)
|
||||
|
||||
next_state, reward, done, _ = env.step(action)
|
||||
episode_reward += reward
|
||||
@@ -266,7 +232,7 @@ with Progress(
|
||||
if total_numsteps >= args.num_steps:
|
||||
break
|
||||
|
||||
if args.train:
|
||||
save(save_dir)
|
||||
if args.train:
|
||||
save(save_dir)
|
||||
|
||||
env.close()
|
||||
|
||||
+3
-2
@@ -122,7 +122,7 @@ class ProcessObservation(nn.Module):
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'data/nets/cornell-randsplit-rgbd-grconvnet3-drop1-ch16/epoch_30_iou_0.97.pt'
|
||||
)
|
||||
self.feature_extractor = GenerativeResnet3Headless()#.half()
|
||||
self.feature_extractor = GenerativeResnet3Headless().half()
|
||||
self.feature_extractor.load_state_dict(state_dict=torch.load(grconvnet3_path), strict=False)
|
||||
|
||||
old_img_size = (res[0], res[1], 8)
|
||||
@@ -145,7 +145,8 @@ class ProcessObservation(nn.Module):
|
||||
|
||||
# make a batch
|
||||
x = torch.cat([base_rgbd, arm_rgbd], 0)
|
||||
x = x.permute((0, 3, 1, 2)) # to ((-1, 4, x, y))
|
||||
x = x.permute((0, 3, 1, 2)) # to ((-1, 4, x, y))
|
||||
x = x.half()
|
||||
h = self.feature_extractor(x)
|
||||
|
||||
# undo fake batch
|
||||
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
from rich.progress import (
|
||||
ProgressColumn,
|
||||
BarColumn,
|
||||
DownloadColumn,
|
||||
TextColumn,
|
||||
TransferSpeedColumn,
|
||||
TimeRemainingColumn,
|
||||
Progress,
|
||||
TaskID,
|
||||
TimeElapsedColumn,
|
||||
SpinnerColumn,
|
||||
Text
|
||||
)
|
||||
|
||||
class SpeedColumn(ProgressColumn):
|
||||
"""Renders human readable transfer speed."""
|
||||
|
||||
def render(self, task: "Task") -> Text:
|
||||
"""Show data transfer speed."""
|
||||
speed = task.speed
|
||||
if speed is None:
|
||||
return Text("?", style="progress.data.speed")
|
||||
return Text(f"{speed:2.2f} it/s", style="progress.data.speed")
|
||||
|
||||
def RichTQDM():
|
||||
return Progress(
|
||||
SpinnerColumn(),
|
||||
"[progress.description]{task.description}",
|
||||
BarColumn(),
|
||||
TextColumn("{task.completed}/{task.total}"),
|
||||
"[",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
',',
|
||||
SpeedColumn(),
|
||||
']',
|
||||
refresh_per_second=.1, speed_estimate_period=30
|
||||
)
|
||||
+3
-3
@@ -21,8 +21,8 @@ def unpack(data):
|
||||
return data
|
||||
|
||||
|
||||
class ReplayMemory2:
|
||||
def __init__(self, capacity, seed):
|
||||
class ReplayMemory:
|
||||
def __init__(self, capacity, seed, *args, **kwargs):
|
||||
random.seed(seed)
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
@@ -56,7 +56,7 @@ class ReplayMemory2:
|
||||
self.position = len(self.buffer)
|
||||
|
||||
|
||||
class ReplayMemory:
|
||||
class ReplayMemory2:
|
||||
def __init__(self, capacity, seed, observation_dim, action_dim):
|
||||
random.seed(seed)
|
||||
self.capacity = capacity
|
||||
|
||||
Reference in New Issue
Block a user