Function calling with Mistral-7B-Instruct-v0.3

Mistral
Function calling
Tabular data
Published

May 26, 2024

import json
import os
import sqlite3
import torch

from pathlib import Path

from mistral_inference.model import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, ToolMessage, AssistantMessage, SystemMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool, ToolCall, FunctionCall

Download sqlite3 sample database chinook

!wget https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip
--2024-05-26 14:35:19--  https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip
Resolving www.sqlitetutorial.net (www.sqlitetutorial.net)... 172.64.80.1, 2606:4700:130:436c:6f75:6466:6c61:7265
Connecting to www.sqlitetutorial.net (www.sqlitetutorial.net)|172.64.80.1|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 305596 (298K) [application/zip]
Saving to: ‘chinook.zip’

chinook.zip         100%[===================>] 298.43K  --.-KB/s    in 0.006s  

2024-05-26 14:35:19 (50.9 MB/s) - ‘chinook.zip’ saved [305596/305596]
!unzip chinook.zip
Archive:  chinook.zip
  inflating: chinook.db              

Load the model and its tokenizer

mistral_models_path = Path.home().joinpath('.cache/huggingface/hub/mistral_models/7B-Instruct-v0.3')
tokenizer = MistralTokenizer.from_file(f'{mistral_models_path}/tokenizer.model.v3')
model = Transformer.from_folder(mistral_models_path, dtype=torch.float16)

Define function tools and some helper functions

# 
# helper functions
# 
def format_prompt(messages, tokenizer=tokenizer, tools=None):
    completion_request = ChatCompletionRequest(
        tools=tools,
        messages=messages
    )
    tokenized = tokenizer.encode_chat_completion(completion_request)
    return tokenized.tokens, tokenized.text

def ask_llm(messages, tools=None, model=model, tokenizer=tokenizer, max_tokens=512, temperature=0.0):
    tokens, formatted_prompt = format_prompt(messages, tools=tools)
    output_tokens, _ = generate([tokens], model, max_tokens=max_tokens, temperature=temperature, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
    result = tokenizer.instruct_tokenizer.tokenizer.decode(output_tokens[0])
    return result, formatted_prompt

def call_tool(response):
    toolcall = json.loads(response)[0]
    fn = name_fn_mappings[toolcall['name']]
    fn_results = fn(**toolcall['arguments'])
    return fn_results, toolcall

# 
# functions to be used as tools
# 
def search_customer_support(customer_firstname: str, customer_lastname) -> list[tuple[str, str, int]]|None:
    """search customers table and return firstname, lastname and support representative id."""
    stmt = (
        "SELECT FirstName, LastName, SupportRepid FROM customers "
        f"WHERE FirstName LIKE '%{customer_firstname}%' and LastName LIKE '%{customer_lastname}%'"
    )
    res = conn.execute(stmt).fetchall()
    if len(res) == 0:
        res = None
    return res

def search_employee(employee_id: int) -> list[tuple[str, str, str]]|None:
    """search employees table and return firstname, lastname and title."""
    stmt = (
        "SELECT FirstName, LastName, Title FROM employees "
        f"WHERE Employeeid={employee_id}"
    )
    res = conn.execute(stmt).fetchall()
    if len(res) == 0:
        res = None
    return res
name_fn_mappings = {
    'search_customer_support': search_customer_support,
    'search_employee': search_employee
}

tools=[
    Tool(
        function=Function(
            name="search_customer_support",
            description="Useful when you want to find out who provided support to a customer.",
            parameters={
                "type": "object",
                "properties": {
                    "customer_firstname": {
                        "type": "string",
                        "description": "A customer's first name."
                    },
                    "customer_lastname": {
                        "type": "string",
                        "description": "A customer's last name."
                    }
                },
                "required": ["customer_firstname", "customer_lastname"]
            }
        )
    ),
    Tool(
        function=Function(
            name="search_employee",
            description="Useful when you want to retrieve more information about an employee.",
            parameters={
                "type": "object",
                "properties": {
                    "employee_id": {
                        "type": "integer",
                        "description": "employee's id"
                    }
                },
                "required": ["employee_id"]
            }
        )
    )    
]

Run query with function calling

conn = sqlite3.connect('chinook.db')
messages=[
    SystemMessage(
        content=(
            "Your task is to answer user's questions. "
            "If you need to use tools multiple times in order to answer a question, "
            "only respond with the first tool call. "
            "Format your ourput in a valid JSON format so that a python function can consume it."
        )
    ),
    UserMessage(
        content="Get the fristname and lastname of the employee who provided customer support to Stanisław Wójcik."
    )
]
response, formatted_prompt = ask_llm(messages, tools)
print(formatted_prompt)
response
<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"search_customer_support",▁"description":▁"Useful▁when▁you▁want▁to▁find▁out▁who▁provided▁support▁to▁a▁customer.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"customer_firstname":▁{"type":▁"string",▁"description":▁"A▁customer's▁first▁name."},▁"customer_lastname":▁{"type":▁"string",▁"description":▁"A▁customer's▁last▁name."}},▁"required":▁["customer_firstname",▁"customer_lastname"]}}},▁{"type":▁"function",▁"function":▁{"name":▁"search_employee",▁"description":▁"Useful▁when▁you▁want▁to▁retrieve▁more▁information▁about▁an▁employee.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"employee_id":▁{"type":▁"integer",▁"description":▁"employee's▁id"}},▁"required":▁["employee_id"]}}}][/AVAILABLE_TOOLS][INST]▁Your▁task▁is▁to▁answer▁user's▁questions.▁If▁you▁need▁to▁use▁tools▁multiple▁times▁in▁order▁to▁answer▁a▁question,▁only▁respond▁with▁the▁first▁tool▁call.▁Format▁your▁ourput▁in▁a▁valid▁JSON▁format▁so▁that▁a▁python▁function▁can▁consume▁it.<0x0A><0x0A>Get▁the▁fristname▁and▁lastname▁of▁the▁employee▁who▁provided▁customer▁support▁to▁Stanisław▁Wójcik.[/INST]
'[{"name": "search_customer_support", "arguments": {"customer_firstname": "Stanisław", "customer_lastname": "Wójcik"}}]'
fn_results, toolcall = call_tool(response)
print(toolcall)
fn_results
{'name': 'search_customer_support', 'arguments': {'customer_firstname': 'Stanisław', 'customer_lastname': 'Wójcik'}}
[('Stanisław', 'Wójcik', 4)]
messages.append(
    AssistantMessage(
        tool_calls=[
            ToolCall(function=FunctionCall(**toolcall))
        ]
    )
)
messages.append(
    ToolMessage(
        content=json.dumps({"employee_id": fn_results[0][-1]}),
        tool_call_id='abcdefghi',
        name=toolcall['name']
    )
)
response, formatted_prompt = ask_llm(messages, tools)
print(formatted_prompt)
response
<s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"search_customer_support",▁"description":▁"Useful▁when▁you▁want▁to▁find▁out▁who▁provided▁support▁to▁a▁customer.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"customer_firstname":▁{"type":▁"string",▁"description":▁"A▁customer's▁first▁name."},▁"customer_lastname":▁{"type":▁"string",▁"description":▁"A▁customer's▁last▁name."}},▁"required":▁["customer_firstname",▁"customer_lastname"]}}},▁{"type":▁"function",▁"function":▁{"name":▁"search_employee",▁"description":▁"Useful▁when▁you▁want▁to▁retrieve▁more▁information▁about▁an▁employee.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"employee_id":▁{"type":▁"integer",▁"description":▁"employee's▁id"}},▁"required":▁["employee_id"]}}}][/AVAILABLE_TOOLS][INST]▁Your▁task▁is▁to▁answer▁user's▁questions.▁If▁you▁need▁to▁use▁tools▁multiple▁times▁in▁order▁to▁answer▁a▁question,▁only▁respond▁with▁the▁first▁tool▁call.▁Format▁your▁ourput▁in▁a▁valid▁JSON▁format▁so▁that▁a▁python▁function▁can▁consume▁it.<0x0A><0x0A>Get▁the▁fristname▁and▁lastname▁of▁the▁employee▁who▁provided▁customer▁support▁to▁Stanisław▁Wójcik.[/INST][TOOL_CALLS]▁[{"name":▁"search_customer_support",▁"arguments":▁{"customer_firstname":▁"Stanisław",▁"customer_lastname":▁"Wójcik"}}]</s>[TOOL_RESULTS]▁{"content":▁{"employee_id":▁4},▁"call_id":▁"abcdefghi"}[/TOOL_RESULTS]
"I found that the employee who provided support to Stanisław Wójcik is employee with id 4. To get more information about this employee, you can use the search_employee function with the employee's id 4."
messages.append(AssistantMessage(content=response))
messages.append(UserMessage(content=('Given the above tool results, answer the original question.')))

response, formatted_prompt = ask_llm(messages, tools)
print(formatted_prompt)
response
<s>[INST]▁Get▁the▁fristname▁and▁lastname▁of▁the▁employee▁who▁provided▁customer▁support▁to▁Stanisław▁Wójcik.[/INST][TOOL_CALLS]▁[{"name":▁"search_customer_support",▁"arguments":▁{"customer_firstname":▁"Stanisław",▁"customer_lastname":▁"Wójcik"}}]</s>[TOOL_RESULTS]▁{"content":▁{"employee_id":▁4},▁"call_id":▁"abcdefghi"}[/TOOL_RESULTS]▁I▁found▁that▁the▁employee▁who▁provided▁support▁to▁Stanisław▁Wójcik▁is▁employee▁with▁id▁4.▁To▁get▁more▁information▁about▁this▁employee,▁you▁can▁use▁the▁search_employee▁function▁with▁the▁employee's▁id▁4.</s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"search_customer_support",▁"description":▁"Useful▁when▁you▁want▁to▁find▁out▁who▁provided▁support▁to▁a▁customer.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"customer_firstname":▁{"type":▁"string",▁"description":▁"A▁customer's▁first▁name."},▁"customer_lastname":▁{"type":▁"string",▁"description":▁"A▁customer's▁last▁name."}},▁"required":▁["customer_firstname",▁"customer_lastname"]}}},▁{"type":▁"function",▁"function":▁{"name":▁"search_employee",▁"description":▁"Useful▁when▁you▁want▁to▁retrieve▁more▁information▁about▁an▁employee.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"employee_id":▁{"type":▁"integer",▁"description":▁"employee's▁id"}},▁"required":▁["employee_id"]}}}][/AVAILABLE_TOOLS][INST]▁Your▁task▁is▁to▁answer▁user's▁questions.▁If▁you▁need▁to▁use▁tools▁multiple▁times▁in▁order▁to▁answer▁a▁question,▁only▁respond▁with▁the▁first▁tool▁call.▁Format▁your▁ourput▁in▁a▁valid▁JSON▁format▁so▁that▁a▁python▁function▁can▁consume▁it.<0x0A><0x0A>Given▁the▁above▁tool▁results,▁answer▁the▁original▁question.[/INST]
'[{"name": "search_employee", "arguments": {"employee_id": 4}}]'
fn_results, toolcall = call_tool(response)
print(toolcall)
fn_results
{'name': 'search_employee', 'arguments': {'employee_id': 4}}
[('Margaret', 'Park', 'Sales Support Agent')]
messages.append(
    AssistantMessage(
        tool_calls=[
            ToolCall(function=FunctionCall(**toolcall))
        ]
    )
)
messages.append(
    ToolMessage(
        content=json.dumps({"employee": fn_results[0]}),
        tool_call_id='abcdefghi',
        name=toolcall['name']
    )
)
response, formatted_prompt = ask_llm(messages, tools)
print(formatted_prompt)
response
<s>[INST]▁Get▁the▁fristname▁and▁lastname▁of▁the▁employee▁who▁provided▁customer▁support▁to▁Stanisław▁Wójcik.[/INST][TOOL_CALLS]▁[{"name":▁"search_customer_support",▁"arguments":▁{"customer_firstname":▁"Stanisław",▁"customer_lastname":▁"Wójcik"}}]</s>[TOOL_RESULTS]▁{"content":▁{"employee_id":▁4},▁"call_id":▁"abcdefghi"}[/TOOL_RESULTS]▁I▁found▁that▁the▁employee▁who▁provided▁support▁to▁Stanisław▁Wójcik▁is▁employee▁with▁id▁4.▁To▁get▁more▁information▁about▁this▁employee,▁you▁can▁use▁the▁search_employee▁function▁with▁the▁employee's▁id▁4.</s>[AVAILABLE_TOOLS]▁[{"type":▁"function",▁"function":▁{"name":▁"search_customer_support",▁"description":▁"Useful▁when▁you▁want▁to▁find▁out▁who▁provided▁support▁to▁a▁customer.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"customer_firstname":▁{"type":▁"string",▁"description":▁"A▁customer's▁first▁name."},▁"customer_lastname":▁{"type":▁"string",▁"description":▁"A▁customer's▁last▁name."}},▁"required":▁["customer_firstname",▁"customer_lastname"]}}},▁{"type":▁"function",▁"function":▁{"name":▁"search_employee",▁"description":▁"Useful▁when▁you▁want▁to▁retrieve▁more▁information▁about▁an▁employee.",▁"parameters":▁{"type":▁"object",▁"properties":▁{"employee_id":▁{"type":▁"integer",▁"description":▁"employee's▁id"}},▁"required":▁["employee_id"]}}}][/AVAILABLE_TOOLS][INST]▁Your▁task▁is▁to▁answer▁user's▁questions.▁If▁you▁need▁to▁use▁tools▁multiple▁times▁in▁order▁to▁answer▁a▁question,▁only▁respond▁with▁the▁first▁tool▁call.▁Format▁your▁ourput▁in▁a▁valid▁JSON▁format▁so▁that▁a▁python▁function▁can▁consume▁it.<0x0A><0x0A>Given▁the▁above▁tool▁results,▁answer▁the▁original▁question.[/INST][TOOL_CALLS]▁[{"name":▁"search_employee",▁"arguments":▁{"employee_id":▁4}}]</s>[TOOL_RESULTS]▁{"content":▁{"employee":▁["Margaret",▁"Park",▁"Sales▁Support▁Agent"]},▁"call_id":▁"abcdefghi"}[/TOOL_RESULTS]
'The employee who provided support to Stanisław Wójcik is Margaret Park, a Sales Support Agent.'
conn.close()