1pub mod anthropic;
7pub mod factory;
8pub mod openai_compat;
9pub mod pricing;
10
11use crate::error::LlmError;
12
13#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
15pub struct TokenUsage {
16 pub input_tokens: u32,
18 pub output_tokens: u32,
20}
21
22impl TokenUsage {
23 pub fn accumulate(&mut self, other: &TokenUsage) {
25 self.input_tokens += other.input_tokens;
26 self.output_tokens += other.output_tokens;
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct LlmResponse {
33 pub text: String,
35 pub usage: TokenUsage,
37 pub model: String,
39}
40
41#[derive(Debug, Clone)]
43pub struct GenerationParams {
44 pub max_tokens: u32,
46 pub temperature: f32,
48 pub system_prompt: Option<String>,
50}
51
52impl Default for GenerationParams {
53 fn default() -> Self {
54 Self {
55 max_tokens: 512,
56 temperature: 0.7,
57 system_prompt: None,
58 }
59 }
60}
61
62#[async_trait::async_trait]
67pub trait LlmProvider: Send + Sync {
68 fn name(&self) -> &str;
70
71 async fn complete(
75 &self,
76 system: &str,
77 user_message: &str,
78 params: &GenerationParams,
79 ) -> Result<LlmResponse, LlmError>;
80
81 async fn health_check(&self) -> Result<(), LlmError>;
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn token_usage_default_is_zero() {
91 let usage = TokenUsage::default();
92 assert_eq!(usage.input_tokens, 0);
93 assert_eq!(usage.output_tokens, 0);
94 }
95
96 #[test]
97 fn token_usage_accumulate() {
98 let mut total = TokenUsage {
99 input_tokens: 100,
100 output_tokens: 50,
101 };
102 let other = TokenUsage {
103 input_tokens: 200,
104 output_tokens: 80,
105 };
106 total.accumulate(&other);
107 assert_eq!(total.input_tokens, 300);
108 assert_eq!(total.output_tokens, 130);
109 }
110
111 #[test]
112 fn token_usage_accumulate_multiple() {
113 let mut total = TokenUsage::default();
114 for i in 1..=5 {
115 total.accumulate(&TokenUsage {
116 input_tokens: i * 10,
117 output_tokens: i * 5,
118 });
119 }
120 assert_eq!(total.input_tokens, 150);
122 assert_eq!(total.output_tokens, 75);
123 }
124
125 #[test]
126 fn token_usage_accumulate_zero() {
127 let mut total = TokenUsage {
128 input_tokens: 42,
129 output_tokens: 17,
130 };
131 total.accumulate(&TokenUsage::default());
132 assert_eq!(total.input_tokens, 42);
133 assert_eq!(total.output_tokens, 17);
134 }
135
136 #[test]
137 fn generation_params_default() {
138 let params = GenerationParams::default();
139 assert_eq!(params.max_tokens, 512);
140 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
141 assert!(params.system_prompt.is_none());
142 }
143
144 #[test]
145 fn generation_params_with_system_prompt() {
146 let params = GenerationParams {
147 system_prompt: Some("You are a helpful assistant.".to_string()),
148 ..Default::default()
149 };
150 assert_eq!(
151 params.system_prompt.as_deref(),
152 Some("You are a helpful assistant.")
153 );
154 assert_eq!(params.max_tokens, 512);
155 }
156
157 #[test]
158 fn llm_response_fields() {
159 let response = LlmResponse {
160 text: "Hello, world!".to_string(),
161 usage: TokenUsage {
162 input_tokens: 10,
163 output_tokens: 3,
164 },
165 model: "gpt-4o-mini".to_string(),
166 };
167 assert_eq!(response.text, "Hello, world!");
168 assert_eq!(response.usage.input_tokens, 10);
169 assert_eq!(response.usage.output_tokens, 3);
170 assert_eq!(response.model, "gpt-4o-mini");
171 }
172
173 #[test]
174 fn token_usage_serde_roundtrip() {
175 let usage = TokenUsage {
176 input_tokens: 100,
177 output_tokens: 50,
178 };
179 let json = serde_json::to_string(&usage).expect("serialize");
180 let deserialized: TokenUsage = serde_json::from_str(&json).expect("deserialize");
181 assert_eq!(deserialized.input_tokens, 100);
182 assert_eq!(deserialized.output_tokens, 50);
183 }
184
185 #[test]
186 fn token_usage_clone() {
187 let usage = TokenUsage {
188 input_tokens: 42,
189 output_tokens: 17,
190 };
191 let cloned = usage.clone();
192 assert_eq!(cloned.input_tokens, 42);
193 assert_eq!(cloned.output_tokens, 17);
194 }
195
196 #[test]
197 fn generation_params_clone() {
198 let params = GenerationParams {
199 max_tokens: 1024,
200 temperature: 0.5,
201 system_prompt: Some("test".to_string()),
202 };
203 let cloned = params.clone();
204 assert_eq!(cloned.max_tokens, 1024);
205 assert!((cloned.temperature - 0.5).abs() < f32::EPSILON);
206 assert_eq!(cloned.system_prompt.as_deref(), Some("test"));
207 }
208}