Source code for orca.flagging.flagoperations

"""Low-level flag manipulation operations for measurement sets.

Provides functions for directly manipulating FLAG columns in CASA
measurement sets, including flagging antennas, baselines, merging
flags between observations, and I/O with numpy arrays.

Originally adapted from Marin Anderson's code (3/8/2019).

Functions
---------
flag_ants
    Flag all visibilities involving specified antennas.
merge_flags
    OR-merge flags between two measurement sets.
merge_group_flags
    OR-merge flags across a group of measurement sets.
write_to_flag_column
    Apply flags from a numpy array to a measurement set.
save_to_flag_npy
    Save the FLAG column to a numpy file.
flag_bls
    Apply baseline flags from a text file.
"""
"""
Copy from Marin Anderson 3/8/2019
"""

import casacore.tables as pt
import numpy as np

from typing import Tuple, List

[docs] DATA = 'DATA'
[docs] CORRECTED_DATA = 'CORRECTED_DATA'
[docs] def flag_ants(ms: str, ants: List[int]) -> str: """Flag all visibilities involving specified antennas. Args: ms: Path to the measurement set. ants: List of antenna indices (0-indexed) to flag. Returns: Path to the modified measurement set. """ if len(ants) > 0 : pt.taql('UPDATE %s SET FLAG = True WHERE ANTENNA1 IN %s OR ANTENNA2 IN %s' % (ms, tuple(ants), tuple(ants))) return ms
[docs] def merge_flags(ms1: str, ms2: str) -> Tuple[str, str]: """OR-merge flags between two measurement sets. Updates both measurement sets so their FLAG columns contain the logical OR of the original flags from both files. Args: ms1: Path to the first measurement set. ms2: Path to the second measurement set. Returns: Tuple of (ms1, ms2) paths. """ with pt.table(ms1, readonly=False, ack=False) as t_prev, pt.table(ms2, readonly=False, ack=False) as t: flagcol1 = t_prev.getcol('FLAG') flagcol2 = t.getcol('FLAG') flagcol = flagcol1 | flagcol2 t.putcol('FLAG', flagcol) t_prev.putcol('FLAG', flagcol) return ms1, ms2
[docs] def merge_group_flags(ms_list: List[str]) -> List[str]: """OR-merge flags across a group of measurement sets. Updates all measurement sets so they share the same merged FLAG column, containing the logical OR of flags from all input files. Args: ms_list: List of measurement set paths. Returns: The input list of paths (all now updated). """ with pt.table(ms_list[0], readonly=True, ack=False) as table: merged_flags = table.getcol('FLAG') for ms in ms_list[1:]: with pt.table(ms, readonly=True, ack=False) as tt: merged_flags = merged_flags | tt.getcol('FLAG') for ms in ms_list: with pt.table(ms, readonly=False, ack=False) as tt: tt.putcol('FLAG', merged_flags) return ms_list
[docs] def write_to_flag_column(ms: str, flag_npy: str) -> str: """Apply flags from a numpy array to a measurement set. The numpy array is OR-merged with existing flags. Args: ms: Path to the measurement set. flag_npy: Path to the numpy file containing boolean flag array. Returns: Path to the modified measurement set. Raises: AssertionError: If the flag array shape doesn't match the MS. """ with pt.table(ms, readonly=False, ack=False) as t: flagcol = np.load(flag_npy) assert flagcol.shape == t.getcol('FLAG').shape, 'Flag file and measurement set have different shapes' t.putcol('FLAG', flagcol | t.getcol('FLAG')) return ms
[docs] def save_to_flag_npy(ms: str, flag_npy: str) -> str: """Save the FLAG column from a measurement set to a numpy file. Args: ms: Path to the measurement set. flag_npy: Output path for the numpy file. Returns: Path to the saved numpy file. """ with pt.table(ms, ack=False) as t: flagcol = t.getcol('FLAG') np.save(flag_npy, flagcol) return flag_npy
[docs] def flag_bls(msfile: str, blfile: str) -> str: """Apply baseline flags from a text file to a measurement set. Args: msfile: Path to the measurement set. blfile: Path to the baseline flag file. Format is one baseline per line as 'ant1&ant2' (0-indexed). Returns: Path to the modified measurement set. Raises: ValueError: If the number of visibilities is unexpected. """ with pt.table(msfile, readonly=False, ack=False) as t: flagcol = t.getcol('FLAG') # flagcol.shape = (N*(N-1)/2 + N)*Nspw*Nints,Nchans,Ncorrs Nants = t.getcol('ANTENNA1')[-1] + 1 Nbls = int((Nants*(Nants-1)/2.) + Nants) if not (flagcol.shape[0] >= Nbls): raise ValueError(f'Unexpected number of visibilities in flagcol {flagcol.shape}') Nspw = int(flagcol.shape[0]/Nbls) Nchans = flagcol.shape[1] Ncorrs = flagcol.shape[2] # make the correlation matrix flagmat = np.zeros((Nants,Nants,Nspw,Nchans,Ncorrs),dtype=bool) tiuinds = np.triu_indices(Nants) # put the FLAG column into the correlation matrix flagmat[tiuinds] = flagcol.reshape(Nspw,Nbls,Nchans,Ncorrs).transpose(1,0,2,3) # read in baseline flags ant1,ant2 = np.genfromtxt(blfile,delimiter='&',unpack=True,dtype=int) # flag the correlation matrix flagmat[(ant1,ant2)] = 1 # reshape correlation matrix into FLAG column newflagcol = flagmat[tiuinds].transpose(1,0,2,3).reshape(Nbls*Nspw,Nchans,Ncorrs) # t.putcol('FLAG',newflagcol) return msfile