Skip to main content

tandem_server/
preset_composer.rs

1use serde::{Deserialize, Serialize};
2use sha2::{Digest, Sha256};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct PromptFragment {
6    pub id: String,
7    pub phase: String,
8    pub content: String,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12pub struct PromptComposeInput {
13    #[serde(default)]
14    pub base_prompt: String,
15    #[serde(default)]
16    pub fragments: Vec<PromptFragment>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PromptComposeOutput {
21    pub prompt: String,
22    pub composition_hash: String,
23    pub ordered_fragment_ids: Vec<String>,
24}
25
26pub fn compose(input: PromptComposeInput) -> PromptComposeOutput {
27    let mut fragments = input.fragments.clone();
28    fragments.sort_by(|a, b| {
29        phase_rank(&a.phase)
30            .cmp(&phase_rank(&b.phase))
31            .then_with(|| a.id.cmp(&b.id))
32            .then_with(|| a.content.cmp(&b.content))
33    });
34
35    let mut sections = Vec::new();
36    if !input.base_prompt.trim().is_empty() {
37        sections.push(input.base_prompt.trim().to_string());
38    }
39    let ordered_fragment_ids = fragments.iter().map(|f| f.id.clone()).collect::<Vec<_>>();
40    for fragment in &fragments {
41        if fragment.content.trim().is_empty() {
42            continue;
43        }
44        sections.push(format!(
45            "[{}:{}]\n{}",
46            fragment.phase.trim().to_ascii_lowercase(),
47            fragment.id.trim(),
48            fragment.content.trim()
49        ));
50    }
51    let prompt = sections.join("\n\n---\n\n");
52    let composition_hash = format!("{:x}", Sha256::digest(prompt.as_bytes()));
53    PromptComposeOutput {
54        prompt,
55        composition_hash,
56        ordered_fragment_ids,
57    }
58}
59
60fn phase_rank(phase: &str) -> usize {
61    match phase.trim().to_ascii_lowercase().as_str() {
62        "core" => 0,
63        "domain" => 1,
64        "style" => 2,
65        "safety" => 3,
66        _ => 99,
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn compose_is_deterministic_by_phase_and_id() {
76        let out = compose(PromptComposeInput {
77            base_prompt: "Base".to_string(),
78            fragments: vec![
79                PromptFragment {
80                    id: "zeta".to_string(),
81                    phase: "style".to_string(),
82                    content: "Style Z".to_string(),
83                },
84                PromptFragment {
85                    id: "alpha".to_string(),
86                    phase: "core".to_string(),
87                    content: "Core A".to_string(),
88                },
89                PromptFragment {
90                    id: "beta".to_string(),
91                    phase: "style".to_string(),
92                    content: "Style B".to_string(),
93                },
94                PromptFragment {
95                    id: "safe".to_string(),
96                    phase: "safety".to_string(),
97                    content: "Do no harm".to_string(),
98                },
99            ],
100        });
101        assert_eq!(
102            out.ordered_fragment_ids,
103            vec![
104                "alpha".to_string(),
105                "beta".to_string(),
106                "zeta".to_string(),
107                "safe".to_string()
108            ]
109        );
110        let out2 = compose(PromptComposeInput {
111            base_prompt: "Base".to_string(),
112            fragments: vec![
113                PromptFragment {
114                    id: "safe".to_string(),
115                    phase: "safety".to_string(),
116                    content: "Do no harm".to_string(),
117                },
118                PromptFragment {
119                    id: "beta".to_string(),
120                    phase: "style".to_string(),
121                    content: "Style B".to_string(),
122                },
123                PromptFragment {
124                    id: "alpha".to_string(),
125                    phase: "core".to_string(),
126                    content: "Core A".to_string(),
127                },
128                PromptFragment {
129                    id: "zeta".to_string(),
130                    phase: "style".to_string(),
131                    content: "Style Z".to_string(),
132                },
133            ],
134        });
135        assert_eq!(out.prompt, out2.prompt);
136        assert_eq!(out.composition_hash, out2.composition_hash);
137    }
138}