{ "cells": [ { "cell_type": "markdown", "id": "ce73e963", "metadata": {}, "source": [ "# Dataset Overview\n", "\n", "## Overview of the `dacapo_toolbox.iterable_dataset` helper function\n", "\n", "The `iterable_dataset` function is a powerful helper function that wraps around\n", "`gunpowder` and `funlib` libraries to provide a simple and powerful entrypoint\n", "for creating torch datasets. It's main features are:\n", "1. Properly handling spatial augmentations such as mirroring, transposing,\n", "elastic deformations, rotations and image scaling while handling any\n", "necessary context without excess padding or data reads. (This is a gunpowder feature)\n", "See the [gunpowder docs](https://funkelab.github.io/gunpowder/) for more details.\n", "2. Robust sampling of input data using a variety of sampling strategies such as\n", "sampling from a set of points, guaranteeing a certain amount of masked in data,\n", "or sampling uniformly from the input data. (This is also achieved using gunpowder)\n", "3. Creates a simple torch dataset interface that can be used with any pytorch\n", "parallelization scheme such as `torch.utils.data.DataLoader`.\n", "4. Can handle both arrays and graphs as input and output data.\n", "5. Can handle arbitrary number of dimensions, easily generalizing to 3D plus time." ] }, { "cell_type": "code", "execution_count": null, "id": "6cf08c76", "metadata": {}, "outputs": [], "source": [ "# ## A simple dataset\n", "\n", "from dacapo_toolbox.dataset import iterable_dataset\n", "from funlib.persistence import Array\n", "from skimage.data import astronaut\n", "import matplotlib.pyplot as plt\n", "\n", "from pathlib import Path\n", "\n", "out_ds = Path(\"_static/dataset_overview\")\n", "if not out_ds.exists():\n", " out_ds.mkdir(parents=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "787241c7", "metadata": {}, "outputs": [], "source": [ "\n", "dataset = iterable_dataset(\n", " {\"astronaut\": Array(astronaut().transpose((2, 0, 1)), voxel_size=(1, 1))},\n", " shapes={\"astronaut\": (256, 256)},\n", ")\n", "batch_gen = iter(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "f7c055dd", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "plt.imshow(sample[\"astronaut\"].numpy().transpose((1, 2, 0)))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "4b8a60f6", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "plt.imshow(sample[\"astronaut\"].numpy().transpose((1, 2, 0)))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "3403b784", "metadata": {}, "source": [ "You may notice a couple things about the above images.\n", "\n", "First why are we transposing and adding a voxel size? `Gunpowder` expects channel dimensions\n", "to come before spatial dimensions, and the voxel size defines the number of spatial channels.\n", "Allowing for a voxel size lets us handle arrays of different resolutions and makes sure we can\n", "handle any non-isotropic data gracefully.\n", "\n", "Second we see some padding at the side. This is because we treat every array given to us\n", "as an infinite array padded with zeros, and by default only guarantee that the center pixel is\n", "sampled from within the provided array. You can adjust this with the `trim` term." ] }, { "cell_type": "code", "execution_count": null, "id": "c9d148e8", "metadata": {}, "outputs": [], "source": [ "\n", "dataset = iterable_dataset(\n", " {\"astronaut\": Array(astronaut().transpose((2, 0, 1)), voxel_size=(1, 1))},\n", " shapes={\"astronaut\": (256, 256)},\n", " trim=(128, 128),\n", ")\n", "batch_gen = iter(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "89842fea", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "plt.imshow(sample[\"astronaut\"].numpy().transpose((1, 2, 0)))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "8988b59b", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "plt.imshow(sample[\"astronaut\"].numpy().transpose((1, 2, 0)))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "a4efea8d", "metadata": {}, "source": [ "Now since we trim, we make sure we only choose samples where the\n", "center pixel is at least `trim` pixels away from the edge of the image.\n", "This guarantees that we don't get samples with padding, but this may lead\n", "to errors if your training data is smaller than `2*trim`\n", "\n", "Next lets add more arrays. Maybe you have multiple datasets, and multiple\n", "arrays per dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "4499d1d4", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "856302b6", "metadata": {}, "outputs": [], "source": [ "from skimage.data import coffee\n", "\n", "astronaut_data = astronaut().transpose((2, 0, 1)) / 255\n", "coffee_data = coffee().transpose((2, 0, 1)) / 255\n", "\n", "dataset = iterable_dataset(\n", " {\n", " \"image\": [\n", " Array(astronaut_data, voxel_size=(1, 1)),\n", " Array(coffee_data, voxel_size=(1, 1)),\n", " ],\n", " \"mask\": [\n", " Array(\n", " astronaut_data[0] > (astronaut_data[1] + astronaut_data[2])\n", " ), # mask in red regions\n", " Array(\n", " coffee_data[0] > (coffee_data[1] + coffee_data[2])\n", " ), # mask in red regions\n", " ],\n", " },\n", " shapes={\"image\": (256, 256), \"mask\": (256, 256)},\n", ")\n", "batch_gen = iter(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "d269bb08", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", "ax[0].imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "ax[1].imshow(sample[\"mask\"].numpy())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "bf8d08bb", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", "ax[0].imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "ax[1].imshow(sample[\"mask\"].numpy())\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "009ec1cb", "metadata": {}, "source": [ "Lets add a graph to the dataset. We use the `networkx` library\n", "to interface with graphs. Each node in the graph must have a `position`\n", "attribute in *world coordinates*, this means accounting for the\n", "voxel size of the arrays given. In our case voxel size is (1, 1) so\n", "we can just use the pixel coordinates. We'll use a simple grid of points as our graph." ] }, { "cell_type": "code", "execution_count": null, "id": "409f041d", "metadata": {}, "outputs": [], "source": [ "import networkx as nx\n", "from itertools import product\n", "\n", "\n", "def gen_graphs():\n", " graphs = []\n", " for img_data in [astronaut_data, coffee_data]:\n", " graph = nx.Graph()\n", " for i, j in product(\n", " range(0, img_data.shape[1], 32), range(0, img_data.shape[2], 32)\n", " ):\n", " graph.add_node(i * img_data.shape[2] + j, position=(i + 0.5, j + 0.5))\n", " graphs.append(graph)\n", " return graphs" ] }, { "cell_type": "code", "execution_count": null, "id": "e8db8913", "metadata": {}, "outputs": [], "source": [ "# Lets request the graph in a smaller region. It will be centered within the data\n", "# we ask for." ] }, { "cell_type": "code", "execution_count": null, "id": "24609eb5", "metadata": {}, "outputs": [], "source": [ "\n", "dataset = iterable_dataset(\n", " {\n", " \"image\": [\n", " Array(astronaut_data, voxel_size=(1, 1)),\n", " Array(coffee_data, voxel_size=(1, 1)),\n", " ],\n", " \"mask\": [\n", " Array(\n", " astronaut_data[0] > (astronaut_data[1] + astronaut_data[2])\n", " ), # mask in red regions\n", " Array(\n", " coffee_data[0] > (coffee_data[1] + coffee_data[2])\n", " ), # mask in red regions\n", " ],\n", " \"graph\": gen_graphs(),\n", " },\n", " shapes={\"image\": (256, 256), \"mask\": (256, 256), \"graph\": (128, 128)},\n", ")\n", "batch_gen = iter(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "519c258e", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "sample = next(batch_gen)\n", "\n", "graph = sample[\"graph\"]\n", "xs = np.array([attrs[\"position\"][0] for attrs in graph.nodes.values()])\n", "ys = np.array([attrs[\"position\"][1] for attrs in graph.nodes.values()])\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", "ax[0].imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "ax[0].scatter(ys, xs, c=\"red\", s=10)\n", "ax[1].imshow(sample[\"mask\"].numpy())\n", "ax[1].scatter(ys, xs, c=\"red\", s=10)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "d65db8a2", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "graph = sample[\"graph\"]\n", "xs = np.array([attrs[\"position\"][0] for attrs in graph.nodes.values()])\n", "ys = np.array([attrs[\"position\"][1] for attrs in graph.nodes.values()])\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n", "ax[0].imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "ax[0].scatter(ys, xs, c=\"red\", s=10)\n", "ax[1].imshow(sample[\"mask\"].numpy())\n", "ax[1].scatter(ys, xs, c=\"red\", s=10)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "701a4e64", "metadata": {}, "source": [ "## Augmentations/Transformations\n", "\n", "We almost always want to transform our data in some way. The `iterable_dataset`\n", "differentiates between two types of transformations:\n", "1. **Spatial augmentations**: These are augmentations that change the spatial\n", " properties of the data. For example, mirroring, transposing, elastic deformations.\n", "2. **Non-spatial augmentations**: These are augmentations that operate on the image\n", " content itself. For example, adding noise, changing brightness, binarizing, etc." ] }, { "cell_type": "markdown", "id": "934f5947", "metadata": {}, "source": [ "### Spatial augmentations\n", "\n", "We take two config classes that parameterize the spatial augmentations we support.\n", "`DeformAugmentConfig`, and `SimpleAugmentConfig`.\n", "1. The `DeformAugmentConfig` handles continuous transforms requiring interpolation.\n", "This includes rotation, scaling and elastically deforming.\n", "2. The `SimpleAugmentConfig` handles discrete transforms that don't require interpolation.\n", "This includes mirroring and transposing." ] }, { "cell_type": "code", "execution_count": null, "id": "ab823cf1", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "from dacapo_toolbox.dataset import DeformAugmentConfig, SimpleAugmentConfig\n", "\n", "\n", "dataset = iterable_dataset(\n", " {\n", " \"image\": [\n", " Array(astronaut_data, voxel_size=(1, 1)),\n", " Array(coffee_data, voxel_size=(1, 1)),\n", " ],\n", " \"graph\": gen_graphs(),\n", " },\n", " shapes={\"image\": (256, 256), \"graph\": (128, 128)},\n", " simple_augment_config=SimpleAugmentConfig(p=1.0, mirror_only=[1]),\n", " deform_augment_config=DeformAugmentConfig(\n", " p=1.0,\n", " control_point_spacing=(8, 8),\n", " jitter_sigma=(8.0, 8.0),\n", " scale_interval=(0.5, 2.0),\n", " rotate=True,\n", " ),\n", ")\n", "batch_gen = iter(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "8f9f2daa", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "graph = sample[\"graph\"]\n", "xs = np.array([attrs[\"position\"][0] for attrs in graph.nodes.values()])\n", "ys = np.array([attrs[\"position\"][1] for attrs in graph.nodes.values()])\n", "\n", "plt.imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "plt.scatter(ys, xs, c=\"red\", s=10)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "53c1d409", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "graph = sample[\"graph\"]\n", "xs = np.array([attrs[\"position\"][0] for attrs in graph.nodes.values()])\n", "ys = np.array([attrs[\"position\"][1] for attrs in graph.nodes.values()])\n", "\n", "plt.imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "plt.scatter(ys, xs, c=\"red\", s=10)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "78f4db83", "metadata": {}, "source": [ "### Non-spatial augmentations\n", "\n", "Non-spatial augmentations are handled using the `transforms` argument to the\n", "`iterable_dataset` function. This is a dictionary of tuples, where the key is\n", "a tuple of input and output keys, and the value is a callable that takes the\n", "designated inputs, and generates the designated outputs.\n", "\n", "e.g. `((\"a\", \"b\"), (\"c\", \"d\")): transform_1` means we expect `transform_1` to take\n", "2 tensors in (\"a\", \"b\"), and produce 2 tensors (\"c\", \"d\").\n", "\n", "`(\"a\", \"c\"): transform_2` is short hand for a transform that takes a single tensor\n", "in and outputs a single tensor.\n", "\n", "`\"a\": transform_3` is short hand for a transform that takes in a single tensor and\n", "produces a single tensor that should replace the input tensor.\n", "\n", "Lets see some examples:" ] }, { "cell_type": "code", "execution_count": null, "id": "be7574d3", "metadata": {}, "outputs": [], "source": [ "from torchvision.transforms import v2 as transforms\n", "\n", "dataset = iterable_dataset(\n", " {\n", " \"image\": [\n", " Array(astronaut_data, voxel_size=(1, 1)),\n", " Array(coffee_data, voxel_size=(1, 1)),\n", " ],\n", " },\n", " transforms={\n", " \"image\": transforms.GaussianBlur(3, sigma=(2.0, 2.0)),\n", " (\"image\", \"mask\"): lambda d: d[0] > d[1] + d[2], # mask in red regions\n", " ((\"mask\", \"image\"), \"masked_image\"): lambda mask, image: mask * image,\n", " },\n", " shapes={\"image\": (256, 256)},\n", ")\n", "batch_gen = iter(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "9953fcb2", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "fig, ax = plt.subplots(1, 3, figsize=(15, 5))\n", "ax[0].imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "ax[1].imshow(sample[\"mask\"].numpy())\n", "ax[2].imshow(sample[\"masked_image\"].numpy().transpose((1, 2, 0)))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "2ce218aa", "metadata": {}, "outputs": [], "source": [ "sample = next(batch_gen)\n", "\n", "fig, ax = plt.subplots(1, 3, figsize=(15, 5))\n", "ax[0].imshow(sample[\"image\"].numpy().transpose((1, 2, 0)))\n", "ax[1].imshow(sample[\"mask\"].numpy())\n", "ax[2].imshow(sample[\"masked_image\"].numpy().transpose((1, 2, 0)))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "7aca6d9f", "metadata": {}, "source": [ "The iterable dataset is very flexible and can handle a variety of use cases.\n", "Below is a more complex example showing a dataset with multiple datasets at various\n", "resolutions, using different sampling strategies, spatial and non-spatial augmentations,\n", "and both arrays and graphs as input and output.\n", "\n", "We generate 2 synthetic datasets with different sized blobs. `blobs_a` and `blobs_b`.\n", "We create raw and ground truth arrays in scale pyramid fashion for each dataset.\n", "`blobs_a` two scale levels s0 and s1 with voxel sizes (raw: (1, 1), gt: (2, 2)) and (raw: (2, 2), gt: (4, 4)) respectively.\n", "`blobs_b` two scale levels s0 and s1 with voxel sizes (raw: (1, 1), gt: (2, 2)) and (raw: (2, 2), gt: (4, 4)) respectively.\n", "We also create a mask for each dataset that masks out half the image. The iterable dataset function\n", "has no problem generating samples from a variety of different voxel sizes, different sampling strategies, etc." ] }, { "cell_type": "code", "execution_count": null, "id": "303a2d87", "metadata": {}, "outputs": [], "source": [ "from dacapo_toolbox.dataset import (\n", " iterable_dataset,\n", " SimpleAugmentConfig,\n", " DeformAugmentConfig,\n", " MaskedSampling,\n", " PointSampling,\n", ")\n", "from dacapo_toolbox.transforms.affs import Affs, AffsMask\n", "from funlib.persistence import Array\n", "from skimage import data\n", "from torchvision.transforms import v2 as transforms\n", "import numpy as np\n", "from skimage.measure import label\n", "\n", "side_length = 2048\n", "\n", "# two different datasets with vastly different blob sizes\n", "blobs_a = data.binary_blobs(\n", " length=side_length, blob_size_fraction=20 / side_length, n_dim=2\n", ")\n", "blobs_a_gt = label(blobs_a, connectivity=2)\n", "blobs_b = data.binary_blobs(\n", " length=side_length, blob_size_fraction=100 / side_length, n_dim=2\n", ")\n", "blobs_b_gt = label(blobs_b, connectivity=2)\n", "mask = np.ones((side_length, side_length), dtype=bool)\n", "mask[side_length // 2 : side_length] = 0\n", "\n", "# raw and gt arrays at various voxel sizes\n", "raw_a_s0 = Array(blobs_a[::1, ::1], offset=(0, 0), voxel_size=(1, 1))\n", "raw_a_s1 = Array(blobs_a[::2, ::2], offset=(0, 0), voxel_size=(2, 2))\n", "raw_b_s0 = Array(blobs_b[::1, ::1], offset=(0, 0), voxel_size=(2, 2))\n", "raw_b_s1 = Array(blobs_b[::2, ::2], offset=(0, 0), voxel_size=(4, 4))\n", "gt_a_s0 = Array(blobs_a_gt[::2, ::2], offset=(0, 0), voxel_size=(2, 2))\n", "gt_a_s1 = Array(blobs_a_gt[::4, ::4], offset=(0, 0), voxel_size=(4, 4))\n", "gt_b_s0 = Array(blobs_b_gt[::2, ::2], offset=(0, 0), voxel_size=(4, 4))\n", "gt_b_s1 = Array(blobs_b_gt[::4, ::4], offset=(0, 0), voxel_size=(8, 8))\n", "mask_a = Array(mask, offset=(0, 0), voxel_size=(1, 1))\n", "mask_b = Array(mask, offset=(0, 0), voxel_size=(2, 2))\n", "\n", "g = nx.Graph()\n", "g.add_nodes_from(\n", " [\n", " (i, {\"position\": position})\n", " for i, position in enumerate(\n", " [\n", " (side_length * 2 - 0.5, side_length * 2 - 0.5),\n", " (0.5, side_length * 2 - 0.5),\n", " (side_length * 2 - 0.5, 0.5),\n", " (0.5, 0.5),\n", " ]\n", " )\n", " ]\n", ")\n", "\n", "# defining the datasets\n", "iter_ds = iterable_dataset(\n", " {\n", " \"raw_s0\": [raw_a_s0, raw_b_s0],\n", " \"gt_s0\": [gt_a_s0, gt_b_s0],\n", " \"raw_s1\": [raw_a_s1, raw_b_s1],\n", " \"gt_s1\": [gt_a_s1, gt_b_s1],\n", " \"mask\": [mask_a, mask_b],\n", " \"mask_dummy\": [mask_a, mask_b],\n", " \"sample_points\": [None, g],\n", " },\n", " shapes={\n", " \"raw_s0\": (128 * 5, 128 * 5),\n", " \"gt_s0\": (64 * 5, 64 * 5),\n", " \"raw_s1\": (64 * 5, 64 * 5),\n", " \"gt_s1\": (32 * 5, 32 * 5),\n", " \"mask\": (128 * 5, 128 * 5),\n", " \"mask_dummy\": (64 * 5, 64 * 5),\n", " \"sample_points\": (128 * 5, 128 * 5),\n", " },\n", " sampling_strategies=[\n", " MaskedSampling(\"mask_dummy\", 0.8),\n", " PointSampling(\"sample_points\"),\n", " ],\n", " transforms={\n", " (\"raw_s0\", \"noisy_s0\"): transforms.Compose(\n", " [transforms.ConvertImageDtype(), transforms.GaussianNoise(sigma=1.0)]\n", " ),\n", " (\"raw_s1\", \"noisy_s1\"): transforms.Compose(\n", " [transforms.ConvertImageDtype(), transforms.GaussianNoise(sigma=0.3)]\n", " ),\n", " (\"gt_s0\", \"affs_s0\"): Affs([[4, 0], [0, 4], [4, 4]]),\n", " (\"gt_s0\", \"affs_mask_s0\"): AffsMask([[4, 0], [0, 4], [4, 4]]),\n", " (\"gt_s1\", \"affs_s1\"): Affs([[4, 0], [0, 4], [4, 4]]),\n", " (\"gt_s1\", \"affs_mask_s1\"): AffsMask([[4, 0], [0, 4], [4, 4]]),\n", " },\n", " simple_augment_config=SimpleAugmentConfig(\n", " p=1.0, mirror_probs=[1.0, 0.0], transpose_only=[]\n", " ),\n", " deform_augment_config=DeformAugmentConfig(\n", " p=1.0,\n", " control_point_spacing=(10, 10),\n", " jitter_sigma=(5.0, 5.0),\n", " scale_interval=(0.5, 2.0),\n", " rotate=True,\n", " ),\n", ")\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "for i, batch in enumerate(iter_ds):\n", " print(f\"Batch {i}\")\n", " if i >= 4: # Limit to 4 batches for demonstration\n", " break\n", " points = batch[\"sample_points\"]\n", " xs = np.array([attrs[\"position\"][0] for attrs in points.nodes.values()])\n", " ys = np.array([attrs[\"position\"][1] for attrs in points.nodes.values()])\n", " plt.scatter(xs, ys, c=\"red\", s=10)\n", "\n", " fig, axs = plt.subplots(2, 5, figsize=(18, 8))\n", " axs[0, 0].imshow(batch[\"noisy_s0\"], cmap=\"gray\")\n", " axs[0, 1].imshow(batch[\"gt_s0\"], cmap=\"magma\")\n", " axs[0, 2].imshow(batch[\"affs_s0\"].permute(1, 2, 0).float())\n", " axs[0, 3].imshow(batch[\"affs_mask_s0\"].permute(1, 2, 0).float())\n", " axs[0, 4].imshow(batch[\"mask\"].float(), vmin=0, vmax=1, cmap=\"gray\")\n", " axs[1, 0].imshow(batch[\"noisy_s1\"], cmap=\"gray\")\n", " axs[1, 1].imshow(batch[\"gt_s1\"], cmap=\"magma\")\n", " axs[1, 2].imshow(batch[\"affs_s1\"].permute(1, 2, 0).float())\n", " axs[1, 3].imshow(batch[\"affs_mask_s1\"].permute(1, 2, 0).float())\n", " axs[1, 4].imshow(batch[\"mask\"][::2, ::2].float(), vmin=0, vmax=1, cmap=\"gray\")\n", " for a, b in product(range(2), range(5)):\n", " s = 2 ** (a + (b % 4 != 0))\n", " axs[a, b].scatter(ys / s, xs / s, c=\"red\", s=10)\n", "\n", " axs[0, 0].set_title(\"Raw\")\n", " axs[0, 1].set_title(\"GT\")\n", " axs[0, 2].set_title(\"Affs\")\n", " axs[0, 3].set_title(\"Affs Mask\")\n", " axs[0, 4].set_title(\"Mask\")\n", "\n", " plt.show()" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }