YOLOv2代码分析_读取labels[by zhangzexuan]
- YOLOv2代码分析_读取labelsby zhangzexuan
嗯……现在参与的项目要求在人脸检测步骤直接连同人脸特征点一起预测出来,因此需要我做标题说明的工作,但是看了很久的YOLOv2代码也没有什么头绪,初步的想法是需要修改以下几个地方的代码:
- 读取ground-truth标签文档
- 损失函数
- 非极大值抑制
但是YOLOv2的代码我读了好几天,都没有找到在什么地方读取的ground-truth标签,就决定把看到的想到的东西记录下来,慢慢研究,好记性不如烂笔头嘛。
YOLOv2的输入
yolov2在训练的时候需要输入的材料如下:
- .data文件,其中记录了检测类别数,训练集图片路径文档路径,验证集图片路径文档路径,类别名文件路径,权重文件存储路径
- 训练集图片路径文档
- 验证集图片路径文档
- 类别名文件
- ground-truth标签文档
由以上可以看出.data 文件指示了图片地址但没有指示labels的地址,那么yolo是如何找到labels文档的呢。其实这里yolo使用的是标准voc数据格式,数据的部署采用voc的目录结构,那么知道了图片地址,就可以在图片路径的上级目录中的labels子目录下找到所有的label文档。当然这些肯定是在代码中实现的了,要想找到读取label文档的代码,那么首先就是要找到读取.data文件的代码,看看代码是在什么地方通过图片路径去寻找label文档的。
代码阅读
在detector.c的最后部分,函数run_detector对终端指令参数进行分析,指示train参数转到函数train_detector,因此接下来我们主要看的就是train_detector的内容。
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{
list *options = read_data_cfg(datacfg);
char *train_images = option_find_str(options, "train", "data/train.list");
char *backup_directory = option_find_str(options, "backup", "/backup/");
上面是函数train_detector的头三行代码,参数datacfg就是终端命令参数中的.data文件路径,read_data_cfg()将.data的内容读取出来存储在一个名为options的list类型变量中,option_find_str()将options变量中的关键词为train的内容读取到字符串train_images中,就是说字符串train_images的内容为训练图片路径文档的路径。
list *plist = get_paths(train_images);
char **paths = (char **)list_to_array(plist);
load_args args = get_base_args(net);
args.coords = l.coords;
args.paths = paths;
这里get_paths()通过train_images中存储的路径将训练图片路径文档中的每一行都存入一个字符串中,然后将每个字符串地址存入一个node结构体中,并返回一个list结构体指针,这个list结构体中包含指向第一个和最后一个node结构体的指针。总的来说就是把所有的图片路径都存进一个链表中,list结构体中的指针分别指向这个链表的头结点和尾结点。
函数list_ro_array()只是单纯将存在list中的图片路径全部存入到字符串数组paths中去。
可以看到paths的值又被赋给了args.paths,接下来我们看args.paths的动向。
pthread_t load_thread = load_data(args);
clock_t time;
int count = 0;
while(get_current_batch(net) < net.max_batches){
if(l.random && count++%10 == 0){
printf("Resizing\n");
int dim = (rand() % 10 + 10) * 32;
if (get_current_batch(net)+200 > net.max_batches) dim = 608;
printf("%d\n", dim);
args.w = dim;
args.h = dim;
pthread_join(load_thread, 0);
train = buffer;
free_data(train);
load_thread = load_data(args);
for(i = 0; i < ngpus; ++i){
resize_network(nets + i, dim, dim);
}
net = nets[0];
}
time=clock();
pthread_join(load_thread, 0);
train = buffer;
load_thread = load_data(args);
args.paths变量在train_detector()中没有被使用,那么就只有到函数load_data(args)中去寻找args.paths的蛛丝马迹。
pthread_t load_data(load_args args)
{
pthread_t thread;
struct load_args *ptr = calloc(1, sizeof(struct load_args));
*ptr = args;
if(pthread_create(&thread, 0, load_threads, ptr)) error("Thread creation failed");
return thread;
}
由load_data的定义可以看出这个函数只是创建了一个新线程,这个线程的函数入口是load_threads,那么我们继续转到load_threads去追寻args.paths的足迹。
void *load_threads(void *ptr)
{
int i;
load_args args = *(load_args *)ptr;
if (args.threads == 0) args.threads = 1;
data *out = args.d;
int total = args.n;
free(ptr);
data *buffers = calloc(args.threads, sizeof(data));
pthread_t *threads = calloc(args.threads, sizeof(pthread_t));
for(i = 0; i < args.threads; ++i){
args.d = buffers + i;
args.n = (i+1) * total/args.threads - i * total/args.threads;
threads[i] = load_data_in_thread(args);
}
for(i = 0; i < args.threads; ++i){
pthread_join(threads[i], 0);
}
*out = concat_datas(buffers, args.threads);
out->shallow = 0;
for(i = 0; i < args.threads; ++i){
buffers[i].shallow = 1;
free_data(buffers[i]);
}
free(buffers);
free(threads);
return 0;
}
可以看到,在load_threads()中仍然没有使用args.paths,只是其中的load_data_in_thread()函数将args作为参数,那么继续去查看该函数的定义。
pthread_t load_data_in_thread(load_args args)
{
pthread_t thread;
struct load_args *ptr = calloc(1, sizeof(struct load_args));
*ptr = args;
if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed");
return thread;
}
可见其又创建了一个新线程,入口为load_thread,继续我们追寻的脚步。
void *load_thread(void *ptr)
{
load_args a = *(struct load_args*)ptr;
if(a.exposure == 0) a.exposure = 1;
if(a.saturation == 0) a.saturation = 1;
if(a.aspect == 0) a.aspect = 1;
if (a.type == OLD_CLASSIFICATION_DATA){
*a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
} else if (a.type == REGRESSION_DATA){
*a.d = load_data_regression(a.paths, a.n, a.m, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
} else if (a.type == CLASSIFICATION_DATA){
*a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.center);
} else if (a.type == SUPER_DATA){
*a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
} else if (a.type == WRITING_DATA){
*a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
} else if (a.type == INSTANCE_DATA){
*a.d = load_data_iseg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.coords, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
} else if (a.type == SEGMENTATION_DATA){
*a.d = load_data_seg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.scale);
} else if (a.type == REGION_DATA){
*a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
} else if (a.type == DETECTION_DATA){
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
} else if (a.type == SWAG_DATA){
*a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
} else if (a.type == COMPARE_DATA){
*a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
} else if (a.type == IMAGE_DATA){
*(a.im) = load_image_color(a.path, 0, 0);
*(a.resized) = resize_image(*(a.im), a.w, a.h);
} else if (a.type == LETTERBOX_DATA){
*(a.im) = load_image_color(a.path, 0, 0);
*(a.resized) = letterbox_image(*(a.im), a.w, a.h);
} else if (a.type == TAG_DATA){
*a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
}
free(ptr);
return 0;
}
终于在这里我们看到了args.paths的踪迹,args作为参数被拷贝给变量a,a.paths的值就是我们一直追寻的args.paths的值,从detector.c中我们看到args.type的值为DETECTION_DATA,因此我们继续看将a.paths作为参数的函数load_data_detection()。
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure)
{
char **random_paths = get_random_paths(paths, n, m);
int i;
data d = {0};
d.shallow = 0;
d.X.rows = n;
d.X.vals = calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*3;
d.y = make_matrix(n, 5*boxes);
for(i = 0; i < n; ++i){
image orig = load_image_color(random_paths[i], 0, 0);
image sized = make_image(w, h, orig.c);
fill_image(sized, .5);
float dw = jitter * orig.w;
float dh = jitter * orig.h;
float new_ar = (orig.w + rand_uniform(-dw, dw)) / (orig.h + rand_uniform(-dh, dh));
float scale = rand_uniform(.25, 2);
float nw, nh;
if(new_ar < 1){
nh = scale * h;
nw = nh * new_ar;
} else {
nw = scale * w;
nh = nw / new_ar;
}
float dx = rand_uniform(0, w - nw);
float dy = rand_uniform(0, h - nh);
place_image(orig, nw, nh, dx, dy, sized);
random_distort_image(sized, hue, saturation, exposure);
int flip = rand()%2;
if(flip) flip_image(sized);
d.X.vals[i] = sized.data;
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, -dx/w, -dy/h, nw/w, nh/h);
free_image(orig);
}
free(random_paths);
return d;
}
a.paths是作为第二个参数传入的,也就是函数定义中的char **paths,我们看看函数中都对paths做了什么事情。
可以看到,这段代码中只有get_random_paths()将paths作为参数引用了,这个函数返回一个打乱顺序的图片路径数组,即random_paths。那么random_paths在这段代码中被使用了两次,分别作为参数传入函数load_image_color()以及函数fill_truth_detection(),我们按顺序来看。
image load_image_color(char *filename, int w, int h)
{
return load_image(filename, w, h, 3);
}
image load_image(char *filename, int w, int h, int c)
{
#ifdef OPENCV
image out = load_image_cv(filename, c);
#else
image out = load_image_stb(filename, c);
#endif
if((h && w) && (h != out.h || w != out.w)){
image resized = resize_image(out, w, h);
free_image(out);
out = resized;
}
return out;
}
以上代码说明函数load_image_color()以random_paths为参数将原始训练图片读取到变量orig中。我们继续看函数fill_truth_detection()做了什么。
void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
{
char labelpath[4096];
find_replace(path, "images", "labels", labelpath);
find_replace(labelpath, "JPEGImages", "labels", labelpath);
find_replace(labelpath, "raw", "labels", labelpath);
find_replace(labelpath, ".jpg", ".txt", labelpath);
find_replace(labelpath, ".png", ".txt", labelpath);
find_replace(labelpath, ".JPG", ".txt", labelpath);
find_replace(labelpath, ".JPEG", ".txt", labelpath);
int count = 0;
box_label *boxes = read_boxes(labelpath, &count);
randomize_boxes(boxes, count);
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
if(count > num_boxes) count = num_boxes;
float x,y,w,h;
int id;
int i;
for (i = 0; i < count; ++i) {
x = boxes[i].x;
y = boxes[i].y;
w = boxes[i].w;
h = boxes[i].h;
id = boxes[i].id;
if ((w < .001 || h < .001)) continue;
truth[i*5+0] = x;
truth[i*5+1] = y;
truth[i*5+2] = w;
truth[i*5+3] = h;
truth[i*5+4] = id;
}
free(boxes);
}
嘿嘿嘿哈哈哈!终于让我找着啦!这段代码就是yolo根据图片路径读取标签文档功能的所在之处!
首先开头7个find_replace()的作用是根据voc数据集目录结构将图片路径变换成为labels文档的路径。接下来是函数read_boxes(),让我们看看这个函数都做了什么。
box_label *read_boxes(char *filename, int *n)
{
box_label *boxes = calloc(1, sizeof(box_label));
FILE *file = fopen(filename, "r");
if(!file) file_error(filename);
float x, y, h, w;
int id;
int count = 0;
while(fscanf(file, "%d %f %f %f %f", &id, &x, &y, &w, &h) == 5){
boxes = realloc(boxes, (count+1)*sizeof(box_label));
boxes[count].id = id;
boxes[count].x = x;
boxes[count].y = y;
boxes[count].h = h;
boxes[count].w = w;
boxes[count].left = x - w/2;
boxes[count].right = x + w/2;
boxes[count].top = y - h/2;
boxes[count].bottom = y + h/2;
++count;
}
fclose(file);
*n = count;
return boxes;
}
哎,到了这里,就有一种拨开云雾见天明的感觉有木有,这个函数用labels文档的路径将所有的ground truth boxes信息存入一个box_label结构体的数组中,此结构体包含的成员变量为id[目标类别代码],x[bbox中心归一化横坐标],y[bbox中心归一化纵坐标],h[bbox归一化高度],w[bbox归一化宽度],left[归一化左边横坐标],right[归一化右边横坐标],top[归一化上边纵坐标],bottom[归一化下边纵坐标]。
emmm……找到了读取labels文档的代码,就可以很容易的修改了,但是不能只修改这一部分的代码嘛,还需要修改loss的计算部分,以及非极大值抑制的那一部分,但是想一下觉得非极大值抑制那部分通过一定的投机取巧可以不用修改,哈哈,接下来就是去分析loss的计算那部分代码啦。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)