Skip to content

Commit

Permalink
add StripNormalizer (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzbac authored Oct 2, 2024
1 parent 71963c3 commit a7a61a2
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 45 deletions.
107 changes: 67 additions & 40 deletions Sources/Tokenizers/Normalizer.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// Normalizer.swift
//
//
//
// Created by Pedro Cuenca on 17/7/23.
//
Expand All @@ -11,7 +11,7 @@ import Hub
public protocol Normalizer {
func normalize(text: String) -> String
func callAsFunction(text: String) -> String

init(config: Config)
}

Expand All @@ -33,6 +33,7 @@ enum NormalizerType: String {
case Bert
case Precompiled
case StripAccents
case Strip
case Unknown = ""
}

Expand All @@ -43,29 +44,32 @@ struct NormalizerFactory {
let type = NormalizerType(rawValue: typeName)
switch type {
case .Sequence: return NormalizerSequence(config: config)
case .Prepend : return PrependNormalizer(config: config)
case .Replace : return ReplaceNormalizer(config: config)
case .Lowercase : return LowercaseNormalizer(config: config)
case .NFD : return NFDNormalizer(config: config)
case .NFC : return NFCNormalizer(config: config)
case .NFKD : return NFKDNormalizer(config: config)
case .NFKC : return NFKCNormalizer(config: config)
case .Bert : return BertNormalizer(config: config)
case .Precompiled : return PrecompiledNormalizer(config: config)
case .StripAccents : return StripAccentsNormalizer(config: config)
default : fatalError("Unsupported Normalizer type: \(typeName)")
case .Prepend: return PrependNormalizer(config: config)
case .Replace: return ReplaceNormalizer(config: config)
case .Lowercase: return LowercaseNormalizer(config: config)
case .NFD: return NFDNormalizer(config: config)
case .NFC: return NFCNormalizer(config: config)
case .NFKD: return NFKDNormalizer(config: config)
case .NFKC: return NFKCNormalizer(config: config)
case .Bert: return BertNormalizer(config: config)
case .Precompiled: return PrecompiledNormalizer(config: config)
case .StripAccents: return StripAccentsNormalizer(config: config)
case .Strip: return StripNormalizer(config: config)
default: fatalError("Unsupported Normalizer type: \(typeName)")
}
}
}

class NormalizerSequence: Normalizer {
let normalizers: [Normalizer]

required public init(config: Config) {
guard let configs = config.normalizers?.arrayValue else { fatalError("No normalizers in Sequence") }
guard let configs = config.normalizers?.arrayValue else {
fatalError("No normalizers in Sequence")
}
normalizers = configs.compactMap { NormalizerFactory.fromConfig(config: $0) }
}

public func normalize(text: String) -> String {
normalizers.reduce(text) { current, normalizer in
normalizer(text: current)
Expand All @@ -75,23 +79,23 @@ class NormalizerSequence: Normalizer {

class PrependNormalizer: Normalizer {
let prepend: String

required public init(config: Config) {
prepend = config.prepend?.stringValue ?? ""
}

public func normalize(text: String) -> String {
return prepend + text
}
}

class ReplaceNormalizer: Normalizer {
let pattern: StringReplacePattern?

required public init(config: Config) {
self.pattern = StringReplacePattern.from(config: config)
}

public func normalize(text: String) -> String {
guard let pattern = pattern else { return text }
return pattern.replace(text)
Expand All @@ -106,7 +110,7 @@ class LowercaseNormalizer: Normalizer {
}
}

class NFDNormalizer: Normalizer {
class NFDNormalizer: Normalizer {
required public init(config: Config) {}

public func normalize(text: String) -> String {
Expand All @@ -122,7 +126,7 @@ class NFCNormalizer: Normalizer {
}
}

class NFKDNormalizer: Normalizer {
class NFKDNormalizer: Normalizer {
required init(config: Config) {}

func normalize(text: String) -> String {
Expand Down Expand Up @@ -172,15 +176,13 @@ class BertNormalizer: Normalizer {
private func cleanText(text: String) -> String {
text.map { c in
guard let scalar = c.unicodeScalars.first,
scalar.value != 0x0,
scalar.value != 0xFFFD,
!isControl(scalar)
scalar.value != 0x0,
scalar.value != 0xFFFD,
!isControl(scalar)
else { return "\(c)" }

// Replace whitespace: \t, \n, \r
if scalar.value == 0x009 ||
scalar.value == 0x00A ||
scalar.value == 0x000D {
if scalar.value == 0x009 || scalar.value == 0x00A || scalar.value == 0x000D {
return " "
} else {
return "\(c)"
Expand All @@ -201,29 +203,27 @@ class BertNormalizer: Normalizer {
}

private func isOther(_ c: Unicode.GeneralCategory) -> Bool {
c == .control ||
c == .format ||
c == .surrogate ||
c == .privateUse ||
c == .unassigned
c == .control || c == .format || c == .surrogate || c == .privateUse || c == .unassigned
}

private func handleChineseChars(text: String) -> String {
text.map { c in
if let scalar = c.unicodeScalars.first, Utils.isChineseChar(scalar) {
" \(c) "
} else {
"\(c)"
"\(c)"
}
}
.joined()
}

private func stripAccents(text: String) -> String {
text.decomposedStringWithCanonicalMapping
.filter { $0.unicodeScalars.allSatisfy { scalar in
!(0x0300 <= scalar.value && scalar.value <= 0x036F)
}}
.filter {
$0.unicodeScalars.allSatisfy { scalar in
!(0x0300 <= scalar.value && scalar.value <= 0x036F)
}
}
}
}

Expand All @@ -245,7 +245,8 @@ class PrecompiledNormalizer: Normalizer {
case 0x0001...0x0008, 0x000B, 0x000E...0x001F, 0x007F, 0x008F, 0x009F:
// Non-printing control characters
output.append("")
case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581, 0xFEFF, 0xFFFD:
case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581,
0xFEFF, 0xFFFD:
// Separators
output.append(" ")
case 0xFF5E:
Expand All @@ -257,7 +258,8 @@ class PrecompiledNormalizer: Normalizer {
}

if hasFullwidthTilde {
return output
return
output
.split(by: "\u{FF5E}")
.map({ $0.precomposedStringWithCompatibilityMapping })
.joined(separator: "\u{FF5E}")
Expand All @@ -275,6 +277,30 @@ class StripAccentsNormalizer: Normalizer {
}
}

class StripNormalizer: Normalizer {
let leftStrip: Bool
let rightStrip: Bool

required init(config: Config) {
self.leftStrip = config.stripLeft?.boolValue ?? true
self.rightStrip = config.stripRight?.boolValue ?? true
}

func normalize(text: String) -> String {
var result = text

if leftStrip {
result = String(result.drop(while: { $0.isWhitespace }))
}

if rightStrip {
result = String(result.reversed().drop(while: { $0.isWhitespace }).reversed())
}

return result
}
}

enum StringReplacePattern {
case regexp(regexp: NSRegularExpression, replacement: String)
case string(pattern: String, replacement: String)
Expand All @@ -285,7 +311,8 @@ extension StringReplacePattern {
switch self {
case .regexp(let regexp, let replacement):
let range = NSRange(text.startIndex..., in: text)
let replaced = regexp.stringByReplacingMatches(in: text, options: [], range: range, withTemplate: replacement)
let replaced = regexp.stringByReplacingMatches(
in: text, options: [], range: range, withTemplate: replacement)
return replaced
case .string(let toReplace, let replacement):
return text.replacingOccurrences(of: toReplace, with: replacement)
Expand Down
40 changes: 35 additions & 5 deletions Tests/NormalizerTests/NormalizerTests.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import XCTest
@testable import Tokenizers

@testable import Hub
@testable import Tokenizers

class NormalizerTests: XCTestCase {

Expand All @@ -22,7 +23,7 @@ class NormalizerTests: XCTestCase {
let normalizer = LowercaseNormalizer(config: config)
XCTAssertEqual(normalizer.normalize(text: arg), expect)
}

let config = Config(["type": NormalizerType.Lowercase.rawValue])
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? LowercaseNormalizer)
}
Expand Down Expand Up @@ -68,11 +69,11 @@ class NormalizerTests: XCTestCase {
let normalizer = NFCNormalizer(config: config)
XCTAssertEqual(normalizer.normalize(text: arg), expect)
}

let config = Config(["type": NormalizerType.NFC.rawValue])
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? NFCNormalizer)
}

func testNFKDNormalizer() {
let testCases: [(String, String)] = [
("café", "cafe\u{301}"),
Expand Down Expand Up @@ -118,7 +119,7 @@ class NormalizerTests: XCTestCase {
let config = Config(["type": NormalizerType.NFKC.rawValue])
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? NFKCNormalizer)
}

func testBertNormalizer() {
let testCases: [(String, String)] = [
("Café", "café"),
Expand All @@ -141,6 +142,7 @@ class NormalizerTests: XCTestCase {
let config = Config(["type": NormalizerType.Bert.rawValue])
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? BertNormalizer)
}

func testPrecompiledNormalizer() {
let testCases: [(String, String)] = [
("café", "café"),
Expand Down Expand Up @@ -188,4 +190,32 @@ class NormalizerTests: XCTestCase {
let config = Config(["type": NormalizerType.StripAccents.rawValue])
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? StripAccentsNormalizer)
}

func testStripNormalizer() {
let testCases: [(String, String, Bool, Bool)] = [
(" hello ", "hello", true, true),
(" hello ", "hello ", true, false),
(" hello ", " hello", false, true),
(" hello ", " hello ", false, false),
("\t\nHello\t\n", "Hello", true, true),
(" ", "", true, true),
("", "", true, true),
]

for (input, expected, leftStrip, rightStrip) in testCases {
let config = Config([
"type": NormalizerType.Strip.rawValue,
"stripLeft": leftStrip,
"stripRight": rightStrip,
])
let normalizer = StripNormalizer(config: config)
XCTAssertEqual(
normalizer.normalize(text: input), expected,
"Failed for input: '\(input)', leftStrip: \(leftStrip), rightStrip: \(rightStrip)")
}

let config = Config(["type": NormalizerType.Strip.rawValue])
XCTAssertNotNil(NormalizerFactory.fromConfig(config: config) as? StripNormalizer)
}

}

0 comments on commit a7a61a2

Please sign in to comment.