注意力机制 YOLOv8添加注意力机制_yolov8引入注意力机制-程序员宅基地

技术标签: YOLO  python  yolov8  人工智能  开发语言  

一、注意力机制介绍:

注意力机制(Attention Mechanism)是深度学习中一种重要的技术,它可以帮助模型更好地关注输入数据中的关键信息,从而提高模型的性能。注意力机制最早在自然语言处理领域的序列到序列(seq2seq)模型中得到广泛应用,后来逐渐扩展到了计算机视觉、语音识别等多个领域。 

注意力机制的基本思想是为输入数据的每个部分分配一个权重,这个权重表示该部分对于当前任务的重要程度。在自然语言处理任务中,这通常意味着对输入句子中的每个单词分配一个权重,而在计算机视觉任务中,这可能意味着为输入图像的每个像素或区域分配一个权重。

添加方法

总结:1.在conv.py加入注意力代码

           2.在__init.oy__和tasks.py引用GAM

           3.修改yaml文件

1.在conv.py代码中加入注意力代码

conv.py的路径:ultralytics-main\ultralytics\nn\modules\conv.py 

如图下所示:

在conv.py的最下面添加注意力代码:

代码如下:

#-----------注意力机制代码-----------------
import torch.nn as nn
import torch
 
class GAM_Attention(nn.Module):
    def __init__(self, in_channels,c2, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(in_channels / rate), in_channels)
        )
 
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
            nn.BatchNorm2d(int(in_channels / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(in_channels)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
        x = x * x_channel_att
        x_spatial_att = self.spatial_attention(x).sigmoid()
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    x = torch.randn(1, 64, 20, 20)
    b, c, h, w = x.shape
    net = GAM_Attention(in_channels=c)
    y = net(x)
    print(y.size())

效果如图下所示:

 

 2.注册及引用GAM_Attention

__init__.py文件中引用GAM_Attention

路径:ultralytics-main\ultralytics\nn\modules\__init__.py

如图下:

在__init__.py文件中,在导包里面找到from .conv import和__all__,最后面添加GAM_Attention。

如图下所示:

tasks.py 文件中引用GAM_Attention

路径:ultralytics-main\ultralytics\nn\tasks.py

如图下:

在tasks.py文件中,在导包里面找到from ultralytics.nn.modules最后面添加GAM_Attention

如图下所示:

 在tasks.py里写入调用方式

打开tasks.py,Ctrl键+F查找n = 1(有空格)就可以找到添加的位置,如效果图:

        # """**************add Attention***************"""
        elif m in {GAM_Attention}:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if not output
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]

效果如图下所示:

 3.修改自己的yolov8.yaml文件:

路径如下:ultralytics-main\ultralytics\cfg\models\v8\my_yolov8.yaml

如图下所示:

 修改后的代码如下(可以直接复制到自己的yaml里面):

# Ultralytics YOLO , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8-SPPCSPC.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 3, GAM_Attention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

自己修改可以根据下图(修改后的图)红色箭头是需要修改的:

 

 完成以上就可以训练了

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/2301_79152843/article/details/132555870

智能推荐

关于idea将模块修改为Resources报错:Two modules in a project cannot share the same content root.-程序员宅基地

文章浏览阅读1.8k次。点击: file-->Project Structure-->选择:Models----> 右击文件点击Delete最后修改即可_two modules in a project cannot share the same content root

异步FIFO设计-程序员宅基地

文章浏览阅读590次,点赞26次,收藏26次。同步后的写指针与读指针进行比较,如果它们相等或满足其他预定的条件,就表明FIFO为空,从而产生空逻辑信号。产生空状态信号时,实际FIFO中有数据,相当于提前判断了空状态信号,此时不再进行读FIFO数据操作也是安全的。此时经常使用多余的1bit分别当做读写地址的拓展位,来区分读写地址相同的时候,FIFO的状态是空还是满状态。(格雷码是一种二进制编码方式,其相邻的两个数值只有一个位的差异,这种特性使得格雷码在变化时只涉及到一个位的翻转,从而减少了由于多位同时变化可能带来的不稳定性和错误。

Cannot load configuration class_cannot load configuration class: com.jsh.erp.erpap-程序员宅基地

文章浏览阅读3.0k次。将SDK从16设置为1.8,如下图_cannot load configuration class: com.jsh.erp.erpapplication

java csv 复杂表头_POIUtil(动态复杂表头、动态数据、多线程、合并数据列的POI导出成Excel工具附带生成csv文件)...-程序员宅基地

文章浏览阅读426次。package com.sckj.base.util;import java.io.IOException;import java.lang.reflect.Constructor;import java.util.ArrayList;import java.util.HashMap;import java.util.List;import java.util.Map;import java.ut..._csv能复杂表头吗

【Python】Python中input的使用_python中input的用法-程序员宅基地

文章浏览阅读1.8w次。input有类似c中的scanf函数。Python2中input使用如下:>>>x = input("x: ")x: 3>>>y = input("y" )y: 4>>> print x*y12但是Python3中input使用会有如下的提示:>>> x=input("x:")x:3>>> y=input("y:")y:5>>> print (x*y)Traceback (mos_python中input的用法

sql server异地备份_sql2008r2 异地备份-程序员宅基地

文章浏览阅读562次。服务器名为:jiliangserver 备份的数据库为:JLSDB declare @strsql varchar(1000) declare @strdirname varchar(50) declare @strcmd varchar(50) dec_sql2008r2 异地备份

随便推点

第9讲:使用ajax技术实现增删改查及分页显示功能(jQuery)_ajax实现修改功能-程序员宅基地

文章浏览阅读1.2k次。本讲内容首先讲解jQuery对ajax的支持,分别讲解$.post,$.get,$.ajax等方法,这些方法的参数,使用方法及区别。最后对ajax的综合应用举例:在同一个页面实现新增,修改,删除学校资料,分页列表等功能,前端使用html静态页面,使用MySQL数据库,后台使用servlet技术实现。_ajax实现修改功能

找回word文件的两种密码_word文档保护默认密码是多少-程序员宅基地

文章浏览阅读773次。Word文档的密码也有两种:一种是打开密码,一种是编辑限制两种密码加密后的效果也是不一样的:设置了打开密码的Word文档,是在打开文件的时候需要输入密码,保护文件内容不被其他人看到。当我们输入了正确的word密码,进入到文件之后,就一些都正常了,可以正常阅读、正常编辑word文件。设置了编辑限制的Word文档,打开文件的时候不需要输入密码,打开之后能够查看Word文档内容,但是想要编辑WORD文件的时候,保护文件内容不被修改,需要输入正确的Word密码,将限制编辑取消才能够正常编辑Word文档。两种密码如果_word文档保护默认密码是多少

Cocos2d场景编辑器CocosBuilder使用教程-程序员宅基地

文章浏览阅读162次。在使用Cocos2d-iPhone框架开发iOS游戏的时候,对于每一个场景(CCScene)的编辑是比较麻烦的,好在有外国的牛人提供了非常棒的场景编辑器----CocosBuilder。下面我将详细介绍CocosBuilder结合Cocos2d-iPhone框架的使用。 框架:Cocos2d-iPhone v2.0-rc2 工具:CocosBuilde..._coco2d场景编辑

el-input输入保留两位小数_el-input保留两位小数-程序员宅基地

文章浏览阅读1k次。el-input输入保留两位小数_el-input保留两位小数

MyBatis多条件查询_mybatis if test 多条件-程序员宅基地

文章浏览阅读1.8k次。在MyBatis中进行多条件查询可以使用动态SQL来构建查询语句。_mybatis if test 多条件

Chrome 您的连接不是私密连接 解决办法_chrome您的连接不是私密连接-程序员宅基地

文章浏览阅读1.9w次,点赞8次,收藏25次。您的连接不是私密连接今天打开b站出现不是私密连接被拦截的情况,试了网上好多种方法都没有效果,最后刷新DNS给解决了,特此记录一下先附上错误截图攻击者可能会试图从 x.x.x.x 窃取您的信息(例如:密码、通讯内容或信用卡信息)。了解详情NET::ERR_CERT_INVALID将您访问的部分网页的网址、有限的系统信息以及部分网页内容发送给 Google,以帮助我们提升 Chrome 的安全性。隐私权政策x.x.x.x 通常会使用加密技术来保护您的信息。Google Chrome 此次尝试连接到_chrome您的连接不是私密连接

推荐文章

热门文章

相关标签