1use once_cell::sync::Lazy;
7use serde_json::{Map, Value};
8use std::collections::HashSet;
9
10#[derive(Debug, Clone, Default)]
11pub struct PromptLearningValidationResult {
12 pub errors: Vec<String>,
13 pub warnings: Vec<String>,
14 pub info: Vec<String>,
15}
16
17impl PromptLearningValidationResult {
18 pub fn new() -> Self {
19 Self::default()
20 }
21
22 pub fn is_valid(&self) -> bool {
23 self.errors.is_empty()
24 }
25
26 fn add_error(&mut self, msg: String) {
27 self.errors.push(msg);
28 }
29
30 fn add_warning(&mut self, msg: String) {
31 self.warnings.push(msg);
32 }
33
34 fn add_info(&mut self, msg: String) {
35 self.info.push(msg);
36 }
37}
38
39const KNOWN_TOP_LEVEL_SECTIONS: &[&str] = &["prompt_learning", "display", "termination_config"];
40
41const KNOWN_PROMPT_LEARNING_FIELDS: &[&str] = &[
42 "algorithm",
43 "container_url",
44 "container_api_key",
45 "container_id",
46 "initial_prompt",
47 "policy",
48 "mipro",
49 "gepa",
50 "ontology",
51 "verifier",
52 "proxy_models",
53 "env_config",
54 "env_name",
55 "termination_config",
56 "results_folder",
57 "bootstrap_train_seeds",
58 "online_pool",
59 "test_pool",
60 "reference_pool",
61 "auto_discover_patterns",
62 "use_byok",
63];
64
65const KNOWN_POLICY_FIELDS: &[&str] = &[
66 "model",
67 "provider",
68 "inference_url",
69 "inference_mode",
70 "temperature",
71 "max_completion_tokens",
72 "policy_name",
73 "config",
74 "context_override",
75 "timeout",
76];
77
78const KNOWN_TERMINATION_CONFIG_FIELDS: &[&str] = &[
79 "max_cost_usd",
80 "max_trials",
81 "max_seconds",
82 "max_time_seconds",
83 "max_rollouts",
84 "max_trials_without_improvement",
85 "pessimism_enabled",
86 "max_category_costs_usd",
87];
88
89const KNOWN_GEPA_FIELDS: &[&str] = &[
90 "env_name",
91 "env_config",
92 "rng_seed",
93 "proposer_type",
94 "proposer_mode",
95 "proposal_pipeline",
96 "proposer_backend",
97 "proposer",
98 "context_override",
99 "proposer_effort",
100 "proposer_output_tokens",
101 "metaprompt",
102 "dspy_meta_model",
103 "dspy_meta_provider",
104 "dspy_inference_url",
105 "synth_meta_model",
106 "synth_meta_provider",
107 "synth_inference_url",
108 "gepa_ai_meta_model",
109 "gepa_ai_meta_provider",
110 "gepa_ai_inference_url",
111 "modules",
112 "rollout",
113 "evaluation",
114 "mutation",
115 "population",
116 "archive",
117 "token",
118 "verifier",
119 "proxy_models",
120 "adaptive_pool",
121 "adaptive_batch",
122 "rollout_budget",
123 "max_concurrent_rollouts",
124 "minibatch_size",
125 "evaluation_seeds",
126 "validation_seeds",
127 "test_pool",
128 "validation_pool",
129 "validation_top_k",
130 "mutation_rate",
131 "mutation_llm_model",
132 "mutation_llm_provider",
133 "mutation_llm_inference_url",
134 "mutation_prompt",
135 "initial_population_size",
136 "num_generations",
137 "children_per_generation",
138 "crossover_rate",
139 "selection_pressure",
140 "patience_generations",
141 "archive_size",
142 "pareto_set_size",
143 "pareto_eps",
144 "feedback_fraction",
145 "max_token_limit",
146 "token_counting_model",
147 "enforce_pattern_token_limit",
148 "max_spend_usd",
149 "unified_optimization",
150 "baseline_context_override",
151 "proposed_prompt_max_tokens",
152 "use_byok",
153];
154
155const KNOWN_GEPA_ROLLOUT_FIELDS: &[&str] =
156 &["budget", "max_concurrent", "minibatch_size", "timeout"];
157
158const KNOWN_GEPA_EVALUATION_FIELDS: &[&str] = &[
159 "seeds",
160 "train_seeds",
161 "validation_seeds",
162 "val_seeds",
163 "test_pool",
164 "validation_pool",
165 "validation_top_k",
166];
167
168const KNOWN_GEPA_MUTATION_FIELDS: &[&str] = &[
169 "rate",
170 "llm_model",
171 "llm_provider",
172 "llm_inference_url",
173 "prompt",
174];
175
176const KNOWN_GEPA_POPULATION_FIELDS: &[&str] = &[
177 "initial_size",
178 "num_generations",
179 "children_per_generation",
180 "crossover_rate",
181 "selection_pressure",
182 "patience_generations",
183];
184
185const KNOWN_GEPA_ARCHIVE_FIELDS: &[&str] =
186 &["size", "pareto_set_size", "pareto_eps", "feedback_fraction"];
187
188const KNOWN_GEPA_TOKEN_FIELDS: &[&str] = &[
189 "max_limit",
190 "counting_model",
191 "enforce_pattern_limit",
192 "max_spend_usd",
193];
194
195const KNOWN_MIPRO_FIELDS: &[&str] = &[
196 "container_url",
197 "container_api_key",
198 "container_id",
199 "mode",
200 "online",
201 "num_iterations",
202 "num_evaluations_per_iteration",
203 "batch_size",
204 "max_concurrent",
205 "env_name",
206 "env_config",
207 "meta_model",
208 "meta_model_provider",
209 "meta_model_inference_url",
210 "few_shot_score_threshold",
211 "results_file",
212 "max_wall_clock_seconds",
213 "max_total_tokens",
214 "policy_config",
215 "meta",
216 "modules",
217 "seeds",
218 "proposer_effort",
219 "proposer_output_tokens",
220 "max_token_limit",
221 "max_spend_usd",
222 "token_counting_model",
223 "enforce_token_limit",
224 "tpe",
225 "demo",
226 "grounding",
227 "meta_update",
228 "verifier",
229 "proxy_models",
230 "adaptive_pool",
231 "spec_path",
232 "spec_max_tokens",
233 "spec_include_examples",
234 "spec_priority_threshold",
235 "metaprompt",
236 "bootstrap_train_seeds",
237 "online_pool",
238 "online_rollouts_per_candidate",
239 "rollouts_per_candidate",
240 "candidate_rollouts",
241 "online_recent_event_window",
242 "online_recent_rollout_window",
243 "online_event_shard_size",
244 "online_rollout_shard_size",
245 "online_snapshot_max_events",
246 "online_snapshot_max_seconds",
247 "online_proposer_min_rewards",
248 "online_proposer_min_rollouts",
249 "online_proposer_min_seconds",
250 "online_proposer_max_candidates",
251 "online_proposer_mode",
252 "max_instruction_slots",
253 "transform_slots",
254 "ontology",
255 "text_dreamer",
256 "test_pool",
257 "reference_pool",
258 "min_bootstrap_demos",
259];
260
261const KNOWN_MIPRO_ONTOLOGY_FIELDS: &[&str] = &[
262 "enabled",
263 "reads",
264 "read",
265 "writes",
266 "write",
267 "node_name",
268 "node",
269 "context_node",
270 "batch_proposer",
271];
272
273const KNOWN_MIPRO_ONTOLOGY_BATCH_PROPOSER_FIELDS: &[&str] = &[
274 "enabled",
275 "min_rollouts",
276 "trigger_mode",
277 "stage_rollout_intervals",
278 "staged_rollout_intervals",
279 "stage_intervals",
280 "repeat_after_last_stage_rollouts",
281 "repeat_stage_rollouts",
282 "repeat_every_rollouts",
283 "batch_size",
284 "model",
285 "provider",
286 "inference_url",
287 "temperature",
288 "max_tokens",
289 "api_key_env",
290];
291
292const KNOWN_MIPRO_TEXT_DREAMER_FIELDS: &[&str] = &[
293 "enabled",
294 "mode",
295 "world_model_mode",
296 "on_overlap",
297 "runtime_backend",
298 "max_pending_jobs_per_system",
299 "include_trace_required",
300 "max_replay_rollouts",
301 "observation_trigger_every_rollouts",
302 "observation_log_window",
303 "container_url",
304 "container_api_key",
305 "shadow_rollouts",
306 "shadow_max_turns",
307 "shadow_timeout_seconds",
308 "shadow_seed_base",
309];
310
311const KNOWN_VERIFIER_FIELDS: &[&str] = &[
312 "enabled",
313 "reward_source",
314 "backend_base",
315 "backend_api_key_env",
316 "backend_provider",
317 "backend_model",
318 "verifier_graph_id",
319 "backend_event_enabled",
320 "backend_outcome_enabled",
321 "backend_options",
322 "concurrency",
323 "timeout",
324 "weight_env",
325 "weight_event",
326 "weight_outcome",
327 "spec_path",
328 "spec_max_tokens",
329 "spec_context",
330];
331
332const KNOWN_ADAPTIVE_POOL_FIELDS: &[&str] = &[
333 "level",
334 "anchor_size",
335 "pool_init_size",
336 "pool_min_size",
337 "warmup_iters",
338 "anneal_stop_iter",
339 "pool_update_period",
340 "min_evals_per_example",
341 "k_info_prompts",
342 "info_buffer_factor",
343 "info_epsilon",
344 "anchor_selection_method",
345 "exploration_strategy",
346 "heatup_reserve_pool",
347 "heatup_trigger",
348 "heatup_size",
349 "heatup_cooldown_trials",
350 "heatup_schedule",
351];
352
353const KNOWN_ADAPTIVE_BATCH_FIELDS: &[&str] = &[
354 "level",
355 "reflection_minibatch_size",
356 "min_local_improvement",
357 "val_evaluation_mode",
358 "val_subsample_size",
359 "candidate_selection_strategy",
360];
361
362const KNOWN_PROXY_MODELS_FIELDS: &[&str] = &[
363 "hi_provider",
364 "hi_model",
365 "lo_provider",
366 "lo_model",
367 "n_min_hi",
368 "r2_thresh",
369 "r2_stop",
370 "sigma_max",
371 "sigma_stop",
372 "verify_every",
373 "proxy_patience_usd",
374];
375
376fn deprecated_message(key: &str) -> Option<&'static str> {
377 match key {
378 "display" => Some(
379 "The [display] section is deprecated and ignored by the backend. Remove it from your config.",
380 ),
381 "results_folder" => Some(
382 "'results_folder' is deprecated and ignored by the backend. Remove it from your config.",
383 ),
384 "rollout_budget" => Some(
385 "Use [prompt_learning.gepa.rollout].budget instead of flat rollout_budget.",
386 ),
387 "max_concurrent_rollouts" => Some(
388 "Use [prompt_learning.gepa.rollout].max_concurrent instead.",
389 ),
390 "evaluation_seeds" => Some(
391 "Use [prompt_learning.gepa.evaluation].seeds instead of flat evaluation_seeds.",
392 ),
393 "validation_seeds" => Some(
394 "Use [prompt_learning.gepa.evaluation].validation_seeds instead.",
395 ),
396 "backend_rubric_id" => Some("Use 'verifier_graph_id' in [prompt_learning.verifier]."),
397 _ => None,
398 }
399}
400
401fn contains_known(known: &[&str], key: &str) -> bool {
402 known.iter().any(|k| *k == key)
403}
404
405fn check_unknown_fields(
406 map: &Map<String, Value>,
407 known_fields: &[&str],
408 section_path: &str,
409 result: &mut PromptLearningValidationResult,
410) {
411 for key in map.keys() {
412 if !contains_known(known_fields, key) {
413 result.add_warning(format!(
414 "Unknown field '{}' in [{}]. This field will be ignored. Check spelling or remove it.",
415 key, section_path
416 ));
417 }
418 }
419}
420
421fn check_deprecated_fields(
422 map: &Map<String, Value>,
423 section_path: &str,
424 result: &mut PromptLearningValidationResult,
425) {
426 for key in map.keys() {
427 if let Some(msg) = deprecated_message(key) {
428 result.add_warning(format!("[{}] {}", section_path, msg));
429 }
430 }
431}
432
433fn validate_gepa_config(
434 gepa: &Map<String, Value>,
435 result: &mut PromptLearningValidationResult,
436 path_prefix: &str,
437) {
438 check_unknown_fields(gepa, KNOWN_GEPA_FIELDS, "prompt_learning.gepa", result);
439
440 for field in [
441 "rollout_budget",
442 "max_concurrent_rollouts",
443 "evaluation_seeds",
444 "validation_seeds",
445 ] {
446 if gepa.contains_key(field) {
447 result.add_info(format!(
448 "Using flat '{}' in [prompt_learning.gepa] - consider migrating to nested structure for clarity",
449 field
450 ));
451 }
452 }
453
454 if let Some(Value::Object(rollout)) = gepa.get("rollout") {
455 check_unknown_fields(
456 rollout,
457 KNOWN_GEPA_ROLLOUT_FIELDS,
458 "prompt_learning.gepa.rollout",
459 result,
460 );
461 }
462 if let Some(Value::Object(evaluation)) = gepa.get("evaluation") {
463 check_unknown_fields(
464 evaluation,
465 KNOWN_GEPA_EVALUATION_FIELDS,
466 "prompt_learning.gepa.evaluation",
467 result,
468 );
469 }
470 if let Some(Value::Object(mutation)) = gepa.get("mutation") {
471 check_unknown_fields(
472 mutation,
473 KNOWN_GEPA_MUTATION_FIELDS,
474 "prompt_learning.gepa.mutation",
475 result,
476 );
477 }
478 if let Some(Value::Object(population)) = gepa.get("population") {
479 check_unknown_fields(
480 population,
481 KNOWN_GEPA_POPULATION_FIELDS,
482 "prompt_learning.gepa.population",
483 result,
484 );
485 }
486 if let Some(Value::Object(archive)) = gepa.get("archive") {
487 check_unknown_fields(
488 archive,
489 KNOWN_GEPA_ARCHIVE_FIELDS,
490 "prompt_learning.gepa.archive",
491 result,
492 );
493 }
494 if let Some(Value::Object(token)) = gepa.get("token") {
495 check_unknown_fields(
496 token,
497 KNOWN_GEPA_TOKEN_FIELDS,
498 "prompt_learning.gepa.token",
499 result,
500 );
501 }
502 if let Some(Value::Object(adaptive_pool)) = gepa.get("adaptive_pool") {
503 check_unknown_fields(
504 adaptive_pool,
505 KNOWN_ADAPTIVE_POOL_FIELDS,
506 "prompt_learning.gepa.adaptive_pool",
507 result,
508 );
509 }
510 if let Some(Value::Object(adaptive_batch)) = gepa.get("adaptive_batch") {
511 check_unknown_fields(
512 adaptive_batch,
513 KNOWN_ADAPTIVE_BATCH_FIELDS,
514 "prompt_learning.gepa.adaptive_batch",
515 result,
516 );
517 }
518 if let Some(Value::Object(proxy_models)) = gepa.get("proxy_models") {
519 check_unknown_fields(
520 proxy_models,
521 KNOWN_PROXY_MODELS_FIELDS,
522 "prompt_learning.gepa.proxy_models",
523 result,
524 );
525 }
526 if let Some(Value::Object(verifier)) = gepa.get("verifier") {
527 check_unknown_fields(
528 verifier,
529 KNOWN_VERIFIER_FIELDS,
530 "prompt_learning.gepa.verifier",
531 result,
532 );
533 }
534
535 if gepa.is_empty() {
536 result.add_warning(format!(
537 "{}No [prompt_learning.gepa] section found for GEPA algorithm",
538 path_prefix
539 ));
540 }
541}
542
543fn validate_mipro_config(
544 mipro: &Map<String, Value>,
545 result: &mut PromptLearningValidationResult,
546 path_prefix: &str,
547) {
548 check_unknown_fields(mipro, KNOWN_MIPRO_FIELDS, "prompt_learning.mipro", result);
549
550 if let Some(Value::Object(verifier)) = mipro.get("verifier") {
551 check_unknown_fields(
552 verifier,
553 KNOWN_VERIFIER_FIELDS,
554 "prompt_learning.mipro.verifier",
555 result,
556 );
557 }
558 if let Some(Value::Object(adaptive_pool)) = mipro.get("adaptive_pool") {
559 check_unknown_fields(
560 adaptive_pool,
561 KNOWN_ADAPTIVE_POOL_FIELDS,
562 "prompt_learning.mipro.adaptive_pool",
563 result,
564 );
565 }
566 if let Some(Value::Object(proxy_models)) = mipro.get("proxy_models") {
567 check_unknown_fields(
568 proxy_models,
569 KNOWN_PROXY_MODELS_FIELDS,
570 "prompt_learning.mipro.proxy_models",
571 result,
572 );
573 }
574 if let Some(Value::Object(ontology)) = mipro.get("ontology") {
575 validate_ontology_config(ontology, "prompt_learning.mipro.ontology", result);
576 }
577 if let Some(Value::Object(text_dreamer)) = mipro.get("text_dreamer") {
578 check_unknown_fields(
579 text_dreamer,
580 KNOWN_MIPRO_TEXT_DREAMER_FIELDS,
581 "prompt_learning.mipro.text_dreamer",
582 result,
583 );
584 }
585
586 if mipro.is_empty() {
587 result.add_warning(format!(
588 "{}No [prompt_learning.mipro] section found for MIPRO algorithm",
589 path_prefix
590 ));
591 }
592}
593
594fn validate_ontology_config(
595 ontology: &Map<String, Value>,
596 path: &str,
597 result: &mut PromptLearningValidationResult,
598) {
599 check_unknown_fields(ontology, KNOWN_MIPRO_ONTOLOGY_FIELDS, path, result);
600
601 if let Some(Value::Object(batch)) = ontology.get("batch_proposer") {
602 let batch_path = format!("{path}.batch_proposer");
603 check_unknown_fields(
604 batch,
605 KNOWN_MIPRO_ONTOLOGY_BATCH_PROPOSER_FIELDS,
606 &batch_path,
607 result,
608 );
609 }
610}
611
612pub fn validate_prompt_learning_config(
613 config: &Value,
614 config_path: Option<&str>,
615) -> PromptLearningValidationResult {
616 let mut result = PromptLearningValidationResult::new();
617 let path_prefix = config_path
618 .map(|p| format!("({}) ", p))
619 .unwrap_or_else(String::new);
620
621 let config_map = match config.as_object() {
622 Some(map) => map,
623 None => {
624 result.add_error(format!("{}Config must be an object", path_prefix));
625 return result;
626 }
627 };
628
629 for key in config_map.keys() {
630 if !contains_known(KNOWN_TOP_LEVEL_SECTIONS, key) {
631 result.add_warning(format!(
632 "{}Unknown top-level section '[{}]'. Known sections: {}",
633 path_prefix,
634 key,
635 KNOWN_TOP_LEVEL_SECTIONS.join(", ")
636 ));
637 }
638 }
639
640 if config_map.contains_key("display") {
641 result.add_warning(format!(
642 "{}The [display] section is deprecated and ignored by the backend. Remove it to clean up your config.",
643 path_prefix
644 ));
645 }
646
647 let pl_value = config_map.get("prompt_learning");
648 let pl_map = match pl_value.and_then(|v| v.as_object()) {
649 Some(map) => map,
650 None => {
651 result.add_error(format!(
652 "{}Missing required [prompt_learning] section",
653 path_prefix
654 ));
655 return result;
656 }
657 };
658
659 check_unknown_fields(
660 pl_map,
661 KNOWN_PROMPT_LEARNING_FIELDS,
662 "prompt_learning",
663 &mut result,
664 );
665 check_deprecated_fields(pl_map, "prompt_learning", &mut result);
666
667 let algorithm = pl_map.get("algorithm").and_then(|v| v.as_str());
668 if algorithm.is_none() {
669 result.add_error(format!(
670 "{}Missing required 'algorithm' field in [prompt_learning]",
671 path_prefix
672 ));
673 } else if !matches!(algorithm, Some("gepa") | Some("mipro")) {
674 if let Some(value) = algorithm {
675 result.add_error(format!(
676 "{}Invalid algorithm '{}'. Must be 'gepa' or 'mipro'",
677 path_prefix, value
678 ));
679 }
680 }
681
682 if pl_map
683 .get("container_url")
684 .and_then(|v| v.as_str())
685 .is_none()
686 {
687 result.add_error(format!(
688 "{}Missing required 'container_url' in [prompt_learning]",
689 path_prefix
690 ));
691 }
692
693 if let Some(Value::Object(policy)) = pl_map.get("policy") {
694 check_unknown_fields(
695 policy,
696 KNOWN_POLICY_FIELDS,
697 "prompt_learning.policy",
698 &mut result,
699 );
700 }
701
702 if let Some(Value::Object(termination)) = pl_map.get("termination_config") {
703 check_unknown_fields(
704 termination,
705 KNOWN_TERMINATION_CONFIG_FIELDS,
706 "prompt_learning.termination_config",
707 &mut result,
708 );
709 result.add_info(
710 "termination_config is supported and will create backend TerminationManager conditions"
711 .to_string(),
712 );
713 }
714
715 if let Some(Value::Object(verifier)) = pl_map.get("verifier") {
716 check_unknown_fields(
717 verifier,
718 KNOWN_VERIFIER_FIELDS,
719 "prompt_learning.verifier",
720 &mut result,
721 );
722 }
723
724 if let Some(Value::Object(proxy_models)) = pl_map.get("proxy_models") {
725 check_unknown_fields(
726 proxy_models,
727 KNOWN_PROXY_MODELS_FIELDS,
728 "prompt_learning.proxy_models",
729 &mut result,
730 );
731 }
732
733 if let Some(Value::Object(ontology)) = pl_map.get("ontology") {
734 validate_ontology_config(ontology, "prompt_learning.ontology", &mut result);
735 }
736
737 match algorithm {
738 Some("gepa") => {
739 if let Some(Value::Object(gepa)) = pl_map.get("gepa") {
740 validate_gepa_config(gepa, &mut result, &path_prefix);
741 } else {
742 result.add_warning(format!(
743 "{}No [prompt_learning.gepa] section found for GEPA algorithm",
744 path_prefix
745 ));
746 }
747 }
748 Some("mipro") => {
749 if let Some(Value::Object(mipro)) = pl_map.get("mipro") {
750 validate_mipro_config(mipro, &mut result, &path_prefix);
751 } else {
752 result.add_warning(format!(
753 "{}No [prompt_learning.mipro] section found for MIPRO algorithm",
754 path_prefix
755 ));
756 }
757 }
758 _ => {}
759 }
760
761 result
762}
763
764#[derive(Debug, Clone)]
769struct SupportedModels {
770 openai: HashSet<String>,
771 groq: HashSet<String>,
772 google: HashSet<String>,
773}
774
775fn extract_model_list(value: Option<&Value>) -> Vec<String> {
776 match value.and_then(|v| v.as_array()) {
777 Some(arr) => arr
778 .iter()
779 .filter_map(|v| v.as_str())
780 .map(|s| s.to_lowercase())
781 .collect(),
782 None => Vec::new(),
783 }
784}
785
786fn load_supported_models() -> Option<SupportedModels> {
787 let raw = include_str!("../../assets/supported_models.json");
788 let value: Value = serde_json::from_str(raw).ok()?;
789 let prompt_opt = value.get("prompt_optimization")?.as_object()?;
790
791 let openai = extract_model_list(prompt_opt.get("openai").and_then(|v| v.get("models")));
792 let openai_image =
793 extract_model_list(prompt_opt.get("openai_image").and_then(|v| v.get("models")));
794 let google = extract_model_list(prompt_opt.get("google").and_then(|v| v.get("models")));
795 let google_image =
796 extract_model_list(prompt_opt.get("google_image").and_then(|v| v.get("models")));
797 let groq = extract_model_list(prompt_opt.get("groq").and_then(|v| v.get("models")));
798
799 let mut openai_set = HashSet::new();
800 for item in openai.into_iter().chain(openai_image.into_iter()) {
801 openai_set.insert(item);
802 }
803 let mut google_set = HashSet::new();
804 for item in google.into_iter().chain(google_image.into_iter()) {
805 google_set.insert(item);
806 }
807 let groq_set: HashSet<String> = groq.into_iter().collect();
808
809 Some(SupportedModels {
810 openai: openai_set,
811 groq: groq_set,
812 google: google_set,
813 })
814}
815
816static SUPPORTED_MODELS: Lazy<Option<SupportedModels>> = Lazy::new(load_supported_models);
817
818fn is_supported_openai_model(model: &str) -> bool {
819 if let Some(models) = SUPPORTED_MODELS.as_ref() {
820 let key = model.to_lowercase();
821 return models.openai.contains(&key);
822 }
823 true
824}
825
826fn is_supported_groq_model(model: &str) -> bool {
827 if let Some(models) = SUPPORTED_MODELS.as_ref() {
828 let key = model.to_lowercase();
829 return models.groq.contains(&key);
830 }
831 true
832}
833
834fn is_supported_google_model(model: &str) -> bool {
835 if let Some(models) = SUPPORTED_MODELS.as_ref() {
836 let key = model.to_lowercase();
837 return models.google.contains(&key);
838 }
839 true
840}
841
842fn value_type_name(value: &Value) -> &'static str {
843 match value {
844 Value::Null => "None",
845 Value::Bool(_) => "bool",
846 Value::Number(_) => "number",
847 Value::String(_) => "str",
848 Value::Array(_) => "list",
849 Value::Object(_) => "dict",
850 }
851}
852
853fn parse_int(value: &Value) -> Option<i64> {
854 match value {
855 Value::Number(n) => n.as_i64().or_else(|| n.as_f64().map(|f| f as i64)),
856 Value::String(s) => s.trim().parse::<i64>().ok(),
857 _ => None,
858 }
859}
860
861fn parse_float(value: &Value) -> Option<f64> {
862 match value {
863 Value::Number(n) => n.as_f64(),
864 Value::String(s) => s.trim().parse::<f64>().ok(),
865 _ => None,
866 }
867}
868
869fn ontology_connector_enabled_for_text_dreamer(
870 pl_section: &Map<String, Value>,
871 mipro_map: &Map<String, Value>,
872) -> bool {
873 let ontology_value = mipro_map
874 .get("ontology")
875 .or_else(|| pl_section.get("ontology"));
876
877 let Some(ontology_value) = ontology_value else {
878 return false;
879 };
880
881 match ontology_value {
882 Value::Bool(enabled) => *enabled,
883 Value::Object(map) => {
884 let enabled = map.get("enabled").and_then(|v| v.as_bool());
885 let reads = map
886 .get("reads")
887 .or_else(|| map.get("read"))
888 .and_then(|v| v.as_bool());
889 let writes = map
890 .get("writes")
891 .or_else(|| map.get("write"))
892 .and_then(|v| v.as_bool());
893 reads.or(writes).or(enabled).unwrap_or(false)
894 }
895 _ => false,
896 }
897}
898
899fn value_to_string(value: &Value) -> Option<String> {
900 match value {
901 Value::String(s) => Some(s.to_string()),
902 Value::Number(n) => Some(n.to_string()),
903 Value::Bool(b) => Some(b.to_string()),
904 _ => None,
905 }
906}
907
908fn validate_model_for_provider(
909 model: &str,
910 provider: &str,
911 field_name: &str,
912 allow_nano: bool,
913) -> Vec<String> {
914 let mut errors = Vec::new();
915
916 if model.trim().is_empty() {
917 errors.push(format!("Missing or empty {}", field_name));
918 return errors;
919 }
920
921 let provider_lower = provider.trim().to_lowercase();
922 let model_lower = model.trim().to_lowercase();
923 let model_without_prefix = if let Some((_, rest)) = model_lower.split_once('/') {
924 rest
925 } else {
926 model_lower.as_str()
927 };
928
929 if model_without_prefix == "gpt-5-pro" {
930 errors.push(format!(
931 "Model '{}' is not supported for prompt learning (too expensive).\n gpt-5-pro is excluded due to high cost ($15/$120 per 1M tokens).\n Please use a supported model instead.",
932 model
933 ));
934 return errors;
935 }
936
937 if !allow_nano && model_without_prefix.ends_with("-nano") {
938 errors.push(format!(
939 "Model '{}' is not supported for {}.\n ❌ Nano models (e.g., gpt-4.1-nano, gpt-5-nano) are NOT allowed for proposal/mutation models.\n \n Why?\n Proposal and mutation models need to be SMART and capable of generating high-quality,\n creative prompt variations. Nano models are too small and lack the reasoning capability\n needed for effective prompt optimization.\n \n ✅ Use a larger model instead:\n - For OpenAI: gpt-4.1-mini, gpt-4o-mini, gpt-4o, or gpt-4.1\n - For Groq: openai/gpt-oss-120b, llama-3.3-70b-versatile\n - For Google: gemini-2.5-flash, gemini-2.5-pro\n \n Note: Nano models ARE allowed for policy models (task execution), but NOT for\n proposal/mutation models (prompt generation).",
940 model, field_name
941 ));
942 return errors;
943 }
944
945 match provider_lower.as_str() {
946 "openai" => {
947 if !is_supported_openai_model(model_without_prefix) {
948 errors.push(format!(
949 "Unsupported OpenAI model: '{}'\n Supported OpenAI models for prompt learning:\n - gpt-4o\n - gpt-4o-mini\n - gpt-4.1, gpt-4.1-mini, gpt-4.1-nano\n - gpt-5, gpt-5-mini, gpt-5-nano\n - Image generation: gpt-image-1.5, gpt-image-1, gpt-image-1-mini, chatgpt-image-latest\n Note: gpt-5-pro is excluded (too expensive)\n Got: '{}'",
950 model, model
951 ));
952 }
953 }
954 "groq" => {
955 if !is_supported_groq_model(&model_lower) {
956 errors.push(format!(
957 "Unsupported Groq model: '{}'\n Supported Groq models for prompt learning:\n - gpt-oss-Xb (e.g., gpt-oss-20b, openai/gpt-oss-120b)\n - llama-3.3-70b (and variants like llama-3.3-70b-versatile)\n - llama-3.1-8b-instant\n - qwen/qwen3-32b (and variants)\n Got: '{}'",
958 model, model
959 ));
960 }
961 }
962 "google" => {
963 if !is_supported_google_model(model_without_prefix) {
964 errors.push(format!(
965 "Unsupported Google/Gemini model: '{}'\n Supported Google models for prompt learning:\n - gemini-2.5-pro, gemini-2.5-pro-gt200k\n - gemini-2.5-flash\n - gemini-2.5-flash-lite\n - Image generation: gemini-2.5-flash-image, gemini-3-pro-image-preview\n Got: '{}'",
966 model, model
967 ));
968 }
969 }
970 _ => {
971 errors.push(format!(
972 "Unsupported provider: '{}'\n Supported providers for prompt learning: 'openai', 'groq', 'google'\n Got: '{}'",
973 provider, provider
974 ));
975 }
976 }
977
978 errors
979}
980
981fn validate_adaptive_pool_config(
982 adaptive_pool_section: &Value,
983 prefix: &str,
984 errors: &mut Vec<String>,
985) {
986 let section = match adaptive_pool_section.as_object() {
987 Some(map) => map,
988 None => {
989 errors.push(format!("❌ {} must be a table/dict when provided", prefix));
990 return;
991 }
992 };
993
994 if let Some(level) = section.get("level") {
995 let level_str = value_to_string(level).unwrap_or_default().to_uppercase();
996 let valid = ["NONE", "LOW", "MODERATE", "HIGH"];
997 if !valid.contains(&level_str.as_str()) {
998 errors.push(format!(
999 "❌ {}.level must be one of {:?}, got '{}'",
1000 prefix, valid, level_str
1001 ));
1002 }
1003 }
1004
1005 for (field, min_val) in [
1006 ("anchor_size", 0),
1007 ("pool_init_size", 0),
1008 ("pool_min_size", 0),
1009 ("warmup_iters", 0),
1010 ("anneal_stop_iter", 0),
1011 ("pool_update_period", 1),
1012 ("min_evals_per_example", 1),
1013 ("k_info_prompts", 0),
1014 ] {
1015 if let Some(val) = section.get(field) {
1016 match parse_int(val) {
1017 Some(ival) => {
1018 if ival < min_val {
1019 errors.push(format!(
1020 "❌ {}.{} must be >= {}, got {}",
1021 prefix, field, min_val, ival
1022 ));
1023 }
1024 }
1025 None => {
1026 errors.push(format!(
1027 "❌ {}.{} must be an integer, got {}",
1028 prefix,
1029 field,
1030 value_type_name(val)
1031 ));
1032 }
1033 }
1034 }
1035 }
1036
1037 let pool_init = section.get("pool_init_size").and_then(parse_int);
1038 let pool_min = section.get("pool_min_size").and_then(parse_int);
1039 if let (Some(init), Some(min)) = (pool_init, pool_min) {
1040 if init < min {
1041 errors.push(format!(
1042 "❌ {}.pool_init_size ({}) must be >= pool_min_size ({})",
1043 prefix, init, min
1044 ));
1045 }
1046 }
1047
1048 let anchor_size = section.get("anchor_size").and_then(parse_int);
1049 if let (Some(min), Some(anchor)) = (pool_min, anchor_size) {
1050 if min < anchor {
1051 errors.push(format!(
1052 "❌ {}.pool_min_size ({}) must be >= anchor_size ({})",
1053 prefix, min, anchor
1054 ));
1055 }
1056 }
1057
1058 for (field, min_val, max_val) in [
1059 ("info_buffer_factor", 0.0, Some(1.0)),
1060 ("info_epsilon", 0.0, None),
1061 ] {
1062 if let Some(val) = section.get(field) {
1063 match parse_float(val) {
1064 Some(fval) => {
1065 if fval < min_val {
1066 errors.push(format!(
1067 "❌ {}.{} must be >= {}, got {}",
1068 prefix, field, min_val, fval
1069 ));
1070 }
1071 if let Some(max) = max_val {
1072 if fval > max {
1073 errors.push(format!(
1074 "❌ {}.{} must be <= {}, got {}",
1075 prefix, field, max, fval
1076 ));
1077 }
1078 }
1079 }
1080 None => {
1081 errors.push(format!(
1082 "❌ {}.{} must be numeric, got {}",
1083 prefix,
1084 field,
1085 value_type_name(val)
1086 ));
1087 }
1088 }
1089 }
1090 }
1091
1092 if let Some(val) = section.get("anchor_selection_method") {
1093 let method = value_to_string(val).unwrap_or_default();
1094 if !["random", "clustering"].contains(&method.as_str()) {
1095 errors.push(format!(
1096 "❌ {}.anchor_selection_method must be 'random' or 'clustering', got '{}'",
1097 prefix, method
1098 ));
1099 }
1100 }
1101
1102 if let Some(val) = section.get("exploration_strategy") {
1103 let method = value_to_string(val).unwrap_or_default();
1104 if !["random", "diversity"].contains(&method.as_str()) {
1105 errors.push(format!(
1106 "❌ {}.exploration_strategy must be 'random' or 'diversity', got '{}'",
1107 prefix, method
1108 ));
1109 }
1110 }
1111
1112 if let Some(val) = section.get("heatup_trigger") {
1113 let trigger = value_to_string(val).unwrap_or_default();
1114 if !["after_min_size", "immediate", "every_N_trials_after_min"].contains(&trigger.as_str())
1115 {
1116 errors.push(format!(
1117 "❌ {}.heatup_trigger must be 'after_min_size', 'immediate', or 'every_N_trials_after_min', got '{}'",
1118 prefix, trigger
1119 ));
1120 }
1121 }
1122
1123 if let Some(val) = section.get("heatup_schedule") {
1124 let schedule = value_to_string(val).unwrap_or_default();
1125 if !["repeat", "once"].contains(&schedule.as_str()) {
1126 errors.push(format!(
1127 "❌ {}.heatup_schedule must be 'repeat' or 'once', got '{}'",
1128 prefix, schedule
1129 ));
1130 }
1131 }
1132
1133 if let Some(val) = section.get("heatup_size") {
1134 match parse_int(val) {
1135 Some(ival) => {
1136 if ival <= 0 {
1137 errors.push(format!(
1138 "❌ {}.heatup_size must be > 0, got {}",
1139 prefix, ival
1140 ));
1141 }
1142 }
1143 None => {
1144 errors.push(format!(
1145 "❌ {}.heatup_size must be an integer, got {}",
1146 prefix,
1147 value_type_name(val)
1148 ));
1149 }
1150 }
1151 }
1152
1153 if let Some(val) = section.get("heatup_cooldown_trials") {
1154 match parse_int(val) {
1155 Some(ival) => {
1156 if ival < 0 {
1157 errors.push(format!(
1158 "❌ {}.heatup_cooldown_trials must be >= 0, got {}",
1159 prefix, ival
1160 ));
1161 }
1162 }
1163 None => {
1164 errors.push(format!(
1165 "❌ {}.heatup_cooldown_trials must be an integer, got {}",
1166 prefix,
1167 value_type_name(val)
1168 ));
1169 }
1170 }
1171 }
1172
1173 if let Some(val) = section.get("heatup_reserve_pool") {
1174 match val.as_array() {
1175 Some(list) => {
1176 if list.iter().any(|item| parse_int(item).is_none()) {
1177 errors.push(format!(
1178 "❌ {}.heatup_reserve_pool must contain only integers",
1179 prefix
1180 ));
1181 }
1182 }
1183 None => {
1184 errors.push(format!(
1185 "❌ {}.heatup_reserve_pool must be a list, got {}",
1186 prefix,
1187 value_type_name(val)
1188 ));
1189 }
1190 }
1191 }
1192}
1193
1194fn extract_pipeline_modules(initial_prompt: Option<&Value>) -> Vec<String> {
1195 let mut out = Vec::new();
1196 let initial_prompt = match initial_prompt.and_then(|v| v.as_object()) {
1197 Some(map) => map,
1198 None => return out,
1199 };
1200 let metadata = match initial_prompt.get("metadata").and_then(|v| v.as_object()) {
1201 Some(map) => map,
1202 None => return out,
1203 };
1204 let pipeline_modules = match metadata.get("pipeline_modules").and_then(|v| v.as_array()) {
1205 Some(arr) => arr,
1206 None => return out,
1207 };
1208
1209 for entry in pipeline_modules {
1210 if let Some(name) = entry.as_str() {
1211 let trimmed = name.trim();
1212 if !trimmed.is_empty() {
1213 out.push(trimmed.to_string());
1214 }
1215 continue;
1216 }
1217 if let Some(map) = entry.as_object() {
1218 let name = map
1219 .get("name")
1220 .or_else(|| map.get("module_id"))
1221 .or_else(|| map.get("stage_id"))
1222 .and_then(|v| v.as_str())
1223 .unwrap_or("")
1224 .trim()
1225 .to_string();
1226 if !name.is_empty() {
1227 out.push(name);
1228 }
1229 }
1230 }
1231
1232 out
1233}
1234
1235pub fn validate_prompt_learning_config_strict(config: &Value) -> Vec<String> {
1236 let mut errors: Vec<String> = Vec::new();
1237
1238 let config_map = match config.as_object() {
1239 Some(map) => map,
1240 None => {
1241 errors.push("Missing [prompt_learning] section in config. Expected: [prompt_learning] with algorithm, container_url, etc.".to_string());
1242 return errors;
1243 }
1244 };
1245
1246 let pl_section = match config_map.get("prompt_learning") {
1247 Some(Value::Object(map)) => map,
1248 Some(other) => {
1249 errors.push(format!(
1250 "[prompt_learning] must be a table/dict, got {}",
1251 value_type_name(other)
1252 ));
1253 return errors;
1254 }
1255 None => {
1256 errors.push(
1257 "Missing [prompt_learning] section in config. Expected: [prompt_learning] with algorithm, container_url, etc."
1258 .to_string(),
1259 );
1260 return errors;
1261 }
1262 };
1263
1264 let algorithm = pl_section.get("algorithm").and_then(|v| v.as_str());
1265 if algorithm.is_none() {
1266 errors.push(
1267 "Missing required field: prompt_learning.algorithm\n Must be one of: 'gepa', 'mipro'\n Example:\n [prompt_learning]\n algorithm = \"gepa\""
1268 .to_string(),
1269 );
1270 } else if !matches!(algorithm, Some("gepa") | Some("mipro")) {
1271 let algo = algorithm.unwrap_or_default();
1272 errors.push(format!(
1273 "Invalid algorithm: '{}'\n Must be one of: 'gepa', 'mipro'\n Got: '{}'",
1274 algo, algo
1275 ));
1276 }
1277
1278 let container_url = pl_section.get("container_url");
1279 let container_id = pl_section.get("container_id");
1280 if container_url.is_none() && container_id.is_none() {
1281 errors.push(
1282 "Missing required field: prompt_learning.container_url or prompt_learning.container_id\n Example:\n container_url = \"http://127.0.0.1:8102\"\n Or:\n container_id = \"container_abc123\""
1283 .to_string(),
1284 );
1285 } else if let Some(val) = container_url {
1286 if let Some(url) = val.as_str() {
1287 if !url.starts_with("http://") && !url.starts_with("https://") {
1288 errors.push(format!(
1289 "container_url must start with http:// or https://, got: '{}'",
1290 url
1291 ));
1292 }
1293 } else {
1294 errors.push(format!(
1295 "container_url must be a string, got {}",
1296 value_type_name(val)
1297 ));
1298 }
1299 }
1300 if let Some(val) = container_id {
1301 if let Some(id) = val.as_str() {
1302 if id.trim().is_empty() {
1303 errors.push("container_id cannot be empty when provided".to_string());
1304 }
1305 } else {
1306 errors.push(format!(
1307 "container_id must be a string, got {}",
1308 value_type_name(val)
1309 ));
1310 }
1311 }
1312
1313 if let Some(initial_prompt) = pl_section.get("initial_prompt") {
1314 if let Some(map) = initial_prompt.as_object() {
1315 if let Some(messages) = map.get("messages") {
1316 match messages.as_array() {
1317 Some(arr) => {
1318 if arr.is_empty() {
1319 errors.push("prompt_learning.initial_prompt.messages is empty (must have at least one message)".to_string());
1320 }
1321 }
1322 None => {
1323 errors.push(format!(
1324 "prompt_learning.initial_prompt.messages must be an array, got {}",
1325 value_type_name(messages)
1326 ));
1327 }
1328 }
1329 }
1330 } else {
1331 errors.push(format!(
1332 "prompt_learning.initial_prompt must be a table/dict, got {}",
1333 value_type_name(initial_prompt)
1334 ));
1335 }
1336 }
1337
1338 let policy = pl_section.get("policy");
1339 if let Some(Value::Object(policy_map)) = policy {
1340 let mode = policy_map
1341 .get("inference_mode")
1342 .and_then(|v| v.as_str())
1343 .unwrap_or("")
1344 .trim()
1345 .to_lowercase();
1346 if mode.is_empty() {
1347 errors.push(
1348 "Missing required field: prompt_learning.policy.inference_mode (must be 'synth_hosted')"
1349 .to_string(),
1350 );
1351 } else if mode != "synth_hosted" {
1352 errors.push(
1353 "prompt_learning.policy.inference_mode must be 'synth_hosted' (bring_your_own unsupported)"
1354 .to_string(),
1355 );
1356 }
1357
1358 let provider = policy_map
1359 .get("provider")
1360 .and_then(|v| v.as_str())
1361 .unwrap_or("")
1362 .trim()
1363 .to_string();
1364 let model = policy_map
1365 .get("model")
1366 .and_then(|v| v.as_str())
1367 .unwrap_or("")
1368 .trim()
1369 .to_string();
1370 if provider.is_empty() {
1371 errors.push("Missing required field: prompt_learning.policy.provider".to_string());
1372 }
1373 if model.is_empty() {
1374 errors.push("Missing required field: prompt_learning.policy.model".to_string());
1375 } else if !provider.is_empty() {
1376 errors.extend(validate_model_for_provider(
1377 &model,
1378 &provider,
1379 "prompt_learning.policy.model",
1380 true,
1381 ));
1382 }
1383
1384 for forbidden in ["inference_url", "api_base", "base_url"] {
1385 if policy_map.contains_key(forbidden) {
1386 errors.push(format!(
1387 "{} must not be specified in [prompt_learning.policy]. The trainer provides the inference URL in rollout requests. Remove {} from your config file.",
1388 forbidden, forbidden
1389 ));
1390 }
1391 }
1392 } else {
1393 errors.push("Missing [prompt_learning.policy] section or not a table".to_string());
1394 }
1395
1396 if let Some(proxy_models) = pl_section.get("proxy_models") {
1397 match proxy_models.as_object() {
1398 Some(map) => {
1399 for field in ["hi_provider", "hi_model", "lo_provider", "lo_model"] {
1400 if map
1401 .get(field)
1402 .and_then(|v| v.as_str())
1403 .unwrap_or("")
1404 .trim()
1405 .is_empty()
1406 {
1407 errors.push(format!(
1408 "prompt_learning.proxy_models.{} is required",
1409 field
1410 ));
1411 }
1412 }
1413 for (field, min_val) in [
1414 ("n_min_hi", 0.0),
1415 ("r2_thresh", 0.0),
1416 ("r2_stop", 0.0),
1417 ("sigma_max", 0.0),
1418 ("sigma_stop", 0.0),
1419 ("verify_every", 0.0),
1420 ] {
1421 if let Some(val) = map.get(field) {
1422 match parse_float(val) {
1423 Some(fval) => {
1424 if (field == "r2_thresh" || field == "r2_stop")
1425 && !(0.0..=1.0).contains(&fval)
1426 {
1427 errors.push(format!(
1428 "prompt_learning.proxy_models.{} must be between 0.0 and 1.0, got {}",
1429 field, fval
1430 ));
1431 } else if fval < min_val {
1432 errors.push(format!(
1433 "prompt_learning.proxy_models.{} must be >= {}, got {}",
1434 field, min_val, fval
1435 ));
1436 }
1437 }
1438 None => errors.push(format!(
1439 "prompt_learning.proxy_models.{} must be numeric, got {}",
1440 field,
1441 value_type_name(val)
1442 )),
1443 }
1444 }
1445 }
1446
1447 let hi_provider = map
1448 .get("hi_provider")
1449 .and_then(|v| v.as_str())
1450 .unwrap_or("");
1451 let hi_model = map.get("hi_model").and_then(|v| v.as_str()).unwrap_or("");
1452 if !hi_provider.is_empty() && !hi_model.is_empty() {
1453 errors.extend(validate_model_for_provider(
1454 hi_model,
1455 hi_provider,
1456 "prompt_learning.proxy_models.hi_model",
1457 true,
1458 ));
1459 }
1460
1461 let lo_provider = map
1462 .get("lo_provider")
1463 .and_then(|v| v.as_str())
1464 .unwrap_or("");
1465 let lo_model = map.get("lo_model").and_then(|v| v.as_str()).unwrap_or("");
1466 if !lo_provider.is_empty() && !lo_model.is_empty() {
1467 errors.extend(validate_model_for_provider(
1468 lo_model,
1469 lo_provider,
1470 "prompt_learning.proxy_models.lo_model",
1471 true,
1472 ));
1473 }
1474 }
1475 None => errors.push(format!(
1476 "prompt_learning.proxy_models must be a table/dict, got {}",
1477 value_type_name(proxy_models)
1478 )),
1479 }
1480 }
1481
1482 if let Some(verifier) = pl_section.get("verifier") {
1483 match verifier.as_object() {
1484 Some(map) => {
1485 let reward_source = map
1486 .get("reward_source")
1487 .and_then(|v| v.as_str())
1488 .unwrap_or("container")
1489 .trim()
1490 .to_lowercase();
1491 if !reward_source.is_empty()
1492 && !matches!(reward_source.as_str(), "container" | "verifier" | "fused")
1493 {
1494 errors.push(
1495 "prompt_learning.verifier.reward_source must be 'container', 'verifier', or 'fused'"
1496 .to_string(),
1497 );
1498 }
1499 if reward_source == "fused" {
1500 let weight_event = map.get("weight_event");
1501 let weight_outcome = map.get("weight_outcome");
1502 let weight_event_f = weight_event.and_then(parse_float);
1503 let weight_outcome_f = weight_outcome.and_then(parse_float);
1504 if weight_event.is_some() && weight_event_f.is_none() {
1505 errors.push(
1506 "prompt_learning.verifier.weight_event must be numeric".to_string(),
1507 );
1508 }
1509 if weight_outcome.is_some() && weight_outcome_f.is_none() {
1510 errors.push(
1511 "prompt_learning.verifier.weight_outcome must be numeric".to_string(),
1512 );
1513 }
1514 if weight_event_f.unwrap_or(0.0) <= 0.0
1515 && weight_outcome_f.unwrap_or(0.0) <= 0.0
1516 {
1517 errors.push(
1518 "prompt_learning.verifier.reward_source='fused' requires weight_event > 0 or weight_outcome > 0"
1519 .to_string(),
1520 );
1521 }
1522 }
1523 }
1524 None => errors.push(format!(
1525 "prompt_learning.verifier must be a table/dict, got {}",
1526 value_type_name(verifier)
1527 )),
1528 }
1529 }
1530
1531 let pipeline_modules = extract_pipeline_modules(pl_section.get("initial_prompt"));
1532 let has_multi_stage = !pipeline_modules.is_empty();
1533
1534 match algorithm {
1535 Some("gepa") => {
1536 let gepa_config = pl_section.get("gepa");
1537 let gepa_map = match gepa_config.and_then(|v| v.as_object()) {
1538 Some(map) => map,
1539 None => {
1540 errors.push(
1541 "Missing [prompt_learning.gepa] section for GEPA algorithm".to_string(),
1542 );
1543 return errors;
1544 }
1545 };
1546
1547 if has_multi_stage {
1548 let modules_config = gepa_map.get("modules");
1549 match modules_config.and_then(|v| v.as_array()) {
1550 Some(arr) if !arr.is_empty() => {
1551 let mut module_ids = HashSet::new();
1552 for module in arr {
1553 if let Some(map) = module.as_object() {
1554 if let Some(id) = map
1555 .get("module_id")
1556 .or_else(|| map.get("stage_id"))
1557 .and_then(|v| v.as_str())
1558 {
1559 module_ids.insert(id.trim().to_string());
1560 }
1561 }
1562 }
1563 let pipeline_set: HashSet<String> =
1564 pipeline_modules.iter().cloned().collect();
1565 let missing: Vec<String> =
1566 pipeline_set.difference(&module_ids).cloned().collect();
1567 if !missing.is_empty() {
1568 errors.push(format!(
1569 "Pipeline modules {:?} are missing from [prompt_learning.gepa.modules]. Each pipeline module must have a corresponding module config with matching module_id.",
1570 missing
1571 ));
1572 }
1573 }
1574 _ => {
1575 errors.push(format!(
1576 "GEPA multi-stage pipeline detected (found {} modules in prompt_learning.initial_prompt.metadata.pipeline_modules), but [prompt_learning.gepa.modules] is missing or empty. Define module configs for each pipeline stage.",
1577 pipeline_modules.len()
1578 ));
1579 }
1580 }
1581 }
1582
1583 let pos_int = |name: &str, errors: &mut Vec<String>| {
1584 if let Some(val) = gepa_map.get(name) {
1585 match parse_int(val) {
1586 Some(ival) => {
1587 if ival <= 0 {
1588 errors.push(format!("prompt_learning.gepa.{} must be > 0", name));
1589 }
1590 }
1591 None => {
1592 errors.push(format!("prompt_learning.gepa.{} must be an integer", name))
1593 }
1594 }
1595 }
1596 };
1597 let non_neg_int = |name: &str, errors: &mut Vec<String>| {
1598 if let Some(val) = gepa_map.get(name) {
1599 match parse_int(val) {
1600 Some(ival) => {
1601 if ival < 0 {
1602 errors.push(format!("prompt_learning.gepa.{} must be >= 0", name));
1603 }
1604 }
1605 None => {
1606 errors.push(format!("prompt_learning.gepa.{} must be an integer", name))
1607 }
1608 }
1609 }
1610 };
1611 let rate_float = |name: &str, errors: &mut Vec<String>| {
1612 if let Some(val) = gepa_map.get(name) {
1613 match parse_float(val) {
1614 Some(fval) => {
1615 if !(0.0..=1.0).contains(&fval) {
1616 errors.push(format!(
1617 "prompt_learning.gepa.{} must be between 0.0 and 1.0",
1618 name
1619 ));
1620 }
1621 }
1622 None => {
1623 errors.push(format!("prompt_learning.gepa.{} must be numeric", name))
1624 }
1625 }
1626 }
1627 };
1628 let pos_float = |name: &str, errors: &mut Vec<String>| {
1629 if let Some(val) = gepa_map.get(name) {
1630 match parse_float(val) {
1631 Some(fval) => {
1632 if fval <= 0.0 {
1633 errors.push(format!("prompt_learning.gepa.{} must be > 0", name));
1634 }
1635 }
1636 None => {
1637 errors.push(format!("prompt_learning.gepa.{} must be numeric", name))
1638 }
1639 }
1640 }
1641 };
1642 let pos_int_nested = |section: &str, name: &str, errors: &mut Vec<String>| {
1643 if let Some(Value::Object(section_map)) = gepa_map.get(section) {
1644 if let Some(val) = section_map.get(name) {
1645 match parse_int(val) {
1646 Some(ival) => {
1647 if ival <= 0 {
1648 errors.push(format!(
1649 "prompt_learning.gepa.{}.{} must be > 0",
1650 section, name
1651 ));
1652 }
1653 }
1654 None => errors.push(format!(
1655 "prompt_learning.gepa.{}.{} must be an integer",
1656 section, name
1657 )),
1658 }
1659 }
1660 }
1661 };
1662
1663 for fld in [
1664 "initial_population_size",
1665 "num_generations",
1666 "children_per_generation",
1667 "max_concurrent_rollouts",
1668 ] {
1669 pos_int(fld, &mut errors);
1670 }
1671 pos_int_nested("rollout", "budget", &mut errors);
1672 pos_int_nested("rollout", "max_concurrent", &mut errors);
1673 pos_int_nested("rollout", "minibatch_size", &mut errors);
1674 pos_int_nested("population", "initial_size", &mut errors);
1675 pos_int_nested("population", "num_generations", &mut errors);
1676 pos_int_nested("population", "children_per_generation", &mut errors);
1677 rate_float("mutation_rate", &mut errors);
1678 rate_float("crossover_rate", &mut errors);
1679 pos_float("selection_pressure", &mut errors);
1680 if let Some(val) = gepa_map.get("selection_pressure") {
1681 if let Some(sp) = parse_float(val) {
1682 if sp < 1.0 {
1683 errors.push(
1684 "prompt_learning.gepa.selection_pressure must be >= 1.0".to_string(),
1685 );
1686 }
1687 }
1688 }
1689 non_neg_int("patience_generations", &mut errors);
1690 pos_int_nested("archive", "size", &mut errors);
1691 pos_int_nested("archive", "pareto_set_size", &mut errors);
1692 pos_float("pareto_eps", &mut errors);
1693 rate_float("feedback_fraction", &mut errors);
1694
1695 if let Some(Value::Object(mutation)) = gepa_map.get("mutation") {
1696 let mutation_model = mutation.get("llm_model").and_then(|v| v.as_str());
1697 let mutation_provider = mutation
1698 .get("llm_provider")
1699 .and_then(|v| v.as_str())
1700 .unwrap_or("")
1701 .trim()
1702 .to_string();
1703 if let Some(model) = mutation_model {
1704 if mutation_provider.is_empty() {
1705 errors.push(
1706 "Missing required field: prompt_learning.gepa.mutation.llm_provider\n Required when prompt_learning.gepa.mutation.llm_model is set"
1707 .to_string(),
1708 );
1709 } else {
1710 errors.extend(validate_model_for_provider(
1711 model,
1712 &mutation_provider,
1713 "prompt_learning.gepa.mutation.llm_model",
1714 false,
1715 ));
1716 }
1717 }
1718 }
1719
1720 if let Some(val) = gepa_map.get("max_spend_usd") {
1721 match parse_float(val) {
1722 Some(fval) => {
1723 if fval <= 0.0 {
1724 errors.push(
1725 "prompt_learning.gepa.max_spend_usd must be > 0 when provided"
1726 .to_string(),
1727 );
1728 }
1729 }
1730 None => errors
1731 .push("prompt_learning.gepa.max_spend_usd must be numeric".to_string()),
1732 }
1733 }
1734
1735 let rollout_budget = gepa_map
1736 .get("rollout")
1737 .and_then(|v| v.get("budget"))
1738 .or_else(|| gepa_map.get("rollout_budget"));
1739 if let Some(val) = rollout_budget {
1740 match parse_int(val) {
1741 Some(ival) => {
1742 if ival <= 0 {
1743 errors.push("prompt_learning.gepa.rollout.budget (or rollout_budget) must be > 0 when provided".to_string());
1744 }
1745 }
1746 None => errors.push("prompt_learning.gepa.rollout.budget (or rollout_budget) must be an integer".to_string()),
1747 }
1748 }
1749
1750 let minibatch_size = gepa_map
1751 .get("rollout")
1752 .and_then(|v| v.get("minibatch_size"))
1753 .or_else(|| gepa_map.get("minibatch_size"));
1754 if let Some(val) = minibatch_size {
1755 match parse_int(val) {
1756 Some(ival) => {
1757 if ival <= 0 {
1758 errors.push("prompt_learning.gepa.rollout.minibatch_size (or minibatch_size) must be > 0".to_string());
1759 }
1760 }
1761 None => errors.push("prompt_learning.gepa.rollout.minibatch_size (or minibatch_size) must be an integer".to_string()),
1762 }
1763 }
1764
1765 let proposer_backend = gepa_map
1766 .get("proposer_backend")
1767 .and_then(|v| v.as_str())
1768 .unwrap_or("prompt")
1769 .trim()
1770 .to_lowercase();
1771 if !matches!(proposer_backend.as_str(), "prompt" | "rlm" | "agent") {
1772 errors.push(format!(
1773 "Invalid proposer_backend: '{}'\n Must be one of: 'prompt', 'rlm', 'agent'\n Got: '{}'",
1774 proposer_backend, proposer_backend
1775 ));
1776 }
1777 let proposer_prompt_strategy = gepa_map
1781 .get("proposer")
1782 .and_then(|v| v.as_object())
1783 .and_then(|proposer_map| proposer_map.get("prompt"))
1784 .and_then(|v| v.as_object())
1785 .and_then(|prompt_map| prompt_map.get("strategy"))
1786 .and_then(|v| v.as_str());
1787 let proposer_type_raw = proposer_prompt_strategy
1788 .or_else(|| gepa_map.get("proposer_type").and_then(|v| v.as_str()))
1789 .unwrap_or("dspy");
1790 let proposer_type = match proposer_type_raw {
1791 "gepa_ai" => "gepa-ai",
1792 "builtin" => "synth",
1793 other => other,
1794 };
1795 if !matches!(proposer_type, "dspy" | "spec" | "synth" | "gepa-ai") {
1796 errors.push(format!(
1797 "Invalid proposer_type: '{}'\n Must be one of: 'dspy', 'spec', 'synth', 'gepa-ai'\n Got: '{}'",
1798 proposer_type, proposer_type
1799 ));
1800 }
1801
1802 let proposer_effort = gepa_map
1803 .get("proposer_effort")
1804 .and_then(|v| v.as_str())
1805 .unwrap_or("LOW")
1806 .to_uppercase();
1807 let valid_effort = [
1808 "LOW_CONTEXT",
1809 "LOW",
1810 "MEDIUM",
1811 "HIGH",
1812 "GEMINI",
1813 "GEMINI_PRO",
1814 ];
1815 if !valid_effort.contains(&proposer_effort.as_str()) {
1816 errors.push(format!(
1817 "Invalid proposer_effort: '{}'\n Must be one of: {}\n Got: '{}'",
1818 proposer_effort,
1819 valid_effort.join(", "),
1820 proposer_effort
1821 ));
1822 }
1823
1824 let proposer_output_tokens = gepa_map
1825 .get("proposer_output_tokens")
1826 .and_then(|v| v.as_str())
1827 .unwrap_or("FAST")
1828 .to_uppercase();
1829 let valid_output = ["RAPID", "FAST", "SLOW"];
1830 if !valid_output.contains(&proposer_output_tokens.as_str()) {
1831 errors.push(format!(
1832 "Invalid proposer_output_tokens: '{}'\n Must be one of: {}\n Got: '{}'",
1833 proposer_output_tokens,
1834 valid_output.join(", "),
1835 proposer_output_tokens
1836 ));
1837 }
1838
1839 if proposer_type == "spec" {
1840 if gepa_map
1841 .get("spec_path")
1842 .and_then(|v| v.as_str())
1843 .unwrap_or("")
1844 .is_empty()
1845 {
1846 errors.push(
1847 "Missing required field: prompt_learning.gepa.spec_path\n Required when proposer_type='spec'\n Example:\n [prompt_learning.gepa]\n proposer_type = \"spec\"\n spec_path = \"examples/containers/banking77/banking77_spec.json\""
1848 .to_string(),
1849 );
1850 } else {
1851 if let Some(val) = gepa_map.get("spec_max_tokens") {
1852 match parse_int(val) {
1853 Some(ival) => {
1854 if ival <= 0 {
1855 errors.push(
1856 "prompt_learning.gepa.spec_max_tokens must be > 0"
1857 .to_string(),
1858 );
1859 }
1860 }
1861 None => errors.push(
1862 "prompt_learning.gepa.spec_max_tokens must be an integer"
1863 .to_string(),
1864 ),
1865 }
1866 }
1867 if let Some(val) = gepa_map.get("spec_priority_threshold") {
1868 match parse_int(val) {
1869 Some(ival) => {
1870 if ival < 0 {
1871 errors.push(
1872 "prompt_learning.gepa.spec_priority_threshold must be >= 0"
1873 .to_string(),
1874 );
1875 }
1876 }
1877 None => errors.push(
1878 "prompt_learning.gepa.spec_priority_threshold must be an integer"
1879 .to_string(),
1880 ),
1881 }
1882 }
1883 }
1884 }
1885
1886 let archive_size = gepa_map
1887 .get("archive")
1888 .and_then(|v| v.get("size"))
1889 .or_else(|| gepa_map.get("archive_size"));
1890 if let Some(val) = archive_size {
1891 match parse_int(val) {
1892 Some(ival) => {
1893 if ival <= 0 {
1894 errors.push(
1895 "prompt_learning.gepa.archive.size (or archive_size) must be > 0"
1896 .to_string(),
1897 );
1898 }
1899 }
1900 None => errors.push(
1901 "prompt_learning.gepa.archive.size (or archive_size) must be an integer"
1902 .to_string(),
1903 ),
1904 }
1905 }
1906
1907 let eval_config = gepa_map.get("evaluation").and_then(|v| v.as_object());
1908 if let Some(eval_map) = eval_config {
1909 let train_seeds = eval_map
1910 .get("seeds")
1911 .or_else(|| eval_map.get("train_seeds"))
1912 .and_then(|v| v.as_array());
1913 if let Some(seeds_list) = train_seeds {
1914 if !seeds_list.is_empty() {
1915 let total_seeds = seeds_list.len();
1916 let pareto_set_size = gepa_map
1917 .get("archive")
1918 .and_then(|v| v.get("pareto_set_size"))
1919 .or_else(|| gepa_map.get("pareto_set_size"))
1920 .and_then(parse_int)
1921 .unwrap_or(64);
1922 let feedback_fraction = gepa_map
1923 .get("archive")
1924 .and_then(|v| v.get("feedback_fraction"))
1925 .or_else(|| gepa_map.get("feedback_fraction"))
1926 .and_then(parse_float)
1927 .unwrap_or(0.5);
1928 let _ = feedback_fraction;
1929
1930 let feedback_count = total_seeds as i64 - pareto_set_size;
1931 let min_pareto_set_size = 10;
1932 let min_feedback_seeds = 3;
1933
1934 if pareto_set_size > total_seeds as i64 {
1935 errors.push(format!(
1936 "CONFIG ERROR: pareto_set_size={} > total_seeds={}. Increase [prompt_learning.gepa.evaluation].seeds or decrease [prompt_learning.gepa.archive].pareto_set_size. Seeds: {:?}{}",
1937 pareto_set_size,
1938 total_seeds,
1939 seeds_list.iter().take(10).filter_map(value_to_string).collect::<Vec<_>>(),
1940 if seeds_list.len() > 10 { "..." } else { "" }
1941 ));
1942 }
1943 if pareto_set_size < min_pareto_set_size {
1944 errors.push(format!(
1945 "CONFIG ERROR: pareto_set_size={} < MIN_PARETO_SET_SIZE={}. Increase [prompt_learning.gepa.archive].pareto_set_size to at least {}. Below this threshold, accuracy estimates are too noisy for reliable optimization.",
1946 pareto_set_size, min_pareto_set_size, min_pareto_set_size
1947 ));
1948 }
1949 if feedback_count < min_feedback_seeds {
1950 errors.push(format!(
1951 "CONFIG ERROR: feedback_count={} < MIN_FEEDBACK_SEEDS={}. Increase total seeds or decrease pareto_set_size to ensure at least {} feedback seeds. Below this threshold, reflection prompts lack sufficient diversity.",
1952 feedback_count, min_feedback_seeds, min_feedback_seeds
1953 ));
1954 }
1955 }
1956 }
1957 }
1958
1959 let pareto_eps = gepa_map
1960 .get("archive")
1961 .and_then(|v| v.get("pareto_eps"))
1962 .or_else(|| gepa_map.get("pareto_eps"));
1963 if let Some(val) = pareto_eps {
1964 match parse_float(val) {
1965 Some(fval) => {
1966 if fval <= 0.0 {
1967 errors.push("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) must be > 0".to_string());
1968 } else if fval >= 1.0 {
1969 errors.push("prompt_learning.gepa.archive.pareto_eps (or pareto_eps) should be < 1.0 (typically 1e-6)".to_string());
1970 }
1971 }
1972 None => errors.push(
1973 "prompt_learning.gepa.archive.pareto_eps (or pareto_eps) must be numeric"
1974 .to_string(),
1975 ),
1976 }
1977 }
1978
1979 let feedback_fraction = gepa_map
1980 .get("archive")
1981 .and_then(|v| v.get("feedback_fraction"))
1982 .or_else(|| gepa_map.get("feedback_fraction"));
1983 if let Some(val) = feedback_fraction {
1984 match parse_float(val) {
1985 Some(fval) => {
1986 if !(0.0..=1.0).contains(&fval) {
1987 errors.push("prompt_learning.gepa.archive.feedback_fraction (or feedback_fraction) must be between 0.0 and 1.0".to_string());
1988 }
1989 }
1990 None => errors.push("prompt_learning.gepa.archive.feedback_fraction (or feedback_fraction) must be numeric".to_string()),
1991 }
1992 }
1993
1994 let token_config = gepa_map
1995 .get("token")
1996 .or_else(|| gepa_map.get("prompt_budget"));
1997 let token_counting_model = token_config
1998 .and_then(|v| v.get("counting_model"))
1999 .or_else(|| gepa_map.get("token_counting_model"));
2000 if let Some(val) = token_counting_model {
2001 let ok = val.as_str().map(|s| !s.trim().is_empty()).unwrap_or(false);
2002 if !ok {
2003 errors.push("prompt_learning.gepa.token.counting_model (or prompt_budget.counting_model, token_counting_model) must be a non-empty string".to_string());
2004 }
2005 }
2006
2007 if has_multi_stage {
2008 if let Some(Value::Array(modules)) = gepa_map.get("modules") {
2009 for (idx, module_entry) in modules.iter().enumerate() {
2010 if let Some(map) = module_entry.as_object() {
2011 if let Some(val) = map.get("max_instruction_slots") {
2012 match parse_int(val) {
2013 Some(ival) => {
2014 if ival < 1 {
2015 errors.push(format!(
2016 "prompt_learning.gepa.modules[{}].max_instruction_slots must be >= 1",
2017 idx
2018 ));
2019 }
2020 }
2021 None => errors.push(format!(
2022 "prompt_learning.gepa.modules[{}].max_instruction_slots must be an integer",
2023 idx
2024 )),
2025 }
2026 }
2027 if let Some(val) = map.get("max_tokens") {
2028 match parse_int(val) {
2029 Some(ival) => {
2030 if ival <= 0 {
2031 errors.push(format!(
2032 "prompt_learning.gepa.modules[{}].max_tokens must be > 0",
2033 idx
2034 ));
2035 }
2036 }
2037 None => errors.push(format!(
2038 "prompt_learning.gepa.modules[{}].max_tokens must be an integer",
2039 idx
2040 )),
2041 }
2042 }
2043 if let Some(val) = map.get("allowed_tools") {
2044 match val.as_array() {
2045 Some(tools) => {
2046 if tools.is_empty() {
2047 errors.push(format!(
2048 "prompt_learning.gepa.modules[{}].allowed_tools cannot be empty (use null/omit to allow all tools)",
2049 idx
2050 ));
2051 } else {
2052 let mut seen = HashSet::new();
2053 for (tool_idx, tool) in tools.iter().enumerate() {
2054 let name = tool.as_str().unwrap_or("").trim().to_string();
2055 if name.is_empty() {
2056 errors.push(format!(
2057 "prompt_learning.gepa.modules[{}].allowed_tools[{}] cannot be empty",
2058 idx, tool_idx
2059 ));
2060 } else if seen.contains(&name) {
2061 errors.push(format!(
2062 "prompt_learning.gepa.modules[{}].allowed_tools contains duplicate '{}'",
2063 idx, name
2064 ));
2065 } else {
2066 seen.insert(name);
2067 }
2068 }
2069 }
2070 }
2071 None => errors.push(format!(
2072 "prompt_learning.gepa.modules[{}].allowed_tools must be a list",
2073 idx
2074 )),
2075 }
2076 }
2077 let module_policy = map.get("policy");
2078 match module_policy {
2079 None => errors.push(format!(
2080 "❌ gepa.modules[{}]: [policy] table is REQUIRED. Each module must have its own policy configuration with 'model' and 'provider' fields.",
2081 idx
2082 )),
2083 Some(Value::Object(policy_map)) => {
2084 if policy_map
2085 .get("provider")
2086 .and_then(|v| v.as_str())
2087 .unwrap_or("")
2088 .trim()
2089 .is_empty()
2090 {
2091 errors.push(format!(
2092 "❌ gepa.modules[{}]: [policy].provider is required",
2093 idx
2094 ));
2095 }
2096 let module_model = policy_map.get("model").and_then(|v| v.as_str());
2097 let module_provider = policy_map.get("provider").and_then(|v| v.as_str());
2098 if let (Some(model), Some(provider)) = (module_model, module_provider)
2099 {
2100 errors.extend(validate_model_for_provider(
2101 model,
2102 provider,
2103 &format!(
2104 "prompt_learning.gepa.modules[{}].policy.model",
2105 idx
2106 ),
2107 true,
2108 ));
2109 }
2110 for forbidden in ["inference_url", "api_base", "base_url"] {
2111 if policy_map.contains_key(forbidden) {
2112 errors.push(format!(
2113 "❌ gepa.modules[{}]: [policy].{} must not be specified. The trainer provides the inference URL in rollout requests. Remove {} from module policy.",
2114 idx, forbidden, forbidden
2115 ));
2116 }
2117 }
2118 }
2119 Some(other) => errors.push(format!(
2120 "❌ gepa.modules[{}]: [policy] must be a table/dict, got {}",
2121 idx,
2122 value_type_name(other)
2123 )),
2124 }
2125 }
2126 }
2127 }
2128 }
2129 }
2130 Some("mipro") => {
2131 let mipro_config = pl_section.get("mipro");
2132 let mipro_map = match mipro_config.and_then(|v| v.as_object()) {
2133 Some(map) => map,
2134 None => {
2135 errors.push(
2136 "Missing [prompt_learning.mipro] section for MIPRO algorithm".to_string(),
2137 );
2138 return errors;
2139 }
2140 };
2141
2142 let pos_int = |name: &str, errors: &mut Vec<String>| {
2143 if let Some(val) = mipro_map.get(name) {
2144 match parse_int(val) {
2145 Some(ival) => {
2146 if ival <= 0 {
2147 errors.push(format!("prompt_learning.mipro.{} must be > 0", name));
2148 }
2149 }
2150 None => errors
2151 .push(format!("prompt_learning.mipro.{} must be an integer", name)),
2152 }
2153 }
2154 };
2155 for fld in [
2156 "num_iterations",
2157 "num_evaluations_per_iteration",
2158 "batch_size",
2159 "max_concurrent",
2160 ] {
2161 pos_int(fld, &mut errors);
2162 }
2163 for fld in [
2164 "max_demo_set_size",
2165 "max_demo_sets",
2166 "max_instruction_sets",
2167 "full_eval_every_k",
2168 "instructions_per_batch",
2169 "max_instructions",
2170 "duplicate_retry_limit",
2171 ] {
2172 pos_int(fld, &mut errors);
2173 }
2174
2175 if let Some(meta_model) = mipro_map.get("meta_model").and_then(|v| v.as_str()) {
2176 let provider = mipro_map
2177 .get("meta_model_provider")
2178 .and_then(|v| v.as_str())
2179 .unwrap_or("")
2180 .trim()
2181 .to_string();
2182 if provider.is_empty() {
2183 errors.push(
2184 "Missing required field: prompt_learning.mipro.meta_model_provider\n Required when prompt_learning.mipro.meta_model is set"
2185 .to_string(),
2186 );
2187 } else {
2188 errors.extend(validate_model_for_provider(
2189 meta_model,
2190 &provider,
2191 "prompt_learning.mipro.meta_model",
2192 false,
2193 ));
2194 }
2195 }
2196
2197 if let Some(val) = mipro_map.get("meta_model_temperature") {
2198 match parse_float(val) {
2199 Some(fval) => {
2200 if fval < 0.0 {
2201 errors.push(
2202 "prompt_learning.mipro.meta_model_temperature must be >= 0.0"
2203 .to_string(),
2204 );
2205 }
2206 }
2207 None => errors.push(
2208 "prompt_learning.mipro.meta_model_temperature must be numeric".to_string(),
2209 ),
2210 }
2211 }
2212 if let Some(val) = mipro_map.get("meta_model_max_tokens") {
2213 match parse_int(val) {
2214 Some(ival) => {
2215 if ival <= 0 {
2216 errors.push(
2217 "prompt_learning.mipro.meta_model_max_tokens must be > 0"
2218 .to_string(),
2219 );
2220 }
2221 }
2222 None => errors.push(
2223 "prompt_learning.mipro.meta_model_max_tokens must be an integer"
2224 .to_string(),
2225 ),
2226 }
2227 }
2228
2229 if let Some(val) = mipro_map.get("generate_at_iterations") {
2230 match val.as_array() {
2231 Some(arr) => {
2232 for (idx, item) in arr.iter().enumerate() {
2233 match parse_int(item) {
2234 Some(ival) => {
2235 if ival < 0 {
2236 errors.push(format!(
2237 "prompt_learning.mipro.generate_at_iterations[{}] must be >= 0",
2238 idx
2239 ));
2240 }
2241 }
2242 None => errors.push(format!(
2243 "prompt_learning.mipro.generate_at_iterations[{}] must be an integer",
2244 idx
2245 )),
2246 }
2247 }
2248 }
2249 None => errors.push(
2250 "prompt_learning.mipro.generate_at_iterations must be a list".to_string(),
2251 ),
2252 }
2253 }
2254
2255 if mipro_map
2256 .get("spec_path")
2257 .and_then(|v| v.as_str())
2258 .is_some()
2259 {
2260 if let Some(val) = mipro_map.get("spec_max_tokens") {
2261 match parse_int(val) {
2262 Some(ival) => {
2263 if ival <= 0 {
2264 errors.push(
2265 "prompt_learning.mipro.spec_max_tokens must be > 0".to_string(),
2266 );
2267 }
2268 }
2269 None => errors.push(
2270 "prompt_learning.mipro.spec_max_tokens must be an integer".to_string(),
2271 ),
2272 }
2273 }
2274 if let Some(val) = mipro_map.get("spec_priority_threshold") {
2275 match parse_int(val) {
2276 Some(ival) => {
2277 if ival < 0 {
2278 errors.push(
2279 "prompt_learning.mipro.spec_priority_threshold must be >= 0"
2280 .to_string(),
2281 );
2282 }
2283 }
2284 None => errors.push(
2285 "prompt_learning.mipro.spec_priority_threshold must be an integer"
2286 .to_string(),
2287 ),
2288 }
2289 }
2290 }
2291
2292 if let Some(modules) = mipro_map.get("modules").and_then(|v| v.as_array()) {
2293 let max_instruction_sets = mipro_map
2294 .get("max_instruction_sets")
2295 .and_then(parse_int)
2296 .unwrap_or(128);
2297 let max_demo_sets = mipro_map
2298 .get("max_demo_sets")
2299 .and_then(parse_int)
2300 .unwrap_or(128);
2301 let mut seen_module_ids = HashSet::new();
2302 let mut seen_stage_ids = HashSet::new();
2303
2304 for (module_idx, module_entry) in modules.iter().enumerate() {
2305 let module_map = match module_entry.as_object() {
2306 Some(map) => map,
2307 None => {
2308 errors.push(format!(
2309 "prompt_learning.mipro.modules[{}] must be a table/dict",
2310 module_idx
2311 ));
2312 continue;
2313 }
2314 };
2315
2316 let module_id = module_map
2317 .get("module_id")
2318 .or_else(|| module_map.get("id"))
2319 .and_then(|v| v.as_str())
2320 .unwrap_or(&format!("module_{}", module_idx))
2321 .to_string();
2322 if !seen_module_ids.insert(module_id.clone()) {
2323 errors.push(format!(
2324 "Duplicate module_id '{}' in prompt_learning.mipro.modules",
2325 module_id
2326 ));
2327 }
2328
2329 let stages = module_map.get("stages");
2330 if let Some(stages_val) = stages {
2331 match stages_val.as_array() {
2332 Some(stage_list) => {
2333 for (stage_idx, stage_entry) in stage_list.iter().enumerate() {
2334 if let Some(stage_map) = stage_entry.as_object() {
2335 let stage_id = stage_map
2336 .get("stage_id")
2337 .or_else(|| stage_map.get("module_stage_id"))
2338 .and_then(|v| v.as_str())
2339 .unwrap_or(&format!("stage_{}", stage_idx))
2340 .to_string();
2341 if !seen_stage_ids.insert(stage_id.clone()) {
2342 errors.push(format!(
2343 "Duplicate stage_id '{}' across modules",
2344 stage_id
2345 ));
2346 }
2347 if let Some(val) = stage_map.get("max_instruction_slots") {
2348 match parse_int(val) {
2349 Some(ival) => {
2350 if ival < 1 {
2351 errors.push(format!(
2352 "prompt_learning.mipro.modules[{}].stages[{}].max_instruction_slots must be >= 1",
2353 module_idx, stage_idx
2354 ));
2355 } else if ival > max_instruction_sets {
2356 errors.push(format!(
2357 "prompt_learning.mipro.modules[{}].stages[{}].max_instruction_slots ({}) exceeds max_instruction_sets ({})",
2358 module_idx, stage_idx, ival, max_instruction_sets
2359 ));
2360 }
2361 }
2362 None => errors.push(format!(
2363 "prompt_learning.mipro.modules[{}].stages[{}].max_instruction_slots must be an integer",
2364 module_idx, stage_idx
2365 )),
2366 }
2367 }
2368 if let Some(val) = stage_map.get("max_demo_slots") {
2369 match parse_int(val) {
2370 Some(ival) => {
2371 if ival < 0 {
2372 errors.push(format!(
2373 "prompt_learning.mipro.modules[{}].stages[{}].max_demo_slots must be >= 0",
2374 module_idx, stage_idx
2375 ));
2376 } else if ival > max_demo_sets {
2377 errors.push(format!(
2378 "prompt_learning.mipro.modules[{}].stages[{}].max_demo_slots ({}) exceeds max_demo_sets ({})",
2379 module_idx, stage_idx, ival, max_demo_sets
2380 ));
2381 }
2382 }
2383 None => errors.push(format!(
2384 "prompt_learning.mipro.modules[{}].stages[{}].max_demo_slots must be an integer",
2385 module_idx, stage_idx
2386 )),
2387 }
2388 }
2389 }
2390 }
2391 }
2392 None => errors.push(format!(
2393 "prompt_learning.mipro.modules[{}].stages must be a list",
2394 module_idx
2395 )),
2396 }
2397 }
2398
2399 if let Some(edges_val) = module_map.get("edges") {
2400 match edges_val.as_array() {
2401 Some(edges) => {
2402 let mut stage_ids_in_module = HashSet::new();
2403 if let Some(Value::Array(stage_list)) = stages {
2404 for stage_entry in stage_list {
2405 if let Some(stage_map) = stage_entry.as_object() {
2406 if let Some(id) = stage_map
2407 .get("stage_id")
2408 .or_else(|| stage_map.get("module_stage_id"))
2409 .and_then(|v| v.as_str())
2410 {
2411 stage_ids_in_module.insert(id.to_string());
2412 }
2413 }
2414 }
2415 }
2416 for (edge_idx, edge) in edges.iter().enumerate() {
2417 let (source, target) = if let Some(arr) = edge.as_array() {
2418 if arr.len() == 2 {
2419 (arr[0].clone(), arr[1].clone())
2420 } else {
2421 errors.push(format!(
2422 "prompt_learning.mipro.modules[{}].edges[{}] must be a pair or mapping",
2423 module_idx, edge_idx
2424 ));
2425 continue;
2426 }
2427 } else if let Some(map) = edge.as_object() {
2428 let source = map
2429 .get("from")
2430 .or_else(|| map.get("source"))
2431 .cloned()
2432 .unwrap_or(Value::Null);
2433 let target = map
2434 .get("to")
2435 .or_else(|| map.get("target"))
2436 .cloned()
2437 .unwrap_or(Value::Null);
2438 (source, target)
2439 } else {
2440 errors.push(format!(
2441 "prompt_learning.mipro.modules[{}].edges[{}] must be a pair or mapping",
2442 module_idx, edge_idx
2443 ));
2444 continue;
2445 };
2446
2447 let source_str = value_to_string(&source)
2448 .unwrap_or_default()
2449 .trim()
2450 .to_string();
2451 let target_str = value_to_string(&target)
2452 .unwrap_or_default()
2453 .trim()
2454 .to_string();
2455 if !source_str.is_empty()
2456 && !stage_ids_in_module.contains(&source_str)
2457 {
2458 errors.push(format!(
2459 "prompt_learning.mipro.modules[{}].edges[{}] references unknown source stage '{}'",
2460 module_idx, edge_idx, source_str
2461 ));
2462 }
2463 if !target_str.is_empty()
2464 && !stage_ids_in_module.contains(&target_str)
2465 {
2466 errors.push(format!(
2467 "prompt_learning.mipro.modules[{}].edges[{}] references unknown target stage '{}'",
2468 module_idx, edge_idx, target_str
2469 ));
2470 }
2471 }
2472 }
2473 None => errors.push(format!(
2474 "prompt_learning.mipro.modules[{}].edges must be a list",
2475 module_idx
2476 )),
2477 }
2478 }
2479 }
2480 }
2481
2482 let bootstrap_seeds = pl_section
2483 .get("bootstrap_train_seeds")
2484 .or_else(|| mipro_map.get("bootstrap_train_seeds"));
2485 let online_pool = pl_section
2486 .get("online_pool")
2487 .or_else(|| mipro_map.get("online_pool"));
2488
2489 match bootstrap_seeds {
2490 None => errors.push(
2491 "Missing required field: prompt_learning.bootstrap_train_seeds\n MIPRO requires bootstrap seeds for the few-shot bootstrapping phase.\n Example:\n [prompt_learning]\n bootstrap_train_seeds = [0, 1, 2, 3, 4]"
2492 .to_string(),
2493 ),
2494 Some(Value::Array(arr)) => {
2495 if arr.is_empty() {
2496 errors.push("prompt_learning.bootstrap_train_seeds cannot be empty".to_string());
2497 }
2498 }
2499 Some(_) => errors.push("prompt_learning.bootstrap_train_seeds must be an array".to_string()),
2500 }
2501
2502 match online_pool {
2503 None => errors.push(
2504 "Missing required field: prompt_learning.online_pool\n MIPRO requires online_pool seeds for mini-batch evaluation during optimization.\n Example:\n [prompt_learning]\n online_pool = [5, 6, 7, 8, 9]"
2505 .to_string(),
2506 ),
2507 Some(Value::Array(arr)) => {
2508 if arr.is_empty() {
2509 errors.push("prompt_learning.online_pool cannot be empty".to_string());
2510 }
2511 }
2512 Some(_) => errors.push("prompt_learning.online_pool must be an array".to_string()),
2513 }
2514
2515 if let Some(threshold) = mipro_map.get("few_shot_score_threshold") {
2516 match parse_float(threshold) {
2517 Some(fval) => {
2518 if !(0.0..=1.0).contains(&fval) {
2519 errors.push("prompt_learning.mipro.few_shot_score_threshold must be between 0.0 and 1.0".to_string());
2520 }
2521 }
2522 None => errors.push(
2523 "prompt_learning.mipro.few_shot_score_threshold must be a number"
2524 .to_string(),
2525 ),
2526 }
2527 }
2528
2529 if let Some(val) = mipro_map.get("min_bootstrap_demos") {
2530 match parse_int(val) {
2531 Some(ival) => {
2532 if ival < 0 {
2533 errors.push(
2534 "prompt_learning.mipro.min_bootstrap_demos must be >= 0"
2535 .to_string(),
2536 );
2537 } else if let Some(Value::Array(arr)) = bootstrap_seeds {
2538 if ival as usize > arr.len() {
2539 errors.push(format!(
2540 "prompt_learning.mipro.min_bootstrap_demos ({}) exceeds bootstrap_train_seeds count ({}). You can never have more demos than bootstrap seeds.",
2541 ival,
2542 arr.len()
2543 ));
2544 }
2545 }
2546 }
2547 None => errors.push(
2548 "prompt_learning.mipro.min_bootstrap_demos must be an integer".to_string(),
2549 ),
2550 }
2551 }
2552
2553 if let Some(reference_pool) = mipro_map
2554 .get("reference_pool")
2555 .or_else(|| pl_section.get("reference_pool"))
2556 {
2557 match reference_pool.as_array() {
2558 Some(ref_list) => {
2559 let mut all_train_test = HashSet::new();
2560 if let Some(Value::Array(arr)) = bootstrap_seeds {
2561 for item in arr {
2562 if let Some(val) = value_to_string(item) {
2563 all_train_test.insert(val);
2564 }
2565 }
2566 }
2567 if let Some(Value::Array(arr)) = online_pool {
2568 for item in arr {
2569 if let Some(val) = value_to_string(item) {
2570 all_train_test.insert(val);
2571 }
2572 }
2573 }
2574 let test_pool = mipro_map
2575 .get("test_pool")
2576 .or_else(|| pl_section.get("test_pool"));
2577 if let Some(Value::Array(arr)) = test_pool {
2578 for item in arr {
2579 if let Some(val) = value_to_string(item) {
2580 all_train_test.insert(val);
2581 }
2582 }
2583 }
2584 let mut overlapping = Vec::new();
2585 for item in ref_list {
2586 if let Some(val) = value_to_string(item) {
2587 if all_train_test.contains(&val) {
2588 overlapping.push(val);
2589 }
2590 }
2591 }
2592 if !overlapping.is_empty() {
2593 errors.push(format!(
2594 "reference_pool seeds must not overlap with bootstrap/online/test pools. Found overlapping seeds: {:?}",
2595 overlapping
2596 ));
2597 }
2598 }
2599 None => errors.push(
2600 "prompt_learning.mipro.reference_pool (or prompt_learning.reference_pool) must be an array"
2601 .to_string(),
2602 ),
2603 }
2604 }
2605
2606 if let Some(text_dreamer) = mipro_map.get("text_dreamer") {
2607 if let Some(td_map) = text_dreamer.as_object() {
2608 let enabled = td_map
2609 .get("enabled")
2610 .and_then(|v| v.as_bool())
2611 .unwrap_or(false);
2612
2613 if let Some(mode) = td_map.get("mode") {
2614 if mode.as_str().is_none() {
2615 errors.push(
2616 "prompt_learning.mipro.text_dreamer.mode must be a string"
2617 .to_string(),
2618 );
2619 }
2620 }
2621
2622 if let Some(world_model_mode) =
2623 td_map.get("world_model_mode").and_then(|v| v.as_str())
2624 {
2625 let normalized = world_model_mode
2626 .trim()
2627 .to_ascii_lowercase()
2628 .replace('-', "_");
2629 if normalized == "wm_only" {
2630 errors.push(
2631 "prompt_learning.mipro.text_dreamer.world_model_mode cannot be 'wm_only'; use 'ontology_only' or 'ontology_plus_wm'"
2632 .to_string(),
2633 );
2634 } else if !normalized.is_empty()
2635 && normalized != "ontology_only"
2636 && normalized != "ontology_plus_wm"
2637 {
2638 errors.push(
2639 "prompt_learning.mipro.text_dreamer.world_model_mode must be 'ontology_only' or 'ontology_plus_wm'"
2640 .to_string(),
2641 );
2642 }
2643 } else if td_map.contains_key("world_model_mode") {
2644 errors.push(
2645 "prompt_learning.mipro.text_dreamer.world_model_mode must be a string"
2646 .to_string(),
2647 );
2648 }
2649
2650 if let Some(on_overlap) = td_map.get("on_overlap").and_then(|v| v.as_str()) {
2651 let normalized = on_overlap.trim().to_ascii_lowercase().replace('-', "_");
2652 if !normalized.is_empty()
2653 && normalized != "queue"
2654 && normalized != "skip_latest"
2655 {
2656 errors.push(
2657 "prompt_learning.mipro.text_dreamer.on_overlap must be 'queue' or 'skip_latest'"
2658 .to_string(),
2659 );
2660 }
2661 } else if td_map.contains_key("on_overlap") {
2662 errors.push(
2663 "prompt_learning.mipro.text_dreamer.on_overlap must be a string"
2664 .to_string(),
2665 );
2666 }
2667
2668 if let Some(runtime_backend) =
2669 td_map.get("runtime_backend").and_then(|v| v.as_str())
2670 {
2671 let normalized = runtime_backend.trim().to_ascii_lowercase();
2672 if !normalized.is_empty()
2673 && !matches!(
2674 normalized.as_str(),
2675 "rhodes" | "harbor" | "openenv" | "archipelago"
2676 )
2677 {
2678 errors.push(
2679 "prompt_learning.mipro.text_dreamer.runtime_backend must be one of: rhodes, harbor, openenv, archipelago"
2680 .to_string(),
2681 );
2682 }
2683 } else if td_map.contains_key("runtime_backend") {
2684 errors.push(
2685 "prompt_learning.mipro.text_dreamer.runtime_backend must be a string"
2686 .to_string(),
2687 );
2688 }
2689
2690 for (field, min_value) in [
2691 ("max_pending_jobs_per_system", 1),
2692 ("max_replay_rollouts", 1),
2693 ("observation_trigger_every_rollouts", 1),
2694 ("observation_log_window", 1),
2695 ("shadow_max_turns", 1),
2696 ("shadow_timeout_seconds", 1),
2697 ] {
2698 if let Some(value) = td_map.get(field) {
2699 match parse_int(value) {
2700 Some(parsed) => {
2701 if parsed < min_value {
2702 errors.push(format!(
2703 "prompt_learning.mipro.text_dreamer.{} must be >= {}",
2704 field, min_value
2705 ));
2706 }
2707 }
2708 None => errors.push(format!(
2709 "prompt_learning.mipro.text_dreamer.{} must be an integer",
2710 field
2711 )),
2712 }
2713 }
2714 }
2715
2716 for field in ["shadow_rollouts", "shadow_seed_base"] {
2717 if let Some(value) = td_map.get(field) {
2718 match parse_int(value) {
2719 Some(parsed) => {
2720 if parsed < 0 {
2721 errors.push(format!(
2722 "prompt_learning.mipro.text_dreamer.{} must be >= 0",
2723 field
2724 ));
2725 }
2726 }
2727 None => errors.push(format!(
2728 "prompt_learning.mipro.text_dreamer.{} must be an integer",
2729 field
2730 )),
2731 }
2732 }
2733 }
2734
2735 for field in ["container_url", "container_api_key"] {
2736 if let Some(value) = td_map.get(field) {
2737 if value.as_str().is_none() {
2738 errors.push(format!(
2739 "prompt_learning.mipro.text_dreamer.{} must be a string",
2740 field
2741 ));
2742 }
2743 }
2744 }
2745
2746 if enabled
2747 && !ontology_connector_enabled_for_text_dreamer(pl_section, mipro_map)
2748 {
2749 errors.push(
2750 "prompt_learning.mipro.text_dreamer.enabled=true requires ontology connector enabled (set prompt_learning.mipro.ontology.reads=true and/or writes=true)"
2751 .to_string(),
2752 );
2753 }
2754 } else {
2755 errors.push(
2756 "prompt_learning.mipro.text_dreamer must be a table/dict".to_string(),
2757 );
2758 }
2759 }
2760 }
2761 _ => {}
2762 }
2763
2764 if let Some(Value::Object(gepa)) = pl_section.get("gepa") {
2765 if let Some(adaptive_pool) = gepa.get("adaptive_pool") {
2766 validate_adaptive_pool_config(adaptive_pool, "gepa.adaptive_pool", &mut errors);
2767 }
2768 }
2769 if let Some(Value::Object(mipro)) = pl_section.get("mipro") {
2770 if let Some(adaptive_pool) = mipro.get("adaptive_pool") {
2771 validate_adaptive_pool_config(adaptive_pool, "mipro.adaptive_pool", &mut errors);
2772 }
2773 }
2774
2775 errors
2776}
2777
2778#[cfg(test)]
2779mod tests {
2780 use super::validate_prompt_learning_config;
2781 use serde_json::json;
2782
2783 #[test]
2784 fn accepts_mipro_ontology_batch_proposer_stage_fields() {
2785 let config = json!({
2786 "prompt_learning": {
2787 "algorithm": "mipro",
2788 "container_url": "http://localhost:8102",
2789 "mipro": {
2790 "bootstrap_train_seeds": [0, 1],
2791 "online_pool": [2, 3],
2792 "ontology": {
2793 "enabled": true,
2794 "reads": true,
2795 "writes": true,
2796 "batch_proposer": {
2797 "enabled": true,
2798 "trigger_mode": "stage_intervals",
2799 "stage_rollout_intervals": [10, 50, 200],
2800 "repeat_after_last_stage_rollouts": 25,
2801 "batch_size": 40,
2802 "model": "gpt-5.2",
2803 "provider": "openai",
2804 "temperature": 0.7,
2805 "max_tokens": 4096
2806 }
2807 }
2808 }
2809 }
2810 });
2811
2812 let result = validate_prompt_learning_config(&config, None);
2813 assert!(
2814 !result.warnings.iter().any(|warning| {
2815 warning.contains("Unknown field 'ontology' in [prompt_learning.mipro]")
2816 }),
2817 "unexpected ontology unknown-field warning(s): {:?}",
2818 result.warnings
2819 );
2820 assert!(
2821 !result
2822 .warnings
2823 .iter()
2824 .any(|warning| warning.contains("prompt_learning.mipro.ontology")),
2825 "unexpected ontology warning(s): {:?}",
2826 result.warnings
2827 );
2828 assert!(
2829 !result
2830 .warnings
2831 .iter()
2832 .any(|warning| warning.contains("prompt_learning.mipro.ontology.batch_proposer")),
2833 "unexpected batch proposer warning(s): {:?}",
2834 result.warnings
2835 );
2836 }
2837
2838 #[test]
2839 fn accepts_top_level_prompt_learning_ontology_section() {
2840 let config = json!({
2841 "prompt_learning": {
2842 "algorithm": "mipro",
2843 "container_url": "http://localhost:8102",
2844 "ontology": {
2845 "reads": true,
2846 "batch_proposer": {
2847 "enabled": true
2848 }
2849 },
2850 "mipro": {
2851 "bootstrap_train_seeds": [0, 1],
2852 "online_pool": [2, 3]
2853 }
2854 }
2855 });
2856
2857 let result = validate_prompt_learning_config(&config, None);
2858 assert!(
2859 !result
2860 .warnings
2861 .iter()
2862 .any(|warning| warning.contains("Unknown field 'ontology' in [prompt_learning]")),
2863 "top-level ontology should be accepted: {:?}",
2864 result.warnings
2865 );
2866 }
2867
2868 #[test]
2869 fn warns_on_unknown_ontology_keys() {
2870 let config = json!({
2871 "prompt_learning": {
2872 "algorithm": "mipro",
2873 "container_url": "http://localhost:8102",
2874 "mipro": {
2875 "bootstrap_train_seeds": [0, 1],
2876 "online_pool": [2, 3],
2877 "ontology": {
2878 "enabled": true,
2879 "mystery": "value",
2880 "batch_proposer": {
2881 "enabled": true,
2882 "mystery_setting": 123
2883 }
2884 }
2885 }
2886 }
2887 });
2888
2889 let result = validate_prompt_learning_config(&config, None);
2890 assert!(
2891 result.warnings.iter().any(|warning| {
2892 warning.contains("Unknown field 'mystery' in [prompt_learning.mipro.ontology]")
2893 }),
2894 "missing ontology unknown-key warning: {:?}",
2895 result.warnings
2896 );
2897 assert!(
2898 result.warnings.iter().any(|warning| {
2899 warning.contains(
2900 "Unknown field 'mystery_setting' in [prompt_learning.mipro.ontology.batch_proposer]"
2901 )
2902 }),
2903 "missing batch proposer unknown-key warning: {:?}",
2904 result.warnings
2905 );
2906 }
2907
2908 #[test]
2909 fn accepts_mipro_text_dreamer_fields() {
2910 let config = json!({
2911 "prompt_learning": {
2912 "algorithm": "mipro",
2913 "container_url": "http://localhost:8102",
2914 "mipro": {
2915 "bootstrap_train_seeds": [0, 1],
2916 "online_pool": [2, 3],
2917 "ontology": {
2918 "reads": true
2919 },
2920 "text_dreamer": {
2921 "enabled": true,
2922 "mode": "observation_only",
2923 "world_model_mode": "ontology_plus_wm",
2924 "on_overlap": "queue",
2925 "runtime_backend": "rhodes",
2926 "max_pending_jobs_per_system": 2,
2927 "max_replay_rollouts": 4,
2928 "observation_trigger_every_rollouts": 1,
2929 "observation_log_window": 25,
2930 "shadow_rollouts": 2,
2931 "shadow_max_turns": 4,
2932 "shadow_timeout_seconds": 45,
2933 "shadow_seed_base": 1000
2934 }
2935 }
2936 }
2937 });
2938
2939 let result = validate_prompt_learning_config(&config, None);
2940 assert!(
2941 !result
2942 .warnings
2943 .iter()
2944 .any(|warning| { warning.contains("prompt_learning.mipro.text_dreamer") }),
2945 "unexpected text_dreamer warning(s): {:?}",
2946 result.warnings
2947 );
2948 assert!(
2949 !result
2950 .warnings
2951 .iter()
2952 .any(|warning| warning
2953 .contains("Unknown field 'ontology' in [prompt_learning.mipro]")),
2954 "unexpected ontology unknown-field warning(s): {:?}",
2955 result.warnings
2956 );
2957 assert!(
2958 !result.warnings.iter().any(|warning| warning
2959 .contains("Unknown field 'text_dreamer' in [prompt_learning.mipro]")),
2960 "unexpected text_dreamer unknown-field warning(s): {:?}",
2961 result.warnings
2962 );
2963 }
2964
2965 #[test]
2966 fn warns_on_unknown_text_dreamer_keys() {
2967 let config = json!({
2968 "prompt_learning": {
2969 "algorithm": "mipro",
2970 "container_url": "http://localhost:8102",
2971 "mipro": {
2972 "bootstrap_train_seeds": [0, 1],
2973 "online_pool": [2, 3],
2974 "text_dreamer": {
2975 "enabled": true,
2976 "mystery_setting": "value"
2977 }
2978 }
2979 }
2980 });
2981
2982 let result = validate_prompt_learning_config(&config, None);
2983 assert!(
2984 result.warnings.iter().any(|warning| {
2985 warning.contains(
2986 "Unknown field 'mystery_setting' in [prompt_learning.mipro.text_dreamer]",
2987 )
2988 }),
2989 "missing text_dreamer unknown-key warning: {:?}",
2990 result.warnings
2991 );
2992 }
2993
2994 #[test]
2995 fn strict_rejects_text_dreamer_wm_only_and_missing_ontology_connector() {
2996 let config = json!({
2997 "prompt_learning": {
2998 "algorithm": "mipro",
2999 "container_url": "http://localhost:8102",
3000 "mipro": {
3001 "bootstrap_train_seeds": [0, 1],
3002 "online_pool": [2, 3],
3003 "text_dreamer": {
3004 "enabled": true,
3005 "world_model_mode": "wm_only"
3006 }
3007 }
3008 }
3009 });
3010
3011 let errors = super::validate_prompt_learning_config_strict(&config);
3012 assert!(
3013 errors
3014 .iter()
3015 .any(|err| err.contains("world_model_mode cannot be 'wm_only'")),
3016 "missing wm_only strict error: {:?}",
3017 errors
3018 );
3019 assert!(
3020 errors
3021 .iter()
3022 .any(|err| err.contains("requires ontology connector enabled")),
3023 "missing ontology connector strict error: {:?}",
3024 errors
3025 );
3026 }
3027}