mirror of
https://github.com/wassname/openai-transformer-lm-gutenberg-erotic.git
synced 2026-06-27 16:10:19 +08:00
clean up multi gpu logic
This commit is contained in:
+2
-1
@@ -177,7 +177,8 @@ class LMHead(nn.Module):
|
||||
def __init__(self, model, cfg):
|
||||
super(LMHead, self).__init__()
|
||||
self.n_embd = cfg.n_embd
|
||||
self.decoder = lambda x: F.linear(x, model.embed.weight) # Tied weights
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.decoder.weight = model.embed.weight # Tied weights
|
||||
|
||||
def forward(self, h):
|
||||
# Truncated Language modeling logits (we remove the last token)
|
||||
|
||||
@@ -151,7 +151,7 @@ def run_epoch():
|
||||
compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
|
||||
n_updates += 1
|
||||
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
|
||||
log()
|
||||
log(save_dir, desc)
|
||||
|
||||
|
||||
argmax = lambda x: np.argmax(x, 1)
|
||||
@@ -194,7 +194,6 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--clf_pdrop', type=float, default=0.1)
|
||||
parser.add_argument('--l2', type=float, default=0.01)
|
||||
parser.add_argument('--vector_l2', action='store_true')
|
||||
parser.add_argument('--n_gpu', type=int, default=1)
|
||||
parser.add_argument('--opt', type=str, default='adam')
|
||||
parser.add_argument('--afn', type=str, default='gelu')
|
||||
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
|
||||
@@ -223,9 +222,11 @@ if __name__ == '__main__':
|
||||
desc = args.desc
|
||||
data_dir = args.data_dir
|
||||
log_dir = args.log_dir
|
||||
submission_dir = args.submission_dir
|
||||
|
||||
# torch.device object used throughout this script TODO add gpu setting
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
print("device", device, "n_gpu", n_gpu)
|
||||
|
||||
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
|
||||
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
|
||||
@@ -259,10 +260,11 @@ if __name__ == '__main__':
|
||||
|
||||
n_train = len(trY)
|
||||
n_valid = len(vaY)
|
||||
n_batch_train = args.n_batch * args.n_gpu
|
||||
n_batch_train = args.n_batch * n_gpu
|
||||
n_updates_total = (n_train // n_batch_train) * args.n_iter
|
||||
|
||||
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
|
||||
if n_gpu > 1:
|
||||
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
|
||||
|
||||
criterion = nn.CrossEntropyLoss(reduce=False)
|
||||
model_opt = OpenAIAdam(dh_model.parameters(),
|
||||
|
||||
Reference in New Issue
Block a user