Skip to content

Contributing New AI Functions#

This page contains references and guides for developing new AI Functions in daft. These steps will guide you through implementing a model expression like:

  • embed_text
  • embed_image
  • classify_text
  • prompt

Step 1. Define the Protocol and Descriptor#

All model expressions are backed by a Protocol and Descriptor. These protocols are based on the verb and modality, and the descriptor is used to instantiate the model at runtime.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# daft.ai.protocols

@runtime_checkable
class TextClassifier(Protocol):
    """Protocol for text classification implementations."""

    def classify_text(self, text: list[str], labels: list[str]) -> list[str]:
        """Classifies a batch of text strings using the given label(s)."""
        ...

class TextClassifierDescriptor(Descriptor[TextClassifier]):
    """Descriptor for a TextClassifier implementation."""

Step 2. Add to the Provider Interface#

You must update the Provider interface with a new method to create your descriptor. This should have a default implementation which simply raises; this makes it so that you need not update all existing providers.

1
2
3
4
5
6
7
8
# daft.ai.provider
class Provider(ABC):

    # ... existing code

    def get_text_classifier(self, model: str | None = None, **options: Any) -> TextClassifierDescriptor:
        """Returns a TextClassifierDescriptor for this provider."""
        raise not_implemented_err(self, method="classify_text")

Step 3. Define the Function.#

In daft.functions.ai you can add the function, and then re-export it in daft.functions.__init__.py. The implementation is responsible for resolving the provider from the given arguments, then you will call the appropriate provider method to get the relevant descriptor.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import daft
from daft import DataType, Series
from daft.ai.protocols import TextClassifier, TextClassifierDescriptor

def classify_text(
    text: Expression,
    labels: list[str],
    *,
    provider: str | Provider | None = None,
    model: str | None = None,
) -> Expression:
    # Load a TextClassifierDescriptor from the resolved provider
    text_classifier = _resolve_provider(provider, "transformers").get_text_classifier(model)

    # Create the stateful class UDF
    classifier = _TextClassifierExpression(text_classifier, labels)

    # Return the expression
    return classifier.classify(text)


@daft.cls
class _TextClassifierExpression:
    """Function expression implementation for a TextClassifier protocol."""

    def __init__(self, descriptor: TextClassifierDescriptor, labels: list[str]):
        # Instantiate from the descriptor in __init__
        self.text_classifier = descriptor.instantiate()
        self.labels = labels

    @daft.method.batch(return_dtype=DataType.string())
    def classify(self, text: Series) -> list[str]:
        text_list = text.to_pylist()
        if not text_list:
            return []
        return self.text_classifier.classify_text(text_list, self.labels)

Step 4. Implement the Protocol for some Provider.#

Here is a simplified example implementation of embed_text for OpenAI. This should give you and idea of where you actual logic should live, and the previous steps are to properly hook your new expression into the provider/model system.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
@dataclass
class OpenAITextEmbedderDescriptor(TextEmbedderDescriptor):
    model: str  # store some metadata

    # We can use the stored metadata to instantiate the protocol implementation
    def instantiate(self) -> TextEmbedder:
        return OpenAITextEmbedder(client=OpenAI(), model=self.model)

@dataclass
class OpenAITextEmbedder(TextEmbedder):
    client: OpenAI
    model: str

    # This is a simple version using the batch API. The full implementation
    # uses dynamic batching and has error handling mechanisms.
    def embed_text(self, text: list[str]) -> list[Embedding]:
        response = self.client.embeddings.create(
            input=text,
            model=self.model,
            encoding_format="float",
        )
        return [np.array(embedding.embedding) for embedding in response.data]

Step 5. Expression Usage#

You can now use this like any other expression.

1
2
3
4
5
import daft

df = daft.read_parquet("/path/to/file.parquet")  # assuming has some column 'text'
df = df.with_column("embedding", embed_text(df["text"], provider="openai"))  # <- set provider to 'openai'
df.show()