From 2ec9f0af2e639ac0835462df206fa9954758e0e4 Mon Sep 17 00:00:00 2001 From: yari-dewalt Date: Fri, 13 Sep 2024 12:53:20 -0700 Subject: [PATCH] Add support for new renderer --- .../mermaid/src/diagrams/class/classDb.ts | 167 +++++++++++++++++- .../src/diagrams/class/classDiagram.ts | 2 +- .../class/classRenderer-v3-unified.ts | 23 +-- 3 files changed, 171 insertions(+), 21 deletions(-) diff --git a/packages/mermaid/src/diagrams/class/classDb.ts b/packages/mermaid/src/diagrams/class/classDb.ts index 32d007634..3dd6dabdf 100644 --- a/packages/mermaid/src/diagrams/class/classDb.ts +++ b/packages/mermaid/src/diagrams/class/classDb.ts @@ -1,9 +1,8 @@ -import type { Selection } from 'd3'; -import { select } from 'd3'; +import { select, type Selection } from 'd3'; import { log } from '../../logger.js'; import { getConfig } from '../../diagram-api/diagramAPI.js'; import common from '../common/common.js'; -import utils from '../../utils.js'; +import utils, { getEdgeId } from '../../utils.js'; import { setAccTitle, getAccTitle, @@ -21,12 +20,15 @@ import type { ClassMap, NamespaceMap, NamespaceNode, + StyleClass, } from './classTypes.js'; +import type { Node, Edge } from '../../rendering-util/types.js'; const MERMAID_DOM_ID_PREFIX = 'classId-'; let relations: ClassRelation[] = []; let classes = new Map(); +const styleClasses = new Map(); let notes: ClassNote[] = []; let classCounter = 0; let namespaces = new Map(); @@ -59,7 +61,7 @@ export const setClassLabel = function (_id: string, label: string) { const { className } = splitClassNameAndType(id); classes.get(className)!.label = label; classes.get(className)!.text = - `${label}${classes.get(className)!.type ? `<${classes.get(className)!.type}>` : ''}`; + `${label}${classes.get(className)!.type ? `<${classes.get(className)!.type}>` : ''}`; }; /** @@ -82,7 +84,7 @@ export const addClass = function (_id: string) { id: name, type: type, label: name, - text: `${name}${type ? `<${type}>` : ''}`, + text: `${name}${type ? `<${type}>` : ''}`, shape: 'classBox', cssClasses: [], methods: [], @@ -238,6 +240,36 @@ export const setCssClass = function (ids: string, className: string) { }); }; +export const defineClass = function (ids: string[], style: string[]) { + for (const id of ids) { + let styleClass = styleClasses.get(id); + if (styleClass === undefined) { + styleClass = { id, styles: [], textStyles: [] }; + styleClasses.set(id, styleClass); + } + + if (style !== undefined && style !== null) { + style.forEach(function (s) { + if (/color/.exec(s)) { + const newStyle = s.replace('fill', 'bgFill'); // .replace('color', 'fill'); + styleClass.textStyles.push(newStyle); + } + styleClass.styles.push(s); + }); + } + + for (const [, value] of classes) { + if (value.cssClasses.includes(id)) { + for (const s of style) { + for (const k of s.split(',')) { + value.styles.push(k); + } + } + } + } + } +}; + /** * Called by parser when a tooltip is found, e.g. a clickable element. * @@ -476,9 +508,131 @@ export const setCssStyle = function (id: string, styles: string[]) { } }; +/** + * Gets the arrow marker for a type index + * + * @param type - The type to look for + * @returns The arrow marker + */ +function getArrowMarker(type: number) { + let marker; + switch (type) { + case 0: + marker = 'aggregation'; + break; + case 1: + marker = 'extension'; + break; + case 2: + marker = 'composition'; + break; + case 3: + marker = 'dependency'; + break; + case 4: + marker = 'lollipop'; + break; + default: + marker = 'none'; + } + return marker; +} + export const getData = () => { + const nodes: Node[] = []; + const edges: Edge[] = []; const config = getConfig(); - return { nodes: classes, edges: relations, other: {}, config, direction: getDirection() }; + + for (const namespaceKey of namespaces.keys()) { + const namespace = namespaces.get(namespaceKey); + if (namespace) { + const node: Node = { + id: namespace.id, + label: namespace.id, + isGroup: false, + // parent node must be one of [rect, roundedWithTitle, noteGroup, divider] + shape: 'rect', + cssStyles: ['fill: none', 'stroke: black'], + look: config.look, + }; + nodes.push(node); + } + } + + for (const classKey of classes.keys()) { + const classNode = classes.get(classKey); + if (classNode) { + const node = classNode as unknown as Node; + node.parentId = classNode.parent; + nodes.push(node); + } + } + + let cnt = 0; + for (const note of notes) { + cnt++; + const noteNode: Node = { + id: note.id, + label: note.text.replaceAll('\\n', '
'), // 'rect' shape label sanitizes these newlines so must change to
manually + isGroup: false, + shape: 'rect', + padding: config.class!.padding ?? 6, + cssStyles: ['text-align: left'], + look: config.look, + }; + nodes.push(noteNode); + + const noteClassId = classes.get(note.class)?.id ?? ''; + + if (noteClassId) { + const edge: Edge = { + id: `edgeNote${cnt}`, + start: note.id, + end: noteClassId, + type: 'normal', + thickness: 'normal', + classes: 'relation', + arrowTypeStart: 'none', + arrowTypeEnd: 'none', + arrowheadStyle: '', + labelStyle: [''], + style: ['fill: none'], + pattern: 'dotted', + look: config.look, + }; + edges.push(edge); + } + } + + cnt = 0; + for (const relation of relations) { + cnt++; + const edge: Edge = { + id: getEdgeId(relation.id1, relation.id2, { + prefix: 'id', + counter: cnt, + }), + start: relation.id1, + end: relation.id2, + type: 'normal', + label: relation.title, + labelpos: 'c', + thickness: 'normal', + classes: 'relation', + arrowTypeStart: getArrowMarker(relation.relation.type1), + arrowTypeEnd: getArrowMarker(relation.relation.type2), + startLabelRight: relation.relationTitle1 === 'none' ? '' : relation.relationTitle1, + endLabelLeft: relation.relationTitle2 === 'none' ? '' : relation.relationTitle2, + arrowheadStyle: '', + labelStyle: ['display: inline-block'], + style: relation.style || '', + pattern: relation.relation.lineType == 1 ? 'dashed' : 'solid', + look: config.look, + }; + edges.push(edge); + } + + return { nodes, edges, other: {}, config, direction: getDirection() }; }; export default { @@ -506,6 +660,7 @@ export default { relationType, setClickEvent, setCssClass, + defineClass, setLink, getTooltip, setTooltip, diff --git a/packages/mermaid/src/diagrams/class/classDiagram.ts b/packages/mermaid/src/diagrams/class/classDiagram.ts index 7f027c186..6a3747e41 100644 --- a/packages/mermaid/src/diagrams/class/classDiagram.ts +++ b/packages/mermaid/src/diagrams/class/classDiagram.ts @@ -3,7 +3,7 @@ import type { DiagramDefinition } from '../../diagram-api/types.js'; import parser from './parser/classDiagram.jison'; import db from './classDb.js'; import styles from './styles.js'; -import renderer from './classRenderer.js'; +import renderer from './classRenderer-v3-unified.js'; export const diagram: DiagramDefinition = { parser, diff --git a/packages/mermaid/src/diagrams/class/classRenderer-v3-unified.ts b/packages/mermaid/src/diagrams/class/classRenderer-v3-unified.ts index 404a53c6c..f17b7312c 100644 --- a/packages/mermaid/src/diagrams/class/classRenderer-v3-unified.ts +++ b/packages/mermaid/src/diagrams/class/classRenderer-v3-unified.ts @@ -1,8 +1,8 @@ import { getConfig } from '../../diagram-api/diagramAPI.js'; import type { DiagramStyleClassDef } from '../../diagram-api/types.js'; import { log } from '../../logger.js'; -import { getDiagramElements } from '../../rendering-util/insertElementsForSize.js'; -import { render } from '../../rendering-util/render.js'; +import { getDiagramElement } from '../../rendering-util/insertElementsForSize.js'; +import { getRegisteredLayoutAlgorithm, render } from '../../rendering-util/render.js'; import { setupViewPortForSVG } from '../../rendering-util/setupViewPortForSVG.js'; import type { LayoutData } from '../../rendering-util/types.js'; import utils from '../../utils.js'; @@ -16,7 +16,7 @@ import utils from '../../utils.js'; * @param defaultDir - the direction to use if none is found * @returns The direction to use */ -export const getDir = (parsedItem: any, defaultDir = DEFAULT_NESTED_DOC_DIR) => { +export const getDir = (parsedItem: any, defaultDir = 'TB') => { if (!parsedItem.doc) { return defaultDir; } @@ -36,7 +36,6 @@ export const getClasses = function ( text: string, diagramObj: any ): Map { - // diagramObj.db.extract(diagramObj.db.getRootDocV2()); return diagramObj.db.getClasses(); }; @@ -48,29 +47,25 @@ export const draw = async function (text: string, id: string, _version: string, // Not related to the refactoring, but this is the first step in the rendering process // diag.db.extract(diag.db.getRootDocV2()); - //const DIR = getDir(diag.db.getRootDocV2()); - // The getData method provided in all supported diagrams is used to extract the data from the parsed structure // into the Layout data format const data4Layout = diag.db.getData() as LayoutData; // Create the root SVG - the element is the div containing the SVG element - const { element, svg } = getDiagramElements(id, securityLevel); + const svg = getDiagramElement(id, securityLevel); data4Layout.type = diag.type; - data4Layout.layoutAlgorithm = layout; - - // TODO: Should we move these two to baseConfig? These types are not there in StateConfig. + data4Layout.layoutAlgorithm = getRegisteredLayoutAlgorithm(layout); data4Layout.nodeSpacing = conf?.nodeSpacing || 50; data4Layout.rankSpacing = conf?.rankSpacing || 50; - data4Layout.markers = ['barb']; + data4Layout.markers = ['aggregation', 'extension', 'composition', 'dependency', 'lollipop']; data4Layout.diagramId = id; - await render(data4Layout, svg, element); + await render(data4Layout, svg); const padding = 8; utils.insertTitle( - element, - 'statediagramTitleText', + svg, + 'classDiagramTitleText', conf?.titleTopMargin ?? 25, diag.db.getDiagramTitle() );