Module SpykeTorch.functional

Expand source code
import torch
import torch.nn as nn
import torch.nn.functional as fn
import numpy as np
from .utils import to_pair

# padding
# pad = (padLeft, padRight, padTop, padBottom)
def pad(input, pad, value=0):
    r"""Applies 2D padding on the input tensor.

    Args:
        input (Tensor): The input tensor.
        pad (tuple): A tuple of 4 integers in the form of (padLeft, padRight, padTop, padBottom)
        value (int or float): The value of padding. Default: 0

    Returns:
        Tensor: Padded tensor.
    """
    return fn.pad(input, pad, value=value)

# pooling
def pooling(input, kernel_size, stride=None, padding=0):
    r"""Performs a 2D max-pooling over an input signal (spike-wave or potentials) composed of several input
    planes.

    Args:
        input (Tensor): The input tensor.
        kernel_size (int or tuple): Size of the pooling window.
        stride (int or tuple, optional): Stride of the pooling window. Default: None
        padding (int or tuple, optional): Size of the padding. Default: 0

    Returns:
        Tensor: The result of the max-pooling operation.
    """
    return fn.max_pool2d(input, kernel_size, stride, padding)

def fire(potentials, threshold=None, return_thresholded_potentials=False):
    r"""Computes the spike-wave tensor from tensor of potentials. If :attr:`threshold` is :attr:`None`, all the neurons
    emit one spike (if the potential is greater than zero) in the last time step.

    Args:
        potentials (Tensor): The tensor of input potentials.
        threshold (float): Firing threshold. Default: None
        return_thresholded_potentials (boolean): If True, the tensor of thresholded potentials will be returned
        as well as the tensor of spike-wave. Default: False

    Returns:
        Tensor: Spike-wave tensor.
    """
    thresholded = potentials.clone().detach()
    if threshold is None:
        thresholded[:-1]=0
    else:
        fn.threshold_(thresholded, threshold, 0)
    if return_thresholded_potentials:
        return thresholded.sign(), thresholded
    return thresholded.sign()

def fire_(potentials, threshold=None):
    r"""The inplace version of :func:`~fire`
    """
    if threshold is None:
        potentials[:-1]=0
    else:
        fn.threshold_(potentials, threshold, 0)
    potentials.sign_()

def threshold(potentials, threshold=None):
    r"""Applies a threshold on potentials by which all of the values lower or equal to the threshold becomes zero.
    If :attr:`threshold` is :attr:`None`, only the potentials corresponding to the final time step will survive.

    Args:
        potentials (Tensor): The tensor of input potentials.
        threshold (float): The threshold value. Default: None

    Returns:
        Tensor: Thresholded potentials.
    """
    outputs = potentials.clone().detach()
    if threshold is None:
        outputs[:-1]=0
    else:
        fn.threshold_(outputs, threshold, 0)
    return outputs

def threshold_(potentials, threshold=None):
    r"""The inplace version of :func:`~threshold`
    """
    if threshold is None:
        potentials[:-1]=0
    else:
        fn.threshold_(potentials, threshold, 0)

# in each position, the most fitted feature will survive (first earliest spike then maximum potential)
# it is assumed that the threshold function is applied on the input potentials
def pointwise_inhibition(thresholded_potentials):
    r"""Performs point-wise inhibition between feature maps. After inhibition, at most one neuron is allowed to fire at each
    position, which is the neuron with the earliest spike time. If the spike times are the same, the neuron with the maximum
    potential will be chosen. As a result, the potential of all of the inhibited neurons will be reset to zero.

    Args:
        thresholded_potentials (Tensor): The tensor of thresholded input potentials.

    Returns:
        Tensor: Inhibited potentials.
    """
    # maximum of each position in each time step
    maximum = torch.max(thresholded_potentials, dim=1, keepdim=True)
    # compute signs for detection of the earliest spike
    clamp_pot = maximum[0].sign()
    # maximum of clamped values is the indices of the earliest spikes
    clamp_pot_max_1 = (clamp_pot.size(0) - clamp_pot.sum(dim = 0, keepdim=True)).long()
    clamp_pot_max_1.clamp_(0,clamp_pot.size(0)-1)
    clamp_pot_max_0 = clamp_pot[-1:,:,:,:]
    # finding winners (maximum potentials between early spikes)
    winners = maximum[1].gather(0, clamp_pot_max_1)
    # generating inhibition coefficient
    coef = torch.zeros_like(thresholded_potentials[0]).unsqueeze_(0)
    coef.scatter_(1, winners,clamp_pot_max_0)
    # applying inhibition to potentials (broadcasting multiplication)
    return torch.mul(thresholded_potentials, coef)

# inhibiting particular features, preventing them to be winners
# inhibited_features is a list of features numbers to be inhibited
def feature_inhibition_(potentials, inhibited_features):
    r"""The inplace version of :func:`~feature_inhibition`
    """
    if len(inhibited_features) != 0:
        potentials[:, inhibited_features, :, :] = 0

def feature_inhibition(potentials, inhibited_features):
    r"""Inhibits specified features (reset the corresponding neurons' potentials to zero).

    Args:
        potentials (Tensor): The tensor of input potentials.
        inhibited_features (List): The list of features to be inhibited.

    Returns:
        Tensor: Inhibited potentials.
    """
    potentials_copy = potentials.clone().detach()
    if len(inhibited_features) != 0:
        feature_inhibition_(potentials_copy, inhibited_features)
    return potentials_copy

# returns list of winners
# inhibition_radius is to increase the chance of diversity among features (if needed)
def get_k_winners(potentials, kwta = 1, inhibition_radius = 0, spikes = None):
    r"""Finds at most :attr:`kwta` winners first based on the earliest spike time, then based on the maximum potential.
    It returns a list of winners, each in a tuple of form (feature, row, column).

    .. note::

        Winners are selected sequentially. Each winner inhibits surrounding neruons in a specific radius in all of the
        other feature maps. Note that only one winner can be selected from each feature map.

    Args:
        potentials (Tensor): The tensor of input potentials.
        kwta (int, optional): The number of winners. Default: 1
        inhibition_radius (int, optional): The radius of lateral inhibition. Default: 0
        spikes (Tensor, optional): Spike-wave corresponding to the input potentials. Default: None

    Returns:
        List: List of winners.
    """
    if spikes is None:
        spikes = potentials.sign()
    # finding earliest potentials for each position in each feature
    maximum = (spikes.size(0) - spikes.sum(dim = 0, keepdim=True)).long()
    maximum.clamp_(0,spikes.size(0)-1)
    values = potentials.gather(dim=0, index=maximum) # gathering values
    # propagating the earliest potential through the whole timesteps
    truncated_pot = spikes * values
    # summation with a high enough value (maximum of potential summation over timesteps) at spike positions
    v = truncated_pot.max() * potentials.size(0)
    truncated_pot.addcmul_(spikes,v)
    # summation over all timesteps
    total = truncated_pot.sum(dim=0,keepdim=True)
    
    total.squeeze_(0)
    global_pooling_size = tuple(total.size())
    winners = []
    for k in range(kwta):
        max_val,max_idx = total.view(-1).max(0)
        if max_val.item() != 0:
            # finding the 3d position of the maximum value
            max_idx_unraveled = np.unravel_index(max_idx.item(),global_pooling_size)
            # adding to the winners list
            winners.append(max_idx_unraveled)
            # preventing the same feature to be the next winner
            total[max_idx_unraveled[0],:,:] = 0
            # columnar inhibition (increasing the chance of leanring diverse features)
            if inhibition_radius != 0:
                rowMin,rowMax = max(0,max_idx_unraveled[-2]-inhibition_radius),min(total.size(-2),max_idx_unraveled[-2]+inhibition_radius+1)
                colMin,colMax = max(0,max_idx_unraveled[-1]-inhibition_radius),min(total.size(-1),max_idx_unraveled[-1]+inhibition_radius+1)
                total[:,rowMin:rowMax,colMin:colMax] = 0
        else:
            break
    return winners

# decrease lateral intencities by factors given in the inhibition_kernel
def intensity_lateral_inhibition(intencities, inhibition_kernel):
    r"""Applies lateral inhibition on intensities. For each location, this inhibition decreases the intensity of the
    surrounding cells that has lower intensities by a specific factor. This factor is relative to the distance of the
    neighbors and are put in the :attr:`inhibition_kernel`.

    Args:
        intencities (Tensor): The tensor of input intensities.
        inhibition_kernel (Tensor): The tensor of inhibition factors.

    Returns:
        Tensor: Inhibited intensities.
    """
    intencities.squeeze_(0)
    intencities.unsqueeze_(1)

    inh_win_size = inhibition_kernel.size(-1)
    rad = inh_win_size//2
    # repeat each value
    values = intencities.reshape(intencities.size(0),intencities.size(1),-1,1)
    values = values.repeat(1,1,1,inh_win_size)
    values = values.reshape(intencities.size(0),intencities.size(1),-1,intencities.size(-1)*inh_win_size)
    values = values.repeat(1,1,1,inh_win_size)
    values = values.reshape(intencities.size(0),intencities.size(1),-1,intencities.size(-1)*inh_win_size)
    # extend patches
    padded = fn.pad(intencities,(rad,rad,rad,rad))
    # column-wise
    patches = padded.unfold(-1,inh_win_size,1)
    patches = patches.reshape(patches.size(0),patches.size(1),patches.size(2),-1,patches.size(3)*patches.size(4))
    patches.squeeze_(-2)
    # row-wise
    patches = patches.unfold(-2,inh_win_size,1).transpose(-1,-2)
    patches = patches.reshape(patches.size(0),patches.size(1),1,-1,patches.size(-1))
    patches.squeeze_(-3)
    # compare each element by its neighbors
    coef = values - patches
    coef.clamp_(min=0).sign_() # "ones" are neighbors greater than center
    # convolution with full stride to get accumulative inhibiiton factor
    factors = fn.conv2d(coef, inhibition_kernel, stride=inh_win_size)
    result = intencities + intencities * factors

    intencities.squeeze_(1)
    intencities.unsqueeze_(0)
    result.squeeze_(1)
    result.unsqueeze_(0)
    return result

# performs local normalization
# on each region (of size radius*2 + 1) the mean value is computed and 
# intensities will be divided by the mean value
# x is a 4D tensor
def local_normalization(input, normalization_radius, eps=1e-12):
    r"""Applies local normalization. on each region (of size radius*2 + 1) the mean value is computed and the
    intensities will be divided by the mean value. The input is a 4D tensor.

    Args:
        input (Tensor): The input tensor of shape (timesteps, features, height, width).
        normalization_radius (int): The radius of normalization window.

    Returns:
        Tensor: Locally normalized tensor.
    """
    # computing local mean by 2d convolution
    kernel = torch.ones(1,1,normalization_radius*2+1,normalization_radius*2+1,device=input.device).float()/((normalization_radius*2+1)**2)
    # rearrange 4D tensor so input channels will be considered as minibatches
    y = input.squeeze(0) # removes minibatch dim which was 1
    y.unsqueeze_(1)  # adds a dimension after channels so previous channels are now minibatches
    means = fn.conv2d(y,kernel,padding=normalization_radius) + eps # computes means
    y = y/means # normalization
    # swap minibatch with channels
    y.squeeze_(1)
    y.unsqueeze_(0)
    return y

Functions

def pad(input, pad, value=0)

Applies 2D padding on the input tensor.

Args

input : Tensor
The input tensor.
pad : tuple
A tuple of 4 integers in the form of (padLeft, padRight, padTop, padBottom)
value : int or float
The value of padding. Default: 0

Returns

Tensor
Padded tensor.
Expand source code
def pad(input, pad, value=0):
    r"""Applies 2D padding on the input tensor.

    Args:
        input (Tensor): The input tensor.
        pad (tuple): A tuple of 4 integers in the form of (padLeft, padRight, padTop, padBottom)
        value (int or float): The value of padding. Default: 0

    Returns:
        Tensor: Padded tensor.
    """
    return fn.pad(input, pad, value=value)
def pooling(input, kernel_size, stride=None, padding=0)

Performs a 2D max-pooling over an input signal (spike-wave or potentials) composed of several input planes.

Args

input : Tensor
The input tensor.
kernel_size : int or tuple
Size of the pooling window.
stride : int or tuple, optional
Stride of the pooling window. Default: None
padding : int or tuple, optional
Size of the padding. Default: 0

Returns

Tensor
The result of the max-pooling operation.
Expand source code
def pooling(input, kernel_size, stride=None, padding=0):
    r"""Performs a 2D max-pooling over an input signal (spike-wave or potentials) composed of several input
    planes.

    Args:
        input (Tensor): The input tensor.
        kernel_size (int or tuple): Size of the pooling window.
        stride (int or tuple, optional): Stride of the pooling window. Default: None
        padding (int or tuple, optional): Size of the padding. Default: 0

    Returns:
        Tensor: The result of the max-pooling operation.
    """
    return fn.max_pool2d(input, kernel_size, stride, padding)
def fire(potentials, threshold=None, return_thresholded_potentials=False)

Computes the spike-wave tensor from tensor of potentials. If :attr:threshold() is :attr:None, all the neurons emit one spike (if the potential is greater than zero) in the last time step.

Args

potentials : Tensor
The tensor of input potentials.
threshold : float
Firing threshold. Default: None
return_thresholded_potentials : boolean
If True, the tensor of thresholded potentials will be returned

as well as the tensor of spike-wave. Default: False

Returns

Tensor
Spike-wave tensor.
Expand source code
def fire(potentials, threshold=None, return_thresholded_potentials=False):
    r"""Computes the spike-wave tensor from tensor of potentials. If :attr:`threshold` is :attr:`None`, all the neurons
    emit one spike (if the potential is greater than zero) in the last time step.

    Args:
        potentials (Tensor): The tensor of input potentials.
        threshold (float): Firing threshold. Default: None
        return_thresholded_potentials (boolean): If True, the tensor of thresholded potentials will be returned
        as well as the tensor of spike-wave. Default: False

    Returns:
        Tensor: Spike-wave tensor.
    """
    thresholded = potentials.clone().detach()
    if threshold is None:
        thresholded[:-1]=0
    else:
        fn.threshold_(thresholded, threshold, 0)
    if return_thresholded_potentials:
        return thresholded.sign(), thresholded
    return thresholded.sign()
def fire_(potentials, threshold=None)

The inplace version of :func:~fire

Expand source code
def fire_(potentials, threshold=None):
    r"""The inplace version of :func:`~fire`
    """
    if threshold is None:
        potentials[:-1]=0
    else:
        fn.threshold_(potentials, threshold, 0)
    potentials.sign_()
def threshold(potentials, threshold=None)

Applies a threshold on potentials by which all of the values lower or equal to the threshold becomes zero. If :attr:threshold() is :attr:None, only the potentials corresponding to the final time step will survive.

Args

potentials : Tensor
The tensor of input potentials.
threshold : float
The threshold value. Default: None

Returns

Tensor
Thresholded potentials.
Expand source code
def threshold(potentials, threshold=None):
    r"""Applies a threshold on potentials by which all of the values lower or equal to the threshold becomes zero.
    If :attr:`threshold` is :attr:`None`, only the potentials corresponding to the final time step will survive.

    Args:
        potentials (Tensor): The tensor of input potentials.
        threshold (float): The threshold value. Default: None

    Returns:
        Tensor: Thresholded potentials.
    """
    outputs = potentials.clone().detach()
    if threshold is None:
        outputs[:-1]=0
    else:
        fn.threshold_(outputs, threshold, 0)
    return outputs
def threshold_(potentials, threshold=None)

The inplace version of :func:~threshold

Expand source code
def threshold_(potentials, threshold=None):
    r"""The inplace version of :func:`~threshold`
    """
    if threshold is None:
        potentials[:-1]=0
    else:
        fn.threshold_(potentials, threshold, 0)
def pointwise_inhibition(thresholded_potentials)

Performs point-wise inhibition between feature maps. After inhibition, at most one neuron is allowed to fire at each position, which is the neuron with the earliest spike time. If the spike times are the same, the neuron with the maximum potential will be chosen. As a result, the potential of all of the inhibited neurons will be reset to zero.

Args

thresholded_potentials : Tensor
The tensor of thresholded input potentials.

Returns

Tensor
Inhibited potentials.
Expand source code
def pointwise_inhibition(thresholded_potentials):
    r"""Performs point-wise inhibition between feature maps. After inhibition, at most one neuron is allowed to fire at each
    position, which is the neuron with the earliest spike time. If the spike times are the same, the neuron with the maximum
    potential will be chosen. As a result, the potential of all of the inhibited neurons will be reset to zero.

    Args:
        thresholded_potentials (Tensor): The tensor of thresholded input potentials.

    Returns:
        Tensor: Inhibited potentials.
    """
    # maximum of each position in each time step
    maximum = torch.max(thresholded_potentials, dim=1, keepdim=True)
    # compute signs for detection of the earliest spike
    clamp_pot = maximum[0].sign()
    # maximum of clamped values is the indices of the earliest spikes
    clamp_pot_max_1 = (clamp_pot.size(0) - clamp_pot.sum(dim = 0, keepdim=True)).long()
    clamp_pot_max_1.clamp_(0,clamp_pot.size(0)-1)
    clamp_pot_max_0 = clamp_pot[-1:,:,:,:]
    # finding winners (maximum potentials between early spikes)
    winners = maximum[1].gather(0, clamp_pot_max_1)
    # generating inhibition coefficient
    coef = torch.zeros_like(thresholded_potentials[0]).unsqueeze_(0)
    coef.scatter_(1, winners,clamp_pot_max_0)
    # applying inhibition to potentials (broadcasting multiplication)
    return torch.mul(thresholded_potentials, coef)
def feature_inhibition_(potentials, inhibited_features)

The inplace version of :func:~feature_inhibition

Expand source code
def feature_inhibition_(potentials, inhibited_features):
    r"""The inplace version of :func:`~feature_inhibition`
    """
    if len(inhibited_features) != 0:
        potentials[:, inhibited_features, :, :] = 0
def feature_inhibition(potentials, inhibited_features)

Inhibits specified features (reset the corresponding neurons' potentials to zero).

Args

potentials : Tensor
The tensor of input potentials.
inhibited_features : List
The list of features to be inhibited.

Returns

Tensor
Inhibited potentials.
Expand source code
def feature_inhibition(potentials, inhibited_features):
    r"""Inhibits specified features (reset the corresponding neurons' potentials to zero).

    Args:
        potentials (Tensor): The tensor of input potentials.
        inhibited_features (List): The list of features to be inhibited.

    Returns:
        Tensor: Inhibited potentials.
    """
    potentials_copy = potentials.clone().detach()
    if len(inhibited_features) != 0:
        feature_inhibition_(potentials_copy, inhibited_features)
    return potentials_copy
def get_k_winners(potentials, kwta=1, inhibition_radius=0, spikes=None)

Finds at most :attr:kwta winners first based on the earliest spike time, then based on the maximum potential. It returns a list of winners, each in a tuple of form (feature, row, column).

Note

Winners are selected sequentially. Each winner inhibits surrounding neruons in a specific radius in all of the other feature maps. Note that only one winner can be selected from each feature map.

Args

potentials : Tensor
The tensor of input potentials.
kwta : int, optional
The number of winners. Default: 1
inhibition_radius : int, optional
The radius of lateral inhibition. Default: 0
spikes : Tensor, optional
Spike-wave corresponding to the input potentials. Default: None

Returns

List
List of winners.
Expand source code
def get_k_winners(potentials, kwta = 1, inhibition_radius = 0, spikes = None):
    r"""Finds at most :attr:`kwta` winners first based on the earliest spike time, then based on the maximum potential.
    It returns a list of winners, each in a tuple of form (feature, row, column).

    .. note::

        Winners are selected sequentially. Each winner inhibits surrounding neruons in a specific radius in all of the
        other feature maps. Note that only one winner can be selected from each feature map.

    Args:
        potentials (Tensor): The tensor of input potentials.
        kwta (int, optional): The number of winners. Default: 1
        inhibition_radius (int, optional): The radius of lateral inhibition. Default: 0
        spikes (Tensor, optional): Spike-wave corresponding to the input potentials. Default: None

    Returns:
        List: List of winners.
    """
    if spikes is None:
        spikes = potentials.sign()
    # finding earliest potentials for each position in each feature
    maximum = (spikes.size(0) - spikes.sum(dim = 0, keepdim=True)).long()
    maximum.clamp_(0,spikes.size(0)-1)
    values = potentials.gather(dim=0, index=maximum) # gathering values
    # propagating the earliest potential through the whole timesteps
    truncated_pot = spikes * values
    # summation with a high enough value (maximum of potential summation over timesteps) at spike positions
    v = truncated_pot.max() * potentials.size(0)
    truncated_pot.addcmul_(spikes,v)
    # summation over all timesteps
    total = truncated_pot.sum(dim=0,keepdim=True)
    
    total.squeeze_(0)
    global_pooling_size = tuple(total.size())
    winners = []
    for k in range(kwta):
        max_val,max_idx = total.view(-1).max(0)
        if max_val.item() != 0:
            # finding the 3d position of the maximum value
            max_idx_unraveled = np.unravel_index(max_idx.item(),global_pooling_size)
            # adding to the winners list
            winners.append(max_idx_unraveled)
            # preventing the same feature to be the next winner
            total[max_idx_unraveled[0],:,:] = 0
            # columnar inhibition (increasing the chance of leanring diverse features)
            if inhibition_radius != 0:
                rowMin,rowMax = max(0,max_idx_unraveled[-2]-inhibition_radius),min(total.size(-2),max_idx_unraveled[-2]+inhibition_radius+1)
                colMin,colMax = max(0,max_idx_unraveled[-1]-inhibition_radius),min(total.size(-1),max_idx_unraveled[-1]+inhibition_radius+1)
                total[:,rowMin:rowMax,colMin:colMax] = 0
        else:
            break
    return winners
def intensity_lateral_inhibition(intencities, inhibition_kernel)

Applies lateral inhibition on intensities. For each location, this inhibition decreases the intensity of the surrounding cells that has lower intensities by a specific factor. This factor is relative to the distance of the neighbors and are put in the :attr:inhibition_kernel.

Args

intencities : Tensor
The tensor of input intensities.
inhibition_kernel : Tensor
The tensor of inhibition factors.

Returns

Tensor
Inhibited intensities.
Expand source code
def intensity_lateral_inhibition(intencities, inhibition_kernel):
    r"""Applies lateral inhibition on intensities. For each location, this inhibition decreases the intensity of the
    surrounding cells that has lower intensities by a specific factor. This factor is relative to the distance of the
    neighbors and are put in the :attr:`inhibition_kernel`.

    Args:
        intencities (Tensor): The tensor of input intensities.
        inhibition_kernel (Tensor): The tensor of inhibition factors.

    Returns:
        Tensor: Inhibited intensities.
    """
    intencities.squeeze_(0)
    intencities.unsqueeze_(1)

    inh_win_size = inhibition_kernel.size(-1)
    rad = inh_win_size//2
    # repeat each value
    values = intencities.reshape(intencities.size(0),intencities.size(1),-1,1)
    values = values.repeat(1,1,1,inh_win_size)
    values = values.reshape(intencities.size(0),intencities.size(1),-1,intencities.size(-1)*inh_win_size)
    values = values.repeat(1,1,1,inh_win_size)
    values = values.reshape(intencities.size(0),intencities.size(1),-1,intencities.size(-1)*inh_win_size)
    # extend patches
    padded = fn.pad(intencities,(rad,rad,rad,rad))
    # column-wise
    patches = padded.unfold(-1,inh_win_size,1)
    patches = patches.reshape(patches.size(0),patches.size(1),patches.size(2),-1,patches.size(3)*patches.size(4))
    patches.squeeze_(-2)
    # row-wise
    patches = patches.unfold(-2,inh_win_size,1).transpose(-1,-2)
    patches = patches.reshape(patches.size(0),patches.size(1),1,-1,patches.size(-1))
    patches.squeeze_(-3)
    # compare each element by its neighbors
    coef = values - patches
    coef.clamp_(min=0).sign_() # "ones" are neighbors greater than center
    # convolution with full stride to get accumulative inhibiiton factor
    factors = fn.conv2d(coef, inhibition_kernel, stride=inh_win_size)
    result = intencities + intencities * factors

    intencities.squeeze_(1)
    intencities.unsqueeze_(0)
    result.squeeze_(1)
    result.unsqueeze_(0)
    return result
def local_normalization(input, normalization_radius, eps=1e-12)

Applies local normalization. on each region (of size radius*2 + 1) the mean value is computed and the intensities will be divided by the mean value. The input is a 4D tensor.

Args

input : Tensor
The input tensor of shape (timesteps, features, height, width).
normalization_radius : int
The radius of normalization window.

Returns

Tensor
Locally normalized tensor.
Expand source code
def local_normalization(input, normalization_radius, eps=1e-12):
    r"""Applies local normalization. on each region (of size radius*2 + 1) the mean value is computed and the
    intensities will be divided by the mean value. The input is a 4D tensor.

    Args:
        input (Tensor): The input tensor of shape (timesteps, features, height, width).
        normalization_radius (int): The radius of normalization window.

    Returns:
        Tensor: Locally normalized tensor.
    """
    # computing local mean by 2d convolution
    kernel = torch.ones(1,1,normalization_radius*2+1,normalization_radius*2+1,device=input.device).float()/((normalization_radius*2+1)**2)
    # rearrange 4D tensor so input channels will be considered as minibatches
    y = input.squeeze(0) # removes minibatch dim which was 1
    y.unsqueeze_(1)  # adds a dimension after channels so previous channels are now minibatches
    means = fn.conv2d(y,kernel,padding=normalization_radius) + eps # computes means
    y = y/means # normalization
    # swap minibatch with channels
    y.squeeze_(1)
    y.unsqueeze_(0)
    return y