Skip to main content

spec_ai/spec_ai_api/api/
models.rs

1/// API request and response models
2use serde::{Deserialize, Serialize};
3
4/// Request to query the agent
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct QueryRequest {
7    /// The user's message/query
8    pub message: String,
9    /// Optional session ID for conversation continuity
10    pub session_id: Option<String>,
11    /// Optional agent profile to use
12    pub agent: Option<String>,
13    /// Whether to stream the response
14    #[serde(default)]
15    pub stream: bool,
16    /// Optional temperature override
17    pub temperature: Option<f32>,
18    /// Optional max tokens
19    pub max_tokens: Option<usize>,
20}
21
22/// Response from the agent
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct QueryResponse {
25    /// The agent's response message
26    pub response: String,
27    /// Session ID for this conversation
28    pub session_id: String,
29    /// Agent profile used
30    pub agent: String,
31    /// Tool calls made (if any)
32    #[serde(skip_serializing_if = "Vec::is_empty", default)]
33    pub tool_calls: Vec<ToolCallInfo>,
34    /// Processing metadata
35    pub metadata: ResponseMetadata,
36}
37
38/// Information about a tool call
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ToolCallInfo {
41    /// Tool name
42    pub name: String,
43    /// Tool arguments
44    pub arguments: serde_json::Value,
45    /// Execution status
46    pub success: bool,
47    /// Tool output (if any)
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub output: Option<String>,
50    /// Error message if the tool failed
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub error: Option<String>,
53}
54
55/// Response metadata
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ResponseMetadata {
58    /// Timestamp of response
59    pub timestamp: String,
60    /// Model used
61    pub model: String,
62    /// Processing time in milliseconds
63    pub processing_time_ms: u64,
64    /// Unique identifier for correlating with telemetry
65    pub run_id: String,
66}
67
68/// Streaming response chunk
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(tag = "type")]
71pub enum StreamChunk {
72    /// Initial metadata
73    #[serde(rename = "start")]
74    Start { session_id: String, agent: String },
75    /// Content chunk
76    #[serde(rename = "chunk")]
77    Content { text: String },
78    /// Tool call notification
79    #[serde(rename = "tool_call")]
80    ToolCall {
81        name: String,
82        arguments: serde_json::Value,
83    },
84    /// Tool result
85    #[serde(rename = "tool_result")]
86    ToolResult {
87        name: String,
88        result: serde_json::Value,
89    },
90    /// End of stream
91    #[serde(rename = "end")]
92    End { metadata: ResponseMetadata },
93    /// Error occurred
94    #[serde(rename = "error")]
95    Error { message: String },
96}
97
98/// Error response
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ErrorResponse {
101    /// Error message
102    pub error: String,
103    /// Error code
104    pub code: String,
105    /// Additional details
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub details: Option<serde_json::Value>,
108}
109
110impl ErrorResponse {
111    pub fn new(code: impl Into<String>, error: impl Into<String>) -> Self {
112        Self {
113            error: error.into(),
114            code: code.into(),
115            details: None,
116        }
117    }
118
119    pub fn with_details(mut self, details: serde_json::Value) -> Self {
120        self.details = Some(details);
121        self
122    }
123}
124
125/// Health check response
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct HealthResponse {
128    /// Service status
129    pub status: String,
130    /// Server version
131    pub version: String,
132    /// Uptime in seconds
133    pub uptime_seconds: u64,
134    /// Active sessions count
135    pub active_sessions: usize,
136}
137
138/// Agent list response
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct AgentListResponse {
141    /// Available agents
142    pub agents: Vec<AgentInfo>,
143}
144
145/// Agent information
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct AgentInfo {
148    /// Agent ID
149    pub id: String,
150    /// Agent description/prompt
151    pub description: String,
152    /// Allowed tools
153    #[serde(skip_serializing_if = "Vec::is_empty", default)]
154    pub allowed_tools: Vec<String>,
155    /// Denied tools
156    #[serde(skip_serializing_if = "Vec::is_empty", default)]
157    pub denied_tools: Vec<String>,
158}
159
160/// Semantic code search request
161#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct SearchRequest {
163    /// Search query text
164    pub query: String,
165    /// Repository root to search (defaults to current dir)
166    pub root: Option<String>,
167    /// Page number (0-indexed, default 0)
168    #[serde(default)]
169    pub page: usize,
170    /// Results per page (default 10, max 25)
171    pub page_size: Option<usize>,
172    /// Force re-generation of embeddings
173    #[serde(default)]
174    pub refresh: bool,
175}
176
177/// Search result item
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct SearchResult {
180    /// File path
181    pub path: String,
182    /// Similarity score
183    pub similarity: f32,
184    /// Code snippet
185    pub snippet: String,
186}
187
188/// Paginated search response
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct SearchResponse {
191    /// Search query
192    pub query: String,
193    /// Repository root searched
194    pub root: String,
195    /// Current page (0-indexed)
196    pub page: usize,
197    /// Results per page
198    pub page_size: usize,
199    /// Total results available
200    pub total_results: usize,
201    /// Total pages
202    pub total_pages: usize,
203    /// Results for this page
204    pub results: Vec<SearchResult>,
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_query_request_serialization() {
213        let req = QueryRequest {
214            message: "Hello".to_string(),
215            session_id: Some("sess123".to_string()),
216            agent: Some("coder".to_string()),
217            stream: false,
218            temperature: Some(0.7),
219            max_tokens: Some(1000),
220        };
221
222        let json = serde_json::to_string(&req).unwrap();
223        let deserialized: QueryRequest = serde_json::from_str(&json).unwrap();
224
225        assert_eq!(deserialized.message, "Hello");
226        assert_eq!(deserialized.session_id, Some("sess123".to_string()));
227    }
228
229    #[test]
230    fn test_query_response_serialization() {
231        let resp = QueryResponse {
232            response: "Hi there".to_string(),
233            session_id: "sess123".to_string(),
234            agent: "coder".to_string(),
235            tool_calls: vec![],
236            metadata: ResponseMetadata {
237                timestamp: "2024-01-01T00:00:00Z".to_string(),
238                model: "mock".to_string(),
239                processing_time_ms: 100,
240                run_id: "run-1".to_string(),
241            },
242        };
243
244        let json = serde_json::to_string(&resp).unwrap();
245        let deserialized: QueryResponse = serde_json::from_str(&json).unwrap();
246
247        assert_eq!(deserialized.response, "Hi there");
248        assert_eq!(deserialized.session_id, "sess123");
249    }
250
251    #[test]
252    fn test_stream_chunk_variants() {
253        let chunks = vec![
254            StreamChunk::Start {
255                session_id: "sess1".to_string(),
256                agent: "coder".to_string(),
257            },
258            StreamChunk::Content {
259                text: "Hello".to_string(),
260            },
261            StreamChunk::End {
262                metadata: ResponseMetadata {
263                    timestamp: "2024-01-01T00:00:00Z".to_string(),
264                    model: "mock".to_string(),
265                    processing_time_ms: 100,
266                    run_id: "run-1".to_string(),
267                },
268            },
269        ];
270
271        for chunk in chunks {
272            let json = serde_json::to_string(&chunk).unwrap();
273            let _deserialized: StreamChunk = serde_json::from_str(&json).unwrap();
274        }
275    }
276
277    #[test]
278    fn test_error_response() {
279        let err = ErrorResponse::new("invalid_request", "Invalid API key")
280            .with_details(serde_json::json!({"hint": "Check your configuration"}));
281
282        assert_eq!(err.error, "Invalid API key");
283        assert_eq!(err.code, "invalid_request");
284        assert!(err.details.is_some());
285    }
286
287    #[test]
288    fn test_health_response() {
289        let health = HealthResponse {
290            status: "healthy".to_string(),
291            version: "0.1.0".to_string(),
292            uptime_seconds: 3600,
293            active_sessions: 5,
294        };
295
296        let json = serde_json::to_string(&health).unwrap();
297        let deserialized: HealthResponse = serde_json::from_str(&json).unwrap();
298
299        assert_eq!(deserialized.status, "healthy");
300        assert_eq!(deserialized.uptime_seconds, 3600);
301    }
302}