为了探究影响模型运行时间的变量,之前运用了参数量做标准
参数量在TF框架下还是很容易计算的
TF框架下运用
tf.keras.models.Model().summary()
就能一键生成包含模型的layers,output,parameters的报告
为了探究其他标准用于反映模型的运行时间,我们在网上找到了三个标准:
参数量(parameters)
浮点运算次数(FLOPs)
内存访问次数(MAC)
这次我们就来探究FLOPs对模型latency的影响
一上来先踩坑
TF2.X
取消了Profiler
接口对于FLOPs的统计
即使通过网上给的TF1.X
接口再用TF2.X compat.v1
接口依然不能成功返回FLOPs的值
于是只能写个程序硬算
FLOPs本质上是模型中的乘法和加法运算,
模型里面的layers有:
Input layer
Zero Padding layer
Conv2D
BatchNormalization
Activation
Depthwise conv2D
Dense
其中因为浮点运算次数很少而可以忽略不记的layers:
Input,
zero padding,
BatchNormalization,
Activation
剩下需要计算的layers就是:
Conv2D,
Depthwise Conv2D,
Dense
Conv2D:
F
L
O
P
s
=
C
i
n
∗
K
∗
K
∗
H
∗
W
∗
C
o
u
t
FLOPs=Cin*K*K*H*W*Cout
FLOPs=Cin∗K∗K∗H∗W∗Cout
Cin是输入的channel
K*K是kernel size
H*W是输出size
Cout是输出channel
Depthwise Conv2D:
F
L
O
P
s
=
C
i
n
∗
H
∗
W
∗
K
∗
K
/
S
/
S
FLOPs=Cin*H*W*K*K/S/S
FLOPs=Cin∗H∗W∗K∗K/S/S
Cin是输入的channel
K*K是kernel size
H*W是输入size
S*S是strides
Dense:
F
L
O
P
s
=
2
∗
C
i
n
∗
C
o
u
t
FLOPs=2*Cin*Cout
FLOPs=2∗Cin∗Cout
Cin是输入的channel
Cout是输出channel
基本层定义好了之后就按照模型的结构将运算FLOPs的结构搭起来
每一层输出的数据跟模型的也是一样的
def conv2DDRB(filter,kernel,cin):
H=cin[0]
W=cin[1]
conv2dDRB_FLOPs=cin[2]*kernel[0]*kernel[1]*filter*H*W
cout=[H,W,filter]
return conv2dDRB_FLOPs,cout
def DepthwiseConv2DDRB(kernel,stride,cin):
out=[0,0,0]
if stride[0]==1 :
DepConv2dDRB_FLOPs=cin[0]*cin[1]*cin[2]*kernel[0]*kernel[1]
out=cin
else:
DepConv2dDRB_FLOPs=cin[0]*cin[1]*cin[2]*kernel[0]*kernel[1]/stride[0]/stride[1]
out[0]=cin[0]/stride[0]
out[1]=cin[1]/stride[1]
out[2]=cin[2]
return DepConv2dDRB_FLOPs, out
def DRB(cin, filter,kernel,stride,t):
exp_channel=cin[2]*t
alpha=filter
block_counter=0
conv_flop,cout=conv2DDRB(exp_channel,(1,1),cin)
block_counter+=conv_flop
dep_flop,cout=DepthwiseConv2DDRB(kernel,stride,cout)
block_counter+=dep_flop
conv_flop,cout=conv2DDRB(alpha,(1,1),cout)
block_counter+=conv_flop
print('This Block FLOPs:',block_counter,'Output:',cout)
return block_counter,cout
def conv(filter,kernel,stride,cin):
H=cin[0]/stride[0]
W=cin[1]/stride[1]
conv_FLOPs=cin[2]*kernel[0]*kernel[1]*filter*H*W
cout=[H,W,filter]
return conv_FLOPs, cout
def GolbalAvgPool(cin):
features=cin[2]
print('Output:',features)
return features
def den(classes,fin):
flop=2*fin*classes
print('This layer FLOPs:',flop,'Output:',classes)
return flop, classes
def MNV2(classes):
total_count=0
img_input=(224,224,3)
conv_flop,cout=conv(32,(3,3),(2,2),img_input)
total_count+=conv_flop
print('This layer FLOPs:',conv_flop,'Output:',cout)
Dep_flop,cout=DepthwiseConv2DDRB((3,3),(1,1),cout)
total_count+=Dep_flop
print('This layer FLOPs:', Dep_flop, 'Output:', cout)
conv_flop,cout=conv(16,(1,1),(1,1),cout)
total_count += conv_flop
print('This layer FLOPs:', conv_flop, 'Output:', cout)
DRB_flops,cout=DRB(cout,24,(3,3),(2,2),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,24,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,32,(3,3),(2,2),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,32,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,32,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,64,(3,3),(2,2),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,64,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,64,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,64,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,96,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,96,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,96,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,160,(3,3),(2,2),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,160,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,160,(3,3),(1,1),6)
total_count+=DRB_flops
DRB_flops,cout=DRB(cout,320,(3,3),(1,1),6)
total_count+=DRB_flops
conv_flop,cout=conv(1280,(1,1),(1,1),cout)
total_count+=conv_flop
cout=GolbalAvgPool(cout)
Den_flops,classes=den(classes,cout)
total_count+=Den_flops
print('TOTAL FLOPs:',total_count,'Output:',classes)
return total_count,classes
def main():
MNV2(5)
if __name__=='__main__':
main()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)