import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from dataclasses import dataclass
from matplotlib.patches import Ellipse, Rectangle
from typing import TypeVar
from plotastrodata.analysis_utils import AstroData, AstroFrame
from plotastrodata.coord_utils import (coord2xy, xy2coord,
get_hmdm, get_min, get_sec)
from plotastrodata.noise_utils import estimate_rms
from plotastrodata.other_utils import (close_figure, listing,
reform_grid, reform_data)
plt.ioff() # force to turn off interactive mode
T = TypeVar('T')
[docs]
def set_rcparams(fontsize: int = 18, nancolor: str = 'w',
dpi: int = 256) -> None:
"""Nice rcParams for figures.
Args:
fontsize (int, optional): plt.rcParams['font.size']. Defaults to 18.
nancolor (str, optional): plt.rcParams['axes.facecolor']. Defaults to 'w'.
dpi (int, optional): plt.rcParams['savefig.dpi']. Defaults to 256.
"""
# plt.rcParams['font.family'] = 'arial'
plt.rcParams['axes.facecolor'] = nancolor
plt.rcParams['font.size'] = fontsize
plt.rcParams['savefig.dpi'] = dpi
plt.rcParams['legend.fontsize'] = 15
plt.rcParams['axes.linewidth'] = 1.5
plt.rcParams['xtick.direction'] = 'inout'
plt.rcParams['ytick.direction'] = 'inout'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True
plt.rcParams['xtick.major.size'] = 10
plt.rcParams['ytick.major.size'] = 10
plt.rcParams['xtick.minor.size'] = 6
plt.rcParams['ytick.minor.size'] = 6
plt.rcParams['xtick.major.width'] = 1.5
plt.rcParams['ytick.major.width'] = 1.5
plt.rcParams['xtick.minor.width'] = 1.5
plt.rcParams['ytick.minor.width'] = 1.5
[docs]
def logticks(ticks: list[float], lim: list[float, float]
) -> tuple[list[float], list[str]]:
"""Make nice ticks for a log axis.
Args:
ticks (list): List of ticks.
lim (list): [min, max].
Returns:
tuple: (new ticks, new labels).
"""
order = int(np.floor((np.log10(lim[0]))))
a = (lim[0] // 10**order + 1) * 10**order
a = np.round(a, max(-order, 0))
order = int(np.floor((np.log10(lim[1]))))
b = (lim[1] // 10**order) * 10**order
b = np.round(b, max(-order, 0))
newticks = np.sort(np.unique(np.r_[a, ticks, b]))
newlabels = [str(t if t < 1 else int(t)) for t in newticks]
return newticks, newlabels
[docs]
def logcbticks(vmin: float = 1e-3, vmax: float = 1e3
) -> tuple[np.ndarray, np.ndarray]:
"""Make nice ticks for a log color bar.
Args:
vmin (float, optional): Minimum value. Defaults to 1e-3.
vmax (float, optional): Maximum value. Defaults to 1e3.
Returns:
tuple: (ticks, ticklabels).
"""
i0 = int(np.floor(np.log10(vmin)))
i1 = int(np.ceil(np.log10(vmax)))
ticks = np.outer(np.logspace(i0, i1, i1 - i0 + 1), np.arange(1, 10))
ticklabels = []
for i in range(i0, i1 + 1):
ii = np.abs(min(i, 0))
ii = f'{ii:d}'
for j in range(1, 10):
jj = j * 10**i
if j in [1, 2, 5]:
s = f'{jj:.{ii}f}'
else:
s = ''
ticklabels.append(s)
ticks = np.ravel(ticks)
ticklabels = np.ravel(ticklabels)
cond = (vmin <= ticks) * (ticks <= vmax)
return ticks[cond], ticklabels[cond]
[docs]
def get_figsize(xmin: float, xmax: float, ymin: float, ymax: float,
figsize: tuple | None = None,
ncols: int = 1, nrows: int = 1, nchan: int = 1
) -> tuple[float, float]:
"""Get a nice figsize (tuple) with the given x and y ranges.
Args:
xmin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
xmax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
ymin (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
ymax (float): The figsize is based on the aspect ratio of (ymax - ymin) / (xmax - xmin).
figsize (tuple | None, optional): If this is not None, this will be the output as is. Defaults to None.
ncols (int, optional): The number of columns for the channel map. Defaults to 1.
nrows (int, optional): The number of rows for the channel map. Defaults to 1.
nchan (int, optional): The number of total channels for the channel map. Defaults to 1.
Returns:
tuple[float, float]: figsize for matplotlib.pyplot.Figure.
"""
if figsize is not None:
return figsize
sqrt_a = (ymax - ymin) / (xmax - xmin)
sqrt_a = np.sqrt(np.abs(sqrt_a))
if nchan == 1:
figsize = (7 / sqrt_a, 5 * sqrt_a)
else:
figsize = (ncols * 2 / sqrt_a, max(nrows*2, 3) * sqrt_a)
return figsize
def _get_gridwidth(mode: str, rmax: float) -> tuple[float, int]:
# 10^1.45 / 15 ~ 2 grids for R.A.
# 10^0.45 ~ 3 grids for Dec.
scale = 1.45 if mode == 'ra' else 0.45
x = np.log10(2. * rmax) - scale
order = np.floor(x)
frac = x - order
if frac <= 0.33:
base = 1
elif frac <= 0.68:
base = 2
else:
base = 5
return base * 10**order, int(order)
def _get_v(p, v: np.ndarray | None = None,
restfreq: float | None = None,
vskip: int = 1) -> np.ndarray:
if p.fitsimage is not None:
p.read(d := AstroData(fitsimage=p.fitsimage,
restfreq=restfreq, sigma=None))
v = d.v
if v is None:
v = np.array([0])
if len(v) > 1:
v = reform_grid(v=v, vmin=p.vmin, vmax=p.vmax)
v = v[::vskip]
return v
def _get_nij2ch(nrows: int = 1, ncols: int = 1) -> object:
def nij2ch(n: int, i: int, j: int) -> int:
return n*nrows*ncols + i*ncols + j
return nij2ch
def _get_ch2nij(nrows: int = 1, ncols: int = 1) -> object:
def ch2nij(ch: int) -> tuple[int, int, int]:
n = ch // (nrows*ncols)
i = (ch - n*nrows*ncols) // ncols
j = ch % ncols
return n, i, j
return ch2nij
def _get_vskipfill(nv: float, v_org: np.ndarray, vskip: int) -> object:
def vskipfill(c: np.ndarray, v_in: np.ndarray) -> np.ndarray:
return reform_data(c=c, v_in=v_in, nv=nv, v_org=v_org, vskip=vskip)
return vskipfill
[docs]
@dataclass
class Stretcher():
"""Arguments and methods related to the stretch in PlotAstroData.add_color() and add_rgb().
Args:
stretch (str, optional): 'log', 'asinh', 'power', or 'linear'. Any other means 'linear'. 'log' means the mapped data are logarithmic. 'asinh' means the mapped data are arc sin hyperbolic. 'power' means the mapped data are power-law (see also stretchpower). Defaults to 'linear'.
stretchscale (float, optional): The output is asinh(data / stretchscale). Defaults to None.
stretchpower (float, optional): The output is data**stretchpower / stretchpower. 1 means the linear scale, while 0 means the logarithmic scale. Defaults to 0.5.
vmin (float, optional): The minimum value for Axes.pcolormesh() of matplotlib. Defaults to None.
vmax (float, optional): The maximum value for Axes.pcolormesh() of matplotlib. Defaults to None.
sigma (float, optional): Noise level. Defaults to 0.
"""
stretch: str = 'linear'
stretchscale: float | None = None
stretchpower: float = 0.5
vmin: float | None = None
vmax: float | None = None
sigma: float = 0
def __post_init__(self):
self.n = 1 if type(self.stretch) is str else len(self.stretch)
stretch = self.stretch
stsc = self.stretchscale
vmin = self.vmin
sigma = self.sigma
if self.n == 1:
if stsc is None:
self.stretchscale = sigma
if (stretch == 'log' or stretch == 'power') and vmin is None:
self.vmin = sigma
else:
getsigma = np.equal(stsc, None)
self.stretchscale = np.where(getsigma, sigma, stsc)
islog = np.equal(stretch, 'log')
ispower = np.equal(stretch, 'power')
novmin = np.equal(vmin, None)
getsigma = (islog + ispower) * novmin
self.vmin = np.where(getsigma, sigma, vmin)
[docs]
def do(self, x: list | np.ndarray, i: int = 0) -> np.ndarray:
"""Get the stretched values.
Args:
x (list | np.ndarray): Input array in the linear scale.
i (int): Which element is used in the case where the stretch parameters are lists.
Returns:
np.ndarray: Output stretched array.
"""
st = self.stretch[i] if self.n > 1 else self.stretch
stsc = self.stretchscale[i] if self.n > 1 else self.stretchscale
stpw = self.stretchpower[i] if self.n > 1 else self.stretchpower
t = np.array(x)
match st:
case 'log':
t = np.log10(t) # To be consistent with logcbticks().
case 'asinh':
t = np.arcsinh(t / stsc)
case 'power':
p = 1e-6 if stpw == 0 else stpw
t = t**p / p
return t
[docs]
def undo(self, x: list | np.ndarray, i: int = 0) -> np.ndarray:
"""Get the linear values from the stretched values.
Args:
x (list | np.ndarray): Input stretched array.
i (int): Which element is used in the case where the stretch parameters are lists.
Returns:
np.ndarray: Output array in the linear scale.
"""
st = self.stretch[i] if self.n > 1 else self.stretch
stsc = self.stretchscale[i] if self.n > 1 else self.stretchscale
stpw = self.stretchpower[i] if self.n > 1 else self.stretchpower
t = np.array(x)
match st:
case 'log':
t = 10**t # To be consistent with logcbticks().
case 'asinh':
t = np.sinh(t) * stsc
case 'power':
p = 1e-6 if stpw == 0 else stpw
t = (t * p)**(1 / p)
return t
[docs]
def set_minmax(self, data: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Set vmin and vmax for color pcolormesh and RGB maps.
Args:
data (np.ndarray): 2D/3D data to plot.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray]: (Clipped stretched data, new vmin, new vmax).
"""
single = self.n == 1
vminout = [self.vmin] if single else self.vmin
vmaxout = [self.vmax] if single else self.vmax
dataout = [data] if single else data
for i, (c, v0, v1) in enumerate(zip(dataout, vminout, vmaxout)):
dataout[i] = cout = self.do(c.clip(v0, v1), i)
vminout[i] = np.nanmin(cout)
vmaxout[i] = np.nanmax(cout)
if single:
dataout = dataout[0]
vminout = vminout[0]
vmaxout = vmaxout[0]
self.vmin = vminout
self.vmax = vmaxout
return dataout, vminout, vmaxout
[docs]
class Beam():
"""Arguments for PlotAstroData.add_beam().
Args:
show_beam (bool, optional): Defaults to True.
beam (list, optional): [bmaj, bmin, bpa]. This may be a list of list. Defaults to [None, None, None].
beamcolor (str, optional): matplotlib color. This may be a list of str. Defaults to 'gray'.
beampos (list, optional): Relative position. This may be a list of list or a list of None. Defaults to None.
beam_kwargs (dict, optional): Additional arguments for matplotlib.patches. Defaults to {}.
"""
def __init__(self,
show_beam: bool = True,
beam: list[float | None] = [None] * 3,
beamcolor: str = 'gray',
beampos: list[float] | None = None,
beam_kwargs: dict = {}):
self.show_beam = show_beam
self.beam = beam
self.beamcolor = beamcolor
self.beampos = beampos
self.beam_kwargs = beam_kwargs
[docs]
def todict(self):
tmp = {'show_beam': self.show_beam,
'beam': self.beam,
'beamcolor': self.beamcolor,
'beampos': self.beampos}
tmp.update(self.beam_kwargs)
return tmp
[docs]
@dataclass
class PlotAxes2D():
"""Use Axes.set_* to adjust x and y axes.
Args:
samexy (bool, optional): True supports same ticks between x and y. Defaults to True.
loglog (float, optional): If a float is given, plot on a log-log plane, and xim=(xmax / loglog, xmax) and so does ylim. Defaults to None.
xscale (str, optional): Defaults to None.
yscale (str, optional): Defaults to None.
xlim (list, optional): Defaults to None.
ylim (list, optional): Defaults to None.
xlabel (str, optional): Defaults to None.
ylabel (str, optional): Defaults to None.
xticks (list, optional): Defaults to None.
yticks (list, optional): Defaults to None.
xticklabels (list, optional): Defaults to None.
yticklabels (list, optional): Defaults to None.
xticksminor (list or int, optional): If int, int times more than xticks. Defaults to None.
yticksminor (list or int, optional): Defaults to None. If int, int times more than xticks. Defaults to None.
grid (dict, optional): True means merely grid(). Defaults to None.
aspect (dict or float, optional): Defaults to None.
"""
samexy: bool = True
loglog: float | None = None
xscale: str = 'linear'
yscale: str = 'linear'
xlim: list | None = None
ylim: list | None = None
xlabel: str | None = None
ylabel: str | None = None
xticks: list | None = None
yticks: list | None = None
xticklabels: list | None = None
yticklabels: list | None = None
xticksminor: list | int | None = None
yticksminor: list | int | None = None
grid: dict | None = None
aspect: dict | float | None = None
[docs]
def set_xyaxes(self, ax):
if self.loglog is not None:
self.xscale = 'log'
self.yscale = 'log'
self.samexy = True
if self.xlim is not None:
self.xlim[0] = self.xlim[1] / self.loglog
if self.ylim is not None:
self.ylim[0] = self.ylim[1] / self.loglog
ax.set_xscale(self.xscale)
ax.set_yscale(self.yscale)
if self.samexy:
ax.set_xticks(ax.get_yticks())
ax.set_yticks(ax.get_xticks())
ax.set_aspect(1)
if self.xticks is None:
self.xticks = ax.get_xticks()
if self.xscale == 'log':
self.xticks, self.xticklabels \
= logticks(self.xticks, self.xlim)
if self.yticks is None:
self.yticks = ax.get_yticks()
if self.yscale == 'log':
self.yticks, self.yticklabels \
= logticks(self.yticks, self.ylim)
ax.set_xticks(self.xticks)
ax.set_yticks(self.yticks)
if self.xticksminor is not None:
if type(self.xticksminor) is int:
t = ax.get_xticks()
dt = t[1] - t[0]
t = np.r_[t[0] - dt, t, t[-1] + dt]
num = self.xticksminor * (len(t) - 1) + 1
self.xticksminor = np.linspace(t[0], t[-1], num)
ax.set_xticks(self.xticksminor, minor=True)
if self.yticksminor is not None:
if type(self.yticksminor) is int:
t = ax.get_yticks()
dt = t[1] - t[0]
t = np.r_[t[0] - dt, t, t[-1] + dt]
num = self.yticksminor * (len(t) - 1) + 1
self.yticksminor = np.linspace(t[0], t[-1], num)
ax.set_yticks(self.yticksminor, minor=True)
if self.xticklabels is not None:
ax.set_xticklabels(self.xticklabels)
if self.yticklabels is not None:
ax.set_yticklabels(self.yticklabels)
if self.xlabel is not None:
ax.set_xlabel(self.xlabel)
if self.ylabel is not None:
ax.set_ylabel(self.ylabel)
if self.xlim is not None:
ax.set_xlim(*self.xlim)
if self.ylim is not None:
ax.set_ylim(*self.ylim)
if self.grid is not None:
ax.grid(**({} if self.grid is True else self.grid))
if self.aspect is not None:
if type(self.aspect) is dict:
ax.set_aspect(**self.aspect)
else:
ax.set_aspect(self.aspect)
[docs]
def kwargs2instance(cls: type[T], kw: dict) -> T:
"""Get an instance and remove its arguments from kwargs.
Args:
cls (class): Class to make the instance.
kw (dict): Parameters to make the instance.
Returns:
instance: an instance of cls made from the parameters in kwargs.
"""
kw0 = {}
if cls == AstroData:
kw0 = {'data': np.zeros((2, 2))}
exkeys = {}
if cls == AstroFrame:
exkeys = {'fitsimage', 'center'}
elif cls == Stretcher:
exkeys = {'vmin', 'vmax'}
elif cls == Beam:
exkeys = {'beam'}
keys = vars(cls(**kw0)).keys()
tmp = {k: kw[k] for k in keys if k in kw}
for k in keys - exkeys:
kw.pop(k, None)
return cls(**tmp)
[docs]
class PlotAstroData(AstroFrame):
"""Make a figure from 2D/3D FITS files or 2D/3D arrays.
Basic rules ---
For 3D data, a 1D velocity array or a FITS file with a velocity axis must be given to set up channels in each page.
For 2D/3D data, the spatial center can be read from a FITS file or manually given.
len(v)=1 (default) means to make a 2D figure.
Spatial lengths are in the unit of arcsec, or au if dist (!= 1) is given.
Angles are in the unit of degree.
For region, line, arrow, label, and marker, a single input can be treated without a list, e.g., anglelist=60, as well as anglelist=[60].
Each element of poslist supposes a text coordinate like '01h23m45.6s 01d23m45.6s' or a list of relative x and y like [0.2, 0.3] (0 is left or bottom, 1 is right or top).
Parameters for original methods in matplotlib.axes.Axes can be used as kwargs; see the default _kw for reference.
Position-velocity diagrams (pv=True) does not yet suppot region, line, arrow, and segment because the units of abscissa and ordinate are different.
kwargs is the arguments of AstroFrame to define plotting ranges.
Args:
v (np.ndarray, optional): Used to set up channels if fitsimage not given. Defaults to None.
vskip (int, optional): How many channels are skipped. Defaults to 1.
veldigit (int, optional): How many digits after the decimal point. Defaults to 2.
restfreq (float, optional): Used for velocity and brightness T. Defaults to None.
channelnumber (int, optional): Specify a channel number to make 2D maps. Defaults to None.
nrows (int, optional): Used for channel maps. Defaults to 4.
ncols (int, optional): Used for channel maps. Defaults to 6.
fontsize (int, optional): rc_Params['font.size']. None means 18 (2D) or 12 (3D). Defaults to None.
nancolor (str, optional): Color for masked regions. Defaults to white.
dpi (int, optional): Dot per inch for plotting an image. Defaults to 256.
figsize (tuple, optional): Defaults to None.
fig (optional): External plt.figure(). Defaults to None.
ax (optional): External fig.add_subplot(). Defaults to None.
"""
def __init__(self,
v: np.ndarray | None = None, vskip: int = 1,
veldigit: int = 2, restfreq: float | None = None,
channelnumber: int | None = None,
nrows: int = 4, ncols: int = 6,
fontsize: int | None = None,
nancolor: str = 'w', dpi: int = 256,
figsize: tuple[float, float] | None = None,
fig: object | None = None, ax: object | None = None,
**kwargs) -> None:
super().__init__(**kwargs)
internalfig = fig is None
internalax = ax is None
v = _get_v(p=self, v=v, restfreq=restfreq, vskip=vskip)
nv = len(v) # number of channels with a label
if self.pv or len(v) == 1 or channelnumber is not None:
nrows = ncols = npages = nchan = 1
else:
npages = int(np.ceil(nv / nrows / ncols))
nchan = npages * nrows * ncols
v = reform_grid(v, k1=nchan - nv)
nij2ch = _get_nij2ch(nrows=nrows, ncols=ncols)
ch2nij = _get_ch2nij(nrows=nrows, ncols=ncols)
if fontsize is None:
fontsize = 18 if nchan == 1 else 12
set_rcparams(fontsize=fontsize, nancolor=nancolor, dpi=dpi)
ax = np.empty(nchan, dtype=object) if internalax else [ax]
figsize = get_figsize(xmin=self.xmin, xmax=self.xmax,
ymin=self.ymin, ymax=self.ymax,
figsize=figsize,
ncols=ncols, nrows=nrows, nchan=nchan)
need_vlabel = nchan > 1 or type(channelnumber) is int
for ch in range(nchan):
n, i, j = ch2nij(ch)
if internalfig and n not in plt.get_fignums():
fig = plt.figure(n, figsize=figsize)
if need_vlabel:
fig.subplots_adjust(hspace=0, wspace=0,
right=0.87, top=0.87)
if internalax:
sharex = ax[nij2ch(n, i - 1, j)] if i > 0 else None
sharey = ax[nij2ch(n, i, j - 1)] if j > 0 else None
ax[ch] = fig.add_subplot(nrows, ncols, i*ncols + j + 1,
sharex=sharex, sharey=sharey)
if need_vlabel and ch < nv:
vlabel = v[channelnumber or ch]
ax[ch].text(0.9 * self.rmax, 0.7 * self.rmax,
rf'${vlabel:.{veldigit}f}$', color='black',
backgroundcolor='white', zorder=20)
self.fig = None if internalfig else fig
self.ax = ax
self.rowcol = nrows * ncols
self.npages = npages
self.allchan = np.arange(nv)
self.bottomleft = nij2ch(np.arange(npages), nrows - 1, 0)
self.channelnumber = channelnumber
self.v = v
self.vskipfill = _get_vskipfill(nv=nv, v_org=v, vskip=vskip)
def _map_init(self, kw: dict) -> tuple:
"""
Common process for add_color, add_contour, add_segment, and add_rgb.
xskip and yskip (int) mean spatial pixel skips, which defaults to 1.
Args:
kw (dict): kwargs input for each method.
Returns:
tuple: Data and parameters used in each method.
"""
b = kwargs2instance(Beam, kw)
self._kw.update(kw)
xskip = self._kw.pop('xskip', 1)
yskip = self._kw.pop('yskip', 1)
d = kwargs2instance(AstroData, self._kw)
self.read(d, xskip, yskip)
self.sigma = d.sigma
singlepix = d.dx is None or d.dy is None
if len(d.beam) == 4:
b.beam = self.beam = next(b for b in d.beam if None not in b)
else:
b.beam = self.beam = d.beam
self.add_beam(**b.todict())
return (d.data, d.x, d.y, d.v, d.sigma, d.bunit,
self._kw, singlepix)
[docs]
def add_region(self, patch: str = 'ellipse',
poslist: list[str | list[float, float]] = [],
majlist: list[float] = [], minlist: list[float] = [],
palist: list[float] = [],
include_chan: list[int] | None = None,
**kwargs) -> None:
"""Use add_patch() and Rectangle or Ellipse of matplotlib.
Args:
patch (str, optional): 'ellipse' or 'rectangle'. Defaults to 'ellipse'.
poslist (list, optional): Text or relative center. Defaults to [].
majlist (list, optional): Ellipse major axis. Defaults to [].
minlist (list, optional): Ellipse minor axis. Defaults to [].
palist (list, optional): Position angle (north to east). Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'facecolor': 'none', 'edgecolor': 'gray',
'linewidth': 1.5, 'zorder': 10}
_kw.update(kwargs)
if include_chan is None:
include_chan = self.allchan
if patch not in ['rectangle', 'ellipse']:
print('Only patch=\'rectangle\' or \'ellipse\' supported. ')
return
z = listing(*self.pos2xy(poslist), minlist, majlist, palist)
for x, y, width, height, angle in zip(*z):
for ch, axnow in enumerate(self.ax):
if type(self.channelnumber) is int:
ch = self.channelnumber
if ch not in include_chan:
continue
if self.fig is None:
plt.figure(ch // self.rowcol)
if patch == 'rectangle':
a = np.radians(angle)
xp = x - (width*np.cos(a) + height*np.sin(a)) / 2.
yp = y - (-width*np.sin(a) + height*np.cos(a)) / 2.
p = Rectangle
else:
xp, yp = x, y
p = Ellipse
p = p((xp, yp), width=width, height=height,
angle=angle * self.xdir, **_kw)
axnow.add_patch(p)
[docs]
def add_beam(self, **kwargs) -> None:
"""Use add_region(). kwargs may include the arguments of Beam, except for beam_kwargs, to specify the beam apparance. Those arguments may be a list of each format.
"""
b = kwargs2instance(Beam, kwargs)
show_beam, beamcolor, beampos = b.show_beam, b.beamcolor, b.beampos
beam = b.beam
del kwargs['beam']
if not show_beam:
return
animation = self.channelnumber is not None
include_chan = self.allchan if animation else self.bottomleft
patch = 'rectangle' if self.pv else 'ellipse'
blist = [beam] if np.ndim(beam) == 1 else beam
n = len(blist)
bclist = beamcolor if type(beamcolor) is list else [beamcolor] * n
islist = beampos == [None] * 3 or np.ndim(beampos) == 2
bplist = beampos if islist else [beampos] * n
for (bmaj, bmin, bpa), bc, bp in zip(blist, bclist, bplist):
if None in [bmaj, bmin, bpa]:
print('No beam to plot.')
continue
a = max(0.35 * bmaj / self.rmax, 0.1)
if bp is None:
bp = [a, 0.1 if self.pv else a]
if self.swapxy:
bp = np.transpose(bp)
bpa = 90 - bpa
_kw = {'facecolor': bc, 'edgecolor': None}
_kw.update(kwargs)
self.add_region(patch=patch, poslist=bp,
majlist=bmaj, minlist=bmin, palist=bpa,
include_chan=include_chan,
**_kw)
[docs]
def add_marker(self, poslist: list[str | list[float, float]] = [],
include_chan: list[int] | None = None, **kwargs) -> None:
"""Use Axes.plot of matplotlib.
Args:
poslist (list, optional): Text or relative. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'marker': '+', 'ms': 10, 'mfc': 'gray',
'mec': 'gray', 'mew': 2, 'alpha': 1, 'zorder': 10}
_kw.update(kwargs)
if include_chan is None:
include_chan = self.allchan
for ch, axnow in enumerate(self.ax):
if type(self.channelnumber) is int:
ch = self.channelnumber
if ch not in include_chan:
continue
for x, y in zip(*self.pos2xy(poslist)):
axnow.plot(x, y, **_kw)
[docs]
def add_text(self, poslist: list[str | list[float, float]] = [],
slist: list[str] = [],
include_chan: list[int] | None = None, **kwargs) -> None:
"""Use Axes.text of matplotlib.
Args:
poslist (list, optional): Text or relative. Defaults to [].
slist (list, optional): List of text. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'color': 'gray', 'fontsize': 15, 'ha': 'center',
'va': 'center', 'zorder': 10}
subkeys = {'ha': 'horizontalalignment',
'va': 'verticalalignment'}
for short, long in subkeys.items():
if long in kwargs:
kwargs[short] = kwargs.pop(long)
_kw.update(kwargs)
if include_chan is None:
include_chan = self.allchan
for ch, axnow in enumerate(self.ax):
if type(self.channelnumber) is int:
ch = self.channelnumber
if ch not in include_chan:
continue
z = listing(*self.pos2xy(poslist), slist)
for x, y, s in zip(*z):
axnow.text(x=x, y=y, s=s, **_kw)
[docs]
def add_line(self, poslist: list[str | list[float, float]] = [],
anglelist: list[float] = [],
rlist: list[float] = [], include_chan: list[int] | None = None,
**kwargs) -> None:
"""Use Axes.plot of matplotlib.
Args:
poslist (list, optional): Text or relative. Defaults to [].
anglelist (list, optional): North to east. Defaults to [].
rlist (list, optional): List of radius. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'color': 'gray', 'linewidth': 1.5,
'linestyle': '-', 'zorder': 10}
_kw.update(kwargs)
if include_chan is None:
include_chan = self.allchan
for ch, axnow in enumerate(self.ax):
if type(self.channelnumber) is int:
ch = self.channelnumber
if ch not in include_chan:
continue
alist = np.radians(anglelist)
z = listing(*self.pos2xy(poslist), alist, rlist)
for x, y, a, r in zip(*z):
axnow.plot([x, x + r * np.sin(a)],
[y, y + r * np.cos(a)], **_kw)
[docs]
def add_arrow(self, poslist: list[str | list[float, float]] = [],
anglelist: list[float] = [],
rlist: list[float] = [], include_chan: list[int] | None = None,
**kwargs) -> None:
"""Use Axes.quiver of matplotlib.
Args:
poslist (list, optional): Text or relative. Defaults to [].
anglelist (list, optional): North to east. Defaults to [].
rlist (list, optional): List of radius. Defaults to [].
include_chan (list, optional): None means all. Defaults to None.
"""
_kw = {'color': 'gray', 'width': 0.012,
'headwidth': 5, 'headlength': 5, 'zorder': 10}
_kw.update(kwargs)
if include_chan is None:
include_chan = self.allchan
for ch, axnow in enumerate(self.ax):
if type(self.channelnumber) is int:
ch = self.channelnumber
if ch not in include_chan:
continue
alist = np.radians(anglelist)
z = listing(*self.pos2xy(poslist), alist, rlist)
for x, y, a, r in zip(*z):
axnow.quiver(x, y, r * np.sin(a), r * np.cos(a),
angles='xy', scale_units='xy', scale=1,
**_kw)
[docs]
def add_scalebar(self, length: float = 0, label: str = '',
color: str = 'gray', barpos: tuple[float, float] = (0.8, 0.12),
fontsize: float | None = None, linewidth: float = 3,
bbox: dict = {'alpha': 0}) -> None:
"""Use Axes.text and Axes.plot of matplotlib.
Args:
length (float, optional): In the unit of arcsec. Defaults to 0.
label (str, optional): Text like '100 au'. Defaults to ''.
color (str, optional): Same for bar and label. Defaults to 'gray'.
barpos (tuple, optional): Relative position. Defaults to (0.8, 0.12).
fontsize (float, optional): None means 15 if one channel else 20. Defaults to None.
linewidth (float, optional): Width of the bar. Defaults to 3.
"""
if length == 0:
print('No length is given. Skip add_scalebar().')
return
if fontsize is None:
fontsize = 20 if len(self.ax) == 1 else 15
for ch, axnow in enumerate(self.ax):
if ch not in self.bottomleft:
continue
x, y = self.pos2xy([barpos[0], barpos[1] - 0.012])
axnow.text(x[0], y[0], label, color=color, size=fontsize,
ha='center', va='top', bbox=bbox, zorder=10)
x, y = self.pos2xy([barpos[0], barpos[1] + 0.012])
axnow.plot([x[0] - length/2., x[0] + length/2.], [y[0], y[0]],
'-', linewidth=linewidth, color=color)
def _set_colorbar(self, mappable, ch: int, show_cbar: bool,
cblabel: str, cbformat: str,
cbticks: list | None, cbticklabels: list | None,
cblocation: str,
cblabelfontsize: int, cbtickfontsize: int,
st: Stretcher):
if not show_cbar:
return
if self.fig is None:
fig = plt.figure(ch // self.rowcol)
else:
fig = self.fig
if len(self.ax) == 1:
ax = self.ax[ch]
cb = fig.colorbar(mappable[ch], ax=ax, label=cblabel,
format=cbformat, location=cblocation)
else:
cax = plt.axes([0.88, 0.105, 0.015, 0.77])
cb = fig.colorbar(mappable[ch], cax=cax, label=cblabel,
format=cbformat)
cb.ax.tick_params(labelsize=cbtickfontsize)
font = mpl.font_manager.FontProperties(size=cblabelfontsize)
cb.ax.yaxis.label.set_font_properties(font)
if cbticks is None and st.stretch == 'log':
cbticks, cbticklabels = logcbticks(10**st.vmin, 10**st.vmax)
cbticks = cb.get_ticks() if cbticks is None else st.do(cbticks)
cond = (st.vmin <= cbticks) * (cbticks <= st.vmax)
cbticks = cbticks[cond]
cb.set_ticks(cbticks)
if cbticklabels is None:
cbticklabels = [f'{t:{cbformat[1:]}}' for t in st.undo(cbticks)]
else:
cbticklabels = np.array(cbticklabels)[cond]
cb.set_ticklabels(cbticklabels)
[docs]
def add_color(self,
show_cbar: bool = True,
cblabel: str | None = None,
cbformat: float = '%.1e',
cbticks: list[float] | None = None,
cbticklabels: list[str] | None = None,
cblocation: str = 'right',
cblabelfontsize: int = 16,
cbtickfontsize: int = 14,
**kwargs) -> None:
"""Use Axes.pcolormesh of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs may include the arguments for Stretcher (stretch, stretchscale, and stretchpower) to specify the stretch parameters. kwargs may include arguments of Beam; a dict of beam_kwargs specifies the beam patch in more detail. kwargs may include xskiip and yskip.
Args:
show_cbar (bool, optional): Show color bar. Defaults to True.
cblabel (str, optional): Colorbar label. Defaults to None.
cbformat (float, optional): Format for ticklabels of colorbar. Defaults to '%.1e'.
cbticks (list, optional): Ticks of colorbar. Defaults to None.
cbticklabels (list, optional): Ticklabels of colorbar. Defaults to None.
cblocation (str, optional): 'left', 'top', 'left', 'right'. Only for 2D images. Defaults to 'right'.
cblabelfontsize (int, optional): Fontsize for the colorbar label. This is independent of set_rcparams().
cbtickfontsize (int, optional): Fontsize for the colorbar ticks. This is independent of set_rcparams().
"""
self._kw = {'cmap': 'cubehelix', 'alpha': 1,
'edgecolors': 'none', 'zorder': 1,
'vmin': None, 'vmax': None}
c, x, y, v, sigma, bunit, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_color.')
return
cblabel = bunit if cblabel is None else cblabel
_kw['sigma'] = sigma
st = kwargs2instance(Stretcher, _kw)
c, cmin, cmax = st.set_minmax(c)
_kw['vmin'] = cmin
_kw['vmax'] = cmax
c = self.vskipfill(c, v)
if type(self.channelnumber) is int:
c = [c[self.channelnumber]]
p = [None] * len(self.ax)
for ch, (axnow, cnow) in enumerate(zip(self.ax, c)):
pnow = axnow.pcolormesh(x, y, cnow, **_kw)
if ch in self.bottomleft:
p[ch] = pnow
for ch in self.bottomleft:
self._set_colorbar(p, ch, show_cbar, cblabel, cbformat,
cbticks, cbticklabels, cblocation,
cblabelfontsize, cbtickfontsize, st)
[docs]
def add_contour(self,
levels: list[float] = [-12, -6, -3, 3, 6, 12, 24, 48, 96, 192, 384],
**kwargs) -> None:
"""Use Axes.contour of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs may include arguments of Beam; a dict of beam_kwargs specifies the beam patch in more detail. kwargs may include xskiip and yskip.
Args:
levels (list, optional): Contour levels in the unit of sigma. Defaults to [-12,-6,-3,3,6,12,24,48,96,192,384].
"""
self._kw = {'colors': 'gray', 'linewidths': 1.0, 'zorder': 2}
c, x, y, v, sigma, _, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_contour.')
return
c = self.vskipfill(c, v)
if type(self.channelnumber) is int:
c = [c[self.channelnumber]]
for axnow, cnow in zip(self.ax, c):
axnow.contour(x, y, cnow, np.sort(levels) * sigma, **_kw)
[docs]
def add_segment(self,
ampfits: str = None, angfits: str = None,
Ufits: str = None, Qfits: str = None,
amp: list[np.ndarray] | None = None,
ang: list[np.ndarray] | None = None,
stU: list[np.ndarray] | None = None,
stQ: list[np.ndarray] | None = None,
ampfactor: float = 1., angonly: bool = False,
rotation: float = 0.,
cutoff: float = 3.,
**kwargs) -> None:
"""Use Axes.quiver of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. fitsimage = [ampfits, angfits, Ufits, Qfits]. data = [amp, ang, stU, stQ]. kwargs may include arguments of Beam; a dict of beam_kwargs specifies the beam patch in more detail. kwargs may include xskiip and yskip.
Args:
ampfits (str, optional): In put fits name. Length of segment. Defaults to None.
angfits (str, optional): In put fits name. North to east. Defaults to None.
Ufits (str, optional): In put fits name. Stokes U. Defaults to None.
Qfits (str, optional): In put fits name. Stokes Q. Defaults to None.
amp (list, optional): Length of segment. Defaults to None.
ang (list, optional): North to east. Defaults to None.
stU (list, optional): Stokes U. Defaults to None.
stQ (list, optional): Stokes Q. Defaults to None.
ampfactor (float, optional): Length of segment is amp times ampfactor. Defaults to 1..
angonly (bool, optional): True means amp=1 for all. Defaults to False.
rotation (float, optional): Segment angle is ang + rotation. Defaults to 0..
cutoff (float, optional): Used when amp and ang are calculated from Stokes U and Q. In the unit of sigma. Defaults to 3..
"""
self._kw = {'angles': 'xy', 'scale_units': 'xy', 'color': 'gray',
'pivot': 'mid', 'headwidth': 0, 'headlength': 0,
'headaxislength': 0, 'width': 0.007, 'zorder': 3,
'fitsimage': [ampfits, angfits, Ufits, Qfits],
'data': [amp, ang, stU, stQ]}
c, x, y, v, sigma, _, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_segment.')
return
amp, ang, stU, stQ = c
sigmaU, sigmaQ = sigma[2:]
if stU is not None and stQ is not None:
self.sigma = sigma = (sigmaU + sigmaQ) / 2.
ang = np.degrees(np.arctan2(stU, stQ) / 2.)
amp = np.hypot(stU, stQ)
amp[amp < cutoff * sigma] = np.nan
if amp is None:
amp = np.ones_like(ang)
if angonly:
amp = np.sign(amp)**2
amp = amp / np.nanmax(amp)
U = ampfactor * amp * np.sin(np.radians(ang + rotation))
V = ampfactor * amp * np.cos(np.radians(ang + rotation))
U = self.vskipfill(U, v)
V = self.vskipfill(V, v)
if type(self.channelnumber) is int:
U = [U[self.channelnumber]]
V = [V[self.channelnumber]]
_kw['scale'] = 1. / np.abs(x[1] - x[0])
for axnow, unow, vnow in zip(self.ax, U, V):
axnow.quiver(x, y, unow, vnow, **_kw)
[docs]
def add_rgb(self, **kwargs) -> None:
"""Use PIL.Image and imshow of matplotlib. kwargs must include the arguments of AstroData to specify the data to be plotted. A three-element array ([red, green, blue]) is supposed for all arguments, including vmax and vmin, except for xskip, yskip and show_beam. kwargs may include the arguments for Stretcher (stretch, stretchscale, and stretchpower; three-element array for each) to specify the stretch parameters. kwargs may include arguments of Beam (three-element arrays); a single dict of beam_kwargs specifies the beam patch in more detail. kwargs may include xskiip and yskip.
"""
from PIL import Image
self._kw = {'vmin': [None] * 3, 'vmax': [None] * 3,
'stretch': ['linear'] * 3,
'stretchscale': [None] * 3,
'stretchpower': [0.5] * 3}
c, x, y, v, _, _, _kw, singlepix = self._map_init(kwargs)
if singlepix:
print('No pixel size. Skip add_rgb.')
return
if not (np.shape(c[0]) == np.shape(c[1]) == np.shape(c[2])):
print('RGB shapes mismatch. Skip add_rgb.')
return
st = kwargs2instance(Stretcher, _kw)
c, cmin, cmax = st.set_minmax(c)
for i in range(st.n):
if cmax[i] > cmin[i]:
c[i] = (c[i] - cmin[i]) / (cmax[i] - cmin[i]) * 255
c[i] = self.vskipfill(c[i], v)
size = np.shape(c[0][0])[::-1]
c = np.moveaxis(c, 1, 0)[:, :, ::-self.ydir, ::-self.xdir]
for axnow, rgb in zip(self.ax, c):
im = Image.new('RGB', size, (128, 128, 128))
for j in range(size[1]):
for i in range(size[0]):
value = tuple(int(a[j, i]) for a in rgb)
im.putpixel((i, j), value)
axnow.imshow(im, extent=[x[0], x[-1], y[0], y[-1]])
axnow.set_aspect(np.abs((x[-1]-x[0]) / (y[-1]-y[0])))
def _set_axis_shared(self, pa2: PlotAxes2D, title: dict | str | None):
"""Internal method used in set_axis() and set_axis_radec().
Args:
pa2 (PlotAxes2D): This is instantiated in set_axis() or set_axis_radec().
title (dict | str | None): str means set_title(str) for 2D or fig.suptitle(str) for 3D. Defaults to None.
"""
for ch, axnow in enumerate(self.ax):
pa2.set_xyaxes(axnow)
self.Xlim = pa2.xlim
self.Ylim = pa2.ylim
if ch not in self.bottomleft:
plt.setp(axnow.get_xticklabels(), visible=False)
plt.setp(axnow.get_yticklabels(), visible=False)
axnow.set_xlabel('')
axnow.set_ylabel('')
if len(self.ax) == 1:
if self.fig is None:
plt.figure(0).tight_layout()
if title is not None:
if len(self.ax) > 1:
t = {'y': 0.9}
t_in = {'t': title} if type(title) is str else title
t.update(t_in)
for i in range(self.npages):
fig = plt.figure(i)
fig.suptitle(**t)
else:
t = {'label': title} if type(title) is str else title
axnow.set_title(**t)
[docs]
def set_axis(self, title: dict | str | None = None, **kwargs) -> None:
"""Use Axes.set_* of matplotlib. kwargs can include the arguments of PlotAxes2D to adjust x and y axis.
Args:
title (dict | str | None): str means set_title(str) for 2D or fig.suptitle(str) for 3D. Defaults to None.
"""
_kw = {}
_kw.update(kwargs)
offunit = '(arcsec)' if self.dist == 1 else '(au)'
if self.pv:
offlabel = f'Offset {offunit}'
vellabel = r'Velocity (km s$^{-1})$'
if 'xlabel' not in _kw:
_kw['xlabel'] = vellabel if self.swapxy else offlabel
if 'ylabel' not in _kw:
_kw['ylabel'] = offlabel if self.swapxy else vellabel
_kw['samexy'] = False
else:
ralabel, declabel = f'R.A. {offunit}', f'Dec. {offunit}'
if 'xlabel' not in _kw:
_kw['xlabel'] = declabel if self.swapxy else ralabel
if 'ylabel' not in _kw:
_kw['ylabel'] = ralabel if self.swapxy else declabel
if 'xlim' not in _kw:
_kw['xlim'] = self.Xlim
if 'ylim' not in _kw:
_kw['ylim'] = self.Ylim
pa2 = kwargs2instance(PlotAxes2D, _kw)
self._set_axis_shared(pa2=pa2, title=title)
[docs]
def set_axis_radec(self, center: str | None = None,
xlabel: str = 'R.A. (ICRS)',
ylabel: str = 'Dec. (ICRS)',
nticksminor: int = 2,
grid: dict | None = None, title: dict | None = None
) -> None:
"""Use Axes.set_* of matplotlib. kwargs can include the arguments of PlotAxes2D to adjust x and y axis.
Args:
center (str, optional): Defaults to None, initial one.
xlabel (str, optional): Defaults to 'R.A. (ICRS)'.
ylabel (str, optional): Defaults to 'Dec. (ICRS)'.
nticksminor (int, optional): Interval ratio of major and minor ticks. Defaults to 2.
grid (dict, optional): True means merely grid(). Defaults to None.
title (dict | str | None): str means set_title(str) for 2D or fig.suptitle(str) for 3D. Defaults to None.
"""
if center is None:
center = self.center
if center is None:
center = '00h00m00s 00d00m00s'
if len(csplit := center.split()) == 3:
center = f'{csplit[1]} {csplit[2]}'
if on_min_scale := (self.rmax >= 60.0):
# On a 5-second grid.
ra_s = np.floor(float(get_sec(center, 0)) / 5) * 5
dec_s = 0.0
ra = get_hmdm(center, 'ra') + f'{ra_s:.1f}s'
dec = get_hmdm(center, 'dec') + f'{dec_s:.1f}s'
center = f'{ra} {dec}'
def get_tickvalues(ticks: np.ndarray, mode: str, no_sec: bool
) -> np.ndarray:
xy = [np.zeros_like(ticks), ticks / 3600.]
if mode == 'ra':
xy.reverse()
tickvalues = xy2coord(xy, center)
getter = get_min if no_sec else get_sec
tickvalues = [getter(t, mode) for t in tickvalues] # str
tickvalues = np.array(tickvalues, dtype=float)
# 7-digit precision for practical use.
tickvalues = np.round(tickvalues, 7)
return tickvalues
units = {'ra': {'h': r'$^\mathrm{h}$',
'm': r'$^\mathrm{m}$',
's': r'.$\hspace{-0.4}^\mathrm{s}$'},
'dec': {'d': r'$^{\circ}$',
'm': r'$^{\prime}$',
's': r'.$\hspace{-0.4}^{\prime\prime}$'}}
dec_center = coord2xy(center)[1]
sign_dec = np.sign(dec_center)
cos_dec = np.cos(np.radians(dec_center))
intgrid = np.array([-3, -2, -1, 0, 1, 2, 3])
i_mid = (len(intgrid) - 1) // 2
def makegrid(mode: str):
second = float(get_sec(center, mode))
no_sec = on_min_scale and (mode == 'dec')
# gridwidth is a float like 2 x 10^order (arcsec).
gridwidth, order = _get_gridwidth(mode, self.rmax)
# ndigits = -1 is the largest case for 10", 20", ...
decimals = str(max(-order, 0))
rounded = round(second, ndigits=max(-order, -1))
# Get a grid point closest to the input second.
rounded = round(rounded / gridwidth) * gridwidth
factor = 15 * cos_dec if mode == 'ra' else sign_dec
ticks = (intgrid * gridwidth - second + rounded) * factor
ticksminor = np.linspace(ticks[0], ticks[-1], 6*nticksminor + 1)
tickvalues = get_tickvalues(ticks, mode, no_sec)
whole, frac = np.divmod(tickvalues, 1)
u = units[mode]['m' if no_sec else 's']
ticklabels = [f'{int(i):02d}{u}' + f'{j:.{decimals}f}'[2:]
for i, j in zip(whole % 60, frac)]
return ticks, ticksminor, ticklabels
xticks, xticksminor, xticklabels = makegrid('ra')
yticks, yticksminor, yticklabels = makegrid('dec')
ra_hm = get_hmdm(xy2coord([xticks[i_mid] / 3600., 0], center), 'ra')
dec_dm = get_hmdm(xy2coord([0, yticks[i_mid] / 3600.], center), 'dec')
if on_min_scale:
dec_dm = dec_dm.split('d')[0] + 'd'
ra_hm = ra_hm.translate(str.maketrans(units['ra']))
dec_dm = dec_dm.translate(str.maketrans(units['dec']))
xticklabels[i_mid] = ra_hm + xticklabels[i_mid]
yticklabels[i_mid] = dec_dm + '\n' + yticklabels[i_mid]
pa2 = PlotAxes2D(True, None, 'linear', 'linear',
self.Xlim, self.Ylim, xlabel, ylabel,
xticks, yticks, xticklabels, yticklabels,
xticksminor, yticksminor, grid)
self._set_axis_shared(pa2=pa2, title=title)
[docs]
def savefig(self, filename: str | None = None,
show: bool = False, **kwargs) -> None:
"""Use savefig of matplotlib.
Args:
filename (str, optional): Output image file name. Defaults to None.
show (bool, optional): True means doing plt.show(). Defaults to False.
"""
_kw = {'transparent': True, 'bbox_inches': 'tight'}
_kw.update(kwargs)
for axnow in self.ax:
axnow.set_xlim(*self.Xlim)
axnow.set_ylim(*self.Ylim)
if type(filename) is str:
ext = filename.split('.')[-1]
for i in range(self.npages):
ver = '' if self.npages == 1 else f'_{i:d}'
fig = plt.figure(i)
fig.patch.set_alpha(0)
fname = filename.replace(f'.{ext}', f'{ver}.{ext}')
fig.savefig(fname, **_kw)
if show:
plt.show()
plt.close('all')
[docs]
def get_figax(self) -> tuple[object, object]:
"""Output the external fig and ax after plotting.
Returns:
tuple: (fig, ax)
"""
if len(self.ax) > 1:
print('PlotAstroData.get_figax() is not supported'
+ ' with channel maps')
return
fig = plt.figure(0) if self.fig is None else self.fig
return fig, self.ax[0]
[docs]
def plotprofile(coords: list[str] | str = [],
xlist: list[float] = [], ylist: list[float] = [],
ellipse: list[float, float, float] | None = None,
ninterp: int = 1,
flux: bool = False, width: int = 1,
gaussfit: bool = False, gauss_kwargs: dict = {},
title: list[str] | None = None,
text: list[str] | None = None,
nrows: int = 0, ncols: int = 1,
fig: object | None = None,
ax: object | None = None,
getfigax: bool = False,
savefig: dict | str | None = None,
show: bool = False,
**kwargs) -> tuple[object, object]:
"""Use Axes.plot of matplotlib to plot line profiles at given coordinates. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs must include the arguments of AstroFrame to specify the ranges and so on for plotting. kwargs can include the arguments of PlotAxes2D to adjust x and y axes.
Args:
coords (list, optional): Coordinates. Defaults to [].
xlist (list, optional): Offset from the center. Defaults to [].
ylist (list, optional): Offset from the center. Defaults to [].
ellipse (list, optional): [major, minor, pa], For average. Defaults to None.
ninterp (int, optional): Number of points for interpolation. Defaults to 1.
flux (bool, optional): y axis is flux density. Defaults to False.
width (int, optional): Rebinning step along v. Defaults to 1.
gaussfit (bool, optional): Fit the profiles. Defaults to False.
gauss_kwargs (dict, optional): Kwargs for Axes.plot. Defaults to {}.
title (list, optional): For each plot. Defaults to None.
text (list, optional): For each plot. Defaults to None.
nrows (int, optional): Used for channel maps. Defaults to 0.
ncols (int, optional): Used for channel maps. Defaults to 1.
fig (object, optional): External plt.figure(). Defaults to None.
ax (object, optional): External fig.add_subplot(). Defaults to None.
getfigax (bool, optional): Defaults to False.
savefig (dict or str, optional): For plt.figure().savefig(). Defaults to None.
show (bool, optional): True means doing plt.show(). Defaults to False.
Returns:
tuple: (fig, ax), where ax is a list, if getfigax=True. Otherwise, no return.
"""
_kw = {'drawstyle': 'steps-mid', 'color': 'k'}
_kw.update(kwargs)
_kwgauss = {'drawstyle': 'default', 'color': 'g'}
_kwgauss.update(gauss_kwargs)
if type(coords) is str:
coords = [coords]
f = kwargs2instance(AstroFrame, _kw)
d = kwargs2instance(AstroData, _kw)
f.read(d)
d.binning([width, 1, 1])
v, prof, gfitres = d.profile(coords=coords, xlist=xlist, ylist=ylist,
ellipse=ellipse, ninterp=ninterp,
flux=flux, gaussfit=gaussfit)
nprof = len(prof)
if 'ylabel' in _kw:
ylabel = _kw['ylabel']
elif d.Tb:
ylabel = r'$T_b$ (K)'
elif flux:
ylabel = 'Flux (Jy)'
else:
ylabel = d.bunit
if type(ylabel) is str:
ylabel = [ylabel] * nprof
def gauss(x, p, c, w):
return p * np.exp(-4. * np.log(2.) * ((x - c) / w)**2)
set_rcparams(20, 'w')
if ncols == 1:
nrows = nprof
if fig is None:
fig = plt.figure(figsize=(6 * ncols, 3 * nrows))
if nprof > 1 and ax is not None:
print('External ax is supported only when len(coords)=1.')
ax = None
ax = np.empty(nprof, dtype='object') if ax is None else [ax]
if 'xlabel' not in _kw:
_kw['xlabel'] = 'Velocity (km s$^{-1}$)'
if 'xlim' not in _kw:
_kw['xlim'] = [v.min(), v.max()]
_kw['samexy'] = False
pa2 = kwargs2instance(PlotAxes2D, _kw)
for i in range(nprof):
sharex = None if i < nrows - 1 else ax[i - 1]
ax[i] = fig.add_subplot(nrows, ncols, i + 1, sharex=sharex)
if gaussfit:
ax[i].plot(v, gauss(v, *gfitres['best'][i]), **_kwgauss)
ax[i].plot(v, prof[i], **_kw)
ax[i].hlines([0], v.min(), v.max(), linestyle='dashed', color='k')
ax[i].set_ylabel(ylabel[i])
pa2.set_xyaxes(ax[i])
if text is not None:
ax[i].text(**text[i])
if title is not None:
if type(title[i]) is str:
title[i] = {'label': title[i]}
ax[i].set_title(**title[i])
if i <= nprof - ncols - 1:
plt.setp(ax[i].get_xticklabels(), visible=False)
if getfigax:
return fig, ax
close_figure(fig, savefig, show)
[docs]
def plotslice(length: float, dx: float | None = None, pa: float = 0,
txtfile: str | None = None,
fig: object | None = None, ax: object | None = None,
getfigax: bool = False,
savefig: str | dict | None = None, show: bool = False,
**kwargs) -> tuple[object, object]:
"""Use Axes.plot of matplotlib to plot a 1D spatial slice in a 2D map. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs must include the arguments of AstroFrame to specify the ranges and so on for plotting. kwargs can include the arguments of PlotAxes2D to adjust x and y axes.
Args:
length (float): Slice length.
dx (float, optional): Grid increment. Defaults to None.
pa (float, optional): Degree. Position angle. Defaults to 0.
txtfile (str, optional): File name for numpy.savetxt(). Defaults to None.
fig (object, optional): External plt.figure(). Defaults to None.
ax (object, optional): External fig.add_subplot(). Defaults to None.
getfigax (bool, optional): Defaults to False.
savefig (dict or str, optional): For plt.figure().savefig(). Defaults to None.
show (bool, optional): True means doing plt.show(). Defaults to False.
Returns:
tuple: (fig, ax), where ax is a list, if getfigax=True. Otherwise, no return.
"""
_kw = {'linestyle': '-', 'marker': 'o'}
_kw.update(kwargs)
_kw['rmax'] = length / 2
f = kwargs2instance(AstroFrame, _kw)
d = kwargs2instance(AstroData, _kw)
f.read(d)
if np.ndim(d.data) > 2:
print('Only 2D map is supported.')
return
r, z = d.slice(length=length, pa=pa, dx=dx)
xunit = 'arcsec' if f.dist == 1 else 'au'
yunit = 'K' if d.Tb else d.bunit
yquantity = r'$T_b$' if d.Tb else 'intensity'
if txtfile is not None:
np.savetxt(txtfile, np.c_[r, z],
header=f'x ({xunit}), {yquantity} ({yunit}); '
+ f'positive x is pa={pa:.2f} deg.')
if 'xlabel' not in _kw:
_kw['xlabel'] = f'Offset ({xunit})'
if 'ylabel' not in _kw:
_kw['ylabel'] = f'Intensity ({yunit})'
if 'xlim' not in _kw:
_kw['xlim'] = [r.min(), r.max()]
_kw['samexy'] = False
set_rcparams()
if fig is None:
fig = plt.figure()
if ax is None:
ax = fig.add_subplot(1, 1, 1)
pa2 = kwargs2instance(PlotAxes2D, _kw)
ax.plot(r, z, **_kw)
if d.sigma is not None:
ax.plot(r, r * 0 + 3 * d.sigma, 'k--')
pa2.set_xyaxes(ax)
if getfigax:
return fig, ax
close_figure(fig, savefig, show)
[docs]
def plot3d(levels: list[float] = [3, 6, 12],
cmap: str = 'Jet', alpha: float = 0.08,
xlabel: str = 'R.A. (arcsec)',
ylabel: str = 'Dec. (arcsec)',
vlabel: str = 'Velocity (km/s)',
xskip: int = 1, yskip: int = 1,
eye_p: float = 0, eye_i: float = 180,
xplus: dict = {}, xminus: dict = {},
yplus: dict = {}, yminus: dict = {},
vplus: dict = {}, vminus: dict = {},
outname: str = 'plot3d', show: bool = False,
return_data_layout: bool = False,
**kwargs) -> None | dict:
"""Use Plotly. kwargs must include the arguments of AstroData to specify the data to be plotted. kwargs must include the arguments of AstroFrame to specify the ranges and so on for plotting.
Args:
levels (list, optional): Contour levels. Defaults to [3,6,12].
cmap (str, optional): Color map name. Defaults to 'Jet'.
alpha (float, optional): opacity in plotly. Defaults to 0.08.
xlabel (str, optional): Defaults to 'R.A. (arcsec)'.
ylabel (str, optional): Defaults to 'Dec. (arcsec)'.
vlabel (str, optional): Defaults to 'Velocity (km/s)'.
xskip (int, optional): Number of pixel to skip. Defaults to 1.
yskip (int, optional): Number of pixel to skip. Defaults to 1.
eye_p (float, optional): Azimuthal angle of camera. Defaults to 0.
eye_i (float, optional): Inclination angle of camera. Defaults to 180.
xplus (dict, optional): 2D data to be plotted on the y-v plane at the positive edge of x. This dictionary must have a key of data and can have keys of levels, sigma, cmap, and alpha. Defaults to {}.
xminus (dict, optional): See xplus. Defaults to {}.
yplus (dict, optional): See xplus. Defaults to {}.
yminus (dict, optional): See xplus. Defaults to {}.
vplus (dict, optional): See xplus. Defaults to {}.
vminus (dict, optional): See xplus. Defaults to {}.
outname (str, optional): Output file name. Defaults to 'plot3d'.
show (bool, optional): auto_play in plotly. Defaults to False.
return_data_layout (bool, optional): Whether to return data and layout for plotly.graph_objs.Figure. Defaults to False.
Returns:
dict: {'data': data, 'layout': layout}, if return_data_layout=True. Otherwise, no return.
"""
import plotly.graph_objs as go
from skimage import measure
f = kwargs2instance(AstroFrame, kwargs)
d = kwargs2instance(AstroData, kwargs)
f.read(d, xskip, yskip)
volume, x, y, v, sigma = d.data, d.x, d.y, d.v, d.sigma
dx, dy, dv = d.dx, d.dy, d.dv
volume[np.isnan(volume)] = 0
if dx < 0:
x, dx, volume = x[::-1], -dx, volume[:, :, ::-1]
if dy < 0:
y, dy, volume = y[::-1], -dy, volume[:, ::-1, :]
if dv < 0:
v, dv, volume = v[::-1], -dv, volume[::-1, :, :]
s, ds = [x, y, v], [dx, dy, dv]
deg = np.radians(1)
xeye = -np.sin(eye_i * deg) * np.sin(eye_p * deg)
yeye = -np.sin(eye_i * deg) * np.cos(eye_p * deg)
zeye = np.cos(eye_i * deg)
margin = dict(l=0, r=0, b=0, t=0)
camera = dict(eye=dict(x=xeye, y=yeye, z=zeye), up=dict(x=0, y=1, z=0))
xaxis = dict(range=[x[0], x[-1]], title=xlabel)
yaxis = dict(range=[y[0], y[-1]], title=ylabel)
zaxis = dict(range=[v[0], v[-1]], title=vlabel)
scene = dict(aspectmode='cube', camera=camera,
xaxis=xaxis, yaxis=yaxis, zaxis=zaxis)
layout = go.Layout(margin=margin, scene=scene, showlegend=False)
data = []
for lev in levels:
if lev * sigma > np.max(volume):
continue
vertices, simplices, _, _ = measure.marching_cubes(volume, lev * sigma)
Xg, Yg, Zg = [t[0] + i * dt for t, i, dt
in zip(s, vertices.T[::-1], ds)]
i, j, k = simplices.T
mesh = dict(type='mesh3d', x=Xg, y=Yg, z=Zg, i=i, j=j, k=k,
intensity=Zg * 0 + lev,
colorscale=cmap, reversescale=False,
cmin=np.min(levels), cmax=np.max(levels),
opacity=alpha, name='', showscale=False)
data.append(mesh)
Xe, Ye, Ze = [], [], []
for t in vertices[simplices]:
Xe += [x[0] + dx * t[k % 3][2] for k in range(4)] + [None]
Ye += [y[0] + dy * t[k % 3][1] for k in range(4)] + [None]
Ze += [v[0] + dv * t[k % 3][0] for k in range(4)] + [None]
lines = dict(type='scatter3d', x=Xe, y=Ye, z=Ze,
mode='lines', opacity=0.04, visible=True,
name='', line=dict(color='rgb(0,0,0)', width=1))
data.append(lines)
def plot_on_wall(sign: int, axis: int, **kwargs):
if kwargs == {}:
return
match axis:
case 2:
shape = np.shape(d.data[:, :, 0])
case 1:
shape = np.shape(d.data[:, 0, :])
case 0:
shape = np.shape(d.data[0, :, :])
if np.shape(kwargs['data']) != shape:
print('The shape of the 2D data is inconsistent'
+ ' with the shape of the 3D data.')
return
_kw = {'levels': [3, 6, 12, 24, 48, 96, 192, 384],
'sigma': 'hist', 'cmap': 'Jet', 'alpha': 0.3}
_kw.update(kwargs)
volume = _kw['data']
levels = _kw['levels']
cmap = _kw['cmap']
alpha = _kw['alpha']
sigma = estimate_rms(data=volume, sigma=_kw['sigma'])
volume[np.isnan(volume)] = 0
a = int(sign == -1)
b = int(sign == 1)
volume = np.moveaxis([volume * a, volume * b], 0, axis)
if d.dx < 0:
volume = volume[:, :, ::-1]
if d.dy < 0:
volume = volume[:, ::-1, :]
if d.dv < 0:
volume = volume[::-1, :, :]
for lev in levels:
if lev * sigma > np.max(volume):
continue
vertices, simplices, _, _ = measure.marching_cubes(volume, lev * sigma)
Xg, Yg, Zg = [t[0] + i * dt for t, i, dt
in zip(s, vertices.T[::-1], ds)]
match axis:
case 2:
Xg = Xg * 0 + (x[-1] if sign == 1 else x[0])
case 1:
Yg = Yg * 0 + (y[-1] if sign == 1 else y[0])
case 0:
Zg = Zg * 0 + (v[-1] if sign == 1 else v[0])
i, j, k = simplices.T
mesh2d = dict(type='mesh3d', x=Xg, y=Yg, z=Zg,
i=i, j=j, k=k,
intensity=Zg * 0 + lev,
colorscale=cmap, reversescale=False,
cmin=np.min(levels), cmax=np.max(levels),
opacity=alpha, name='', showscale=False)
data.append(mesh2d)
klist = [xplus, xminus, yplus, yminus, vplus, vminus]
slist = [1, -1, 1, -1, 1, -1]
alist = [2, 2, 1, 1, 0, 0]
for kw, sign, axis in zip(klist, slist, alist):
plot_on_wall(sign=sign, axis=axis, **kw)
if return_data_layout:
return {'data': data, 'layout': layout}
else:
fig = go.Figure(data=data, layout=layout)
fig.write_html(file=outname.replace('.html', '') + '.html',
auto_play=show)