1use crate::reasoning::conversation::Conversation;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolDefinition {
15 pub name: String,
17 pub description: String,
19 pub parameters: serde_json::Value,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ToolCallRequest {
26 pub id: String,
28 pub name: String,
30 pub arguments: String,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "snake_case")]
37pub enum FinishReason {
38 Stop,
40 ToolCalls,
42 MaxTokens,
44 ContentFilter,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "type")]
51pub enum ResponseFormat {
52 #[serde(rename = "text")]
54 Text,
55 #[serde(rename = "json_object")]
57 JsonObject,
58 #[serde(rename = "json_schema")]
60 JsonSchema {
61 schema: serde_json::Value,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
65 name: Option<String>,
66 },
67}
68
69#[derive(Debug, Clone, Default, Serialize, Deserialize)]
71pub struct Usage {
72 pub prompt_tokens: u32,
74 pub completion_tokens: u32,
76 pub total_tokens: u32,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct InferenceOptions {
83 #[serde(default = "default_max_tokens")]
85 pub max_tokens: u32,
86 #[serde(default = "default_temperature")]
88 pub temperature: f32,
89 #[serde(default, skip_serializing_if = "Vec::is_empty")]
91 pub tool_definitions: Vec<ToolDefinition>,
92 #[serde(default = "default_response_format")]
94 pub response_format: ResponseFormat,
95 #[serde(default, skip_serializing_if = "Option::is_none")]
97 pub model: Option<String>,
98 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
100 pub extra: HashMap<String, serde_json::Value>,
101}
102
103fn default_max_tokens() -> u32 {
104 4096
105}
106
107fn default_temperature() -> f32 {
108 0.3
109}
110
111fn default_response_format() -> ResponseFormat {
112 ResponseFormat::Text
113}
114
115impl Default for InferenceOptions {
116 fn default() -> Self {
117 Self {
118 max_tokens: default_max_tokens(),
119 temperature: default_temperature(),
120 tool_definitions: Vec::new(),
121 response_format: ResponseFormat::Text,
122 model: None,
123 extra: HashMap::new(),
124 }
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct InferenceResponse {
131 pub content: String,
133 pub tool_calls: Vec<ToolCallRequest>,
135 pub finish_reason: FinishReason,
137 pub usage: Usage,
139 pub model: String,
141}
142
143impl InferenceResponse {
144 pub fn has_tool_calls(&self) -> bool {
146 !self.tool_calls.is_empty()
147 }
148}
149
150#[derive(Debug, thiserror::Error)]
152pub enum InferenceError {
153 #[error("Provider error: {0}")]
154 Provider(String),
155
156 #[error("Rate limited, retry after {retry_after_ms}ms")]
157 RateLimited { retry_after_ms: u64 },
158
159 #[error("Context window exceeded: {0} tokens requested, {1} available")]
160 ContextOverflow(usize, usize),
161
162 #[error("Model not available: {0}")]
163 ModelUnavailable(String),
164
165 #[error("Invalid request: {0}")]
166 InvalidRequest(String),
167
168 #[error("Timeout after {0:?}")]
169 Timeout(std::time::Duration),
170
171 #[error("Response parse error: {0}")]
172 ParseError(String),
173}
174
175#[async_trait]
183pub trait InferenceProvider: Send + Sync {
184 async fn complete(
186 &self,
187 conversation: &Conversation,
188 options: &InferenceOptions,
189 ) -> Result<InferenceResponse, InferenceError>;
190
191 fn provider_name(&self) -> &str;
193
194 fn default_model(&self) -> &str;
196
197 fn supports_native_tools(&self) -> bool;
199
200 fn supports_structured_output(&self) -> bool;
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_inference_options_default() {
210 let opts = InferenceOptions::default();
211 assert_eq!(opts.max_tokens, 4096);
212 assert!((opts.temperature - 0.3).abs() < f32::EPSILON);
213 assert!(opts.tool_definitions.is_empty());
214 assert!(matches!(opts.response_format, ResponseFormat::Text));
215 }
216
217 #[test]
218 fn test_tool_definition_serde() {
219 let tool = ToolDefinition {
220 name: "web_search".into(),
221 description: "Search the web".into(),
222 parameters: serde_json::json!({
223 "type": "object",
224 "properties": {
225 "query": { "type": "string" }
226 },
227 "required": ["query"]
228 }),
229 };
230 let json = serde_json::to_string(&tool).unwrap();
231 let restored: ToolDefinition = serde_json::from_str(&json).unwrap();
232 assert_eq!(restored.name, "web_search");
233 }
234
235 #[test]
236 fn test_response_format_serde() {
237 let text = ResponseFormat::Text;
238 let json = serde_json::to_string(&text).unwrap();
239 assert!(json.contains("text"));
240
241 let schema = ResponseFormat::JsonSchema {
242 schema: serde_json::json!({"type": "object"}),
243 name: Some("MySchema".into()),
244 };
245 let json = serde_json::to_string(&schema).unwrap();
246 assert!(json.contains("json_schema"));
247 assert!(json.contains("MySchema"));
248 }
249
250 #[test]
251 fn test_inference_response_has_tool_calls() {
252 let resp = InferenceResponse {
253 content: String::new(),
254 tool_calls: vec![ToolCallRequest {
255 id: "tc_1".into(),
256 name: "search".into(),
257 arguments: "{}".into(),
258 }],
259 finish_reason: FinishReason::ToolCalls,
260 usage: Usage::default(),
261 model: "test".into(),
262 };
263 assert!(resp.has_tool_calls());
264
265 let resp_no_tools = InferenceResponse {
266 content: "Hello".into(),
267 tool_calls: vec![],
268 finish_reason: FinishReason::Stop,
269 usage: Usage::default(),
270 model: "test".into(),
271 };
272 assert!(!resp_no_tools.has_tool_calls());
273 }
274
275 #[test]
276 fn test_finish_reason_serde() {
277 let json = serde_json::to_string(&FinishReason::ToolCalls).unwrap();
278 assert_eq!(json, "\"tool_calls\"");
279 let restored: FinishReason = serde_json::from_str(&json).unwrap();
280 assert_eq!(restored, FinishReason::ToolCalls);
281 }
282}