1use super::error::TaskType;
4use crate::config::ResourceConstraints;
5use crate::sandbox::SandboxTier;
6use crate::types::AgentId;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, VecDeque};
9use std::time::Duration;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub enum RouteDecision {
14 UseSLM {
16 model_id: String,
17 monitoring: MonitoringLevel,
18 fallback_on_failure: bool,
19 sandbox_tier: Option<SandboxTier>,
20 },
21 UseLLM {
23 provider: LLMProvider,
24 reason: String,
25 sandbox_tier: Option<SandboxTier>,
26 },
27 Deny {
29 reason: String,
30 policy_violated: String,
31 },
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub enum MonitoringLevel {
37 None,
38 Basic,
39 Enhanced { confidence_threshold: f64 },
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum LLMProvider {
45 OpenAI {
46 model: Option<String>,
47 },
48 Anthropic {
49 model: Option<String>,
50 },
51 Custom {
52 endpoint: String,
53 model: Option<String>,
54 },
55}
56
57impl std::fmt::Display for LLMProvider {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 LLMProvider::OpenAI { model } => {
61 let model_name = model.as_deref().unwrap_or("gpt-3.5-turbo");
62 write!(f, "OpenAI({})", model_name)
63 }
64 LLMProvider::Anthropic { model } => {
65 let model_name = model.as_deref().unwrap_or("claude-3-haiku");
66 write!(f, "Anthropic({})", model_name)
67 }
68 LLMProvider::Custom { endpoint, model } => {
69 let model_name = model.as_deref().unwrap_or("unknown");
70 write!(f, "Custom({}, {})", endpoint, model_name)
71 }
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct RoutingContext {
79 pub request_id: String,
81 pub agent_id: AgentId,
82 pub timestamp: chrono::DateTime<chrono::Utc>,
83
84 pub task_type: TaskType,
86 pub prompt: String,
87 pub expected_output_type: OutputType,
88
89 pub max_execution_time: Option<Duration>,
91 pub resource_limits: Option<ResourceConstraints>,
92
93 pub agent_capabilities: Vec<String>,
95 pub agent_security_level: SecurityLevel,
96
97 pub metadata: HashMap<String, String>,
99}
100
101#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
103pub enum OutputType {
104 Text,
105 Code,
106 Json,
107 Structured,
108 Binary,
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
113pub enum SecurityLevel {
114 Low = 1,
115 Medium = 2,
116 High = 3,
117 Critical = 4,
118}
119
120#[derive(Debug, Clone)]
122pub struct ModelRequest {
123 pub prompt: String,
124 pub parameters: HashMap<String, serde_json::Value>,
125 pub max_tokens: Option<u32>,
126 pub temperature: Option<f32>,
127 pub stop_sequences: Option<Vec<String>>,
128}
129
130#[derive(Debug, Clone)]
132pub struct ModelResponse {
133 pub content: String,
134 pub finish_reason: FinishReason,
135 pub token_usage: Option<TokenUsage>,
136 pub metadata: HashMap<String, serde_json::Value>,
137 pub confidence_score: Option<f64>,
138}
139
140#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
142pub enum FinishReason {
143 Stop,
144 Length,
145 ContentFilter,
146 Error,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct TokenUsage {
152 pub prompt_tokens: u32,
153 pub completion_tokens: u32,
154 pub total_tokens: u32,
155}
156
157#[derive(Debug, Clone)]
159pub struct RoutingStatistics {
160 pub total_requests: u64,
161 pub slm_routes: u64,
162 pub llm_routes: u64,
163 pub denied_routes: u64,
164 pub fallback_routes: u64,
165 cumulative_response_time_nanos: u128,
167 pub average_response_time: Duration,
169 pub success_rate: f64,
170 confidence_scores: VecDeque<f64>,
172 max_confidence_scores: usize,
174}
175
176impl Default for RoutingStatistics {
177 fn default() -> Self {
178 Self {
179 total_requests: 0,
180 slm_routes: 0,
181 llm_routes: 0,
182 denied_routes: 0,
183 fallback_routes: 0,
184 cumulative_response_time_nanos: 0,
185 average_response_time: Duration::from_millis(0),
186 success_rate: 0.0,
187 confidence_scores: VecDeque::new(),
188 max_confidence_scores: 1000,
189 }
190 }
191}
192
193impl RoutingContext {
194 pub fn new(agent_id: AgentId, task_type: TaskType, prompt: String) -> Self {
196 Self {
197 request_id: uuid::Uuid::new_v4().to_string(),
198 agent_id,
199 timestamp: chrono::Utc::now(),
200 task_type,
201 prompt,
202 expected_output_type: OutputType::Text,
203 max_execution_time: None,
204 resource_limits: None,
205 agent_capabilities: Vec::new(),
206 agent_security_level: SecurityLevel::Medium,
207 metadata: HashMap::new(),
208 }
209 }
210
211 pub fn with_output_type(mut self, output_type: OutputType) -> Self {
213 self.expected_output_type = output_type;
214 self
215 }
216
217 pub fn with_resource_limits(mut self, limits: ResourceConstraints) -> Self {
219 self.resource_limits = Some(limits);
220 self
221 }
222
223 pub fn with_security_level(mut self, level: SecurityLevel) -> Self {
225 self.agent_security_level = level;
226 self
227 }
228
229 pub fn with_metadata(mut self, key: String, value: String) -> Self {
231 self.metadata.insert(key, value);
232 self
233 }
234}
235
236impl ModelRequest {
237 pub fn from_task(prompt: String) -> Self {
239 Self {
240 prompt,
241 parameters: HashMap::new(),
242 max_tokens: None,
243 temperature: None,
244 stop_sequences: None,
245 }
246 }
247
248 pub fn with_temperature(mut self, temperature: f32) -> Self {
250 self.temperature = Some(temperature);
251 self
252 }
253
254 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
256 self.max_tokens = Some(max_tokens);
257 self
258 }
259}
260
261impl RoutingStatistics {
262 pub fn update(&mut self, decision: &RouteDecision, response_time: Duration, success: bool) {
264 self.total_requests += 1;
265
266 match decision {
267 RouteDecision::UseSLM { .. } => self.slm_routes += 1,
268 RouteDecision::UseLLM { .. } => self.llm_routes += 1,
269 RouteDecision::Deny { .. } => self.denied_routes += 1,
270 }
271
272 self.cumulative_response_time_nanos += response_time.as_nanos();
274 self.average_response_time = Duration::from_nanos(
275 (self.cumulative_response_time_nanos / self.total_requests as u128) as u64,
276 );
277
278 let successful_requests = if success { 1 } else { 0 };
280 self.success_rate = (self.success_rate * (self.total_requests - 1) as f64
281 + successful_requests as f64)
282 / self.total_requests as f64;
283 }
284
285 pub fn add_confidence_score(&mut self, score: f64) {
287 self.confidence_scores.push_back(score);
288 if self.confidence_scores.len() > self.max_confidence_scores {
290 self.confidence_scores.pop_front();
291 }
292 }
293
294 pub fn average_confidence(&self) -> Option<f64> {
296 if self.confidence_scores.is_empty() {
297 None
298 } else {
299 Some(self.confidence_scores.iter().sum::<f64>() / self.confidence_scores.len() as f64)
300 }
301 }
302}