rust_logic_graph/distributed/
context.rs1use anyhow::{Context as AnyhowContext, Result};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct DistributedContext {
15 pub session_id: String,
17
18 pub data: HashMap<String, Value>,
20
21 pub metadata: ContextMetadata,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ContextMetadata {
28 pub created_at: u64,
30
31 pub updated_at: u64,
33
34 pub version: u64,
36
37 pub modified_by: Option<String>,
39
40 pub tags: Vec<String>,
42}
43
44impl DistributedContext {
45 pub fn new(session_id: impl Into<String>) -> Self {
56 let now = std::time::SystemTime::now()
57 .duration_since(std::time::UNIX_EPOCH)
58 .unwrap()
59 .as_millis() as u64;
60
61 Self {
62 session_id: session_id.into(),
63 data: HashMap::new(),
64 metadata: ContextMetadata {
65 created_at: now,
66 updated_at: now,
67 version: 1,
68 modified_by: None,
69 tags: Vec::new(),
70 },
71 }
72 }
73
74 pub fn set(&mut self, key: impl Into<String>, value: Value) {
76 self.data.insert(key.into(), value);
77 self.bump_version();
78 }
79
80 pub fn get(&self, key: &str) -> Option<&Value> {
82 self.data.get(key)
83 }
84
85 pub fn remove(&mut self, key: &str) -> Option<Value> {
87 let result = self.data.remove(key);
88 if result.is_some() {
89 self.bump_version();
90 }
91 result
92 }
93
94 pub fn serialize(&self) -> Result<Vec<u8>> {
98 rmp_serde::to_vec(self).context("Failed to serialize distributed context")
99 }
100
101 pub fn deserialize(bytes: &[u8]) -> Result<Self> {
103 rmp_serde::from_slice(bytes).context("Failed to deserialize distributed context")
104 }
105
106 pub fn to_json(&self) -> Result<String> {
108 serde_json::to_string_pretty(self).context("Failed to serialize context to JSON")
109 }
110
111 pub fn from_json(json: &str) -> Result<Self> {
113 serde_json::from_str(json).context("Failed to deserialize context from JSON")
114 }
115
116 pub fn snapshot(&self) -> ContextSnapshot {
118 ContextSnapshot {
119 session_id: self.session_id.clone(),
120 data: self.data.clone(),
121 version: self.metadata.version,
122 timestamp: self.metadata.updated_at,
123 }
124 }
125
126 pub fn merge(&mut self, other: &DistributedContext) {
130 for (key, value) in &other.data {
131 self.data.insert(key.clone(), value.clone());
132 }
133 self.bump_version();
134 }
135
136 fn bump_version(&mut self) {
138 self.metadata.version += 1;
139 self.metadata.updated_at = std::time::SystemTime::now()
140 .duration_since(std::time::UNIX_EPOCH)
141 .unwrap()
142 .as_millis() as u64;
143 }
144
145 pub fn add_tag(&mut self, tag: impl Into<String>) {
147 let tag = tag.into();
148 if !self.metadata.tags.contains(&tag) {
149 self.metadata.tags.push(tag);
150 }
151 }
152
153 pub fn set_modified_by(&mut self, service: impl Into<String>) {
155 self.metadata.modified_by = Some(service.into());
156 self.bump_version();
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
162pub struct ContextSnapshot {
163 pub session_id: String,
164 pub data: HashMap<String, Value>,
165 pub version: u64,
166 pub timestamp: u64,
167}
168
169#[derive(Debug, Clone)]
171pub struct SharedContext {
172 inner: Arc<RwLock<DistributedContext>>,
173}
174
175impl SharedContext {
176 pub fn new(session_id: impl Into<String>) -> Self {
178 Self {
179 inner: Arc::new(RwLock::new(DistributedContext::new(session_id))),
180 }
181 }
182
183 pub async fn get(&self, key: &str) -> Option<Value> {
185 let ctx = self.inner.read().await;
186 ctx.get(key).cloned()
187 }
188
189 pub async fn set(&self, key: impl Into<String>, value: Value) {
191 let mut ctx = self.inner.write().await;
192 ctx.set(key, value);
193 }
194
195 pub async fn serialize(&self) -> Result<Vec<u8>> {
197 let ctx = self.inner.read().await;
198 ctx.serialize()
199 }
200
201 pub async fn version(&self) -> u64 {
203 let ctx = self.inner.read().await;
204 ctx.metadata.version
205 }
206
207 pub async fn snapshot(&self) -> ContextSnapshot {
209 let ctx = self.inner.read().await;
210 ctx.snapshot()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use serde_json::json;
218
219 #[test]
220 fn test_context_creation() {
221 let ctx = DistributedContext::new("test-session");
222 assert_eq!(ctx.session_id, "test-session");
223 assert_eq!(ctx.metadata.version, 1);
224 }
225
226 #[test]
227 fn test_set_and_get() {
228 let mut ctx = DistributedContext::new("test");
229 ctx.set("key1", json!("value1"));
230
231 assert_eq!(ctx.get("key1"), Some(&json!("value1")));
232 assert_eq!(ctx.metadata.version, 2);
233 }
234
235 #[test]
236 fn test_serialization() {
237 let mut ctx = DistributedContext::new("test");
238 ctx.set("user_id", json!("user-123"));
239 ctx.set("count", json!(42));
240
241 let bytes = ctx.serialize().unwrap();
242 let deserialized = DistributedContext::deserialize(&bytes).unwrap();
243
244 assert_eq!(deserialized.session_id, "test");
245 assert_eq!(deserialized.get("user_id"), Some(&json!("user-123")));
246 assert_eq!(deserialized.get("count"), Some(&json!(42)));
247 }
248
249 #[test]
250 fn test_json_serialization() {
251 let mut ctx = DistributedContext::new("test");
252 ctx.set("name", json!("Alice"));
253
254 let json_str = ctx.to_json().unwrap();
255 let deserialized = DistributedContext::from_json(&json_str).unwrap();
256
257 assert_eq!(deserialized.session_id, "test");
258 assert_eq!(deserialized.get("name"), Some(&json!("Alice")));
259 }
260
261 #[test]
262 fn test_merge() {
263 let mut ctx1 = DistributedContext::new("test");
264 ctx1.set("key1", json!("value1"));
265
266 let mut ctx2 = DistributedContext::new("test");
267 ctx2.set("key2", json!("value2"));
268
269 ctx1.merge(&ctx2);
270
271 assert_eq!(ctx1.get("key1"), Some(&json!("value1")));
272 assert_eq!(ctx1.get("key2"), Some(&json!("value2")));
273 }
274
275 #[tokio::test]
276 async fn test_shared_context() {
277 let ctx = SharedContext::new("test");
278
279 ctx.set("key1", json!("value1")).await;
280 let value = ctx.get("key1").await;
281
282 assert_eq!(value, Some(json!("value1")));
283 assert_eq!(ctx.version().await, 2);
284 }
285}