Skip to content

Commit

Permalink
Add get_hf_file_metadata Functionality (#142)
Browse files Browse the repository at this point in the history
* add getHfFileMetadata function to HubApi

* only allow huggingface endpoints in getHfFileMetadata

* add test case for getHfFileMetadata

* remove hardcoded string from location check in test case

* rename getHfFileMetadata to getFileMetadata and refactor

* add blob search for file metadata

* Update Tests/HubTests/HubApiTests.swift

Co-authored-by: Pedro Cuenca <[email protected]>

---------

Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
ardaatahan and pcuenca authored Dec 12, 2024
1 parent 2f611bf commit 37e234e
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
98 changes: 98 additions & 0 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ public extension HubApi {
return (data, response)
}

func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
var request = URLRequest(url: url)
request.httpMethod = "HEAD"
if let hfToken = hfToken {
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization")
}
request.setValue("identity", forHTTPHeaderField: "Accept-Encoding")
let (data, response) = try await URLSession.shared.data(for: request)
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError }

switch response.statusCode {
case 200..<300: break
case 400..<500: throw Hub.HubClientError.authorizationRequired
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
}

return (data, response)
}

func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] {
// Read repo info and only parse "siblings"
let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)")!
Expand Down Expand Up @@ -222,6 +241,65 @@ public extension HubApi {
}
}

/// Metadata
public extension HubApi {
/// A structure representing metadata for a remote file
struct FileMetadata {
/// The file's Git commit hash
public let commitHash: String?

/// Server-provided ETag for caching
public let etag: String?

/// Stringified URL location of the file
public let location: String

/// The file's size in bytes
public let size: Int?
}

private func normalizeEtag(_ etag: String?) -> String? {
guard let etag = etag else { return nil }
return etag.trimmingPrefix("W/").trimmingCharacters(in: CharacterSet(charactersIn: "\""))
}

func getFileMetadata(url: URL) async throws -> FileMetadata {
let (_, response) = try await httpHead(for: url)

return FileMetadata(
commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"),
etag: normalizeEtag(
(response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag"))
),
location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString,
size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "")
)
}

func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [FileMetadata] {
let files = try await getFilenames(from: repo, matching: globs)
let url = URL(string: "\(endpoint)/\(repo.id)/resolve/main")! // TODO: revisions
var selectedMetadata: Array<FileMetadata> = []
for file in files {
let fileURL = url.appending(path: file)
selectedMetadata.append(try await getFileMetadata(url: fileURL))
}
return selectedMetadata
}

func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [FileMetadata] {
return try await getFileMetadata(from: Repo(id: repoId), matching: globs)
}

func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [FileMetadata] {
return try await getFileMetadata(from: repo, matching: [glob])
}

func getFileMetadata(from repoId: String, matching glob: String) async throws -> [FileMetadata] {
return try await getFileMetadata(from: Repo(id: repoId), matching: [glob])
}
}

/// Stateless wrappers that use `HubApi` instances
public extension Hub {
static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] {
Expand Down Expand Up @@ -259,6 +337,26 @@ public extension Hub {
static func whoami(token: String) async throws -> Config {
return try await HubApi(hfToken: token).whoami()
}

static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata {
return try await HubApi.shared.getFileMetadata(url: fileURL)
}

static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
return try await HubApi.shared.getFileMetadata(from: repo, matching: globs)
}

static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs)
}

static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] {
return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob])
}

static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] {
return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob])
}
}

public extension [String] {
Expand Down
57 changes: 57 additions & 0 deletions Tests/HubTests/HubApiTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,63 @@ class HubApiTests: XCTestCase {
XCTFail("\(error)")
}
}

func testGetFileMetadata() async throws {
do {
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")
let metadata = try await Hub.getFileMetadata(fileURL: url!)

XCTAssertNotNil(metadata.commitHash)
XCTAssertNotNil(metadata.etag)
XCTAssertEqual(metadata.location, url?.absoluteString)
XCTAssertEqual(metadata.size, 163)
} catch {
XCTFail("\(error)")
}
}

func testGetFileMetadataBlobPath() async throws {
do {
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/blob/main/config.json")
let metadata = try await Hub.getFileMetadata(fileURL: url!)

XCTAssertEqual(metadata.commitHash, nil)
XCTAssertTrue(metadata.etag != nil && metadata.etag!.hasPrefix("10841-"))
XCTAssertEqual(metadata.location, url?.absoluteString)
XCTAssertEqual(metadata.size, 67649)
} catch {
XCTFail("\(error)")
}
}

func testGetFileMetadataWithRevision() async throws {
do {
let revision = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
let url = URL(string: "https://huggingface.co/julien-c/dummy-unknown/resolve/\(revision)/config.json")
let metadata = try await Hub.getFileMetadata(fileURL: url!)

XCTAssertEqual(metadata.commitHash, revision)
XCTAssertNotNil(metadata.etag)
XCTAssertGreaterThan(metadata.etag!.count, 0)
XCTAssertEqual(metadata.location, url?.absoluteString)
XCTAssertEqual(metadata.size, 851)
} catch {
XCTFail("\(error)")
}
}

func testGetFileMetadataWithBlobSearch() async throws {
let repo = "coreml-projects/Llama-2-7b-chat-coreml"
let metadataFromBlob = try await Hub.getFileMetadata(from: repo, matching: "*.json").sorted { $0.location < $1.location }
let files = try await Hub.getFilenames(from: repo, matching: "*.json").sorted()
for (metadata, file) in zip(metadataFromBlob, files) {
XCTAssertNotNil(metadata.commitHash)
XCTAssertNotNil(metadata.etag)
XCTAssertGreaterThan(metadata.etag!.count, 0)
XCTAssertTrue(metadata.location.contains(file))
XCTAssertGreaterThan(metadata.size!, 0)
}
}
}

class SnapshotDownloadTests: XCTestCase {
Expand Down

0 comments on commit 37e234e

Please sign in to comment.