简介
Matlab降低了深度神经网络的开发难度,可以通过拖拽的模式设计网络,甚至训练的过程也是GUI操作。
实例
以高光谱图像分类为例,参考文献 。构造一个卷积神经网络,输入为
9
×
9
×
B
9\times 9\times B
9×9×B 的图像,其中
B
B
B为波段数,类标为中心像素的标签。
网络设计
在Matlab的APPS中搜索Deep Network工具箱,打开后,选择New
来创建网络,在弹出的界面中可以选择创建空白网络,也可以选择预训练的网络。
进入设计洁面后,从左侧拖拽相应的模块,命名-->设置参数-->连接不同模块
,网络搭建完成后,可以选择Analyze
来分析下网络,看看有没有错误,没有错误责可以导出代码。
代码
主训练文件"train_cnn.m",主要完成加载数据、从图像中随机抽取小的图像块,构造训练集,验证集和测试集。注意,真值变量需要用categorical
函数转换一下。
load('../data/WHU_Hi_HongHu_preprocessing_tensor_edgemap_7.mat')
rng(2022);
% In the experiments, the patch sizes of the three datasets were set as
% 9 × 9 × d, where d denotes the band number of the remote sensing image.
Ntrain = 1000;
Nvalid = 500;
Ntest = 200;
ptcsize = [9, 9];
M = ones(size(Label));
nclass = length(unique(Label));
[X, Y] = sample_patchs(X, Label, M, ptcsize, nclass, Ntrain+Nvalid+Ntest);
Xtrain = X(:, :, :, 1:Ntrain); Ytrain = Y(1:Ntrain);
Xvalid = X(:, :, :, 1:Nvalid); Yvalid = Y(1:Nvalid);
Xtest = X(:, :, :, 1:Ntest); Ytest = Y(1:Ntest);
Ytrain = categorical(Ytrain);
Yvalid = categorical(Yvalid);
Ytest = categorical(Ytest);
layers = [
imageInputLayer([9 9 270],"Name","imageinput")
convolution2dLayer([3 3],128,"Name","conv1")
batchNormalizationLayer("Name","batchnorm1")
reluLayer("Name","relu1")
convolution2dLayer([3 3],256,"Name","conv2")
batchNormalizationLayer("Name","batchnorm2")
reluLayer("Name","relu2")
convolution2dLayer([3 3],256,"Name","conv3","Padding","same")
batchNormalizationLayer("Name","batchnorm3")
reluLayer("Name","relu3")
convolution2dLayer([3 3],128,"Name","conv4")
batchNormalizationLayer("Name","batchnorm4")
reluLayer("Name","relu4")
fullyConnectedLayer(128,"Name","fc1")
batchNormalizationLayer("Name","batchnorm5")
reluLayer("Name","relu5")
fullyConnectedLayer(64,"Name","fc2")
batchNormalizationLayer("Name","batchnorm6")
reluLayer("Name","relu6")
fullyConnectedLayer(nclass,"Name","fc3")
softmaxLayer("Name","softmax")
classificationLayer("Name", "classoutput")];
% plot(layerGraph(layers));
options = trainingOptions('adam', ...
'ValidationData', {Xvalid, Yvalid}, ...
'Plots', 'training-progress', ...
'MaxEpochs', 100, ...
'Shuffle', 'every-epoch', ...
'InitialLearnRate', 1e-3, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.1, ...
'LearnRateDropPeriod', 20, ...
'ExecutionEnvironment', 'gpu', ...
'MiniBatchSize', 32);
net = trainNetwork(Xtrain, Ytrain, layers, options);
Ptest = classify(net, Xtest);
precision = sum(Ptest==Ytest) / numel(Ptest);
disp(precision)
随机选图像块文件 “sample_patchs.m”
function [Xp, Yp] = sample_patchs(X, Y, M, ptcsize, nclass, nptcs)
% X: Data image
% Y: Label image
% M: mask: 1: candidate
% ptcsize: size (h, w) of patch
% nclass: number of classes
% nptcs: number of patchs
if isempty(ptcsize)
ptcsize = [9, 9];
end
if isempty(nptcs)
nptcs = 100;
end
pH = ptcsize(1);
pW = ptcsize(2);
pH2 = floor(pH / 2.);
pW2 = floor(pW / 2.);
[xH, xW, C] = size(X);
M(1:pH2, :) = 0; % boundary
M(xH-pH2:xH, :) = 0; % boundary
M(:, 1:pW2) = 0; % boundary
M(:, xW-pW2:xW) = 0; % boundary
[rows, cols] = find(M==1);
npixel = length(rows);
idx = randi([1, npixel], nptcs, 1);
idxH = rows(idx);
idxW = cols(idx);
Xp = zeros(ptcsize(1), ptcsize(2), C, nptcs);
Yp = zeros(nptcs, 1);
% Yp = zeros(nptcs, nclass); % one-hot
for i = 1:nptcs
Xp(:, :, :, i) = X(idxH(i) - pH2:idxH(i) + pH2, idxW(i) - pW2:idxW(i) + pW2, :);
Yp(i, 1) = Y(idxH(i), idxW(i));
% Yp(i, Y(idxH(i), idxW(i)) + 1) = 1; % one-hot
end
运行结果
下图为训练过程的日志结果,图中曲线和一些统计信息是Matlab自动绘制的,不需要自己额外添加代码。
此外,Matlab命令窗口也有相应的信息,如下:
>> train_cnn
Initializing input data normalization.
|======================================================================================================================|
| Epoch | Iteration | Time Elapsed | Mini-batch | Validation | Mini-batch | Validation | Base Learning |
| | | (hh:mm:ss) | Accuracy | Accuracy | Loss | Loss | Rate |
|======================================================================================================================|
| 1 | 1 | 00:00:08 | 0.00% | 30.60% | 3.6983 | 2.8994 | 0.0010 |
| 2 | 50 | 00:00:11 | 75.00% | 64.80% | 1.0919 | 1.2310 | 0.0010 |
| 4 | 100 | 00:00:14 | 65.62% | 65.20% | 0.9713 | 1.0583 | 0.0010 |
| 5 | 150 | 00:00:17 | 62.50% | 74.80% | 1.1589 | 0.8747 | 0.0010 |
| 7 | 200 | 00:00:19 | 62.50% | 73.00% | 0.9210 | 0.8468 | 0.0010 |
| 9 | 250 | 00:00:23 | 78.12% | 76.60% | 0.6505 | 0.7860 | 0.0010 |
| 10 | 300 | 00:00:25 | 78.12% | 77.80% | 0.7985 | 0.7317 | 0.0010 |
| 12 | 350 | 00:00:28 | 81.25% | 80.00% | 0.6691 | 0.6691 | 0.0010 |
| 13 | 400 | 00:00:31 | 71.88% | 80.20% | 0.9969 | 0.6473 | 0.0010 |
| 15 | 450 | 00:00:34 | 87.50% | 80.20% | 0.4374 | 0.6442 | 0.0010 |
| 17 | 500 | 00:00:37 | 84.38% | 81.20% | 0.4327 | 0.6272 | 0.0010 |
| 18 | 550 | 00:00:39 | 84.38% | 83.80% | 0.3872 | 0.5438 | 0.0010 |
| 20 | 600 | 00:00:42 | 81.25% | 83.00% | 0.6669 | 0.5028 | 0.0010 |
| 21 | 650 | 00:00:45 | 81.25% | 86.40% | 0.4656 | 0.4147 | 0.0001 |
| 23 | 700 | 00:00:48 | 78.12% | 88.00% | 0.6784 | 0.3880 | 0.0001 |
| 25 | 750 | 00:00:51 | 96.88% | 88.40% | 0.2379 | 0.3900 | 0.0001 |
| 26 | 800 | 00:00:53 | 93.75% | 88.20% | 0.3173 | 0.4199 | 0.0001 |
| 28 | 850 | 00:00:56 | 87.50% | 89.00% | 0.3716 | 0.3864 | 0.0001 |
| 30 | 900 | 00:00:59 | 87.50% | 89.20% | 0.3112 | 0.3499 | 0.0001 |
| 31 | 950 | 00:01:01 | 81.25% | 90.60% | 0.4589 | 0.3472 | 0.0001 |
| 33 | 1000 | 00:01:04 | 90.62% | 90.20% | 0.2410 | 0.3030 | 0.0001 |
| 34 | 1050 | 00:01:07 | 96.88% | 91.00% | 0.2589 | 0.3052 | 0.0001 |
| 36 | 1100 | 00:01:10 | 84.38% | 92.00% | 0.5322 | 0.2920 | 0.0001 |
| 38 | 1150 | 00:01:12 | 96.88% | 91.20% | 0.2072 | 0.2998 | 0.0001 |
| 39 | 1200 | 00:01:15 | 90.62% | 92.20% | 0.2447 | 0.2759 | 0.0001 |
| 41 | 1250 | 00:01:18 | 93.75% | 92.00% | 0.1627 | 0.2724 | 1.0000e-05 |
| 42 | 1300 | 00:01:20 | 96.88% | 92.40% | 0.1265 | 0.2751 | 1.0000e-05 |
| 44 | 1350 | 00:01:23 | 93.75% | 90.80% | 0.1679 | 0.3054 | 1.0000e-05 |
| 46 | 1400 | 00:01:26 | 96.88% | 93.40% | 0.1650 | 0.2544 | 1.0000e-05 |
| 47 | 1450 | 00:01:29 | 93.75% | 92.20% | 0.2000 | 0.2709 | 1.0000e-05 |
| 49 | 1500 | 00:01:32 | 93.75% | 92.40% | 0.1877 | 0.2520 | 1.0000e-05 |
| 50 | 1550 | 00:01:34 | 93.75% | 92.20% | 0.1618 | 0.2842 | 1.0000e-05 |
| 52 | 1600 | 00:01:37 | 93.75% | 91.80% | 0.3416 | 0.2809 | 1.0000e-05 |
| 54 | 1650 | 00:01:40 | 96.88% | 91.60% | 0.1159 | 0.2628 | 1.0000e-05 |
| 55 | 1700 | 00:01:43 | 90.62% | 94.00% | 0.2882 | 0.2346 | 1.0000e-05 |
| 57 | 1750 | 00:01:46 | 93.75% | 93.00% | 0.1924 | 0.2571 | 1.0000e-05 |
| 59 | 1800 | 00:01:48 | 100.00% | 94.40% | 0.0592 | 0.2273 | 1.0000e-05 |
| 60 | 1850 | 00:01:51 | 93.75% | 91.40% | 0.1993 | 0.2669 | 1.0000e-05 |
| 62 | 1900 | 00:01:54 | 87.50% | 91.00% | 0.3692 | 0.2943 | 1.0000e-06 |
| 63 | 1950 | 00:01:57 | 96.88% | 92.80% | 0.2041 | 0.2607 | 1.0000e-06 |
| 65 | 2000 | 00:02:00 | 93.75% | 91.60% | 0.2100 | 0.2653 | 1.0000e-06 |
| 67 | 2050 | 00:02:03 | 87.50% | 92.60% | 0.3792 | 0.2715 | 1.0000e-06 |
| 68 | 2100 | 00:02:06 | 93.75% | 91.80% | 0.1791 | 0.2868 | 1.0000e-06 |
| 70 | 2150 | 00:02:08 | 96.88% | 92.60% | 0.2040 | 0.2728 | 1.0000e-06 |
| 71 | 2200 | 00:02:11 | 90.62% | 93.20% | 0.2053 | 0.2353 | 1.0000e-06 |
| 73 | 2250 | 00:02:14 | 93.75% | 93.60% | 0.2120 | 0.2299 | 1.0000e-06 |
| 75 | 2300 | 00:02:17 | 90.62% | 93.20% | 0.2796 | 0.2405 | 1.0000e-06 |
| 76 | 2350 | 00:02:19 | 93.75% | 92.60% | 0.2731 | 0.2586 | 1.0000e-06 |
| 78 | 2400 | 00:02:22 | 93.75% | 91.80% | 0.1932 | 0.2732 | 1.0000e-06 |
| 80 | 2450 | 00:02:25 | 96.88% | 92.80% | 0.1315 | 0.2484 | 1.0000e-06 |
| 81 | 2500 | 00:02:28 | 93.75% | 93.60% | 0.2221 | 0.2730 | 1.0000e-07 |
| 83 | 2550 | 00:02:31 | 93.75% | 92.20% | 0.1957 | 0.2558 | 1.0000e-07 |
| 84 | 2600 | 00:02:34 | 96.88% | 91.80% | 0.1457 | 0.2807 | 1.0000e-07 |
| 86 | 2650 | 00:02:36 | 87.50% | 93.20% | 0.4540 | 0.2724 | 1.0000e-07 |
| 88 | 2700 | 00:02:39 | 93.75% | 93.40% | 0.2235 | 0.2315 | 1.0000e-07 |
| 89 | 2750 | 00:02:42 | 100.00% | 93.40% | 0.0892 | 0.2506 | 1.0000e-07 |
| 91 | 2800 | 00:02:45 | 93.75% | 92.00% | 0.2005 | 0.2666 | 1.0000e-07 |
| 92 | 2850 | 00:02:48 | 100.00% | 91.20% | 0.1301 | 0.2748 | 1.0000e-07 |
| 94 | 2900 | 00:02:51 | 96.88% | 92.20% | 0.1594 | 0.2691 | 1.0000e-07 |
| 96 | 2950 | 00:02:53 | 93.75% | 93.00% | 0.1665 | 0.2548 | 1.0000e-07 |
| 97 | 3000 | 00:02:56 | 93.75% | 94.00% | 0.2878 | 0.2366 | 1.0000e-07 |
| 99 | 3050 | 00:02:59 | 90.62% | 92.00% | 0.1891 | 0.2761 | 1.0000e-07 |
| 100 | 3100 | 00:03:02 | 93.75% | 92.00% | 0.1937 | 0.2665 | 1.0000e-07 |
|======================================================================================================================|
0.9500
参考文献
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)