1
+ import pycuda .autoinit
2
+ import pycuda .driver as drv
3
+ import numpy , math , sys
4
+ from pycuda .compiler import DynamicSourceModule
5
+
6
+ if len (sys .argv )> 2 and sys .argv [1 ]== '-double' :
7
+ real_py = 'float64'
8
+ real_cpp = 'double'
9
+ else :
10
+ real_py = 'float32'
11
+ real_cpp = 'float'
12
+
13
+ mod = DynamicSourceModule (r"""
14
+ #include <cooperative_groups.h>
15
+ using namespace cooperative_groups;
16
+ const unsigned FULL_MASK = 0xffffffff;
17
+
18
+ extern "C"{void __global__ reduce_syncwarp(const real *d_x, real *d_y, const int N)
19
+ {
20
+ const int tid = threadIdx.x;
21
+ const int bid = blockIdx.x;
22
+ const int n = bid * blockDim.x + tid;
23
+ extern __shared__ real s_y[];
24
+ s_y[tid] = (n < N) ? d_x[n] : 0.0;
25
+ __syncthreads();
26
+
27
+ for (int offset = blockDim.x >> 1; offset >= 32; offset >>= 1)
28
+ {
29
+ if (tid < offset)
30
+ {
31
+ s_y[tid] += s_y[tid + offset];
32
+ }
33
+ __syncthreads();
34
+ }
35
+
36
+ for (int offset = 16; offset > 0; offset >>= 1)
37
+ {
38
+ if (tid < offset)
39
+ {
40
+ s_y[tid] += s_y[tid + offset];
41
+ }
42
+ __syncwarp();
43
+ }
44
+
45
+ if (tid == 0)
46
+ {
47
+ atomicAdd(d_y, s_y[0]);
48
+ }
49
+ }}
50
+
51
+ extern "C"{void __global__ reduce_shfl(const real *d_x, real *d_y, const int N)
52
+ {
53
+ const int tid = threadIdx.x;
54
+ const int bid = blockIdx.x;
55
+ const int n = bid * blockDim.x + tid;
56
+ extern __shared__ real s_y[];
57
+ s_y[tid] = (n < N) ? d_x[n] : 0.0;
58
+ __syncthreads();
59
+
60
+ for (int offset = blockDim.x >> 1; offset >= 32; offset >>= 1)
61
+ {
62
+ if (tid < offset)
63
+ {
64
+ s_y[tid] += s_y[tid + offset];
65
+ }
66
+ __syncthreads();
67
+ }
68
+
69
+ real y = s_y[tid];
70
+
71
+ for (int offset = 16; offset > 0; offset >>= 1)
72
+ {
73
+ y += __shfl_down_sync(FULL_MASK, y, offset);
74
+ }
75
+
76
+ if (tid == 0)
77
+ {
78
+ atomicAdd(d_y, y);
79
+ }
80
+ }
81
+ }
82
+
83
+ extern "C"{void __global__ reduce_cp(const real *d_x, real *d_y, const int N)
84
+ {
85
+ const int tid = threadIdx.x;
86
+ const int bid = blockIdx.x;
87
+ const int n = bid * blockDim.x + tid;
88
+ extern __shared__ real s_y[];
89
+ s_y[tid] = (n < N) ? d_x[n] : 0.0;
90
+ __syncthreads();
91
+
92
+ for (int offset = blockDim.x >> 1; offset >= 32; offset >>= 1)
93
+ {
94
+ if (tid < offset)
95
+ {
96
+ s_y[tid] += s_y[tid + offset];
97
+ }
98
+ __syncthreads();
99
+ }
100
+
101
+ real y = s_y[tid];
102
+
103
+ thread_block_tile<32> g = tiled_partition<32>(this_thread_block());
104
+ for (int i = g.size() >> 1; i > 0; i >>= 1)
105
+ {
106
+ y += g.shfl_down(y, i);
107
+ }
108
+
109
+ if (tid == 0)
110
+ {
111
+ atomicAdd(d_y, y);
112
+ }
113
+ }
114
+ }
115
+ """ .replace ('real' , real_cpp ), no_extern_c = True )
116
+ reduce_syncwarp = mod .get_function ("reduce_syncwarp" )
117
+ reduce_shfl = mod .get_function ("reduce_shfl" )
118
+ reduce_cp = mod .get_function ("reduce_cp" )
119
+
120
+
121
+
122
+ def timing (method ):
123
+ NUM_REPEATS = 10
124
+ N = 100000000
125
+ BLOCK_SIZE = 128
126
+ grid_size = (N - 1 )// 128 + 1
127
+ h_x = numpy .full ((N ,1 ), 1.23 , dtype = real_py )
128
+ d_x = drv .mem_alloc (h_x .nbytes )
129
+ drv .memcpy_htod (d_x , h_x )
130
+ t_sum = 0
131
+ t2_sum = 0
132
+ for repeat in range (NUM_REPEATS + 1 ):
133
+ start = drv .Event ()
134
+ stop = drv .Event ()
135
+ start .record ()
136
+
137
+ h_y = numpy .zeros ((1 ,1 ), dtype = real_py )
138
+ d_y = drv .mem_alloc (h_y .nbytes )
139
+ drv .memcpy_htod (d_y , h_y )
140
+ if method == 0 :
141
+ reduce_syncwarp (d_x , d_y , numpy .int32 (N ), grid = (grid_size , 1 ), block = (128 ,1 ,1 ), shared = numpy .zeros ((1 ,1 ),dtype = real_py ).nbytes * BLOCK_SIZE )
142
+ elif method == 1 :
143
+ reduce_shfl (d_x , d_y , numpy .int32 (N ), grid = ((N - 1 )// 128 + 1 , 1 ), block = (128 ,1 ,1 ), shared = numpy .zeros ((1 ,1 ),dtype = real_py ).nbytes * BLOCK_SIZE )
144
+ elif method == 2 :
145
+ reduce_cp (d_x , d_y , numpy .int32 (N ), grid = ((N - 1 )// 128 + 1 , 1 ), block = (128 ,1 ,1 ), shared = numpy .zeros ((1 ,1 ),dtype = real_py ).nbytes * BLOCK_SIZE )
146
+ else :
147
+ print ("Error: wrong method" )
148
+ break
149
+ drv .memcpy_dtoh (h_y , d_y )
150
+ v_sum = h_y [0 ,0 ]
151
+
152
+ stop .record ()
153
+ stop .synchronize ()
154
+ elapsed_time = start .time_till (stop )
155
+ print ("Time = {:.6f} ms." .format (elapsed_time ))
156
+ if repeat > 0 :
157
+ t_sum += elapsed_time
158
+ t2_sum += elapsed_time * elapsed_time
159
+ t_ave = t_sum / NUM_REPEATS
160
+ t_err = math .sqrt (t2_sum / NUM_REPEATS - t_ave * t_ave )
161
+ print ("Time = {:.6f} +- {:.6f} ms." .format (t_ave , t_err ))
162
+ print ("sum = " , v_sum )
163
+
164
+
165
+ print ("\n using syncwarp:" )
166
+ timing (0 )
167
+ print ("\n using shfl:" )
168
+ timing (1 )
169
+ print ("\n using cooperative group:" )
170
+ timing (2 )
0 commit comments