Home Torch nn.modulelist
Post
Cancel

Torch nn.modulelist

What is nn.ModuleList?

nn.ModuleList will take in multiple layers just as nn.Sequential() does.

1
2
3
4
heads = torch.nn.ModuleList([
            Head(block_size, embed_size, head_size)
            for _ in range(n_heads)
        ])

However, it does not connect the layers and hence does not have the ability to forward any input through the model.

So when do we use this? We can use nn.ModuleList when we want to store layers in parallel that we will not need to connect.

For example in Transformers, we have multiple attention heads. For each of the different attention heads, we will want to give our input x and then later on concatenate all the outputs from the heads.

In this case, we will be able to implement this with the nn.ModuleList() aforementioned.

torch.cat([head(x) for head in heads], 2)

nn.ModuleList() is also iterable.

This post is licensed under CC BY 4.0 by the author.

Understanding Karpathy's Building GPT From Scratch

Torch negative dimension