From 00bd84c59995c7b3daa0b4fa1597f77608806fdb Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Wed, 10 Apr 2024 22:02:22 -0700 Subject: [PATCH 1/4] Add: Initial Swift support --- .gitignore | 4 +- .vscode/settings.json | 10 +- Package.resolved | 22 ++ Package.swift | 38 ++ python/scripts/export.ipynb | 666 +++++++++++++++++++++++++++++++++++ python/uform/torch_models.py | 36 +- swift/Embeddings.swift | 290 +++++++++++++++ swift/EmbeddingsTests.swift | 95 +++++ 8 files changed, 1152 insertions(+), 9 deletions(-) create mode 100644 Package.resolved create mode 100644 Package.swift create mode 100644 python/scripts/export.ipynb create mode 100644 swift/Embeddings.swift create mode 100644 swift/EmbeddingsTests.swift diff --git a/.gitignore b/.gitignore index 1732614..af7d4af 100755 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ build/ package-lock.json *.egg-info *.onnx -__pycache__ \ No newline at end of file +__pycache__ +.build +.swiftpm \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 9b08d04..0ac7435 100755 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,16 +1,22 @@ { "cSpell.words": [ + "coreml", "dtype", "embs", "huggingface", "keepdim", + "linalg", "logits", "Matryoshka", + "mlmodel", + "mlpackage", + "mlprogram", "multimodal", "ndarray", "numpy", "ONNX", "onnxruntime", + "packbits", "preprocess", "pretrained", "probs", @@ -18,11 +24,13 @@ "rerank", "reranker", "reranking", + "SIMD", "softmax", "transfromers", "uform", "unimodal", - "unsqueeze" + "unsqueeze", + "Vardanian" ], "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..de00bbf --- /dev/null +++ b/Package.resolved @@ -0,0 +1,22 @@ +{ + "pins" : [ + { + "identity" : "swift-argument-parser", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-argument-parser.git", + "state" : { + "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ashvardanian/swift-transformers", + "state" : { + "revision" : "4060e8ff7c959b89afa7f672cb0a479e87add284" + } + } + ], + "version" : 2 +} diff --git a/Package.swift b/Package.swift new file mode 100644 index 0000000..39fc6ed --- /dev/null +++ b/Package.swift @@ -0,0 +1,38 @@ +// swift-tools-version:5.9 +import PackageDescription + +let package = Package( + name: "UForm", + platforms: [ + // Linux doesn't have to be explicitly listed + .iOS(.v16), // For iOS, version 13 and later + .tvOS(.v16), // For tvOS, version 13 and later + .macOS(.v13), // For macOS, version 10.15 (Catalina) and later + .watchOS(.v6) // For watchOS, version 6 and later + ], + products: [ + .library( + name: "UForm", + targets: ["UForm"] + ) + ], + dependencies: [ + .package(url: "https://github.com/ashvardanian/swift-transformers", revision: "4060e8ff7c959b89afa7f672cb0a479e87add284") + ], + targets: [ + .target( + name: "UForm", + dependencies: [ + .product(name: "Transformers", package: "swift-transformers") + ], + path: "swift", + exclude: ["EmbeddingsTests.swift"] + ), + .testTarget( + name: "UFormTests", + dependencies: ["UForm"], + path: "swift", + sources: ["EmbeddingsTests.swift"] + ) + ] +) diff --git a/python/scripts/export.ipynb b/python/scripts/export.ipynb new file mode 100644 index 0000000..ce8cf10 --- /dev/null +++ b/python/scripts/export.ipynb @@ -0,0 +1,666 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Scripts for Exporting PyTorch Models to ONNX and CoreML" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --upgrade \"uform[torch]\" coremltools" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/av/miniconda3/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: dlopen(/Users/av/miniconda3/lib/python3.10/site-packages/torchvision/image.so, 0x0006): Symbol not found: __ZN3c106detail19maybe_wrap_dim_slowExxb\n", + " Referenced from: <0B637046-A38B-3A5C-80C6-E847C27DCCD5> /Users/av/miniconda3/lib/python3.10/site-packages/torchvision/image.so\n", + " Expected in: <3AE92490-D363-3FD7-8532-CB6F5F795BC8> /Users/av/miniconda3/lib/python3.10/site-packages/torch/lib/libc10.dylib\n", + " warn(f\"Failed to load image Python extension: {e}\")\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fadffc0299c04e249fd4f7a5b40ba0af", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 5 files: 0%| | 0/5 [00:00 MIL Ops: 100%|█████████▉| 453/455 [00:00<00:00, 5638.83 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 381.07 passes/s]\n", + "Running MIL default pipeline: 100%|██████████| 69/69 [00:00<00:00, 156.08 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 699.38 passes/s]\n" + ] + } + ], + "source": [ + "coreml_model = ct.convert(\n", + " traced_script_module, source=\"pytorch\",\n", + " inputs=[image_input], outputs=[image_features, image_embeddings],\n", + " convert_to='mlprogram', compute_precision=ct.precision.FLOAT32)\n", + "\n", + "coreml_model.author = 'Unum Cloud'\n", + "coreml_model.license = 'Apache 2.0'\n", + "coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'\n", + "coreml_model.save(\"../uform-vl-english-small-image.mlpackage\")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TextEncoder(\n", + " original_name=TextEncoder\n", + " (word_embeddings): Embedding(original_name=Embedding)\n", + " (position_embeddings): Embedding(original_name=Embedding)\n", + " (layer_norm): LayerNorm(original_name=LayerNorm)\n", + " (dropout): Dropout(original_name=Dropout)\n", + " (blocks): ModuleList(\n", + " original_name=ModuleList\n", + " (0): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " (1): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " (2): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_crossattn): LayerNorm(original_name=LayerNorm)\n", + " (crossattn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " (3): TextEncoderBlock(\n", + " original_name=TextEncoderBlock\n", + " (norm_attn): LayerNorm(original_name=LayerNorm)\n", + " (attention): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_crossattn): LayerNorm(original_name=LayerNorm)\n", + " (crossattn): Attention(\n", + " original_name=Attention\n", + " (query): Linear(original_name=Linear)\n", + " (key): Linear(original_name=Linear)\n", + " (value): Linear(original_name=Linear)\n", + " (out): Linear(original_name=Linear)\n", + " )\n", + " (norm_mlp): LayerNorm(original_name=LayerNorm)\n", + " (mlp): MLP(\n", + " original_name=MLP\n", + " (hidden_layer): Linear(original_name=Linear)\n", + " (output_layer): Linear(original_name=Linear)\n", + " )\n", + " (dropout): Dropout(original_name=Dropout)\n", + " )\n", + " )\n", + " (embedding_projection): Linear(original_name=Linear)\n", + " (matching_head): Linear(original_name=Linear)\n", + " (context_projection): Linear(original_name=Linear)\n", + ")" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "module = model.text_encoder\n", + "module.eval()\n", + "module.return_features = True\n", + "\n", + "traced_script_module = torch.jit.trace(module, example_inputs=[text_data['input_ids'], text_data['attention_mask']])\n", + "traced_script_module" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tuple detected at graph output. This will be flattened in the converted model.\n", + "Converting PyTorch Frontend ==> MIL Ops: 0%| | 0/157 [00:00 MIL Ops: 99%|█████████▊| 155/157 [00:00<00:00, 6809.29 ops/s]\n", + "Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 1947.76 passes/s]\n", + "Running MIL default pipeline: 100%|██████████| 69/69 [00:00<00:00, 816.08 passes/s]\n", + "Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 3294.17 passes/s]\n" + ] + } + ], + "source": [ + "coreml_model = ct.convert(\n", + " traced_script_module, source=\"pytorch\",\n", + " inputs=[text_input, text_attention_input], outputs=[text_features, text_embeddings],\n", + " convert_to='mlprogram', compute_precision=ct.precision.FLOAT32)\n", + "\n", + "coreml_model.author = 'Unum Cloud'\n", + "coreml_model.license = 'Apache 2.0'\n", + "coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'\n", + "coreml_model.save(\"../uform-vl-english-small-text.mlpackage\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/uform/torch_models.py b/python/uform/torch_models.py index d0d810b..ab86622 100644 --- a/python/uform/torch_models.py +++ b/python/uform/torch_models.py @@ -207,6 +207,7 @@ def __post_init__(self): self.context_projection = nn.Linear(self.context_dim, self.dim, bias=False) else: self.context_projection = nn.Identity() + self.return_features = False def forward_features(self, x: Tensor, attn_mask: Tensor) -> Tensor: x = self.embed_text(x) @@ -267,10 +268,27 @@ def embed_text(self, x: Tensor) -> Tensor: x = self.word_embeddings(x) + positional_embedding return self.dropout(self.layer_norm(x)) - def forward(self, x: dict) -> Tensor: - features = self.forward_features(x["input_ids"], x["attention_mask"]) - embeddings = self.forward_embedding(features, x["attention_mask"]) - return features, embeddings + def forward( + self, + x: Union[Tensor, dict], + attention_mask: Optional[Tensor] = None, + return_features: Optional[bool] = None, + ) -> Tensor: + if isinstance(x, dict): + assert attention_mask is None, "If `x` is a dictionary, then `attention_mask` should be None" + attention_mask = x["attention_mask"] + x = x["input_ids"] + elif attention_mask is None: + # If no attention mask is provided - create one with all ones + attention_mask = torch.ones_like(x) + + features = self.forward_features(x, attention_mask) + embeddings = self.forward_embedding(features, attention_mask) + + return_features = return_features if return_features is not None else self.return_features + if return_features: + return features, embeddings + return embeddings @dataclass(eq=False) @@ -301,6 +319,7 @@ def __post_init__(self): self.norm = nn.LayerNorm(self.dim, eps=1e-6) self.embedding_projection = nn.Linear(self.dim, self.embedding_dim, bias=False) + self.return_features = False def forward_features(self, x: Tensor) -> Tensor: x = self.patch_embed(x).flatten(start_dim=2).transpose(2, 1) @@ -325,10 +344,13 @@ def forward_embedding(self, x: Tensor) -> Tensor: return self.embedding_projection(x) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, return_features: Optional[bool] = None) -> Tensor: features = self.forward_features(x) embeddings = self.forward_embedding(features) - return features, embeddings + return_features = return_features if return_features is not None else self.return_features + if return_features: + return features, embeddings + return embeddings class VLM(nn.Module): @@ -430,7 +452,7 @@ def encode_multimodal( attention_mask if attention_mask is not None else text["attention_mask"], image_features, ) - + if return_scores: return self.get_matching_scores(embeddings), embeddings diff --git a/swift/Embeddings.swift b/swift/Embeddings.swift new file mode 100644 index 0000000..5da7258 --- /dev/null +++ b/swift/Embeddings.swift @@ -0,0 +1,290 @@ +// +// Embeddings.swift +// +// +// Created by Ash Vardanian on 3/27/24. +// +import Foundation +import CoreGraphics +import Accelerate +import CoreML + +import Hub // `Config` +import Tokenizers // `AutoTokenizer` + +// MARK: - Helpers + +func readConfig(fromPath path: String) throws -> [String: Any] { + let data = try Data(contentsOf: URL(fileURLWithPath: path)) + return try JSONSerialization.jsonObject(with: data, options: []) as! [String: Any] +} + +func readModel(fromPath path: String) throws -> MLModel { + // If compilation succeeds, you can then load the compiled model + let modelURL = URL(fileURLWithPath: path, isDirectory: true) + let compiledModelURL = try MLModel.compileModel(at: modelURL) + return try MLModel(contentsOf: compiledModelURL) +} + +// MARK: - Encoders + +public class TextEncoder { + let model: MLModel + let processor: TextProcessor + + public init(modelPath: String, configPath: String, tokenizerPath: String) throws { + self.model = try readModel(fromPath: modelPath) + self.processor = try TextProcessor(configPath: configPath, tokenizerPath: tokenizerPath, model: self.model) + } + + public func forward(with text: String) throws -> [Float32] { + let inputFeatureProvider = try self.processor.preprocess(text) + let prediction = try self.model.prediction(from: inputFeatureProvider) + let predictionFeature = prediction.featureValue(for: "embeddings") + // The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32 + let output = predictionFeature!.multiArrayValue! + return Array(UnsafeBufferPointer(start: output.dataPointer.assumingMemoryBound(to: Float32.self), count: Int(truncating: output.shape[1]))) + } +} + + +public class ImageEncoder { + let model: MLModel + let processor: ImageProcessor + + public init(modelPath: String, configPath: String) throws { + self.model = try readModel(fromPath: modelPath) + self.processor = try ImageProcessor(configPath: configPath) + } + + public func forward(with image: CGImage) throws -> [Float32] { + let inputFeatureProvider = try self.processor.preprocess(image) + let prediction = try self.model.prediction(from: inputFeatureProvider) + let predictionFeature = prediction.featureValue(for: "embeddings") + // The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32 + let output = predictionFeature!.multiArrayValue! + return Array(UnsafeBufferPointer(start: output.dataPointer.assumingMemoryBound(to: Float32.self), count: Int(truncating: output.shape[1]))) + } + +} + +// MARK: - Processors + +class TextProcessor { + let tokenizer: Tokenizer + let minContextLength: Int + let maxContextLength: Int + + public init(configPath: String, tokenizerPath: String, model: MLModel) throws { + let configDict = try readConfig(fromPath: configPath) + let tokenizerDict = try readConfig(fromPath: tokenizerPath) + + let config = Config(configDict) + let tokenizerData = Config(tokenizerDict) + self.tokenizer = try AutoTokenizer.from(tokenizerConfig: config, tokenizerData: tokenizerData) + + let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"] + guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else { + fatalError("Cannot obtain shape information") + } + + switch shapeConstraint.type { + case .enumerated: + minContextLength = shapeConstraint.enumeratedShapes[0][1].intValue + maxContextLength = minContextLength + case .range: + let range = inputDescription?.multiArrayConstraint?.shapeConstraint.sizeRangeForDimension[1] as? NSRange + minContextLength = range?.location ?? 1 + maxContextLength = range?.length ?? 128 + case .unspecified: + minContextLength = 128 + maxContextLength = 128 + @unknown default: + minContextLength = 128 + maxContextLength = 128 + } + } + + public func preprocess(_ text: String) throws -> MLFeatureProvider { + let inputIDs = self.tokenizer.encode(text: text) + return TextInput(inputIDs: inputIDs, sequenceLength: self.maxContextLength) + } +} + +class ImageProcessor { + let imageSize: Int + let mean: [Float] = [0.485, 0.456, 0.406] // Common mean values for normalization + let std: [Float] = [0.229, 0.224, 0.225] // Common std values for normalization + + init(configPath: String) throws { + let configDict = try readConfig(fromPath: configPath) + let config = Config(configDict) + self.imageSize = config.imageSize!.intValue! + } + + func preprocess(_ cgImage: CGImage) throws -> MLFeatureProvider { + // Populate a tensor of size 3 x `imageSize` x `imageSize`, + // by resizing the image, then performing a center crop. + // Then normalize with the `mean` and `std` and export as a provider. + let cropped = resizeAndCrop(image: cgImage, toSideLength: self.imageSize)! + let normalized = exportToTensorAndNormalize(image: cropped, mean: self.mean, std: self.std)! + let featureValue = MLFeatureValue(multiArray: normalized) + return try ImageInput(precomputedFeature: featureValue) + } + + private func resizeAndCrop(image: CGImage, toSideLength imageSize: Int) -> CGImage? { + let originalWidth = CGFloat(image.width) + let originalHeight = CGFloat(image.height) + + let widthRatio = CGFloat(imageSize) / originalWidth + let heightRatio = CGFloat(imageSize) / originalHeight + let scaleFactor = max(widthRatio, heightRatio) + + let scaledWidth = originalWidth * scaleFactor + let scaledHeight = originalHeight * scaleFactor + + let dx = (scaledWidth - CGFloat(imageSize)) / 2.0 + let dy = (scaledHeight - CGFloat(imageSize)) / 2.0 + let insetRect = CGRect(x: dx, y: dy, width: CGFloat(imageSize) - dx*2, height: CGFloat(imageSize) - dy*2) + + // Create a new context (off-screen canvas) with the desired dimensions + guard let context = CGContext( + data: nil, + width: imageSize, height: imageSize, bitsPerComponent: image.bitsPerComponent, bytesPerRow: 0, + space: image.colorSpace ?? CGColorSpaceCreateDeviceRGB(), bitmapInfo: image.bitmapInfo.rawValue) else { return nil } + + // Draw the image in the context with the specified inset (cropping as necessary) + context.interpolationQuality = .high + context.draw(image, in: insetRect, byTiling: false) + + // Extract the new image from the context + return context.makeImage() + } + + private func exportToTensorAndNormalize(image: CGImage, mean: [Float], std: [Float]) -> MLMultiArray? { + let width = image.width + let height = image.height + + // Prepare the bitmap context for drawing the image. + var pixelData = [UInt8](repeating: 0, count: width * height * 4) + let colorSpace = CGColorSpaceCreateDeviceRGB() + let context = CGContext(data: &pixelData, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 4 * width, space: colorSpace, bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue) + context?.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height)) + + // Convert pixel data to float and normalize + let totalCount = width * height * 4 + var floatPixels = [Float](repeating: 0, count: totalCount) + vDSP_vfltu8(pixelData, 1, &floatPixels, 1, vDSP_Length(totalCount)) + + // Scale the pixel values to [0, 1] + var divisor = Float(255.0) + vDSP_vsdiv(floatPixels, 1, &divisor, &floatPixels, 1, vDSP_Length(totalCount)) + + // Normalize the pixel values + for c in 0..<3 { + var slice = [Float](repeating: 0, count: width * height) + for i in 0..<(width * height) { + slice[i] = (floatPixels[i * 4 + c] - mean[c]) / std[c] + } + floatPixels.replaceSubrange(c*width*height..<(c+1)*width*height, with: slice) + } + + // Rearrange the array to C x H x W + var tensor = [Float](repeating: 0, count: width * height * 3) + for y in 0.. { + return Set(["input_ids", "attention_mask"]) + } + + // The model expects the input IDs to be an array of integers + // of length `sequenceLength`, padded with `paddingID` if necessary + func featureValue(for featureName: String) -> MLFeatureValue? { + switch featureName { + case "input_ids", "attention_mask": + return createFeatureValue(for: featureName) + default: + return nil + } + } + + private func createFeatureValue(for featureName: String) -> MLFeatureValue? { + let count = min(inputIDs.count, sequenceLength) + let totalElements = sequenceLength + guard let multiArray = try? MLMultiArray(shape: [1, NSNumber(value: totalElements)], dataType: .int32) else { + return nil + } + + if featureName == "input_ids" { + for i in 0.. { + return Set(["input"]) + } + + // The model expects the input IDs to be an array of integers + // of length `sequenceLength`, padded with `paddingID` if necessary + func featureValue(for featureName: String) -> MLFeatureValue? { + switch featureName { + case "input": + return precomputedFeature + default: + return nil + } + } +} + diff --git a/swift/EmbeddingsTests.swift b/swift/EmbeddingsTests.swift new file mode 100644 index 0000000..221842a --- /dev/null +++ b/swift/EmbeddingsTests.swift @@ -0,0 +1,95 @@ +import UForm + +import XCTest +import CoreGraphics +import ImageIO + +final class TokenizerTests: XCTestCase { + + + func cosineSimilarity(between vectorA: [T], and vectorB: [T]) -> T { + guard vectorA.count == vectorB.count else { + fatalError("Vectors must be of the same length.") + } + + let dotProduct = zip(vectorA, vectorB).reduce(T.zero) { $0 + ($1.0 * $1.1) } + let magnitudeA = sqrt(vectorA.reduce(T.zero) { $0 + $1 * $1 }) + let magnitudeB = sqrt(vectorB.reduce(T.zero) { $0 + $1 * $1 }) + + // Avoid division by zero + if magnitudeA == T.zero || magnitudeB == T.zero { + return T.zero + } + + return dotProduct / (magnitudeA * magnitudeB) + } + + + func testTextEmbeddings() async throws { + let model = try TextEncoder( + modelPath: "uform/uform-vl-english-small-text.mlpackage", + configPath: "uform/config.json", + tokenizerPath: "uform/tokenizer.json" + ) + + let texts = [ + "sunny beach with clear blue water", + "crowded sandbeach under the bright sun", + "dense forest with tall green trees", + "quiet park in the morning light" + ] + + var embeddings: [[Float32]] = [] + for text in texts { + let embedding: [Float32] = try model.forward(with: text) + embeddings.append(embedding) + } + + // Now let's compute the cosine similarity between the embeddings + let similarityBeach = cosineSimilarity(between: embeddings[0], and: embeddings[1]) + let similarityForest = cosineSimilarity(between: embeddings[2], and: embeddings[3]) + let dissimilarityBetweenScenes = cosineSimilarity(between: embeddings[0], and: embeddings[2]) + + // Assert that similar texts have higher similarity scores + XCTAssertTrue(similarityBeach > dissimilarityBetweenScenes, "Beach texts should be more similar to each other than to forest texts.") + XCTAssertTrue(similarityForest > dissimilarityBetweenScenes, "Forest texts should be more similar to each other than to beach texts.") + } + + func testImageEmbeddings() async throws { + let model = try ImageEncoder( + modelPath: "uform/uform-vl-english-small-image.mlpackage", + configPath: "uform/config_image.json" + ) + + let imageURLs = [ + "https://github.com/ashvardanian/ashvardanian/blob/master/demos/bbq-on-beach.jpg?raw=true", + "https://github.com/ashvardanian/ashvardanian/blob/master/demos/cat-in-garden.jpg?raw=true", + "https://github.com/ashvardanian/ashvardanian/blob/master/demos/girl-and-rain.jpg?raw=true", + "https://github.com/ashvardanian/ashvardanian/blob/master/demos/light-bedroom-furniture.jpg?raw=true", + "https://github.com/ashvardanian/ashvardanian/blob/master/demos/louvre-at-night.jpg?raw=true", + ] + + var embeddings: [[Float32]] = [] + for imageURL in imageURLs { + guard let url = URL(string: imageURL), + let imageSource = CGImageSourceCreateWithURL(url as CFURL, nil), + let cgImage = CGImageSourceCreateImageAtIndex(imageSource, 0, nil) else { + throw NSError(domain: "ImageError", code: 100, userInfo: [NSLocalizedDescriptionKey: "Could not load image from URL: \(imageURL)"]) + } + + let embedding: [Float32] = try model.forward(with: cgImage) + embeddings.append(embedding) + } + + // Now let's compute the cosine similarity between the embeddings + let similarityGirlAndBeach = cosineSimilarity(between: embeddings[2], and: embeddings[0]) + let similarityGirlAndLouvre = cosineSimilarity(between: embeddings[2], and: embeddings[4]) + let similarityBeachAndLouvre = cosineSimilarity(between: embeddings[0], and: embeddings[4]) + + // Assert that similar images have higher similarity scores + XCTAssertTrue(similarityGirlAndBeach > similarityGirlAndLouvre, ""); + XCTAssertTrue(similarityGirlAndBeach > similarityBeachAndLouvre, ""); + } + + +} From f6faf4cd877f6034d8c66edb108bc07ec1735232 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Thu, 11 Apr 2024 16:04:35 -0700 Subject: [PATCH 2/4] Make: Formatting Swift code --- .swift-format | 13 +++ CONTRIBUTING.md | 18 +++- Package.resolved | 2 +- Package.swift | 15 ++-- swift/Embeddings.swift | 168 +++++++++++++++++++++--------------- swift/EmbeddingsTests.swift | 130 ++++++++++++++++++---------- 6 files changed, 222 insertions(+), 124 deletions(-) create mode 100644 .swift-format diff --git a/.swift-format b/.swift-format new file mode 100644 index 0000000..53bf631 --- /dev/null +++ b/.swift-format @@ -0,0 +1,13 @@ +{ + "version": 1, + "lineLength": 120, + "indentation": { + "spaces": 4 + }, + "maximumBlankLines": 1, + "respectsExistingLineBreaks": true, + "lineBreakBeforeControlFlowKeywords": true, + "lineBreakBeforeEachArgument": true, + "multiElementCollectionTrailingCommas": true, + "spacesAroundRangeFormationOperators": true +} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1548c30..181d9e2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,9 @@ # Contributing to UForm We welcome contributions to UForm! + +## Python + Before submitting any changes, please make sure that the tests pass. ```sh @@ -13,4 +16,17 @@ pip install -e ".[torch,onnx]" # For PyTorch and ONNX Python tests pytest python/scripts/ -s -x -Wd -v pytest python/scripts/ -s -x -Wd -v -k onnx # To run only ONNX tests without loading Torch -``` \ No newline at end of file +``` + +## Swift + +Swift formatting is enforced with `swift-format` default utility from Apple. +To install and run it on all the files in the project, use the following command: + +```bash +brew install swift-format +swift-format . -i -r +``` + +The style is controlled by the `.swift-format` JSON file in the root of the repository. +As there is no standard for Swift formatting, even Apple's own `swift-format` tool and Xcode differ in their formatting rules, and available settings. diff --git a/Package.resolved b/Package.resolved index de00bbf..fe63c94 100644 --- a/Package.resolved +++ b/Package.resolved @@ -14,7 +14,7 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ashvardanian/swift-transformers", "state" : { - "revision" : "4060e8ff7c959b89afa7f672cb0a479e87add284" + "revision" : "9ef46a51eca46978b62773f8887926dfe72b0ab4" } } ], diff --git a/Package.swift b/Package.swift index 39fc6ed..6ac8372 100644 --- a/Package.swift +++ b/Package.swift @@ -5,10 +5,10 @@ let package = Package( name: "UForm", platforms: [ // Linux doesn't have to be explicitly listed - .iOS(.v16), // For iOS, version 13 and later - .tvOS(.v16), // For tvOS, version 13 and later - .macOS(.v13), // For macOS, version 10.15 (Catalina) and later - .watchOS(.v6) // For watchOS, version 6 and later + .iOS(.v16), // For iOS, version 13 and later + .tvOS(.v16), // For tvOS, version 13 and later + .macOS(.v13), // For macOS, version 10.15 (Catalina) and later + .watchOS(.v6), // For watchOS, version 6 and later ], products: [ .library( @@ -17,7 +17,10 @@ let package = Package( ) ], dependencies: [ - .package(url: "https://github.com/ashvardanian/swift-transformers", revision: "4060e8ff7c959b89afa7f672cb0a479e87add284") + .package( + url: "https://github.com/ashvardanian/swift-transformers", + revision: "9ef46a51eca46978b62773f8887926dfe72b0ab4" + ) ], targets: [ .target( @@ -33,6 +36,6 @@ let package = Package( dependencies: ["UForm"], path: "swift", sources: ["EmbeddingsTests.swift"] - ) + ), ] ) diff --git a/swift/Embeddings.swift b/swift/Embeddings.swift index 5da7258..00d56c4 100644 --- a/swift/Embeddings.swift +++ b/swift/Embeddings.swift @@ -4,24 +4,26 @@ // // Created by Ash Vardanian on 3/27/24. // -import Foundation -import CoreGraphics import Accelerate +import CoreGraphics import CoreML - -import Hub // `Config` -import Tokenizers // `AutoTokenizer` +import Foundation +import Hub // `Config` +import Tokenizers // `AutoTokenizer` // MARK: - Helpers func readConfig(fromPath path: String) throws -> [String: Any] { - let data = try Data(contentsOf: URL(fileURLWithPath: path)) + // If it's not an absolute path, let's assume it's a path relative to the current working directory + let absPath = path.hasPrefix("/") ? path : FileManager.default.currentDirectoryPath + "/" + path + let data = try Data(contentsOf: URL(fileURLWithPath: absPath)) return try JSONSerialization.jsonObject(with: data, options: []) as! [String: Any] } func readModel(fromPath path: String) throws -> MLModel { - // If compilation succeeds, you can then load the compiled model - let modelURL = URL(fileURLWithPath: path, isDirectory: true) + // If it's not an absolute path, let's assume it's a path relative to the current working directory + let absPath = path.hasPrefix("/") ? path : FileManager.default.currentDirectoryPath + "/" + path + let modelURL = URL(fileURLWithPath: absPath, isDirectory: true) let compiledModelURL = try MLModel.compileModel(at: modelURL) return try MLModel(contentsOf: compiledModelURL) } @@ -31,41 +33,50 @@ func readModel(fromPath path: String) throws -> MLModel { public class TextEncoder { let model: MLModel let processor: TextProcessor - + public init(modelPath: String, configPath: String, tokenizerPath: String) throws { self.model = try readModel(fromPath: modelPath) self.processor = try TextProcessor(configPath: configPath, tokenizerPath: tokenizerPath, model: self.model) } - + public func forward(with text: String) throws -> [Float32] { let inputFeatureProvider = try self.processor.preprocess(text) let prediction = try self.model.prediction(from: inputFeatureProvider) let predictionFeature = prediction.featureValue(for: "embeddings") // The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32 let output = predictionFeature!.multiArrayValue! - return Array(UnsafeBufferPointer(start: output.dataPointer.assumingMemoryBound(to: Float32.self), count: Int(truncating: output.shape[1]))) + return Array( + UnsafeBufferPointer( + start: output.dataPointer.assumingMemoryBound(to: Float32.self), + count: Int(truncating: output.shape[1]) + ) + ) } } - public class ImageEncoder { let model: MLModel let processor: ImageProcessor - + public init(modelPath: String, configPath: String) throws { self.model = try readModel(fromPath: modelPath) self.processor = try ImageProcessor(configPath: configPath) } - + public func forward(with image: CGImage) throws -> [Float32] { let inputFeatureProvider = try self.processor.preprocess(image) let prediction = try self.model.prediction(from: inputFeatureProvider) let predictionFeature = prediction.featureValue(for: "embeddings") // The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32 let output = predictionFeature!.multiArrayValue! - return Array(UnsafeBufferPointer(start: output.dataPointer.assumingMemoryBound(to: Float32.self), count: Int(truncating: output.shape[1]))) + return Array( + UnsafeBufferPointer( + start: output.dataPointer.assumingMemoryBound(to: Float32.self), + count: Int(truncating: output.shape[1]) + ) + ) } - + } // MARK: - Processors @@ -74,20 +85,20 @@ class TextProcessor { let tokenizer: Tokenizer let minContextLength: Int let maxContextLength: Int - + public init(configPath: String, tokenizerPath: String, model: MLModel) throws { let configDict = try readConfig(fromPath: configPath) let tokenizerDict = try readConfig(fromPath: tokenizerPath) - + let config = Config(configDict) let tokenizerData = Config(tokenizerDict) self.tokenizer = try AutoTokenizer.from(tokenizerConfig: config, tokenizerData: tokenizerData) - + let inputDescription = model.modelDescription.inputDescriptionsByName["input_ids"] guard let shapeConstraint = inputDescription?.multiArrayConstraint?.shapeConstraint else { fatalError("Cannot obtain shape information") } - + switch shapeConstraint.type { case .enumerated: minContextLength = shapeConstraint.enumeratedShapes[0][1].intValue @@ -104,7 +115,7 @@ class TextProcessor { maxContextLength = 128 } } - + public func preprocess(_ text: String) throws -> MLFeatureProvider { let inputIDs = self.tokenizer.encode(text: text) return TextInput(inputIDs: inputIDs, sequenceLength: self.maxContextLength) @@ -113,15 +124,15 @@ class TextProcessor { class ImageProcessor { let imageSize: Int - let mean: [Float] = [0.485, 0.456, 0.406] // Common mean values for normalization - let std: [Float] = [0.229, 0.224, 0.225] // Common std values for normalization - + let mean: [Float] = [0.485, 0.456, 0.406] // Common mean values for normalization + let std: [Float] = [0.229, 0.224, 0.225] // Common std values for normalization + init(configPath: String) throws { let configDict = try readConfig(fromPath: configPath) let config = Config(configDict) self.imageSize = config.imageSize!.intValue! } - + func preprocess(_ cgImage: CGImage) throws -> MLFeatureProvider { // Populate a tensor of size 3 x `imageSize` x `imageSize`, // by resizing the image, then performing a center crop. @@ -129,79 +140,97 @@ class ImageProcessor { let cropped = resizeAndCrop(image: cgImage, toSideLength: self.imageSize)! let normalized = exportToTensorAndNormalize(image: cropped, mean: self.mean, std: self.std)! let featureValue = MLFeatureValue(multiArray: normalized) - return try ImageInput(precomputedFeature: featureValue) + return try ImageInput(precomputedFeature: featureValue) } - + private func resizeAndCrop(image: CGImage, toSideLength imageSize: Int) -> CGImage? { let originalWidth = CGFloat(image.width) let originalHeight = CGFloat(image.height) - + let widthRatio = CGFloat(imageSize) / originalWidth let heightRatio = CGFloat(imageSize) / originalHeight let scaleFactor = max(widthRatio, heightRatio) - + let scaledWidth = originalWidth * scaleFactor let scaledHeight = originalHeight * scaleFactor - + let dx = (scaledWidth - CGFloat(imageSize)) / 2.0 let dy = (scaledHeight - CGFloat(imageSize)) / 2.0 - let insetRect = CGRect(x: dx, y: dy, width: CGFloat(imageSize) - dx*2, height: CGFloat(imageSize) - dy*2) - + let insetRect = CGRect(x: dx, y: dy, width: CGFloat(imageSize) - dx * 2, height: CGFloat(imageSize) - dy * 2) + // Create a new context (off-screen canvas) with the desired dimensions - guard let context = CGContext( - data: nil, - width: imageSize, height: imageSize, bitsPerComponent: image.bitsPerComponent, bytesPerRow: 0, - space: image.colorSpace ?? CGColorSpaceCreateDeviceRGB(), bitmapInfo: image.bitmapInfo.rawValue) else { return nil } - + guard + let context = CGContext( + data: nil, + width: imageSize, + height: imageSize, + bitsPerComponent: image.bitsPerComponent, + bytesPerRow: 0, + space: image.colorSpace ?? CGColorSpaceCreateDeviceRGB(), + bitmapInfo: image.bitmapInfo.rawValue + ) + else { return nil } + // Draw the image in the context with the specified inset (cropping as necessary) context.interpolationQuality = .high context.draw(image, in: insetRect, byTiling: false) - + // Extract the new image from the context return context.makeImage() } - + private func exportToTensorAndNormalize(image: CGImage, mean: [Float], std: [Float]) -> MLMultiArray? { let width = image.width let height = image.height - + // Prepare the bitmap context for drawing the image. var pixelData = [UInt8](repeating: 0, count: width * height * 4) let colorSpace = CGColorSpaceCreateDeviceRGB() - let context = CGContext(data: &pixelData, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 4 * width, space: colorSpace, bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue) + let context = CGContext( + data: &pixelData, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: 4 * width, + space: colorSpace, + bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue + ) context?.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height)) - + // Convert pixel data to float and normalize let totalCount = width * height * 4 var floatPixels = [Float](repeating: 0, count: totalCount) vDSP_vfltu8(pixelData, 1, &floatPixels, 1, vDSP_Length(totalCount)) - + // Scale the pixel values to [0, 1] var divisor = Float(255.0) vDSP_vsdiv(floatPixels, 1, &divisor, &floatPixels, 1, vDSP_Length(totalCount)) - + // Normalize the pixel values - for c in 0..<3 { + for c in 0 ..< 3 { var slice = [Float](repeating: 0, count: width * height) - for i in 0..<(width * height) { + for i in 0 ..< (width * height) { slice[i] = (floatPixels[i * 4 + c] - mean[c]) / std[c] } - floatPixels.replaceSubrange(c*width*height..<(c+1)*width*height, with: slice) + floatPixels.replaceSubrange(c * width * height ..< (c + 1) * width * height, with: slice) } - + // Rearrange the array to C x H x W var tensor = [Float](repeating: 0, count: width * height * 3) - for y in 0.. { return Set(["input_ids", "attention_mask"]) } - + // The model expects the input IDs to be an array of integers // of length `sequenceLength`, padded with `paddingID` if necessary func featureValue(for featureName: String) -> MLFeatureValue? { @@ -236,46 +265,46 @@ class TextInput: MLFeatureProvider { return nil } } - + private func createFeatureValue(for featureName: String) -> MLFeatureValue? { let count = min(inputIDs.count, sequenceLength) let totalElements = sequenceLength guard let multiArray = try? MLMultiArray(shape: [1, NSNumber(value: totalElements)], dataType: .int32) else { return nil } - + if featureName == "input_ids" { - for i in 0.. { return Set(["input"]) } - + // The model expects the input IDs to be an array of integers // of length `sequenceLength`, padded with `paddingID` if necessary func featureValue(for featureName: String) -> MLFeatureValue? { @@ -287,4 +316,3 @@ class ImageInput: MLFeatureProvider { } } } - diff --git a/swift/EmbeddingsTests.swift b/swift/EmbeddingsTests.swift index 221842a..671a547 100644 --- a/swift/EmbeddingsTests.swift +++ b/swift/EmbeddingsTests.swift @@ -1,12 +1,10 @@ -import UForm - -import XCTest import CoreGraphics import ImageIO +import UForm +import XCTest final class TokenizerTests: XCTestCase { - - + func cosineSimilarity(between vectorA: [T], and vectorB: [T]) -> T { guard vectorA.count == vectorB.count else { fatalError("Vectors must be of the same length.") @@ -24,43 +22,64 @@ final class TokenizerTests: XCTestCase { return dotProduct / (magnitudeA * magnitudeB) } - func testTextEmbeddings() async throws { - let model = try TextEncoder( - modelPath: "uform/uform-vl-english-small-text.mlpackage", - configPath: "uform/config.json", - tokenizerPath: "uform/tokenizer.json" + + let root = "/uform/" + let textModel = try TextEncoder( + modelPath: root + "uform-vl-english-large-text.mlpackage", + configPath: root + "uform-vl-english-large-text.json", + tokenizerPath: root + "uform-vl-english-large-text.tokenizer.json" ) - + let texts = [ "sunny beach with clear blue water", "crowded sandbeach under the bright sun", "dense forest with tall green trees", - "quiet park in the morning light" + "quiet park in the morning light", ] - - var embeddings: [[Float32]] = [] + + var textEmbeddings: [[Float32]] = [] for text in texts { - let embedding: [Float32] = try model.forward(with: text) - embeddings.append(embedding) + let embedding: [Float32] = try textModel.forward(with: text) + textEmbeddings.append(embedding) } - - // Now let's compute the cosine similarity between the embeddings - let similarityBeach = cosineSimilarity(between: embeddings[0], and: embeddings[1]) - let similarityForest = cosineSimilarity(between: embeddings[2], and: embeddings[3]) - let dissimilarityBetweenScenes = cosineSimilarity(between: embeddings[0], and: embeddings[2]) - + + // Now let's compute the cosine similarity between the textEmbeddings + let similarityBeach = cosineSimilarity(between: textEmbeddings[0], and: textEmbeddings[1]) + let similarityForest = cosineSimilarity(between: textEmbeddings[2], and: textEmbeddings[3]) + let dissimilarityBetweenScenes = cosineSimilarity(between: textEmbeddings[0], and: textEmbeddings[2]) + // Assert that similar texts have higher similarity scores - XCTAssertTrue(similarityBeach > dissimilarityBetweenScenes, "Beach texts should be more similar to each other than to forest texts.") - XCTAssertTrue(similarityForest > dissimilarityBetweenScenes, "Forest texts should be more similar to each other than to beach texts.") + XCTAssertTrue( + similarityBeach > dissimilarityBetweenScenes, + "Beach texts should be more similar to each other than to forest texts." + ) + XCTAssertTrue( + similarityForest > dissimilarityBetweenScenes, + "Forest texts should be more similar to each other than to beach texts." + ) } - + func testImageEmbeddings() async throws { - let model = try ImageEncoder( - modelPath: "uform/uform-vl-english-small-image.mlpackage", - configPath: "uform/config_image.json" + + let root = "/uform/" + let textModel = try TextEncoder( + modelPath: root + "uform-vl-english-large-text.mlpackage", + configPath: root + "uform-vl-english-large-text.json", + tokenizerPath: root + "uform-vl-english-large-text.tokenizer.json" ) - + let imageModel = try ImageEncoder( + modelPath: root + "uform-vl-english-large-image.mlpackage", + configPath: root + "uform-vl-english-large-image.json" + ) + + let texts = [ + "A group of friends enjoy a barbecue on a sandy beach, with one person grilling over a large black grill, while the other sits nearby, laughing and enjoying the camaraderie.", + "A white and orange cat stands on its hind legs, reaching towards a wicker basket filled with red raspberries on a wooden table in a garden, surrounded by orange flowers and a white teapot, creating a serene and whimsical scene.", + "A young girl in a yellow dress stands in a grassy field, holding an umbrella and looking at the camera, amidst rain.", + "This serene bedroom features a white bed with a black canopy, a gray armchair, a black dresser with a mirror, a vase with a plant, a window with white curtains, a rug, and a wooden floor, creating a tranquil and elegant atmosphere.", + "The image captures the iconic Louvre Museum in Paris, illuminated by warm lights against a dark sky, with the iconic glass pyramid in the center, surrounded by ornate buildings and a large courtyard, showcasing the museum's grandeur and historical significance.", + ] let imageURLs = [ "https://github.com/ashvardanian/ashvardanian/blob/master/demos/bbq-on-beach.jpg?raw=true", "https://github.com/ashvardanian/ashvardanian/blob/master/demos/cat-in-garden.jpg?raw=true", @@ -69,27 +88,46 @@ final class TokenizerTests: XCTestCase { "https://github.com/ashvardanian/ashvardanian/blob/master/demos/louvre-at-night.jpg?raw=true", ] - var embeddings: [[Float32]] = [] - for imageURL in imageURLs { + var textEmbeddings: [[Float32]] = [] + var imageEmbeddings: [[Float32]] = [] + for (text, imageURL) in zip(texts, imageURLs) { guard let url = URL(string: imageURL), - let imageSource = CGImageSourceCreateWithURL(url as CFURL, nil), - let cgImage = CGImageSourceCreateImageAtIndex(imageSource, 0, nil) else { - throw NSError(domain: "ImageError", code: 100, userInfo: [NSLocalizedDescriptionKey: "Could not load image from URL: \(imageURL)"]) + let imageSource = CGImageSourceCreateWithURL(url as CFURL, nil), + let cgImage = CGImageSourceCreateImageAtIndex(imageSource, 0, nil) + else { + throw NSError( + domain: "ImageError", + code: 100, + userInfo: [NSLocalizedDescriptionKey: "Could not load image from URL: \(imageURL)"] + ) } - - let embedding: [Float32] = try model.forward(with: cgImage) - embeddings.append(embedding) + + let textEmbedding: [Float32] = try textModel.forward(with: text) + textEmbeddings.append(textEmbedding) + let imageEmbedding: [Float32] = try imageModel.forward(with: cgImage) + imageEmbeddings.append(imageEmbedding) } - // Now let's compute the cosine similarity between the embeddings - let similarityGirlAndBeach = cosineSimilarity(between: embeddings[2], and: embeddings[0]) - let similarityGirlAndLouvre = cosineSimilarity(between: embeddings[2], and: embeddings[4]) - let similarityBeachAndLouvre = cosineSimilarity(between: embeddings[0], and: embeddings[4]) + // Now let's make sure that the cosine distance between image and respective text embeddings is low. + // Make sure that the similarity between image and text at index `i` is higher than with other texts and images. + for i in 0 ..< texts.count { + let pairSimilarity = cosineSimilarity(between: textEmbeddings[i], and: imageEmbeddings[i]) + let otherTextSimilarities = (0 ..< texts.count).filter { $0 != i }.map { + cosineSimilarity(between: textEmbeddings[$0], and: imageEmbeddings[i]) + } + let otherImageSimilarities = (0 ..< texts.count).filter { $0 != i }.map { + cosineSimilarity(between: textEmbeddings[i], and: imageEmbeddings[$0]) + } - // Assert that similar images have higher similarity scores - XCTAssertTrue(similarityGirlAndBeach > similarityGirlAndLouvre, ""); - XCTAssertTrue(similarityGirlAndBeach > similarityBeachAndLouvre, ""); + XCTAssertTrue( + pairSimilarity > otherTextSimilarities.max()!, + "Text should be more similar to its corresponding image than to other images." + ) + XCTAssertTrue( + pairSimilarity > otherImageSimilarities.max()!, + "Text should be more similar to its corresponding image than to other texts." + ) + } } - - + } From f2772d0d92317818c4d1c49166bc7ec3ee314f60 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Fri, 12 Apr 2024 16:44:58 -0700 Subject: [PATCH 3/4] Fix: Image preprocessing in Swift --- .vscode/launch.json | 15 +++ python/scripts/test_embeddings.py | 7 +- python/uform/numpy_preprocessor.py | 3 + swift/Embeddings.swift | 150 +++++++++++++++++++---------- swift/EmbeddingsTests.swift | 6 +- 5 files changed, 127 insertions(+), 54 deletions(-) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..59eb78c --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File with Arguments", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + } + ] +} \ No newline at end of file diff --git a/python/scripts/test_embeddings.py b/python/scripts/test_embeddings.py index f831d72..d71bf0b 100644 --- a/python/scripts/test_embeddings.py +++ b/python/scripts/test_embeddings.py @@ -11,7 +11,7 @@ torch_available = True except: torch_available = False - + # ONNX is not a very light dependency either try: import onnx @@ -34,6 +34,7 @@ ("unum-cloud/uform-vl-english-large", "gpu", "fp16"), ] + @pytest.mark.skipif(not torch_available, reason="PyTorch is not installed") @pytest.mark.parametrize("model_name", torch_models) def test_torch_one_embedding(model_name: str): @@ -141,3 +142,7 @@ def test_onnx_many_embeddings(model_specs: Tuple[str, str, str], batch_size: int except ExecutionProviderError as e: pytest.skip(f"Execution provider error: {e}") + + +if __name__ == "__main__": + pytest.main(["-s", "-x", __file__]) diff --git a/python/uform/numpy_preprocessor.py b/python/uform/numpy_preprocessor.py index 6cdd54b..a556db4 100644 --- a/python/uform/numpy_preprocessor.py +++ b/python/uform/numpy_preprocessor.py @@ -89,6 +89,9 @@ def _resize_crop_normalize(self, image: Image): bottom = (height + self._image_size) / 2 image = image.convert("RGB").crop((left, top, right, bottom)) + # At this point `image` is a PIL Image with RGB channels. + # If you convert it to `np.ndarray` it will have shape (H, W, C) where C is the number of channels. image = (np.array(image).astype(np.float32) / 255.0 - self.image_mean) / self.image_std + # To make it compatible with PyTorch, we need to transpose the image to (C, H, W). return np.transpose(image, (2, 0, 1)) diff --git a/swift/Embeddings.swift b/swift/Embeddings.swift index 00d56c4..176e884 100644 --- a/swift/Embeddings.swift +++ b/swift/Embeddings.swift @@ -11,6 +11,69 @@ import Foundation import Hub // `Config` import Tokenizers // `AutoTokenizer` +public enum Embedding { + case i32s([Int32]) + case f16s([Float16]) + case f32s([Float32]) + case f64s([Float64]) + + init?(from multiArray: MLMultiArray) { + switch multiArray.dataType { + case .float64: + self = .f64s( + Array( + UnsafeBufferPointer( + start: multiArray.dataPointer.assumingMemoryBound(to: Float64.self), + count: Int(truncating: multiArray.shape[1]) + ) + ) + ) + case .float32: + self = .f32s( + Array( + UnsafeBufferPointer( + start: multiArray.dataPointer.assumingMemoryBound(to: Float32.self), + count: Int(truncating: multiArray.shape[1]) + ) + ) + ) + case .float16: + self = .f16s( + Array( + UnsafeBufferPointer( + start: multiArray.dataPointer.assumingMemoryBound(to: Float16.self), + count: Int(truncating: multiArray.shape[1]) + ) + ) + ) + case .int32: + self = .i32s( + Array( + UnsafeBufferPointer( + start: multiArray.dataPointer.assumingMemoryBound(to: Int32.self), + count: Int(truncating: multiArray.shape[1]) + ) + ) + ) + @unknown default: + return nil // return nil for unsupported data types + } + } + + public func asFloats() -> [Float] { + switch self { + case .f32s(let array): + return array + case .i32s(let array): + return array.map { Float($0) } + case .f16s(let array): + return array.map { Float($0) } + case .f64s(let array): + return array.map { Float($0) } + } + } +} + // MARK: - Helpers func readConfig(fromPath path: String) throws -> [String: Any] { @@ -39,18 +102,20 @@ public class TextEncoder { self.processor = try TextProcessor(configPath: configPath, tokenizerPath: tokenizerPath, model: self.model) } - public func forward(with text: String) throws -> [Float32] { + public func forward(with text: String) throws -> Embedding { let inputFeatureProvider = try self.processor.preprocess(text) let prediction = try self.model.prediction(from: inputFeatureProvider) - let predictionFeature = prediction.featureValue(for: "embeddings") - // The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32 - let output = predictionFeature!.multiArrayValue! - return Array( - UnsafeBufferPointer( - start: output.dataPointer.assumingMemoryBound(to: Float32.self), - count: Int(truncating: output.shape[1]) + guard let predictionFeature = prediction.featureValue(for: "embeddings"), + let output = predictionFeature.multiArrayValue, + let embedding = Embedding(from: output) + else { + throw NSError( + domain: "TextEncoder", + code: 0, + userInfo: [NSLocalizedDescriptionKey: "Failed to extract embeddings or unsupported data type."] ) - ) + } + return embedding } } @@ -63,20 +128,21 @@ public class ImageEncoder { self.processor = try ImageProcessor(configPath: configPath) } - public func forward(with image: CGImage) throws -> [Float32] { + public func forward(with image: CGImage) throws -> Embedding { let inputFeatureProvider = try self.processor.preprocess(image) let prediction = try self.model.prediction(from: inputFeatureProvider) - let predictionFeature = prediction.featureValue(for: "embeddings") - // The `predictionFeature` is an MLMultiArray, which can be converted to an array of Float32 - let output = predictionFeature!.multiArrayValue! - return Array( - UnsafeBufferPointer( - start: output.dataPointer.assumingMemoryBound(to: Float32.self), - count: Int(truncating: output.shape[1]) + guard let predictionFeature = prediction.featureValue(for: "embeddings"), + let output = predictionFeature.multiArrayValue, + let embedding = Embedding(from: output) + else { + throw NSError( + domain: "ImageEncoder", + code: 0, + userInfo: [NSLocalizedDescriptionKey: "Failed to extract embeddings or unsupported data type."] ) - ) + } + return embedding } - } // MARK: - Processors @@ -147,6 +213,7 @@ class ImageProcessor { let originalWidth = CGFloat(image.width) let originalHeight = CGFloat(image.height) + // Calculate new size preserving the aspect ratio let widthRatio = CGFloat(imageSize) / originalWidth let heightRatio = CGFloat(imageSize) / originalHeight let scaleFactor = max(widthRatio, heightRatio) @@ -154,11 +221,9 @@ class ImageProcessor { let scaledWidth = originalWidth * scaleFactor let scaledHeight = originalHeight * scaleFactor + // Calculate the crop rectangle let dx = (scaledWidth - CGFloat(imageSize)) / 2.0 let dy = (scaledHeight - CGFloat(imageSize)) / 2.0 - let insetRect = CGRect(x: dx, y: dy, width: CGFloat(imageSize) - dx * 2, height: CGFloat(imageSize) - dy * 2) - - // Create a new context (off-screen canvas) with the desired dimensions guard let context = CGContext( data: nil, @@ -171,11 +236,9 @@ class ImageProcessor { ) else { return nil } - // Draw the image in the context with the specified inset (cropping as necessary) + // Draw the scaled and cropped image in the context context.interpolationQuality = .high - context.draw(image, in: insetRect, byTiling: false) - - // Extract the new image from the context + context.draw(image, in: CGRect(x: -dx, y: -dy, width: scaledWidth, height: scaledHeight)) return context.makeImage() } @@ -193,44 +256,31 @@ class ImageProcessor { bitsPerComponent: 8, bytesPerRow: 4 * width, space: colorSpace, - bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue + bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue ) context?.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height)) - // Convert pixel data to float and normalize - let totalCount = width * height * 4 - var floatPixels = [Float](repeating: 0, count: totalCount) - vDSP_vfltu8(pixelData, 1, &floatPixels, 1, vDSP_Length(totalCount)) - - // Scale the pixel values to [0, 1] - var divisor = Float(255.0) - vDSP_vsdiv(floatPixels, 1, &divisor, &floatPixels, 1, vDSP_Length(totalCount)) - - // Normalize the pixel values + // Normalize the pixel data + var floatPixels = [Float](repeating: 0, count: width * height * 3) for c in 0 ..< 3 { - var slice = [Float](repeating: 0, count: width * height) for i in 0 ..< (width * height) { - slice[i] = (floatPixels[i * 4 + c] - mean[c]) / std[c] + floatPixels[i * 3 + c] = (Float(pixelData[i * 4 + c]) / 255.0 - mean[c]) / std[c] } - floatPixels.replaceSubrange(c * width * height ..< (c + 1) * width * height, with: slice) } - // Rearrange the array to C x H x W - var tensor = [Float](repeating: 0, count: width * height * 3) - for y in 0 ..< height { - for x in 0 ..< width { - for c in 0 ..< 3 { - tensor[c * width * height + y * width + x] = floatPixels[y * width * 4 + x * 4 + c] - } + // Create the tensor array + var tensor = [Float](repeating: 0, count: 3 * width * height) + for i in 0 ..< (width * height) { + for c in 0 ..< 3 { + tensor[c * width * height + i] = floatPixels[i * 3 + c] } } - // Reshape the tensor to 1 x 3 x H x W and pack into a rank-3 `MLFeatureValue` let multiArray = try? MLMultiArray( - shape: [1, 3, NSNumber(value: self.imageSize), NSNumber(value: self.imageSize)], + shape: [1, 3, NSNumber(value: height), NSNumber(value: width)], dataType: .float32 ) - for i in 0 ..< (width * height * 3) { + for i in 0 ..< tensor.count { multiArray?[i] = NSNumber(value: tensor[i]) } return multiArray diff --git a/swift/EmbeddingsTests.swift b/swift/EmbeddingsTests.swift index 671a547..2797c63 100644 --- a/swift/EmbeddingsTests.swift +++ b/swift/EmbeddingsTests.swift @@ -40,7 +40,7 @@ final class TokenizerTests: XCTestCase { var textEmbeddings: [[Float32]] = [] for text in texts { - let embedding: [Float32] = try textModel.forward(with: text) + let embedding: [Float32] = try textModel.forward(with: text).asFloats() textEmbeddings.append(embedding) } @@ -102,9 +102,9 @@ final class TokenizerTests: XCTestCase { ) } - let textEmbedding: [Float32] = try textModel.forward(with: text) + let textEmbedding: [Float32] = try textModel.forward(with: text).asFloats() textEmbeddings.append(textEmbedding) - let imageEmbedding: [Float32] = try imageModel.forward(with: cgImage) + let imageEmbedding: [Float32] = try imageModel.forward(with: cgImage).asFloats() imageEmbeddings.append(imageEmbedding) } From 729b9d9f73990f2689c22af593171184589a2b27 Mon Sep 17 00:00:00 2001 From: Ash Vardanian <1983160+ashvardanian@users.noreply.github.com> Date: Sat, 13 Apr 2024 17:49:44 -0700 Subject: [PATCH 4/4] Improve: Fetching nested configs --- .vscode/settings.json | 5 ++++ swift/Embeddings.swift | 49 +++++++++++++++++++++++++++++++------ swift/EmbeddingsTests.swift | 39 +++++++++++++++++++---------- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0ac7435..a6cceb8 100755 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,8 +1,12 @@ { "cSpell.words": [ + "arange", + "CFURL", "coreml", + "cumsum", "dtype", "embs", + "finfo", "huggingface", "keepdim", "linalg", @@ -24,6 +28,7 @@ "rerank", "reranker", "reranking", + "sess", "SIMD", "softmax", "transfromers", diff --git a/swift/Embeddings.swift b/swift/Embeddings.swift index 176e884..6d973ac 100644 --- a/swift/Embeddings.swift +++ b/swift/Embeddings.swift @@ -83,12 +83,16 @@ func readConfig(fromPath path: String) throws -> [String: Any] { return try JSONSerialization.jsonObject(with: data, options: []) as! [String: Any] } +func readModel(fromURL modelURL: URL) throws -> MLModel { + let compiledModelURL = try MLModel.compileModel(at: modelURL) + return try MLModel(contentsOf: compiledModelURL) +} + func readModel(fromPath path: String) throws -> MLModel { // If it's not an absolute path, let's assume it's a path relative to the current working directory let absPath = path.hasPrefix("/") ? path : FileManager.default.currentDirectoryPath + "/" + path let modelURL = URL(fileURLWithPath: absPath, isDirectory: true) - let compiledModelURL = try MLModel.compileModel(at: modelURL) - return try MLModel(contentsOf: compiledModelURL) + return try readModel(fromURL: modelURL) } // MARK: - Encoders @@ -97,8 +101,20 @@ public class TextEncoder { let model: MLModel let processor: TextProcessor - public init(modelPath: String, configPath: String, tokenizerPath: String) throws { + public init(modelPath: String, configPath: String? = nil, tokenizerPath: String? = nil) throws { + let finalConfigPath = configPath ?? modelPath + "/config.json" + let finalTokenizerPath = tokenizerPath ?? modelPath + "/tokenizer.json" self.model = try readModel(fromPath: modelPath) + self.processor = try TextProcessor(configPath: finalConfigPath, tokenizerPath: finalTokenizerPath, model: self.model) + } + + + public init(modelName: String, hubApi: HubApi = .shared) async throws { + let repo = Hub.Repo(id: modelName) + let modelURL = try await hubApi.snapshot(from: repo, matching: ["text.mlpackage/*", "config.json", "tokenizer.json"]) + let configPath = modelURL.appendingPathComponent("config.json").path + let tokenizerPath = modelURL.appendingPathComponent("tokenizer.json").path + self.model = try readModel(fromURL: modelURL.appendingPathComponent("text.mlpackage", isDirectory: true)) self.processor = try TextProcessor(configPath: configPath, tokenizerPath: tokenizerPath, model: self.model) } @@ -123,11 +139,20 @@ public class ImageEncoder { let model: MLModel let processor: ImageProcessor - public init(modelPath: String, configPath: String) throws { + public init(modelPath: String, configPath: String? = nil) throws { + let finalConfigPath = configPath ?? modelPath + "/config.json" self.model = try readModel(fromPath: modelPath) - self.processor = try ImageProcessor(configPath: configPath) + self.processor = try ImageProcessor(configPath: finalConfigPath) } + public init(modelName: String, hubApi: HubApi = .shared) async throws { + let repo = Hub.Repo(id: modelName) + let modelURL = try await hubApi.snapshot(from: repo, matching: ["image.mlpackage/*", "config.json"]) + let configPath = modelURL.appendingPathComponent("config.json").path + self.model = try readModel(fromURL: modelURL.appendingPathComponent("image.mlpackage", isDirectory: true)) + self.processor = try ImageProcessor(configPath: configPath) + } + public func forward(with image: CGImage) throws -> Embedding { let inputFeatureProvider = try self.processor.preprocess(image) let prediction = try self.model.prediction(from: inputFeatureProvider) @@ -153,9 +178,14 @@ class TextProcessor { let maxContextLength: Int public init(configPath: String, tokenizerPath: String, model: MLModel) throws { - let configDict = try readConfig(fromPath: configPath) + var configDict = try readConfig(fromPath: configPath) let tokenizerDict = try readConfig(fromPath: tokenizerPath) + // Check if there's a specific 'text_encoder' configuration within the main configuration + if let textEncoderConfig = configDict["text_encoder"] as? [String: Any] { + configDict = textEncoderConfig // Use the specific 'text_encoder' configuration + } + let config = Config(configDict) let tokenizerData = Config(tokenizerDict) self.tokenizer = try AutoTokenizer.from(tokenizerConfig: config, tokenizerData: tokenizerData) @@ -194,7 +224,12 @@ class ImageProcessor { let std: [Float] = [0.229, 0.224, 0.225] // Common std values for normalization init(configPath: String) throws { - let configDict = try readConfig(fromPath: configPath) + var configDict = try readConfig(fromPath: configPath) + // Check if there's a specific 'image_encoder' configuration within the main configuration + if let imageEncoderConfig = configDict["image_encoder"] as? [String: Any] { + configDict = imageEncoderConfig + } + let config = Config(configDict) self.imageSize = config.imageSize!.intValue! } diff --git a/swift/EmbeddingsTests.swift b/swift/EmbeddingsTests.swift index 2797c63..5efb87f 100644 --- a/swift/EmbeddingsTests.swift +++ b/swift/EmbeddingsTests.swift @@ -1,6 +1,7 @@ import CoreGraphics import ImageIO import UForm +import Hub import XCTest final class TokenizerTests: XCTestCase { @@ -24,11 +25,10 @@ final class TokenizerTests: XCTestCase { func testTextEmbeddings() async throws { - let root = "/uform/" - let textModel = try TextEncoder( - modelPath: root + "uform-vl-english-large-text.mlpackage", - configPath: root + "uform-vl-english-large-text.json", - tokenizerPath: root + "uform-vl-english-large-text.tokenizer.json" + let api = HubApi(hfToken: "xxx") + let textModel = try await TextEncoder( + modelName: "unum-cloud/uform-vl2-english-small", + hubApi: api ) let texts = [ @@ -62,15 +62,28 @@ final class TokenizerTests: XCTestCase { func testImageEmbeddings() async throws { - let root = "/uform/" - let textModel = try TextEncoder( - modelPath: root + "uform-vl-english-large-text.mlpackage", - configPath: root + "uform-vl-english-large-text.json", - tokenizerPath: root + "uform-vl-english-large-text.tokenizer.json" + // One option is to use a local model repository. + // + // let root = "uform/" + // let textModel = try TextEncoder( + // modelPath: root + "uform-vl-english-large-text.mlpackage", + // configPath: root + "uform-vl-english-large-text.json", + // tokenizerPath: root + "uform-vl-english-large-text.tokenizer.json" + // ) + // let imageModel = try ImageEncoder( + // modelPath: root + "uform-vl-english-large-image.mlpackage", + // configPath: root + "uform-vl-english-large-image.json" + // ) + // + // A better option is to fetch directly from HuggingFace, similar to how users would do that: + let api = HubApi(hfToken: "xxx") + let textModel = try await TextEncoder( + modelName: "unum-cloud/uform-vl2-english-small", + hubApi: api ) - let imageModel = try ImageEncoder( - modelPath: root + "uform-vl-english-large-image.mlpackage", - configPath: root + "uform-vl-english-large-image.json" + let imageModel = try await ImageEncoder( + modelName: "unum-cloud/uform-vl2-english-small", + hubApi: api ) let texts = [