Using Trainer API

PyTorch
Published

June 26, 2023

Limit train data to 5000 samples

The function load_dataset allows to load a subset of data.

import evaluate
import numpy as np

from datasets import load_dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    TrainingArguments,
    AutoModelForSequenceClassification,
    Trainer
)
train_ds, val_ds, test_ds = load_dataset('glue', 'sst2',
                            split=['train[:5000]', 'validation', 'test'])

raw_datasets = DatasetDict({
    'train': train_ds,
    'validation': val_ds,
    'test': test_ds
})

raw_datasets
Found cached dataset glue (/home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 5000
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

Tokenize datasets

checkpoint = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# no padding at this stage
def f(x):
    return tokenizer(x["sentence"], truncation=True)

tokenized_datasets = raw_datasets.map(f, batched=True)
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-28dc12396551a413.arrow
Loading cached processed dataset at /home/limin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-165bd9e0b1d649fa.arrow
tokenized_datasets
DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 5000
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1821
    })
})

Prepare for training

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
training_args = TrainingArguments('test-trainer', evaluation_strategy='epoch')
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def compute_metrics(predictions):
    logits, labels = predictions
    preds = np.argmax(logits, axis=-1)
    
    metric = evaluate.load('glue', 'mrpc')
    return metric.compute(predictions=preds, references=labels)

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

training

trainer.train()
/home/limin/conversational-ai-lab/venv/lib/python3.10/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
[1875/1875 03:38, Epoch 3/3]
Epoch Training Loss Validation Loss Accuracy F1
1 0.696600 0.693132 0.509174 0.674772
2 0.695000 0.697514 0.509174 0.674772
3 0.699100 0.374979 0.845183 0.853420

TrainOutput(global_step=1875, training_loss=0.669135546875, metrics={'train_runtime': 219.2111, 'train_samples_per_second': 68.427, 'train_steps_per_second': 8.553, 'total_flos': 228729840422880.0, 'train_loss': 0.669135546875, 'epoch': 3.0})