TransformerBlocks.jl
Simple, blazing fast, transformer components.
Basic Usage
using TransformerBlocks
# C: channel size (embedding dimension)
# T: block size (sequence length)
# B: batch size
C, T, B = 10, 5, 3
x = rand(Float32, C, T, B)
# Example 1: Transformer block
block = Block(C)
@assert size(block(x)) == (C, T, B)
# Example 2: Block with mask
using LinearAlgebra
mask = tril(fill(-Inf, T, T), -1)
@assert size(block(x; mask=mask)) == (C, T, B)
# Example 3: Sequential blocks
num_layers = 3
blocks = BlockList([Block(C) for _ in 1:num_layers])
@assert size(blocks(x)) == (C, T, B)
API index
TransformerBlocks.Block
TransformerBlocks.BlockList
TransformerBlocks.FeedForward
TransformerBlocks.Head
TransformerBlocks.MultiheadAttention
Components
TransformerBlocks.Head
— TypeHead(input_dim, head_size; dropout=0)
Initializes an instance of the Head
type, representing one head of self-attention.
A Head
instance accepts an input array x
of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size. "HS" is the head size.
The following keyword arguments are supported:
mask
(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
HS = 10
head = Head(C,HS)
@assert size(head(rand(Float32, C,T,B))) == (HS,T,B)
TransformerBlocks.MultiheadAttention
— TypeMultiheadAttention(input_dim, num_heads; head_size=(input_dim ÷ num_heads), dropout=0)
Initializes an instance of the MultiheadAttention
type, representing multiple heads of parallel self-attention.
The following keyword arguments are supported:
head_size
(Defaults toinput_dim
/num_heads
)dropout
(Defaults to 0)
A MultiheadAttention
instance accepts an input array x
of dimensions (C, T, B) and outputs an array of dimensions (C, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
The following keyword arguments are supported:
mask
(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
NH = 4 # Num heads
multihead = MultiheadAttention(C,NH)
@assert size(multihead(rand(Float32, C, T, B))) == (C, T, B)
TransformerBlocks.FeedForward
— TypeFeedForward(input_dim::Integer; dropout=0)
Initializes an instance of the FeedForward
type, representing a simple linear layer followed by a non-linearity.
The following keyword arguments are supported:
dropout
(Defaults to 0)
A FeedForward
instance accepts an input array x
of dimensions (C, T, B) and outputs an array of dimensions (C, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
Examples:
C,T,B = 8,3,4
ff = FeedForward(C)
@assert size(ff(rand(Float32, C, T, B))) == (C, T, B)
TransformerBlocks.Block
— TypeBlock(input_dim; num_heads=1, head_size=(input_dim÷num_heads), dropout=0)
Initializes an instance of the Block
type, representing a transformer block.
A Block
instance accepts an input array x
of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
The following keyword arguments are supported:
mask
(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
block = Block(C)
@assert size(block(rand(Float32, C,T,B))) == (C,T,B)
TransformerBlocks.BlockList
— TypeBlockList(input_dim, num_heads; head_size=(input_dim ÷ num_heads), dropout=0)
Initializes an instance of the BlockList
type, representing a sequence of transformer blocks composed together.
The following keyword arguments are supported:
head_size
(Defaults toinput_dim
/num_heads
)dropout
(Defaults to 0)
A BlockList
instance accepts an input array x
of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.
The following keyword arguments are supported:
mask
(Defaults to nothing. Must be of dimensions (T, T).)
Examples:
C,T,B = 8,3,4
blocklist = BlockList([Block(C), Block(C)])
@assert size(blocklist(rand(Float32, C,T,B))) == (C,T,B)