Data Science Asked on December 19, 2021
Does there exist a fast and convenient way for handling such a problem:
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 20) for _ in range(10)])
def forward(self, x, indices):
x = self.linears[indices](x)
return x
You see i want to access different layers in the network conditioned on an additional input, which is also a list. Further i want to process the whole batch at once and the output.shape != input.shape.
Here is my understanding of your problem:
# Import
from torch import nn
# Define custom class
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 20) for _ in range(10)])
def forward(self, x, indices):
x = self.linears[indices](x)
return x
# Intialize custom class
net = MyModule()
# Access networks layers based on additional input
additional_input = 1
if additional_input == 1:
idx = 0
print(net.linears[idx].in_features)
Answered by Brian Spiering on December 19, 2021
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP