Source code for biosiglive.gui.plot

"""
This file contains all the plot functions to plot the data in live or offline.
"""
try:
    import pyqtgraph as pg
    import pyqtgraph.opengl as gl
    from PyQt5.QtWidgets import QProgressBar
except ModuleNotFoundError:
    pass
import numpy as np
from typing import Union
import matplotlib.pyplot as plt
from math import ceil
from ..enums import PlotType
import time


# TODO: Add plot class to be able to plot several curves in the same plot. And do several plot in the same app.
[docs] class LivePlot: def __init__( self, nb_subplots: int = 1, plot_type: Union[PlotType, str] = PlotType.Curve, name: str = None, channel_names: list = None, rate: int = None, ): """ Initialize the plot class. Parameters ---------- plot_type : Union[PlotType, str] type of the plot (curve, progress bar, 3D scatter, skeleton...). name : str name of the plot channel_names : list list of the channel names (subplots names) nb_subplots : int number of subplots """ if isinstance(plot_type, str): if plot_type not in [t.value for t in PlotType]: raise ValueError("Plot type not recognized") plot_type = PlotType(plot_type) self.plot_type = plot_type if nb_subplots and channel_names: if len(channel_names) != nb_subplots: raise ValueError("The number of subplots is not equal to the number of channel names") self.channel_names = channel_names self.figure_name = name self.nb_subplot = nb_subplots self.rate = rate self.layout = None self.app = None self.viz = None self.resize = (400, 400) self.move = (0, 0) self.plot_windows = None self.plot_buffer = None self.msk_model = None self.last_plot = None self.once_update = False self.plots = [] self.curves = [] self.ptr = [] self.unit = "" self.size_to_append = []
[docs] def init( self, plot_windows: Union[int, list] = None, **kwargs, ): """ This function is used to initialize the qt app. Parameters ---------- plot_windows: Union[int, list] The number of frames ti plot. If is a list, the number of frames to plot for each subplot. """ self.plot_buffer = [None] * self.nb_subplot if isinstance(plot_windows, int): plot_windows = [plot_windows] * self.nb_subplot self.plot_windows = plot_windows if self.plot_type == PlotType.Curve: self._init_curve(self.figure_name, self.channel_names, self.nb_subplot, **kwargs) elif self.plot_type == PlotType.ProgressBar: self._init_progress_bar(self.figure_name, self.nb_subplot, **kwargs) elif self.plot_type == PlotType.Scatter3D: if self.nb_subplot != 1: raise ValueError("The number of subplots should be 1 for 3DScatter plot.") self._init_3d_scatter(self.figure_name, **kwargs) elif self.plot_type == PlotType.Skeleton: self.viz = self._init_skeleton(**kwargs) else: raise ValueError(f"The plot type ({self.plot_type}) is not supported.")
[docs] def update(self, data: Union[np.ndarray, list], **kwargs): """ This function is used to update the qt app. Parameters ---------- data: Union[np.ndarray, list] The data to plot. If it is a list, the data to plot for each subplot. """ update = True if self.plot_type != PlotType.Scatter3D and self.plot_type != PlotType.Skeleton: if isinstance(data, list): if len(data) != self.nb_subplot: raise ValueError("The number of subplots is not equal to the number of data.") for d in data: if isinstance(d, np.ndarray): if len(d.shape) != 2: raise ValueError("The data should be a 2D array.") else: raise ValueError("The data should be a 2D array.") if isinstance(data, np.ndarray): data_mat = data data = [] if data_mat.shape[0] != self.nb_subplot: raise ValueError("The number of subplots is not equal to the number of data.") for d in data_mat: data.append(d[np.newaxis, :]) if self.plot_windows: for i in range(self.nb_subplot): if self.plot_buffer[i] is None: self.plot_buffer[i] = data[i][..., -self.plot_windows[i] :] if self.plot_buffer[i].shape[1] < self.plot_windows[i]: size = self.plot_windows[i] - self.plot_buffer[i].shape[1] self.plot_buffer[i] = np.append( np.zeros((self.plot_buffer[i].shape[0], size)), self.plot_buffer[i], axis=-1 ) elif self.plot_buffer[i].shape[1] < self.plot_windows[i]: self.plot_buffer[i] = np.append(self.plot_buffer[i], data[i], axis=-1) elif self.plot_buffer[i].shape[1] >= self.plot_windows[i]: size = data[i].shape[1] self.plot_buffer[i] = np.append(self.plot_buffer[i][..., size:], data[i], axis=-1) data = self.plot_buffer if self.rate and self.once_update: plot_time = time.time() - self.last_plot if plot_time != 0 and 1 / plot_time > self.rate: update = False else: update = True if update: self.once_update = True if self.plot_type == PlotType.ProgressBar: self._update_progress_bar(data) elif self.plot_type == PlotType.Curve: self._update_curve(data) elif self.plot_type == PlotType.Skeleton: self._update_skeleton(data, self.viz) elif self.plot_type == PlotType.Scatter3D: self._update_3d_scatter(data, **kwargs) else: raise ValueError(f"The plot type ({self.plot_type}) is not supported.") self.last_plot = time.time()
def _init_curve( self, figure_name: str = "Figure", subplot_labels: Union[list, str] = None, nb_subplot: int = None, x_labels: Union[list, str] = None, y_labels: Union[list, str] = None, grid: bool = True, colors: Union[list, tuple] = None, ): """ This function is used to initialize the curve plot. Parameters ---------- figure_name: str The name of the figure. subplot_labels: Union[list, str] The labels of the subplots. nb_subplot: int The number of subplot. x_labels: Union[list, str] The labels of the x axis. y_labels: Union[list, str] The labels of the y axis. grid: bool If True, the grid is displayed. colors: Union[list, tuple] The colors of the curves. """ # --- Curve graph --- # self.app = pg.mkQApp("Curve_plot") pg.setConfigOption("background", "w") pg.setConfigOption("foreground", "k") self.win = pg.GraphicsLayoutWidget(show=True) self.win.setWindowTitle(figure_name) nb_line = 4 nb_col = ceil(nb_subplot / nb_line) line_count = 0 self.win.resize(self.resize[0], self.resize[1]) self.win.move(self.move[0], self.move[1]) if colors: if isinstance(colors, tuple): colors = [colors] * nb_subplot elif isinstance(colors, list): if len(colors) != nb_subplot: raise ValueError("The number of colors is not equal to the number of subplots.") else: colors = [(0, 128, 232)] * nb_subplot # Blue if not x_labels: x_labels = ["Frames"] * nb_subplot else: if isinstance(x_labels, str): x_labels = [x_labels] * nb_subplot elif isinstance(x_labels, list): if len(x_labels) != nb_subplot: raise ValueError("The number of x labels is not equal to the number of subplots.") if not y_labels: y_labels = ["Amplitude"] * nb_subplot else: if isinstance(y_labels, str): y_labels = [y_labels] * nb_subplot elif isinstance(y_labels, list): if len(y_labels) != nb_subplot: raise ValueError("The number of y labels is not equal to the number of subplots.") if not subplot_labels: subplot_labels = [f"Subplot {i}" for i in range(nb_subplot)] else: if isinstance(subplot_labels, list): if len(subplot_labels) != nb_subplot: raise ValueError("The number of subplot labels is not equal to the number of subplots.") for subplot in range(nb_subplot): self.ptr.append(0) self.size_to_append.append(0) if line_count == nb_col: self.win.nextRow() line_count = 0 self.plots.append(self.win.addPlot(title=subplot_labels[subplot])) self.plots[-1].setDownsampling(mode="peak") self.plots[-1].setClipToView(False) self.curves.append(self.plots[-1].plot([], pen=colors[subplot], name="Blue curve")) self.plots[-1].setLabel("bottom", x_labels[subplot]) self.plots[-1].setLabel("left", y_labels[subplot]) self.plots[-1].showGrid(x=grid, y=grid) line_count += 1 def _init_progress_bar( self, figure_name: str = "Figure", nb_subplot: int = None, bar_graph_max_value: Union[int, list] = 100, unit: Union[str, list] = "", ): """ This function is used to initialize the curve plot. Parameters ---------- figure_name: str The name of the figure. nb_subplot: int The number of subplot. bar_graph_max_value: int or list The maximum value of the bar graph. unit: str or list The unit of the bar graph. """ # --- Progress bar graph --- # if isinstance(unit, str): self.unit = [unit] * nb_subplot self.layout, self.app = self._init_layout(figure_name, resize=self.resize, move=self.move) row_count = 0 if bar_graph_max_value is None: bar_graph_max_value = [100] * nb_subplot if isinstance(bar_graph_max_value, int): bar_graph_max_value = [bar_graph_max_value] * nb_subplot for plot in range(nb_subplot): self.plots.append(QProgressBar()) self.plots[-1].setMaximum(bar_graph_max_value[plot]) self.layout.addWidget(self.plots[-1], row=plot, col=0) self.layout.show() row_count += 1 def _init_3d_scatter( self, figure_name: str = "Figure", colors: Union[list, tuple] = (1.0, 0.0, 0.0, 0.5), size: Union[int, list] = 0.03, ): """ This function is used to initialize the 3d scatter plot. Parameters ---------- figure_name: str The name of the figure. colors: Union[list, tuple] The color of the scatter. size: Union[int, list] The size of the scatters. """ # --- 3D scatter graph --- # self.app = pg.mkQApp("3D_scatter_plot") w = gl.GLViewWidget() w.opts["bgcolor"] = (0.2, 0.2, 0.2, 10) w.opts["distance"] = 8 w.show() w.setWindowTitle(figure_name) g = gl.GLGridItem() # g.setColor((1, 1, 1, 100)) w.addItem(g) pos = np.zeros((1, 3)) self.plots.append(gl.GLScatterPlotItem(pos=pos, color=colors, size=size, pxMode=False)) w.addItem(self.plots[-1]) def _update_3d_scatter( self, data: Union[np.ndarray, list], colors: Union[list, tuple] = (0, 1.0, 0.0, 50), size: Union[list, float] = 0.03, ): """ This function is used to update the 3d scatter plot. Parameters ---------- data: np.ndarray The data to plot. (N, 3) colors: Union[list, tuple] The color of the scatter. size: float The size of the scatter. """ if isinstance(data, np.ndarray): if len(data.shape) != 2: raise ValueError("The data must be a 2D array.") if data.shape[1] != 3: raise ValueError("The data must be a (N, 3) array.") if isinstance(colors, list): if len(colors) != len(data): raise ValueError("The number of colors is not equal to the number of data.") if isinstance(size, list): if len(size) != len(data): raise ValueError("The number of size is not equal to the number of data.") for plot in self.plots: plot.setData(pos=data, color=colors, size=size) self.app.processEvents() def _update_curve(self, data: list): """ This function is used to update the curve plot. Parameters ---------- data: list The data to plot. """ if len(data) != len(self.curves): raise ValueError( f"The number of data ({len(data)}) is different from the number of curves ({len(self.curves)})." ) for i in range(len(data)): if self.ptr[i] == 0: self.size_to_append[i] = data[i].shape[1] self.ptr[i] += self.size_to_append[i] * 2 self.curves[i].setData(data[i][0, :]) # self.curves[i].setPos(self.ptr[i], 0) self.app.processEvents() def _update_progress_bar(self, data: list): """ This function is used to update the progress bar plot. Parameters ---------- data: list The data to plot. """ if self.channel_names and len(self.channel_names) != len(data): raise RuntimeError( f"The length of Subplot labels ({len(self.channel_names)}) is different than" f" the first dimension of your data ({len(data)})." ) for i in range(len(data)): value = np.mean(data[i][0, :]) self.plots[i].setValue(int(value)) name = self.channel_names[i] if self.channel_names else f"plot_{i}" self.plots[i].setFormat(f"{name}: {int(value)} {self.unit[i]}") self.app.processEvents() @staticmethod def _update_skeleton(data: list, viz): """ This function is used to update the skeleton plot. Parameters ---------- data : list The data to plot. list of length degree of freedom. viz: Viz3D The plot. Returns ------- """ viz.set_q(data[:, -1], refresh_window=True) @staticmethod def _init_skeleton(**kwargs): try: import bioviz except ImportError: raise ImportError("Please install bioviz (github.com/pyomeca/bioviz) to use the skeleton plot.") if not "model_path" in kwargs or "model" in kwargs: raise ValueError( "You must provide a model_path or a model to use the skeleton plot through" " the keyword arguments 'model_path' or 'model' respectively." ) plot = bioviz.Viz(**kwargs) return plot @staticmethod def _init_layout(figure_name: str = "Figure", resize: tuple = (400, 400), move: tuple = (0, 0)): """ This function is used to initialize the qt app layout. Parameters ---------- figure_name: str The name of the figure. resize: tuple The size of the figure. move: tuple The position of the figure. Returns ------- layout: QVBoxLayout The layout of the qt app. app: QApplication The qt app. """ app = pg.mkQApp(figure_name) layout = pg.LayoutWidget() layout.resize(resize[0], resize[1]) layout.move(move[0], move[1]) return layout, app
[docs] def disconnect(self): self.app.disconnect() try: self.app.closeAllWindows() except RuntimeError: pass
[docs] class OfflinePlot: """ This class is used to plot data offline. """
[docs] @staticmethod def multi_plot( data: Union[list, np.ndarray], x: Union[list, np.ndarray] = None, nb_column: int = None, y_label: str = None, x_label: str = None, legend: Union[list, str] = None, subplot_title: Union[str, list] = None, figure_name: str = None, ): """ This function is used to plot multiple data in one figure. Parameters ---------- data: list or np.ndarray The data to plot. x: list or np.ndarray The x-axis data. nb_column: int The number of columns in the figure. y_label: str The y-axis label. x_label: str The x-axis label. legend: list or str The legend of the data. subplot_title: str or list The title of the subplot. figure_name: str The name of the figure. """ if not isinstance(data, list): data = [data] nb_data = len(data) plt.figure(figure_name) size_police = 12 if nb_column: col = nb_column else: col = data[0].shape[0] if data[0].shape[0] <= 4 else 4 line = ceil(data[0].shape[0] / col) if isinstance(legend, str): legend = [legend] for i in range(data[0].shape[0]): plt.subplot(line, col, i + 1) if y_label and i % 4 == 0: plt.ylabel(y_label, fontsize=size_police) if x_label: plt.xlabel(x_label, fontsize=size_police) for j in range(nb_data): if legend: legend_tmp = legend[j] else: legend_tmp = None if x is not None: plt.plot(x, data[j][i, :], label=legend_tmp) else: plt.plot(data[j][i, :], label=legend_tmp) plt.legend() if subplot_title: plt.title(subplot_title[i], fontsize=size_police) plt.show()