Skip to main content

spec_ai/spec_ai_api/api/
models.rs

1/// API request and response models
2use crate::spec_ai_api::agent::safety::SafetyStats;
3use crate::spec_ai_api::config::SafetyConfig;
4use serde::{Deserialize, Serialize};
5
6/// Request to query the agent
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct QueryRequest {
9    /// The user's message/query
10    pub message: String,
11    /// Optional session ID for conversation continuity
12    pub session_id: Option<String>,
13    /// Optional agent profile to use
14    pub agent: Option<String>,
15    /// Whether to stream the response
16    #[serde(default)]
17    pub stream: bool,
18    /// Optional temperature override
19    pub temperature: Option<f32>,
20    /// Optional max tokens
21    pub max_tokens: Option<usize>,
22    /// Optional per-request safety overrides.
23    #[serde(default)]
24    pub safety: Option<SafetyConfig>,
25}
26
27/// Response from the agent
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct QueryResponse {
30    /// The agent's response message
31    pub response: String,
32    /// Session ID for this conversation
33    pub session_id: String,
34    /// Agent profile used
35    pub agent: String,
36    /// Tool calls made (if any)
37    #[serde(skip_serializing_if = "Vec::is_empty", default)]
38    pub tool_calls: Vec<ToolCallInfo>,
39    /// Token usage information
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub token_usage: Option<crate::spec_ai_api::agent::model::TokenUsage>,
42    /// Processing metadata
43    pub metadata: ResponseMetadata,
44}
45
46/// Information about a tool call
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ToolCallInfo {
49    /// Tool name
50    pub name: String,
51    /// Tool arguments
52    pub arguments: serde_json::Value,
53    /// Execution status
54    pub success: bool,
55    /// Tool output (if any)
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub output: Option<String>,
58    /// Error message if the tool failed
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub error: Option<String>,
61}
62
63/// Response metadata
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct ResponseMetadata {
66    /// Timestamp of response
67    pub timestamp: String,
68    /// Model used
69    pub model: String,
70    /// Processing time in milliseconds
71    pub processing_time_ms: u64,
72    /// Token usage information (for streaming end chunk)
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub token_usage: Option<crate::spec_ai_api::agent::model::TokenUsage>,
75    /// Unique identifier for correlating with telemetry
76    pub run_id: String,
77    /// Recursion/cost safety counters
78    #[serde(default)]
79    pub safety: SafetyStats,
80}
81
82/// Streaming response chunk
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(tag = "type")]
85pub enum StreamChunk {
86    /// Initial metadata
87    #[serde(rename = "start")]
88    Start { session_id: String, agent: String },
89    /// Content chunk
90    #[serde(rename = "chunk")]
91    Content { text: String },
92    /// Tool call notification
93    #[serde(rename = "tool_call")]
94    ToolCall {
95        name: String,
96        arguments: serde_json::Value,
97    },
98    /// Tool result
99    #[serde(rename = "tool_result")]
100    ToolResult {
101        name: String,
102        result: serde_json::Value,
103    },
104    /// End of stream
105    #[serde(rename = "end")]
106    End { metadata: ResponseMetadata },
107    /// Error occurred
108    #[serde(rename = "error")]
109    Error { message: String },
110}
111
112/// Error response
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ErrorResponse {
115    /// Error message
116    pub error: String,
117    /// Error code
118    pub code: String,
119    /// Additional details
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub details: Option<serde_json::Value>,
122}
123
124impl ErrorResponse {
125    pub fn new(code: impl Into<String>, error: impl Into<String>) -> Self {
126        Self {
127            error: error.into(),
128            code: code.into(),
129            details: None,
130        }
131    }
132
133    pub fn with_details(mut self, details: serde_json::Value) -> Self {
134        self.details = Some(details);
135        self
136    }
137}
138
139/// Health check response
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct HealthResponse {
142    /// Service status
143    pub status: String,
144    /// Server version
145    pub version: String,
146    /// Uptime in seconds
147    pub uptime_seconds: u64,
148    /// Active sessions count
149    pub active_sessions: usize,
150}
151
152/// Agent list response
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct AgentListResponse {
155    /// Available agents
156    pub agents: Vec<AgentInfo>,
157}
158
159/// Agent information
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct AgentInfo {
162    /// Agent ID
163    pub id: String,
164    /// Agent description/prompt
165    pub description: String,
166    /// Allowed tools
167    #[serde(skip_serializing_if = "Vec::is_empty", default)]
168    pub allowed_tools: Vec<String>,
169    /// Denied tools
170    #[serde(skip_serializing_if = "Vec::is_empty", default)]
171    pub denied_tools: Vec<String>,
172}
173
174/// Semantic code search request
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct SearchRequest {
177    /// Search query text
178    pub query: String,
179    /// Repository root to search (defaults to current dir)
180    pub root: Option<String>,
181    /// Page number (0-indexed, default 0)
182    #[serde(default)]
183    pub page: usize,
184    /// Results per page (default 10, max 25)
185    pub page_size: Option<usize>,
186    /// Force re-generation of embeddings
187    #[serde(default)]
188    pub refresh: bool,
189}
190
191/// Search result item
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct SearchResult {
194    /// File path
195    pub path: String,
196    /// Similarity score
197    pub similarity: f32,
198    /// Code snippet
199    pub snippet: String,
200}
201
202/// Paginated search response
203#[derive(Debug, Clone, Serialize, Deserialize)]
204pub struct SearchResponse {
205    /// Search query
206    pub query: String,
207    /// Repository root searched
208    pub root: String,
209    /// Current page (0-indexed)
210    pub page: usize,
211    /// Results per page
212    pub page_size: usize,
213    /// Total results available
214    pub total_results: usize,
215    /// Total pages
216    pub total_pages: usize,
217    /// Results for this page
218    pub results: Vec<SearchResult>,
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_query_request_serialization() {
227        let req = QueryRequest {
228            message: "Hello".to_string(),
229            session_id: Some("sess123".to_string()),
230            agent: Some("coder".to_string()),
231            stream: false,
232            temperature: Some(0.7),
233            max_tokens: Some(1000),
234            safety: Some(SafetyConfig {
235                max_model_calls_per_run: 8,
236                ..SafetyConfig::default()
237            }),
238        };
239
240        let json = serde_json::to_string(&req).unwrap();
241        let deserialized: QueryRequest = serde_json::from_str(&json).unwrap();
242
243        assert_eq!(deserialized.message, "Hello");
244        assert_eq!(deserialized.session_id, Some("sess123".to_string()));
245        assert_eq!(deserialized.safety.unwrap().max_model_calls_per_run, 8);
246    }
247
248    #[test]
249    fn test_query_response_serialization() {
250        let resp = QueryResponse {
251            response: "Hi there".to_string(),
252            session_id: "sess123".to_string(),
253            agent: "coder".to_string(),
254            tool_calls: vec![],
255            token_usage: None,
256            metadata: ResponseMetadata {
257                timestamp: "2024-01-01T00:00:00Z".to_string(),
258                model: "mock".to_string(),
259                processing_time_ms: 100,
260                token_usage: None,
261                run_id: "run-1".to_string(),
262                safety: SafetyStats::default(),
263            },
264        };
265
266        let json = serde_json::to_string(&resp).unwrap();
267        let deserialized: QueryResponse = serde_json::from_str(&json).unwrap();
268
269        assert_eq!(deserialized.response, "Hi there");
270        assert_eq!(deserialized.session_id, "sess123");
271    }
272
273    #[test]
274    fn test_stream_chunk_variants() {
275        let chunks = vec![
276            StreamChunk::Start {
277                session_id: "sess1".to_string(),
278                agent: "coder".to_string(),
279            },
280            StreamChunk::Content {
281                text: "Hello".to_string(),
282            },
283            StreamChunk::End {
284                metadata: ResponseMetadata {
285                    timestamp: "2024-01-01T00:00:00Z".to_string(),
286                    model: "mock".to_string(),
287                    processing_time_ms: 100,
288                    token_usage: None,
289                    run_id: "run-1".to_string(),
290                    safety: SafetyStats::default(),
291                },
292            },
293        ];
294
295        for chunk in chunks {
296            let json = serde_json::to_string(&chunk).unwrap();
297            let _deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
298        }
299    }
300
301    #[test]
302    fn test_error_response() {
303        let err = ErrorResponse::new("invalid_request", "Invalid API key")
304            .with_details(serde_json::json!({"hint": "Check your configuration"}));
305
306        assert_eq!(err.error, "Invalid API key");
307        assert_eq!(err.code, "invalid_request");
308        assert!(err.details.is_some());
309    }
310
311    #[test]
312    fn test_health_response() {
313        let health = HealthResponse {
314            status: "healthy".to_string(),
315            version: "0.1.0".to_string(),
316            uptime_seconds: 3600,
317            active_sessions: 5,
318        };
319
320        let json = serde_json::to_string(&health).unwrap();
321        let deserialized: HealthResponse = serde_json::from_str(&json).unwrap();
322
323        assert_eq!(deserialized.status, "healthy");
324        assert_eq!(deserialized.uptime_seconds, 3600);
325    }
326}