1use crate::types::{Content, Memory, MemoryMetadata, State, UUID};
42use crate::{ZoeyError, Result};
43use chrono::Utc;
44use serde::{Deserialize, Serialize};
45use std::collections::HashMap;
46use std::path::PathBuf;
47use std::sync::{Arc, RwLock};
48use tracing::{debug, error, info, instrument, warn};
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
52#[serde(rename_all = "lowercase")]
53pub enum TrainingFormat {
54 Jsonl,
56 Alpaca,
58 ShareGpt,
60 OpenAi,
62 Custom,
64}
65
66#[derive(Debug, Clone)]
68pub struct TrainingConfig {
69 pub enabled: bool,
71
72 pub min_quality_score: f32,
74
75 pub max_samples: usize,
77
78 pub auto_save_interval: u64,
80
81 pub output_dir: PathBuf,
83
84 pub default_format: TrainingFormat,
86
87 pub include_thoughts: bool,
89
90 pub include_negative_examples: bool,
92
93 pub negative_example_ratio: f32,
95
96 pub enable_rlhf: bool,
98
99 pub auto_label: bool,
101}
102
103impl Default for TrainingConfig {
104 fn default() -> Self {
105 Self {
106 enabled: true,
107 min_quality_score: 0.6,
108 max_samples: 10000,
109 auto_save_interval: 300, output_dir: PathBuf::from("./training_data"),
111 default_format: TrainingFormat::Jsonl,
112 include_thoughts: true,
113 include_negative_examples: true,
114 negative_example_ratio: 0.1,
115 enable_rlhf: true,
116 auto_label: true,
117 }
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123#[serde(rename_all = "camelCase")]
124pub struct TrainingSample {
125 pub id: UUID,
127
128 pub prompt: String,
130
131 pub response: String,
133
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub thought: Option<String>,
137
138 #[serde(skip_serializing_if = "Option::is_none")]
140 pub context: Option<HashMap<String, String>>,
141
142 pub quality_score: f32,
144
145 #[serde(skip_serializing_if = "Option::is_none")]
147 pub feedback_score: Option<f32>,
148
149 #[serde(skip_serializing_if = "Option::is_none")]
151 pub category: Option<String>,
152
153 #[serde(default)]
155 pub tags: Vec<String>,
156
157 pub timestamp: i64,
159
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub message_ids: Option<MessagePair>,
163
164 #[serde(skip_serializing_if = "Option::is_none")]
166 pub metadata: Option<HashMap<String, serde_json::Value>>,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171#[serde(rename_all = "camelCase")]
172pub struct MessagePair {
173 pub user_message_id: UUID,
175
176 pub agent_message_id: UUID,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct AlpacaSample {
183 pub instruction: String,
185
186 pub input: String,
188
189 pub output: String,
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct ShareGptConversation {
196 pub conversations: Vec<ShareGptMessage>,
198}
199
200#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct ShareGptMessage {
203 pub from: String,
205
206 pub value: String,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct OpenAiFineTuning {
213 pub messages: Vec<OpenAiMessage>,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct OpenAiMessage {
220 pub role: String,
222
223 pub content: String,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229#[serde(rename_all = "camelCase")]
230pub struct DatasetStatistics {
231 pub total_samples: usize,
233
234 pub high_quality_count: usize,
236
237 pub medium_quality_count: usize,
239
240 pub low_quality_count: usize,
242
243 pub with_thoughts_count: usize,
245
246 pub with_feedback_count: usize,
248
249 pub avg_quality_score: f32,
251
252 pub avg_feedback_score: f32,
254
255 pub categories: HashMap<String, usize>,
257
258 pub tags: HashMap<String, usize>,
260}
261
262pub struct TrainingCollector {
264 config: TrainingConfig,
266
267 samples: Arc<RwLock<Vec<TrainingSample>>>,
269
270 last_save: Arc<RwLock<std::time::Instant>>,
272}
273
274impl TrainingCollector {
275 pub fn new(config: TrainingConfig) -> Self {
277 Self {
278 config,
279 samples: Arc::new(RwLock::new(Vec::new())),
280 last_save: Arc::new(RwLock::new(std::time::Instant::now())),
281 }
282 }
283
284 pub fn is_rlhf_enabled(&self) -> bool {
286 self.config.enable_rlhf
287 }
288
289 #[instrument(skip(self, prompt, response, thought), level = "debug")]
291 pub async fn record_interaction(
292 &self,
293 prompt: impl Into<String>,
294 response: impl Into<String>,
295 thought: Option<String>,
296 quality_score: f32,
297 ) -> Result<UUID> {
298 if !self.config.enabled {
299 return Err(ZoeyError::Config(
300 "Training collection is disabled".to_string(),
301 ));
302 }
303
304 let prompt = prompt.into();
305 let response = response.into();
306
307 if quality_score < self.config.min_quality_score {
309 debug!(
310 "Skipping interaction due to low quality score: {}",
311 quality_score
312 );
313 return Err(ZoeyError::Validation(
314 "Quality score below threshold".to_string(),
315 ));
316 }
317
318 let sample = TrainingSample {
319 id: uuid::Uuid::new_v4(),
320 prompt: prompt.clone(),
321 response: response.clone(),
322 thought,
323 context: None,
324 quality_score,
325 feedback_score: None,
326 category: if self.config.auto_label {
327 Some(auto_categorize(&prompt, &response))
328 } else {
329 None
330 },
331 tags: if self.config.auto_label {
332 auto_generate_tags(&prompt, &response)
333 } else {
334 vec![]
335 },
336 timestamp: Utc::now().timestamp_millis(),
337 message_ids: None,
338 metadata: None,
339 };
340
341 let sample_id = sample.id;
342
343 {
345 let mut samples = self.samples.write().unwrap();
346 samples.push(sample);
347
348 if samples.len() > self.config.max_samples {
350 warn!("Training samples exceeded limit, removing oldest");
351 samples.remove(0);
352 }
353 }
354
355 info!(
356 "Recorded training sample: {} (quality: {})",
357 sample_id, quality_score
358 );
359
360 self.check_auto_save().await?;
362
363 Ok(sample_id)
364 }
365
366 #[instrument(
368 skip(self, runtime_any, thought_text, original_message),
369 level = "info"
370 )]
371 pub async fn store_thought(
372 &self,
373 runtime_any: Arc<dyn std::any::Any + Send + Sync>,
374 thought_text: &str,
375 original_message: &Memory,
376 quality_score: f32,
377 ) -> Result<UUID> {
378 info!(
379 "💭 Storing agent thought ({} chars, quality: {})",
380 thought_text.len(),
381 quality_score
382 );
383
384 let runtime_ref = crate::runtime_ref::downcast_runtime_ref(&runtime_any)
386 .ok_or_else(|| ZoeyError::Runtime("Invalid runtime reference".to_string()))?;
387
388 let runtime_arc = runtime_ref
389 .try_upgrade()
390 .ok_or_else(|| ZoeyError::Runtime("Runtime has been dropped".to_string()))?;
391
392 let agent_runtime = runtime_arc.read().unwrap();
393 let agent_id = agent_runtime.agent_id;
394
395 let thought_memory = Memory {
397 id: uuid::Uuid::new_v4(),
398 entity_id: agent_id,
399 agent_id,
400 room_id: original_message.room_id,
401 content: Content {
402 text: thought_text.to_string(),
403 source: Some("internal_thought".to_string()),
404 thought: Some(thought_text.to_string()),
405 ..Default::default()
406 },
407 embedding: None,
408 metadata: Some(MemoryMetadata {
409 memory_type: Some("thought".to_string()),
410 entity_name: Some(agent_runtime.character.name.clone()),
411 data: {
412 let mut meta = HashMap::new();
413 meta.insert("purpose".to_string(), serde_json::json!("reflection"));
414 meta.insert(
415 "related_message".to_string(),
416 serde_json::json!(original_message.id.to_string()),
417 );
418 meta.insert(
419 "timestamp".to_string(),
420 serde_json::json!(Utc::now().timestamp_millis()),
421 );
422 meta.insert(
423 "quality_score".to_string(),
424 serde_json::json!(quality_score),
425 );
426 meta.insert(
427 "can_be_used_for".to_string(),
428 serde_json::json!([
429 "decision_pattern_analysis",
430 "response_improvement",
431 "self_reflection",
432 "training_data",
433 "rlhf"
434 ]),
435 );
436 meta
437 },
438 }),
439 created_at: Utc::now().timestamp_millis(),
440 unique: Some(false),
441 similarity: None,
442 };
443
444 let thought_id = thought_memory.id;
445
446 let adapter_opt = agent_runtime.adapter.read().unwrap().clone();
448 if let Some(adapter) = adapter_opt.as_ref() {
449 match adapter.create_memory(&thought_memory, "thoughts").await {
450 Ok(id) => {
451 info!("✓ Thought stored with ID: {}", id);
452 info!("💾 Available for: pattern analysis, RLHF, training");
453 }
454 Err(e) => {
455 error!("Failed to store thought: {}", e);
456 return Err(e);
457 }
458 }
459 }
460
461 if self.config.enabled && quality_score >= self.config.min_quality_score {
463 self.record_interaction(
464 original_message.content.text.clone(),
465 thought_text,
466 Some(thought_text.to_string()),
467 quality_score,
468 )
469 .await?;
470 }
471
472 Ok(thought_id)
473 }
474
475 #[instrument(skip(self), level = "info")]
477 pub async fn add_feedback(
478 &self,
479 sample_id: UUID,
480 feedback_score: f32,
481 feedback_text: Option<String>,
482 ) -> Result<()> {
483 if !self.config.enable_rlhf {
484 return Err(ZoeyError::Config("RLHF is disabled".to_string()));
485 }
486
487 if !(-1.0..=1.0).contains(&feedback_score) {
489 return Err(ZoeyError::Validation(
490 "Feedback score must be between -1.0 and 1.0".to_string(),
491 ));
492 }
493
494 let mut samples = self.samples.write().unwrap();
495
496 if let Some(sample) = samples.iter_mut().find(|s| s.id == sample_id) {
497 sample.feedback_score = Some(feedback_score);
498
499 if let Some(text) = feedback_text {
501 let mut metadata = sample.metadata.take().unwrap_or_default();
502 metadata.insert("feedback_text".to_string(), serde_json::json!(text));
503 metadata.insert(
504 "feedback_timestamp".to_string(),
505 serde_json::json!(Utc::now().timestamp_millis()),
506 );
507 sample.metadata = Some(metadata);
508 }
509
510 info!(
511 "✓ Added feedback to sample {} (score: {})",
512 sample_id, feedback_score
513 );
514 Ok(())
515 } else {
516 Err(ZoeyError::NotFound(format!(
517 "Training sample {} not found",
518 sample_id
519 )))
520 }
521 }
522
523 #[instrument(skip(self), level = "info")]
525 pub async fn add_review(
526 &self,
527 sample_id: UUID,
528 review_score: f32,
529 review_text: Option<String>,
530 ) -> Result<()> {
531 if !(0.0..=1.0).contains(&review_score) {
532 return Err(ZoeyError::Validation(
533 "Review score must be between 0.0 and 1.0".to_string(),
534 ));
535 }
536 let mut samples = self.samples.write().unwrap();
537 if let Some(sample) = samples.iter_mut().find(|s| s.id == sample_id) {
538 let mut metadata = sample.metadata.take().unwrap_or_default();
539 metadata.insert("review_score".to_string(), serde_json::json!(review_score));
540 if let Some(text) = review_text {
541 metadata.insert("review_text".to_string(), serde_json::json!(text));
542 }
543 metadata.insert(
544 "review_timestamp".to_string(),
545 serde_json::json!(Utc::now().timestamp_millis()),
546 );
547 sample.metadata = Some(metadata);
548 info!(
549 "✓ Added evaluator review to sample {} (score: {})",
550 sample_id, review_score
551 );
552 Ok(())
553 } else {
554 Err(ZoeyError::NotFound(format!(
555 "Training sample {} not found",
556 sample_id
557 )))
558 }
559 }
560
561 #[instrument(skip(self, message, response, thought, state), level = "debug")]
563 pub async fn record_conversation_turn(
564 &self,
565 message: &Memory,
566 response: &Memory,
567 thought: Option<String>,
568 state: &State,
569 ) -> Result<UUID> {
570 if !self.config.enabled {
571 return Err(ZoeyError::Config(
572 "Training collection is disabled".to_string(),
573 ));
574 }
575
576 let quality_score = calculate_quality_score(message, response, &thought, state);
578
579 if quality_score < self.config.min_quality_score {
580 debug!(
581 "Skipping low quality interaction (score: {})",
582 quality_score
583 );
584 return Err(ZoeyError::Validation(
585 "Quality score below threshold".to_string(),
586 ));
587 }
588
589 let context: HashMap<String, String> = state.values.clone();
591
592 let sample = TrainingSample {
593 id: uuid::Uuid::new_v4(),
594 prompt: message.content.text.clone(),
595 response: response.content.text.clone(),
596 thought: if self.config.include_thoughts {
597 thought
598 } else {
599 None
600 },
601 context: Some(context),
602 quality_score,
603 feedback_score: None,
604 category: if self.config.auto_label {
605 Some(auto_categorize(
606 &message.content.text,
607 &response.content.text,
608 ))
609 } else {
610 None
611 },
612 tags: if self.config.auto_label {
613 auto_generate_tags(&message.content.text, &response.content.text)
614 } else {
615 vec![]
616 },
617 timestamp: Utc::now().timestamp_millis(),
618 message_ids: Some(MessagePair {
619 user_message_id: message.id,
620 agent_message_id: response.id,
621 }),
622 metadata: None,
623 };
624
625 let sample_id = sample.id;
626
627 {
628 let mut samples = self.samples.write().unwrap();
629 samples.push(sample);
630
631 if samples.len() > self.config.max_samples {
632 samples.remove(0);
633 }
634 }
635
636 info!(
637 "Recorded conversation turn: {} (quality: {})",
638 sample_id, quality_score
639 );
640
641 self.check_auto_save().await?;
642
643 Ok(sample_id)
644 }
645
646 pub fn get_samples(&self) -> Vec<TrainingSample> {
648 self.samples.read().unwrap().clone()
649 }
650
651 pub fn get_samples_by_quality(&self, min_score: f32, max_score: f32) -> Vec<TrainingSample> {
653 self.samples
654 .read()
655 .unwrap()
656 .iter()
657 .filter(|s| s.quality_score >= min_score && s.quality_score <= max_score)
658 .cloned()
659 .collect()
660 }
661
662 pub fn get_samples_with_feedback(&self) -> Vec<TrainingSample> {
664 self.samples
665 .read()
666 .unwrap()
667 .iter()
668 .filter(|s| s.feedback_score.is_some())
669 .cloned()
670 .collect()
671 }
672
673 pub fn get_statistics(&self) -> DatasetStatistics {
675 let samples = self.samples.read().unwrap();
676
677 let total_samples = samples.len();
678 let high_quality_count = samples.iter().filter(|s| s.quality_score > 0.8).count();
679 let medium_quality_count = samples
680 .iter()
681 .filter(|s| s.quality_score >= 0.6 && s.quality_score <= 0.8)
682 .count();
683 let low_quality_count = samples.iter().filter(|s| s.quality_score < 0.6).count();
684 let with_thoughts_count = samples.iter().filter(|s| s.thought.is_some()).count();
685 let with_feedback_count = samples
686 .iter()
687 .filter(|s| s.feedback_score.is_some())
688 .count();
689
690 let avg_quality_score = if total_samples > 0 {
691 samples.iter().map(|s| s.quality_score).sum::<f32>() / total_samples as f32
692 } else {
693 0.0
694 };
695
696 let feedback_samples: Vec<_> = samples.iter().filter_map(|s| s.feedback_score).collect();
697 let avg_feedback_score = if !feedback_samples.is_empty() {
698 feedback_samples.iter().sum::<f32>() / feedback_samples.len() as f32
699 } else {
700 let review_scores: Vec<f32> = samples
702 .iter()
703 .filter_map(|s| {
704 s.metadata
705 .as_ref()
706 .and_then(|m| m.get("review_score"))
707 .and_then(|v| v.as_f64())
708 .map(|f| f as f32)
709 })
710 .collect();
711 if !review_scores.is_empty() {
712 review_scores.iter().sum::<f32>() / review_scores.len() as f32
713 } else {
714 0.0
715 }
716 };
717
718 let mut categories: HashMap<String, usize> = HashMap::new();
719 for sample in samples.iter() {
720 if let Some(cat) = &sample.category {
721 *categories.entry(cat.clone()).or_insert(0) += 1;
722 }
723 }
724
725 let mut tags: HashMap<String, usize> = HashMap::new();
726 for sample in samples.iter() {
727 for tag in &sample.tags {
728 *tags.entry(tag.clone()).or_insert(0) += 1;
729 }
730 }
731
732 DatasetStatistics {
733 total_samples,
734 high_quality_count,
735 medium_quality_count,
736 low_quality_count,
737 with_thoughts_count,
738 with_feedback_count,
739 avg_quality_score,
740 avg_feedback_score,
741 categories,
742 tags,
743 }
744 }
745
746 #[instrument(skip(self), level = "info")]
748 pub async fn export_jsonl(&self) -> Result<String> {
749 let samples = self.samples.read().unwrap();
750
751 let jsonl = samples
752 .iter()
753 .map(|sample| {
754 let mut s = sample.clone();
755 let _ = s.metadata.as_ref().and_then(|m| m.get("review_score"));
757 serde_json::to_string(&s).unwrap()
758 })
759 .collect::<Vec<_>>()
760 .join("\n");
761
762 info!(
763 "Exported {} samples as JSONL ({} bytes)",
764 samples.len(),
765 jsonl.len()
766 );
767 Ok(jsonl)
768 }
769
770 #[instrument(skip(self), level = "info")]
772 pub async fn export_alpaca(&self) -> Result<String> {
773 let samples = self.samples.read().unwrap();
774
775 let alpaca_samples: Vec<AlpacaSample> = samples
776 .iter()
777 .map(|sample| AlpacaSample {
778 instruction: extract_instruction(&sample.prompt),
779 input: sample.prompt.clone(),
780 output: sample.response.clone(),
781 })
782 .collect();
783
784 let json = serde_json::to_string_pretty(&alpaca_samples)?;
785 info!("Exported {} samples as Alpaca format", samples.len());
786 Ok(json)
787 }
788
789 #[instrument(skip(self), level = "info")]
791 pub async fn export_sharegpt(&self) -> Result<String> {
792 let samples = self.samples.read().unwrap();
793
794 let conversations: Vec<ShareGptConversation> = samples
795 .iter()
796 .map(|sample| ShareGptConversation {
797 conversations: vec![
798 ShareGptMessage {
799 from: "human".to_string(),
800 value: sample.prompt.clone(),
801 },
802 ShareGptMessage {
803 from: "gpt".to_string(),
804 value: sample.response.clone(),
805 },
806 ],
807 })
808 .collect();
809
810 let json = serde_json::to_string_pretty(&conversations)?;
811 info!("Exported {} samples as ShareGPT format", samples.len());
812 Ok(json)
813 }
814
815 #[instrument(skip(self), level = "info")]
817 pub async fn export_openai(&self) -> Result<String> {
818 let samples = self.samples.read().unwrap();
819
820 let training_data: Vec<OpenAiFineTuning> = samples
821 .iter()
822 .map(|sample| OpenAiFineTuning {
823 messages: vec![
824 OpenAiMessage {
825 role: "user".to_string(),
826 content: sample.prompt.clone(),
827 },
828 OpenAiMessage {
829 role: "assistant".to_string(),
830 content: sample.response.clone(),
831 },
832 ],
833 })
834 .collect();
835
836 let jsonl = training_data
838 .iter()
839 .map(|item| serde_json::to_string(item).unwrap())
840 .collect::<Vec<_>>()
841 .join("\n");
842
843 info!(
844 "Exported {} samples as OpenAI fine-tuning format",
845 samples.len()
846 );
847 Ok(jsonl)
848 }
849
850 #[instrument(skip(self), level = "info")]
852 pub async fn save_to_file(&self, format: TrainingFormat) -> Result<PathBuf> {
853 tokio::fs::create_dir_all(&self.config.output_dir).await?;
855
856 let timestamp = Utc::now().format("%Y%m%d_%H%M%S");
857 let (data, extension) = match format {
858 TrainingFormat::Jsonl => (self.export_jsonl().await?, "jsonl"),
859 TrainingFormat::Alpaca => (self.export_alpaca().await?, "json"),
860 TrainingFormat::ShareGpt => (self.export_sharegpt().await?, "json"),
861 TrainingFormat::OpenAi => (self.export_openai().await?, "jsonl"),
862 TrainingFormat::Custom => (self.export_jsonl().await?, "jsonl"),
863 };
864
865 let filename = format!(
866 "training_data_{}_{}.{}",
867 format!("{:?}", format).to_lowercase(),
868 timestamp,
869 extension
870 );
871 let filepath = self.config.output_dir.join(filename);
872
873 tokio::fs::write(&filepath, data).await?;
874
875 info!("✓ Saved training data to: {:?}", filepath);
876 Ok(filepath)
877 }
878
879 async fn check_auto_save(&self) -> Result<()> {
881 if self.config.auto_save_interval == 0 {
882 return Ok(());
883 }
884
885 let should_save = {
886 let last_save = self.last_save.read().unwrap();
887 last_save.elapsed().as_secs() >= self.config.auto_save_interval
888 };
889
890 if should_save {
891 info!("Auto-save triggered");
892 self.save_to_file(self.config.default_format).await?;
893
894 let mut last_save = self.last_save.write().unwrap();
895 *last_save = std::time::Instant::now();
896 }
897
898 Ok(())
899 }
900
901 #[instrument(skip(self), level = "info")]
903 pub fn remove_sample(&self, sample_id: UUID) -> Result<()> {
904 let mut samples = self.samples.write().unwrap();
905 let initial_len = samples.len();
906 samples.retain(|s| s.id != sample_id);
907
908 if samples.len() < initial_len {
909 info!("Removed training sample: {}", sample_id);
910 Ok(())
911 } else {
912 Err(ZoeyError::NotFound(format!(
913 "Training sample {} not found",
914 sample_id
915 )))
916 }
917 }
918
919 pub fn get_sample(&self, sample_id: UUID) -> Option<TrainingSample> {
921 self.samples
922 .read()
923 .unwrap()
924 .iter()
925 .find(|s| s.id == sample_id)
926 .cloned()
927 }
928
929 pub fn clear(&self) {
931 let mut samples = self.samples.write().unwrap();
932 samples.clear();
933 info!("Cleared all training samples");
934 }
935
936 pub fn count(&self) -> usize {
938 self.samples.read().unwrap().len()
939 }
940}
941
942fn calculate_quality_score(
944 _message: &Memory,
945 response: &Memory,
946 thought: &Option<String>,
947 state: &State,
948) -> f32 {
949 let mut score: f32 = 0.5; let response_len = response.content.text.len();
953 if response_len > 20 && response_len < 1000 {
954 score += 0.1;
955 } else if response_len >= 1000 {
956 score += 0.05; }
958
959 if thought.is_some() {
961 score += 0.15;
962 }
963
964 if state.values.len() > 5 {
966 score += 0.1;
967 }
968
969 if response.content.text.ends_with('.')
971 || response.content.text.ends_with('!')
972 || response.content.text.ends_with('?')
973 {
974 score += 0.05;
975 }
976
977 if response.content.text.split_whitespace().count() > 3 {
979 score += 0.1;
980 }
981
982 score.min(1.0)
984}
985
986fn auto_categorize(prompt: &str, response: &str) -> String {
988 let prompt_lower = prompt.to_lowercase();
989 let response_lower = response.to_lowercase();
990
991 if prompt_lower.contains("how")
993 && (prompt_lower.contains("work") || prompt_lower.contains("do"))
994 {
995 "how_to".to_string()
996 } else if prompt_lower.contains("what") || prompt_lower.contains("explain") {
997 "explanation".to_string()
998 } else if prompt_lower.contains("why") {
999 "reasoning".to_string()
1000 } else if prompt_lower.contains("?") {
1001 "question_answer".to_string()
1002 } else if response_lower.contains("error") || response_lower.contains("sorry") {
1003 "error_handling".to_string()
1004 } else if prompt_lower.contains("thank") || response_lower.contains("welcome") {
1005 "social".to_string()
1006 } else if prompt_lower.contains("help") {
1007 "help_request".to_string()
1008 } else {
1009 "general".to_string()
1010 }
1011}
1012
1013fn auto_generate_tags(prompt: &str, response: &str) -> Vec<String> {
1015 let mut tags = Vec::new();
1016
1017 let prompt_lower = prompt.to_lowercase();
1018 let response_lower = response.to_lowercase();
1019
1020 if prompt_lower.contains("code") || response_lower.contains("```") {
1022 tags.push("code".to_string());
1023 }
1024
1025 if prompt_lower.contains("data") || prompt_lower.contains("information") {
1026 tags.push("data".to_string());
1027 }
1028
1029 if prompt_lower.len() > 200 {
1030 tags.push("long_prompt".to_string());
1031 }
1032
1033 if response_lower.len() > 500 {
1034 tags.push("detailed_response".to_string());
1035 }
1036
1037 if prompt_lower.contains("?") {
1038 tags.push("question".to_string());
1039 }
1040
1041 if response_lower.contains("step") || response_lower.contains("first") {
1042 tags.push("instructional".to_string());
1043 }
1044
1045 tags
1046}
1047
1048fn extract_instruction(prompt: &str) -> String {
1050 let first_sentence = prompt.split('.').next().unwrap_or(prompt);
1052
1053 if first_sentence.len() > 100 {
1054 format!("{}...", &first_sentence[..100])
1055 } else {
1056 first_sentence.to_string()
1057 }
1058}
1059
1060pub struct RLHFManager {
1062 collector: Arc<TrainingCollector>,
1063}
1064
1065impl RLHFManager {
1066 pub fn new(collector: Arc<TrainingCollector>) -> Self {
1068 Self { collector }
1069 }
1070
1071 pub async fn record_positive(&self, sample_id: UUID, reason: Option<String>) -> Result<()> {
1073 self.collector.add_feedback(sample_id, 1.0, reason).await
1074 }
1075
1076 pub async fn record_negative(&self, sample_id: UUID, reason: Option<String>) -> Result<()> {
1078 self.collector.add_feedback(sample_id, -1.0, reason).await
1079 }
1080
1081 pub async fn record_neutral(&self, sample_id: UUID) -> Result<()> {
1083 self.collector.add_feedback(sample_id, 0.0, None).await
1084 }
1085
1086 pub fn get_rlhf_dataset(&self) -> Vec<(TrainingSample, TrainingSample)> {
1088 let samples = self.collector.get_samples_with_feedback();
1089
1090 let mut pairs = Vec::new();
1092 let positive: Vec<_> = samples
1093 .iter()
1094 .filter(|s| s.feedback_score.unwrap_or(0.0) > 0.5)
1095 .cloned()
1096 .collect();
1097
1098 let negative: Vec<_> = samples
1099 .iter()
1100 .filter(|s| s.feedback_score.unwrap_or(0.0) < -0.5)
1101 .cloned()
1102 .collect();
1103
1104 for (pos, neg) in positive.iter().zip(negative.iter()) {
1106 pairs.push((pos.clone(), neg.clone()));
1107 }
1108
1109 pairs
1110 }
1111
1112 pub fn calculate_rewards(&self, sample_ids: &[UUID]) -> HashMap<UUID, f32> {
1114 let samples = self.collector.get_samples();
1115 let mut rewards = HashMap::new();
1116
1117 for id in sample_ids {
1118 if let Some(sample) = samples.iter().find(|s| s.id == *id) {
1119 let quality_reward = sample.quality_score;
1121 let feedback_reward = sample.feedback_score.unwrap_or(0.0);
1122
1123 let total_reward = (quality_reward * 0.4) + (feedback_reward * 0.6);
1125
1126 rewards.insert(*id, total_reward);
1127 }
1128 }
1129
1130 rewards
1131 }
1132}
1133
1134pub struct DatasetBuilder {
1136 samples: Vec<TrainingSample>,
1137}
1138
1139impl DatasetBuilder {
1140 pub fn new() -> Self {
1142 Self {
1143 samples: Vec::new(),
1144 }
1145 }
1146
1147 pub fn add_from_collector(mut self, collector: &TrainingCollector) -> Self {
1149 self.samples.extend(collector.get_samples());
1150 self
1151 }
1152
1153 pub fn filter_by_quality(mut self, min_score: f32) -> Self {
1155 self.samples.retain(|s| s.quality_score >= min_score);
1156 self
1157 }
1158
1159 pub fn filter_by_category(mut self, category: &str) -> Self {
1161 self.samples
1162 .retain(|s| s.category.as_ref().map(|c| c == category).unwrap_or(false));
1163 self
1164 }
1165
1166 pub fn filter_by_tags(mut self, tags: &[String]) -> Self {
1168 self.samples
1169 .retain(|s| tags.iter().any(|tag| s.tags.contains(tag)));
1170 self
1171 }
1172
1173 pub fn only_with_thoughts(mut self) -> Self {
1175 self.samples.retain(|s| s.thought.is_some());
1176 self
1177 }
1178
1179 pub fn only_with_feedback(mut self) -> Self {
1181 self.samples.retain(|s| s.feedback_score.is_some());
1182 self
1183 }
1184
1185 pub fn top_n(mut self, n: usize) -> Self {
1187 self.samples.sort_by(|a, b| {
1188 b.quality_score
1189 .partial_cmp(&a.quality_score)
1190 .unwrap_or(std::cmp::Ordering::Equal)
1191 });
1192 self.samples.truncate(n);
1193 self
1194 }
1195
1196 pub fn balance_examples(mut self, positive_ratio: f32) -> Self {
1198 let positive: Vec<_> = self
1199 .samples
1200 .iter()
1201 .filter(|s| s.quality_score > 0.7)
1202 .cloned()
1203 .collect();
1204
1205 let negative: Vec<_> = self
1206 .samples
1207 .iter()
1208 .filter(|s| s.quality_score < 0.5)
1209 .cloned()
1210 .collect();
1211
1212 let target_positive = (positive.len() as f32 * positive_ratio) as usize;
1213 let target_negative = positive.len() - target_positive;
1214
1215 self.samples.clear();
1216 self.samples
1217 .extend(positive.into_iter().take(target_positive));
1218 self.samples
1219 .extend(negative.into_iter().take(target_negative));
1220
1221 self
1222 }
1223
1224 pub fn build(self) -> Vec<TrainingSample> {
1226 self.samples
1227 }
1228
1229 pub fn count(&self) -> usize {
1231 self.samples.len()
1232 }
1233}
1234
1235impl Default for DatasetBuilder {
1236 fn default() -> Self {
1237 Self::new()
1238 }
1239}
1240
1241pub fn create_training_collector(config: TrainingConfig) -> Arc<TrainingCollector> {
1243 Arc::new(TrainingCollector::new(config))
1244}
1245
1246#[cfg(test)]
1247mod tests {
1248 use super::*;
1249 use uuid::Uuid;
1250
1251 #[test]
1252 fn test_training_config() {
1253 let config = TrainingConfig::default();
1254 assert!(config.enabled);
1255 assert_eq!(config.min_quality_score, 0.6);
1256 assert_eq!(config.max_samples, 10000);
1257 }
1258
1259 #[tokio::test]
1260 async fn test_record_interaction() {
1261 let config = TrainingConfig::default();
1262 let collector = TrainingCollector::new(config);
1263
1264 let result = collector
1265 .record_interaction(
1266 "Hello, how are you?",
1267 "I'm doing well, thank you!",
1268 Some("User is greeting me".to_string()),
1269 0.8,
1270 )
1271 .await;
1272
1273 assert!(result.is_ok());
1274 assert_eq!(collector.count(), 1);
1275 }
1276
1277 #[tokio::test]
1278 async fn test_low_quality_rejected() {
1279 let config = TrainingConfig::default();
1280 let collector = TrainingCollector::new(config);
1281
1282 let result = collector
1283 .record_interaction(
1284 "test", "ok", None, 0.3, )
1286 .await;
1287
1288 assert!(result.is_err());
1289 assert_eq!(collector.count(), 0);
1290 }
1291
1292 #[tokio::test]
1293 async fn test_feedback() {
1294 let config = TrainingConfig::default();
1295 let collector = TrainingCollector::new(config);
1296
1297 let sample_id = collector
1298 .record_interaction(
1299 "What is Rust?",
1300 "Rust is a systems programming language",
1301 None,
1302 0.9,
1303 )
1304 .await
1305 .unwrap();
1306
1307 collector
1308 .add_feedback(sample_id, 1.0, Some("Great answer!".to_string()))
1309 .await
1310 .unwrap();
1311
1312 let samples = collector.get_samples_with_feedback();
1313 assert_eq!(samples.len(), 1);
1314 assert_eq!(samples[0].feedback_score, Some(1.0));
1315 }
1316
1317 #[test]
1318 fn test_auto_categorize() {
1319 assert_eq!(
1320 auto_categorize("How does this work?", "It works by..."),
1321 "how_to"
1322 );
1323 assert_eq!(auto_categorize("What is AI?", "AI is..."), "explanation");
1324 assert_eq!(auto_categorize("Why is that?", "Because..."), "reasoning");
1325 assert_eq!(auto_categorize("Help me", "Sure!"), "help_request");
1326 }
1327
1328 #[test]
1329 fn test_auto_generate_tags() {
1330 let tags = auto_generate_tags("Can you write some code?", "```python\nprint('hello')\n```");
1331 assert!(tags.contains(&"code".to_string()));
1332 assert!(tags.contains(&"question".to_string()));
1333 }
1334
1335 #[tokio::test]
1336 async fn test_export_jsonl() {
1337 let config = TrainingConfig::default();
1338 let collector = TrainingCollector::new(config);
1339
1340 collector
1341 .record_interaction("Test", "Response", None, 0.8)
1342 .await
1343 .unwrap();
1344
1345 let jsonl = collector.export_jsonl().await.unwrap();
1346 assert!(jsonl.contains("Test"));
1347 assert!(jsonl.contains("Response"));
1348 }
1349
1350 #[tokio::test]
1351 async fn test_statistics() {
1352 let config = TrainingConfig {
1353 min_quality_score: 0.5, ..Default::default()
1355 };
1356 let collector = TrainingCollector::new(config);
1357
1358 collector
1359 .record_interaction("Q1", "A1", Some("T1".to_string()), 0.9)
1360 .await
1361 .unwrap();
1362 collector
1363 .record_interaction("Q2", "A2", None, 0.7)
1364 .await
1365 .unwrap();
1366 collector
1367 .record_interaction("Q3", "A3", Some("T3".to_string()), 0.5)
1368 .await
1369 .unwrap();
1370
1371 let stats = collector.get_statistics();
1372 assert_eq!(stats.total_samples, 3);
1373 assert_eq!(stats.high_quality_count, 1); assert_eq!(stats.with_thoughts_count, 2);
1375 }
1376
1377 #[test]
1378 fn test_dataset_builder() {
1379 let config = TrainingConfig::default();
1380 let collector = TrainingCollector::new(config);
1381
1382 let dataset = DatasetBuilder::new()
1383 .add_from_collector(&collector)
1384 .filter_by_quality(0.7)
1385 .top_n(10)
1386 .build();
1387
1388 assert!(dataset.len() <= 10);
1389 }
1390
1391 #[test]
1392 fn test_quality_score_calculation() {
1393 let message = Memory {
1394 id: Uuid::new_v4(),
1395 entity_id: Uuid::new_v4(),
1396 agent_id: Uuid::new_v4(),
1397 room_id: Uuid::new_v4(),
1398 content: Content {
1399 text: "Hello".to_string(),
1400 ..Default::default()
1401 },
1402 embedding: None,
1403 metadata: None,
1404 created_at: 12345,
1405 unique: None,
1406 similarity: None,
1407 };
1408
1409 let response = Memory {
1410 id: Uuid::new_v4(),
1411 entity_id: Uuid::new_v4(),
1412 agent_id: Uuid::new_v4(),
1413 room_id: Uuid::new_v4(),
1414 content: Content {
1415 text: "Hello! How can I help you today?".to_string(),
1416 ..Default::default()
1417 },
1418 embedding: None,
1419 metadata: None,
1420 created_at: 12346,
1421 unique: None,
1422 similarity: None,
1423 };
1424
1425 let thought = Some("User is greeting me".to_string());
1426 let state = State::new();
1427
1428 let score = calculate_quality_score(&message, &response, &thought, &state);
1429 assert!(score >= 0.5);
1430 assert!(score <= 1.0);
1431 }
1432
1433 #[test]
1434 fn test_rlhf_manager() {
1435 let config = TrainingConfig::default();
1436 let collector = Arc::new(TrainingCollector::new(config));
1437 let rlhf = RLHFManager::new(collector);
1438
1439 let _ = rlhf;
1441 }
1442
1443 #[tokio::test]
1444 async fn test_export_formats() {
1445 let config = TrainingConfig::default();
1446 let collector = TrainingCollector::new(config);
1447
1448 collector
1449 .record_interaction("Test Q", "Test A", None, 0.8)
1450 .await
1451 .unwrap();
1452
1453 let jsonl = collector.export_jsonl().await;
1455 assert!(jsonl.is_ok());
1456
1457 let alpaca = collector.export_alpaca().await;
1458 assert!(alpaca.is_ok());
1459
1460 let sharegpt = collector.export_sharegpt().await;
1461 assert!(sharegpt.is_ok());
1462
1463 let openai = collector.export_openai().await;
1464 assert!(openai.is_ok());
1465 }
1466
1467 #[tokio::test]
1468 async fn test_add_review_non_rlhf() {
1469 let config = TrainingConfig {
1470 enable_rlhf: false,
1471 ..Default::default()
1472 };
1473 let collector = TrainingCollector::new(config);
1474
1475 let sample_id = collector
1476 .record_interaction("Prompt X", "Response Y", None, 0.8)
1477 .await
1478 .unwrap();
1479
1480 collector
1481 .add_review(sample_id, 0.9, Some("Good coherence".to_string()))
1482 .await
1483 .unwrap();
1484
1485 let samples = collector.get_samples_by_quality(0.0, 1.0);
1486 let sample = samples.into_iter().find(|s| s.id == sample_id).unwrap();
1487 let meta = sample.metadata.unwrap();
1488 assert_eq!(
1489 meta.get("review_score").and_then(|v| v.as_f64()).unwrap() as f32,
1490 0.9
1491 );
1492 assert_eq!(
1493 meta.get("review_text").and_then(|v| v.as_str()).unwrap(),
1494 "Good coherence"
1495 );
1496
1497 let stats = collector.get_statistics();
1498 assert!(stats.avg_feedback_score > 0.0); }
1500
1501 #[tokio::test]
1502 async fn test_export_jsonl_includes_review() {
1503 let config = TrainingConfig {
1504 enable_rlhf: false,
1505 ..Default::default()
1506 };
1507 let collector = TrainingCollector::new(config);
1508 let sample_id = collector
1509 .record_interaction("A", "B", None, 0.8)
1510 .await
1511 .unwrap();
1512 collector
1513 .add_review(sample_id, 0.6, Some("Okay".to_string()))
1514 .await
1515 .unwrap();
1516 let jsonl = collector.export_jsonl().await.unwrap();
1517 assert!(jsonl.contains("\"review_score\""));
1518 assert!(jsonl.contains("\"review_text\""));
1519 }
1520
1521 #[tokio::test]
1522 async fn print_jsonl_preview() {
1523 let config = TrainingConfig {
1524 enable_rlhf: true,
1525 ..Default::default()
1526 };
1527 let collector = TrainingCollector::new(config);
1528
1529 let s1 = collector
1530 .record_interaction("How are you?", "I'm well.", None, 0.82)
1531 .await
1532 .unwrap();
1533 collector
1534 .add_review(s1, 0.9, Some("Coherent".to_string()))
1535 .await
1536 .unwrap();
1537
1538 let s2 = collector
1539 .record_interaction("Tell a joke", "Why did the dev cross the road?", None, 0.78)
1540 .await
1541 .unwrap();
1542 collector
1543 .add_feedback(s2, 1.0, Some("Funny".to_string()))
1544 .await
1545 .unwrap();
1546
1547 let jsonl = collector.export_jsonl().await.unwrap();
1548 println!("{}", jsonl);
1549 }
1550
1551 #[tokio::test]
1552 async fn e2e_conversation_logging_preview() {
1553 let config = TrainingConfig {
1555 enable_rlhf: false,
1556 ..Default::default()
1557 };
1558 let collector = TrainingCollector::new(config);
1559
1560 let mut state = State::new();
1562 state.set_value("UI_TONE", "friendly".to_string());
1563 state.set_value("UI_VERBOSITY", "concise".to_string());
1564 state.set_value(
1565 "CONTEXT_LAST_THOUGHT",
1566 "User asked about project status; earlier we shipped v1".to_string(),
1567 );
1568 state.set_value(
1569 "DIALOGUE_SUMMARY",
1570 "Discussed roadmap, blockers, and timelines".to_string(),
1571 );
1572
1573 let room_id = Uuid::new_v4();
1575 let user_id = Uuid::new_v4();
1576 let agent_id = Uuid::new_v4();
1577 let message = Memory {
1578 id: Uuid::new_v4(),
1579 entity_id: user_id,
1580 agent_id,
1581 room_id,
1582 content: Content {
1583 text: "What is the current project status?".to_string(),
1584 ..Default::default()
1585 },
1586 embedding: None,
1587 metadata: None,
1588 created_at: chrono::Utc::now().timestamp(),
1589 unique: Some(false),
1590 similarity: None,
1591 };
1592 let response = Memory {
1593 id: Uuid::new_v4(),
1594 entity_id: agent_id,
1595 agent_id,
1596 room_id,
1597 content: Content {
1598 text: "We completed the core milestones and are preparing the release.".to_string(),
1599 ..Default::default()
1600 },
1601 embedding: None,
1602 metadata: None,
1603 created_at: chrono::Utc::now().timestamp(),
1604 unique: Some(false),
1605 similarity: None,
1606 };
1607
1608 println!(
1610 "[STATE] UI_TONE={}",
1611 state.get_value("UI_TONE").cloned().unwrap_or_default()
1612 );
1613 println!(
1614 "[STATE] UI_VERBOSITY={}",
1615 state.get_value("UI_VERBOSITY").cloned().unwrap_or_default()
1616 );
1617 println!(
1618 "[STATE] CONTEXT_LAST_THOUGHT={}",
1619 state
1620 .get_value("CONTEXT_LAST_THOUGHT")
1621 .cloned()
1622 .unwrap_or_default()
1623 );
1624 println!(
1625 "[STATE] DIALOGUE_SUMMARY={}",
1626 state
1627 .get_value("DIALOGUE_SUMMARY")
1628 .cloned()
1629 .unwrap_or_default()
1630 );
1631
1632 let sample_id = collector
1634 .record_conversation_turn(&message, &response, None, &state)
1635 .await
1636 .unwrap();
1637
1638 collector
1640 .add_review(
1641 sample_id,
1642 0.88,
1643 Some("Direct, concise, and helpful".to_string()),
1644 )
1645 .await
1646 .unwrap();
1647
1648 let jsonl = collector.export_jsonl().await.unwrap();
1651 println!("[DATASET]\n{}", jsonl);
1652
1653 let stats = collector.get_statistics();
1655 println!(
1656 "[STATS] total={}, avg_quality={:.2}, avg_feedback_or_review={:.2}",
1657 stats.total_samples, stats.avg_quality_score, stats.avg_feedback_score
1658 );
1659 }
1660}