1use crate::error::PreviewError;
13use tt_provider_openai::ClientConfig;
14use tt_shared::{ModelPricing, Provider};
15
16#[derive(Debug, Clone)]
17pub struct LookupHit {
18 pub provider: &'static str,
19 pub input_per_m: f64,
21 pub output_per_m: f64,
23}
24
25fn hit(provider: &'static str, p: &ModelPricing) -> LookupHit {
26 LookupHit {
27 provider,
28 input_per_m: p.input_per_million,
29 output_per_m: p.output_per_million,
30 }
31}
32
33pub fn lookup(model: &str) -> Result<LookupHit, PreviewError> {
34 if let Some(p) = tt_provider_anthropic::pricing::pricing_for(model) {
36 return Ok(hit("anthropic", &p));
37 }
38 if let Some(p) = tt_provider_openai::pricing::pricing_for(model) {
39 return Ok(hit("openai", &p));
40 }
41 if let Some(p) = tt_provider_gemini::pricing::pricing_for(model) {
42 return Ok(hit("gemini", &p));
43 }
44 let cfg = ClientConfig::default;
46 if let Some(p) = tt_provider_groq::GroqProvider::new(cfg()).pricing(model) {
47 return Ok(hit("groq", &p));
48 }
49 if let Some(p) = tt_provider_mistral::MistralProvider::new(cfg()).pricing(model) {
50 return Ok(hit("mistral", &p));
51 }
52 if let Some(p) = tt_provider_together::TogetherProvider::new(cfg()).pricing(model) {
53 return Ok(hit("together", &p));
54 }
55 if let Some(p) = tt_provider_openrouter::OpenRouterProvider::new(cfg()).pricing(model) {
56 return Ok(hit("openrouter", &p));
57 }
58 Err(PreviewError::UnknownModel(model.to_string()))
59}
60
61pub fn lookup_with_provider(model: &str, provider: &str) -> Result<LookupHit, PreviewError> {
71 let cfg = ClientConfig::default;
72 let found = match provider {
73 "anthropic" => {
74 tt_provider_anthropic::pricing::pricing_for(model).map(|p| hit("anthropic", &p))
75 }
76 "openai" => tt_provider_openai::pricing::pricing_for(model).map(|p| hit("openai", &p)),
77 "gemini" => tt_provider_gemini::pricing::pricing_for(model).map(|p| hit("gemini", &p)),
78 "groq" => tt_provider_groq::GroqProvider::new(cfg())
79 .pricing(model)
80 .map(|p| hit("groq", &p)),
81 "mistral" => tt_provider_mistral::MistralProvider::new(cfg())
82 .pricing(model)
83 .map(|p| hit("mistral", &p)),
84 "together" => tt_provider_together::TogetherProvider::new(cfg())
85 .pricing(model)
86 .map(|p| hit("together", &p)),
87 "openrouter" => tt_provider_openrouter::OpenRouterProvider::new(cfg())
88 .pricing(model)
89 .map(|p| hit("openrouter", &p)),
90 _ => None,
91 };
92 found.ok_or_else(|| PreviewError::UnknownModel(model.to_string()))
93}
94
95pub fn cost_usd(input_tokens: u32, output_tokens: u32, hit: &LookupHit) -> f64 {
97 let i = (input_tokens as f64) * hit.input_per_m / 1_000_000.0;
98 let o = (output_tokens as f64) * hit.output_per_m / 1_000_000.0;
99 i + o
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn cost_math_basics() {
108 let hit = LookupHit {
109 provider: "x",
110 input_per_m: 3.0,
111 output_per_m: 15.0,
112 };
113 let c = cost_usd(1000, 100, &hit);
115 assert!((c - 0.0045).abs() < 1e-9, "cost = {c}");
116 }
117
118 #[test]
119 fn lookup_unknown_model_errors() {
120 let err = lookup("does-not-exist-model").unwrap_err();
121 assert!(matches!(err, PreviewError::UnknownModel(_)));
122 }
123
124 #[test]
125 fn lookup_resolves_compat_provider_models() {
126 let hit = lookup("llama-3.1-8b-instant").expect("groq model should resolve");
129 assert_eq!(hit.provider, "groq");
130 assert!(hit.input_per_m > 0.0, "groq pricing should be > 0");
131 }
132
133 #[test]
134 fn lookup_with_provider_attributes_to_named_provider() {
135 let hit =
137 lookup_with_provider("gpt-4o-mini", "openai").expect("openai carries gpt-4o-mini");
138 assert_eq!(hit.provider, "openai");
139 let hit = lookup_with_provider("claude-haiku-4-5", "anthropic")
140 .expect("anthropic carries claude-haiku-4-5");
141 assert_eq!(hit.provider, "anthropic");
142 }
143
144 #[test]
145 fn lookup_with_provider_errors_when_provider_lacks_model() {
146 let err = lookup_with_provider("gpt-4o-mini", "groq").unwrap_err();
149 assert!(matches!(err, PreviewError::UnknownModel(_)));
150 let err = lookup_with_provider("gpt-4o-mini", "nope").unwrap_err();
152 assert!(matches!(err, PreviewError::UnknownModel(_)));
153 }
154}