use crate::{Device, Kind, TchError, Tensor};
#[derive(Debug, Copy, Clone)]
pub enum FanInOut {
FanIn,
FanOut,
}
impl FanInOut {
pub fn for_weight_dims(&self, dims: &[i64]) -> i64 {
let receptive_field_size: i64 = dims.iter().skip(2).product();
match &self {
FanInOut::FanIn => {
if dims.len() < 2 {
1
} else {
dims[1] * receptive_field_size
}
}
FanInOut::FanOut => {
if dims.is_empty() {
1
} else {
dims[0] * receptive_field_size
}
}
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum NormalOrUniform {
Normal,
Uniform,
}
#[derive(Debug, Copy, Clone)]
pub enum NonLinearity {
ReLU,
Linear,
Sigmoid,
Tanh,
SELU,
ExplicitGain(f64),
}
impl NonLinearity {
pub fn gain(&self) -> f64 {
match *self {
NonLinearity::ReLU => 2f64.sqrt(),
NonLinearity::Tanh => 5. / 3.,
NonLinearity::Linear | NonLinearity::Sigmoid => 1.,
NonLinearity::SELU => 0.75,
NonLinearity::ExplicitGain(g) => g,
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum Init {
Const(f64),
Randn { mean: f64, stdev: f64 },
Uniform { lo: f64, up: f64 },
Kaiming { dist: NormalOrUniform, fan: FanInOut, non_linearity: NonLinearity },
Orthogonal { gain: f64 },
}
pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
dist: NormalOrUniform::Uniform,
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};
pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
dist: NormalOrUniform::Normal,
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};
pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError> {
match i {
Init::Const(cst) => {
if cst == 0. {
Tensor::f_zeros(dims, (Kind::Float, device))
} else if (cst - 1.).abs() <= std::f64::EPSILON {
Tensor::f_ones(dims, (Kind::Float, device))
} else {
Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst)
}
}
Init::Uniform { lo, up } => {
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(lo, up)
}
Init::Randn { mean, stdev } => {
if mean == 0. && (stdev - 1.).abs() <= std::f64::EPSILON {
Tensor::f_randn(dims, (Kind::Float, device))
} else {
Tensor::f_randn(dims, (Kind::Float, device)).map(|t| t * stdev + mean)
}
}
Init::Kaiming { dist, fan, non_linearity } => {
let fan = fan.for_weight_dims(dims);
let gain = non_linearity.gain();
let std = gain / (fan as f64).sqrt();
match dist {
NormalOrUniform::Uniform => {
let bound = 3f64.sqrt() * std;
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(-bound, bound)
}
NormalOrUniform::Normal => {
let randn = Tensor::f_randn(dims, (Kind::Float, device))?;
Ok(randn * std)
}
}
}
Init::Orthogonal { gain } => {
if dims.len() < 2 {
return Err(TchError::Shape(
"Only tensors with 2 or more dimensions are supported".to_string(),
));
}
let rows = dims[0];
let cols: i64 = dims.iter().skip(1).product();
let mut flattened =
Tensor::f_empty([rows, cols], (Kind::Float, device))?.f_normal_(0.0, 1.0)?;
let flattened = if rows < cols { flattened.f_t_()? } else { flattened };
let (mut q, r) = Tensor::f_linalg_qr(&flattened, "reduced")?;
let d = r.f_diag(0)?;
let ph = d.f_sign()?;
q *= ph;
let mut q = if rows < cols { q.f_t_()? } else { q };
crate::no_grad(|| q *= gain);
q.f_contiguous()
}
}
}
pub fn init(i: Init, dims: &[i64], device: Device) -> Tensor {
f_init(i, dims, device).unwrap()
}
impl Init {
pub fn set(self, tensor: &mut Tensor) {
match self {
Init::Const(cst) => {
let _ = tensor.fill_(cst);
}
Init::Uniform { lo, up } => {
let _ = tensor.uniform_(lo, up);
}
Init::Kaiming { dist, fan, non_linearity } => {
let fan = fan.for_weight_dims(&tensor.size());
let gain = non_linearity.gain();
let std = gain / (fan as f64).sqrt();
match dist {
NormalOrUniform::Uniform => {
let bound = 3f64.sqrt() * std;
let _ = tensor.uniform_(-bound, bound);
}
NormalOrUniform::Normal => {
tensor.copy_(&(tensor.randn_like() * std));
}
}
}
Init::Randn { mean, stdev } => {
tensor.copy_(&(tensor.randn_like() * stdev + mean));
}
Init::Orthogonal { gain } => {
let q = f_init(Init::Orthogonal { gain }, &tensor.size(), tensor.device()).unwrap();
crate::no_grad(|| tensor.view_as(&q).copy_(&q));
}
}
}
}
impl Tensor {
pub fn init(&mut self, i: Init) {
i.set(self)
}
}