forked from shouxieai/tensorRT_Pro
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdbface.cpp
More file actions
290 lines (240 loc) · 12.5 KB
/
dbface.cpp
File metadata and controls
290 lines (240 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#include "dbface.hpp"
#include <atomic>
#include <mutex>
#include <queue>
#include <condition_variable>
#include <infer/trt_infer.hpp>
#include <common/ilogger.hpp>
#include <common/infer_controller.hpp>
#include <common/preprocess_kernel.cuh>
#include <common/monopoly_allocator.hpp>
#include <common/cuda_tools.hpp>
namespace DBFace{
using namespace cv;
using namespace std;
void decode_kernel_invoker(float* pool_hm_ptr, float* hm_ptr, float* tlrb_ptr, float* landmark_ptr,
int fm_width, int fm_height, int stride,
float conf_T, float nms_threshold, float* invert_affine_matrix, float* parray,
int max_objects, cudaStream_t stream
);
struct AffineMatrix{
float i2d[6]; // image to dst(network), 2x3 matrix
float d2i[6]; // dst to image, 2x3 matrix
void compute(const cv::Size& from, const cv::Size& to){
float scale_x = to.width / (float)from.width;
float scale_y = to.height / (float)from.height;
float scale = std::min(scale_x, scale_y);
i2d[0] = scale; i2d[1] = 0; i2d[2] = -scale * from.width * 0.5 + to.width * 0.5;
i2d[3] = 0; i2d[4] = scale; i2d[5] = -scale * from.height * 0.5 + to.height * 0.5;
cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);
cv::Mat m2x3_d2i(2, 3, CV_32F, d2i);
cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);
}
cv::Mat i2d_mat(){
return cv::Mat(2, 3, CV_32F, i2d);
}
};
using ControllerImpl = InferController
<
Mat, // input
BoxArray, // output
tuple<string, int>, // start param
AffineMatrix // additional
>;
class InferImpl : public Infer, public ControllerImpl{
public:
virtual bool startup(const string& file, int gpuid, float confidence_threshold, float nms_threshold){
float mean[] = {0.408, 0.447, 0.470};
float std[] = {0.289, 0.274, 0.278};
normalize_ = CUDAKernel::Norm::mean_std(mean, std, 1/255.0f, CUDAKernel::ChannelType::None);
confidence_threshold_ = confidence_threshold;
nms_threshold_ = nms_threshold;
return ControllerImpl::startup(make_tuple(file, gpuid));
}
virtual void worker(promise<bool>& result) override{
string file = get<0>(start_param_);
int gpuid = get<1>(start_param_);
TRT::set_device(gpuid);
auto engine = TRT::load_infer(file);
if(engine == nullptr){
INFOE("Engine %s load failed", file.c_str());
result.set_value(false);
return;
}
engine->print();
const int MAX_IMAGE_BBOX = 1024;
const int NUM_BOX_ELEMENT = 17; // left, top, right, bottom, confidence, class, keepflag + 10landmark
TRT::Tensor affin_matrix_device(TRT::DataType::Float);
TRT::Tensor output_array_device(TRT::DataType::Float);
int max_batch_size = engine->get_max_batch_size();
auto input = engine->input(0);
auto pool_hm = engine->tensor("pool_hm");
auto hm = engine->tensor("hm");
auto tlrb = engine->tensor("tlrb");
auto landmark = engine->tensor("landmark");
const int stride = 4;
input_width_ = input->size(3);
input_height_ = input->size(2);
int fm_width = input_width_ / stride;
int fm_height = input_height_ / stride;
tensor_allocator_ = make_shared<MonopolyAllocator<TRT::Tensor>>(max_batch_size * 2);
stream_ = engine->get_stream();
gpu_ = gpuid;
result.set_value(true);
input->resize_single_dim(0, max_batch_size).to_gpu();
affin_matrix_device.set_stream(stream_);
// the nubmer 8 here means 8 * sizeof(float) % 32 == 0
affin_matrix_device.resize(max_batch_size, 8).to_gpu();
// 1 image has n detected bboxexs, which can be expressed as [counter, bbox0, bbox1...bboxn]
output_array_device.resize(max_batch_size, 1 + MAX_IMAGE_BBOX * NUM_BOX_ELEMENT).to_gpu();
vector<Job> fetch_jobs;
while(get_jobs_and_wait(fetch_jobs, max_batch_size)){
int infer_batch_size = fetch_jobs.size();
input->resize_single_dim(0, infer_batch_size);
for(int ibatch = 0; ibatch < infer_batch_size; ++ibatch){
auto& job = fetch_jobs[ibatch];
auto& mono = job.mono_tensor->data();
affin_matrix_device.copy_from_gpu(affin_matrix_device.offset(ibatch), mono->get_workspace()->gpu(), 6);
input->copy_from_gpu(input->offset(ibatch), mono->gpu(), mono->count());
job.mono_tensor->release();
}
engine->forward(false);
output_array_device.to_gpu(false);
for(int ibatch = 0; ibatch < infer_batch_size; ++ibatch){
auto& job = fetch_jobs[ibatch];
float* pool_hm_ptr = pool_hm->gpu<float>(ibatch);
float* hm_ptr = hm->gpu<float>(ibatch);
float* tlrb_ptr = tlrb->gpu<float>(ibatch);
float* landmark_ptr = landmark->gpu<float>(ibatch);
float* output_array_ptr = output_array_device.gpu<float>(ibatch);
auto affine_matrix = affin_matrix_device.gpu<float>(ibatch);
checkCudaRuntime(cudaMemsetAsync(output_array_ptr, 0, sizeof(int), stream_));
decode_kernel_invoker(
pool_hm_ptr, hm_ptr, tlrb_ptr, landmark_ptr,
fm_width, fm_height,
stride,
confidence_threshold_,
nms_threshold_,
affine_matrix,
output_array_ptr,
MAX_IMAGE_BBOX,
stream_
);
}
output_array_device.to_cpu();
for(int ibatch = 0; ibatch < infer_batch_size; ++ibatch){
float* parray = output_array_device.cpu<float>(ibatch);
int count = min(MAX_IMAGE_BBOX, (int)*parray);
auto& job = fetch_jobs[ibatch];
auto& image_based_boxes = job.output;
for(int i = 0; i < count; ++i){
float* pbox = parray + 1 + i * NUM_BOX_ELEMENT;
int label = pbox[5];
int keepflag = pbox[6];
if(keepflag == 1){
FaceDetector::Box box;
box.left = pbox[0];
box.top = pbox[1];
box.right = pbox[2];
box.bottom = pbox[3];
box.confidence = pbox[4];
memcpy(box.landmark, pbox + 7, sizeof(box.landmark));
image_based_boxes.emplace_back(box);
}
}
job.pro->set_value(image_based_boxes);
}
fetch_jobs.clear();
}
INFO("Engine destroy.");
}
virtual bool preprocess(Job& job, const Mat& image) override{
job.mono_tensor = tensor_allocator_->query();
if(job.mono_tensor == nullptr){
INFOE("Tensor allocator query failed.");
return false;
}
CUDATools::AutoDevice auto_device(gpu_);
auto& tensor = job.mono_tensor->data();
if(tensor == nullptr){
// not init
tensor = make_shared<TRT::Tensor>();
tensor->set_workspace(make_shared<TRT::MixMemory>());
}
Size input_size(input_width_, input_height_);
job.additional.compute(image.size(), input_size);
tensor->set_stream(stream_);
tensor->resize(1, 3, input_height_, input_width_);
size_t size_image = image.cols * image.rows * 3;
size_t size_matrix = iLogger::upbound(sizeof(job.additional.d2i), 32);
auto workspace = tensor->get_workspace();
uint8_t* gpu_workspace = (uint8_t*)workspace->gpu(size_matrix + size_image);
float* affine_matrix_device = (float*)gpu_workspace;
uint8_t* image_device = size_matrix + gpu_workspace;
uint8_t* cpu_workspace = (uint8_t*)workspace->cpu(size_matrix + size_image);
float* affine_matrix_host = (float*)cpu_workspace;
uint8_t* image_host = size_matrix + cpu_workspace;
memcpy(image_host, image.data, size_image);
memcpy(affine_matrix_host, job.additional.d2i, sizeof(job.additional.d2i));
checkCudaRuntime(cudaMemcpyAsync(image_device, image_host, size_image, cudaMemcpyHostToDevice, stream_));
checkCudaRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(job.additional.d2i), cudaMemcpyHostToDevice, stream_));
CUDAKernel::warp_affine_bilinear_and_normalize_plane(
image_device, image.cols * 3, image.cols, image.rows,
tensor->gpu<float>(), input_width_, input_height_,
affine_matrix_device, 0, // note
normalize_, stream_
);
return true;
}
virtual vector<shared_future<BoxArray>> commits(const vector<Mat>& images) override{
return ControllerImpl::commits(images);
}
virtual std::shared_future<BoxArray> commit(const Mat& image) override{
return ControllerImpl::commit(image);
}
private:
int input_width_ = 0;
int input_height_ = 0;
int gpu_ = 0;
float confidence_threshold_ = 0;
float nms_threshold_ = 0;
TRT::CUStream stream_ = nullptr;
CUDAKernel::Norm normalize_;
};
shared_ptr<Infer> create_infer(const string& engine_file, int gpuid, float confidence_threshold, float nms_threshold){
shared_ptr<InferImpl> instance(new InferImpl());
if(!instance->startup(engine_file, gpuid, confidence_threshold, nms_threshold)){
instance.reset();
}
return instance;
}
void image_to_tensor(const cv::Mat& image, shared_ptr<TRT::Tensor>& tensor, int ibatch){
float mean[] = {0.408, 0.447, 0.470};
float std[] = {0.289, 0.274, 0.278};
auto normalize = CUDAKernel::Norm::mean_std(mean, std, 1/255.0f, CUDAKernel::ChannelType::None);
Size input_size(tensor->size(3), tensor->size(2));
AffineMatrix affine;
affine.compute(image.size(), input_size);
size_t size_image = image.cols * image.rows * 3;
size_t size_matrix = iLogger::upbound(sizeof(affine.d2i), 32);
auto workspace = tensor->get_workspace();
uint8_t* gpu_workspace = (uint8_t*)workspace->gpu(size_matrix + size_image);
float* affine_matrix_device = (float*)gpu_workspace;
uint8_t* image_device = size_matrix + gpu_workspace;
uint8_t* cpu_workspace = (uint8_t*)workspace->cpu(size_matrix + size_image);
float* affine_matrix_host = (float*)cpu_workspace;
uint8_t* image_host = size_matrix + cpu_workspace;
auto stream = tensor->get_stream();
memcpy(image_host, image.data, size_image);
memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
checkCudaRuntime(cudaMemcpyAsync(image_device, image_host, size_image, cudaMemcpyHostToDevice, stream));
checkCudaRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i), cudaMemcpyHostToDevice, stream));
CUDAKernel::warp_affine_bilinear_and_normalize_plane(
image_device, image.cols * 3, image.cols, image.rows,
tensor->gpu<float>(ibatch), input_size.width, input_size.height,
affine_matrix_device, 0,
normalize, stream
);
tensor->synchronize();
}
};