API¶
Extractor¶
-
class
torchextractor.Extractor(model: torch.nn.modules.module.Module, module_names: Optional[Iterable[str]] = None, module_filter_fn: Optional[Callable] = None, capture_fn: Optional[Callable] = None)¶ Bases:
torch.nn.modules.module.ModuleCapture the intermediate feature maps of of model.
- Parameters
- model: nn.Module,
The model to extract features from.
- module_names: list of str, default None
The fully qualified names of the modules producing the relevant feature maps.
- module_filter_fn: callable, default None
A filtering function. Takes a module and module name as input and returns True for modules producing the relevant features. Either module_names or module_filter_fn should be provided but not both at the same time.
Example:
def module_filter_fn(module, name): return isinstance(module, torch.nn.Conv2d)
- capture_fn: callable, default None
Operation to carry at each forward pass. The function should comply to the following interface.
Example:
def capture_fn( module: nn.Module, input: Any, output: Any, module_name:str, feature_maps: Dict[str, Any] ): feature_maps[module_name] = output
Methods
T_destinationadd_module(name, module)Adds a child module to the current module.
apply(fn)Applies
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Returns an iterator over module buffers.
children()Returns an iterator over immediate children modules.
clear_placeholder()Resets the structure holding captured feature maps.
collect()Returns the structure holding the most recent feature maps.
cpu()Moves all model parameters and buffers to the CPU.
cuda([device])Moves all model parameters and buffers to the GPU.
double()Casts all floating point parameters and buffers to
doubledatatype.eval()Sets the module in evaluation mode.
extra_repr()Set the extra representation of the module
float()Casts all floating point parameters and buffers to float datatype.
forward(*args, **kwargs)Performs model computations and collects feature maps
half()Casts all floating point parameters and buffers to
halfdatatype.load_state_dict(state_dict[, strict])Copies parameters and buffers from
state_dictinto this module and its descendants.modules()Returns an iterator over all modules in the network.
named_buffers([prefix, recurse])Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix])Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse])Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Returns an iterator over module parameters.
register_backward_hook(hook)Registers a backward hook on the module.
register_buffer(name, tensor[, persistent])Adds a buffer to the module.
register_forward_hook(hook)Registers a forward hook on the module.
register_forward_pre_hook(hook)Registers a forward pre-hook on the module.
register_full_backward_hook(hook)Registers a backward hook on the module.
register_parameter(name, param)Adds a parameter to the module.
requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
state_dict([destination, prefix, keep_vars])Returns a dictionary containing a whole state of the module.
to(*args, **kwargs)Moves and/or casts the parameters and buffers.
train([mode])Sets the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.xpu([device])Moves all model parameters and buffers to the XPU.
zero_grad([set_to_none])Sets gradients of all model parameters to zero.
__call__
share_memory
Utils¶
-
torchextractor.list_module_names(model: torch.nn.modules.module.Module) → List[str]¶ List names of modules and submodules.
- Parameters
- model: nn.Module
PyTorch model to examine.
- Returns
- list[str]:
List of names
-
torchextractor.find_modules_by_names(model: torch.nn.modules.module.Module, names: Iterable[str]) → Dict[str, torch.nn.modules.module.Module]¶ Find some modules given their fully qualifying names.
- Parameters
- model: nn.Module
PyTorch model to examine.
- names: list of str
List of fully qualifying names.
- Returns
- dict: name -> module
If no match is found for a name, it is not added to the returned structure