"""
TODO
"""
import logging
from typing import Any, List, Optional, Tuple
import matplotlib.colors
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as _pd
from matplotlib.ticker import FixedFormatter, FixedLocator, MultipleLocator
from ..units import ureg as _ureg
from .artist import PALETTE
from .artist import Artist as _Artist
from .artist import ArtistException as _ArtistException
FLUKA_COLORS = [
(1.0, 1.0, 1.0),
(0.9, 0.6, 0.9),
(1.0, 0.4, 1.0),
(0.9, 0.0, 1.0),
(0.7, 0.0, 1.0),
(0.5, 0.0, 0.8),
(0.0, 0.0, 0.8),
(0.0, 0.0, 1.0),
(0.0, 0.6, 1.0),
(0.0, 0.8, 1.0),
(0.0, 0.7, 0.5),
(0.0, 0.9, 0.2),
(0.5, 1.0, 0.0),
(0.8, 1.0, 0.0),
(1.0, 1.0, 0.0),
(1.0, 0.8, 0.0),
(1.0, 0.5, 0.0),
(1.0, 0.0, 0.0),
(0.8, 0.0, 0.0),
(0.6, 0.0, 0.0),
(0.0, 0.0, 0.0),
]
FlukaColormap = matplotlib.colors.LinearSegmentedColormap.from_list("fluka", FLUKA_COLORS, N=300)
# # Define default color palette
palette = PALETTE["solarized"]
# Define "logical" colors
palette["BEND"] = palette["blue"]
palette["QUADRUPOLE"] = palette["red"]
palette["SEXTUPOLE"] = palette["yellow"]
palette["OCTUPOLE"] = palette["green"]
palette["MULTIPOLE"] = palette["gray"]
palette["DEGRADER"] = palette["base02"]
palette["RECTANGULARCOLLIMATOR"] = palette["darkgreen"]
palette["CIRCULARCOLLIMATOR"] = palette["magenta"]
palette["ELLIPTICALCOLLIMATOR"] = palette["orange"]
palette["COLLIMATOR"] = palette["magenta"]
palette["HKICKER"] = palette["magenta"]
palette["VKICKER"] = palette["violet"]
palette["SCATTERER"] = palette["base02"]
palette["MATRIX"] = palette["cyan"]
palette["ELEMENT"] = palette["base0"]
[docs]
class MatplotlibArtist(_Artist):
"""
TODO
"""
def __init__(self, ax: Optional[plt.Axes] = None, **kwargs: Any):
"""
Args:
param ax: the matplotlib ax used for plotting. If None it will be created with `init_plot` (kwargs are
forwarded).
with_frames: draw the entry and exit frames of each element
with_centers: draw the center of each polar coordinate elements
kwargs: forwarded to `Artist` and to `init_plot`.
"""
super().__init__(**kwargs)
if ax is None:
self.init_plot(**kwargs)
else:
self._ax = ax
self._fig = ax.figure
self._ax2 = None
@property
def ax(self) -> plt.Axes:
"""Current Matplotlib ax.
Returns:
the Matplotlib ax.
"""
return self._ax
@property
def ax2(self) -> plt.Axes:
"""
Returns:
"""
return self._ax2
@property
def figure(self) -> plt.figure:
"""Current Matplotlib figure.
Returns:
the Matplotlib figure.
"""
return self._fig
@ax.setter # type:ignore[no-redef]
def ax(self, ax: plt.Axes) -> None:
self._ax = ax
[docs]
def init_plot(self, figsize: Tuple[int, int] = (12, 8), subplots: int = 111) -> None:
"""
Initialize the Matplotlib figure and ax.
Args:
subplots: number of subplots
figsize: figure size
"""
self._fig = plt.figure(figsize=figsize)
self._ax = self._fig.add_subplot(subplots)
[docs]
def plot(self, *args: Any, **kwargs: Any) -> None:
"""Proxy for matplotlib.pyplot.plot
Same as `matplotlib.pyplot.plot`, forwards all arguments.
"""
self._ax.plot(*args, **kwargs)
[docs]
@staticmethod
def beamline_get_ticks_locations(o: _pd.DataFrame) -> List[float]:
return list(o["AT_CENTER"].apply(lambda e: e.m_as("m")).values)
[docs]
@staticmethod
def beamline_get_ticks_labels(o: _pd.DataFrame) -> List[str]:
return list(o.index)
[docs]
def plot_cartouche(
self,
beamline: Optional[_pd.DataFrame] = None,
print_label: bool = False,
labels: Optional[_pd.DataFrame] = None,
vertical_position: float = 1.15,
) -> None:
"""
Args:
beamline:
print_label:
labels:
Returns:
"""
if beamline is None:
raise _ArtistException("No beamline provided")
self._ax2 = self._ax.twinx()
self._ax2.set_ylim([0, 1])
self._ax2.axis("off")
self._ax2.axis("on")
self._ax2.set_yticks([])
self._ax2.set_ylim([0, 1])
self._ax2.hlines(
vertical_position,
0,
beamline.iloc[-1]["AT_EXIT"].m_as("m"),
clip_on=False,
colors="black",
lw=1,
)
for i, e in beamline.iterrows():
if e["CLASS"].upper() in ["DRIFT", "MARKER"]:
continue
if e["CLASS"].upper() in ["SBEND", "RBEND"]:
self._ax2.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), vertical_position - 0.05),
e["L"].m_as("m"),
0.1,
hatch="",
facecolor=palette["BEND"],
clip_on=False,
),
)
elif e["CLASS"].upper() in [
"SEXTUPOLE",
"QUADRUPOLE",
"MULTIPOLE",
"HKICKER",
"RECTANGULARCOLLIMATOR",
"VKICKER",
"DEGRADER",
"CIRCULARCOLLIMATOR",
"ELLIPTICALCOLLIMATOR",
"MATRIX",
"SCATTERER",
"ELEMENT",
]:
self._ax2.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), vertical_position - 0.05),
e["L"].m_as("m"),
0.1,
hatch="",
facecolor=palette[e["CLASS"].upper()],
ec=palette[e["CLASS"].upper()],
clip_on=False,
),
)
else:
logging.warning(f"colors are not implemented for {e['CLASS']}")
if print_label: # For beamline losses or Twiss
bl_short = beamline.reset_index()
bl_short = bl_short.query("CLASS != 'Drift'")
bl_short = bl_short.set_index("NAME")
if labels is not None:
ticks_locations_short = self.beamline_get_ticks_locations(labels)
ticks_labels_short = self.beamline_get_ticks_labels(labels)
else:
ticks_locations_short = self.beamline_get_ticks_locations(bl_short)
ticks_labels_short = self.beamline_get_ticks_labels(bl_short)
self._ax2.xaxis.set_major_locator(FixedLocator(ticks_locations_short))
self._ax2.xaxis.set_major_formatter(FixedFormatter(ticks_labels_short))
self._ax2.get_xaxis().set_tick_params(direction="out")
self._ax2.tick_params(axis="both", which="major")
self._ax2.tick_params(axis="x")
plt.setp(self._ax2.xaxis.get_majorticklabels(), rotation=-90)
[docs]
def plot_beamline(
self,
beamline: Optional[_pd.DataFrame] = None,
print_label: bool = False,
with_aperture: bool = True,
labels: Optional[_pd.DataFrame] = None,
**kwargs: Any,
) -> None:
"""
Args:
beamline ():
print_label ():
with_aperture ():
labels ():
**kwargs ():
Returns:
"""
if beamline is None:
raise _ArtistException("No beamline provided")
bl_short = beamline.reset_index()
bl_short = bl_short[[not a for a in bl_short["NAME"].str.contains("DRIFT")]]
bl_short = bl_short.set_index("NAME")
ticks_locations_short = self.beamline_get_ticks_locations(bl_short)
ticks_labels_short = self.beamline_get_ticks_labels(bl_short)
self._ax.tick_params(axis="both", which="major")
self._ax.tick_params(axis="x")
plt.setp(self._ax.xaxis.get_majorticklabels(), rotation=-45)
self._ax.set_xlim([bl_short.iloc[0]["AT_ENTRY"].m_as("m"), bl_short.iloc[-1]["AT_EXIT"].m_as("m")])
self._ax.get_xaxis().set_tick_params(direction="out")
self._ax.yaxis.set_major_locator(MultipleLocator(10))
self._ax.yaxis.set_minor_locator(MultipleLocator(5))
self._ax.set_ylim(kwargs.get("ylim", [-60, 60]))
self._ax.grid(True, alpha=0.25)
if with_aperture:
self.draw_aperture(beamline, **kwargs)
if print_label:
if labels is not None:
ticks_locations_short = self.beamline_get_ticks_locations(labels)
ticks_labels_short = self.beamline_get_ticks_labels(labels)
self._ax.xaxis.set_major_locator(FixedLocator(ticks_locations_short))
self._ax.xaxis.set_major_formatter(FixedFormatter(ticks_labels_short))
[docs]
def draw_aperture(self, bl: _pd.DataFrame, **kwargs: Any) -> None:
bl = bl.copy()
if "APERTURE" not in bl:
logging.warning("No APERTURE defined in the beamline")
return
bl = bl[~bl["APERTYPE"].isnull()]
bl[["APERTYPE", "CLASS"]] = bl[["APERTYPE", "CLASS"]].applymap(lambda e: e.upper())
bl.query(
"CLASS in ['QUADRUPOLE', 'SBEND', 'RBEND', 'RECTANGULARCOLLIMATOR', 'CIRCULARCOLLIMATOR', "
"'ELLIPTICALCOLLIMATOR']",
inplace=True,
)
planes = kwargs.get("plane", "X")
# Set the y aperture for circular apertype
# TODO this raises an error if the option DontSplitSBends = 0
for idx in bl.query("APERTYPE == 'CIRCULAR'").index:
bl.at[idx, "APERTURE"][0:] = [bl.at[idx, "APERTURE"][0], bl.at[idx, "APERTURE"][0]]
if planes == "X":
index = [0, 0]
elif planes == "Y":
index = [1, 1]
elif planes == "both":
index = [1, 0]
else:
raise _ArtistException("Plane must be 'X', 'Y' or 'both'.")
bl["APERTURE_UP"] = bl["APERTURE"].apply(lambda a: a[index[0]].m_as("mm"))
bl["APERTURE_DOWN"] = bl["APERTURE"].apply(lambda a: a[index[1]].m_as("mm"))
# Draw the collimator as they have no chamber.
bl.query("CLASS == 'RECTANGULARCOLLIMATOR'").apply(lambda e: self.draw_coll(e), axis=1)
bl.query("CLASS == 'CIRCULARCOLLIMATOR'").apply(lambda e: self.draw_coll(e), axis=1)
bl.query("CLASS != 'RECTANGULARCOLLIMATOR' and CLASS !='CIRCULARCOLLIMATOR'", inplace=True)
if "CHAMBER" not in bl:
bl["CHAMBER"] = 0
bl["CHAMBER"] = bl["CHAMBER"].apply(lambda a: a * _ureg.mm)
bl["CHAMBER_UP"] = bl["CHAMBER"].apply(lambda a: a.m_as("mm"))
bl["CHAMBER_DOWN"] = bl["CHAMBER"].apply(lambda a: a.m_as("mm"))
bl.query("CLASS == 'QUADRUPOLE'").apply(lambda e: self.draw_quad(e), axis=1)
bl.query("CLASS == 'SBEND'").apply(lambda e: self.draw_bend(e), axis=1)
bl.query("CLASS == 'RBEND'").apply(lambda e: self.draw_bend(e), axis=1)
[docs]
def draw_quad(self, e: _pd.DataFrame) -> None:
self._ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), e["APERTURE_UP"] + e["CHAMBER_UP"]), # (x,y)
e["L"].m_as("m"), # width
100,
facecolor=palette["QUADRUPOLE"],
),
)
self._ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"] - e["CHAMBER_DOWN"]), # (x,y)
e["L"].m_as("m"), # width
-100,
facecolor=palette["QUADRUPOLE"],
),
)
self.draw_chamber(self._ax, e)
[docs]
def draw_coll(self, e: _pd.DataFrame) -> None:
self._ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), e["APERTURE_UP"]), # (x,y)
e["L"].m_as("m"), # width
100, # height
facecolor=palette["COLLIMATOR"],
),
)
self._ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"]), # (x,y)
e["L"].m_as("m"), # width
-100, # height
facecolor=palette["COLLIMATOR"],
),
)
[docs]
def draw_bend(self, e: _pd.DataFrame) -> None:
tmp = e["APERTURE_UP"] + e["CHAMBER_UP"]
if tmp > 55:
logging.warning(f"Aperture are bigger than 55 mm for {e.name}.")
self._ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), tmp if tmp < 55 else 55), # (x,y)
e["L"].m_as("m"), # width
100, # height
facecolor=palette["BEND"],
),
)
tmp = -e["APERTURE_DOWN"] - e["CHAMBER_UP"]
if tmp < -55:
logging.warning(f"Aperture are bigger than 55 mm for {e.name}.")
self._ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), tmp if abs(tmp) < 55 else -55), # (x,y)
e["L"].m_as("m"), # width
-100,
facecolor=palette["BEND"],
),
)
self.draw_chamber(self._ax, e)
[docs]
@staticmethod
def draw_chamber(ax: plt.Axes, e: _pd.DataFrame) -> None:
ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), (e["APERTURE_UP"])), # (x,y)
e["L"].m_as("m"), # width
e["CHAMBER_UP"], # height
hatch="",
facecolor=palette["base01"],
),
)
ax.add_patch(
patches.Rectangle(
(e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"]), # (x,y)
e["L"].m_as("m"), # width
-e["CHAMBER_UP"], # height
hatch="",
facecolor=palette["base01"],
),
)
# Old method to convert for matplotlib
# @staticmethod
# def draw_bpm_size(ax, s, x):
# ax.add_patch(
# patches.Rectangle(
# (s - 0.05, -x),
# 0.1,
# 2 * x,
# )
# )
#
# def bpm(self, ax, bl, **kwargs):
# """TODO."""
# if kwargs.get('plane') is None:
# raise Exception("'plane' argument must be provided.")
# bl.line[bl.line[f"BPM_STD_{kwargs.get('plane')}"].notnull()].apply(
# lambda x: self.draw_bpm_size(ax, x['AT_CENTER'], x[f"BPM_STD_{kwargs.get('plane')}"]),
# axis=1
# )
#
# # THIS IS THE OLD scattering.py
# @staticmethod
# def draw_slab(ax, e):
# materials_colors = {
# 'graphite': 'g',
# 'beryllium': 'r',
# 'water': 'b',
# 'lexan': 'y',
# }
# ax.add_patch(
# patches.Rectangle(
# (e['AT_ENTRY'], -1), # (x,y)
# e['LENGTH'], # width
# 2, # height
# hatch='', facecolor=materials_colors.get(e['MATERIAL'], 'k')
# )
# )
#
# @staticmethod
# def draw_measuring_plane(ax, e):
# ax.add_patch(
# patches.Rectangle(
# (e['AT_ENTRY'] - 0.005, -1), # (x,y)
# 0.01, # width
# 2, # height
# hatch='', facecolor='k'
# )
# )
#
# def scattering(self, ax, bl, **kwargs):
# bl.line.query("TYPE == 'slab'").apply(lambda e: self.draw_slab(ax, e), axis=1)
# bl.line.query("TYPE == 'mp'").apply(lambda e: self.draw_measuring_plane(ax, e), axis=1)