Skip to main content

use_ai_prompt/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        AiPromptError, PromptFormat, PromptId, PromptInstructionKind, PromptName, PromptPartKind,
10        PromptStatus, PromptTemplate, PromptText, PromptVariableKind, PromptVariableName,
11    };
12}
13
14macro_rules! prompt_text_newtype {
15    ($name:ident) => {
16        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17        pub struct $name(String);
18
19        impl $name {
20            pub fn new(value: impl AsRef<str>) -> Result<Self, AiPromptError> {
21                non_empty_text(value).map(Self)
22            }
23
24            pub fn as_str(&self) -> &str {
25                &self.0
26            }
27
28            pub fn value(&self) -> &str {
29                self.as_str()
30            }
31
32            pub fn into_string(self) -> String {
33                self.0
34            }
35        }
36
37        impl AsRef<str> for $name {
38            fn as_ref(&self) -> &str {
39                self.as_str()
40            }
41        }
42
43        impl fmt::Display for $name {
44            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
45                formatter.write_str(self.as_str())
46            }
47        }
48
49        impl FromStr for $name {
50            type Err = AiPromptError;
51
52            fn from_str(value: &str) -> Result<Self, Self::Err> {
53                Self::new(value)
54            }
55        }
56
57        impl TryFrom<&str> for $name {
58            type Error = AiPromptError;
59
60            fn try_from(value: &str) -> Result<Self, Self::Error> {
61                Self::new(value)
62            }
63        }
64    };
65}
66
67macro_rules! prompt_enum {
68    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
69        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
70        pub enum $name {
71            $($variant),+
72        }
73
74        impl $name {
75            pub const ALL: &'static [Self] = &[$(Self::$variant),+];
76
77            pub const fn as_str(self) -> &'static str {
78                match self {
79                    $(Self::$variant => $label),+
80                }
81            }
82        }
83
84        impl fmt::Display for $name {
85            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
86                formatter.write_str(self.as_str())
87            }
88        }
89
90        impl FromStr for $name {
91            type Err = AiPromptError;
92
93            fn from_str(value: &str) -> Result<Self, Self::Err> {
94                match normalized_label(value)?.as_str() {
95                    $($label => Ok(Self::$variant),)+
96                    _ => Err(AiPromptError::UnknownLabel),
97                }
98            }
99        }
100    };
101}
102
103prompt_text_newtype!(PromptName);
104prompt_text_newtype!(PromptId);
105prompt_text_newtype!(PromptText);
106prompt_text_newtype!(PromptVariableName);
107
108#[derive(Clone, Debug, Eq, PartialEq)]
109pub struct PromptTemplate {
110    name: PromptName,
111    text: PromptText,
112    format: PromptFormat,
113}
114
115impl PromptTemplate {
116    pub fn new(name: PromptName, text: PromptText, format: PromptFormat) -> Self {
117        Self { name, text, format }
118    }
119
120    pub fn name(&self) -> &PromptName {
121        &self.name
122    }
123
124    pub fn text(&self) -> &PromptText {
125        &self.text
126    }
127
128    pub const fn format(&self) -> PromptFormat {
129        self.format
130    }
131}
132
133prompt_enum!(PromptVariableKind {
134    String => "string",
135    Number => "number",
136    Boolean => "boolean",
137    Json => "json",
138    Text => "text",
139    List => "list",
140    Object => "object",
141    FileRef => "file-ref",
142    ImageRef => "image-ref",
143    AudioRef => "audio-ref",
144    Custom => "custom",
145});
146
147prompt_enum!(PromptPartKind {
148    System => "system",
149    Developer => "developer",
150    User => "user",
151    Assistant => "assistant",
152    Tool => "tool",
153    Context => "context",
154    Example => "example",
155    Constraint => "constraint",
156    OutputFormat => "output-format",
157    Metadata => "metadata",
158});
159
160prompt_enum!(PromptInstructionKind {
161    Task => "task",
162    Constraint => "constraint",
163    Style => "style",
164    Safety => "safety",
165    Role => "role",
166    Format => "format",
167    Context => "context",
168    Example => "example",
169    ToolUse => "tool-use",
170    Refusal => "refusal",
171    Custom => "custom",
172});
173
174prompt_enum!(PromptFormat {
175    PlainText => "plain-text",
176    Markdown => "markdown",
177    Json => "json",
178    Xml => "xml",
179    Yaml => "yaml",
180    ChatMessages => "chat-messages",
181    Custom => "custom",
182});
183
184prompt_enum!(PromptStatus {
185    Draft => "draft",
186    Active => "active",
187    Deprecated => "deprecated",
188    Archived => "archived",
189    Experimental => "experimental",
190});
191
192#[derive(Clone, Copy, Debug, Eq, PartialEq)]
193pub enum AiPromptError {
194    Empty,
195    UnknownLabel,
196}
197
198impl fmt::Display for AiPromptError {
199    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
200        match self {
201            Self::Empty => formatter.write_str("AI prompt metadata text cannot be empty"),
202            Self::UnknownLabel => formatter.write_str("unknown AI prompt metadata label"),
203        }
204    }
205}
206
207impl Error for AiPromptError {}
208
209fn non_empty_text(value: impl AsRef<str>) -> Result<String, AiPromptError> {
210    let trimmed = value.as_ref().trim();
211    if trimmed.is_empty() {
212        Err(AiPromptError::Empty)
213    } else {
214        Ok(trimmed.to_string())
215    }
216}
217
218fn normalized_label(value: &str) -> Result<String, AiPromptError> {
219    let trimmed = value.trim();
220    if trimmed.is_empty() {
221        Err(AiPromptError::Empty)
222    } else {
223        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::{
230        AiPromptError, PromptFormat, PromptId, PromptInstructionKind, PromptName, PromptPartKind,
231        PromptStatus, PromptTemplate, PromptText, PromptVariableKind, PromptVariableName,
232    };
233    use core::{fmt, str::FromStr};
234
235    macro_rules! assert_text_newtype {
236        ($type:ty, $value:literal) => {{
237            let value = <$type>::new(concat!(" ", $value, " "))?;
238            assert_eq!(value.as_str(), $value);
239            assert_eq!(value.value(), $value);
240            assert_eq!(value.as_ref(), $value);
241            assert_eq!(value.to_string(), $value);
242            assert_eq!(<$type as TryFrom<&str>>::try_from($value)?, value);
243            assert_eq!(value.into_string(), $value.to_string());
244        }};
245    }
246
247    fn assert_enum_family<T>(variants: &[T]) -> Result<(), AiPromptError>
248    where
249        T: Copy + Eq + fmt::Debug + fmt::Display + FromStr<Err = AiPromptError>,
250    {
251        for variant in variants {
252            let label = variant.to_string();
253            assert_eq!(label.parse::<T>()?, *variant);
254            assert_eq!(label.replace('-', "_").parse::<T>()?, *variant);
255            assert_eq!(label.replace('-', " ").parse::<T>()?, *variant);
256        }
257        Ok(())
258    }
259
260    #[test]
261    fn validates_prompt_text_newtypes() -> Result<(), AiPromptError> {
262        assert_text_newtype!(PromptName, "support-triage");
263        assert_text_newtype!(PromptId, "prompt-001");
264        assert_text_newtype!(PromptText, "Classify the support request");
265        assert_text_newtype!(PromptVariableName, "customer_tier");
266        assert_eq!(PromptName::new("  "), Err(AiPromptError::Empty));
267        Ok(())
268    }
269
270    #[test]
271    fn models_prompt_templates() -> Result<(), AiPromptError> {
272        let name = PromptName::new("support-triage")?;
273        let text = PromptText::new("Classify the support request")?;
274        let template = PromptTemplate::new(name, text, PromptFormat::Markdown);
275
276        assert_eq!(template.name().as_str(), "support-triage");
277        assert_eq!(template.text().as_str(), "Classify the support request");
278        assert_eq!(template.format(), PromptFormat::Markdown);
279        Ok(())
280    }
281
282    #[test]
283    fn displays_and_parses_prompt_enums() -> Result<(), AiPromptError> {
284        assert_enum_family(PromptVariableKind::ALL)?;
285        assert_enum_family(PromptPartKind::ALL)?;
286        assert_enum_family(PromptInstructionKind::ALL)?;
287        assert_enum_family(PromptFormat::ALL)?;
288        assert_enum_family(PromptStatus::ALL)?;
289        assert_eq!(
290            "plain text".parse::<PromptFormat>()?,
291            PromptFormat::PlainText
292        );
293        assert_eq!(
294            "nope".parse::<PromptFormat>(),
295            Err(AiPromptError::UnknownLabel)
296        );
297        Ok(())
298    }
299}