Skip to main content

tower_llm/validation/
gen.rs

1//! Proptest-based generators for valid conversations.
2
3use async_openai::types::*;
4use proptest::prelude::*;
5
6#[derive(Debug, Clone)]
7pub struct GeneratorConfig {
8    pub min_messages: usize,
9    pub max_messages: usize,
10    pub must_have_system: bool,
11    pub must_have_tool_calls: bool,
12    pub min_tool_calls: usize,
13    pub allow_developer: bool,
14    pub enforce_tool_order: bool,
15}
16
17impl Default for GeneratorConfig {
18    fn default() -> Self {
19        Self {
20            min_messages: 2,
21            max_messages: 12,
22            must_have_system: true,
23            must_have_tool_calls: false,
24            min_tool_calls: 0,
25            allow_developer: false,
26            enforce_tool_order: true,
27        }
28    }
29}
30
31/// Strategy that yields valid conversations that should pass the default validator.
32pub fn valid_conversation(
33    cfg: GeneratorConfig,
34) -> impl Strategy<Value = Vec<ChatCompletionRequestMessage>> {
35    // Keep an intentionally simple generator for MVP: up to 3 turns, optional tools.
36    let turn = (any::<bool>(), any::<bool>()); // (with_tools, with_text)
37    let turns = proptest::collection::vec(turn, 1..=3);
38    let sys_flag = proptest::strategy::Just(cfg.must_have_system);
39    (sys_flag, turns).prop_map(move |(must_sys, turns)| {
40        let mut msgs: Vec<ChatCompletionRequestMessage> = Vec::new();
41        if must_sys {
42            let sys = ChatCompletionRequestSystemMessageArgs::default()
43                .content("sys")
44                .build()
45                .unwrap();
46            msgs.push(sys.into());
47        }
48        // Ensure first non-system is user
49        let usr = ChatCompletionRequestUserMessageArgs::default()
50            .content("hi")
51            .build()
52            .unwrap();
53        msgs.push(usr.into());
54
55        let mut tool_id_counter = 1usize;
56        let last_turn = turns.len().saturating_sub(1);
57        for (idx, (with_tools, with_text)) in turns.into_iter().enumerate() {
58            // Assistant response
59            if with_tools || cfg.must_have_tool_calls {
60                let min_calls = cfg.min_tool_calls.max(1);
61                let num_calls = std::cmp::max(min_calls, 1);
62                let mut calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
63                for _ in 0..num_calls {
64                    let id = format!("c{}", tool_id_counter);
65                    tool_id_counter += 1;
66                    calls.push(ChatCompletionMessageToolCall {
67                        id: id.clone(),
68                        r#type: ChatCompletionToolType::Function,
69                        function: FunctionCall {
70                            name: "tool".into(),
71                            arguments: "{}".into(),
72                        },
73                    });
74                }
75                let asst = ChatCompletionRequestAssistantMessageArgs::default()
76                    .content("")
77                    .tool_calls(calls.clone())
78                    .build()
79                    .unwrap();
80                msgs.push(asst.into());
81                // Tool responses, in order
82                for tc in calls.into_iter() {
83                    let t = ChatCompletionRequestToolMessageArgs::default()
84                        .tool_call_id(tc.id)
85                        .content("{}")
86                        .build()
87                        .unwrap();
88                    msgs.push(t.into());
89                }
90            } else {
91                let content = if with_text { "ok" } else { "" };
92                let asst = ChatCompletionRequestAssistantMessageArgs::default()
93                    .content(content)
94                    .build()
95                    .unwrap();
96                msgs.push(asst.into());
97            }
98            // Always add a user between assistant turns, except after the last turn
99            if idx != last_turn {
100                let u = ChatCompletionRequestUserMessageArgs::default()
101                    .content("next")
102                    .build()
103                    .unwrap();
104                msgs.push(u.into());
105            }
106        }
107        msgs
108    })
109}