# orca/transform/spectrum_v3.py
"""
Dynamic-spectrum helpers (v3)
* 100% read-only MeasurementSet access
* Optionally respects MS FLAG column
* bcal optional; if present, used at native resolution (NO rebin)
* NO spectral averaging: 192 chans/subband → 16 SPWs → 3072 total chans
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Iterable
from dataclasses import dataclass
from datetime import datetime
import logging
import os
import numpy as np
from numba import njit
from casacore.tables import table
from astropy.io import fits
import redis
from orca.celery import app
from orca.utils.datetimeutils import STAGE_III_INTEGRATION_TIME
from orca.configmanager import queue_config
[docs]
logger = logging.getLogger(__name__)
# --- Constants: NON-averaged mode ----------------------------------------
[docs]
N_CHAN_OUT = 192 # full SPW width
_CHAN_WIDTH_MHZ = 0.023926 # native per-channel width
_TRANSPORT_DTYPE = np.float32
[docs]
REDIS_URL = queue_config.result_backend_uri
[docs]
REDIS_EXPIRE_S = 3600 * 10
[docs]
REDIS_KEY_PREFIX = "spec-v3-" # distinct keyspace
# Selected cross-correlation rows (unchanged)
[docs]
ROW_NUMS = [
('LWA-128&LWA-160', 54282),
('LWA-048&LWA-051', 33335),
('LWA-018&LWA-237', 30360),
('LWA-065&LWA-094', 41689),
('LWA-286&LWA-333', 28524),
]
# --- Calibration + flagging at native channelization ---------------------
@njit
def _applycal_cross_with_flags_native(data, flags, bcal):
"""
data : complex64[nRow, nChan, 4]
flags : bool [nRow, nChan, 4]
bcal : complex64[nAnt, nChan, 2] or None
Applies MS flags. If bcal is provided, applies diagonal 2x2 gains per ant.
"""
n_row, n_chan, _ = data.shape
data4 = data.reshape(n_row, n_chan, 2, 2)
out4 = np.empty_like(data4)
if bcal is None:
for i in range(n_row):
for c in range(n_chan):
if flags[i, c, 0] or flags[i, c, 1] or flags[i, c, 2] or flags[i, c, 3]:
out4[i, c, 0, 0] = np.nan
out4[i, c, 0, 1] = np.nan
out4[i, c, 1, 0] = np.nan
out4[i, c, 1, 1] = np.nan
else:
out4[i, c, 0, 0] = data4[i, c, 0, 0]
out4[i, c, 0, 1] = data4[i, c, 0, 1]
out4[i, c, 1, 0] = data4[i, c, 1, 0]
out4[i, c, 1, 1] = data4[i, c, 1, 1]
return out4.reshape(n_row, n_chan, 4)
bcal_inv = (1.0 / bcal).astype(np.complex64)
n_ant = bcal_inv.shape[0]
i_row = 0
for i in range(n_ant):
for j in range(i + 1, n_ant):
for c in range(n_chan):
if flags[i_row, c, 0] or flags[i_row, c, 1] or flags[i_row, c, 2] or flags[i_row, c, 3]:
out4[i_row, c, 0, 0] = np.nan
out4[i_row, c, 0, 1] = np.nan
out4[i_row, c, 1, 0] = np.nan
out4[i_row, c, 1, 1] = np.nan
else:
g_i = np.diag(bcal_inv[i, c]) # 2x2
g_j = np.conjugate(np.diag(bcal_inv[j, c]))
tmp = g_i @ data4[i_row, c]
out4[i_row, c] = tmp @ g_j
i_row += 1
return out4.reshape(n_row, n_chan, 4)
# --- Payload dataclass (same shape as v2 to make downstream code simple) --
@dataclass
class _SnapshotSpectrumV3:
type: str
subband_no: int
scan_no: int
key: str
def to_json(self):
return {"type": self.type, "subband_no": self.subband_no, "scan_no": self.scan_no, "key": self.key}
@classmethod
def from_json(cls, d):
return cls(d["type"], d["subband_no"], d["scan_no"], d["key"])
# --- Tasks ----------------------------------------------------------------
@app.task(name="orca.transform.spectrum_v3.dynspec_map_v3")
[docs]
def dynspec_map_v3(subband_no: int,
scan_no: int,
ms: str,
bcal: str | None = None,
use_ms_flags: bool = True) -> List[_SnapshotSpectrumV3]:
"""
Map step on a SINGLE MS:
• Reads DATA & FLAG (cross-correlations only)
• Optionally applies bcal at native resolution (192 channels) — NO rebin
• Stores incoherent-sum and selected-baseline spectra in Redis
Returns JSONable _SnapshotSpectrumV3 list (for the reducer).
"""
with table(ms, ack=False) as t:
tcross = t.query("ANTENNA1!=ANTENNA2")
data = tcross.getcol("DATA") # complex64, (nRow, nChan, 4)
flags = tcross.getcol("FLAG") # bool, (nRow, nChan, 4)
ant1 = tcross.getcol("ANTENNA1")
ant2 = tcross.getcol("ANTENNA2")
n_ant = int(max(ant1.max(), ant2.max()) + 1)
# bcal handling (NO rebin)
if bcal is None:
bcal_dat = None
else:
with table(bcal, ack=False) as tb:
bcal_raw = tb.getcol("CPARAM") # (nAnt, nChan, 2)
bcal_raw_flag = tb.getcol("FLAG")
bcal_raw = bcal_raw.astype(np.complex64)
# respect flagged gains → NaN (they’ll propagate)
bcal_raw[bcal_raw_flag] = np.nan
bcal_dat = bcal_raw
flags_in = flags if use_ms_flags else np.zeros_like(flags, dtype=np.bool_)
calibrated = _applycal_cross_with_flags_native(
data.astype(np.complex64), flags_in, bcal_dat
)
amp = np.abs(calibrated).astype(_TRANSPORT_DTYPE) # keep NaNs from flags
# incoherent sum across baselines
incoh_sum = np.nanmean(amp, axis=0).astype(_TRANSPORT_DTYPE) # (nChan, 4)
r = redis.Redis.from_url(REDIS_URL)
out : list[_SnapshotSpectrumV3] = []
key_sum = f"{REDIS_KEY_PREFIX}{Path(ms).stem}-sum"
r.set(key_sum, incoh_sum.tobytes(), ex=REDIS_EXPIRE_S)
out.append(_SnapshotSpectrumV3("incoherent-sum", subband_no, scan_no, key_sum))
for name, row_idx in ROW_NUMS:
if row_idx >= amp.shape[0]:
logger.warning("%s missing row %d in %s", name, row_idx, ms)
continue
key = f"{REDIS_KEY_PREFIX}{Path(ms).stem}-{row_idx}"
r.set(key, amp[row_idx].tobytes(), ex=REDIS_EXPIRE_S)
out.append(_SnapshotSpectrumV3(name, subband_no, scan_no, key))
return [x.to_json() for x in out]
@app.task(name="orca.transform.spectrum_v3.dynspec_reduce_v3")
[docs]
def dynspec_reduce_v3(spectra: Iterable[List[_SnapshotSpectrumV3]],
start_ts: datetime,
out_dir: str) -> None:
"""
Reduce step:
• Gathers Redis blobs from dynspec_map_v3 outputs
• Builds 4×(time×freq) cubes per type, freq = 192×16 = 3072
• Writes FITS into {out_dir}/{type}/{DATE}-{corr}.fits
"""
# payload may arrive as list of dicts (Celery JSON); normalize
if isinstance(spectra[0][0], dict):
spectra = [[_SnapshotSpectrumV3.from_json(s) for s in snap] for snap in spectra]
n_scans = max(spec.scan_no for snap in spectra for spec in snap) + 1
n_freqs = N_CHAN_OUT * 16 # 192 × 16 = 3072
types = ['incoherent-sum'] + [name for name, _ in ROW_NUMS]
cubes = {t: np.zeros((4, n_scans, n_freqs), dtype=_TRANSPORT_DTYPE) for t in types}
r = redis.Redis.from_url(REDIS_URL)
for snapshot in spectra:
for spec in snapshot:
if spec.type not in cubes:
continue
j = spec.scan_no
k = spec.subband_no * N_CHAN_OUT
buf = r.get(spec.key)
r.delete(spec.key)
if buf is None:
logger.warning("Missing Redis key: %s", spec.key)
continue
arr = np.frombuffer(buf, dtype=_TRANSPORT_DTYPE)
if arr.size != N_CHAN_OUT * 4:
logger.warning("Corrupted Redis entry: %s (%d values)", spec.key, arr.size)
continue
arr = arr.reshape(N_CHAN_OUT, 4) # (chan, corr)
cubes[spec.type][:, j, k:k+N_CHAN_OUT] = arr.T # (corr, time, freq)
# --- Write FITS (XX, XY, YX, YY) ---
for name, dat in cubes.items():
for i, corr in enumerate(['XX', 'XY', 'YX', 'YY']):
hdu = fits.PrimaryHDU(dat[i].T)
# WCS header
zero_ut = datetime(start_ts.year, start_ts.month, start_ts.day)
hdr = hdu.header
hdr['CTYPE2'] = 'FREQ'
hdr['CUNIT2'] = 'MHz'
hdr['CRVAL2'] = 13.398 # first chan MHz (unchanged baseline)
hdr['CDELT2'] = _CHAN_WIDTH_MHZ # 0.023926 MHz native
hdr['CRPIX2'] = 1
hdr['CTYPE1'] = 'TIME'
hdr['CUNIT1'] = 'HOUR'
hdr['CRVAL1'] = (start_ts - zero_ut).total_seconds() / 3600
hdr['CDELT1'] = STAGE_III_INTEGRATION_TIME.total_seconds() / 3600
hdr['CRPIX1'] = 1
hdr['DATE-OBS'] = start_ts.date().isoformat()
out_sub = f"{out_dir}/{name}"
os.makedirs(out_sub, exist_ok=True)
fits.HDUList([hdu]).writeto(
f"{out_sub}/{start_ts.date().isoformat()}-{corr}.fits",
overwrite=True
)