Skip to main content

systemprompt_agent/services/
context.rs

1use crate::services::shared::{AgentServiceError, Result};
2use base64::Engine;
3use systemprompt_database::DbPool;
4use systemprompt_models::{
5    AiContentPart, AiMessage, MessageRole, is_supported_audio, is_supported_image,
6    is_supported_text, is_supported_video,
7};
8
9use crate::models::a2a::{Artifact, FilePart, Message, Part};
10use crate::repository::task::TaskRepository;
11
12#[derive(Debug, Clone)]
13pub struct ContextService {
14    task_repo: TaskRepository,
15}
16
17impl ContextService {
18    pub fn new(db_pool: &DbPool) -> Result<Self> {
19        Ok(Self {
20            task_repo: TaskRepository::new(db_pool)?,
21        })
22    }
23
24    pub async fn load_conversation_history(
25        &self,
26        context_id: &systemprompt_identifiers::ContextId,
27    ) -> Result<Vec<AiMessage>> {
28        let tasks = self
29            .task_repo
30            .list_tasks_by_context(context_id)
31            .await
32            .map_err(|e| {
33                AgentServiceError::Internal(format!("Failed to load conversation history: {}", e))
34            })?;
35
36        let mut history_messages = Vec::new();
37
38        for task in tasks {
39            if let Some(task_history) = task.history {
40                for msg in task_history {
41                    let (text, parts) = Self::extract_message_content(&msg);
42                    if text.is_empty() && parts.is_empty() {
43                        continue;
44                    }
45
46                    let role = match msg.role {
47                        crate::models::a2a::MessageRole::User => MessageRole::User,
48                        crate::models::a2a::MessageRole::Agent => MessageRole::Assistant,
49                    };
50
51                    history_messages.push(AiMessage {
52                        role,
53                        content: text,
54                        parts,
55                    });
56                }
57            }
58
59            if let Some(artifacts) = task.artifacts {
60                for artifact in artifacts {
61                    let artifact_content = Self::serialize_artifact_for_context(&artifact);
62                    history_messages.push(AiMessage {
63                        role: MessageRole::Assistant,
64                        content: artifact_content,
65                        parts: Vec::new(),
66                    });
67                }
68            }
69        }
70
71        Ok(history_messages)
72    }
73
74    fn extract_message_content(message: &Message) -> (String, Vec<AiContentPart>) {
75        let mut text_content = String::new();
76        let mut content_parts = Vec::new();
77
78        for part in &message.parts {
79            match part {
80                Part::Text(text_part) => {
81                    if text_content.is_empty() {
82                        text_content.clone_from(&text_part.text);
83                    }
84                    content_parts.push(AiContentPart::text(&text_part.text));
85                },
86                Part::File(file_part) => {
87                    if let Some(content_part) = Self::file_to_content_part(file_part) {
88                        content_parts.push(content_part);
89                    }
90                },
91                Part::Data(_) => {},
92            }
93        }
94
95        (text_content, content_parts)
96    }
97
98    fn file_to_content_part(file_part: &FilePart) -> Option<AiContentPart> {
99        let mime_type = file_part.file.mime_type.as_deref()?;
100        let file_name = file_part.file.name.as_deref().unwrap_or("unnamed");
101
102        let bytes = file_part.file.bytes.as_deref()?;
103
104        if is_supported_image(mime_type) {
105            return Some(AiContentPart::image(mime_type, bytes));
106        }
107
108        if is_supported_audio(mime_type) {
109            return Some(AiContentPart::audio(mime_type, bytes));
110        }
111
112        if is_supported_video(mime_type) {
113            return Some(AiContentPart::video(mime_type, bytes));
114        }
115
116        if is_supported_text(mime_type) {
117            return Self::decode_text_file(bytes, file_name, mime_type);
118        }
119
120        tracing::warn!(
121            file_name = %file_name,
122            mime_type = %mime_type,
123            "Unsupported file type - file will not be sent to AI"
124        );
125        None
126    }
127
128    fn decode_text_file(bytes: &str, file_name: &str, mime_type: &str) -> Option<AiContentPart> {
129        let decoded = base64::engine::general_purpose::STANDARD
130            .decode(bytes)
131            .map_err(|e| {
132                tracing::warn!(
133                    file_name = %file_name,
134                    mime_type = %mime_type,
135                    error = %e,
136                    "Failed to decode base64 text file"
137                );
138                e
139            })
140            .ok()?;
141
142        let text_content = String::from_utf8(decoded)
143            .map_err(|e| {
144                tracing::warn!(
145                    file_name = %file_name,
146                    mime_type = %mime_type,
147                    error = %e,
148                    "Failed to decode text file as UTF-8"
149                );
150                e
151            })
152            .ok()?;
153
154        let formatted = format!("[File: {file_name} ({mime_type})]\n{text_content}");
155        Some(AiContentPart::text(formatted))
156    }
157
158    fn serialize_artifact_for_context(artifact: &Artifact) -> String {
159        let artifact_name = artifact.title.as_deref().unwrap_or("unnamed");
160
161        let mut content = format!(
162            "[Artifact: {} (type: {}, id: {})]",
163            artifact_name, artifact.metadata.artifact_type, artifact.id
164        );
165
166        if let Some(description) = &artifact.description {
167            if !description.is_empty() {
168                let truncated = if description.len() > 300 {
169                    format!("{}...", &description[..300])
170                } else {
171                    description.clone()
172                };
173                content.push_str(&format!("\n{truncated}"));
174            }
175        }
176
177        content
178    }
179}