multi-graph-former
A TensorFlow/Keras experiment showing that transformer attention can be read as graph neural network message passing, then generalized from seq2seq to graph2graph.
- Sequences are just a special class of graphs. The multi-graph-former can process any kind of graph.
- Supports intra- and inter-graph attention, vert updates, and edge updates with dynamic structure
- Implemented gated-update mechanism for both vertices and edges with einsum-based operations
- Example applies multi-graph-former to building a graph-structured hidden state for a recurrent neural network that encodes a sequence of words
Problem
Transformers are usually introduced as sequence models, while graph neural networks are introduced as models for structured relational data. That separation is useful pedagogically, but it hides the shared geometry. A transformer layer over a token sequence is already doing message passing on a graph: each token is a vertex, attention weights are directed edge weights, and the next token embedding is a weighted aggregation of messages from neighboring vertices.
multi-graph-former starts from that observation and asks what happens if we stop pretending the input must be a line. If a sentence can be treated as a graph, then sequence-to-sequence becomes one special case of graph-to-graph.
The Isometry
A transformer block can be written as:
Read graph-theoretically:
- is the feature vector on vertex
- is a learned directed edge weight from token to token
- is the message emitted by source vertex
- is the destination vertex update
That is message passing on a complete directed graph:
For a vanilla transformer, is every token position and is mostly implicit. For multi-graph-former, can come from an arbitrary adjacency tensor, and can be a real learned edge embedding.
The "isometry" is not that every implementation detail is identical. It is that the same computational shape is preserved:
| Transformer object | Graph neural network object | Multi-graph-former generalization |
|---|---|---|
| token | vertex | vertex in any named graph |
| sequence position | node index | arbitrary graph topology |
| attention matrix | weighted adjacency | alive/dead adjacency from edge state |
| value vector | source message | relation-aware message from source vertex plus edge |
| self-attention | intragraph message passing | graph self-update |
| cross-attention | bipartite message passing | intergraph update |
| decoder hidden state | recurrent state | working-memory graph |
Solution
The repo implements the idea as a small TensorFlow/Keras package with three reusable layers:
Graph_Attentionperforms relation-aware, vertex-centric attention. Destination vertices make queries; keys come from edge embeddings; values come from source vertex data concatenated with edge data.Graph_Multihead_Attentionrepeats that relation-aware attention over multiple heads and concatenates the outputs.Edge_Updateupdates edge embeddings from source vertices, destination vertices, existing edge features, and adjacency.Smart_Updategates writes into vertices or edges, allowing the model to partially erase old state and write new state.
The main model, WM_Graph_Former, connects three graphs:
- an input graph with fixed sequence-style edges
- a working-memory graph initialized with a seed vertex and learned internal structure
- an output graph that decodes from working memory
Encoding updates let input vertices self-attend, update edges from input to working memory, let working memory attend to input, and then let working memory self-attend. Decoding updates form working-memory-to-output edges, let output vertices attend to working memory, and then let output vertices self-attend.
How The Graph Attention Works
For a source graph and destination graph , the code carries:
The attention layer first localizes source vertices to each destination using the adjacency:
Then it computes destination queries, edge-derived keys, and source-edge values:
and pools messages into each destination vertex:
This is the transformer attention equation with the relation made explicit. Instead of attention being only a token-token similarity table, the edge itself participates in key and value construction.
What The Demo Shows
Language_WM_Graph_Former wraps the working-memory graph former for token sequences. It embeds input token ids, constructs a doubly linked sequence graph with seq_edges, creates an empty output graph, runs the graph former, and decodes output vertices into token logits.
That demo matters because it shows the intended bridge: a sequence can enter as a simple chain graph, but the model can route through a learned working-memory graph before producing another sequence-shaped output graph. The sequence interface is familiar; the internal state is relational.
Results
This project was an architectural prototype rather than a benchmarked production model. Its result is the working code path and the conceptual compression: transformers, graph attention, cross-attention, and recurrent working memory can be described with one graph update vocabulary.
The later multigraph-nn work expands the same direction into a more general multigraph tensor container and relation-update framework.
Lessons
The useful insight is that "attention" and "graph message passing" are not separate species. A transformer is a graph neural network on a mostly implicit complete graph. Once that graph is made explicit, it becomes natural to ask which edges should exist, which edges should carry features, which graph should store memory, and which graph should receive the decoded answer.
The hard part is engineering discipline. A fully general graph-to-graph recurrent architecture grows a large design surface quickly: vertex updates, edge updates, adjacency thresholds, update schedules, memory size, and decoding strategy. The project is strongest as a clear prototype of the transformer/GNN equivalence and as a stepping stone toward better-scoped graph-memory architectures.
Neighborhood