Skip to main content

systemprompt_files/services/
ai_provider.rs

1use async_trait::async_trait;
2use systemprompt_database::DbPool;
3use systemprompt_identifiers::{FileId, UserId};
4use systemprompt_traits::{
5    AiFilePersistenceProvider, AiGeneratedFile, AiProviderError, AiProviderResult,
6    ImageStorageConfig, InsertAiFileParams,
7};
8
9use crate::config::FilesConfig;
10use crate::repository::{FileRepository, InsertFileRequest};
11
12#[derive(Debug)]
13pub struct FilesAiPersistenceProvider {
14    repository: FileRepository,
15}
16
17impl FilesAiPersistenceProvider {
18    pub fn new(db: &DbPool) -> Result<Self, anyhow::Error> {
19        Ok(Self {
20            repository: FileRepository::new(db)?,
21        })
22    }
23
24    pub const fn from_repository(repository: FileRepository) -> Self {
25        Self { repository }
26    }
27}
28
29#[async_trait]
30impl AiFilePersistenceProvider for FilesAiPersistenceProvider {
31    async fn insert_file(&self, params: InsertAiFileParams) -> AiProviderResult<()> {
32        let file_id = FileId::new(params.id.to_string());
33        let mut request =
34            InsertFileRequest::new(file_id, params.path, params.public_url, params.mime_type)
35                .with_ai_content(true)
36                .with_metadata(params.metadata);
37
38        if let Some(size) = params.size_bytes {
39            request = request.with_size(size);
40        }
41
42        if let Some(user_id) = params.user_id {
43            request = request.with_user_id(user_id);
44        }
45
46        if let Some(session_id) = params.session_id {
47            request = request.with_session_id(session_id);
48        }
49
50        if let Some(trace_id) = params.trace_id {
51            request = request.with_trace_id(trace_id);
52        }
53
54        if let Some(context_id) = params.context_id {
55            request = request.with_context_id(context_id);
56        }
57
58        self.repository
59            .insert(request)
60            .await
61            .map(|_| ())
62            .map_err(|e| AiProviderError::Internal(e.to_string()))
63    }
64
65    async fn find_by_id(&self, id: &FileId) -> AiProviderResult<Option<AiGeneratedFile>> {
66        let file = self
67            .repository
68            .find_by_id(id)
69            .await
70            .map_err(|e| AiProviderError::Internal(e.to_string()))?;
71
72        Ok(file.map(|f| AiGeneratedFile {
73            id: f.id,
74            path: f.path,
75            public_url: f.public_url,
76            mime_type: f.mime_type,
77            size_bytes: f.size_bytes,
78            ai_content: f.ai_content,
79            metadata: f.metadata,
80            user_id: f.user_id,
81            session_id: f.session_id,
82            trace_id: f.trace_id,
83            context_id: f.context_id,
84            created_at: f.created_at,
85            updated_at: f.updated_at,
86            deleted_at: f.deleted_at,
87        }))
88    }
89
90    async fn list_by_user(
91        &self,
92        user_id: &UserId,
93        limit: i64,
94        offset: i64,
95    ) -> AiProviderResult<Vec<AiGeneratedFile>> {
96        let files = self
97            .repository
98            .list_by_user(user_id, limit, offset)
99            .await
100            .map_err(|e| AiProviderError::Internal(e.to_string()))?;
101
102        Ok(files
103            .into_iter()
104            .map(|f| AiGeneratedFile {
105                id: f.id,
106                path: f.path,
107                public_url: f.public_url,
108                mime_type: f.mime_type,
109                size_bytes: f.size_bytes,
110                ai_content: f.ai_content,
111                metadata: f.metadata,
112                user_id: f.user_id,
113                session_id: f.session_id,
114                trace_id: f.trace_id,
115                context_id: f.context_id,
116                created_at: f.created_at,
117                updated_at: f.updated_at,
118                deleted_at: f.deleted_at,
119            })
120            .collect())
121    }
122
123    async fn delete(&self, id: &FileId) -> AiProviderResult<()> {
124        self.repository
125            .delete(id)
126            .await
127            .map_err(|e| AiProviderError::Internal(e.to_string()))
128    }
129
130    fn storage_config(&self) -> AiProviderResult<ImageStorageConfig> {
131        let config =
132            FilesConfig::get().map_err(|e| AiProviderError::ConfigurationError(e.to_string()))?;
133
134        Ok(ImageStorageConfig {
135            base_path: config.generated_images(),
136            url_prefix: format!("{}/images/generated", config.url_prefix()),
137        })
138    }
139}