Skip to main content

spec_ai/spec_ai_api/api/
mesh.rs

1use crate::spec_ai_api::persistence::Persistence;
2use anyhow::Result;
3/// Mesh registry handlers and models
4use axum::{
5    extract::{Json, Path, State},
6    http::StatusCode,
7    response::IntoResponse,
8};
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15/// Agent instance information in the mesh
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MeshInstance {
18    pub instance_id: String,
19    pub hostname: String,
20    pub port: u16,
21    pub capabilities: Vec<String>,
22    pub is_leader: bool,
23    pub last_heartbeat: DateTime<Utc>,
24    pub created_at: DateTime<Utc>,
25    pub agent_profiles: Vec<String>,
26}
27
28/// Request to register a new instance
29#[derive(Debug, Serialize, Deserialize)]
30pub struct RegisterRequest {
31    pub instance_id: String,
32    pub hostname: String,
33    pub port: u16,
34    pub capabilities: Vec<String>,
35    pub agent_profiles: Vec<String>,
36}
37
38/// Response from registration
39#[derive(Debug, Serialize, Deserialize)]
40pub struct RegisterResponse {
41    pub success: bool,
42    pub instance_id: String,
43    pub is_leader: bool,
44    pub leader_id: Option<String>,
45    pub peers: Vec<MeshInstance>,
46}
47
48/// List of registered instances
49#[derive(Debug, Serialize, Deserialize)]
50pub struct InstancesResponse {
51    pub instances: Vec<MeshInstance>,
52    pub leader_id: Option<String>,
53}
54
55/// Heartbeat request
56#[derive(Debug, Serialize, Deserialize)]
57pub struct HeartbeatRequest {
58    pub status: String,
59    pub metrics: Option<HashMap<String, serde_json::Value>>,
60}
61
62/// Heartbeat response
63#[derive(Debug, Serialize, Deserialize)]
64pub struct HeartbeatResponse {
65    pub acknowledged: bool,
66    pub leader_id: Option<String>,
67    pub should_sync: bool,
68}
69
70/// Message types for inter-agent communication
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum MessageType {
73    Query,          // Request information from another agent
74    Response,       // Response to a query
75    Notification,   // One-way notification
76    TaskDelegation, // Delegate a task to another agent
77    TaskResult,     // Result of a delegated task
78    GraphSync,      // Knowledge graph synchronization
79    // Collective intelligence message types
80    CapabilityUpdate,    // Share capability/expertise profile updates
81    CapabilityQuery,     // Request capability information from peers
82    LearningShare,       // Share a learned strategy with the mesh
83    ProposalSubmit,      // Submit a proposal for collective decision
84    ProposalVote,        // Cast a vote on a proposal
85    WorkflowAssignment,  // Assign a workflow stage to an agent
86    WorkflowStageResult, // Report completion of a workflow stage
87    Custom(String),      // Custom message type
88}
89
90impl MessageType {
91    pub fn as_str(&self) -> String {
92        match self {
93            MessageType::Query => "query".to_string(),
94            MessageType::Response => "response".to_string(),
95            MessageType::Notification => "notification".to_string(),
96            MessageType::TaskDelegation => "task_delegation".to_string(),
97            MessageType::TaskResult => "task_result".to_string(),
98            MessageType::GraphSync => "graph_sync".to_string(),
99            MessageType::CapabilityUpdate => "capability_update".to_string(),
100            MessageType::CapabilityQuery => "capability_query".to_string(),
101            MessageType::LearningShare => "learning_share".to_string(),
102            MessageType::ProposalSubmit => "proposal_submit".to_string(),
103            MessageType::ProposalVote => "proposal_vote".to_string(),
104            MessageType::WorkflowAssignment => "workflow_assignment".to_string(),
105            MessageType::WorkflowStageResult => "workflow_stage_result".to_string(),
106            MessageType::Custom(s) => s.clone(),
107        }
108    }
109
110    #[allow(clippy::should_implement_trait)]
111    pub fn from_str(s: &str) -> Self {
112        match s.to_lowercase().as_str() {
113            "query" => MessageType::Query,
114            "response" => MessageType::Response,
115            "notification" => MessageType::Notification,
116            "task_delegation" => MessageType::TaskDelegation,
117            "task_result" => MessageType::TaskResult,
118            "graph_sync" => MessageType::GraphSync,
119            "capability_update" => MessageType::CapabilityUpdate,
120            "capability_query" => MessageType::CapabilityQuery,
121            "learning_share" => MessageType::LearningShare,
122            "proposal_submit" => MessageType::ProposalSubmit,
123            "proposal_vote" => MessageType::ProposalVote,
124            "workflow_assignment" => MessageType::WorkflowAssignment,
125            "workflow_stage_result" => MessageType::WorkflowStageResult,
126            custom => MessageType::Custom(custom.to_string()),
127        }
128    }
129}
130
131/// Inter-agent message
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct AgentMessage {
134    pub message_id: String,
135    pub source_instance: String,
136    pub target_instance: Option<String>, // None for broadcast
137    pub message_type: MessageType,
138    pub payload: serde_json::Value,
139    pub correlation_id: Option<String>, // For request/response correlation
140    pub created_at: DateTime<Utc>,
141}
142
143/// Message send request
144#[derive(Debug, Serialize, Deserialize)]
145pub struct SendMessageRequest {
146    pub target_instance: Option<String>, // None for broadcast
147    pub message_type: MessageType,
148    pub payload: serde_json::Value,
149    pub correlation_id: Option<String>,
150}
151
152/// Message send response
153#[derive(Debug, Serialize, Deserialize)]
154pub struct SendMessageResponse {
155    pub message_id: String,
156    pub status: String,
157    pub delivered_to: Vec<String>,
158}
159
160/// Pending messages response
161#[derive(Debug, Serialize, Deserialize)]
162pub struct PendingMessagesResponse {
163    pub messages: Vec<AgentMessage>,
164}
165
166/// Mesh registry state
167#[derive(Clone)]
168pub struct MeshRegistry {
169    instances: Arc<RwLock<HashMap<String, MeshInstance>>>,
170    leader_id: Arc<RwLock<Option<String>>>,
171    message_queue: Arc<RwLock<Vec<AgentMessage>>>,
172    persistence: Option<Persistence>,
173}
174
175impl Default for MeshRegistry {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181impl MeshRegistry {
182    pub fn new() -> Self {
183        Self {
184            instances: Arc::new(RwLock::new(HashMap::new())),
185            leader_id: Arc::new(RwLock::new(None)),
186            message_queue: Arc::new(RwLock::new(Vec::new())),
187            persistence: None,
188        }
189    }
190
191    pub fn with_persistence(persistence: Persistence) -> Self {
192        Self {
193            instances: Arc::new(RwLock::new(HashMap::new())),
194            leader_id: Arc::new(RwLock::new(None)),
195            message_queue: Arc::new(RwLock::new(Vec::new())),
196            persistence: Some(persistence),
197        }
198    }
199
200    /// Register a new instance
201    pub async fn register(&self, instance: MeshInstance) -> RegisterResponse {
202        let mut instances = self.instances.write().await;
203        let mut leader = self.leader_id.write().await;
204
205        // First instance becomes the leader
206        let is_leader = instances.is_empty();
207        let mut new_instance = instance.clone();
208        new_instance.is_leader = is_leader;
209
210        if is_leader {
211            *leader = Some(instance.instance_id.clone());
212        }
213
214        instances.insert(instance.instance_id.clone(), new_instance);
215
216        RegisterResponse {
217            success: true,
218            instance_id: instance.instance_id.clone(),
219            is_leader,
220            leader_id: leader.clone(),
221            peers: instances.values().cloned().collect(),
222        }
223    }
224
225    /// Update heartbeat timestamp
226    pub async fn heartbeat(&self, instance_id: &str) -> HeartbeatResponse {
227        let mut instances = self.instances.write().await;
228        let leader = self.leader_id.read().await;
229
230        if let Some(instance) = instances.get_mut(instance_id) {
231            instance.last_heartbeat = Utc::now();
232            HeartbeatResponse {
233                acknowledged: true,
234                leader_id: leader.clone(),
235                should_sync: false,
236            }
237        } else {
238            HeartbeatResponse {
239                acknowledged: false,
240                leader_id: leader.clone(),
241                should_sync: false,
242            }
243        }
244    }
245
246    /// Remove an instance
247    pub async fn deregister(&self, instance_id: &str) -> bool {
248        let mut instances = self.instances.write().await;
249        let mut leader = self.leader_id.write().await;
250
251        if let Some(instance) = instances.remove(instance_id) {
252            // If leader is leaving, elect a new one
253            if instance.is_leader && !instances.is_empty() {
254                // Simple election: first remaining instance becomes leader
255                if let Some((new_leader_id, new_leader)) = instances.iter_mut().next() {
256                    new_leader.is_leader = true;
257                    *leader = Some(new_leader_id.clone());
258                }
259            } else if instances.is_empty() {
260                *leader = None;
261            }
262            true
263        } else {
264            false
265        }
266    }
267
268    /// Get all instances
269    pub async fn list(&self) -> Vec<MeshInstance> {
270        let instances = self.instances.read().await;
271        instances.values().cloned().collect()
272    }
273
274    /// Check for stale instances and remove them
275    pub async fn cleanup_stale(&self, timeout_secs: u64) {
276        let now = Utc::now();
277        let mut instances = self.instances.write().await;
278        let mut leader = self.leader_id.write().await;
279
280        let stale_ids: Vec<String> = instances
281            .iter()
282            .filter_map(|(id, instance)| {
283                let elapsed = now.timestamp() - instance.last_heartbeat.timestamp();
284                if elapsed > timeout_secs as i64 {
285                    Some(id.clone())
286                } else {
287                    None
288                }
289            })
290            .collect();
291
292        for id in stale_ids {
293            if let Some(instance) = instances.remove(&id) {
294                // Handle leader failover if needed
295                if instance.is_leader && !instances.is_empty() {
296                    if let Some((new_leader_id, new_leader)) = instances.iter_mut().next() {
297                        new_leader.is_leader = true;
298                        *leader = Some(new_leader_id.clone());
299                    }
300                }
301            }
302        }
303
304        if instances.is_empty() {
305            *leader = None;
306        }
307    }
308
309    /// Get the current leader ID
310    pub async fn get_leader(&self) -> Option<String> {
311        let leader = self.leader_id.read().await;
312        leader.clone()
313    }
314
315    /// Send a message to an instance or broadcast
316    pub async fn send_message(
317        &self,
318        source_instance: String,
319        target_instance: Option<String>,
320        message_type: MessageType,
321        payload: serde_json::Value,
322        correlation_id: Option<String>,
323    ) -> Result<SendMessageResponse> {
324        // Generate time-ordered UUID v7 for better database performance and distributed safety
325        let message_id = uuid::Uuid::new_v7(uuid::Timestamp::now(uuid::NoContext)).to_string();
326
327        let message = AgentMessage {
328            message_id: message_id.clone(),
329            source_instance,
330            target_instance: target_instance.clone(),
331            message_type,
332            payload,
333            correlation_id,
334            created_at: Utc::now(),
335        };
336
337        // Persist to database if available
338        if let Some(ref persistence) = self.persistence {
339            let target_str = target_instance.as_deref();
340            if let Err(e) = persistence.mesh_message_store(
341                &message_id,
342                &message.source_instance,
343                target_str,
344                &message.message_type.as_str(),
345                &message.payload,
346                "pending",
347            ) {
348                tracing::warn!("Failed to persist mesh message: {}", e);
349            }
350        }
351
352        // Add to message queue
353        let mut queue = self.message_queue.write().await;
354        queue.push(message.clone());
355
356        // GraphSync messages are handled when retrieved from the queue
357        // to avoid recursion issues
358
359        // Determine who received it
360        let delivered_to = if let Some(ref target) = target_instance {
361            let instances = self.instances.read().await;
362            if instances.contains_key(target) {
363                vec![target.clone()]
364            } else {
365                return Err(anyhow::anyhow!("Target instance '{}' not found", target));
366            }
367        } else {
368            // Broadcast - delivered to all instances
369            let instances = self.instances.read().await;
370            instances.keys().cloned().collect()
371        };
372
373        Ok(SendMessageResponse {
374            message_id,
375            status: "queued".to_string(),
376            delivered_to,
377        })
378    }
379
380    /// Get pending messages for an instance
381    pub async fn get_pending_messages(&self, instance_id: &str) -> Vec<AgentMessage> {
382        let queue = self.message_queue.read().await;
383        queue
384            .iter()
385            .filter(|msg| {
386                // Return messages targeted at this instance or broadcasts (None)
387                msg.target_instance.as_deref() == Some(instance_id) || msg.target_instance.is_none()
388            })
389            .cloned()
390            .collect()
391    }
392
393    /// Acknowledge/remove messages after delivery
394    pub async fn acknowledge_messages(&self, message_ids: Vec<String>) {
395        let mut queue = self.message_queue.write().await;
396        queue.retain(|msg| !message_ids.contains(&msg.message_id));
397    }
398}
399
400/// Client-side mesh operations
401#[derive(Clone)]
402pub struct MeshClient {
403    base_url: String,
404    client: reqwest::Client,
405}
406
407impl MeshClient {
408    pub fn new(host: &str, port: u16) -> Self {
409        let client = reqwest::Client::builder()
410            .no_proxy()
411            .build()
412            .unwrap_or_else(|e| {
413                tracing::warn!("Failed to build mesh client without proxy lookup: {}", e);
414                reqwest::Client::new()
415            });
416
417        Self {
418            base_url: format!("http://{}:{}", host, port),
419            client,
420        }
421    }
422
423    /// Generate a unique instance ID
424    pub fn generate_instance_id() -> String {
425        let hostname = hostname::get()
426            .ok()
427            .and_then(|h| h.into_string().ok())
428            .unwrap_or_else(|| "unknown".to_string());
429        // Use UUID v7 for time-ordered, globally unique IDs with better collision resistance
430        let uuid = uuid::Uuid::new_v7(uuid::Timestamp::now(uuid::NoContext));
431        format!("{}-{}", hostname, uuid)
432    }
433
434    /// Register this instance with a mesh registry
435    pub async fn register(
436        &self,
437        instance_id: String,
438        hostname: String,
439        port: u16,
440        capabilities: Vec<String>,
441        agent_profiles: Vec<String>,
442    ) -> Result<RegisterResponse> {
443        let request = RegisterRequest {
444            instance_id,
445            hostname,
446            port,
447            capabilities,
448            agent_profiles,
449        };
450
451        let response = self
452            .client
453            .post(format!("{}/registry/register", self.base_url))
454            .json(&request)
455            .send()
456            .await?;
457
458        if response.status().is_success() {
459            Ok(response.json().await?)
460        } else {
461            anyhow::bail!("Registration failed: {}", response.status())
462        }
463    }
464
465    /// Send heartbeat to registry
466    pub async fn heartbeat(
467        &self,
468        instance_id: &str,
469        metrics: Option<HashMap<String, serde_json::Value>>,
470    ) -> Result<HeartbeatResponse> {
471        let request = HeartbeatRequest {
472            status: "healthy".to_string(),
473            metrics,
474        };
475
476        let response = self
477            .client
478            .post(format!(
479                "{}/registry/heartbeat/{}",
480                self.base_url, instance_id
481            ))
482            .json(&request)
483            .send()
484            .await?;
485
486        if response.status().is_success() {
487            Ok(response.json().await?)
488        } else {
489            anyhow::bail!("Heartbeat failed: {}", response.status())
490        }
491    }
492
493    /// List all instances in the mesh
494    pub async fn list_instances(&self) -> Result<InstancesResponse> {
495        let response = self
496            .client
497            .get(format!("{}/registry/agents", self.base_url))
498            .send()
499            .await?;
500
501        if response.status().is_success() {
502            Ok(response.json().await?)
503        } else {
504            anyhow::bail!("Failed to list instances: {}", response.status())
505        }
506    }
507
508    /// Deregister from the mesh
509    pub async fn deregister(&self, instance_id: &str) -> Result<()> {
510        let response = self
511            .client
512            .delete(format!(
513                "{}/registry/deregister/{}",
514                self.base_url, instance_id
515            ))
516            .send()
517            .await?;
518
519        if response.status().is_success() {
520            Ok(())
521        } else {
522            anyhow::bail!("Deregistration failed: {}", response.status())
523        }
524    }
525
526    /// Send a message to another instance
527    pub async fn send_message(
528        &self,
529        source_instance: String,
530        target_instance: Option<String>,
531        message_type: MessageType,
532        payload: serde_json::Value,
533        correlation_id: Option<String>,
534    ) -> Result<SendMessageResponse> {
535        let request = SendMessageRequest {
536            target_instance,
537            message_type,
538            payload,
539            correlation_id,
540        };
541
542        let response = self
543            .client
544            .post(format!(
545                "{}/messages/send/{}",
546                self.base_url, source_instance
547            ))
548            .json(&request)
549            .send()
550            .await?;
551
552        if response.status().is_success() {
553            Ok(response.json().await?)
554        } else {
555            anyhow::bail!("Send message failed: {}", response.status())
556        }
557    }
558
559    /// Get pending messages for an instance
560    pub async fn get_messages(&self, instance_id: &str) -> Result<PendingMessagesResponse> {
561        let response = self
562            .client
563            .get(format!("{}/messages/{}", self.base_url, instance_id))
564            .send()
565            .await?;
566
567        if response.status().is_success() {
568            Ok(response.json().await?)
569        } else {
570            anyhow::bail!("Get messages failed: {}", response.status())
571        }
572    }
573
574    /// Acknowledge received messages
575    pub async fn acknowledge_messages(
576        &self,
577        instance_id: &str,
578        message_ids: Vec<String>,
579    ) -> Result<()> {
580        let request = AcknowledgeMessagesRequest { message_ids };
581
582        let response = self
583            .client
584            .post(format!("{}/messages/ack/{}", self.base_url, instance_id))
585            .json(&request)
586            .send()
587            .await?;
588
589        if response.status().is_success() {
590            Ok(())
591        } else {
592            anyhow::bail!("Acknowledge failed: {}", response.status())
593        }
594    }
595}
596
597/// Extension trait to add mesh registry to app state
598pub trait MeshState {
599    fn mesh_registry(&self) -> &MeshRegistry;
600}
601
602/// Handler: Register a new instance
603pub async fn register_instance<S: MeshState>(
604    State(state): State<S>,
605    Json(request): Json<RegisterRequest>,
606) -> impl IntoResponse {
607    let instance = MeshInstance {
608        instance_id: request.instance_id,
609        hostname: request.hostname,
610        port: request.port,
611        capabilities: request.capabilities,
612        is_leader: false, // Will be set by registry
613        last_heartbeat: Utc::now(),
614        created_at: Utc::now(),
615        agent_profiles: request.agent_profiles,
616    };
617
618    let response = state.mesh_registry().register(instance).await;
619    (StatusCode::OK, Json(response))
620}
621
622/// Handler: List all instances
623pub async fn list_instances<S: MeshState>(State(state): State<S>) -> impl IntoResponse {
624    let instances = state.mesh_registry().list().await;
625    let leader_id = instances
626        .iter()
627        .find(|i| i.is_leader)
628        .map(|i| i.instance_id.clone());
629
630    Json(InstancesResponse {
631        instances,
632        leader_id,
633    })
634}
635
636/// Handler: Heartbeat from an instance
637pub async fn heartbeat<S: MeshState>(
638    State(state): State<S>,
639    Path(instance_id): Path<String>,
640    Json(_request): Json<HeartbeatRequest>,
641) -> impl IntoResponse {
642    let response = state.mesh_registry().heartbeat(&instance_id).await;
643
644    if response.acknowledged {
645        (StatusCode::OK, Json(response))
646    } else {
647        (StatusCode::NOT_FOUND, Json(response))
648    }
649}
650
651/// Handler: Deregister an instance
652pub async fn deregister_instance<S: MeshState>(
653    State(state): State<S>,
654    Path(instance_id): Path<String>,
655) -> impl IntoResponse {
656    let removed = state.mesh_registry().deregister(&instance_id).await;
657
658    if removed {
659        StatusCode::NO_CONTENT
660    } else {
661        StatusCode::NOT_FOUND
662    }
663}
664
665/// Handler: Send a message to another instance
666pub async fn send_message<S: MeshState>(
667    State(state): State<S>,
668    Path(source_instance): Path<String>,
669    Json(request): Json<SendMessageRequest>,
670) -> impl IntoResponse {
671    match state
672        .mesh_registry()
673        .send_message(
674            source_instance,
675            request.target_instance,
676            request.message_type,
677            request.payload,
678            request.correlation_id,
679        )
680        .await
681    {
682        Ok(response) => (StatusCode::OK, Json(response)).into_response(),
683        Err(e) => (
684            StatusCode::BAD_REQUEST,
685            Json(serde_json::json!({
686                "error": e.to_string()
687            })),
688        )
689            .into_response(),
690    }
691}
692
693/// Handler: Get pending messages for an instance
694pub async fn get_messages<S: MeshState>(
695    State(state): State<S>,
696    Path(instance_id): Path<String>,
697) -> impl IntoResponse {
698    let messages = state
699        .mesh_registry()
700        .get_pending_messages(&instance_id)
701        .await;
702
703    Json(PendingMessagesResponse { messages })
704}
705
706/// Acknowledge messages request
707#[derive(Debug, Serialize, Deserialize)]
708pub struct AcknowledgeMessagesRequest {
709    pub message_ids: Vec<String>,
710}
711
712/// Handler: Acknowledge received messages
713pub async fn acknowledge_messages<S: MeshState>(
714    State(state): State<S>,
715    Path(_instance_id): Path<String>,
716    Json(request): Json<AcknowledgeMessagesRequest>,
717) -> impl IntoResponse {
718    state
719        .mesh_registry()
720        .acknowledge_messages(request.message_ids)
721        .await;
722
723    StatusCode::NO_CONTENT
724}