TransformerBlocks.jl
Simple, blazing fast, transformer components.
Basic Usage
using TransformerBlocks
# C: channel size (embedding dimension)
# T: block size (sequence length)
# B: batch size
C, T, B = 10, 5, 3
x = rand(Float32, C, T, B)
# Example 1: Transformer block
block = Block(C)
@assert size(block(x)) == (C, T, B)
# Example 2: Block with mask
using LinearAlgebra
mask = tril(fill(-Inf, T, T), -1)
@assert size(block(x; mask=mask)) == (C, T, B)
# Example 3: Sequential blocks
num_layers = 3
blocks = BlockList([Block(C) for _ in 1:num_layers])
@assert size(blocks(x)) == (C, T, B)API index
TransformerBlocks.BlockTransformerBlocks.BlockListTransformerBlocks.FeedForwardTransformerBlocks.HeadTransformerBlocks.MultiheadAttention
Components
TransformerBlocks.Head — TypeHead(input_dim, head_size; dropout=0)Initializes an instance of the Head type, representing one head of self-attention.
A Head instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size. "HS" is the head size.
The following keyword arguments are supported:
mask(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
HS = 10
head = Head(C,HS)
@assert size(head(rand(Float32, C,T,B))) == (HS,T,B)TransformerBlocks.MultiheadAttention — TypeMultiheadAttention(input_dim, num_heads; head_size=(input_dim ÷ num_heads), dropout=0)Initializes an instance of the MultiheadAttention type, representing multiple heads of parallel self-attention.
The following keyword arguments are supported:
head_size(Defaults toinput_dim/num_heads)dropout(Defaults to 0)
A MultiheadAttention instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (C, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
The following keyword arguments are supported:
mask(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
NH = 4 # Num heads
multihead = MultiheadAttention(C,NH)
@assert size(multihead(rand(Float32, C, T, B))) == (C, T, B)TransformerBlocks.FeedForward — TypeFeedForward(input_dim::Integer; dropout=0)Initializes an instance of the FeedForward type, representing a simple linear layer followed by a non-linearity.
The following keyword arguments are supported:
dropout(Defaults to 0)
A FeedForward instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (C, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
Examples:
C,T,B = 8,3,4
ff = FeedForward(C)
@assert size(ff(rand(Float32, C, T, B))) == (C, T, B)TransformerBlocks.Block — TypeBlock(input_dim; num_heads=1, head_size=(input_dim÷num_heads), dropout=0)Initializes an instance of the Block type, representing a transformer block.
A Block instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
The following keyword arguments are supported:
mask(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
block = Block(C)
@assert size(block(rand(Float32, C,T,B))) == (C,T,B)TransformerBlocks.BlockList — TypeBlockList(input_dim, num_heads; head_size=(input_dim ÷ num_heads), dropout=0)Initializes an instance of the BlockList type, representing a sequence of transformer blocks composed together.
The following keyword arguments are supported:
head_size(Defaults toinput_dim/num_heads)dropout(Defaults to 0)
A BlockList instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
The following keyword arguments are supported:
mask(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
blocklist = BlockList([Block(C), Block(C)])
@assert size(blocklist(rand(Float32, C,T,B))) == (C,T,B)