forked from shouxieai/tensorRT_Pro
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp_plugin.cpp
More file actions
90 lines (69 loc) · 2.55 KB
/
app_plugin.cpp
File metadata and controls
90 lines (69 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <builder/trt_builder.hpp>
#include <infer/trt_infer.hpp>
#include <common/ilogger.hpp>
#include "app_yolo/yolo.hpp"
using namespace std;
static void test_hswish(TRT::Mode mode){
// The plugin.onnx can be generated by the following code
// cd workspace
// python test_plugin.py
iLogger::set_log_level(iLogger::LogLevel::Verbose);
TRT::set_device(0);
auto mode_name = TRT::mode_string(mode);
auto engine_name = iLogger::format("hswish.plugin.%s.trtmodel", mode_name);
TRT::compile(
mode, 3, "hswish.plugin.onnx", engine_name, {}
);
auto engine = TRT::load_infer(engine_name);
engine->print();
auto input0 = engine->input(0);
auto input1 = engine->input(1);
auto output = engine->output(0);
INFO("offset %d", output->offset(1, 0));
INFO("input0: %s", input0->shape_string());
INFO("input1: %s", input1->shape_string());
INFO("output: %s", output->shape_string());
float input0_val = 0.8;
float input1_val = 2;
input0->set_to(input0_val);
input1->set_to(input1_val);
auto hswish = [](float x){float a = x + 3; a=a<0?0:(a>=6?6:a); return x * a / 6;};
auto sigmoid = [](float x){return 1 / (1 + exp(-x));};
auto relu = [](float x){return max(0.0f, x);};
float output_real = relu(hswish(input0_val) * input1_val);
engine->forward(true);
INFO("output %f, output_real = %f", output->at<float>(0, 0), output_real);
}
static void test_dcnv2(TRT::Mode mode){
// The plugin.onnx can be generated by the following code
// cd workspace
// python test_plugin.py
iLogger::set_log_level(iLogger::LogLevel::Verbose);
TRT::set_device(0);
auto mode_name = TRT::mode_string(mode);
auto engine_name = iLogger::format("dcnv2.plugin.%s.trtmodel", mode_name);
TRT::compile(
mode, 1, "dcnv2.plugin.onnx", engine_name, {}
);
auto engine = TRT::load_infer(engine_name);
engine->print();
auto input0 = engine->input(0);
auto input1 = engine->input(1);
auto output = engine->output(0);
INFO("input0: %s", input0->shape_string());
INFO("input1: %s", input1->shape_string());
INFO("output: %s", output->shape_string());
float input0_val = 1;
float input1_val = 1;
input0->set_to(input0_val);
input1->set_to(input1_val);
engine->forward(true);
for(int i = 0; i < output->count(); ++i)
INFO("output[%d] = %f", i, output->cpu<float>()[i]);
}
int app_plugin(){
//test_hswish(TRT::Mode::FP32);
test_dcnv2(TRT::Mode::FP32);
//stest_plugin(TRT::Mode::FP16);
return 0;
}