Source code for bliss.surveys.des

import warnings
from pathlib import Path
from typing import List, Tuple, TypedDict
from urllib.error import HTTPError

import galsim
import numpy as np
import torch
from astropy.io import fits
from astropy.table import Table
from astropy.wcs import WCS, FITSFixedWarning
from galsim import des as galsim_des
from numpy import char as defchararray
from omegaconf import DictConfig
from pyvo.dal import sia
from scipy.interpolate import RectBivariateSpline
from scipy.ndimage import zoom

from bliss.catalog import FullCatalog, SourceType
from bliss.simulator.psf import ImagePSF, PSFConfig
from bliss.surveys.download_utils import download_file_to_dst
from bliss.surveys.sdss import column_to_tensor
from bliss.surveys.survey import Survey, SurveyDownloader

SkyCoord = TypedDict(
    "SkyCoord",
    {"ra": float, "dec": float},
)
DESImageID = TypedDict(
    "DESImageID",
    {
        "sky_coord": SkyCoord,  # TODO: keep one of {this, decals_brickname}
        "decals_brickname": str,
        "ccdname": str,
        "g": str,
        "r": str,
        "i": str,
        "z": str,
    },
    total=False,
)


[docs] class DarkEnergySurvey(Survey): BANDS = ("g", "r", "i", "z") # cf. https://noirlab.edu/science/programs/ctio/instruments/Dark-Energy-Camera/characteristics GAIN = 4.0 # e-/ADU EXPTIME = 90.0 # s
[docs] @staticmethod def zpt_to_scale(zpt): """Converts a magnitude zero point per sec to nelec/nmgy scale. See also https://github.com/dstndstn/tractor/blob/ \ cdb82000422e85c9c97b134edadff31d68bced0c/tractor/brightness.py#L217C6-L217C6 Args: zpt (float): magnitude zero point per sec Returns: float: nelec/nmgy scale """ return 10.0 ** ((zpt - 22.5) / 2.5)
def __init__( self, psf_config: PSFConfig, dir_path="data/des", image_ids: Tuple[DESImageID] = ( # TODO: maybe find a better/more general image_id representation? { "sky_coord": {"ra": 336.6643042496718, "dec": -0.9316385797930247}, "decals_brickname": "3366m010", "ccdname": "S28", "g": "decam/CP/V4.8.2a/CP20171108/c4d_171109_002003_ooi_g_ls9", "r": "decam/CP/V4.8.2a/CP20170926/c4d_170927_025457_ooi_r_ls9", "i": "", "z": "decam/CP/V4.8.2a/CP20170926/c4d_170927_025655_ooi_z_ls9", }, ), load_image_data: bool = False, ): super().__init__() self.des_path = Path(dir_path) self.load_image_data = load_image_data self.image_id_list = self.process_image_ids(image_ids) self.bands = tuple(range(len(self.BANDS))) self.n_bands = len(self.BANDS) self.downloader = DESDownloader(self.image_id_list, self.des_path) self.prepare_data() self.psf = DES_PSF(dir_path, self.image_ids(), self.bands, psf_config) if self.load_image_data: self._predict_batch = {"images": self[0]["image"], "background": self[0]["background"]}
[docs] def prepare_data(self): self.downloader.download_images() self.downloader.download_backgrounds() self.downloader.download_psfexs()
def __len__(self): return len(self.image_id_list) def __getitem__(self, idx): return self.get_from_disk(idx)
[docs] def image_id(self, idx) -> DESImageID: return self.image_id_list[idx]
[docs] def idx(self, image_id: DESImageID) -> int: return self.image_id_list.index(self.to_dictconfig(image_id))
[docs] def image_ids(self) -> List[DESImageID]: return self.image_id_list
def get_from_disk(self, idx): des_image_id = self.image_id(idx) image_list = [{} for _ in self.BANDS] # first get structure of image data for a present band # get first present band by checking des_image_id[bl] for bl in DES.BANDS first_present_bl = next(bl for bl in DES.BANDS if des_image_id[bl]) first_present_bl_obj = self.read_image_for_band(des_image_id, first_present_bl) image_list[DES.BANDS.index(first_present_bl)] = first_present_bl_obj img_shape = first_present_bl_obj["background"].shape for b, bl in enumerate(self.BANDS): if bl != first_present_bl and des_image_id[bl]: image_list[b] = self.read_image_for_band(des_image_id, bl) elif bl != first_present_bl: image_list[b] = { "background": np.random.rand(*img_shape).astype(np.float32), "wcs": first_present_bl_obj["wcs"], # NOTE: junk; just for format "flux_calibration": np.ones((1,)), } if self.load_image_data: image_list[b].update( {"image": np.zeros(img_shape).astype(np.float32), "sig1": 0.0} ) ret = {} for k in image_list[0]: data_per_band = [image[k] for image in image_list] if isinstance(data_per_band[0], np.ndarray): ret[k] = np.stack(data_per_band) else: ret[k] = data_per_band ret["psf_params"] = self.psf.psf_params[self.image_id(idx)] ret["psf_galsim"] = self.psf.psf_galsim[self.image_id(idx)] return ret def read_image_for_band(self, des_image_id, band): brickname = des_image_id["decals_brickname"] ccdname = des_image_id["ccdname"] image_basename = DESDownloader.image_basename_from_filename(des_image_id[band], band) img_fits_filename = self.des_path / brickname[:3] / brickname / f"{image_basename}.fits" hr = fits.getheader(img_fits_filename, 0) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FITSFixedWarning) wcs = WCS(hr) image_shape = (hr["NAXIS2"], hr["NAXIS1"]) flux_calibration = self.GAIN * hr["EXPTIME"] background_nelec = ( self.splinesky_level_for_band(brickname, ccdname, des_image_id[band], image_shape) * flux_calibration ) d = { "background": background_nelec, "wcs": wcs, "flux_calibration": np.array([flux_calibration]), } if self.load_image_data: image = fits.getdata(img_fits_filename, 0) # TODO: don't use image data to compute sig1 - so DECaLS gen won't load DES images diffs = image[:-5:10, :-5:10] - image[5::10, 5::10] mad = np.median(np.abs(diffs).ravel()) zpscale = DES.zpt_to_scale(hr["MAGZPT"]) sig1 = (1.4826 * mad / np.sqrt(2.0)) / zpscale image_nelec = image.astype(np.float32) * flux_calibration d.update({"image": image_nelec, "sig1": sig1}) return d def splinesky_level_for_band(self, brickname, ccdname, image_filename, image_shape): save_filename = DESDownloader.save_filename_from_image_filename(image_filename) background_fits_filename = ( self.des_path / brickname[:3] / brickname / f"{save_filename}-splinesky.fits" ) background_fits = fits.open(background_fits_filename) background_table_hdu = background_fits[1] background_table = Table.read(background_table_hdu) # Get `row` corresponding to DECam image (i.e., CCD) rows = np.where(background_table["ccdname"] == ccdname)[0] assert len(rows) == 1 row = rows[0] splinesky_params = background_table[row] gridw = splinesky_params["gridw"] gridh = splinesky_params["gridh"] gridvals = splinesky_params["gridvals"] xgrid = splinesky_params["xgrid"] ygrid = splinesky_params["ygrid"] order = splinesky_params["order"] # Meshgrid for pixel coordinates on smaller grid x, y = np.meshgrid(np.arange(gridw), np.arange(gridh)) # Initialize the B-spline sky model with the extracted parameters splinesky_x = RectBivariateSpline(ygrid, xgrid, gridvals, kx=order, ky=order) splinesky_y = RectBivariateSpline(ygrid, xgrid, gridvals, kx=order, ky=order) # Evaluate the sky model at the given pixel coordinates background_values_grid_x = splinesky_x(y.flatten(), x.flatten(), grid=False).reshape( gridh, gridw ) background_values_grid_y = splinesky_y(y.flatten(), x.flatten(), grid=False).reshape( gridh, gridw ) # Upscale the background values from the smaller grid to the original image size using # bi-`order` interpolation background_values_x = zoom( background_values_grid_x, zoom=(image_shape[0] / gridh, image_shape[1] / gridw), order=order, mode="nearest", ).astype(np.float32) background_values_y = zoom( background_values_grid_y, zoom=(image_shape[0] / gridh, image_shape[1] / gridw), order=order, mode="nearest", ).astype(np.float32) # Take the mean of the x and y components return (background_values_x + background_values_y) / 2 def to_dictconfig(self, image_id): # convert sky_coord["ra"], sky_coord["dec"] to np.float32 image_id["sky_coord"] = { "ra": float(image_id["sky_coord"]["ra"]), "dec": float(image_id["sky_coord"]["dec"]), } return DictConfig(image_id) def process_image_ids(self, image_ids) -> List[DictConfig]: im_ids = list(image_ids) for im_id in im_ids: for b in self.BANDS: im_id[b] = im_id.get(b, "") return [self.to_dictconfig(im_id2) for im_id2 in im_ids]
DES = DarkEnergySurvey
[docs] class DESDownloader(SurveyDownloader): """Class for downloading DECaLS data.""" URLBASE = "https://portal.nersc.gov/cfs/cosmo/data/legacysurvey/dr9" DEF_ACCESS_URL = "https://datalab.noirlab.edu/sia/calibrated_all" DECaLS_URLBASE = "https://portal.nersc.gov/cfs/cosmo/data/legacysurvey/dr9" @staticmethod def image_basename_from_filename(image_filename, bl): return f"{image_filename.split('/')[-1].split(f'_{bl}')[0]}_{bl}" @staticmethod def save_filename_from_image_filename(image_filename): return image_filename.split("/")[-1]
[docs] @staticmethod def download_catalog_from_filename(tractor_filename: str): """Download tractor catalog given tractor-<brick_name>.fits filename.""" basename = Path(tractor_filename).name brickname = basename.split("-")[1].split(".")[0] download_file_to_dst( f"{DESDownloader.DECaLS_URLBASE}/south/tractor/{brickname[:3]}/{basename}", tractor_filename, )
def __init__(self, image_ids: List[DESImageID], download_dir): self.band_image_filenames = image_ids self.bricknames = [image_id["decals_brickname"] for image_id in image_ids] self.download_dir = download_dir self.svc = sia.SIAService(self.DEF_ACCESS_URL)
[docs] def download_images(self): """Download images for all bands, for all image_ids.""" for image_id in self.band_image_filenames: brickname = image_id["decals_brickname"] for bl in DES.BANDS: if image_id[bl]: image_basename = DESDownloader.image_basename_from_filename(image_id[bl], bl) self.download_image(brickname, image_id["sky_coord"], image_basename)
[docs] def download_image(self, brickname, sky_coord, image_basename): """Download image for specified band, for this brick/ccd.""" image_table = self.svc.search((sky_coord["ra"], sky_coord["dec"])).to_table() access_urls = image_table["access_url"].filled("").astype(str) sel = defchararray.find(access_urls, image_basename) != -1 image_dst_filename = ( self.download_dir / brickname[:3] / brickname / f"{image_basename}.fits" ) try: download_file_to_dst(image_table[sel][0]["access_url"], image_dst_filename) except IndexError as e: warnings.warn( f"Desired image with basename {image_basename} not found in SIA database.", stacklevel=2, ) raise e except HTTPError as e: warnings.warn( f"No {image_basename} image for brick {brickname} at sky position " f"({sky_coord['ra']}, {sky_coord['dec']}). Check cfg.datasets.des.image_ids.", stacklevel=2, ) raise e
[docs] def download_psfexs(self): """Download psfex files for all image_ids.""" for image_id in self.band_image_filenames: brickname = image_id["decals_brickname"] for bl in DES.BANDS: if image_id[bl]: self.download_psfex(brickname, image_id[bl])
[docs] def download_psfex(self, brickname, image_filename_no_ext): """Download psfex file for specified band, for this brick/ccd.""" save_filename = self.save_filename_from_image_filename(image_filename_no_ext) psfex_filename = ( self.download_dir / brickname[:3] / brickname / f"{save_filename}-psfex.fits" ) try: download_file_to_dst( f"{DESDownloader.URLBASE}/calib/psfex/{image_filename_no_ext}-psfex.fits", psfex_filename, ) except HTTPError as e: warnings.warn( f"No {psfex_filename} image for brick {brickname}. Check " "cfg.datasets.des.image_ids.", stacklevel=2, ) raise e return str(psfex_filename)
[docs] def download_backgrounds(self): """Download sky params for all image_ids.""" for image_id in self.band_image_filenames: brickname = image_id["decals_brickname"] for bl in DES.BANDS: if image_id[bl]: self.download_background(brickname, image_id[bl])
[docs] def download_background(self, brickname, image_filename_no_ext): """Download sky params for specified band, for this brick/ccd.""" save_filename = self.save_filename_from_image_filename(image_filename_no_ext) background_filename = ( self.download_dir / brickname[:3] / brickname / f"{save_filename}-splinesky.fits" ) try: download_file_to_dst( f"{DESDownloader.URLBASE}/calib/sky/{image_filename_no_ext}-splinesky.fits", background_filename, ) except HTTPError as e: warnings.warn( f"No {background_filename} image for brick {brickname}. Check " "cfg.datasets.des.image_ids.", stacklevel=2, ) raise e return str(background_filename)
[docs] def download_catalog(self, des_image_id) -> str: """Download tractor catalog given tractor-<brick_name>.fits filename.""" brickname = des_image_id["decals_brickname"] tractor_filename = str( self.download_dir / brickname[:3] / brickname / f"tractor-{brickname}.fits" ) basename = Path(tractor_filename).name download_file_to_dst( f"{DESDownloader.DECaLS_URLBASE}/south/tractor/{brickname[:3]}/{basename}", tractor_filename, ) return tractor_filename
# NOTE: No DES catalog; re-use DecalsFullCatalog class DES_PSF(ImagePSF): # noqa: N801 # PSF parameters for encoder to learn PARAM_NAMES = [ "chi2", "fit_original", "moffat_alpha", "moffat_beta", "polscal1", "polscal2", "polzero1", "polzero2", "psf_fwhm", "sum_diff", ] def __init__(self, survey_data_dir, image_ids, bands, psf_config: PSFConfig): super().__init__(bands, **psf_config) self.survey_data_dir = survey_data_dir # NOTE: pass `method="no_pixel"` to galsim.drawImage to avoid double-convolution # see https://galsim-developers.github.io/GalSim/_build/html/des.html#des-psf-models self.psf_draw_method = "no_pixel" self.psf_galsim = {} self.psf_params = {} for image_id in image_ids: psf_params = torch.zeros(len(DES.BANDS), len(DES_PSF.PARAM_NAMES)) for b, bl in enumerate(DES.BANDS): if image_id[bl]: psf_params[b] = self._psf_params_for_band(image_id, bl) self.psf_params[image_id] = psf_params self.psf_galsim[image_id] = self.get_psf_via_despsfex(image_id) def _psfex_hdu_for_band_image(self, des_image_id, bl): brickname = des_image_id["decals_brickname"] ccdname = des_image_id["ccdname"] save_filename = DESDownloader.save_filename_from_image_filename(des_image_id[bl]) psfex_fits_filename = ( Path(self.survey_data_dir) / brickname[:3] / brickname / f"{save_filename}-psfex.fits" ) psfex_fits = fits.open(psfex_fits_filename) psfex_table_hdu = psfex_fits[1] # Get `row` corresponding to DECam image (i.e., CCD) rows = np.where(psfex_table_hdu.data["ccdname"] == ccdname)[0] assert len(rows) == 1, f"Found {len(rows)} rows for ccdname {ccdname}; expected 1." psfex_table_hdu.data = psfex_table_hdu.data[rows[0] : rows[0] + 1] return psfex_table_hdu def _psf_params_for_band(self, des_image_id, bl): band_psfex_table_hdu = self._psfex_hdu_for_band_image(des_image_id, bl) psf_params = np.zeros(len(DES_PSF.PARAM_NAMES)) for i, param in enumerate(DES_PSF.PARAM_NAMES): psf_params[i] = band_psfex_table_hdu.data[param] return torch.tensor(psf_params, dtype=torch.float32) def get_psf_via_despsfex(self, des_image_id, px=0.0, py=0.0): """Construct PSF image from PSFEx FITS files. Args: des_image_id (DESImageID): image_id for this image px (float): x image pixel coordinate for PSF center py (float): y image pixel coordinate for PSF center Returns: images (List[InterpolatedImage]): list of psf transformations for each band """ brickname = des_image_id["decals_brickname"] # Filler PSF for bands not in `bands` fake_psf = galsim.InterpolatedImage( galsim.Image(np.random.rand(self.psf_slen, self.psf_slen), scale=1) ).withFlux(1) images = [fake_psf for _ in range(len(DES.BANDS))] for b, bl in enumerate(DES.BANDS): if des_image_id[bl]: psfex_table_hdu = self._psfex_hdu_for_band_image(des_image_id, bl) fmt_psfex_table_hdu = self._format_psfex_table_hdu_for_galsim(psfex_table_hdu) image_basename = DESDownloader.image_basename_from_filename(des_image_id[bl], bl) image_filename = ( Path(self.survey_data_dir) / brickname[:3] / brickname / f"{image_basename}.fits" ) des_psfex_band = galsim_des.DES_PSFEx( fmt_psfex_table_hdu, str(image_filename), ) # TODO: use an appropriate image position for the PSF psf_image = des_psfex_band.getPSF(galsim.PositionD(px, py)) images[b] = psf_image return images def _format_psfex_table_hdu_for_galsim(self, psfex_table_hdu): """Format PSFEx table HDU for use with `galsim.des`.""" # Get single values for the following parameters param_names = [ "polnaxis", "polzero1", "polzero2", "polscal1", "polscal2", "polname1", "polname2", "polngrp", "polgrp1", "polgrp2", "poldeg1", "psfnaxis", "psfaxis1", "psfaxis2", "psfaxis3", "psf_samp", ] # Add to HDU header for param in param_names: psfex_table_hdu.header[param.upper()] = psfex_table_hdu.data[0][param] psfex_table_hdu.header["NAXIS2"] = 1 psfex_table_hdu.header["NAXIS1"] = len(psfex_table_hdu.columns) return psfex_table_hdu
[docs] class TractorFullCatalog(FullCatalog): """Class for the DECaLS Tractor Catalog. Some resources: - https://portal.nersc.gov/cfs/cosmo/data/legacysurvey/dr9/south/sweep/9.0/ - https://www.legacysurvey.org/dr9/files/#sweep-catalogs-region-sweep - https://www.legacysurvey.org/dr5/description/#photometry - https://www.legacysurvey.org/dr9/bitmasks/ """ @staticmethod def _flux_to_mag(flux): return 22.5 - 2.5 * torch.log10(flux)
[docs] @classmethod def from_file( cls, cat_path, wcs: WCS, height, width, band: str = "r", ): """Loads DECaLS catalog from FITS file. Args: cat_path (str): Path to .fits file. band (str): Band to read from. Defaults to "r". wcs (WCS): WCS object for the image. height (int): Height of the image. width (int): Width of the image. Returns: A TractorFullCatalog containing data from the provided file. """ catalog_path = Path(cat_path) if not catalog_path.exists(): DESDownloader.download_catalog_from_filename(catalog_path.name) assert catalog_path.exists(), f"File {catalog_path} does not exist" table = Table.read(catalog_path, format="fits", unit_parse_strict="silent") table = {k.upper(): v for k, v in table.items()} # uppercase keys band = band.capitalize() # filter out pixels that aren't in primary region, had issues with source fitting, # in SGA large galaxy, or in a globular cluster. In the future this should probably # be an input parameter. bitmask = 0b0011010000000001 objid = column_to_tensor(table, "OBJID") objc_type = table["TYPE"].data.astype(str) bits = table["MASKBITS"].data.astype(int) is_galaxy = torch.from_numpy( (objc_type == "DEV") | (objc_type == "REX") | (objc_type == "EXP") | (objc_type == "SER") ) is_star = torch.from_numpy(objc_type == "PSF") ras = column_to_tensor(table, "RA") decs = column_to_tensor(table, "DEC") fluxes = column_to_tensor(table, f"FLUX_{band}") mask = torch.from_numpy((bits & bitmask) == 0).bool() galaxy_bools = is_galaxy & mask & (fluxes > 0) star_bools = is_star & mask & (fluxes > 0) # true light source mask keep = galaxy_bools | star_bools # filter quantities objid = objid[keep] galaxy_bools = galaxy_bools[keep] star_bools = star_bools[keep] ras = ras[keep] decs = decs[keep] fluxes = fluxes[keep] mags = cls._flux_to_mag(fluxes) nobj = objid.shape[0] # get pixel coordinates plocs = cls.plocs_from_ra_dec(ras, decs, wcs) # Verify each tile contains either a star or a galaxy assert torch.all(star_bools + galaxy_bools) source_type = SourceType.STAR * star_bools + SourceType.GALAXY * galaxy_bools d = { "plocs": plocs.reshape(1, nobj, 2), "objid": objid.reshape(1, nobj, 1), "n_sources": torch.tensor((nobj,)), "source_type": source_type.reshape(1, nobj, 1), "fluxes": fluxes.reshape(1, nobj, 1), "mags": mags.reshape(1, nobj, 1), "ra": ras.reshape(1, nobj, 1), "dec": decs.reshape(1, nobj, 1), } return cls(height, width, d)