diff --git a/packages/mermaid/src/diagrams/er/erDb.ts b/packages/mermaid/src/diagrams/er/erDb.ts index 7a1e2f025..5e9ba20b6 100644 --- a/packages/mermaid/src/diagrams/er/erDb.ts +++ b/packages/mermaid/src/diagrams/er/erDb.ts @@ -2,7 +2,6 @@ import { log } from '../../logger.js'; import { getConfig } from '../../diagram-api/diagramAPI.js'; import type { Edge, Node } from '../../rendering-util/types.js'; import type { EntityNode, Attribute, Relationship, EntityClass, RelSpec } from './erTypes.js'; - import { setAccTitle, getAccTitle, @@ -13,228 +12,239 @@ import { getDiagramTitle, } from '../common/commonDb.js'; import { getEdgeId } from '../../utils.js'; +import type { DiagramDB } from '../../diagram-api/types.js'; -let entities = new Map(); -let relationships: Relationship[] = []; -let classes = new Map(); -let direction = 'TB'; +export class ErDB implements DiagramDB { + private entities = new Map(); + private relationships: Relationship[] = []; + private classes = new Map(); + private direction = 'TB'; -const Cardinality = { - ZERO_OR_ONE: 'ZERO_OR_ONE', - ZERO_OR_MORE: 'ZERO_OR_MORE', - ONE_OR_MORE: 'ONE_OR_MORE', - ONLY_ONE: 'ONLY_ONE', - MD_PARENT: 'MD_PARENT', -}; - -const Identification = { - NON_IDENTIFYING: 'NON_IDENTIFYING', - IDENTIFYING: 'IDENTIFYING', -}; -/** - * Add entity - * @param name - The name of the entity - * @param alias - The alias of the entity - */ -const addEntity = function (name: string, alias = ''): EntityNode { - if (!entities.has(name)) { - entities.set(name, { - id: `entity-${name}-${entities.size}`, - label: name, - attributes: [], - alias, - shape: 'erBox', - look: getConfig().look || 'default', - cssClasses: 'default', - cssStyles: [], - }); - log.info('Added new entity :', name); - } else if (!entities.get(name)?.alias && alias) { - entities.get(name)!.alias = alias; - log.info(`Add alias '${alias}' to entity '${name}'`); - } - - return entities.get(name)!; -}; - -const getEntity = function (name: string) { - return entities.get(name); -}; - -const getEntities = () => entities; - -const getClasses = () => classes; - -const addAttributes = function (entityName: string, attribs: Attribute[]) { - const entity = addEntity(entityName); // May do nothing (if entity has already been added) - - // Process attribs in reverse order due to effect of recursive construction (last attribute is first) - let i; - for (i = attribs.length - 1; i >= 0; i--) { - if (!attribs[i].keys) { - attribs[i].keys = []; - } - if (!attribs[i].comment) { - attribs[i].comment = ''; - } - entity.attributes.push(attribs[i]); - log.debug('Added attribute ', attribs[i].name); - } -}; - -/** - * Add a relationship - * - * @param entA - The first entity in the relationship - * @param rolA - The role played by the first entity in relation to the second - * @param entB - The second entity in the relationship - * @param rSpec - The details of the relationship between the two entities - */ -const addRelationship = function (entA: string, rolA: string, entB: string, rSpec: RelSpec) { - const entityA = entities.get(entA); - const entityB = entities.get(entB); - if (!entityA || !entityB) { - return; - } - - const rel = { - entityA: entityA.id, - roleA: rolA, - entityB: entityB.id, - relSpec: rSpec, + private Cardinality = { + ZERO_OR_ONE: 'ZERO_OR_ONE', + ZERO_OR_MORE: 'ZERO_OR_MORE', + ONE_OR_MORE: 'ONE_OR_MORE', + ONLY_ONE: 'ONLY_ONE', + MD_PARENT: 'MD_PARENT', }; - relationships.push(rel); - log.debug('Added new relationship :', rel); -}; + private Identification = { + NON_IDENTIFYING: 'NON_IDENTIFYING', + IDENTIFYING: 'IDENTIFYING', + }; -const getRelationships = () => relationships; + constructor() { + this.clear(); + this.addEntity = this.addEntity.bind(this); + this.addAttributes = this.addAttributes.bind(this); + this.addRelationship = this.addRelationship.bind(this); + this.setDirection = this.setDirection.bind(this); + this.addCssStyles = this.addCssStyles.bind(this); + this.addClass = this.addClass.bind(this); + this.setClass = this.setClass.bind(this); + this.setAccTitle = this.setAccTitle.bind(this); + this.setAccDescription = this.setAccDescription.bind(this); + } -export const getDirection = () => direction; -const setDirection = (dir: string) => { - direction = dir; -}; + /** + * Add entity + * @param name - The name of the entity + * @param alias - The alias of the entity + */ + public addEntity(name: string, alias = ''): EntityNode { + if (!this.entities.has(name)) { + this.entities.set(name, { + id: `entity-${name}-${this.entities.size}`, + label: name, + attributes: [], + alias, + shape: 'erBox', + look: getConfig().look ?? 'default', + cssClasses: 'default', + cssStyles: [], + }); + log.info('Added new entity :', name); + } else if (!this.entities.get(name)?.alias && alias) { + this.entities.get(name)!.alias = alias; + log.info(`Add alias '${alias}' to entity '${name}'`); + } -const clear = function () { - entities = new Map(); - classes = new Map(); - relationships = []; - commonClear(); -}; + return this.entities.get(name)!; + } -export const getData = function () { - const nodes: Node[] = []; - const edges: Edge[] = []; - const config = getConfig(); + public getEntity(name: string) { + return this.entities.get(name); + } - for (const entityKey of entities.keys()) { - const entityNode = entities.get(entityKey); - if (entityNode) { - entityNode.cssCompiledStyles = getCompiledStyles(entityNode.cssClasses!.split(' ')); - nodes.push(entityNode as unknown as Node); + public getEntities() { + return this.entities; + } + + public getClasses() { + return this.classes; + } + + public addAttributes(entityName: string, attribs: Attribute[]) { + const entity = this.addEntity(entityName); // May do nothing (if entity has already been added) + + // Process attribs in reverse order due to effect of recursive construction (last attribute is first) + let i; + for (i = attribs.length - 1; i >= 0; i--) { + if (!attribs[i].keys) { + attribs[i].keys = []; + } + if (!attribs[i].comment) { + attribs[i].comment = ''; + } + entity.attributes.push(attribs[i]); + log.debug('Added attribute ', attribs[i].name); } } - let count = 0; - for (const relationship of relationships) { - const edge: Edge = { - id: getEdgeId(relationship.entityA, relationship.entityB, { prefix: 'id', counter: count++ }), - type: 'normal', - start: relationship.entityA, - end: relationship.entityB, - label: relationship.roleA, - labelpos: 'c', - thickness: 'normal', - classes: 'relationshipLine', - arrowTypeStart: relationship.relSpec.cardB.toLowerCase(), - arrowTypeEnd: relationship.relSpec.cardA.toLowerCase(), - pattern: relationship.relSpec.relType == 'IDENTIFYING' ? 'solid' : 'dashed', - look: config.look, - }; - edges.push(edge); - } - return { nodes, edges, other: {}, config, direction: 'TB' }; -}; - -export const addCssStyles = function (ids: string[], styles: string[]) { - for (const id of ids) { - const entity = entities.get(id); - if (!styles || !entity) { + /** + * Add a relationship + * + * @param entA - The first entity in the relationship + * @param rolA - The role played by the first entity in relation to the second + * @param entB - The second entity in the relationship + * @param rSpec - The details of the relationship between the two entities + */ + public addRelationship(entA: string, rolA: string, entB: string, rSpec: RelSpec) { + const entityA = this.entities.get(entA); + const entityB = this.entities.get(entB); + if (!entityA || !entityB) { return; } - for (const style of styles) { - entity.cssStyles!.push(style); - } + + const rel = { + entityA: entityA.id, + roleA: rolA, + entityB: entityB.id, + relSpec: rSpec, + }; + + this.relationships.push(rel); + log.debug('Added new relationship :', rel); } -}; -export const addClass = function (ids: string[], style: string[]) { - ids.forEach(function (id) { - let classNode = classes.get(id); - if (classNode === undefined) { - classNode = { id, styles: [], textStyles: [] }; - classes.set(id, classNode); + public getRelationships() { + return this.relationships; + } + + public getDirection() { + return this.direction; + } + + public setDirection(dir: string) { + this.direction = dir; + } + + private getCompiledStyles(classDefs: string[]) { + let compiledStyles: string[] = []; + for (const customClass of classDefs) { + const cssClass = this.classes.get(customClass); + if (cssClass?.styles) { + compiledStyles = [...compiledStyles, ...(cssClass.styles ?? [])].map((s) => s.trim()); + } + if (cssClass?.textStyles) { + compiledStyles = [...compiledStyles, ...(cssClass.textStyles ?? [])].map((s) => s.trim()); + } } + return compiledStyles; + } - if (style) { - style.forEach(function (s) { - if (/color/.exec(s)) { - const newStyle = s.replace('fill', 'bgFill'); - classNode.textStyles.push(newStyle); - } - classNode.styles.push(s); - }); - } - }); -}; - -export const setClass = function (ids: string[], classNames: string[]) { - for (const id of ids) { - const entity = entities.get(id); - if (entity) { - for (const className of classNames) { - entity.cssClasses += ' ' + className; + public addCssStyles(ids: string[], styles: string[]) { + for (const id of ids) { + const entity = this.entities.get(id); + if (!styles || !entity) { + return; + } + for (const style of styles) { + entity.cssStyles!.push(style); } } } -}; -function getCompiledStyles(classDefs: string[]) { - let compiledStyles: string[] = []; - for (const customClass of classDefs) { - const cssClass = classes.get(customClass); - if (cssClass?.styles) { - compiledStyles = [...compiledStyles, ...(cssClass.styles ?? [])].map((s) => s.trim()); - } - if (cssClass?.textStyles) { - compiledStyles = [...compiledStyles, ...(cssClass.textStyles ?? [])].map((s) => s.trim()); + public addClass(ids: string[], style: string[]) { + ids.forEach((id) => { + let classNode = this.classes.get(id); + if (classNode === undefined) { + classNode = { id, styles: [], textStyles: [] }; + this.classes.set(id, classNode); + } + + if (style) { + style.forEach(function (s) { + if (/color/.exec(s)) { + const newStyle = s.replace('fill', 'bgFill'); + classNode.textStyles.push(newStyle); + } + classNode.styles.push(s); + }); + } + }); + } + + public setClass(ids: string[], classNames: string[]) { + for (const id of ids) { + const entity = this.entities.get(id); + if (entity) { + for (const className of classNames) { + entity.cssClasses += ' ' + className; + } + } } } - return compiledStyles; -} -export default { - Cardinality, - Identification, - getConfig: () => getConfig().er, - addEntity, - addAttributes, - getEntities, - getEntity, - getClasses, - addRelationship, - getRelationships, - clear, - getDirection, - setDirection, - setAccTitle, - getAccTitle, - setAccDescription, - getAccDescription, - setDiagramTitle, - getDiagramTitle, - getData, - addCssStyles, - addClass, - setClass, -}; + public clear() { + this.entities = new Map(); + this.classes = new Map(); + this.relationships = []; + commonClear(); + } + + public getData() { + const nodes: Node[] = []; + const edges: Edge[] = []; + const config = getConfig(); + + for (const entityKey of this.entities.keys()) { + const entityNode = this.entities.get(entityKey); + if (entityNode) { + entityNode.cssCompiledStyles = this.getCompiledStyles(entityNode.cssClasses!.split(' ')); + nodes.push(entityNode as unknown as Node); + } + } + + let count = 0; + for (const relationship of this.relationships) { + const edge: Edge = { + id: getEdgeId(relationship.entityA, relationship.entityB, { + prefix: 'id', + counter: count++, + }), + type: 'normal', + start: relationship.entityA, + end: relationship.entityB, + label: relationship.roleA, + labelpos: 'c', + thickness: 'normal', + classes: 'relationshipLine', + arrowTypeStart: relationship.relSpec.cardB.toLowerCase(), + arrowTypeEnd: relationship.relSpec.cardA.toLowerCase(), + pattern: relationship.relSpec.relType == 'IDENTIFYING' ? 'solid' : 'dashed', + look: config.look, + }; + edges.push(edge); + } + return { nodes, edges, other: {}, config, direction: 'TB' }; + } + + public setAccTitle = setAccTitle; + public getAccTitle = getAccTitle; + public setAccDescription = setAccDescription; + public getAccDescription = getAccDescription; + public setDiagramTitle = setDiagramTitle; + public getDiagramTitle = getDiagramTitle; + public getConfig = () => getConfig().er; +} diff --git a/packages/mermaid/src/diagrams/er/erDiagram.ts b/packages/mermaid/src/diagrams/er/erDiagram.ts index 1647f181b..29bd36a05 100644 --- a/packages/mermaid/src/diagrams/er/erDiagram.ts +++ b/packages/mermaid/src/diagrams/er/erDiagram.ts @@ -1,12 +1,14 @@ // @ts-ignore: TODO: Fix ts errors import erParser from './parser/erDiagram.jison'; -import erDb from './erDb.js'; +import { ErDB } from './erDb.js'; import erRenderer from './erRenderer-unified.js'; import erStyles from './styles.js'; export const diagram = { parser: erParser, - db: erDb, + get db() { + return new ErDB(); + }, renderer: erRenderer, styles: erStyles, }; diff --git a/packages/mermaid/src/diagrams/er/erRenderer-unified.ts b/packages/mermaid/src/diagrams/er/erRenderer-unified.ts index 902d9829f..7611747fe 100644 --- a/packages/mermaid/src/diagrams/er/erRenderer-unified.ts +++ b/packages/mermaid/src/diagrams/er/erRenderer-unified.ts @@ -4,7 +4,6 @@ import { getDiagramElement } from '../../rendering-util/insertElementsForSize.js import { getRegisteredLayoutAlgorithm, render } from '../../rendering-util/render.js'; import { setupViewPortForSVG } from '../../rendering-util/setupViewPortForSVG.js'; import type { LayoutData } from '../../rendering-util/types.js'; -import { getDirection } from './erDb.js'; import utils from '../../utils.js'; import { select } from 'd3'; @@ -26,7 +25,7 @@ export const draw = async function (text: string, id: string, _version: string, // Workaround as when rendering and setting up the graph it uses flowchart spacing before data4Layout spacing? data4Layout.config.flowchart!.nodeSpacing = conf?.nodeSpacing || 140; data4Layout.config.flowchart!.rankSpacing = conf?.rankSpacing || 80; - data4Layout.direction = getDirection(); + data4Layout.direction = diag.db.getDirection(); data4Layout.markers = ['only_one', 'zero_or_one', 'one_or_more', 'zero_or_more']; data4Layout.diagramId = id; diff --git a/packages/mermaid/src/diagrams/er/parser/erDiagram.spec.js b/packages/mermaid/src/diagrams/er/parser/erDiagram.spec.js index 48cd3edce..3bd2339ba 100644 --- a/packages/mermaid/src/diagrams/er/parser/erDiagram.spec.js +++ b/packages/mermaid/src/diagrams/er/parser/erDiagram.spec.js @@ -1,5 +1,5 @@ import { setConfig } from '../../../config.js'; -import erDb from '../erDb.js'; +import { ErDb } from '../erDb.js'; import erDiagram from './erDiagram.jison'; // jison file setConfig({ @@ -7,6 +7,7 @@ setConfig({ }); describe('when parsing ER diagram it...', function () { + const erDb = new ErDb(); beforeEach(function () { erDiagram.parser.yy = erDb; erDiagram.parser.yy.clear();