1use crate::error::check_for_model_error;
4use crate::error::XaiError;
5use crate::traits::{ClientConfig, EmbeddingFetcher};
6use reqwest::Method;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct EmbeddingRequest {
11 pub input: Vec<String>,
12 pub model: String,
13 pub encoding_format: String,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EmbeddingResponse {
18 pub data: Vec<EmbeddingData>,
19 pub model: String,
20 pub object: String,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct EmbeddingData {
25 pub embedding: EmbeddingValue,
26 pub index: u32,
27 pub object: String,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31#[serde(untagged)]
32pub enum EmbeddingValue {
33 Float(Vec<f32>),
34}
35
36#[derive(Debug, Clone)]
37pub struct EmbeddingRequestBuilder<T: ClientConfig + Clone + Send + Sync> {
38 client: T,
39 request: EmbeddingRequest,
40}
41
42impl<T> EmbeddingRequestBuilder<T>
43where
44 T: ClientConfig + Clone + Send + Sync,
45{
46 pub fn new(client: T, model: String, input: Vec<String>, encoding_format: String) -> Self {
47 Self {
48 client,
49 request: EmbeddingRequest {
50 input,
51 model,
52 encoding_format,
53 },
54 }
55 }
56
57 pub fn build(self) -> Result<EmbeddingRequest, XaiError> {
58 Ok(self.request)
59 }
60}
61
62impl<T> EmbeddingFetcher for EmbeddingRequestBuilder<T>
63where
64 T: ClientConfig + Clone + Send + Sync,
65{
66 async fn create_embedding(
67 &self,
68 request: EmbeddingRequest,
69 ) -> Result<EmbeddingResponse, XaiError> {
70 let response = self
71 .client
72 .request(Method::POST, "embeddings")?
73 .json(&request)
74 .send()
75 .await?;
76
77 if response.status().is_success() {
78 let chat_completion = response.json::<EmbeddingResponse>().await?;
79 Ok(chat_completion)
80 } else {
81 let error_body = response.text().await.unwrap_or_else(|_| "".to_string());
82
83 if let Some(model_error) = check_for_model_error(&error_body) {
84 return Err(model_error);
85 }
86
87 Err(XaiError::Http(error_body))
88 }
89 }
90}