x_ai/
embedding.rs

1//! Reference: https://docs.x.ai/api/endpoints#create-embeddings
2
3use crate::error::check_for_model_error;
4use crate::error::XaiError;
5use crate::traits::{ClientConfig, EmbeddingFetcher};
6use reqwest::Method;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EmbeddingRequest {
11    pub input: Vec<String>,
12    pub model: String,
13    pub encoding_format: String,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EmbeddingResponse {
18    pub data: Vec<EmbeddingData>,
19    pub model: String,
20    pub object: String,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct EmbeddingData {
25    pub embedding: EmbeddingValue,
26    pub index: u32,
27    pub object: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(untagged)]
32pub enum EmbeddingValue {
33    Float(Vec<f32>),
34}
35
36#[derive(Debug, Clone)]
37pub struct EmbeddingRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
38    client: T,
39    request: EmbeddingRequest,
40}
41
42impl<T> EmbeddingRequestBuilder<T>
43where
44    T: ClientConfig + Clone + Send + Sync,
45{
46    pub fn new(client: T, model: String, input: Vec<String>, encoding_format: String) -> Self {
47        Self {
48            client,
49            request: EmbeddingRequest {
50                input,
51                model,
52                encoding_format,
53            },
54        }
55    }
56
57    pub fn build(self) -> Result<EmbeddingRequest, XaiError> {
58        Ok(self.request)
59    }
60}
61
62impl<T> EmbeddingFetcher for EmbeddingRequestBuilder<T>
63where
64    T: ClientConfig + Clone + Send + Sync,
65{
66    async fn create_embedding(
67        &self,
68        request: EmbeddingRequest,
69    ) -> Result<EmbeddingResponse, XaiError> {
70        let response = self
71            .client
72            .request(Method::POST, "embeddings")?
73            .json(&request)
74            .send()
75            .await?;
76
77        if response.status().is_success() {
78            let chat_completion = response.json::<EmbeddingResponse>().await?;
79            Ok(chat_completion)
80        } else {
81            let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
82
83            if let Some(model_error) = check_for_model_error(&error_body) {
84                return Err(model_error);
85            }
86
87            Err(XaiError::Http(error_body))
88        }
89    }
90}