Skip to content

Commit

Permalink
Fix: UTF8 support in Unigram tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Apr 21, 2024
1 parent 9ef46a5 commit 89fb5d9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
1 change: 1 addition & 0 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct TokenizerModel {
"Gemma" : GemmaTokenizer.self,
"GPT2" : GPT2Tokenizer.self,
"Llama" : LlamaTokenizer.self,
"Unigram" : UnigramTokenizer.self,
"T5" : T5Tokenizer.self,
"Whisper" : WhisperTokenizer.self,
"Cohere" : CohereTokenizer.self,
Expand Down
32 changes: 25 additions & 7 deletions Sources/Tokenizers/UnigramTokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,26 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
public var unknownToken: String? { unknownPiece.token }

let minScore: Float
let tokensToIds: [String: Int]
let tokensToIds: [LiteralString: Int]

let bosToken: String? = " "
let bosTokenId: Int?
let eosToken: String?
let eosTokenId: Int?

private let trie: Trie<Character>

struct LiteralString: Hashable {
let value: String

static func ==(lhs: LiteralString, rhs: LiteralString) -> Bool {
return lhs.value.compare(rhs.value, options: .literal) == .orderedSame
}

func hash(into hasher: inout Hasher) {
hasher.combine(value)
}
}

required init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String : Int]) throws {
guard let configVocab = tokenizerData.model?.vocab?.value as? [[Any]] else {
Expand All @@ -38,7 +50,10 @@ class UnigramTokenizer: PreTrainedTokenizerModel {

vocab = try configVocab.map { piece in
guard let token = piece.first as? String else { throw TokenizerError.malformedVocab }
guard let score = piece.last as? Float else { throw TokenizerError.malformedVocab }
// Immediately mapping to `Float` values will result in exception,
// when precision loss is detected. So let's convert to `Double` first.
guard let scoreDouble = piece.last as? Double else { throw TokenizerError.malformedVocab }
let score = Float(scoreDouble) // Convert Double to Float
return SentencePieceToken(token: token, score: score)
}

Expand All @@ -50,11 +65,14 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
self.unknownTokenId = unknownTokenId
self.unknownPiece = SentencePieceToken(token: vocab[unknownTokenId].token, score: minScore - 10)

tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token }.enumerated().map { ($1, $0) })
bosTokenId = tokensToIds[bosToken!] // May be nil
// Using `Dictionary(uniqueKeysWithValues:)` is the default approach for constructing the mapping.
// It, however, will use the normal `compare` function of strings and will result in collisions for different
// UTF8 strings. Instead, we should use the `a.compare(b, options: .literal) == .orderedSame`.
tokensToIds = Dictionary(uniqueKeysWithValues: vocab.map { $0.token }.enumerated().map { (LiteralString(value: $1), $0) })
bosTokenId = tokensToIds[LiteralString(value: bosToken!)] // May be nil

eosToken = tokenizerConfig.eosToken?.stringValue
eosTokenId = eosToken == nil ? nil : tokensToIds[eosToken!]
eosTokenId = eosToken == nil ? nil : tokensToIds[LiteralString(value: eosToken!)]

trie = Trie()
trie.append(contentsOf: vocab.map { $0.token })
Expand All @@ -63,7 +81,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel {
}

func convertTokenToId(_ token: String) -> Int? {
return tokensToIds[token] ?? self.unknownTokenId
return tokensToIds[LiteralString(value: token)] ?? self.unknownTokenId
}

func convertIdToToken(_ id: Int) -> String? {
Expand All @@ -82,7 +100,7 @@ class UnigramTokenizer: PreTrainedTokenizerModel {

let beginIndex = sentence.index(sentence.startIndex, offsetBy: beginPos)
for token in trie.commonPrefixSearchIterator(sentence[beginIndex...]).map({ String($0) }) {
guard let tokenId = tokensToIds[token] else { fatalError("Token not in vocab: \(token)") }
guard let tokenId = tokensToIds[LiteralString(value: token)] else { fatalError("Token not in vocab: \(token)") }
let tokenScore = vocab[tokenId].score
lattice.insert(startOffset: beginPos, length: token.count, score: tokenScore, tokenId: tokenId)
if !hasSingleNode && token.count == mblen {
Expand Down

0 comments on commit 89fb5d9

Please sign in to comment.