Source code for hypergraphx.viz.plot_motifs

import math

from hypergraphx.motifs.utils import generate_motifs

_POS_COLOR = "#4C72B0"
_NEG_COLOR = "#C44E52"
_DEFAULT_BLOB_COLORS = {3: "#9BB8E8", 4: "#F3C7A6", "default": "#A8D5BA"}


def _sort_for_visualization(motifs: list):
    """
    Sort motifs for visualization.
    Motifs are sorted in such a way to show first lower order motifs, then higher order motifs.

    Parameters
    ----------
    motifs : list
        List of motifs to sort

    Returns
    -------
    list
        Sorted list of motifs
    """
    try:
        import numpy as np

        if isinstance(motifs, np.ndarray):
            return np.roll(motifs, 3)
    except Exception:
        pass

    motifs = list(motifs)
    shift = 3 % len(motifs)
    return motifs[-shift:] + motifs[:-shift]


def _default_motif_patterns(order: int = 3):
    mapping, _ = generate_motifs(order)
    return sorted(mapping.keys())


def _style_axes(ax, grid_axis="y"):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    if grid_axis is not None:
        ax.grid(axis=grid_axis, alpha=0.3, linestyle=":", linewidth=0.6)
    ax.set_axisbelow(True)
    ax.tick_params(axis="both", which="both", labelsize=9)


def _node_layout(k: int):
    if k == 1:
        return [(0.5, 0.5)]
    if k == 2:
        return [(0.2, 0.5), (0.8, 0.5)]
    if k == 3:
        h = math.sqrt(3) / 2.0
        return [(0.2, 0.25), (0.8, 0.25), (0.5, 0.25 + h * 0.4)]
    if k == 4:
        return [(0.2, 0.2), (0.8, 0.2), (0.8, 0.8), (0.2, 0.8)]
    coords = []
    cx, cy, r = 0.5, 0.5, 0.35
    for i in range(k):
        angle = 2 * math.pi * i / k
        coords.append((cx + r * math.cos(angle), cy + r * math.sin(angle)))
    return coords


def _expand_polygon(xs, ys, scale: float = 1.2):
    if not xs or not ys:
        return []
    cx = sum(xs) / len(xs)
    cy = sum(ys) / len(ys)
    pts = []
    for x, y in zip(xs, ys):
        angle = math.atan2(y - cy, x - cx)
        pts.append((angle, x, y))
    pts.sort(key=lambda tup: tup[0])
    vx, vy = [], []
    for angle, x, y in pts:
        dx, dy = x - cx, y - cy
        vx.append(cx + dx * scale)
        vy.append(cy + dy * scale)
    return list(zip(vx, vy))


def _draw_motif_icon(
    ax,
    events,
    center_x: float,
    icon_width: float = 0.85,
    aggregate: bool = False,
    y_center: float = 0.9,
    y_scale: float = 0.3,
    blob_colors: dict | None = None,
    node_color: str = "#222222",
    node_edge: str = "white",
    node_size: float = 30,
    edge_color: str = "#333333",
):
    if not events:
        return
    # Lazy import of Polygon to keep this module import-safe without matplotlib.
    try:
        from matplotlib.patches import Polygon  # type: ignore
    except ImportError as exc:  # pragma: no cover
        raise ImportError(
            "plot_motifs requires matplotlib. Install with `pip install hypergraphx[viz]`."
        ) from exc

    node_ids = sorted({n for ev in events for n in ev})
    k = len(node_ids)
    if k == 0:
        return
    base_coords = _node_layout(k)
    jitter = {node: (0.0, 0.0) for node in node_ids}
    icon_width = icon_width * (3 / max(3, k)) ** 0.25
    node_to_xy_local = {node_ids[i]: base_coords[i] for i in range(k)}
    L = len(events)
    if L == 0:
        return
    if blob_colors is None:
        blob_colors = _DEFAULT_BLOB_COLORS

    if aggregate:
        xs_all, ys_all = [], []
        for node in node_ids:
            x_local, y_local = node_to_xy_local[node]
            jx, jy = jitter[node]
            x_scaled = center_x + (x_local + jx - 0.5) * icon_width
            y_scaled = y_center + (y_local + jy - 0.5) * y_scale
            xs_all.append(x_scaled)
            ys_all.append(y_scaled)

        for ev in events:
            ev_set = set(ev)
            xs_present, ys_present = [], []
            for node in node_ids:
                if node in ev_set:
                    x_local, y_local = node_to_xy_local[node]
                    jx, jy = jitter[node]
                    x_scaled = center_x + (x_local + jx - 0.5) * icon_width
                    y_scaled = y_center + (y_local + jy - 0.5) * y_scale
                    xs_present.append(x_scaled)
                    ys_present.append(y_scaled)

            ev_size = len(xs_present)
            if ev_size >= 3:
                poly_pts = _expand_polygon(xs_present, ys_present, scale=1.35)
                if poly_pts:
                    color = blob_colors.get(
                        ev_size, blob_colors.get("default", "#55A868")
                    )
                    ax.add_patch(
                        Polygon(
                            poly_pts,
                            closed=True,
                            facecolor=color,
                            alpha=0.22,
                            edgecolor="#FFFFFF",
                            linewidth=0.4,
                        )
                    )
            elif ev_size == 2:
                x1, x2 = xs_present
                y1, y2 = ys_present
                ax.plot(
                    [x1, x2],
                    [y1, y2],
                    linewidth=1.1,
                    color=edge_color,
                    alpha=0.9,
                    solid_capstyle="round",
                )

        if xs_all:
            ax.scatter(
                xs_all,
                ys_all,
                s=node_size * 2.2,
                color=node_edge,
                alpha=0.25,
                linewidths=0,
                zorder=2,
            )
            ax.scatter(
                xs_all,
                ys_all,
                s=node_size,
                color=node_color,
                edgecolors=node_edge,
                linewidths=0.6,
                zorder=4,
            )
    else:
        row_height = 1.0 / L
        for e_idx, ev in enumerate(events):
            ev_set = set(ev)
            row_center_y = 1.0 - (e_idx + 0.5) / L

            xs_all, ys_all = [], []
            xs_present, ys_present = [], []

            for node in node_ids:
                x_local, y_local = node_to_xy_local[node]
                y_scaled = row_center_y + (y_local - 0.5) * (row_height * 0.7)
                x_scaled = center_x + (x_local - 0.5) * icon_width
                xs_all.append(x_scaled)
                ys_all.append(y_scaled)
                if node in ev_set:
                    xs_present.append(x_scaled)
                    ys_present.append(y_scaled)

            ev_size = len(xs_present)
            if ev_size >= 3:
                poly_pts = _expand_polygon(xs_present, ys_present, scale=1.25)
                if poly_pts:
                    if ev_size == 3:
                        color = "#4C72B0"
                    elif ev_size == 4:
                        color = "#DD8452"
                    else:
                        color = "#55A868"
                    ax.add_patch(
                        Polygon(
                            poly_pts,
                            closed=True,
                            facecolor=color,
                            alpha=0.22,
                            edgecolor="none",
                        )
                    )
            elif ev_size == 2:
                x1, x2 = xs_present
                y1, y2 = ys_present
                ax.plot([x1, x2], [y1, y2], linewidth=1.1, color="black", alpha=0.9)

            if xs_all:
                ax.scatter(xs_all, ys_all, s=8, color="0.8", zorder=3)
            if xs_present:
                ax.scatter(xs_present, ys_present, s=12, color="black", zorder=4)


[docs] def plot_motifs( motifs: list, save_name: str = None, show: bool = False, roman_numbers: bool = False, motif_patterns: list = None, pos_color: str = _POS_COLOR, neg_color: str = _NEG_COLOR, blob_colors: dict | None = None, icon_width: float = 0.85, icon_y: float = 0.86, icon_scale: float = 0.95, icon_row_ylim: tuple | None = None, node_color: str = "#222222", node_edge: str = "white", node_size: float = 30, edge_color: str = "#333333", show_motif_labels: bool = False, motif_labels: list | None = None, annotate_bars: bool = False, annotate_fmt: str = "{:+.2f}", hover_labels: bool = False, ): try: import matplotlib.pyplot as plt # type: ignore from matplotlib.patches import Polygon # type: ignore except ImportError as exc: # pragma: no cover raise ImportError( "plot_motifs requires matplotlib. Install with `pip install hypergraphx[viz]`." ) from exc """ Plot motifs. Motifs are sorted in such a way to show first lower order motifs, then higher order motifs. Parameters ---------- motifs : list List of motif scores or list of (motif, score) pairs. save_name : str, optional Name of the file to save the plot, by default None show : bool If True, call plt.show(). roman_numbers : bool If True, use roman numerals on the x-axis instead of motif drawings. motif_patterns : list, optional List of motif patterns to draw when roman_numbers is False. If None, defaults to the canonical order of 3-node motifs. pos_color, neg_color : str Colors for positive/negative bars. blob_colors : dict, optional Colors for motif blobs by hyperedge size (keys 3, 4, 'default'). icon_width, icon_y, icon_scale : float Size and placement controls for graphlet icons. icon_row_ylim : tuple, optional (ymin, ymax) for the icon row. If None, computed from icon_y/scale. node_color, node_edge, node_size, edge_color : styling for graphlet nodes/edges. show_motif_labels : bool If True, show motif labels under graphlets. motif_labels : list, optional Labels for motifs (length 6). Defaults to M1..M6 when show_motif_labels is True. annotate_bars : bool If True, annotate bar values above/below bars. annotate_fmt : str Format string for bar annotations. hover_labels : bool If True, add interactive hover labels (requires mplcursors). Raises ------ ValueError Motifs must be a list of length 6. Returns ------- matplotlib.axes.Axes The axes containing the bar chart. """ if len(motifs) == 0: raise ValueError("Motifs must be a non-empty list.") if all(isinstance(m, (list, tuple)) and len(m) >= 2 for m in motifs): motifs = [m[1] for m in motifs] motifs = [float(x) for x in motifs] if len(motifs) != 6: raise ValueError("Motifs must be a list of length 6.") if motif_labels is not None and len(motif_labels) != len(motifs): raise ValueError("motif_labels length must match motifs length.") motifs = _sort_for_visualization(motifs) if motif_patterns is None: motif_patterns = _default_motif_patterns(order=3) if len(motif_patterns) != len(motifs): raise ValueError("motif_patterns length must match motifs length.") motif_patterns = _sort_for_visualization(motif_patterns) cols = [neg_color if (x < 0) else pos_color for x in motifs] labels = ["I", "II", "III", "IV", "V", "VI"] if roman_numbers: fig = plt.gcf() ax_bar = fig.add_subplot(111) idx = list(range(len(motifs))) _style_axes(ax_bar, grid_axis=None) bars = ax_bar.bar( idx, motifs, color=cols, width=0.85, edgecolor="white", linewidth=0.6, zorder=3, ) ax_bar.axhline(0, color="black", linewidth=0.5) ax_bar.set_ylabel("Motif abundance score") ax_bar.set_xticks(idx) ax_bar.set_xticklabels(labels) ax_bar.set_ylim(-1, 1) ax_bar.set_xlim(-0.5, len(motifs) - 0.5) if annotate_bars: for i, val in enumerate(motifs): ax_bar.text( i, val + (0.03 if val >= 0 else -0.03), annotate_fmt.format(val), ha="center", va="bottom" if val >= 0 else "top", fontsize=8, color="#333333", ) if hover_labels: try: import mplcursors labels = motif_labels or [f"M{i + 1}" for i in range(len(motifs))] cursor = mplcursors.cursor(bars, hover=True) @cursor.connect("add") def _on_add(sel): idx = sel.index sel.annotation.set_text(f"{labels[idx]}: {motifs[idx]:.3f}") except Exception: pass else: fig = plt.figure(figsize=(max(8, 0.6 * len(motifs)), 5)) gs = fig.add_gridspec(2, 1, height_ratios=[2.4, 0.6], hspace=0.01) ax_bar = fig.add_subplot(gs[0]) idx = list(range(len(motifs))) _style_axes(ax_bar, grid_axis=None) bars = ax_bar.bar( idx, motifs, color=cols, width=0.85, edgecolor="white", linewidth=0.6, zorder=3, ) ax_bar.axhline(0, color="black", linewidth=0.5) ax_bar.set_ylim(-1, 1) ax_bar.set_ylabel("Motif abundance score") ax_bar.set_xticks(idx) ax_bar.set_xticklabels([""] * len(idx)) ax_bar.tick_params(axis="x", bottom=True, length=4, width=0.8, color="#333333") ax_bar.set_xlim(-0.5, len(motifs) - 0.5) if annotate_bars: for i, val in enumerate(motifs): ax_bar.text( i, val + (0.03 if val >= 0 else -0.03), annotate_fmt.format(val), ha="center", va="bottom" if val >= 0 else "top", fontsize=8, color="#333333", ) if hover_labels: try: import mplcursors labels = motif_labels or [f"M{i + 1}" for i in range(len(motifs))] cursor = mplcursors.cursor(bars, hover=True) @cursor.connect("add") def _on_add(sel): idx = sel.index sel.annotation.set_text(f"{labels[idx]}: {motifs[idx]:.3f}") except Exception: pass ax_icon = fig.add_subplot(gs[1]) if icon_row_ylim is None: icon_row_ylim = ( icon_y - 0.5 * icon_scale - 0.1, icon_y + 0.5 * icon_scale + 0.1, ) ax_icon.set_ylim(*icon_row_ylim) ax_icon.set_xlim(-0.5, len(motif_patterns) - 0.5) ax_icon.set_yticks([]) ax_icon.set_xticks([]) ax_icon.tick_params(axis="x", bottom=False, top=False, labelbottom=False) for spine in ax_icon.spines.values(): spine.set_visible(False) ax_icon.patch.set_alpha(0) ax_icon.set_zorder(0) ax_bar.set_zorder(2) for i, pattern in enumerate(motif_patterns): _draw_motif_icon( ax_icon, pattern, center_x=i, icon_width=icon_width, aggregate=True, y_center=icon_y, y_scale=icon_scale, blob_colors=blob_colors, node_color=node_color, node_edge=node_edge, node_size=node_size, edge_color=edge_color, ) if show_motif_labels: labels = motif_labels or [f"M{i + 1}" for i in range(len(motifs))] label_y = icon_row_ylim[0] + 0.02 for i, label in enumerate(labels): ax_icon.text( i, label_y, label, ha="center", va="bottom", fontsize=8, color="#333333", ) fig.subplots_adjust(hspace=0.01, bottom=0.08, top=0.97) if save_name != None: plt.savefig("{}".format(save_name), bbox_inches="tight") if show: plt.show() return ax_bar