Skip to main content

scivex_core/tensor/
display.rs

1//! `Display` and `Debug` formatting for [`Tensor`].
2
3use core::fmt;
4
5use crate::Scalar;
6
7use super::Tensor;
8
9impl<T: Scalar> fmt::Display for Tensor<T> {
10    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
11        if self.is_empty() {
12            return write!(f, "tensor([], shape={:?})", self.shape);
13        }
14
15        match self.ndim() {
16            0 => write!(f, "tensor({})", self.data[0]),
17            1 => {
18                write!(f, "tensor([")?;
19                for (i, v) in self.data.iter().enumerate() {
20                    if i > 0 {
21                        write!(f, ", ")?;
22                    }
23                    write!(f, "{v}")?;
24                }
25                write!(f, "])")
26            }
27            2 => {
28                let rows = self.shape[0];
29                let cols = self.shape[1];
30                writeln!(f, "tensor([")?;
31                for r in 0..rows {
32                    write!(f, "  [")?;
33                    for c in 0..cols {
34                        if c > 0 {
35                            write!(f, ", ")?;
36                        }
37                        write!(f, "{}", self.data[r * cols + c])?;
38                    }
39                    if r < rows - 1 {
40                        writeln!(f, "],")?;
41                    } else {
42                        writeln!(f, "]")?;
43                    }
44                }
45                write!(f, "])")
46            }
47            _ => {
48                // For 3-D+ tensors, show shape and flat data summary
49                write!(
50                    f,
51                    "tensor(shape={:?}, data=[{}, {}, ..., {}])",
52                    self.shape,
53                    self.data[0],
54                    self.data[1],
55                    self.data[self.data.len() - 1]
56                )
57            }
58        }
59    }
60}
61
62// ---------------------------------------------------------------------------
63// HTML rendering for Jupyter / evcxr
64// ---------------------------------------------------------------------------
65
66/// Maximum elements to show for 1-D/flat tensors in HTML.
67const MAX_HTML_ELEMENTS: usize = 100;
68
69impl<T: Scalar> Tensor<T> {
70    /// Render the tensor as an HTML string.
71    ///
72    /// - Scalars and 1-D tensors are shown as formatted values.
73    /// - 2-D tensors are rendered as `<table>`.
74    /// - Higher-dimensional tensors show shape and a data summary.
75    pub fn to_html(&self) -> String {
76        use fmt::Write;
77
78        if self.is_empty() {
79            return format!("<pre>tensor([], shape={:?})</pre>", self.shape);
80        }
81
82        match self.ndim() {
83            0 => format!("<pre>tensor({})</pre>", self.data[0]),
84            1 => {
85                let n = self.data.len();
86                let truncated = n > MAX_HTML_ELEMENTS;
87                let show = if truncated { MAX_HTML_ELEMENTS } else { n };
88                let mut s = String::from("<pre>tensor([");
89                for (i, v) in self.data.iter().take(show).enumerate() {
90                    if i > 0 {
91                        s.push_str(", ");
92                    }
93                    let _ = write!(s, "{v}");
94                }
95                if truncated {
96                    let _ = write!(s, ", ... ({n} elements)");
97                }
98                s.push_str("])</pre>");
99                s
100            }
101            2 => {
102                let rows = self.shape[0];
103                let cols = self.shape[1];
104                let mut html = String::with_capacity(64 + rows * cols * 12);
105                let _ = writeln!(
106                    html,
107                    "<div><strong>Tensor</strong> shape=[{rows}, {cols}]</div>"
108                );
109                html.push_str("<table style=\"border-collapse:collapse;\">\n<tbody>\n");
110                let max_rows = 20;
111                let truncated = rows > max_rows;
112                for r in 0..rows.min(max_rows) {
113                    html.push_str("<tr>");
114                    for c in 0..cols {
115                        html.push_str("<td style=\"border:1px solid #ddd;padding:2px 6px;text-align:right;\">");
116                        let _ = write!(html, "{}", self.data[r * cols + c]);
117                        html.push_str("</td>");
118                    }
119                    html.push_str("</tr>\n");
120                }
121                if truncated {
122                    html.push_str("<tr>");
123                    for _ in 0..cols {
124                        html.push_str("<td style=\"border:1px solid #ddd;padding:2px 6px;text-align:center;\">…</td>");
125                    }
126                    html.push_str("</tr>\n");
127                }
128                html.push_str("</tbody>\n</table>\n");
129                html
130            }
131            _ => {
132                format!(
133                    "<pre>tensor(shape={:?}, numel={})</pre>",
134                    self.shape,
135                    self.data.len()
136                )
137            }
138        }
139    }
140
141    /// Display this tensor in an evcxr Jupyter notebook.
142    ///
143    /// Auto-detected by the evcxr kernel for rich HTML output.
144    pub fn evcxr_display(&self) {
145        println!(
146            "EVCXR_BEGIN_CONTENT text/html\n{}\nEVCXR_END_CONTENT",
147            self.to_html()
148        );
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_display_scalar() {
158        let t = Tensor::scalar(42_i32);
159        assert_eq!(format!("{t}"), "tensor(42)");
160    }
161
162    #[test]
163    fn test_display_1d() {
164        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
165        assert_eq!(format!("{t}"), "tensor([1, 2, 3])");
166    }
167
168    #[test]
169    fn test_display_2d() {
170        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
171        let s = format!("{t}");
172        assert!(s.contains("tensor("));
173        assert!(s.contains("[1, 2]"));
174        assert!(s.contains("[3, 4]"));
175    }
176
177    #[test]
178    fn test_display_empty() {
179        let t = Tensor::<f64>::zeros(vec![0]);
180        let s = format!("{t}");
181        assert!(s.contains("[]"));
182    }
183
184    #[test]
185    fn test_display_3d() {
186        let t = Tensor::<i32>::arange(24).reshape(vec![2, 3, 4]).unwrap();
187        let s = format!("{t}");
188        assert!(s.contains("shape=[2, 3, 4]"));
189    }
190
191    #[test]
192    fn test_to_html_scalar() {
193        let t = Tensor::scalar(42_i32);
194        let html = t.to_html();
195        assert!(html.contains("tensor(42)"));
196        assert!(html.contains("<pre>"));
197    }
198
199    #[test]
200    fn test_to_html_1d() {
201        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
202        let html = t.to_html();
203        assert!(html.contains("tensor(["));
204        assert!(html.contains('1'));
205        assert!(html.contains('3'));
206    }
207
208    #[test]
209    fn test_to_html_2d() {
210        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
211        let html = t.to_html();
212        assert!(html.contains("<table"));
213        assert!(html.contains("<tbody>"));
214        assert!(html.contains("shape=[2, 2]"));
215        assert!(html.contains(">1</td>"));
216        assert!(html.contains(">4</td>"));
217    }
218
219    #[test]
220    fn test_to_html_empty() {
221        let t = Tensor::<f64>::zeros(vec![0]);
222        let html = t.to_html();
223        assert!(html.contains("[]"));
224    }
225
226    #[test]
227    fn test_to_html_3d() {
228        let t = Tensor::<i32>::arange(24).reshape(vec![2, 3, 4]).unwrap();
229        let html = t.to_html();
230        assert!(html.contains("shape=[2, 3, 4]"));
231        assert!(html.contains("numel=24"));
232    }
233}