Skip to content

Commit

Permalink
Support multiple chat templates per model (#134)
Browse files Browse the repository at this point in the history
* Improve chat template parsing

* Clean up

* Improve chat template selection

* Add tests for chat templates

* Update Sources/Tokenizers/Tokenizer.swift

Co-authored-by: Pedro Cuenca <[email protected]>

* Improve template selection

* More elegant solution for chatTemplate argument

* Update Sources/Tokenizers/Tokenizer.swift

Co-authored-by: Pedro Cuenca <[email protected]>

* Add overload with `chatTemplate` argument of type `String`

---------

Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
DePasqualeOrg and pcuenca authored Oct 3, 2024
1 parent a7a61a2 commit 4d25d20
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 17 deletions.
102 changes: 85 additions & 17 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import Hub
import Foundation
import Jinja

enum TokenizerError : Error {
enum TokenizerError: Error {
case missingConfig
case missingTokenizerClassInConfig
case unsupportedTokenizer(String)
case missingVocab
case malformedVocab

case chatTemplate(String)
case tooLong(String)
}

Expand Down Expand Up @@ -94,6 +94,13 @@ struct TokenizerModel {
}
}

public enum ChatTemplateArgument {
/// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config.
case literal(String)
/// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary.
case name(String)
}

public protocol Tokenizer {
func tokenize(text: String) -> [String]

Expand All @@ -117,15 +124,24 @@ public protocol Tokenizer {
var eosTokenId: Int? { get }
var unknownToken: String? { get }
var unknownTokenId: Int? { get }


/// The appropriate chat template is selected from the tokenizer config
func applyChatTemplate(messages: [[String: String]]) throws -> [Int]


/// The chat template is provided as a string literal or specified by name
func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int]

/// The chat template is provided as a string literal
func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int]

func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
/// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
chatTemplate: ChatTemplateArgument?,
addGenerationPrompt: Bool,
truncation: Bool,
maxLength: Int?
maxLength: Int?,
tools: [[String: Any]]?
) throws -> [Int]
}

Expand Down Expand Up @@ -176,8 +192,6 @@ public class PreTrainedTokenizer: Tokenizer {
private let tokenizerConfig: Config

private let cleanUpTokenizationSpaces: Bool

private let defaultChatTemplate: String = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

required public init(tokenizerConfig: Config, tokenizerData: Config) throws {
var addedTokens: [String : Int] = [:]
Expand Down Expand Up @@ -222,7 +236,7 @@ public class PreTrainedTokenizer: Tokenizer {
self.decoder = DecoderFactory.fromConfig(config: tokenizerData.decoder)
self.cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces?.boolValue ?? true
self.tokenizerConfig = tokenizerConfig

model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
}

Expand Down Expand Up @@ -316,22 +330,76 @@ public class PreTrainedTokenizer: Tokenizer {
public func convertIdToToken(_ id: Int) -> String? {
model.convertIdToToken(id)
}

public func applyChatTemplate(messages: [[String: String]]) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: nil, addGenerationPrompt: true, maxLength: nil)
try applyChatTemplate(messages: messages, addGenerationPrompt: true)
}

public func applyChatTemplate(messages: [[String: String]], chatTemplate: ChatTemplateArgument) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true)
}


public func applyChatTemplate(messages: [[String: String]], chatTemplate: String) throws -> [Int] {
try applyChatTemplate(messages: messages, chatTemplate: .literal(chatTemplate), addGenerationPrompt: true)
}

public func applyChatTemplate(
messages: [[String: String]],
chatTemplate: String?,
chatTemplate: ChatTemplateArgument? = nil,
addGenerationPrompt: Bool = false,
truncation: Bool = false,
maxLength: Int?
maxLength: Int? = nil,
/// A list of tools (callable functions) that will be accessible to the model. If the template does not
/// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
/// giving the name, description and argument types for the tool. See the
/// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
/// for more information.
/// Note: tool calling is not supported yet, it will be available in a future update.
tools: [[String: Any]]? = nil
) throws -> [Int] {
let template = try Template(chatTemplate ?? tokenizerConfig.chatTemplate?.stringValue ?? defaultChatTemplate)
var selectedChatTemplate: String?
if let chatTemplate, case .literal(let template) = chatTemplate {
// Use chat template from argument
selectedChatTemplate = template
} else if let valueFromConfig = tokenizerConfig.chatTemplate {
if let arrayValue = valueFromConfig.arrayValue {
// If the config specifies a list of chat templates, convert them to a dictionary
let templateDict = Dictionary<String, String>(uniqueKeysWithValues: arrayValue.compactMap { item in
guard let name = item.name?.stringValue, let template = item.template?.stringValue else {
return nil
}
return (name, template)
})
if let chatTemplate, case .name(let name) = chatTemplate {
// Select chat template from config by name
if let matchingDictEntry = templateDict[name] {
selectedChatTemplate = matchingDictEntry
} else {
throw TokenizerError.chatTemplate("No chat template named \"\(name)\" was found in the tokenizer config")
}
} else if let tools, !tools.isEmpty, let toolUseTemplate = templateDict["tool_use"] {
// Use tool use chat template from config
selectedChatTemplate = toolUseTemplate
} else if let defaultChatTemplate = templateDict["default"] {
// Use default chat template from config
selectedChatTemplate = defaultChatTemplate
}
} else if let stringValue = valueFromConfig.stringValue {
// Use chat template from config
selectedChatTemplate = stringValue
}
}

guard let selectedChatTemplate else {
throw TokenizerError.chatTemplate("No chat template was specified")
}

let template = try Template(selectedChatTemplate)
var context: [String: Any] = [
"messages": messages,
"add_generation_prompt": addGenerationPrompt
// TODO: Add `tools` entry when support is added in Jinja
// "tools": tools
]

// TODO: maybe keep NSString here
Expand Down Expand Up @@ -397,15 +465,15 @@ extension AutoTokenizer {

return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}

public static func from(
modelFolder: URL,
hubApi: HubApi = .shared
) async throws -> Tokenizer {
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
let tokenizerData = try await config.tokenizerData

return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
}
}
Expand Down
73 changes: 73 additions & 0 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//
// ChatTemplateTests.swift
// swift-transformers
//
// Created by Anthony DePasquale on 2/10/24.
//

import XCTest
import Tokenizers

class ChatTemplateTests: XCTestCase {
let messages = [[
"role": "user",
"content": "Describe the Swift programming language.",
]]

func testTemplateFromConfig() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [32010, 4002, 29581, 278, 14156, 8720, 4086, 29889, 32007, 32001]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<|user|>Describe the Swift programming language.<|end|><|assistant|>"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testDefaultTemplateFromArrayInConfig() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testTemplateFromArgumentWithEnum() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
// Purposely not using the correct template for this model to verify that the template from the config is not being used
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .literal(mistral7BDefaultTemplate))
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testTemplateFromArgumentWithString() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
// Purposely not using the correct template for this model to verify that the template from the config is not being used
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate)
let encodedTarget = [1, 518, 25580, 29962, 20355, 915, 278, 14156, 8720, 4086, 29889, 518, 29914, 25580, 29962]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

func testNamedTemplateFromArgument() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Mistral-7B-Instruct-v0.3-4bit")
// Normally it is not necessary to specify the name `default`, but I'm not aware of models with lists of templates in the config that are not `default` or `tool_use`
let encoded = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: .name("default"))
let encodedTarget = [1, 29473, 3, 28752, 1040, 4672, 2563, 17060, 4610, 29491, 29473, 4]
let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<s> [INST] Describe the Swift programming language. [/INST]"
XCTAssertEqual(encoded, encodedTarget)
XCTAssertEqual(decoded, decodedTarget)
}

// TODO: Add tests for tool use template
}

0 comments on commit 4d25d20

Please sign in to comment.