Skip to main content

systemprompt_api/services/middleware/context/sources/
payload.rs

1use axum::body::Body;
2use axum::extract::Request;
3use serde_json::Value;
4use systemprompt_models::execution::{ContextExtractionError, ContextIdSource};
5
6#[derive(Debug, Clone, Copy)]
7pub struct PayloadSource;
8
9impl PayloadSource {
10    pub fn extract_context_source(
11        body_bytes: &[u8],
12    ) -> Result<ContextIdSource, ContextExtractionError> {
13        // JSON: A2A JSON-RPC envelope is an external protocol boundary; the
14        // method name drives which typed field is read, so the shape is dynamic.
15        let payload: Value = serde_json::from_slice(body_bytes).map_err(|e| {
16            ContextExtractionError::InvalidHeaderValue {
17                header: "payload".to_string(),
18                reason: format!("Invalid JSON: {e}"),
19            }
20        })?;
21
22        let method = payload.get("method").and_then(|m| m.as_str()).unwrap_or("");
23
24        if method.starts_with("tasks/") {
25            let task_id = payload
26                .get("params")
27                .and_then(|p| p.get("id"))
28                .and_then(|id| id.as_str())
29                .map(ToString::to_string)
30                .ok_or_else(|| ContextExtractionError::InvalidHeaderValue {
31                    header: "params.id".to_string(),
32                    reason: "Task ID required for task methods".to_string(),
33                })?;
34
35            return Ok(ContextIdSource::FromTask {
36                task_id: systemprompt_identifiers::TaskId::new(task_id),
37            });
38        }
39
40        payload
41            .get("params")
42            .and_then(|p| p.get("message"))
43            .and_then(|m| m.get("contextId"))
44            .and_then(|c| c.as_str())
45            .map(|s| ContextIdSource::Direct(s.to_string()))
46            .ok_or(ContextExtractionError::MissingContextId)
47    }
48
49    pub async fn read_and_reconstruct(
50        request: Request<Body>,
51    ) -> Result<(Vec<u8>, Request<Body>), ContextExtractionError> {
52        let (parts, body) = request.into_parts();
53
54        let body_bytes = axum::body::to_bytes(body, usize::MAX)
55            .await
56            .map_err(|e| ContextExtractionError::InvalidHeaderValue {
57                header: "body".to_string(),
58                reason: format!("Failed to read body: {e}"),
59            })?
60            .to_vec();
61
62        let new_body = Body::from(body_bytes.clone());
63        let new_request = Request::from_parts(parts, new_body);
64
65        Ok((body_bytes, new_request))
66    }
67}