Simplifying Transformer Blocks: Implementation Details

cover
19 Jun 2024

Authors:

(1) Bobby He, Department of Computer Science, ETH Zurich (Correspondence to: [email protected].);

(2) Thomas Hofmann, Department of Computer Science, ETH Zurich.

Abstract and Introduction

Related Work

Preliminaries

Simplifying Transformer Blocks

Further Experimental Analysis

Discussion, Reproducibility Statement, Acknowledgements and References

A Duality Between Downweighted Residual and Restricting Updates In Linear Layers

B Block Layouts

C Additional Experiments

D Implementation Details

D IMPLEMENTATION DETAILS

In this section we add remaining implementation details that were not discussed in the main paper. We break down our implementation details into two subsections, one for the next-token prediction task on CodeParrot and one for our Crammed BERT (Geiping & Goldstein, 2023) masked language modelling experiments pretrained on the Pile dataset (Gao et al., 2020) and fine-tuned to downstream GLUE benchmark (Wang et al., 2019). To avoid repetition, any details that are mentioned in one subsection but not the other are shared between both subsections. All runtime results on CodeParrot were run on a single A5000 GPU.

D.1 CODEPARROT NEXT-TOKEN PREDICTION

As mentioned, much of our setup is derived from https://huggingface.co/learn/ nlp-course/chapter7/6.

Model The model is a 18-layer GPT-style auto-regressive decoder-only transformer. We use width d = 768, and H = 12 heads in multi-head attention. We remove dropout entirely as our focus is on training speed, and we are always in a single-epoch regime so regularisation hurts training speed. The MLP uses ReLU activation unless stated otherwise, and we use MLP hidden dimension 3072 = 4d. The only exception to this is in Fig. 6, where we reduce the MLP hidden dimension to 1536 = 2d to account for the increased memory requirements of larger depths.

For any of our simplified model we initialise βFF = 0.1 in Eq. (1) to account for the lack of skip, apart from the 18-layer models in Fig. 6, where βFF = 0.2 due to the narrower width.

We use RMSNorm (Zhang & Sennrich, 2019) where applicable with epsilon 1e − 8, and add a final normalisation after the decoder. Sinusoidal positional encodings are used and added at the embedding level.

Training We use AdamW optimiser (Loshchilov & Hutter, 2017) with weight decay 0.1 which we tuned on a small grid, and found to work well for both baselines and our models. We do not apply weight decay to any scalar gain parameter. We clip gradients with with clipping parameter 1, and use epsilon of 1e − 8 and default betas of (0.9, 0.999) in AdamW. As discussed, we use a linear decay rate with 5% of all steps used for linear warmup. The optimal learning rate was tuned in all cases, and for our best (SAS and SAS-P) models, was found to be 1e−3, which exactly matched that of the default Pre-LN. This held true also when we scaled to 72 layers. V-SkipInit needed a lower learning rate for the depth scaling experiments (3e−4 and 1e−4 for depths 18 and 72 respectively). We use batch size of 128 with microbatches of size 32.

Dataset The Codeparrot dataset is a large corpus of 20 million python files from GitHub. We take the dataset, pre-processing and tokeniser from https://huggingface.co/learn/ nlp-course/chapter7/6. We use sequence length T = 128 throughout, and our tokeniser has 50K vocabulary size. Our base experiments train for around 43K steps on batch size 128 and sequence length 128 which is around 700M tokens. In Fig. 8 we scale this to 2B tokens.

Task The model is trained on next-token prediction using cross-entropy loss.

D.2 BERT ENCODER-ONLY

As discussed in Sec. 5, we inherit much of our hyperparameters from the Cramming setup of Geiping & Goldstein (2023), and also base our implementation from their excellent codebase.[9] We highlight important implementation details here.

Model We use a 16-layer encoder only model, with width d = 768 and 12 heads. We use MLP width 3072 = 4d, but now we use GLU (Dauphin et al., 2017) with GeLU activation, which essentially halves the hidden dimension. We use LayerNorm (Ba et al., 2016) for normalisation where applicable with epsilon 1e − 12 as taken from Geiping & Goldstein (2023); we always use a final LN after all the layers. Again, we remove all dropout, and use a sequence length of 128. We found our simplified skipless models prefered smaller MLP block scales and initialise βFF = 0.05.

Parameter Initialisation The initialisations are identical to those in Codeparrot, and are detailed above.

Datasets Like Geiping & Goldstein (2023), we train on the Pile dataset (Gao et al., 2020), with a WordPiece tokeniser of vocabulary size 32768, and a sequence length of 128. Our fastest runs took around 600K steps with microbatch size 64 in 24 hours, which corresponds to around 5B tokens.

Training We again trained with AdamW optimiser, with weight decay 0.1. AdamW had hyparameters [β1, β2] = [0.9, 0.98], and epsilon 1e − 12. We used a microbatch of 64 (to fit on a RTX-2080Ti), and scale the batch size to reach 8192 linearly after 60% of total training like in Geiping & Goldstein (2023). We use the same aggressive learning rate as Geiping & Goldstein (2023), which increase linearly to max value after 75% of all training steps, before linear decay, and tune the maximum learning rate to 3e − 3 for our SAS and SAS-P models. This was slightly too large for the SAS-P model without normalisation, so we reduce to 2e − 3. We inherit the clipping parameter of 0.5 from Geiping & Goldstein (2023).

Fine-tuning We followed the same protocol as Geiping & Goldstein (2023), In particular, we finetune for 5 epochs with fixed hyperparameters across tasks. We found dropout to be important for good downstream performane (unlike during pre-training), and set dropout probability p = 0.1. We use batch size 32, with a maximum learning of 1.5e − 4. We keep other hyperparameters, e.g. the choice of cosine decay and AdamW epsilon 1e − 6, like from Geiping & Goldstein (2023).

Task The model is trained on the masked language modelling task with masking probability 0.25, as in Geiping & Goldstein (2023).

This paper is available on arxiv under CC 4.0 license.


[9] https://github.com/JonasGeiping/cramming