# Copyright (C) 2022 The Qt Company Ltd. # SPDX-License-Identifier: LicenseRef-Qt-Commercial OR BSD-3-Clause import math import sys from PySide6.QtCore import (QEasingCurve, QLineF, QParallelAnimationGroup, QPointF, QPropertyAnimation, QRectF, Qt) from PySide6.QtGui import QBrush, QColor, QPainter, QPen, QPolygonF from PySide6.QtWidgets import (QApplication, QComboBox, QGraphicsItem, QGraphicsObject, QGraphicsScene, QGraphicsView, QStyleOptionGraphicsItem, QVBoxLayout, QWidget) import networkx as nx class Node(QGraphicsObject): """A QGraphicsItem representing node in a graph""" def __init__(self, name: str, parent=None): """Node constructor Args: name (str): Node label """ super().__init__(parent) self._name = name self._edges = [] self._color = "#5AD469" self._radius = 30 self._rect = QRectF(0, 0, self._radius * 2, self._radius * 2) self.setFlag(QGraphicsItem.ItemIsMovable) self.setFlag(QGraphicsItem.ItemSendsGeometryChanges) self.setCacheMode(QGraphicsItem.DeviceCoordinateCache) def boundingRect(self) -> QRectF: """Override from QGraphicsItem Returns: QRect: Return node bounding rect """ return self._rect def paint(self, painter: QPainter, option: QStyleOptionGraphicsItem, widget: QWidget = None): """Override from QGraphicsItem Draw node Args: painter (QPainter) option (QStyleOptionGraphicsItem) """ painter.setRenderHints(QPainter.Antialiasing) painter.setPen( QPen( QColor(self._color).darker(), 2, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin, ) ) painter.setBrush(QBrush(QColor(self._color))) painter.drawEllipse(self.boundingRect()) painter.setPen(QPen(QColor("white"))) painter.drawText(self.boundingRect(), Qt.AlignCenter, self._name) def add_edge(self, edge): """Add an edge to this node Args: edge (Edge) """ self._edges.append(edge) def itemChange(self, change: QGraphicsItem.GraphicsItemChange, value): """Override from QGraphicsItem Args: change (QGraphicsItem.GraphicsItemChange) value (Any) Returns: Any """ if change == QGraphicsItem.ItemPositionHasChanged: for edge in self._edges: edge.adjust() return super().itemChange(change, value) class Edge(QGraphicsItem): def __init__(self, source: Node, dest: Node, parent: QGraphicsItem = None): """Edge constructor Args: source (Node): source node dest (Node): destination node """ super().__init__(parent) self._source = source self._dest = dest self._tickness = 2 self._color = "#2BB53C" self._arrow_size = 20 self._source.add_edge(self) self._dest.add_edge(self) self._line = QLineF() self.setZValue(-1) self.adjust() def boundingRect(self) -> QRectF: """Override from QGraphicsItem Returns: QRect: Return node bounding rect """ return ( QRectF(self._line.p1(), self._line.p2()) .normalized() .adjusted( -self._tickness - self._arrow_size, -self._tickness - self._arrow_size, self._tickness + self._arrow_size, self._tickness + self._arrow_size, ) ) def adjust(self): """ Update edge position from source and destination node. This method is called from Node::itemChange """ self.prepareGeometryChange() self._line = QLineF( self._source.pos() + self._source.boundingRect().center(), self._dest.pos() + self._dest.boundingRect().center(), ) def _draw_arrow(self, painter: QPainter, start: QPointF, end: QPointF): """Draw arrow from start point to end point. Args: painter (QPainter) start (QPointF): start position end (QPointF): end position """ painter.setBrush(QBrush(self._color)) line = QLineF(end, start) angle = math.atan2(-line.dy(), line.dx()) arrow_p1 = line.p1() + QPointF( math.sin(angle + math.pi / 3) * self._arrow_size, math.cos(angle + math.pi / 3) * self._arrow_size, ) arrow_p2 = line.p1() + QPointF( math.sin(angle + math.pi - math.pi / 3) * self._arrow_size, math.cos(angle + math.pi - math.pi / 3) * self._arrow_size, ) arrow_head = QPolygonF() arrow_head.clear() arrow_head.append(line.p1()) arrow_head.append(arrow_p1) arrow_head.append(arrow_p2) painter.drawLine(line) painter.drawPolygon(arrow_head) def _arrow_target(self) -> QPointF: """Calculate the position of the arrow taking into account the size of the destination node Returns: QPointF """ target = self._line.p1() center = self._line.p2() radius = self._dest._radius vector = target - center length = math.sqrt(vector.x() ** 2 + vector.y() ** 2) if length == 0: return target normal = vector / length target = QPointF(center.x() + (normal.x() * radius), center.y() + (normal.y() * radius)) return target def paint(self, painter: QPainter, option: QStyleOptionGraphicsItem, widget=None): """Override from QGraphicsItem Draw Edge. This method is called from Edge.adjust() Args: painter (QPainter) option (QStyleOptionGraphicsItem) """ if self._source and self._dest: painter.setRenderHints(QPainter.Antialiasing) painter.setPen( QPen( QColor(self._color), self._tickness, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin, ) ) painter.drawLine(self._line) self._draw_arrow(painter, self._line.p1(), self._arrow_target()) self._arrow_target() class GraphView(QGraphicsView): def __init__(self, graph: nx.DiGraph, parent=None): """GraphView constructor This widget can display a directed graph Args: graph (nx.DiGraph): a networkx directed graph """ super().__init__() self._graph = graph self._scene = QGraphicsScene() self.setScene(self._scene) # Used to add space between nodes self._graph_scale = 200 # Map node name to Node object {str=>Node} self._nodes_map = {} # List of networkx layout function self._nx_layout = { "circular": nx.circular_layout, "planar": nx.planar_layout, "random": nx.random_layout, "shell_layout": nx.shell_layout, "kamada_kawai_layout": nx.kamada_kawai_layout, "spring_layout": nx.spring_layout, "spiral_layout": nx.spiral_layout, } self._load_graph() self.set_nx_layout("circular") def get_nx_layouts(self) -> list: """Return all layout names Returns: list: layout name (str) """ return self._nx_layout.keys() def set_nx_layout(self, name: str): """Set networkx layout and start animation Args: name (str): Layout name """ if name in self._nx_layout: self._nx_layout_function = self._nx_layout[name] # Compute node position from layout function positions = self._nx_layout_function(self._graph) # Change position of all nodes using an animation self.animations = QParallelAnimationGroup() for node, pos in positions.items(): x, y = pos x *= self._graph_scale y *= self._graph_scale item = self._nodes_map[node] animation = QPropertyAnimation(item, b"pos") animation.setDuration(1000) animation.setEndValue(QPointF(x, y)) animation.setEasingCurve(QEasingCurve.OutExpo) self.animations.addAnimation(animation) self.animations.start() def _load_graph(self): """Load graph into QGraphicsScene using Node class and Edge class""" self.scene().clear() self._nodes_map.clear() # Add nodes for node in self._graph: item = Node(node) self.scene().addItem(item) self._nodes_map[node] = item # Add edges for a, b in self._graph.edges: source = self._nodes_map[a] dest = self._nodes_map[b] self.scene().addItem(Edge(source, dest)) class MainWindow(QWidget): def __init__(self, parent=None): super().__init__() self.graph = nx.DiGraph() self.graph.add_edges_from( [ ("1", "2"), ("2", "3"), ("3", "4"), ("1", "5"), ("1", "6"), ("1", "7"), ] ) self.view = GraphView(self.graph) self.choice_combo = QComboBox() self.choice_combo.addItems(self.view.get_nx_layouts()) v_layout = QVBoxLayout(self) v_layout.addWidget(self.choice_combo) v_layout.addWidget(self.view) self.choice_combo.currentTextChanged.connect(self.view.set_nx_layout) if __name__ == "__main__": app = QApplication(sys.argv) # Create a networkx graph widget = MainWindow() widget.show() widget.resize(800, 600) sys.exit(app.exec())