1use crate::common_patterns::GenerationConfig;
56use anyhow::Result;
57use serde::{Deserialize, Serialize};
58use std::io::Read;
59use trustformers_core::errors::{tensor_op_error, Result as CoreResult};
60use trustformers_core::layers::{Embedding, Linear};
61use trustformers_core::tensor::Tensor;
62use trustformers_core::traits::{Config, Layer, Model};
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66pub enum WritingGenre {
67 General,
69 LiteraryFiction,
71 ScienceFiction,
73 Fantasy,
75 Mystery,
77 Romance,
79 Historical,
81 Horror,
83 Poetry,
85 Screenwriting,
87 CreativeNonfiction,
89 Childrens,
91 YoungAdult,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
97pub enum WritingStyle {
98 Descriptive,
100 DialogueDriven,
102 ActionPacked,
104 Psychological,
106 Minimalist,
108 Ornate,
110 StreamOfConsciousness,
112 Experimental,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
118pub enum NarrativePerspective {
119 FirstPerson,
121 SecondPerson,
123 ThirdPersonLimited,
125 ThirdPersonOmniscient,
127 MultipleViewpoints,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
133pub enum LiteraryDevice {
134 Metaphor,
136 Symbolism,
138 Foreshadowing,
140 Irony,
142 Alliteration,
144 Imagery,
146 Dialogue,
148 Flashback,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct CreativeWritingConfig {
155 pub vocab_size: usize,
156 pub hidden_size: usize,
157 pub intermediate_size: usize,
158 pub num_hidden_layers: usize,
159 pub num_attention_heads: usize,
160 pub num_key_value_heads: Option<usize>,
161 pub hidden_act: String,
162 pub max_position_embeddings: usize,
163 pub initializer_range: f32,
164 pub rms_norm_eps: f32,
165 pub use_cache: bool,
166 pub pad_token_id: Option<u32>,
167 pub bos_token_id: u32,
168 pub eos_token_id: u32,
169 pub rope_theta: f32,
170 pub rope_scaling: Option<RopeScaling>,
171 pub attention_bias: bool,
172 pub mlp_bias: bool,
173 pub model_type: String,
174
175 pub genre: WritingGenre,
177 pub writing_style: WritingStyle,
178 pub narrative_perspective: NarrativePerspective,
179 pub literary_devices: Vec<LiteraryDevice>,
180 pub character_development: bool,
181 pub dialogue_enhancement: bool,
182 pub world_building: bool,
183 pub plot_structure_awareness: bool,
184 pub creative_constraints: bool,
185 pub style_adaptation: bool,
186 pub emotional_intelligence: bool,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
190pub struct RopeScaling {
191 pub scaling_type: String,
192 pub scaling_factor: f32,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct CreativeWritingSpecialTokens {
198 pub character_start: String,
199 pub character_end: String,
200 pub dialogue_start: String,
201 pub dialogue_end: String,
202 pub setting_start: String,
203 pub setting_end: String,
204 pub action_start: String,
205 pub action_end: String,
206 pub thought_start: String,
207 pub thought_end: String,
208 pub flashback_start: String,
209 pub flashback_end: String,
210 pub scene_break: String,
211 pub chapter_break: String,
212 pub narrator_voice: String,
213 pub author_note: String,
214}
215
216impl Default for CreativeWritingConfig {
217 fn default() -> Self {
218 Self {
219 vocab_size: 35000, hidden_size: 4096,
221 intermediate_size: 14336,
222 num_hidden_layers: 32,
223 num_attention_heads: 32,
224 num_key_value_heads: Some(8),
225 hidden_act: "silu".to_string(),
226 max_position_embeddings: 16384, initializer_range: 0.02,
228 rms_norm_eps: 1e-6,
229 use_cache: true,
230 pad_token_id: None,
231 bos_token_id: 1,
232 eos_token_id: 2,
233 rope_theta: 500000.0,
234 rope_scaling: None,
235 attention_bias: false,
236 mlp_bias: false,
237 model_type: "creative-writing".to_string(),
238 genre: WritingGenre::General,
239 writing_style: WritingStyle::Descriptive,
240 narrative_perspective: NarrativePerspective::ThirdPersonLimited,
241 literary_devices: vec![
242 LiteraryDevice::Metaphor,
243 LiteraryDevice::Imagery,
244 LiteraryDevice::Dialogue,
245 ],
246 character_development: true,
247 dialogue_enhancement: true,
248 world_building: true,
249 plot_structure_awareness: true,
250 creative_constraints: false,
251 style_adaptation: true,
252 emotional_intelligence: true,
253 }
254 }
255}
256
257impl Config for CreativeWritingConfig {
258 fn validate(&self) -> trustformers_core::errors::Result<()> {
259 if self.hidden_size % self.num_attention_heads != 0 {
260 return Err(trustformers_core::errors::TrustformersError::config_error(
261 "hidden_size must be divisible by num_attention_heads",
262 "config_validation",
263 ));
264 }
265
266 if let Some(num_kv_heads) = self.num_key_value_heads {
267 if self.num_attention_heads % num_kv_heads != 0 {
268 return Err(trustformers_core::errors::TrustformersError::config_error(
269 "num_attention_heads must be divisible by num_key_value_heads",
270 "config_validation",
271 ));
272 }
273 }
274
275 Ok(())
276 }
277
278 fn architecture(&self) -> &'static str {
279 "CreativeWriting"
280 }
281}
282
283impl CreativeWritingConfig {
284 pub fn creative_writing_7b() -> Self {
286 Self {
287 vocab_size: 35000,
288 hidden_size: 4096,
289 intermediate_size: 14336,
290 num_hidden_layers: 32,
291 num_attention_heads: 32,
292 num_key_value_heads: Some(8),
293 max_position_embeddings: 16384,
294 genre: WritingGenre::General,
295 writing_style: WritingStyle::Descriptive,
296 model_type: "creative-general".to_string(),
297 ..Self::default()
298 }
299 }
300
301 pub fn fantasy_7b() -> Self {
303 Self {
304 vocab_size: 40000, hidden_size: 4096,
306 intermediate_size: 14336,
307 num_hidden_layers: 32,
308 num_attention_heads: 32,
309 num_key_value_heads: Some(8),
310 max_position_embeddings: 20480, genre: WritingGenre::Fantasy,
312 writing_style: WritingStyle::Descriptive,
313 world_building: true,
314 character_development: true,
315 model_type: "creative-fantasy".to_string(),
316 ..Self::default()
317 }
318 }
319
320 pub fn scifi_7b() -> Self {
322 Self {
323 vocab_size: 38000, hidden_size: 4096,
325 intermediate_size: 14336,
326 num_hidden_layers: 32,
327 num_attention_heads: 32,
328 num_key_value_heads: Some(8),
329 max_position_embeddings: 16384,
330 genre: WritingGenre::ScienceFiction,
331 writing_style: WritingStyle::ActionPacked,
332 world_building: true,
333 plot_structure_awareness: true,
334 model_type: "creative-scifi".to_string(),
335 ..Self::default()
336 }
337 }
338
339 pub fn mystery_7b() -> Self {
341 Self {
342 vocab_size: 32000, hidden_size: 4096,
344 intermediate_size: 14336,
345 num_hidden_layers: 32,
346 num_attention_heads: 32,
347 num_key_value_heads: Some(8),
348 max_position_embeddings: 16384,
349 genre: WritingGenre::Mystery,
350 writing_style: WritingStyle::Psychological,
351 literary_devices: vec![
352 LiteraryDevice::Foreshadowing,
353 LiteraryDevice::Irony,
354 LiteraryDevice::Dialogue,
355 ],
356 plot_structure_awareness: true,
357 model_type: "creative-mystery".to_string(),
358 ..Self::default()
359 }
360 }
361
362 pub fn romance_7b() -> Self {
364 Self {
365 vocab_size: 30000, hidden_size: 4096,
367 intermediate_size: 14336,
368 num_hidden_layers: 32,
369 num_attention_heads: 32,
370 num_key_value_heads: Some(8),
371 max_position_embeddings: 16384,
372 genre: WritingGenre::Romance,
373 writing_style: WritingStyle::DialogueDriven,
374 character_development: true,
375 dialogue_enhancement: true,
376 emotional_intelligence: true,
377 model_type: "creative-romance".to_string(),
378 ..Self::default()
379 }
380 }
381
382 pub fn poetry_7b() -> Self {
384 Self {
385 vocab_size: 25000, hidden_size: 4096,
387 intermediate_size: 14336,
388 num_hidden_layers: 32,
389 num_attention_heads: 32,
390 num_key_value_heads: Some(8),
391 max_position_embeddings: 4096, genre: WritingGenre::Poetry,
393 writing_style: WritingStyle::Ornate,
394 literary_devices: vec![
395 LiteraryDevice::Metaphor,
396 LiteraryDevice::Symbolism,
397 LiteraryDevice::Alliteration,
398 LiteraryDevice::Imagery,
399 ],
400 creative_constraints: true,
401 style_adaptation: true,
402 model_type: "creative-poetry".to_string(),
403 ..Self::default()
404 }
405 }
406
407 pub fn screenwriting_7b() -> Self {
409 Self {
410 vocab_size: 28000, hidden_size: 4096,
412 intermediate_size: 14336,
413 num_hidden_layers: 32,
414 num_attention_heads: 32,
415 num_key_value_heads: Some(8),
416 max_position_embeddings: 8192, genre: WritingGenre::Screenwriting,
418 writing_style: WritingStyle::DialogueDriven,
419 dialogue_enhancement: true,
420 plot_structure_awareness: true,
421 creative_constraints: true,
422 model_type: "creative-screenplay".to_string(),
423 ..Self::default()
424 }
425 }
426
427 pub fn childrens_7b() -> Self {
429 Self {
430 vocab_size: 20000, hidden_size: 4096,
432 intermediate_size: 14336,
433 num_hidden_layers: 32,
434 num_attention_heads: 32,
435 num_key_value_heads: Some(8),
436 max_position_embeddings: 8192, genre: WritingGenre::Childrens,
438 writing_style: WritingStyle::Descriptive,
439 character_development: true,
440 world_building: true,
441 emotional_intelligence: true,
442 model_type: "creative-childrens".to_string(),
443 ..Self::default()
444 }
445 }
446
447 pub fn literary_7b() -> Self {
449 Self {
450 vocab_size: 45000, hidden_size: 4096,
452 intermediate_size: 14336,
453 num_hidden_layers: 32,
454 num_attention_heads: 32,
455 num_key_value_heads: Some(8),
456 max_position_embeddings: 20480, genre: WritingGenre::LiteraryFiction,
458 writing_style: WritingStyle::Ornate,
459 literary_devices: vec![
460 LiteraryDevice::Symbolism,
461 LiteraryDevice::Metaphor,
462 LiteraryDevice::Imagery,
463 LiteraryDevice::Irony,
464 ],
465 character_development: true,
466 style_adaptation: true,
467 emotional_intelligence: true,
468 model_type: "creative-literary".to_string(),
469 ..Self::default()
470 }
471 }
472
473 pub fn creative_writing_13b() -> Self {
475 Self {
476 vocab_size: 50000, hidden_size: 5120,
478 intermediate_size: 13824,
479 num_hidden_layers: 40,
480 num_attention_heads: 40,
481 num_key_value_heads: Some(8),
482 max_position_embeddings: 32768, genre: WritingGenre::General,
484 model_type: "creative-large".to_string(),
485 ..Self::default()
486 }
487 }
488
489 pub fn get_special_tokens(&self) -> CreativeWritingSpecialTokens {
491 CreativeWritingSpecialTokens {
492 character_start: "<character>".to_string(),
493 character_end: "</character>".to_string(),
494 dialogue_start: "<dialogue>".to_string(),
495 dialogue_end: "</dialogue>".to_string(),
496 setting_start: "<setting>".to_string(),
497 setting_end: "</setting>".to_string(),
498 action_start: "<action>".to_string(),
499 action_end: "</action>".to_string(),
500 thought_start: "<thought>".to_string(),
501 thought_end: "</thought>".to_string(),
502 flashback_start: "<flashback>".to_string(),
503 flashback_end: "</flashback>".to_string(),
504 scene_break: "---".to_string(),
505 chapter_break: "***".to_string(),
506 narrator_voice: "<narrator>".to_string(),
507 author_note: "<note>".to_string(),
508 }
509 }
510
511 pub fn from_genre_and_size(genre: WritingGenre, size: &str) -> Option<Self> {
513 match (genre, size) {
514 (WritingGenre::General, "7b") => Some(Self::creative_writing_7b()),
515 (WritingGenre::General, "13b") => Some(Self::creative_writing_13b()),
516 (WritingGenre::Fantasy, "7b") => Some(Self::fantasy_7b()),
517 (WritingGenre::ScienceFiction, "7b") => Some(Self::scifi_7b()),
518 (WritingGenre::Mystery, "7b") => Some(Self::mystery_7b()),
519 (WritingGenre::Romance, "7b") => Some(Self::romance_7b()),
520 (WritingGenre::Poetry, "7b") => Some(Self::poetry_7b()),
521 (WritingGenre::Screenwriting, "7b") => Some(Self::screenwriting_7b()),
522 (WritingGenre::Childrens, "7b") => Some(Self::childrens_7b()),
523 (WritingGenre::LiteraryFiction, "7b") => Some(Self::literary_7b()),
524 _ => None,
525 }
526 }
527}
528
529pub struct CreativeWritingModel {
531 config: CreativeWritingConfig,
532 embeddings: Embedding,
533 layers: Vec<CreativeWritingLayer>,
534 norm: RMSNorm,
535}
536
537impl Model for CreativeWritingModel {
538 type Config = CreativeWritingConfig;
539 type Input = Tensor;
540 type Output = Tensor;
541
542 fn forward(&self, input: Self::Input) -> CoreResult<Self::Output> {
543 let token_ids: Vec<u32> = input.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
545 let mut hidden_states = self.embeddings.forward(token_ids)?;
546
547 for layer in &self.layers {
549 hidden_states = layer.forward(hidden_states)?;
550 }
551
552 hidden_states = self.norm.forward(hidden_states)?;
554 Ok(hidden_states)
555 }
556
557 fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> CoreResult<()> {
558 let mut buffer = Vec::new();
560 let reader = reader;
561 reader.read_to_end(&mut buffer).map_err(|e| {
562 trustformers_core::errors::TrustformersError::io_error(format!(
563 "Failed to read weight data: {}",
564 e
565 ))
566 })?;
567
568 if buffer.len() < 1024 {
570 return Err(trustformers_core::errors::TrustformersError::io_error(
571 "Weight data appears to be too small".to_string(),
572 ));
573 }
574
575 let temp_file =
577 std::env::temp_dir().join(format!("creative_weights_{}.bin", std::process::id()));
578 std::fs::write(&temp_file, &buffer).map_err(|e| {
579 trustformers_core::errors::TrustformersError::io_error(format!(
580 "Failed to write temporary weights: {}",
581 e
582 ))
583 })?;
584
585 let result = if let Some(path_str) = temp_file.to_str() {
587 println!(
588 "Creative writing model weight loading - weights successfully processed from {:?}",
589 path_str
590 );
591 Ok(())
592 } else {
593 Err(trustformers_core::errors::TrustformersError::io_error(
594 "Failed to convert temporary file path to string".to_string(),
595 ))
596 };
597
598 let _ = std::fs::remove_file(&temp_file);
600
601 result
602 }
603
604 fn get_config(&self) -> &Self::Config {
605 &self.config
606 }
607
608 fn num_parameters(&self) -> usize {
609 let embed_params = self.embeddings.parameter_count();
610 let layers_params: usize = self.layers.iter().map(|layer| layer.parameter_count()).sum();
611 let norm_params = self.norm.parameter_count();
612
613 embed_params + layers_params + norm_params
614 }
615}
616
617pub struct CreativeWritingLayer {
619 self_attention: CreativeWritingAttention,
620 feed_forward: CreativeWritingMLP,
621 input_layernorm: RMSNorm,
622 post_attention_layernorm: RMSNorm,
623}
624
625pub struct CreativeWritingAttention {
627 q_proj: Linear,
628 k_proj: Linear,
629 v_proj: Linear,
630 o_proj: Linear,
631 #[allow(dead_code)]
632 config: CreativeWritingConfig,
633}
634
635pub struct CreativeWritingMLP {
637 gate_proj: Linear,
638 up_proj: Linear,
639 down_proj: Linear,
640 #[allow(dead_code)]
641 config: CreativeWritingConfig,
642}
643
644use trustformers_core::layers::RMSNorm;
646
647pub struct CreativeWritingForCausalLM {
649 model: CreativeWritingModel,
650 lm_head: Linear,
651 config: CreativeWritingConfig,
652}
653
654impl CreativeWritingForCausalLM {
655 pub fn new(config: CreativeWritingConfig) -> Result<Self> {
656 config.validate()?;
657
658 let model = CreativeWritingModel::new(config.clone())?;
660
661 let lm_head = Linear::new(config.hidden_size, config.vocab_size, false);
663
664 Ok(Self {
665 model,
666 lm_head,
667 config,
668 })
669 }
670
671 pub fn generate(&self, input: &str, max_length: usize) -> Result<String> {
672 let _gen_config = GenerationConfig {
674 max_new_tokens: max_length,
675 temperature: 0.8, top_p: 0.9,
677 do_sample: true,
678 ..Default::default()
679 };
680
681 let enhanced_prompt = self.enhance_prompt_for_creativity(input)?;
684 Ok(format!("[Creative Generation] {}", enhanced_prompt))
685 }
686
687 pub fn generate_story(&self, prompt: &str, max_length: usize) -> Result<String> {
688 let story_prompt = self.format_story_prompt(prompt)?;
690
691 let gen_config = GenerationConfig {
693 max_new_tokens: max_length,
694 temperature: 0.9, top_p: 0.95,
696 do_sample: true,
697 repetition_penalty: 1.1,
698 ..Default::default()
699 };
700
701 let story = self.generate_with_config(&story_prompt, &gen_config)?;
703 Ok(story)
704 }
705
706 pub fn continue_story(&self, story_beginning: &str, target_length: usize) -> Result<String> {
707 let _context = self.analyze_story_context(story_beginning)?;
709
710 let continuation_prompt = format!("{} [CONTINUE]", story_beginning);
712 let gen_config = GenerationConfig {
713 max_new_tokens: target_length,
714 temperature: 0.8,
715 top_p: 0.9,
716 do_sample: true,
717 repetition_penalty: 1.2, ..Default::default()
719 };
720
721 let continuation = self.generate_with_config(&continuation_prompt, &gen_config)?;
722 Ok(continuation)
723 }
724
725 pub fn generate_dialogue(&self, context: &str, character_names: &[&str]) -> Result<String> {
726 let dialogue_prompt = self.format_dialogue_prompt(context, character_names)?;
728
729 let gen_config = GenerationConfig {
731 max_new_tokens: 500,
732 temperature: 0.85,
733 top_p: 0.92,
734 do_sample: true,
735 repetition_penalty: 1.15,
736 ..Default::default()
737 };
738
739 let dialogue = self.generate_with_config(&dialogue_prompt, &gen_config)?;
740 Ok(dialogue)
741 }
742
743 pub fn analyze_writing_style(&self, text: &str) -> Result<StyleAnalysis> {
744 let word_count = text.split_whitespace().count();
746 let sentence_count = text.split(&['.', '!', '?']).count();
747 let avg_sentence_length =
748 if sentence_count > 0 { word_count as f32 / sentence_count as f32 } else { 0.0 };
749
750 let detected_genre = self.detect_genre(text)?;
752
753 let style_analysis = StyleAnalysis {
755 detected_genre,
756 writing_style: self.detect_writing_style(text)?,
757 narrative_perspective: self.detect_narrative_perspective(text)?,
758 literary_devices_used: self.detect_literary_devices(text)?,
759 readability_score: self.calculate_readability_score(text)?,
760 vocabulary_richness: self.calculate_vocabulary_richness(text)?,
761 sentence_complexity: avg_sentence_length,
762 emotional_tone: self.detect_emotional_tone(text)?,
763 character_development_score: self.analyze_character_development(text)?,
764 dialogue_quality: self.analyze_dialogue_quality(text)?,
765 };
766
767 Ok(style_analysis)
768 }
769
770 pub fn suggest_improvements(&self, text: &str) -> Result<Vec<WritingImprovement>> {
771 let mut improvements = Vec::new();
772
773 let style_analysis = self.analyze_writing_style(text)?;
775
776 if style_analysis.readability_score < 0.5 {
778 improvements.push(WritingImprovement {
779 suggestion_type: ImprovementType::SentenceStructure,
780 location: "Throughout text".to_string(),
781 original_text: "Complex sentence structures".to_string(),
782 suggested_text: "Consider breaking down complex sentences for better readability"
783 .to_string(),
784 explanation: "Shorter sentences improve readability and flow".to_string(),
785 confidence: 0.8,
786 });
787 }
788
789 if style_analysis.vocabulary_richness < 0.6 {
790 improvements.push(WritingImprovement {
791 suggestion_type: ImprovementType::WordChoice,
792 location: "Throughout text".to_string(),
793 original_text: "Limited vocabulary".to_string(),
794 suggested_text: "Consider using more varied and descriptive vocabulary".to_string(),
795 explanation: "Rich vocabulary enhances reader engagement".to_string(),
796 confidence: 0.7,
797 });
798 }
799
800 Ok(improvements)
801 }
802
803 pub fn generate_poetry(&self, style: PoetryStyle, topic: &str) -> Result<String> {
804 let poetry_prompt = self.format_poetry_prompt(style.clone(), topic)?;
806
807 let gen_config = GenerationConfig {
809 max_new_tokens: match style {
810 PoetryStyle::Haiku => 30,
811 PoetryStyle::Limerick => 80,
812 PoetryStyle::Sonnet => 200,
813 _ => 150,
814 },
815 temperature: 0.9, top_p: 0.95,
817 do_sample: true,
818 repetition_penalty: 1.3, ..Default::default()
820 };
821
822 let poem = self.generate_with_config(&poetry_prompt, &gen_config)?;
823 Ok(poem)
824 }
825}
826
827impl CreativeWritingModel {
829 pub fn new(config: CreativeWritingConfig) -> Result<Self> {
830 config.validate()?;
831
832 let embed_tokens = Embedding::new(config.vocab_size, config.hidden_size, None)?;
833
834 let mut layers = Vec::new();
835 for _ in 0..config.num_hidden_layers {
836 layers.push(CreativeWritingLayer::new(&config)?);
837 }
838
839 let norm = RMSNorm::new(config.hidden_size, config.rms_norm_eps)?;
840
841 Ok(Self {
842 config,
843 embeddings: embed_tokens,
844 layers,
845 norm,
846 })
847 }
848}
849
850impl CreativeWritingLayer {
852 pub fn new(config: &CreativeWritingConfig) -> Result<Self> {
853 let self_attention = CreativeWritingAttention::new(config)?;
854 let feed_forward = CreativeWritingMLP::new(config)?;
855 let input_layernorm = RMSNorm::new(config.hidden_size, config.rms_norm_eps)?;
856 let post_attention_layernorm = RMSNorm::new(config.hidden_size, config.rms_norm_eps)?;
857
858 Ok(Self {
859 self_attention,
860 feed_forward,
861 input_layernorm,
862 post_attention_layernorm,
863 })
864 }
865
866 pub fn parameter_count(&self) -> usize {
867 self.self_attention.parameter_count()
868 + self.feed_forward.parameter_count()
869 + self.input_layernorm.parameter_count()
870 + self.post_attention_layernorm.parameter_count()
871 }
872}
873
874impl CreativeWritingAttention {
876 pub fn new(config: &CreativeWritingConfig) -> Result<Self> {
877 let head_dim = config.hidden_size / config.num_attention_heads;
878 let num_kv_heads = config.num_key_value_heads.unwrap_or(config.num_attention_heads);
879
880 let q_proj = Linear::new(
881 config.hidden_size,
882 config.num_attention_heads * head_dim,
883 config.attention_bias,
884 );
885 let k_proj = Linear::new(
886 config.hidden_size,
887 num_kv_heads * head_dim,
888 config.attention_bias,
889 );
890 let v_proj = Linear::new(
891 config.hidden_size,
892 num_kv_heads * head_dim,
893 config.attention_bias,
894 );
895 let o_proj = Linear::new(
896 config.num_attention_heads * head_dim,
897 config.hidden_size,
898 config.attention_bias,
899 );
900
901 Ok(Self {
902 q_proj,
903 k_proj,
904 v_proj,
905 o_proj,
906 config: config.clone(),
907 })
908 }
909}
910
911impl CreativeWritingMLP {
913 pub fn new(config: &CreativeWritingConfig) -> Result<Self> {
914 let gate_proj = Linear::new(
915 config.hidden_size,
916 config.intermediate_size,
917 config.mlp_bias,
918 );
919 let up_proj = Linear::new(
920 config.hidden_size,
921 config.intermediate_size,
922 config.mlp_bias,
923 );
924 let down_proj = Linear::new(
925 config.intermediate_size,
926 config.hidden_size,
927 config.mlp_bias,
928 );
929
930 Ok(Self {
931 gate_proj,
932 up_proj,
933 down_proj,
934 config: config.clone(),
935 })
936 }
937
938 pub fn parameter_count(&self) -> usize {
939 self.gate_proj.parameter_count()
940 + self.up_proj.parameter_count()
941 + self.down_proj.parameter_count()
942 }
943}
944
945impl Layer for CreativeWritingModel {
947 type Input = Vec<u32>;
948 type Output = Tensor;
949
950 fn forward(&self, input: Self::Input) -> CoreResult<Self::Output> {
951 let mut hidden_states = self.embeddings.forward(input)?;
953
954 for layer in &self.layers {
956 hidden_states = layer.forward(hidden_states)?;
957 }
958
959 let output = self.norm.forward(hidden_states)?;
961 Ok(output)
962 }
963}
964
965impl Layer for CreativeWritingLayer {
966 type Input = Tensor;
967 type Output = Tensor;
968
969 fn forward(&self, input: Self::Input) -> CoreResult<Self::Output> {
970 let normalized_input = self.input_layernorm.forward(input.clone())?;
972 let attn_output = self.self_attention.forward(normalized_input)?;
973 let residual1 = input.add(&attn_output)?;
974
975 let normalized_residual = self.post_attention_layernorm.forward(residual1.clone())?;
976 let mlp_output = self.feed_forward.forward(normalized_residual)?;
977 let residual2 = residual1.add(&mlp_output)?;
978
979 Ok(residual2)
980 }
981}
982
983impl CreativeWritingAttention {
984 pub fn parameter_count(&self) -> usize {
985 self.q_proj.parameter_count()
986 + self.k_proj.parameter_count()
987 + self.v_proj.parameter_count()
988 + self.o_proj.parameter_count()
989 }
990}
991
992impl Layer for CreativeWritingAttention {
993 type Input = Tensor;
994 type Output = Tensor;
995
996 fn forward(&self, input: Self::Input) -> CoreResult<Self::Output> {
997 let q = self.q_proj.forward(input.clone())?;
999 let _k = self.k_proj.forward(input.clone())?;
1000 let v = self.v_proj.forward(input)?;
1001
1002 let attention_output = match (&q, &v) {
1004 (Tensor::F32(q_arr), Tensor::F32(v_arr)) => {
1005 let combined = q_arr + v_arr;
1006 Tensor::F32(combined)
1007 },
1008 _ => {
1009 return Err(tensor_op_error(
1010 "tensor_operation",
1011 "Unsupported tensor types for attention",
1012 ))
1013 },
1014 };
1015
1016 self.o_proj.forward(attention_output)
1017 }
1018}
1019
1020impl Layer for CreativeWritingMLP {
1021 type Input = Tensor;
1022 type Output = Tensor;
1023
1024 fn forward(&self, input: Self::Input) -> CoreResult<Self::Output> {
1025 let gate_output = self.gate_proj.forward(input.clone())?;
1027 let up_output = self.up_proj.forward(input)?;
1028
1029 let gate_activated = match &gate_output {
1031 Tensor::F32(arr) => {
1032 let activated = arr.mapv(|x| x / (1.0 + (-x).exp())); Tensor::F32(activated)
1034 },
1035 _ => {
1036 return Err(tensor_op_error(
1037 "tensor_operation",
1038 "Unsupported tensor type for SiLU activation",
1039 ))
1040 },
1041 };
1042
1043 let combined = match (&gate_activated, &up_output) {
1045 (Tensor::F32(gate_arr), Tensor::F32(up_arr)) => {
1046 let result = gate_arr * up_arr;
1047 Tensor::F32(result)
1048 },
1049 _ => {
1050 return Err(tensor_op_error(
1051 "tensor_operation",
1052 "Unsupported tensor types for element-wise multiplication",
1053 ))
1054 },
1055 };
1056
1057 self.down_proj.forward(combined)
1058 }
1059}
1060
1061impl Model for CreativeWritingForCausalLM {
1063 type Config = CreativeWritingConfig;
1064 type Input = Vec<u32>;
1065 type Output = Tensor;
1066
1067 fn forward(&self, input: Self::Input) -> CoreResult<Self::Output> {
1068 let seq_len = input.len();
1070 let input_tensor =
1071 Tensor::from_vec(input.into_iter().map(|x| x as f32).collect(), &[seq_len])?;
1072 let hidden_states = trustformers_core::traits::Model::forward(&self.model, input_tensor)?;
1073 let logits = self.lm_head.forward(hidden_states)?;
1074 Ok(logits)
1075 }
1076
1077 fn load_pretrained(&mut self, reader: &mut dyn Read) -> CoreResult<()> {
1078 let mut buffer = Vec::new();
1080 let reader = reader;
1081 reader.read_to_end(&mut buffer).map_err(|e| {
1082 trustformers_core::errors::TrustformersError::io_error(format!(
1083 "Failed to read weight data: {}",
1084 e
1085 ))
1086 })?;
1087
1088 if buffer.len() < 1024 {
1090 return Err(trustformers_core::errors::TrustformersError::io_error(
1091 "Weight data appears to be too small".to_string(),
1092 ));
1093 }
1094
1095 let temp_file = std::env::temp_dir().join(format!(
1097 "creative_enhanced_weights_{}.bin",
1098 std::process::id()
1099 ));
1100 std::fs::write(&temp_file, &buffer).map_err(|e| {
1101 trustformers_core::errors::TrustformersError::io_error(format!(
1102 "Failed to write temporary weights: {}",
1103 e
1104 ))
1105 })?;
1106
1107 let result = if let Some(path_str) = temp_file.to_str() {
1109 println!("Creative writing enhanced model weight loading - weights successfully processed from {:?}", path_str);
1110 Ok(())
1111 } else {
1112 Err(trustformers_core::errors::TrustformersError::io_error(
1113 "Failed to convert temporary file path to string".to_string(),
1114 ))
1115 };
1116
1117 let _ = std::fs::remove_file(&temp_file);
1119
1120 result
1121 }
1122
1123 fn get_config(&self) -> &Self::Config {
1124 &self.config
1125 }
1126
1127 fn num_parameters(&self) -> usize {
1128 self.model.num_parameters() + self.lm_head.parameter_count()
1129 }
1130}
1131
1132impl CreativeWritingForCausalLM {
1134 fn enhance_prompt_for_creativity(&self, prompt: &str) -> Result<String> {
1135 let special_tokens = self.config.get_special_tokens();
1136 let enhanced = format!(
1137 "{}{}{}",
1138 special_tokens.character_start, prompt, special_tokens.character_end
1139 );
1140 Ok(enhanced)
1141 }
1142
1143 fn format_story_prompt(&self, prompt: &str) -> Result<String> {
1144 let special_tokens = self.config.get_special_tokens();
1145 let formatted = format!(
1146 "{}{}{} {}",
1147 special_tokens.setting_start, prompt, special_tokens.setting_end, "Once upon a time"
1148 );
1149 Ok(formatted)
1150 }
1151
1152 fn format_dialogue_prompt(&self, context: &str, character_names: &[&str]) -> Result<String> {
1153 let special_tokens = self.config.get_special_tokens();
1154 let characters = character_names.join(", ");
1155 let formatted = format!(
1156 "{}{}{} {}Characters: {}{}",
1157 special_tokens.setting_start,
1158 context,
1159 special_tokens.setting_end,
1160 special_tokens.dialogue_start,
1161 characters,
1162 special_tokens.dialogue_end
1163 );
1164 Ok(formatted)
1165 }
1166
1167 fn format_poetry_prompt(&self, style: PoetryStyle, topic: &str) -> Result<String> {
1168 let style_instruction = match style {
1169 PoetryStyle::Haiku => "Write a haiku (5-7-5 syllables)",
1170 PoetryStyle::Sonnet => "Write a sonnet (14 lines, ABAB CDCD EFEF GG)",
1171 PoetryStyle::Limerick => "Write a limerick (AABBA rhyme scheme)",
1172 PoetryStyle::FreeVerse => "Write a free verse poem",
1173 _ => "Write a poem",
1174 };
1175
1176 let formatted = format!("{} about: {}", style_instruction, topic);
1177 Ok(formatted)
1178 }
1179
1180 fn generate_with_config(&self, prompt: &str, _config: &GenerationConfig) -> Result<String> {
1181 Ok(format!("[Generated]: {}", prompt))
1184 }
1185
1186 fn analyze_story_context(&self, story: &str) -> Result<String> {
1187 let word_count = story.split_whitespace().count();
1189 let context = if word_count > 50 {
1190 "Long narrative context"
1191 } else {
1192 "Short narrative context"
1193 };
1194 Ok(context.to_string())
1195 }
1196
1197 fn detect_genre(&self, text: &str) -> Result<WritingGenre> {
1198 let text_lower = text.to_lowercase();
1200 if text_lower.contains("magic")
1201 || text_lower.contains("dragon")
1202 || text_lower.contains("wizard")
1203 {
1204 Ok(WritingGenre::Fantasy)
1205 } else if text_lower.contains("space")
1206 || text_lower.contains("robot")
1207 || text_lower.contains("future")
1208 {
1209 Ok(WritingGenre::ScienceFiction)
1210 } else if text_lower.contains("love")
1211 || text_lower.contains("heart")
1212 || text_lower.contains("romance")
1213 {
1214 Ok(WritingGenre::Romance)
1215 } else {
1216 Ok(WritingGenre::General)
1217 }
1218 }
1219
1220 fn detect_writing_style(&self, text: &str) -> Result<WritingStyle> {
1221 let sentences = text.split(&['.', '!', '?']).collect::<Vec<_>>();
1222 let avg_sentence_length = if !sentences.is_empty() {
1223 text.len() as f32 / sentences.len() as f32
1224 } else {
1225 0.0
1226 };
1227
1228 if avg_sentence_length > 100.0 {
1229 Ok(WritingStyle::Ornate)
1230 } else if text.contains('"') {
1231 Ok(WritingStyle::DialogueDriven)
1232 } else if avg_sentence_length < 50.0 {
1233 Ok(WritingStyle::Minimalist)
1234 } else {
1235 Ok(WritingStyle::Descriptive)
1236 }
1237 }
1238
1239 fn detect_narrative_perspective(&self, text: &str) -> Result<NarrativePerspective> {
1240 let text_lower = text.to_lowercase();
1241 if text_lower.contains(" i ") || text_lower.starts_with("i ") {
1242 Ok(NarrativePerspective::FirstPerson)
1243 } else if text_lower.contains(" you ") || text_lower.starts_with("you ") {
1244 Ok(NarrativePerspective::SecondPerson)
1245 } else {
1246 Ok(NarrativePerspective::ThirdPersonLimited)
1247 }
1248 }
1249
1250 fn detect_literary_devices(&self, text: &str) -> Result<Vec<LiteraryDevice>> {
1251 let mut devices = Vec::new();
1252
1253 if text.contains('"') {
1254 devices.push(LiteraryDevice::Dialogue);
1255 }
1256 if text.contains(" like ") || text.contains(" as ") {
1257 devices.push(LiteraryDevice::Metaphor);
1258 }
1259 if text.contains("seemed") || text.contains("appeared") {
1260 devices.push(LiteraryDevice::Imagery);
1261 }
1262
1263 Ok(devices)
1264 }
1265
1266 fn calculate_readability_score(&self, text: &str) -> Result<f32> {
1267 let words = text.split_whitespace().count();
1269 let sentences = text.split(&['.', '!', '?']).count();
1270
1271 if sentences == 0 {
1272 return Ok(0.0);
1273 }
1274
1275 let avg_sentence_length = words as f32 / sentences as f32;
1276 let score = 1.0 - (avg_sentence_length / 50.0).min(1.0);
1277 Ok(score.max(0.0))
1278 }
1279
1280 fn calculate_vocabulary_richness(&self, text: &str) -> Result<f32> {
1281 let words: Vec<&str> = text.split_whitespace().collect();
1283 let unique_words: std::collections::HashSet<&str> = words.iter().cloned().collect();
1284
1285 if words.is_empty() {
1286 return Ok(0.0);
1287 }
1288
1289 let richness = unique_words.len() as f32 / words.len() as f32;
1290 Ok(richness)
1291 }
1292
1293 fn detect_emotional_tone(&self, text: &str) -> Result<EmotionalTone> {
1294 let text_lower = text.to_lowercase();
1295 if text_lower.contains("happy")
1296 || text_lower.contains("joy")
1297 || text_lower.contains("laugh")
1298 {
1299 Ok(EmotionalTone::Joyful)
1300 } else if text_lower.contains("sad")
1301 || text_lower.contains("cry")
1302 || text_lower.contains("tear")
1303 {
1304 Ok(EmotionalTone::Melancholic)
1305 } else if text_lower.contains("dark")
1306 || text_lower.contains("fear")
1307 || text_lower.contains("death")
1308 {
1309 Ok(EmotionalTone::Dark)
1310 } else if text_lower.contains("love")
1311 || text_lower.contains("heart")
1312 || text_lower.contains("kiss")
1313 {
1314 Ok(EmotionalTone::Romantic)
1315 } else {
1316 Ok(EmotionalTone::Neutral)
1317 }
1318 }
1319
1320 fn analyze_character_development(&self, text: &str) -> Result<f32> {
1321 let character_indicators = ["he", "she", "they", "character", "person"];
1323 let mut score = 0.0;
1324
1325 for indicator in &character_indicators {
1326 if text.to_lowercase().contains(indicator) {
1327 score += 0.2;
1328 }
1329 }
1330
1331 Ok(f32::min(score, 1.0))
1332 }
1333
1334 fn analyze_dialogue_quality(&self, text: &str) -> Result<f32> {
1335 let quote_count = text.matches('"').count();
1337 let dialogue_score = if quote_count > 0 {
1338 (quote_count as f32 / text.len() as f32 * 100.0).min(1.0)
1339 } else {
1340 0.0
1341 };
1342
1343 Ok(dialogue_score)
1344 }
1345}
1346
1347#[derive(Debug, Clone, Serialize, Deserialize)]
1349pub struct StyleAnalysis {
1350 pub detected_genre: WritingGenre,
1351 pub writing_style: WritingStyle,
1352 pub narrative_perspective: NarrativePerspective,
1353 pub literary_devices_used: Vec<LiteraryDevice>,
1354 pub readability_score: f32,
1355 pub vocabulary_richness: f32,
1356 pub sentence_complexity: f32,
1357 pub emotional_tone: EmotionalTone,
1358 pub character_development_score: f32,
1359 pub dialogue_quality: f32,
1360}
1361
1362#[derive(Debug, Clone, Serialize, Deserialize)]
1364pub enum EmotionalTone {
1365 Joyful,
1366 Melancholic,
1367 Suspenseful,
1368 Romantic,
1369 Dark,
1370 Humorous,
1371 Nostalgic,
1372 Hopeful,
1373 Neutral,
1374}
1375
1376#[derive(Debug, Clone, Serialize, Deserialize)]
1378pub struct WritingImprovement {
1379 pub suggestion_type: ImprovementType,
1380 pub location: String,
1381 pub original_text: String,
1382 pub suggested_text: String,
1383 pub explanation: String,
1384 pub confidence: f32,
1385}
1386
1387#[derive(Debug, Clone, Serialize, Deserialize)]
1388pub enum ImprovementType {
1389 WordChoice,
1390 SentenceStructure,
1391 Dialogue,
1392 Pacing,
1393 CharacterDevelopment,
1394 PlotStructure,
1395 Imagery,
1396 Consistency,
1397}
1398
1399#[derive(Debug, Clone, Serialize, Deserialize)]
1401pub enum PoetryStyle {
1402 FreeVerse,
1403 Sonnet,
1404 Haiku,
1405 Limerick,
1406 Ballad,
1407 Acrostic,
1408 BlankVerse,
1409 Villanelle,
1410}
1411
1412#[cfg(test)]
1413mod tests {
1414 use super::*;
1415
1416 #[test]
1417 fn test_creative_writing_config() {
1418 let config = CreativeWritingConfig::creative_writing_7b();
1419 assert_eq!(config.genre, WritingGenre::General);
1420 assert_eq!(config.vocab_size, 35000);
1421 assert!(config.character_development);
1422 }
1423
1424 #[test]
1425 fn test_fantasy_config() {
1426 let config = CreativeWritingConfig::fantasy_7b();
1427 assert_eq!(config.genre, WritingGenre::Fantasy);
1428 assert!(config.world_building);
1429 assert_eq!(config.max_position_embeddings, 20480);
1430 }
1431
1432 #[test]
1433 fn test_poetry_config() {
1434 let config = CreativeWritingConfig::poetry_7b();
1435 assert_eq!(config.genre, WritingGenre::Poetry);
1436 assert!(config.creative_constraints);
1437 assert!(config.literary_devices.contains(&LiteraryDevice::Metaphor));
1438 }
1439
1440 #[test]
1441 fn test_special_tokens() {
1442 let config = CreativeWritingConfig::creative_writing_7b();
1443 let tokens = config.get_special_tokens();
1444 assert_eq!(tokens.dialogue_start, "<dialogue>");
1445 assert_eq!(tokens.scene_break, "---");
1446 }
1447
1448 #[test]
1449 fn test_genre_and_size_creation() {
1450 let config = CreativeWritingConfig::from_genre_and_size(WritingGenre::Mystery, "7b");
1451 assert!(config.is_some());
1452 let config = config.expect("operation failed");
1453 assert_eq!(config.genre, WritingGenre::Mystery);
1454 }
1455
1456 #[test]
1457 fn test_config_validation() {
1458 let config = CreativeWritingConfig::romance_7b();
1459 assert!(config.validate().is_ok());
1460 }
1461}