Training and Evaluation Loop

PyTorch
Published

June 29, 2023

Prepare for training

Preparation steps before starting training.

  • preprocessing
    • remove and rename columns
    • specify to return pytorch tensors
  • dataloaders
  • model
  • optimizer
  • learning rate scheduler
from torch.optim import AdamW
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_scheduler
)
raw_datasets = load_dataset('glue', 'sst2')
checkpoint = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def f(x):
    return tokenizer(x['sentence'], truncation=True)
Found cached dataset glue (/home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
tokenized_datasets = raw_datasets.map(f, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-989431ea55d09aff.arrow
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dcfc5e44e548784f.arrow
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-974d6c4aa35125e2.arrow
tokenized_datasets = tokenized_datasets.remove_columns(['sentence', 'idx'])
tokenized_datasets = tokenized_datasets.rename_column('label', 'labels')
tokenized_datasets.set_format('torch')
tokenized_datasets['train'].column_names
['labels', 'input_ids', 'token_type_ids', 'attention_mask']
train_dataloader = DataLoader(
    tokenized_datasets['train'],
    shuffle=True,
    batch_size=64,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    tokenized_datasets['validation'],
    batch_size=8,
    collate_fn=data_collator
)

for batch in train_dataloader:
    break
    
{k: v.shape for k, v in batch.items()}
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
{'labels': torch.Size([64]),
 'input_ids': torch.Size([64, 36]),
 'token_type_ids': torch.Size([64, 36]),
 'attention_mask': torch.Size([64, 36])}
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
optimizer = AdamW(model.parameters(), lr=5e-5)

epochs = 3
training_steps = epochs * len(train_dataloader)

lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=training_steps
)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Training loop

import torch
from tqdm.auto import tqdm
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device
device(type='cuda')
model.to(device)
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
)
# use `tqdm` to add a progress bar
progress_bar = tqdm(range(training_steps))
model.train()

for epoch in range(epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        progress_bar.update(1)

Evaluation loop

import evaluate
metric = evaluate.load('glue', 'sst2')
model.eval()

for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    
    with torch.no_grad():
        outputs = model(**batch)
        
    logits = outputs.logits
    predictions = torch.argmax(logits, axis=-1)
    metric.add_batch(predictions=predictions, references=batch['labels'])
    
metric.compute()
{'accuracy': 0.930045871559633}