import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfigFine-tune Mistral-7B-Instruct-v0.2 for generating prompts based on given texts. The fine-tuned model could be used to generate synthetic datasets for a certain domain, which could be used to fine-tune a model for domain specific tasks.
Dataset used for fine-tuning: Alpaca-GPT-4 dataset
1. Load model
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
use_cache=False,
attn_implementation='flash_attention_2',
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"base_modelMistralForCausalLM(
(model): MistralModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x MistralDecoderLayer(
(self_attn): MistralFlashAttention2(
(q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
(rotary_emb): MistralRotaryEmbedding()
)
(mlp): MistralMLP(
(gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): MistralRMSNorm()
(post_attention_layernorm): MistralRMSNorm()
)
)
(norm): MistralRMSNorm()
)
(lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)
2. Load dataset
from datasets import load_datasetdataset_name = "c-s-ale/alpaca-gpt4-data"
dataset = load_dataset(dataset_name, split='train[:6000]')
SEED = 42
train_test_ds = dataset.train_test_split(test_size=1000, seed=SEED)
train_ds = train_test_ds['train']
test_ds = train_test_ds['test']
train_ds, test_ds(Dataset({
features: ['instruction', 'input', 'output'],
num_rows: 5000
}),
Dataset({
features: ['instruction', 'input', 'output'],
num_rows: 1000
}))
2.1 prompt formatting
def prompt_formatting_fn(example, training=True):
prompt_template_1 = """
Your task is to generate a concise prompt for querying a large language model so that \
the model can output the following response.
Response:
{output}
""".strip()
prompt_template_2 = """
Your task is to generate a concise prompt for querying a large language model so that \
the model can output the following response. \
Use the provided context to help you create the prompt.
Response:
{output}
Context:
{input}
""".strip()
input = example['input']
output = example['output']
if input is not None and len(input) > 0:
messages = [
{'role': 'user', 'content': prompt_template_2.format(output=output, input=input)}
]
else:
messages = [
{'role': 'user', 'content': prompt_template_1.format(output=output)}
]
if training:
messages.append(
{'role': 'assistant', 'content': example['instruction']}
)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return promptprint(prompt_formatting_fn(train_ds[1]))<s>[INST] Your task is to generate a concise prompt for querying a large language model so that the model can output the following response.
Response:
A pine tree is an evergreen conifer belonging to the genus Pinus, in the family Pinaceae. This type of tree is characterized by its needle-like leaves, which grow in clusters and are typically 1-8 inches long. Pine trees have a distinct fragrance and produce cones that contain the seeds of the plant. The bark of a pine tree is usually thick and scaly, with deep furrows, providing protection from the elements. Pine trees can grow to be very tall, sometimes reaching over 80 feet in height, and have a conical shape with branches that are often level or slightly ascending. They are also known for their longevity, with some species capable of living for hundreds or thousands of years. [/INST]Describe the attributes of a pine tree.</s>
3. Train
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
lora_config = LoraConfig(
r=64,
lora_alpha=128,
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
lora_dropout=0.1,
bias="none",
task_type="CAUSAL_LM"
)
base_model = prepare_model_for_kbit_training(base_model)
model = get_peft_model(base_model, lora_config)from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir='finetuning_output',
num_train_epochs=1,
gradient_accumulation_steps=2,
gradient_checkpointing=True,
optim="paged_adamw_32bit",
logging_steps=10,
save_strategy="epoch",
learning_rate=2e-4,
bf16=True,
max_grad_norm=0.3,
warmup_ratio=0.03,
lr_scheduler_type="constant"
)from trl import SFTTrainer
max_seq_length = 4096
trainer = SFTTrainer(
model=model,
train_dataset=train_ds,
peft_config=lora_config,
max_seq_length=max_seq_length,
tokenizer=tokenizer,
packing=True,
formatting_func=prompt_formatting_fn,
args=training_args,
)/home/jovyan/.local/lib/python3.11/site-packages/trl/trainer/utils.py:434: UserWarning: The passed formatting_func has more than one argument. Usually that function should have a single argument `example` which corresponds to the dictionary returned by each element of the dataset. Make sure you know what you are doing.
warnings.warn(
/home/jovyan/.local/lib/python3.11/site-packages/accelerate/accelerator.py:432: FutureWarning: Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead:
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
warnings.warn(
trainer.train()/home/jovyan/.local/lib/python3.11/site-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
warnings.warn(
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
[16/16 14:40, Epoch 0/1]
| Step | Training Loss |
|---|---|
| 10 | 1.029300 |
TrainOutput(global_step=16, training_loss=0.9475531280040741, metrics={'train_runtime': 939.2543, 'train_samples_per_second': 0.278, 'train_steps_per_second': 0.017, 'total_flos': 4.579193628760474e+16, 'train_loss': 0.9475531280040741, 'epoch': 0.97})
model_folder = 'mistral-7b-instruct-v0.2-sft'
trainer.save_model(model_folder)4. Test Peft adaptor
from peft import AutoPeftModelForCausalLM
peft_model = AutoPeftModelForCausalLM.from_pretrained(
model_folder,
quantization_config=bnb_config,
attn_implementation='flash_attention_2',
device_map='auto'
)
tokenizer = AutoTokenizer.from_pretrained(model_folder)def generate(prompt, max_new_tokens=256, model=model, tokenizer=tokenizer):
tokenized_prompt = tokenizer(prompt, return_tensors="pt")
input_ids = tokenized_prompt.input_ids.cuda()
attn_mask = tokenized_prompt.attention_mask.cuda()
outputs = model.generate(
input_ids=input_ids,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id
)
generated_texts = tokenizer.batch_decode(
outputs.detach().cpu().numpy(),
skip_special_tokens=True
)[0].split('[/INST]')[-1]
return generated_texts.strip()import random
random.seed(SEED)n = random.randrange(len(test_ds))
sample = test_ds[n]
prompt = prompt_formatting_fn(sample, training=False)
print(prompt)<s>[INST] Your task is to generate a concise prompt for querying a large language model so that the model can output the following response.
Response:
One popular dating app is Tinder. [/INST]
generated_texts = generate(prompt, model=peft_model)
print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")Generated prompt:
Name a popular dating app.
Ground truth:
Name a popular dating app.
4.1. compare with the original model
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
untuned_model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
attn_implementation='flash_attention_2',
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)generated_texts = generate(prompt, model=untuned_model)
print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")Generated prompt:
Could you please provide information about well-known dating applications? Which one is frequently used and has gained significant popularity?
Ground truth:
Name a popular dating app.
5. Merge and Save
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_folder)
model = AutoPeftModelForCausalLM.from_pretrained(
model_folder,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16
)
# Merge LoRA and base model
merged_model = model.merge_and_unload()output_folder = 'merged-mistral-7b-instruct-v0.2-sft'
merged_model.save_pretrained(output_folder, safe_serialization=True)
tokenizer.save_pretrained(output_folder)('merged-mistral-7b-instruct-v0.2-sft/tokenizer_config.json',
'merged-mistral-7b-instruct-v0.2-sft/special_tokens_map.json',
'merged-mistral-7b-instruct-v0.2-sft/tokenizer.model',
'merged-mistral-7b-instruct-v0.2-sft/added_tokens.json',
'merged-mistral-7b-instruct-v0.2-sft/tokenizer.json')
6. Inference using merged model
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(output_folder)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
merged_model = AutoModelForCausalLM.from_pretrained(
output_folder,
quantization_config=bnb_config,
attn_implementation='flash_attention_2',
device_map="auto",
)generated_texts = generate(prompt, model=merged_model)
print(f'Generated prompt:\n{generated_texts}')
print(f"\nGround truth:\n{sample['instruction']}")Generated prompt:
Name a popular dating app.
Ground truth:
Name a popular dating app.