1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use anyhow::{Context as _, Result};
11use async_trait::async_trait;
12use parking_lot::RwLock;
13use tracing::{debug, instrument};
14
15use crate::{
16 config::{BotConfig, PipelineConfig},
17 context::Context,
18 error::Error,
19 message::{Message, Response},
20};
21
22pub struct MessagePipeline {
24 #[allow(dead_code)]
25 config: PipelineConfig,
26 stages: Vec<Box<dyn PipelineStage>>,
27 middleware: Vec<Box<dyn PipelineMiddleware>>,
28 metrics: Arc<PipelineMetrics>,
29}
30
31impl MessagePipeline {
32 #[instrument(skip(config))]
38 pub async fn new(config: &BotConfig) -> Result<Self> {
39 debug!("Creating message pipeline");
40
41 let mut stages: Vec<Box<dyn PipelineStage>> = Vec::new();
42
43 for stage_name in &config.pipeline_config.enabled_stages {
45 let stage = Self::create_stage(stage_name, config)?;
46 stages.push(stage);
47 }
48
49 let middleware = vec![
51 Box::new(LoggingMiddleware::new()) as Box<dyn PipelineMiddleware>,
52 Box::new(MetricsMiddleware::new()) as Box<dyn PipelineMiddleware>,
53 Box::new(TimeoutMiddleware::new(
54 config.pipeline_config.max_processing_time,
55 )) as Box<dyn PipelineMiddleware>,
56 ];
57
58 Ok(Self {
59 config: config.pipeline_config.clone(),
60 stages,
61 middleware,
62 metrics: Arc::new(PipelineMetrics::new()),
63 })
64 }
65
66 #[instrument(skip(self, message, context))]
72 pub async fn process(
73 &self,
74 mut message: Message,
75 context: Arc<RwLock<Context>>,
76 ) -> Result<Response> {
77 let start = std::time::Instant::now();
78 self.metrics.increment_requests();
79
80 for mw in &self.middleware {
82 message = mw.before_pipeline(message).await?;
83 }
84
85 let mut pipeline_ctx = PipelineContext {
87 message,
88 context,
89 metadata: HashMap::default(),
90 };
91
92 for stage in &self.stages {
94 debug!("Processing stage: {}", stage.name());
95 pipeline_ctx = stage.process(pipeline_ctx).await?;
96 }
97
98 let mut response = self.generate_response(pipeline_ctx)?;
100
101 for mw in self.middleware.iter().rev() {
103 response = mw.after_pipeline(response).await?;
104 }
105
106 let duration = start.elapsed();
108 self.metrics.record_processing_time(duration);
109
110 debug!("Pipeline processed in {:?}", duration);
111 Ok(response)
112 }
113
114 pub fn add_stage(&mut self, stage: Box<dyn PipelineStage>) {
116 self.stages.push(stage);
117 }
118
119 pub fn add_middleware(&mut self, middleware: Box<dyn PipelineMiddleware>) {
121 self.middleware.push(middleware);
122 }
123
124 #[must_use]
126 pub fn metrics(&self) -> &PipelineMetrics {
127 &self.metrics
128 }
129
130 fn create_stage(name: &str, config: &BotConfig) -> Result<Box<dyn PipelineStage>> {
133 match name {
134 "sanitize" => Ok(Box::new(SanitizeStage::new())),
135 "enrich" => Ok(Box::new(EnrichStage::new())),
136 "route" => Ok(Box::new(RouteStage::new())),
137 "process" => Ok(Box::new(ProcessStage::new(config.clone()))),
138 "format" => Ok(Box::new(FormatStage::new())),
139 _ => Err(Error::Configuration(format!("Unknown pipeline stage: {name}")).into()),
140 }
141 }
142
143 #[allow(clippy::unused_self)]
144 fn generate_response(&self, ctx: PipelineContext) -> Result<Response> {
145 if let Some(response) = ctx.metadata.get("response") {
147 let response: Response = serde_json::from_value(response.clone())
148 .context("Failed to deserialize response")?;
149 Ok(response)
150 } else {
151 Ok(Response::text(
153 ctx.message.conversation_id,
154 "Message processed successfully",
155 ))
156 }
157 }
158}
159
160#[derive(Debug)]
162pub struct PipelineContext {
163 pub message: Message,
165 pub context: Arc<RwLock<Context>>,
167 pub metadata: std::collections::HashMap<String, serde_json::Value>,
169}
170
171#[async_trait]
173pub trait PipelineStage: Send + Sync {
174 fn name(&self) -> &str;
176
177 async fn process(&self, ctx: PipelineContext) -> Result<PipelineContext>;
179}
180
181#[async_trait]
183pub trait PipelineMiddleware: Send + Sync {
184 async fn before_pipeline(&self, message: Message) -> Result<Message> {
186 Ok(message)
187 }
188
189 async fn after_pipeline(&self, response: Response) -> Result<Response> {
191 Ok(response)
192 }
193}
194
195struct SanitizeStage;
197
198impl SanitizeStage {
199 fn new() -> Self {
200 Self
201 }
202}
203
204#[async_trait]
205impl PipelineStage for SanitizeStage {
206 fn name(&self) -> &str {
207 "sanitize"
208 }
209
210 async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
211 ctx.message.content = self.sanitize_content(&ctx.message.content);
213
214 ctx.message
216 .validate()
217 .context("Message validation failed")?;
218
219 self.sanitize_metadata(&mut ctx.message.metadata);
221
222 Ok(ctx)
223 }
224}
225
226impl SanitizeStage {
227 #[allow(clippy::unused_self)]
228 fn sanitize_content(&self, content: &str) -> String {
229 let sanitized = content
231 .chars()
232 .filter(|c| !c.is_control() || c.is_whitespace())
233 .collect::<String>();
234
235 sanitized
237 .lines()
238 .map(str::trim)
239 .filter(|line| !line.is_empty())
240 .collect::<Vec<_>>()
241 .join("\n")
242 }
243
244 #[allow(clippy::unused_self)]
245 fn sanitize_metadata(
246 &self,
247 metadata: &mut std::collections::HashMap<String, serde_json::Value>,
248 ) {
249 const SENSITIVE_KEYS: &[&str] = &["password", "token", "secret", "api_key", "auth"];
251
252 metadata.retain(|key, _| {
253 !SENSITIVE_KEYS
254 .iter()
255 .any(|&sensitive| key.to_lowercase().contains(sensitive))
256 });
257 }
258}
259
260struct EnrichStage;
262
263impl EnrichStage {
264 fn new() -> Self {
265 Self
266 }
267}
268
269#[async_trait]
270impl PipelineStage for EnrichStage {
271 fn name(&self) -> &str {
272 "enrich"
273 }
274
275 async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
276 ctx.metadata.insert(
278 "processed_at".to_string(),
279 serde_json::json!(chrono::Utc::now()),
280 );
281
282 let context_summary = {
284 let context = ctx.context.read();
285 serde_json::json!({
286 "message_count": context.metadata.message_count,
287 "token_count": context.token_count,
288 "age_seconds": context.age().as_secs(),
289 })
290 };
291 ctx.metadata
292 .insert("context_summary".to_string(), context_summary);
293
294 if !ctx.message.metadata.contains_key("language") {
296 let language = self.detect_language(&ctx.message.content);
297 ctx.message
298 .metadata
299 .insert("language".to_string(), serde_json::json!(language));
300 }
301
302 Ok(ctx)
303 }
304}
305
306impl EnrichStage {
307 #[allow(clippy::unused_self)]
308 fn detect_language(&self, _content: &str) -> &str {
309 "en"
311 }
312}
313
314struct RouteStage;
316
317impl RouteStage {
318 fn new() -> Self {
319 Self
320 }
321}
322
323#[async_trait]
324impl PipelineStage for RouteStage {
325 fn name(&self) -> &str {
326 "route"
327 }
328
329 async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
330 use crate::message::MessageType;
331
332 let route = match ctx.message.message_type {
334 MessageType::Command => "command",
335 MessageType::System => "system",
336 MessageType::Error => "error",
337 _ if ctx.message.has_attachments() => "media",
338 _ => "default",
339 };
340
341 ctx.metadata
342 .insert("route".to_string(), serde_json::json!(route));
343
344 match route {
346 "command" => {
347 if let Some(command) = self.extract_command(&ctx.message.content) {
348 ctx.metadata
349 .insert("command".to_string(), serde_json::json!(command));
350 }
351 }
352 "media" => {
353 let media_types: Vec<String> = ctx
354 .message
355 .attachments
356 .iter()
357 .map(|a| a.mime_type.clone())
358 .collect();
359 ctx.metadata
360 .insert("media_types".to_string(), serde_json::json!(media_types));
361 }
362 _ => {}
363 }
364
365 Ok(ctx)
366 }
367}
368
369impl RouteStage {
370 #[allow(clippy::unused_self)]
371 fn extract_command(&self, content: &str) -> Option<String> {
372 if content.starts_with('/') {
373 content
374 .split_whitespace()
375 .next()
376 .map(|cmd| cmd.trim_start_matches('/').to_string())
377 } else {
378 None
379 }
380 }
381}
382
383struct ProcessStage {
385 #[allow(dead_code)]
386 config: BotConfig,
387}
388
389impl ProcessStage {
390 fn new(config: BotConfig) -> Self {
391 Self { config }
392 }
393}
394
395#[async_trait]
396impl PipelineStage for ProcessStage {
397 fn name(&self) -> &str {
398 "process"
399 }
400
401 async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
402 let route = ctx
406 .metadata
407 .get("route")
408 .and_then(|v| v.as_str())
409 .unwrap_or("default");
410
411 let response_content = match route {
412 "command" => self.process_command(&ctx),
413 "system" => "System message received".to_string(),
414 "error" => "Error processed".to_string(),
415 "media" => format!("Received {} attachment(s)", ctx.message.attachments.len()),
416 _ => format!("Processing message: {}", ctx.message.content),
417 };
418
419 let response = Response::text(ctx.message.conversation_id.clone(), response_content);
420
421 ctx.metadata
422 .insert("response".to_string(), serde_json::to_value(response)?);
423
424 Ok(ctx)
425 }
426}
427
428impl ProcessStage {
429 #[allow(clippy::unused_self)]
430 fn process_command(&self, ctx: &PipelineContext) -> String {
431 let command = ctx
432 .metadata
433 .get("command")
434 .and_then(|v| v.as_str())
435 .unwrap_or("unknown");
436
437 format!("Executing command: {command}")
438 }
439}
440
441struct FormatStage;
443
444impl FormatStage {
445 fn new() -> Self {
446 Self
447 }
448}
449
450#[async_trait]
451impl PipelineStage for FormatStage {
452 fn name(&self) -> &str {
453 "format"
454 }
455
456 async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
457 if let Some(response_value) = ctx.metadata.get_mut("response") {
458 if let Ok(mut response) = serde_json::from_value::<Response>(response_value.clone()) {
459 if let Some(format_pref) = ctx.message.metadata.get("format") {
461 if let Some(format) = format_pref.as_str() {
462 match format {
463 "markdown" => {
464 response.response_type = crate::message::ResponseType::Markdown;
465 }
466 "html" => {
467 response.response_type = crate::message::ResponseType::Html;
468 response.content = self.to_html(&response.content);
469 }
470 "json" => {
471 response.response_type = crate::message::ResponseType::Json;
472 }
473 _ => {}
474 }
475 }
476 }
477
478 *response_value = serde_json::to_value(response)?;
479 }
480 }
481
482 Ok(ctx)
483 }
484}
485
486impl FormatStage {
487 #[allow(clippy::unused_self)]
488 fn to_html(&self, content: &str) -> String {
489 format!(
491 "<p>{}</p>",
492 content
493 .lines()
494 .map(|line| format!("{}<br>", html_escape::encode_text(line)))
495 .collect::<Vec<_>>()
496 .join("\n")
497 )
498 }
499}
500
501struct LoggingMiddleware {
503 enabled: bool,
504}
505
506impl LoggingMiddleware {
507 fn new() -> Self {
508 Self { enabled: true }
509 }
510}
511
512#[async_trait]
513impl PipelineMiddleware for LoggingMiddleware {
514 async fn before_pipeline(&self, message: Message) -> Result<Message> {
515 if self.enabled {
516 debug!("Pipeline processing message: {}", message.id);
517 }
518 Ok(message)
519 }
520
521 async fn after_pipeline(&self, response: Response) -> Result<Response> {
522 if self.enabled {
523 debug!("Pipeline generated response: {}", response.id);
524 }
525 Ok(response)
526 }
527}
528
529struct MetricsMiddleware {
531 start_time: Arc<RwLock<Option<std::time::Instant>>>,
532}
533
534impl MetricsMiddleware {
535 fn new() -> Self {
536 Self {
537 start_time: Arc::new(RwLock::new(None)),
538 }
539 }
540}
541
542#[async_trait]
543impl PipelineMiddleware for MetricsMiddleware {
544 async fn before_pipeline(&self, message: Message) -> Result<Message> {
545 *self.start_time.write() = Some(std::time::Instant::now());
546 Ok(message)
547 }
548
549 async fn after_pipeline(&self, response: Response) -> Result<Response> {
550 if let Some(start) = *self.start_time.read() {
551 let duration = start.elapsed();
552 debug!("Pipeline processing took {:?}", duration);
553 }
554 Ok(response)
555 }
556}
557
558struct TimeoutMiddleware {
560 #[allow(dead_code)]
561 timeout: Duration,
562}
563
564impl TimeoutMiddleware {
565 fn new(timeout: Duration) -> Self {
566 Self { timeout }
567 }
568}
569
570#[async_trait]
571impl PipelineMiddleware for TimeoutMiddleware {
572 async fn before_pipeline(&self, message: Message) -> Result<Message> {
573 Ok(message)
575 }
576}
577
578#[derive(Debug)]
580pub struct PipelineMetrics {
581 requests_total: Arc<RwLock<u64>>,
582 processing_times: Arc<RwLock<Vec<Duration>>>,
583}
584
585impl PipelineMetrics {
586 fn new() -> Self {
587 Self {
588 requests_total: Arc::new(RwLock::new(0)),
589 processing_times: Arc::new(RwLock::new(Vec::new())),
590 }
591 }
592
593 fn increment_requests(&self) {
594 *self.requests_total.write() += 1;
595 }
596
597 fn record_processing_time(&self, duration: Duration) {
598 let mut times = self.processing_times.write();
599 times.push(duration);
600 if times.len() > 1000 {
601 times.remove(0);
602 }
603 }
604
605 #[must_use]
607 pub fn requests_total(&self) -> u64 {
608 *self.requests_total.read()
609 }
610
611 #[must_use]
613 #[allow(clippy::cast_possible_truncation)]
614 pub fn average_processing_time(&self) -> Option<Duration> {
615 let times = self.processing_times.read();
616 if times.is_empty() {
617 return None;
618 }
619
620 let total: Duration = times.iter().sum();
621 Some(total / times.len() as u32)
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628
629 #[tokio::test]
630 async fn test_pipeline_creation() {
631 let config = BotConfig::default();
632 let pipeline = MessagePipeline::new(&config).await;
633 assert!(pipeline.is_ok());
634 }
635
636 #[test]
637 fn test_sanitize_stage() {
638 let stage = SanitizeStage::new();
639 let content = "Hello\x00World\x01Test";
640 let sanitized = stage.sanitize_content(content);
641 assert!(!sanitized.contains('\x00'));
642 assert!(!sanitized.contains('\x01'));
643 }
644
645 #[test]
646 fn test_route_stage_command_extraction() {
647 let stage = RouteStage::new();
648 assert_eq!(stage.extract_command("/help me"), Some("help".to_string()));
649 assert_eq!(stage.extract_command("not a command"), None);
650 }
651}