diff --git a/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts b/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts index 3d6f085a4..ff3e2998f 100644 --- a/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts +++ b/packages/mermaid/src/rendering-util/rendering-elements/shapes/choice.ts @@ -3,20 +3,17 @@ import type { Node } from '$root/rendering-util/types.d.ts'; import type { SVG } from '$root/diagram-api/types.js'; // @ts-ignore TODO: Fix rough typings import rough from 'roughjs'; -import { solidStateFill, styles2String } from './handDrawnShapeStyles.js'; -import { getConfig } from '$root/diagram-api/diagramAPI.js'; +import { styles2String, userNodeOverrides } from './handDrawnShapeStyles.js'; +import { createPathFromPoints, getNodeClasses, labelHelper } from './util.js'; -export const choice = (parent: SVG, node: Node) => { - const { labelStyles, nodeStyles } = styles2String(node); - node.labelStyle = labelStyles; - const { themeVariables } = getConfig(); - const { lineColor } = themeVariables; - const shapeSvg = parent - .insert('g') - .attr('class', 'node default') - .attr('id', node.domId || node.id); +export const choice = async (parent: SVG, node: Node) => { + const { nodeStyles } = styles2String(node); + node.label = ''; + const { shapeSvg } = await labelHelper(parent, node, getNodeClasses(node)); + const { cssStyles } = node; + + const s = Math.max(28, node.width ?? 0); - const s = 28; const points = [ { x: 0, y: s / 2 }, { x: s / 2, y: 0 }, @@ -24,40 +21,34 @@ export const choice = (parent: SVG, node: Node) => { { x: -s / 2, y: 0 }, ]; - let choice; - if (node.look === 'handDrawn') { - // @ts-ignore TODO: Fix rough typings - const rc = rough.svg(shapeSvg); - const pointArr = points.map(function (d) { - return [d.x, d.y]; - }); - const roughNode = rc.polygon(pointArr, solidStateFill(lineColor)); - choice = shapeSvg.insert(() => roughNode); - } else { - choice = shapeSvg.insert('polygon', ':first-child').attr( - 'points', - points - .map(function (d) { - return d.x + ',' + d.y; - }) - .join(' ') - ); + // @ts-ignore TODO: Fix rough typings + const rc = rough.svg(shapeSvg); + const options = userNodeOverrides(node, {}); + + if (node.look !== 'handDrawn') { + options.roughness = 0; + options.fillStyle = 'solid'; } - // center the circle around its coordinate - choice - .attr('class', 'state-start') - // @ts-ignore TODO: Fix rough typings - .attr('r', 7) - .attr('width', 28) - .attr('height', 28) - .attr('style', nodeStyles); + const choicePath = createPathFromPoints(points); + const roughNode = rc.path(choicePath, options); + const choiceShape = shapeSvg.insert(() => roughNode, ':first-child'); + + choiceShape.attr('class', 'basic label-container'); + + if (cssStyles && node.look !== 'handDrawn') { + choiceShape.selectAll('path').attr('style', cssStyles); + } + + if (nodeStyles && node.look !== 'handDrawn') { + choiceShape.selectAll('path').attr('style', nodeStyles); + } node.width = 28; node.height = 28; node.intersect = function (point) { - return intersect.circle(node, 14, point); + return intersect.polygon(node, points, point); }; return shapeSvg;