Source code for orca.flagging.flag_bad_chans

"""Channel flagging based on visibility amplitude statistics.

Identifies bad frequency channels by analyzing visibility amplitude
and SNR distributions. Uses median filtering to detect outliers.

Originally adapted from Marin Anderson's code (3/8/2019).

Functions
---------
flag_bad_chans
    Identify and optionally apply channel flags.
"""
#!/usr/bin/env python

"""
Copy from Marin Anderson 3/8/2019
"""
from __future__ import division
from typing import Optional
import numpy as np
import casacore.tables as pt
import os,argparse
import numpy.ma as ma
import logging
from scipy.ndimage import filters

from orca.configmanager import telescope as tele

[docs] logger = logging.getLogger(__name__)
[docs] def flag_bad_chans(msfile: str, band: str, usedatacol=False, generate_plot=False, apply_flag=False, crosshand=False, uvcut_m: Optional[float] = None): """Flag bad channels. Finds remaining bad channels and flags those in the measurement set. Also writes out text file that lists flags that were applied. Args: msfile: measurement set to flag. band: Subband number, must be convertable to integer. usedatacol: If True, uses DATA column, else use CORRECTED_DATA. generate_plot: generate a plot or not. apply_flag: Whether to apply the flags. crosshand: If true, it will use the XY and YX correlations when determining flags. Otherwise, it will ignore the flags that are in flaglist[:,1] and flaglist[:,2]. uvcut_m: uvcut in meters before doing thresholding to suppress short baseline flux """ with pt.table(msfile, readonly=False) as t: tcross = t.query('ANTENNA1!=ANTENNA2') if usedatacol: datacol = tcross.getcol('DATA') else: datacol = tcross.getcol('CORRECTED_DATA') flagcol = tcross.getcol('FLAG') if uvcut_m: uvw = tcross.getcol('UVW') uvdist = np.sqrt( uvw[:,0]**2. + uvw[:,1]**2. ) indsbyuvdist = np.where(uvdist > uvcut_m) datacol = datacol[indsbyuvdist] flagcol = flagcol[indsbyuvdist] datacolamp = np.abs(datacol) datacolamp_mask = ma.masked_array(datacolamp, mask=flagcol, fill_value=np.nan) maxamps = np.ma.max(datacolamp_mask, axis=0) meanamps = np.ma.mean(datacolamp_mask, axis=0) maxamps_medfilt = filters.median_filter(maxamps, size=(25,1)) #10,1)) maxamps_norm = maxamps / maxamps_medfilt maxamps_norm_stdfilt = filters.generic_filter(maxamps_norm, np.std, size=(25,1)) threshold_vec = np.array([10,6,6,10]) maxamps_lower = 1 - threshold_vec*np.ma.min(maxamps_norm_stdfilt, axis=0) maxamps_upper = 1 + threshold_vec*np.ma.min(maxamps_norm_stdfilt, axis=0) meanamps_stdfilt = filters.generic_filter(meanamps, np.std, size=(25,1)) flaglist = np.where( (maxamps_norm < maxamps_lower) | (maxamps_norm > maxamps_upper) | (meanamps > np.ma.median(meanamps, axis=0)+100*np.ma.min(meanamps_stdfilt, axis=0)) ) if not crosshand: flaglist = np.unique(flaglist[0][np.where( (flaglist[1] == 0) | (flaglist[1] == 3) )]) else: flaglist = np.unique(flaglist[0]) ################################################# #this is for testing purposes only #generate plot of visibilities for quick check of how well flagging performed if generate_plot: import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt plt.figure(figsize=(5,10)) chans = np.arange(0, tele.n_chan) for chan in chans: if chan not in flaglist: chanpts = np.zeros(len(datacolamp_mask[:,chan,0]))+chan plt.plot(datacolamp_mask[:,chan,0],chanpts, '.', color='Blue', markersize=0.5) plt.plot(datacolamp_mask[:,chan,3],chanpts, '.', color='Green', markersize=0.5) plt.ylim([0, tele.n_chan - 1]) plt.ylabel('channel') plt.xlabel('Amp') plt.gca().invert_yaxis() plotfile = os.path.splitext(os.path.abspath(msfile))[0]+'.png' plt.savefig(plotfile) ################################################ logger.info('Flaglist size is %i' % flaglist.size) if flaglist.size > 0: # turn flaglist into text file of channel flags textfile = os.path.splitext(os.path.abspath(msfile))[0]+'.chans' chans = np.arange(0,tele.n_chan) chanlist = chans[flaglist] with open(textfile, 'w') as f: for chan in chanlist: f.write('%02d:%03d\n' % (int(band),chan)) # write flags into FLAG column if apply_flag: flagcol_altered = t.getcol('FLAG') flagcol_altered[:,flaglist,:] = 1 t.putcol('FLAG', flagcol_altered) return msfile
[docs] def main(): parser = argparse.ArgumentParser(description="Flag bad channels and write out list of channels that were \ flagged into text file of same name as ms. MUST BE RUN ON \ SINGLE SUBBAND MS.") parser.add_argument("msfile", help="Measurement set.") parser.add_argument("band", help="Subband number.") parser.add_argument("--usedatacol", action="store_true", default=False, help="Grab DATA column, not CORRECTED_DATA.") parser.add_argument('--plot', action='store_true', default=False, help='Generate plot of amp vs channel.') parser.add_argument('--apply-flag', action='store_true', default=False, help='Apply flags to measurement set.') parser.add_argument('--crosshand', action='store_true', default=False, help='Use the cross-hand visibilities also.') parser.add_argument('--uvcut_m', action='store', type=float, default=None, help='Only use visibilities greater than {uvcut_m} in meters when determining channel flags. Default is None.') args = parser.parse_args() flag_bad_chans(args.msfile, args.band, usedatacol=args.usedatacol, generate_plot=args.plot, apply_flag=args.apply_flag, crosshand=args.crosshand, uvcut_m=args.uvcut_m)
if __name__ == '__main__': main()