systemprompt_agent/services/
context_provider.rs1use async_trait::async_trait;
4use systemprompt_database::DbPool;
5use systemprompt_identifiers::{ContextId, SessionId, UserId};
6use systemprompt_traits::{ContextProvider, ContextProviderError, ContextWithStats};
7
8use crate::repository::ContextRepository;
9
10#[derive(Debug, Clone)]
11pub struct ContextProviderService {
12 repo: ContextRepository,
13}
14
15impl ContextProviderService {
16 pub fn new(db_pool: DbPool) -> Self {
17 Self {
18 repo: ContextRepository::new(db_pool),
19 }
20 }
21}
22
23#[async_trait]
24impl ContextProvider for ContextProviderService {
25 async fn list_contexts_with_stats(
26 &self,
27 user_id: &str,
28 ) -> Result<Vec<ContextWithStats>, ContextProviderError> {
29 let user_id = UserId::new(user_id);
30 let contexts = self
31 .repo
32 .list_contexts_with_stats(&user_id)
33 .await
34 .map_err(|e| ContextProviderError::Database(e.to_string()))?;
35
36 Ok(contexts
37 .into_iter()
38 .map(|c| ContextWithStats {
39 context_id: c.context_id.to_string(),
40 user_id: c.user_id.to_string(),
41 name: c.name,
42 created_at: c.created_at,
43 updated_at: c.updated_at,
44 task_count: c.task_count,
45 message_count: c.message_count,
46 last_message_at: c.last_message_at,
47 })
48 .collect())
49 }
50
51 async fn get_context(
52 &self,
53 context_id: &str,
54 user_id: &str,
55 ) -> Result<ContextWithStats, ContextProviderError> {
56 let context_id = ContextId::new(context_id);
57 let user_id = UserId::new(user_id);
58
59 let context = self
60 .repo
61 .get_context(&context_id, &user_id)
62 .await
63 .map_err(|e| match e {
64 systemprompt_traits::RepositoryError::NotFound(msg) => {
65 ContextProviderError::NotFound(msg)
66 },
67 other => ContextProviderError::Database(other.to_string()),
68 })?;
69
70 let all_contexts = self
71 .repo
72 .list_contexts_with_stats(&user_id)
73 .await
74 .map_err(|e| ContextProviderError::Database(e.to_string()))?;
75
76 let context_with_stats = all_contexts
77 .into_iter()
78 .find(|c| c.context_id == context.context_id)
79 .ok_or_else(|| {
80 ContextProviderError::NotFound(format!("Context {} not found", context_id))
81 })?;
82
83 Ok(ContextWithStats {
84 context_id: context_with_stats.context_id.to_string(),
85 user_id: context_with_stats.user_id.to_string(),
86 name: context_with_stats.name,
87 created_at: context_with_stats.created_at,
88 updated_at: context_with_stats.updated_at,
89 task_count: context_with_stats.task_count,
90 message_count: context_with_stats.message_count,
91 last_message_at: context_with_stats.last_message_at,
92 })
93 }
94
95 async fn create_context(
96 &self,
97 user_id: &str,
98 session_id: Option<&str>,
99 name: &str,
100 ) -> Result<String, ContextProviderError> {
101 let user_id = UserId::new(user_id);
102 let session_id = session_id.map(SessionId::new);
103
104 let context_id = self
105 .repo
106 .create_context(&user_id, session_id.as_ref(), name)
107 .await
108 .map_err(|e| ContextProviderError::Database(e.to_string()))?;
109
110 Ok(context_id.to_string())
111 }
112
113 async fn update_context_name(
114 &self,
115 context_id: &str,
116 user_id: &str,
117 name: &str,
118 ) -> Result<(), ContextProviderError> {
119 let context_id = ContextId::new(context_id);
120 let user_id = UserId::new(user_id);
121
122 self.repo
123 .update_context_name(&context_id, &user_id, name)
124 .await
125 .map_err(|e| match e {
126 systemprompt_traits::RepositoryError::NotFound(msg) => {
127 ContextProviderError::NotFound(msg)
128 },
129 other => ContextProviderError::Database(other.to_string()),
130 })
131 }
132
133 async fn delete_context(
134 &self,
135 context_id: &str,
136 user_id: &str,
137 ) -> Result<(), ContextProviderError> {
138 let context_id = ContextId::new(context_id);
139 let user_id = UserId::new(user_id);
140
141 self.repo
142 .delete_context(&context_id, &user_id)
143 .await
144 .map_err(|e| match e {
145 systemprompt_traits::RepositoryError::NotFound(msg) => {
146 ContextProviderError::NotFound(msg)
147 },
148 other => ContextProviderError::Database(other.to_string()),
149 })
150 }
151}