Skip to content

Commit

Permalink
Tokenizer fixes (#113)
Browse files Browse the repository at this point in the history
* Bring over hf token envvar from preview branch

* Add tests for Gemma, including edge cases

Edge cases also added for other BPE tokenizers, but not for T5 yet.

* Sort added tokens by length (descending) to avoid early partial matches

Similar to huggingface/transformers.js@c305c38

* Store vocab as NSString to allow multiple tokens with the same Unicode
canonical representation.

* Remove comments

* Go back to making vocab dictionaries private

* Use ungated copy of Gemma tokenizer

* Use NSString in UnigramTokenizer

* Switch test to microsoft tokenizer, verify in Python
  • Loading branch information
pcuenca authored Aug 19, 2024
1 parent e72d032 commit 4c8cf07
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 40 deletions.
12 changes: 6 additions & 6 deletions Sources/Hub/Hub.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ public extension Hub {

@dynamicMemberLookup
public struct Config {
public private(set) var dictionary: [String: Any]
public private(set) var dictionary: [NSString: Any]

public init(_ dictionary: [String: Any]) {
public init(_ dictionary: [NSString: Any]) {
self.dictionary = dictionary
}

Expand Down Expand Up @@ -76,8 +76,8 @@ public struct Config {


public subscript(dynamicMember member: String) -> Config? {
let key = dictionary[member] != nil ? member : uncamelCase(member)
if let value = dictionary[key] as? [String: Any] {
let key = (dictionary[member as NSString] != nil ? member : uncamelCase(member)) as NSString
if let value = dictionary[key] as? [NSString: Any] {
return Config(value)
} else if let value = dictionary[key] {
return Config(["value": value])
Expand All @@ -96,7 +96,7 @@ public struct Config {
// Instead of doing this we could provide custom classes and decode to them
public var arrayValue: [Config]? {
guard let list = value as? [Any] else { return nil }
return list.map { Config($0 as! [String : Any]) }
return list.map { Config($0 as! [NSString : Any]) }
}

/// Tuple of token identifier and string value
Expand Down Expand Up @@ -206,7 +206,7 @@ public class LanguageModelConfigurationFromHub {
do {
let data = try Data(contentsOf: url)
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
guard let dictionary = parsed as? [String: Any] else { return nil }
guard let dictionary = parsed as? [NSString: Any] else { return nil }
return Config(dictionary)
} catch {
return nil
Expand Down
6 changes: 3 additions & 3 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public struct HubApi {
public typealias Repo = Hub.Repo

public init(downloadBase: URL? = nil, hfToken: String? = nil, endpoint: String = "https://huggingface.co", useBackgroundSession: Bool = false) {
self.hfToken = hfToken
self.hfToken = hfToken ?? ProcessInfo.processInfo.environment["HUGGING_FACE_HUB_TOKEN"]
if let downloadBase {
self.downloadBase = downloadBase
} else {
Expand Down Expand Up @@ -102,7 +102,7 @@ public extension HubApi {
func configuration(fileURL: URL) throws -> Config {
let data = try Data(contentsOf: fileURL)
let parsed = try JSONSerialization.jsonObject(with: data, options: [])
guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse }
guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse }
return Config(dictionary)
}
}
Expand All @@ -116,7 +116,7 @@ public extension HubApi {
let (data, _) = try await httpGet(for: url)

let parsed = try JSONSerialization.jsonObject(with: data, options: [])
guard let dictionary = parsed as? [String: Any] else { throw Hub.HubClientError.parse }
guard let dictionary = parsed as? [NSString: Any] else { throw Hub.HubClientError.parse }
return Config(dictionary)
}
}
Expand Down
24 changes: 13 additions & 11 deletions Sources/Tokenizers/BPETokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ struct BytePair: Hashable {

class BPETokenizer: PreTrainedTokenizerModel {
let bpeRanks: Dictionary<BytePair, Int>
private let tokensToIds: [String: Int]
private let idsToTokens: [Int: String]

private let tokensToIds: [NSString: Int]
private let idsToTokens: [Int: NSString]

var vocabCount: Int { tokensToIds.count }

public let bosToken: String?
public let bosTokenId: Int?
public let eosToken: String?
Expand All @@ -45,7 +47,7 @@ class BPETokenizer: PreTrainedTokenizerModel {

required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws {
guard let merges = tokenizerData.model?.merges?.value as? [String] else { fatalError("BPETokenizer requires merges") }
guard let vocab = tokenizerData.model?.vocab?.dictionary as? [String: Int] else {
guard let vocab = tokenizerData.model?.vocab?.dictionary as? [NSString: Int] else {
throw TokenizerError.missingVocab
}
var bpeRanks: Dictionary<BytePair, Int> = [:]
Expand All @@ -56,31 +58,31 @@ class BPETokenizer: PreTrainedTokenizerModel {
}
self.bpeRanks = bpeRanks

self.tokensToIds = vocab.merging(addedTokens) { $1 }
self.tokensToIds = vocab.merging(addedTokens as [NSString : Int]) { $1 }
self.idsToTokens = Utils.invert(self.tokensToIds)

// Populate tokens
if let unknownToken = TokenizerModel.unknownToken(from: tokenizerConfig) {
self.unknownToken = unknownToken
self.unknownTokenId = self.tokensToIds[unknownToken]
self.unknownTokenId = self.tokensToIds[unknownToken as NSString]
} else {
self.unknownToken = nil
self.unknownTokenId = nil
}

eosToken = tokenizerConfig.eosToken?.stringValue
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken!]
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString]

bosToken = tokenizerConfig.bosToken?.stringValue
bosTokenId = bosToken == nil ? nil : tokensToIds[bosToken!]
bosTokenId = bosToken == nil ? nil : tokensToIds[bosToken! as NSString]
}

func convertTokenToId(_ token: String) -> Int? {
return tokensToIds[token] ?? self.unknownTokenId
return tokensToIds[token as NSString] ?? self.unknownTokenId
}

func convertIdToToken(_ id: Int) -> String? {
return idsToTokens[id]
return idsToTokens[id] as String?
}

func byteEncode(text: String) -> [String] {
Expand Down Expand Up @@ -162,7 +164,7 @@ class BPETokenizer: PreTrainedTokenizerModel {
var tokens: [String] = []
let bpeTokens = self.bpe(token: text).split(separator: " ").map { String($0) }
for token in bpeTokens {
if let _ = tokensToIds[token] {
if convertTokenToId(token) != unknownTokenId {
tokens.append(token)
} else {
// TODO: if config.byte_fallback is False, append the unknown token instead
Expand Down
23 changes: 17 additions & 6 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,23 @@ public class PreTrainedTokenizer: Tokenizer {
}
}

let addedTokensRegexString = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in
guard let content = addedToken.content?.stringValue else { return nil }
let prefix = (addedToken.lstrip?.boolValue ?? false ? #"\s*"# : "")
let suffix = (addedToken.rstrip?.boolValue ?? false ? #"\s*"# : "")
let token = NSRegularExpression.escapedPattern(for: content)
return "\(prefix)(\(token))\(suffix)"
// Convert to tuples for easier access, then sort by length (descending) to avoid early partial matches
// (https://github.com/xenova/transformers.js/commit/c305c3824f628f1f02806a6310bd3b18b0f7f8f5)
let unwrappedAddedTokens : [(content: String, prefix: Bool, suffix: Bool)] = (tokenizerData.addedTokens?.arrayValue ?? []).compactMap { addedToken in
guard let content = addedToken.content?.stringValue else { return nil }
let prefix = addedToken.lstrip?.boolValue ?? false
let suffix = addedToken.rstrip?.boolValue ?? false
return (content: content, prefix: prefix, suffix: suffix)
}.sorted {
$0.content.count > $1.content.count
}

// then concatenate into regular expression
let addedTokensRegexString = unwrappedAddedTokens.map {
let token = NSRegularExpression.escapedPattern(for: $0.content)
let prefix = $0.prefix ? #"\s*"# : ""
let suffix = $0.suffix ? #"\s*"# : ""
return "\(prefix)(\(token))\(suffix)"
}.joined(separator: "|")
addedTokensRegex = try? NSRegularExpression(pattern: addedTokensRegexString, options: [])

Expand Down
18 changes: 9 additions & 9 deletions Sources/Tokenizers/UnigramTokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
public var unknownToken: String? { unknownPiece.token }

let minScore: Float
let tokensToIds: [String: Int]
let tokensToIds: [NSString: Int]

let bosToken: String? = " "
let bosTokenId: Int?
let eosToken: String?
Expand Down Expand Up @@ -63,20 +63,20 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
self.unknownTokenId = unknownTokenId
self.unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10)

tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token }.enumerated().map { ($1, $0) })
bosTokenId = tokensToIds[bosToken!] // May be nil
tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token as NSString }.enumerated().map { ($1, $0) })
bosTokenId = tokensToIds[bosToken! as NSString] // May be nil

eosToken = tokenizerConfig.eosToken?.stringValue
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken!]
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken! as NSString]

trie = Trie()
trie.append(contentsOf: vocab.map { $0.token })

// TODO: set fuse_unk to true
}

func convertTokenToId(_ token: String) -> Int? {
return tokensToIds[token] ?? self.unknownTokenId
return tokensToIds[token as NSString] ?? self.unknownTokenId
}

func convertIdToToken(_ id: Int) -> String? {
Expand All @@ -95,7 +95,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel {

let beginIndex = sentence.index(sentence.startIndex, offsetBy: beginPos)
for token in trie.commonPrefixSearchIterator(sentence[beginIndex...]).map({ String($0) }) {
guard let tokenId = tokensToIds[token] else { fatalError("Token not in vocab: \(token)") }
guard let tokenId = tokensToIds[token as NSString] else { fatalError("Token not in vocab: \(token)") }
let tokenScore = vocab[tokenId].score
lattice.insert(startOffset: beginPos, length: token.count, score: tokenScore, tokenId: tokenId)
if !hasSingleNode && token.count == mblen {
Expand Down
21 changes: 21 additions & 0 deletions Tests/HubTests/HubTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,25 @@ class HubTests: XCTestCase {
XCTFail("Cannot download test configuration from the Hub: \(error)")
}
}

func testConfigUnicode() {
// These are two different characters
let json = "{\"vocab\": {\"\": 1, \"à\": 2}}"
let data = json.data(using: .utf8)
let dict = try! JSONSerialization.jsonObject(with: data!, options: []) as! [NSString: Any]
let config = Config(dict)

let vocab_nsdict = config.dictionary["vocab"] as! NSDictionary
let vocab_nsstring = config.dictionary["vocab"] as! [NSString: Int]
let vocab = config.vocab!.dictionary

XCTAssertEqual(vocab_nsdict.count, 2)
XCTAssertEqual(vocab_nsstring.count, 2)
XCTAssertEqual(vocab.count, 2)

// This is expected because, unlike with NSString, String comparison uses the canonical Unicode representation
// https://developer.apple.com/documentation/swift/string#Modifying-and-Comparing-Strings
let vocab_dict = config.dictionary["vocab"] as! [String: Int]
XCTAssertNotEqual(vocab_dict.count, 2)
}
}
15 changes: 12 additions & 3 deletions Tests/TokenizersTests/AddedTokensTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@ import Hub

class AddedTokensTests: XCTestCase {
func testPhiAddedTokens() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "mlx-community/Phi-3-mini-128k-instruct-4bit")
let tokenizer = try await AutoTokenizer.from(pretrained: "microsoft/Phi-3-mini-128k-instruct")
let inputIds = tokenizer("This is the <|end|>. My only friend, the <|end|>")
XCTAssertEqual(inputIds, [1, 910, 338, 278, 29871, 32007, 29889, 1619, 871, 5121, 29892, 278, 29871, 32007])
XCTAssertEqual(inputIds, [910, 338, 278, 29871, 32007, 29889, 1619, 871, 5121, 29892, 278, 29871, 32007])

let decoded = tokenizer.decode(tokens: inputIds)
XCTAssertEqual(decoded, "<s> This is the <|end|>. My only friend, the <|end|>")
XCTAssertEqual(decoded, "This is the <|end|>. My only friend, the <|end|>")
}

func testGemmaAddedTokens() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/gemma-tokenizer")
let inputIds = tokenizer("This\n\nis\na\ntest.")
XCTAssertEqual(inputIds, [2, 1596, 109, 502, 108, 235250, 108, 2195, 235265])

let decoded = tokenizer.decode(tokens: inputIds)
XCTAssertEqual(decoded, "<bos>This\n\nis\na\ntest.")
}

func testSplitWithCaptureGroups() {
Expand Down
1 change: 1 addition & 0 deletions Tests/TokenizersTests/Resources/gemma_encoded.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"text": "Fatouville-Grestain est une commune du Nord-Ouest du d\u00e9partement de l'Eure situ\u00e9e au \nbord de l'estuaire de la Seine et \u00e0 proximit\u00e9 du d\u00e9partement du Calvados. Selon l'atlas des paysages \nde Haute-Normandie, elle appartient \u00e0 la r\u00e9gion naturelle du Lieuvin. Toutefois, l'Agreste, le service \nde la statistique et de la prospective du minist\u00e8re de l'Agriculture, de l'Agroalimentaire et de la For\u00eat, \nla classe au sein du pays d'Auge (en tant que r\u00e9gion agricole).La commune est \u00e0 moins de dix kilom\u00e8tres \u00e0 \nl'est de Honfleur, \u00e0 autant de Beuzeville et \u00e0 environ dix-sept kilom\u00e8tres de Pont-Audemer.", "bpe_tokens": ["Fat", "ou", "ville", "-", "G", "rest", "ain", "\u2581est", "\u2581une", "\u2581commune", "\u2581du", "\u2581Nord", "-", "Ouest", "\u2581du", "\u2581d\u00e9partement", "\u2581de", "\u2581l", "'", "Eure", "\u2581situ\u00e9e", "\u2581au", "\u2581", "\n", "bord", "\u2581de", "\u2581l", "'", "est", "uaire", "\u2581de", "\u2581la", "\u2581Seine", "\u2581et", "\u2581\u00e0", "\u2581proximit\u00e9", "\u2581du", "\u2581d\u00e9partement", "\u2581du", "\u2581Cal", "vados", ".", "\u2581Selon", "\u2581l", "'", "atlas", "\u2581des", "\u2581paysages", "\u2581", "\n", "de", "\u2581Haute", "-", "Norman", "die", ",", "\u2581elle", "\u2581appartient", "\u2581\u00e0", "\u2581la", "\u2581r\u00e9gion", "\u2581naturelle", "\u2581du", "\u2581Lieu", "vin", ".", "\u2581Toutefois", ",", "\u2581l", "'", "Ag", "reste", ",", "\u2581le", "\u2581service", "\u2581", "\n", "de", "\u2581la", "\u2581statistique", "\u2581et", "\u2581de", "\u2581la", "\u2581prospective", "\u2581du", "\u2581minist\u00e8re", "\u2581de", "\u2581l", "'", "Agriculture", ",", "\u2581de", "\u2581l", "'", "Agro", "alimenta", "ire", "\u2581et", "\u2581de", "\u2581la", "\u2581For", "\u00eat", ",", "\u2581", "\n", "la", "\u2581classe", "\u2581au", "\u2581sein", "\u2581du", "\u2581pays", "\u2581d", "'", "Au", "ge", "\u2581(", "en", "\u2581tant", "\u2581que", "\u2581r\u00e9gion", "\u2581agricole", ").", "La", "\u2581commune", "\u2581est", "\u2581\u00e0", "\u2581moins", "\u2581de", "\u2581dix", "\u2581kilom\u00e8tres", "\u2581\u00e0", "\u2581", "\n", "l", "'", "est", "\u2581de", "\u2581Hon", "fleur", ",", "\u2581\u00e0", "\u2581autant", "\u2581de", "\u2581Be", "uze", "ville", "\u2581et", "\u2581\u00e0", "\u2581environ", "\u2581dix", "-", "sept", "\u2581kilom\u00e8tres", "\u2581de", "\u2581Pont", "-", "Au", "de", "mer", "."], "token_ids": [2, 33690, 507, 5259, 235290, 235319, 4803, 985, 1455, 2360, 34960, 1344, 14852, 235290, 101323, 1344, 57781, 581, 533, 235303, 128985, 80493, 992, 235248, 108, 51123, 581, 533, 235303, 644, 106910, 581, 683, 53876, 1008, 1305, 72883, 1344, 57781, 1344, 2659, 119613, 235265, 86721, 533, 235303, 64117, 848, 141362, 235248, 108, 495, 70628, 235290, 74906, 3917, 235269, 11340, 133635, 1305, 683, 33927, 72277, 1344, 174959, 2964, 235265, 145673, 235269, 533, 235303, 6665, 62423, 235269, 709, 2566, 235248, 108, 495, 683, 160719, 1008, 581, 683, 40675, 1344, 85986, 581, 533, 235303, 79742, 235269, 581, 533, 235303, 166317, 104544, 844, 1008, 581, 683, 1699, 19941, 235269, 235248, 108, 522, 30739, 992, 8399, 1344, 11928, 499, 235303, 2159, 541, 591, 479, 21482, 907, 33927, 113917, 846, 2841, 34960, 1455, 1305, 15006, 581, 51102, 118516, 1305, 235248, 108, 235257, 235303, 644, 581, 9073, 129564, 235269, 1305, 54409, 581, 2065, 52172, 5259, 1008, 1305, 15265, 51102, 235290, 91012, 118516, 581, 52291, 235290, 2159, 495, 977, 235265], "decoded_text": "<bos>Fatouville-Grestain est une commune du Nord-Ouest du d\u00e9partement de l'Eure situ\u00e9e au \nbord de l'estuaire de la Seine et \u00e0 proximit\u00e9 du d\u00e9partement du Calvados. Selon l'atlas des paysages \nde Haute-Normandie, elle appartient \u00e0 la r\u00e9gion naturelle du Lieuvin. Toutefois, l'Agreste, le service \nde la statistique et de la prospective du minist\u00e8re de l'Agriculture, de l'Agroalimentaire et de la For\u00eat, \nla classe au sein du pays d'Auge (en tant que r\u00e9gion agricole).La commune est \u00e0 moins de dix kilom\u00e8tres \u00e0 \nl'est de Honfleur, \u00e0 autant de Beuzeville et \u00e0 environ dix-sept kilom\u00e8tres de Pont-Audemer."}
2 changes: 1 addition & 1 deletion Tests/TokenizersTests/Resources/tokenizer_tests.json

Large diffs are not rendered by default.

47 changes: 46 additions & 1 deletion Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class LlamaTokenizerTests: TokenizerTests {
override class var hubModelName: String? { "coreml-projects/Llama-2-7b-chat-coreml" }
override class var encodedSamplesFilename: String? { "llama_encoded" }
override class var unknownTokenId: Int? { 0 }

func testHexaEncode() async {
if let tester = Self._tester {
let tokenized = await tester.tokenizer?.tokenize(text: "\n")
XCTAssertEqual(tokenized, ["", "<0x0A>"])
}
}
}

class WhisperLargeTokenizerTests: TokenizerTests {
Expand All @@ -48,6 +55,41 @@ class T5TokenizerTests: TokenizerTests {
override class var unknownTokenId: Int? { 2 }
}

class GemmaTokenizerTests: TokenizerTests {
override class var hubModelName: String? { "pcuenq/gemma-tokenizer" }
override class var encodedSamplesFilename: String? { "gemma_encoded" }
override class var unknownTokenId: Int? { 3 }

func testUnicodeEdgeCase() async {
guard let tester = Self._tester else {
XCTFail()
return
}

// These are two different characters
let cases = ["" /* 0x61 0x300 */, "à" /* 0xe0 */]
let expected = [217138, 1305]

// These are different characters
for (s, expected) in zip(cases, expected) {
let encoded = await tester.tokenizer?.encode(text: " " + s)
XCTAssertEqual(encoded, [2, expected])
}
}
}

class GemmaUnicodeTests: XCTestCase {
func testGemmaVocab() async throws {
guard let tokenizer = try await AutoTokenizer.from(pretrained: "pcuenq/gemma-tokenizer") as? PreTrainedTokenizer else {
XCTFail()
return
}

// FIXME: This should be 256_000, I believe
XCTAssertEqual((tokenizer.model as? BPETokenizer)?.vocabCount, 255994)
}
}


struct EncodedTokenizerSamplesDataset: Decodable {
let text: String
Expand Down Expand Up @@ -156,7 +198,10 @@ class TokenizerTester {

/// Test encode and decode for a few edge cases
func testEdgeCases() async {
guard let edgeCases = edgeCases else { return }
guard let edgeCases = edgeCases else {
print("Edge cases test ignored")
return
}
guard let tokenizer = await tokenizer else { return }
for edgeCase in edgeCases {
print("Testing \(edgeCase.input)")
Expand Down

0 comments on commit 4c8cf07

Please sign in to comment.