xai_grpc_client/
models.rs

1//! Model listing and information API.
2//!
3//! This module provides access to xAI's model listing API, allowing you to:
4//! - List all available language models with [`GrokClient::list_models`](crate::GrokClient::list_models)
5//! - Get detailed information about specific models with [`GrokClient::get_model`](crate::GrokClient::get_model)
6//! - Check pricing, context lengths, and capabilities
7//!
8//! # Examples
9//!
10//! ## Listing all models
11//!
12//! ```no_run
13//! use xai_grpc_client::GrokClient;
14//!
15//! #[tokio::main]
16//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//!     let mut client = GrokClient::from_env().await?;
18//!     let models = client.list_models().await?;
19//!
20//!     for model in models {
21//!         println!("{}: {} tokens", model.name, model.max_prompt_length);
22//!     }
23//!     Ok(())
24//! }
25//! ```
26//!
27//! ## Getting specific model information
28//!
29//! ```no_run
30//! use xai_grpc_client::GrokClient;
31//!
32//! #[tokio::main]
33//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
34//!     let mut client = GrokClient::from_env().await?;
35//!     let model = client.get_model("grok-2-1212").await?;
36//!
37//!     println!("Model: {}", model.name);
38//!     println!("Version: {}", model.version);
39//!     println!("Max context: {} tokens", model.max_prompt_length);
40//!     println!("Multimodal: {}", model.supports_multimodal());
41//!
42//!     // Calculate cost for a typical request
43//!     let cost = model.calculate_cost(10_000, 1_000, 0);
44//!     println!("Cost for 10K prompt + 1K completion: ${:.4}", cost);
45//!
46//!     Ok(())
47//! }
48//! ```
49
50use crate::proto;
51
52/// Information about a language model.
53///
54///This struct contains comprehensive metadata about an xAI language model,
55/// including its capabilities, pricing, and technical specifications.
56///
57/// # Pricing Units
58///
59/// The pricing fields use specific units to represent fractional cents:
60/// - `prompt_text_token_price`: 1/100 USD cents per 1M tokens (e.g., 500 = $0.05 per 1M tokens)
61/// - `prompt_image_token_price`: 1/100 USD cents per 1M tokens
62/// - `completion_text_token_price`: 1/100 USD cents per 1M tokens
63/// - `cached_prompt_token_price`: USD cents per 100M tokens (e.g., 50 = $0.50 per 100M tokens)
64/// - `search_price`: 1/100 USD cents per 1M searches
65///
66/// Use [`calculate_cost`](LanguageModel::calculate_cost) to convert these to USD amounts.
67///
68/// # Examples
69///
70/// ```no_run
71/// # use xai_grpc_client::GrokClient;
72/// # #[tokio::main]
73/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
74/// # let mut client = GrokClient::from_env().await?;
75/// let model = client.get_model("grok-2-1212").await?;
76///
77/// // Check capabilities
78/// if model.supports_multimodal() {
79///     println!("{} supports images!", model.name);
80/// }
81///
82/// // Calculate costs
83/// let cost = model.calculate_cost(50_000, 5_000, 0);
84/// println!("50K prompt + 5K completion costs: ${:.4}", cost);
85/// # Ok(())
86/// # }
87/// ```
88#[derive(Clone, Debug)]
89pub struct LanguageModel {
90    /// The model name used in API requests (e.g., "grok-2-1212").
91    pub name: String,
92
93    /// Alternative names that can be used for this model (e.g., ["grok-2-latest"]).
94    ///
95    /// Aliases provide convenient shortcuts for referring to models without
96    /// needing to know the specific version number.
97    pub aliases: Vec<String>,
98
99    /// Version number of the model (e.g., "2.0").
100    pub version: String,
101
102    /// Supported input modalities.
103    ///
104    /// Common combinations:
105    /// - `[Text]` - Text-only model
106    /// - `[Text, Image]` - Multimodal model supporting vision
107    pub input_modalities: Vec<Modality>,
108
109    /// Supported output modalities.
110    ///
111    /// Most models output `[Text]`, but some specialized models may
112    /// support image generation or embeddings.
113    pub output_modalities: Vec<Modality>,
114
115    /// Price per million prompt text tokens in 1/100 USD cents.
116    ///
117    /// Example: 500 = $0.05 per 1M tokens = $0.00005 per token
118    pub prompt_text_token_price: i64,
119
120    /// Price per million prompt image tokens in 1/100 USD cents.
121    ///
122    /// Only applicable for multimodal models that accept images.
123    pub prompt_image_token_price: i64,
124
125    /// Price per 100 million cached prompt tokens in USD cents.
126    ///
127    /// Example: 50 = $0.50 per 100M tokens
128    ///
129    /// Cached tokens are significantly cheaper as they're reused from
130    /// previous requests with the same prefix.
131    pub cached_prompt_token_price: i64,
132
133    /// Price per million completion text tokens in 1/100 USD cents.
134    ///
135    /// Example: 1500 = $0.15 per 1M tokens = $0.00015 per token
136    pub completion_text_token_price: i64,
137
138    /// Price per million searches in 1/100 USD cents.
139    ///
140    /// Only applicable when using web search or X search tools.
141    pub search_price: i64,
142
143    /// Maximum context length in tokens (prompt + completion).
144    ///
145    /// This represents the total number of tokens the model can process
146    /// in a single request, including both input and output.
147    pub max_prompt_length: i32,
148
149    /// Backend configuration fingerprint.
150    ///
151    /// This identifier tracks the specific backend configuration used by
152    /// the model, useful for debugging and reproducibility.
153    pub system_fingerprint: String,
154}
155
156/// Modality supported by a model for input or output.
157///
158/// Models can support different combinations of modalities:
159/// - Text-only models: `input_modalities: [Text]`, `output_modalities: [Text]`
160/// - Vision models: `input_modalities: [Text, Image]`, `output_modalities: [Text]`
161/// - Embedding models: `input_modalities: [Text]`, `output_modalities: [Embedding]`
162#[derive(Clone, Debug, PartialEq, Eq)]
163pub enum Modality {
164    /// Text input/output - supported by all language models.
165    Text,
166
167    /// Image input/output - supported by multimodal vision models.
168    ///
169    /// Models with `Image` in `input_modalities` can process image URLs
170    /// alongside text prompts.
171    Image,
172
173    /// Embedding input/output - vector representations.
174    ///
175    /// Used by embedding models that convert text or images into
176    /// high-dimensional vector representations for semantic search.
177    Embedding,
178}
179
180/// Information about an embedding model.
181///
182/// Embedding models convert text or images into high-dimensional vector
183/// representations that can be used for semantic search, clustering, and
184/// similarity comparisons.
185///
186/// # Pricing Units
187///
188/// - `prompt_text_token_price`: 1/100 USD cents per 1M tokens
189/// - `prompt_image_token_price`: 1/100 USD cents per 1M tokens
190///
191/// # Examples
192///
193/// ```no_run
194/// # use xai_grpc_client::GrokClient;
195/// # #[tokio::main]
196/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
197/// # let mut client = GrokClient::from_env().await?;
198/// let model = client.get_embedding_model("embed-large-v1").await?;
199///
200/// println!("Model: {}", model.name);
201/// println!("Version: {}", model.version);
202/// # Ok(())
203/// # }
204/// ```
205#[derive(Clone, Debug)]
206pub struct EmbeddingModel {
207    /// The model name used in API requests (e.g., "embed-large-v1").
208    pub name: String,
209
210    /// Alternative names that can be used for this model.
211    pub aliases: Vec<String>,
212
213    /// Version number of the model.
214    pub version: String,
215
216    /// Supported input modalities (typically Text and optionally Image).
217    pub input_modalities: Vec<Modality>,
218
219    /// Supported output modalities (always includes Embedding).
220    pub output_modalities: Vec<Modality>,
221
222    /// Price per million text prompt tokens in 1/100 USD cents.
223    pub prompt_text_token_price: i64,
224
225    /// Price per million image prompt tokens in 1/100 USD cents.
226    pub prompt_image_token_price: i64,
227
228    /// Backend configuration fingerprint.
229    pub system_fingerprint: String,
230}
231
232/// Information about an image generation model.
233///
234/// Image generation models create images from text prompts.
235///
236/// # Pricing Units
237///
238/// - `image_price`: USD cents per image
239///
240/// # Examples
241///
242/// ```no_run
243/// # use xai_grpc_client::GrokClient;
244/// # #[tokio::main]
245/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
246/// # let mut client = GrokClient::from_env().await?;
247/// let model = client.get_image_generation_model("image-gen-1").await?;
248///
249/// println!("Model: {}", model.name);
250/// println!("Cost per image: ${:.2}", model.image_price as f64 / 100.0);
251/// # Ok(())
252/// # }
253/// ```
254#[derive(Clone, Debug)]
255pub struct ImageGenerationModel {
256    /// The model name used in API requests.
257    pub name: String,
258
259    /// Alternative names that can be used for this model.
260    pub aliases: Vec<String>,
261
262    /// Version number of the model.
263    pub version: String,
264
265    /// Supported input modalities (typically Text).
266    pub input_modalities: Vec<Modality>,
267
268    /// Supported output modalities (typically Image).
269    pub output_modalities: Vec<Modality>,
270
271    /// Price per image in USD cents.
272    ///
273    /// Example: 200 = $2.00 per image
274    pub image_price: i64,
275
276    /// Maximum length of the prompt/input in tokens.
277    pub max_prompt_length: i32,
278
279    /// Backend configuration fingerprint.
280    pub system_fingerprint: String,
281}
282
283impl From<proto::LanguageModel> for LanguageModel {
284    fn from(proto: proto::LanguageModel) -> Self {
285        Self {
286            name: proto.name,
287            aliases: proto.aliases,
288            version: proto.version,
289            input_modalities: proto
290                .input_modalities
291                .into_iter()
292                .filter_map(|m| proto::Modality::try_from(m).ok())
293                .map(Modality::from)
294                .collect(),
295            output_modalities: proto
296                .output_modalities
297                .into_iter()
298                .filter_map(|m| proto::Modality::try_from(m).ok())
299                .map(Modality::from)
300                .collect(),
301            prompt_text_token_price: proto.prompt_text_token_price,
302            prompt_image_token_price: proto.prompt_image_token_price,
303            cached_prompt_token_price: proto.cached_prompt_token_price,
304            completion_text_token_price: proto.completion_text_token_price,
305            search_price: proto.search_price,
306            max_prompt_length: proto.max_prompt_length,
307            system_fingerprint: proto.system_fingerprint,
308        }
309    }
310}
311
312impl From<proto::Modality> for Modality {
313    fn from(proto: proto::Modality) -> Self {
314        match proto {
315            proto::Modality::Text => Modality::Text,
316            proto::Modality::Image => Modality::Image,
317            proto::Modality::Embedding => Modality::Embedding,
318            proto::Modality::InvalidModality => Modality::Text, // Default fallback
319        }
320    }
321}
322
323impl LanguageModel {
324    /// Calculate the cost (in USD) for a given number of prompt and completion tokens.
325    ///
326    /// # Examples
327    ///
328    /// ```no_run
329    /// # use xai_grpc_client::models::LanguageModel;
330    /// # let model = LanguageModel {
331    /// #     name: "grok-2".to_string(),
332    /// #     aliases: vec![],
333    /// #     version: "1.0".to_string(),
334    /// #     input_modalities: vec![],
335    /// #     output_modalities: vec![],
336    /// #     prompt_text_token_price: 500,
337    /// #     prompt_image_token_price: 0,
338    /// #     cached_prompt_token_price: 0,
339    /// #     completion_text_token_price: 1500,
340    /// #     search_price: 0,
341    /// #     max_prompt_length: 131072,
342    /// #     system_fingerprint: "".to_string(),
343    /// # };
344    /// let cost = model.calculate_cost(1000, 500, 0);
345    /// println!("Cost: ${:.4}", cost);
346    /// ```
347    pub fn calculate_cost(
348        &self,
349        prompt_tokens: u32,
350        completion_tokens: u32,
351        cached_tokens: u32,
352    ) -> f64 {
353        let prompt_cost =
354            (prompt_tokens as f64 * self.prompt_text_token_price as f64) / 1_000_000.0 / 100.0;
355        let cached_cost =
356            (cached_tokens as f64 * self.cached_prompt_token_price as f64) / 100_000_000.0;
357        let completion_cost = (completion_tokens as f64 * self.completion_text_token_price as f64)
358            / 1_000_000.0
359            / 100.0;
360
361        prompt_cost + cached_cost + completion_cost
362    }
363
364    /// Check if the model supports multimodal input (text + images).
365    ///
366    /// Returns `true` if the model accepts both text and image inputs,
367    /// allowing you to send image URLs alongside text prompts.
368    ///
369    /// # Examples
370    ///
371    /// ```no_run
372    /// # use xai_grpc_client::GrokClient;
373    /// # #[tokio::main]
374    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
375    /// # let mut client = GrokClient::from_env().await?;
376    /// let model = client.get_model("grok-2-vision-1212").await?;
377    ///
378    /// if model.supports_multimodal() {
379    ///     println!("{} can process images!", model.name);
380    /// } else {
381    ///     println!("{} is text-only", model.name);
382    /// }
383    /// # Ok(())
384    /// # }
385    /// ```
386    pub fn supports_multimodal(&self) -> bool {
387        self.input_modalities.contains(&Modality::Text)
388            && self.input_modalities.contains(&Modality::Image)
389    }
390}
391
392impl From<proto::EmbeddingModel> for EmbeddingModel {
393    fn from(proto: proto::EmbeddingModel) -> Self {
394        Self {
395            name: proto.name,
396            aliases: proto.aliases,
397            version: proto.version,
398            input_modalities: proto
399                .input_modalities
400                .into_iter()
401                .filter_map(|m| proto::Modality::try_from(m).ok())
402                .map(Modality::from)
403                .collect(),
404            output_modalities: proto
405                .output_modalities
406                .into_iter()
407                .filter_map(|m| proto::Modality::try_from(m).ok())
408                .map(Modality::from)
409                .collect(),
410            prompt_text_token_price: proto.prompt_text_token_price,
411            prompt_image_token_price: proto.prompt_image_token_price,
412            system_fingerprint: proto.system_fingerprint,
413        }
414    }
415}
416
417impl From<proto::ImageGenerationModel> for ImageGenerationModel {
418    fn from(proto: proto::ImageGenerationModel) -> Self {
419        Self {
420            name: proto.name,
421            aliases: proto.aliases,
422            version: proto.version,
423            input_modalities: proto
424                .input_modalities
425                .into_iter()
426                .filter_map(|m| proto::Modality::try_from(m).ok())
427                .map(Modality::from)
428                .collect(),
429            output_modalities: proto
430                .output_modalities
431                .into_iter()
432                .filter_map(|m| proto::Modality::try_from(m).ok())
433                .map(Modality::from)
434                .collect(),
435            image_price: proto.image_price,
436            max_prompt_length: proto.max_prompt_length,
437            system_fingerprint: proto.system_fingerprint,
438        }
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    fn create_test_model() -> LanguageModel {
447        LanguageModel {
448            name: "grok-2".to_string(),
449            aliases: vec!["grok-latest".to_string()],
450            version: "1.0".to_string(),
451            input_modalities: vec![Modality::Text],
452            output_modalities: vec![Modality::Text],
453            prompt_text_token_price: 500, // $0.005 per 1M tokens
454            prompt_image_token_price: 0,
455            cached_prompt_token_price: 50,     // $0.50 per 100M tokens
456            completion_text_token_price: 1500, // $0.015 per 1M tokens
457            search_price: 0,
458            max_prompt_length: 131072,
459            system_fingerprint: "fp_test".to_string(),
460        }
461    }
462
463    #[test]
464    fn test_calculate_cost_basic() {
465        let model = create_test_model();
466
467        // price is in 1/100 USD cents per 1M tokens
468        // 500 = $0.05 per 1M tokens = $0.00005 per 1K tokens
469        // 1500 = $0.15 per 1M tokens = $0.00015 per 1K tokens
470        // 1000 prompt * 0.00005 + 500 completion * 0.00015 = 0.005 + 0.0075 = 0.0125
471        let cost = model.calculate_cost(1000, 500, 0);
472        assert!(
473            (cost - 0.0125).abs() < 0.0001,
474            "Expected ~$0.0125, got ${cost}"
475        );
476    }
477
478    #[test]
479    fn test_calculate_cost_with_cached() {
480        let model = create_test_model();
481
482        // cached_prompt_token_price is in USD cents per 100M tokens
483        // Formula: (cached_tokens * cached_prompt_token_price) / 100_000_000
484        // = (10000 * 50) / 100_000_000 = 500000 / 100000000 = $0.005
485        // Total: $0.005 (prompt) + $0.0075 (completion) + $0.005 (cached) = $0.0175
486        let cost = model.calculate_cost(1000, 500, 10000);
487        assert!(
488            (cost - 0.0175).abs() < 0.0001,
489            "Expected ~$0.0175, got ${cost}"
490        );
491    }
492
493    #[test]
494    fn test_calculate_cost_large_numbers() {
495        let model = create_test_model();
496
497        // 1M prompt + 100K completion
498        // = 1M * 0.00005 + 100K * 0.00015 = $5.0 + $1.5 = $6.50
499        let cost = model.calculate_cost(1_000_000, 100_000, 0);
500        assert!((cost - 6.5).abs() < 0.01, "Expected ~$6.50, got ${cost}");
501    }
502
503    #[test]
504    fn test_calculate_cost_zero() {
505        let model = create_test_model();
506        let cost = model.calculate_cost(0, 0, 0);
507        assert_eq!(cost, 0.0);
508    }
509
510    #[test]
511    fn test_supports_multimodal_text_only() {
512        let text_only = LanguageModel {
513            input_modalities: vec![Modality::Text],
514            output_modalities: vec![Modality::Text],
515            ..create_test_model()
516        };
517
518        assert!(!text_only.supports_multimodal());
519    }
520
521    #[test]
522    fn test_supports_multimodal_vision() {
523        let multimodal = LanguageModel {
524            input_modalities: vec![Modality::Text, Modality::Image],
525            output_modalities: vec![Modality::Text],
526            ..create_test_model()
527        };
528
529        assert!(multimodal.supports_multimodal());
530    }
531
532    #[test]
533    fn test_supports_multimodal_image_only() {
534        let image_only = LanguageModel {
535            input_modalities: vec![Modality::Image],
536            output_modalities: vec![Modality::Image],
537            ..create_test_model()
538        };
539
540        assert!(!image_only.supports_multimodal());
541    }
542
543    #[test]
544    fn test_modality_from_proto() {
545        assert_eq!(Modality::from(proto::Modality::Text), Modality::Text);
546        assert_eq!(Modality::from(proto::Modality::Image), Modality::Image);
547        assert_eq!(
548            Modality::from(proto::Modality::Embedding),
549            Modality::Embedding
550        );
551        // Invalid should default to Text
552        assert_eq!(
553            Modality::from(proto::Modality::InvalidModality),
554            Modality::Text
555        );
556    }
557
558    #[test]
559    fn test_language_model_clone() {
560        let model = create_test_model();
561        let cloned = model.clone();
562
563        assert_eq!(model.name, cloned.name);
564        assert_eq!(model.version, cloned.version);
565        assert_eq!(model.max_prompt_length, cloned.max_prompt_length);
566    }
567
568    #[test]
569    fn test_language_model_debug() {
570        let model = create_test_model();
571        let debug_str = format!("{model:?}");
572        assert!(debug_str.contains("grok-2"));
573        assert!(debug_str.contains("1.0"));
574    }
575
576    #[test]
577    fn test_language_model_aliases() {
578        let model = create_test_model();
579        assert_eq!(model.aliases.len(), 1);
580        assert_eq!(model.aliases[0], "grok-latest");
581    }
582
583    #[test]
584    fn test_language_model_from_proto() {
585        let proto_model = proto::LanguageModel {
586            name: "test-model".to_string(),
587            aliases: vec!["test-alias".to_string()],
588            version: "2.0".to_string(),
589            input_modalities: vec![proto::Modality::Text as i32, proto::Modality::Image as i32],
590            output_modalities: vec![proto::Modality::Text as i32],
591            prompt_text_token_price: 1000,
592            prompt_image_token_price: 2000,
593            cached_prompt_token_price: 100,
594            completion_text_token_price: 3000,
595            search_price: 500,
596            created: None,
597            max_prompt_length: 32768,
598            system_fingerprint: "fp_test_123".to_string(),
599        };
600
601        let model: LanguageModel = proto_model.into();
602
603        assert_eq!(model.name, "test-model");
604        assert_eq!(model.aliases, vec!["test-alias"]);
605        assert_eq!(model.version, "2.0");
606        assert_eq!(model.input_modalities.len(), 2);
607        assert!(model.input_modalities.contains(&Modality::Text));
608        assert!(model.input_modalities.contains(&Modality::Image));
609        assert_eq!(model.prompt_text_token_price, 1000);
610        assert_eq!(model.prompt_image_token_price, 2000);
611        assert_eq!(model.cached_prompt_token_price, 100);
612        assert_eq!(model.completion_text_token_price, 3000);
613        assert_eq!(model.search_price, 500);
614        assert_eq!(model.max_prompt_length, 32768);
615        assert_eq!(model.system_fingerprint, "fp_test_123");
616    }
617}