1use super::error::TaskType;
4use crate::config::ResourceConstraints;
5use crate::sandbox::SandboxTier;
6use crate::types::AgentId;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Mutex;
11use std::time::Duration;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum RouteDecision {
16 UseSLM {
18 model_id: String,
19 monitoring: MonitoringLevel,
20 fallback_on_failure: bool,
21 sandbox_tier: Option<SandboxTier>,
22 },
23 UseLLM {
25 provider: LLMProvider,
26 reason: String,
27 sandbox_tier: Option<SandboxTier>,
28 },
29 Deny {
31 reason: String,
32 policy_violated: String,
33 },
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub enum MonitoringLevel {
39 None,
40 Basic,
41 Enhanced { confidence_threshold: f64 },
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub enum LLMProvider {
47 OpenAI {
48 model: Option<String>,
49 },
50 Anthropic {
51 model: Option<String>,
52 },
53 Custom {
54 endpoint: String,
55 model: Option<String>,
56 },
57}
58
59impl std::fmt::Display for LLMProvider {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 LLMProvider::OpenAI { model } => {
63 let model_name = model.as_deref().unwrap_or("gpt-3.5-turbo");
64 write!(f, "OpenAI({})", model_name)
65 }
66 LLMProvider::Anthropic { model } => {
67 let model_name = model.as_deref().unwrap_or("claude-3-haiku");
68 write!(f, "Anthropic({})", model_name)
69 }
70 LLMProvider::Custom { endpoint, model } => {
71 let model_name = model.as_deref().unwrap_or("unknown");
72 write!(f, "Custom({}, {})", endpoint, model_name)
73 }
74 }
75 }
76}
77
78#[derive(Debug, Clone)]
80pub struct RoutingContext {
81 pub request_id: String,
83 pub agent_id: AgentId,
84 pub timestamp: chrono::DateTime<chrono::Utc>,
85
86 pub task_type: TaskType,
88 pub prompt: String,
89 pub expected_output_type: OutputType,
90
91 pub max_execution_time: Option<Duration>,
93 pub resource_limits: Option<ResourceConstraints>,
94
95 pub agent_capabilities: Vec<String>,
97 pub agent_security_level: SecurityLevel,
98
99 pub metadata: HashMap<String, String>,
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
105pub enum OutputType {
106 Text,
107 Code,
108 Json,
109 Structured,
110 Binary,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
115pub enum SecurityLevel {
116 Low = 1,
117 Medium = 2,
118 High = 3,
119 Critical = 4,
120}
121
122#[derive(Debug, Clone)]
124pub struct ModelRequest {
125 pub prompt: String,
126 pub parameters: HashMap<String, serde_json::Value>,
127 pub max_tokens: Option<u32>,
128 pub temperature: Option<f32>,
129 pub stop_sequences: Option<Vec<String>>,
130}
131
132#[derive(Debug, Clone)]
134pub struct ModelResponse {
135 pub content: String,
136 pub finish_reason: FinishReason,
137 pub token_usage: Option<TokenUsage>,
138 pub metadata: HashMap<String, serde_json::Value>,
139 pub confidence_score: Option<f64>,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
144pub enum FinishReason {
145 Stop,
146 Length,
147 ContentFilter,
148 Error,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct TokenUsage {
154 pub prompt_tokens: u32,
155 pub completion_tokens: u32,
156 pub total_tokens: u32,
157}
158
159pub struct RoutingStatistics {
165 total_requests: AtomicU64,
166 slm_routes: AtomicU64,
167 llm_routes: AtomicU64,
168 denied_routes: AtomicU64,
169 fallback_routes: AtomicU64,
170 cumulative_response_time_nanos: AtomicU64,
172 successful_requests: AtomicU64,
174 confidence_state: Mutex<ConfidenceState>,
176}
177
178struct ConfidenceState {
180 scores: VecDeque<f64>,
181 max_scores: usize,
182}
183
184impl Default for RoutingStatistics {
185 fn default() -> Self {
186 Self {
187 total_requests: AtomicU64::new(0),
188 slm_routes: AtomicU64::new(0),
189 llm_routes: AtomicU64::new(0),
190 denied_routes: AtomicU64::new(0),
191 fallback_routes: AtomicU64::new(0),
192 cumulative_response_time_nanos: AtomicU64::new(0),
193 successful_requests: AtomicU64::new(0),
194 confidence_state: Mutex::new(ConfidenceState {
195 scores: VecDeque::new(),
196 max_scores: 1000,
197 }),
198 }
199 }
200}
201
202impl std::fmt::Debug for RoutingStatistics {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("RoutingStatistics")
205 .field("total_requests", &self.total_requests())
206 .field("slm_routes", &self.slm_routes())
207 .field("llm_routes", &self.llm_routes())
208 .field("denied_routes", &self.denied_routes())
209 .field("fallback_routes", &self.fallback_routes())
210 .field("average_response_time", &self.average_response_time())
211 .field("success_rate", &self.success_rate())
212 .finish()
213 }
214}
215
216impl Clone for RoutingStatistics {
217 fn clone(&self) -> Self {
218 let confidence_state = self.confidence_state.lock().unwrap();
219 Self {
220 total_requests: AtomicU64::new(self.total_requests.load(Ordering::Relaxed)),
221 slm_routes: AtomicU64::new(self.slm_routes.load(Ordering::Relaxed)),
222 llm_routes: AtomicU64::new(self.llm_routes.load(Ordering::Relaxed)),
223 denied_routes: AtomicU64::new(self.denied_routes.load(Ordering::Relaxed)),
224 fallback_routes: AtomicU64::new(self.fallback_routes.load(Ordering::Relaxed)),
225 cumulative_response_time_nanos: AtomicU64::new(
226 self.cumulative_response_time_nanos.load(Ordering::Relaxed),
227 ),
228 successful_requests: AtomicU64::new(self.successful_requests.load(Ordering::Relaxed)),
229 confidence_state: Mutex::new(ConfidenceState {
230 scores: confidence_state.scores.clone(),
231 max_scores: confidence_state.max_scores,
232 }),
233 }
234 }
235}
236
237impl RoutingContext {
238 pub fn new(agent_id: AgentId, task_type: TaskType, prompt: String) -> Self {
240 Self {
241 request_id: uuid::Uuid::new_v4().to_string(),
242 agent_id,
243 timestamp: chrono::Utc::now(),
244 task_type,
245 prompt,
246 expected_output_type: OutputType::Text,
247 max_execution_time: None,
248 resource_limits: None,
249 agent_capabilities: Vec::new(),
250 agent_security_level: SecurityLevel::Medium,
251 metadata: HashMap::new(),
252 }
253 }
254
255 pub fn with_output_type(mut self, output_type: OutputType) -> Self {
257 self.expected_output_type = output_type;
258 self
259 }
260
261 pub fn with_resource_limits(mut self, limits: ResourceConstraints) -> Self {
263 self.resource_limits = Some(limits);
264 self
265 }
266
267 pub fn with_security_level(mut self, level: SecurityLevel) -> Self {
269 self.agent_security_level = level;
270 self
271 }
272
273 pub fn with_metadata(mut self, key: String, value: String) -> Self {
275 self.metadata.insert(key, value);
276 self
277 }
278}
279
280impl ModelRequest {
281 pub fn from_task(prompt: String) -> Self {
283 Self {
284 prompt,
285 parameters: HashMap::new(),
286 max_tokens: None,
287 temperature: None,
288 stop_sequences: None,
289 }
290 }
291
292 pub fn with_temperature(mut self, temperature: f32) -> Self {
294 self.temperature = Some(temperature);
295 self
296 }
297
298 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
300 self.max_tokens = Some(max_tokens);
301 self
302 }
303}
304
305impl RoutingStatistics {
306 pub fn total_requests(&self) -> u64 {
310 self.total_requests.load(Ordering::Relaxed)
311 }
312
313 pub fn slm_routes(&self) -> u64 {
315 self.slm_routes.load(Ordering::Relaxed)
316 }
317
318 pub fn llm_routes(&self) -> u64 {
320 self.llm_routes.load(Ordering::Relaxed)
321 }
322
323 pub fn denied_routes(&self) -> u64 {
325 self.denied_routes.load(Ordering::Relaxed)
326 }
327
328 pub fn fallback_routes(&self) -> u64 {
330 self.fallback_routes.load(Ordering::Relaxed)
331 }
332
333 pub fn average_response_time(&self) -> Duration {
335 let total = self.total_requests.load(Ordering::Relaxed);
336 if total == 0 {
337 return Duration::ZERO;
338 }
339 let cumulative = self.cumulative_response_time_nanos.load(Ordering::Relaxed);
340 Duration::from_nanos(cumulative / total)
341 }
342
343 pub fn success_rate(&self) -> f64 {
345 let total = self.total_requests.load(Ordering::Relaxed);
346 if total == 0 {
347 return 0.0;
348 }
349 let successful = self.successful_requests.load(Ordering::Relaxed);
350 successful as f64 / total as f64
351 }
352
353 pub fn record_request(&self, decision: &RouteDecision, response_time: Duration, success: bool) {
357 self.total_requests.fetch_add(1, Ordering::Relaxed);
358
359 match decision {
360 RouteDecision::UseSLM { .. } => {
361 self.slm_routes.fetch_add(1, Ordering::Relaxed);
362 }
363 RouteDecision::UseLLM { .. } => {
364 self.llm_routes.fetch_add(1, Ordering::Relaxed);
365 }
366 RouteDecision::Deny { .. } => {
367 self.denied_routes.fetch_add(1, Ordering::Relaxed);
368 }
369 }
370
371 let nanos = response_time.as_nanos() as u64;
373 self.cumulative_response_time_nanos
374 .fetch_add(nanos, Ordering::Relaxed);
375
376 if success {
377 self.successful_requests.fetch_add(1, Ordering::Relaxed);
378 }
379 }
380
381 pub fn record_fallback(&self) {
383 self.fallback_routes.fetch_add(1, Ordering::Relaxed);
384 }
385
386 pub fn add_confidence_score(&self, score: f64) {
388 let mut state = self.confidence_state.lock().unwrap();
389 state.scores.push_back(score);
390 if state.scores.len() > state.max_scores {
391 state.scores.pop_front();
392 }
393 }
394
395 pub fn average_confidence(&self) -> Option<f64> {
397 let state = self.confidence_state.lock().unwrap();
398 if state.scores.is_empty() {
399 None
400 } else {
401 Some(state.scores.iter().sum::<f64>() / state.scores.len() as f64)
402 }
403 }
404}