spring_ai_rs/
ai_logger.rs

1use std::{
2    error::Error,
3    ffi::CString,
4    fmt::{Display, Formatter, Result as FmtResult},
5};
6
7use slog::{Drain, Level, OwnedKVList, Record};
8
9use crate::get_callback;
10
11fn basic_log<S>(ai_id: i32, message: S) -> Result<(), Box<dyn Error>>
12where
13    S: Into<Vec<u8>>,
14{
15    let log_func = get_callback!(ai_id, Log_log)?;
16
17    let c_message = CString::new(message)?;
18
19    Ok(unsafe { log_func(ai_id, c_message.as_ptr()) })
20}
21
22fn basic_exception_log<S>(
23    ai_id: i32,
24    message: S,
25    level: i32,
26    die: bool,
27) -> Result<(), Box<dyn Error>>
28where
29    S: Into<Vec<u8>>,
30{
31    let log_exception_func = get_callback!(ai_id, Log_exception)?;
32
33    let c_message = CString::new(message)?;
34
35    Ok(unsafe { log_exception_func(ai_id, c_message.as_ptr(), level, die) })
36}
37
38pub struct AILogger {
39    ai_id: i32,
40    level: Level,
41    file_info: bool,
42}
43
44impl AILogger {
45    pub fn new(ai_id: i32) -> Self {
46        Self {
47            ai_id,
48            level: Level::Warning,
49            file_info: false,
50        }
51    }
52
53    pub fn with_level(self, level: Level) -> Self {
54        Self { level, ..self }
55    }
56
57    pub fn with_file_info(self, file_info: bool) -> Self {
58        Self { file_info, ..self }
59    }
60}
61
62#[derive(Copy, Clone, Debug)]
63pub struct AILoggerOk {}
64
65#[derive(Copy, Clone, Debug)]
66pub struct AILoggerErr {}
67
68impl Error for AILoggerErr {}
69
70impl Display for AILoggerErr {
71    fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
72        write!(fmt, "{:?}", self)
73    }
74}
75
76impl Drain for AILogger {
77    type Ok = AILoggerOk;
78    type Err = AILoggerErr;
79
80    fn log(&self, record: &Record, values: &OwnedKVList) -> Result<Self::Ok, Self::Err> {
81        let message = if self.file_info {
82            format!(
83                "{} [{}::{} ({}:{})]: {}",
84                record.level().to_string(),
85                record.module(),
86                record.function(),
87                record.file(),
88                record.line(),
89                record.msg()
90            )
91        } else {
92            format!("{}: {}", record.level().to_string(), record.msg())
93        };
94        if record.level() < self.level {
95            match record.level() {
96                Level::Debug | Level::Trace | Level::Info => {
97                    basic_log(self.ai_id, message).unwrap()
98                }
99                Level::Warning => basic_exception_log(self.ai_id, message, 3, false).unwrap(),
100                Level::Error => basic_exception_log(self.ai_id, message, 6, false).unwrap(),
101                Level::Critical => basic_exception_log(self.ai_id, message, 9, true).unwrap(),
102            }
103        }
104
105        Ok(AILoggerOk {})
106    }
107}