From 89fb5d97e1df347f9f588f62fc538dcad6fdb16c Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 20 Apr 2024 22:10:00 -0700 Subject: [PATCH] Fix: UTF8 support in Unigram tokenizer --- Sources/Tokenizers/Tokenizer.swift | 1 + Sources/Tokenizers/UnigramTokenizer.swift | 32 ++++++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/Sources/Tokenizers/Tokenizer.swift b/Sources/Tokenizers/Tokenizer.swift index 084bad2..ccbdd52 100644 --- a/Sources/Tokenizers/Tokenizer.swift +++ b/Sources/Tokenizers/Tokenizer.swift @@ -66,6 +66,7 @@ struct TokenizerModel { "Gemma" : GemmaTokenizer.self, "GPT2" : GPT2Tokenizer.self, "Llama" : LlamaTokenizer.self, + "Unigram" : UnigramTokenizer.self, "T5" : T5Tokenizer.self, "Whisper" : WhisperTokenizer.self, "Cohere" : CohereTokenizer.self, diff --git a/Sources/Tokenizers/UnigramTokenizer.swift b/Sources/Tokenizers/UnigramTokenizer.swift index 2fe754d..6e5966d 100644 --- a/Sources/Tokenizers/UnigramTokenizer.swift +++ b/Sources/Tokenizers/UnigramTokenizer.swift @@ -22,7 +22,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel { public var unknownToken: String? { unknownPiece.token } let minScore: Float - let tokensToIds: [String: Int] + let tokensToIds: [LiteralString: Int] let bosToken: String? = " " let bosTokenId: Int? @@ -30,6 +30,18 @@ class UnigramTokenizer: PreTrainedTokenizerModel { let eosTokenId: Int? private let trie: Trie + + struct LiteralString: Hashable { + let value: String + + static func ==(lhs: LiteralString, rhs: LiteralString) -> Bool { + return lhs.value.compare(rhs.value, options: .literal) == .orderedSame + } + + func hash(into hasher: inout Hasher) { + hasher.combine(value) + } + } required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws { guard let configVocab = tokenizerData.model?.vocab?.value as? [[Any]] else { @@ -38,7 +50,10 @@ class UnigramTokenizer: PreTrainedTokenizerModel { vocab = try configVocab.map { piece in guard let token = piece.first as? String else { throw TokenizerError.malformedVocab } - guard let score = piece.last as? Float else { throw TokenizerError.malformedVocab } + // Immediately mapping to `Float` values will result in exception, + // when precision loss is detected. So let's convert to `Double` first. + guard let scoreDouble = piece.last as? Double else { throw TokenizerError.malformedVocab } + let score = Float(scoreDouble) // Convert Double to Float return SentencePieceToken(token: token, score: score) } @@ -50,11 +65,14 @@ class UnigramTokenizer: PreTrainedTokenizerModel { self.unknownTokenId = unknownTokenId self.unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10) - tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token }.enumerated().map { ($1, $0) }) - bosTokenId = tokensToIds[bosToken!] // May be nil + // Using `Dictionary(uniqueKeysWithValues:)` is the default approach for constructing the mapping. + // It, however, will use the normal `compare` function of strings and will result in collisions for different + // UTF8 strings. Instead, we should use the `a.compare(b, options: .literal) == .orderedSame`. + tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token }.enumerated().map { (LiteralString(value: $1), $0) }) + bosTokenId = tokensToIds[LiteralString(value: bosToken!)] // May be nil eosToken = tokenizerConfig.eosToken?.stringValue - eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken!] + eosTokenId = eosToken == nil ? nil : tokensToIds[LiteralString(value: eosToken!)] trie = Trie() trie.append(contentsOf: vocab.map { $0.token }) @@ -63,7 +81,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel { } func convertTokenToId(_ token: String) -> Int? { - return tokensToIds[token] ?? self.unknownTokenId + return tokensToIds[LiteralString(value: token)] ?? self.unknownTokenId } func convertIdToToken(_ id: Int) -> String? { @@ -82,7 +100,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel { let beginIndex = sentence.index(sentence.startIndex, offsetBy: beginPos) for token in trie.commonPrefixSearchIterator(sentence[beginIndex...]).map({ String($0) }) { - guard let tokenId = tokensToIds[token] else { fatalError("Token not in vocab: \(token)") } + guard let tokenId = tokensToIds[LiteralString(value: token)] else { fatalError("Token not in vocab: \(token)") } let tokenScore = vocab[tokenId].score lattice.insert(startOffset: beginPos, length: token.count, score: tokenScore, tokenId: tokenId) if !hasSingleNode && token.count == mblen {