Examples

[1]:
%load_ext autoreload
%autoreload 2
[2]:
!pip install torch
!pip install torchvision

# Uncomment one of the following
# !pip install torchextractor  # stable
!pip install git+https://github.com/antoinebrl/torchextractor.git  # latest
# import sys, os; sys.path.insert(0, os.path.abspath("../.."))  # current code

Requirement already satisfied: torch in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (1.8.0)
Requirement already satisfied: typing-extensions in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch) (3.7.4.3)
Requirement already satisfied: dataclasses in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch) (0.8)
Requirement already satisfied: numpy in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch) (1.19.5)
Requirement already satisfied: torchvision in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (0.9.0)
Requirement already satisfied: numpy in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torchvision) (1.19.5)
Requirement already satisfied: torch==1.8.0 in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torchvision) (1.8.0)
Requirement already satisfied: pillow>=4.1.1 in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torchvision) (8.1.2)
Requirement already satisfied: typing-extensions in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch==1.8.0->torchvision) (3.7.4.3)
Requirement already satisfied: dataclasses in /home/antoine/Projects/torchextractor/.env/lib/python3.6/site-packages (from torch==1.8.0->torchvision) (0.8)
[3]:
import torch
import torchvision
import torchextractor as tx
[4]:
model = torchvision.models.resnet18()
dummy_input = torch.rand(7, 3, 224, 224)

List module names

[5]:
for name, module in model.named_modules():
    print(name)

conv1
bn1
relu
maxpool
layer1
layer1.0
layer1.0.conv1
layer1.0.bn1
layer1.0.relu
layer1.0.conv2
layer1.0.bn2
layer1.1
layer1.1.conv1
layer1.1.bn1
layer1.1.relu
layer1.1.conv2
layer1.1.bn2
layer2
layer2.0
layer2.0.conv1
layer2.0.bn1
layer2.0.relu
layer2.0.conv2
layer2.0.bn2
layer2.0.downsample
layer2.0.downsample.0
layer2.0.downsample.1
layer2.1
layer2.1.conv1
layer2.1.bn1
layer2.1.relu
layer2.1.conv2
layer2.1.bn2
layer3
layer3.0
layer3.0.conv1
layer3.0.bn1
layer3.0.relu
layer3.0.conv2
layer3.0.bn2
layer3.0.downsample
layer3.0.downsample.0
layer3.0.downsample.1
layer3.1
layer3.1.conv1
layer3.1.bn1
layer3.1.relu
layer3.1.conv2
layer3.1.bn2
layer4
layer4.0
layer4.0.conv1
layer4.0.bn1
layer4.0.relu
layer4.0.conv2
layer4.0.bn2
layer4.0.downsample
layer4.0.downsample.0
layer4.0.downsample.1
layer4.1
layer4.1.conv1
layer4.1.bn1
layer4.1.relu
layer4.1.conv2
layer4.1.bn2
avgpool
fc

Extract features

[6]:
model = torchvision.models.resnet18()
model = tx.Extractor(model, ["layer1", "layer2", "layer3", "layer4"])

model_output, features = model(dummy_input)
{name: f.shape for name, f in features.items()}
[6]:
{'layer1': torch.Size([7, 64, 56, 56]),
 'layer2': torch.Size([7, 128, 28, 28]),
 'layer3': torch.Size([7, 256, 14, 14]),
 'layer4': torch.Size([7, 512, 7, 7])}

Extract features from nested modules

[7]:
model = torchvision.models.resnet18()
model = tx.Extractor(model, ["layer1", "layer2.1.conv1", "layer3.0.downsample.0", "layer4.0"])

model_output, features = model(dummy_input)
{name: f.shape for name, f in features.items()}
[7]:
{'layer1': torch.Size([7, 64, 56, 56]),
 'layer2.1.conv1': torch.Size([7, 128, 28, 28]),
 'layer3.0.downsample.0': torch.Size([7, 256, 14, 14]),
 'layer4.0': torch.Size([7, 512, 7, 7])}

Filter modules

[8]:
model = torchvision.models.resnet18()
module_filter_fn = lambda module, name: isinstance(module, torch.nn.Conv2d)
model = tx.Extractor(model, module_filter_fn=module_filter_fn)

model_output, features = model(dummy_input)
{name: f.shape for name, f in features.items()}
[8]:
{'conv1': torch.Size([7, 64, 112, 112]),
 'layer1.0.conv1': torch.Size([7, 64, 56, 56]),
 'layer1.0.conv2': torch.Size([7, 64, 56, 56]),
 'layer1.1.conv1': torch.Size([7, 64, 56, 56]),
 'layer1.1.conv2': torch.Size([7, 64, 56, 56]),
 'layer2.0.conv1': torch.Size([7, 128, 28, 28]),
 'layer2.0.conv2': torch.Size([7, 128, 28, 28]),
 'layer2.0.downsample.0': torch.Size([7, 128, 28, 28]),
 'layer2.1.conv1': torch.Size([7, 128, 28, 28]),
 'layer2.1.conv2': torch.Size([7, 128, 28, 28]),
 'layer3.0.conv1': torch.Size([7, 256, 14, 14]),
 'layer3.0.conv2': torch.Size([7, 256, 14, 14]),
 'layer3.0.downsample.0': torch.Size([7, 256, 14, 14]),
 'layer3.1.conv1': torch.Size([7, 256, 14, 14]),
 'layer3.1.conv2': torch.Size([7, 256, 14, 14]),
 'layer4.0.conv1': torch.Size([7, 512, 7, 7]),
 'layer4.0.conv2': torch.Size([7, 512, 7, 7]),
 'layer4.0.downsample.0': torch.Size([7, 512, 7, 7]),
 'layer4.1.conv1': torch.Size([7, 512, 7, 7]),
 'layer4.1.conv2': torch.Size([7, 512, 7, 7])}

ONNX export with named output nodes

[9]:
model = torchvision.models.resnet18()
model = tx.Extractor(model, ["layer3", "layer4"])

torch.onnx.export(model, dummy_input, "resnet.onnx", output_names=["classifier", "layer3", "layer4"])

Custom Operation

[10]:
model = torchvision.models.resnet18()

# Concatenate outputs of every runs
def capture_fn(module, input, output, module_name, feature_maps):
    if module_name not in feature_maps:
        feature_maps[module_name] = []
    feature_maps[module_name].append(output)


extractor = tx.Extractor(model, ["layer3", "layer4"], capture_fn=capture_fn)

for i in range(10):
    x = torch.rand(7, 3, 224, 224)
    model(x)

feature_maps = extractor.collect()
for name, features in feature_maps.items():
    print(f"{name}: {len(features)} items")
    for i, f in enumerate(features):
        print(f"    {i+1} - {f.shape}")
layer3: 10 items
    1 - torch.Size([7, 256, 14, 14])
    2 - torch.Size([7, 256, 14, 14])
    3 - torch.Size([7, 256, 14, 14])
    4 - torch.Size([7, 256, 14, 14])
    5 - torch.Size([7, 256, 14, 14])
    6 - torch.Size([7, 256, 14, 14])
    7 - torch.Size([7, 256, 14, 14])
    8 - torch.Size([7, 256, 14, 14])
    9 - torch.Size([7, 256, 14, 14])
    10 - torch.Size([7, 256, 14, 14])
layer4: 10 items
    1 - torch.Size([7, 512, 7, 7])
    2 - torch.Size([7, 512, 7, 7])
    3 - torch.Size([7, 512, 7, 7])
    4 - torch.Size([7, 512, 7, 7])
    5 - torch.Size([7, 512, 7, 7])
    6 - torch.Size([7, 512, 7, 7])
    7 - torch.Size([7, 512, 7, 7])
    8 - torch.Size([7, 512, 7, 7])
    9 - torch.Size([7, 512, 7, 7])
    10 - torch.Size([7, 512, 7, 7])