Custom Embedding Providers

raghilda includes built-in support for OpenAI and Cohere embeddings, but you can create custom providers for other embedding services or local models.

The EmbeddingProvider Interface

All embedding providers implement the EmbeddingProvider abstract base class with three required methods:

  • embed(x, input_type) — Generate embeddings for a sequence of texts
  • get_config() — Return a configuration dict for serialization
  • from_config(config) — Create an instance from a configuration dict

Creating a Custom Provider

Here’s a complete example using the Voyage AI embedding API:

from typing import Any, Sequence
import voyageai
from raghilda.embedding import (
    EmbeddingProvider,
    EmbedInputType,
    register_embedding_provider,
)

@register_embedding_provider("EmbeddingVoyage")
class EmbeddingVoyage(EmbeddingProvider):
    """Embedding provider using Voyage AI models."""

    def __init__(
        self,
        model: str = "voyage-3",
        api_key: str | None = None,
        batch_size: int = 128,
    ):
        self.model = model
        self.api_key = api_key
        self.batch_size = batch_size
        self.client = voyageai.Client(api_key=api_key)

    def embed(
        self,
        x: Sequence[str],
        input_type: EmbedInputType = EmbedInputType.DOCUMENT,
    ) -> Sequence[Sequence[float]]:
        if isinstance(x, str):
            raise TypeError("Input must be a sequence of strings")

        if len(x) == 0:
            return []

        # Map to Voyage's input types
        voyage_input_type = (
            "query" if input_type == EmbedInputType.QUERY else "document"
        )

        result = []
        for i in range(0, len(x), self.batch_size):
            batch = list(x[i : i + self.batch_size])
            response = self.client.embed(
                texts=batch,
                model=self.model,
                input_type=voyage_input_type,
            )
            result.extend(response.embeddings)

        return result

    def get_config(self) -> dict[str, Any]:
        return {
            "type": "EmbeddingVoyage",  # Must match the registered name
            "model": self.model,
            "batch_size": self.batch_size,
            # Never include api_key in config for security
        }

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "EmbeddingVoyage":
        return cls(
            model=config.get("model", "voyage-3"),
            batch_size=config.get("batch_size", 128),
        )

Registration

The @register_embedding_provider("EmbeddingVoyage") decorator registers your provider in a global registry. This enables automatic restoration when reconnecting to a store:

from raghilda.store import DuckDBStore
from raghilda.chunker import MarkdownChunker
from raghilda.read import read_as_markdown

# Create store with custom provider
store = DuckDBStore.create(
    location="my_store.db",
    embed=EmbeddingVoyage(model="voyage-3"),
)
chunker = MarkdownChunker()
for uri in documents:
    store.upsert(chunker.chunk(read_as_markdown(uri)))

# Later, reconnect - provider is automatically restored
store = DuckDBStore.connect("my_store.db")
# store.embed is now an EmbeddingVoyage instance

The registered name in the decorator must match the "type" value in get_config().

Configuration Serialization

The get_config() and from_config() methods handle serialization:

  • get_config() returns a dict with all parameters needed to recreate the provider
  • from_config() creates a new instance from that dict
  • Never include API keys in the config — they should come from environment variables

When you connect to an existing store, raghilda reads the stored config and calls from_config() to recreate the provider.

Input Types

The EmbedInputType enum distinguishes between queries and documents:

from raghilda.embedding import EmbedInputType

# Embedding documents for storage
doc_embeddings = provider.embed(documents, input_type=EmbedInputType.DOCUMENT)

# Embedding a query for search
query_embedding = provider.embed([query], input_type=EmbedInputType.QUERY)

Some models (like Cohere and Voyage) produce different embeddings for queries vs documents to optimize retrieval. Others (like OpenAI) ignore this parameter. Your provider should handle both cases appropriately.

Local Models Example

Here’s an example using a local model with sentence-transformers:

from typing import Any, Sequence
from sentence_transformers import SentenceTransformer
from raghilda.embedding import (
    EmbeddingProvider,
    EmbedInputType,
    register_embedding_provider,
)

@register_embedding_provider("EmbeddingLocal")
class EmbeddingLocal(EmbeddingProvider):
    """Embedding provider using local sentence-transformers models."""

    def __init__(self, model: str = "all-MiniLM-L6-v2"):
        self.model_name = model
        self._model = SentenceTransformer(model)

    def embed(
        self,
        x: Sequence[str],
        input_type: EmbedInputType = EmbedInputType.DOCUMENT,
    ) -> Sequence[Sequence[float]]:
        if isinstance(x, str):
            raise TypeError("Input must be a sequence of strings")

        if len(x) == 0:
            return []

        # sentence-transformers doesn't distinguish query vs document
        embeddings = self._model.encode(list(x))
        return [emb.tolist() for emb in embeddings]

    def get_config(self) -> dict[str, Any]:
        return {
            "type": "EmbeddingLocal",
            "model": self.model_name,
        }

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "EmbeddingLocal":
        return cls(model=config.get("model", "all-MiniLM-L6-v2"))

Usage:

from raghilda.store import DuckDBStore

store = DuckDBStore.create(
    location="local_store.db",
    embed=EmbeddingLocal(model="all-MiniLM-L6-v2"),
)

ChromaDB Compatibility

Custom providers work with ChromaDBStore without any extra method. Pass the provider to ChromaDBStore.create(), and raghilda adapts it internally for Chroma.

If your provider does not map to a native Chroma embedding function, raghilda uses the provider’s regular embed() implementation. That path is Python-only, so cross-language Chroma clients cannot restore it.

Best Practices

  1. Validate inputs — Check for empty strings and wrong types in embed()
  2. Batch efficiently — Process texts in batches to avoid API limits and improve performance
  3. Handle errors gracefully — Provide clear error messages for common issues
  4. Omit secrets from config — Never store API keys; use environment variables
  5. Test round-trip serialization — Ensure from_config(provider.get_config()) produces an equivalent provider

Example: Testing Your Provider

def test_provider_round_trip():
    # Create provider
    original = EmbeddingVoyage(model="voyage-3", batch_size=64)

    # Serialize and deserialize
    config = original.get_config()
    restored = EmbeddingVoyage.from_config(config)

    # Verify
    assert restored.model == original.model
    assert restored.batch_size == original.batch_size

def test_embedding_output():
    provider = EmbeddingVoyage()
    texts = ["Hello world", "Testing embeddings"]

    embeddings = provider.embed(texts)

    assert len(embeddings) == 2
    assert all(isinstance(emb, (list, tuple)) for emb in embeddings)
    assert all(isinstance(v, float) for emb in embeddings for v in emb)