TransformerBlocks.jl

Simple, blazing fast, transformer components.

Basic Usage

using TransformerBlocks

# C: channel size (embedding dimension)
# T: block size (sequence length)
# B: batch size
C, T, B = 10, 5, 3
x = rand(Float32, C, T, B)

# Example 1: Transformer block
block = Block(C)
@assert size(block(x)) == (C, T, B)

# Example 2: Block with mask
using LinearAlgebra
mask = tril(fill(-Inf, T, T), -1)
@assert size(block(x; mask=mask)) == (C, T, B)

# Example 3: Sequential blocks
num_layers = 3
blocks = BlockList([Block(C) for _ in 1:num_layers])
@assert size(blocks(x)) == (C, T, B)

API index

Components

TransformerBlocks.HeadType
Head(input_dim, head_size; dropout=0)

Initializes an instance of the Head type, representing one head of self-attention.

A Head instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size. "HS" is the head size.

The following keyword arguments are supported:

  • mask (Defaults to nothing. Must be of dimensions (T, T).)

Examples:

C,T,B = 8,3,4
HS = 10
head = Head(C,HS)
@assert size(head(rand(Float32, C,T,B))) == (HS,T,B)
source
TransformerBlocks.MultiheadAttentionType
MultiheadAttention(input_dim, num_heads; head_size=(input_dim ÷ num_heads), dropout=0)

Initializes an instance of the MultiheadAttention type, representing multiple heads of parallel self-attention.

The following keyword arguments are supported:

  • head_size (Defaults to input_dim / num_heads)
  • dropout (Defaults to 0)

A MultiheadAttention instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (C, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.

The following keyword arguments are supported:

  • mask (Defaults to nothing. Must be of dimensions (T, T).)

Examples:

C,T,B = 8,3,4
NH = 4 # Num heads
multihead = MultiheadAttention(C,NH)
@assert size(multihead(rand(Float32, C, T, B))) == (C, T, B)
source
TransformerBlocks.FeedForwardType
FeedForward(input_dim::Integer; dropout=0)

Initializes an instance of the FeedForward type, representing a simple linear layer followed by a non-linearity.

The following keyword arguments are supported:

  • dropout (Defaults to 0)

A FeedForward instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (C, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.

Examples:

C,T,B = 8,3,4
ff = FeedForward(C)
@assert size(ff(rand(Float32, C, T, B))) == (C, T, B)
source
TransformerBlocks.BlockType
Block(input_dim; num_heads=1, head_size=(input_dim÷num_heads), dropout=0)

Initializes an instance of the Block type, representing a transformer block.

A Block instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.

The following keyword arguments are supported:

  • mask (Defaults to nothing. Must be of dimensions (T, T).)

Examples:

C,T,B = 8,3,4
block = Block(C)
@assert size(block(rand(Float32, C,T,B))) == (C,T,B)
source
TransformerBlocks.BlockListType
BlockList(input_dim, num_heads; head_size=(input_dim ÷ num_heads), dropout=0)

Initializes an instance of the BlockList type, representing a sequence of transformer blocks composed together.

The following keyword arguments are supported:

  • head_size (Defaults to input_dim / num_heads)
  • dropout (Defaults to 0)

A BlockList instance accepts an input array x of dimensions (C, T, B) and outputs an array of dimensions (HS, T, B). "C" is the channel size (embedding dimension). "T" is the block size (number of input tokens). "B" is the batch size.

The following keyword arguments are supported:

  • mask (Defaults to nothing. Must be of dimensions (T, T).)

Examples:

C,T,B = 8,3,4
blocklist = BlockList([Block(C), Block(C)])
@assert size(blocklist(rand(Float32, C,T,B))) == (C,T,B)
source