API¶
Extractor¶
-
class
torchextractor.
Extractor
(model: torch.nn.modules.module.Module, module_names: Optional[Iterable[str]] = None, module_filter_fn: Optional[Callable] = None, capture_fn: Optional[Callable] = None)¶ Bases:
torch.nn.modules.module.Module
Capture the intermediate feature maps of of model.
- Parameters
- model: nn.Module,
The model to extract features from.
- module_names: list of str, default None
The fully qualified names of the modules producing the relevant feature maps.
- module_filter_fn: callable, default None
A filtering function. Takes a module and module name as input and returns True for modules producing the relevant features. Either module_names or module_filter_fn should be provided but not both at the same time.
Example:
def module_filter_fn(module, name): return isinstance(module, torch.nn.Conv2d)
- capture_fn: callable, default None
Operation to carry at each forward pass. The function should comply to the following interface.
Example:
def capture_fn( module: nn.Module, input: Any, output: Any, module_name:str, feature_maps: Dict[str, Any] ): feature_maps[module_name] = output
Methods
T_destination
add_module
(name, module)Adds a child module to the current module.
apply
(fn)Applies
fn
recursively to every submodule (as returned by.children()
) as well as self.bfloat16
()Casts all floating point parameters and buffers to
bfloat16
datatype.buffers
([recurse])Returns an iterator over module buffers.
children
()Returns an iterator over immediate children modules.
clear_placeholder
()Resets the structure holding captured feature maps.
collect
()Returns the structure holding the most recent feature maps.
cpu
()Moves all model parameters and buffers to the CPU.
cuda
([device])Moves all model parameters and buffers to the GPU.
double
()Casts all floating point parameters and buffers to
double
datatype.eval
()Sets the module in evaluation mode.
extra_repr
()Set the extra representation of the module
float
()Casts all floating point parameters and buffers to float datatype.
forward
(*args, **kwargs)Performs model computations and collects feature maps
half
()Casts all floating point parameters and buffers to
half
datatype.load_state_dict
(state_dict[, strict])Copies parameters and buffers from
state_dict
into this module and its descendants.modules
()Returns an iterator over all modules in the network.
named_buffers
([prefix, recurse])Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children
()Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules
([memo, prefix])Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters
([prefix, recurse])Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters
([recurse])Returns an iterator over module parameters.
register_backward_hook
(hook)Registers a backward hook on the module.
register_buffer
(name, tensor[, persistent])Adds a buffer to the module.
register_forward_hook
(hook)Registers a forward hook on the module.
register_forward_pre_hook
(hook)Registers a forward pre-hook on the module.
register_full_backward_hook
(hook)Registers a backward hook on the module.
register_parameter
(name, param)Adds a parameter to the module.
requires_grad_
([requires_grad])Change if autograd should record operations on parameters in this module.
state_dict
([destination, prefix, keep_vars])Returns a dictionary containing a whole state of the module.
to
(*args, **kwargs)Moves and/or casts the parameters and buffers.
train
([mode])Sets the module in training mode.
type
(dst_type)Casts all parameters and buffers to
dst_type
.xpu
([device])Moves all model parameters and buffers to the XPU.
zero_grad
([set_to_none])Sets gradients of all model parameters to zero.
__call__
share_memory
Utils¶
-
torchextractor.
list_module_names
(model: torch.nn.modules.module.Module) → List[str]¶ List names of modules and submodules.
- Parameters
- model: nn.Module
PyTorch model to examine.
- Returns
- list[str]:
List of names
-
torchextractor.
find_modules_by_names
(model: torch.nn.modules.module.Module, names: Iterable[str]) → Dict[str, torch.nn.modules.module.Module]¶ Find some modules given their fully qualifying names.
- Parameters
- model: nn.Module
PyTorch model to examine.
- names: list of str
List of fully qualifying names.
- Returns
- dict: name -> module
If no match is found for a name, it is not added to the returned structure