Source code for georges_core.vis.matplotlib

import logging
from typing import Any, List, Optional, Tuple

import matplotlib.colors
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as _pd
from matplotlib.ticker import FixedFormatter, FixedLocator, MultipleLocator

from ..units import ureg as _ureg
from .artist import PALETTE
from .artist import Artist as _Artist
from .artist import ArtistException as _ArtistException

    (1.0, 1.0, 1.0),
    (0.9, 0.6, 0.9),
    (1.0, 0.4, 1.0),
    (0.9, 0.0, 1.0),
    (0.7, 0.0, 1.0),
    (0.5, 0.0, 0.8),
    (0.0, 0.0, 0.8),
    (0.0, 0.0, 1.0),
    (0.0, 0.6, 1.0),
    (0.0, 0.8, 1.0),
    (0.0, 0.7, 0.5),
    (0.0, 0.9, 0.2),
    (0.5, 1.0, 0.0),
    (0.8, 1.0, 0.0),
    (1.0, 1.0, 0.0),
    (1.0, 0.8, 0.0),
    (1.0, 0.5, 0.0),
    (1.0, 0.0, 0.0),
    (0.8, 0.0, 0.0),
    (0.6, 0.0, 0.0),
    (0.0, 0.0, 0.0),
FlukaColormap = matplotlib.colors.LinearSegmentedColormap.from_list("fluka", FLUKA_COLORS, N=300)

# # Define default color palette
palette = PALETTE["solarized"]

# Define "logical" colors
palette["BEND"] = palette["blue"]
palette["QUADRUPOLE"] = palette["red"]
palette["SEXTUPOLE"] = palette["yellow"]
palette["OCTUPOLE"] = palette["green"]
palette["MULTIPOLE"] = palette["gray"]
palette["DEGRADER"] = palette["base02"]
palette["RECTANGULARCOLLIMATOR"] = palette["darkgreen"]
palette["CIRCULARCOLLIMATOR"] = palette["magenta"]
palette["ELLIPTICALCOLLIMATOR"] = palette["orange"]
palette["COLLIMATOR"] = palette["magenta"]
palette["HKICKER"] = palette["magenta"]
palette["VKICKER"] = palette["violet"]
palette["SCATTERER"] = palette["base02"]
palette["MATRIX"] = palette["cyan"]
palette["ELEMENT"] = palette["base0"]

[docs] class MatplotlibArtist(_Artist): """ TODO """ def __init__(self, ax: Optional[plt.Axes] = None, **kwargs: Any): """ Args: param ax: the matplotlib ax used for plotting. If None it will be created with `init_plot` (kwargs are forwarded). with_frames: draw the entry and exit frames of each element with_centers: draw the center of each polar coordinate elements kwargs: forwarded to `Artist` and to `init_plot`. """ super().__init__(**kwargs) if ax is None: self.init_plot(**kwargs) else: self._ax = ax self._fig = ax.figure self._ax2 = None @property def ax(self) -> plt.Axes: """Current Matplotlib ax. Returns: the Matplotlib ax. """ return self._ax @property def ax2(self) -> plt.Axes: """ Returns: """ return self._ax2 @property def figure(self) -> plt.figure: """Current Matplotlib figure. Returns: the Matplotlib figure. """ return self._fig @ax.setter # type:ignore[no-redef] def ax(self, ax: plt.Axes) -> None: self._ax = ax
[docs] def init_plot(self, figsize: Tuple[int, int] = (12, 8), subplots: int = 111) -> None: """ Initialize the Matplotlib figure and ax. Args: subplots: number of subplots figsize: figure size """ self._fig = plt.figure(figsize=figsize) self._ax = self._fig.add_subplot(subplots)
[docs] def plot(self, *args: Any, **kwargs: Any) -> None: """Proxy for matplotlib.pyplot.plot Same as `matplotlib.pyplot.plot`, forwards all arguments. """ self._ax.plot(*args, **kwargs)
[docs] @staticmethod def beamline_get_ticks_locations(o: _pd.DataFrame) -> List[float]: return list(o["AT_CENTER"].apply(lambda e: e.m_as("m")).values)
[docs] @staticmethod def beamline_get_ticks_labels(o: _pd.DataFrame) -> List[str]: return list(o.index)
[docs] def plot_cartouche( self, beamline: Optional[_pd.DataFrame] = None, print_label: bool = False, labels: Optional[_pd.DataFrame] = None, vertical_position: float = 1.15, ) -> None: """ Args: beamline: print_label: labels: Returns: """ if beamline is None: raise _ArtistException("No beamline provided") self._ax2 = self._ax.twinx() self._ax2.set_ylim([0, 1]) self._ax2.axis("off") self._ax2.axis("on") self._ax2.set_yticks([]) self._ax2.set_ylim([0, 1]) self._ax2.hlines( vertical_position, 0, beamline.iloc[-1]["AT_EXIT"].m_as("m"), clip_on=False, colors="black", lw=1, ) for i, e in beamline.iterrows(): if e["CLASS"].upper() in ["DRIFT", "MARKER"]: continue if e["CLASS"].upper() in ["SBEND", "RBEND"]: self._ax2.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), vertical_position - 0.05), e["L"].m_as("m"), 0.1, hatch="", facecolor=palette["BEND"], clip_on=False, ), ) elif e["CLASS"].upper() in [ "SEXTUPOLE", "QUADRUPOLE", "MULTIPOLE", "HKICKER", "RECTANGULARCOLLIMATOR", "VKICKER", "DEGRADER", "CIRCULARCOLLIMATOR", "ELLIPTICALCOLLIMATOR", "MATRIX", "SCATTERER", "ELEMENT", ]: self._ax2.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), vertical_position - 0.05), e["L"].m_as("m"), 0.1, hatch="", facecolor=palette[e["CLASS"].upper()], ec=palette[e["CLASS"].upper()], clip_on=False, ), ) else: logging.warning(f"colors are not implemented for {e['CLASS']}") if print_label: # For beamline losses or Twiss bl_short = beamline.reset_index() bl_short = bl_short.query("CLASS != 'Drift'") bl_short = bl_short.set_index("NAME") if labels is not None: ticks_locations_short = self.beamline_get_ticks_locations(labels) ticks_labels_short = self.beamline_get_ticks_labels(labels) else: ticks_locations_short = self.beamline_get_ticks_locations(bl_short) ticks_labels_short = self.beamline_get_ticks_labels(bl_short) self._ax2.xaxis.set_major_locator(FixedLocator(ticks_locations_short)) self._ax2.xaxis.set_major_formatter(FixedFormatter(ticks_labels_short)) self._ax2.get_xaxis().set_tick_params(direction="out") self._ax2.tick_params(axis="both", which="major") self._ax2.tick_params(axis="x") plt.setp(self._ax2.xaxis.get_majorticklabels(), rotation=-90)
[docs] def plot_beamline( self, beamline: Optional[_pd.DataFrame] = None, print_label: bool = False, with_aperture: bool = True, labels: Optional[_pd.DataFrame] = None, **kwargs: Any, ) -> None: """ Args: beamline (): print_label (): with_aperture (): labels (): **kwargs (): Returns: """ if beamline is None: raise _ArtistException("No beamline provided") bl_short = beamline.reset_index() bl_short = bl_short[[not a for a in bl_short["NAME"].str.contains("DRIFT")]] bl_short = bl_short.set_index("NAME") ticks_locations_short = self.beamline_get_ticks_locations(bl_short) ticks_labels_short = self.beamline_get_ticks_labels(bl_short) self._ax.tick_params(axis="both", which="major") self._ax.tick_params(axis="x") plt.setp(self._ax.xaxis.get_majorticklabels(), rotation=-45) self._ax.set_xlim([bl_short.iloc[0]["AT_ENTRY"].m_as("m"), bl_short.iloc[-1]["AT_EXIT"].m_as("m")]) self._ax.get_xaxis().set_tick_params(direction="out") self._ax.yaxis.set_major_locator(MultipleLocator(10)) self._ax.yaxis.set_minor_locator(MultipleLocator(5)) self._ax.set_ylim(kwargs.get("ylim", [-60, 60])) self._ax.grid(True, alpha=0.25) if with_aperture: self.draw_aperture(beamline, **kwargs) if print_label: if labels is not None: ticks_locations_short = self.beamline_get_ticks_locations(labels) ticks_labels_short = self.beamline_get_ticks_labels(labels) self._ax.xaxis.set_major_locator(FixedLocator(ticks_locations_short)) self._ax.xaxis.set_major_formatter(FixedFormatter(ticks_labels_short))
[docs] def draw_aperture(self, bl: _pd.DataFrame, **kwargs: Any) -> None: bl = bl.copy() if "APERTURE" not in bl: logging.warning("No APERTURE defined in the beamline") return bl = bl[~bl["APERTYPE"].isnull()] bl[["APERTYPE", "CLASS"]] = bl[["APERTYPE", "CLASS"]].applymap(lambda e: e.upper()) bl.query( "CLASS in ['QUADRUPOLE', 'SBEND', 'RBEND', 'RECTANGULARCOLLIMATOR', 'CIRCULARCOLLIMATOR', " "'ELLIPTICALCOLLIMATOR']", inplace=True, ) planes = kwargs.get("plane", "X") # Set the y aperture for circular apertype # TODO this raises an error if the option DontSplitSBends = 0 for idx in bl.query("APERTYPE == 'CIRCULAR'").index:[idx, "APERTURE"][0:] = [[idx, "APERTURE"][0],[idx, "APERTURE"][0]] if planes == "X": index = [0, 0] elif planes == "Y": index = [1, 1] elif planes == "both": index = [1, 0] else: raise _ArtistException("Plane must be 'X', 'Y' or 'both'.") bl["APERTURE_UP"] = bl["APERTURE"].apply(lambda a: a[index[0]].m_as("mm")) bl["APERTURE_DOWN"] = bl["APERTURE"].apply(lambda a: a[index[1]].m_as("mm")) # Draw the collimator as they have no chamber. bl.query("CLASS == 'RECTANGULARCOLLIMATOR'").apply(lambda e: self.draw_coll(e), axis=1) bl.query("CLASS == 'CIRCULARCOLLIMATOR'").apply(lambda e: self.draw_coll(e), axis=1) bl.query("CLASS != 'RECTANGULARCOLLIMATOR' and CLASS !='CIRCULARCOLLIMATOR'", inplace=True) if "CHAMBER" not in bl: bl["CHAMBER"] = 0 bl["CHAMBER"] = bl["CHAMBER"].apply(lambda a: a * bl["CHAMBER_UP"] = bl["CHAMBER"].apply(lambda a: a.m_as("mm")) bl["CHAMBER_DOWN"] = bl["CHAMBER"].apply(lambda a: a.m_as("mm")) bl.query("CLASS == 'QUADRUPOLE'").apply(lambda e: self.draw_quad(e), axis=1) bl.query("CLASS == 'SBEND'").apply(lambda e: self.draw_bend(e), axis=1) bl.query("CLASS == 'RBEND'").apply(lambda e: self.draw_bend(e), axis=1)
[docs] def draw_quad(self, e: _pd.DataFrame) -> None: self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), e["APERTURE_UP"] + e["CHAMBER_UP"]), # (x,y) e["L"].m_as("m"), # width 100, facecolor=palette["QUADRUPOLE"], ), ) self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"] - e["CHAMBER_DOWN"]), # (x,y) e["L"].m_as("m"), # width -100, facecolor=palette["QUADRUPOLE"], ), ) self.draw_chamber(self._ax, e)
[docs] def draw_coll(self, e: _pd.DataFrame) -> None: self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), e["APERTURE_UP"]), # (x,y) e["L"].m_as("m"), # width 100, # height facecolor=palette["COLLIMATOR"], ), ) self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"]), # (x,y) e["L"].m_as("m"), # width -100, # height facecolor=palette["COLLIMATOR"], ), )
[docs] def draw_bend(self, e: _pd.DataFrame) -> None: tmp = e["APERTURE_UP"] + e["CHAMBER_UP"] if tmp > 55: logging.warning(f"Aperture are bigger than 55 mm for {}.") self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), tmp if tmp < 55 else 55), # (x,y) e["L"].m_as("m"), # width 100, # height facecolor=palette["BEND"], ), ) tmp = -e["APERTURE_DOWN"] - e["CHAMBER_UP"] if tmp < -55: logging.warning(f"Aperture are bigger than 55 mm for {}.") self._ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), tmp if abs(tmp) < 55 else -55), # (x,y) e["L"].m_as("m"), # width -100, facecolor=palette["BEND"], ), ) self.draw_chamber(self._ax, e)
[docs] @staticmethod def draw_chamber(ax: plt.Axes, e: _pd.DataFrame) -> None: ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), (e["APERTURE_UP"])), # (x,y) e["L"].m_as("m"), # width e["CHAMBER_UP"], # height hatch="", facecolor=palette["base01"], ), ) ax.add_patch( patches.Rectangle( (e["AT_ENTRY"].m_as("m"), -e["APERTURE_DOWN"]), # (x,y) e["L"].m_as("m"), # width -e["CHAMBER_UP"], # height hatch="", facecolor=palette["base01"], ), )
# Old method to convert for matplotlib # @staticmethod # def draw_bpm_size(ax, s, x): # ax.add_patch( # patches.Rectangle( # (s - 0.05, -x), # 0.1, # 2 * x, # ) # ) # # def bpm(self, ax, bl, **kwargs): # """TODO.""" # if kwargs.get('plane') is None: # raise Exception("'plane' argument must be provided.") # bl.line[bl.line[f"BPM_STD_{kwargs.get('plane')}"].notnull()].apply( # lambda x: self.draw_bpm_size(ax, x['AT_CENTER'], x[f"BPM_STD_{kwargs.get('plane')}"]), # axis=1 # ) # # # THIS IS THE OLD # @staticmethod # def draw_slab(ax, e): # materials_colors = { # 'graphite': 'g', # 'beryllium': 'r', # 'water': 'b', # 'lexan': 'y', # } # ax.add_patch( # patches.Rectangle( # (e['AT_ENTRY'], -1), # (x,y) # e['LENGTH'], # width # 2, # height # hatch='', facecolor=materials_colors.get(e['MATERIAL'], 'k') # ) # ) # # @staticmethod # def draw_measuring_plane(ax, e): # ax.add_patch( # patches.Rectangle( # (e['AT_ENTRY'] - 0.005, -1), # (x,y) # 0.01, # width # 2, # height # hatch='', facecolor='k' # ) # ) # # def scattering(self, ax, bl, **kwargs): # bl.line.query("TYPE == 'slab'").apply(lambda e: self.draw_slab(ax, e), axis=1) # bl.line.query("TYPE == 'mp'").apply(lambda e: self.draw_measuring_plane(ax, e), axis=1)