PAIProcessor
class perforatedai.PAIProcessor()
Abstract base class for processing neuron and dendrite operations.
Processors handle state management and data flow between neurons and dendrites, allowing for custom pre/post processing of modules which have multiple inputs and outputs, rather than the default single tensor input/output. Subclasses must implement the five core processing methods to handle their specific state management needs.
Methods
post_n1
**post_n1(*args, kwargs) [Abstract]
Post-process neuron output before dendrite processing.
Called immediately after the main module/neuron is executed and before any dendrite processing occurs. This method should extract and return only the tensor of the neuron output that should be seen by dendrite operations.
Parameters:
*args (tuple) – Positional arguments, typically containing the neuron output
**kwargs (dict) – Keyword arguments from the neuron output
Returns:
Any – The filtered output to be passed to dendrite processing
post_n2
**post_n2(*args, kwargs) [Abstract]
Post-process dendrite-modified output before final return.
Called after dendrite processing is complete and before passing the final value forward in the network. This method should combine the dendrite-modified output with any stored state to produce the complete output that matches the expected format of the main module.
Parameters:
*args (tuple) – Positional arguments containing the dendrite-modified output
**kwargs (dict) – Keyword arguments from the processing chain
Returns:
Any – The complete output in the format expected by downstream components
pre_d
**pre_d(*args, kwargs) [Abstract]
Pre-process input before dendrite operations.
Filters and prepares inputs for dendrite processing. This method handles special cases such as initial time steps vs. subsequent iterations, ensuring dendrites receive the appropriate inputs (e.g., external inputs vs. internal recurrent state).
Parameters:
*args (tuple) – Positional arguments containing inputs to the PAI module
**kwargs (dict) – Keyword arguments containing inputs to the PAI module
Returns:
tuple – A tuple of (processed_args, processed_kwargs) to pass to dendrite
post_d
**post_d(*args, kwargs) [Abstract]
Post-process dendrite output and manage state.
Processes the output from dendrite operations, storing any state needed for future iterations and returning only the portion that should be combined with the neuron output. This is where recurrent state is saved for the next time step.
Parameters:
*args (tuple) – Positional arguments containing the dendrite output
**kwargs (dict) – Keyword arguments from the dendrite output
Returns:
Any – The filtered dendrite output to be added to the neuron output
clear_processor
clear_processor() [Abstract]
Clear all internal processor state.
Resets the processor by removing all stored state variables. Must be called before saving or safe_tensors will run into errors. Implementations should safely check for attribute existence before deletion to avoid errors.
Returns:
None
Notes
Processors exist to enable simplicity in adding dendrites to modules where forward() is not one tensor in and one tensor out. The main module has one instance which uses post_n1 and post_n2, and each new Dendrite node gets a unique instance to use pre_d and post_d.
For details on creating custom processors, see customization.md in the API directory.
LSTMCellProcessor
class perforatedai.LSTMCellProcessor()
Processor for LSTM cells to handle hidden and cell states.
Manages the separation and recombination of LSTM hidden states (h_t) and cell states (c_t) during dendrite processing. Ensures dendrites only modify the hidden state while preserving cell state integrity across time steps.
Inherits From
PAIProcessor
Methods
post_n1
**post_n1(*args, kwargs)
Extract hidden state from LSTM output for dendrite processing.
Separates the hidden state (h_t) from the cell state (c_t) in the LSTM output tuple. Stores the cell state temporarily since only the hidden state should be modified by dendrites.
Parameters:
*args (tuple) – Contains LSTM output tuple (h_t, c_t) as first element
**kwargs (dict) – Unused keyword arguments
Returns:
torch.Tensor – Hidden state h_t to be passed to dendrite processing
post_n2
**post_n2(*args, kwargs)
Recombine dendrite-modified hidden state with cell state.
Takes the hidden state that has been modified by dendrite operations and combines it with the stored cell state to produce the complete LSTM output tuple.
Parameters:
*args (tuple) – Contains the dendrite-modified hidden state h_t
**kwargs (dict) – Unused keyword arguments
Returns:
tuple – Complete LSTM output (h_t, c_t) where h_t has been modified
pre_d
**pre_d(*args, kwargs)
Filter LSTM input for dendrite based on initialization state.
Checks if this is the first time step (all zeros in h_t) or a subsequent step. For the first step, passes through the original inputs. For subsequent steps, replaces the neuron's hidden state with the dendrite's own internal state from the previous iteration.
Parameters:
*args (tuple) – Contains (input, (h_t, c_t)) where input is the external input and (h_t, c_t) is the neuron's recurrent state
**kwargs (dict) – Keyword arguments to pass through
Returns:
tuple – ((processed_input, processed_state), kwargs) for dendrite call
post_d
**post_d(*args, kwargs)
Extract and store dendrite's LSTM state for next iteration.
Separates the dendrite's hidden and cell states from its output tuple, stores both for use in the next time step, and returns only the hidden state to be combined with the neuron's output.
Parameters:
*args (tuple) – Contains dendrite LSTM output tuple (h_t, c_t)
**kwargs (dict) – Unused keyword arguments
Returns:
torch.Tensor – Hidden state h_t to be added to the neuron output
clear_processor
clear_processor()
Clear all stored LSTM states.
Removes dendrite hidden state (h_t_d), dendrite cell state (c_t_d), and temporarily stored neuron cell state (c_t_n). Safe to call even if attributes don't exist.
Returns:
None
MultiOutputProcessor
class perforatedai.MultiOutputProcessor()
Processor for handling multiple outputs, ignoring later ones.
General processor for modules that return tuples where only the first element should be modified by dendrites. Stores additional outputs and recombines them after dendrite processing.
Methods
post_n1
**post_n1(*args, kwargs)
Save extra outputs and return the first output.
Extracts the first tensor from the output tuple for dendrite processing while storing all additional outputs for later recombination.
Parameters:
*args (tuple) – Contains the module's output tuple
**kwargs (dict) – Unused keyword arguments
Returns:
torch.Tensor – The first tensor of the tuple
post_n2
**post_n2(*args, kwargs)
Combine output with stored extra outputs.
Recombines the dendrite-modified first output with the stored additional outputs to reconstruct the complete output tuple.
Parameters:
*args (tuple) – The first tensor combined with dendrite output
**kwargs (dict) – Unused keyword arguments
Returns:
tuple – The recombined output tuple with the new first output modified
pre_d
**pre_d(*args, kwargs)
Pass through arguments unchanged for dendrite preprocessing.
Parameters:
*args (tuple) – Positional arguments containing inputs to the PAI module
**kwargs (dict) – Keyword arguments containing inputs to the PAI module
Returns:
tuple – (args, kwargs) passed through unchanged
post_d
**post_d(*args, kwargs)
Extract first output for dendrite postprocessing.
Parameters:
*args (tuple) – Contains the dendrite module's output tuple
**kwargs (dict) – Unused keyword arguments
Returns:
torch.Tensor – The first tensor of the tuple
clear_processor
clear_processor()
Clear stored processor state.
Removes the stored extra_out attribute if it exists.
Returns:
None
ResNetPAI
class perforatedai.ResNetPAI(other_resnet)
PAI-compatible ResNet wrapper.
Converts a standard torchvision ResNet model into a format compatible with Perforated AI by wrapping normalization layers in PAISequential containers. This wrapper demonstrates how to adapt predefined models for use with the modules_to_replace configuration option.
Parameters
other_resnet (torchvision.models.resnet.ResNet) – An existing ResNet model to convert to PAI-compatible format.
Attributes
_norm_layer – Normalization layer type from original model.
inplanes (int) – Number of input planes from original model.
dilation (int) – Dilation rate from original model.
groups (int) – Number of groups from original model.
base_width (int) – Base width from original model.
b1 (PAISequential) – Combined first convolution and batch normalization layer.
relu (nn.ReLU) – ReLU activation from original model.
maxpool (nn.MaxPool2d) – Max pooling layer from original model.
layer1 through layer4 (nn.Sequential) – ResNet layer blocks converted to PAI format.
avgpool (nn.AdaptiveAvgPool2d) – Average pooling layer from original model.
fc (nn.Linear) – Fully connected output layer from original model.
Methods
forward
forward(x)
Forward pass through the network.
Parameters:
x (torch.Tensor) – Input tensor to the network
Returns:
torch.Tensor – Output tensor from the network
Examples
>>> import torchvision.models as models
>>> from perforatedai import ResNetPAI
>>>
>>> # Load a standard ResNet
>>> original_resnet = models.resnet18(pretrained=True)
>>>
>>> # Convert to PAI-compatible format
>>> pai_resnet = ResNetPAI(original_resnet)
Notes
All normalization layers should be wrapped in a PAISequential or other wrapped module. When working with a predefined model, this class shows an example of how to create a module for the modules_to_replace configuration option.