Skip to main content

use_rag/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        RagChunkId, RagChunkKind, RagCitationKind, RagContextAssemblyKind, RagCorpusName,
10        RagDocumentId, RagError, RagFreshnessStatus, RagGroundingStatus, RagRankerKind,
11        RagRetrievalMode, RagRetrieverKind,
12    };
13}
14
15macro_rules! rag_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, RagError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28
29            pub fn value(&self) -> &str {
30                self.as_str()
31            }
32
33            pub fn into_string(self) -> String {
34                self.0
35            }
36        }
37
38        impl AsRef<str> for $name {
39            fn as_ref(&self) -> &str {
40                self.as_str()
41            }
42        }
43
44        impl fmt::Display for $name {
45            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46                formatter.write_str(self.as_str())
47            }
48        }
49
50        impl FromStr for $name {
51            type Err = RagError;
52
53            fn from_str(value: &str) -> Result<Self, Self::Err> {
54                Self::new(value)
55            }
56        }
57
58        impl TryFrom<&str> for $name {
59            type Error = RagError;
60
61            fn try_from(value: &str) -> Result<Self, Self::Error> {
62                Self::new(value)
63            }
64        }
65    };
66}
67
68macro_rules! rag_enum {
69    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
70        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
71        pub enum $name {
72            $($variant),+
73        }
74
75        impl $name {
76            pub const ALL: &'static [Self] = &[$(Self::$variant),+];
77
78            pub const fn as_str(self) -> &'static str {
79                match self {
80                    $(Self::$variant => $label),+
81                }
82            }
83        }
84
85        impl fmt::Display for $name {
86            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
87                formatter.write_str(self.as_str())
88            }
89        }
90
91        impl FromStr for $name {
92            type Err = RagError;
93
94            fn from_str(value: &str) -> Result<Self, Self::Err> {
95                match normalized_label(value)?.as_str() {
96                    $($label => Ok(Self::$variant),)+
97                    _ => Err(RagError::UnknownLabel),
98                }
99            }
100        }
101    };
102}
103
104rag_text_newtype!(RagCorpusName);
105rag_text_newtype!(RagDocumentId);
106rag_text_newtype!(RagChunkId);
107
108rag_enum!(RagChunkKind {
109    Text => "text",
110    Table => "table",
111    Code => "code",
112    Image => "image",
113    Audio => "audio",
114    Metadata => "metadata",
115    Mixed => "mixed",
116    Custom => "custom",
117});
118
119rag_enum!(RagRetrieverKind {
120    Keyword => "keyword",
121    Vector => "vector",
122    Hybrid => "hybrid",
123    Graph => "graph",
124    Sql => "sql",
125    Api => "api",
126    FileSearch => "file-search",
127    WebSearch => "web-search",
128    Custom => "custom",
129});
130
131rag_enum!(RagRetrievalMode {
132    TopK => "top-k",
133    Threshold => "threshold",
134    Filtered => "filtered",
135    Hybrid => "hybrid",
136    Recursive => "recursive",
137    MultiQuery => "multi-query",
138    Custom => "custom",
139});
140
141rag_enum!(RagRankerKind {
142    None => "none",
143    Score => "score",
144    Reranker => "reranker",
145    CrossEncoder => "cross-encoder",
146    Diversity => "diversity",
147    Recency => "recency",
148    Custom => "custom",
149});
150
151rag_enum!(RagCitationKind {
152    Document => "document",
153    Url => "url",
154    File => "file",
155    LineRange => "line-range",
156    PageRange => "page-range",
157    Timestamp => "timestamp",
158    Chunk => "chunk",
159    Custom => "custom",
160});
161
162rag_enum!(RagGroundingStatus {
163    Grounded => "grounded",
164    PartiallyGrounded => "partially-grounded",
165    Ungrounded => "ungrounded",
166    Conflicting => "conflicting",
167    Unknown => "unknown",
168});
169
170rag_enum!(RagContextAssemblyKind {
171    Append => "append",
172    Compress => "compress",
173    Summarize => "summarize",
174    MapReduce => "map-reduce",
175    Windowed => "windowed",
176    Hierarchical => "hierarchical",
177    Custom => "custom",
178});
179
180rag_enum!(RagFreshnessStatus {
181    Fresh => "fresh",
182    Stale => "stale",
183    Unknown => "unknown",
184    TimeSensitive => "time-sensitive",
185});
186
187#[derive(Clone, Copy, Debug, Eq, PartialEq)]
188pub enum RagError {
189    Empty,
190    UnknownLabel,
191}
192
193impl fmt::Display for RagError {
194    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
195        match self {
196            Self::Empty => formatter.write_str("RAG metadata text cannot be empty"),
197            Self::UnknownLabel => formatter.write_str("unknown RAG metadata label"),
198        }
199    }
200}
201
202impl Error for RagError {}
203
204fn non_empty_text(value: impl AsRef<str>) -> Result<String, RagError> {
205    let trimmed = value.as_ref().trim();
206    if trimmed.is_empty() {
207        Err(RagError::Empty)
208    } else {
209        Ok(trimmed.to_string())
210    }
211}
212
213fn normalized_label(value: &str) -> Result<String, RagError> {
214    let trimmed = value.trim();
215    if trimmed.is_empty() {
216        Err(RagError::Empty)
217    } else {
218        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::{
225        RagChunkId, RagChunkKind, RagCitationKind, RagContextAssemblyKind, RagCorpusName,
226        RagDocumentId, RagError, RagFreshnessStatus, RagGroundingStatus, RagRankerKind,
227        RagRetrievalMode, RagRetrieverKind,
228    };
229    use core::{fmt, str::FromStr};
230
231    macro_rules! assert_text_newtype {
232        ($type:ty, $value:literal) => {{
233            let value = <$type>::new(concat!(" ", $value, " "))?;
234            assert_eq!(value.as_str(), $value);
235            assert_eq!(value.value(), $value);
236            assert_eq!(value.as_ref(), $value);
237            assert_eq!(value.to_string(), $value);
238            assert_eq!(<$type as TryFrom<&str>>::try_from($value)?, value);
239            assert_eq!(value.into_string(), $value.to_string());
240        }};
241    }
242
243    fn assert_enum_family<T>(variants: &[T]) -> Result<(), RagError>
244    where
245        T: Copy + Eq + fmt::Debug + fmt::Display + FromStr<Err = RagError>,
246    {
247        for variant in variants {
248            let label = variant.to_string();
249            assert_eq!(label.parse::<T>()?, *variant);
250            assert_eq!(label.replace('-', "_").parse::<T>()?, *variant);
251            assert_eq!(label.replace('-', " ").parse::<T>()?, *variant);
252        }
253        Ok(())
254    }
255
256    #[test]
257    fn validates_rag_text_newtypes() -> Result<(), RagError> {
258        assert_text_newtype!(RagCorpusName, "support-docs");
259        assert_text_newtype!(RagDocumentId, "doc-001");
260        assert_text_newtype!(RagChunkId, "chunk-001");
261        assert_eq!(RagCorpusName::new("  "), Err(RagError::Empty));
262        Ok(())
263    }
264
265    #[test]
266    fn displays_and_parses_rag_enums() -> Result<(), RagError> {
267        assert_enum_family(RagChunkKind::ALL)?;
268        assert_enum_family(RagRetrieverKind::ALL)?;
269        assert_enum_family(RagRetrievalMode::ALL)?;
270        assert_enum_family(RagRankerKind::ALL)?;
271        assert_enum_family(RagCitationKind::ALL)?;
272        assert_enum_family(RagGroundingStatus::ALL)?;
273        assert_enum_family(RagContextAssemblyKind::ALL)?;
274        assert_enum_family(RagFreshnessStatus::ALL)?;
275        assert_eq!(
276            "file search".parse::<RagRetrieverKind>()?,
277            RagRetrieverKind::FileSearch
278        );
279        Ok(())
280    }
281}