import torch
import conv_relu_cuda as cr_cuda
from msd_pytorch.msd_module import MSDFinalLayer, init_convolution_weights
import numpy as np
IDX_WEIGHT_START = 3
[docs]class MSDBlockImpl2d(torch.autograd.Function):
[docs] @staticmethod
def forward(ctx, input, dilations, bias, *weights):
depth = len(dilations)
assert depth == len(weights), "number of weights does not match depth"
num_out_channels = sum(w.shape[0] for w in weights)
assert (
len(bias) == num_out_channels
), "number of biases does not match number of output channels from weights"
ctx.dilations = dilations
ctx.depth = depth
result = input.new_empty(
input.shape[0], input.shape[1] + num_out_channels, *input.shape[2:]
)
# Copy input into result buffer
result[:, : input.shape[1]] = input
result_start = input.shape[1]
bias_start = 0
for i in range(depth):
# Extract variables
sub_input = result[:, :result_start]
sub_weight = weights[i]
blocksize = sub_weight.shape[0]
sub_bias = bias[bias_start : bias_start + blocksize]
sub_result = result[:, result_start : result_start + blocksize]
dilation = ctx.dilations[i]
# Compute convolution. conv_relu_forward computes the
# convolution and relu in one pass and stores the
# output in sub_result.
cr_cuda.conv_relu_forward(
sub_input, sub_weight, sub_bias, sub_result, dilation
)
# Update steps etc
result_start += blocksize
bias_start += blocksize
ctx.save_for_backward(bias, result, *weights)
return result
[docs] @staticmethod
def backward(ctx, grad_output):
bias, result, *weights = ctx.saved_tensors
depth = ctx.depth
grad_bias = torch.zeros_like(bias)
# XXX: Could we just overwrite grad_output instead of clone?
gradients = grad_output.clone()
grad_weights = []
result_end = result.shape[1]
bias_end = len(bias)
for i in range(depth):
idx = depth - 1 - i
# Get subsets
sub_weight = weights[idx]
blocksize = sub_weight.shape[0]
result_start = result_end - blocksize
bias_start = bias_end - blocksize
sub_grad_output = gradients[:, result_start:result_end]
sub_grad_input = gradients[:, :result_start]
sub_result = result[:, result_start:result_end]
sub_input = result[:, :result_start]
dilation = ctx.dilations[idx]
# Gradient w.r.t. input: conv_relu_backward_x computes the
# gradient wrt sub_input and adds the gradient to
# sub_grad_input.
cr_cuda.conv_relu_backward_x(
sub_result, sub_grad_output, sub_weight, sub_grad_input, dilation
)
# Gradient w.r.t weights
if ctx.needs_input_grad[i + IDX_WEIGHT_START]:
sub_grad_weight = torch.zeros_like(sub_weight)
cr_cuda.conv_relu_backward_k(
sub_result, sub_grad_output, sub_input, sub_grad_weight, dilation
)
grad_weights.insert(0, sub_grad_weight)
else:
grad_weights.insert(0, None)
# Gradient of Bias
if ctx.needs_input_grad[2]:
sub_grad_bias = grad_bias[bias_start:bias_end]
cr_cuda.conv_relu_backward_bias(
sub_result, sub_grad_output, sub_grad_bias
)
# Update positions etc
result_end -= blocksize
bias_end -= blocksize
grad_input = gradients[:, : weights[0].shape[1]]
return (grad_input, None, grad_bias, *grad_weights)
msdblock2d = MSDBlockImpl2d.apply
[docs]class MSDBlock2d(torch.nn.Module):
[docs] def __init__(self, in_channels, dilations, width=1):
"""Multi-scale dense block
Parameters
----------
in_channels : int
Number of input channels
dilations : tuple of int
Dilation for each convolution-block
width : int
Number of channels per convolution.
Notes
-----
The number of output channels is in_channels + depth * width
"""
super().__init__()
self.kernel_size = (3, 3)
self.width = width
self.dilations = dilations
depth = len(self.dilations)
self.bias = torch.nn.Parameter(torch.Tensor(depth * width))
self.weights = []
for i in range(depth):
n_in = in_channels + width * i
weight = torch.nn.Parameter(torch.Tensor(width, n_in, *self.kernel_size))
self.register_parameter("weight{}".format(i), weight)
self.weights.append(weight)
self.reset_parameters()
[docs] def reset_parameters(self):
for weight in self.weights:
torch.nn.init.kaiming_uniform_(weight, a=np.sqrt(5))
if self.bias is not None:
# TODO: improve
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weights[0])
bound = 1 / np.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
[docs] def forward(self, input):
# This is a bit of a hack, since we require but cannot assume
# that self.parameters() remains sorted in the order that we
# added the parameters.
#
# However, we need to obtain weights in this way, because
# self.weights may become obsolete when used in multi-gpu
# settings when the weights are automatically transferred (by,
# e.g., torch.nn.DataParallel). In that case, self.weights may
# continue to point to the weight parameters on the original
# device, even when the weight parameters have been
# transferred to a different gpu.
bias, *weights = self.parameters()
return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights)
[docs]class MSDModule2d(torch.nn.Module):
[docs] def __init__(
self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
):
"""Create a 2-dimensional MSD Module
:param c_in: # of input channels
:param c_out: # of output channels
:param depth: # of layers
:param width: # the width of the module
:param dilations: `list(int)`
A list of dilations to use. Default is ``[1, 2, ..., 10]``. A
good alternative is ``[1, 2, 4, 8]``. The dilations are
repeated.
:returns: an MSD module
:rtype: MSDModule2d
"""
super(MSDModule2d, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.depth = depth
self.width = width
self.dilations = [dilations[i % len(dilations)] for i in range(depth)]
self.msd_block = MSDBlock2d(self.c_in, self.dilations, self.width)
self.final_layer = MSDFinalLayer(c_in=c_in + width * depth, c_out=c_out)
self.reset_parameters()
[docs] def reset_parameters(self):
# Initialize weights for hidden layers:
for w in self.msd_block.weights:
init_convolution_weights(
w.data, self.c_in, self.c_out, self.width, self.depth
)
self.msd_block.bias.data.zero_()
self.final_layer.reset_parameters()
[docs] def forward(self, input):
output = self.msd_block(input)
output = self.final_layer(output)
return output