{ "cells": [ { "cell_type": "markdown", "id": "68ac8b96", "metadata": {}, "source": [ "# Funlib API comparison\n", "This notebook compares the `UNet` implementation in `funlib.learn.torch`\n", "with the one in `tems`.\n", "\n", "We specifically compare\n", "- the constructors to show that it is a very easy migration.\n", "- funlib.UNet fails `torch.jit.script`.\n", "- funlib.UNet crops more aggressively to maintain translation equivariance." ] }, { "cell_type": "markdown", "id": "ecc366bf", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "id": "3121ec65", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from funlib.learn.torch.models import UNet as FunlibUNet\n", "\n", "from tems import UNet\n", "\n", "# Here we define a set of downsampling factors for demonstration purposes.\n", "# I use 2D downsampling factors to simplify compute and because the\n", "# funlib UNet does not support 1D.\n", "downsample_factors = [\n", " [[2, 1], [2, 1], [2, 1], [2, 1]],\n", " [[2, 1], [3, 1], [4, 1], [2, 1]],\n", " [[2, 1], [4, 1], [3, 1], [5, 1]],\n", "]\n", "# The extra input is necessary because the funlib UNet crops more aggressively\n", "# than the tems UNet, thus has a larger `min_input_shape`. These were found via manual\n", "# guess and check. Figuring out the appropriate input shape for the funlib UNet is not\n", "# always easy and one of the main motivators for `tems`.\n", "extra_inputs = [\n", " (16, 0),\n", " (48 * 2, 0),\n", " (120 * 2, 0),\n", "]" ] }, { "cell_type": "code", "execution_count": null, "id": "b251516f", "metadata": {}, "outputs": [], "source": [ "def build_unet(downsample_factors):\n", " return UNet.funlib_api(\n", " dims=2,\n", " in_channels=1,\n", " num_fmaps=1,\n", " fmap_inc_factor=1,\n", " downsample_factors=downsample_factors,\n", " activation=\"Identity\",\n", " )\n", "\n", "\n", "def build_funlib_unet(downsample_factors):\n", " return FunlibUNet(\n", " in_channels=1,\n", " num_fmaps=1,\n", " fmap_inc_factor=1,\n", " downsample_factors=downsample_factors,\n", " kernel_size_down=[[[3, 3], [3, 3]]] * (len(downsample_factors) + 1),\n", " kernel_size_up=[[[3, 3], [3, 3]]] * len(downsample_factors),\n", " activation=\"Identity\",\n", " )" ] }, { "cell_type": "markdown", "id": "d9a5be94", "metadata": {}, "source": [ "## Jit scripting" ] }, { "cell_type": "code", "execution_count": null, "id": "eb533108", "metadata": {}, "outputs": [], "source": [ "unet = build_unet(downsample_factors[0])\n", "try:\n", " torch.jit.script(unet)\n", " print(\"Successfully scripted tems.UNet\")\n", "except RuntimeError:\n", " print(\"Failed to script tems.UNet\")" ] }, { "cell_type": "code", "execution_count": null, "id": "75a1a8d2", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "funlib_unet = build_funlib_unet(downsample_factors[0])\n", "try:\n", " torch.jit.script(funlib_unet)\n", " print(\"Successfully scripted funlib.UNet\")\n", "except RuntimeError:\n", " print(\"Failed to script funlib.UNet\")" ] }, { "cell_type": "markdown", "id": "a41772d2", "metadata": { "lines_to_next_cell": 2 }, "source": [ "## Cropping" ] }, { "cell_type": "code", "execution_count": null, "id": "ee9b96ea", "metadata": {}, "outputs": [], "source": [ "def test_unet_comparison(tems_unet: UNet, funlib_unet: FunlibUNet, input_shape):\n", " in_data = torch.rand(1, 1, *(input_shape))\n", " print(\"Input shape:\", list(in_data.shape[2:]))\n", " tems_out_data = tems_unet(in_data)\n", " tems_out_training = tems_unet.train()(in_data)\n", " funlib_out_data = funlib_unet(in_data)\n", " print(\"Output shape tems (Training):\", list(tems_out_training.shape[2:]))\n", " print(\"Output shape tems (Translation Equivariant):\", list(tems_out_data.shape[2:]))\n", " print(\n", " \"Output shape funlib (Translation Equivariant):\",\n", " list(funlib_out_data.shape[2:]),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "8e8d4e01", "metadata": {}, "outputs": [], "source": [ "\n", "for downsample_factor, extra_input in zip(downsample_factors, extra_inputs):\n", " unet = build_unet(downsample_factor).eval()\n", " input_shape = unet.min_input_shape + torch.tensor(extra_input)\n", " funlib_unet = build_funlib_unet(downsample_factor).eval()\n", "\n", " print(\"Total downsampling factor:\", unet.equivariant_step)\n", " test_unet_comparison(unet, funlib_unet, input_shape)\n", " print()" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }