-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Read weights from safetensors * Check file extension * Deintegrate Safetensor * Separate Safetensor from the weights * Rename test tensors to include type * Rename ModelWeights to Weights * Throw error for unsupported data types * Remove model weights from LanguageModel.Configurations * Move Weights to TensorUtils * Specify filenames to download in tests. * Make the weights optional and public Enable safe access to keys.
- Loading branch information
Showing
8 changed files
with
195 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import CoreML | ||
|
||
|
||
public struct Weights { | ||
|
||
enum WeightsError: Error { | ||
case notSupported(message: String) | ||
case invalidFile | ||
} | ||
|
||
private let dictionary: [String: MLMultiArray] | ||
|
||
init(_ dictionary: [String: MLMultiArray]) { | ||
self.dictionary = dictionary | ||
} | ||
|
||
subscript(key: String) -> MLMultiArray? { dictionary[key] } | ||
|
||
public static func from(fileURL: URL) throws -> Weights { | ||
guard ["safetensors", "gguf", "mlx"].contains(fileURL.pathExtension) | ||
else { throw WeightsError.notSupported(message: "\(fileURL.pathExtension)") } | ||
|
||
let data = try Data(contentsOf: fileURL, options: .mappedIfSafe) | ||
switch ([UInt8](data.subdata(in: 0..<4)), [UInt8](data.subdata(in: 4..<6))) { | ||
case ([0x47, 0x47, 0x55, 0x46], _): throw WeightsError.notSupported(message: ("gguf")) | ||
case ([0x93, 0x4e, 0x55, 0x4d], [0x50, 0x59]): throw WeightsError.notSupported(message: "mlx") | ||
default: return try Safetensor.from(data: data) | ||
} | ||
} | ||
} | ||
|
||
struct Safetensor { | ||
|
||
typealias Error = Weights.WeightsError | ||
|
||
struct Header { | ||
|
||
struct Offset: Decodable { | ||
let dataOffsets: [Int]? | ||
let dtype: String? | ||
let shape: [Int]? | ||
|
||
/// Unsupported: "I8", "U8", "I16", "U16", "BF16" | ||
var dataType: MLMultiArrayDataType? { | ||
get throws { | ||
switch dtype { | ||
case "I32", "U32": .int32 | ||
case "F16": .float16 | ||
case "F32": .float32 | ||
case "F64", "U64": .float64 | ||
default: throw Error.notSupported(message: "\(dtype ?? "empty")") | ||
} | ||
} | ||
} | ||
} | ||
|
||
static func from(data: Data) throws -> [String: Offset?] { | ||
let decoder = JSONDecoder() | ||
decoder.keyDecodingStrategy = .convertFromSnakeCase | ||
return try decoder.decode([String: Offset?].self, from: data) | ||
} | ||
} | ||
|
||
static func from(data: Data) throws -> Weights { | ||
let headerSize: Int = data.subdata(in: 0..<8).withUnsafeBytes({ $0.load(as: Int.self) }) | ||
guard headerSize < data.count else { throw Error.invalidFile } | ||
let header = try Header.from(data: data.subdata(in: 8..<(headerSize + 8))) | ||
|
||
var dict = [String: MLMultiArray]() | ||
for (key, point) in header { | ||
guard let offsets = point?.dataOffsets, offsets.count == 2, | ||
let shape = point?.shape as? [NSNumber], | ||
let dType = try point?.dataType | ||
else { continue } | ||
|
||
let strides = shape.dropFirst().reversed().reduce(into: [1]) { acc, a in | ||
acc.insert(acc[0].intValue * a.intValue as NSNumber, at: 0) | ||
} | ||
let start = 8 + offsets[0] + headerSize | ||
let end = 8 + offsets[1] + headerSize | ||
let tensorData = data.subdata(in: start..<end) as NSData | ||
let ptr = UnsafeMutableRawPointer(mutating: tensorData.bytes) | ||
dict[key] = try MLMultiArray(dataPointer: ptr, shape: shape, dataType: dType, strides: strides) | ||
} | ||
|
||
return Weights(dict) | ||
} | ||
} |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
@testable import TensorUtils | ||
@testable import Hub | ||
import XCTest | ||
|
||
class WeightsTests: XCTestCase { | ||
|
||
let downloadDestination: URL = { | ||
FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first!.appending(component: "huggingface-tests") | ||
}() | ||
|
||
var hubApi: HubApi { HubApi(downloadBase: downloadDestination) } | ||
|
||
func testLoadWeightsFromFileURL() async throws { | ||
let repo = "google/bert_uncased_L-2_H-128_A-2" | ||
let modelDir = try await hubApi.snapshot(from: repo, matching: ["config.json", "model.safetensors"]) | ||
|
||
let files = try FileManager.default.contentsOfDirectory(at: modelDir, includingPropertiesForKeys: [.isReadableKey]) | ||
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "config.json" })) | ||
XCTAssertTrue(files.contains(where: { $0.lastPathComponent == "model.safetensors" })) | ||
|
||
let modelFile = modelDir.appending(path: "/model.safetensors") | ||
let weights = try Weights.from(fileURL: modelFile) | ||
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.dataType, .float32) | ||
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.count, 128) | ||
XCTAssertEqual(weights["bert.embeddings.LayerNorm.bias"]!.shape.count, 1) | ||
|
||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.dataType, .float32) | ||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.count, 3906816) | ||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]!.shape.count, 2) | ||
|
||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[0, 0]].floatValue, -0.0041, accuracy: 1e-3) | ||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[3, 4]].floatValue, 0.0037, accuracy: 1e-3) | ||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[5, 3]].floatValue, -0.5371, accuracy: 1e-3) | ||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[7, 8]].floatValue, 0.0460, accuracy: 1e-3) | ||
XCTAssertEqual(weights["bert.embeddings.word_embeddings.weight"]![[11, 7]].floatValue, -0.0058, accuracy: 1e-3) | ||
} | ||
|
||
func testSafetensorReadTensor1D() throws { | ||
let modelFile = Bundle.module.url(forResource: "tensor-1d-int32", withExtension: "safetensors")! | ||
let weights: Weights = try Weights.from(fileURL: modelFile) | ||
let tensor = weights["embedding"]! | ||
XCTAssertEqual(tensor.dataType, .int32) | ||
XCTAssertEqual(tensor[[0]], 1) | ||
XCTAssertEqual(tensor[[1]], 2) | ||
XCTAssertEqual(tensor[[2]], 3) | ||
} | ||
|
||
func testSafetensorReadTensor2D() throws { | ||
let modelFile = Bundle.module.url(forResource: "tensor-2d-float64", withExtension: "safetensors")! | ||
let weights: Weights = try Weights.from(fileURL: modelFile) | ||
let tensor = weights["embedding"]! | ||
XCTAssertEqual(tensor.dataType, .float64) | ||
XCTAssertEqual(tensor[[0, 0]], 1) | ||
XCTAssertEqual(tensor[[0, 1]], 2) | ||
XCTAssertEqual(tensor[[0, 2]], 3) | ||
XCTAssertEqual(tensor[[1, 0]], 24) | ||
XCTAssertEqual(tensor[[1, 1]], 25) | ||
XCTAssertEqual(tensor[[1, 2]], 26) | ||
} | ||
|
||
func testSafetensorReadTensor3D() throws { | ||
let modelFile = Bundle.module.url(forResource: "tensor-3d-float32", withExtension: "safetensors")! | ||
let weights: Weights = try Weights.from(fileURL: modelFile) | ||
let tensor = weights["embedding"]! | ||
XCTAssertEqual(tensor.dataType, .float32) | ||
XCTAssertEqual(tensor[[0, 0, 0]], 22) | ||
XCTAssertEqual(tensor[[0, 0, 1]], 23) | ||
XCTAssertEqual(tensor[[0, 0, 2]], 24) | ||
XCTAssertEqual(tensor[[0, 1, 0]], 11) | ||
XCTAssertEqual(tensor[[0, 1, 1]], 12) | ||
XCTAssertEqual(tensor[[0, 1, 2]], 13) | ||
XCTAssertEqual(tensor[[1, 0, 0]], 2) | ||
XCTAssertEqual(tensor[[1, 0, 1]], 3) | ||
XCTAssertEqual(tensor[[1, 0, 2]], 4) | ||
XCTAssertEqual(tensor[[1, 1, 0]], 1) | ||
XCTAssertEqual(tensor[[1, 1, 1]], 2) | ||
XCTAssertEqual(tensor[[1, 1, 2]], 3) | ||
} | ||
|
||
func testSafetensorReadTensor4D() throws { | ||
let modelFile = Bundle.module.url(forResource: "tensor-4d-float32", withExtension: "safetensors")! | ||
let weights: Weights = try Weights.from(fileURL: modelFile) | ||
let tensor = weights["embedding"]! | ||
XCTAssertEqual(tensor.dataType, .float32) | ||
XCTAssertEqual(tensor[[0, 0, 0, 0]], 11) | ||
XCTAssertEqual(tensor[[0, 0, 0, 1]], 12) | ||
XCTAssertEqual(tensor[[0, 0, 0, 2]], 13) | ||
XCTAssertEqual(tensor[[0, 0, 1, 0]], 1) | ||
XCTAssertEqual(tensor[[0, 0, 1, 1]], 2) | ||
XCTAssertEqual(tensor[[0, 0, 1, 2]], 3) | ||
XCTAssertEqual(tensor[[0, 0, 2, 0]], 4) | ||
XCTAssertEqual(tensor[[0, 0, 2, 1]], 5) | ||
XCTAssertEqual(tensor[[0, 0, 2, 2]], 6) | ||
XCTAssertEqual(tensor[[1, 0, 0, 0]], 22) | ||
XCTAssertEqual(tensor[[1, 0, 0, 1]], 23) | ||
XCTAssertEqual(tensor[[1, 0, 0, 2]], 24) | ||
XCTAssertEqual(tensor[[1, 0, 1, 0]], 15) | ||
XCTAssertEqual(tensor[[1, 0, 1, 1]], 16) | ||
XCTAssertEqual(tensor[[1, 0, 1, 2]], 17) | ||
XCTAssertEqual(tensor[[1, 0, 2, 0]], 17) | ||
XCTAssertEqual(tensor[[1, 0, 2, 1]], 18) | ||
XCTAssertEqual(tensor[[1, 0, 2, 2]], 19) | ||
} | ||
} |