mxnet multi-task_cpu mxnet multitask-程序员宅基地

技术标签: muiti-task_mx  multi-task-mx  

import argparse
import os, sys

# for import the docker based mxnet version
mxnet_root = "/mxnet/"
sys.path.insert(0, mxnet_root + 'python')
import mxnet as mx

import importlib
import find_mxnet
import time

sys.path.insert(0, "./settings")
sys.path.insert(0, "../")

import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(message)s')
console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)


class MultiTask_iterator(mx.io.DataIter):

    '''multi label mnist iterator'''
    def __init__(self, data_iter):
        super(MultiTask_iterator, self).__init__('multitask_iter')
        self.data_iter  = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        return self.data_iter.provide_data

    @property
    def provide_label(self):
        provide_label = self.data_iter.provide_label[0]

        # Different labels should be used here for actual application
        return [('softmax_multitask1_label', [provide_label[1][0]]), \
                ('softmax_multitask2_label', [provide_label[1][0]]), \
                ('softmax_multitask3_label', [provide_label[1][0]]), \
                ('softmax_multitask4_label', [provide_label[1][0]]), \
                ('softmax_multitask5_label', [provide_label[1][0]])]


    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch  = self.data_iter.next()

        label = batch.label[0]
        label_numpy = label.asnumpy()

        label1 = mx.nd.array(label_numpy[:, 0]).astype('float32')
        label2 = mx.nd.array(label_numpy[:, 1]).astype('float32')
        label3 = mx.nd.array(label_numpy[:, 2]).astype('float32')
        label4 = mx.nd.array(label_numpy[:, 3]).astype('float32')
        label5 = mx.nd.array(label_numpy[:, 4]).astype('float32')

        return mx.io.DataBatch(data  = batch.data,
                               label = [label1, label2, label3, label4, label5],
                               pad   = batch.pad,
                               index = batch.index)


# define multi task accuracy
class MultiTask_Accuracy(mx.metric.EvalMetric):

    def __init__(self, num = None, output_names = None):
        self.num = num
        super(MultiTask_Accuracy, self).__init__('multi_accuracy', num)
        self.output_names = output_names

    def reset(self):
        ''' Resets the internal evaluation result to initial state.'''
        self.num_inst   = 0 if self.num is None else [0] * self.num
        self.sum_metric = 0.0 if self.num is None else [0.0] * self.num

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)

        if self.num != None:
            assert len(labels) == self.num

        for i in range(len(labels)):
            pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
            label = labels[i].asnumpy().astype('int32')

            mx.metric.check_label_shapes(label, pred_label)

            if self.num is None:
                self.sum_metric += (pred_label.flat == label.flat).sum()
                self.num_inst   += len(pred_label.flat)
            else:
                self.sum_metric[i] += (pred_label.flat == label.flat).sum()
                self.num_inst[i]   += len(pred_label.flat)

    def get(self):
        if self.num is None:
            return super(MultiTask_Accuracy, self).get()
        else:
            return zip(*(('%s-task%d' % (self.name, i), float('nan') if self.num_inst[i] == 0
            else self.sum_metric[i] / self.num_inst[i])
                         for i in range(self.num)))


# for fine-tuning for the MLMT
def get_fine_tune_model(sym, arg_params, num_classes_mt1, num_classes_mt2, num_classes_mt3, num_classes_mt4, num_classes_mt5, layer_name):

    all_layers = sym.get_internals()
    net = all_layers[layer_name + '_output']

    # task1
    fc_multitask1  = mx.symbol.FullyConnected(data = net, num_hidden = num_classes_mt1, name = 'fc_multitask1')
    smo_multitask1 = mx.symbol.SoftmaxOutput(data = fc_multitask1, name = 'softmax_multitask1')

    # task2
    fc_multitask2  = mx.symbol.FullyConnected(data = net, num_hidden = num_classes_mt2, name = 'fc_multitask2')
    smo_multitask2 = mx.symbol.SoftmaxOutput(data = fc_multitask2, name = 'softmax_multitask2')

    # task3
    fc_multitask3  = mx.symbol.FullyConnected(data = net, num_hidden = num_classes_mt3, name = 'fc_multitask3')
    smo_multitask3 = mx.symbol.SoftmaxOutput(data = fc_multitask3, name = 'softmax_multitask3')

    # task4
    fc_multitask4  = mx.symbol.FullyConnected(data = net, num_hidden = num_classes_mt4, name = 'fc_multitask4')
    smo_multitask4 = mx.symbol.SoftmaxOutput(data = fc_multitask4, name = 'softmax_multitask4')

    # task5
    fc_multitask5  = mx.symbol.FullyConnected(data = net, num_hidden = num_classes_mt5, name = 'fc_multitask5')
    smo_multitask5 = mx.symbol.SoftmaxOutput(data = fc_multitask5, name = 'softmax_multitask5')

    softmax_group = mx.symbol.Group([smo_multitask1, smo_multitask2, smo_multitask3, smo_multitask4, smo_multitask5])

    return softmax_group


# learing rate step size setup
def multi_factor_scheduler(begin_epoch, epoch_size, step=[5, 10, 15], factor=0.1):

    step_ = [epoch_size * (x - begin_epoch) for x in step if x - begin_epoch > 0]
    return mx.lr_scheduler.MultiFactorScheduler(step = step_, factor = factor) if len(step_) else None


def train_model(model, gpus, batch_size, image_shape, num_label, epoch = 0, num_epoch = 20, kv = 'device'):

    train = mx.image.ImageIter(
        batch_size   = args.batch_size,
        data_shape   = (3, 224, 224),
        label_width  = num_label,
        path_imglist = args.data_train,
        path_root    = args.image_train,
        part_index   = kv.rank,
        num_parts    = kv.num_workers,
        shuffle      = True,
        data_name    = 'data',
        label_name   = ['softmax_multitask1_label',
                     'softmax_multitask2_label',
                     'softmax_multitask3_label',
                     'softmax_multitask4_label',
                     'softmax_multitask5_label'],
        aug_list     = mx.image.CreateAugmenter((3, 224, 224), resize=224, rand_crop=True, rand_mirror=True, mean=True, std=True)
    )

    val = mx.image.ImageIter(
        batch_size   = args.batch_size,
        data_shape   = (3, 224, 224),
        label_width  = num_label,
        path_imglist = args.data_val,
        path_root    = args.image_val,
        part_index   = kv.rank,
        num_parts    = kv.num_workers,
        data_name    = 'data',
        label_name   = ['softmax_multitask1_label',
                    'softmax_multitask2_label',
                    'softmax_multitask3_label',
                    'softmax_multitask4_label',
                    'softmax_multitask5_label'],
        aug_list     = mx.image.CreateAugmenter((3, 224, 224), resize=224, mean=True, std=True)
    )

    train = MultiTask_iterator(train)
    val   = MultiTask_iterator(val)

    kv = mx.kvstore.create(args.kv_store)

    prefix = model
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

    # flatten0: for resnext-50-symbol.json
    new_sym = get_fine_tune_model(sym,
                                  arg_params,
                                  args.num_classes_mt1,
                                  args.num_classes_mt2,
                                  args.num_classes_mt3,
                                  args.num_classes_mt4,
                                  args.num_classes_mt5,
                                  'flatten0')

    epoch_size = max(int(args.num_examples / args.batch_size / kv.num_workers), 1)
    lr_scheduler = multi_factor_scheduler(args.epoch, epoch_size)

    optimizer_params = {
        'learning_rate': args.lr,
        'momentum': args.mom,
        'wd': args.wd,
        'lr_scheduler': lr_scheduler}

    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)

    if gpus == '':
        devs = mx.cpu()
    else:
        devs = [mx.gpu(int(i)) for i in gpus.split(',')]

    model = mx.mod.Module(
        context     = devs,
        symbol      = new_sym,
        data_names  = ['data'],
        label_names = ['softmax_multitask1_label',
                       'softmax_multitask2_label',
                       'softmax_multitask3_label',
                       'softmax_multitask4_label',
                       'softmax_multitask5_label']
    )

    checkpoint = mx.callback.do_checkpoint(args.save_result)

    eval_metric = mx.metric.CompositeEvalMetric()
    eval_metric.add(MultiTask_Accuracy(num = 5, output_names = ['softmax_multitask1_output',
                                                                'softmax_multitask2_output',
                                                                'softmax_multitask3_output',
                                                                'softmax_multitask4_output',
                                                                'softmax_multitask5_output']))

    model.fit(
        train_data         = train,
        begin_epoch        = epoch,
        num_epoch          = num_epoch,
        eval_data          = val,
        eval_metric        = eval_metric,
        validation_metric  = eval_metric,
        kvstore            = kv,
        optimizer          = 'sgd',
        optimizer_params   = optimizer_params,
        arg_params         = arg_params,
        aux_params         = aux_params,
        initializer        = initializer,
        allow_missing      = True,
        batch_end_callback = mx.callback.Speedometer(args.batch_size, 20),
        epoch_end_callback = checkpoint
    )


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description = 'train a model on a dataset')
    parser.add_argument('--model',       type = str,   default = '/root/mxnet_dpn/models/models_org/resnext-50', required = True)
    parser.add_argument('--gpus',        type = str,   default = '0')
    parser.add_argument('--batch-size',  type = int,   default = 32)
    parser.add_argument('--epoch',       type = int,   default = 0)
    parser.add_argument('--image-shape', type = str,   default = '3,224,224')
    parser.add_argument('--data-train',  type = str,   default = '/root/mxnet_dpn/mxnet/tools/mnist224_train.lst')
    parser.add_argument('--image-train', type = str,   default = '/root/mxnet_datasets/')
    parser.add_argument('--data-val',    type = str,   default = '/root/mxnet_dpn/mxnet/tools/mnist224_test.lst')
    parser.add_argument('--image-val',   type = str,   default = '/root/mxnet_datasets/')
    parser.add_argument('--num-classes-mt1', type = int,   default = 26)
    parser.add_argument('--num-classes-mt2', type = int,   default = 11)
    parser.add_argument('--num-classes-mt3', type = int,   default = 8)
    parser.add_argument('--num-classes-mt4', type = int,   default = 8)
    parser.add_argument('--num-classes-mt5', type = int,   default = 4)
    parser.add_argument('--num-labels',  type = int,   default = 5)
    parser.add_argument('--lr',          type = float, default = 0.01)
    parser.add_argument('--num-epoch',   type = int,   default = 30)
    parser.add_argument('--kv-store',    type = str,   default = 'device', help = 'the kvstore type')
    parser.add_argument('--save-result', type = str,   default = '/root/mxnet_dpn/models/mnist224_resnext50_SLMT/resnext50',
                        help = 'the save path')
    parser.add_argument('--num-examples',type = int,   default = 60000)
    parser.add_argument('--mom',         type = float, default = 0.9,    help = 'momentulm for sgd')
    parser.add_argument('--wd',          type = float, default = 0.0005, help = 'weight decay for sgd')
    args = parser.parse_args()

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    kv = mx.kvstore.create(args.kv_store)

    if not os.path.exists(args.save_result):
        os.mkdir(args.save_result)

    hdlr = logging.FileHandler(args.save_result + '/train.log')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logging.info(args)

    train_model(
        model       = args.model,
        gpus        = args.gpus,
        batch_size  = args.batch_size,
        image_shape = '3,224,224',
        epoch       = args.epoch,
        num_epoch   = args.num_epoch,
        kv          = kv,
        num_label   = args.num_labels
    )
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/u013381011/article/details/80570848

智能推荐

hdu2094——set的应用_hdu2094 为什么集合a-集合b能判断是否产生-程序员宅基地

文章浏览阅读54次。题目n场比赛,每场比赛对应为一行输出,格式为A队战胜B队,经历过n场比赛,请判断是否有冠军产生。算法由题意得但凡是输过的队伍都不能作为冠军,我们只需要把所有队伍放进一个集合,再把比赛失败的队伍放入另一个集合,当全部比赛结束后,比较两个集合的大小,如果A-B=1,说明已经产生冠军了,如果=0则没有冠军,如果>1则说明冠军还未产生。数据结构放入set中进行代码#include<iostream>#include<string>#include<set>_hdu2094 为什么集合a-集合b能判断是否产生

用strtol函数进行进制的转换_stroll函数-程序员宅基地

文章浏览阅读599次。1、stroll函数:(将字符串转换为长整型)可以转换任何进制 第一个参数:字符串开始地址; 第二个参数:二级指针,返回字符串解析时停下来的位置 第三个参数:x进制int main(){ const char* arr = "100!"; int flg = strtol(arr, NULL, 8);//将100转换为8进制 printf("%d\n", flg); return 0;}..._stroll函数

Hbase shell 常用命令笔记_hbase shell 常用命令报错-程序员宅基地

文章浏览阅读7.3k次。下面我们看看HBase Shell的一些基本操作命令,我列出了几个常用的HBase Shell命令,如下:名称命令表达式创建表create '表名称', '列名称1','列名称2','列名称N'添加记录 put '表名称', '行名称', '列名称:', '值'_hbase shell 常用命令报错

利用阿里云OSS开发一个私人网盘/外链系统,php+js实现_云脚本外链系统-程序员宅基地

文章浏览阅读1.7k次。https://segmentfault.com/a/1190000013357533转自上述博客,这里记录了主要的代码段,阿里云的配置可看作者原文。 亲测可用。我这里就利用官方的DEMO来制作一个简单的文件上传系统,非常简单的网盘。下面是截图:代码:index.html<!DOCTYPE html><html><head>..._云脚本外链系统

苹果手机支持鸿蒙,全球第三大手机系统「鸿蒙」上线,这19款能抢先用…-程序员宅基地

文章浏览阅读5k次。官宣!鸿蒙手机这回真要来了!就在昨日,华为公布一条重磅消息:将于6月2日正式举办鸿蒙产品发布会。“等等党”们终于迎来了胜利~随后,在5月25日上午华为宣布,华为EMUI微博正式更名为@HarmonyOS。并且发布出了鸿蒙手机操作系统开机画面视频。文案为16字口号:“‘鸿’鹄志远,一举千里。承‘蒙’厚爱,不负期待”。寓意接下来鸿蒙对手机设备的覆盖,将进一步在操作系统层面,实现万物互联。消息一出可谓是..._苹果手机可以支持鸿蒙系统吗

html点击触发loading,jQuery按钮点击loading加载-程序员宅基地

文章浏览阅读910次。在前加入如下代码body {font-size: 100%;margin: 1em 5em 5em;font-family: 'Lucida Grande', sans-serif;text-align: center;}h1 {margin-top: 1.2em;margin-bottom: 0;padding-bottom: 0;}h2 {font-size: 1em;color: ..._jq 点击按钮使数据变为loding

随便推点

【Java源码】ArrayList源码(上)关于get方法的遗留问题_java arraylist.get为什么返回@-程序员宅基地

文章浏览阅读751次。 问题重现 在ArrayList源码的get方法中,传参为负的异常是如何产生的? 源码只判断了index &gt; size public E get(int index) { rangeCheck(index); return elementData(index); } private void rangeCheck(in..._java arraylist.get为什么返回@

《梦断代码》读书感悟一-程序员宅基地

文章浏览阅读68次。其实早就应该发这一篇了,我周围的人又读了两章就发感悟的,而我整本书快看完了却还没发过。我总是觉得没什么感悟,只是机械的看书。想来,我还是没能真正地把自己当做一个合格的计算机行业从业者,我很难体会到书中程序员平凡中的不平凡。 在本书中,只要是开发就必然有一个团队,虽然没有特意的强调,但它很明确的告诉我程序的开发是一个团队的项目,单打独斗并不适合现在的潮流。我一向比较倾向于团队合作,只有..._阅读资料《梦断代码》?为何发出了『梦断』的感叹?

Oracle sql developer中调试存储过程_oracle sql developer调试存储过程-程序员宅基地

文章浏览阅读2k次。  很奇怪, 网上相关信息很少. 是我搜索的关键词不对吗?  进入过程(Procedures)或程序包(Packages) :点击选择需要调试的程序后, 在右边代码编辑区域 的工具栏找到两个齿轮的图标, 选择"编辑以进行调试(Compile for debug)":点击甲虫(Debug)图标, 填入参数, 即可进行调试:记住, 调试之前, 请先打上debug!注意: 如果配置了VPN, 在debug时, SQL Developer可能无法取得正确的调试主机(Debug _oracle sql developer调试存储过程

css实现元素居中的6种方法_css居中-程序员宅基地

文章浏览阅读3.7w次,点赞29次,收藏113次。相信大家在面试的时候也会经常碰到css实现元素居中的方法,下面我介绍6种方法给大家,欢迎大家评论区交流。给定两个元素,这两个元素是父子级关系并且两个元素的大小都是不确定的,那么这时候如何让子级在父级中上下左右都居中?(暂且设定父级比子级要大一些)。_css居中

Unity 模拟键盘按键事件_unity 模拟触发按键时间-程序员宅基地

文章浏览阅读7.3k次,点赞4次,收藏11次。参考文章:http://blog.csdn.net/crazyape/article/details/70666598有时候我们将一些逻辑绑定在了一个键盘事件上,而在别处我们又需要调用这段代码,我们可以选择将之前的代码写成方法调用一次,也可以选择模拟之前的键盘事件,让这个按键假装被按下了或抬起了。using System.Collections;using System.Collection_unity 模拟触发按键时间

Linux中kill命令杀不掉进程的解决办法_linux进程杀不掉-程序员宅基地

文章浏览阅读1.3w次,点赞3次,收藏25次。Linux中kill命令杀不掉进程的解决办法_linux进程杀不掉

推荐文章

热门文章

相关标签