1use crate::{Error, Result};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
24pub enum SubtitleFormat {
25 Srt,
27 Vtt,
29}
30
31impl fmt::Display for SubtitleFormat {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 match self {
34 Self::Srt => write!(f, "srt"),
35 Self::Vtt => write!(f, "vtt"),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SubtitleCue {
43 pub index: usize,
45 pub start_secs: f64,
47 pub end_secs: f64,
49 pub text: String,
51}
52
53#[derive(Debug, Clone)]
55pub struct SubtitleTrack {
56 pub format: SubtitleFormat,
58 pub cues: Vec<SubtitleCue>,
60}
61
62impl SubtitleTrack {
63 #[must_use]
65 pub fn duration_secs(&self) -> f64 {
66 self.cues.last().map(|c| c.end_secs).unwrap_or(0.0)
67 }
68
69 #[must_use]
71 pub fn to_plain_text(&self) -> String {
72 self.cues.iter().map(|c| c.text.as_str()).collect::<Vec<_>>().join(" ")
73 }
74
75 #[must_use]
77 pub fn cues_in_range(&self, start: f64, end: f64) -> Vec<&SubtitleCue> {
78 self.cues.iter().filter(|c| c.end_secs > start && c.start_secs < end).collect()
79 }
80
81 #[must_use]
83 pub fn to_srt_string(&self) -> String {
84 use std::fmt::Write;
85 let mut out = String::new();
86 for (i, cue) in self.cues.iter().enumerate() {
87 if i > 0 {
88 out.push('\n');
89 }
90 let _ = writeln!(out, "{}", i + 1);
91 let _ = writeln!(
92 out,
93 "{} --> {}",
94 format_srt_time(cue.start_secs),
95 format_srt_time(cue.end_secs),
96 );
97 out.push_str(&cue.text);
98 out.push('\n');
99 }
100 out
101 }
102}
103
104#[allow(clippy::cast_sign_loss)]
106fn format_srt_time(secs: f64) -> String {
107 let total_ms = (secs.max(0.0) * 1000.0).round() as u64;
108 let ms = total_ms % 1000;
109 let total_secs = total_ms / 1000;
110 let s = total_secs % 60;
111 let total_mins = total_secs / 60;
112 let m = total_mins % 60;
113 let h = total_mins / 60;
114 format!("{h:02}:{m:02}:{s:02},{ms:03}")
115}
116
117#[must_use]
119#[allow(clippy::cast_sign_loss)]
120pub fn format_display_time(secs: f64) -> String {
121 let total_secs = secs.max(0.0).round() as u64;
122 let s = total_secs % 60;
123 let total_mins = total_secs / 60;
124 let m = total_mins % 60;
125 let h = total_mins / 60;
126 if h > 0 {
127 format!("{h}:{m:02}:{s:02}")
128 } else {
129 format!("{m}:{s:02}")
130 }
131}
132
133pub fn parse_subtitles(input: &str) -> Result<SubtitleTrack> {
135 let trimmed = strip_bom(input);
136 if trimmed.starts_with("WEBVTT") {
137 parse_vtt(trimmed)
138 } else {
139 parse_srt(trimmed)
140 }
141}
142
143fn strip_bom(s: &str) -> &str {
145 s.strip_prefix('\u{FEFF}').unwrap_or(s)
146}
147
148fn normalize_and_split(input: &str) -> Vec<String> {
152 let normalized = input.replace("\r\n", "\n").replace('\r', "\n");
153 normalized.split("\n\n").filter(|b| !b.trim().is_empty()).map(String::from).collect()
154}
155
156fn find_timestamp_index(lines: &[&str]) -> Option<usize> {
158 lines.iter().position(|l| l.contains("-->"))
159}
160
161fn find_srt_timestamp_index(lines: &[&str]) -> Option<usize> {
163 find_timestamp_index(lines).filter(|_| lines.len() >= 2)
164}
165
166fn parse_srt_index(lines: &[&str], ts_idx: usize, fallback: usize) -> usize {
168 if ts_idx > 0 {
169 lines[0].trim().parse::<usize>().unwrap_or(fallback)
170 } else {
171 fallback
172 }
173}
174
175fn extract_cue_text(lines: &[&str], ts_idx: usize) -> String {
177 lines[ts_idx + 1..].join("\n").trim().to_string()
178}
179
180fn build_srt_cue(index: usize, start: f64, end: f64, text: String) -> Option<SubtitleCue> {
182 if text.is_empty() {
183 return None;
184 }
185 Some(SubtitleCue { index: index.saturating_sub(1), start_secs: start, end_secs: end, text })
186}
187
188fn parse_srt_block(block: &str, fallback_index: usize) -> Result<Option<SubtitleCue>> {
190 let lines: Vec<&str> = block.lines().collect();
191 let Some(ts_idx) = find_srt_timestamp_index(&lines) else {
192 return Ok(None);
193 };
194
195 let index = parse_srt_index(&lines, ts_idx, fallback_index);
196 let (start, end) = parse_timestamp_line(lines[ts_idx], ',')?;
197 let text = extract_cue_text(&lines, ts_idx);
198 Ok(build_srt_cue(index, start, end, text))
199}
200
201fn reindex_cues(cues: &mut [SubtitleCue]) {
203 for (i, cue) in cues.iter_mut().enumerate() {
204 cue.index = i;
205 }
206}
207
208fn parse_srt(input: &str) -> Result<SubtitleTrack> {
210 let blocks = normalize_and_split(input);
211 let mut cues = Vec::new();
212
213 for block in &blocks {
214 if let Some(cue) = parse_srt_block(block, cues.len())? {
215 cues.push(cue);
216 }
217 }
218
219 if cues.is_empty() {
220 return Err(Error::InvalidInput("No valid SRT cues found".into()));
221 }
222
223 reindex_cues(&mut cues);
224
225 Ok(SubtitleTrack { format: SubtitleFormat::Srt, cues })
226}
227
228fn vtt_body(normalized: &str) -> &str {
232 normalized.split_once("\n\n").map(|x| x.1).unwrap_or("")
233}
234
235fn extract_vtt_cue_text(lines: &[&str], ts_idx: usize) -> String {
237 strip_vtt_tags(&lines[ts_idx + 1..].join("\n")).trim().to_string()
238}
239
240fn build_vtt_cue(index: usize, start: f64, end: f64, text: String) -> Option<SubtitleCue> {
242 if text.is_empty() {
243 return None;
244 }
245 Some(SubtitleCue { index, start_secs: start, end_secs: end, text })
246}
247
248fn parse_vtt_block(block: &str, index: usize) -> Result<Option<SubtitleCue>> {
250 let lines: Vec<&str> = block.lines().collect();
251 let Some(ts_idx) = find_timestamp_index(&lines) else {
252 return Ok(None);
253 };
254
255 let (start, end) = parse_timestamp_line(lines[ts_idx], '.')?;
256 let text = extract_vtt_cue_text(&lines, ts_idx);
257 Ok(build_vtt_cue(index, start, end, text))
258}
259
260fn parse_vtt(input: &str) -> Result<SubtitleTrack> {
262 let normalized = input.replace("\r\n", "\n").replace('\r', "\n");
263 let body = vtt_body(&normalized);
264 let mut cues = Vec::new();
265
266 for block in body.split("\n\n").filter(|b| !b.trim().is_empty()) {
267 if let Some(cue) = parse_vtt_block(block, cues.len())? {
268 cues.push(cue);
269 }
270 }
271
272 if cues.is_empty() {
273 return Err(Error::InvalidInput("No valid VTT cues found".into()));
274 }
275
276 Ok(SubtitleTrack { format: SubtitleFormat::Vtt, cues })
277}
278
279fn split_arrow(line: &str) -> Result<(&str, &str)> {
283 line.split_once("-->")
284 .ok_or_else(|| Error::InvalidInput(format!("Invalid timestamp line: {line}")))
285}
286
287fn extract_end_timestamp(end_half: &str) -> &str {
289 end_half.split_whitespace().next().unwrap_or("")
290}
291
292fn parse_timestamp_line(line: &str, ms_sep: char) -> Result<(f64, f64)> {
294 let (start_half, end_half) = split_arrow(line)?;
295 let start = parse_time(start_half.trim(), ms_sep)?;
296 let end = parse_time(extract_end_timestamp(end_half), ms_sep)?;
297 Ok((start, end))
298}
299
300fn parse_ts_field(field: &str, label: &str, raw: &str) -> Result<f64> {
302 field.parse().map_err(|e| Error::InvalidInput(format!("Bad timestamp {label} '{raw}': {e}")))
303}
304
305fn secs_from_mm_ss(parts: &[&str], raw: &str) -> Result<f64> {
307 let mins = parse_ts_field(parts[0], "minutes", raw)?;
308 let secs = parse_ts_field(parts[1], "seconds", raw)?;
309 Ok(mins * 60.0 + secs)
310}
311
312fn secs_from_hh_mm_ss(parts: &[&str], raw: &str) -> Result<f64> {
314 let hrs = parse_ts_field(parts[0], "hours", raw)?;
315 let mins = parse_ts_field(parts[1], "minutes", raw)?;
316 let secs = parse_ts_field(parts[2], "seconds", raw)?;
317 Ok(hrs * 3600.0 + mins * 60.0 + secs)
318}
319
320fn parse_time(s: &str, ms_sep: char) -> Result<f64> {
323 let normalized = s.replace(ms_sep, ".");
324 let parts: Vec<&str> = normalized.split(':').collect();
325 match parts.len() {
326 2 => secs_from_mm_ss(&parts, s),
327 3 => secs_from_hh_mm_ss(&parts, s),
328 _ => Err(Error::InvalidInput(format!("Invalid timestamp: {s}"))),
329 }
330}
331
332fn vtt_tag_filter(ch: char, in_tag: &mut bool) -> bool {
337 match ch {
338 '<' => {
339 *in_tag = true;
340 false
341 }
342 '>' => {
343 *in_tag = false;
344 false
345 }
346 _other => !*in_tag,
347 }
348}
349
350fn strip_vtt_tags(s: &str) -> String {
352 let mut in_tag = false;
353 s.chars().filter(|&ch| vtt_tag_filter(ch, &mut in_tag)).collect()
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
363 fn test_parse_srt_basic() {
364 let srt = "\
3651
36600:00:01,000 --> 00:00:04,500
367Welcome to this lecture.
368
3692
37000:00:05,000 --> 00:00:09,200
371Today we cover supervised learning.
372";
373 let track = parse_subtitles(srt).unwrap();
374 assert_eq!(track.format, SubtitleFormat::Srt);
375 assert_eq!(track.cues.len(), 2);
376 assert_eq!(track.cues[0].index, 0);
377 assert!((track.cues[0].start_secs - 1.0).abs() < 0.01);
378 assert!((track.cues[0].end_secs - 4.5).abs() < 0.01);
379 assert_eq!(track.cues[0].text, "Welcome to this lecture.");
380 assert!((track.cues[1].start_secs - 5.0).abs() < 0.01);
381 }
382
383 #[test]
384 fn test_parse_srt_multiline_text() {
385 let srt = "\
3861
38700:00:01,000 --> 00:00:04,500
388Line one of the cue
389and line two of the cue.
390";
391 let track = parse_subtitles(srt).unwrap();
392 assert_eq!(track.cues.len(), 1);
393 assert_eq!(track.cues[0].text, "Line one of the cue\nand line two of the cue.");
394 }
395
396 #[test]
397 fn test_parse_srt_with_bom() {
398 let srt = "\u{FEFF}1\n00:00:01,000 --> 00:00:04,500\nHello.\n";
399 let track = parse_subtitles(srt).unwrap();
400 assert_eq!(track.cues.len(), 1);
401 assert_eq!(track.cues[0].text, "Hello.");
402 }
403
404 #[test]
405 fn test_parse_srt_crlf() {
406 let srt = "1\r\n00:00:01,000 --> 00:00:04,500\r\nHello.\r\n\r\n2\r\n00:00:05,000 --> 00:00:09,000\r\nWorld.\r\n";
407 let track = parse_subtitles(srt).unwrap();
408 assert_eq!(track.cues.len(), 2);
409 }
410
411 #[test]
412 fn test_parse_srt_empty_cue_skipped() {
413 let srt = "\
4141
41500:00:01,000 --> 00:00:04,500
416
417
4182
41900:00:05,000 --> 00:00:09,000
420Actual text.
421";
422 let track = parse_subtitles(srt).unwrap();
423 assert_eq!(track.cues.len(), 1);
424 assert_eq!(track.cues[0].text, "Actual text.");
425 }
426
427 #[test]
428 fn test_parse_srt_error_on_empty() {
429 let result = parse_subtitles("");
430 assert!(result.is_err());
431 }
432
433 #[test]
436 fn test_parse_vtt_basic() {
437 let vtt = "\
438WEBVTT
439
44000:00:01.000 --> 00:00:04.500
441Welcome to this lecture.
442
44300:00:05.000 --> 00:00:09.200
444Today we cover supervised learning.
445";
446 let track = parse_subtitles(vtt).unwrap();
447 assert_eq!(track.format, SubtitleFormat::Vtt);
448 assert_eq!(track.cues.len(), 2);
449 assert!((track.cues[0].start_secs - 1.0).abs() < 0.01);
450 assert!((track.cues[0].end_secs - 4.5).abs() < 0.01);
451 assert_eq!(track.cues[0].text, "Welcome to this lecture.");
452 }
453
454 #[test]
455 fn test_parse_vtt_with_cue_ids() {
456 let vtt = "\
457WEBVTT
458
459intro-1
46000:00:01.000 --> 00:00:04.500
461Hello world.
462";
463 let track = parse_subtitles(vtt).unwrap();
464 assert_eq!(track.cues.len(), 1);
465 assert_eq!(track.cues[0].text, "Hello world.");
466 }
467
468 #[test]
469 fn test_parse_vtt_with_metadata_header() {
470 let vtt = "\
471WEBVTT
472Kind: captions
473Language: en
474
47500:00:01.000 --> 00:00:04.500
476Hello.
477";
478 let track = parse_subtitles(vtt).unwrap();
479 assert_eq!(track.cues.len(), 1);
480 }
481
482 #[test]
483 fn test_parse_vtt_strips_tags() {
484 let vtt = "\
485WEBVTT
486
48700:00:01.000 --> 00:00:04.500
488<b>Bold</b> and <i>italic</i> text.
489";
490 let track = parse_subtitles(vtt).unwrap();
491 assert_eq!(track.cues[0].text, "Bold and italic text.");
492 }
493
494 #[test]
495 fn test_parse_vtt_mm_ss_format() {
496 let vtt = "\
497WEBVTT
498
49901:30.000 --> 02:00.000
500Short timestamp format.
501";
502 let track = parse_subtitles(vtt).unwrap();
503 assert!((track.cues[0].start_secs - 90.0).abs() < 0.01);
504 assert!((track.cues[0].end_secs - 120.0).abs() < 0.01);
505 }
506
507 #[test]
508 fn test_parse_vtt_position_settings() {
509 let vtt = "\
510WEBVTT
511
51200:00:01.000 --> 00:00:04.500 position:10% align:start
513Positioned text.
514";
515 let track = parse_subtitles(vtt).unwrap();
516 assert_eq!(track.cues.len(), 1);
517 assert!((track.cues[0].end_secs - 4.5).abs() < 0.01);
518 }
519
520 #[test]
523 fn test_track_duration() {
524 let track = SubtitleTrack {
525 format: SubtitleFormat::Srt,
526 cues: vec![
527 SubtitleCue { index: 0, start_secs: 0.0, end_secs: 5.0, text: "A".into() },
528 SubtitleCue { index: 1, start_secs: 5.0, end_secs: 120.5, text: "B".into() },
529 ],
530 };
531 assert!((track.duration_secs() - 120.5).abs() < 0.01);
532 }
533
534 #[test]
535 fn test_track_duration_empty() {
536 let track = SubtitleTrack { format: SubtitleFormat::Srt, cues: vec![] };
537 assert!((track.duration_secs()).abs() < 0.01);
538 }
539
540 #[test]
541 fn test_track_plain_text() {
542 let track = SubtitleTrack {
543 format: SubtitleFormat::Srt,
544 cues: vec![
545 SubtitleCue { index: 0, start_secs: 0.0, end_secs: 3.0, text: "Hello".into() },
546 SubtitleCue { index: 1, start_secs: 3.0, end_secs: 6.0, text: "world".into() },
547 ],
548 };
549 assert_eq!(track.to_plain_text(), "Hello world");
550 }
551
552 #[test]
553 fn test_track_cues_in_range() {
554 let track = SubtitleTrack {
555 format: SubtitleFormat::Srt,
556 cues: vec![
557 SubtitleCue { index: 0, start_secs: 0.0, end_secs: 5.0, text: "A".into() },
558 SubtitleCue { index: 1, start_secs: 5.0, end_secs: 10.0, text: "B".into() },
559 SubtitleCue { index: 2, start_secs: 10.0, end_secs: 15.0, text: "C".into() },
560 ],
561 };
562 let range = track.cues_in_range(4.0, 11.0);
564 assert_eq!(range.len(), 3);
565 assert_eq!(range[0].text, "A");
566 assert_eq!(range[1].text, "B");
567 assert_eq!(range[2].text, "C");
568
569 let range2 = track.cues_in_range(6.0, 9.0);
571 assert_eq!(range2.len(), 1);
572 assert_eq!(range2[0].text, "B");
573 }
574
575 #[test]
576 fn test_srt_roundtrip() {
577 let srt = "\
5781
57900:00:01,000 --> 00:00:04,500
580Hello world.
581
5822
58300:01:30,500 --> 00:02:00,000
584Second cue here.
585";
586 let track = parse_subtitles(srt).unwrap();
587 let output = track.to_srt_string();
588 let reparsed = parse_srt(&output).unwrap();
589 assert_eq!(reparsed.cues.len(), track.cues.len());
590 for (a, b) in track.cues.iter().zip(reparsed.cues.iter()) {
591 assert!((a.start_secs - b.start_secs).abs() < 0.01);
592 assert!((a.end_secs - b.end_secs).abs() < 0.01);
593 assert_eq!(a.text, b.text);
594 }
595 }
596
597 #[test]
600 fn test_parse_time_zero() {
601 let t = parse_time("00:00:00.000", '.').unwrap();
602 assert!((t).abs() < 0.001);
603 }
604
605 #[test]
606 fn test_parse_time_large() {
607 let t = parse_time("99:59:59.999", '.').unwrap();
608 let expected = 99.0 * 3600.0 + 59.0 * 60.0 + 59.999;
609 assert!((t - expected).abs() < 0.01);
610 }
611
612 #[test]
613 fn test_parse_time_mm_ss() {
614 let t = parse_time("01:30.500", '.').unwrap();
615 assert!((t - 90.5).abs() < 0.01);
616 }
617
618 #[test]
619 fn test_parse_time_invalid() {
620 assert!(parse_time("invalid", '.').is_err());
621 assert!(parse_time("1:2:3:4", '.').is_err());
622 }
623
624 #[test]
627 fn test_format_display_time() {
628 assert_eq!(format_display_time(0.0), "0:00");
629 assert_eq!(format_display_time(65.0), "1:05");
630 assert_eq!(format_display_time(3661.0), "1:01:01");
631 assert_eq!(format_display_time(90.4), "1:30");
632 }
633
634 #[test]
635 fn test_format_srt_time() {
636 assert_eq!(format_srt_time(0.0), "00:00:00,000");
637 assert_eq!(format_srt_time(90.5), "00:01:30,500");
638 assert_eq!(format_srt_time(3661.123), "01:01:01,123");
639 }
640
641 #[test]
642 fn test_subtitle_format_display() {
643 assert_eq!(SubtitleFormat::Srt.to_string(), "srt");
644 assert_eq!(SubtitleFormat::Vtt.to_string(), "vtt");
645 }
646
647 #[test]
650 fn test_strip_vtt_tags_none() {
651 assert_eq!(strip_vtt_tags("plain text"), "plain text");
652 }
653
654 #[test]
655 fn test_strip_vtt_tags_bold() {
656 assert_eq!(strip_vtt_tags("<b>bold</b>"), "bold");
657 }
658
659 #[test]
660 fn test_strip_vtt_tags_nested() {
661 assert_eq!(strip_vtt_tags("<b><i>text</i></b>"), "text");
662 }
663
664 #[test]
665 fn test_strip_vtt_tags_class() {
666 assert_eq!(strip_vtt_tags("<c.highlight>text</c>"), "text");
667 }
668}