Skip to content

Commit

Permalink
Add skipSpecialTokens option to Tokenizer.decode (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
finnvoor authored Dec 26, 2024
1 parent fc95ce1 commit 44e2c04
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
18 changes: 15 additions & 3 deletions Sources/Tokenizers/Tokenizer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ public protocol Tokenizer {

/// Decode
func decode(tokens: [Int]) -> String
func decode(tokens: [Int], skipSpecialTokens: Bool) -> String

func convertTokenToId(_ token: String) -> Int?
func convertTokensToIds(_ tokens: [String]) -> [Int?]
Expand Down Expand Up @@ -150,6 +151,10 @@ public extension Tokenizer {
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
encode(text: text, addSpecialTokens: addSpecialTokens)
}

func decode(tokens: [Int]) -> String {
decode(tokens: tokens, skipSpecialTokens: false)
}

func convertTokensToIds(_ tokens: [String]) -> [Int?] {
return tokens.map { convertTokenToId($0) }
Expand Down Expand Up @@ -315,10 +320,17 @@ public class PreTrainedTokenizer: Tokenizer {
return encode(text: text, addSpecialTokens: true)
}

/// Decode
public func decode(tokens: [Int]) -> String {
public func decode(tokens: [Int], skipSpecialTokens: Bool = false) -> String {
// IDs to tokens
let tokenStrings = tokens.compactMap { model.convertIdToToken($0) }
let tokenStrings: [String]
if skipSpecialTokens {
let specialTokenIDs = Set(specialTokens.values)
tokenStrings = tokens
.filter { !specialTokenIDs.contains($0) }
.compactMap { model.convertIdToToken($0) }
} else {
tokenStrings = tokens.compactMap { model.convertIdToToken($0) }
}
let decoded = decodeTokens(tokenStrings)
// At this point we should have a single String
return cleanUp(text: decoded.joined(separator: ""))
Expand Down
4 changes: 4 additions & 0 deletions Tests/TokenizersTests/TokenizerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ class TokenizerTester {
tokenizer.decode(tokens: edgeCase.encoded.input_ids),
edgeCase.decoded_with_special
)
XCTAssertEqual(
tokenizer.decode(tokens: edgeCase.encoded.input_ids, skipSpecialTokens: true),
edgeCase.decoded_without_special
)
}
}

Expand Down

0 comments on commit 44e2c04

Please sign in to comment.