{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# BLISS User API\n", "\n", "Bayesian Light Source Separator (BLISS) is a Bayesian method for deblending and cataloging light sources." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "!pip install -e /home/zhteoh/770-bulk-predict" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "!pip install bliss-deblender" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorial" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "from bliss.api import BlissClient\n", "\n", "# bliss_client = BlissClient(cwd=\"/data/scratch/zhteoh/tutorial\")\n", "bliss_client = BlissClient(cwd=\"/tmp/pytest-of-zhteoh/pytest-417\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Generate synthetic image data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Simulating images in batches for file: 100%|██████████| 2/2 [04:34<00:00, 137.09s/it]\n", "Simulating images in batches for file: 100%|██████████| 2/2 [04:41<00:00, 140.52s/it]3s/it]\n", "Generating and writing cached dataset files: 100%|██████████| 2/2 [09:15<00:00, 277.74s/it]\n" ] } ], "source": [ "bliss_client.generate(\n", " n_batches=3, \n", " batch_size=64,\n", " max_images_per_file=128\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Pass additional custom configuration parameters" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Simulating images in batches for file: 100%|██████████| 2/2 [01:01<00:00, 30.78s/it]\n", "Simulating images in batches for file: 100%|██████████| 2/2 [01:06<00:00, 33.06s/it]7s/it]\n", "Generating and writing cached dataset files: 100%|██████████| 2/2 [02:07<00:00, 63.95s/it]\n" ] } ], "source": [ "# Alter default cached_data_path\n", "bliss_client.cached_data_path = \"/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\"\n", "\n", "bliss_client.generate(\n", " n_batches=3, # required\n", " batch_size=64, # required\n", " max_images_per_file=128, # required\n", " simulator={\"survey\": {\"prior_config\": {\"mean_sources\": 0.02}}}, # optional\n", " generate={\"file_prefix\": \"dataset\"}, # optional\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.cached_data_path = \"/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Check that the dataset is generated\n", "!ls /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\n", "!du -sh /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\n", "# !cat /data/scratch/zhteoh/tutorial/dataset/hparams.yaml\n", "\n", "print(\"Dataset:\", bliss_client.cached_data_path)\n", "dataset_0 = bliss_client.get_dataset_file(filename=\"dataset_0.pt\")\n", "print(\" Size:\", len(dataset_0))\n", "print(\" Shape:\", dataset_0[0][\"images\"].shape)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Train the model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Without pretrained weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.train(weight_save_path=\"tutorial_encoder/0.pt\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### With pretrained weights\n", "\n", "Download our relevant pretrained weights for your sky survey." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "assert os.path.exists(\"/data/scratch/zhteoh/tutorial/data/pretrained_models\")\n", "\n", "bliss_client.load_pretrained_weights_for_survey(survey=\"sdss\", filename=\"sdss_pretrained.pt\")\n", "\n", "!ls /data/scratch/zhteoh/tutorial/data/pretrained_models" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Train on cached generated disk dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.train_on_cached_data(\n", " weight_save_path=\"tutorial_encoder/0.pt\",\n", " train_n_batches=2,\n", " batch_size=64,\n", " val_split_file_idxs=[1],\n", " pretrained_weights_filename=None,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Run the model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Using sample SDSS dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Get predictions for the sample dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " from n params module arguments \n", " 0 -1 1 16128 yolov5.models.common.Conv [10, 64, 5, 1] \n", " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", " 19 [17] 1 29222 yolov5.models.yolo.Detect [33, [[4, 4]], [768]] \n", "Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs\n", "\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "480b4977d65b42c1afbca0d3ce02e0c9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\n", " from n params module arguments \n", " 0 -1 1 16128 yolov5.models.common.Conv [10, 64, 5, 1] \n", " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", " 19 [17] 1 29222 yolov5.models.yolo.Detect [33, [[4, 4]], [768]] \n", "Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs\n", "\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b844c98de046471c81406f36002e8b94", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "est_cat, est_cat_table, pred_tables = bliss_client.predict_sdss(\n", " weight_save_path=\"tutorial_encoder/zscore_five_band.pt\",\n", " # predict={\"dataset\": {\"sdss_fields\": [{\"run\": 94, \"camcol\": 1, \"fields\": [12]}, {\"run\": 3900, \"camcol\": 6, \"fields\": [296]}]}},\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " Bokeh Plot\n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "bliss_client.plot_predictions_in_notebook()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of entries: 254\n" ] }, { "data": { "text/html": [ "Table length=5\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
idxplocssource_typestar_flux_ugalaxy_flux_ustar_flux_ggalaxy_flux_gstar_flux_rgalaxy_flux_rstar_flux_igalaxy_flux_istar_flux_zgalaxy_flux_zgalaxy_disk_fracgalaxy_beta_radiansgalaxy_disk_qgalaxy_a_dgalaxy_bulge_qgalaxy_a_b
nmgynmgynmgynmgynmgynmgynmgynmgynmgynmgy
0tensor([266.37012, 101.83371])tensor([1])1.9210918e-071.88252950.651014574.0829490.83070425.49383263.483197723.9213753.790694242.905990.0041288760.0034716730.000761442240.00214500030.00328061381.7661792
1tensor([549.19055, 274.24149])tensor([1])1.3584375e-050.263266950.1219470050.70436280.4644341.75929521.95451758.9221662.766725516.5474130.00283737530.00255063460.00073144650.0054546640.00453693931.13643
2tensor([236.75182, 576.83881])tensor([1])0.000119498620.085143840.369022550.80455710.599738061.08202061.48342373.12111570.0099486162.36346460.00207926150.00251216980.00059752040.00253470310.00413356260.96826047
3tensor([248.72098, 260.99545])tensor([1])4.0953696e-060.36880590.429168340.979606870.58036841.01144361.19781122.89758940.00262587753.12889080.002359160.0027015970.000583882850.00323049210.00374381221.0822673
4tensor([ 7.23273, 211.16307])tensor([1])0.170723780.885299150.26931830.64269190.51498930.953418430.89364082.16016530.0152116163.360910.00249141780.00450364030.000305592520.00282802670.00305669081.2435206
\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Number of entries:\", len(est_cat_table))\n", "est_cat_table[:5].show_in_notebook(display_length=5)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Inspect probabilistic predictions\n", "\n", "BLISS produces probability distributions on the predicted latent variables." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of entries (RCF (94, 1, 12)): 24964\n" ] }, { "data": { "text/html": [ "Table length=5\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
idxon_prob_falseon_prob_truegalaxy_prob_falsegalaxy_prob_truegalsim_disk_frac_meangalsim_disk_frac_stdgalsim_beta_radians_meangalsim_beta_radians_stdgalsim_disk_q_meangalsim_disk_q_stdgalsim_a_d_meangalsim_a_d_stdgalsim_bulge_q_meangalsim_bulge_q_stdgalsim_a_b_meangalsim_a_b_stdstar_flux_u_meanstar_flux_u_stdstar_flux_g_meanstar_flux_g_stdstar_flux_r_meanstar_flux_r_stdstar_flux_i_meanstar_flux_i_stdstar_flux_z_meanstar_flux_z_stdgalaxy_flux_u_meangalaxy_flux_u_stdgalaxy_flux_g_meangalaxy_flux_g_stdgalaxy_flux_r_meangalaxy_flux_r_stdgalaxy_flux_i_meangalaxy_flux_i_stdgalaxy_flux_z_meangalaxy_flux_z_std
radradarcsecarcsecarcsecarcsecnmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgy
00.99678020.00321976540.000432789330.99956720.57243181.85244510.37481451.66494390.470053432.07418161.23134020.81039680.066831352.09495470.454636571.23431223.34724663.89838984.31438541.18222395.6566250.493259255.8699220.752895954.53218656.9602565.04342751.52036065.5714180.858202466.75043870.376854037.3013840.460732466.12694452.2173305
10.99980330.000196710480.0001000165940.99991.59988572.190470.50159981.75457870.66870571.86497261.63008430.697890160.434260371.99820510.648199561.35815811.56567624.78707363.52552181.03309994.5661420.936750773.79787452.5735153.591493.69056994.38897132.08299145.90586950.68660566.60502150.334762286.94263650.465261376.59285831.2505308
20.99991e-040.000172615050.99982741.86665422.04445770.437796121.72137580.65551781.74150441.84740090.644446550.181624891.83131380.681199551.27925161.21534016.09186653.17501641.34800454.14894960.99908193.6976182.75835084.22460272.578333.95683032.859255.93399050.815708046.5072230.430997976.77234360.703732437.22697161.1791977
30.999846460.00015353810.00010395050.999896051.62800341.8893498-0.230698592.16578960.479752781.80081022.15197730.781493070.60241221.81481230.90130571.46621541.84037594.16382843.77111480.99872534.81160550.430488084.63790461.15260224.84519961.14406314.0424293.65394545.92075060.97080376.78137870.50634097.14950940.90572317.5742961.3889303
40.99973310.00026693050.0001000165940.99991.42483881.7584034-0.00107550621.95953180.631301641.49456482.07881470.804047170.585586551.81606280.755789761.33613012.65422825.42983343.81303170.97062625.10223770.412487184.7252271.69422765.42944430.879893545.010332.47703775.8588990.79444186.77067570.430732677.10946850.73547147.78264330.8656995
\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Number of entries (RCF (94, 1, 12)):\", len(pred_tables[(94, 1, 12)]))\n", "pred_tables[(94, 1, 12)][:5].show_in_notebook(display_length=5)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of entries (RCF (3900, 6, 269)): 24964\n" ] }, { "data": { "text/html": [ "Table length=5\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
idxon_prob_falseon_prob_truegalaxy_prob_falsegalaxy_prob_truegalsim_disk_frac_meangalsim_disk_frac_stdgalsim_beta_radians_meangalsim_beta_radians_stdgalsim_disk_q_meangalsim_disk_q_stdgalsim_a_d_meangalsim_a_d_stdgalsim_bulge_q_meangalsim_bulge_q_stdgalsim_a_b_meangalsim_a_b_stdstar_flux_u_meanstar_flux_u_stdstar_flux_g_meanstar_flux_g_stdstar_flux_r_meanstar_flux_r_stdstar_flux_i_meanstar_flux_i_stdstar_flux_z_meanstar_flux_z_stdgalaxy_flux_u_meangalaxy_flux_u_stdgalaxy_flux_g_meangalaxy_flux_g_stdgalaxy_flux_r_meangalaxy_flux_r_stdgalaxy_flux_i_meangalaxy_flux_i_stdgalaxy_flux_z_meangalaxy_flux_z_std
radradarcsecarcsecarcsecarcsecnmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgynmgy
00.99989070.000109314220.0001000165940.99991.84275461.89667670.17593551.53407651.01268172.11461621.56248310.75048670.0134491921.8136890.74573851.43309412.0786724.3719253.13351541.34688234.68619251.0258933.80023433.33322934.45268733.04674484.95785142.14614345.6807641.05425816.62231450.635017346.690810.90154777.42443661.308013
10.99965440.00034560130.00107365850.998926341.15804361.89814640.103563311.51382-0.00156497961.75042211.35188320.73613540.0235483652.10106330.389216421.19684793.09337473.34775354.41250130.78151085.56520840.423210625.57749030.881161875.874461.55818214.57809352.01080065.57082650.66346946.45614430.377527336.81918530.53798257.3068221.0409187
20.999850030.000149940450.000612974170.9993871.35574171.8018293-0.263671641.79469860.308890582.03052351.45481280.69167155-0.01733472.20347360.38862181.15886843.28078752.61924154.02236840.911264845.18662640.530872465.03069161.36111965.0993452.6539875.01789951.6534545.66885380.62892766.5837280.349233666.8162720.56851187.05228231.2957809
30.99963860.00036139740.000129461290.999870542.03241321.7666548-0.174253461.89105331.19250941.56565361.69150380.63886320.0470974452.09073070.455917841.46029161.69576122.82567332.98597720.97972883.79730250.755215173.61664341.58537462.90772275.69611025.23357872.46308786.19858260.8754326.8619920.542729267.11725330.742677336.76400572.2090604
40.99889680.00110319640.0001000165940.99991.74303221.6583972-0.0157330041.48242011.11489131.67240631.83293460.69212830.238222122.1990880.67237711.46318842.63429453.11351283.82552340.806065265.1348380.3358475.23295550.84052855.0726951.71613415.0337253.0645265.80013940.96344746.84038540.549091647.3483440.728391237.51123331.2816015
\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Number of entries (RCF (3900, 6, 269)):\", len(pred_tables[(3900, 6, 269)]))\n", "pred_tables[(3900, 6, 269)][:5].show_in_notebook(display_length=5)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Save predicted catalog to FITS file" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "est_cat_table.write(\"est_cat.fits\", format=\"fits\", overwrite=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Check that catalog is saved as intended\n", "from astropy.table import Table\n", "\n", "est_cat_table = Table.read(\"est_cat.fits\", format=\"fits\")\n", "print(\"Number of entries:\", len(est_cat_table))\n", "est_cat_table.show_in_notebook(display_length=5)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Evaluate prediction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from bliss.metrics import BlissMetrics\n", "from bliss.surveys.sdss import PhotoFullCatalog\n", "\n", "sdss_data_path = \"/data/scratch/zhteoh/tutorial/data/sdss\"\n", "sdss = SloanDigitalSkySurvey()\n", "photo_cat = PhotoFullCatalog.from_file()\n", "\n", "est_cat_cuda = est_cat.to(torch.device(\"cpu\"))\n", "photo_cat_cuda = photo_cat.to(torch.device(\"cpu\"))\n", "\n", "metrics = BlissMetrics()\n", "results = metrics(est_cat_cuda, photo_cat_cuda)\n", "\n", "print(results)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Using user-specified SDSS dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Download online dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from astropy.coordinates import SkyCoord\n", "from astroquery.sdss import SDSS\n", "from pathlib import Path\n", "\n", "pos = SkyCoord('0h8m05.63s +14d50m23.3s', frame='icrs') # 1011/3/44\n", "# pos = SkyCoord(\"1h8m05.73s +13d10m20.3s\", frame=\"icrs\") # 4829/5/27\n", "# pos = SkyCoord(\"1h2m05.83s -2d11m20.3s\", frame=\"icrs\") # 2699/4/71\n", "region = SDSS.query_region(pos, radius=\"5 arcsec\")\n", "run, camcol, field = region[\"run\"][0], region[\"camcol\"][0], region[\"field\"][0]\n", "print(\"run:\", run, \"camcol:\", camcol, \"field:\", field)\n", "bliss_client.load_survey(\"sdss\", run, camcol, field, download_dir=Path(\"data/sdss\"))" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Get predictions for the downloaded dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "est_cat_dl, est_cat_table_dl, pred_tables_dl = bliss_client.predict_sdss(\n", " data_path=\"data/sdss\",\n", " weight_save_path=\"tutorial_encoder/0.pt\",\n", " predict={\"dataset\": {\"run\": 1011, \"camcol\": 3, \"fields\": [44]}}\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.plot_predictions_in_notebook()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Inspect probabilistic predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Number of entries:\", len(pred_tables_dl[(1011, 3, 44)]))\n", "pred_tables_dl[(1011, 3, 44)][:5].show_in_notebook(display_length=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Using sample DECaLS dataset" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " from n params module arguments \n", " 0 -1 1 3328 yolov5.models.common.Conv [2, 64, 5, 1] \n", " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", " 19 [17] 1 29222 yolov5.models.yolo.Detect [33, [[4, 4]], [768]] \n", "Model summary: 275 layers, 30782630 parameters, 30782630 gradients, 363.8 GFLOPs\n", "\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32190fec10d448dbac15bdcd9cfa7369", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "est_cat, est_cat_table, pred_tables = bliss_client.predict_decals(\n", " weight_save_path=\"tutorial_encoder/single_band_base.pt\",\n", " predict={\n", " \"dataset\": {\n", " \"sky_coords\": [\n", " # brick '3366m010' corresponds to SDSS RCF 94-1-12\n", " {\"ra\": 336.6643042496718, \"dec\": -0.9316385797930247},\n", " # brick '1358p297' corresponds to SDSS RCF 3635-1-169\n", " {\"ra\": 135.95496736941683, \"dec\": 29.646883837721347},\n", " ]\n", " }\n", " },\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.plot_predictions_in_notebook()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 }