universal_bot_core/
pipeline.rs

1//! Message processing pipeline
2//!
3//! This module implements the message processing pipeline that handles
4//! sanitization, enrichment, routing, processing, and formatting of messages.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use anyhow::{Context as _, Result};
11use async_trait::async_trait;
12use parking_lot::RwLock;
13use tracing::{debug, instrument};
14
15use crate::{
16    config::{BotConfig, PipelineConfig},
17    context::Context,
18    error::Error,
19    message::{Message, Response},
20};
21
22/// Message processing pipeline
23pub struct MessagePipeline {
24    #[allow(dead_code)]
25    config: PipelineConfig,
26    stages: Vec<Box<dyn PipelineStage>>,
27    middleware: Vec<Box<dyn PipelineMiddleware>>,
28    metrics: Arc<PipelineMetrics>,
29}
30
31impl MessagePipeline {
32    /// Create a new message pipeline
33    ///
34    /// # Errors
35    ///
36    /// Returns an error if pipeline initialization fails.
37    #[instrument(skip(config))]
38    pub async fn new(config: &BotConfig) -> Result<Self> {
39        debug!("Creating message pipeline");
40
41        let mut stages: Vec<Box<dyn PipelineStage>> = Vec::new();
42
43        // Add stages based on configuration
44        for stage_name in &config.pipeline_config.enabled_stages {
45            let stage = Self::create_stage(stage_name, config)?;
46            stages.push(stage);
47        }
48
49        // Add default middleware
50        let middleware = vec![
51            Box::new(LoggingMiddleware::new()) as Box<dyn PipelineMiddleware>,
52            Box::new(MetricsMiddleware::new()) as Box<dyn PipelineMiddleware>,
53            Box::new(TimeoutMiddleware::new(
54                config.pipeline_config.max_processing_time,
55            )) as Box<dyn PipelineMiddleware>,
56        ];
57
58        Ok(Self {
59            config: config.pipeline_config.clone(),
60            stages,
61            middleware,
62            metrics: Arc::new(PipelineMetrics::new()),
63        })
64    }
65
66    /// Process a message through the pipeline
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if any stage in the pipeline fails
71    #[instrument(skip(self, message, context))]
72    pub async fn process(
73        &self,
74        mut message: Message,
75        context: Arc<RwLock<Context>>,
76    ) -> Result<Response> {
77        let start = std::time::Instant::now();
78        self.metrics.increment_requests();
79
80        // Apply middleware pre-processing
81        for mw in &self.middleware {
82            message = mw.before_pipeline(message).await?;
83        }
84
85        // Create pipeline context
86        let mut pipeline_ctx = PipelineContext {
87            message,
88            context,
89            metadata: HashMap::default(),
90        };
91
92        // Process through stages
93        for stage in &self.stages {
94            debug!("Processing stage: {}", stage.name());
95            pipeline_ctx = stage.process(pipeline_ctx).await?;
96        }
97
98        // Generate response
99        let mut response = self.generate_response(pipeline_ctx)?;
100
101        // Apply middleware post-processing
102        for mw in self.middleware.iter().rev() {
103            response = mw.after_pipeline(response).await?;
104        }
105
106        // Record metrics
107        let duration = start.elapsed();
108        self.metrics.record_processing_time(duration);
109
110        debug!("Pipeline processed in {:?}", duration);
111        Ok(response)
112    }
113
114    /// Add a custom stage to the pipeline
115    pub fn add_stage(&mut self, stage: Box<dyn PipelineStage>) {
116        self.stages.push(stage);
117    }
118
119    /// Add middleware to the pipeline
120    pub fn add_middleware(&mut self, middleware: Box<dyn PipelineMiddleware>) {
121        self.middleware.push(middleware);
122    }
123
124    /// Get pipeline metrics
125    #[must_use]
126    pub fn metrics(&self) -> &PipelineMetrics {
127        &self.metrics
128    }
129
130    // Private helper methods
131
132    fn create_stage(name: &str, config: &BotConfig) -> Result<Box<dyn PipelineStage>> {
133        match name {
134            "sanitize" => Ok(Box::new(SanitizeStage::new())),
135            "enrich" => Ok(Box::new(EnrichStage::new())),
136            "route" => Ok(Box::new(RouteStage::new())),
137            "process" => Ok(Box::new(ProcessStage::new(config.clone()))),
138            "format" => Ok(Box::new(FormatStage::new())),
139            _ => Err(Error::Configuration(format!("Unknown pipeline stage: {name}")).into()),
140        }
141    }
142
143    #[allow(clippy::unused_self)]
144    fn generate_response(&self, ctx: PipelineContext) -> Result<Response> {
145        // Extract response from pipeline context
146        if let Some(response) = ctx.metadata.get("response") {
147            let response: Response = serde_json::from_value(response.clone())
148                .context("Failed to deserialize response")?;
149            Ok(response)
150        } else {
151            // Create default response if none was generated
152            Ok(Response::text(
153                ctx.message.conversation_id,
154                "Message processed successfully",
155            ))
156        }
157    }
158}
159
160/// Pipeline processing context
161#[derive(Debug)]
162pub struct PipelineContext {
163    /// The message being processed
164    pub message: Message,
165    /// Conversation context
166    pub context: Arc<RwLock<Context>>,
167    /// Pipeline metadata
168    pub metadata: std::collections::HashMap<String, serde_json::Value>,
169}
170
171/// Trait for pipeline stages
172#[async_trait]
173pub trait PipelineStage: Send + Sync {
174    /// Stage name
175    fn name(&self) -> &str;
176
177    /// Process the pipeline context
178    async fn process(&self, ctx: PipelineContext) -> Result<PipelineContext>;
179}
180
181/// Trait for pipeline middleware
182#[async_trait]
183pub trait PipelineMiddleware: Send + Sync {
184    /// Called before pipeline processing
185    async fn before_pipeline(&self, message: Message) -> Result<Message> {
186        Ok(message)
187    }
188
189    /// Called after pipeline processing
190    async fn after_pipeline(&self, response: Response) -> Result<Response> {
191        Ok(response)
192    }
193}
194
195/// Sanitization stage - cleans and validates input
196struct SanitizeStage;
197
198impl SanitizeStage {
199    fn new() -> Self {
200        Self
201    }
202}
203
204#[async_trait]
205impl PipelineStage for SanitizeStage {
206    fn name(&self) -> &str {
207        "sanitize"
208    }
209
210    async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
211        // Sanitize message content
212        ctx.message.content = self.sanitize_content(&ctx.message.content);
213
214        // Validate message
215        ctx.message
216            .validate()
217            .context("Message validation failed")?;
218
219        // Remove sensitive data from metadata
220        self.sanitize_metadata(&mut ctx.message.metadata);
221
222        Ok(ctx)
223    }
224}
225
226impl SanitizeStage {
227    #[allow(clippy::unused_self)]
228    fn sanitize_content(&self, content: &str) -> String {
229        // Remove control characters
230        let sanitized = content
231            .chars()
232            .filter(|c| !c.is_control() || c.is_whitespace())
233            .collect::<String>();
234
235        // Trim excessive whitespace
236        sanitized
237            .lines()
238            .map(str::trim)
239            .filter(|line| !line.is_empty())
240            .collect::<Vec<_>>()
241            .join("\n")
242    }
243
244    #[allow(clippy::unused_self)]
245    fn sanitize_metadata(
246        &self,
247        metadata: &mut std::collections::HashMap<String, serde_json::Value>,
248    ) {
249        // Remove potentially sensitive keys
250        const SENSITIVE_KEYS: &[&str] = &["password", "token", "secret", "api_key", "auth"];
251
252        metadata.retain(|key, _| {
253            !SENSITIVE_KEYS
254                .iter()
255                .any(|&sensitive| key.to_lowercase().contains(sensitive))
256        });
257    }
258}
259
260/// Enrichment stage - adds context and metadata
261struct EnrichStage;
262
263impl EnrichStage {
264    fn new() -> Self {
265        Self
266    }
267}
268
269#[async_trait]
270impl PipelineStage for EnrichStage {
271    fn name(&self) -> &str {
272        "enrich"
273    }
274
275    async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
276        // Add timestamp if not present
277        ctx.metadata.insert(
278            "processed_at".to_string(),
279            serde_json::json!(chrono::Utc::now()),
280        );
281
282        // Add context summary
283        let context_summary = {
284            let context = ctx.context.read();
285            serde_json::json!({
286                "message_count": context.metadata.message_count,
287                "token_count": context.token_count,
288                "age_seconds": context.age().as_secs(),
289            })
290        };
291        ctx.metadata
292            .insert("context_summary".to_string(), context_summary);
293
294        // Detect language if needed
295        if !ctx.message.metadata.contains_key("language") {
296            let language = self.detect_language(&ctx.message.content);
297            ctx.message
298                .metadata
299                .insert("language".to_string(), serde_json::json!(language));
300        }
301
302        Ok(ctx)
303    }
304}
305
306impl EnrichStage {
307    #[allow(clippy::unused_self)]
308    fn detect_language(&self, _content: &str) -> &str {
309        // Simple language detection (would use a proper library in production)
310        "en"
311    }
312}
313
314/// Routing stage - determines processing path
315struct RouteStage;
316
317impl RouteStage {
318    fn new() -> Self {
319        Self
320    }
321}
322
323#[async_trait]
324impl PipelineStage for RouteStage {
325    fn name(&self) -> &str {
326        "route"
327    }
328
329    async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
330        use crate::message::MessageType;
331
332        // Determine route based on message type
333        let route = match ctx.message.message_type {
334            MessageType::Command => "command",
335            MessageType::System => "system",
336            MessageType::Error => "error",
337            _ if ctx.message.has_attachments() => "media",
338            _ => "default",
339        };
340
341        ctx.metadata
342            .insert("route".to_string(), serde_json::json!(route));
343
344        // Add route-specific metadata
345        match route {
346            "command" => {
347                if let Some(command) = self.extract_command(&ctx.message.content) {
348                    ctx.metadata
349                        .insert("command".to_string(), serde_json::json!(command));
350                }
351            }
352            "media" => {
353                let media_types: Vec<String> = ctx
354                    .message
355                    .attachments
356                    .iter()
357                    .map(|a| a.mime_type.clone())
358                    .collect();
359                ctx.metadata
360                    .insert("media_types".to_string(), serde_json::json!(media_types));
361            }
362            _ => {}
363        }
364
365        Ok(ctx)
366    }
367}
368
369impl RouteStage {
370    #[allow(clippy::unused_self)]
371    fn extract_command(&self, content: &str) -> Option<String> {
372        if content.starts_with('/') {
373            content
374                .split_whitespace()
375                .next()
376                .map(|cmd| cmd.trim_start_matches('/').to_string())
377        } else {
378            None
379        }
380    }
381}
382
383/// Processing stage - main AI processing
384struct ProcessStage {
385    #[allow(dead_code)]
386    config: BotConfig,
387}
388
389impl ProcessStage {
390    fn new(config: BotConfig) -> Self {
391        Self { config }
392    }
393}
394
395#[async_trait]
396impl PipelineStage for ProcessStage {
397    fn name(&self) -> &str {
398        "process"
399    }
400
401    async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
402        // This is where we would integrate with AI providers
403        // For now, create a simple response
404
405        let route = ctx
406            .metadata
407            .get("route")
408            .and_then(|v| v.as_str())
409            .unwrap_or("default");
410
411        let response_content = match route {
412            "command" => self.process_command(&ctx),
413            "system" => "System message received".to_string(),
414            "error" => "Error processed".to_string(),
415            "media" => format!("Received {} attachment(s)", ctx.message.attachments.len()),
416            _ => format!("Processing message: {}", ctx.message.content),
417        };
418
419        let response = Response::text(ctx.message.conversation_id.clone(), response_content);
420
421        ctx.metadata
422            .insert("response".to_string(), serde_json::to_value(response)?);
423
424        Ok(ctx)
425    }
426}
427
428impl ProcessStage {
429    #[allow(clippy::unused_self)]
430    fn process_command(&self, ctx: &PipelineContext) -> String {
431        let command = ctx
432            .metadata
433            .get("command")
434            .and_then(|v| v.as_str())
435            .unwrap_or("unknown");
436
437        format!("Executing command: {command}")
438    }
439}
440
441/// Formatting stage - formats the response
442struct FormatStage;
443
444impl FormatStage {
445    fn new() -> Self {
446        Self
447    }
448}
449
450#[async_trait]
451impl PipelineStage for FormatStage {
452    fn name(&self) -> &str {
453        "format"
454    }
455
456    async fn process(&self, mut ctx: PipelineContext) -> Result<PipelineContext> {
457        if let Some(response_value) = ctx.metadata.get_mut("response") {
458            if let Ok(mut response) = serde_json::from_value::<Response>(response_value.clone()) {
459                // Apply formatting based on preferences
460                if let Some(format_pref) = ctx.message.metadata.get("format") {
461                    if let Some(format) = format_pref.as_str() {
462                        match format {
463                            "markdown" => {
464                                response.response_type = crate::message::ResponseType::Markdown;
465                            }
466                            "html" => {
467                                response.response_type = crate::message::ResponseType::Html;
468                                response.content = self.to_html(&response.content);
469                            }
470                            "json" => {
471                                response.response_type = crate::message::ResponseType::Json;
472                            }
473                            _ => {}
474                        }
475                    }
476                }
477
478                *response_value = serde_json::to_value(response)?;
479            }
480        }
481
482        Ok(ctx)
483    }
484}
485
486impl FormatStage {
487    #[allow(clippy::unused_self)]
488    fn to_html(&self, content: &str) -> String {
489        // Simple HTML conversion
490        format!(
491            "<p>{}</p>",
492            content
493                .lines()
494                .map(|line| format!("{}<br>", html_escape::encode_text(line)))
495                .collect::<Vec<_>>()
496                .join("\n")
497        )
498    }
499}
500
501/// Logging middleware
502struct LoggingMiddleware {
503    enabled: bool,
504}
505
506impl LoggingMiddleware {
507    fn new() -> Self {
508        Self { enabled: true }
509    }
510}
511
512#[async_trait]
513impl PipelineMiddleware for LoggingMiddleware {
514    async fn before_pipeline(&self, message: Message) -> Result<Message> {
515        if self.enabled {
516            debug!("Pipeline processing message: {}", message.id);
517        }
518        Ok(message)
519    }
520
521    async fn after_pipeline(&self, response: Response) -> Result<Response> {
522        if self.enabled {
523            debug!("Pipeline generated response: {}", response.id);
524        }
525        Ok(response)
526    }
527}
528
529/// Metrics middleware
530struct MetricsMiddleware {
531    start_time: Arc<RwLock<Option<std::time::Instant>>>,
532}
533
534impl MetricsMiddleware {
535    fn new() -> Self {
536        Self {
537            start_time: Arc::new(RwLock::new(None)),
538        }
539    }
540}
541
542#[async_trait]
543impl PipelineMiddleware for MetricsMiddleware {
544    async fn before_pipeline(&self, message: Message) -> Result<Message> {
545        *self.start_time.write() = Some(std::time::Instant::now());
546        Ok(message)
547    }
548
549    async fn after_pipeline(&self, response: Response) -> Result<Response> {
550        if let Some(start) = *self.start_time.read() {
551            let duration = start.elapsed();
552            debug!("Pipeline processing took {:?}", duration);
553        }
554        Ok(response)
555    }
556}
557
558/// Timeout middleware
559struct TimeoutMiddleware {
560    #[allow(dead_code)]
561    timeout: Duration,
562}
563
564impl TimeoutMiddleware {
565    fn new(timeout: Duration) -> Self {
566        Self { timeout }
567    }
568}
569
570#[async_trait]
571impl PipelineMiddleware for TimeoutMiddleware {
572    async fn before_pipeline(&self, message: Message) -> Result<Message> {
573        // Timeout would be enforced at the pipeline level
574        Ok(message)
575    }
576}
577
578/// Pipeline metrics
579#[derive(Debug)]
580pub struct PipelineMetrics {
581    requests_total: Arc<RwLock<u64>>,
582    processing_times: Arc<RwLock<Vec<Duration>>>,
583}
584
585impl PipelineMetrics {
586    fn new() -> Self {
587        Self {
588            requests_total: Arc::new(RwLock::new(0)),
589            processing_times: Arc::new(RwLock::new(Vec::new())),
590        }
591    }
592
593    fn increment_requests(&self) {
594        *self.requests_total.write() += 1;
595    }
596
597    fn record_processing_time(&self, duration: Duration) {
598        let mut times = self.processing_times.write();
599        times.push(duration);
600        if times.len() > 1000 {
601            times.remove(0);
602        }
603    }
604
605    /// Get total requests processed
606    #[must_use]
607    pub fn requests_total(&self) -> u64 {
608        *self.requests_total.read()
609    }
610
611    /// Get average processing time
612    #[must_use]
613    #[allow(clippy::cast_possible_truncation)]
614    pub fn average_processing_time(&self) -> Option<Duration> {
615        let times = self.processing_times.read();
616        if times.is_empty() {
617            return None;
618        }
619
620        let total: Duration = times.iter().sum();
621        Some(total / times.len() as u32)
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    #[tokio::test]
630    async fn test_pipeline_creation() {
631        let config = BotConfig::default();
632        let pipeline = MessagePipeline::new(&config).await;
633        assert!(pipeline.is_ok());
634    }
635
636    #[test]
637    fn test_sanitize_stage() {
638        let stage = SanitizeStage::new();
639        let content = "Hello\x00World\x01Test";
640        let sanitized = stage.sanitize_content(content);
641        assert!(!sanitized.contains('\x00'));
642        assert!(!sanitized.contains('\x01'));
643    }
644
645    #[test]
646    fn test_route_stage_command_extraction() {
647        let stage = RouteStage::new();
648        assert_eq!(stage.extract_command("/help me"), Some("help".to_string()));
649        assert_eq!(stage.extract_command("not a command"), None);
650    }
651}