import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
Fine-tune Mistral-7B-Instruct-v0.2 for generating prompts based on given texts. The fine-tuned model could be used to generate synthetic datasets for a certain domain, which could be used to fine-tune a model for domain specific tasks.
Dataset used for fine-tuning: Alpaca-GPT-4 dataset
1. Load model
= "mistralai/Mistral-7B-Instruct-v0.2"
model_name
= BitsAndBytesConfig(
bnb_config =True,
load_in_4bit="nf4",
bnb_4bit_quant_type=torch.bfloat16,
bnb_4bit_compute_dtype=True
bnb_4bit_use_double_quant
)
= AutoModelForCausalLM.from_pretrained(
base_model
model_name,=bnb_config,
quantization_config=False,
use_cache='flash_attention_2',
attn_implementation="auto"
device_map
)
= AutoTokenizer.from_pretrained(model_name)
tokenizer = tokenizer.eos_token
tokenizer.pad_token = "right" tokenizer.padding_side
base_model
MistralForCausalLM(
(model): MistralModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x MistralDecoderLayer(
(self_attn): MistralFlashAttention2(
(q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(rotary_emb): MistralRotaryEmbedding()
)
(mlp): MistralMLP(
(gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): MistralRMSNorm()
(post_attention_layernorm): MistralRMSNorm()
)
)
(norm): MistralRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
2. Load dataset
from datasets import load_dataset
= "c-s-ale/alpaca-gpt4-data"
dataset_name = load_dataset(dataset_name, split='train[:6000]')
dataset
= 42
SEED
= dataset.train_test_split(test_size=1000, seed=SEED)
train_test_ds = train_test_ds['train']
train_ds = train_test_ds['test']
test_ds
train_ds, test_ds
(Dataset({
features: ['instruction', 'input', 'output'],
num_rows: 5000
}),
Dataset({
features: ['instruction', 'input', 'output'],
num_rows: 1000
}))
2.1 prompt formatting
def prompt_formatting_fn(example, training=True):
= """
prompt_template_1 Your task is to generate a concise prompt for querying a large language model so that \
the model can output the following response.
Response:
{output}
""".strip()
= """
prompt_template_2 Your task is to generate a concise prompt for querying a large language model so that \
the model can output the following response. \
Use the provided context to help you create the prompt.
Response:
{output}
Context:
{input}
""".strip()
input = example['input']
= example['output']
output if input is not None and len(input) > 0:
= [
messages 'role': 'user', 'content': prompt_template_2.format(output=output, input=input)}
{
]
else:
= [
messages 'role': 'user', 'content': prompt_template_1.format(output=output)}
{
]if training:
messages.append('role': 'assistant', 'content': example['instruction']}
{
)= tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt return prompt
print(prompt_formatting_fn(train_ds[1]))
<s>[INST] Your task is to generate a concise prompt for querying a large language model so that the model can output the following response.
Response:
A pine tree is an evergreen conifer belonging to the genus Pinus, in the family Pinaceae. This type of tree is characterized by its needle-like leaves, which grow in clusters and are typically 1-8 inches long. Pine trees have a distinct fragrance and produce cones that contain the seeds of the plant. The bark of a pine tree is usually thick and scaly, with deep furrows, providing protection from the elements. Pine trees can grow to be very tall, sometimes reaching over 80 feet in height, and have a conical shape with branches that are often level or slightly ascending. They are also known for their longevity, with some species capable of living for hundreds or thousands of years. [/INST]Describe the attributes of a pine tree.</s>
3. Train
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
= LoraConfig(
lora_config =64,
r=128,
lora_alpha=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
target_modules=0.1,
lora_dropout="none",
bias="CAUSAL_LM"
task_type
)
= prepare_model_for_kbit_training(base_model)
base_model = get_peft_model(base_model, lora_config) model
from transformers import TrainingArguments
= TrainingArguments(
training_args ='finetuning_output',
output_dir=1,
num_train_epochs=2,
gradient_accumulation_steps=True,
gradient_checkpointing="paged_adamw_32bit",
optim=10,
logging_steps="epoch",
save_strategy=2e-4,
learning_rate=True,
bf16=0.3,
max_grad_norm=0.03,
warmup_ratio="constant"
lr_scheduler_type )
from trl import SFTTrainer
= 4096
max_seq_length
= SFTTrainer(
trainer =model,
model=train_ds,
train_dataset=lora_config,
peft_config=max_seq_length,
max_seq_length=tokenizer,
tokenizer=True,
packing=prompt_formatting_fn,
formatting_func=training_args,
args )
/home/jovyan/.local/lib/python3.11/site-packages/trl/trainer/utils.py:434: UserWarning: The passed formatting_func has more than one argument. Usually that function should have a single argument `example` which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing.
warnings.warn(
/home/jovyan/.local/lib/python3.11/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead:
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
warnings.warn(
trainer.train()
/home/jovyan/.local/lib/python3.11/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
warnings.warn(
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
[16/16 14:40, Epoch 0/1]
Step | Training Loss |
---|---|
10 | 1.029300 |
TrainOutput(global_step=16, training_loss=0.9475531280040741, metrics={'train_runtime': 939.2543, 'train_samples_per_second': 0.278, 'train_steps_per_second': 0.017, 'total_flos': 4.579193628760474e+16, 'train_loss': 0.9475531280040741, 'epoch': 0.97})
= 'mistral-7b-instruct-v0.2-sft'
model_folder
trainer.save_model(model_folder)
4. Test Peft adaptor
from peft import AutoPeftModelForCausalLM
= AutoPeftModelForCausalLM.from_pretrained(
peft_model
model_folder,=bnb_config,
quantization_config='flash_attention_2',
attn_implementation='auto'
device_map
)
= AutoTokenizer.from_pretrained(model_folder) tokenizer
def generate(prompt, max_new_tokens=256, model=model, tokenizer=tokenizer):
= tokenizer(prompt, return_tensors="pt")
tokenized_prompt = tokenized_prompt.input_ids.cuda()
input_ids = tokenized_prompt.attention_mask.cuda()
attn_mask
= model.generate(
outputs =input_ids,
input_ids=attn_mask,
attention_mask=max_new_tokens,
max_new_tokens=tokenizer.eos_token_id
pad_token_id
)= tokenizer.batch_decode(
generated_texts
outputs.detach().cpu().numpy(),=True
skip_special_tokens0].split('[/INST]')[-1]
)[
return generated_texts.strip()
import random
random.seed(SEED)
= random.randrange(len(test_ds))
n
= test_ds[n]
sample = prompt_formatting_fn(sample, training=False)
prompt print(prompt)
<s>[INST] Your task is to generate a concise prompt for querying a large language model so that the model can output the following response.
Response:
One popular dating app is Tinder. [/INST]
= generate(prompt, model=peft_model)
generated_texts print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")
Generated prompt:
Name a popular dating app.
Ground truth:
Name a popular dating app.
4.1. compare with the original model
= BitsAndBytesConfig(
bnb_config =True,
load_in_4bit=True,
bnb_4bit_use_double_quant="nf4",
bnb_4bit_quant_type=torch.float16
bnb_4bit_compute_dtype
)
= AutoModelForCausalLM.from_pretrained(
untuned_model
model_name,=bnb_config,
quantization_config='flash_attention_2',
attn_implementation="auto"
device_map
)
= AutoTokenizer.from_pretrained(model_name) tokenizer
= generate(prompt, model=untuned_model)
generated_texts print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")
Generated prompt:
Could you please provide information about well-known dating applications? Which one is frequently used and has gained significant popularity?
Ground truth:
Name a popular dating app.
5. Merge and Save
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
= AutoTokenizer.from_pretrained(model_folder)
tokenizer
= AutoPeftModelForCausalLM.from_pretrained(
model
model_folder,=True,
low_cpu_mem_usage=torch.bfloat16
torch_dtype
)
# Merge LoRA and base model
= model.merge_and_unload() merged_model
= 'merged-mistral-7b-instruct-v0.2-sft'
output_folder
=True)
merged_model.save_pretrained(output_folder, safe_serialization tokenizer.save_pretrained(output_folder)
('merged-mistral-7b-instruct-v0.2-sft/tokenizer_config.json',
'merged-mistral-7b-instruct-v0.2-sft/special_tokens_map.json',
'merged-mistral-7b-instruct-v0.2-sft/tokenizer.model',
'merged-mistral-7b-instruct-v0.2-sft/added_tokens.json',
'merged-mistral-7b-instruct-v0.2-sft/tokenizer.json')
6. Inference using merged model
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
= AutoTokenizer.from_pretrained(output_folder)
tokenizer
= BitsAndBytesConfig(
bnb_config =True,
load_in_4bit="nf4",
bnb_4bit_quant_type=torch.bfloat16,
bnb_4bit_compute_dtype=True
bnb_4bit_use_double_quant
)
= AutoModelForCausalLM.from_pretrained(
merged_model
output_folder,=bnb_config,
quantization_config='flash_attention_2',
attn_implementation="auto",
device_map )
= generate(prompt, model=merged_model)
generated_texts print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")
Generated prompt:
Name a popular dating app.
Ground truth:
Name a popular dating app.