torchvision分类介绍
Torchvision高版本支持各种SOTA的图像分类模型,同时还支持不同数据集分类模型的预训练模型的切换。使用起来十分方便快捷,Pytroch中支持两种迁移学习方式,分别是:
- Finetune模式 基于预训练模型,全链路调优参数 - 冻结特征层模式 这种方式只修改输出层的参数,CNN部分的参数冻结上述两种迁移方式,分别适合大量数据跟少量数据,前一种方式计算跟训练时间会比第二种方式要长点,但是针对大量自定义分类数据效果会比较好。
自定义分类模型修改与训练
加载模型之后,feature_extracting 为true表示冻结模式,否则为finetune模式,相关的代码如下:
def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False以resnet18为例,修改之后的自定义训练代码如下:
model_ft=models.resnet18(pretrained=True) num_ftrs=model_ft.fc.in_features #Herethesizeofeachoutputsampleissetto5. #Alternatively,itcanbegeneralizedtonn.Linear(num_ftrs,len(class_names)). model_ft.fc=nn.Linear(num_ftrs,5) model_ft=model_ft.to(device) criterion=nn.CrossEntropyLoss() #Observethatallparametersarebeingoptimized optimizer_ft=optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9) #DecayLRbyafactorof0.1every7epochs exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) model_ft=train_model(model_ft,criterion,optimizer_ft,exp_lr_scheduler, num_epochs=25)
数据集是flowers-dataset,有五个分类分别是:
daisy dandelion roses sunflowers tulips
全链路调优,迁移学习训练CNN部分的权重参数
Epoch0/24 ---------- trainLoss:1.3993Acc:0.5597 validLoss:1.8571Acc:0.7073 Epoch1/24 ---------- trainLoss:1.0903Acc:0.6580 validLoss:0.6150Acc:0.7805 Epoch2/24 ---------- trainLoss:0.9095Acc:0.6991 validLoss:0.4386Acc:0.8049 Epoch3/24 ---------- trainLoss:0.7628Acc:0.7349 validLoss:0.9111Acc:0.7317 Epoch4/24 ---------- trainLoss:0.7107Acc:0.7669 validLoss:0.4854Acc:0.8049 Epoch5/24 ---------- trainLoss:0.6231Acc:0.7793 validLoss:0.6822Acc:0.8049 Epoch6/24 ---------- trainLoss:0.5768Acc:0.8033 validLoss:0.2748Acc:0.8780 Epoch7/24 ---------- trainLoss:0.5448Acc:0.8110 validLoss:0.4440Acc:0.7561 Epoch8/24 ---------- trainLoss:0.5037Acc:0.8170 validLoss:0.2900Acc:0.9268 Epoch9/24 ---------- trainLoss:0.4836Acc:0.8360 validLoss:0.7108Acc:0.7805 Epoch10/24 ---------- trainLoss:0.4663Acc:0.8369 validLoss:0.5868Acc:0.8049 Epoch11/24 ---------- trainLoss:0.4276Acc:0.8504 validLoss:0.6998Acc:0.8293 Epoch12/24 ---------- trainLoss:0.4299Acc:0.8529 validLoss:0.6449Acc:0.8049 Epoch13/24 ---------- trainLoss:0.4256Acc:0.8567 validLoss:0.7897Acc:0.7805 Epoch14/24 ---------- trainLoss:0.4062Acc:0.8559 validLoss:0.5855Acc:0.8293 Epoch15/24 ---------- trainLoss:0.4030Acc:0.8545 validLoss:0.7336Acc:0.7805 Epoch16/24 ---------- trainLoss:0.3786Acc:0.8730 validLoss:1.0429Acc:0.7561 Epoch17/24 ---------- trainLoss:0.3699Acc:0.8763 validLoss:0.4549Acc:0.8293 Epoch18/24 ---------- trainLoss:0.3394Acc:0.8788 validLoss:0.2828Acc:0.9024 Epoch19/24 ---------- trainLoss:0.3300Acc:0.8834 validLoss:0.6766Acc:0.8537 Epoch20/24 ---------- trainLoss:0.3136Acc:0.8906 validLoss:0.5893Acc:0.8537 Epoch21/24 ---------- trainLoss:0.3110Acc:0.8901 validLoss:0.4909Acc:0.8537 Epoch22/24 ---------- trainLoss:0.3141Acc:0.8931 validLoss:0.3930Acc:0.9024 Epoch23/24 ---------- trainLoss:0.3106Acc:0.8887 validLoss:0.3079Acc:0.9024 Epoch24/24 ---------- trainLoss:0.3143Acc:0.8923 validLoss:0.5122Acc:0.8049 Trainingcompletein25m34s BestvalAcc:0.926829
冻结CNN部分,只训练全连接分类权重
Paramstolearn: fc.weight fc.bias Epoch0/24 ---------- trainLoss:1.0217Acc:0.6465 validLoss:1.5317Acc:0.8049 Epoch1/24 ---------- trainLoss:0.9569Acc:0.6947 validLoss:1.2450Acc:0.6829 Epoch2/24 ---------- trainLoss:1.0280Acc:0.6999 validLoss:1.5677Acc:0.7805 Epoch3/24 ---------- trainLoss:0.8344Acc:0.7426 validLoss:1.1053Acc:0.7317 Epoch4/24 ---------- trainLoss:0.9110Acc:0.7250 validLoss:1.1148Acc:0.7561 Epoch5/24 ---------- trainLoss:0.9049Acc:0.7346 validLoss:1.1541Acc:0.6341 Epoch6/24 ---------- trainLoss:0.8538Acc:0.7465 validLoss:1.4098Acc:0.8293 Epoch7/24 ---------- trainLoss:0.9041Acc:0.7349 validLoss:0.9604Acc:0.7561 Epoch8/24 ---------- trainLoss:0.8885Acc:0.7468 validLoss:1.2603Acc:0.7561 Epoch9/24 ---------- trainLoss:0.9257Acc:0.7333 validLoss:1.0751Acc:0.7561 Epoch10/24 ---------- trainLoss:0.8637Acc:0.7492 validLoss:0.9748Acc:0.7317 Epoch11/24 ---------- trainLoss:0.8686Acc:0.7517 validLoss:1.0194Acc:0.8049 Epoch12/24 ---------- trainLoss:0.8492Acc:0.7572 validLoss:1.0378Acc:0.7317 Epoch13/24 ---------- trainLoss:0.8773Acc:0.7432 validLoss:0.7224Acc:0.8049 Epoch14/24 ---------- trainLoss:0.8919Acc:0.7473 validLoss:1.3564Acc:0.7805 Epoch15/24 ---------- trainLoss:0.8634Acc:0.7490 validLoss:0.7822Acc:0.7805 Epoch16/24 ---------- trainLoss:0.8069Acc:0.7644 validLoss:1.4132Acc:0.7561 Epoch17/24 ---------- trainLoss:0.8589Acc:0.7492 validLoss:0.9812Acc:0.8049 Epoch18/24 ---------- trainLoss:0.7677Acc:0.7688 validLoss:0.7176Acc:0.8293 Epoch19/24 ---------- trainLoss:0.8044Acc:0.7514 validLoss:1.4486Acc:0.7561 Epoch20/24 ---------- trainLoss:0.7916Acc:0.7564 validLoss:1.0575Acc:0.8049 Epoch21/24 ---------- trainLoss:0.7922Acc:0.7647 validLoss:1.0406Acc:0.7805 Epoch22/24 ---------- trainLoss:0.8187Acc:0.7647 validLoss:1.0965Acc:0.7561 Epoch23/24 ---------- trainLoss:0.8443Acc:0.7503 validLoss:1.6163Acc:0.7317 Epoch24/24 ---------- trainLoss:0.8165Acc:0.7583 validLoss:1.1680Acc:0.7561 Trainingcompletein20m7s BestvalAcc:0.829268
测试结果:
零代码训练演示
我已经完成torchvision中分类模型自定义数据集迁移学习的代码封装与开发,支持基于收集到的数据集,零代码训练,生成模型。图示如下:
-
数据
+关注
关注
8文章
7268浏览量
92453 -
模型
+关注
关注
1文章
3549浏览量
50753 -
迁移学习
+关注
关注
0文章
74浏览量
5757
原文标题:tochvision轻松支持十种图像分类模型迁移学习
文章出处:【微信号:CVSCHOOL,微信公众号:OpenCV学堂】欢迎添加关注!文章转载请注明出处。
发布评论请先 登录
linux配置mysql的两种方式
SQL语言的两种使用方式
WiMAX系统中两种多天线技术的原理和特点详述

评论