【nnUNetv2进阶】四、nnUNetv2 魔改网络-小试牛刀-加入注意力机制ChannelAttention

nnUNet是一个自适应的深度学习框架,专为医学图像分割任务设计。以下是关于nnUNet的详细解释和特点:

自适应框架:nnUNet能够根据具体的医学图像分割任务自动调整模型结构、训练参数等,从而避免了繁琐的手工调参过程。
自动化流程:nnUNet包含了从数据预处理到模型训练、验证及测试的全流程自动化工具,大大简化了使用深度学习进行医学图像分割的复杂度。
自适应网络结构调整:根据输入数据集的特点,nnUNet能够自动选择和配置合适的网络深度、宽度等超参数,确保模型在复杂性和性能之间取得平衡。
Patch-Based Training and Inference:nnUNet使用基于patch级别的训练方法,通过滑窗的方式遍历整个图像进行训练。在推理阶段,也采用类似的方法来生成整个图像的分割结果。这种方法对于处理大尺寸图像或有限显存的情况非常有效。
集成学习与交叉验证:nnUNet还采用了交叉验证策略以最大程度利用有限的数据集,并结合集成学习技术来提高模型预测的稳定性和准确性。
此外,nnUNet还提供了丰富的文档和示例,帮助用户更好地了解和使用该框架。要使用nnUNet,用户需要安装Python和相应的深度学习框架,然后按照官方文档提供的步骤进行操作即可。

总的来说,nnUNet是一个功能强大、易于使用的深度学习框架,特别适用于医学图像分割任务。它的自适应特性、自动化流程和先进的训练策略使得用户能够更高效地构建和训练模型,同时获得更好的性能表现。

之前已经介绍过nnunet的安装、使用以及自定义网络的教程,本文介绍在nnunet中加入ChannelAttention的方法,阅读本文前,请确保已经掌握以下内容:

【nnUNetv2实践】一、nnUNetv2安装

【nnUNetv2实践】二、nnUNetv2快速入门-训练验证推理集成一条龙教程

【nnUNetv2进阶】三、nnUNetv2 自定义网络-发paper必会-CSDN博客

本文介绍在nnunet中加入ChannelAttention的方法,ChannelAttention是一种非常简单的注意力机制,非常适合魔改网络练手之用,更高级的魔改教程后续慢慢推出。

一、ChannelAttention

ChannelAttention就是通道注意力机制,其2D代码非常简单,这里不过多介绍其原理,各位朋友可自行搜索其原理。

class ChannelAttention(nn.Module):
    def __init__(self, channels: int) -> None:
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.act(self.fc(self.pool(x)))

二、nnunet加入ChannelAttention

之前的教程已经提到过,nnunet的网络需要在dynamic-network-architectures中修改,并在数据集的plan中修改来实现自己的网络训练。

1、网络结构修改

在dynamic-network-architectures的architectures目录下新建caunet.py:

caunet的代码如下所示:


from typing import Union, Type, List, Tuple

import torch
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
from torch import nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
import numpy as np
from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp


class CAPlainConvUNet(nn.Module):
    def __init__(self,
                 input_channels: int,
                 n_stages: int,
                 features_per_stage: Union[int, List[int], Tuple[int, ...]],
                 conv_op: Type[_ConvNd],
                 kernel_sizes: Union[int, List[int], Tuple[int, ...]],
                 strides: Union[int, List[int], Tuple[int, ...]],
                 n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
                 num_classes: int,
                 n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
                 conv_bias: bool = False,
                 norm_op: Union[None, Type[nn.Module]] = None,
                 norm_op_kwargs: dict = None,
                 dropout_op: Union[None, Type[_DropoutNd]] = None,
                 dropout_op_kwargs: dict = None,
                 nonlin: Union[None, Type[torch.nn.Module]] = None,
                 nonlin_kwargs: dict = None,
                 deep_supervision: bool = False,
                 nonlin_first: bool = False
                 ):
        """
        nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin
        """
        super().__init__()
        if isinstance(n_conv_per_stage, int):
            n_conv_per_stage = [n_conv_per_stage] * n_stages
        if isinstance(n_conv_per_stage_decoder, int):
            n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
        assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \
                                                  f"resolution stages. here: {n_stages}. " \
                                                  f"n_conv_per_stage: {n_conv_per_stage}"
        assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
                                                                f"as we have resolution stages. here: {n_stages} " \
                                                                f"stages, so it should have {n_stages - 1} entries. " \
                                                                f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
        self.encoder = CAPlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
                                        n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
                                        dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,
                                        nonlin_first=nonlin_first)
        self.decoder = CAUNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,
                                   nonlin_first=nonlin_first)
        print('using ca unet...')

    def forward(self, x):
        skips = self.encoder(x)
        return self.decoder(skips)

    def compute_conv_feature_map_size(self, input_size):
        assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
                                                            "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
                                                            "Give input_size=(x, y(, z))!"
        return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)

    @staticmethod
    def initialize(module):
        InitWeights_He(1e-2)(module)


class CAPlainConvEncoder(nn.Module):
    def __init__(self,
                 input_channels: int,
                 n_stages: int,
                 features_per_stage: Union[int, List[int], Tuple[int, ...]],
                 conv_op: Type[_ConvNd],
                 kernel_sizes: Union[int, List[int], Tuple[int, ...]],
                 strides: Union[int, List[int], Tuple[int, ...]],
                 n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],
                 conv_bias: bool = False,
                 norm_op: Union[None, Type[nn.Module]] = None,
                 norm_op_kwargs: dict = None,
                 dropout_op: Union[None, Type[_DropoutNd]] = None,
                 dropout_op_kwargs: dict = None,
                 nonlin: Union[None, Type[torch.nn.Module]] = None,
                 nonlin_kwargs: dict = None,
                 return_skips: bool = False,
                 nonlin_first: bool = False,
                 pool: str = 'conv'
                 ):

        super().__init__()
        if isinstance(kernel_sizes, int):
            kernel_sizes = [kernel_sizes] * n_stages
        if isinstance(features_per_stage, int):
            features_per_stage = [features_per_stage] * n_stages
        if isinstance(n_conv_per_stage, int):
            n_conv_per_stage = [n_conv_per_stage] * n_stages
        if isinstance(strides, int):
            strides = [strides] * n_stages
        assert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"
        assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"
        assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"
        assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \
                                             "Important: first entry is recommended to be 1, else we run strided conv drectly on the input"

        stages = []
        for s in range(n_stages):
            stage_modules = []
            if pool == 'max' or pool == 'avg':
                if (isinstance(strides[s], int) and strides[s] != 1) or \
                        isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):
                    stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))
                conv_stride = 1
            elif pool == 'conv':
                conv_stride = strides[s]
            else:
                raise RuntimeError()
            stage_modules.append(CAStackedConvBlocks(
                n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,
                conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
            ))
            stages.append(nn.Sequential(*stage_modules))
            input_channels = features_per_stage[s]

        self.stages = nn.Sequential(*stages)
        self.output_channels = features_per_stage
        self.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]
        self.return_skips = return_skips

        # we store some things that a potential decoder needs
        self.conv_op = conv_op
        self.norm_op = norm_op
        self.norm_op_kwargs = norm_op_kwargs
        self.nonlin = nonlin
        self.nonlin_kwargs = nonlin_kwargs
        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        self.conv_bias = conv_bias
        self.kernel_sizes = kernel_sizes

    def forward(self, x):
        ret = []
        for s in self.stages:
            x = s(x)
            ret.append(x)
        if self.return_skips:
            return ret
        else:
            return ret[-1]

    def compute_conv_feature_map_size(self, input_size):
        output = np.int64(0)
        for s in range(len(self.stages)):
            if isinstance(self.stages[s], nn.Sequential):
                for sq in self.stages[s]:
                    if hasattr(sq, 'compute_conv_feature_map_size'):
                        output += self.stages[s][-1].compute_conv_feature_map_size(input_size)
            else:
                output += self.stages[s].compute_conv_feature_map_size(input_size)
            input_size = [i // j for i, j in zip(input_size, self.strides[s])]
        return output
    

class CAUNetDecoder(nn.Module):
    def __init__(self,
                 encoder: Union[CAPlainConvEncoder],
                 num_classes: int,
                 n_conv_per_stage: Union[int, Tuple[int, ...], List[int]],
                 deep_supervision,
                 nonlin_first: bool = False,
                 norm_op: Union[None, Type[nn.Module]] = None,
                 norm_op_kwargs: dict = None,
                 dropout_op: Union[None, Type[_DropoutNd]] = None,
                 dropout_op_kwargs: dict = None,
                 nonlin: Union[None, Type[torch.nn.Module]] = None,
                 nonlin_kwargs: dict = None,
                 conv_bias: bool = None
                 ):
        """
        This class needs the skips of the encoder as input in its forward.

        the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder
        are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck
        features and the lowest skip as inputs
        the decoder has two (three) parts in each stage:
        1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage)
        2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge
        3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits?
        :param encoder:
        :param num_classes:
        :param n_conv_per_stage:
        :param deep_supervision:
        """
        super().__init__()
        self.deep_supervision = deep_supervision
        self.encoder = encoder
        self.num_classes = num_classes
        n_stages_encoder = len(encoder.output_channels)
        if isinstance(n_conv_per_stage, int):
            n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)
        assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \
                                                          "resolution stages - 1 (n_stages in encoder - 1), " \
                                                          "here: %d" % n_stages_encoder

        transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op)
        conv_bias = encoder.conv_bias if conv_bias is None else conv_bias
        norm_op = encoder.norm_op if norm_op is None else norm_op
        norm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargs
        dropout_op = encoder.dropout_op if dropout_op is None else dropout_op
        dropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargs
        nonlin = encoder.nonlin if nonlin is None else nonlin
        nonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs


        # we start with the bottleneck and work out way up
        stages = []
        transpconvs = []
        seg_layers = []
        for s in range(1, n_stages_encoder):
            input_features_below = encoder.output_channels[-s]
            input_features_skip = encoder.output_channels[-(s + 1)]
            stride_for_transpconv = encoder.strides[-s]
            transpconvs.append(transpconv_op(
                input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv,
                bias=conv_bias
            ))
            # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output)
            stages.append(CAStackedConvBlocks(
                n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip,
                encoder.kernel_sizes[-(s + 1)], 1,
                conv_bias,
                norm_op,
                norm_op_kwargs,
                dropout_op,
                dropout_op_kwargs,
                nonlin,
                nonlin_kwargs,
                nonlin_first
            ))

            # we always build the deep supervision outputs so that we can always load parameters. If we don't do this
            # then a model trained with deep_supervision=True could not easily be loaded at inference time where
            # deep supervision is not needed. It's just a convenience thing
            seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True))

        self.stages = nn.ModuleList(stages)
        self.transpconvs = nn.ModuleList(transpconvs)
        self.seg_layers = nn.ModuleList(seg_layers)

    def forward(self, skips):
        """
        we expect to get the skips in the order they were computed, so the bottleneck should be the last entry
        :param skips:
        :return:
        """
        lres_input = skips[-1]
        seg_outputs = []
        for s in range(len(self.stages)):
            x = self.transpconvs[s](lres_input)
            x = torch.cat((x, skips[-(s+2)]), 1)
            x = self.stages[s](x)
            if self.deep_supervision:
                seg_outputs.append(self.seg_layers[s](x))
            elif s == (len(self.stages) - 1):
                seg_outputs.append(self.seg_layers[-1](x))
            lres_input = x

        # invert seg outputs so that the largest segmentation prediction is returned first
        seg_outputs = seg_outputs[::-1]

        if not self.deep_supervision:
            r = seg_outputs[0]
        else:
            r = seg_outputs
        return r

    def compute_conv_feature_map_size(self, input_size):
        """
        IMPORTANT: input_size is the input_size of the encoder!
        :param input_size:
        :return:
        """
        # first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at
        # least have the size of the skip above that (therefore -1)
        skip_sizes = []
        for s in range(len(self.encoder.strides) - 1):
            skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])])
            input_size = skip_sizes[-1]
        # print(skip_sizes)

        assert len(skip_sizes) == len(self.stages)

        # our ops are the other way around, so let's match things up
        output = np.int64(0)
        for s in range(len(self.stages)):
            # print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)])
            # conv blocks
            output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)])
            # trans conv
            output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64)
            # segmentation
            if self.deep_supervision or (s == (len(self.stages) - 1)):
                output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64)
        return output


class CAStackedConvBlocks(nn.Module):
    def __init__(self,
                 num_convs: int,
                 conv_op: Type[_ConvNd],
                 input_channels: int,
                 output_channels: Union[int, List[int], Tuple[int, ...]],
                 kernel_size: Union[int, List[int], Tuple[int, ...]],
                 initial_stride: Union[int, List[int], Tuple[int, ...]],
                 conv_bias: bool = False,
                 norm_op: Union[None, Type[nn.Module]] = None,
                 norm_op_kwargs: dict = None,
                 dropout_op: Union[None, Type[_DropoutNd]] = None,
                 dropout_op_kwargs: dict = None,
                 nonlin: Union[None, Type[torch.nn.Module]] = None,
                 nonlin_kwargs: dict = None,
                 nonlin_first: bool = False
                 ):
        """

        :param conv_op:
        :param num_convs:
        :param input_channels:
        :param output_channels: can be int or a list/tuple of int. If list/tuple are provided, each entry is for
        one conv. The length of the list/tuple must then naturally be num_convs
        :param kernel_size:
        :param initial_stride:
        :param conv_bias:
        :param norm_op:
        :param norm_op_kwargs:
        :param dropout_op:
        :param dropout_op_kwargs:
        :param nonlin:
        :param nonlin_kwargs:
        """
        super().__init__()
        if not isinstance(output_channels, (tuple, list)):
            output_channels = [output_channels] * num_convs

        self.convs = nn.Sequential(
            ConvDropoutNormReLU(
                conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, norm_op,
                norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
            ),
            *[
                ConvDropoutNormReLU(
                    conv_op, output_channels[i - 1], output_channels[i], kernel_size, 1, conv_bias, norm_op,
                    norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
                )
                for i in range(1, num_convs-1)
            ],

            CA(
                conv_op, output_channels[-2], output_channels[-1], kernel_size, 1, conv_bias, norm_op,
                norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first
            )

        )

        
        self.act = nonlin(**nonlin_kwargs)

        self.output_channels = output_channels[-1]
        self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride)

    def forward(self, x):
        out = self.convs(x)
        out = self.act(out)
        return out

    def compute_conv_feature_map_size(self, input_size):
        assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \
                                                            "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
                                                            "Give input_size=(x, y(, z))!"
        output = self.convs[0].compute_conv_feature_map_size(input_size)
        size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)]
        for b in self.convs[1:]:
            output += b.compute_conv_feature_map_size(size_after_stride)
        return output



class ConvDropoutNormReLU(nn.Module):
    def __init__(self,
                 conv_op: Type[_ConvNd],
                 input_channels: int,
                 output_channels: int,
                 kernel_size: Union[int, List[int], Tuple[int, ...]],
                 stride: Union[int, List[int], Tuple[int, ...]],
                 conv_bias: bool = False,
                 norm_op: Union[None, Type[nn.Module]] = None,
                 norm_op_kwargs: dict = None,
                 dropout_op: Union[None, Type[_DropoutNd]] = None,
                 dropout_op_kwargs: dict = None,
                 nonlin: Union[None, Type[torch.nn.Module]] = None,
                 nonlin_kwargs: dict = None,
                 nonlin_first: bool = False
                 ):
        super(ConvDropoutNormReLU, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        stride = maybe_convert_scalar_to_list(conv_op, stride)
        self.stride = stride

        kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
        if norm_op_kwargs is None:
            norm_op_kwargs = {}
        if nonlin_kwargs is None:
            nonlin_kwargs = {}

        ops = []

        self.conv = conv_op(
            input_channels,
            output_channels,
            kernel_size,
            stride,
            padding=[(i - 1) // 2 for i in kernel_size],
            dilation=1,
            bias=conv_bias,
        )
        ops.append(self.conv)

        if dropout_op is not None:
            self.dropout = dropout_op(**dropout_op_kwargs)
            ops.append(self.dropout)

        if norm_op is not None:
            self.norm = norm_op(output_channels, **norm_op_kwargs)
            ops.append(self.norm)

        if nonlin is not None:
            self.nonlin = nonlin(**nonlin_kwargs)
            ops.append(self.nonlin)

        if nonlin_first and (norm_op is not None and nonlin is not None):
            ops[-1], ops[-2] = ops[-2], ops[-1]

        self.all_modules = nn.Sequential(*ops)

    def forward(self, x):
        return self.all_modules(x)

    def compute_conv_feature_map_size(self, input_size):
        assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
                                                    "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
                                                    "Give input_size=(x, y(, z))!"
        output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same padding
        return np.prod([self.output_channels, *output_size], dtype=np.int64)



class ConvDropoutNorm(nn.Module):
    def __init__(self,
                 conv_op: Type[_ConvNd],
                 input_channels: int,
                 output_channels: int,
                 kernel_size: Union[int, List[int], Tuple[int, ...]],
                 stride: Union[int, List[int], Tuple[int, ...]],
                 conv_bias: bool = False,
                 norm_op: Union[None, Type[nn.Module]] = None,
                 norm_op_kwargs: dict = None,
                 dropout_op: Union[None, Type[_DropoutNd]] = None,
                 dropout_op_kwargs: dict = None,
                 nonlin: Union[None, Type[torch.nn.Module]] = None,
                 nonlin_kwargs: dict = None,
                 nonlin_first: bool = False
                 ):
        super(ConvDropoutNorm, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        stride = maybe_convert_scalar_to_list(conv_op, stride)
        self.stride = stride

        kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
        if norm_op_kwargs is None:
            norm_op_kwargs = {}
        if nonlin_kwargs is None:
            nonlin_kwargs = {}

        ops = []

        self.conv = conv_op(
            input_channels,
            output_channels,
            kernel_size,
            stride,
            padding=[(i - 1) // 2 for i in kernel_size],
            dilation=1,
            bias=conv_bias,
        )
        ops.append(self.conv)

        if dropout_op is not None:
            self.dropout = dropout_op(**dropout_op_kwargs)
            ops.append(self.dropout)

        if norm_op is not None:
            self.norm = norm_op(output_channels, **norm_op_kwargs)
            ops.append(self.norm)

        self.all_modules = nn.Sequential(*ops)

    def forward(self, x):
        return self.all_modules(x)

    def compute_conv_feature_map_size(self, input_size):
        assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
                                                    "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
                                                    "Give input_size=(x, y(, z))!"
        output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same padding
        return np.prod([self.output_channels, *output_size], dtype=np.int64)
    


class CA(nn.Module):
    def __init__(self,
                 conv_op: Type[_ConvNd],
                 input_channels: int,
                 output_channels: int,
                 kernel_size: Union[int, List[int], Tuple[int, ...]],
                 stride: Union[int, List[int], Tuple[int, ...]],
                 conv_bias: bool = False,
                 norm_op: Union[None, Type[nn.Module]] = None,
                 norm_op_kwargs: dict = None,
                 dropout_op: Union[None, Type[_DropoutNd]] = None,
                 dropout_op_kwargs: dict = None,
                 nonlin: Union[None, Type[torch.nn.Module]] = None,
                 nonlin_kwargs: dict = None,
                 nonlin_first: bool = False
                 ):
        super(CA, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        stride = maybe_convert_scalar_to_list(conv_op, stride)
        self.stride = stride

        kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)
        if norm_op_kwargs is None:
            norm_op_kwargs = {}
        if nonlin_kwargs is None:
            nonlin_kwargs = {}

        ops = []

        self.conv = conv_op(
            input_channels,
            output_channels,
            kernel_size,
            stride,
            padding=[(i - 1) // 2 for i in kernel_size],
            dilation=1,
            bias=conv_bias,
        )
        ops.append(self.conv)

        if dropout_op is not None:
            self.dropout = dropout_op(**dropout_op_kwargs)
            ops.append(self.dropout)

        if norm_op is not None:
            self.norm = norm_op(output_channels, **norm_op_kwargs)
            ops.append(self.norm)

        self.all_modules = nn.Sequential(*ops)
        self.ca = ChannelAttention(conv_op=conv_op, channels=output_channels)

    def forward(self, x):
        x =  self.all_modules(x)
        x = self.ca(x) * x
        # x = self.sa(x) * x
        return x

    def compute_conv_feature_map_size(self, input_size):
        assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \
                                                    "batch channel. Do not give input_size=(b, c, x, y(, z)). " \
                                                    "Give input_size=(x, y(, z))!"
        output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same padding
        return np.prod([self.output_channels, *output_size], dtype=np.int64)
    


class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""

    def __init__(self, conv_op, channels: int) -> None:
        """Initializes the class and sets the basic configurations and instance variables required."""
        super().__init__()
        if conv_op == torch.nn.modules.conv.Conv2d:
            self.pool = nn.AdaptiveAvgPool2d(1)
            self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        elif conv_op == torch.nn.modules.conv.Conv3d:
            self.pool = nn.AdaptiveAvgPool3d(1)
            self.fc = nn.Conv3d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""
        return x * self.act(self.fc(self.pool(x)))





简单说下修改思路:在plainconvunet的StackedConvBlocks中加入ChannelAttention模块。

2、配置文件修改

在完成了模型修改后,还是用上个教程的Task04_Hippocampus数据集来验证(如果没做上个教程的,自行完成数据处理),编辑nnUNet\nnUNet_preprocessed\Dataset004_Hippocampus\nnUNetPlans.json这个配置文件,进行以下改动,把network_class_name改成dynamic_network_architectures.architectures.caunet.CAPlainConvUNet,如下图:

三、模型训练

完成了模型和数据集配置文件的修改后,开始训练模型,使用的数据集还是Task04_Hippocampus,以上的代码支持2d和3d模型,可以使用以下的训练命令:

nnUNetv2_train 4 2d 0  
nnUNetv2_train 4 2d 1 
nnUNetv2_train 4 2d 2  
nnUNetv2_train 4 2d 3 
nnUNetv2_train 4 2d 4  

nnUNetv2_train 4 3d_fullres 0 
nnUNetv2_train 4 3d_fullres 1
nnUNetv2_train 4 3d_fullres 2 
nnUNetv2_train 4 3d_fullres 3 
nnUNetv2_train 4 3d_fullres 4 

可以看到模型已经成功跑起来了:

因为nnunet训练非常的久,实验资源有限,没有完成全部训练,只完成了代码修改及跑通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/549796.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 80—— 删除有序数组中的重复项 II

阅读目录 1. 题目2. 解题思路3. 代码实现 1. 题目 2. 解题思路 让 index指向删除重复元素后数组的新长度;让 st_idx 指向重复元素的起始位置,而 i 指向重复元素的结束位置,duplicate_num代表重复元素的个数;一段重复元素结束后&am…

入侵检测数据预处理 特征工程 面临的问题

数据预处理 对于分类任务来说,由于原始数据可能存在异常、缺失值以及不同特征的取值范围差 异大等问题,对机器学习会产生影响,因此,在进行机器学习模型训练之前,需要先对数据 进行预处理。数据预处理的主要过程包括数据清洗、去量纲、离散化等。 1.数据清洗 对采集到的数据进行…

如何制作文字gif图?一键快速生成gif闪图

大家在聊天的过程中少不了使用gif表情包,但是大家知道这些gif动图怎么制作的吗?下面就来跟大家分享一下gif动图是如何制作的吧!其实,非常的简单无需下载软件只需要使用gif图片制作(https://www.gif5.net/)工…

QT creator 代码中有中文,提示常量中有换行符解决方案

QT creator 代码中有中文,提示常量中有换行符解决方案 参考视频问题问题解决 参考 感谢感谢,非常感谢,有你,让Qt不再困难,困扰我四年的问题解决了!!! https://blog.csdn.net/m0_45866718/article/details/112389513 视频 https://www.bilibili.com/video/BV1Fp4…

GitHub提交PR

本教程只做开源代码库Github工程提交pr的教程,不做其他的深入的讲解 Github和Gitlab的操作类似,只不过Github叫PR,GitLab叫MR,基本上做法是一致的 以开源项目QuickChat为例 https://github.com/Binx98/QuickChat https://github…

CAN网络管理(网络节点)

什么是CAN的网络节点 网络节点是指连接到CAN总线上的设备或模块,每个网络节点都具有唯一的标识符,称为节点ID,用于在CAN总线上进行通信和识别。 如何判断CAN的网络节点是多少 可以根据DBC来定义查看, 以ADCU为例,域控作为主节点,一般外部的像雷达,camera的数据都是向…

Yolo-world使用

1、安装 python pip install ultralytics 前往官网下载模型:https://docs.ultralytics.com/models/yolo-world/#key-features 我这里使用yolov8s-world.pt举例 最简单的使用示例 if __name__ __main__:model YOLO(model/yolov8s-world.pt)results model.pre…

JCVI-筛选blast最佳结果(生物信息学工具-015)

通常,大家会问我们经过了NR注释,SwissProt注释,那么如何进行,如何挑选最佳比对结果? 同理,存在一个问题,如何挑选最佳的blast比对结果?什么事最优的同源序列? 唐海宝老…

DBUtils工具类的使用

1、DBUtils是什么 为了更加简单地使用JDBC,Apache组织提供了一个DBUtils工具,它是操作数据库的一个组件,实现了对JDBC的简单封装,可以在不影响数据库访问性能的情况下简化JDBC的编码工作量。DBUtils工具要有2个作用。 写数据&am…

力扣周赛392复盘

3105. 最长的严格递增或递减子数组 题目 给你一个整数数组 nums 。 返回数组 nums 中 严格递增 或 严格递减 的最长非空子数组的长度。 思考: 返回什么:返回最长非空子数组的长度。return max(decs_len,incs_len); 但实际上我们只需要用一个变量ans就…

记录PS学习查漏补缺

PS学习 PS学习理论快捷键抠图PS专属多软件通用快捷键 PS学习 理论 JPEG (不带透明通道) PNG (带透明通道) 快捷键 抠图 抠图方式 魔棒工具 反选选中区域 CtrlShiftI(反选) 钢笔抠图注意事项 按着Ctrl单击节点 会出现当前节…

漫步密度森林:借助HDBSCAN实现高效数据聚类

文章来源:navigating-the-density-forest-harnessing-hdbscan-for-advanced-data-clustering 2024 年 4 月 9 日 介绍 在数据科学中,聚类算法是揭示数据集内在结构的重要工具。在这些工具中,基于分层密度的噪声应用空间聚类 (HDBSCAN) 作为…

arm中模/数转换器工作原理以及I2C工作原理

ADC介绍 什么是ADC ADC就是模拟到数字转换器(Analog-to-Digital Converter)的缩写。 它是一种电子设备或模块,S3C2440内部拥有一个ADC外设。用于将连续变化的模拟信号转换为离散的数字信号,以便数字系统(如微处理器、微控制器等)能够对其进行处理和分析。 模拟信号:一…

Spring学习(二)

图解: 2.核心容器总结 2.2.1 容器相关 BeanFactory是IoC容器的顶层接口,初始化BeanFactory对象时,加载的bean延迟加载 ApplicationContext接口是Spring容器的核心接口,初始化时bean立即加载 ApplicationContext接口提供基础的be…

【GDAL-Python】10-在Python中可视化多波段卫星影像

文章目录 1-介绍1.1 主要内容1.2 线性拉伸介绍 2-代码实现2.1 数据介绍2.2 代码实现2.3 效果显示 4-参考资料 1-介绍 1.1 主要内容 (1)在本教程中,主要介绍如何使用 Python 和 matplotlib 可视化多波段 Landsat 8 卫星影像组成的真彩色影像…

新能源锂电池起火自燃怎么办?全氟己酮自动灭火装置可以提前预防!

3月28日晚,广州市天河区某小区一居民楼突发火灾。据消防部门通报,此次火灾因室外电动自行车(未充电状态)发生自燃引起,烧毁一辆电动自行车,无人员伤亡。无独有偶,新能源汽车和自行车起火自燃的事…

1.2MHz,固定频率白光LED驱动器

一、产品概述 TX6216是一款升压转换器,设计用于通过单节锂离子电池驱动多达7个串联的白光LED。 TX6216采用电流模式,固定频率架构来调节LED电流,LED电流通过外部电流检测电阻测量。其低104mV反馈电压可降低功率损耗并提高效率。 TX6216具有…

5种方法,教你如何清理接口测试后的测试数据

在接口测试之后,清理测试数据是一个很重要的步骤,以确保下一次测试的准确性和一致性。以下是一些常见的测试数据清理方法: 1. 手动清理: 这是最基本的方法,即手动删除或重置测试数据。您可以通过访问数据库、控制台或…

数据结构学习之路--实现带头双向循环链表的详解(附C源码)

嗨嗨大家~本期带来的内容是:带头双向循环链表的实现。在上期文章中我们提到过带头双向循环链表,那么它的实现又是怎样的呢?今天我们来一探究竟! 目录 前言 一、认识带头双向循环链表 1 认识双向链表 2 带头双向循环链表的定…

【精读文献】Scientific data|2017-2021年中国10米玉米农田变化制图

论文名称:Mapping annual 10-m maize cropland changes in China during 2017–2021 第一作者及通讯作者:Xingang Li, Ying Qu 第一作者单位及通讯作者单位:北京师范大学地理学部 文章发表期刊:《Scientific data》&#xff08…
最新文章