行人重识别02-05:fast-reid(BoT)-pytorch编程规范(fast-reid为例)2-DefaultTrainer解析-程序员宅基地

技术标签: # 目标追踪  Bot  ReID  pytorch  # 行人重识别  fast-reid  行人重识别  

以下链接是个人关于fast-reid(BoT行人重识别) 所有见解,如有错误欢迎大家指出,我会第一时间纠正。有兴趣的朋友可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞!因为这是对我最大的鼓励。 文末附带 \color{blue}{文末附带} 文末附带 公众号 − \color{blue}{公众号 -} 公众号 海量资源。 \color{blue}{ 海量资源}。 海量资源

行人重识别02-00:fast-reid(BoT)-目录-史上最新无死角讲解

极度推荐的商业级项目: \color{red}{极度推荐的商业级项目:} 极度推荐的商业级项目:这是本人落地的行为分析项目,主要包含(1.行人检测,2.行人追踪,3.行为识别三大模块):行为分析(商用级别)00-目录-史上最新无死角讲解

前言

通过上一篇博客。我们已经知道继承于 HookBase 的类,都存在以下几个方法:

    def before_train(self):    # 在第一次迭代之前调用
    def after_train(self):      # 在最后一次迭代之后调用
    def before_step(self):   # 在每次迭代之前调用
    def after_step(self):      # 在每次迭代之后调用

并且已经知道他是在什么时候被带哦用,同时知道了训练的大致过程。但是hooks是如何创建的,我们需要那些hooks,我们不是很清楚,接下来我们会为大家进行讲解

DefaultTrainer

在 fastreid\engine\defaults.py 文件中,我们可以看到如下源码:

class DefaultTrainer(SimpleTrainer):
    """
    具有默认训练逻辑的培训师,继承于SimpleTrainer.主要包含了以下逻辑
    A trainer with default training logic. Compared to `SimpleTrainer`, it
    contains the following logic in addition:

    # 根据配置文件创建optimizer, scheduler, dataloader
    1. Create model, optimizer, scheduler, dataloader from the given config.

    # 如果指定了模型权重文件,则加载模型权重
    2. Load a checkpoint or `cfg.MODEL.WEIGHTS`, if exists.

    # 注册一些通用的hooks
    3. Register a few common hooks.

    #这是一个标准的简单训练模型流程,可以减少只需要标准培训工作流程的用户的代码样板,
    这意味着这门课对你的训练逻辑做了很多假设,这些假设在新的研究中很容易变得无效
    事实上,任何超出班级:‘SimpleTrainer’太多了,不适合研究。这个类的代码已经注释了它所产生的限制性假设
    It is created to simplify the **standard model training workflow** and reduce code boilerplate
    for users who only need the standard training workflow, with standard features.
    It means this class makes *many assumptions* about your training logic that
    may easily become invalid in a new research. In fact, any assumptions beyond those made in the
    :class:`SimpleTrainer` are too much for research.
    The code of this class has been annotated about restrictive assumptions it mades.
    When they do not work for you, you're encouraged to:
    # 覆盖类方法
    1. Overwrite methods of this class, OR:
    # 用法:class:`SimpleTrainer`,它只进行最小的SGD培训,而不进行其他任何操作。如果需要,可以添加自己的钩子。或者:
    2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
       nothing else. You can then add your own hooks if needed. OR:
    #编写类似`tools/plain_train_net.py`的训练循环`.
    3. Write your own training loop similar to `tools/plain_train_net.py`.

    # 还要注意这个类的属性,就像这个文件中的其他函数/类一样,他是不稳定的。因为它是用来表示
    “常见的默认行为”。它只能保证与fastreid中的标准模型和培训工作流一起工作。为了获得更稳定的行为,
    可以使用其他公共api编写自己的训练逻辑。
    Also note that the behavior of this class, like other functions/classes in
    this file, is not stable, since it is meant to represent the "common default behavior".
    It is only guaranteed to work well with the standard models and training workflow in fastreid.
    To obtain more stable behavior, write your own training logic with other public APIs.
    Attributes:
        scheduler: # 学习策略
        checkpointer (DetectionCheckpointer): # 模型参数检测加载
        cfg (CfgNode):cfg配置文件
    Examples:
    .. code-block:: python
        trainer = DefaultTrainer(cfg)
        trainer.resume_or_load()  # load last checkpoint or MODEL.WEIGHTS
        trainer.train()
    """

    def __init__(self, cfg):
        """
        Args:
            cfg (CfgNode):
        """
        # 创建记录答应日志的类对象
        logger = logging.getLogger("fastreid")
        if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for fastreid
            setup_logger()

        # Assume these objects must be constructed in this order.
        # 创建训练数据及迭代器
        data_loader = self.build_train_loader(cfg)
        # 自动计算一些配置参数,如共迭代多少次
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        # 根据配置参数构建模型
        model = self.build_model(cfg)
        # 根据配置构建优化器
        optimizer = self.build_optimizer(cfg, model)

        # For training, wrap with DDP. But don't need this for inference.
        # 对于培训,用DDP包装。但不需要这个来推断。
        if comm.get_world_size() > 1:
            # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True`
            # for part of the parameters is not updated.
            model = DistributedDataParallel(
                model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
            )

        super().__init__(model, data_loader, optimizer)

        # 设置学习率衰减策略
        self.scheduler = self.build_lr_scheduler(cfg, optimizer)

        # Assume no other objects need to be checkpointed.
        # We can later make it checkpoint the stateful hooks
        # 加载指定的模型参数
        self.checkpointer = Checkpointer(
            # Assume you want to save checkpoints together with logs/statistics
            model,
            cfg.OUTPUT_DIR,
            save_to_disk=comm.is_main_process(),
            optimizer=optimizer,
            scheduler=self.scheduler,
        )
        # 初始化迭代次数
        self.start_iter = 0
        if cfg.SOLVER.SWA.ENABLED:
            self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER
        else:
            self.max_iter = cfg.SOLVER.MAX_ITER
        self.cfg = cfg

        # 创建hooks,并且注册hooks
        self.register_hooks(self.build_hooks())

    def resume_or_load(self, resume=True):
        """
        如果resume==True表示接着之前的迭代次数训练,否则从0开始训练
        If `resume==True`, and last checkpoint exists, resume from it.
        Otherwise, load a model specified by the config.
        Args:
            resume (bool): whether to do resume or not
        """
		......
    def build_hooks(self):
        """
        构建一个默认的hooks列表,包含了timing,checkpointing, lr scheduling, precise BN, writing events
        可以理解为把这些类,或者函数放入到一个容器中,需要他的时候再把他取出来进行调用
        Build a list of default hooks, including timing, evaluation,
        checkpointing, lr scheduling, precise BN, writing events.
        Returns:
            list[HookBase]:
        """
		......

        return ret

    def build_writers(self):
        """
        主要用于写入log日志等等
        Build a list of writers to be used. By default it contains
        writers that write metrics to the screen,
        a json file, and a tensorboard event file respectively.
        If you'd like a different list of writers, you can overwrite it in
        your trainer.
        Returns:
            list[EventWriter]: a list of :class:`EventWriter` objects.
        It is now implemented by:
        .. code-block:: python
            return [
                CommonMetricPrinter(self.max_iter),
                JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
                TensorboardXWriter(self.cfg.OUTPUT_DIR),
            ]
        """
        # Assume the default print/log frequency.
        return [
            # It may not always print what you want to see, since it prints "common" metrics only.
            CommonMetricPrinter(self.max_iter),
            JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")),
            TensorboardXWriter(self.cfg.OUTPUT_DIR),
        ]

    def train(self):
        """
        Run training.
        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        # 调用父类的训练函数
        super().train(self.start_iter, self.max_iter)
        # 等待训练完成之后进行,返回最后一次的评估结果
        if comm.is_main_process():
            assert hasattr(
                self, "_last_eval_results"
            ), "No evaluation results obtained during training!"
            # verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

    @classmethod
    def build_model(cls, cfg):
        """
        根据配置信息cfg构建模型
        Returns:
            torch.nn.Module:
        It now calls :func:`fastreid.modeling.build_model`.
        Overwrite it if you'd like a different model.
        """
        model = build_model(cfg)
        # logger = logging.getLogger(__name__)
        # logger.info("Model:\n{}".format(model))
        return model

    @classmethod
    def build_optimizer(cls, cfg, model):
        """
        根据配置参数构建优化器
        Returns:
            torch.optim.Optimizer:
        It now calls :func:`fastreid.solver.build_optimizer`.
        Overwrite it if you'd like a different optimizer.
        """
        return build_optimizer(cfg, model)

    @classmethod
    def build_lr_scheduler(cls, cfg, optimizer):
        """
        根据配置参数指定学习率衰减策略
        It now calls :func:`fastreid.solver.build_lr_scheduler`.
        Overwrite it if you'd like a different scheduler.
        """
        return build_lr_scheduler(cfg, optimizer)

    @classmethod
    def build_train_loader(cls, cfg):
        """
        构建一个训练数据迭代器
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_detection_train_loader`.
        Overwrite it if you'd like a different data loader.
        """
        logger = logging.getLogger(__name__)
        logger.info("Prepare training set")
        return build_reid_train_loader(cfg)

    @classmethod
    def build_test_loader(cls, cfg, dataset_name):
        """
        构建一个测试数据迭代器
        Returns:
            iterable
        It now calls :func:`fastreid.data.build_detection_test_loader`.
        Overwrite it if you'd like a different data loader.
        """
        return build_reid_test_loader(cfg, dataset_name)

    @classmethod
    def build_evaluator(cls, cfg, num_query, output_dir=None):
        """
        构建评估器
        """
        return ReidEvaluator(cfg, num_query, output_dir)

    @classmethod
    def test(cls, cfg, model, evaluators=None):
        """
        对模型进行评估
        Args:
            cfg (CfgNode):
            model (nn.Module):
            evaluators (list[DatasetEvaluator] or None): if None, will call
                :meth:`build_evaluator`. Otherwise, must have the same length as
                `cfg.DATASETS.TEST`.
        Returns:
            dict: a dict of result metrics
        """
        # 用于log日志的保存
        logger = logging.getLogger(__name__)
        # 检测是evaluators是否为正确的评估器
        if isinstance(evaluators, DatasetEvaluator):
            evaluators = [evaluators]

        # 如果evaluators不为none,则对evaluators的长度进行检测
        if evaluators is not None:
            assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
                len(cfg.DATASETS.TEST), len(evaluators)
            )
        # 创建一个字典,用于结果保存
        results = OrderedDict()
        
        # 对多个数据集进行评估 
        for idx, dataset_name in enumerate(cfg.DATASETS.TESTS):
            # 进行log打印,并且创建评估数据迭代器
            logger.info("Prepare testing set")
            data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
            
            # When evaluators are passed in as arguments,
            # implicitly assume that evaluators can be created before data_loader.
            if evaluators is not None:
                evaluator = evaluators[idx]
            else:
                try:
                    evaluator = cls.build_evaluator(cfg, num_query)
                except NotImplementedError:
                    logger.warn(
                        "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
                        "or implement its `build_evaluator` method."
                    )
                    results[dataset_name] = {
    }
                    continue
            # 对单个评估数据集进行推断,并且获得推断结果
            results_i = inference_on_dataset(model, data_loader, evaluator)
            # 保存数据集对应的推断结果
            results[dataset_name] = results_i

        # 如果为主进程,则返回一个评估之后的字典
        if comm.is_main_process():
            assert isinstance(
                results, dict
            ), "Evaluator must return a dict on the main process. Got {} instead.".format(
                results
            )
            # 使用csv的格式打印评估结果
            print_csv_format(results)

        if len(results) == 1: results = list(results.values())[0]

        return results

    @staticmethod
    def auto_scale_hyperparams(cfg, data_loader):
        r"""
        根据传入的cfg,推算出一些cfg配置参数,如总迭代次数等等
        This is used for auto-computation actual training iterations,
        because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
        so we need to convert specific hyper-param to training iterations.
        """

        cfg = cfg.clone()
		......
        return cfg

代码领读

从上面代码的注释中,我们可以看到如下:

        # Assume these objects must be constructed in this order.
        # 创建训练数据及迭代器
        data_loader = self.build_train_loader(cfg)
        # 自动计算一些配置参数,如共迭代多少次
        cfg = self.auto_scale_hyperparams(cfg, data_loader)
        # 根据配置参数构建模型
        model = self.build_model(cfg)
        # 根据配置构建优化器
        optimizer = self.build_optimizer(cfg, model)

这里就是创建数据迭代器,优化器,以及模型的过程。对于 def build_hooks(self) 函数,其会去构建所有的 hooks,如 timing,checkpointing, lr scheduling, precise BN, writing events 等等。这些都是hooks,继承于HookBase。

在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_43013761/article/details/108051799

智能推荐

攻防世界_难度8_happy_puzzle_攻防世界困难模式攻略图文-程序员宅基地

文章浏览阅读645次。这个肯定是末尾的IDAT了,因为IDAT必须要满了才会开始一下个IDAT,这个明显就是末尾的IDAT了。,对应下面的create_head()代码。,对应下面的create_tail()代码。不要考虑爆破,我已经试了一下,太多情况了。题目来源:UNCTF。_攻防世界困难模式攻略图文

达梦数据库的导出(备份)、导入_达梦数据库导入导出-程序员宅基地

文章浏览阅读2.9k次,点赞3次,收藏10次。偶尔会用到,记录、分享。1. 数据库导出1.1 切换到dmdba用户su - dmdba1.2 进入达梦数据库安装路径的bin目录,执行导库操作  导出语句:./dexp cwy_init/[email protected]:5236 file=cwy_init.dmp log=cwy_init_exp.log 注释:   cwy_init/init_123..._达梦数据库导入导出

js引入kindeditor富文本编辑器的使用_kindeditor.js-程序员宅基地

文章浏览阅读1.9k次。1. 在官网上下载KindEditor文件,可以删掉不需要要到的jsp,asp,asp.net和php文件夹。接着把文件夹放到项目文件目录下。2. 修改html文件,在页面引入js文件:<script type="text/javascript" src="./kindeditor/kindeditor-all.js"></script><script type="text/javascript" src="./kindeditor/lang/zh-CN.js"_kindeditor.js

STM32学习过程记录11——基于STM32G431CBU6硬件SPI+DMA的高效WS2812B控制方法-程序员宅基地

文章浏览阅读2.3k次,点赞6次,收藏14次。SPI的详情简介不必赘述。假设我们通过SPI发送0xAA,我们的数据线就会变为10101010,通过修改不同的内容,即可修改SPI中0和1的持续时间。比如0xF0即为前半周期为高电平,后半周期为低电平的状态。在SPI的通信模式中,CPHA配置会影响该实验,下图展示了不同采样位置的SPI时序图[1]。CPOL = 0,CPHA = 1:CLK空闲状态 = 低电平,数据在下降沿采样,并在上升沿移出CPOL = 0,CPHA = 0:CLK空闲状态 = 低电平,数据在上升沿采样,并在下降沿移出。_stm32g431cbu6

计算机网络-数据链路层_接收方收到链路层数据后,使用crc检验后,余数为0,说明链路层的传输时可靠传输-程序员宅基地

文章浏览阅读1.2k次,点赞2次,收藏8次。数据链路层习题自测问题1.数据链路(即逻辑链路)与链路(即物理链路)有何区别?“电路接通了”与”数据链路接通了”的区别何在?2.数据链路层中的链路控制包括哪些功能?试讨论数据链路层做成可靠的链路层有哪些优点和缺点。3.网络适配器的作用是什么?网络适配器工作在哪一层?4.数据链路层的三个基本问题(帧定界、透明传输和差错检测)为什么都必须加以解决?5.如果在数据链路层不进行帧定界,会发生什么问题?6.PPP协议的主要特点是什么?为什么PPP不使用帧的编号?PPP适用于什么情况?为什么PPP协议不_接收方收到链路层数据后,使用crc检验后,余数为0,说明链路层的传输时可靠传输

软件测试工程师移民加拿大_无证移民,未受过软件工程师的教育(第1部分)-程序员宅基地

文章浏览阅读587次。软件测试工程师移民加拿大 无证移民,未受过软件工程师的教育(第1部分) (Undocumented Immigrant With No Education to Software Engineer(Part 1))Before I start, I want you to please bear with me on the way I write, I have very little gen...

随便推点

Thinkpad X250 secure boot failed 启动失败问题解决_安装完系统提示secureboot failure-程序员宅基地

文章浏览阅读304次。Thinkpad X250笔记本电脑,装的是FreeBSD,进入BIOS修改虚拟化配置(其后可能是误设置了安全开机),保存退出后系统无法启动,显示:secure boot failed ,把自己惊出一身冷汗,因为这台笔记本刚好还没开始做备份.....根据错误提示,到bios里面去找相关配置,在Security里面找到了Secure Boot选项,发现果然被设置为Enabled,将其修改为Disabled ,再开机,终于正常启动了。_安装完系统提示secureboot failure

C++如何做字符串分割(5种方法)_c++ 字符串分割-程序员宅基地

文章浏览阅读10w+次,点赞93次,收藏352次。1、用strtok函数进行字符串分割原型: char *strtok(char *str, const char *delim);功能:分解字符串为一组字符串。参数说明:str为要分解的字符串,delim为分隔符字符串。返回值:从str开头开始的一个个被分割的串。当没有被分割的串时则返回NULL。其它:strtok函数线程不安全,可以使用strtok_r替代。示例://借助strtok实现split#include <string.h>#include <stdio.h&_c++ 字符串分割

2013第四届蓝桥杯 C/C++本科A组 真题答案解析_2013年第四届c a组蓝桥杯省赛真题解答-程序员宅基地

文章浏览阅读2.3k次。1 .高斯日记 大数学家高斯有个好习惯:无论如何都要记日记。他的日记有个与众不同的地方,他从不注明年月日,而是用一个整数代替,比如:4210后来人们知道,那个整数就是日期,它表示那一天是高斯出生后的第几天。这或许也是个好习惯,它时时刻刻提醒着主人:日子又过去一天,还有多少时光可以用于浪费呢?高斯出生于:1777年4月30日。在高斯发现的一个重要定理的日记_2013年第四届c a组蓝桥杯省赛真题解答

基于供需算法优化的核极限学习机(KELM)分类算法-程序员宅基地

文章浏览阅读851次,点赞17次,收藏22次。摘要:本文利用供需算法对核极限学习机(KELM)进行优化,并用于分类。

metasploitable2渗透测试_metasploitable2怎么进入-程序员宅基地

文章浏览阅读1.1k次。一、系统弱密码登录1、在kali上执行命令行telnet 192.168.26.1292、Login和password都输入msfadmin3、登录成功,进入系统4、测试如下:二、MySQL弱密码登录:1、在kali上执行mysql –h 192.168.26.129 –u root2、登录成功,进入MySQL系统3、测试效果:三、PostgreSQL弱密码登录1、在Kali上执行psql -h 192.168.26.129 –U post..._metasploitable2怎么进入

Python学习之路:从入门到精通的指南_python人工智能开发从入门到精通pdf-程序员宅基地

文章浏览阅读257次。本文将为初学者提供Python学习的详细指南,从Python的历史、基础语法和数据类型到面向对象编程、模块和库的使用。通过本文,您将能够掌握Python编程的核心概念,为今后的编程学习和实践打下坚实基础。_python人工智能开发从入门到精通pdf

推荐文章

热门文章

相关标签