1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//! Scalar elements.

use crate::TchError;

/// A single scalar value.
pub struct Scalar {
    pub(super) c_scalar: *mut torch_sys::C_scalar,
}

impl Scalar {
    /// Creates an integer scalar.
    pub fn int(v: i64) -> Scalar {
        let c_scalar = unsafe_torch!(torch_sys::ats_int(v));
        Scalar { c_scalar }
    }

    /// Creates a float scalar scalar.
    pub fn float(v: f64) -> Scalar {
        let c_scalar = unsafe_torch!(torch_sys::ats_float(v));
        Scalar { c_scalar }
    }

    /// Returns an integer value.
    pub fn to_int(&self) -> Result<i64, TchError> {
        let i = unsafe_torch_err!(torch_sys::ats_to_int(self.c_scalar));
        Ok(i)
    }

    /// Returns a float value.
    pub fn to_float(&self) -> Result<f64, TchError> {
        let f = unsafe_torch_err!(torch_sys::ats_to_float(self.c_scalar));
        Ok(f)
    }

    /// Returns a string representation of the scalar.
    pub fn to_string(&self) -> Result<String, TchError> {
        let s = unsafe_torch_err!({
            super::utils::ptr_to_string(torch_sys::ats_to_string(self.c_scalar))
        });
        match s {
            None => Err(TchError::Kind("nullptr representation".to_string())),
            Some(s) => Ok(s),
        }
    }
}

impl std::fmt::Debug for Scalar {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self.to_string() {
            Err(_) => write!(f, "err"),
            Ok(s) => write!(f, "scalar<{s}>"),
        }
    }
}

impl Drop for Scalar {
    fn drop(&mut self) {
        unsafe_torch!(torch_sys::ats_free(self.c_scalar))
    }
}

impl From<i64> for Scalar {
    fn from(v: i64) -> Scalar {
        Scalar::int(v)
    }
}

impl From<f64> for Scalar {
    fn from(v: f64) -> Scalar {
        Scalar::float(v)
    }
}

impl From<Scalar> for i64 {
    fn from(s: Scalar) -> i64 {
        Self::from(&s)
    }
}

impl From<Scalar> for f64 {
    fn from(s: Scalar) -> f64 {
        Self::from(&s)
    }
}

impl From<&Scalar> for i64 {
    fn from(s: &Scalar) -> i64 {
        s.to_int().unwrap()
    }
}

impl From<&Scalar> for f64 {
    fn from(s: &Scalar) -> f64 {
        s.to_float().unwrap()
    }
}

#[cfg(test)]
mod tests {
    use super::Scalar;
    #[test]
    fn scalar() {
        let pi = Scalar::float(std::f64::consts::PI);
        assert_eq!(i64::from(&pi), 3);
        assert_eq!(f64::from(&pi), std::f64::consts::PI);
        let leet = Scalar::int(1337);
        assert_eq!(i64::from(&leet), 1337);
        assert_eq!(f64::from(&leet), 1337.);
        assert_eq!(&format!("{pi:?}"), "scalar<3.14159>");
    }
}