Source code for torchflare.datasets.image_segmentation

import glob
import pathlib
from typing import List, Optional, Tuple, Union

import albumentations as A
import pandas as pd

from torchflare.datasets.core_utils import (
    apply_image_augmentations,
    apply_segmentation_augs,
    get_iloc_cols,
    join_paths,
    make_masks,
    open_image,
)
from torchflare.datasets.data_core import BaseDataset, ItemReader


def get_rle(df: pd.DataFrame, image_col: str, mask_cols: List[str], name: str) -> List[str]:
    """Get rule length encoding for corresponding image ids."""
    ids = df[df[image_col] == name].index.values
    rle = df.loc[ids, mask_cols].values.tolist()
    return rle


def create_rle_list(
    df: pd.DataFrame,
    image_col: str,
    mask_cols: List[str],
) -> List[List[str]]:
    """Create mask list.

    Args:
        df : The dataframe.
        image_col: The column containing image_names.
        mask_cols : The column/columns containing mask rle's.

    Returns:
        A list of rle's
    """
    if isinstance(image_col, list) and len(image_col) == 1:
        image_col = image_col[0]

    image_names = df[image_col].values.tolist()
    mask_list = []

    for name in image_names:
        rle = get_rle(df, image_col, mask_cols, name)
        mask_list.append(rle)

    return mask_list


class MaskDataset(BaseDataset):
    """Dataset for image segmentation."""

    def __init__(
        self,
        image_convert_mode: str,
        mask_convert_mode: str,
        shape: Tuple = None,
        num_classes: int = None,
        **kwargs
    ):
        super(MaskDataset, self).__init__(**kwargs)
        self.mask_convert_mode = mask_convert_mode
        self.image_convert_mode = image_convert_mode
        self.shape = shape
        self.num_classes = num_classes

    def _get_labels(self, idx):
        if any(isinstance(ele, list) for ele in self.y):
            mask = make_masks(rle=self.y[idx], shape=self.shape, num_classes=self.num_classes)
        else:
            mask = open_image(self.y[idx], convert_mode=self.mask_convert_mode)
        return mask

    def __getitem__(self, idx):
        x = open_image(self.item_reader.items[idx], convert_mode=self.image_convert_mode)
        if not self.is_y_none:
            mask = self._get_labels(idx=idx)
            images, mask = apply_segmentation_augs(
                image=x, transforms=self.item_reader.transforms, mask=mask
            )
            return images, mask
        images = apply_image_augmentations(x, transforms=self.item_reader.transforms)
        return images


[docs]class SegmentationDataset(ItemReader): """PyTorch style dataset for image segmentation.""" def __init__(self, input_cols, image_convert_mode, **kwargs): super(SegmentationDataset, self).__init__(**kwargs) self.image_convert_mode = image_convert_mode self.input_cols = input_cols self.mask_dataset = MaskDataset def apply_target_transforms(self, transforms, item): """Method to apply transforms to inputs.""" raise NotImplementedError def apply_input_transforms(self, transforms, item): """Method to apply transforms to targets.""" raise NotImplementedError # skipcq : PYL-W0221
[docs] @classmethod def from_df( cls, df: pd.DataFrame, path: Union[str, pathlib.Path], input_columns: List[str], transforms: A.Compose = None, image_convert_mode: str = "RGB", extension: Optional[str] = None, **kwargs ): """Method to read images from dataframe. Args: df : The dataframe containing the image names/ids. input_columns : A list containing columns which have names of images. path: The path where images are saved. transforms: The transforms to be used on the inputs. image_convert_mode: The mode to be passed to PIL.Image.convert. extension : The extension of image file. Example: .. code-block:: python from torchflare.datasets import SegmentationDataset ds = SegmentationDataset.from_df( df=df, path="/train/images", input_columns=["image_id"], extension=".jpg", augmentations=augs, image_convert_mode="RGB", ).masks_from_rle(mask_cols=["EncodedPixles"], mask_size=(320, 320), num_classes=4) """ path = pathlib.Path(path) files = get_iloc_cols(df, input_columns) files = join_paths(path=path, files=files, extension=extension) return cls( items=files, transforms=transforms, df=df, path=path, image_convert_mode=image_convert_mode, input_cols=input_columns, **kwargs )
# skipcq : PYL-W0221
[docs] @classmethod def from_folders( cls, image_path: Union[str, pathlib.Path], transforms: Optional[A.Compose] = None, image_convert_mode: str = "RGB", extension: str = None, **kwargs ): """Classmethod to create pytorch dataset from folders. Args: image_path: The path where images are stored. transforms: The transforms to apply on images and masks. image_convert_mode: The mode to be passed to PIL.Image.convert for input images extension : The extension for image like .jpg, etc Example: .. code-block:: python from torchflare.datasets import SegmentationDataset ds = SegmentationDataset.from_folders( image_path="/train/images", transforms=augs, image_convert_mode="L", ).masks_from_folders(mask_convert_mode="L", mask_path="/train/masks", mask_convert_mode = "L") """ files = glob.glob(image_path + "/*") return cls( items=files, path=image_path, transforms=transforms, input_cols=None, image_convert_mode=image_convert_mode, extension=extension, **kwargs )
def _create_mask_dataset( self, labels, shape=None, num_classes=None, mask_convert_mode=None, target_transforms=None ): return self.mask_dataset( item_reader=self, y=labels, target_transforms=target_transforms, mask_convert_mode=mask_convert_mode, image_convert_mode=self.image_convert_mode, shape=shape, num_classes=num_classes, )
[docs] def masks_from_rle( self, shape: Tuple[int, int], num_classes: int, mask_columns: Optional[List[str]] ): """Create masks from rule length encoding. Args: mask_columns : The list of columns containing the rule length encoding. shape : The shape for masks. num_classes: The number of num_classes """ masks = create_rle_list( df=self.df, image_col=self.input_cols, mask_cols=mask_columns, ) return self._create_mask_dataset( labels=masks, shape=shape, num_classes=num_classes, mask_convert_mode=None, target_transforms=None, )
[docs] def masks_from_folders(self, mask_path: Union[str, pathlib.Path], mask_convert_mode: str): """Read masks from folders. Args: mask_path: The path where masks are stored. mask_convert_mode: The mode to be passed to PIL.Image.convert for masks. """ masks = glob.glob(mask_path + "/*") return self._create_mask_dataset( labels=masks, mask_convert_mode=mask_convert_mode, target_transforms=None )
[docs] def add_test(self): """Method to create dataset for inference.""" return self._create_mask_dataset( labels=None, mask_convert_mode=None, target_transforms=None )
__all__ = ["SegmentationDataset"]