Skip to main content

smooth_operator_adapter_postgres/
embedder.rs

1//! Adapter-specific embedder: the live [`GatewayEmbedder`].
2//!
3//! The provider-agnostic [`Embedder`] trait, [`InputType`], the network-free
4//! [`DeterministicEmbedder`], and [`DEFAULT_EMBEDDING_DIM`] all live in
5//! [`smooth_operator::embedding`] — the one shared home so the
6//! Postgres adapter and the ingestion pipeline embed text identically (same
7//! FNV-1a hashing, same L2 normalization, same vectors). This module only holds
8//! the adapter-specific [`GatewayEmbedder`]: an OpenAI-compatible
9//! `/v1/embeddings` HTTP client (the SmooAI LiteLLM gateway) that drags
10//! `reqwest` and lives here rather than in `core`.
11//!
12//! ## Dimension decision
13//!
14//! Voyage (`voyage-3-large`, 1024-d) is the production north-star (it backs
15//! smooai's `knowledge_vectors`), but Voyage is *not* exposed on the LiteLLM
16//! gateway. The gateway does expose OpenAI `text-embedding-3-small` (1536-d).
17//! Rather than couple the column width to whichever embedder happens to be
18//! configured, the vector dimension is a first-class adapter parameter:
19//!
20//! | Embedder                | Dim  | Use                              |
21//! | ----------------------- | ---- | -------------------------------- |
22//! | `DeterministicEmbedder` | 1024 | tests / default (Voyage-shaped)  |
23//! | `GatewayEmbedder`       | 1536 | live `text-embedding-3-small`    |
24//!
25//! The `vector(N)` column and the HNSW index are created at `init` time using
26//! the adapter's configured dimension, so dense retrieval is always
27//! dimension-consistent.
28
29use anyhow::{anyhow, Result};
30use async_trait::async_trait;
31
32use smooth_operator::embedding::{Embedder, InputType};
33
34/// Dimension returned by OpenAI `text-embedding-3-small`.
35pub const OPENAI_SMALL_EMBEDDING_DIM: usize = 1536;
36
37/// OpenAI-compatible `/v1/embeddings` embedder (the SmooAI LiteLLM gateway).
38///
39/// Only used when explicitly configured. Reads the endpoint from
40/// `SMOOAI_GATEWAY_URL` and the key from `SMOOAI_GATEWAY_KEY` (or pass them in).
41/// The default model is `text-embedding-3-small` (1536-d) — set the adapter
42/// dimension to [`OPENAI_SMALL_EMBEDDING_DIM`] when using it.
43#[derive(Clone)]
44pub struct GatewayEmbedder {
45    client: reqwest::Client,
46    base_url: String,
47    api_key: String,
48    model: String,
49    dim: usize,
50}
51
52impl GatewayEmbedder {
53    /// Build from explicit config.
54    #[must_use]
55    pub fn new(
56        base_url: impl Into<String>,
57        api_key: impl Into<String>,
58        model: impl Into<String>,
59        dim: usize,
60    ) -> Self {
61        Self {
62            client: reqwest::Client::new(),
63            base_url: base_url.into(),
64            api_key: api_key.into(),
65            model: model.into(),
66            dim,
67        }
68    }
69
70    /// Build from `SMOOAI_GATEWAY_URL` + `SMOOAI_GATEWAY_KEY`, defaulting the
71    /// model to `text-embedding-3-small` and the dimension to 1536.
72    ///
73    /// # Errors
74    /// Returns an error if either environment variable is unset.
75    pub fn from_env() -> Result<Self> {
76        let base_url = std::env::var("SMOOAI_GATEWAY_URL")
77            .map_err(|_| anyhow!("SMOOAI_GATEWAY_URL is not set"))?;
78        let api_key = std::env::var("SMOOAI_GATEWAY_KEY")
79            .map_err(|_| anyhow!("SMOOAI_GATEWAY_KEY is not set"))?;
80        Ok(Self::new(
81            base_url,
82            api_key,
83            "text-embedding-3-small",
84            OPENAI_SMALL_EMBEDDING_DIM,
85        ))
86    }
87}
88
89#[async_trait]
90impl Embedder for GatewayEmbedder {
91    fn dim(&self) -> usize {
92        self.dim
93    }
94
95    async fn embed(&self, texts: &[String], _input_type: InputType) -> Result<Vec<Vec<f32>>> {
96        if texts.is_empty() {
97            return Ok(Vec::new());
98        }
99        // Trim a trailing slash so `{base}/v1/embeddings` is well-formed whether
100        // the configured URL ends in `/` or not.
101        let url = format!("{}/v1/embeddings", self.base_url.trim_end_matches('/'));
102        let body = serde_json::json!({ "model": self.model, "input": texts });
103
104        let resp = self
105            .client
106            .post(&url)
107            .bearer_auth(&self.api_key)
108            .json(&body)
109            .send()
110            .await?;
111
112        if !resp.status().is_success() {
113            let status = resp.status();
114            let text = resp.text().await.unwrap_or_default();
115            return Err(anyhow!("embeddings request failed ({status}): {text}"));
116        }
117
118        #[derive(serde::Deserialize)]
119        struct EmbeddingData {
120            embedding: Vec<f32>,
121            index: usize,
122        }
123        #[derive(serde::Deserialize)]
124        struct EmbeddingResponse {
125            data: Vec<EmbeddingData>,
126        }
127
128        let mut parsed: EmbeddingResponse = resp.json().await?;
129        // OpenAI returns results in request order but documents `index`; sort to
130        // be safe, then validate the dimension matches the column.
131        parsed.data.sort_by_key(|d| d.index);
132        let out: Vec<Vec<f32>> = parsed.data.into_iter().map(|d| d.embedding).collect();
133
134        if out.len() != texts.len() {
135            return Err(anyhow!(
136                "embeddings count mismatch: got {} for {} inputs",
137                out.len(),
138                texts.len()
139            ));
140        }
141        for (i, v) in out.iter().enumerate() {
142            if v.len() != self.dim {
143                return Err(anyhow!(
144                    "embedding {i} has dim {} but adapter expects {}",
145                    v.len(),
146                    self.dim
147                ));
148            }
149        }
150        Ok(out)
151    }
152}