Skip to content

daft.functions.classify_image#

classify_image #

classify_image(image: Expression, labels: Label | list[Label], *, provider: str | Provider | None = None, model: str | None = None, **options: Unpack[ClassifyImageOptions]) -> Expression

Returns an expression that classifies images using the specified model and provider.

Parameters:

Name Type Description Default
image Image Expression

The input image column expression.

required
labels str | list[str]

Label(s) for classification.

required
provider str | Provider | None

The provider to use for the embedding model. By default this will use 'transformers' provider

None
model str | None

The classifier model to use. Can be a model instance or a model name. By default this will use zero-shot-classification model

None
**options Unpack[ClassifyImageOptions]

Any additional options to pass for the model.

{}
Note

Make sure the required provider packages are installed (e.g. vllm, transformers, openai).

Returns:

Name Type Description
Expression String Expression

An expression representing the most-probable label string.

Examples:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
>>> import daft
>>> from daft.functions import classify_image, decode_image
>>> df = (
...     # Discover a few images from HuggingFace
...     daft.from_glob_path("hf://datasets/datasets-examples/doc-image-3/images")
...     # Read the 4 PNG, JPEG, TIFF, WEBP Images
...     .with_column("image_bytes", daft.col("path").download())
...     # Decode the image bytes into a daft Image DataType
...     .with_column("image_type", decode_image(daft.col("image_bytes")))
...     # Convert Image to RGB and resize the image to 288x288
...     .with_column("image_resized", daft.col("image_type").convert_image("RGB").resize(288, 288))
...     # Classify the image
...     .with_column(
...         "image_label",
...         classify_image(
...             daft.col("image_resized"),
...             labels=["bulbasaur", "catapie", "voltorb", "electrode"],
...             provider="transformers",
...             model="openai/clip-vit-base-patch32",
...         ),
...     )
... )
>>> df.show()
╭────────────────────────────────┬─────────┬────────────────┬──────────────┬───────────────────────┬───────────────╮
│ path                           ┆ size    ┆ image_bytes    ┆ image_type   ┆ image_resized         ┆ image_labels  │
│ ---                            ┆ ---     ┆ ---            ┆ ---          ┆ ---                   ┆ ---           │
│ String                         ┆ Int64   ┆ Binary         ┆ Image[MIXED] ┆ Image[RGB; 288 x 288] ┆ String        │
╞════════════════════════════════╪═════════╪════════════════╪══════════════╪═══════════════════════╪═══════════════╡
│ hf://datasets/datasets-exampl… ┆ 113469  ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ bulbasaur     │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ hf://datasets/datasets-exampl… ┆ 206898  ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ catapie       │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ hf://datasets/datasets-exampl… ┆ 1871034 ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ voltorb       │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ hf://datasets/datasets-exampl… ┆ 22022   ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ electrode     │
╰────────────────────────────────┴─────────┴────────────────┴──────────────┴───────────────────────┴───────────────╯
(Showing first 4 of 4 rows)
Source code in daft/functions/ai/__init__.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def classify_image(
    image: Expression,
    labels: Label | list[Label],
    *,
    provider: str | Provider | None = None,
    model: str | None = None,
    **options: Unpack[ClassifyImageOptions],
) -> Expression:
    """Returns an expression that classifies images using the specified model and provider.

    Args:
        image (Image Expression):
            The input image column expression.
        labels (str | list[str]):
            Label(s) for classification.
        provider (str | Provider | None):
            The provider to use for the embedding model.
            By default this will use 'transformers' provider
        model (str | None):
            The classifier model to use. Can be a model instance or a model name.
            By default this will use `zero-shot-classification` model
        **options:
            Any additional options to pass for the model.

    Note:
        Make sure the required provider packages are installed (e.g. vllm, transformers, openai).

    Returns:
        Expression (String Expression): An expression representing the most-probable label string.

    Examples:
        >>> import daft
        >>> from daft.functions import classify_image, decode_image
        >>> df = (
        ...     # Discover a few images from HuggingFace
        ...     daft.from_glob_path("hf://datasets/datasets-examples/doc-image-3/images")
        ...     # Read the 4 PNG, JPEG, TIFF, WEBP Images
        ...     .with_column("image_bytes", daft.col("path").download())
        ...     # Decode the image bytes into a daft Image DataType
        ...     .with_column("image_type", decode_image(daft.col("image_bytes")))
        ...     # Convert Image to RGB and resize the image to 288x288
        ...     .with_column("image_resized", daft.col("image_type").convert_image("RGB").resize(288, 288))
        ...     # Classify the image
        ...     .with_column(
        ...         "image_label",
        ...         classify_image(
        ...             daft.col("image_resized"),
        ...             labels=["bulbasaur", "catapie", "voltorb", "electrode"],
        ...             provider="transformers",
        ...             model="openai/clip-vit-base-patch32",
        ...         ),
        ...     )
        ... )
        >>> df.show()
        ╭────────────────────────────────┬─────────┬────────────────┬──────────────┬───────────────────────┬───────────────╮
        │ path                           ┆ size    ┆ image_bytes    ┆ image_type   ┆ image_resized         ┆ image_labels  │
        │ ---                            ┆ ---     ┆ ---            ┆ ---          ┆ ---                   ┆ ---           │
        │ String                         ┆ Int64   ┆ Binary         ┆ Image[MIXED] ┆ Image[RGB; 288 x 288] ┆ String        │
        ╞════════════════════════════════╪═════════╪════════════════╪══════════════╪═══════════════════════╪═══════════════╡
        │ hf://datasets/datasets-exampl… ┆ 113469  ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ bulbasaur     │
        ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
        │ hf://datasets/datasets-exampl… ┆ 206898  ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ catapie       │
        ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
        │ hf://datasets/datasets-exampl… ┆ 1871034 ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ voltorb       │
        ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤
        │ hf://datasets/datasets-exampl… ┆ 22022   ┆ ...            ┆ <Image>      ┆ <FixedShapeImage>     ┆ electrode     │
        ╰────────────────────────────────┴─────────┴────────────────┴──────────────┴───────────────────────┴───────────────╯
        <BLANKLINE>
        (Showing first 4 of 4 rows)
    """
    from daft.ai._expressions import _ImageClassificationExpression

    image_classifier = _resolve_provider(provider, "transformers").get_image_classifier(model, **options)

    # TODO: classification with structured outputs will be more interesting
    label_list = [labels] if isinstance(labels, str) else labels
    # Decorate the __call__ method with @daft.method to specify return_dtype
    _ImageClassificationExpression.__call__ = method.batch(  # type: ignore[method-assign]
        method=_ImageClassificationExpression.__call__,
        return_dtype=DataType.string(),
    )
    # implemented as a class-based udf for now
    udf_options = image_classifier.get_udf_options()
    wrapped_cls = daft_cls(
        _ImageClassificationExpression,
        max_concurrency=udf_options.concurrency,
        gpus=udf_options.num_gpus or 0,
        max_retries=udf_options.max_retries,
        on_error=udf_options.on_error,
        name_override="classify_image",
    )
    instance = wrapped_cls(image_classifier, label_list)
    return instance(image)