TransWikia.com

Pytorch dynamic forward pass

Data Science Asked by Andreas Look on August 29, 2021

Does there exist a fast and convenient way for handling such a problem:

class MyModule(nn.Module):
def __init__(self):
    super(MyModule, self).__init__()
    self.linears = nn.ModuleList([nn.Linear(10, 20) for _ in range(10)])

def forward(self, x, indices):
    x = self.linears[indices](x) 
    return x

You see i want to access different layers in the network conditioned on an additional input, which is also a list. Further i want to process the whole batch at once and the output.shape != input.shape.

One Answer

Here is my understanding of your problem:

# Import
from torch import nn

# Define custom class
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 20) for _ in range(10)])

    def forward(self, x, indices):
        x = self.linears[indices](x) 
        return x

# Intialize custom class    
net = MyModule()

# Access networks layers based on additional input
additional_input = 1
if additional_input == 1:
    idx = 0
    print(net.linears[idx].in_features)

Answered by Brian Spiering on August 29, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP