"""Flagging transforms for measurement sets.
Provides high-level flagging operations that combine detection and application,
including AOFlagger integration and autocorrelation-based antenna flagging.
Functions
---------
flag_with_aoflagger
Run AOFlagger RFI detection on a measurement set.
flag_ants
Flag specified antennas in a measurement set.
flag_ant_chan_from_autocorr
Flag antennas and channels based on autocorrelation anomalies.
flag_on_autocorr
Identify and flag bad antennas from autocorrelation statistics.
identify_bad_ants
Identify bad antennas without applying flags.
save_flag_metadata
Save flag statistics in a compact binary format.
"""
from typing import List, Optional
import subprocess
import logging
from matplotlib.pyplot import flag
import numpy as np
import datetime
from casacore import tables
from orca.configmanager import execs, telescope as tele
from orca.utils.maths import core_outrigger_slices
from orca.utils import flagutils
import os
from casatools import table
[docs]
log = logging.getLogger(__name__)
[docs]
def flag_with_aoflagger(ms: str, strategy: str='/opt/share/aoflagger/strategies/nenufar-lite.lua',
in_memory : bool =False, n_threads : int = 5) -> str:
"""Run AOFlagger RFI detection on a measurement set.
Args:
ms: Path to the measurement set.
strategy: Path to the AOFlagger Lua strategy file.
in_memory: If True, load data into memory for processing.
n_threads: Number of threads for AOFlagger.
Returns:
Path to the flagged measurement set.
Raises:
RuntimeError: If AOFlagger returns a non-zero exit code.
"""
# TODO use the API
arg_list = [execs.aoflagger, '-strategy', strategy, '-j', str(n_threads)]
if not in_memory:
arg_list.append('-direct-read')
arg_list.append(ms)
proc = subprocess.Popen(arg_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
stdoutdata, stderrdata = proc.communicate()
if proc.returncode != 0:
if stderrdata:
log.error(f'Error in aoflagger: {stderrdata.decode()}')
if stdoutdata:
log.error(stdoutdata.decode())
raise RuntimeError(f'Error in aoflagger for {ms}.')
finally:
proc.terminate()
return ms
[docs]
def flag_ants(ms: str, ants: List[int]) -> str:
"""Flag all visibilities involving specified antennas.
Args:
ms: Path to the measurement set.
ants: List of antenna indices (0-indexed) to flag.
Returns:
Path to the modified measurement set.
"""
if len(ants) > 0 :
ants_str = '(' + ', '.join([str(ant) for ant in ants]) + ')'
tables.taql('UPDATE %s SET FLAG = True WHERE ANTENNA1 IN %s OR ANTENNA2 IN %s' % (ms, ants_str, ants_str))
return ms
[docs]
def flag_ant_chan_from_autocorr(ms: str, threshold: float=5.0) -> str:
"""Flag antennas and channels based on autocorrelation anomalies.
Args:
ms: Path to the measurement set.
threshold: Sigma threshold for outlier detection.
Returns:
Path to the flagged measurement set.
Note:
Currently only works on single spectral window data.
"""
return ms
[docs]
def flag_on_autocorr(ms, date: Optional[datetime.date] = None, thresh: float=7.0, column='DATA') -> str:
"""Identify and flag bad antennas from autocorrelation statistics.
Optionally loads a priori bad antenna list for the given date before
performing autocorrelation-based detection.
Args:
ms: Path to the measurement set.
date: Observation date for loading a priori bad antennas.
thresh: Sigma threshold for flagging (default 7.0).
column: Data column to analyze ('DATA' or 'CORRECTED_DATA').
Returns:
Path to the flagged measurement set.
"""
if date:
a_priori_bad_ants = flagutils.get_bad_ants(date)
flag_ants(ms, a_priori_bad_ants)
with tables.table(ms, readonly=False, ack=False) as t:
bad_ants = identify_bad_ants(t, thresh, column)
return flag_ants(ms, bad_ants)
[docs]
def identify_bad_ants(t: tables.table, thresh: float=7, column='DATA') -> List[int]:
"""Identify bad antennas from autocorrelation statistics.
Compares each antenna's autocorrelation bandpass to a median template
for core and outrigger antennas separately. Antennas with normalized
deviations exceeding the threshold are flagged.
Args:
t: Open casacore table object for the measurement set.
thresh: Sigma threshold for flagging (default 7.0).
column: Data column to analyze ('DATA' or 'CORRECTED_DATA').
Returns:
List of antenna indices identified as bad.
Raises:
AssertionError: If data contains multiple timestamps or spectral windows.
"""
tautos = t.query('ANTENNA1=ANTENNA2')
autos_corrected : np.ndarray = np.abs(tautos.getcol(column))[...,(0,3)]
autos_flags : np.ndarray = tautos.getcol('FLAG')[...,(0,3)]
autos_antnums : np.ndarray = tautos.getcol('ANTENNA1')
n_ants = tele.n_ant
assert autos_corrected.shape[0] == n_ants, 'Doing multiple timestamp or multiple spws is too complicated for now.'
bandpasses = autos_corrected
is_core, is_exp = core_outrigger_slices(autos_antnums, tele.outriggers)
core_template = np.median(bandpasses[is_core], axis=0)
exp_template = np.median(bandpasses[is_exp], axis=0)
diff = np.zeros_like(bandpasses)
diff[is_core] = np.abs(bandpasses[is_core] - core_template) / core_template
diff[is_exp] = np.abs(bandpasses[is_exp] - exp_template) / exp_template
diff = diff * (1 - autos_flags)
sum_diff = np.mean(diff, axis=1)
mad_core = np.median(sum_diff[is_core], axis=0)
mad_exp = np.median(sum_diff[is_exp], axis=0)
scores = sum_diff / mad_core
scores[is_exp] = sum_diff[is_exp] / mad_exp
return autos_antnums[np.any(scores > thresh, axis=1)]
def _mask_bad_rows(autos_corrected, autos_flags, n_chans):
"""Mask rows with high flagged channel fractions.
Internal helper function for autocorrelation analysis.
"""
bad_row_mask = np.zeros(autos_corrected.shape[0], dtype=bool)
for pol in range(2):
flagged_chan_count = np.sum(autos_flags[..., pol], axis=1)
median_flag_count = np.median(flagged_chan_count)
if median_flag_count < n_chans // FLAG_COUNT_FACTOR:
bad_row_mask[flagged_chan_count > FLAG_COUNT_FACTOR * median_flag_count] = True
return bad_row_mask
def _calculate_avg_bandpass(autos_corrected, bad_row_mask, n_ants, n_ints, n_chans) -> np.ndarray:
"""Calculate average bandpass from autocorrelation data.
Internal helper function for multi-integration autocorrelation analysis.
"""
bp = np.zeros((n_ants, n_chans, 2))
autos_corrected[bad_row_mask] = np.nan
return bp