Skip to main content

systemprompt_agent/services/external_integrations/webhook/service/
mod.rs

1mod delivery;
2mod types;
3
4pub use types::{RetryPolicy, WebhookConfig, WebhookDeliveryResult, WebhookTestResult};
5
6use hmac::{Hmac, Mac};
7use reqwest::Client;
8use serde_json::Value;
9use sha2::Sha256;
10use std::collections::HashMap;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14use crate::models::external_integrations::{
15    IntegrationError, IntegrationResult, WebhookEndpoint, WebhookRequest, WebhookResponse,
16};
17
18type HmacSha256 = Hmac<Sha256>;
19
20#[derive(Debug)]
21pub struct WebhookService {
22    pub(crate) endpoints: RwLock<HashMap<String, WebhookEndpoint>>,
23    pub(crate) http_client: Client,
24}
25
26impl WebhookService {
27    pub fn new() -> Self {
28        Self {
29            endpoints: RwLock::new(HashMap::new()),
30            http_client: Client::new(),
31        }
32    }
33
34    pub async fn register_endpoint(
35        &self,
36        mut endpoint: WebhookEndpoint,
37    ) -> IntegrationResult<String> {
38        if endpoint.id.is_empty() {
39            endpoint.id = Uuid::new_v4().to_string();
40        }
41
42        let endpoint_id = endpoint.id.clone();
43
44        {
45            let mut endpoints = self.endpoints.write().await;
46            endpoints.insert(endpoint_id.clone(), endpoint);
47        }
48
49        Ok(endpoint_id)
50    }
51
52    pub async fn update_endpoint(&self, endpoint: WebhookEndpoint) -> IntegrationResult<()> {
53        {
54            let mut endpoints = self.endpoints.write().await;
55            endpoints.insert(endpoint.id.clone(), endpoint);
56        }
57        Ok(())
58    }
59
60    pub async fn get_endpoint(
61        &self,
62        endpoint_id: &str,
63    ) -> IntegrationResult<Option<WebhookEndpoint>> {
64        let endpoints = self.endpoints.read().await;
65        Ok(endpoints.get(endpoint_id).cloned())
66    }
67
68    pub async fn list_endpoints(&self) -> IntegrationResult<Vec<WebhookEndpoint>> {
69        let endpoints = self.endpoints.read().await;
70        Ok(endpoints.values().cloned().collect())
71    }
72
73    pub async fn remove_endpoint(&self, endpoint_id: &str) -> IntegrationResult<bool> {
74        let mut endpoints = self.endpoints.write().await;
75        Ok(endpoints.remove(endpoint_id).is_some())
76    }
77
78    pub async fn handle_webhook(
79        &self,
80        endpoint_id: &str,
81        request: WebhookRequest,
82    ) -> IntegrationResult<WebhookResponse> {
83        let endpoint = {
84            let endpoints = self.endpoints.read().await;
85            endpoints.get(endpoint_id).cloned().ok_or_else(|| {
86                IntegrationError::Webhook(format!("Endpoint not found: {endpoint_id}"))
87            })?
88        };
89
90        if !endpoint.active {
91            return Ok(WebhookResponse {
92                status: 404,
93                body: Some(serde_json::json!({"error": "Endpoint is inactive"})),
94            });
95        }
96
97        if let (Some(_secret), Some(signature)) = (&endpoint.secret, &request.signature) {
98            if !self.verify_signature_internal(&endpoint, &request.body, signature)? {
99                return Ok(WebhookResponse {
100                    status: 401,
101                    body: Some(serde_json::json!({"error": "Invalid signature"})),
102                });
103            }
104        }
105
106        let event_type = request
107            .headers
108            .get("x-webhook-event")
109            .or_else(|| request.headers.get("x-event-type"))
110            .or_else(|| request.headers.get("x-github-event"))
111            .cloned()
112            .unwrap_or_else(|| "unknown".to_string());
113
114        if !endpoint.events.is_empty()
115            && !endpoint.events.contains(&event_type)
116            && !endpoint.events.contains(&"*".to_string())
117        {
118            return Ok(WebhookResponse {
119                status: 200,
120                body: Some(serde_json::json!({"message": "Event type not subscribed"})),
121            });
122        }
123
124        Ok(WebhookResponse {
125            status: 200,
126            body: Some(serde_json::json!({
127                "message": "Webhook processed successfully",
128                "event_type": event_type,
129                "endpoint_id": endpoint_id
130            })),
131        })
132    }
133
134    pub async fn verify_signature(
135        &self,
136        endpoint_id: &str,
137        payload: &Value,
138        signature: &str,
139    ) -> IntegrationResult<bool> {
140        let endpoint = {
141            let endpoints = self.endpoints.read().await;
142            endpoints.get(endpoint_id).cloned().ok_or_else(|| {
143                IntegrationError::Webhook(format!("Endpoint not found: {endpoint_id}"))
144            })?
145        };
146
147        self.verify_signature_internal(&endpoint, payload, signature)
148    }
149
150    pub(crate) fn verify_signature_internal(
151        &self,
152        endpoint: &WebhookEndpoint,
153        payload: &Value,
154        signature: &str,
155    ) -> IntegrationResult<bool> {
156        let secret = endpoint.secret.as_ref().ok_or_else(|| {
157            IntegrationError::Webhook("No secret configured for endpoint".to_string())
158        })?;
159
160        let expected_signature = self.generate_signature(secret, payload)?;
161
162        Ok(self.secure_compare(&expected_signature, signature))
163    }
164
165    pub(crate) fn generate_signature(
166        &self,
167        secret: &str,
168        payload: &Value,
169    ) -> IntegrationResult<String> {
170        let payload_bytes = serde_json::to_vec(payload)?;
171
172        let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
173            .map_err(|e| IntegrationError::Webhook(format!("Invalid secret: {e}")))?;
174
175        mac.update(&payload_bytes);
176        let result = mac.finalize();
177        let hex_result = hex::encode(result.into_bytes());
178
179        Ok(format!("sha256={hex_result}"))
180    }
181
182    fn secure_compare(&self, a: &str, b: &str) -> bool {
183        if a.len() != b.len() {
184            return false;
185        }
186
187        let mut result = 0u8;
188        for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
189            result |= byte_a ^ byte_b;
190        }
191
192        result == 0
193    }
194}
195
196impl Default for WebhookService {
197    fn default() -> Self {
198        Self::new()
199    }
200}