xai_grpc_client/embedding.rs
1//! Embedding API for generating vector representations.
2//!
3//! This module provides access to xAI's embedding models, allowing you to:
4//! - Generate embeddings from text strings
5//! - Generate embeddings from images
6//! - Support for both text-only and multimodal embedding models
7//!
8//! # Examples
9//!
10//! ## Embedding text
11//!
12//! ```no_run
13//! use xai_grpc_client::{GrokClient, EmbedRequest};
14//!
15//! #[tokio::main]
16//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//! let mut client = GrokClient::from_env().await?;
18//!
19//! let request = EmbedRequest::new("embed-large-v1")
20//! .add_text("Hello, world!")
21//! .add_text("How are you?");
22//!
23//! let response = client.embed(request).await?;
24//!
25//! for embedding in response.embeddings {
26//! println!("Embedding {} has {} dimensions",
27//! embedding.index, embedding.vector.len());
28//! }
29//! Ok(())
30//! }
31//! ```
32//!
33//! ## Embedding images
34//!
35//! ```no_run
36//! use xai_grpc_client::{GrokClient, EmbedRequest};
37//!
38//! #[tokio::main]
39//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
40//! let mut client = GrokClient::from_env().await?;
41//!
42//! let request = EmbedRequest::new("embed-vision-v1")
43//! .add_image("https://example.com/image.jpg");
44//!
45//! let response = client.embed(request).await?;
46//! println!("Generated {} embeddings", response.embeddings.len());
47//! Ok(())
48//! }
49//! ```
50
51use crate::{proto, request::ImageDetail};
52
53/// Request for generating embeddings.
54///
55/// Supports embedding text strings, images, or a mix of both depending on
56/// the model capabilities. You can embed up to 128 inputs in a single request.
57#[derive(Clone, Debug)]
58pub struct EmbedRequest {
59 /// Inputs to embed (text or images).
60 pub inputs: Vec<EmbedInput>,
61 /// Model name or alias to use.
62 pub model: String,
63 /// Encoding format for the embeddings (Float or Base64).
64 pub encoding_format: EmbedEncodingFormat,
65 /// Optional user identifier for tracking.
66 pub user: Option<String>,
67}
68
69impl EmbedRequest {
70 /// Create a new embedding request with the specified model.
71 ///
72 /// # Examples
73 ///
74 /// ```
75 /// use xai_grpc_client::EmbedRequest;
76 ///
77 /// let request = EmbedRequest::new("embed-large-v1");
78 /// ```
79 pub fn new(model: impl Into<String>) -> Self {
80 Self {
81 inputs: Vec::new(),
82 model: model.into(),
83 encoding_format: EmbedEncodingFormat::Float,
84 user: None,
85 }
86 }
87
88 /// Add a text string to embed.
89 ///
90 /// # Examples
91 ///
92 /// ```
93 /// use xai_grpc_client::EmbedRequest;
94 ///
95 /// let request = EmbedRequest::new("embed-large-v1")
96 /// .add_text("Hello, world!");
97 /// ```
98 pub fn add_text(mut self, text: impl Into<String>) -> Self {
99 self.inputs.push(EmbedInput::Text(text.into()));
100 self
101 }
102
103 /// Add an image URL to embed.
104 ///
105 /// # Examples
106 ///
107 /// ```
108 /// use xai_grpc_client::EmbedRequest;
109 ///
110 /// let request = EmbedRequest::new("embed-vision-v1")
111 /// .add_image("https://example.com/image.jpg");
112 /// ```
113 pub fn add_image(self, url: impl Into<String>) -> Self {
114 self.add_image_with_detail(url, ImageDetail::Auto)
115 }
116
117 /// Add an image URL with specific detail level.
118 ///
119 /// # Examples
120 ///
121 /// ```
122 /// use xai_grpc_client::{EmbedRequest, ImageDetail};
123 ///
124 /// let request = EmbedRequest::new("embed-vision-v1")
125 /// .add_image_with_detail("https://example.com/image.jpg", ImageDetail::High);
126 /// ```
127 pub fn add_image_with_detail(mut self, url: impl Into<String>, detail: ImageDetail) -> Self {
128 self.inputs.push(EmbedInput::Image {
129 url: url.into(),
130 detail,
131 });
132 self
133 }
134
135 /// Set the encoding format for embeddings.
136 ///
137 /// # Examples
138 ///
139 /// ```
140 /// use xai_grpc_client::{EmbedRequest, EmbedEncodingFormat};
141 ///
142 /// let request = EmbedRequest::new("embed-large-v1")
143 /// .with_encoding_format(EmbedEncodingFormat::Base64);
144 /// ```
145 pub fn with_encoding_format(mut self, format: EmbedEncodingFormat) -> Self {
146 self.encoding_format = format;
147 self
148 }
149
150 /// Set the user identifier for tracking.
151 pub fn with_user(mut self, user: impl Into<String>) -> Self {
152 self.user = Some(user.into());
153 self
154 }
155}
156
157/// Input to be embedded (text or image).
158#[derive(Clone, Debug)]
159pub enum EmbedInput {
160 /// Text string to embed.
161 Text(String),
162 /// Image URL to embed with optional detail level.
163 Image {
164 /// URL of the image.
165 url: String,
166 /// Detail level for processing.
167 detail: ImageDetail,
168 },
169}
170
171/// Encoding format for embedding vectors.
172#[derive(Clone, Debug, PartialEq, Eq)]
173pub enum EmbedEncodingFormat {
174 /// Return embeddings as arrays of floats.
175 Float,
176 /// Return embeddings as base64-encoded strings.
177 Base64,
178}
179
180/// Response from an embedding request.
181#[derive(Clone, Debug)]
182pub struct EmbedResponse {
183 /// Request identifier.
184 pub id: String,
185 /// Generated embeddings (one per input).
186 pub embeddings: Vec<Embedding>,
187 /// Token usage statistics.
188 pub usage: EmbeddingUsage,
189 /// Model name used (may differ from request if alias was used).
190 pub model: String,
191 /// Backend configuration fingerprint.
192 pub system_fingerprint: String,
193}
194
195/// A single embedding vector.
196#[derive(Clone, Debug)]
197pub struct Embedding {
198 /// Index of the input that generated this embedding.
199 pub index: usize,
200 /// The embedding vector.
201 pub vector: Vec<f32>,
202}
203
204/// Usage statistics for an embedding request.
205#[derive(Clone, Debug, Default)]
206pub struct EmbeddingUsage {
207 /// Number of text embeddings generated.
208 pub num_text_embeddings: u32,
209 /// Number of image embeddings generated.
210 pub num_image_embeddings: u32,
211}
212
213impl From<proto::EmbeddingUsage> for EmbeddingUsage {
214 fn from(proto: proto::EmbeddingUsage) -> Self {
215 Self {
216 num_text_embeddings: proto.num_text_embeddings as u32,
217 num_image_embeddings: proto.num_image_embeddings as u32,
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_embed_request_builder() {
228 let request = EmbedRequest::new("embed-large-v1")
229 .add_text("Hello")
230 .add_text("World");
231
232 assert_eq!(request.model, "embed-large-v1");
233 assert_eq!(request.inputs.len(), 2);
234 assert!(matches!(request.inputs[0], EmbedInput::Text(_)));
235 }
236
237 #[test]
238 fn test_embed_request_with_images() {
239 let request = EmbedRequest::new("embed-vision-v1")
240 .add_image("https://example.com/img1.jpg")
241 .add_image_with_detail("https://example.com/img2.jpg", ImageDetail::High);
242
243 assert_eq!(request.inputs.len(), 2);
244 assert!(matches!(request.inputs[0], EmbedInput::Image { .. }));
245 assert!(matches!(request.inputs[1], EmbedInput::Image { .. }));
246 }
247
248 #[test]
249 fn test_embed_request_mixed() {
250 let request = EmbedRequest::new("embed-multimodal-v1")
251 .add_text("Description")
252 .add_image("https://example.com/img.jpg");
253
254 assert_eq!(request.inputs.len(), 2);
255 }
256
257 #[test]
258 fn test_encoding_format() {
259 let request =
260 EmbedRequest::new("embed-large-v1").with_encoding_format(EmbedEncodingFormat::Base64);
261
262 assert_eq!(request.encoding_format, EmbedEncodingFormat::Base64);
263 }
264
265 #[test]
266 fn test_with_user() {
267 let request = EmbedRequest::new("embed-large-v1").with_user("user123");
268
269 assert_eq!(request.user, Some("user123".to_string()));
270 }
271
272 #[test]
273 fn test_embedding_usage_default() {
274 let usage = EmbeddingUsage::default();
275 assert_eq!(usage.num_text_embeddings, 0);
276 assert_eq!(usage.num_image_embeddings, 0);
277 }
278
279 #[test]
280 fn test_embedding_usage_from_proto() {
281 let proto = proto::EmbeddingUsage {
282 num_text_embeddings: 5,
283 num_image_embeddings: 2,
284 };
285
286 let usage: EmbeddingUsage = proto.into();
287 assert_eq!(usage.num_text_embeddings, 5);
288 assert_eq!(usage.num_image_embeddings, 2);
289 }
290}