swarm_rs/
agent.rs

1use std::{any::Any, fmt::Display};
2
3use async_trait::async_trait;
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::prelude::Swarm;
8
9#[async_trait]
10pub trait Agent: Any + Send + Sync {
11    async fn execute(&self, input: &Action, swarm: &Swarm) -> Output;
12    fn as_any(&self) -> &dyn Any;
13}
14
15#[derive(Serialize, Deserialize)]
16pub struct Action {
17    id: String,
18    pub(crate) payload: Value,
19}
20
21impl Action {
22    pub fn new<T: Serialize>(id: &str, payload: T) -> Self {
23        Self {
24            id: id.to_string(),
25            payload: serde_json::to_value(payload).unwrap(),
26        }
27    }
28
29    pub fn get_payload<T: DeserializeOwned>(&self) -> Result<T, String> {
30        if let Ok(payload) = serde_json::from_value(self.payload.clone()) {
31            Ok(payload)
32        } else {
33            Err("Unable to deserialize payload".to_string())
34        }
35    }
36
37    pub fn get_name(&self) -> &str {
38        if let Some((_, action_id)) = &self.id.split_once(".") {
39            *action_id
40        } else {
41            "default"
42        }
43    }
44
45    pub fn get_agent(&self) -> &str {
46        if let Some((agent_id, _)) = &self.id.split_once(".") {
47            *agent_id
48        } else {
49            &self.id
50        }
51    }
52
53    pub fn get_id(&self) -> &str {
54        &self.id
55    }
56}
57
58#[derive(Serialize, Deserialize)]
59pub struct Output {
60    pub agent_id: String,
61    status: String,
62    payload: Value,
63}
64
65impl Output {
66    pub fn new_success<T: Serialize>(payload: T) -> Self {
67        Self {
68            agent_id: "".to_string(),
69            status: "SUCCESS".to_string(),
70            payload: serde_json::to_value(payload).unwrap(),
71        }
72    }
73
74    pub fn new_error(message: &str) -> Self {
75        Self {
76            agent_id: "".to_string(),
77            status: "ERROR".to_string(),
78            payload: serde_json::to_value(message).unwrap(),
79        }
80    }
81
82    pub fn get_payload<T: DeserializeOwned>(&self) -> T {
83        serde_json::from_value(self.payload.clone()).unwrap()
84    }
85
86    pub fn get_value(&self) -> &Value {
87        &self.payload
88    }
89
90    pub fn get_error_message(&self) -> String {
91        self.get_payload::<String>()
92    }
93
94    pub fn is_success(&self) -> bool {
95        self.status.as_str() == "SUCCESS"
96    }
97}
98
99impl Display for Output {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        if self.is_success() {
102            write!(f, "SUCCESS : {}", self.get_value())
103        } else {
104            write!(f, "ERROR : {}", self.get_error_message())
105        }
106    }
107}