·project

multi-graph-former

A TensorFlow/Keras experiment showing that transformer attention can be read as graph neural network message passing, then generalized from seq2seq to graph2graph.

  • Sequences are just a special class of graphs. The multi-graph-former can process any kind of graph.
  • Supports intra- and inter-graph attention, vert updates, and edge updates with dynamic structure
  • Implemented gated-update mechanism for both vertices and edges with einsum-based operations
  • Example applies multi-graph-former to building a graph-structured hidden state for a recurrent neural network that encodes a sequence of words

Transformer attention as graph neural network message passing

Problem

Transformers are usually introduced as sequence models, while graph neural networks are introduced as models for structured relational data. That separation is useful pedagogically, but it hides the shared geometry. A transformer layer over a token sequence is already doing message passing on a graph: each token is a vertex, attention weights are directed edge weights, and the next token embedding is a weighted aggregation of messages from neighboring vertices.

multi-graph-former starts from that observation and asks what happens if we stop pretending the input must be a line. If a sentence can be treated as a graph, then sequence-to-sequence becomes one special case of graph-to-graph.

The Isometry

A transformer block can be written as:

Q=XWQ,K=XWK,V=XWVQ = XW_Q,\qquad K = XW_K,\qquad V = XW_V A=softmax(QKdk),X=AVWOA = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right), \qquad X' = A V W_O

Read graph-theoretically:

  • XiX_i is the feature vector on vertex ii
  • AijA_{ij} is a learned directed edge weight from token jj to token ii
  • VjV_j is the message emitted by source vertex jj
  • Xi=jAijVjWOX'_i = \sum_j A_{ij} V_j W_O is the destination vertex update

That is message passing on a complete directed graph:

hit+1=ϕ(hit,  jN(i)αijtψ(hjt,ejit))h_i^{t+1} = \phi\left( h_i^t,\; \sum_{j \in \mathcal{N}(i)} \alpha_{ij}^t\, \psi(h_j^t, e_{ji}^t) \right)

For a vanilla transformer, N(i)\mathcal{N}(i) is every token position and ejie_{ji} is mostly implicit. For multi-graph-former, N(i)\mathcal{N}(i) can come from an arbitrary adjacency tensor, and ejie_{ji} can be a real learned edge embedding.

The "isometry" is not that every implementation detail is identical. It is that the same computational shape is preserved:

Transformer objectGraph neural network objectMulti-graph-former generalization
tokenvertexvertex in any named graph
sequence positionnode indexarbitrary graph topology
attention matrixweighted adjacencyalive/dead adjacency from edge state
value vectorsource messagerelation-aware message from source vertex plus edge
self-attentionintragraph message passinggraph self-update
cross-attentionbipartite message passingintergraph update
decoder hidden staterecurrent stateworking-memory graph

Solution

The repo implements the idea as a small TensorFlow/Keras package with three reusable layers:

  • Graph_Attention performs relation-aware, vertex-centric attention. Destination vertices make queries; keys come from edge embeddings; values come from source vertex data concatenated with edge data.
  • Graph_Multihead_Attention repeats that relation-aware attention over multiple heads and concatenates the outputs.
  • Edge_Update updates edge embeddings from source vertices, destination vertices, existing edge features, and adjacency.
  • Smart_Update gates writes into vertices or edges, allowing the model to partially erase old state and write new state.

The main model, WM_Graph_Former, connects three graphs:

  • an input graph with fixed sequence-style edges
  • a working-memory graph initialized with a seed vertex and learned internal structure
  • an output graph that decodes from working memory

Encoding updates let input vertices self-attend, update edges from input to working memory, let working memory attend to input, and then let working memory self-attend. Decoding updates form working-memory-to-output edges, let output vertices attend to working memory, and then let output vertices self-attend.

How The Graph Attention Works

For a source graph SS and destination graph DD, the code carries:

HSRNS×dS,HDRND×dD,ES,DRNS×ND×dE,AS,DRNS×NDH_S \in \mathbb{R}^{N_S \times d_S},\quad H_D \in \mathbb{R}^{N_D \times d_D},\quad E_{S,D} \in \mathbb{R}^{N_S \times N_D \times d_E},\quad A_{S,D} \in \mathbb{R}^{N_S \times N_D}

The attention layer first localizes source vertices to each destination using the adjacency:

H~s,d=As,dHs\tilde{H}_{s,d} = A_{s,d} H_s

Then it computes destination queries, edge-derived keys, and source-edge values:

qd=WQhd,ks,d=WKes,d,vs,d=WV[h~s,d;es,d]q_d = W_Q h_d,\qquad k_{s,d} = W_K e_{s,d},\qquad v_{s,d} = W_V[\tilde{h}_{s,d}; e_{s,d}]

and pools messages into each destination vertex:

αs,d=softmaxs(qdks,dNS),md=sαs,dvs,d\alpha_{s,d} = \operatorname{softmax}_s \left( \frac{q_d^\top k_{s,d}}{\sqrt{N_S}} \right), \qquad m_d = \sum_s \alpha_{s,d}v_{s,d}

This is the transformer attention equation with the relation made explicit. Instead of attention being only a token-token similarity table, the edge itself participates in key and value construction.

What The Demo Shows

Language_WM_Graph_Former wraps the working-memory graph former for token sequences. It embeds input token ids, constructs a doubly linked sequence graph with seq_edges, creates an empty output graph, runs the graph former, and decodes output vertices into token logits.

That demo matters because it shows the intended bridge: a sequence can enter as a simple chain graph, but the model can route through a learned working-memory graph before producing another sequence-shaped output graph. The sequence interface is familiar; the internal state is relational.

Results

This project was an architectural prototype rather than a benchmarked production model. Its result is the working code path and the conceptual compression: transformers, graph attention, cross-attention, and recurrent working memory can be described with one graph update vocabulary.

The later multigraph-nn work expands the same direction into a more general multigraph tensor container and relation-update framework.

Lessons

The useful insight is that "attention" and "graph message passing" are not separate species. A transformer is a graph neural network on a mostly implicit complete graph. Once that graph is made explicit, it becomes natural to ask which edges should exist, which edges should carry features, which graph should store memory, and which graph should receive the decoded answer.

The hard part is engineering discipline. A fully general graph-to-graph recurrent architecture grows a large design surface quickly: vertex updates, edge updates, adjacency thresholds, update schedules, memory size, and decoding strategy. The project is strongest as a clear prototype of the transformer/GNN equivalence and as a stepping stone toward better-scoped graph-memory architectures.

Neighborhood

Related

From Arxiv Reading to ML Systems TasteFrom Arxiv Reading to M...Broadening and Building Beyond Classical Reinforcement LearningBroadening and Building...Design Patterns for AIDesign Patterns for AIBlock Sparse Attention With Block RetrievalBlock Sparse Attention ...Why I rebuilt the siteWhy I rebuilt the siteThe Node Neural Network (NNN)The Node Neural Network...Research and technical writingResearch and technical writ...multigraph-nnmultigraph-nnmulti-graph-former