spec_ai/spec_ai_collective/
capability.rs1use crate::spec_ai_collective::types::{Domain, InstanceId};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Capability {
14 pub domain: Domain,
16
17 pub proficiency: f32,
19
20 pub experience_count: u64,
22
23 pub success_rate: f32,
25
26 pub avg_duration_ms: Option<u64>,
28
29 pub last_updated: DateTime<Utc>,
31}
32
33impl Capability {
34 pub fn new(domain: Domain) -> Self {
36 Self {
37 domain,
38 proficiency: 0.0,
39 experience_count: 0,
40 success_rate: 0.0,
41 avg_duration_ms: None,
42 last_updated: Utc::now(),
43 }
44 }
45
46 pub fn update(&mut self, outcome: &TaskOutcome) {
48 self.experience_count += 1;
49
50 let success_value = match outcome {
52 TaskOutcome::Success { .. } => 1.0,
53 TaskOutcome::Partial { completion_ratio } => *completion_ratio,
54 TaskOutcome::Failure { .. } => 0.0,
55 };
56
57 let alpha = 0.1; self.success_rate = (1.0 - alpha) * self.success_rate + alpha * success_value;
59
60 if let TaskOutcome::Success { duration_ms, .. } = outcome {
62 self.avg_duration_ms = Some(match self.avg_duration_ms {
63 Some(avg) => ((avg as f64 * 0.9) + (*duration_ms as f64 * 0.1)) as u64,
64 None => *duration_ms,
65 });
66 }
67
68 self.proficiency = self.calculate_proficiency();
70 self.last_updated = Utc::now();
71 }
72
73 fn calculate_proficiency(&self) -> f32 {
75 let experience_factor = (1.0 - (-0.01 * self.experience_count as f32).exp()).min(1.0);
77 (experience_factor * self.success_rate).min(1.0)
78 }
79
80 pub fn is_specialist(&self) -> bool {
82 self.proficiency > 0.8
83 }
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum TaskOutcome {
89 Success {
91 confidence: f32,
93 duration_ms: u64,
95 },
96 Failure {
98 error_category: String,
100 recoverable: bool,
102 },
103 Partial {
105 completion_ratio: f32,
107 },
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct LearningEvent {
113 pub task_type: String,
115
116 pub outcome: TaskOutcome,
118
119 pub strategy_used: String,
121
122 pub context_embedding: Option<Vec<f32>>,
124
125 pub timestamp: DateTime<Utc>,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ExpertiseProfile {
132 pub instance_id: InstanceId,
134
135 pub capabilities: HashMap<Domain, Capability>,
137
138 pub specializations: Vec<Domain>,
140
141 pub learning_history: Vec<LearningEvent>,
143
144 #[serde(default = "default_max_history")]
146 pub max_history: usize,
147}
148
149fn default_max_history() -> usize {
150 100
151}
152
153impl ExpertiseProfile {
154 pub fn new(instance_id: InstanceId) -> Self {
156 Self {
157 instance_id,
158 capabilities: HashMap::new(),
159 specializations: Vec::new(),
160 learning_history: Vec::new(),
161 max_history: default_max_history(),
162 }
163 }
164
165 pub fn get_or_create_capability(&mut self, domain: &str) -> &mut Capability {
167 self.capabilities
168 .entry(domain.to_string())
169 .or_insert_with(|| Capability::new(domain.to_string()))
170 }
171
172 pub fn record_outcome(&mut self, domain: &str, outcome: TaskOutcome, strategy: String) {
174 let capability = self.get_or_create_capability(domain);
176 capability.update(&outcome);
177
178 let event = LearningEvent {
180 task_type: domain.to_string(),
181 outcome,
182 strategy_used: strategy,
183 context_embedding: None,
184 timestamp: Utc::now(),
185 };
186 self.learning_history.push(event);
187
188 while self.learning_history.len() > self.max_history {
190 self.learning_history.remove(0);
191 }
192
193 self.update_specializations();
195 }
196
197 fn update_specializations(&mut self) {
199 self.specializations = self
200 .capabilities
201 .iter()
202 .filter(|(_, cap)| cap.is_specialist())
203 .map(|(domain, _)| domain.clone())
204 .collect();
205 }
206
207 pub fn match_score(&self, required: &[String]) -> f32 {
209 if required.is_empty() {
210 return 1.0;
211 }
212
213 let total: f32 = required
214 .iter()
215 .map(|domain| {
216 self.capabilities
217 .get(domain)
218 .map(|c| c.proficiency)
219 .unwrap_or(0.0)
220 })
221 .sum();
222
223 total / required.len() as f32
224 }
225}
226
227#[derive(Debug)]
229pub struct CapabilityTracker {
230 instance_id: InstanceId,
232
233 profile: ExpertiseProfile,
235
236 peers: HashMap<InstanceId, ExpertiseProfile>,
238}
239
240impl CapabilityTracker {
241 pub fn new(instance_id: InstanceId) -> Self {
243 Self {
244 instance_id: instance_id.clone(),
245 profile: ExpertiseProfile::new(instance_id),
246 peers: HashMap::new(),
247 }
248 }
249
250 pub fn instance_id(&self) -> &str {
252 &self.instance_id
253 }
254
255 pub fn profile(&self) -> &ExpertiseProfile {
257 &self.profile
258 }
259
260 pub fn profile_mut(&mut self) -> &mut ExpertiseProfile {
262 &mut self.profile
263 }
264
265 pub fn record_task_outcome(&mut self, domain: &str, outcome: TaskOutcome, strategy: String) {
267 self.profile.record_outcome(domain, outcome, strategy);
268 }
269
270 pub fn update_peer_profile(&mut self, profile: ExpertiseProfile) {
272 self.peers.insert(profile.instance_id.clone(), profile);
273 }
274
275 pub fn get_best_agent(
277 &self,
278 required_capabilities: &[String],
279 ) -> Option<RoutingRecommendation> {
280 let mut best: Option<(String, f32)> = None;
281
282 let self_score = self.profile.match_score(required_capabilities);
284 if self_score > 0.0 {
285 best = Some((self.instance_id.clone(), self_score));
286 }
287
288 for (instance_id, profile) in &self.peers {
290 let score = profile.match_score(required_capabilities);
291 if let Some((_, best_score)) = &best {
292 if score > *best_score {
293 best = Some((instance_id.clone(), score));
294 }
295 } else if score > 0.0 {
296 best = Some((instance_id.clone(), score));
297 }
298 }
299
300 best.map(|(instance_id, score)| {
301 let is_self = instance_id == self.instance_id;
302 RoutingRecommendation {
303 instance_id,
304 score,
305 is_self,
306 }
307 })
308 }
309
310 pub fn get_capable_agents(
312 &self,
313 required_capabilities: &[String],
314 min_score: f32,
315 ) -> Vec<RoutingRecommendation> {
316 let mut agents = Vec::new();
317
318 let self_score = self.profile.match_score(required_capabilities);
320 if self_score >= min_score {
321 agents.push(RoutingRecommendation {
322 instance_id: self.instance_id.clone(),
323 score: self_score,
324 is_self: true,
325 });
326 }
327
328 for (instance_id, profile) in &self.peers {
330 let score = profile.match_score(required_capabilities);
331 if score >= min_score {
332 agents.push(RoutingRecommendation {
333 instance_id: instance_id.clone(),
334 score,
335 is_self: false,
336 });
337 }
338 }
339
340 agents.sort_by(|a, b| {
342 b.score
343 .partial_cmp(&a.score)
344 .unwrap_or(std::cmp::Ordering::Equal)
345 });
346
347 agents
348 }
349
350 pub fn peers(&self) -> &HashMap<InstanceId, ExpertiseProfile> {
352 &self.peers
353 }
354}
355
356#[derive(Debug, Clone)]
358pub struct RoutingRecommendation {
359 pub instance_id: InstanceId,
361
362 pub score: f32,
364
365 pub is_self: bool,
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_capability_update() {
375 let mut cap = Capability::new("code_review".to_string());
376 assert_eq!(cap.proficiency, 0.0);
377
378 for _ in 0..10 {
380 cap.update(&TaskOutcome::Success {
381 confidence: 0.9,
382 duration_ms: 1000,
383 });
384 }
385
386 assert!(cap.proficiency > 0.0);
387 assert!(cap.success_rate > 0.5);
389 assert_eq!(cap.experience_count, 10);
390 }
391
392 #[test]
393 fn test_expertise_profile_matching() {
394 let mut profile = ExpertiseProfile::new("agent-1".to_string());
395
396 for _ in 0..20 {
398 profile.record_outcome(
399 "code_review",
400 TaskOutcome::Success {
401 confidence: 0.9,
402 duration_ms: 1000,
403 },
404 "standard_review".to_string(),
405 );
406 }
407
408 let score = profile.match_score(&["code_review".to_string()]);
409 assert!(score > 0.0);
410
411 let score2 = profile.match_score(&["unknown_domain".to_string()]);
412 assert_eq!(score2, 0.0);
413 }
414
415 #[test]
416 fn test_capability_tracker_routing() {
417 let mut tracker = CapabilityTracker::new("agent-1".to_string());
418
419 for _ in 0..10 {
421 tracker.record_task_outcome(
422 "data_analysis",
423 TaskOutcome::Success {
424 confidence: 0.9,
425 duration_ms: 500,
426 },
427 "standard".to_string(),
428 );
429 }
430
431 let rec = tracker.get_best_agent(&["data_analysis".to_string()]);
432 assert!(rec.is_some());
433 assert!(rec.unwrap().is_self);
434 }
435}