Funlib API comparison

This notebook compares the UNet implementation in funlib.learn.torch with the one in tems.

We specifically compare

  • the constructors to show that it is a very easy migration.

  • funlib.UNet fails torch.jit.script.

  • funlib.UNet crops more aggressively to maintain translation equivariance.

Setup

[1]:
import torch
from funlib.learn.torch.models import UNet as FunlibUNet

from tems import UNet

# Here we define a set of downsampling factors for demonstration purposes.
# I use 2D downsampling factors to simplify compute and because the
# funlib UNet does not support 1D.
downsample_factors = [
    [[2, 1], [2, 1], [2, 1], [2, 1]],
    [[2, 1], [3, 1], [4, 1], [2, 1]],
    [[2, 1], [4, 1], [3, 1], [5, 1]],
]
# The extra input is necessary because the funlib UNet crops more aggressively
# than the tems UNet, thus has a larger `min_input_shape`. These were found via manual
# guess and check. Figuring out the appropriate input shape for the funlib UNet is not
# always easy and one of the main motivators for `tems`.
extra_inputs = [
    (16, 0),
    (48 * 2, 0),
    (120 * 2, 0),
]
[2]:
def build_unet(downsample_factors):
    return UNet.funlib_api(
        dims=2,
        in_channels=1,
        num_fmaps=1,
        fmap_inc_factor=1,
        downsample_factors=downsample_factors,
        activation="Identity",
    )


def build_funlib_unet(downsample_factors):
    return FunlibUNet(
        in_channels=1,
        num_fmaps=1,
        fmap_inc_factor=1,
        downsample_factors=downsample_factors,
        kernel_size_down=[[[3, 3], [3, 3]]] * (len(downsample_factors) + 1),
        kernel_size_up=[[[3, 3], [3, 3]]] * len(downsample_factors),
        activation="Identity",
    )

Jit scripting

[3]:
unet = build_unet(downsample_factors[0])
try:
    torch.jit.script(unet)
    print("Successfully scripted tems.UNet")
except RuntimeError:
    print("Failed to script tems.UNet")
/home/runner/work/tems/tems/.venv/lib/python3.12/site-packages/tems/conv_pass.py:81: UserWarning: Using Identity activation with the ConvPass module is assumed to be a test case. The convolutional layer will be initialized with constants.
  warnings.warn(
Successfully scripted tems.UNet
[4]:
funlib_unet = build_funlib_unet(downsample_factors[0])
try:
    torch.jit.script(funlib_unet)
    print("Successfully scripted funlib.UNet")
except RuntimeError:
    print("Failed to script funlib.UNet")
Failed to script funlib.UNet

Cropping

[5]:
def test_unet_comparison(tems_unet: UNet, funlib_unet: FunlibUNet, input_shape):
    in_data = torch.rand(1, 1, *(input_shape))
    print("Input shape:", list(in_data.shape[2:]))
    tems_out_data = tems_unet(in_data)
    tems_out_training = tems_unet.train()(in_data)
    funlib_out_data = funlib_unet(in_data)
    print("Output shape tems (Training):", list(tems_out_training.shape[2:]))
    print("Output shape tems (Translation Equivariant):", list(tems_out_data.shape[2:]))
    print(
        "Output shape funlib (Translation Equivariant):",
        list(funlib_out_data.shape[2:]),
    )
[6]:

for downsample_factor, extra_input in zip(downsample_factors, extra_inputs): unet = build_unet(downsample_factor).eval() input_shape = unet.min_input_shape + torch.tensor(extra_input) funlib_unet = build_funlib_unet(downsample_factor).eval() print("Total downsampling factor:", unet.equivariant_step) test_unet_comparison(unet, funlib_unet, input_shape) print()
Total downsampling factor: tensor([16,  1])
Input shape: [220, 37]
Output shape tems (Training): [36, 1]
Output shape tems (Translation Equivariant): [32, 1]
Output shape funlib (Translation Equivariant): [16, 1]

Total downsampling factor: tensor([48,  1])
Input shape: [612, 37]
Output shape tems (Training): [156, 1]
Output shape tems (Translation Equivariant): [144, 1]
Output shape funlib (Translation Equivariant): [48, 1]

Total downsampling factor: tensor([120,   1])
Input shape: [1220, 37]
Output shape tems (Training): [460, 1]
Output shape tems (Translation Equivariant): [360, 1]
Output shape funlib (Translation Equivariant): [120, 1]