rust_expect/
encoding.rs

1//! Encoding detection and conversion utilities.
2//!
3//! This module provides utilities for handling text encoding in terminal I/O,
4//! including UTF-8 validation, encoding detection, and line ending normalization.
5
6use std::borrow::Cow;
7use std::fmt::Write;
8
9/// Result of encoding a byte sequence to text.
10#[derive(Debug, Clone)]
11pub struct EncodedText {
12    /// The decoded text.
13    pub text: String,
14    /// Number of bytes consumed from input.
15    pub bytes_consumed: usize,
16    /// Whether there were any encoding errors.
17    pub had_errors: bool,
18    /// Number of replacement characters inserted.
19    pub replacements: usize,
20}
21
22impl EncodedText {
23    /// Create a successful encoding result.
24    #[must_use]
25    pub fn ok(text: impl Into<String>, bytes_consumed: usize) -> Self {
26        Self {
27            text: text.into(),
28            bytes_consumed,
29            had_errors: false,
30            replacements: 0,
31        }
32    }
33
34    /// Create an encoding result with errors.
35    #[must_use]
36    pub fn with_errors(
37        text: impl Into<String>,
38        bytes_consumed: usize,
39        replacements: usize,
40    ) -> Self {
41        Self {
42            text: text.into(),
43            bytes_consumed,
44            had_errors: replacements > 0,
45            replacements,
46        }
47    }
48}
49
50/// Decode bytes as UTF-8, replacing invalid sequences.
51///
52/// This is the default behavior for rust-expect. Invalid UTF-8 sequences
53/// are replaced with the Unicode replacement character (U+FFFD).
54#[must_use]
55pub fn decode_utf8_lossy(bytes: &[u8]) -> EncodedText {
56    let text = String::from_utf8_lossy(bytes);
57    let replacements = text.matches('\u{FFFD}').count();
58
59    EncodedText {
60        text: text.into_owned(),
61        bytes_consumed: bytes.len(),
62        had_errors: replacements > 0,
63        replacements,
64    }
65}
66
67/// Decode bytes as UTF-8, returning an error on invalid sequences.
68///
69/// # Errors
70///
71/// Returns an error if the input is not valid UTF-8.
72pub fn decode_utf8_strict(bytes: &[u8]) -> Result<EncodedText, std::str::Utf8Error> {
73    let text = std::str::from_utf8(bytes)?;
74    Ok(EncodedText::ok(text, bytes.len()))
75}
76
77/// Decode bytes as UTF-8, escaping invalid bytes as hex.
78///
79/// Invalid bytes are replaced with `\xHH` escape sequences.
80#[must_use]
81#[allow(unsafe_code)]
82pub fn decode_utf8_escape(bytes: &[u8]) -> EncodedText {
83    let mut result = String::with_capacity(bytes.len());
84    let mut replacements = 0;
85    let mut i = 0;
86
87    while i < bytes.len() {
88        match std::str::from_utf8(&bytes[i..]) {
89            Ok(valid) => {
90                result.push_str(valid);
91                break;
92            }
93            Err(e) => {
94                // Add the valid prefix
95                let valid_up_to = e.valid_up_to();
96                if valid_up_to > 0 {
97                    // Safe: from_utf8 confirmed these bytes are valid
98                    result.push_str(unsafe {
99                        std::str::from_utf8_unchecked(&bytes[i..i + valid_up_to])
100                    });
101                }
102                i += valid_up_to;
103
104                // Handle the invalid byte(s)
105                let error_len = e.error_len().unwrap_or(1);
106                for byte in &bytes[i..i + error_len] {
107                    let _ = write!(result, "\\x{byte:02x}");
108                    replacements += 1;
109                }
110                i += error_len;
111            }
112        }
113    }
114
115    EncodedText::with_errors(result, bytes.len(), replacements)
116}
117
118/// Skip invalid UTF-8 sequences.
119///
120/// Invalid bytes are simply removed from the output.
121#[must_use]
122#[allow(unsafe_code)]
123pub fn decode_utf8_skip(bytes: &[u8]) -> EncodedText {
124    let mut result = String::with_capacity(bytes.len());
125    let mut replacements = 0;
126    let mut i = 0;
127
128    while i < bytes.len() {
129        match std::str::from_utf8(&bytes[i..]) {
130            Ok(valid) => {
131                result.push_str(valid);
132                break;
133            }
134            Err(e) => {
135                let valid_up_to = e.valid_up_to();
136                if valid_up_to > 0 {
137                    result.push_str(unsafe {
138                        std::str::from_utf8_unchecked(&bytes[i..i + valid_up_to])
139                    });
140                }
141                i += valid_up_to;
142
143                let error_len = e.error_len().unwrap_or(1);
144                replacements += error_len;
145                i += error_len;
146            }
147        }
148    }
149
150    EncodedText::with_errors(result, bytes.len(), replacements)
151}
152
153/// Normalize line endings in text.
154///
155/// Converts all line endings (CRLF, CR, LF) to the specified style.
156#[must_use]
157pub fn normalize_line_endings(text: &str, ending: LineEndingStyle) -> Cow<'_, str> {
158    let target = ending.as_str();
159
160    // Check if normalization is needed
161    let needs_crlf = text.contains("\r\n");
162    let needs_cr = text.contains('\r') && !needs_crlf;
163    let needs_lf = text.contains('\n') && !needs_crlf;
164
165    // If already normalized, return as-is
166    match ending {
167        LineEndingStyle::Lf if !needs_crlf && !needs_cr => return Cow::Borrowed(text),
168        LineEndingStyle::CrLf if needs_crlf && !needs_cr && !needs_lf => {
169            return Cow::Borrowed(text);
170        }
171        LineEndingStyle::Cr if needs_cr && !needs_crlf && !needs_lf => return Cow::Borrowed(text),
172        _ => {}
173    }
174
175    // First normalize all endings to LF
176    let normalized = if needs_crlf {
177        text.replace("\r\n", "\n")
178    } else {
179        text.to_string()
180    };
181
182    let normalized = if normalized.contains('\r') {
183        normalized.replace('\r', "\n")
184    } else {
185        normalized
186    };
187
188    // Then convert to target if not LF
189    let result = if target == "\n" {
190        normalized
191    } else {
192        normalized.replace('\n', target)
193    };
194
195    Cow::Owned(result)
196}
197
198/// Line ending styles.
199#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
200pub enum LineEndingStyle {
201    /// Unix-style (LF)
202    #[default]
203    Lf,
204    /// Windows-style (CRLF)
205    CrLf,
206    /// Classic Mac (CR)
207    Cr,
208}
209
210impl LineEndingStyle {
211    /// Get the line ending as a string.
212    #[must_use]
213    pub const fn as_str(self) -> &'static str {
214        match self {
215            Self::Lf => "\n",
216            Self::CrLf => "\r\n",
217            Self::Cr => "\r",
218        }
219    }
220
221    /// Get the line ending as bytes.
222    #[must_use]
223    pub const fn as_bytes(self) -> &'static [u8] {
224        match self {
225            Self::Lf => b"\n",
226            Self::CrLf => b"\r\n",
227            Self::Cr => b"\r",
228        }
229    }
230
231    /// Detect the line ending style from environment.
232    #[must_use]
233    pub const fn from_env() -> Self {
234        if cfg!(windows) { Self::CrLf } else { Self::Lf }
235    }
236}
237
238/// Detect the predominant line ending in text.
239#[must_use]
240pub fn detect_line_ending(text: &str) -> Option<LineEndingStyle> {
241    let crlf_count = text.matches("\r\n").count();
242    let lf_only_count = text.matches('\n').count().saturating_sub(crlf_count);
243    let cr_only_count = text
244        .chars()
245        .zip(text.chars().skip(1).chain(std::iter::once('\0')))
246        .filter(|&(c, next)| c == '\r' && next != '\n')
247        .count();
248
249    if crlf_count == 0 && lf_only_count == 0 && cr_only_count == 0 {
250        return None;
251    }
252
253    if crlf_count >= lf_only_count && crlf_count >= cr_only_count {
254        Some(LineEndingStyle::CrLf)
255    } else if lf_only_count >= cr_only_count {
256        Some(LineEndingStyle::Lf)
257    } else {
258        Some(LineEndingStyle::Cr)
259    }
260}
261
262/// Detect encoding from environment variables.
263///
264/// Checks `LC_ALL`, `LC_CTYPE`, and `LANG` in order.
265#[must_use]
266pub fn detect_encoding_from_env() -> DetectedEncoding {
267    let locale = std::env::var("LC_ALL")
268        .or_else(|_| std::env::var("LC_CTYPE"))
269        .or_else(|_| std::env::var("LANG"))
270        .unwrap_or_default();
271
272    let locale_lower = locale.to_lowercase();
273
274    if locale_lower.contains("utf-8") || locale_lower.contains("utf8") {
275        DetectedEncoding::Utf8
276    } else if locale_lower.contains("iso-8859-1") || locale_lower.contains("iso8859-1") {
277        DetectedEncoding::Latin1
278    } else if locale_lower.contains("1252") {
279        DetectedEncoding::Windows1252
280    } else if locale.is_empty() {
281        // Default to UTF-8 for modern systems
282        DetectedEncoding::Utf8
283    } else {
284        DetectedEncoding::Unknown(locale)
285    }
286}
287
288/// Detected encoding from environment.
289#[derive(Debug, Clone, PartialEq, Eq)]
290pub enum DetectedEncoding {
291    /// UTF-8 encoding.
292    Utf8,
293    /// ISO-8859-1 (Latin-1).
294    Latin1,
295    /// Windows-1252.
296    Windows1252,
297    /// Unknown encoding (contains the locale string).
298    Unknown(String),
299}
300
301impl DetectedEncoding {
302    /// Check if this is UTF-8.
303    #[must_use]
304    pub const fn is_utf8(&self) -> bool {
305        matches!(self, Self::Utf8)
306    }
307}
308
309/// Strip ANSI escape sequences from text.
310///
311/// Removes all ANSI control sequences (CSI, OSC, etc.) from the input.
312#[must_use]
313pub fn strip_ansi(text: &str) -> Cow<'_, str> {
314    // Quick check: if no escape character, return as-is
315    if !text.contains('\x1b') {
316        return Cow::Borrowed(text);
317    }
318
319    let mut result = String::with_capacity(text.len());
320    let mut chars = text.chars().peekable();
321
322    while let Some(c) = chars.next() {
323        if c == '\x1b' {
324            // Start of escape sequence
325            if let Some(&next) = chars.peek() {
326                match next {
327                    '[' => {
328                        // CSI sequence: ESC [ ... final byte
329                        chars.next(); // consume '['
330                        while let Some(&param) = chars.peek() {
331                            if param.is_ascii_alphabetic() || param == '@' || param == '`' {
332                                chars.next(); // consume final byte
333                                break;
334                            }
335                            chars.next();
336                        }
337                    }
338                    ']' => {
339                        // OSC sequence: ESC ] ... ST or BEL
340                        chars.next(); // consume ']'
341                        while let Some(osc_char) = chars.next() {
342                            if osc_char == '\x07' || osc_char == '\x1b' {
343                                // BEL or possible ST
344                                if osc_char == '\x1b' && chars.peek() == Some(&'\\') {
345                                    chars.next(); // consume '\\'
346                                }
347                                break;
348                            }
349                        }
350                    }
351                    '(' | ')' | '*' | '+' => {
352                        // Designate character set: ESC ( X
353                        chars.next();
354                        chars.next();
355                    }
356                    _ if next.is_ascii_uppercase() || next == '=' || next == '>' => {
357                        // Simple escape sequence: ESC X
358                        chars.next();
359                    }
360                    _ => {
361                        // Unknown, just skip the ESC
362                    }
363                }
364            }
365        } else {
366            result.push(c);
367        }
368    }
369
370    Cow::Owned(result)
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn decode_valid_utf8() {
379        let result = decode_utf8_lossy(b"hello world");
380        assert_eq!(result.text, "hello world");
381        assert!(!result.had_errors);
382        assert_eq!(result.replacements, 0);
383    }
384
385    #[test]
386    fn decode_invalid_utf8_lossy() {
387        let result = decode_utf8_lossy(b"hello\xff\xfeworld");
388        assert!(result.text.contains('\u{FFFD}'));
389        assert!(result.had_errors);
390        assert!(result.replacements > 0);
391    }
392
393    #[test]
394    fn decode_invalid_utf8_escape() {
395        let result = decode_utf8_escape(b"hello\xffworld");
396        assert!(result.text.contains("\\xff"));
397        assert!(result.had_errors);
398    }
399
400    #[test]
401    fn decode_invalid_utf8_skip() {
402        let result = decode_utf8_skip(b"hello\xff\xfeworld");
403        assert_eq!(result.text, "helloworld");
404        assert!(result.had_errors);
405    }
406
407    #[test]
408    fn normalize_crlf_to_lf() {
409        let text = "line1\r\nline2\r\nline3";
410        let result = normalize_line_endings(text, LineEndingStyle::Lf);
411        assert_eq!(result, "line1\nline2\nline3");
412    }
413
414    #[test]
415    fn normalize_lf_to_crlf() {
416        let text = "line1\nline2\nline3";
417        let result = normalize_line_endings(text, LineEndingStyle::CrLf);
418        assert_eq!(result, "line1\r\nline2\r\nline3");
419    }
420
421    #[test]
422    fn detect_line_ending_lf() {
423        assert_eq!(
424            detect_line_ending("line1\nline2\n"),
425            Some(LineEndingStyle::Lf)
426        );
427    }
428
429    #[test]
430    fn detect_line_ending_crlf() {
431        assert_eq!(
432            detect_line_ending("line1\r\nline2\r\n"),
433            Some(LineEndingStyle::CrLf)
434        );
435    }
436
437    #[test]
438    fn strip_ansi_csi() {
439        let text = "\x1b[32mgreen\x1b[0m text";
440        let result = strip_ansi(text);
441        assert_eq!(result, "green text");
442    }
443
444    #[test]
445    fn strip_ansi_no_escape() {
446        let text = "plain text";
447        let result = strip_ansi(text);
448        assert!(matches!(result, Cow::Borrowed(_)));
449        assert_eq!(result, "plain text");
450    }
451
452    #[test]
453    fn strip_ansi_osc() {
454        let text = "\x1b]0;Window Title\x07normal text";
455        let result = strip_ansi(text);
456        assert_eq!(result, "normal text");
457    }
458}