Skip to main content

use_ai_context/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        AiContextCitation, AiContextError, AiContextFitStatus, AiContextId, AiContextItemId,
10        AiContextItemKind, AiContextPriority, AiContextSourceKind, AiContextWindow,
11        AiContextWindowSize, AiGroundingStatus,
12    };
13}
14
15macro_rules! context_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, AiContextError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28
29            pub fn value(&self) -> &str {
30                self.as_str()
31            }
32
33            pub fn into_string(self) -> String {
34                self.0
35            }
36        }
37
38        impl AsRef<str> for $name {
39            fn as_ref(&self) -> &str {
40                self.as_str()
41            }
42        }
43
44        impl fmt::Display for $name {
45            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46                formatter.write_str(self.as_str())
47            }
48        }
49
50        impl FromStr for $name {
51            type Err = AiContextError;
52
53            fn from_str(value: &str) -> Result<Self, Self::Err> {
54                Self::new(value)
55            }
56        }
57
58        impl TryFrom<&str> for $name {
59            type Error = AiContextError;
60
61            fn try_from(value: &str) -> Result<Self, Self::Error> {
62                Self::new(value)
63            }
64        }
65    };
66}
67
68macro_rules! context_enum {
69    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
70        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
71        pub enum $name {
72            $($variant),+
73        }
74
75        impl $name {
76            pub const ALL: &'static [Self] = &[$(Self::$variant),+];
77
78            pub const fn as_str(self) -> &'static str {
79                match self {
80                    $(Self::$variant => $label),+
81                }
82            }
83        }
84
85        impl fmt::Display for $name {
86            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
87                formatter.write_str(self.as_str())
88            }
89        }
90
91        impl FromStr for $name {
92            type Err = AiContextError;
93
94            fn from_str(value: &str) -> Result<Self, Self::Err> {
95                match normalized_label(value)?.as_str() {
96                    $($label => Ok(Self::$variant),)+
97                    _ => Err(AiContextError::UnknownLabel),
98                }
99            }
100        }
101    };
102}
103
104context_text_newtype!(AiContextId);
105context_text_newtype!(AiContextItemId);
106context_text_newtype!(AiContextCitation);
107
108#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
109pub struct AiContextWindowSize(u32);
110
111impl AiContextWindowSize {
112    pub fn new(value: u32) -> Result<Self, AiContextError> {
113        if value == 0 {
114            Err(AiContextError::Zero)
115        } else {
116            Ok(Self(value))
117        }
118    }
119
120    pub const fn value(self) -> u32 {
121        self.0
122    }
123
124    pub const fn get(self) -> u32 {
125        self.0
126    }
127}
128
129#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
130pub struct AiContextWindow {
131    size: AiContextWindowSize,
132}
133
134impl AiContextWindow {
135    pub const fn new(size: AiContextWindowSize) -> Self {
136        Self { size }
137    }
138
139    pub const fn size(self) -> AiContextWindowSize {
140        self.size
141    }
142}
143
144context_enum!(AiContextItemKind {
145    Message => "message",
146    Document => "document",
147    File => "file",
148    WebPage => "web-page",
149    Code => "code",
150    Table => "table",
151    Image => "image",
152    Audio => "audio",
153    ToolResult => "tool-result",
154    Memory => "memory",
155    RetrievedChunk => "retrieved-chunk",
156    Metadata => "metadata",
157    Custom => "custom",
158});
159
160context_enum!(AiContextSourceKind {
161    UserProvided => "user-provided",
162    SystemProvided => "system-provided",
163    Retrieved => "retrieved",
164    ToolGenerated => "tool-generated",
165    Memory => "memory",
166    Web => "web",
167    File => "file",
168    Database => "database",
169    Api => "api",
170    Synthetic => "synthetic",
171    Unknown => "unknown",
172});
173
174context_enum!(AiContextPriority {
175    Low => "low",
176    Normal => "normal",
177    High => "high",
178    Critical => "critical",
179});
180
181context_enum!(AiContextFitStatus {
182    Fits => "fits",
183    Truncated => "truncated",
184    Summarized => "summarized",
185    Omitted => "omitted",
186    Overflow => "overflow",
187});
188
189context_enum!(AiGroundingStatus {
190    Grounded => "grounded",
191    PartiallyGrounded => "partially-grounded",
192    Ungrounded => "ungrounded",
193    Unknown => "unknown",
194});
195
196#[derive(Clone, Copy, Debug, Eq, PartialEq)]
197pub enum AiContextError {
198    Empty,
199    Zero,
200    UnknownLabel,
201}
202
203impl fmt::Display for AiContextError {
204    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
205        match self {
206            Self::Empty => formatter.write_str("AI context metadata text cannot be empty"),
207            Self::Zero => formatter.write_str("AI context numeric value must be positive"),
208            Self::UnknownLabel => formatter.write_str("unknown AI context metadata label"),
209        }
210    }
211}
212
213impl Error for AiContextError {}
214
215fn non_empty_text(value: impl AsRef<str>) -> Result<String, AiContextError> {
216    let trimmed = value.as_ref().trim();
217    if trimmed.is_empty() {
218        Err(AiContextError::Empty)
219    } else {
220        Ok(trimmed.to_string())
221    }
222}
223
224fn normalized_label(value: &str) -> Result<String, AiContextError> {
225    let trimmed = value.trim();
226    if trimmed.is_empty() {
227        Err(AiContextError::Empty)
228    } else {
229        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::{
236        AiContextCitation, AiContextError, AiContextFitStatus, AiContextId, AiContextItemId,
237        AiContextItemKind, AiContextPriority, AiContextSourceKind, AiContextWindow,
238        AiContextWindowSize, AiGroundingStatus,
239    };
240    use core::{fmt, str::FromStr};
241
242    macro_rules! assert_text_newtype {
243        ($type:ty, $value:literal) => {{
244            let value = <$type>::new(concat!(" ", $value, " "))?;
245            assert_eq!(value.as_str(), $value);
246            assert_eq!(value.value(), $value);
247            assert_eq!(value.as_ref(), $value);
248            assert_eq!(value.to_string(), $value);
249            assert_eq!(<$type as TryFrom<&str>>::try_from($value)?, value);
250            assert_eq!(value.into_string(), $value.to_string());
251        }};
252    }
253
254    fn assert_enum_family<T>(variants: &[T]) -> Result<(), AiContextError>
255    where
256        T: Copy + Eq + fmt::Debug + fmt::Display + FromStr<Err = AiContextError>,
257    {
258        for variant in variants {
259            let label = variant.to_string();
260            assert_eq!(label.parse::<T>()?, *variant);
261            assert_eq!(label.replace('-', "_").parse::<T>()?, *variant);
262            assert_eq!(label.replace('-', " ").parse::<T>()?, *variant);
263        }
264        Ok(())
265    }
266
267    #[test]
268    fn validates_context_text_newtypes() -> Result<(), AiContextError> {
269        assert_text_newtype!(AiContextId, "ctx-001");
270        assert_text_newtype!(AiContextItemId, "item-001");
271        assert_text_newtype!(AiContextCitation, "doc:1");
272        assert_eq!(AiContextId::new("  "), Err(AiContextError::Empty));
273        Ok(())
274    }
275
276    #[test]
277    fn validates_context_windows() -> Result<(), AiContextError> {
278        let size = AiContextWindowSize::new(8_192)?;
279        let window = AiContextWindow::new(size);
280
281        assert_eq!(size.get(), 8_192);
282        assert_eq!(window.size().value(), 8_192);
283        assert_eq!(AiContextWindowSize::new(0), Err(AiContextError::Zero));
284        Ok(())
285    }
286
287    #[test]
288    fn displays_and_parses_context_enums() -> Result<(), AiContextError> {
289        assert_enum_family(AiContextItemKind::ALL)?;
290        assert_enum_family(AiContextSourceKind::ALL)?;
291        assert_enum_family(AiContextPriority::ALL)?;
292        assert_enum_family(AiContextFitStatus::ALL)?;
293        assert_enum_family(AiGroundingStatus::ALL)?;
294        assert_eq!(
295            "web page".parse::<AiContextItemKind>()?,
296            AiContextItemKind::WebPage
297        );
298        Ok(())
299    }
300}