BLISS User API
Bayesian Light Source Separator (BLISS) is a Bayesian method for deblending and cataloging light sources.
Installation
[ ]:
!pip install -e /home/zhteoh/770-bulk-predict
[ ]:
!pip install bliss-deblender
Tutorial
[6]:
from bliss.api import BlissClient
# bliss_client = BlissClient(cwd="/data/scratch/zhteoh/tutorial")
bliss_client = BlissClient(cwd="/tmp/pytest-of-zhteoh/pytest-417")
Train the model
Generate synthetic image data
[5]:
bliss_client.generate(
n_batches=3,
batch_size=64,
max_images_per_file=128
)
Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset
Simulating images in batches for file: 100%|██████████| 2/2 [04:34<00:00, 137.09s/it]
Simulating images in batches for file: 100%|██████████| 2/2 [04:41<00:00, 140.52s/it]3s/it]
Generating and writing cached dataset files: 100%|██████████| 2/2 [09:15<00:00, 277.74s/it]
Pass additional custom configuration parameters
[7]:
# Alter default cached_data_path
bliss_client.cached_data_path = "/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02"
bliss_client.generate(
n_batches=3, # required
batch_size=64, # required
max_images_per_file=128, # required
simulator={"survey": {"prior_config": {"mean_sources": 0.02}}}, # optional
generate={"file_prefix": "dataset"}, # optional
)
Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02
Simulating images in batches for file: 100%|██████████| 2/2 [01:01<00:00, 30.78s/it]
Simulating images in batches for file: 100%|██████████| 2/2 [01:06<00:00, 33.06s/it]7s/it]
Generating and writing cached dataset files: 100%|██████████| 2/2 [02:07<00:00, 63.95s/it]
[ ]:
bliss_client.cached_data_path = "/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02"
[ ]:
# Check that the dataset is generated
!ls /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02
!du -sh /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02
# !cat /data/scratch/zhteoh/tutorial/dataset/hparams.yaml
print("Dataset:", bliss_client.cached_data_path)
dataset_0 = bliss_client.get_dataset_file(filename="dataset_0.pt")
print(" Size:", len(dataset_0))
print(" Shape:", dataset_0[0]["images"].shape)
Train the model
Without pretrained weights
[ ]:
bliss_client.train(weight_save_path="tutorial_encoder/0.pt")
With pretrained weights
Download our relevant pretrained weights for your sky survey.
[ ]:
import os
assert os.path.exists("/data/scratch/zhteoh/tutorial/data/pretrained_models")
bliss_client.load_pretrained_weights_for_survey(survey="sdss", filename="sdss_pretrained.pt")
!ls /data/scratch/zhteoh/tutorial/data/pretrained_models
Train on cached generated disk dataset
[ ]:
bliss_client.train_on_cached_data(
weight_save_path="tutorial_encoder/0.pt",
train_n_batches=2,
batch_size=64,
val_split_file_idxs=[1],
pretrained_weights_filename=None,
)
Run the model
Using sample SDSS dataset
Get predictions for the sample dataset
[3]:
est_cat, est_cat_table, pred_tables = bliss_client.predict_sdss(
weight_save_path="tutorial_encoder/zscore_five_band.pt",
# predict={"dataset": {"sdss_fields": [{"run": 94, "camcol": 1, "fields": [12]}, {"run": 3900, "camcol": 6, "fields": [296]}]}},
)
from n params module arguments
0 -1 1 16128 yolov5.models.common.Conv [10, 64, 5, 1]
1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1]
2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2]
3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1]
4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2]
5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6]
6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2]
7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9]
8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2]
9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3]
10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5]
11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1]
12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
13 [-1, 6] 1 0 yolov5.models.common.Concat [1]
14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False]
15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1]
16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1]
18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False]
19 [17] 1 29222 yolov5.models.yolo.Detect [33, [[4, 4]], [768]]
Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
from n params module arguments
0 -1 1 16128 yolov5.models.common.Conv [10, 64, 5, 1]
1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1]
2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2]
3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1]
4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2]
5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6]
6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2]
7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9]
8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2]
9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3]
10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5]
11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1]
12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
13 [-1, 6] 1 0 yolov5.models.common.Concat [1]
14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False]
15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1]
16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1]
18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False]
19 [17] 1 29222 yolov5.models.yolo.Detect [33, [[4, 4]], [768]]
Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
[5]:
bliss_client.plot_predictions_in_notebook()