from raghilda.embedding import EmbeddingProvider, register_embedding_provider
@register_embedding_provider("MyCustomEmbedding")
class MyCustomEmbedding(EmbeddingProvider):
def __init__(self, model: str = "default", api_key: str | None = None):
self.model = model
self.api_key = api_key
# Initialize your embedding client here
def embed(self, x, input_type=None):
# Return list of embedding vectors
...
def get_config(self):
# Return config dict (exclude sensitive values like api_key)
return {"type": "MyCustomEmbedding", "model": self.model}
@classmethod
def from_config(cls, config):
return cls(model=config.get("model", "default"))embedding.EmbeddingProvider
Interface for embedding function providers.
Usage
embedding.EmbeddingProvider()To create a custom embedding provider:
- Subclass EmbeddingProvider and implement embed(), get_config(), and from_config()
- Register it with
@register_embedding_provider("MyProvider")
Registered providers are automatically restored when connecting to a database that was created with that provider.
Examples
Methods
| Name | Description |
|---|---|
| embed() | Generate embeddings for a sequence of texts. |
| from_config() | Create a provider instance from a configuration dict. |
| get_config() | Get the configuration dict for this provider. |
embed()
Generate embeddings for a sequence of texts.
Usage
embed(x, input_type=EmbedInputType.DOCUMENT)Parameters
x: Sequence[str]-
A sequence of texts to generate embeddings for.
input_type: EmbedInputType = EmbedInputType.DOCUMENT- The type of input being embedded. Some models (e.g., Cohere) produce different embeddings for queries vs documents. Default is DOCUMENT.
Returns
Sequence[Sequence[float]]-
A sequence of embeddings (the same length as
x), where each embedding is a sequence of floats.
from_config()
Create a provider instance from a configuration dict.
Usage
from_config(config)Parameters
config: dict[str, Any]- Configuration dict from get_config().
Returns
EmbeddingProvider- A new instance of the provider.
get_config()
Get the configuration dict for this provider.
Usage
get_config()The config should contain all parameters needed to recreate the provider, except for sensitive values like API keys. It must include a “type” key with the registered name of the provider.
Returns
dict- Configuration dict that can be passed to from_config().