Skip to content

Commit

Permalink
Read model weights (#91)
Browse files Browse the repository at this point in the history
* Read weights from safetensors

* Check file extension

* Deintegrate Safetensor

  * Separate Safetensor from the weights
  * Rename test tensors to include type
  * Rename ModelWeights to Weights
  * Throw error for unsupported data types
  * Remove model weights from LanguageModel.Configurations

* Move Weights to TensorUtils

* Specify filenames to download in tests.

* Make the weights optional and public

Enable safe access to keys.
  • Loading branch information
shavit authored Dec 14, 2024
1 parent 37e234e commit d656ad4
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 3 deletions.
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ let package = Package(
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources"), .process("Vocabs")]),
.testTarget(name: "HubTests", dependencies: ["Hub"]),
.testTarget(name: "PreTokenizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils"]),
.testTarget(name: "TensorUtilsTests", dependencies: ["TensorUtils", "Models", "Hub"], resources: [.process("Resources")]),
.testTarget(name: "NormalizerTests", dependencies: ["Tokenizers", "Hub"]),
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"])
.testTarget(name: "PostProcessorTests", dependencies: ["Tokenizers", "Hub"]),
]
)
2 changes: 1 addition & 1 deletion Sources/Models/LanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class LanguageModel {
var tokenizerConfig: Config?
var tokenizerData: Config
}

private var configuration: LanguageModelConfigurationFromHub? = nil
private var _tokenizer: Tokenizer? = nil

Expand Down
88 changes: 88 additions & 0 deletions Sources/TensorUtils/Weights.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import CoreML


public struct Weights {

enum WeightsError: Error {
case notSupported(message: String)
case invalidFile
}

private let dictionary: [String: MLMultiArray]

init(_ dictionary: [String: MLMultiArray]) {
self.dictionary = dictionary
}

subscript(key: String) -> MLMultiArray? { dictionary[key] }

public static func from(fileURL: URL) throws -> Weights {
guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension)
else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") }

let data = try Data(contentsOf: fileURL, options: .mappedIfSafe)
switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) {
case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf"))
case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx")
default: return try Safetensor.from(data: data)
}
}
}

struct Safetensor {

typealias Error = Weights.WeightsError

struct Header {

struct Offset: Decodable {
let dataOffsets: [Int]?
let dtype: String?
let shape: [Int]?

/// Unsupported: "I8", "U8", "I16", "U16", "BF16"
var dataType: MLMultiArrayDataType? {
get throws {
switch dtype {
case "I32", "U32": .int32
case "F16": .float16
case "F32": .float32
case "F64", "U64": .float64
default: throw Error.notSupported(message: "\(dtype ?? "empty")")
}
}
}
}

static func from(data: Data) throws -> [String: Offset?] {
let decoder = JSONDecoder()
decoder.keyDecodingStrategy = .convertFromSnakeCase
return try decoder.decode([String: Offset?].self, from: data)
}
}

static func from(data: Data) throws -> Weights {
let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) })
guard headerSize < data.count else { throw Error.invalidFile }
let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8)))

var dict = [String: MLMultiArray]()
for (key, point) in header {
guard let offsets = point?.dataOffsets, offsets.count == 2,
let shape = point?.shape as? [NSNumber],
let dType = try point?.dataType
else { continue }

let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in
acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0)
}
let start = 8 + offsets[0] + headerSize
let end = 8 + offsets[1] + headerSize
let tensorData = data.subdata(in: start..<end) as NSData
let ptr = UnsafeMutableRawPointer(mutating: tensorData.bytes)
dict[key] = try MLMultiArray(dataPointer: ptr, shape: shape, dataType: dType, strides: strides)
}

return Weights(dict)
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
104 changes: 104 additions & 0 deletions Tests/TensorUtilsTests/WeightsTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
@testable import TensorUtils
@testable import Hub
import XCTest

class WeightsTests: XCTestCase {

let downloadDestination: URL = {
FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests")
}()

var hubApi: HubApi { HubApi(downloadBase: downloadDestination) }

func testLoadWeightsFromFileURL() async throws {
let repo = "google/bert_uncased_L-2_H-128_A-2"
let modelDir = try await hubApi.snapshot(from: repo, matching: ["config.json", "model.safetensors"])

let files = try FileManager.default.contentsOfDirectory(at: modelDir, includingPropertiesForKeys: [.isReadableKey])
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "config.json" }))
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "model.safetensors" }))

let modelFile = modelDir.appending(path: "/model.safetensors")
let weights = try Weights.from(fileURL: modelFile)
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.dataType, .float32)
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.count, 128)
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.shape.count, 1)

XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.dataType, .float32)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.count, 3906816)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.shape.count, 2)

XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[0, 0]].floatValue, -0.0041, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[3, 4]].floatValue, 0.0037, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[5, 3]].floatValue, -0.5371, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[7, 8]].floatValue, 0.0460, accuracy: 1e-3)
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[11, 7]].floatValue, -0.0058, accuracy: 1e-3)
}

func testSafetensorReadTensor1D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-1d-int32", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .int32)
XCTAssertEqual(tensor[[0]], 1)
XCTAssertEqual(tensor[[1]], 2)
XCTAssertEqual(tensor[[2]], 3)
}

func testSafetensorReadTensor2D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-2d-float64", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .float64)
XCTAssertEqual(tensor[[0, 0]], 1)
XCTAssertEqual(tensor[[0, 1]], 2)
XCTAssertEqual(tensor[[0, 2]], 3)
XCTAssertEqual(tensor[[1, 0]], 24)
XCTAssertEqual(tensor[[1, 1]], 25)
XCTAssertEqual(tensor[[1, 2]], 26)
}

func testSafetensorReadTensor3D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-3d-float32", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .float32)
XCTAssertEqual(tensor[[0, 0, 0]], 22)
XCTAssertEqual(tensor[[0, 0, 1]], 23)
XCTAssertEqual(tensor[[0, 0, 2]], 24)
XCTAssertEqual(tensor[[0, 1, 0]], 11)
XCTAssertEqual(tensor[[0, 1, 1]], 12)
XCTAssertEqual(tensor[[0, 1, 2]], 13)
XCTAssertEqual(tensor[[1, 0, 0]], 2)
XCTAssertEqual(tensor[[1, 0, 1]], 3)
XCTAssertEqual(tensor[[1, 0, 2]], 4)
XCTAssertEqual(tensor[[1, 1, 0]], 1)
XCTAssertEqual(tensor[[1, 1, 1]], 2)
XCTAssertEqual(tensor[[1, 1, 2]], 3)
}

func testSafetensorReadTensor4D() throws {
let modelFile = Bundle.module.url(forResource: "tensor-4d-float32", withExtension: "safetensors")!
let weights: Weights = try Weights.from(fileURL: modelFile)
let tensor = weights["embedding"]!
XCTAssertEqual(tensor.dataType, .float32)
XCTAssertEqual(tensor[[0, 0, 0, 0]], 11)
XCTAssertEqual(tensor[[0, 0, 0, 1]], 12)
XCTAssertEqual(tensor[[0, 0, 0, 2]], 13)
XCTAssertEqual(tensor[[0, 0, 1, 0]], 1)
XCTAssertEqual(tensor[[0, 0, 1, 1]], 2)
XCTAssertEqual(tensor[[0, 0, 1, 2]], 3)
XCTAssertEqual(tensor[[0, 0, 2, 0]], 4)
XCTAssertEqual(tensor[[0, 0, 2, 1]], 5)
XCTAssertEqual(tensor[[0, 0, 2, 2]], 6)
XCTAssertEqual(tensor[[1, 0, 0, 0]], 22)
XCTAssertEqual(tensor[[1, 0, 0, 1]], 23)
XCTAssertEqual(tensor[[1, 0, 0, 2]], 24)
XCTAssertEqual(tensor[[1, 0, 1, 0]], 15)
XCTAssertEqual(tensor[[1, 0, 1, 1]], 16)
XCTAssertEqual(tensor[[1, 0, 1, 2]], 17)
XCTAssertEqual(tensor[[1, 0, 2, 0]], 17)
XCTAssertEqual(tensor[[1, 0, 2, 1]], 18)
XCTAssertEqual(tensor[[1, 0, 2, 2]], 19)
}
}

0 comments on commit d656ad4

Please sign in to comment.