use super::stream::ReadSeekAdapter;
use super::utils::{path_to_cstring, ptr_to_string};
use super::{
device::{Cuda, Device},
kind,
kind::Kind,
};
use crate::TchError;
use libc::{c_char, c_int, c_void};
use std::borrow::Borrow;
use std::io::{Read, Seek, Write};
use std::path::Path;
use torch_sys::io::ReadStream;
use torch_sys::*;
#[must_use]
pub struct Tensor {
pub(super) c_tensor: *mut C_tensor,
}
unsafe impl Send for Tensor {}
pub extern "C" fn add_callback(data: *mut c_void, name: *const c_char, c_tensor: *mut C_tensor) {
let name = unsafe { std::ffi::CStr::from_ptr(name).to_str().unwrap() };
let name = name.replace('|', ".");
let v: &mut Vec<(String, Tensor)> = unsafe { &mut *(data as *mut Vec<(String, Tensor)>) };
v.push((name, Tensor { c_tensor }))
}
impl Tensor {
pub fn new() -> Tensor {
let c_tensor = unsafe_torch!(at_new_tensor());
Tensor { c_tensor }
}
pub unsafe fn from_ptr(c_tensor: *mut C_tensor) -> Self {
Self { c_tensor }
}
pub unsafe fn clone_from_ptr(c_tensor: *mut C_tensor) -> Self {
let c_tensor = at_shallow_clone(c_tensor);
crate::wrappers::utils::read_and_clean_error().unwrap();
Self { c_tensor }
}
pub fn as_ptr(&self) -> *const C_tensor {
self.c_tensor
}
pub fn as_mut_ptr(&mut self) -> *mut C_tensor {
self.c_tensor
}
pub fn dim(&self) -> usize {
unsafe_torch!(at_dim(self.c_tensor))
}
pub fn size(&self) -> Vec<i64> {
let dim = unsafe_torch!(at_dim(self.c_tensor));
let mut sz = vec![0i64; dim];
unsafe_torch!(at_shape(self.c_tensor, sz.as_mut_ptr()));
sz
}
pub fn size1(&self) -> Result<i64, TchError> {
match self.size().as_slice() {
&[s0] => Ok(s0),
size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
}
}
pub fn size2(&self) -> Result<(i64, i64), TchError> {
match self.size().as_slice() {
&[s0, s1] => Ok((s0, s1)),
size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
}
}
pub fn size3(&self) -> Result<(i64, i64, i64), TchError> {
match self.size().as_slice() {
&[s0, s1, s2] => Ok((s0, s1, s2)),
size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
}
}
pub fn size4(&self) -> Result<(i64, i64, i64, i64), TchError> {
match self.size().as_slice() {
&[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
}
}
pub fn size5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
match self.size().as_slice() {
&[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
}
}
pub fn size6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
match self.size().as_slice() {
&[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
}
}
pub fn stride(&self) -> Vec<i64> {
let dim = unsafe_torch!(at_dim(self.c_tensor));
let mut sz = vec![0i64; dim];
unsafe_torch!(at_stride(self.c_tensor, sz.as_mut_ptr()));
sz
}
pub fn stride1(&self) -> Result<i64, TchError> {
match self.stride().as_slice() {
&[s0] => Ok(s0),
size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
}
}
pub fn stride2(&self) -> Result<(i64, i64), TchError> {
match self.stride().as_slice() {
&[s0, s1] => Ok((s0, s1)),
size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
}
}
pub fn stride3(&self) -> Result<(i64, i64, i64), TchError> {
match self.stride().as_slice() {
&[s0, s1, s2] => Ok((s0, s1, s2)),
size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
}
}
pub fn stride4(&self) -> Result<(i64, i64, i64, i64), TchError> {
match self.stride().as_slice() {
&[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
}
}
pub fn stride5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
match self.stride().as_slice() {
&[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
}
}
pub fn stride6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
match self.stride().as_slice() {
&[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
}
}
pub fn f_kind(&self) -> Result<Kind, TchError> {
let kind = unsafe_torch!(at_scalar_type(self.c_tensor));
Kind::from_c_int(kind)
}
pub fn kind(&self) -> Kind {
self.f_kind().unwrap()
}
pub fn device(&self) -> Device {
let device = unsafe_torch!(at_device(self.c_tensor));
Device::from_c_int(device)
}
pub fn print(&self) {
unsafe_torch!(at_print(self.c_tensor))
}
pub fn f_double_value(&self, idx: &[i64]) -> Result<f64, TchError> {
Ok(unsafe_torch_err!({
at_double_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
}))
}
pub fn f_int64_value(&self, idx: &[i64]) -> Result<i64, TchError> {
Ok(unsafe_torch_err!({
at_int64_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
}))
}
pub fn double_value(&self, idx: &[i64]) -> f64 {
self.f_double_value(idx).unwrap()
}
pub fn int64_value(&self, idx: &[i64]) -> i64 {
self.f_int64_value(idx).unwrap()
}
pub fn requires_grad(&self) -> bool {
unsafe_torch!(at_requires_grad(self.c_tensor)) != 0
}
pub fn data_ptr(&self) -> *mut c_void {
unsafe_torch!(at_data_ptr(self.c_tensor))
}
pub fn defined(&self) -> bool {
unsafe_torch!(at_defined(self.c_tensor) != 0)
}
pub fn is_mkldnn(&self) -> bool {
unsafe_torch!(at_is_mkldnn(self.c_tensor) != 0)
}
pub fn is_sparse(&self) -> bool {
unsafe_torch!(at_is_sparse(self.c_tensor) != 0)
}
pub fn is_contiguous(&self) -> bool {
unsafe_torch!(at_is_contiguous(self.c_tensor) != 0)
}
pub fn zero_grad(&mut self) {
let mut grad = self.grad();
if grad.defined() {
let _ = grad.detach_().zero_();
}
}
pub fn f_backward(&self) -> Result<(), TchError> {
unsafe_torch_err!(at_backward(self.c_tensor, 0, 0));
Ok(())
}
pub fn backward(&self) {
self.f_backward().unwrap()
}
pub fn f_run_backward<T1, T2>(
tensors: &[T1],
inputs: &[T2],
keep_graph: bool,
create_graph: bool,
) -> Result<Vec<Tensor>, TchError>
where
T1: Borrow<Tensor>,
T2: Borrow<Tensor>,
{
let mut outputs = vec![std::ptr::null_mut(); inputs.len()];
let tensors: Vec<_> = tensors.iter().map(|x| x.borrow().c_tensor).collect();
let inputs: Vec<_> = inputs.iter().map(|x| x.borrow().c_tensor).collect();
unsafe_torch_err!(at_run_backward(
tensors.as_ptr(),
tensors.len() as c_int,
inputs.as_ptr(),
inputs.len() as c_int,
outputs.as_mut_ptr(),
keep_graph as c_int,
create_graph as c_int,
));
Ok(outputs.into_iter().map(|c_tensor| Tensor { c_tensor }).collect())
}
pub fn run_backward<T1, T2>(
tensors: &[T1],
inputs: &[T2],
keep_graph: bool,
create_graph: bool,
) -> Vec<Tensor>
where
T1: Borrow<Tensor>,
T2: Borrow<Tensor>,
{
Tensor::f_run_backward(tensors, inputs, keep_graph, create_graph).unwrap()
}
pub fn f_copy_data_u8(&self, dst: &mut [u8], numel: usize) -> Result<(), TchError> {
let elt_size_in_bytes = self.f_kind()?.elt_size_in_bytes();
if dst.len() < numel * elt_size_in_bytes {
return Err(TchError::Shape(format!("slice len < {numel}")));
}
unsafe_torch_err!(at_copy_data(
self.c_tensor,
dst.as_mut_ptr() as *const c_void,
numel,
elt_size_in_bytes,
));
Ok(())
}
pub fn f_internal_amp_non_finite_check_and_unscale(
&mut self,
found_inf: &mut Tensor,
inv_scale: &Tensor,
) -> Result<(), TchError> {
unsafe_torch_err!(at__amp_non_finite_check_and_unscale(
self.c_tensor,
found_inf.c_tensor,
inv_scale.c_tensor
));
Ok(())
}
pub fn internal_amp_non_finite_check_and_unscale(
&mut self,
found_inf: &mut Tensor,
inv_scale: &Tensor,
) {
self.f_internal_amp_non_finite_check_and_unscale(found_inf, inv_scale).unwrap()
}
pub fn copy_data_u8(&self, dst: &mut [u8], numel: usize) {
self.f_copy_data_u8(dst, numel).unwrap()
}
pub fn f_copy_data<T: kind::Element>(
&self,
dst: &mut [T],
numel: usize,
) -> Result<(), TchError> {
if T::KIND != self.f_kind()? {
return Err(TchError::Kind(format!(
"incoherent elt kind, {:?} != {:?}",
self.f_kind(),
T::KIND
)));
}
if dst.len() < numel {
return Err(TchError::Shape(format!("slice len < {numel}")));
}
unsafe_torch_err!(at_copy_data(
self.c_tensor,
dst.as_mut_ptr() as *const c_void,
numel,
T::KIND.elt_size_in_bytes(),
));
Ok(())
}
pub fn copy_data<T: kind::Element>(&self, dst: &mut [T], numel: usize) {
self.f_copy_data(dst, numel).unwrap()
}
pub fn numel(&self) -> usize {
self.size().iter().product::<i64>() as usize
}
pub fn f_from_slice<T: kind::Element>(data: &[T]) -> Result<Tensor, TchError> {
let data_len = data.len();
let data = data.as_ptr() as *const c_void;
let c_tensor = unsafe_torch_err!(at_tensor_of_data(
data,
[data_len as i64].as_ptr(),
1,
T::KIND.elt_size_in_bytes(),
T::KIND.c_int(),
));
Ok(Tensor { c_tensor })
}
pub fn from_slice<T: kind::Element>(data: &[T]) -> Tensor {
Self::f_from_slice(data).unwrap()
}
pub fn f_from_data_size(data: &[u8], size: &[i64], kind: Kind) -> Result<Tensor, TchError> {
let data = data.as_ptr() as *const c_void;
let elt_size_in_bytes = kind.elt_size_in_bytes();
let c_tensor = unsafe_torch_err!(at_tensor_of_data(
data,
size.as_ptr(),
size.len(),
elt_size_in_bytes,
kind.c_int(),
));
Ok(Tensor { c_tensor })
}
pub unsafe fn f_from_blob(
data: *const u8,
size: &[i64],
strides: &[i64],
kind: Kind,
device: Device,
) -> Result<Tensor, TchError> {
let data = data as *const c_void;
#[allow(unused_unsafe)]
let c_tensor = unsafe_torch_err!(at_tensor_of_blob(
data,
size.as_ptr(),
size.len(),
strides.as_ptr(),
strides.len(),
kind.c_int(),
device.c_int()
));
Ok(Tensor { c_tensor })
}
pub unsafe fn from_blob(
data: *const u8,
size: &[i64],
strides: &[i64],
kind: Kind,
device: Device,
) -> Tensor {
Self::f_from_blob(data, size, strides, kind, device).unwrap()
}
pub fn from_data_size(data: &[u8], size: &[i64], kind: Kind) -> Tensor {
Self::f_from_data_size(data, size, kind).unwrap()
}
pub fn shallow_clone(&self) -> Tensor {
let c_tensor = unsafe_torch!(at_shallow_clone(self.c_tensor));
Tensor { c_tensor }
}
pub fn f_get(&self, index: i64) -> Result<Tensor, TchError> {
let c_tensor = unsafe_torch_err!(at_get(self.c_tensor, index as c_int));
Ok(Tensor { c_tensor })
}
pub fn get(&self, index: i64) -> Tensor {
self.f_get(index).unwrap()
}
pub fn f_copy_(&mut self, src: &Tensor) -> Result<(), TchError> {
unsafe_torch_err!(at_copy_(self.c_tensor, src.c_tensor));
Ok(())
}
pub fn copy_(&mut self, src: &Tensor) {
self.f_copy_(src).unwrap()
}
pub fn load<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
let path = path_to_cstring(path)?;
let c_tensor = unsafe_torch_err!(at_load(path.as_ptr()));
Ok(Tensor { c_tensor })
}
pub fn load_from_stream<T: Read + Seek>(stream: T) -> Result<Tensor, TchError> {
let adapter = ReadSeekAdapter::new(stream);
let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
let c_tensor =
unsafe_torch_err!(at_load_from_stream(Box::into_raw(boxed_stream) as *mut c_void,));
Ok(Tensor { c_tensor })
}
pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
let path = path_to_cstring(path)?;
unsafe_torch_err!(at_save(self.c_tensor, path.as_ptr()));
Ok(())
}
pub fn save_to_stream<W: Write>(&self, stream: W) -> Result<(), TchError> {
let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
unsafe_torch_err!(at_save_to_stream(
self.c_tensor,
Box::into_raw(boxed_stream) as *mut c_void,
));
Ok(())
}
pub fn save_multi<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
named_tensors: &[(S, T)],
path: P,
) -> Result<(), TchError> {
let path = path_to_cstring(path)?;
let c_tensors = named_tensors.iter().map(|nt| nt.1.as_ref().c_tensor).collect::<Vec<_>>();
let names = named_tensors
.iter()
.map(|nt| nt.0.as_ref().replace('.', "|").into_bytes())
.map(std::ffi::CString::new)
.collect::<Result<Vec<_>, _>>()?;
let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
unsafe_torch_err!(at_save_multi(
c_tensors.as_ptr(),
name_ptrs.as_ptr(),
names.len() as i32,
path.as_ptr(),
));
Ok(())
}
pub fn save_multi_to_stream<S: AsRef<str>, T: AsRef<Tensor>, W: Write>(
named_tensors: &[(S, T)],
stream: W,
) -> Result<(), TchError> {
let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
let c_tensors = named_tensors.iter().map(|nt| nt.1.as_ref().c_tensor).collect::<Vec<_>>();
let names = named_tensors
.iter()
.map(|nt| nt.0.as_ref().replace('.', "|").into_bytes())
.map(std::ffi::CString::new)
.collect::<Result<Vec<_>, _>>()?;
let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
unsafe_torch_err!(at_save_multi_to_stream(
c_tensors.as_ptr(),
name_ptrs.as_ptr(),
names.len() as i32,
Box::into_raw(boxed_stream) as *mut c_void,
));
Ok(())
}
pub fn load_multi<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
let path = path_to_cstring(path)?;
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_load_callback(
path.as_ptr(),
&mut v as *mut _ as *mut c_void,
add_callback
));
Ok(v)
}
pub fn load_multi_with_device<T: AsRef<Path>>(
path: T,
device: Device,
) -> Result<Vec<(String, Tensor)>, TchError> {
let path = path_to_cstring(path)?;
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_load_callback_with_device(
path.as_ptr(),
&mut v as *mut _ as *mut c_void,
add_callback,
device.c_int(),
));
Ok(v)
}
pub fn loadz_multi<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
let path = path_to_cstring(path)?;
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_loadz_callback(
path.as_ptr(),
&mut v as *mut _ as *mut c_void,
add_callback
));
Ok(v)
}
pub fn loadz_multi_with_device<T: AsRef<Path>>(
path: T,
device: Device,
) -> Result<Vec<(String, Tensor)>, TchError> {
let path = path_to_cstring(path)?;
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_loadz_callback_with_device(
path.as_ptr(),
&mut v as *mut _ as *mut c_void,
add_callback,
device.c_int(),
));
Ok(v)
}
pub fn load_multi_from_stream<T: Read + Seek>(
stream: T,
) -> Result<Vec<(String, Tensor)>, TchError> {
let adapter = ReadSeekAdapter::new(stream);
let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_load_from_stream_callback(
Box::into_raw(boxed_stream) as *mut c_void,
&mut v as *mut _ as *mut c_void,
add_callback,
false,
0,
));
Ok(v)
}
pub fn load_multi_from_stream_with_device<T: Read + Seek>(
stream: T,
device: Device,
) -> Result<Vec<(String, Tensor)>, TchError> {
let adapter = ReadSeekAdapter::new(stream);
let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
let mut v: Vec<(String, Tensor)> = vec![];
unsafe_torch_err!(at_load_from_stream_callback(
Box::into_raw(boxed_stream) as *mut c_void,
&mut v as *mut _ as *mut c_void,
add_callback,
true,
device.c_int(),
));
Ok(v)
}
pub fn to_string(&self, lw: i64) -> Result<String, TchError> {
let s =
unsafe_torch_err!(ptr_to_string(torch_sys::at_to_string(self.c_tensor, lw as c_int)));
match s {
None => Err(TchError::Kind("nullptr representation".to_string())),
Some(s) => Ok(s),
}
}
}
impl Default for Tensor {
fn default() -> Self {
Self::new()
}
}
impl Drop for Tensor {
fn drop(&mut self) {
unsafe_torch!(at_free(self.c_tensor))
}
}
fn autocast_clear_cache() {
unsafe_torch!(at_autocast_clear_cache())
}
fn autocast_decrement_nesting() -> isize {
unsafe_torch!(at_autocast_decrement_nesting() as isize)
}
fn autocast_increment_nesting() -> isize {
unsafe_torch!(at_autocast_increment_nesting() as isize)
}
fn autocast_is_enabled() -> bool {
unsafe_torch!(at_autocast_is_enabled() != 0)
}
fn autocast_set_enabled(b: bool) -> bool {
unsafe_torch!(at_autocast_set_enabled(i32::from(b)) != 0)
}
pub fn autocast<T, F>(enabled: bool, f: F) -> T
where
F: FnOnce() -> T,
{
if !Cuda::is_available() {
return f();
}
let prev = autocast_is_enabled();
autocast_set_enabled(enabled);
autocast_increment_nesting();
let result = f();
if autocast_decrement_nesting() == 0 {
autocast_clear_cache();
}
autocast_set_enabled(prev);
result
}
fn grad_set_enabled(b: bool) -> bool {
unsafe_torch!(at_grad_set_enabled(i32::from(b)) != 0)
}
pub fn no_grad<T, F>(f: F) -> T
where
F: FnOnce() -> T,
{
let prev = grad_set_enabled(false);
let result = f();
let _false = grad_set_enabled(prev);
result
}
pub fn with_grad<T, F>(f: F) -> T
where
F: FnOnce() -> T,
{
let prev = grad_set_enabled(true);
let result = f();
let _false = grad_set_enabled(prev);
result
}
pub struct NoGradGuard {
enabled: bool,
}
pub fn no_grad_guard() -> NoGradGuard {
NoGradGuard { enabled: grad_set_enabled(false) }
}
impl std::convert::AsRef<Tensor> for Tensor {
fn as_ref(&self) -> &Self {
self
}
}
impl Drop for NoGradGuard {
fn drop(&mut self) {
let _enabled = grad_set_enabled(self.enabled);
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Reduction {
None,
Mean,
Sum,
Other(i64),
}
impl Reduction {
pub fn to_int(self) -> i64 {
match self {
Reduction::None => 0,
Reduction::Mean => 1,
Reduction::Sum => 2,
Reduction::Other(i) => i,
}
}
}