{ "cells": [ { "cell_type": "markdown", "id": "148dab10", "metadata": {}, "source": [ "# Cremi Example\n", "This tutorial demonstrates some simple pipelines using the dacapo_toolbox\n", "dataset on [cremi data](https://cremi.org/data/). We'll cover a fun method\n", "for instance segmentation using a 2.5D U-Net." ] }, { "cell_type": "markdown", "id": "e56dad7e", "metadata": {}, "source": [ "## Introduction and overview\n", "\n", "In this tutorial we will cover a few basic ML tasks using the DaCapo toolbox. We will:\n", "\n", "- Prepare a dataloader for the CREMI dataset\n", "- Train a simple 2D U-Net for both instance and semantic segmentation\n", "- Visualize the results\n" ] }, { "cell_type": "markdown", "id": "ad9c5b5f", "metadata": {}, "source": [ "## Environment setup\n", "If you have not already done so, you will need to install DaCapo. You can do this\n", "by first creating a new environment and then installing the DaCapo Toolbox.\n", "\n", "I highly recommend using [uv](https://docs.astral.sh/uv/) for environment management,\n", "but there are many tools to choose from.\n", "\n", "```bash\n", "uv init\n", "uv add git+https://github.com/pattonw/dacapo-toolbox.git\n", "```" ] }, { "cell_type": "markdown", "id": "5b342cc0", "metadata": {}, "source": [ "## Data Preparation\n", "DaCapo works with zarr, so we will download [CREMI Sample A](https://cremi.org/static/data/sample_A%2B_20160601.hdf)\n", "and save it as a zarr file." ] }, { "cell_type": "code", "execution_count": null, "id": "a65bb087", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "import multiprocessing as mp\n", "\n", "mp.set_start_method(\"fork\", force=True)\n", "import dask\n", "\n", "dask.config.set(scheduler=\"single-threaded\")\n", "\n", "from pathlib import Path\n", "from functools import partial\n", "from tqdm import tqdm\n", "\n", "from funlib.persistence import Array\n", "from funlib.geometry import Coordinate, Roi\n", "from dacapo_toolbox.sample_datasets import cremi\n", "\n", "if not Path(\"_static/cremi\").exists():\n", " Path(\"_static/cremi\").mkdir(parents=True, exist_ok=True)\n", "\n", "raw_train, labels_train, raw_test, labels_test = cremi(Path(\"cremi.zarr\"))\n", "\n", "# define some variables that we will use later\n", "# The number of iterations we will train\n", "NUM_ITERATIONS = 300\n", "# A reasonable block size for processing image data with a UNet\n", "blocksize = Coordinate(32, 256, 256)\n", "# We choose a small and large eval roi for performance evaluation\n", "# The small roi will be processed in memory, the large will be processed blockwise\n", "offset = Coordinate(78, 465, 465)\n", "small_eval_roi = Roi(offset, blocksize) * raw_test.voxel_size\n", "large_eval_roi = (\n", " Roi(offset - blocksize, blocksize * Coordinate(1, 3, 3)) * raw_test.voxel_size\n", ")" ] }, { "cell_type": "markdown", "id": "92b97279", "metadata": {}, "source": [ "Lets visualize our train and test data" ] }, { "cell_type": "markdown", "id": "1c91a5d3", "metadata": {}, "source": [ "### Training data" ] }, { "cell_type": "code", "execution_count": null, "id": "8cca6ded", "metadata": {}, "outputs": [], "source": [ "\n", "from dacapo_toolbox.vis.preview import gif_2d, cube" ] }, { "cell_type": "code", "execution_count": null, "id": "b7f3764a", "metadata": {}, "outputs": [], "source": [ "\n", "# create a 2D gif of the training data\n", "gif_2d(\n", " arrays={\"Train Raw\": raw_train, \"Train Labels\": labels_train},\n", " array_types={\"Train Raw\": \"raw\", \"Train Labels\": \"labels\"},\n", " filename=\"_static/cremi/training-data.gif\",\n", " title=\"Training Data\",\n", " fps=10,\n", ")\n", "cube(\n", " arrays={\"Train Raw\": raw_train, \"Train Labels\": labels_train},\n", " array_types={\"Train Raw\": \"raw\", \"Train Labels\": \"labels\"},\n", " filename=\"_static/cremi/training-data.jpg\",\n", " title=\"Training Data\",\n", ")" ] }, { "cell_type": "markdown", "id": "7d6efa46", "metadata": {}, "source": [ "Here we visualize the training data:\n", "![training-data](_static/cremi/training-data.gif)\n", "![training-data-cube](_static/cremi/training-data.jpg)" ] }, { "cell_type": "markdown", "id": "ce69751f", "metadata": {}, "source": [ "### Testing data" ] }, { "cell_type": "code", "execution_count": null, "id": "8e666fd8", "metadata": {}, "outputs": [], "source": [ "gif_2d(\n", " arrays={\"Test Raw\": raw_test, \"Test Labels\": labels_test},\n", " array_types={\"Test Raw\": \"raw\", \"Test Labels\": \"labels\"},\n", " filename=\"_static/cremi/testing-data.gif\",\n", " title=\"Testing Data\",\n", " fps=10,\n", ")\n", "cube(\n", " arrays={\"Test Raw\": raw_test, \"Test Labels\": labels_test},\n", " array_types={\"Test Raw\": \"raw\", \"Test Labels\": \"labels\"},\n", " filename=\"_static/cremi/testing-data.jpg\",\n", " title=\"Testing Data\",\n", ")" ] }, { "cell_type": "markdown", "id": "a8c12607", "metadata": {}, "source": [ "Here we visualize the test data:\n", "![test-data](_static/cremi/test-data.gif)\n", "![test-data-cube](_static/cremi/test-data.jpg)" ] }, { "cell_type": "markdown", "id": "0e172baa", "metadata": {}, "source": [ "### DaCapo\n", "Now that we have some data, lets look at how we can use DaCapo to interface with it for some common ML use cases." ] }, { "cell_type": "markdown", "id": "cb72f18a", "metadata": {}, "source": [ "### Data Split\n", "We always want to be explicit when we define our data split for training and validation so that we are aware what data is being used for training and validation." ] }, { "cell_type": "code", "execution_count": null, "id": "0865ca81", "metadata": {}, "outputs": [], "source": [ "from dacapo_toolbox.dataset import (\n", " iterable_dataset,\n", " DeformAugmentConfig,\n", " SimpleAugmentConfig,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f2706e38", "metadata": {}, "outputs": [], "source": [ "train_dataset = iterable_dataset(\n", " datasets={\"raw\": raw_train, \"gt\": labels_train},\n", " shapes={\"raw\": (13, 256, 256), \"gt\": (13, 256, 256)},\n", " deform_augment_config=DeformAugmentConfig(\n", " p=0.1,\n", " control_point_spacing=(2, 10, 10),\n", " jitter_sigma=(0.5, 2, 2),\n", " rotate=True,\n", " subsample=4,\n", " rotation_axes=(1, 2),\n", " scale_interval=(1.0, 1.0),\n", " ),\n", " simple_augment_config=SimpleAugmentConfig(\n", " p=1.0,\n", " mirror_only=(1, 2),\n", " transpose_only=(1, 2),\n", " ),\n", " trim=Coordinate(5, 5, 5),\n", ")\n", "batch_gen = iter(train_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "99c0e3b1", "metadata": {}, "outputs": [], "source": [ "batch = next(batch_gen)\n", "gif_2d(\n", " arrays={\n", " \"Raw\": Array(batch[\"raw\"].numpy(), voxel_size=raw_train.voxel_size),\n", " \"Labels\": Array(batch[\"gt\"].numpy(), voxel_size=labels_train.voxel_size),\n", " },\n", " array_types={\"Raw\": \"raw\", \"Labels\": \"labels\"},\n", " filename=\"_static/cremi/simple-batch.gif\",\n", " title=\"Simple Batch\",\n", " fps=10,\n", ")\n", "cube(\n", " arrays={\n", " \"Raw\": Array(batch[\"raw\"].numpy(), voxel_size=raw_train.voxel_size),\n", " \"Labels\": Array(batch[\"gt\"].numpy(), voxel_size=labels_train.voxel_size),\n", " },\n", " array_types={\"Raw\": \"raw\", \"Labels\": \"labels\"},\n", " filename=\"_static/cremi/simple-batch.jpg\",\n", " title=\"Simple Batch\",\n", ")" ] }, { "cell_type": "markdown", "id": "b1f4d27e", "metadata": { "lines_to_next_cell": 2 }, "source": [ "Here we visualize the training data:\n", "![simple-batch](_static/cremi/simple-batch.gif)\n", "![simple-batch-cube](_static/cremi/simple-batch.jpg)" ] }, { "cell_type": "markdown", "id": "46f2af48", "metadata": {}, "source": [ "### Tasks\n", "When training for instance segmentation, it is not possible to directly predict label ids since the ids have to be unique accross the full volume which is not possible to do with the local context that a UNet operates on. So instead we need to transform our labels into some intermediate representation that is both easy to predict and easy to post process. The most common method we use is a combination of [affinities](https://arxiv.org/pdf/1706.00120) with optional [lsds](https://github.com/funkelab/lsd) for prediction plus [mutex watershed](https://arxiv.org/abs/1904.12654) for post processing.\n", "\n", "Next we will define the task that encapsulates this process." ] }, { "cell_type": "code", "execution_count": null, "id": "e2d42581", "metadata": {}, "outputs": [], "source": [ "from dacapo_toolbox.transforms.affs import Affs, AffsMask\n", "from dacapo_toolbox.transforms.weight_balancing import BalanceLabels\n", "import torchvision\n", "\n", "neighborhood = [\n", " (1, 0, 0),\n", " (0, 1, 0),\n", " (0, 0, 1),\n", " (0, 7, 0),\n", " (0, 0, 7),\n", " (0, 23, 0),\n", " (0, 0, 23),\n", "]\n", "train_dataset = iterable_dataset(\n", " datasets={\"raw\": raw_train, \"gt\": labels_train},\n", " shapes={\"raw\": (13, 256, 256), \"gt\": (13, 256, 256)},\n", " transforms={\n", " (\"gt\", \"affs\"): Affs(neighborhood=neighborhood, concat_dim=0),\n", " (\"gt\", \"affs_mask\"): AffsMask(neighborhood=neighborhood),\n", " },\n", " deform_augment_config=DeformAugmentConfig(\n", " p=0.1,\n", " control_point_spacing=(2, 10, 10),\n", " jitter_sigma=(0.5, 2, 2),\n", " rotate=True,\n", " subsample=4,\n", " rotation_axes=(1, 2),\n", " scale_interval=(1.0, 1.0),\n", " ),\n", " simple_augment_config=SimpleAugmentConfig(\n", " p=1.0,\n", " mirror_only=(1, 2),\n", " transpose_only=(1, 2),\n", " ),\n", ")\n", "\n", "batch_gen = iter(train_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "92d7bb06", "metadata": {}, "outputs": [], "source": [ "batch = next(batch_gen)\n", "gif_2d(\n", " arrays={\n", " \"Raw\": Array(batch[\"raw\"].numpy(), voxel_size=raw_train.voxel_size),\n", " \"GT\": Array(batch[\"gt\"].numpy() % 256, voxel_size=raw_train.voxel_size),\n", " \"Affs\": Array(\n", " batch[\"affs\"].float().numpy()[[0, 3, 4]],\n", " voxel_size=raw_train.voxel_size,\n", " ),\n", " \"Affs Mask\": Array(\n", " batch[\"affs_mask\"].float().numpy()[[0, 3, 4]],\n", " voxel_size=raw_train.voxel_size,\n", " ),\n", " },\n", " array_types={\n", " \"Raw\": \"raw\",\n", " \"GT\": \"labels\",\n", " \"Affs\": \"affs\",\n", " \"Affs Mask\": \"affs\",\n", " },\n", " filename=\"_static/cremi/affs-batch.gif\",\n", " title=\"Affinities Batch\",\n", " fps=10,\n", ")\n", "cube(\n", " arrays={\n", " \"Raw\": Array(batch[\"raw\"].numpy(), voxel_size=raw_train.voxel_size),\n", " \"GT\": Array(batch[\"gt\"].numpy(), voxel_size=raw_train.voxel_size),\n", " \"Affs\": Array(\n", " batch[\"affs\"].float().numpy()[[0, 3, 4]],\n", " voxel_size=raw_train.voxel_size,\n", " ),\n", " \"Affs Mask\": Array(\n", " batch[\"affs_mask\"].float().numpy()[[0, 3, 4]],\n", " voxel_size=raw_train.voxel_size,\n", " ),\n", " },\n", " array_types={\n", " \"Raw\": \"raw\",\n", " \"GT\": \"labels\",\n", " \"Affs\": \"affs\",\n", " \"Affs Mask\": \"affs\",\n", " },\n", " filename=\"_static/cremi/affs-batch.jpg\",\n", " title=\"Affinities Batch\",\n", ")" ] }, { "cell_type": "markdown", "id": "fc69fd2f", "metadata": {}, "source": [ "Here we visualize a batch with (raw, gt, target) triplets for the affinities task:\n", "![affs-batch](_static/cremi/affs-batch.gif)\n", "![affs-batch-cube](_static/cremi/affs-batch.jpg)" ] }, { "cell_type": "markdown", "id": "005fd796", "metadata": {}, "source": [ "### Models\n", "Lets define our model" ] }, { "cell_type": "code", "execution_count": null, "id": "9ea7fa72", "metadata": {}, "outputs": [], "source": [ "import tems\n", "import torch\n", "\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"cpu\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "\n", "unet = tems.UNet.funlib_api(\n", " dims=3,\n", " in_channels=1,\n", " num_fmaps=32,\n", " fmap_inc_factor=4,\n", " downsample_factors=[(1, 2, 2), (1, 2, 2), (1, 2, 2)],\n", " kernel_size_down=[\n", " [(1, 3, 3), (1, 3, 3)],\n", " [(1, 3, 3), (1, 3, 3)],\n", " [(1, 3, 3), (1, 3, 3)],\n", " [(1, 3, 3), (1, 3, 3)],\n", " ],\n", " kernel_size_up=[\n", " [(1, 3, 3), (1, 3, 3)],\n", " [(1, 3, 3), (1, 3, 3)],\n", " [(3, 3, 3), (3, 3, 3)],\n", " ],\n", " activation=\"LeakyReLU\",\n", ")\n", "\n", "# Small sigmoid wrapper to apply sigmoid only when not training\n", "# this is because training BCEWithLogitsLoss is more stable\n", "# than training with a sigmoid followed by BCELoss\n", "class SigmoidWrapper(torch.nn.Module):\n", " def __init__(self, model):\n", " super().__init__()\n", " self.model = model\n", " self.apply_sigmoid = True\n", "\n", " def forward(self, x):\n", " logits = self.model(x)\n", " if self.apply_sigmoid and not self.training:\n", " return torch.sigmoid(logits)\n", " return logits\n", "\n", "\n", "module = SigmoidWrapper(\n", " torch.nn.Sequential(unet, torch.nn.Conv3d(32, len(neighborhood), kernel_size=1))\n", ").to(device)" ] }, { "cell_type": "markdown", "id": "7b5a2ff4", "metadata": {}, "source": [ "### Training loop\n", "Now we can bring everything together and train our model." ] }, { "cell_type": "code", "execution_count": null, "id": "a7e380b5", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "extra = torch.tensor((2, 64, 64))\n", "train_dataset = iterable_dataset(\n", " datasets={\"raw\": raw_train, \"gt\": labels_train},\n", " shapes={\n", " \"raw\": unet.min_input_shape + extra,\n", " \"gt\": unet.min_output_shape + extra,\n", " },\n", " transforms={\n", " \"raw\": torchvision.transforms.Lambda(lambda x: x[None].float() / 255.0),\n", " (\"gt\", \"affs\"): Affs(neighborhood=neighborhood, concat_dim=0),\n", " (\"gt\", \"affs_mask\"): AffsMask(neighborhood=neighborhood),\n", " },\n", " deform_augment_config=DeformAugmentConfig(\n", " p=0.1,\n", " control_point_spacing=(2, 10, 10),\n", " jitter_sigma=(0.5, 2, 2),\n", " rotate=True,\n", " subsample=4,\n", " rotation_axes=(1, 2),\n", " scale_interval=(1.0, 1.0),\n", " ),\n", " simple_augment_config=SimpleAugmentConfig(\n", " p=1.0,\n", " mirror_only=(1, 2),\n", " transpose_only=(1, 2),\n", " ),\n", ")\n", "\n", "loss_func = partial(torchvision.ops.sigmoid_focal_loss, reduction=\"none\")\n", "optimizer = torch.optim.Adam(module.parameters(), lr=5e-5)\n", "dataloader = torch.utils.data.DataLoader(\n", " train_dataset,\n", " batch_size=3,\n", " num_workers=4,\n", ")\n", "losses = []\n", "\n", "for iteration, batch in tqdm(enumerate(iter(dataloader))):\n", " raw, target, affs_mask = (\n", " batch[\"raw\"].to(device),\n", " batch[\"affs\"].to(device),\n", " batch[\"affs_mask\"].to(device),\n", " )\n", " optimizer.zero_grad()\n", "\n", " output = module(raw)\n", "\n", " voxel_loss = loss_func(output, target.float())\n", " loss = (voxel_loss * affs_mask).sum() / affs_mask.sum()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " losses.append(loss.item())\n", "\n", " if iteration >= NUM_ITERATIONS:\n", " break" ] }, { "cell_type": "code", "execution_count": null, "id": "0284905f", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from funlib.geometry import Coordinate\n", "\n", "plt.plot(losses)\n", "plt.xlabel(\"Iteration\")\n", "plt.ylabel(\"Loss\")\n", "plt.title(\"Loss Curve\")\n", "plt.savefig(\"_static/cremi/affs-loss-curve.png\")\n", "plt.show()\n", "plt.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "1cfe73f1", "metadata": {}, "outputs": [], "source": [ "import mwatershed as mws\n", "from funlib.geometry import Roi\n", "import numpy as np\n", "\n", "module = module.eval()\n", "unet = unet.eval()\n", "context = Coordinate(unet.context // 2) * raw_test.voxel_size" ] }, { "cell_type": "code", "execution_count": null, "id": "213aa31d", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "raw_input = raw_test.to_ndarray(small_eval_roi.grow(context, context))\n", "raw_output = raw_test.to_ndarray(small_eval_roi)\n", "gt = labels_test.to_ndarray(small_eval_roi)\n", "\n", "# Predict on the validation data\n", "with torch.no_grad():\n", " device = torch.device(\"cpu\")\n", " module = module.to(device)\n", " pred = (\n", " module(\n", " (torch.from_numpy(raw_input).float() / 255.0)\n", " .to(device)\n", " .unsqueeze(0)\n", " .unsqueeze(0)\n", " )\n", " .cpu()\n", " .detach()\n", " .numpy()\n", " )\n", "pred_labels = mws.agglom(pred[0].astype(np.float64) - 0.5, offsets=neighborhood)" ] }, { "cell_type": "code", "execution_count": null, "id": "8a79de7a", "metadata": {}, "outputs": [], "source": [ "# Plot the results\n", "gif_2d(\n", " arrays={\n", " \"Raw\": Array(raw_output, voxel_size=raw_test.voxel_size),\n", " \"GT\": Array(gt % 256, voxel_size=raw_test.voxel_size),\n", " \"Pred Affs\": Array(pred[0][[0, 3, 4]], voxel_size=raw_test.voxel_size),\n", " \"Pred\": Array(pred_labels % 256, voxel_size=raw_test.voxel_size),\n", " },\n", " array_types={\n", " \"Raw\": \"raw\",\n", " \"GT\": \"labels\",\n", " \"Pred Affs\": \"affs\",\n", " \"Pred\": \"labels\",\n", " },\n", " filename=\"_static/cremi/affs-prediction.gif\",\n", " title=\"Prediction\",\n", " fps=10,\n", ")\n", "cube(\n", " arrays={\n", " \"Raw\": Array(raw_output, voxel_size=raw_test.voxel_size),\n", " \"GT\": Array(gt, voxel_size=raw_test.voxel_size),\n", " \"Pred Affs\": Array(pred[0][[0, 3, 4]], voxel_size=raw_test.voxel_size),\n", " \"Pred\": Array(pred_labels, voxel_size=raw_test.voxel_size),\n", " },\n", " array_types={\n", " \"Raw\": \"raw\",\n", " \"GT\": \"labels\",\n", " \"Pred Affs\": \"affs\",\n", " \"Pred\": \"labels\",\n", " },\n", " filename=\"_static/cremi/affs-prediction.jpg\",\n", " title=\"Prediction\",\n", ")" ] }, { "cell_type": "markdown", "id": "e89bb02a", "metadata": {}, "source": [ "Here we visualize the prediction results:\n", "![affs-prediction](_static/cremi/affs-prediction.gif)\n", "![affs-prediction-cube](_static/cremi/affs-prediction.jpg)" ] }, { "cell_type": "markdown", "id": "395ac6da", "metadata": {}, "source": [ "## Blockwise Processing\n", "Now that we have a trained model, we can use it to process the full volume.\n", "We will use the `volara` library to do this. It provides a simple interface\n", "for blockwise processing of large volumes. We will use the `volara_torch`\n", "module to wrap our trained model and use it in a blockwise pipeline." ] }, { "cell_type": "code", "execution_count": null, "id": "c875c138", "metadata": {}, "outputs": [], "source": [ "from dacapo_toolbox.postprocessing import blockwise_predict_mutex\n", "from volara.workers import LocalWorker\n", "\n", "\n", "unet = unet.eval()\n", "scripted_unet = torch.jit.script(module)\n", "torch.jit.save(scripted_unet, \"cremi.zarr/affs_unet.pt\")\n", "torch.save(scripted_unet.state_dict(), \"cremi.zarr/weights.pth\")\n", "\n", "blocksize = Coordinate(unet.min_output_shape) + blocksize\n", "\n", "# default biases:\n", "# interpolate log offset distances to a range of [-0.2, -0.8]\n", "\n", "blockwise_predict_mutex(\n", " raw_store=\"cremi.zarr/test/raw\",\n", " affs_store=\"cremi.zarr/test/affs\", # optional, provided for visualization\n", " frags_store=\"cremi.zarr/test/frags\", # optional, provided for visualization\n", " labels_store=\"cremi.zarr/test/pred_labels\",\n", " neighborhood=neighborhood,\n", " blocksize=blocksize,\n", " model_path=\"cremi.zarr/affs_unet.pt\",\n", " in_channels=1,\n", " model_context=unet.context // 2,\n", " predict_worker=LocalWorker(), # optional, see docstring\n", " extract_frag_bias=[\n", " -0.5,\n", " -0.2,\n", " -0.2,\n", " -0.5,\n", " -0.5,\n", " -0.8,\n", " -0.8,\n", " ], # optional, TODO: defaults not very good yet\n", " edge_scores=[ # optional, TODO: defaults not very good yet\n", " (\"affs_z\", [Coordinate(1, 0, 0)], -0.5),\n", " (\"affs_xy\", [Coordinate(0, 1, 0), Coordinate(0, 0, 1)], -0.2),\n", " (\n", " \"affs_long_xy\",\n", " [\n", " Coordinate(0, 7, 0),\n", " Coordinate(0, 0, 7),\n", " Coordinate(0, 23, 0),\n", " Coordinate(0, 0, 23),\n", " ],\n", " -0.8,\n", " ),\n", " ],\n", " num_extract_frag_workers=3,\n", " num_aff_agglom_workers=3,\n", " num_relabel_workers=3,\n", " roi=large_eval_roi,\n", ")" ] }, { "cell_type": "markdown", "id": "488b7a3b", "metadata": {}, "source": [ "## Visualizing the results" ] }, { "cell_type": "code", "execution_count": null, "id": "f48faff1", "metadata": {}, "outputs": [], "source": [ "from funlib.persistence import open_ds\n", "\n", "affs = open_ds(\"cremi.zarr/test/affs\")\n", "affs.lazy_op(lambda x: x[[0, 3, 4]] / 255.0)\n", "raw = open_ds(\"cremi.zarr/test/raw\")\n", "raw.lazy_op(large_eval_roi)\n", "gif_2d(\n", " arrays={\n", " \"Raw\": raw,\n", " \"Affs\": affs,\n", " \"Frags\": open_ds(\"cremi.zarr/test/frags\"),\n", " \"Pred Labels\": open_ds(\"cremi.zarr/test/pred_labels\"),\n", " },\n", " array_types={\n", " \"Raw\": \"raw\",\n", " \"Affs\": \"affs\",\n", " \"Frags\": \"labels\",\n", " \"Pred Labels\": \"labels\",\n", " },\n", " title=\"CREMI Affs Prediction\",\n", " filename=\"_static/cremi/cremi-prediction.gif\",\n", " fps=10,\n", ")\n", "cube(\n", " arrays={\n", " \"raw\": raw,\n", " \"affs\": affs,\n", " \"frags\": open_ds(\"cremi.zarr/test/frags\"),\n", " \"pred_labels\": open_ds(\"cremi.zarr/test/pred_labels\"),\n", " },\n", " array_types={\n", " \"raw\": \"raw\",\n", " \"affs\": \"affs\",\n", " \"frags\": \"labels\",\n", " \"pred_labels\": \"labels\",\n", " },\n", " title=\"CREMI Affs Prediction\",\n", " filename=\"_static/cremi/cremi-prediction.jpg\",\n", ")" ] }, { "cell_type": "markdown", "id": "aa7f3039", "metadata": {}, "source": [ "Here we visualize the prediction results:\n", "![cremi-prediction](_static/cremi/cremi-prediction.gif)\n", "![cremi-prediction-cube](_static/cremi/cremi-prediction.jpg)" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }