{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# BLISS User API\n", "\n", "Bayesian Light Source Separator (BLISS) is a Bayesian method for deblending and cataloging light sources." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "!pip install -e /home/zhteoh/770-bulk-predict" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "!pip install bliss-deblender" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Tutorial" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "from bliss.api import BlissClient\n", "\n", "# bliss_client = BlissClient(cwd=\"/data/scratch/zhteoh/tutorial\")\n", "bliss_client = BlissClient(cwd=\"/tmp/pytest-of-zhteoh/pytest-417\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Generate synthetic image data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Simulating images in batches for file: 100%|██████████| 2/2 [04:34<00:00, 137.09s/it]\n", "Simulating images in batches for file: 100%|██████████| 2/2 [04:41<00:00, 140.52s/it]3s/it]\n", "Generating and writing cached dataset files: 100%|██████████| 2/2 [09:15<00:00, 277.74s/it]\n" ] } ], "source": [ "bliss_client.generate(\n", " n_batches=3, \n", " batch_size=64,\n", " max_images_per_file=128\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Pass additional custom configuration parameters" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Data will be saved to /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Simulating images in batches for file: 100%|██████████| 2/2 [01:01<00:00, 30.78s/it]\n", "Simulating images in batches for file: 100%|██████████| 2/2 [01:06<00:00, 33.06s/it]7s/it]\n", "Generating and writing cached dataset files: 100%|██████████| 2/2 [02:07<00:00, 63.95s/it]\n" ] } ], "source": [ "# Alter default cached_data_path\n", "bliss_client.cached_data_path = \"/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\"\n", "\n", "bliss_client.generate(\n", " n_batches=3, # required\n", " batch_size=64, # required\n", " max_images_per_file=128, # required\n", " simulator={\"survey\": {\"prior_config\": {\"mean_sources\": 0.02}}}, # optional\n", " generate={\"file_prefix\": \"dataset\"}, # optional\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.cached_data_path = \"/data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Check that the dataset is generated\n", "!ls /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\n", "!du -sh /data/scratch/zhteoh/tutorial/data/cached_dataset_ms0.02\n", "# !cat /data/scratch/zhteoh/tutorial/dataset/hparams.yaml\n", "\n", "print(\"Dataset:\", bliss_client.cached_data_path)\n", "dataset_0 = bliss_client.get_dataset_file(filename=\"dataset_0.pt\")\n", "print(\" Size:\", len(dataset_0))\n", "print(\" Shape:\", dataset_0[0][\"images\"].shape)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Train the model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Without pretrained weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.train(weight_save_path=\"tutorial_encoder/0.pt\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### With pretrained weights\n", "\n", "Download our relevant pretrained weights for your sky survey." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "assert os.path.exists(\"/data/scratch/zhteoh/tutorial/data/pretrained_models\")\n", "\n", "bliss_client.load_pretrained_weights_for_survey(survey=\"sdss\", filename=\"sdss_pretrained.pt\")\n", "\n", "!ls /data/scratch/zhteoh/tutorial/data/pretrained_models" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Train on cached generated disk dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bliss_client.train_on_cached_data(\n", " weight_save_path=\"tutorial_encoder/0.pt\",\n", " train_n_batches=2,\n", " batch_size=64,\n", " val_split_file_idxs=[1],\n", " pretrained_weights_filename=None,\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Run the model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "#### Using sample SDSS dataset" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "##### Get predictions for the sample dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", " from n params module arguments \n", " 0 -1 1 16128 yolov5.models.common.Conv [10, 64, 5, 1] \n", " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", " 19 [17] 1 29222 yolov5.models.yolo.Detect [33, [[4, 4]], [768]] \n", "Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs\n", "\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "480b4977d65b42c1afbca0d3ce02e0c9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\n", " from n params module arguments \n", " 0 -1 1 16128 yolov5.models.common.Conv [10, 64, 5, 1] \n", " 1 -1 3 12672 yolov5.models.common.Conv [64, 64, 1, 1] \n", " 2 -1 1 73984 yolov5.models.common.Conv [64, 128, 3, 2] \n", " 3 -1 1 147712 yolov5.models.common.Conv [128, 128, 3, 1] \n", " 4 -1 1 295424 yolov5.models.common.Conv [128, 256, 3, 2] \n", " 5 -1 6 1118208 yolov5.models.common.C3 [256, 256, 6] \n", " 6 -1 1 1180672 yolov5.models.common.Conv [256, 512, 3, 2] \n", " 7 -1 9 6433792 yolov5.models.common.C3 [512, 512, 9] \n", " 8 -1 1 4720640 yolov5.models.common.Conv [512, 1024, 3, 2] \n", " 9 -1 3 9971712 yolov5.models.common.C3 [1024, 1024, 3] \n", " 10 -1 1 2624512 yolov5.models.common.SPPF [1024, 1024, 5] \n", " 11 -1 1 525312 yolov5.models.common.Conv [1024, 512, 1, 1] \n", " 12 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 13 [-1, 6] 1 0 yolov5.models.common.Concat [1] \n", " 14 -1 3 2757632 yolov5.models.common.C3 [1024, 512, 3, False] \n", " 15 -1 1 131584 yolov5.models.common.Conv [512, 256, 1, 1] \n", " 16 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", " 17 [-1, 4, 5] 1 0 yolov5.models.common.Concat [1] \n", " 18 -1 3 756224 yolov5.models.common.C3 [768, 256, 3, False] \n", " 19 [17] 1 29222 yolov5.models.yolo.Detect [33, [[4, 4]], [768]] \n", "Model summary: 275 layers, 30795430 parameters, 30795430 gradients, 374.3 GFLOPs\n", "\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b844c98de046471c81406f36002e8b94", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Predicting: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "est_cat, est_cat_table, pred_tables = bliss_client.predict_sdss(\n", " weight_save_path=\"tutorial_encoder/zscore_five_band.pt\",\n", " # predict={\"dataset\": {\"sdss_fields\": [{\"run\": 94, \"camcol\": 1, \"fields\": [12]}, {\"run\": 3900, \"camcol\": 6, \"fields\": [296]}]}},\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "
\n", " \n", "