1use std::pin::Pin;
4
5use async_trait::async_trait;
6use futures::Stream;
7
8use crate::client::{CreateMessageRequest, MessageResponse, StreamEvent};
9use crate::error::Result;
10
11#[derive(Debug, Clone)]
13pub struct ProviderCapabilities {
14 pub streaming: bool,
16 pub tool_use: bool,
18 pub thinking: bool,
20 pub prompt_caching: bool,
22}
23
24#[derive(Debug, Clone)]
26pub struct CostRates {
27 pub input_per_million: f64,
28 pub output_per_million: f64,
29 pub cache_read_multiplier: Option<f64>,
32 pub cache_creation_multiplier: Option<f64>,
35}
36
37impl CostRates {
38 pub fn compute(&self, input_tokens: u64, output_tokens: u64) -> f64 {
40 self.compute_with_cache(input_tokens, output_tokens, 0, 0)
41 }
42
43 pub fn compute_with_cache(
48 &self,
49 input_tokens: u64,
50 output_tokens: u64,
51 cache_read_tokens: u64,
52 cache_creation_tokens: u64,
53 ) -> f64 {
54 let read_rate = self.input_per_million * self.cache_read_multiplier.unwrap_or(1.0);
55 let create_rate = self.input_per_million * self.cache_creation_multiplier.unwrap_or(1.0);
56 (input_tokens as f64 * self.input_per_million
57 + cache_read_tokens as f64 * read_rate
58 + cache_creation_tokens as f64 * create_rate
59 + output_tokens as f64 * self.output_per_million)
60 / 1_000_000.0
61 }
62}
63
64#[async_trait]
70pub trait LlmProvider: Send + Sync {
71 fn name(&self) -> &str;
73
74 fn capabilities(&self) -> ProviderCapabilities;
76
77 fn cost_rates(&self, model: &str) -> CostRates;
79
80 async fn create_message(&self, request: &CreateMessageRequest) -> Result<MessageResponse>;
82
83 async fn create_message_stream(
85 &self,
86 request: &CreateMessageRequest,
87 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>>;
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 fn simple_rates(input: f64, output: f64) -> CostRates {
95 CostRates {
96 input_per_million: input,
97 output_per_million: output,
98 cache_read_multiplier: None,
99 cache_creation_multiplier: None,
100 }
101 }
102
103 #[test]
104 fn cost_rates_compute() {
105 let rates = simple_rates(2.0, 8.0);
106 let cost = rates.compute(1_000_000, 500_000);
108 assert!((cost - 6.0).abs() < 1e-9, "expected 6.0, got {}", cost);
109 }
110
111 #[test]
112 fn cost_rates_compute_zero_tokens() {
113 let rates = simple_rates(10.0, 40.0);
114 let cost = rates.compute(0, 0);
115 assert!((cost - 0.0).abs() < 1e-9, "expected 0.0, got {}", cost);
116 }
117
118 #[test]
119 fn cost_rates_compute_small_usage() {
120 let rates = simple_rates(2.5, 10.0);
121 let cost = rates.compute(100, 50);
123 let expected = 750.0 / 1_000_000.0;
124 assert!(
125 (cost - expected).abs() < 1e-12,
126 "expected {}, got {}",
127 expected,
128 cost
129 );
130 }
131
132 #[test]
133 fn cost_rates_with_cache() {
134 let rates = CostRates {
135 input_per_million: 3.0, output_per_million: 15.0,
137 cache_read_multiplier: Some(0.1),
138 cache_creation_multiplier: Some(1.25),
139 };
140 let cost = rates.compute_with_cache(1000, 500, 10_000, 2000);
142 let expected = (1000.0 * 3.0 + 10_000.0 * 0.3 + 2000.0 * 3.75 + 500.0 * 15.0) / 1_000_000.0;
143 assert!(
144 (cost - expected).abs() < 1e-12,
145 "expected {}, got {}",
146 expected,
147 cost
148 );
149 }
150
151 #[test]
152 fn cost_rates_cache_read_only() {
153 let rates = CostRates {
155 input_per_million: 3.0,
156 output_per_million: 15.0,
157 cache_read_multiplier: Some(0.1),
158 cache_creation_multiplier: Some(1.25),
159 };
160 let cost = rates.compute_with_cache(0, 200, 13_000, 0);
161 let expected = (13_000.0 * 0.3 + 200.0 * 15.0) / 1_000_000.0;
162 assert!(
163 (cost - expected).abs() < 1e-12,
164 "expected {}, got {}",
165 expected,
166 cost
167 );
168 }
169
170 #[test]
171 fn cost_rates_cache_creation_only() {
172 let rates = CostRates {
174 input_per_million: 3.0,
175 output_per_million: 15.0,
176 cache_read_multiplier: Some(0.1),
177 cache_creation_multiplier: Some(1.25),
178 };
179 let cost = rates.compute_with_cache(500, 452, 0, 13_000);
180 let expected = (500.0 * 3.0 + 13_000.0 * 3.75 + 452.0 * 15.0) / 1_000_000.0;
181 assert!(
182 (cost - expected).abs() < 1e-12,
183 "expected {}, got {}",
184 expected,
185 cost
186 );
187 }
188
189 #[test]
190 fn cost_rates_no_cache_multipliers_bills_at_standard_rate() {
191 let rates = CostRates {
193 input_per_million: 2.0,
194 output_per_million: 8.0,
195 cache_read_multiplier: None,
196 cache_creation_multiplier: None,
197 };
198 let cost = rates.compute_with_cache(1000, 500, 5000, 3000);
199 let expected = (9000.0 * 2.0 + 500.0 * 8.0) / 1_000_000.0;
201 assert!(
202 (cost - expected).abs() < 1e-12,
203 "expected {}, got {}",
204 expected,
205 cost
206 );
207 }
208
209 #[test]
210 fn multi_turn_cost_accumulation_with_cache() {
211 let rates = CostRates {
214 input_per_million: 3.0, output_per_million: 15.0,
216 cache_read_multiplier: Some(0.1),
217 cache_creation_multiplier: Some(1.25),
218 };
219
220 let mut total_cost: f64 = 0.0;
221
222 total_cost += rates.compute_with_cache(500, 800, 0, 12_000);
225
226 total_cost += rates.compute_with_cache(200, 400, 12_000, 0);
229
230 total_cost += rates.compute_with_cache(300, 600, 12_000, 0);
233
234 let turn1 = (500.0 * 3.0 + 12_000.0 * 3.75 + 800.0 * 15.0) / 1_000_000.0;
236 let turn2 = (200.0 * 3.0 + 12_000.0 * 0.3 + 400.0 * 15.0) / 1_000_000.0;
237 let turn3 = (300.0 * 3.0 + 12_000.0 * 0.3 + 600.0 * 15.0) / 1_000_000.0;
238 let expected = turn1 + turn2 + turn3;
239
240 assert!(
241 (total_cost - expected).abs() < 1e-12,
242 "multi-turn total: expected {}, got {}",
243 expected,
244 total_cost
245 );
246
247 assert!(
249 turn2 < turn1,
250 "turn 2 should be cheaper than turn 1 (cache reads vs creation)"
251 );
252 assert!(
253 turn3 < turn1,
254 "turn 3 should be cheaper than turn 1 (cache reads vs creation)"
255 );
256 }
257}