use crate::{nn, nn::ModuleT, Tensor};
fn conv_bn(p: nn::Path, c_in: i64, c_out: i64, ksize: i64, pad: i64, stride: i64) -> impl ModuleT {
let conv2d_cfg = nn::ConvConfig { stride, padding: pad, bias: false, ..Default::default() };
let bn_cfg = nn::BatchNormConfig { eps: 0.001, ..Default::default() };
nn::seq_t()
.add(nn::conv2d(&p / "conv", c_in, c_out, ksize, conv2d_cfg))
.add(nn::batch_norm2d(&p / "bn", c_out, bn_cfg))
.add_fn(|xs| xs.relu())
}
fn conv_bn2(p: nn::Path, c_in: i64, c_out: i64, ksize: [i64; 2], pad: [i64; 2]) -> impl ModuleT {
let conv2d_cfg =
nn::ConvConfigND::<[i64; 2]> { padding: pad, bias: false, ..Default::default() };
let bn_cfg = nn::BatchNormConfig { eps: 0.001, ..Default::default() };
nn::seq_t()
.add(nn::conv(&p / "conv", c_in, c_out, ksize, conv2d_cfg))
.add(nn::batch_norm2d(&p / "bn", c_out, bn_cfg))
.add_fn(|xs| xs.relu())
}
fn max_pool2d(xs: &Tensor, ksize: i64, stride: i64) -> Tensor {
xs.max_pool2d([ksize, ksize], [stride, stride], [0, 0], [1, 1], false)
}
fn inception_a(p: nn::Path, c_in: i64, c_pool: i64) -> impl ModuleT {
let b1 = conv_bn(&p / "branch1x1", c_in, 64, 1, 0, 1);
let b2_1 = conv_bn(&p / "branch5x5_1", c_in, 48, 1, 0, 1);
let b2_2 = conv_bn(&p / "branch5x5_2", 48, 64, 5, 2, 1);
let b3_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 64, 1, 0, 1);
let b3_2 = conv_bn(&p / "branch3x3dbl_2", 64, 96, 3, 1, 1);
let b3_3 = conv_bn(&p / "branch3x3dbl_3", 96, 96, 3, 1, 1);
let bpool = conv_bn(&p / "branch_pool", c_in, c_pool, 1, 0, 1);
nn::func_t(move |xs, tr| {
let b1 = xs.apply_t(&b1, tr);
let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr);
let b3 = xs.apply_t(&b3_1, tr).apply_t(&b3_2, tr).apply_t(&b3_3, tr);
let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
Tensor::cat(&[b1, b2, b3, bpool], 1)
})
}
fn inception_b(p: nn::Path, c_in: i64) -> impl ModuleT {
let b1 = conv_bn(&p / "branch3x3", c_in, 384, 3, 0, 2);
let b2_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 64, 1, 0, 1);
let b2_2 = conv_bn(&p / "branch3x3dbl_2", 64, 96, 3, 1, 1);
let b2_3 = conv_bn(&p / "branch3x3dbl_3", 96, 96, 3, 0, 2);
nn::func_t(move |xs, tr| {
let b1 = xs.apply_t(&b1, tr);
let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr);
let bpool = max_pool2d(xs, 3, 2);
Tensor::cat(&[b1, b2, bpool], 1)
})
}
fn inception_c(p: nn::Path, c_in: i64, c7: i64) -> impl ModuleT {
let b1 = conv_bn(&p / "branch1x1", c_in, 192, 1, 0, 1);
let b2_1 = conv_bn(&p / "branch7x7_1", c_in, c7, 1, 0, 1);
let b2_2 = conv_bn2(&p / "branch7x7_2", c7, c7, [1, 7], [0, 3]);
let b2_3 = conv_bn2(&p / "branch7x7_3", c7, 192, [7, 1], [3, 0]);
let b3_1 = conv_bn(&p / "branch7x7dbl_1", c_in, c7, 1, 0, 1);
let b3_2 = conv_bn2(&p / "branch7x7dbl_2", c7, c7, [7, 1], [3, 0]);
let b3_3 = conv_bn2(&p / "branch7x7dbl_3", c7, c7, [1, 7], [0, 3]);
let b3_4 = conv_bn2(&p / "branch7x7dbl_4", c7, c7, [7, 1], [3, 0]);
let b3_5 = conv_bn2(&p / "branch7x7dbl_5", c7, 192, [1, 7], [0, 3]);
let bpool = conv_bn(&p / "branch_pool", c_in, 192, 1, 0, 1);
nn::func_t(move |xs, tr| {
let b1 = xs.apply_t(&b1, tr);
let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr);
let b3 = xs
.apply_t(&b3_1, tr)
.apply_t(&b3_2, tr)
.apply_t(&b3_3, tr)
.apply_t(&b3_4, tr)
.apply_t(&b3_5, tr);
let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
Tensor::cat(&[b1, b2, b3, bpool], 1)
})
}
fn inception_d(p: nn::Path, c_in: i64) -> impl ModuleT {
let b1_1 = conv_bn(&p / "branch3x3_1", c_in, 192, 1, 0, 1);
let b1_2 = conv_bn(&p / "branch3x3_2", 192, 320, 3, 0, 2);
let b2_1 = conv_bn(&p / "branch7x7x3_1", c_in, 192, 1, 0, 1);
let b2_2 = conv_bn2(&p / "branch7x7x3_2", 192, 192, [1, 7], [0, 3]);
let b2_3 = conv_bn2(&p / "branch7x7x3_3", 192, 192, [7, 1], [3, 0]);
let b2_4 = conv_bn(&p / "branch7x7x3_4", 192, 192, 3, 0, 2);
nn::func_t(move |xs, tr| {
let b1 = xs.apply_t(&b1_1, tr).apply_t(&b1_2, tr);
let b2 = xs.apply_t(&b2_1, tr).apply_t(&b2_2, tr).apply_t(&b2_3, tr).apply_t(&b2_4, tr);
let bpool = max_pool2d(xs, 3, 2);
Tensor::cat(&[b1, b2, bpool], 1)
})
}
fn inception_e(p: nn::Path, c_in: i64) -> impl ModuleT {
let b1 = conv_bn(&p / "branch1x1", c_in, 320, 1, 0, 1);
let b2_1 = conv_bn(&p / "branch3x3_1", c_in, 384, 1, 0, 1);
let b2_2a = conv_bn2(&p / "branch3x3_2a", 384, 384, [1, 3], [0, 1]);
let b2_2b = conv_bn2(&p / "branch3x3_2b", 384, 384, [3, 1], [1, 0]);
let b3_1 = conv_bn(&p / "branch3x3dbl_1", c_in, 448, 1, 0, 1);
let b3_2 = conv_bn(&p / "branch3x3dbl_2", 448, 384, 3, 1, 1);
let b3_3a = conv_bn2(&p / "branch3x3dbl_3a", 384, 384, [1, 3], [0, 1]);
let b3_3b = conv_bn2(&p / "branch3x3dbl_3b", 384, 384, [3, 1], [1, 0]);
let bpool = conv_bn(&p / "branch_pool", c_in, 192, 1, 0, 1);
nn::func_t(move |xs, tr| {
let b1 = xs.apply_t(&b1, tr);
let b2 = xs.apply_t(&b2_1, tr);
let b2 = Tensor::cat(&[b2.apply_t(&b2_2a, tr), b2.apply_t(&b2_2b, tr)], 1);
let b3 = xs.apply_t(&b3_1, tr).apply_t(&b3_2, tr);
let b3 = Tensor::cat(&[b3.apply_t(&b3_3a, tr), b3.apply_t(&b3_3b, tr)], 1);
let bpool = xs.avg_pool2d([3, 3], [1, 1], [1, 1], false, true, 9).apply_t(&bpool, tr);
Tensor::cat(&[b1, b2, b3, bpool], 1)
})
}
pub fn v3(p: &nn::Path, nclasses: i64) -> impl ModuleT {
nn::seq_t()
.add(conv_bn(p / "Conv2d_1a_3x3", 3, 32, 3, 0, 2))
.add(conv_bn(p / "Conv2d_2a_3x3", 32, 32, 3, 0, 1))
.add(conv_bn(p / "Conv2d_2b_3x3", 32, 64, 3, 1, 1))
.add_fn(|xs| max_pool2d(&xs.relu(), 3, 2))
.add(conv_bn(p / "Conv2d_3b_1x1", 64, 80, 1, 0, 1))
.add(conv_bn(p / "Conv2d_4a_3x3", 80, 192, 3, 0, 1))
.add_fn(|xs| max_pool2d(&xs.relu(), 3, 2))
.add(inception_a(p / "Mixed_5b", 192, 32))
.add(inception_a(p / "Mixed_5c", 256, 64))
.add(inception_a(p / "Mixed_5d", 288, 64))
.add(inception_b(p / "Mixed_6a", 288))
.add(inception_c(p / "Mixed_6b", 768, 128))
.add(inception_c(p / "Mixed_6c", 768, 160))
.add(inception_c(p / "Mixed_6d", 768, 160))
.add(inception_c(p / "Mixed_6e", 768, 192))
.add(inception_d(p / "Mixed_7a", 768))
.add(inception_e(p / "Mixed_7b", 1280))
.add(inception_e(p / "Mixed_7c", 2048))
.add_fn_t(|xs, train| xs.adaptive_avg_pool2d([1, 1]).dropout(0.5, train).flat_view())
.add(nn::linear(p / "fc", 2048, nclasses, Default::default()))
}