"""
TODO
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as _np
import plotly.graph_objs as go
import plotly.offline
from .artist import Artist as _Artist
if TYPE_CHECKING:
import pandas as _pd
[docs]
class PlotlyArtist(_Artist):
"""
TODO
"""
def __init__(
self,
config: Optional[Dict[str, Any]] = None,
layout: Optional[Dict[str, Any]] = None,
width: Optional[float] = None,
height: Optional[float] = None,
**kwargs: Any,
) -> None:
"""
Args:
config:
layout:
width:
height:
**kwargs:
"""
super().__init__(**kwargs)
self._data: List[Any] = []
self._config = config or {
"showLink": False,
"scrollZoom": True,
"displayModeBar": False,
"editable": False,
}
self._layout: Dict[str, Any] = layout or {
"font": {"family": "serif", "size": 18},
"plot_bgcolor": "rgba(0,0,0,0)",
"xaxis": {
"showgrid": True,
"linecolor": "black",
"linewidth": 1,
"mirror": True,
"gridcolor": "grey",
"gridwidth": 0.1,
},
"yaxis": {
"linecolor": "black",
"linewidth": 1,
"gridcolor": "grey",
"gridwidth": 0.1,
"mirror": True,
"exponentformat": "power",
},
"height": 600,
"width": 600,
}
if height is not None:
self._layout["height"] = height
if width is not None:
self._layout["width"] = width
self._shapes: List[Any] = []
self._n_y_axis = len([ax for ax in self._layout.keys() if ax.startswith("yaxis")])
def _init_plot(self) -> None:
pass
@property
def fig(self) -> Dict[str, Any]: # pragma: no cover
"""Provides the plotly figure."""
return {
"data": self.data,
"layout": self.layout,
}
@property
def config(self) -> Dict[str, Any]: # pragma: no cover
return self._config
@property
def data(self) -> List[Any]: # pragma: no cover
return self._data
@property
def layout(self) -> Dict[str, Any]: # pragma: no cover
self._layout["shapes"] = self._shapes
return self._layout
@property
def shapes(self) -> List[Any]: # pragma: no cover
return self._shapes
[docs]
def __iadd__(self, other: Any) -> PlotlyArtist: # pragma: no cover
"""Add a trace to the figure."""
self._data.append(other)
return self
[docs]
def add_axis(self, title: str = "", axis: Optional[Dict[str, Any]] = None) -> None:
"""
Args:
title:
axis:
Returns:
"""
self._n_y_axis += 1
self.layout[f"yaxis{self._n_y_axis if self._n_y_axis > 1 else ''}"] = axis or {
"title": title,
"titlefont": dict(color="black"),
"tickfont": dict(color="black"),
"linewidth": 1,
"exponentformat": "power",
"overlaying": "y",
"side": "left",
}
[docs]
def add_secondary_axis(self, title: str = "", axis: Optional[Dict[str, Any]] = None) -> None:
"""
Args:
title:
axis:
Returns:
"""
self._n_y_axis += 1
self.layout[f"yaxis{self._n_y_axis}"] = axis or {
"title": title,
"titlefont": dict(color="black"),
"tickfont": dict(color="black"),
"linewidth": 1,
"exponentformat": "power",
"overlaying": "y",
"side": "right",
}
[docs]
def render(self) -> None: # pragma: no cover
if len(self._data) == 0:
self._data.append(go.Scatter(x=[0.0], y=[0.0]))
plotly.offline.iplot(self.fig, config=self.config)
[docs]
def save(self, file: str, file_format: str = "png") -> None: # pragma: no cover
plotly.io.write_image(self.fig, file=file, format=file_format)
[docs]
def save_html(self, file: str) -> Any: # pragma: no cover
return plotly.offline.plot(self.fig, config=self.config, auto_open=False, filename=file)
[docs]
def histogram(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover
"""A proxy for plotly.graph_objs.Histogram"""
self._data.append(go.Histogram(*args, **kwargs))
[docs]
def uproot_histogram(self, histogram: Any, **kwargs: Any) -> None: # pragma: no cover
_ = histogram.numpy()
self.bar(x=_[1], y=_[0], error_y={"array": _np.sqrt(histogram.variances)}, **kwargs)
[docs]
def histogram2d(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover
"""A proxy for plotly.graph_objs.Histogram2d"""
self._data.append(go.Histogram2d(*args, **kwargs))
[docs]
def bar(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover
self._data.append(go.Bar(*args, **kwargs))
[docs]
def scatter(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover
"""A proxy for plotly.graph_objs.Scatter ."""
self._data.append(go.Scatter(*args, **kwargs))
[docs]
def scatter3d(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover
"""A proxy for plotly.graph_objs.Scatter3d ."""
self._data.append(go.Scatter3d(*args, **kwargs))
[docs]
def surface(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover
"""A proxy for plotly.graph_objs.Surface ."""
self._data.append(go.Surface(*args, **kwargs))
[docs]
def heatmap(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover
"""A proxy for plotly.graph_objs.Surface ."""
self._data.append(go.Heatmap(*args, **kwargs))
[docs]
def plot_cartouche(
self,
beamline_survey: _pd.DataFrame,
vertical_position: float = 1.2,
unsplit_bends: bool = True,
skip_elements: Optional[List[str]] = None,
) -> None:
"""
Args:
beamline_survey:
vertical_position:
unsplit_bends:
skip_elements:
Returns:
"""
skip_elements = skip_elements or []
def do_sbend(at_entry: float, at_exit: float, polarity: float) -> None:
length = at_exit - at_entry
if polarity >= 0.0:
path = (
f"M{at_entry},{vertical_position + 0.1} "
f"H{at_exit} "
f"L{at_exit - 0.15 * length},{vertical_position - 0.1} "
f"H{at_exit - 0.85 * length} "
f"Z"
)
else:
path = (
f"M{at_entry + 0.15 * length},{vertical_position + 0.1} "
f"H{at_exit - 0.15 * length} "
f"L{at_exit},{vertical_position - 0.1} "
f"H{at_entry} "
f"Z"
)
self.shapes.append(
{
"type": "path",
"xref": "x",
"yref": "paper",
"path": path,
"line": {
"width": 0,
},
"fillcolor": "#0000FF",
},
)
colors = {"ELEMENT": "#AAAAAA", "RCOL": "#11EE11", "ECOL": "#1111EE"}
self.shapes.append(
{
"type": "line",
"xref": "paper",
"yref": "paper",
"x0": 0,
"y0": vertical_position,
"x1": 1,
"y1": vertical_position,
"line": {
"color": "rgb(150, 150, 150)",
"width": 2,
},
},
)
accumulate = False
accumulator: Dict[str, Any] = {}
for i, e in beamline_survey.iterrows():
if e["CLASS"].upper() not in ("QUADRUPOLE", "SBEND", "ELEMENT", "RCOL", "ECOL"):
continue
if i in skip_elements:
continue
if unsplit_bends and accumulate and (e["CLASS"].upper() not in ("SBEND",) or i != accumulator["name"]):
accumulate = False
do_sbend(accumulator["at_entry"], accumulator["at_exit"], accumulator["polarity"])
accumulator = {}
if e["CLASS"].upper() in ("ELEMENT", "RCOL", "ECOL"):
self.shapes.append(
{
"type": "rect",
"xref": "x",
"yref": "paper",
"x0": e["AT_ENTRY"].m_as("m"),
"y0": vertical_position + 0.1,
"x1": e["AT_EXIT"].m_as("m"),
"y1": vertical_position - 0.1,
"line": {
"width": 0,
},
"fillcolor": colors[e["CLASS"].upper()],
},
)
if e["CLASS"].upper() == "QUADRUPOLE":
try:
field_magnitude = e["K1"].magnitude
except KeyError:
field_magnitude = e["K1L"].magnitude
self.shapes.append(
{
"type": "rect",
"xref": "x",
"yref": "paper",
"x0": e["AT_ENTRY"].m_as("m"),
"y0": vertical_position if field_magnitude > 0 else vertical_position - 0.1,
"x1": e["AT_EXIT"].m_as("m"),
"y1": vertical_position + 0.1 if field_magnitude > 0 else vertical_position,
"line": {
"width": 0,
},
"fillcolor": "#FF0000",
},
)
if e["CLASS"].upper() == "HKICKER" or e["CLASS"].upper() == "VKICKER":
self.shapes.append(
{
"type": "rect",
"xref": "x",
"yref": "paper",
"x0": e["AT_ENTRY"].m_as("m"),
"y0": vertical_position,
"x1": e["AT_EXIT"].m_as("m"),
"y1": vertical_position + 0.1,
"line": {
"width": 0,
},
"fillcolor": "Green",
},
)
if e["CLASS"].upper() == "SBEND":
if unsplit_bends:
if accumulate is False:
accumulate = True
accumulator["name"] = i
accumulator["polarity"] = _np.sign(e["ANGLE"].m_as("radians"))
accumulator["at_entry"] = e["AT_ENTRY"].m_as("m")
accumulator["at_exit"] = e["AT_EXIT"].m_as("m")
continue
if accumulate is True:
accumulator["at_exit"] = e["AT_EXIT"].m_as("m")
else:
do_sbend(
e["AT_ENTRY"].m_as("m"),
e["AT_EXIT"].m_as("m"),
polarity=_np.sign(e["ANGLE"].m_as("radians")),
)