Skip to main content

symbi_runtime/reasoning/
inference.rs

1//! Unified inference provider trait
2//!
3//! Defines the `InferenceProvider` trait that abstracts over cloud LLM APIs
4//! and local SLM runners, adding tool calling and structured output support
5//! on top of the existing `LlmClient` and `SlmRunner`.
6
7use crate::reasoning::conversation::Conversation;
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// A tool definition that can be provided to an inference call.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ToolDefinition {
15    /// Tool name (must match the name the LLM will use to call it).
16    pub name: String,
17    /// Human-readable description of what the tool does.
18    pub description: String,
19    /// JSON Schema describing the tool's parameters.
20    pub parameters: serde_json::Value,
21}
22
23/// A tool call request returned by the model.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct ToolCallRequest {
26    /// Unique identifier for this tool call.
27    pub id: String,
28    /// Name of the tool to invoke.
29    pub name: String,
30    /// JSON-encoded arguments for the tool.
31    pub arguments: String,
32}
33
34/// The reason the model stopped generating.
35#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "snake_case")]
37pub enum FinishReason {
38    /// Model produced a complete response.
39    Stop,
40    /// Model wants to call one or more tools.
41    ToolCalls,
42    /// Generation was truncated due to max_tokens.
43    MaxTokens,
44    /// Generation was truncated due to content filter.
45    ContentFilter,
46}
47
48/// Desired response format from the model.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50#[serde(tag = "type")]
51pub enum ResponseFormat {
52    /// Free-form text response.
53    #[serde(rename = "text")]
54    Text,
55    /// JSON object response (model is instructed to return valid JSON).
56    #[serde(rename = "json_object")]
57    JsonObject,
58    /// JSON response conforming to a specific schema.
59    #[serde(rename = "json_schema")]
60    JsonSchema {
61        /// The JSON schema the response must conform to.
62        schema: serde_json::Value,
63        /// Optional name for the schema (used in API calls).
64        #[serde(default, skip_serializing_if = "Option::is_none")]
65        name: Option<String>,
66    },
67}
68
69/// Token usage information.
70#[derive(Debug, Clone, Default, Serialize, Deserialize)]
71pub struct Usage {
72    /// Tokens in the prompt/input.
73    pub prompt_tokens: u32,
74    /// Tokens in the completion/output.
75    pub completion_tokens: u32,
76    /// Total tokens used.
77    pub total_tokens: u32,
78}
79
80/// Options for an inference call.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct InferenceOptions {
83    /// Maximum tokens to generate.
84    #[serde(default = "default_max_tokens")]
85    pub max_tokens: u32,
86    /// Sampling temperature (0.0 = deterministic, 1.0 = creative).
87    #[serde(default = "default_temperature")]
88    pub temperature: f32,
89    /// Tool definitions available for this call.
90    #[serde(default, skip_serializing_if = "Vec::is_empty")]
91    pub tool_definitions: Vec<ToolDefinition>,
92    /// Desired response format.
93    #[serde(default = "default_response_format")]
94    pub response_format: ResponseFormat,
95    /// Optional model override (provider decides default otherwise).
96    #[serde(default, skip_serializing_if = "Option::is_none")]
97    pub model: Option<String>,
98    /// Additional provider-specific parameters.
99    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
100    pub extra: HashMap<String, serde_json::Value>,
101}
102
103fn default_max_tokens() -> u32 {
104    4096
105}
106
107fn default_temperature() -> f32 {
108    0.3
109}
110
111fn default_response_format() -> ResponseFormat {
112    ResponseFormat::Text
113}
114
115impl Default for InferenceOptions {
116    fn default() -> Self {
117        Self {
118            max_tokens: default_max_tokens(),
119            temperature: default_temperature(),
120            tool_definitions: Vec::new(),
121            response_format: ResponseFormat::Text,
122            model: None,
123            extra: HashMap::new(),
124        }
125    }
126}
127
128/// Response from an inference call.
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct InferenceResponse {
131    /// Text content of the response.
132    pub content: String,
133    /// Tool calls requested by the model (empty if none).
134    pub tool_calls: Vec<ToolCallRequest>,
135    /// Why the model stopped generating.
136    pub finish_reason: FinishReason,
137    /// Token usage statistics.
138    pub usage: Usage,
139    /// The model ID that actually served the request.
140    pub model: String,
141}
142
143impl InferenceResponse {
144    /// Returns true if the model requested tool calls.
145    pub fn has_tool_calls(&self) -> bool {
146        !self.tool_calls.is_empty()
147    }
148}
149
150/// Errors that can occur during inference.
151#[derive(Debug, thiserror::Error)]
152pub enum InferenceError {
153    #[error("Provider error: {0}")]
154    Provider(String),
155
156    #[error("Rate limited, retry after {retry_after_ms}ms")]
157    RateLimited { retry_after_ms: u64 },
158
159    #[error("Context window exceeded: {0} tokens requested, {1} available")]
160    ContextOverflow(usize, usize),
161
162    #[error("Model not available: {0}")]
163    ModelUnavailable(String),
164
165    #[error("Invalid request: {0}")]
166    InvalidRequest(String),
167
168    #[error("Timeout after {0:?}")]
169    Timeout(std::time::Duration),
170
171    #[error("Response parse error: {0}")]
172    ParseError(String),
173}
174
175/// Unified trait for inference providers (cloud LLMs and local SLMs).
176///
177/// Wraps existing `LlmClient` and `SlmRunner` to add:
178/// - Multi-turn conversation support
179/// - Tool calling
180/// - Structured output (response_format)
181/// - Token usage tracking
182#[async_trait]
183pub trait InferenceProvider: Send + Sync {
184    /// Run inference on a conversation with the given options.
185    async fn complete(
186        &self,
187        conversation: &Conversation,
188        options: &InferenceOptions,
189    ) -> Result<InferenceResponse, InferenceError>;
190
191    /// Get the provider's name for logging and routing.
192    fn provider_name(&self) -> &str;
193
194    /// Get the default model ID for this provider.
195    fn default_model(&self) -> &str;
196
197    /// Check if this provider supports tool calling natively.
198    fn supports_native_tools(&self) -> bool;
199
200    /// Check if this provider supports structured output natively.
201    fn supports_structured_output(&self) -> bool;
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_inference_options_default() {
210        let opts = InferenceOptions::default();
211        assert_eq!(opts.max_tokens, 4096);
212        assert!((opts.temperature - 0.3).abs() < f32::EPSILON);
213        assert!(opts.tool_definitions.is_empty());
214        assert!(matches!(opts.response_format, ResponseFormat::Text));
215    }
216
217    #[test]
218    fn test_tool_definition_serde() {
219        let tool = ToolDefinition {
220            name: "web_search".into(),
221            description: "Search the web".into(),
222            parameters: serde_json::json!({
223                "type": "object",
224                "properties": {
225                    "query": { "type": "string" }
226                },
227                "required": ["query"]
228            }),
229        };
230        let json = serde_json::to_string(&tool).unwrap();
231        let restored: ToolDefinition = serde_json::from_str(&json).unwrap();
232        assert_eq!(restored.name, "web_search");
233    }
234
235    #[test]
236    fn test_response_format_serde() {
237        let text = ResponseFormat::Text;
238        let json = serde_json::to_string(&text).unwrap();
239        assert!(json.contains("text"));
240
241        let schema = ResponseFormat::JsonSchema {
242            schema: serde_json::json!({"type": "object"}),
243            name: Some("MySchema".into()),
244        };
245        let json = serde_json::to_string(&schema).unwrap();
246        assert!(json.contains("json_schema"));
247        assert!(json.contains("MySchema"));
248    }
249
250    #[test]
251    fn test_inference_response_has_tool_calls() {
252        let resp = InferenceResponse {
253            content: String::new(),
254            tool_calls: vec![ToolCallRequest {
255                id: "tc_1".into(),
256                name: "search".into(),
257                arguments: "{}".into(),
258            }],
259            finish_reason: FinishReason::ToolCalls,
260            usage: Usage::default(),
261            model: "test".into(),
262        };
263        assert!(resp.has_tool_calls());
264
265        let resp_no_tools = InferenceResponse {
266            content: "Hello".into(),
267            tool_calls: vec![],
268            finish_reason: FinishReason::Stop,
269            usage: Usage::default(),
270            model: "test".into(),
271        };
272        assert!(!resp_no_tools.has_tool_calls());
273    }
274
275    #[test]
276    fn test_finish_reason_serde() {
277        let json = serde_json::to_string(&FinishReason::ToolCalls).unwrap();
278        assert_eq!(json, "\"tool_calls\"");
279        let restored: FinishReason = serde_json::from_str(&json).unwrap();
280        assert_eq!(restored, FinishReason::ToolCalls);
281    }
282}