Source code for georges_core.vis.matplotlib

"""
TODO
"""
import logging
from typing import Any, List, Optional, Tuple

import matplotlib.colors
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as _pd
from matplotlib.ticker import FixedFormatter, FixedLocator, MultipleLocator

from ..units import ureg as _ureg
from .artist import PALETTE
from .artist import Artist as _Artist
from .artist import ArtistException as _ArtistException

FLUKA_COLORS = [
    (1.0, 1.0, 1.0),
    (0.9, 0.6, 0.9),
    (1.0, 0.4, 1.0),
    (0.9, 0.0, 1.0),
    (0.7, 0.0, 1.0),
    (0.5, 0.0, 0.8),
    (0.0, 0.0, 0.8),
    (0.0, 0.0, 1.0),
    (0.0, 0.6, 1.0),
    (0.0, 0.8, 1.0),
    (0.0, 0.7, 0.5),
    (0.0, 0.9, 0.2),
    (0.5, 1.0, 0.0),
    (0.8, 1.0, 0.0),
    (1.0, 1.0, 0.0),
    (1.0, 0.8, 0.0),
    (1.0, 0.5, 0.0),
    (1.0, 0.0, 0.0),
    (0.8, 0.0, 0.0),
    (0.6, 0.0, 0.0),
    (0.0, 0.0, 0.0),
]
FlukaColormap = matplotlib.colors.LinearSegmentedColormap.from_list("fluka", FLUKA_COLORS, N=300)

# # Define default color palette
palette = PALETTE["solarized"]

# Define "logical" colors
palette["BEND"] = palette["blue"]
palette["QUADRUPOLE"] = palette["red"]
palette["SEXTUPOLE"] = palette["yellow"]
palette["OCTUPOLE"] = palette["green"]
palette["MULTIPOLE"] = palette["gray"]
palette["DEGRADER"] = palette["base02"]
palette["RECTANGULARCOLLIMATOR"] = palette["darkgreen"]
palette["CIRCULARCOLLIMATOR"] = palette["magenta"]
palette["ELLIPTICALCOLLIMATOR"] = palette["orange"]
palette["COLLIMATOR"] = palette["magenta"]
palette["HKICKER"] = palette["magenta"]
palette["VKICKER"] = palette["violet"]
palette["SCATTERER"] = palette["base02"]
palette["MATRIX"] = palette["cyan"]
palette["ELEMENT"] = palette["base0"]


[docs] class MatplotlibArtist(_Artist): """ TODO """ def __init__(self, ax: Optional[plt.Axes] = None, **kwargs: Any): """ Args: param ax: the matplotlib ax used for plotting. If None it will be created with `init_plot` (kwargs are forwarded). with_frames: draw the entry and exit frames of each element with_centers: draw the center of each polar coordinate elements kwargs: forwarded to `Artist` and to `init_plot`. """ super().__init__(**kwargs) if ax is None: self.init_plot(**kwargs) else: self._ax = ax self._fig = ax.figure self._ax2 = None @property def ax(self) -> plt.Axes: """Current Matplotlib ax. Returns: the Matplotlib ax. """ return self._ax @property def ax2(self) -> plt.Axes: """ Returns: """ return self._ax2 @property def figure(self) -> plt.figure: """Current Matplotlib figure. Returns: the Matplotlib figure. """ return self._fig @ax.setter # type:ignore[no-redef] def ax(self, ax: plt.Axes) -> None: self._ax = ax
[docs] def init_plot(self, figsize: Tuple[int, int] = (12, 8), subplots: int = 111) -> None: """ Initialize the Matplotlib figure and ax. Args: subplots: number of subplots figsize: figure size """ self._fig = plt.figure(figsize=figsize) self._ax = self._fig.add_subplot(subplots)
[docs] def plot(self, *args: Any, **kwargs: Any) -> None: """Proxy for matplotlib.pyplot.plot Same as `matplotlib.pyplot.plot`, forwards all arguments. """ self._ax.plot(*args, **kwargs)
[docs] @staticmethod def beamline_get_ticks_locations(o: _pd.DataFrame) -> List[float]: return list(o["AT_CENTER"].apply(lambda e: e.m_as("m")).values)
[docs] @staticmethod def beamline_get_ticks_labels(o: _pd.DataFrame) -> List[str]: return list(o.index)
[docs] def plot_cartouche( self, beamline: Optional[_pd.DataFrame] = None, print_label: bool = False, labels: Optional[_pd.DataFrame] = None, vertical_position: float = 1.15, ) -> None: """ Args: beamline: print_label: labels: Returns: """ if beamline is None: raise _ArtistException("No beamline provided") self._ax2 = self._ax.twinx() self._ax2.set_ylim([0, 1]) self._ax2.axis("off") self._ax2.axis("on") self._ax2.set_yticks([]) self._ax2.set_ylim([0, 1]) self._ax2.hlines( vertical_position, 0, beamline.iloc[-1]["AT_EXIT"].m_as("m"), clip_on=False, colors="black", lw=1, ) for i, e in beamline.iterrows(): if e["CLASS"].upper() in ["DRIFT", "MARKER"]: continue if e["CLASS"].upper() in ["SBEND", "RBEND"]: self._ax2.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), vertical_position - 0.05), e["L"].m_as("m"), 0.1, hatch="", facecolor=palette["BEND"], clip_on=False, ), ) elif e["CLASS"].upper() in [ "SEXTUPOLE", "QUADRUPOLE", "MULTIPOLE", "HKICKER", "RECTANGULARCOLLIMATOR", "VKICKER", "DEGRADER", "CIRCULARCOLLIMATOR", "ELLIPTICALCOLLIMATOR", "MATRIX", "SCATTERER", "ELEMENT", ]: self._ax2.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), vertical_position - 0.05), e["L"].m_as("m"), 0.1, hatch="", facecolor=palette[e["CLASS"].upper()], ec=palette[e["CLASS"].upper()], clip_on=False, ), ) else: logging.warning(f"colors are not implemented for {e['CLASS']}") if print_label: # For beamline losses or Twiss bl_short = beamline.reset_index() bl_short = bl_short.query("CLASS != 'Drift'") bl_short = bl_short.set_index("NAME") if labels is not None: ticks_locations_short = self.beamline_get_ticks_locations(labels) ticks_labels_short = self.beamline_get_ticks_labels(labels) else: ticks_locations_short = self.beamline_get_ticks_locations(bl_short) ticks_labels_short = self.beamline_get_ticks_labels(bl_short) self._ax2.xaxis.set_major_locator(FixedLocator(ticks_locations_short)) self._ax2.xaxis.set_major_formatter(FixedFormatter(ticks_labels_short)) self._ax2.get_xaxis().set_tick_params(direction="out") self._ax2.tick_params(axis="both", which="major") self._ax2.tick_params(axis="x") plt.setp(self._ax2.xaxis.get_majorticklabels(), rotation=-90)
[docs] def plot_beamline( self, beamline: Optional[_pd.DataFrame] = None, print_label: bool = False, with_aperture: bool = True, labels: Optional[_pd.DataFrame] = None, **kwargs: Any, ) -> None: """ Args: beamline (): print_label (): with_aperture (): labels (): **kwargs (): Returns: """ if beamline is None: raise _ArtistException("No beamline provided") bl_short = beamline.reset_index() bl_short = bl_short[[not a for a in bl_short["NAME"].str.contains("DRIFT")]] bl_short = bl_short.set_index("NAME") ticks_locations_short = self.beamline_get_ticks_locations(bl_short) ticks_labels_short = self.beamline_get_ticks_labels(bl_short) self._ax.tick_params(axis="both", which="major") self._ax.tick_params(axis="x") plt.setp(self._ax.xaxis.get_majorticklabels(), rotation=-45) self._ax.set_xlim([bl_short.iloc[0]["AT_ENTRY"].m_as("m"), bl_short.iloc[-1]["AT_EXIT"].m_as("m")]) self._ax.get_xaxis().set_tick_params(direction="out") self._ax.yaxis.set_major_locator(MultipleLocator(10)) self._ax.yaxis.set_minor_locator(MultipleLocator(5)) self._ax.set_ylim(kwargs.get("ylim", [-60, 60])) self._ax.grid(True, alpha=0.25) if with_aperture: self.draw_aperture(beamline, **kwargs) if print_label: if labels is not None: ticks_locations_short = self.beamline_get_ticks_locations(labels) ticks_labels_short = self.beamline_get_ticks_labels(labels) self._ax.xaxis.set_major_locator(FixedLocator(ticks_locations_short)) self._ax.xaxis.set_major_formatter(FixedFormatter(ticks_labels_short))
[docs] def draw_aperture(self, bl: _pd.DataFrame, **kwargs: Any) -> None: bl = bl.copy() if "APERTURE" not in bl: logging.warning("No APERTURE defined in the beamline") return bl = bl[~bl["APERTYPE"].isnull()] bl[["APERTYPE", "CLASS"]] = bl[["APERTYPE", "CLASS"]].applymap(lambda e: e.upper()) bl.query( "CLASS in ['QUADRUPOLE', 'SBEND', 'RBEND', 'RECTANGULARCOLLIMATOR', 'CIRCULARCOLLIMATOR', " "'ELLIPTICALCOLLIMATOR']", inplace=True, ) planes = kwargs.get("plane", "X") # Set the y aperture for circular apertype # TODO this raises an error if the option DontSplitSBends = 0 for idx in bl.query("APERTYPE == 'CIRCULAR'").index: bl.at[idx, "APERTURE"][0:] = [bl.at[idx, "APERTURE"][0], bl.at[idx, "APERTURE"][0]] if planes == "X": index = [0, 0] elif planes == "Y": index = [1, 1] elif planes == "both": index = [1, 0] else: raise _ArtistException("Plane must be 'X', 'Y' or 'both'.") bl["APERTURE_UP"] = bl["APERTURE"].apply(lambda a: a[index[0]].m_as("mm")) bl["APERTURE_DOWN"] = bl["APERTURE"].apply(lambda a: a[index[1]].m_as("mm")) # Draw the collimator as they have no chamber. bl.query("CLASS == 'RECTANGULARCOLLIMATOR'").apply(lambda e: self.draw_coll(e), axis=1) bl.query("CLASS == 'CIRCULARCOLLIMATOR'").apply(lambda e: self.draw_coll(e), axis=1) bl.query("CLASS != 'RECTANGULARCOLLIMATOR' and CLASS !='CIRCULARCOLLIMATOR'", inplace=True) if "CHAMBER" not in bl: bl["CHAMBER"] = 0 bl["CHAMBER"] = bl["CHAMBER"].apply(lambda a: a * _ureg.mm) bl["CHAMBER_UP"] = bl["CHAMBER"].apply(lambda a: a.m_as("mm")) bl["CHAMBER_DOWN"] = bl["CHAMBER"].apply(lambda a: a.m_as("mm")) bl.query("CLASS == 'QUADRUPOLE'").apply(lambda e: self.draw_quad(e), axis=1) bl.query("CLASS == 'SBEND'").apply(lambda e: self.draw_bend(e), axis=1) bl.query("CLASS == 'RBEND'").apply(lambda e: self.draw_bend(e), axis=1)
[docs] def draw_quad(self, e: _pd.DataFrame) -> None: self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), e["APERTURE_UP"] + e["CHAMBER_UP"]), # (x,y) e["L"].m_as("m"), # width 100, facecolor=palette["QUADRUPOLE"], ), ) self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"] - e["CHAMBER_DOWN"]), # (x,y) e["L"].m_as("m"), # width -100, facecolor=palette["QUADRUPOLE"], ), ) self.draw_chamber(self._ax, e)
[docs] def draw_coll(self, e: _pd.DataFrame) -> None: self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), e["APERTURE_UP"]), # (x,y) e["L"].m_as("m"), # width 100, # height facecolor=palette["COLLIMATOR"], ), ) self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"]), # (x,y) e["L"].m_as("m"), # width -100, # height facecolor=palette["COLLIMATOR"], ), )
[docs] def draw_bend(self, e: _pd.DataFrame) -> None: tmp = e["APERTURE_UP"] + e["CHAMBER_UP"] if tmp > 55: logging.warning(f"Aperture are bigger than 55 mm for {e.name}.") self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), tmp if tmp < 55 else 55), # (x,y) e["L"].m_as("m"), # width 100, # height facecolor=palette["BEND"], ), ) tmp = -e["APERTURE_DOWN"] - e["CHAMBER_UP"] if tmp < -55: logging.warning(f"Aperture are bigger than 55 mm for {e.name}.") self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), tmp if abs(tmp) < 55 else -55), # (x,y) e["L"].m_as("m"), # width -100, facecolor=palette["BEND"], ), ) self.draw_chamber(self._ax, e)
[docs] @staticmethod def draw_chamber(ax: plt.Axes, e: _pd.DataFrame) -> None: ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), (e["APERTURE_UP"])), # (x,y) e["L"].m_as("m"), # width e["CHAMBER_UP"], # height hatch="", facecolor=palette["base01"], ), ) ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"]), # (x,y) e["L"].m_as("m"), # width -e["CHAMBER_UP"], # height hatch="", facecolor=palette["base01"], ), )
# Old method to convert for matplotlib # @staticmethod # def draw_bpm_size(ax, s, x): # ax.add_patch( # patches.Rectangle( # (s - 0.05, -x), # 0.1, # 2 * x, # ) # ) # # def bpm(self, ax, bl, **kwargs): # """TODO.""" # if kwargs.get('plane') is None: # raise Exception("'plane' argument must be provided.") # bl.line[bl.line[f"BPM_STD_{kwargs.get('plane')}"].notnull()].apply( # lambda x: self.draw_bpm_size(ax, x['AT_CENTER'], x[f"BPM_STD_{kwargs.get('plane')}"]), # axis=1 # ) # # # THIS IS THE OLD scattering.py # @staticmethod # def draw_slab(ax, e): # materials_colors = { # 'graphite': 'g', # 'beryllium': 'r', # 'water': 'b', # 'lexan': 'y', # } # ax.add_patch( # patches.Rectangle( # (e['AT_ENTRY'], -1), # (x,y) # e['LENGTH'], # width # 2, # height # hatch='', facecolor=materials_colors.get(e['MATERIAL'], 'k') # ) # ) # # @staticmethod # def draw_measuring_plane(ax, e): # ax.add_patch( # patches.Rectangle( # (e['AT_ENTRY'] - 0.005, -1), # (x,y) # 0.01, # width # 2, # height # hatch='', facecolor='k' # ) # ) # # def scattering(self, ax, bl, **kwargs): # bl.line.query("TYPE == 'slab'").apply(lambda e: self.draw_slab(ax, e), axis=1) # bl.line.query("TYPE == 'mp'").apply(lambda e: self.draw_measuring_plane(ax, e), axis=1)