Mercurial > hg > Game > Cerium
comparison example/cuda_fft/main.cc @ 2008:2c8eab01cc78 draft
implement fft using cuda
author | Shohei KOKUBO <e105744@ie.u-ryukyu.ac.jp> |
---|---|
date | Tue, 03 Jun 2014 18:10:19 +0900 |
parents | bc2121b09cbc |
children | 6fced32f85fd |
comparison
equal
deleted
inserted
replaced
2007:bc2121b09cbc | 2008:2c8eab01cc78 |
---|---|
1 #include <stdio.h> | 1 #include <stdio.h> |
2 #include <sys/time.h> | 2 #include <sys/time.h> |
3 #include <string.h> | 3 #include <string.h> |
4 #include <cuda.h> | 4 #include <cuda.h> |
5 #include <vector_types.h> | |
5 | 6 |
6 #include "pgm.h" | 7 #include "pgm.h" |
7 | 8 |
8 #define PI 3.14159265358979 | 9 #define PI 3.14159265358979 |
9 #define MAX_SOURCE_SIZE (0x100000) | 10 #define MAX_SOURCE_SIZE (0x100000) |
14 enum Mode { | 15 enum Mode { |
15 forward = 0, | 16 forward = 0, |
16 inverse = 1 | 17 inverse = 1 |
17 }; | 18 }; |
18 | 19 |
19 struct int2 { | |
20 int x; | |
21 int y; | |
22 }; | |
23 | |
24 struct float2 { | |
25 float x; | |
26 float y; | |
27 }; | |
28 | |
29 CUmodule module; | 20 CUmodule module; |
30 | 21 |
31 static double | 22 static double |
32 getTime() { | 23 getTime() { |
33 struct timeval tv; | 24 struct timeval tv; |
38 int | 29 int |
39 setWorkSize(int* block, int* thread, int x, int y) | 30 setWorkSize(int* block, int* thread, int x, int y) |
40 { | 31 { |
41 switch(y) { | 32 switch(y) { |
42 case 1: | 33 case 1: |
43 block = x; | 34 *block = x; |
44 thread = 1; | 35 *thread = 1; |
45 break; | 36 break; |
46 default: | 37 default: |
47 block = x; | 38 *block = x; |
48 thread = y; | 39 *thread = y; |
49 break; | 40 break; |
50 } | 41 } |
51 | 42 |
52 return 0; | 43 return 0; |
53 } | 44 } |
64 int n = 1<<m; | 55 int n = 1<<m; |
65 int block, thread; | 56 int block, thread; |
66 setWorkSize(&block, &thread, n, n); | 57 setWorkSize(&block, &thread, n, n); |
67 | 58 |
68 CUfunction bitReverse; | 59 CUfunction bitReverse; |
69 cuModuleGetFunction(bitReverse, module, "bitReverse"); | 60 cuModuleGetFunction(&bitReverse, module, "bitReverse"); |
70 | 61 |
71 void* kernel_args[] = {&dst, &src, &m, &n}; | 62 void* bitReverse_args[] = {&dst, &src, &m, &n}; |
72 | 63 |
73 cuLaunchKernel(bitReverse, | 64 cuLaunchKernel(bitReverse, |
74 block, 1, 1, | 65 block, 1, 1, |
75 thread, 1, 1, | 66 thread, 1, 1, |
76 0, NULL, kernel_args, NULL); | 67 0, NULL, bitReverse_args, NULL); |
77 | 68 |
78 CUfunction butterfly; | 69 CUfunction butterfly; |
79 cuModuleGetFunction(butterfly, module, "butterfly"); | 70 cuModuleGetFunction(&butterfly, module, "butterfly"); |
80 | 71 |
81 setWorkSize(&block, &thread, n/2, n); | 72 setWorkSize(&block, &thread, n/2, n); |
73 void* butterfly_args[] = {&dst, &spin, &m, &n, 0, &flag}; | |
82 for (int i=1;i<=m;i++) { | 74 for (int i=1;i<=m;i++) { |
83 kernel_args[] = {&dst, &spin, &m, &n, &i, &flag}; | 75 butterfly_args[4] = &i; |
84 cuLaunchKernel(butterfly, | 76 cuLaunchKernel(butterfly, |
85 block, 1, 1, | 77 block, 1, 1, |
86 thread, 1, 1, | 78 thread, 1, 1, |
87 0, NULL, kernel_args, NULL); | 79 0, NULL, butterfly_args, NULL); |
88 } | 80 } |
89 | 81 |
90 CUfunction norm; | 82 CUfunction norm; |
91 cuModuleGetFunction(norm, module, "norm"); | 83 cuModuleGetFunction(&norm, module, "norm"); |
92 | 84 |
85 void* norm_args[] = {&dst, &m}; | |
93 if (direction == inverse) { | 86 if (direction == inverse) { |
94 setWorkSize(&block, &thread, n, n); | 87 setWorkSize(&block, &thread, n, n); |
95 kernel_args[] = {&dst, &m}; | |
96 cuLaunchKernel(norm, | 88 cuLaunchKernel(norm, |
97 block, 1, 1, | 89 block, 1, 1, |
98 thread, 1, 1, | 90 thread, 1, 1, |
99 0, NULL, kernel_args, NULL); | 91 0, NULL, norm_args, NULL); |
100 } | 92 } |
101 | 93 |
102 return 0; | 94 return 0; |
103 } | 95 } |
104 | 96 |
130 CUcontext context; | 122 CUcontext context; |
131 cuCtxCreate(&context, CU_CTX_SCHED_SPIN, device); | 123 cuCtxCreate(&context, CU_CTX_SCHED_SPIN, device); |
132 | 124 |
133 cuModuleLoad(&module, "fft.ptx"); | 125 cuModuleLoad(&module, "fft.ptx"); |
134 | 126 |
135 char* pgm_file = init(argc, argv); | 127 char* pgm_file = init(args, argv); |
136 | 128 |
137 pgm_t ipgm; | 129 pgm_t ipgm; |
138 int err = readPGM(&ipgm, pgm_file); | 130 int err = readPGM(&ipgm, pgm_file); |
139 if (err<0) { | 131 if (err<0) { |
140 fprintf(stderr, "Failed to read image file.\n"); | 132 fprintf(stderr, "Failed to read image file.\n"); |
169 | 161 |
170 // Synchronous data transfer(host to device) | 162 // Synchronous data transfer(host to device) |
171 cuMemcpyHtoD(xmobj, xm, n*n*sizeof(float2)); | 163 cuMemcpyHtoD(xmobj, xm, n*n*sizeof(float2)); |
172 | 164 |
173 CUfunction spinFact; | 165 CUfunction spinFact; |
174 cuModuleGetFunction(spinFact, module, "spinFact"); | 166 cuModuleGetFunction(&spinFact, module, "spinFact"); |
175 | 167 |
176 int block, thread; | 168 int block, thread; |
177 setWorkSize(&block, &thread, n/2, 1); | 169 setWorkSize(&block, &thread, n/2, 1); |
178 | 170 |
179 void* kernel_args[] = {&xmobj, &n}; | 171 void* spinFact_args[] = {&xmobj, &n}; |
180 cuLaunchKernel(spinFact, | 172 cuLaunchKernel(spinFact, |
181 block, 1, 1, | 173 block, 1, 1, |
182 thread, 1, 1, | 174 thread, 1, 1, |
183 0, NULL, kernel_args, NULL); | 175 0, NULL, spinFact_args, NULL); |
184 | 176 |
185 fftCore(rmobj, xmobj, wmobj, m, forward); | 177 fftCore(rmobj, xmobj, wmobj, m, forward); |
186 | 178 |
187 CUfunction transfer; | 179 CUfunction transpose; |
188 cuModuleGetFunction(transfer, module, "transfer"); | 180 cuModuleGetFunction(&transpose, module, "transpose"); |
189 | 181 |
190 setWorkSize(&block, &thread, n, n); | 182 setWorkSize(&block, &thread, n, n); |
191 | 183 |
192 kernel_args[] = {&xmobj, &rmobj, &n}; | 184 void* transpose_args[] = {&xmobj, &rmobj, &n}; |
193 cuLaunchKernel(transfer, | 185 cuLaunchKernel(transpose, |
194 block, 1, 1, | 186 block, 1, 1, |
195 thread, 1, 1, | 187 thread, 1, 1, |
196 0, NULL, kernel_args, NULL); | 188 0, NULL, transpose_args, NULL); |
197 | 189 |
198 fftCore(rmobj, xmobj, wmobj, m, forward); | 190 fftCore(rmobj, xmobj, wmobj, m, forward); |
199 | 191 |
200 CUfunction highPassFilter; | 192 CUfunction highPassFilter; |
201 cuModuleGetFunction(transfer, module, "highPassFilter"); | 193 cuModuleGetFunction(&highPassFilter, module, "highPassFilter"); |
202 | 194 |
203 setWorkSize(&block, &thread, n, n); | 195 setWorkSize(&block, &thread, n, n); |
204 | 196 |
205 int radius = n/8; | 197 int radius = n/8; |
206 kernel_args[] = {&rmobj, &n, &radius}; | 198 void*highPassFilter_args[] = {&rmobj, &n, &radius}; |
207 cuLaunchKernel(highPassFilter, | 199 cuLaunchKernel(highPassFilter, |
208 block, 1, 1, | 200 block, 1, 1, |
209 thread, 1, 1, | 201 thread, 1, 1, |
210 0, NULL, kernel_args, NULL); | 202 0, NULL, highPassFilter_args, NULL); |
211 | 203 |
212 fftCore(xmobj, rmobj, wmobj, m, inverse); | 204 fftCore(xmobj, rmobj, wmobj, m, inverse); |
213 | 205 |
214 setWorkSize(&block, &thread, n, n); | 206 setWorkSize(&block, &thread, n, n); |
215 | 207 |
216 kernel_args[] = {&rmobj, &xmobj}; | 208 void* transpose2_args[] = {&rmobj, &xmobj, &n}; |
217 cuLaunchKernel(transfer, | 209 cuLaunchKernel(transpose, |
218 block, 1, 1, | 210 block, 1, 1, |
219 thread, 1, 1, | 211 thread, 1, 1, |
220 0, NULL, kernel_args, NULL); | 212 0, NULL, transpose2_args, NULL); |
221 | 213 |
222 fftCore(xmobj, rmobj, wmobj, m, inverse); | 214 fftCore(xmobj, rmobj, wmobj, m, inverse); |
223 | 215 |
224 | 216 cuMemcpyDtoH(xm, xmobj, n*n*sizeof(float2)); |
217 | |
218 float* ampd; | |
219 ampd = (float*)malloc(n*n*sizeof(float)); | |
220 | |
221 for (int i=0;i<n*n;i++) | |
222 ampd[i] = (AMP(xm[i].x, xm[i].y)); | |
223 | |
224 opgm.width = n; | |
225 opgm.height = n; | |
226 normalizeF2PGM(&opgm, ampd); | |
227 free(ampd); | |
228 | |
229 ed_time = getTime(); | |
230 | |
231 writePGM(&opgm, "output.pgm"); | |
225 | 232 |
226 // memory release | 233 // memory release |
227 cuMemFree(devA); | 234 cuMemFree(xmobj); |
228 for (int i=0;i<num_exec;i++) { | 235 cuMemFree(rmobj); |
229 cuMemFree(devB[i]); | 236 cuMemFree(wmobj); |
230 cuMemFree(devOut[i]); | |
231 } | |
232 for (int i=0;i<num_stream;i++) | |
233 cuStreamDestroy(stream[i]); | |
234 cuModuleUnload(module); | 237 cuModuleUnload(module); |
235 cuCtxDestroy(context); | 238 cuCtxDestroy(context); |
236 | 239 |
237 delete[] A; | 240 destroyPGM(&ipgm); |
238 delete[] B; | 241 destroyPGM(&opgm); |
239 for (int i=0;i<num_exec;i++) | 242 |
240 delete[] result[i]; | 243 free(xm); |
241 delete[] result; | 244 free(rm); |
245 free(wm); | |
246 | |
247 printf("Time: %0.6f\n", ed_time-st_time); | |
242 | 248 |
243 return 0; | 249 return 0; |
244 } | 250 } |