1use crate::spec_ai_api::persistence::Persistence;
2use anyhow::Result;
3use axum::{
5 extract::{Json, Path, State},
6 http::StatusCode,
7 response::IntoResponse,
8};
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct MeshInstance {
18 pub instance_id: String,
19 pub hostname: String,
20 pub port: u16,
21 pub capabilities: Vec<String>,
22 pub is_leader: bool,
23 pub last_heartbeat: DateTime<Utc>,
24 pub created_at: DateTime<Utc>,
25 pub agent_profiles: Vec<String>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
30pub struct RegisterRequest {
31 pub instance_id: String,
32 pub hostname: String,
33 pub port: u16,
34 pub capabilities: Vec<String>,
35 pub agent_profiles: Vec<String>,
36}
37
38#[derive(Debug, Serialize, Deserialize)]
40pub struct RegisterResponse {
41 pub success: bool,
42 pub instance_id: String,
43 pub is_leader: bool,
44 pub leader_id: Option<String>,
45 pub peers: Vec<MeshInstance>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
50pub struct InstancesResponse {
51 pub instances: Vec<MeshInstance>,
52 pub leader_id: Option<String>,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
57pub struct HeartbeatRequest {
58 pub status: String,
59 pub metrics: Option<HashMap<String, serde_json::Value>>,
60}
61
62#[derive(Debug, Serialize, Deserialize)]
64pub struct HeartbeatResponse {
65 pub acknowledged: bool,
66 pub leader_id: Option<String>,
67 pub should_sync: bool,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
72pub enum MessageType {
73 Query, Response, Notification, TaskDelegation, TaskResult, GraphSync, CapabilityUpdate, CapabilityQuery, LearningShare, ProposalSubmit, ProposalVote, WorkflowAssignment, WorkflowStageResult, Custom(String), }
89
90impl MessageType {
91 pub fn as_str(&self) -> String {
92 match self {
93 MessageType::Query => "query".to_string(),
94 MessageType::Response => "response".to_string(),
95 MessageType::Notification => "notification".to_string(),
96 MessageType::TaskDelegation => "task_delegation".to_string(),
97 MessageType::TaskResult => "task_result".to_string(),
98 MessageType::GraphSync => "graph_sync".to_string(),
99 MessageType::CapabilityUpdate => "capability_update".to_string(),
100 MessageType::CapabilityQuery => "capability_query".to_string(),
101 MessageType::LearningShare => "learning_share".to_string(),
102 MessageType::ProposalSubmit => "proposal_submit".to_string(),
103 MessageType::ProposalVote => "proposal_vote".to_string(),
104 MessageType::WorkflowAssignment => "workflow_assignment".to_string(),
105 MessageType::WorkflowStageResult => "workflow_stage_result".to_string(),
106 MessageType::Custom(s) => s.clone(),
107 }
108 }
109
110 #[allow(clippy::should_implement_trait)]
111 pub fn from_str(s: &str) -> Self {
112 match s.to_lowercase().as_str() {
113 "query" => MessageType::Query,
114 "response" => MessageType::Response,
115 "notification" => MessageType::Notification,
116 "task_delegation" => MessageType::TaskDelegation,
117 "task_result" => MessageType::TaskResult,
118 "graph_sync" => MessageType::GraphSync,
119 "capability_update" => MessageType::CapabilityUpdate,
120 "capability_query" => MessageType::CapabilityQuery,
121 "learning_share" => MessageType::LearningShare,
122 "proposal_submit" => MessageType::ProposalSubmit,
123 "proposal_vote" => MessageType::ProposalVote,
124 "workflow_assignment" => MessageType::WorkflowAssignment,
125 "workflow_stage_result" => MessageType::WorkflowStageResult,
126 custom => MessageType::Custom(custom.to_string()),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct AgentMessage {
134 pub message_id: String,
135 pub source_instance: String,
136 pub target_instance: Option<String>, pub message_type: MessageType,
138 pub payload: serde_json::Value,
139 pub correlation_id: Option<String>, pub created_at: DateTime<Utc>,
141}
142
143#[derive(Debug, Serialize, Deserialize)]
145pub struct SendMessageRequest {
146 pub target_instance: Option<String>, pub message_type: MessageType,
148 pub payload: serde_json::Value,
149 pub correlation_id: Option<String>,
150}
151
152#[derive(Debug, Serialize, Deserialize)]
154pub struct SendMessageResponse {
155 pub message_id: String,
156 pub status: String,
157 pub delivered_to: Vec<String>,
158}
159
160#[derive(Debug, Serialize, Deserialize)]
162pub struct PendingMessagesResponse {
163 pub messages: Vec<AgentMessage>,
164}
165
166#[derive(Clone)]
168pub struct MeshRegistry {
169 instances: Arc<RwLock<HashMap<String, MeshInstance>>>,
170 leader_id: Arc<RwLock<Option<String>>>,
171 message_queue: Arc<RwLock<Vec<AgentMessage>>>,
172 persistence: Option<Persistence>,
173}
174
175impl Default for MeshRegistry {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181impl MeshRegistry {
182 pub fn new() -> Self {
183 Self {
184 instances: Arc::new(RwLock::new(HashMap::new())),
185 leader_id: Arc::new(RwLock::new(None)),
186 message_queue: Arc::new(RwLock::new(Vec::new())),
187 persistence: None,
188 }
189 }
190
191 pub fn with_persistence(persistence: Persistence) -> Self {
192 Self {
193 instances: Arc::new(RwLock::new(HashMap::new())),
194 leader_id: Arc::new(RwLock::new(None)),
195 message_queue: Arc::new(RwLock::new(Vec::new())),
196 persistence: Some(persistence),
197 }
198 }
199
200 pub async fn register(&self, instance: MeshInstance) -> RegisterResponse {
202 let mut instances = self.instances.write().await;
203 let mut leader = self.leader_id.write().await;
204
205 let is_leader = instances.is_empty();
207 let mut new_instance = instance.clone();
208 new_instance.is_leader = is_leader;
209
210 if is_leader {
211 *leader = Some(instance.instance_id.clone());
212 }
213
214 instances.insert(instance.instance_id.clone(), new_instance);
215
216 RegisterResponse {
217 success: true,
218 instance_id: instance.instance_id.clone(),
219 is_leader,
220 leader_id: leader.clone(),
221 peers: instances.values().cloned().collect(),
222 }
223 }
224
225 pub async fn heartbeat(&self, instance_id: &str) -> HeartbeatResponse {
227 let mut instances = self.instances.write().await;
228 let leader = self.leader_id.read().await;
229
230 if let Some(instance) = instances.get_mut(instance_id) {
231 instance.last_heartbeat = Utc::now();
232 HeartbeatResponse {
233 acknowledged: true,
234 leader_id: leader.clone(),
235 should_sync: false,
236 }
237 } else {
238 HeartbeatResponse {
239 acknowledged: false,
240 leader_id: leader.clone(),
241 should_sync: false,
242 }
243 }
244 }
245
246 pub async fn deregister(&self, instance_id: &str) -> bool {
248 let mut instances = self.instances.write().await;
249 let mut leader = self.leader_id.write().await;
250
251 if let Some(instance) = instances.remove(instance_id) {
252 if instance.is_leader && !instances.is_empty() {
254 if let Some((new_leader_id, new_leader)) = instances.iter_mut().next() {
256 new_leader.is_leader = true;
257 *leader = Some(new_leader_id.clone());
258 }
259 } else if instances.is_empty() {
260 *leader = None;
261 }
262 true
263 } else {
264 false
265 }
266 }
267
268 pub async fn list(&self) -> Vec<MeshInstance> {
270 let instances = self.instances.read().await;
271 instances.values().cloned().collect()
272 }
273
274 pub async fn cleanup_stale(&self, timeout_secs: u64) {
276 let now = Utc::now();
277 let mut instances = self.instances.write().await;
278 let mut leader = self.leader_id.write().await;
279
280 let stale_ids: Vec<String> = instances
281 .iter()
282 .filter_map(|(id, instance)| {
283 let elapsed = now.timestamp() - instance.last_heartbeat.timestamp();
284 if elapsed > timeout_secs as i64 {
285 Some(id.clone())
286 } else {
287 None
288 }
289 })
290 .collect();
291
292 for id in stale_ids {
293 if let Some(instance) = instances.remove(&id) {
294 if instance.is_leader && !instances.is_empty() {
296 if let Some((new_leader_id, new_leader)) = instances.iter_mut().next() {
297 new_leader.is_leader = true;
298 *leader = Some(new_leader_id.clone());
299 }
300 }
301 }
302 }
303
304 if instances.is_empty() {
305 *leader = None;
306 }
307 }
308
309 pub async fn get_leader(&self) -> Option<String> {
311 let leader = self.leader_id.read().await;
312 leader.clone()
313 }
314
315 pub async fn send_message(
317 &self,
318 source_instance: String,
319 target_instance: Option<String>,
320 message_type: MessageType,
321 payload: serde_json::Value,
322 correlation_id: Option<String>,
323 ) -> Result<SendMessageResponse> {
324 let message_id = uuid::Uuid::new_v7(uuid::Timestamp::now(uuid::NoContext)).to_string();
326
327 let message = AgentMessage {
328 message_id: message_id.clone(),
329 source_instance,
330 target_instance: target_instance.clone(),
331 message_type,
332 payload,
333 correlation_id,
334 created_at: Utc::now(),
335 };
336
337 if let Some(ref persistence) = self.persistence {
339 let target_str = target_instance.as_deref();
340 if let Err(e) = persistence.mesh_message_store(
341 &message_id,
342 &message.source_instance,
343 target_str,
344 &message.message_type.as_str(),
345 &message.payload,
346 "pending",
347 ) {
348 tracing::warn!("Failed to persist mesh message: {}", e);
349 }
350 }
351
352 let mut queue = self.message_queue.write().await;
354 queue.push(message.clone());
355
356 let delivered_to = if let Some(ref target) = target_instance {
361 let instances = self.instances.read().await;
362 if instances.contains_key(target) {
363 vec![target.clone()]
364 } else {
365 return Err(anyhow::anyhow!("Target instance '{}' not found", target));
366 }
367 } else {
368 let instances = self.instances.read().await;
370 instances.keys().cloned().collect()
371 };
372
373 Ok(SendMessageResponse {
374 message_id,
375 status: "queued".to_string(),
376 delivered_to,
377 })
378 }
379
380 pub async fn get_pending_messages(&self, instance_id: &str) -> Vec<AgentMessage> {
382 let queue = self.message_queue.read().await;
383 queue
384 .iter()
385 .filter(|msg| {
386 msg.target_instance.as_deref() == Some(instance_id) || msg.target_instance.is_none()
388 })
389 .cloned()
390 .collect()
391 }
392
393 pub async fn acknowledge_messages(&self, message_ids: Vec<String>) {
395 let mut queue = self.message_queue.write().await;
396 queue.retain(|msg| !message_ids.contains(&msg.message_id));
397 }
398}
399
400#[derive(Clone)]
402pub struct MeshClient {
403 base_url: String,
404 client: reqwest::Client,
405}
406
407impl MeshClient {
408 pub fn new(host: &str, port: u16) -> Self {
409 let client = reqwest::Client::builder()
410 .no_proxy()
411 .build()
412 .unwrap_or_else(|e| {
413 tracing::warn!("Failed to build mesh client without proxy lookup: {}", e);
414 reqwest::Client::new()
415 });
416
417 Self {
418 base_url: format!("http://{}:{}", host, port),
419 client,
420 }
421 }
422
423 pub fn generate_instance_id() -> String {
425 let hostname = hostname::get()
426 .ok()
427 .and_then(|h| h.into_string().ok())
428 .unwrap_or_else(|| "unknown".to_string());
429 let uuid = uuid::Uuid::new_v7(uuid::Timestamp::now(uuid::NoContext));
431 format!("{}-{}", hostname, uuid)
432 }
433
434 pub async fn register(
436 &self,
437 instance_id: String,
438 hostname: String,
439 port: u16,
440 capabilities: Vec<String>,
441 agent_profiles: Vec<String>,
442 ) -> Result<RegisterResponse> {
443 let request = RegisterRequest {
444 instance_id,
445 hostname,
446 port,
447 capabilities,
448 agent_profiles,
449 };
450
451 let response = self
452 .client
453 .post(format!("{}/registry/register", self.base_url))
454 .json(&request)
455 .send()
456 .await?;
457
458 if response.status().is_success() {
459 Ok(response.json().await?)
460 } else {
461 anyhow::bail!("Registration failed: {}", response.status())
462 }
463 }
464
465 pub async fn heartbeat(
467 &self,
468 instance_id: &str,
469 metrics: Option<HashMap<String, serde_json::Value>>,
470 ) -> Result<HeartbeatResponse> {
471 let request = HeartbeatRequest {
472 status: "healthy".to_string(),
473 metrics,
474 };
475
476 let response = self
477 .client
478 .post(format!(
479 "{}/registry/heartbeat/{}",
480 self.base_url, instance_id
481 ))
482 .json(&request)
483 .send()
484 .await?;
485
486 if response.status().is_success() {
487 Ok(response.json().await?)
488 } else {
489 anyhow::bail!("Heartbeat failed: {}", response.status())
490 }
491 }
492
493 pub async fn list_instances(&self) -> Result<InstancesResponse> {
495 let response = self
496 .client
497 .get(format!("{}/registry/agents", self.base_url))
498 .send()
499 .await?;
500
501 if response.status().is_success() {
502 Ok(response.json().await?)
503 } else {
504 anyhow::bail!("Failed to list instances: {}", response.status())
505 }
506 }
507
508 pub async fn deregister(&self, instance_id: &str) -> Result<()> {
510 let response = self
511 .client
512 .delete(format!(
513 "{}/registry/deregister/{}",
514 self.base_url, instance_id
515 ))
516 .send()
517 .await?;
518
519 if response.status().is_success() {
520 Ok(())
521 } else {
522 anyhow::bail!("Deregistration failed: {}", response.status())
523 }
524 }
525
526 pub async fn send_message(
528 &self,
529 source_instance: String,
530 target_instance: Option<String>,
531 message_type: MessageType,
532 payload: serde_json::Value,
533 correlation_id: Option<String>,
534 ) -> Result<SendMessageResponse> {
535 let request = SendMessageRequest {
536 target_instance,
537 message_type,
538 payload,
539 correlation_id,
540 };
541
542 let response = self
543 .client
544 .post(format!(
545 "{}/messages/send/{}",
546 self.base_url, source_instance
547 ))
548 .json(&request)
549 .send()
550 .await?;
551
552 if response.status().is_success() {
553 Ok(response.json().await?)
554 } else {
555 anyhow::bail!("Send message failed: {}", response.status())
556 }
557 }
558
559 pub async fn get_messages(&self, instance_id: &str) -> Result<PendingMessagesResponse> {
561 let response = self
562 .client
563 .get(format!("{}/messages/{}", self.base_url, instance_id))
564 .send()
565 .await?;
566
567 if response.status().is_success() {
568 Ok(response.json().await?)
569 } else {
570 anyhow::bail!("Get messages failed: {}", response.status())
571 }
572 }
573
574 pub async fn acknowledge_messages(
576 &self,
577 instance_id: &str,
578 message_ids: Vec<String>,
579 ) -> Result<()> {
580 let request = AcknowledgeMessagesRequest { message_ids };
581
582 let response = self
583 .client
584 .post(format!("{}/messages/ack/{}", self.base_url, instance_id))
585 .json(&request)
586 .send()
587 .await?;
588
589 if response.status().is_success() {
590 Ok(())
591 } else {
592 anyhow::bail!("Acknowledge failed: {}", response.status())
593 }
594 }
595}
596
597pub trait MeshState {
599 fn mesh_registry(&self) -> &MeshRegistry;
600}
601
602pub async fn register_instance<S: MeshState>(
604 State(state): State<S>,
605 Json(request): Json<RegisterRequest>,
606) -> impl IntoResponse {
607 let instance = MeshInstance {
608 instance_id: request.instance_id,
609 hostname: request.hostname,
610 port: request.port,
611 capabilities: request.capabilities,
612 is_leader: false, last_heartbeat: Utc::now(),
614 created_at: Utc::now(),
615 agent_profiles: request.agent_profiles,
616 };
617
618 let response = state.mesh_registry().register(instance).await;
619 (StatusCode::OK, Json(response))
620}
621
622pub async fn list_instances<S: MeshState>(State(state): State<S>) -> impl IntoResponse {
624 let instances = state.mesh_registry().list().await;
625 let leader_id = instances
626 .iter()
627 .find(|i| i.is_leader)
628 .map(|i| i.instance_id.clone());
629
630 Json(InstancesResponse {
631 instances,
632 leader_id,
633 })
634}
635
636pub async fn heartbeat<S: MeshState>(
638 State(state): State<S>,
639 Path(instance_id): Path<String>,
640 Json(_request): Json<HeartbeatRequest>,
641) -> impl IntoResponse {
642 let response = state.mesh_registry().heartbeat(&instance_id).await;
643
644 if response.acknowledged {
645 (StatusCode::OK, Json(response))
646 } else {
647 (StatusCode::NOT_FOUND, Json(response))
648 }
649}
650
651pub async fn deregister_instance<S: MeshState>(
653 State(state): State<S>,
654 Path(instance_id): Path<String>,
655) -> impl IntoResponse {
656 let removed = state.mesh_registry().deregister(&instance_id).await;
657
658 if removed {
659 StatusCode::NO_CONTENT
660 } else {
661 StatusCode::NOT_FOUND
662 }
663}
664
665pub async fn send_message<S: MeshState>(
667 State(state): State<S>,
668 Path(source_instance): Path<String>,
669 Json(request): Json<SendMessageRequest>,
670) -> impl IntoResponse {
671 match state
672 .mesh_registry()
673 .send_message(
674 source_instance,
675 request.target_instance,
676 request.message_type,
677 request.payload,
678 request.correlation_id,
679 )
680 .await
681 {
682 Ok(response) => (StatusCode::OK, Json(response)).into_response(),
683 Err(e) => (
684 StatusCode::BAD_REQUEST,
685 Json(serde_json::json!({
686 "error": e.to_string()
687 })),
688 )
689 .into_response(),
690 }
691}
692
693pub async fn get_messages<S: MeshState>(
695 State(state): State<S>,
696 Path(instance_id): Path<String>,
697) -> impl IntoResponse {
698 let messages = state
699 .mesh_registry()
700 .get_pending_messages(&instance_id)
701 .await;
702
703 Json(PendingMessagesResponse { messages })
704}
705
706#[derive(Debug, Serialize, Deserialize)]
708pub struct AcknowledgeMessagesRequest {
709 pub message_ids: Vec<String>,
710}
711
712pub async fn acknowledge_messages<S: MeshState>(
714 State(state): State<S>,
715 Path(_instance_id): Path<String>,
716 Json(request): Json<AcknowledgeMessagesRequest>,
717) -> impl IntoResponse {
718 state
719 .mesh_registry()
720 .acknowledge_messages(request.message_ids)
721 .await;
722
723 StatusCode::NO_CONTENT
724}