import bz2
import gzip
import warnings
from pathlib import Path
from typing import List, Tuple, TypedDict
import numpy as np
import torch
from astropy.io import fits
from astropy.table import Table
from astropy.utils.data import download_file
from astropy.wcs import WCS, FITSFixedWarning
from einops import rearrange
from scipy.interpolate import RegularGridInterpolator
from bliss.catalog import FullCatalog, SourceType
from bliss.simulator.psf import ImagePSF, PSFConfig
from bliss.surveys.download_utils import download_file_to_dst
from bliss.surveys.survey import Survey
SDSSFields = List[TypedDict("SDSSField", {"run": int, "camcol": int, "fields": List[int]})]
[docs]
class SloanDigitalSkySurvey(Survey):
BANDS = ("u", "g", "r", "i", "z")
[docs]
@staticmethod
def radec_for_rcf(run, camcol, field) -> Tuple[float, float]:
"""Get center (RA, DEC) for a given run, camcol, field."""
extents = SDSSDownloader.field_extents()
row = extents[
(extents["run"] == run) & (extents["camcol"] == camcol) & (extents["field"] == field)
][0]
ra_center = row["ramin"] + (row["ramax"] - row["ramin"]) / 2
dec_center = row["decmin"] + (row["decmax"] - row["decmin"]) / 2
return (ra_center, dec_center)
[docs]
@staticmethod
def rcf_for_radec(ra, dec) -> Tuple[int, int, int]:
"""Get run, camcol, field for a given RA, DEC."""
extents = SDSSDownloader.field_extents()
row = extents[
(extents["ramin"] <= ra)
& (extents["ramax"] >= ra)
& (extents["decmin"] <= dec)
& (extents["decmax"] >= dec)
][0]
return (row["run"], row["camcol"], row["field"])
def __init__(
self,
psf_config: PSFConfig,
fields,
dir_path="data/sdss",
load_image_data: bool = False,
background_offset=0.0,
align_to_band=None,
crop_to_bands=None,
crop_to_hw=None,
):
super().__init__()
self.sdss_path = Path(dir_path)
self.sdss_fields = fields
self.load_image_data = load_image_data
self.background_offset = background_offset
self.align_to_band = align_to_band
num_frames = sum(len(rcf_conf["fields"]) for rcf_conf in fields)
self.items = [None for _ in range(num_frames)]
self.rcfgcs = []
self.downloader = SDSSDownloader(self.image_ids(), download_dir=str(self.sdss_path))
self.psf = SDSS_PSF(dir_path, self.image_ids(), range(len(self.BANDS)), psf_config)
self.crop_to_bands = crop_to_bands
self.crop_to_hw = crop_to_hw
[docs]
def prepare_data(self):
self.downloader.download_pfs()
for rcf_conf in self.sdss_fields:
run, camcol, fields = rcf_conf["run"], rcf_conf["camcol"], rcf_conf["fields"]
pf_file = f"photoField-{run:06d}-{camcol:d}.fits"
pf_path = self.sdss_path / str(run) / str(camcol) / pf_file
msg = (
f"{pf_path} does not exist. "
+ "Make sure data files are available for specified (run, camcol)."
)
assert Path(pf_path).exists(), msg
pf_fits = fits.getdata(pf_path)
assert pf_fits is not None, f"Could not load fits file {pf_path}."
fieldnums = pf_fits["FIELD"]
fieldgains = pf_fits["GAIN"]
if fields:
for field in fields:
gain = fieldgains[fieldnums == field][0]
self.rcfgcs.append((run, camcol, field, gain))
else:
for field, gain in zip(fieldnums, fieldgains, strict=True):
self.rcfgcs.append((run, camcol, field, gain))
self.downloader.download_images()
for run, camcol, field, _ in self.rcfgcs:
field_path = self.sdss_path / f"{run}/{camcol}/{field}"
for bl in SloanDigitalSkySurvey.BANDS:
frame_name = f"frame-{bl}-{run:06d}-{camcol:d}-{field:04d}.fits"
frame_path = field_path / frame_name
assert Path(frame_path).exists(), f"{frame_path} does not exist."
def __len__(self):
return len(self.rcfgcs)
def __getitem__(self, idx):
if self.items[idx] is None:
self.items[idx] = self.get_from_disk(idx)
return self.items[idx]
def get_from_disk(self, idx):
run, camcol, field, gain = self.rcfgcs[idx]
camcol_dir = self.sdss_path.joinpath(str(run), str(camcol))
field_dir = camcol_dir.joinpath(str(field))
frame_list = []
item = {
"field": field,
"psf_params": self.psf.psf_params[self.image_id(idx)],
"psf_galsim": self.psf.psf_galsim[self.image_id(idx)],
}
for b, bl in enumerate(self.BANDS):
frame = self.read_frame_for_band(bl, field_dir, run, camcol, field, gain[b])
frame_list.append(frame)
for k in frame_list[0]:
band_data = [frame[k] for frame in frame_list]
item[k] = np.stack(band_data) if isinstance(band_data[0], np.ndarray) else band_data
# a hack to deal with underestimated backgrounds
item["background"] += self.background_offset
return item
[docs]
def image_id(self, idx) -> Tuple[int, int, int]:
"""Return the image_id for the given index."""
return self.rcfgcs[idx][:3]
[docs]
def idx(self, image_id: Tuple[int, int, int]) -> int:
"""Return the index for the given image_id."""
r, c, f = image_id
# Return first index that matches r, c, f
return next(
i
for i, (run, camcol, field, _) in enumerate(self.rcfgcs)
if (run, camcol, field) == (r, c, f)
)
[docs]
def image_ids(self) -> List[Tuple[int, int, int]]:
"""Return all image_ids.
Note: Parallel to `rcfgcs`.
Returns:
List[Tuple[int, int, int]]: List of (run, camcol, field) image_ids.
"""
rcfs = []
for rcf_conf in self.sdss_fields:
run, camcol, fields = rcf_conf["run"], rcf_conf["camcol"], rcf_conf["fields"]
for field in fields:
rcfs.append((run, camcol, field))
return rcfs
def read_frame_for_band(self, bl, field_dir, run, camcol, field, gain):
frame_name = f"frame-{bl}-{run:06d}-{camcol:d}-{field:04d}.fits"
frame_path = str(field_dir.joinpath(frame_name))
calibration = fits.getdata(frame_path, 1)
nelec_per_nmgy = gain / calibration
sky_data = fits.getdata(frame_path, 2)
sky_small = sky_data["ALLSKY"][0]
sky_x = sky_data["XINTERP"][0]
sky_y = sky_data["YINTERP"][0]
small_rows = np.mgrid[0 : sky_small.shape[0]]
small_cols = np.mgrid[0 : sky_small.shape[1]]
small_rcs = (small_rows, small_cols)
sky_interp = RegularGridInterpolator(small_rcs, sky_small, method="nearest")
sky_y = sky_y.clip(0, sky_small.shape[0] - 1)
sky_x = sky_x.clip(0, sky_small.shape[1] - 1)
large_points = rearrange(np.stack(np.meshgrid(sky_y, sky_x), axis=0), "n x y -> y x n")
large_sky = sky_interp(large_points)
large_sky_nelec = large_sky * gain
if self.load_image_data:
pixels_ss_nmgy = fits.getdata(frame_path, 0)
pixels_ss_nelec = pixels_ss_nmgy * nelec_per_nmgy
pixels_nelec = pixels_ss_nelec + large_sky_nelec
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FITSFixedWarning)
wcs = WCS(fits.getheader(frame_path, 0))
d = {
"background": large_sky_nelec,
"gain": np.array(gain),
"flux_calibration": nelec_per_nmgy,
"wcs": wcs,
}
if self.load_image_data:
d["image"] = pixels_nelec
return d
[docs]
class SDSSDownloader:
"""Class for downloading SDSS data."""
URLBASE = "https://data.sdss.org/sas/dr12/boss"
@staticmethod
def stripped(val):
return str(val).lstrip("0")
@staticmethod
def run6(run) -> str:
return f"{int(SDSSDownloader.stripped(run)):06d}"
@staticmethod
def field4(field) -> str:
return f"{int(SDSSDownloader.stripped(field)):04d}"
@staticmethod
def subdir2(run, camcol) -> str:
return f"{SDSSDownloader.stripped(run)}/{camcol}"
@staticmethod
def subdir3(run, camcol, field) -> str:
return f"{SDSSDownloader.subdir2(run, camcol)}/{SDSSDownloader.stripped(field)}"
def __init__(self, image_ids, download_dir):
self.image_ids = image_ids
self.download_dir = download_dir
@classmethod
def download_field_extents(cls):
# Download and use field-extents in memory
field_extents_filename = download_file(
"https://portal.nersc.gov/project/dasrepo/celeste/field_extents.fits",
cache=True,
timeout=10,
)
cls._field_extents = Table.read(field_extents_filename)
[docs]
@classmethod
def field_extents(cls) -> Table:
"""Get field extents table."""
if not getattr(cls, "_field_extents", None):
cls.download_field_extents()
return cls._field_extents
def download_pfs(self):
for image_id in self.image_ids:
run, camcol, _ = image_id
self.download_pf(run, camcol)
def download_pf(self, run, camcol):
download_file_to_dst(
f"{SDSSDownloader.URLBASE}/photoObj/301/{self.stripped(run)}/"
f"photoField-{self.run6(run)}-{camcol}.fits",
f"{self.download_dir}/{self.subdir2(run, camcol)}/"
f"photoField-{self.run6(run)}-{camcol}.fits",
)
def download_catalogs(self):
for image_id in self.image_ids:
run, camcol, field = image_id
self.download_catalog((run, camcol, field))
def download_catalog(self, rcf) -> str:
run, camcol, field = rcf
cat_path = (
f"{self.download_dir}/{self.subdir3(run, camcol, field)}/"
f"photoObj-{self.run6(run)}-{camcol}-{self.field4(field)}.fits"
)
download_file_to_dst(
f"{SDSSDownloader.URLBASE}/photoObj/301/{self.stripped(run)}/{camcol}/"
f"photoObj-{self.run6(run)}-{camcol}-{self.field4(field)}.fits",
cat_path,
)
return cat_path
def download_images(self):
for image_id in self.image_ids:
run, camcol, field = image_id
for bl in SloanDigitalSkySurvey.BANDS:
self.download_image(run, camcol, field, bl)
def download_image(self, run, camcol, field, band="r"):
download_file_to_dst(
f"{SDSSDownloader.URLBASE}/photo/redux/301/{self.stripped(run)}/objcs/{camcol}/"
f"fpM-{self.run6(run)}-{band}{camcol}-{self.field4(field)}.fit.gz",
f"{self.download_dir}/{self.subdir3(run, camcol, field)}/"
f"fpM-{self.run6(run)}-{band}{camcol}-{self.field4(field)}.fits",
gzip.decompress,
)
download_file_to_dst(
f"{SDSSDownloader.URLBASE}/photoObj/frames/301/{self.stripped(run)}/{camcol}/"
f"frame-{band}-{self.run6(run)}-{camcol}-{self.field4(field)}.fits.bz2",
f"{self.download_dir}/{self.subdir3(run, camcol, field)}/"
f"frame-{band}-{self.run6(run)}-{camcol}-{self.field4(field)}.fits",
bz2.decompress,
)
def download_psfields(self):
for image_id in self.image_ids:
run, camcol, field = image_id
self.download_psfield(run, camcol, field)
def download_psfield(self, run, camcol, field):
download_file_to_dst(
f"{SDSSDownloader.URLBASE}/photo/redux/301/{self.stripped(run)}/objcs/{camcol}/"
f"psField-{self.run6(run)}-{camcol}-{self.field4(field)}.fit",
f"{self.download_dir}/{self.subdir3(run, camcol, field)}/"
f"psField-{self.run6(run)}-{camcol}-{self.field4(field)}.fits",
)
def download_all(self):
if not Path(self.download_dir).exists():
Path(self.download_dir).mkdir(parents=True, exist_ok=True)
self.download_pfs()
self.download_catalogs()
self.download_images()
self.download_psfields()
[docs]
class PhotoFullCatalog(FullCatalog):
"""Class for the SDSS PHOTO Catalog.
Some resources:
- https://www.sdss.org/dr12/algorithms/classify/
- https://www.sdss.org/dr12/algorithms/resolve/
"""
[docs]
@classmethod
def from_file(cls, cat_path, wcs: WCS, height, width):
"""Instantiates PhotoFullCatalog with RCF and WCS information from disk."""
assert Path(cat_path).exists(), f"File {cat_path} does not exist"
table = fits.getdata(cat_path)
# Convert table entries to tensors
objc_type = column_to_tensor(table, "objc_type")
thing_id = column_to_tensor(table, "thing_id")
ras = column_to_tensor(table, "ra")
decs = column_to_tensor(table, "dec")
galaxy_bools = (objc_type == 3) & (thing_id != -1)
star_bools = (objc_type == 6) & (thing_id != -1)
# Combine light source parameters to one tensor
star_fluxes = column_to_tensor(table, "psfflux") * star_bools.reshape(-1, 1)
galaxy_fluxes = column_to_tensor(table, "cmodelflux") * galaxy_bools.reshape(-1, 1)
fluxes = star_fluxes + galaxy_fluxes
# true light source mask
keep = galaxy_bools | star_bools
galaxy_bools = galaxy_bools[keep]
star_bools = star_bools[keep]
ras = ras[keep]
decs = decs[keep]
fluxes = fluxes[keep]
nobj = ras.shape[0]
# We require all 5 bands for computing loss on predictions.
n_bands = len(SloanDigitalSkySurvey.BANDS)
# get pixel coordinates
plocs = cls.plocs_from_ra_dec(ras, decs, wcs)
# Verify each tile contains either a star or a galaxy
assert torch.all(star_bools + galaxy_bools)
source_type = SourceType.STAR * star_bools + SourceType.GALAXY * galaxy_bools
d = {
"plocs": plocs.reshape(1, nobj, 2),
"n_sources": torch.tensor((nobj,)),
"source_type": source_type.reshape(1, nobj, 1),
"fluxes": fluxes.reshape(1, nobj, n_bands),
"ra": ras.reshape(1, nobj, 1),
"dec": decs.reshape(1, nobj, 1),
}
return cls(height, width, d)
[docs]
def restrict_by_ra_dec(self, ra_lim, dec_lim):
"""Helper function to restrict photo catalog to within RA and DEC limits."""
ra = self["ra"].squeeze()
dec = self["dec"].squeeze()
keep = (ra > ra_lim[0]) & (ra < ra_lim[1]) & (dec >= dec_lim[0]) & (dec <= dec_lim[1])
plocs = self["plocs"][:, keep]
n_sources = torch.tensor([plocs.size()[1]])
d = {"n_sources": n_sources}
for key, val in self.items():
if key != "n_sources":
d[key] = val[:, keep]
return PhotoFullCatalog(
int(plocs[0, :, 0].max() - plocs[0, :, 0].min()), # new height
int(plocs[0, :, 1].max() - plocs[0, :, 1].min()), # new width
d,
)
class SDSS_PSF(ImagePSF): # noqa: N801
@staticmethod
def _get_fit_file_psf_params(psf_fit_file: str, bands: Tuple[int, ...]):
"""Load psf parameters from fits file.
See https://data.sdss.org/datamodel/files/PHOTO_REDUX/RERUN/RUN/objcs/CAMCOL/psField.html
for details on the parameters.
Args:
psf_fit_file (str): file to load from
bands (Tuple[int, ...]): SDSS bands to load
Returns:
psf_params: tensor of parameters for each band
"""
msg = (
f"{psf_fit_file} does not exist. "
f"Make sure data files are available for fields specified in config."
)
assert Path(psf_fit_file).exists(), msg
# HDU 6 contains the PSF header (after primary and eigenimages)
data = fits.open(psf_fit_file, ignore_missing_end=True).pop(6).data
psf_params = torch.zeros(len(bands), 6)
for i, band in enumerate(bands):
sigma1 = data["psf_sigma1"][0][band] ** 2
sigma2 = data["psf_sigma2"][0][band] ** 2
sigmap = data["psf_sigmap"][0][band] ** 2
beta = data["psf_beta"][0][band]
b = data["psf_b"][0][band]
p0 = data["psf_p0"][0][band]
psf_params[i] = torch.tensor([sigma1, sigma2, sigmap, beta, b, p0])
return psf_params
def __init__(self, survey_data_dir, image_ids, bands, psf_config: PSFConfig):
super().__init__(bands, **psf_config)
self.psf_galsim = {}
self.psf_params = {}
SDSSDownloader(image_ids, download_dir=survey_data_dir).download_psfields()
for run, camcol, field in image_ids:
# load raw params from file
field_dir = f"{survey_data_dir}/{run}/{camcol}/{field}"
filename = f"{field_dir}/psField-{run:06}-{camcol}-{field:04}.fits"
assert Path(filename).exists(), f"psField file {filename} not found"
psf_params = self._get_fit_file_psf_params(filename, bands)
# load psf image from params
self.psf_params[(run, camcol, field)] = psf_params
self.psf_galsim[(run, camcol, field)] = self._get_psf(psf_params)
def _psf_fun(self, r, sigma1, sigma2, sigmap, beta, b, p0):
"""Generate the PSF from the parameters using the power-law model.
See https://data.sdss.org/datamodel/files/PHOTO_REDUX/RERUN/RUN/objcs/CAMCOL/psField.html
for details on the parameters and the equation used.
Args:
r: radius
sigma1: Inner gaussian sigma for the composite fit
sigma2: Outer gaussian sigma for the composite fit
sigmap: Width parameter for power law (pixels)
beta: Slope of power law.
b: Ratio of the outer PSF to the inner PSF at the origin
p0: The value of the power law at the origin.
Returns:
The psf function evaluated at r.
"""
term1 = torch.exp(-(r**2) / (2 * sigma1))
term2 = b * torch.exp(-(r**2) / (2 * sigma2))
term3 = p0 * (1 + r**2 / (beta * sigmap)) ** (-beta / 2)
return (term1 + term2 + term3) / (1 + b + p0)
def nelec_to_nmgy_for_catalog(est_cat, nelec_per_nmgy_per_band):
fluxes_suffix = "_fluxes"
# reshape nelec_per_nmgy_per_band to (1, 1, 1, 1, {n_bands}) to broadcast
nelec_per_nmgy_per_band = torch.tensor(nelec_per_nmgy_per_band, device=est_cat.device)
nelec_per_nmgy_per_band = nelec_per_nmgy_per_band.view(1, 1, 1, 1, -1)
for key in est_cat:
if key.endswith(fluxes_suffix):
est_cat[key] = est_cat[key] / nelec_per_nmgy_per_band
return est_cat
def column_to_tensor(table, colname):
dtypes = {
np.dtype(">i2"): int,
np.dtype(">i4"): int,
np.dtype(">i8"): int,
np.dtype("bool"): bool,
np.dtype(">f4"): np.float32,
np.dtype(">f8"): np.float32,
np.dtype("float32"): np.float32,
np.dtype("float64"): np.dtype("float64"),
}
x = np.array(table[colname])
dtype = dtypes[x.dtype]
x = x.astype(dtype)
return torch.from_numpy(x)