Source code for bliss.data_augmentation

import random

import torch

from bliss.catalog import TileCatalog


[docs] class RotateFlipTransform(torch.nn.Module): def __call__(self, datum, rotate_id=None, do_flip=None): # problematic if the psf isn't rotationally invariant datum_out = {"psf_params": datum["psf_params"]} # apply rotation if rotate_id is None: rotate_id = random.randint(0, 4) # noqa: S311 datum_out["images"] = datum["images"].rot90(rotate_id, [1, 2]) d = datum["tile_catalog"] datum_out["tile_catalog"] = {k: v.rot90(rotate_id, [0, 1]) for k, v in d.items()} # apply flip if do_flip is None: do_flip = random.choice([True, False]) # noqa: S311 if do_flip: datum_out["images"] = datum_out["images"].flip([1]) for k, v in datum_out["tile_catalog"].items(): datum_out["tile_catalog"][k] = v.flip([0]) # locations require special logic if "locs" in datum["tile_catalog"]: locs = datum_out["tile_catalog"]["locs"] for _ in range(rotate_id): # Rotate 90 degrees clockwise (in pixel coordinates) locs = torch.stack((1 - locs[..., 1], locs[..., 0]), dim=3) if do_flip: locs = torch.stack((1 - locs[..., 0], locs[..., 1]), dim=3) datum_out["tile_catalog"]["locs"] = locs return datum_out
[docs] class RandomShiftTransform(torch.nn.Module): def __init__(self, tile_slen, max_sources_per_tile): super().__init__() assert tile_slen % 2 == 0 and tile_slen > 1 self.tile_slen = tile_slen self.max_sources_per_tile = max_sources_per_tile def __call__(self, datum, vertical_shift=None, horizontal_shift=None): datum_out = {"psf_params": datum["psf_params"]} shift_ub = self.tile_slen // 2 shift_lb = -(shift_ub - 1) if vertical_shift is None: vertical_shift = random.randint(shift_lb, shift_ub) # noqa: S311 if horizontal_shift is None: horizontal_shift = random.randint(shift_lb, shift_ub) # noqa: S311 img = datum["images"] img = torch.roll(img, shifts=vertical_shift, dims=1) img = torch.roll(img, shifts=horizontal_shift, dims=2) datum_out["images"] = img tile_cat = TileCatalog.from_dict(datum["tile_catalog"]) full_cat = tile_cat.to_full_catalog(self.tile_slen) full_cat["plocs"][:, :, 0] += vertical_shift full_cat["plocs"][:, :, 1] += horizontal_shift aug_tile = full_cat.to_tile_catalog( self.tile_slen, self.max_sources_per_tile, filter_oob=True ) datum_out["tile_catalog"] = aug_tile.to_dict() return datum_out