use crate::{nn, nn::Conv2D, nn::ModuleT, Tensor};
fn conv2d(p: nn::Path, c_in: i64, c_out: i64, ksize: i64, padding: i64, stride: i64) -> Conv2D {
let conv2d_cfg = nn::ConvConfig { stride, padding, bias: false, ..Default::default() };
nn::conv2d(p, c_in, c_out, ksize, conv2d_cfg)
}
fn dense_layer(p: nn::Path, c_in: i64, bn_size: i64, growth: i64) -> impl ModuleT {
let c_inter = bn_size * growth;
let bn1 = nn::batch_norm2d(&p / "norm1", c_in, Default::default());
let conv1 = conv2d(&p / "conv1", c_in, c_inter, 1, 0, 1);
let bn2 = nn::batch_norm2d(&p / "norm2", c_inter, Default::default());
let conv2 = conv2d(&p / "conv2", c_inter, growth, 3, 1, 1);
nn::func_t(move |xs, train| {
let ys =
xs.apply_t(&bn1, train).relu().apply(&conv1).apply_t(&bn2, train).relu().apply(&conv2);
Tensor::cat(&[xs, &ys], 1)
})
}
fn dense_block(p: nn::Path, c_in: i64, bn_size: i64, growth: i64, nlayers: i64) -> impl ModuleT {
let mut seq = nn::seq_t();
for i in 0..nlayers {
seq = seq.add(dense_layer(
&p / &format!("denselayer{}", 1 + i),
c_in + i * growth,
bn_size,
growth,
));
}
seq
}
fn transition(p: nn::Path, c_in: i64, c_out: i64) -> impl ModuleT {
nn::seq_t()
.add(nn::batch_norm2d(&p / "norm", c_in, Default::default()))
.add_fn(|xs| xs.relu())
.add(conv2d(&p / "conv", c_in, c_out, 1, 0, 1))
.add_fn(|xs| xs.avg_pool2d_default(2))
}
fn densenet(
p: &nn::Path,
c_in: i64,
bn_size: i64,
growth: i64,
block_config: &[i64],
c_out: i64,
) -> impl ModuleT {
let fp = p / "features";
let mut seq = nn::seq_t()
.add(conv2d(&fp / "conv0", 3, c_in, 7, 3, 2))
.add(nn::batch_norm2d(&fp / "norm0", c_in, Default::default()))
.add_fn(|xs| xs.relu().max_pool2d([3, 3], [2, 2], [1, 1], [1, 1], false));
let mut nfeat = c_in;
for (i, &nlayers) in block_config.iter().enumerate() {
seq = seq.add(dense_block(
&fp / &format!("denseblock{}", 1 + i),
nfeat,
bn_size,
growth,
nlayers,
));
nfeat += nlayers * growth;
if i + 1 != block_config.len() {
seq = seq.add(transition(&fp / &format!("transition{}", 1 + i), nfeat, nfeat / 2));
nfeat /= 2
}
}
seq.add(nn::batch_norm2d(&fp / "norm5", nfeat, Default::default()))
.add_fn(|xs| xs.relu().avg_pool2d([7, 7], [1, 1], [0, 0], false, true, 1).flat_view())
.add(nn::linear(p / "classifier", nfeat, c_out, Default::default()))
}
pub fn densenet121(p: &nn::Path, nclasses: i64) -> impl ModuleT {
densenet(p, 64, 4, 32, &[6, 12, 24, 16], nclasses)
}
pub fn densenet161(p: &nn::Path, nclasses: i64) -> impl ModuleT {
densenet(p, 96, 4, 48, &[6, 12, 36, 24], nclasses)
}
pub fn densenet169(p: &nn::Path, nclasses: i64) -> impl ModuleT {
densenet(p, 64, 4, 32, &[6, 12, 32, 32], nclasses)
}
pub fn densenet201(p: &nn::Path, nclasses: i64) -> impl ModuleT {
densenet(p, 64, 4, 32, &[6, 12, 48, 32], nclasses)
}