Skip to main content

sanctum_ai/
protocol.rs

1use serde::{Deserialize, Serialize};
2
3/// JSON-RPC request over the Unix socket.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct RpcRequest {
6    pub id: u64,
7    pub method: String,
8    #[serde(default)]
9    pub params: serde_json::Value,
10}
11
12/// Structured error for AI-agent-friendly diagnostics.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct RpcError {
15    pub code: String,
16    pub message: String,
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub detail: Option<String>,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub suggestion: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub docs_url: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub context: Option<serde_json::Value>,
25}
26
27impl RpcError {
28    pub fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
29        let code = code.into();
30        let docs_url = format!("https://sanctum.dev/errors/{code}");
31        Self {
32            code,
33            message: message.into(),
34            detail: None,
35            suggestion: None,
36            docs_url: Some(docs_url),
37            context: None,
38        }
39    }
40
41    pub fn detail(mut self, d: impl Into<String>) -> Self {
42        self.detail = Some(d.into());
43        self
44    }
45
46    pub fn suggestion(mut self, s: impl Into<String>) -> Self {
47        self.suggestion = Some(s.into());
48        self
49    }
50
51    pub fn context(mut self, ctx: serde_json::Value) -> Self {
52        self.context = Some(ctx);
53        self
54    }
55}
56
57/// JSON-RPC response. The `error` field is either a string (legacy) or a
58/// structured `RpcError` object, keeping backward compatibility.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct RpcResponse {
61    pub id: u64,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub result: Option<serde_json::Value>,
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub error: Option<serde_json::Value>,
66}
67
68impl RpcResponse {
69    pub fn success(id: u64, result: serde_json::Value) -> Self {
70        Self {
71            id,
72            result: Some(result),
73            error: None,
74        }
75    }
76
77    /// Legacy string error (backward compatible).
78    pub fn error(id: u64, message: impl Into<String>) -> Self {
79        Self {
80            id,
81            result: None,
82            error: Some(serde_json::Value::String(message.into())),
83        }
84    }
85
86    /// Structured error with full diagnostic context for AI agents.
87    pub fn structured_error(id: u64, err: RpcError) -> Self {
88        Self {
89            id,
90            result: None,
91            error: Some(
92                serde_json::to_value(err)
93                    .unwrap_or(serde_json::Value::String("serialization error".into())),
94            ),
95        }
96    }
97
98    /// Check whether this response is an error (structured or legacy).
99    pub fn is_error(&self) -> bool {
100        self.error.is_some()
101    }
102
103    /// Extract the error code from a structured error, or None for legacy strings.
104    pub fn error_code(&self) -> Option<&str> {
105        self.error
106            .as_ref()
107            .and_then(|v| v.as_object())
108            .and_then(|obj| obj.get("code"))
109            .and_then(|c| c.as_str())
110    }
111
112    /// Extract the error message (works for both legacy strings and structured errors).
113    pub fn error_message(&self) -> Option<&str> {
114        match &self.error {
115            Some(serde_json::Value::String(s)) => Some(s.as_str()),
116            Some(v) => v
117                .as_object()
118                .and_then(|o| o.get("message"))
119                .and_then(|m| m.as_str()),
120            None => None,
121        }
122    }
123}
124
125/// Encode a message with a 4-byte big-endian length prefix.
126pub fn encode_message(msg: &[u8]) -> Vec<u8> {
127    let len = msg.len() as u32;
128    let mut buf = Vec::with_capacity(4 + msg.len());
129    buf.extend_from_slice(&len.to_be_bytes());
130    buf.extend_from_slice(msg);
131    buf
132}
133
134/// Read the 4-byte length prefix and return the expected message length.
135pub fn decode_length(header: &[u8; 4]) -> u32 {
136    u32::from_be_bytes(*header)
137}
138
139// --- Param types for each RPC method ---
140
141#[derive(Debug, Deserialize)]
142pub struct AuthenticateParams {
143    pub agent_name: String,
144}
145
146#[derive(Debug, Deserialize)]
147pub struct ChallengeResponseParams {
148    pub session_id: String,
149    pub signature: String, // hex
150}
151
152#[derive(Debug, Deserialize)]
153pub struct RetrieveParams {
154    pub session_id: String,
155    pub path: String,
156    pub ttl: Option<u64>,
157}
158
159#[derive(Debug, Deserialize)]
160pub struct ListParams {
161    pub session_id: String,
162}
163
164#[derive(Debug, Deserialize)]
165pub struct ReleaseLeaseParams {
166    pub lease_id: String,
167}
168
169#[derive(Debug, Deserialize)]
170pub struct UseParams {
171    pub session_id: String,
172    pub path: String,
173    pub operation: String,
174    pub params: serde_json::Value,
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_rpc_request_serialize() {
183        let req = RpcRequest {
184            id: 1,
185            method: "authenticate".to_string(),
186            params: serde_json::json!({"agent_name": "test"}),
187        };
188        let json = serde_json::to_string(&req).unwrap();
189        let parsed: RpcRequest = serde_json::from_str(&json).unwrap();
190        assert_eq!(parsed.id, 1);
191        assert_eq!(parsed.method, "authenticate");
192    }
193
194    #[test]
195    fn test_rpc_response_success() {
196        let resp = RpcResponse::success(1, serde_json::json!({"ok": true}));
197        assert!(resp.result.is_some());
198        assert!(resp.error.is_none());
199        let json = serde_json::to_string(&resp).unwrap();
200        assert!(json.contains("\"ok\":true"));
201        assert!(!json.contains("error"));
202    }
203
204    #[test]
205    fn test_rpc_response_error() {
206        let resp = RpcResponse::error(2, "access denied");
207        assert!(resp.result.is_none());
208        assert!(resp.is_error());
209        assert_eq!(resp.error_message(), Some("access denied"));
210    }
211
212    #[test]
213    fn test_rpc_structured_error() {
214        let err = RpcError::new("ACCESS_DENIED", "Agent 'a' cannot retrieve 'k'")
215            .detail("No matching policy")
216            .suggestion("Add a policy for agent:a")
217            .context(serde_json::json!({"agent": "a", "resource": "k"}));
218        let resp = RpcResponse::structured_error(1, err);
219        assert!(resp.is_error());
220        assert_eq!(resp.error_code(), Some("ACCESS_DENIED"));
221        assert_eq!(resp.error_message(), Some("Agent 'a' cannot retrieve 'k'"));
222        // Verify it serializes to JSON with all fields
223        let json = serde_json::to_string(&resp).unwrap();
224        assert!(json.contains("ACCESS_DENIED"));
225        assert!(json.contains("docs_url"));
226    }
227
228    #[test]
229    fn test_encode_decode_message() {
230        let msg = b"hello world";
231        let encoded = encode_message(msg);
232        assert_eq!(encoded.len(), 4 + msg.len());
233        let header: [u8; 4] = encoded[..4].try_into().unwrap();
234        assert_eq!(decode_length(&header), msg.len() as u32);
235        assert_eq!(&encoded[4..], msg);
236    }
237
238    #[test]
239    fn test_authenticate_params_deserialize() {
240        let json = r#"{"agent_name": "my-agent"}"#;
241        let params: AuthenticateParams = serde_json::from_str(json).unwrap();
242        assert_eq!(params.agent_name, "my-agent");
243    }
244
245    #[test]
246    fn test_retrieve_params_deserialize() {
247        let json = r#"{"session_id": "abc", "path": "openai/key", "ttl": 300}"#;
248        let params: RetrieveParams = serde_json::from_str(json).unwrap();
249        assert_eq!(params.path, "openai/key");
250        assert_eq!(params.ttl, Some(300));
251    }
252}