1use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::client::{CreateMessageRequest, MessageResponse, StreamEvent};
9use crate::error::Result;
10
11#[derive(Debug, Clone)]
13pub struct ProviderCapabilities {
14 pub streaming: bool,
16 pub tool_use: bool,
18 pub thinking: bool,
20 pub prompt_caching: bool,
22}
23
24#[derive(Debug, Clone)]
26pub struct CostRates {
27 pub input_per_million: f64,
28 pub output_per_million: f64,
29 pub cache_read_multiplier: Option<f64>,
32 pub cache_creation_multiplier: Option<f64>,
35}
36
37impl CostRates {
38 pub fn compute(&self, input_tokens: u64, output_tokens: u64) -> f64 {
40 self.compute_with_cache(input_tokens, output_tokens, 0, 0)
41 }
42
43 pub fn compute_with_cache(
48 &self,
49 input_tokens: u64,
50 output_tokens: u64,
51 cache_read_tokens: u64,
52 cache_creation_tokens: u64,
53 ) -> f64 {
54 let read_rate = self.input_per_million * self.cache_read_multiplier.unwrap_or(1.0);
55 let create_rate = self.input_per_million * self.cache_creation_multiplier.unwrap_or(1.0);
56 (input_tokens as f64 * self.input_per_million
57 + cache_read_tokens as f64 * read_rate
58 + cache_creation_tokens as f64 * create_rate
59 + output_tokens as f64 * self.output_per_million)
60 / 1_000_000.0
61 }
62}
63
64#[async_trait]
70pub trait LlmProvider: Send + Sync {
71 fn name(&self) -> &str;
73
74 fn capabilities(&self) -> ProviderCapabilities;
76
77 fn cost_rates(&self, model: &str) -> CostRates;
79
80 async fn create_message(&self, request: &CreateMessageRequest) -> Result<MessageResponse>;
82
83 async fn create_message_stream(
85 &self,
86 request: &CreateMessageRequest,
87 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>>;
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 fn simple_rates(input: f64, output: f64) -> CostRates {
95 CostRates {
96 input_per_million: input,
97 output_per_million: output,
98 cache_read_multiplier: None,
99 cache_creation_multiplier: None,
100 }
101 }
102
103 #[test]
104 fn cost_rates_compute() {
105 let rates = simple_rates(2.0, 8.0);
106 let cost = rates.compute(1_000_000, 500_000);
108 assert!((cost - 6.0).abs() < 1e-9, "expected 6.0, got {}", cost);
109 }
110
111 #[test]
112 fn cost_rates_compute_zero_tokens() {
113 let rates = simple_rates(10.0, 40.0);
114 let cost = rates.compute(0, 0);
115 assert!((cost - 0.0).abs() < 1e-9, "expected 0.0, got {}", cost);
116 }
117
118 #[test]
119 fn cost_rates_compute_small_usage() {
120 let rates = simple_rates(2.5, 10.0);
121 let cost = rates.compute(100, 50);
123 let expected = 750.0 / 1_000_000.0;
124 assert!((cost - expected).abs() < 1e-12, "expected {}, got {}", expected, cost);
125 }
126
127 #[test]
128 fn cost_rates_with_cache() {
129 let rates = CostRates {
130 input_per_million: 3.0, output_per_million: 15.0,
132 cache_read_multiplier: Some(0.1),
133 cache_creation_multiplier: Some(1.25),
134 };
135 let cost = rates.compute_with_cache(1000, 500, 10_000, 2000);
137 let expected = (1000.0 * 3.0 + 10_000.0 * 0.3 + 2000.0 * 3.75 + 500.0 * 15.0) / 1_000_000.0;
138 assert!((cost - expected).abs() < 1e-12, "expected {}, got {}", expected, cost);
139 }
140
141 #[test]
142 fn cost_rates_cache_read_only() {
143 let rates = CostRates {
145 input_per_million: 3.0,
146 output_per_million: 15.0,
147 cache_read_multiplier: Some(0.1),
148 cache_creation_multiplier: Some(1.25),
149 };
150 let cost = rates.compute_with_cache(0, 200, 13_000, 0);
151 let expected = (13_000.0 * 0.3 + 200.0 * 15.0) / 1_000_000.0;
152 assert!((cost - expected).abs() < 1e-12, "expected {}, got {}", expected, cost);
153 }
154
155 #[test]
156 fn cost_rates_cache_creation_only() {
157 let rates = CostRates {
159 input_per_million: 3.0,
160 output_per_million: 15.0,
161 cache_read_multiplier: Some(0.1),
162 cache_creation_multiplier: Some(1.25),
163 };
164 let cost = rates.compute_with_cache(500, 452, 0, 13_000);
165 let expected = (500.0 * 3.0 + 13_000.0 * 3.75 + 452.0 * 15.0) / 1_000_000.0;
166 assert!((cost - expected).abs() < 1e-12, "expected {}, got {}", expected, cost);
167 }
168
169 #[test]
170 fn cost_rates_no_cache_multipliers_bills_at_standard_rate() {
171 let rates = CostRates {
173 input_per_million: 2.0,
174 output_per_million: 8.0,
175 cache_read_multiplier: None,
176 cache_creation_multiplier: None,
177 };
178 let cost = rates.compute_with_cache(1000, 500, 5000, 3000);
179 let expected = (9000.0 * 2.0 + 500.0 * 8.0) / 1_000_000.0;
181 assert!((cost - expected).abs() < 1e-12, "expected {}, got {}", expected, cost);
182 }
183
184 #[test]
185 fn multi_turn_cost_accumulation_with_cache() {
186 let rates = CostRates {
189 input_per_million: 3.0, output_per_million: 15.0,
191 cache_read_multiplier: Some(0.1),
192 cache_creation_multiplier: Some(1.25),
193 };
194
195 let mut total_cost: f64 = 0.0;
196
197 total_cost += rates.compute_with_cache(500, 800, 0, 12_000);
200
201 total_cost += rates.compute_with_cache(200, 400, 12_000, 0);
204
205 total_cost += rates.compute_with_cache(300, 600, 12_000, 0);
208
209 let turn1 = (500.0 * 3.0 + 12_000.0 * 3.75 + 800.0 * 15.0) / 1_000_000.0;
211 let turn2 = (200.0 * 3.0 + 12_000.0 * 0.3 + 400.0 * 15.0) / 1_000_000.0;
212 let turn3 = (300.0 * 3.0 + 12_000.0 * 0.3 + 600.0 * 15.0) / 1_000_000.0;
213 let expected = turn1 + turn2 + turn3;
214
215 assert!(
216 (total_cost - expected).abs() < 1e-12,
217 "multi-turn total: expected {}, got {}",
218 expected,
219 total_cost
220 );
221
222 assert!(turn2 < turn1, "turn 2 should be cheaper than turn 1 (cache reads vs creation)");
224 assert!(turn3 < turn1, "turn 3 should be cheaper than turn 1 (cache reads vs creation)");
225 }
226}