mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 22:38:16 +08:00
[RaySGD] revised existing transformer example to work with transformers>=3.0 (#9661)
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -36,10 +36,11 @@ from torch.utils.data import DataLoader, RandomSampler
|
||||
from tqdm import trange
|
||||
import torch.distributed as dist
|
||||
|
||||
from transformers import (MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, AdamW,
|
||||
AutoConfig, AutoModelForSequenceClassification,
|
||||
AutoTokenizer, get_linear_schedule_with_warmup,
|
||||
HfArgumentParser, TrainingArguments)
|
||||
from transformers import (MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AdamW, AutoConfig,
|
||||
AutoModelForSequenceClassification, AutoTokenizer,
|
||||
get_linear_schedule_with_warmup, HfArgumentParser,
|
||||
TrainingArguments)
|
||||
from transformers import glue_output_modes as output_modes
|
||||
from transformers import glue_processors as processors
|
||||
|
||||
@@ -58,7 +59,8 @@ MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys())
|
||||
(tuple(key for key in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()
|
||||
if key.startswith(conf.model_type))
|
||||
for conf in MODEL_CONFIG_CLASSES),
|
||||
(),
|
||||
)
|
||||
@@ -78,11 +80,11 @@ def announce_training(args, dataset_len, t_total):
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", dataset_len)
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d",
|
||||
args.per_gpu_train_batch_size)
|
||||
logger.info(" Instantaneous batch size per device = %d",
|
||||
args.per_device_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accum) = %d",
|
||||
args.per_gpu_train_batch_size * args.gradient_accumulation_steps *
|
||||
args.per_device_train_batch_size * args.gradient_accumulation_steps *
|
||||
args.num_workers,
|
||||
)
|
||||
logger.info(" Gradient Accumulation steps = %d",
|
||||
@@ -154,7 +156,7 @@ def data_creator(config):
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.per_gpu_train_batch_size)
|
||||
batch_size=args.per_device_train_batch_size)
|
||||
|
||||
|
||||
class TransformerOperator(TrainingOperator):
|
||||
|
||||
Reference in New Issue
Block a user