Skip to content

3.2.4

Latest
Compare
Choose a tag to compare
@xenova xenova released this 28 Dec 12:03
307a490

What's new?

  • Add support for visualizing self-attention heatmaps in #1117

    Cats Attention Head 0 Attention Head 1 Attention Head 2
    Attention Head 3 Attention Head 4 Attention Head 5
    Example code
    import { AutoProcessor, AutoModelForImageClassification, interpolate_4d, RawImage } from "@huggingface/transformers";
    
    // Load model and processor
    const model_id = "onnx-community/dinov2-with-registers-small-with-attentions";
    const model = await AutoModelForImageClassification.from_pretrained(model_id);
    const processor = await AutoProcessor.from_pretrained(model_id);
    
    // Load image from URL
    const image = await RawImage.read("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg");
    
    // Pre-process image
    const inputs = await processor(image);
    
    // Perform inference
    const { logits, attentions } = await model(inputs);
    
    // Get the predicted class
    const cls = logits[0].argmax().item();
    const label = model.config.id2label[cls];
    console.log(`Predicted class: ${label}`);
    
    // Set config values
    const patch_size = model.config.patch_size;
    const [width, height] = inputs.pixel_values.dims.slice(-2);
    const w_featmap = Math.floor(width / patch_size);
    const h_featmap = Math.floor(height / patch_size);
    const num_heads = model.config.num_attention_heads;
    const num_cls_tokens = 1;
    const num_register_tokens = model.config.num_register_tokens ?? 0;
    
    // Visualize attention maps
    const selected_attentions = attentions
        .at(-1) // we are only interested in the attention maps of the last layer
        .slice(0, null, 0, [num_cls_tokens + num_register_tokens, null])
        .view(num_heads, 1, w_featmap, h_featmap);
    
    const upscaled = await interpolate_4d(selected_attentions, {
        size: [width, height],
        mode: "nearest",
    });
    
    for (let i = 0; i < num_heads; ++i) {
        const head_attentions = upscaled[i];
        const minval = head_attentions.min().item();
        const maxval = head_attentions.max().item();
        const image = RawImage.fromTensor(
            head_attentions
                .sub_(minval)
                .div_(maxval - minval)
                .mul_(255)
                .to("uint8"),
        );
        await image.save(`attn-head-${i}.png`);
    }
  • Add min, max, argmin, argmax tensor ops for dim=null

  • Add support for nearest-neighbour interpolation in interpolate_4d

  • Depth Estimation pipeline improvements (faster & returns resized depth map)

  • TypeScript improvements by @ocavue and @shrirajh in #1081 and #1122

  • Remove unused imports from tokenizers.js by @pratapvardhan in #1116

New Contributors

Full Changelog: 3.2.3...3.2.4