Runtime Error coccures when using torchsummary

ghz 8months ago ⋅ 96 views

My code is like below.

import numpy as np
import torch
import torch.nn as nn
import cupy as cp
from torchviz import make_dot
from torchinfo import summary
from torchsummary import summary as summary_
    
def get_filter_torch(*args, **kwargs):
    
    class TraversabilityFilter(nn.Module):
        def __init__(self, w1, w2, w3, w_out, device="cuda", use_bias=False):
            super(TraversabilityFilter, self).__init__()
            self.conv1 = nn.Conv2d(1, 4, 3, dilation=1, padding=0, bias=use_bias)
            self.conv2 = nn.Conv2d(1, 4, 3, dilation=2, padding=0, bias=use_bias)
            self.conv3 = nn.Conv2d(1, 4, 3, dilation=3, padding=0, bias=use_bias)
            self.conv_out = nn.Conv2d(12, 1, 1, bias=use_bias)

            # Set weights.
            self.conv1.weight = nn.Parameter(torch.from_numpy(w1).float())
            self.conv2.weight = nn.Parameter(torch.from_numpy(w2).float())
            self.conv3.weight = nn.Parameter(torch.from_numpy(w3).float())
            self.conv_out.weight = nn.Parameter(torch.from_numpy(w_out).float())

        def __call__(self, elevation_cupy):
            # Convert cupy tensor to pytorch.
            elevation_cupy = elevation_cupy.astype(cp.float32)
            elevation = torch.as_tensor(elevation_cupy, device=self.conv1.weight.device)
            print("input: ",elevation.shape)

            with torch.no_grad():
                out1 = self.conv1(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))
                out2 = self.conv2(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))
                out3 = self.conv3(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))

                out1 = out1[:, :, 2:-2, 2:-2]
                out2 = out2[:, :, 1:-1, 1:-1]
                out = torch.cat((out1, out2, out3), dim=1)
                out = self.conv_out(out.abs())
                out = torch.exp(-out)
                print("output: ",out.shape)

            return out

    traversability_filter = TraversabilityFilter(*args, **kwargs).cuda().eval()
    return traversability_filter
   

# Define the weight values (you need to provide actual weight values)
w1 = np.random.randn(4, 1, 3, 3)  # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w2 = np.random.randn(4, 1, 3, 3)  # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w3 = np.random.randn(4, 1, 3, 3)  # Shape: (out_channels, in_channels, kernel_height, kernel_width)
w_out = np.random.randn(1, 12, 1, 1)  # Shape: (out_channels, in_channels, kernel_height, kernel_width)

model = get_filter_torch(w1, w2, w3, w_out)

cell_n = 200
x = cp.random.randn(cell_n, cell_n, dtype=cp.float32)
output = model(x)
print(model)

input_size=(200,200)
summary(model)
summary_(model, input_size)

When I run the code, the result is like below.

input:  torch.Size([200, 200])
output:  torch.Size([1, 1, 194, 194])
TraversabilityFilter(
  (conv1): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (conv2): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), dilation=(2, 2), bias=False)
  (conv3): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), dilation=(3, 3), bias=False)
  (conv_out): Conv2d(12, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
TraversabilityFilter                     --
├─Conv2d: 1-1                            36
├─Conv2d: 1-2                            36
├─Conv2d: 1-3                            36
├─Conv2d: 1-4                            12
=================================================================
Total params: 120
Trainable params: 120
Non-trainable params: 0
=================================================================
input:  torch.Size([2, 200, 200])
Traceback (most recent call last):
  File "/home/Documents/relay/temp/visual.py", line 72, in <module>
    summary_(model, input_size)
  File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torchsummary/torchsummary.py", line 72, in summary
    model(*x)
  File "/home/Documents/relay/temp/visual.py", line 34, in __call__
    out1 = self.conv1(elevation.view(-1, 1, elevation.shape[0], elevation.shape[1]))
  File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/Documents/2levation_ws/venv/myenv/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Calculated padded input size per channel: (2 x 200). Kernel size: (3 x 3). Kernel size can't be greater than actual input size

The error occurs because the input size specified in the summary function doesn't match the actual input size used when applying the model. Even though the input size during usage is (200, 200), specifying input_size=(1, 200, 200) still results in the same error. How can this be resolved? My pytorch version is like below.

torch==1.13.1+cu116
torchaudio==0.13.1+cu116
torchinfo==1.8.0
torchsummary==1.5.1
torchvision==0.14.1+cu116
torchviz==0.0.2

And I use python 3.8.10.

Answers

The torchsummary package requires input tensors to have batch dimension, channel dimension, and spatial dimensions. In your case, you're providing an input size of (200, 200), which is missing the batch dimension.

To fix this issue, you need to add the batch dimension to your input size. Since your input size is (200, 200), you can specify it as (1, 1, 200, 200), indicating a batch size of 1 and a single channel:

input_size = (1, 1, 200, 200)
summary_(model, input_size)

This should resolve the error you're encountering and allow the torchsummary package to generate the summary of your model correctly.