from torch.optim import AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
get_scheduler )
Prepare for training
Preparation steps before starting training.
- preprocessing
- remove and rename columns
- specify to return pytorch tensors
- dataloaders
- model
- optimizer
- learning rate scheduler
= load_dataset('glue', 'sst2')
raw_datasets = 'bert-base-uncased'
checkpoint = AutoTokenizer.from_pretrained(checkpoint)
tokenizer
def f(x):
return tokenizer(x['sentence'], truncation=True)
Found cached dataset glue (/home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
= raw_datasets.map(f, batched=True)
tokenized_datasets = DataCollatorWithPadding(tokenizer=tokenizer) data_collator
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-989431ea55d09aff.arrow
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dcfc5e44e548784f.arrow
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-974d6c4aa35125e2.arrow
= tokenized_datasets.remove_columns(['sentence', 'idx'])
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
tokenized_datasets 'torch')
tokenized_datasets.set_format('train'].column_names tokenized_datasets[
['labels', 'input_ids', 'token_type_ids', 'attention_mask']
= DataLoader(
train_dataloader 'train'],
tokenized_datasets[=True,
shuffle=64,
batch_size=data_collator
collate_fn
)
= DataLoader(
eval_dataloader 'validation'],
tokenized_datasets[=8,
batch_size=data_collator
collate_fn
)
for batch in train_dataloader:
break
for k, v in batch.items()} {k: v.shape
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
{'labels': torch.Size([64]),
'input_ids': torch.Size([64, 36]),
'token_type_ids': torch.Size([64, 36]),
'attention_mask': torch.Size([64, 36])}
= AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model = AdamW(model.parameters(), lr=5e-5)
optimizer
= 3
epochs = epochs * len(train_dataloader)
training_steps
= get_scheduler(
lr_scheduler 'linear',
=optimizer,
optimizer=0,
num_warmup_steps=training_steps
num_training_steps )
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training loop
import torch
from tqdm.auto import tqdm
= torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device device
device(type='cuda')
model.to(device)
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=768, out_features=2, bias=True)
)
# use `tqdm` to add a progress bar
= tqdm(range(training_steps))
progress_bar
model.train()
for epoch in range(epochs):
for batch in train_dataloader:
= {k: v.to(device) for k, v in batch.items()}
batch = model(**batch)
outputs
= outputs.loss
loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
1) progress_bar.update(
Evaluation loop
import evaluate
= evaluate.load('glue', 'sst2')
metric eval()
model.
for batch in eval_dataloader:
= {k: v.to(device) for k, v in batch.items()}
batch
with torch.no_grad():
= model(**batch)
outputs
= outputs.logits
logits = torch.argmax(logits, axis=-1)
predictions =predictions, references=batch['labels'])
metric.add_batch(predictions
metric.compute()
{'accuracy': 0.930045871559633}