import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass
from matplotlib.patches import Ellipse, Rectangle
from typing import Any, TypeVar, Callable
from plotastrodata.analysis_utils import AstroData, AstroFrame
from plotastrodata.coord_utils import (coord2xy, xy2coord,
get_hmdm, get_min, get_sec)
from plotastrodata.fitting_utils import gaussian1d
from plotastrodata.noise_utils import estimate_rms
from plotastrodata.other_utils import (close_figure, listing,
reform_grid, reform_data)
plt.ioff() # force to turn off interactive mode
T = TypeVar('T')
[docs]
def set_rcparams(fontsize: int = 18, nancolor: str = 'w',
dpi: int = 256) -> None:
"""Nice rcParams for figures.
Args:
fontsize (int, optional): plt.rcParams['font.size']. Defaults to 18.
nancolor (str, optional): plt.rcParams['axes.facecolor']. Defaults to 'w'.
dpi (int, optional): plt.rcParams['savefig.dpi']. Defaults to 256.
"""
# plt.rcParams['font.family'] = 'arial'
plt.rcParams['axes.facecolor'] = nancolor
plt.rcParams['font.size'] = fontsize
plt.rcParams['savefig.dpi'] = dpi
plt.rcParams['legend.fontsize'] = 15
plt.rcParams['axes.linewidth'] = 1.5
plt.rcParams['xtick.direction'] = 'inout'
plt.rcParams['ytick.direction'] = 'inout'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True
plt.rcParams['xtick.major.size'] = 10
plt.rcParams['ytick.major.size'] = 10
plt.rcParams['xtick.minor.size'] = 6
plt.rcParams['ytick.minor.size'] = 6
plt.rcParams['xtick.major.width'] = 1.5
plt.rcParams['ytick.major.width'] = 1.5
plt.rcParams['xtick.minor.width'] = 1.5
plt.rcParams['ytick.minor.width'] = 1.5
[docs]
def logticks(ticks: list[float], lim: list[float, float]
) -> tuple[list[float], list[str]]:
"""Make nice ticks for a log axis.
Args:
ticks (list): List of ticks.
lim (list): [min, max].
Returns:
tuple: (new ticks, new labels).
"""
order = int(np.floor((np.log10(lim[0]))))
a = (lim[0] // 10**order + 1) * 10**order
a = np.round(a, max(-order, 0))
order = int(np.floor((np.log10(lim[1]))))
b = (lim[1] // 10**order) * 10**order
b = np.round(b, max(-order, 0))
newticks = np.sort(np.unique(np.r_[a, ticks, b]))
newlabels = [str(t if t < 1 else int(t)) for t in newticks]
return newticks, newlabels
[docs]
def logcbticks(vmin: float = 1e-3, vmax: float = 1e3
) -> tuple[np.ndarray, np.ndarray]:
"""Make nice ticks for a log color bar.
Args:
vmin (float, optional): Minimum value. Defaults to 1e-3.
vmax (float, optional): Maximum value. Defaults to 1e3.
Returns:
tuple: (ticks, ticklabels).
"""
i0 = int(np.floor(np.log10(vmin)))
i1 = int(np.ceil(np.log10(vmax)))
ticks = np.outer(np.logspace(i0, i1, i1 - i0 + 1), np.arange(1, 10))
ticklabels = []
for i in range(i0, i1 + 1):
ii = np.abs(min(i, 0))
ii = f'{ii:d}'
for j in range(1, 10):
jj = j * 10**i
if j in [1, 2, 5]:
s = f'{jj:.{ii}f}'
else:
s = ''
ticklabels.append(s)
ticks = np.ravel(ticks)
ticklabels = np.ravel(ticklabels)
cond = (vmin <= ticks) * (ticks <= vmax)
return ticks[cond], ticklabels[cond]
[docs]
def get_figsize(xmin: float, xmax: float, ymin: float, ymax: float,
figsize: tuple | None = None,
ncols: int = 1, nrows: int = 1, nchan: int = 1
) -> tuple[float, float]:
"""Get a nice figsize (tuple) with the given x and y ranges.
Args:
xmin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
xmax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
ymin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
ymax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
figsize (tuple | None, optional): If this is not None, this will be the output as is. Defaults to None.
ncols (int, optional): The number of columns for the channel map. Defaults to 1.
nrows (int, optional): The number of rows for the channel map. Defaults to 1.
nchan (int, optional): The number of total channels for the channel map. Defaults to 1.
Returns:
tuple[float, float]: figsize for matplotlib.pyplot.Figure.
"""
if figsize is not None:
return figsize
sqrt_a = (ymax - ymin) / (xmax - xmin)
sqrt_a = np.sqrt(np.abs(sqrt_a))
if nchan == 1:
figsize = (7 / sqrt_a, 5 * sqrt_a)
else:
figsize = (ncols * 2 / sqrt_a, max(nrows*2, 3) * sqrt_a)
return figsize
def _get_gridwidth(mode: str, rmax: float, cos_dec: float
) -> tuple[float, int]:
# Length in the units of s for R.A. and " for Dec., respectively.
length = 2 * rmax / (15 * cos_dec if mode == 'ra' else 1)
p = np.floor(np.log10(length))
b = length / 10**p
if b <= 2:
base, order = 5, p - 1
elif b <= 4:
base, order = 1, p
else:
base, order = 2, p
return base * 10**order, int(order)
def _get_v(p: Any, v: np.ndarray | None = None,
restfreq: float | None = None,
vskip: int = 1) -> np.ndarray:
if p.fitsimage is not None and v is None:
p.read(d := AstroData(fitsimage=p.fitsimage,
restfreq=restfreq, sigma=None))
v = d.v
if v is None:
v = np.array([0])
if len(v) > 1:
v = reform_grid(v=v, vmin=p.vmin, vmax=p.vmax)
v = v[::vskip]
return v
def _get_nij2ch(nrows: int = 1, ncols: int = 1) -> Callable:
def nij2ch(n: int, i: int, j: int) -> int:
return n*nrows*ncols + i*ncols + j
return nij2ch
def _get_ch2nij(nrows: int = 1, ncols: int = 1) -> Callable:
def ch2nij(ch: int) -> tuple[int, int, int]:
n = ch // (nrows*ncols)
i = (ch - n*nrows*ncols) // ncols
j = ch % ncols
return n, i, j
return ch2nij
def _get_vskipfill(nv: int, v_org: np.ndarray, vskip: int,
channelnumber: int | None) -> Callable:
def vskipfill(c: np.ndarray, v_in: np.ndarray) -> np.ndarray:
c = reform_data(c=c, v_in=v_in, nv=nv, v_org=v_org, vskip=vskip)
if isinstance(channelnumber, int):
c = [c[channelnumber]]
return c
return vskipfill
[docs]
@dataclass
class Stretcher():
"""Arguments and methods related to the stretch in PlotAstroData.add_color() and add_rgb().
Args:
stretch (str, optional): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
stretchscale (float, optional): The output is asinh(data / stretchscale). Defaults to None.
stretchpower (float, optional): The output is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale. Defaults to 0.5.
vmin (float, optional): The minimum value for Axes.pcolormesh() of matplotlib. Defaults to None.
vmax (float, optional): The maximum value for Axes.pcolormesh() of matplotlib. Defaults to None.
sigma (float, optional): Noise level. Defaults to 0.
"""
stretch: str = 'linear'
stretchscale: float | None = None
stretchpower: float = 0.5
vmin: float | None = None
vmax: float | None = None
sigma: float = 0
def __post_init__(self) -> None:
self.n = 1 if isinstance(self.stretch, str) else len(self.stretch)
stretch = self.stretch
stsc = self.stretchscale
vmin = self.vmin
sigma = self.sigma
if self.n == 1:
if stsc is None:
self.stretchscale = sigma
if (stretch == 'log' or stretch == 'power') and vmin is None:
self.vmin = sigma
else:
getsigma = np.equal(stsc, None)
self.stretchscale = np.where(getsigma, sigma, stsc)
islog = np.equal(stretch, 'log')
ispower = np.equal(stretch, 'power')
novmin = np.equal(vmin, None)
getsigma = (islog + ispower) * novmin
self.vmin = np.where(getsigma, sigma, vmin)
[docs]
def do(self, x: list | np.ndarray, i: int = 0) -> np.ndarray:
"""Get the stretched values.
Args:
x (list | np.ndarray): Input array in the linear scale.
i (int): Which element is used in the case where the stretch parameters are lists.
Returns:
np.ndarray: Output stretched array.
"""
st = self.stretch[i] if self.n > 1 else self.stretch
stsc = self.stretchscale[i] if self.n > 1 else self.stretchscale
stpw = self.stretchpower[i] if self.n > 1 else self.stretchpower
t = np.array(x)
match st:
case 'log':
t = np.log10(t) # To be consistent with logcbticks().
case 'asinh':
t = np.arcsinh(t / stsc)
case 'power':
p = 1e-6 if stpw == 0 else stpw
t = t**p / p
return t
[docs]
def undo(self, x: list | np.ndarray, i: int = 0) -> np.ndarray:
"""Get the linear values from the stretched values.
Args:
x (list | np.ndarray): Input stretched array.
i (int): Which element is used in the case where the stretch parameters are lists.
Returns:
np.ndarray: Output array in the linear scale.
"""
st = self.stretch[i] if self.n > 1 else self.stretch
stsc = self.stretchscale[i] if self.n > 1 else self.stretchscale
stpw = self.stretchpower[i] if self.n > 1 else self.stretchpower
t = np.array(x)
match st:
case 'log':
t = 10**t # To be consistent with logcbticks().
case 'asinh':
t = np.sinh(t) * stsc
case 'power':
p = 1e-6 if stpw == 0 else stpw
t = (t * p)**(1 / p)
return t
[docs]
def set_minmax(self, data: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Set vmin and vmax for color pcolormesh and RGB maps.
Args:
data (np.ndarray): 2D/3D data to plot.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: (Clipped stretched data, new vmin, new vmax).
"""
single = self.n == 1
vminout = [self.vmin] if single else self.vmin
vmaxout = [self.vmax] if single else self.vmax
dataout = [data] if single else data
for i, (c, v0, v1) in enumerate(zip(dataout, vminout, vmaxout)):
dataout[i] = cout = self.do(c.clip(v0, v1), i)
vminout[i] = np.nanmin(cout)
vmaxout[i] = np.nanmax(cout)
if single:
dataout = dataout[0]
vminout = vminout[0]
vmaxout = vmaxout[0]
self.vmin = vminout
self.vmax = vmaxout
return dataout, vminout, vmaxout
[docs]
class Beam():
"""Arguments for PlotAstroData.add_beam().
Args:
show_beam (bool, optional): Defaults to True.
beam (list, optional): [bmaj, bmin, bpa]. This may be a list of list. Defaults to [None, None, None].
beamcolor (str, optional): matplotlib color. This may be a list of str. Defaults to 'gray'.
beampos (list, optional): Relative position. This may be a list of list or a list of None. Defaults to None.
beam_kwargs (dict, optional): Additional arguments for matplotlib.patches. Defaults to {}.
"""
def __init__(self,
show_beam: bool = True,
beam: list[float | None] = [None] * 3,
beamcolor: str = 'gray',
beampos: list[float] | None = None,
beam_kwargs: dict = {}) -> None:
self.show_beam = show_beam
self.beam = beam
self.beamcolor = beamcolor
self.beampos = beampos
self.beam_kwargs = beam_kwargs
[docs]
def todict(self) -> dict[str, Any]:
tmp = {'show_beam': self.show_beam,
'beam': self.beam,
'beamcolor': self.beamcolor,
'beampos': self.beampos}
tmp.update(self.beam_kwargs)
return tmp
[docs]
@dataclass
class PlotAxes2D():
"""Use Axes.set_* to adjust x and y axes.
Args:
samexy (bool, optional): True supports same ticks between x and y. Defaults to True.
loglog (float, optional): If a float is given, plot on a log-log plane, and xim=(xmax / loglog, xmax) and so does ylim. Defaults to None.
xscale (str, optional): Defaults to None.
yscale (str, optional): Defaults to None.
xlim (list, optional): Defaults to None.
ylim (list, optional): Defaults to None.
xlabel (str, optional): Defaults to None.
ylabel (str, optional): Defaults to None.
xticks (list, optional): Defaults to None.
yticks (list, optional): Defaults to None.
xticklabels (list, optional): Defaults to None.
yticklabels (list, optional): Defaults to None.
xticksminor (list or int, optional): If int, int times more than xticks. Defaults to None.
yticksminor (list or int, optional): Defaults to None. If int, int times more than xticks. Defaults to None.
grid (dict, optional): True means merely grid(). Defaults to None.
aspect (dict or float, optional): Defaults to None.
"""
samexy: bool = True
loglog: float | None = None
xscale: str = 'linear'
yscale: str = 'linear'
xlim: list | None = None
ylim: list | None = None
xlabel: str | None = None
ylabel: str | None = None
xticks: list | None = None
yticks: list | None = None
xticklabels: list | None = None
yticklabels: list | None = None
xticksminor: list | int | None = None
yticksminor: list | int | None = None
grid: dict | None = None
aspect: dict | float | None = None
def _set_scale(self) -> None:
ax = self.ax
if self.loglog is not None:
self.xscale = self.yscale = 'log'
self.samexy = True
for axis in ['x', 'y']:
attr = f'{axis}lim'
lim = getattr(self, attr)
if lim is not None:
lim[0] = lim[1] / self.loglog
setattr(self, attr, lim)
ax.set_xscale(self.xscale)
ax.set_yscale(self.yscale)
def _init_ticks(self, axis: str) -> None:
ax = self.ax
ticks_attr = f'{axis}ticks'
ticklabels_attr = f'{axis}ticklabels'
scale = getattr(self, f'{axis}scale')
lim = getattr(self, f'{axis}lim')
ticks = getattr(self, ticks_attr)
if ticks is None:
ticks = getattr(ax, f'get_{axis}ticks')()
if scale == 'log':
ticks, ticklabels = logticks(ticks, lim)
setattr(self, ticklabels_attr, ticklabels)
setattr(self, ticks_attr, ticks)
def _make_ticks(self, ticks: np.ndarray, ticksminor: int) -> np.ndarray:
dt = ticks[1] - ticks[0]
t = np.r_[ticks[0] - dt, ticks, ticks[-1] + dt]
num = ticksminor * (len(t) - 1) + 1
return np.linspace(t[0], t[-1], num)
def _set_ticks(self, axis: str) -> None:
ax = self.ax
attr = f'{axis}ticks'
ticks = getattr(self, attr)
getattr(ax, f'set_{attr}')(ticks)
ticksminor = getattr(self, f'{attr}minor')
if ticksminor is not None:
if isinstance(ticksminor, int):
ticksminor = self._make_ticks(ticks, ticksminor)
getattr(ax, f'set_{attr}')(ticksminor, minor=True)
def _apply_if_not_none(self, axis: str, attr: str) -> None:
ax = self.ax
method = getattr(ax, f'set_{axis}{attr}')
value = getattr(self, f'{axis}{attr}')
if value is not None:
if attr == 'lim':
method(*value)
else:
method(value)
[docs]
def set_xyaxes(self, ax: Any) -> None:
self.ax = ax
self._set_scale()
if self.samexy:
ax.set_xticks(ax.get_yticks())
ax.set_yticks(ax.get_xticks())
ax.set_aspect(1)
for axis in ['x', 'y']:
self._init_ticks(axis)
self._set_ticks(axis)
for attr in ['ticklabels', 'label', 'lim']:
self._apply_if_not_none(axis, attr)
if self.grid is not None:
ax.grid(**({} if self.grid is True else self.grid))
if self.aspect is not None:
if isinstance(self.aspect, dict):
ax.set_aspect(**self.aspect)
else:
ax.set_aspect(self.aspect)
[docs]
def kwargs2instance(cls: type[T], kw: dict[str, Any]) -> T:
"""Get an instance and remove its arguments from kwargs.
Args:
cls (class): Class to make the instance.
kw (dict): Parameters to make the instance.
Returns:
instance: an instance of cls made from the parameters in kwargs.
"""
kw0 = {}
if cls == AstroData:
kw0 = {'data': np.zeros((2, 2))}
exkeys = {}
if cls == AstroFrame:
exkeys = {'fitsimage', 'center'}
elif cls == Stretcher:
exkeys = {'vmin', 'vmax'}
elif cls == Beam:
exkeys = {'beam'}
keys = vars(cls(**kw0)).keys()
tmp = {k: kw[k] for k in keys if k in kw}
for k in keys - exkeys:
kw.pop(k, None)
return cls(**tmp)
[docs]
class PlotAstroData(AstroFrame):
"""Make a figure from 2D/3D FITS files or 2D/3D arrays.
Basic rules ---
For 3D data, a 1D velocity array or a FITS file with a velocity axis must be given to set up channels in each page.
For 2D/3D data, the spatial center can be read from a FITS file or manually given.
len(v)=1 (default) means to make a 2D figure.
Spatial lengths are in the unit of arcsec, or au if dist (!= 1) is given.
Angles are in the unit of degree.
For region, line, arrow, label, and marker, a single input can be treated without a list, e.g., anglelist=60, as well as anglelist=[60].
Each element of poslist supposes a text coordinate like '01h23m45.6s 01d23m45.6s' or a list of relative x and y like [0.2, 0.3] (0 is left or bottom, 1 is right or top).
Parameters for original methods in matplotlib.axes.Axes can be used as kwargs; see the default _kw for reference.
Position-velocity diagrams (pv=True) do not yet support region, line, arrow, and segment because the units of abscissa and ordinate are different.
kwargs is the arguments of AstroFrame to define plotting ranges.
Args:
v (np.ndarray, optional): Used to set up channels if fitsimage not given. Defaults to None.
vskip (int, optional): How many channels are skipped. Defaults to 1.
veldigit (int, optional): How many digits after the decimal point. Defaults to 2.
restfreq (float, optional): Used for velocity and brightness T. Defaults to None.
channelnumber (int, optional): Specify a channel number to make 2D maps. Defaults to None.
nrows (int, optional): Used for channel maps. Defaults to 4.
ncols (int, optional): Used for channel maps. Defaults to 6.
fontsize (int, optional): rcParams['font.size']. None means 18 (2D) or 12 (3D). Defaults to None.
nancolor (str, optional): Color for masked regions. Defaults to white.
dpi (int, optional): Dot per inch for plotting an image. Defaults to 256.
figsize (tuple, optional): Defaults to None.
fig (optional): External plt.figure(). Defaults to None.
ax (optional): External fig.add_subplot(). Defaults to None.
"""
def __init__(self,
v: np.ndarray | None = None, vskip: int = 1,
veldigit: int = 2, restfreq: float | None = None,
channelnumber: int | None = None,
nrows: int = 4, ncols: int = 6,
fontsize: int | None = None,
nancolor: str = 'w', dpi: int = 256,
figsize: tuple[float, float] | None = None,
fig: object | None = None, ax: object | None = None,
**kwargs: Any) -> None:
super().__init__(**kwargs)
internalfig = fig is None
internalax = ax is None
animation = isinstance(channelnumber, int)
v = _get_v(p=self, v=v, restfreq=restfreq, vskip=vskip)
nv = len(v) # number of channels with a label
if self.pv or len(v) == 1 or animation:
nrows = ncols = npages = nchan = 1
else:
npages = int(np.ceil(nv / nrows / ncols))
nchan = npages * nrows * ncols
v = reform_grid(v, k1=nchan - nv)
nij2ch = _get_nij2ch(nrows=nrows, ncols=ncols)
ch2nij = _get_ch2nij(nrows=nrows, ncols=ncols)
if fontsize is None:
fontsize = 18 if nchan == 1 else 12
set_rcparams(fontsize=fontsize, nancolor=nancolor, dpi=dpi)
ax = np.empty(nchan, dtype=object) if internalax else [ax]
figsize = get_figsize(xmin=self.xmin, xmax=self.xmax,
ymin=self.ymin, ymax=self.ymax,
figsize=figsize,
ncols=ncols, nrows=nrows, nchan=nchan)
need_vlabel = nchan > 1 or animation
for ch in range(nchan):
n, i, j = ch2nij(ch)
if internalfig and n not in plt.get_fignums():
fig = plt.figure(n, figsize=figsize)
if need_vlabel:
fig.subplots_adjust(hspace=0, wspace=0,
right=0.87, top=0.87)
if internalax:
sharex = ax[nij2ch(n, i - 1, j)] if i > 0 else None
sharey = ax[nij2ch(n, i, j - 1)] if j > 0 else None
ax[ch] = fig.add_subplot(nrows, ncols, i*ncols + j + 1,
sharex=sharex, sharey=sharey)
if need_vlabel and ch < nv:
vlabel = v[channelnumber or ch]
ax[ch].text(0.9 * self.rmax, 0.7 * self.rmax,
rf'${vlabel:.{veldigit}f}$', color='black',
backgroundcolor='white', zorder=20)
self.fig = None if internalfig else fig
self.ax = ax
self.rowcol = nrows * ncols
self.npages = npages
self.allchan = np.arange(nv)
self.bottomleft = nij2ch(np.arange(npages), nrows - 1, 0)
self.channelnumber = channelnumber
self.animation = animation
self.v = v
self.vskipfill = _get_vskipfill(nv, v, vskip, channelnumber)
def _map_init(self, kw: dict[str, Any]) -> tuple:
"""
Common process for add_color, add_contour, add_segment, and add_rgb.
xskip and yskip (int) mean spatial pixel skips, which defaults to 1.
Args:
kw (dict): kwargs input for each method.
Returns:
tuple: Data and parameters used in each method.
"""
b = kwargs2instance(Beam, kw)
self._kw.update(kw)
xskip = self._kw.pop('xskip', 1)
yskip = self._kw.pop('yskip', 1)
d = kwargs2instance(AstroData, self._kw)
self.read(d, xskip, yskip)
self.sigma = d.sigma
singlepix = d.dx is None or d.dy is None
if len(d.beam) == 4:
b.beam = self.beam = next(b for b in d.beam if None not in b)
else:
b.beam = self.beam = d.beam
self.add_beam(**b.todict())
return (d.data, d.x, d.y, d.v, d.sigma, d.bunit,
self._kw, singlepix)
def _validchan(self, include_chan: list[int] | None
) -> np.ndarray | list[int]:
chans = self.allchan if include_chan is None else include_chan
if self.animation:
chans = [0] if self.channelnumber in include_chan else [1]
return chans
[docs]
def add_region(self, patch: str = 'ellipse',
poslist: list[str | list[float, float]] = [],
majlist: list[float] = [], minlist: list[float] = [],
palist: list[float] = [],
include_chan: list[int] | None = None,
**kwargs: Any) -> None:
"""Use add_patch() and Rectangle or Ellipse of matplotlib.
Default keyword values:
Matplotlib patch: ``facecolor='none'``, ``edgecolor='gray'``, ``linewidth=1.5``, and ``zorder=10``. User-supplied keyword arguments override these values.
Args:
patch (str, optional): 'ellipse' or 'rectangle'. Defaults to 'ellipse'.
poslist (list, optional): Text or relative center. Defaults to [].
majlist (list, optional): Ellipse major axis. Defaults to [].
minlist (list, optional): Ellipse minor axis. Defaults to [].
palist (list, optional): Position angle (north to east). Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'facecolor': 'none', 'edgecolor': 'gray',
'linewidth': 1.5, 'zorder': 10}
_kw.update(kwargs)
if patch not in ['rectangle', 'ellipse']:
print('Only patch=\'rectangle\' or \'ellipse\' supported. ')
return
z = listing(*self.pos2xy(poslist), minlist, majlist, palist)
for x, y, width, height, angle in zip(*z):
for ch, axnow in enumerate(self.ax):
if ch not in self._validchan(include_chan):
continue
if self.fig is None:
plt.figure(ch // self.rowcol)
if patch == 'rectangle':
a = np.radians(angle)
xp = x - (width*np.cos(a) + height*np.sin(a)) / 2.
yp = y - (-width*np.sin(a) + height*np.cos(a)) / 2.
p = Rectangle
else:
xp, yp = x, y
p = Ellipse
p = p((xp, yp), width=width, height=height,
angle=angle * self.xdir, **_kw)
axnow.add_patch(p)
[docs]
def add_beam(self, **kwargs: Any) -> None:
"""Use add_region() to plot the beam.
kwargs may include the arguments of Beam, except for beam_kwargs, to specify the beam appearance. Those arguments may be a list of each format.
Default keyword values:
Beam patch: ``facecolor=beamcolor`` and ``edgecolor=None``. Other keyword arguments override these values and are passed to ``add_region``.
"""
b = kwargs2instance(Beam, kwargs)
show_beam, beamcolor, beampos = b.show_beam, b.beamcolor, b.beampos
beam = b.beam
del kwargs['beam']
if not show_beam:
return
include_chan = self.allchan if self.animation else self.bottomleft
patch = 'rectangle' if self.pv else 'ellipse'
blist = [beam] if np.ndim(beam) == 1 else beam
n = len(blist)
bclist = beamcolor if isinstance(beamcolor, list) else [beamcolor] * n
islist = beampos == [None] * 3 or np.ndim(beampos) == 2
bplist = beampos if islist else [beampos] * n
for (bmaj, bmin, bpa), bc, bp in zip(blist, bclist, bplist):
if None in [bmaj, bmin, bpa]:
print('No beam to plot.')
continue
a = max(0.35 * bmaj / self.rmax, 0.1)
if bp is None:
bp = [a, 0.1 if self.pv else a]
if self.swapxy:
bp = np.transpose(bp)
bpa = 90 - bpa
_kw = {'facecolor': bc, 'edgecolor': None}
_kw.update(kwargs)
self.add_region(patch=patch, poslist=bp,
majlist=bmaj, minlist=bmin, palist=bpa,
include_chan=include_chan,
**_kw)
[docs]
def add_marker(self, poslist: list[str | list[float, float]] = [],
include_chan: list[int] | None = None,
**kwargs: Any) -> None:
"""Use Axes.plot of matplotlib.
Default keyword values:
Matplotlib: ``marker='+'``, ``ms=10``, ``mfc='gray'``, ``mec='gray'``, ``mew=2``, ``alpha=1``, and ``zorder=10``. User-supplied keyword arguments override these values.
Args:
poslist (list, optional): Text or relative. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'marker': '+', 'ms': 10, 'mfc': 'gray',
'mec': 'gray', 'mew': 2, 'alpha': 1, 'zorder': 10}
_kw.update(kwargs)
for ch, axnow in enumerate(self.ax):
if ch not in self._validchan(include_chan):
continue
for x, y in zip(*self.pos2xy(poslist)):
axnow.plot(x, y, **_kw)
[docs]
def add_text(self, poslist: list[str | list[float, float]] = [],
slist: list[str] = [],
include_chan: list[int] | None = None,
**kwargs: Any) -> None:
"""Use Axes.text of matplotlib.
Default keyword values:
Matplotlib: ``color='gray'``, ``fontsize=15``, ``ha='center'``, ``va='center'``, and ``zorder=10``. User-supplied keyword arguments override these values.
Aliases: ``horizontalalignment`` and ``verticalalignment`` are accepted as aliases for ``ha`` and ``va``.
Args:
poslist (list, optional): Text or relative. Defaults to [].
slist (list, optional): List of text. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'color': 'gray', 'fontsize': 15, 'ha': 'center',
'va': 'center', 'zorder': 10}
subkeys = {'ha': 'horizontalalignment',
'va': 'verticalalignment'}
for short, long in subkeys.items():
if long in kwargs:
kwargs[short] = kwargs.pop(long)
_kw.update(kwargs)
for ch, axnow in enumerate(self.ax):
if ch not in self._validchan(include_chan):
continue
z = listing(*self.pos2xy(poslist), slist)
for x, y, s in zip(*z):
axnow.text(x=x, y=y, s=s, **_kw)
[docs]
def add_line(self, poslist: list[str | list[float, float]] = [],
anglelist: list[float] = [],
rlist: list[float] = [],
include_chan: list[int] | None = None,
**kwargs: Any) -> None:
"""Use Axes.plot of matplotlib.
Default keyword values:
Matplotlib: ``color='gray'``, ``linewidth=1.5``, ``linestyle='-'``, and ``zorder=10``. User-supplied keyword arguments override these values.
Args:
poslist (list, optional): Text or relative. Defaults to [].
anglelist (list, optional): North to east. Defaults to [].
rlist (list, optional): List of radius. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'color': 'gray', 'linewidth': 1.5,
'linestyle': '-', 'zorder': 10}
_kw.update(kwargs)
for ch, axnow in enumerate(self.ax):
if ch not in self._validchan(include_chan):
continue
alist = np.radians(anglelist)
z = listing(*self.pos2xy(poslist), alist, rlist)
for x, y, a, r in zip(*z):
axnow.plot([x, x + r * np.sin(a)],
[y, y + r * np.cos(a)], **_kw)
[docs]
def add_arrow(self, poslist: list[str | list[float, float]] = [],
anglelist: list[float] = [],
rlist: list[float] = [],
include_chan: list[int] | None = None,
**kwargs: Any) -> None:
"""Use Axes.quiver of matplotlib.
Default keyword values:
Matplotlib: ``color='gray'``, ``width=0.012``, ``headwidth=5``, ``headlength=5``, and ``zorder=10``. User-supplied keyword arguments override these values.
Args:
poslist (list, optional): Text or relative. Defaults to [].
anglelist (list, optional): North to east. Defaults to [].
rlist (list, optional): List of radius. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'color': 'gray', 'width': 0.012,
'headwidth': 5, 'headlength': 5, 'zorder': 10}
_kw.update(kwargs)
for ch, axnow in enumerate(self.ax):
if ch not in self._validchan(include_chan):
continue
alist = np.radians(anglelist)
z = listing(*self.pos2xy(poslist), alist, rlist)
for x, y, a, r in zip(*z):
axnow.quiver(x, y, r * np.sin(a), r * np.cos(a),
angles='xy', scale_units='xy', scale=1,
**_kw)
[docs]
def add_scalebar(self, length: float = 0, label: str = '',
color: str = 'gray', barpos: tuple[float, float] = (0.8, 0.12),
fontsize: float | None = None, linewidth: float = 3,
bbox: dict = {'alpha': 0}) -> None:
"""Use Axes.text and Axes.plot of matplotlib.
Args:
length (float, optional): In the unit of arcsec. Defaults to 0.
label (str, optional): Text like '100 au'. Defaults to ''.
color (str, optional): Same for bar and label. Defaults to 'gray'.
barpos (tuple, optional): Relative position. Defaults to (0.8, 0.12).
fontsize (float, optional): None means 15 if one channel else 20. Defaults to None.
linewidth (float, optional): Width of the bar. Defaults to 3.
bbox (dict, optional): Keyword arguments for the text bounding box. Defaults to {'alpha': 0}.
"""
if length == 0:
print('No length is given. Skip add_scalebar().')
return
if fontsize is None:
fontsize = 20 if len(self.ax) == 1 else 15
for ch, axnow in enumerate(self.ax):
if ch not in self.bottomleft:
continue
x, y = self.pos2xy([barpos[0], barpos[1] - 0.012])
axnow.text(x[0], y[0], label, color=color, size=fontsize,
ha='center', va='top', bbox=bbox, zorder=10)
x, y = self.pos2xy([barpos[0], barpos[1] + 0.012])
axnow.plot([x[0] - length/2., x[0] + length/2.], [y[0], y[0]],
'-', linewidth=linewidth, color=color)
def _set_colorbar(self, mappable: list[Any], ch: int, show_cbar: bool,
cblabel: str, cbformat: str,
cbticks: list | None, cbticklabels: list | None,
cblocation: str,
cblabelfontsize: int, cbtickfontsize: int,
st: Stretcher) -> None:
if not show_cbar:
return
if self.fig is None:
fig = plt.figure(ch // self.rowcol)
else:
fig = self.fig
if len(self.ax) == 1:
ax = self.ax[ch]
cb = fig.colorbar(mappable[ch], ax=ax, label=cblabel,
format=cbformat, location=cblocation)
else:
cax = plt.axes([0.88, 0.105, 0.015, 0.77])
cb = fig.colorbar(mappable[ch], cax=cax, label=cblabel,
format=cbformat)
cb.ax.tick_params(labelsize=cbtickfontsize)
font = mpl.font_manager.FontProperties(size=cblabelfontsize)
cb.ax.yaxis.label.set_font_properties(font)
if cbticks is None and st.stretch == 'log':
cbticks, cbticklabels = logcbticks(10**st.vmin, 10**st.vmax)
cbticks = cb.get_ticks() if cbticks is None else st.do(cbticks)
cond = (st.vmin <= cbticks) * (cbticks <= st.vmax)
cbticks = cbticks[cond]
cb.set_ticks(cbticks)
if cbticklabels is None:
cbticklabels = [f'{t:{cbformat[1:]}}' for t in st.undo(cbticks)]
else:
cbticklabels = np.array(cbticklabels)[cond]
cb.set_ticklabels(cbticklabels)
[docs]
def add_color(self,
show_cbar: bool = True,
cblabel: str | None = None,
cbformat: str = '%.1e',
cbticks: list[float] | None = None,
cbticklabels: list[str] | None = None,
cblocation: str = 'right',
cblabelfontsize: int = 16,
cbtickfontsize: int = 14,
**kwargs: Any) -> None:
"""Use Axes.pcolormesh of matplotlib.
Keyword groups accepted in ``**kwargs``:
AstroData: Data input and metadata, such as ``fitsimage``, ``data``, ``x``, ``y``, ``v``, ``beam``, ``Tb``, ``sigma``, ``center``, ``restfreq``, ``cfactor``, ``pvpa``, ``pv``, and ``bunit``.
Stretcher: Color scaling, such as ``stretch``, ``stretchscale``, ``stretchpower``, ``vmin``, and ``vmax``.
Beam: Beam display, such as ``show_beam``, ``beamcolor``, ``beampos``, and ``beam_kwargs``.
Sampling: ``xskip`` and ``yskip``.
Matplotlib: Additional keyword arguments are passed to ``matplotlib.axes.Axes.pcolormesh``.
Default keyword values:
Matplotlib: ``cmap='cubehelix'``, ``alpha=1``, ``edgecolors='none'``, ``zorder=1``, ``vmin=None``, and ``vmax=None``. User-supplied keyword arguments override these values.
Args:
show_cbar (bool, optional): Show color bar. Defaults to True.
cblabel (str, optional): Colorbar label. Defaults to None.
cbformat (str, optional): Format for ticklabels of colorbar. Defaults to '%.1e'.
cbticks (list, optional): Ticks of colorbar. Defaults to None.
cbticklabels (list, optional): Ticklabels of colorbar. Defaults to None.
cblocation (str, optional): 'left', 'top', 'left', 'right'. Only for 2D images. Defaults to 'right'.
cblabelfontsize (int, optional): Fontsize for the colorbar label. This is independent of set_rcparams().
cbtickfontsize (int, optional): Fontsize for the colorbar ticks. This is independent of set_rcparams().
"""
self._kw = {'cmap': 'cubehelix', 'alpha': 1,
'edgecolors': 'none', 'zorder': 1,
'vmin': None, 'vmax': None}
c, x, y, v, sigma, bunit, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_color.')
return
cblabel = bunit if cblabel is None else cblabel
_kw['sigma'] = sigma
st = kwargs2instance(Stretcher, _kw)
c, cmin, cmax = st.set_minmax(c)
_kw['vmin'] = cmin
_kw['vmax'] = cmax
c = self.vskipfill(c, v)
p = [None] * len(self.ax)
for ch, (axnow, cnow) in enumerate(zip(self.ax, c)):
pnow = axnow.pcolormesh(x, y, cnow, **_kw)
if ch in self.bottomleft:
p[ch] = pnow
for ch in self.bottomleft:
self._set_colorbar(p, ch, show_cbar, cblabel, cbformat,
cbticks, cbticklabels, cblocation,
cblabelfontsize, cbtickfontsize, st)
[docs]
def add_contour(self,
levels: list[float] = [-12, -6, -3, 3, 6, 12, 24, 48, 96, 192, 384],
**kwargs: Any) -> None:
"""Use Axes.contour of matplotlib.
Keyword groups accepted in ``**kwargs``:
AstroData: Data input and metadata, such as ``fitsimage``, ``data``, ``x``, ``y``, ``v``, ``beam``, ``Tb``, ``sigma``, ``center``, ``restfreq``, ``cfactor``, ``pvpa``, ``pv``, and ``bunit``.
Beam: Beam display, such as ``show_beam``, ``beamcolor``, ``beampos``, and ``beam_kwargs``.
Sampling: ``xskip`` and ``yskip``.
Matplotlib: Additional keyword arguments are passed to ``matplotlib.axes.Axes.contour``.
Default keyword values:
Matplotlib: ``colors='gray'``, ``linewidths=1.0``, and ``zorder=2``. User-supplied keyword arguments override these values.
Args:
levels (list, optional): Contour levels in the unit of sigma. Defaults to [-12,-6,-3,3,6,12,24,48,96,192,384].
"""
self._kw = {'colors': 'gray', 'linewidths': 1.0, 'zorder': 2}
c, x, y, v, sigma, _, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_contour.')
return
c = self.vskipfill(c, v)
for axnow, cnow in zip(self.ax, c):
axnow.contour(x, y, cnow, np.sort(levels) * sigma, **_kw)
[docs]
def add_segment(self,
ampfits: str | None = None, angfits: str | None = None,
Ufits: str | None = None, Qfits: str | None = None,
amp: list[np.ndarray] | None = None,
ang: list[np.ndarray] | None = None,
stU: list[np.ndarray] | None = None,
stQ: list[np.ndarray] | None = None,
ampfactor: float = 1., angonly: bool = False,
rotation: float = 0.,
cutoff: float = 3.,
**kwargs: Any) -> None:
"""Use Axes.quiver of matplotlib.
``fitsimage`` is built from ``[ampfits, angfits, Ufits, Qfits]``, and ``data`` is built from ``[amp, ang, stU, stQ]``.
Keyword groups accepted in ``**kwargs``:
AstroData: Data input and metadata, such as ``x``, ``y``, ``v``, ``beam``, ``Tb``, ``sigma``, ``center``, ``restfreq``, ``cfactor``, ``pvpa``, ``pv``, and ``bunit``. The ``fitsimage`` and ``data`` arguments are assembled from the segment-specific arguments above.
Beam: Beam display, such as ``show_beam``, ``beamcolor``, ``beampos``, and ``beam_kwargs``.
Sampling: ``xskip`` and ``yskip``.
Matplotlib: Additional keyword arguments are passed to ``matplotlib.axes.Axes.quiver``.
Default keyword values:
Matplotlib: ``angles='xy'``, ``scale_units='xy'``, ``color='gray'``, ``pivot='mid'``, ``headwidth=0``, ``headlength=0``, ``headaxislength=0``, ``width=0.007``, and ``zorder=3``. User-supplied keyword arguments override these values.
Args:
ampfits (str, optional): Input FITS file name. Length of segment. Defaults to None.
angfits (str, optional): Input FITS file name. North to east. Defaults to None.
Ufits (str, optional): Input FITS file name. Stokes U. Defaults to None.
Qfits (str, optional): Input FITS file name. Stokes Q. Defaults to None.
amp (list, optional): Length of segment. Defaults to None.
ang (list, optional): North to east. Defaults to None.
stU (list, optional): Stokes U. Defaults to None.
stQ (list, optional): Stokes Q. Defaults to None.
ampfactor (float, optional): Length of segment is amp times ampfactor. Defaults to 1..
angonly (bool, optional): True means amp=1 for all. Defaults to False.
rotation (float, optional): Segment angle is ang + rotation. Defaults to 0..
cutoff (float, optional): Used when amp and ang are calculated from Stokes U and Q. In the unit of sigma. Defaults to 3..
"""
self._kw = {'angles': 'xy', 'scale_units': 'xy', 'color': 'gray',
'pivot': 'mid', 'headwidth': 0, 'headlength': 0,
'headaxislength': 0, 'width': 0.007, 'zorder': 3,
'fitsimage': [ampfits, angfits, Ufits, Qfits],
'data': [amp, ang, stU, stQ]}
c, x, y, v, sigma, _, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_segment.')
return
amp, ang, stU, stQ = c
sigmaU, sigmaQ = sigma[2:]
if stU is not None and stQ is not None:
self.sigma = sigma = (sigmaU + sigmaQ) / 2.
ang = np.degrees(np.arctan2(stU, stQ) / 2.)
amp = np.hypot(stU, stQ)
amp[amp < cutoff * sigma] = np.nan
if amp is None:
amp = np.ones_like(ang)
if angonly:
amp = np.sign(amp)**2
amp = amp / np.nanmax(amp)
U = ampfactor * amp * np.sin(np.radians(ang + rotation))
V = ampfactor * amp * np.cos(np.radians(ang + rotation))
U = self.vskipfill(U, v)
V = self.vskipfill(V, v)
_kw['scale'] = 1. / np.abs(x[1] - x[0])
for axnow, unow, vnow in zip(self.ax, U, V):
axnow.quiver(x, y, unow, vnow, **_kw)
[docs]
def add_rgb(self, **kwargs: Any) -> None:
"""Use PIL.Image and imshow of matplotlib.
A three-element array ([red, green, blue]) is expected for most data, stretch, and beam arguments, including ``vmin`` and ``vmax``.
``xskip``, ``yskip``, and ``show_beam`` are single values.
Keyword groups accepted in ``**kwargs``:
AstroData: Data input and metadata, such as ``fitsimage``, ``data``, ``x``, ``y``, ``v``, ``beam``, ``Tb``, ``sigma``, ``center``, ``restfreq``, ``cfactor``, ``pvpa``, ``pv``, and ``bunit``.
Stretcher: RGB scaling, such as ``stretch``, ``stretchscale``, ``stretchpower``, ``vmin``, and ``vmax``.
Beam: Beam display, such as ``show_beam``, ``beamcolor``, ``beampos``, and ``beam_kwargs``.
Sampling: ``xskip`` and ``yskip``.
Matplotlib: Additional keyword arguments are passed to ``matplotlib.axes.Axes.imshow``.
Default keyword values:
Stretcher: ``vmin=[None] * 3``, ``vmax=[None] * 3``, ``stretch=['linear'] * 3``, ``stretchscale=[None] * 3``, and ``stretchpower=[0.5] * 3``. User-supplied keyword arguments override these values.
"""
from PIL import Image
self._kw = {'vmin': [None] * 3, 'vmax': [None] * 3,
'stretch': ['linear'] * 3,
'stretchscale': [None] * 3,
'stretchpower': [0.5] * 3}
c, x, y, v, _, _, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_rgb.')
return
if not (np.shape(c[0]) == np.shape(c[1]) == np.shape(c[2])):
print('RGB shapes mismatch. Skip add_rgb.')
return
st = kwargs2instance(Stretcher, _kw)
c, cmin, cmax = st.set_minmax(c)
for i in range(st.n):
if cmax[i] > cmin[i]:
c[i] = (c[i] - cmin[i]) / (cmax[i] - cmin[i]) * 255
c[i] = self.vskipfill(c[i], v)
size = np.shape(c[0][0])[::-1]
c = np.moveaxis(c, 1, 0)[:, :, ::-self.ydir, ::-self.xdir]
for axnow, rgb in zip(self.ax, c):
im = Image.new('RGB', size, (128, 128, 128))
for j in range(size[1]):
for i in range(size[0]):
value = tuple(int(a[j, i]) for a in rgb)
im.putpixel((i, j), value)
axnow.imshow(im, extent=[x[0], x[-1], y[0], y[-1]])
axnow.set_aspect(np.abs((x[-1]-x[0]) / (y[-1]-y[0])))
def _set_axis_shared(self, pa2: PlotAxes2D,
title: dict | str | None) -> None:
"""Internal method used in set_axis() and set_axis_radec().
Args:
pa2 (PlotAxes2D): This is instantiated in set_axis() or set_axis_radec().
title (dict | str | None): str means set_title(str) for 2D or fig.suptitle(str) for 3D. Defaults to None.
"""
for ch, axnow in enumerate(self.ax):
pa2.set_xyaxes(axnow)
self.Xlim = pa2.xlim
self.Ylim = pa2.ylim
if ch not in self.bottomleft:
plt.setp(axnow.get_xticklabels(), visible=False)
plt.setp(axnow.get_yticklabels(), visible=False)
axnow.set_xlabel('')
axnow.set_ylabel('')
if len(self.ax) == 1:
if self.fig is None:
plt.figure(0).tight_layout()
if title is not None:
if len(self.ax) > 1:
t = {'y': 0.9}
t_in = {'t': title} if isinstance(title, str) else title
t.update(t_in)
for i in range(self.npages):
fig = plt.figure(i)
fig.suptitle(**t)
else:
t = {'label': title} if isinstance(title, str) else title
axnow.set_title(**t)
[docs]
def set_axis(self, title: dict | str | None = None,
**kwargs: Any) -> None:
"""Use Axes.set_* of matplotlib. kwargs can include the arguments of PlotAxes2D to adjust x and y axis.
Default keyword values:
PlotAxes2D: ``xlabel``, ``ylabel``, ``xlim``, and ``ylim`` are filled from the current frame when they are omitted. In PV mode, ``samexy=False`` is also set internally.
Args:
title (dict | str | None): str means set_title(str) for 2D or fig.suptitle(str) for 3D. Defaults to None.
"""
_kw = {}
_kw.update(kwargs)
offunit = '(arcsec)' if self.dist == 1 else '(au)'
if self.pv:
offlabel = f'Offset {offunit}'
vellabel = r'Velocity (km s$^{-1})$'
if 'xlabel' not in _kw:
_kw['xlabel'] = vellabel if self.swapxy else offlabel
if 'ylabel' not in _kw:
_kw['ylabel'] = offlabel if self.swapxy else vellabel
_kw['samexy'] = False
else:
ralabel, declabel = f'R.A. {offunit}', f'Dec. {offunit}'
if 'xlabel' not in _kw:
_kw['xlabel'] = declabel if self.swapxy else ralabel
if 'ylabel' not in _kw:
_kw['ylabel'] = ralabel if self.swapxy else declabel
if 'xlim' not in _kw:
_kw['xlim'] = self.Xlim
if 'ylim' not in _kw:
_kw['ylim'] = self.Ylim
pa2 = kwargs2instance(PlotAxes2D, _kw)
self._set_axis_shared(pa2=pa2, title=title)
[docs]
def set_axis_radec(self, center: str | None = None,
xlabel: str = 'R.A. (ICRS)',
ylabel: str = 'Dec. (ICRS)',
nticksminor: int = 2,
grid: dict | None = None, title: dict | None = None
) -> None:
"""Use Axes.set_* of matplotlib. kwargs can include the arguments of PlotAxes2D to adjust x and y axis.
Args:
center (str, optional): Defaults to None, initial one.
xlabel (str, optional): Defaults to 'R.A. (ICRS)'.
ylabel (str, optional): Defaults to 'Dec. (ICRS)'.
nticksminor (int, optional): Interval ratio of major and minor ticks. Defaults to 2.
grid (dict, optional): True means merely grid(). Defaults to None.
title (dict | str | None): str means set_title(str) for 2D or fig.suptitle(str) for 3D. Defaults to None.
"""
if center is None:
center = self.center
if center is None:
center = '00h00m00s 00d00m00s'
if len(csplit := center.split()) == 3:
center = f'{csplit[1]} {csplit[2]}'
if on_min_scale := (self.rmax >= 60.0):
# On a 5-second grid.
ra_s = np.floor(float(get_sec(center, 0)) / 5) * 5
dec_s = 0.0
ra = get_hmdm(center, 'ra') + f'{ra_s:.1f}s'
dec = get_hmdm(center, 'dec') + f'{dec_s:.1f}s'
center = f'{ra} {dec}'
def get_tickvalues(ticks: np.ndarray, mode: str, no_sec: bool
) -> np.ndarray:
xy = [np.zeros_like(ticks), ticks / 3600.]
if mode == 'ra':
xy.reverse()
tickvalues = xy2coord(xy, center)
getter = get_min if no_sec else get_sec
tickvalues = [getter(t, mode) for t in tickvalues] # str
tickvalues = np.array(tickvalues, dtype=float)
# 7-digit precision for practical use.
tickvalues = np.round(tickvalues, 7)
return tickvalues
units = {'ra': {'h': r'$^\mathrm{h}$',
'm': r'$^\mathrm{m}$',
's': r'.$\hspace{-0.4}^\mathrm{s}$'},
'dec': {'d': r'$^{\circ}$',
'm': r'$^{\prime}$',
's': r'.$\hspace{-0.4}^{\prime\prime}$'}}
dec_center = coord2xy(center)[1]
sign_dec = np.sign(dec_center)
cos_dec = np.cos(np.radians(dec_center))
intgrid = np.array([-3, -2, -1, 0, 1, 2, 3])
def makegrid(mode: str) -> tuple[np.ndarray, np.ndarray, list[str]]:
second = float(get_sec(center, mode))
no_sec = on_min_scale and (mode == 'dec')
# gridwidth is a float like 2 x 10^order (arcsec).
gridwidth, order = _get_gridwidth(mode, self.rmax, cos_dec)
# ndigits = -1 is the largest case for 10", 20", ...
decimals = str(max(-order, 0))
rounded = round(second, ndigits=max(-order, -1))
# Get a grid point closest to the input second.
rounded = round(rounded / gridwidth) * gridwidth
factor = 15 * cos_dec if mode == 'ra' else sign_dec
ticks = (intgrid * gridwidth - second + rounded) * factor
ticksminor = np.linspace(ticks[0], ticks[-1], 6*nticksminor + 1)
tickvalues = get_tickvalues(ticks, mode, no_sec)
whole, frac = np.divmod(tickvalues, 1)
u = units[mode]['m' if no_sec else 's']
ticklabels = [f'{int(i):02d}{u}' + f'{j:.{decimals}f}'[2:]
for i, j in zip(whole % 60, frac)]
return ticks, ticksminor, ticklabels
xticks, xticksminor, xticklabels = makegrid('ra')
yticks, yticksminor, yticklabels = makegrid('dec')
i_ref = np.where(np.abs(xticks) < self.rmax)[0][-1]
idx_top = -1 if sign_dec > 0 else 0
j_ref = np.where(np.abs(yticks) < self.rmax)[0][idx_top]
ra_hm = get_hmdm(xy2coord([xticks[i_ref] / 3600., 0], center), 'ra')
dec_dm = get_hmdm(xy2coord([0, yticks[j_ref] / 3600.], center), 'dec')
if on_min_scale:
dec_dm = dec_dm.split('d')[0] + 'd'
ra_hm = ra_hm.translate(str.maketrans(units['ra']))
dec_dm = dec_dm.translate(str.maketrans(units['dec']))
textpad = ' ' * 12 # To shift the tick label to left.
xticklabels[i_ref] = ra_hm + xticklabels[i_ref] + textpad
yticklabels[j_ref] = dec_dm + '\n' + yticklabels[j_ref]
pa2 = PlotAxes2D(True, None, 'linear', 'linear',
self.Xlim, self.Ylim, xlabel, ylabel,
xticks, yticks, xticklabels, yticklabels,
xticksminor, yticksminor, grid)
self._set_axis_shared(pa2=pa2, title=title)
[docs]
def savefig(self, filename: str | None = None,
show: bool = False, **kwargs: Any) -> None:
"""Use savefig of matplotlib.
If ``filename`` is provided, existing files with the same name are overwritten by Matplotlib. This method closes all Matplotlib figures with ``plt.close('all')`` after optional saving/showing.
Default keyword values:
Figure.savefig: ``transparent=True`` and ``bbox_inches='tight'``. User-supplied keyword arguments override these values.
Args:
filename (str, optional): Output image file name. Existing files may be overwritten, and all Matplotlib figures are closed after saving/showing. Defaults to None.
show (bool, optional): True means doing plt.show(). Defaults to False.
"""
_kw = {'transparent': True, 'bbox_inches': 'tight'}
_kw.update(kwargs)
for axnow in self.ax:
axnow.set_xlim(*self.Xlim)
axnow.set_ylim(*self.Ylim)
if isinstance(filename, str):
ext = filename.split('.')[-1]
for i in range(self.npages):
ver = '' if self.npages == 1 else f'_{i:d}'
fig = plt.figure(i)
fig.patch.set_alpha(0)
fname = filename.replace(f'.{ext}', f'{ver}.{ext}')
fig.savefig(fname, **_kw)
if show:
plt.show()
plt.close('all')
[docs]
def get_figax(self) -> tuple[object, object] | None:
"""Output the external fig and ax after plotting.
Returns:
tuple: (fig, ax)
"""
if len(self.ax) > 1:
print('PlotAstroData.get_figax() is not supported'
+ ' with channel maps')
return
fig = plt.figure(0) if self.fig is None else self.fig
return fig, self.ax[0]
def _get_ylabel_profile(_kw: dict, Tb: bool, flux: bool, bunit: str
) -> str:
if 'ylabel' in _kw:
return _kw['ylabel']
if Tb:
return r'$T_b$ (K)'
if flux:
return 'Flux (Jy)'
return bunit
def _prep_plotprofile(width: int, coords: list | str,
xlist: list, ylist: list, ellipse: list,
ninterp: int, flux: bool, gaussfit: bool,
_kw: dict) -> tuple:
if isinstance(coords, str):
coords = [coords]
Tb = _kw.get('Tb', False)
f = kwargs2instance(AstroFrame, _kw)
d = kwargs2instance(AstroData, _kw)
f.read(d)
d.binning([width, 1, 1])
v, prof, gfitres = d.profile(coords=coords, xlist=xlist, ylist=ylist,
ellipse=ellipse, ninterp=ninterp,
flux=flux, gaussfit=gaussfit)
ylabel = _get_ylabel_profile(_kw, Tb, flux, d.bunit)
if isinstance(ylabel, str):
ylabel = [ylabel] * len(prof)
_kw.setdefault('xlim', [v.min(), v.max()])
pa2 = kwargs2instance(PlotAxes2D, _kw)
return v, prof, gfitres, pa2, ylabel
def _set_figax_plotprofile(fig: object | None, ax: object | None,
nrows: int, ncols: int,
nprof: int) -> tuple:
if ncols == 1:
nrows = nprof
if fig is None:
fig = plt.figure(figsize=(6 * ncols, 3 * nrows))
if nprof > 1 and ax is not None:
print('External ax is supported only when len(coords)=1.')
ax = None
ax = np.empty(nprof, dtype=object) if ax is None else [ax]
for i in range(nprof):
sharex = None if i < nrows - 1 else ax[i - 1]
ax[i] = fig.add_subplot(nrows, ncols, i + 1, sharex=sharex)
return fig, ax
[docs]
def plotprofile(coords: list[str] | str = [],
xlist: list[float] = [], ylist: list[float] = [],
ellipse: list[float, float, float] | None = None,
ninterp: int = 1,
flux: bool = False, width: int = 1,
gaussfit: bool = False, gauss_kwargs: dict = {},
title: list[str] | None = None,
text: list[str] | None = None,
nrows: int = 0, ncols: int = 1,
fig: object | None = None, ax: object | None = None,
getfigax: bool = False,
savefig: dict | str | None = None, show: bool = False,
**kwargs: Any) -> tuple[object, object] | None:
"""Plot line profiles extracted from a spectral cube.
Keyword groups accepted in ``**kwargs``:
AstroData: Data input and metadata for the cube, such as ``fitsimage``, ``data``, ``x``, ``y``, ``v``, ``beam``, ``Tb``, ``sigma``, ``center``, ``restfreq``, ``cfactor``, and ``bunit``.
AstroFrame: Data trimming and coordinate-frame options, such as ``rmax``, ``center``, ``dist``, ``xoff``, ``yoff``, ``vsys``, ``vmin``, and ``vmax``.
PlotAxes2D: Axis formatting, such as ``xlim``, ``ylim``, ``xlabel``, ``ylabel``, ``grid``, ``xscale``, ``yscale``, and tick options.
Matplotlib: Additional keyword arguments are passed to ``matplotlib.axes.Axes.plot`` for the profile curve.
Default keyword values:
Profile curve: ``drawstyle='steps-mid'``, ``color='k'``, ``xlabel=r'Velocity (km s$^{-1}$)'``, and ``samexy=False``. User-supplied keyword arguments override these values.
Gaussian overlay: ``drawstyle='default'`` and ``color='g'``. User-supplied values in ``gauss_kwargs`` override these values.
Args:
coords (list, optional): Coordinates. Defaults to [].
xlist (list, optional): Offset from the center. Defaults to [].
ylist (list, optional): Offset from the center. Defaults to [].
ellipse (list, optional): [major, minor, pa], For average. Defaults to None.
ninterp (int, optional): Number of points for interpolation. Defaults to 1.
flux (bool, optional): y axis is flux density. Defaults to False.
width (int, optional): Rebinning step along v. Defaults to 1.
gaussfit (bool, optional): Fit the profiles. Defaults to False.
gauss_kwargs (dict, optional): Kwargs for Axes.plot. Defaults to {}.
title (list, optional): For each plot. Defaults to None.
text (list, optional): For each plot. Defaults to None.
nrows (int, optional): Used for channel maps. Defaults to 0.
ncols (int, optional): Used for channel maps. Defaults to 1.
fig (object, optional): External plt.figure(). Defaults to None.
ax (object, optional): External fig.add_subplot(). Defaults to None.
getfigax (bool, optional): Defaults to False.
savefig (dict or str, optional): Passed to ``close_figure``. Existing files may be overwritten, and the figure is closed after saving/showing. Defaults to None.
show (bool, optional): True means doing plt.show(). Defaults to False.
Returns:
tuple: (fig, ax), where ax is a list, if getfigax=True. Otherwise, no return.
"""
_kw = {'drawstyle': 'steps-mid', 'color': 'k',
'xlabel': r'Velocity (km s$^{-1}$)', 'samexy': False}
_kw.update(kwargs)
_kwgauss = {'drawstyle': 'default', 'color': 'g'}
_kwgauss.update(gauss_kwargs)
v, prof, gfitres, pa2, ylabel \
= _prep_plotprofile(width, coords, xlist, ylist, ellipse,
ninterp, flux, gaussfit, _kw)
nprof = len(prof)
set_rcparams(20, 'w')
fig, ax = _set_figax_plotprofile(fig, ax, nrows, ncols, nprof)
for i in range(nprof):
if gaussfit:
ax[i].plot(v, gaussian1d(v, *gfitres['best'][i]), **_kwgauss)
ax[i].plot(v, prof[i], **_kw)
ax[i].hlines([0], v.min(), v.max(), linestyle='dashed', color='k')
ax[i].set_ylabel(ylabel[i])
pa2.set_xyaxes(ax[i])
if text is not None:
ax[i].text(**text[i])
if title is not None:
if isinstance(title[i], str):
title[i] = {'label': title[i]}
ax[i].set_title(**title[i])
if i <= nprof - ncols - 1:
plt.setp(ax[i].get_xticklabels(), visible=False)
if getfigax:
return fig, ax
close_figure(fig, savefig, show)
[docs]
def plotslice(length: float, dx: float | None = None, pa: float = 0,
txtfile: str | None = None,
fig: object | None = None, ax: object | None = None,
getfigax: bool = False,
savefig: str | dict | None = None, show: bool = False,
**kwargs: Any) -> tuple[object, object] | None:
"""Plot a one-dimensional spatial slice through a 2D map.
Keyword groups accepted in ``**kwargs``:
AstroData: Data input and metadata for the 2D map, such as ``fitsimage``, ``data``, ``x``, ``y``, ``beam``, ``Tb``, ``sigma``, ``center``, ``restfreq``, ``cfactor``, and ``bunit``.
AstroFrame: Data trimming and coordinate-frame options, such as ``rmax``, ``center``, ``dist``, ``xoff``, ``yoff``, ``xflip``, ``yflip``, and ``swapxy``.
PlotAxes2D: Axis formatting, such as ``xlim``, ``ylim``, ``xlabel``, ``ylabel``, ``grid``, ``xscale``, ``yscale``, and tick options.
Matplotlib: Additional keyword arguments are passed to ``matplotlib.axes.Axes.plot`` for the slice curve.
Default keyword values:
Slice curve: ``linestyle='-'`` and ``marker='o'``. User-supplied keyword arguments override these values.
Axis setup: ``rmax=length / 2`` and ``samexy=False`` are set internally. ``xlabel``, ``ylabel``, and ``xlim`` are filled from the slice geometry if they are omitted.
Args:
length (float): Slice length.
dx (float, optional): Grid increment. Defaults to None.
pa (float, optional): Degree. Position angle. Defaults to 0.
txtfile (str, optional): File name for ``numpy.savetxt``. Existing files with the same name are overwritten. Defaults to None.
fig (object, optional): External plt.figure(). Defaults to None.
ax (object, optional): External fig.add_subplot(). Defaults to None.
getfigax (bool, optional): Defaults to False.
savefig (dict or str, optional): Passed to ``close_figure``. Existing files may be overwritten, and the figure is closed after saving/showing. Defaults to None.
show (bool, optional): True means doing plt.show(). Defaults to False.
Returns:
tuple: (fig, ax), where ax is a list, if getfigax=True. Otherwise, no return.
"""
_kw = {'linestyle': '-', 'marker': 'o'}
_kw.update(kwargs)
_kw['rmax'] = length / 2
f = kwargs2instance(AstroFrame, _kw)
d = kwargs2instance(AstroData, _kw)
f.read(d)
if np.ndim(d.data) > 2:
print('Only 2D map is supported.')
return
r, z = d.slice(length=length, pa=pa, dx=dx)
xunit = 'arcsec' if f.dist == 1 else 'au'
yunit = 'K' if d.Tb else d.bunit
yquantity = r'$T_b$' if d.Tb else 'intensity'
if txtfile is not None:
np.savetxt(txtfile, np.c_[r, z],
header=f'x ({xunit}), {yquantity} ({yunit}); '
+ f'positive x is pa={pa:.2f} deg.')
if 'xlabel' not in _kw:
_kw['xlabel'] = f'Offset ({xunit})'
if 'ylabel' not in _kw:
_kw['ylabel'] = f'Intensity ({yunit})'
if 'xlim' not in _kw:
_kw['xlim'] = [r.min(), r.max()]
_kw['samexy'] = False
set_rcparams()
if fig is None:
fig = plt.figure()
if ax is None:
ax = fig.add_subplot(1, 1, 1)
pa2 = kwargs2instance(PlotAxes2D, _kw)
ax.plot(r, z, **_kw)
if d.sigma is not None:
ax.plot(r, r * 0 + 3 * d.sigma, 'k--')
pa2.set_xyaxes(ax)
if getfigax:
return fig, ax
close_figure(fig, savefig, show)
def _plot_on_wall(d: AstroData, x: np.ndarray, y: np.ndarray, v: np.ndarray,
measure: object, datalist: list,
sign: int, axis: int, **kwargs: Any) -> None:
dx, dy, dv = x[1] - x[0], y[1] - y[0], v[1] - v[0]
s, ds = [x, y, v], [dx, dy, dv]
if kwargs == {}:
return
match axis:
case 2:
shape = np.shape(d.data[:, :, 0])
case 1:
shape = np.shape(d.data[:, 0, :])
case 0:
shape = np.shape(d.data[0, :, :])
if np.shape(kwargs['data']) != shape:
print('The shape of the 2D data is inconsistent'
+ ' with the shape of the 3D data.')
return
_kw = {'levels': [3, 6, 12, 24, 48, 96, 192, 384],
'sigma': 'hist', 'cmap': 'Jet', 'alpha': 0.3}
_kw.update(kwargs)
volume = np.array(_kw['data'], copy=True)
levels = _kw['levels']
cmap = _kw['cmap']
alpha = _kw['alpha']
sigma = estimate_rms(data=volume, sigma=_kw['sigma'])
volume[np.isnan(volume)] = 0
a = int(sign == -1)
b = int(sign == 1)
volume = np.moveaxis([volume * a, volume * b], 0, axis)
if d.dx < 0:
volume = volume[:, :, ::-1]
if d.dy < 0:
volume = volume[:, ::-1, :]
if d.dv < 0:
volume = volume[::-1, :, :]
for lev in levels:
if lev * sigma > np.max(volume):
continue
vertices, simplices, _, _ = measure.marching_cubes(volume, lev * sigma)
Xg, Yg, Zg = [t[0] + i * dt for t, i, dt
in zip(s, vertices.T[::-1], ds)]
match axis:
case 2:
Xg = Xg * 0 + (x[-1] if sign == 1 else x[0])
case 1:
Yg = Yg * 0 + (y[-1] if sign == 1 else y[0])
case 0:
Zg = Zg * 0 + (v[-1] if sign == 1 else v[0])
i, j, k = simplices.T
mesh2d = dict(type='mesh3d', x=Xg, y=Yg, z=Zg,
i=i, j=j, k=k,
intensity=Zg * 0 + lev,
colorscale=cmap, reversescale=False,
cmin=np.min(levels), cmax=np.max(levels),
opacity=alpha, name='', showscale=False)
datalist.append(mesh2d)
[docs]
def plot3d(levels: list[float] = [3, 6, 12],
cmap: str = 'jet', alpha: float = 0.08,
xlabel: str = 'R.A. (arcsec)',
ylabel: str = 'Dec. (arcsec)',
vlabel: str = 'Velocity (km/s)',
xskip: int = 1, yskip: int = 1,
eye_p: float = 0, eye_i: float = 180,
xplus: dict = {}, xminus: dict = {},
yplus: dict = {}, yminus: dict = {},
vplus: dict = {}, vminus: dict = {},
outname: str = 'plot3d', show: bool = False,
return_data_layout: bool = False,
**kwargs: Any) -> None | dict:
"""Create an interactive Plotly 3D isosurface visualization of a spectral cube.
Keyword groups accepted in ``**kwargs``:
AstroData: Data input and metadata for the cube, such as ``fitsimage``, ``data``, ``x``, ``y``, ``v``, ``beam``, ``Tb``, ``sigma``, ``center``, ``restfreq``, ``cfactor``, and ``bunit``.
AstroFrame: Data trimming and coordinate-frame options, such as ``rmax``, ``center``, ``dist``, ``xoff``, ``yoff``, ``vsys``, ``vmin``, ``vmax``, ``xflip``, ``yflip``, and ``swapxy``.
Default keyword values:
Wall maps: The dictionaries ``xplus``, ``xminus``, ``yplus``, ``yminus``, ``vplus``, and ``vminus`` use ``levels=[3, 6, 12, 24, 48, 96, 192, 384]``, ``sigma='hist'``, ``cmap='Jet'``, and ``alpha=0.3`` when those keys are omitted.
Args:
levels (list, optional): Contour levels. Defaults to [3,6,12].
cmap (str, optional): Color map name. Defaults to 'Jet'.
alpha (float, optional): opacity in plotly. Defaults to 0.08.
xlabel (str, optional): Defaults to 'R.A. (arcsec)'.
ylabel (str, optional): Defaults to 'Dec. (arcsec)'.
vlabel (str, optional): Defaults to 'Velocity (km/s)'.
xskip (int, optional): Number of pixel to skip. Defaults to 1.
yskip (int, optional): Number of pixel to skip. Defaults to 1.
eye_p (float, optional): Azimuthal angle of camera. Defaults to 0.
eye_i (float, optional): Inclination angle of camera. Defaults to 180.
xplus (dict, optional): 2D data to be plotted on the y-v plane at the positive edge of x. This dictionary must have a key of data and can have keys of levels, sigma, cmap, and alpha. Defaults to {}.
xminus (dict, optional): See xplus. Defaults to {}.
yplus (dict, optional): See xplus. Defaults to {}.
yminus (dict, optional): See xplus. Defaults to {}.
vplus (dict, optional): See xplus. Defaults to {}.
vminus (dict, optional): See xplus. Defaults to {}.
outname (str, optional): Output HTML file name, with or without '.html'. Existing files with the same name are overwritten by Plotly. Defaults to 'plot3d'.
show (bool, optional): auto_play in plotly. Defaults to False.
return_data_layout (bool, optional): Whether to return data and layout for plotly.graph_objs.Figure. Defaults to False.
Returns:
dict: {'data': data, 'layout': layout}, if return_data_layout=True. Otherwise, no return.
"""
import plotly.graph_objs as go
from skimage import measure
f = kwargs2instance(AstroFrame, kwargs)
d = kwargs2instance(AstroData, kwargs)
f.read(d, xskip, yskip)
volume = np.array(d.data, copy=True)
x, y, v, sigma = d.x, d.y, d.v, d.sigma
dx, dy, dv = d.dx, d.dy, d.dv
volume[np.isnan(volume)] = 0
if dx < 0:
x, dx, volume = x[::-1], -dx, volume[:, :, ::-1]
if dy < 0:
y, dy, volume = y[::-1], -dy, volume[:, ::-1, :]
if dv < 0:
v, dv, volume = v[::-1], -dv, volume[::-1, :, :]
s, ds = [x, y, v], [dx, dy, dv]
deg = np.radians(1)
xeye = -np.sin(eye_i * deg) * np.sin(eye_p * deg)
yeye = -np.sin(eye_i * deg) * np.cos(eye_p * deg)
zeye = np.cos(eye_i * deg)
margin = dict(l=0, r=0, b=0, t=0)
camera = dict(eye=dict(x=xeye, y=yeye, z=zeye), up=dict(x=0, y=1, z=0))
xaxis = dict(range=[x[0], x[-1]], title=xlabel)
yaxis = dict(range=[y[0], y[-1]], title=ylabel)
zaxis = dict(range=[v[0], v[-1]], title=vlabel)
scene = dict(aspectmode='cube', camera=camera,
xaxis=xaxis, yaxis=yaxis, zaxis=zaxis)
layout = go.Layout(margin=margin, scene=scene, showlegend=False)
data = []
for lev in levels:
if lev * sigma > np.max(volume):
continue
vertices, simplices, _, _ = measure.marching_cubes(volume, lev * sigma)
Xg, Yg, Zg = [t[0] + i * dt for t, i, dt
in zip(s, vertices.T[::-1], ds)]
i, j, k = simplices.T
mesh = dict(type='mesh3d', x=Xg, y=Yg, z=Zg, i=i, j=j, k=k,
intensity=Zg * 0 + lev,
colorscale=cmap, reversescale=False,
cmin=np.min(levels), cmax=np.max(levels),
opacity=alpha, name='', showscale=False)
data.append(mesh)
Xe, Ye, Ze = [], [], []
for t in vertices[simplices]:
Xe += [x[0] + dx * t[k % 3][2] for k in range(4)] + [None]
Ye += [y[0] + dy * t[k % 3][1] for k in range(4)] + [None]
Ze += [v[0] + dv * t[k % 3][0] for k in range(4)] + [None]
lines = dict(type='scatter3d', x=Xe, y=Ye, z=Ze,
mode='lines', opacity=0.04, visible=True,
name='', line=dict(color='rgb(0,0,0)', width=1))
data.append(lines)
klist = [xplus, xminus, yplus, yminus, vplus, vminus]
slist = [1, -1, 1, -1, 1, -1]
alist = [2, 2, 1, 1, 0, 0]
for kw, sign, axis in zip(klist, slist, alist):
_plot_on_wall(d, x, y, v, measure, data, sign, axis, **kw)
if return_data_layout:
return {'data': data, 'layout': layout}
else:
fig = go.Figure(data=data, layout=layout)
fig.write_html(file=outname.replace('.html', '') + '.html',
auto_play=show)