Funlib API comparison
This notebook compares the UNet implementation in funlib.learn.torch with the one in tems.
We specifically compare
the constructors to show that it is a very easy migration.
funlib.UNet fails
torch.jit.script.funlib.UNet crops more aggressively to maintain translation equivariance.
Setup
[1]:
import torch
from funlib.learn.torch.models import UNet as FunlibUNet
from tems import UNet
# Here we define a set of downsampling factors for demonstration purposes.
# I use 2D downsampling factors to simplify compute and because the
# funlib UNet does not support 1D.
downsample_factors = [
[[2, 1], [2, 1], [2, 1], [2, 1]],
[[2, 1], [3, 1], [4, 1], [2, 1]],
[[2, 1], [4, 1], [3, 1], [5, 1]],
]
# The extra input is necessary because the funlib UNet crops more aggressively
# than the tems UNet, thus has a larger `min_input_shape`. These were found via manual
# guess and check. Figuring out the appropriate input shape for the funlib UNet is not
# always easy and one of the main motivators for `tems`.
extra_inputs = [
(16, 0),
(48 * 2, 0),
(120 * 2, 0),
]
[2]:
def build_unet(downsample_factors):
return UNet.funlib_api(
dims=2,
in_channels=1,
num_fmaps=1,
fmap_inc_factor=1,
downsample_factors=downsample_factors,
activation="Identity",
)
def build_funlib_unet(downsample_factors):
return FunlibUNet(
in_channels=1,
num_fmaps=1,
fmap_inc_factor=1,
downsample_factors=downsample_factors,
kernel_size_down=[[[3, 3], [3, 3]]] * (len(downsample_factors) + 1),
kernel_size_up=[[[3, 3], [3, 3]]] * len(downsample_factors),
activation="Identity",
)
Jit scripting
[3]:
unet = build_unet(downsample_factors[0])
try:
torch.jit.script(unet)
print("Successfully scripted tems.UNet")
except RuntimeError:
print("Failed to script tems.UNet")
/home/runner/work/tems/tems/.venv/lib/python3.12/site-packages/tems/conv_pass.py:81: UserWarning: Using Identity activation with the ConvPass module is assumed to be a test case. The convolutional layer will be initialized with constants.
warnings.warn(
Successfully scripted tems.UNet
[4]:
funlib_unet = build_funlib_unet(downsample_factors[0])
try:
torch.jit.script(funlib_unet)
print("Successfully scripted funlib.UNet")
except RuntimeError:
print("Failed to script funlib.UNet")
Failed to script funlib.UNet
Cropping
[5]:
def test_unet_comparison(tems_unet: UNet, funlib_unet: FunlibUNet, input_shape):
in_data = torch.rand(1, 1, *(input_shape))
print("Input shape:", list(in_data.shape[2:]))
tems_out_data = tems_unet(in_data)
tems_out_training = tems_unet.train()(in_data)
funlib_out_data = funlib_unet(in_data)
print("Output shape tems (Training):", list(tems_out_training.shape[2:]))
print("Output shape tems (Translation Equivariant):", list(tems_out_data.shape[2:]))
print(
"Output shape funlib (Translation Equivariant):",
list(funlib_out_data.shape[2:]),
)
[6]:
for downsample_factor, extra_input in zip(downsample_factors, extra_inputs):
unet = build_unet(downsample_factor).eval()
input_shape = unet.min_input_shape + torch.tensor(extra_input)
funlib_unet = build_funlib_unet(downsample_factor).eval()
print("Total downsampling factor:", unet.equivariant_step)
test_unet_comparison(unet, funlib_unet, input_shape)
print()
Total downsampling factor: tensor([16, 1])
Input shape: [220, 37]
Output shape tems (Training): [36, 1]
Output shape tems (Translation Equivariant): [32, 1]
Output shape funlib (Translation Equivariant): [16, 1]
Total downsampling factor: tensor([48, 1])
Input shape: [612, 37]
Output shape tems (Training): [156, 1]
Output shape tems (Translation Equivariant): [144, 1]
Output shape funlib (Translation Equivariant): [48, 1]
Total downsampling factor: tensor([120, 1])
Input shape: [1220, 37]
Output shape tems (Training): [460, 1]
Output shape tems (Translation Equivariant): [360, 1]
Output shape funlib (Translation Equivariant): [120, 1]