Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read model weights #91

Merged
merged 6 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ let package = Package(
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
.testTarget(name: "HubTests", dependencies: ["Hub"]),
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]),
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"])
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]),
]
)
2 changes: 1 addition & 1 deletion Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class LanguageModel {
var tokenizerConfig: Config?
var tokenizerData: Config
}

private var configuration: LanguageModelConfigurationFromHub? = nil
private var _tokenizer: Tokenizer? = nil

Expand Down
88 changes: 88 additions & 0 deletions Sources/TensorUtils/Weights.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import CoreML


public struct Weights {

enum WeightsError: Error {
case notSupported(message: String)
case invalidFile
}

private let dictionary: [String: MLMultiArray]

init(_ dictionary: [String: MLMultiArray]) {
self.dictionary = dictionary
}

subscript(key: String) -> MLMultiArray? { dictionary[key] }

public static func from(fileURL: URL) throws -> Weights {
guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension)
else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") }

let data = try Data(contentsOf: fileURL, options: .mappedIfSafe)
switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) {
case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf"))
case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx")
default: return try Safetensor.from(data: data)
}
}
}

struct Safetensor {

typealias Error = Weights.WeightsError

struct Header {

struct Offset: Decodable {
let dataOffsets: [Int]?
let dtype: String?
let shape: [Int]?

/// Unsupported: "I8", "U8", "I16", "U16", "BF16"
var dataType: MLMultiArrayDataType? {
get throws {
switch dtype {
case "I32", "U32": .int32
case "F16": .float16
case "F32": .float32
case "F64", "U64": .float64
default: throw Error.notSupported(message: "\(dtype ?? "empty")")
}
}
}
}

static func from(data: Data) throws -> [String: Offset?] {
let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
return try decoder.decode([String: Offset?].self, from: data)
}
}

static func from(data: Data) throws -> Weights {
let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) })
guard headerSize < data.count else { throw Error.invalidFile }
let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8)))

var dict = [String: MLMultiArray]()
for (key, point) in header {
guard let offsets = point?.dataOffsets, offsets.count == 2,
let shape = point?.shape as? [NSNumber],
let dType = try point?.dataType
else { continue }

let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in
acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0)
}
let start = 8 + offsets[0] + headerSize
let end = 8 + offsets[1] + headerSize
let tensorData = data.subdata(in: start..<end) as NSData
let ptr = UnsafeMutableRawPointer(mutating: tensorData.bytes)
dict[key] = try MLMultiArray(dataPointer: ptr, shape: shape, dataType: dType, strides: strides)
}

return Weights(dict)
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
104 changes: 104 additions & 0 deletions Tests/TensorUtilsTests/WeightsTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
@testable import TensorUtils
@testable import Hub
import XCTest

class WeightsTests: XCTestCase {

let downloadDestination: URL = {
FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests")
}()

var hubApi: HubApi { HubApi(downloadBase: downloadDestination) }

func testLoadWeightsFromFileURL() async throws {
let repo = "google/bert_uncased_L-2_H-128_A-2"
let modelDir = try await hubApi.snapshot(from: repo, matching: ["config.json", "model.safetensors"])

let files = try FileManager.default.contentsOfDirectory(at: modelDir, includingPropertiesForKeys: [.isReadableKey])
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "config.json" }))
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "model.safetensors" }))

let modelFile = modelDir.appending(path: "/model.safetensors")
let weights = try Weights.from(fileURL: modelFile)
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.dataType, .float32)
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.count, 128)
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.shape.count, 1)

XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.dataType, .float32)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.count, 3906816)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.shape.count, 2)

XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[0, 0]].floatValue, -0.0041, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[3, 4]].floatValue, 0.0037, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[5, 3]].floatValue, -0.5371, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[7, 8]].floatValue, 0.0460, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[11, 7]].floatValue, -0.0058, accuracy: 1e-3)
}

func testSafetensorReadTensor1D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-1d-int32", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .int32)
XCTAssertEqual(tensor[[0]], 1)
XCTAssertEqual(tensor[[1]], 2)
XCTAssertEqual(tensor[[2]], 3)
}

func testSafetensorReadTensor2D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-2d-float64", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .float64)
XCTAssertEqual(tensor[[0, 0]], 1)
XCTAssertEqual(tensor[[0, 1]], 2)
XCTAssertEqual(tensor[[0, 2]], 3)
XCTAssertEqual(tensor[[1, 0]], 24)
XCTAssertEqual(tensor[[1, 1]], 25)
XCTAssertEqual(tensor[[1, 2]], 26)
}

func testSafetensorReadTensor3D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-3d-float32", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .float32)
XCTAssertEqual(tensor[[0, 0, 0]], 22)
XCTAssertEqual(tensor[[0, 0, 1]], 23)
XCTAssertEqual(tensor[[0, 0, 2]], 24)
XCTAssertEqual(tensor[[0, 1, 0]], 11)
XCTAssertEqual(tensor[[0, 1, 1]], 12)
XCTAssertEqual(tensor[[0, 1, 2]], 13)
XCTAssertEqual(tensor[[1, 0, 0]], 2)
XCTAssertEqual(tensor[[1, 0, 1]], 3)
XCTAssertEqual(tensor[[1, 0, 2]], 4)
XCTAssertEqual(tensor[[1, 1, 0]], 1)
XCTAssertEqual(tensor[[1, 1, 1]], 2)
XCTAssertEqual(tensor[[1, 1, 2]], 3)
}

func testSafetensorReadTensor4D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-4d-float32", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .float32)
XCTAssertEqual(tensor[[0, 0, 0, 0]], 11)
XCTAssertEqual(tensor[[0, 0, 0, 1]], 12)
XCTAssertEqual(tensor[[0, 0, 0, 2]], 13)
XCTAssertEqual(tensor[[0, 0, 1, 0]], 1)
XCTAssertEqual(tensor[[0, 0, 1, 1]], 2)
XCTAssertEqual(tensor[[0, 0, 1, 2]], 3)
XCTAssertEqual(tensor[[0, 0, 2, 0]], 4)
XCTAssertEqual(tensor[[0, 0, 2, 1]], 5)
XCTAssertEqual(tensor[[0, 0, 2, 2]], 6)
XCTAssertEqual(tensor[[1, 0, 0, 0]], 22)
XCTAssertEqual(tensor[[1, 0, 0, 1]], 23)
XCTAssertEqual(tensor[[1, 0, 0, 2]], 24)
XCTAssertEqual(tensor[[1, 0, 1, 0]], 15)
XCTAssertEqual(tensor[[1, 0, 1, 1]], 16)
XCTAssertEqual(tensor[[1, 0, 1, 2]], 17)
XCTAssertEqual(tensor[[1, 0, 2, 0]], 17)
XCTAssertEqual(tensor[[1, 0, 2, 1]], 18)
XCTAssertEqual(tensor[[1, 0, 2, 2]], 19)
}
}