Skip to main content

systemprompt_agent/services/
context_provider.rs

1use async_trait::async_trait;
2use systemprompt_database::DbPool;
3use systemprompt_identifiers::{ContextId, SessionId, UserId};
4use systemprompt_traits::{ContextProvider, ContextProviderError, ContextWithStats};
5
6use crate::repository::ContextRepository;
7
8#[derive(Debug, Clone)]
9pub struct ContextProviderService {
10    repo: ContextRepository,
11}
12
13impl ContextProviderService {
14    pub fn new(db_pool: &DbPool) -> anyhow::Result<Self> {
15        Ok(Self {
16            repo: ContextRepository::new(db_pool)?,
17        })
18    }
19}
20
21#[async_trait]
22impl ContextProvider for ContextProviderService {
23    async fn list_contexts_with_stats(
24        &self,
25        user_id: &UserId,
26    ) -> Result<Vec<ContextWithStats>, ContextProviderError> {
27        let contexts = self
28            .repo
29            .list_contexts_with_stats(user_id)
30            .await
31            .map_err(|e| ContextProviderError::Database(e.to_string()))?;
32
33        Ok(contexts
34            .into_iter()
35            .map(|c| ContextWithStats {
36                context_id: c.context_id,
37                user_id: c.user_id,
38                name: c.name,
39                created_at: c.created_at,
40                updated_at: c.updated_at,
41                task_count: c.task_count,
42                message_count: c.message_count,
43                last_message_at: c.last_message_at,
44            })
45            .collect())
46    }
47
48    async fn get_context(
49        &self,
50        context_id: &ContextId,
51        user_id: &UserId,
52    ) -> Result<ContextWithStats, ContextProviderError> {
53        let context = self
54            .repo
55            .get_context(context_id, user_id)
56            .await
57            .map_err(|e| match e {
58                systemprompt_traits::RepositoryError::NotFound(msg) => {
59                    ContextProviderError::NotFound(msg)
60                },
61                other => ContextProviderError::Database(other.to_string()),
62            })?;
63
64        let all_contexts = self
65            .repo
66            .list_contexts_with_stats(user_id)
67            .await
68            .map_err(|e| ContextProviderError::Database(e.to_string()))?;
69
70        let context_with_stats = all_contexts
71            .into_iter()
72            .find(|c| c.context_id == context.context_id)
73            .ok_or_else(|| {
74                ContextProviderError::NotFound(format!("Context {} not found", context_id))
75            })?;
76
77        Ok(ContextWithStats {
78            context_id: context_with_stats.context_id,
79            user_id: context_with_stats.user_id,
80            name: context_with_stats.name,
81            created_at: context_with_stats.created_at,
82            updated_at: context_with_stats.updated_at,
83            task_count: context_with_stats.task_count,
84            message_count: context_with_stats.message_count,
85            last_message_at: context_with_stats.last_message_at,
86        })
87    }
88
89    async fn create_context(
90        &self,
91        user_id: &UserId,
92        session_id: Option<&SessionId>,
93        name: &str,
94    ) -> Result<ContextId, ContextProviderError> {
95        self.repo
96            .create_context(user_id, session_id, name)
97            .await
98            .map_err(|e| ContextProviderError::Database(e.to_string()))
99    }
100
101    async fn update_context_name(
102        &self,
103        context_id: &ContextId,
104        user_id: &UserId,
105        name: &str,
106    ) -> Result<(), ContextProviderError> {
107        self.repo
108            .update_context_name(context_id, user_id, name)
109            .await
110            .map_err(|e| match e {
111                systemprompt_traits::RepositoryError::NotFound(msg) => {
112                    ContextProviderError::NotFound(msg)
113                },
114                other => ContextProviderError::Database(other.to_string()),
115            })
116    }
117
118    async fn delete_context(
119        &self,
120        context_id: &ContextId,
121        user_id: &UserId,
122    ) -> Result<(), ContextProviderError> {
123        self.repo
124            .delete_context(context_id, user_id)
125            .await
126            .map_err(|e| match e {
127                systemprompt_traits::RepositoryError::NotFound(msg) => {
128                    ContextProviderError::NotFound(msg)
129                },
130                other => ContextProviderError::Database(other.to_string()),
131            })
132    }
133}