BLISS User API

Bayesian Light Source Separator (BLISS) is a Bayesian method for deblending and cataloging light sources.

Installation

[ ]:
!pip install -e /home/zhteoh/770-bulk-predict
[ ]:
!pip install bliss-deblender

Tutorial

[6]:
from bliss.api import BlissClient

# bliss_client = BlissClient(cwd="/data/scratch/zhteoh/tutorial")
bliss_client = BlissClient(cwd="/tmp/pytest-of-zhteoh/pytest-417")

Train the model

Generate synthetic image data

[5]:
bliss_client.generate(
    n_batches=3,
    batch_size=64,
    max_images_per_file=128
)
Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset
Simulating images in batches for file: 100%|██████████| 2/2 [04:34<00:00, 137.09s/it]
Simulating images in batches for file: 100%|██████████| 2/2 [04:41<00:00, 140.52s/it]3s/it]
Generating and writing cached dataset files: 100%|██████████| 2/2 [09:15<00:00, 277.74s/it]
Pass additional custom configuration parameters
[7]:
# Alter default cached_data_path
bliss_client.cached_data_path = "/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02"

bliss_client.generate(
    n_batches=3,  # required
    batch_size=64,  # required
    max_images_per_file=128,  # required
    simulator={"survey": {"prior_config": {"mean_sources": 0.02}}},  # optional
    generate={"file_prefix": "dataset"},  # optional
)
Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02
Simulating images in batches for file: 100%|██████████| 2/2 [01:01<00:00, 30.78s/it]
Simulating images in batches for file: 100%|██████████| 2/2 [01:06<00:00, 33.06s/it]7s/it]
Generating and writing cached dataset files: 100%|██████████| 2/2 [02:07<00:00, 63.95s/it]
[ ]:
bliss_client.cached_data_path = "/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02"
[ ]:
# Check that the dataset is generated
!ls /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02
!du -sh /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02
# !cat /data/scratch/zhteoh/tutorial/dataset/hparams.yaml

print("Dataset:", bliss_client.cached_data_path)
dataset_0 = bliss_client.get_dataset_file(filename="dataset_0.pt")
print(" Size:", len(dataset_0))
print(" Shape:", dataset_0[0]["images"].shape)

Train the model

Without pretrained weights
[ ]:
bliss_client.train(weight_save_path="tutorial_encoder/0.pt")
With pretrained weights

Download our relevant pretrained weights for your sky survey.

[ ]:
import os
assert os.path.exists("/data/scratch/zhteoh/tutorial/data/pretrained_models")

bliss_client.load_pretrained_weights_for_survey(survey="sdss", filename="sdss_pretrained.pt")

!ls /data/scratch/zhteoh/tutorial/data/pretrained_models
Train on cached generated disk dataset
[ ]:
bliss_client.train_on_cached_data(
    weight_save_path="tutorial_encoder/0.pt",
    train_n_batches=2,
    batch_size=64,
    val_split_file_idxs=[1],
    pretrained_weights_filename=None,
)

Run the model

Using sample SDSS dataset

Get predictions for the sample dataset
[3]:
est_cat, est_cat_table, pred_tables = bliss_client.predict_sdss(
    weight_save_path="tutorial_encoder/zscore_five_band.pt",
    # predict={"dataset": {"sdss_fields": [{"run": 94, "camcol": 1, "fields": [12]}, {"run": 3900, "camcol": 6, "fields": [296]}]}},
)

                 from  n    params  module                                  arguments
  0                -1  1     16128  yolov5.models.common.Conv               [10, 64, 5, 1]
  1                -1  3     12672  yolov5.models.common.Conv               [64, 64, 1, 1]
  2                -1  1     73984  yolov5.models.common.Conv               [64, 128, 3, 2]
  3                -1  1    147712  yolov5.models.common.Conv               [128, 128, 3, 1]
  4                -1  1    295424  yolov5.models.common.Conv               [128, 256, 3, 2]
  5                -1  6   1118208  yolov5.models.common.C3                 [256, 256, 6]
  6                -1  1   1180672  yolov5.models.common.Conv               [256, 512, 3, 2]
  7                -1  9   6433792  yolov5.models.common.C3                 [512, 512, 9]
  8                -1  1   4720640  yolov5.models.common.Conv               [512, 1024, 3, 2]
  9                -1  3   9971712  yolov5.models.common.C3                 [1024, 1024, 3]
 10                -1  1   2624512  yolov5.models.common.SPPF               [1024, 1024, 5]
 11                -1  1    525312  yolov5.models.common.Conv               [1024, 512, 1, 1]
 12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 13           [-1, 6]  1         0  yolov5.models.common.Concat             [1]
 14                -1  3   2757632  yolov5.models.common.C3                 [1024, 512, 3, False]
 15                -1  1    131584  yolov5.models.common.Conv               [512, 256, 1, 1]
 16                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 17        [-1, 4, 5]  1         0  yolov5.models.common.Concat             [1]
 18                -1  3    756224  yolov5.models.common.C3                 [768, 256, 3, False]
 19              [17]  1     29222  yolov5.models.yolo.Detect               [33, [[4, 4]], [768]]
Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

                 from  n    params  module                                  arguments
  0                -1  1     16128  yolov5.models.common.Conv               [10, 64, 5, 1]
  1                -1  3     12672  yolov5.models.common.Conv               [64, 64, 1, 1]
  2                -1  1     73984  yolov5.models.common.Conv               [64, 128, 3, 2]
  3                -1  1    147712  yolov5.models.common.Conv               [128, 128, 3, 1]
  4                -1  1    295424  yolov5.models.common.Conv               [128, 256, 3, 2]
  5                -1  6   1118208  yolov5.models.common.C3                 [256, 256, 6]
  6                -1  1   1180672  yolov5.models.common.Conv               [256, 512, 3, 2]
  7                -1  9   6433792  yolov5.models.common.C3                 [512, 512, 9]
  8                -1  1   4720640  yolov5.models.common.Conv               [512, 1024, 3, 2]
  9                -1  3   9971712  yolov5.models.common.C3                 [1024, 1024, 3]
 10                -1  1   2624512  yolov5.models.common.SPPF               [1024, 1024, 5]
 11                -1  1    525312  yolov5.models.common.Conv               [1024, 512, 1, 1]
 12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 13           [-1, 6]  1         0  yolov5.models.common.Concat             [1]
 14                -1  3   2757632  yolov5.models.common.C3                 [1024, 512, 3, False]
 15                -1  1    131584  yolov5.models.common.Conv               [512, 256, 1, 1]
 16                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 17        [-1, 4, 5]  1         0  yolov5.models.common.Concat             [1]
 18                -1  3    756224  yolov5.models.common.C3                 [768, 256, 3, False]
 19              [17]  1     29222  yolov5.models.yolo.Detect               [33, [[4, 4]], [768]]
Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
[5]:
bliss_client.plot_predictions_in_notebook()
Bokeh Plot
[4]:
print("Number of entries:", len(est_cat_table))
est_cat_table[:5].show_in_notebook(display_length=5)
Number of entries: 254
[4]:
Table length=5
idxplocssource_typestar_flux_ugalaxy_flux_ustar_flux_ggalaxy_flux_gstar_flux_rgalaxy_flux_rstar_flux_igalaxy_flux_istar_flux_zgalaxy_flux_zgalaxy_disk_fracgalaxy_beta_radiansgalaxy_disk_qgalaxy_a_dgalaxy_bulge_qgalaxy_a_b
nmgynmgynmgynmgynmgynmgynmgynmgynmgynmgy
0tensor([266.37012, 101.83371])tensor([1])1.9210918e-071.88252950.651014574.0829490.83070425.49383263.483197723.9213753.790694242.905990.0041288760.0034716730.000761442240.00214500030.00328061381.7661792
1tensor([549.19055, 274.24149])tensor([1])1.3584375e-050.263266950.1219470050.70436280.4644341.75929521.95451758.9221662.766725516.5474130.00283737530.00255063460.00073144650.0054546640.00453693931.13643
2tensor([236.75182, 576.83881])tensor([1])0.000119498620.085143840.369022550.80455710.599738061.08202061.48342373.12111570.0099486162.36346460.00207926150.00251216980.00059752040.00253470310.00413356260.96826047
3tensor([248.72098, 260.99545])tensor([1])4.0953696e-060.36880590.429168340.979606870.58036841.01144361.19781122.89758940.00262587753.12889080.002359160.0027015970.000583882850.00323049210.00374381221.0822673
4tensor([ 7.23273, 211.16307])tensor([1])0.170723780.885299150.26931830.64269190.51498930.953418430.89364082.16016530.0152116163.360910.00249141780.00450364030.000305592520.00282802670.00305669081.2435206
Inspect probabilistic predictions

BLISS produces probability distributions on the predicted latent variables.

[5]:
print("Number of entries (RCF (94, 1, 12)):", len(pred_tables[(94, 1, 12)]))
pred_tables[(94, 1, 12)][:5].show_in_notebook(display_length=5)
Number of entries (RCF (94, 1, 12)): 24964
[5]:
Table length=5
idxon_prob_falseon_prob_truegalaxy_prob_falsegalaxy_prob_truegalsim_disk_frac_meangalsim_disk_frac_stdgalsim_beta_radians_meangalsim_beta_radians_stdgalsim_disk_q_meangalsim_disk_q_stdgalsim_a_d_meangalsim_a_d_stdgalsim_bulge_q_meangalsim_bulge_q_stdgalsim_a_b_meangalsim_a_b_stdstar_flux_u_meanstar_flux_u_stdstar_flux_g_meanstar_flux_g_stdstar_flux_r_meanstar_flux_r_stdstar_flux_i_meanstar_flux_i_stdstar_flux_z_meanstar_flux_z_stdgalaxy_flux_u_meangalaxy_flux_u_stdgalaxy_flux_g_meangalaxy_flux_g_stdgalaxy_flux_r_meangalaxy_flux_r_stdgalaxy_flux_i_meangalaxy_flux_i_stdgalaxy_flux_z_meangalaxy_flux_z_std
radradarcsecarcsecarcsecarcsecnmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgy
00.99678020.00321976540.000432789330.99956720.57243181.85244510.37481451.66494390.470053432.07418161.23134020.81039680.066831352.09495470.454636571.23431223.34724663.89838984.31438541.18222395.6566250.493259255.8699220.752895954.53218656.9602565.04342751.52036065.5714180.858202466.75043870.376854037.3013840.460732466.12694452.2173305
10.99980330.000196710480.0001000165940.99991.59988572.190470.50159981.75457870.66870571.86497261.63008430.697890160.434260371.99820510.648199561.35815811.56567624.78707363.52552181.03309994.5661420.936750773.79787452.5735153.591493.69056994.38897132.08299145.90586950.68660566.60502150.334762286.94263650.465261376.59285831.2505308
20.99991e-040.000172615050.99982741.86665422.04445770.437796121.72137580.65551781.74150441.84740090.644446550.181624891.83131380.681199551.27925161.21534016.09186653.17501641.34800454.14894960.99908193.6976182.75835084.22460272.578333.95683032.859255.93399050.815708046.5072230.430997976.77234360.703732437.22697161.1791977
30.999846460.00015353810.00010395050.999896051.62800341.8893498-0.230698592.16578960.479752781.80081022.15197730.781493070.60241221.81481230.90130571.46621541.84037594.16382843.77111480.99872534.81160550.430488084.63790461.15260224.84519961.14406314.0424293.65394545.92075060.97080376.78137870.50634097.14950940.90572317.5742961.3889303
40.99973310.00026693050.0001000165940.99991.42483881.7584034-0.00107550621.95953180.631301641.49456482.07881470.804047170.585586551.81606280.755789761.33613012.65422825.42983343.81303170.97062625.10223770.412487184.7252271.69422765.42944430.879893545.010332.47703775.8588990.79444186.77067570.430732677.10946850.73547147.78264330.8656995
[9]:
print("Number of entries (RCF (3900, 6, 269)):", len(pred_tables[(3900, 6, 269)]))
pred_tables[(3900, 6, 269)][:5].show_in_notebook(display_length=5)
Number of entries (RCF (3900, 6, 269)): 24964
[9]:
Table length=5
idxon_prob_falseon_prob_truegalaxy_prob_falsegalaxy_prob_truegalsim_disk_frac_meangalsim_disk_frac_stdgalsim_beta_radians_meangalsim_beta_radians_stdgalsim_disk_q_meangalsim_disk_q_stdgalsim_a_d_meangalsim_a_d_stdgalsim_bulge_q_meangalsim_bulge_q_stdgalsim_a_b_meangalsim_a_b_stdstar_flux_u_meanstar_flux_u_stdstar_flux_g_meanstar_flux_g_stdstar_flux_r_meanstar_flux_r_stdstar_flux_i_meanstar_flux_i_stdstar_flux_z_meanstar_flux_z_stdgalaxy_flux_u_meangalaxy_flux_u_stdgalaxy_flux_g_meangalaxy_flux_g_stdgalaxy_flux_r_meangalaxy_flux_r_stdgalaxy_flux_i_meangalaxy_flux_i_stdgalaxy_flux_z_meangalaxy_flux_z_std
radradarcsecarcsecarcsecarcsecnmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgy
00.99989070.000109314220.0001000165940.99991.84275461.89667670.17593551.53407651.01268172.11461621.56248310.75048670.0134491921.8136890.74573851.43309412.0786724.3719253.13351541.34688234.68619251.0258933.80023433.33322934.45268733.04674484.95785142.14614345.6807641.05425816.62231450.635017346.690810.90154777.42443661.308013
10.99965440.00034560130.00107365850.998926341.15804361.89814640.103563311.51382-0.00156497961.75042211.35188320.73613540.0235483652.10106330.389216421.19684793.09337473.34775354.41250130.78151085.56520840.423210625.57749030.881161875.874461.55818214.57809352.01080065.57082650.66346946.45614430.377527336.81918530.53798257.3068221.0409187
20.999850030.000149940450.000612974170.9993871.35574171.8018293-0.263671641.79469860.308890582.03052351.45481280.69167155-0.01733472.20347360.38862181.15886843.28078752.61924154.02236840.911264845.18662640.530872465.03069161.36111965.0993452.6539875.01789951.6534545.66885380.62892766.5837280.349233666.8162720.56851187.05228231.2957809
30.99963860.00036139740.000129461290.999870542.03241321.7666548-0.174253461.89105331.19250941.56565361.69150380.63886320.0470974452.09073070.455917841.46029161.69576122.82567332.98597720.97972883.79730250.755215173.61664341.58537462.90772275.69611025.23357872.46308786.19858260.8754326.8619920.542729267.11725330.742677336.76400572.2090604
40.99889680.00110319640.0001000165940.99991.74303221.6583972-0.0157330041.48242011.11489131.67240631.83293460.69212830.238222122.1990880.67237711.46318842.63429453.11351283.82552340.806065265.1348380.3358475.23295550.84052855.0726951.71613415.0337253.0645265.80013940.96344746.84038540.549091647.3483440.728391237.51123331.2816015
Save predicted catalog to FITS file
[ ]:
est_cat_table.write("est_cat.fits", format="fits", overwrite=True)
[ ]:
# Check that catalog is saved as intended
from astropy.table import Table

est_cat_table = Table.read("est_cat.fits", format="fits")
print("Number of entries:", len(est_cat_table))
est_cat_table.show_in_notebook(display_length=5)
Evaluate prediction
[ ]:
import torch

from bliss.metrics import BlissMetrics
from bliss.surveys.sdss import PhotoFullCatalog

sdss_data_path = "/data/scratch/zhteoh/tutorial/data/sdss"
sdss = SloanDigitalSkySurvey()
photo_cat = PhotoFullCatalog.from_file()

est_cat_cuda = est_cat.to(torch.device("cpu"))
photo_cat_cuda = photo_cat.to(torch.device("cpu"))

metrics = BlissMetrics()
results = metrics(est_cat_cuda, photo_cat_cuda)

print(results)

Using user-specified SDSS dataset

Download online dataset
[ ]:
from astropy.coordinates import SkyCoord
from astroquery.sdss import SDSS
from pathlib import Path

pos = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') # 1011/3/44
# pos = SkyCoord("1h8m05.73s +13d10m20.3s", frame="icrs") # 4829/5/27
# pos = SkyCoord("1h2m05.83s -2d11m20.3s", frame="icrs") # 2699/4/71
region = SDSS.query_region(pos, radius="5 arcsec")
run, camcol, field = region["run"][0], region["camcol"][0], region["field"][0]
print("run:", run, "camcol:", camcol, "field:", field)
bliss_client.load_survey("sdss", run, camcol, field, download_dir=Path("data/sdss"))
Get predictions for the downloaded dataset
[ ]:
est_cat_dl, est_cat_table_dl, pred_tables_dl = bliss_client.predict_sdss(
    data_path="data/sdss",
    weight_save_path="tutorial_encoder/0.pt",
    predict={"dataset": {"run": 1011, "camcol": 3, "fields": [44]}}
)
[ ]:
bliss_client.plot_predictions_in_notebook()
Inspect probabilistic predictions
[ ]:
print("Number of entries:", len(pred_tables_dl[(1011, 3, 44)]))
pred_tables_dl[(1011, 3, 44)][:5].show_in_notebook(display_length=5)

Using sample DECaLS dataset

[4]:
est_cat, est_cat_table, pred_tables = bliss_client.predict_decals(
    weight_save_path="tutorial_encoder/single_band_base.pt",
    predict={
        "dataset": {
            "sky_coords": [
                # brick '3366m010' corresponds to SDSS RCF 94-1-12
                {"ra": 336.6643042496718, "dec": -0.9316385797930247},
                # brick '1358p297' corresponds to SDSS RCF 3635-1-169
                {"ra": 135.95496736941683, "dec": 29.646883837721347},
            ]
        }
    },
)

                 from  n    params  module                                  arguments
  0                -1  1      3328  yolov5.models.common.Conv               [2, 64, 5, 1]
  1                -1  3     12672  yolov5.models.common.Conv               [64, 64, 1, 1]
  2                -1  1     73984  yolov5.models.common.Conv               [64, 128, 3, 2]
  3                -1  1    147712  yolov5.models.common.Conv               [128, 128, 3, 1]
  4                -1  1    295424  yolov5.models.common.Conv               [128, 256, 3, 2]
  5                -1  6   1118208  yolov5.models.common.C3                 [256, 256, 6]
  6                -1  1   1180672  yolov5.models.common.Conv               [256, 512, 3, 2]
  7                -1  9   6433792  yolov5.models.common.C3                 [512, 512, 9]
  8                -1  1   4720640  yolov5.models.common.Conv               [512, 1024, 3, 2]
  9                -1  3   9971712  yolov5.models.common.C3                 [1024, 1024, 3]
 10                -1  1   2624512  yolov5.models.common.SPPF               [1024, 1024, 5]
 11                -1  1    525312  yolov5.models.common.Conv               [1024, 512, 1, 1]
 12                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 13           [-1, 6]  1         0  yolov5.models.common.Concat             [1]
 14                -1  3   2757632  yolov5.models.common.C3                 [1024, 512, 3, False]
 15                -1  1    131584  yolov5.models.common.Conv               [512, 256, 1, 1]
 16                -1  1         0  torch.nn.modules.upsampling.Upsample    [None, 2, 'nearest']
 17        [-1, 4, 5]  1         0  yolov5.models.common.Concat             [1]
 18                -1  3    756224  yolov5.models.common.C3                 [768, 256, 3, False]
 19              [17]  1     29222  yolov5.models.yolo.Detect               [33, [[4, 4]], [768]]
Model summary: 275 layers, 30782630 parameters, 30782630 gradients, 363.8 GFLOPs

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
[ ]:
bliss_client.plot_predictions_in_notebook()