xai_rust/api/
embeddings.rs1use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::client::XaiClient;
7use crate::{Error, Result};
8
9#[derive(Debug, Clone)]
11pub struct EmbeddingsApi {
12 client: XaiClient,
13}
14
15impl EmbeddingsApi {
16 pub(crate) fn new(client: XaiClient) -> Self {
17 Self { client }
18 }
19
20 pub async fn create(&self, request: EmbeddingsRequest) -> Result<EmbeddingsResponse> {
22 let url = format!("{}/embeddings", self.client.base_url());
23
24 let response = self
25 .client
26 .send(self.client.http().post(&url).json(&request))
27 .await?;
28
29 if !response.status().is_success() {
30 return Err(Error::from_response(response).await);
31 }
32
33 Ok(response.json().await?)
34 }
35}
36
37#[derive(Debug, Clone, Serialize)]
39pub struct EmbeddingsRequest {
40 pub model: String,
42 pub input: Value,
44 #[serde(skip_serializing_if = "Option::is_none")]
46 pub dimensions: Option<u32>,
47}
48
49impl EmbeddingsRequest {
50 pub fn new(model: impl Into<String>, input: Value) -> Self {
52 Self {
53 model: model.into(),
54 input,
55 dimensions: None,
56 }
57 }
58}
59
60#[derive(Debug, Clone, Deserialize)]
62pub struct EmbeddingsResponse {
63 #[serde(default)]
65 pub data: Vec<Value>,
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use wiremock::matchers::{body_json, method, path};
72 use wiremock::{Mock, MockServer, ResponseTemplate};
73
74 #[tokio::test]
75 async fn create_posts_to_embeddings_endpoint() {
76 let server = MockServer::start().await;
77
78 Mock::given(method("POST"))
79 .and(path("/embeddings"))
80 .and(body_json(serde_json::json!({
81 "model": "text-embedding",
82 "input": "sample",
83 })))
84 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
85 "data": [{
86 "embedding": [0.1, 0.2, 0.3]
87 }]
88 })))
89 .mount(&server)
90 .await;
91
92 let client = crate::client::XaiClient::builder()
93 .api_key("test-key")
94 .base_url(server.uri())
95 .build()
96 .unwrap();
97
98 let response = client
99 .embeddings()
100 .create(EmbeddingsRequest::new(
101 "text-embedding",
102 serde_json::json!("sample"),
103 ))
104 .await
105 .unwrap();
106
107 assert_eq!(response.data.len(), 1);
108 }
109}