systemprompt_api/services/middleware/context/sources/
payload.rs1use axum::body::Body;
2use axum::extract::Request;
3use serde_json::Value;
4use systemprompt_models::execution::{ContextExtractionError, ContextIdSource};
5
6#[derive(Debug, Clone, Copy)]
7pub struct PayloadSource;
8
9impl PayloadSource {
10 pub fn extract_context_source(
11 body_bytes: &[u8],
12 ) -> Result<ContextIdSource, ContextExtractionError> {
13 let payload: Value = serde_json::from_slice(body_bytes).map_err(|e| {
14 ContextExtractionError::InvalidHeaderValue {
15 header: "payload".to_string(),
16 reason: format!("Invalid JSON: {e}"),
17 }
18 })?;
19
20 let method = payload.get("method").and_then(|m| m.as_str()).unwrap_or("");
21
22 if method.starts_with("tasks/") {
23 let task_id = payload
24 .get("params")
25 .and_then(|p| p.get("id"))
26 .and_then(|id| id.as_str())
27 .map(ToString::to_string)
28 .ok_or_else(|| ContextExtractionError::InvalidHeaderValue {
29 header: "params.id".to_string(),
30 reason: "Task ID required for task methods".to_string(),
31 })?;
32
33 return Ok(ContextIdSource::FromTask {
34 task_id: systemprompt_identifiers::TaskId::new(task_id),
35 });
36 }
37
38 payload
39 .get("params")
40 .and_then(|p| p.get("message"))
41 .and_then(|m| m.get("contextId"))
42 .and_then(|c| c.as_str())
43 .map(|s| ContextIdSource::Direct(s.to_string()))
44 .ok_or(ContextExtractionError::MissingContextId)
45 }
46
47 pub async fn read_and_reconstruct(
48 request: Request<Body>,
49 ) -> Result<(Vec<u8>, Request<Body>), ContextExtractionError> {
50 let (parts, body) = request.into_parts();
51
52 let body_bytes = axum::body::to_bytes(body, usize::MAX)
53 .await
54 .map_err(|e| ContextExtractionError::InvalidHeaderValue {
55 header: "body".to_string(),
56 reason: format!("Failed to read body: {e}"),
57 })?
58 .to_vec();
59
60 let new_body = Body::from(body_bytes.clone());
61 let new_request = Request::from_parts(parts, new_body);
62
63 Ok((body_bytes, new_request))
64 }
65}