Fine-tuning Mistral-7B-Instruct-v0.2 for synthetic datasets generation

Fine-tuning
Mistral
Published

January 10, 2024

Fine-tune Mistral-7B-Instruct-v0.2 for generating prompts based on given texts. The fine-tuned model could be used to generate synthetic datasets for a certain domain, which could be used to fine-tune a model for domain specific tasks.

Dataset used for fine-tuning: Alpaca-GPT-4 dataset

1. Load model

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_name = "mistralai/Mistral-7B-Instruct-v0.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    use_cache=False,
    attn_implementation='flash_attention_2',
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model
MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralFlashAttention2(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

2. Load dataset

from datasets import load_dataset
dataset_name = "c-s-ale/alpaca-gpt4-data"
dataset = load_dataset(dataset_name, split='train[:6000]')

SEED = 42

train_test_ds = dataset.train_test_split(test_size=1000, seed=SEED)
train_ds = train_test_ds['train']
test_ds = train_test_ds['test']

train_ds, test_ds
(Dataset({
     features: ['instruction', 'input', 'output'],
     num_rows: 5000
 }),
 Dataset({
     features: ['instruction', 'input', 'output'],
     num_rows: 1000
 }))

2.1 prompt formatting

def prompt_formatting_fn(example, training=True):
    prompt_template_1 = """
Your task is to generate a concise prompt for querying a large language model so that \
the model can output the following response.

Response:
{output}
    """.strip()
    
    prompt_template_2 = """
Your task is to generate a concise prompt for querying a large language model so that \
the model can output the following response. \
Use the provided context to help you create the prompt.

Response:
{output}

Context:
{input}
    """.strip()

    input = example['input']
    output = example['output']
    if input is not None and len(input) > 0:
        messages = [
            {'role': 'user', 'content': prompt_template_2.format(output=output, input=input)}
        ]
        
    else:
        messages = [
            {'role': 'user', 'content': prompt_template_1.format(output=output)}
        ]
    if training:
        messages.append(
            {'role': 'assistant', 'content': example['instruction']}
        )
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt
print(prompt_formatting_fn(train_ds[1]))
<s>[INST] Your task is to generate a concise prompt for querying a large language model so that the model can output the following response.

Response:
A pine tree is an evergreen conifer belonging to the genus Pinus, in the family Pinaceae. This type of tree is characterized by its needle-like leaves, which grow in clusters and are typically 1-8 inches long. Pine trees have a distinct fragrance and produce cones that contain the seeds of the plant. The bark of a pine tree is usually thick and scaly, with deep furrows, providing protection from the elements. Pine trees can grow to be very tall, sometimes reaching over 80 feet in height, and have a conical shape with branches that are often level or slightly ascending. They are also known for their longevity, with some species capable of living for hundreds or thousands of years. [/INST]Describe the attributes of a pine tree.</s>

3. Train

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

lora_config = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

base_model = prepare_model_for_kbit_training(base_model)
model = get_peft_model(base_model, lora_config)
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='finetuning_output',
    num_train_epochs=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant"
)
from trl import SFTTrainer

max_seq_length = 4096

trainer = SFTTrainer(
    model=model,
    train_dataset=train_ds,
    peft_config=lora_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=prompt_formatting_fn,
    args=training_args,
)
/home/jovyan/.local/lib/python3.11/site-packages/trl/trainer/utils.py:434: UserWarning: The passed formatting_func has more than one argument. Usually that function should have a single argument `example` which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing.
  warnings.warn(
/home/jovyan/.local/lib/python3.11/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: 
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
  warnings.warn(
trainer.train()
/home/jovyan/.local/lib/python3.11/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
[16/16 14:40, Epoch 0/1]
Step Training Loss
10 1.029300

TrainOutput(global_step=16, training_loss=0.9475531280040741, metrics={'train_runtime': 939.2543, 'train_samples_per_second': 0.278, 'train_steps_per_second': 0.017, 'total_flos': 4.579193628760474e+16, 'train_loss': 0.9475531280040741, 'epoch': 0.97})
model_folder = 'mistral-7b-instruct-v0.2-sft'

trainer.save_model(model_folder)

4. Test Peft adaptor

from peft import AutoPeftModelForCausalLM

peft_model = AutoPeftModelForCausalLM.from_pretrained(
    model_folder,
    quantization_config=bnb_config,
    attn_implementation='flash_attention_2',
    device_map='auto'
)

tokenizer = AutoTokenizer.from_pretrained(model_folder)
def generate(prompt, max_new_tokens=256, model=model, tokenizer=tokenizer):
    tokenized_prompt = tokenizer(prompt, return_tensors="pt")
    input_ids = tokenized_prompt.input_ids.cuda()
    attn_mask = tokenized_prompt.attention_mask.cuda()
    
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attn_mask,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id
    )
    generated_texts = tokenizer.batch_decode(
        outputs.detach().cpu().numpy(),
        skip_special_tokens=True
    )[0].split('[/INST]')[-1]

    return generated_texts.strip()
import random

random.seed(SEED)
n = random.randrange(len(test_ds))

sample = test_ds[n]
prompt = prompt_formatting_fn(sample, training=False)
print(prompt)
<s>[INST] Your task is to generate a concise prompt for querying a large language model so that the model can output the following response.

Response:
One popular dating app is Tinder. [/INST]
generated_texts = generate(prompt, model=peft_model)
print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")
Generated prompt:
Name a popular dating app.

Ground truth:
Name a popular dating app.

4.1. compare with the original model

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16
)

untuned_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    attn_implementation='flash_attention_2',
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
generated_texts = generate(prompt, model=untuned_model)
print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")
Generated prompt:
Could you please provide information about well-known dating applications? Which one is frequently used and has gained significant popularity?

Ground truth:
Name a popular dating app.

5. Merge and Save

import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_folder)

model = AutoPeftModelForCausalLM.from_pretrained(
    model_folder,
    low_cpu_mem_usage=True,
    torch_dtype=torch.bfloat16
)

# Merge LoRA and base model
merged_model = model.merge_and_unload()
output_folder = 'merged-mistral-7b-instruct-v0.2-sft'

merged_model.save_pretrained(output_folder, safe_serialization=True)
tokenizer.save_pretrained(output_folder)
('merged-mistral-7b-instruct-v0.2-sft/tokenizer_config.json',
 'merged-mistral-7b-instruct-v0.2-sft/special_tokens_map.json',
 'merged-mistral-7b-instruct-v0.2-sft/tokenizer.model',
 'merged-mistral-7b-instruct-v0.2-sft/added_tokens.json',
 'merged-mistral-7b-instruct-v0.2-sft/tokenizer.json')

6. Inference using merged model

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(output_folder)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

merged_model = AutoModelForCausalLM.from_pretrained(
    output_folder,
    quantization_config=bnb_config,
    attn_implementation='flash_attention_2',
    device_map="auto",
)
generated_texts = generate(prompt, model=merged_model)
print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")
Generated prompt:
Name a popular dating app.

Ground truth:
Name a popular dating app.