Skip to main content

xai_rust/api/
embeddings.rs

1//! Embeddings API for generating vector embeddings.
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::client::XaiClient;
7use crate::{Error, Result};
8
9/// Embeddings API client.
10#[derive(Debug, Clone)]
11pub struct EmbeddingsApi {
12    client: XaiClient,
13}
14
15impl EmbeddingsApi {
16    pub(crate) fn new(client: XaiClient) -> Self {
17        Self { client }
18    }
19
20    /// Create embeddings from an input payload.
21    pub async fn create(&self, request: EmbeddingsRequest) -> Result<EmbeddingsResponse> {
22        let url = format!("{}/embeddings", self.client.base_url());
23
24        let response = self
25            .client
26            .send(self.client.http().post(&url).json(&request))
27            .await?;
28
29        if !response.status().is_success() {
30            return Err(Error::from_response(response).await);
31        }
32
33        Ok(response.json().await?)
34    }
35}
36
37/// Request body for `/v1/embeddings`.
38#[derive(Debug, Clone, Serialize)]
39pub struct EmbeddingsRequest {
40    /// Model name for embedding generation.
41    pub model: String,
42    /// Input payload (text, array of text, etc.).
43    pub input: Value,
44    /// Optional embedding dimensions.
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub dimensions: Option<u32>,
47}
48
49impl EmbeddingsRequest {
50    /// Create a new embeddings request.
51    pub fn new(model: impl Into<String>, input: Value) -> Self {
52        Self {
53            model: model.into(),
54            input,
55            dimensions: None,
56        }
57    }
58}
59
60/// Response from `/v1/embeddings`.
61#[derive(Debug, Clone, Deserialize)]
62pub struct EmbeddingsResponse {
63    /// Optional response object payload.
64    #[serde(default)]
65    pub data: Vec<Value>,
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use wiremock::matchers::{body_json, method, path};
72    use wiremock::{Mock, MockServer, ResponseTemplate};
73
74    #[tokio::test]
75    async fn create_posts_to_embeddings_endpoint() {
76        let server = MockServer::start().await;
77
78        Mock::given(method("POST"))
79            .and(path("/embeddings"))
80            .and(body_json(serde_json::json!({
81                "model": "text-embedding",
82                "input": "sample",
83            })))
84            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
85                "data": [{
86                    "embedding": [0.1, 0.2, 0.3]
87                }]
88            })))
89            .mount(&server)
90            .await;
91
92        let client = crate::client::XaiClient::builder()
93            .api_key("test-key")
94            .base_url(server.uri())
95            .build()
96            .unwrap();
97
98        let response = client
99            .embeddings()
100            .create(EmbeddingsRequest::new(
101                "text-embedding",
102                serde_json::json!("sample"),
103            ))
104            .await
105            .unwrap();
106
107        assert_eq!(response.data.len(), 1);
108    }
109}