Skip to main content

spec_ai/spec_ai_collective/
capability.rs

1//! Agent capability tracking and expertise management.
2//!
3//! This module provides infrastructure for tracking agent capabilities,
4//! recording task outcomes, and building expertise profiles over time.
5
6use crate::spec_ai_collective::types::{Domain, InstanceId};
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Represents an agent's capability in a specific domain.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Capability {
14    /// The domain this capability applies to (e.g., "code_review", "data_analysis")
15    pub domain: Domain,
16
17    /// Proficiency level from 0.0 (novice) to 1.0 (expert)
18    pub proficiency: f32,
19
20    /// Number of tasks completed in this domain
21    pub experience_count: u64,
22
23    /// Historical success rate (0.0 to 1.0)
24    pub success_rate: f32,
25
26    /// Average task completion time in milliseconds
27    pub avg_duration_ms: Option<u64>,
28
29    /// When this capability was last updated
30    pub last_updated: DateTime<Utc>,
31}
32
33impl Capability {
34    /// Create a new capability with default values.
35    pub fn new(domain: Domain) -> Self {
36        Self {
37            domain,
38            proficiency: 0.0,
39            experience_count: 0,
40            success_rate: 0.0,
41            avg_duration_ms: None,
42            last_updated: Utc::now(),
43        }
44    }
45
46    /// Update capability based on a task outcome.
47    pub fn update(&mut self, outcome: &TaskOutcome) {
48        self.experience_count += 1;
49
50        // Update success rate with exponential moving average
51        let success_value = match outcome {
52            TaskOutcome::Success { .. } => 1.0,
53            TaskOutcome::Partial { completion_ratio } => *completion_ratio,
54            TaskOutcome::Failure { .. } => 0.0,
55        };
56
57        let alpha = 0.1; // Learning rate
58        self.success_rate = (1.0 - alpha) * self.success_rate + alpha * success_value;
59
60        // Update average duration if available
61        if let TaskOutcome::Success { duration_ms, .. } = outcome {
62            self.avg_duration_ms = Some(match self.avg_duration_ms {
63                Some(avg) => ((avg as f64 * 0.9) + (*duration_ms as f64 * 0.1)) as u64,
64                None => *duration_ms,
65            });
66        }
67
68        // Update proficiency based on experience and success rate
69        self.proficiency = self.calculate_proficiency();
70        self.last_updated = Utc::now();
71    }
72
73    /// Calculate proficiency based on experience and success rate.
74    fn calculate_proficiency(&self) -> f32 {
75        // Proficiency grows with experience but is bounded by success rate
76        let experience_factor = (1.0 - (-0.01 * self.experience_count as f32).exp()).min(1.0);
77        (experience_factor * self.success_rate).min(1.0)
78    }
79
80    /// Check if this agent is a specialist in this domain (proficiency > 0.8).
81    pub fn is_specialist(&self) -> bool {
82        self.proficiency > 0.8
83    }
84}
85
86/// Outcome of a task execution.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum TaskOutcome {
89    /// Task completed successfully
90    Success {
91        /// Confidence in the result (0.0 to 1.0)
92        confidence: f32,
93        /// Duration in milliseconds
94        duration_ms: u64,
95    },
96    /// Task failed
97    Failure {
98        /// Category of the error
99        error_category: String,
100        /// Whether the error is recoverable
101        recoverable: bool,
102    },
103    /// Task partially completed
104    Partial {
105        /// Ratio of completion (0.0 to 1.0)
106        completion_ratio: f32,
107    },
108}
109
110/// A single learning event from task execution.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct LearningEvent {
113    /// Type of task performed
114    pub task_type: String,
115
116    /// Outcome of the task
117    pub outcome: TaskOutcome,
118
119    /// Strategy or approach used
120    pub strategy_used: String,
121
122    /// Optional context embedding for semantic matching
123    pub context_embedding: Option<Vec<f32>>,
124
125    /// When this event occurred
126    pub timestamp: DateTime<Utc>,
127}
128
129/// Agent expertise profile tracking all capabilities.
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct ExpertiseProfile {
132    /// The instance ID of the agent
133    pub instance_id: InstanceId,
134
135    /// Capabilities by domain
136    pub capabilities: HashMap<Domain, Capability>,
137
138    /// Domains where the agent is a specialist (proficiency > 0.8)
139    pub specializations: Vec<Domain>,
140
141    /// Recent learning events
142    pub learning_history: Vec<LearningEvent>,
143
144    /// Maximum learning history to keep
145    #[serde(default = "default_max_history")]
146    pub max_history: usize,
147}
148
149fn default_max_history() -> usize {
150    100
151}
152
153impl ExpertiseProfile {
154    /// Create a new expertise profile for an agent.
155    pub fn new(instance_id: InstanceId) -> Self {
156        Self {
157            instance_id,
158            capabilities: HashMap::new(),
159            specializations: Vec::new(),
160            learning_history: Vec::new(),
161            max_history: default_max_history(),
162        }
163    }
164
165    /// Get capability for a domain, creating if necessary.
166    pub fn get_or_create_capability(&mut self, domain: &str) -> &mut Capability {
167        self.capabilities
168            .entry(domain.to_string())
169            .or_insert_with(|| Capability::new(domain.to_string()))
170    }
171
172    /// Record a task outcome and update capabilities.
173    pub fn record_outcome(&mut self, domain: &str, outcome: TaskOutcome, strategy: String) {
174        // Update capability
175        let capability = self.get_or_create_capability(domain);
176        capability.update(&outcome);
177
178        // Add learning event
179        let event = LearningEvent {
180            task_type: domain.to_string(),
181            outcome,
182            strategy_used: strategy,
183            context_embedding: None,
184            timestamp: Utc::now(),
185        };
186        self.learning_history.push(event);
187
188        // Trim history if needed
189        while self.learning_history.len() > self.max_history {
190            self.learning_history.remove(0);
191        }
192
193        // Update specializations
194        self.update_specializations();
195    }
196
197    /// Update the list of specializations based on current capabilities.
198    fn update_specializations(&mut self) {
199        self.specializations = self
200            .capabilities
201            .iter()
202            .filter(|(_, cap)| cap.is_specialist())
203            .map(|(domain, _)| domain.clone())
204            .collect();
205    }
206
207    /// Get the best capability match for required capabilities.
208    pub fn match_score(&self, required: &[String]) -> f32 {
209        if required.is_empty() {
210            return 1.0;
211        }
212
213        let total: f32 = required
214            .iter()
215            .map(|domain| {
216                self.capabilities
217                    .get(domain)
218                    .map(|c| c.proficiency)
219                    .unwrap_or(0.0)
220            })
221            .sum();
222
223        total / required.len() as f32
224    }
225}
226
227/// Tracks capabilities for agents in the mesh.
228#[derive(Debug)]
229pub struct CapabilityTracker {
230    /// This agent's instance ID
231    instance_id: InstanceId,
232
233    /// This agent's expertise profile
234    profile: ExpertiseProfile,
235
236    /// Known peer profiles (from capability updates)
237    peers: HashMap<InstanceId, ExpertiseProfile>,
238}
239
240impl CapabilityTracker {
241    /// Create a new capability tracker.
242    pub fn new(instance_id: InstanceId) -> Self {
243        Self {
244            instance_id: instance_id.clone(),
245            profile: ExpertiseProfile::new(instance_id),
246            peers: HashMap::new(),
247        }
248    }
249
250    /// Get this agent's instance ID.
251    pub fn instance_id(&self) -> &str {
252        &self.instance_id
253    }
254
255    /// Get this agent's expertise profile.
256    pub fn profile(&self) -> &ExpertiseProfile {
257        &self.profile
258    }
259
260    /// Get mutable reference to this agent's expertise profile.
261    pub fn profile_mut(&mut self) -> &mut ExpertiseProfile {
262        &mut self.profile
263    }
264
265    /// Record a task outcome for this agent.
266    pub fn record_task_outcome(&mut self, domain: &str, outcome: TaskOutcome, strategy: String) {
267        self.profile.record_outcome(domain, outcome, strategy);
268    }
269
270    /// Update a peer's profile from a capability update message.
271    pub fn update_peer_profile(&mut self, profile: ExpertiseProfile) {
272        self.peers.insert(profile.instance_id.clone(), profile);
273    }
274
275    /// Get the best agent for a task requiring specific capabilities.
276    pub fn get_best_agent(
277        &self,
278        required_capabilities: &[String],
279    ) -> Option<RoutingRecommendation> {
280        let mut best: Option<(String, f32)> = None;
281
282        // Check self
283        let self_score = self.profile.match_score(required_capabilities);
284        if self_score > 0.0 {
285            best = Some((self.instance_id.clone(), self_score));
286        }
287
288        // Check peers
289        for (instance_id, profile) in &self.peers {
290            let score = profile.match_score(required_capabilities);
291            if let Some((_, best_score)) = &best {
292                if score > *best_score {
293                    best = Some((instance_id.clone(), score));
294                }
295            } else if score > 0.0 {
296                best = Some((instance_id.clone(), score));
297            }
298        }
299
300        best.map(|(instance_id, score)| {
301            let is_self = instance_id == self.instance_id;
302            RoutingRecommendation {
303                instance_id,
304                score,
305                is_self,
306            }
307        })
308    }
309
310    /// Get all known agents with a minimum capability score.
311    pub fn get_capable_agents(
312        &self,
313        required_capabilities: &[String],
314        min_score: f32,
315    ) -> Vec<RoutingRecommendation> {
316        let mut agents = Vec::new();
317
318        // Check self
319        let self_score = self.profile.match_score(required_capabilities);
320        if self_score >= min_score {
321            agents.push(RoutingRecommendation {
322                instance_id: self.instance_id.clone(),
323                score: self_score,
324                is_self: true,
325            });
326        }
327
328        // Check peers
329        for (instance_id, profile) in &self.peers {
330            let score = profile.match_score(required_capabilities);
331            if score >= min_score {
332                agents.push(RoutingRecommendation {
333                    instance_id: instance_id.clone(),
334                    score,
335                    is_self: false,
336                });
337            }
338        }
339
340        // Sort by score descending
341        agents.sort_by(|a, b| {
342            b.score
343                .partial_cmp(&a.score)
344                .unwrap_or(std::cmp::Ordering::Equal)
345        });
346
347        agents
348    }
349
350    /// Get all known peer profiles.
351    pub fn peers(&self) -> &HashMap<InstanceId, ExpertiseProfile> {
352        &self.peers
353    }
354}
355
356/// Recommendation for routing a task to an agent.
357#[derive(Debug, Clone)]
358pub struct RoutingRecommendation {
359    /// The recommended instance ID
360    pub instance_id: InstanceId,
361
362    /// Match score (0.0 to 1.0)
363    pub score: f32,
364
365    /// Whether this is the local agent
366    pub is_self: bool,
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_capability_update() {
375        let mut cap = Capability::new("code_review".to_string());
376        assert_eq!(cap.proficiency, 0.0);
377
378        // Simulate successful tasks
379        for _ in 0..10 {
380            cap.update(&TaskOutcome::Success {
381                confidence: 0.9,
382                duration_ms: 1000,
383            });
384        }
385
386        assert!(cap.proficiency > 0.0);
387        // With EMA alpha=0.1, after 10 successes: 1 - (0.9)^10 ≈ 0.65
388        assert!(cap.success_rate > 0.5);
389        assert_eq!(cap.experience_count, 10);
390    }
391
392    #[test]
393    fn test_expertise_profile_matching() {
394        let mut profile = ExpertiseProfile::new("agent-1".to_string());
395
396        // Add some capabilities
397        for _ in 0..20 {
398            profile.record_outcome(
399                "code_review",
400                TaskOutcome::Success {
401                    confidence: 0.9,
402                    duration_ms: 1000,
403                },
404                "standard_review".to_string(),
405            );
406        }
407
408        let score = profile.match_score(&["code_review".to_string()]);
409        assert!(score > 0.0);
410
411        let score2 = profile.match_score(&["unknown_domain".to_string()]);
412        assert_eq!(score2, 0.0);
413    }
414
415    #[test]
416    fn test_capability_tracker_routing() {
417        let mut tracker = CapabilityTracker::new("agent-1".to_string());
418
419        // Record some outcomes
420        for _ in 0..10 {
421            tracker.record_task_outcome(
422                "data_analysis",
423                TaskOutcome::Success {
424                    confidence: 0.9,
425                    duration_ms: 500,
426                },
427                "standard".to_string(),
428            );
429        }
430
431        let rec = tracker.get_best_agent(&["data_analysis".to_string()]);
432        assert!(rec.is_some());
433        assert!(rec.unwrap().is_self);
434    }
435}