Skip to main content

systemprompt_agent/services/
context_provider.rs

1//! Implementation of ContextProvider trait for the agent module.
2
3use async_trait::async_trait;
4use systemprompt_database::DbPool;
5use systemprompt_identifiers::{ContextId, SessionId, UserId};
6use systemprompt_traits::{ContextProvider, ContextProviderError, ContextWithStats};
7
8use crate::repository::ContextRepository;
9
10#[derive(Debug, Clone)]
11pub struct ContextProviderService {
12    repo: ContextRepository,
13}
14
15impl ContextProviderService {
16    pub fn new(db_pool: DbPool) -> Self {
17        Self {
18            repo: ContextRepository::new(db_pool),
19        }
20    }
21}
22
23#[async_trait]
24impl ContextProvider for ContextProviderService {
25    async fn list_contexts_with_stats(
26        &self,
27        user_id: &str,
28    ) -> Result<Vec<ContextWithStats>, ContextProviderError> {
29        let user_id = UserId::new(user_id);
30        let contexts = self
31            .repo
32            .list_contexts_with_stats(&user_id)
33            .await
34            .map_err(|e| ContextProviderError::Database(e.to_string()))?;
35
36        Ok(contexts
37            .into_iter()
38            .map(|c| ContextWithStats {
39                context_id: c.context_id.to_string(),
40                user_id: c.user_id.to_string(),
41                name: c.name,
42                created_at: c.created_at,
43                updated_at: c.updated_at,
44                task_count: c.task_count,
45                message_count: c.message_count,
46                last_message_at: c.last_message_at,
47            })
48            .collect())
49    }
50
51    async fn get_context(
52        &self,
53        context_id: &str,
54        user_id: &str,
55    ) -> Result<ContextWithStats, ContextProviderError> {
56        let context_id = ContextId::new(context_id);
57        let user_id = UserId::new(user_id);
58
59        let context = self
60            .repo
61            .get_context(&context_id, &user_id)
62            .await
63            .map_err(|e| match e {
64                systemprompt_traits::RepositoryError::NotFound(msg) => {
65                    ContextProviderError::NotFound(msg)
66                },
67                other => ContextProviderError::Database(other.to_string()),
68            })?;
69
70        let all_contexts = self
71            .repo
72            .list_contexts_with_stats(&user_id)
73            .await
74            .map_err(|e| ContextProviderError::Database(e.to_string()))?;
75
76        let context_with_stats = all_contexts
77            .into_iter()
78            .find(|c| c.context_id == context.context_id)
79            .ok_or_else(|| {
80                ContextProviderError::NotFound(format!("Context {} not found", context_id))
81            })?;
82
83        Ok(ContextWithStats {
84            context_id: context_with_stats.context_id.to_string(),
85            user_id: context_with_stats.user_id.to_string(),
86            name: context_with_stats.name,
87            created_at: context_with_stats.created_at,
88            updated_at: context_with_stats.updated_at,
89            task_count: context_with_stats.task_count,
90            message_count: context_with_stats.message_count,
91            last_message_at: context_with_stats.last_message_at,
92        })
93    }
94
95    async fn create_context(
96        &self,
97        user_id: &str,
98        session_id: Option<&str>,
99        name: &str,
100    ) -> Result<String, ContextProviderError> {
101        let user_id = UserId::new(user_id);
102        let session_id = session_id.map(SessionId::new);
103
104        let context_id = self
105            .repo
106            .create_context(&user_id, session_id.as_ref(), name)
107            .await
108            .map_err(|e| ContextProviderError::Database(e.to_string()))?;
109
110        Ok(context_id.to_string())
111    }
112
113    async fn update_context_name(
114        &self,
115        context_id: &str,
116        user_id: &str,
117        name: &str,
118    ) -> Result<(), ContextProviderError> {
119        let context_id = ContextId::new(context_id);
120        let user_id = UserId::new(user_id);
121
122        self.repo
123            .update_context_name(&context_id, &user_id, name)
124            .await
125            .map_err(|e| match e {
126                systemprompt_traits::RepositoryError::NotFound(msg) => {
127                    ContextProviderError::NotFound(msg)
128                },
129                other => ContextProviderError::Database(other.to_string()),
130            })
131    }
132
133    async fn delete_context(
134        &self,
135        context_id: &str,
136        user_id: &str,
137    ) -> Result<(), ContextProviderError> {
138        let context_id = ContextId::new(context_id);
139        let user_id = UserId::new(user_id);
140
141        self.repo
142            .delete_context(&context_id, &user_id)
143            .await
144            .map_err(|e| match e {
145                systemprompt_traits::RepositoryError::NotFound(msg) => {
146                    ContextProviderError::NotFound(msg)
147                },
148                other => ContextProviderError::Database(other.to_string()),
149            })
150    }
151}