173
|
1 // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
|
|
2
|
|
3 // CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)>
|
|
4 // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
|
|
5
|
|
6 // CHECK-LABEL: func @add4x2
|
|
7 // CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
8 // CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
9 // CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
10 // CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
11 // CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32>
|
|
12 // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
13 // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
14 // CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32>
|
|
15 // CHECK-NEXT: %[[R1:.*]] = vector.tuple %[[A1]], %[[A2]] : vector<2x2xf32>, vector<2x2xf32>
|
|
16 // CHECK-NEXT: %[[R2:.*]] = vector.insert_slices %[[R1]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>> into vector<4x2xf32>
|
|
17 // CHECK-NEXT: return %[[R2:.*]] : vector<4x2xf32>
|
|
18
|
|
19 func @add4x2(%0: vector<4x2xf32>) -> vector<4x2xf32> {
|
|
20 %1 = addf %0, %0: vector<4x2xf32>
|
|
21 return %1: vector<4x2xf32>
|
|
22 }
|
|
23
|
|
24 // CHECK-LABEL: func @add4x4
|
|
25 // CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
26 // CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
27
|
|
28 // CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
29 // CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
30 // CHECK-NEXT: %[[A1:.*]] = addf %[[TG1]], %[[TG2]] : vector<2x2xf32>
|
|
31
|
|
32 // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
33 // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
34 // CHECK-NEXT: %[[A2:.*]] = addf %[[TG3]], %[[TG4]] : vector<2x2xf32>
|
|
35
|
|
36 // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
37 // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
38 // CHECK-NEXT: %[[A3:.*]] = addf %[[TG5]], %[[TG6]] : vector<2x2xf32>
|
|
39
|
|
40 // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
41 // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
42 // CHECK-NEXT: %[[A4:.*]] = addf %[[TG7]], %[[TG8]] : vector<2x2xf32>
|
|
43
|
|
44 // CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
45
|
|
46 // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
47 // CHECK-NEXT: %[[A5:.*]] = addf %[[TG9]], %[[A1]] : vector<2x2xf32>
|
|
48
|
|
49 // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
50 // CHECK-NEXT: %[[A6:.*]] = addf %[[TG11]], %[[A2]] : vector<2x2xf32>
|
|
51
|
|
52 // CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
53 // CHECK-NEXT: %[[A7:.*]] = addf %[[TG13]], %[[A3]] : vector<2x2xf32>
|
|
54
|
|
55 // CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
56 // CHECK-NEXT: %[[A8:.*]] = addf %[[TG15]], %[[A4]] : vector<2x2xf32>
|
|
57
|
|
58 // CHECK-NEXT: %[[R3:.*]] = vector.tuple %[[A5]], %[[A6]], %[[A7]], %[[A8]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
|
59 // CHECK-NEXT: %[[R4:.*]] = vector.insert_slices %[[R3]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
|
|
60 // CHECK-NEXT: return %[[R4]] : vector<4x4xf32>
|
|
61
|
|
62 func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
|
|
63 %2 = addf %0, %1: vector<4x4xf32>
|
|
64 %3 = addf %1, %2: vector<4x4xf32>
|
|
65 return %3: vector<4x4xf32>
|
|
66 }
|
|
67
|
|
68 #contraction_accesses0 = [
|
|
69 affine_map<(i, j, k) -> (i, k)>,
|
|
70 affine_map<(i, j, k) -> (k, j)>,
|
|
71 affine_map<(i, j, k) -> (i, j)>
|
|
72 ]
|
|
73 #contraction_trait0 = {
|
|
74 indexing_maps = #contraction_accesses0,
|
|
75 iterator_types = ["parallel", "parallel", "reduction"]
|
|
76 }
|
|
77
|
|
78 // CHECK-LABEL: func @contraction4x4_ijk
|
|
79
|
|
80 // CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 6] : vector<4x6xi1>
|
|
81 // CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [6, 4] : vector<6x4xi1>
|
|
82
|
|
83 // Reducing output vector [0, 0]
|
|
84
|
|
85 // CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x6xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
86 // CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<6x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
87 // CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
88 // CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %[[LMASK]], [2, 2], [1, 1] : vector<4x6xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
89 // CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %[[RMASK]], [2, 2], [1, 1] : vector<6x4xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
90
|
|
91 // CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
92 // CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
93 // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
94 // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
95 // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
96 // CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
97
|
|
98 // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
99 // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
100 // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
101 // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
102 // CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
103
|
|
104 // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
105 // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
106 // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
107 // CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
108 // CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
109
|
|
110 // Reducing output vector [0, 2]
|
|
111
|
|
112 // CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
113 // CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
114 // CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
115 // CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
116
|
|
117 // CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
118 // CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
119 // CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
120
|
|
121 // CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
122 // CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
123 // CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
124
|
|
125 // Reducing output vector [2, 0]
|
|
126
|
|
127 // CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
128 // CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
129 // CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
130 // CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
131
|
|
132 // CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
133 // CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
134 // CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
135
|
|
136 // CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
137 // CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
|
|
138 // CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
139
|
|
140 // Reducing output vector [2, 2]
|
|
141
|
|
142 // CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
143 // CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
144 // CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
145 // CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
146
|
|
147 // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
|
148 // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
|
|
149 // CHECK-NEXT: return %[[RES1]] : vector<4x4xf32>
|
|
150
|
|
151 func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
|
152 %arg2 : vector<4x4xf32>, %arg3 : index)
|
|
153 -> (vector<4x4xf32>) {
|
|
154 %lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
|
|
155 %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
|
|
156 %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
|
157 : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
|
|
158
|
|
159 return %0 : vector<4x4xf32>
|
|
160 }
|
|
161
|
|
162 #contraction_accesses1 = [
|
|
163 affine_map<(i, k, j) -> (i, k)>,
|
|
164 affine_map<(i, k, j) -> (k, j)>,
|
|
165 affine_map<(i, k, j) -> (i, j)>
|
|
166 ]
|
|
167 #contraction_trait1 = {
|
|
168 indexing_maps = #contraction_accesses1,
|
|
169 iterator_types = ["parallel", "reduction", "parallel"]
|
|
170 }
|
|
171
|
|
172 // CHECK-LABEL: func @contraction4x4_ikj
|
|
173
|
|
174
|
|
175 // CHECK: %[[LMASK:.*]] = vector.constant_mask [4, 2] : vector<4x2xi1>
|
|
176 // CHECK-NEXT: %[[RMASK:.*]] = vector.constant_mask [2, 4] : vector<2x4xi1>
|
|
177
|
|
178 // Reducing output vector [0, 0]
|
|
179
|
|
180 // CHECK-NEXT: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
181 // CHECK-NEXT: %[[ES2:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
182 // CHECK-NEXT: %[[ES3:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
183 // CHECK-NEXT: %[[ES4:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>>
|
|
184 // CHECK-NEXT: %[[ES5:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<2x4xi1> into tuple<vector<2x2xi1>, vector<2x2xi1>>
|
|
185
|
|
186 // CHECK-NEXT: %[[TG1:.*]] = vector.tuple_get %[[ES1]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
187 // CHECK-NEXT: %[[TG2:.*]] = vector.tuple_get %[[ES2]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
188 // CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
189 // CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
|
|
190 // CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
|
|
191 // CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
192
|
|
193 // Reducing output vector [0, 2]
|
|
194
|
|
195 // CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
196 // CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
197 // CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
|
|
198 // CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
199
|
|
200 // Reducing output vector [2, 0]
|
|
201
|
|
202 // CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
|
|
203 // CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
204 // CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
|
|
205 // CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
206
|
|
207 // Reducing output vector [2, 2]
|
|
208
|
|
209 // CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
|
|
210 // CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
211
|
|
212 // CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
|
|
213 // CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
|
|
214 // CHECK-NEXT: return %[[RES1]] : vector<4x4xf32>
|
|
215
|
|
216 func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
|
|
217 %arg2 : vector<4x4xf32>, %arg3 : index)
|
|
218 -> (vector<4x4xf32>) {
|
|
219 %lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
|
|
220 %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
|
|
221 %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
|
222 : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
|
|
223
|
|
224 return %0 : vector<4x4xf32>
|
|
225 }
|
|
226
|
|
227 // CHECK-LABEL: func @contraction4x4_ikj_xfer_read
|
|
228
|
|
229 // CHECK: %[[C0:.*]] = constant 0 : index
|
|
230 // CHECK: %[[C2:.*]] = constant 2 : index
|
|
231
|
|
232 // Check LHS vector.transfer read is split for each user.
|
|
233
|
|
234 // CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
|
|
235 // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x2xf32>, vector<2x2xf32>
|
|
236
|
|
237 // CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
|
|
238 // CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<2x4xf32>, vector<2x2xf32>
|
|
239
|
|
240 // CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
|
241 // CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
|
242 // CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
|
243 // CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
|
|
244
|
|
245 // CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
246 // CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
247 // CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
248 // CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
|
249
|
|
250 // CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32>
|
|
251 // CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32>
|
|
252 // CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32>
|
|
253 // CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32>
|
|
254 // CHECK-NEXT: return
|
|
255
|
|
256 func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>,
|
|
257 %arg1 : memref<2x4xf32>,
|
|
258 %arg2 : memref<4x4xf32>) {
|
|
259 %c0 = constant 0 : index
|
|
260 %cf0 = constant 0.0 : f32
|
|
261
|
|
262 %0 = vector.transfer_read %arg0[%c0, %c0], %cf0
|
|
263 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
|
|
264 : memref<4x2xf32>, vector<4x2xf32>
|
|
265
|
|
266 %1 = vector.transfer_read %arg1[%c0, %c0], %cf0
|
|
267 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
|
|
268 : memref<2x4xf32>, vector<2x4xf32>
|
|
269
|
|
270 %2 = vector.transfer_read %arg2[%c0, %c0], %cf0
|
|
271 { permutation_map = affine_map<(d0, d1) -> (d0, d1)> }
|
|
272 : memref<4x4xf32>, vector<4x4xf32>
|
|
273
|
|
274 %3 = vector.contract #contraction_trait1 %0, %1, %2
|
|
275 : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
|
|
276
|
|
277 vector.transfer_write %3, %arg2[%c0, %c0]
|
|
278 {permutation_map = affine_map<(d0, d1) -> (d0, d1)>}
|
|
279 : vector<4x4xf32>, memref<4x4xf32>
|
|
280 return
|
|
281 }
|
|
282
|
|
283 // TODO(andydavis) Update test with VTR split transform.
|
|
284 // CHECK-LABEL: func @vector_transfers
|
|
285 // CHECK-COUNT-8: vector.transfer_read
|
|
286 // CHECK-COUNT-4: addf
|
|
287 // CHECK-COUNT-4: vector.transfer_write
|
|
288
|
|
289 func @vector_transfers(%arg0: index, %arg1: index) {
|
|
290 %cst = constant 0.000000e+00 : f32
|
|
291 %0 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
|
292 %1 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
|
293 %2 = alloc(%arg0, %arg1) : memref<?x?xf32>
|
|
294 %cst_0 = constant 1.000000e+00 : f32
|
|
295 %cst_1 = constant 2.000000e+00 : f32
|
|
296 affine.for %arg2 = 0 to %arg0 step 4 {
|
|
297 affine.for %arg3 = 0 to %arg1 step 4 {
|
|
298 %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
|
|
299 %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref<?x?xf32>, vector<4x4xf32>
|
|
300 %6 = addf %4, %5 : vector<4x4xf32>
|
|
301 vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref<?x?xf32>
|
|
302 }
|
|
303 }
|
|
304 return
|
|
305 }
|
|
306
|
|
307 // CHECK-LABEL: func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>)
|
|
308 // CHECK: return %arg1
|
|
309
|
|
310 func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
|
|
311 %0 = vector.tuple %arg0, %arg1 : vector<4xf32>, vector<8xf32>
|
|
312 %1 = vector.tuple_get %0, 1 : tuple<vector<4xf32>, vector<8xf32>>
|
|
313 return %1 : vector<8xf32>
|
|
314 }
|
|
315
|
|
316 // CHECK-LABEL: func @tuple_get_producer_consumer
|
|
317 // CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
|
|
318 // CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
|
|
319 // CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
|
|
320 // CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
|
|
321 // CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
|
|
322 // CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
|
|
323 // CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
|
|
324 // CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
|
|
325 // CHECK: return %[[A7]] : vector<2x4xf32>
|
|
326
|
|
327 func @tuple_get_producer_consumer(
|
|
328 %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
|
|
329 %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
|
|
330 %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
|
|
331 %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
|
|
332 %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
|
|
333 : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
|
334 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
|
|
335 // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
|
|
336 %1 = vector.insert_slices %0, [2, 4], [1, 1]
|
|
337 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
|
338 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
|
339 into vector<4x16xf32>
|
|
340 // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
|
|
341 %2 = vector.extract_slices %1, [4, 8], [1, 1]
|
|
342 : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
|
|
343 // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
|
|
344 %3 = vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
|
|
345 tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
|
|
346 // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
|
|
347 %4 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
|
|
348 // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4]
|
|
349 %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32>
|
|
350 // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4]
|
|
351 %6 = vector.extract_slices %5, [2, 4], [1, 1]
|
|
352 : vector<4x8xf32> into
|
|
353 tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
|
354 // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0]
|
|
355 %7 = vector.tuple_get %6, 3
|
|
356 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
|
357 // %arg7 == %7
|
|
358 return %7 : vector<2x4xf32>
|
|
359 }
|
|
360
|
|
361 // CHECK-LABEL: func @tuple_get_producer_consumer_swizzle
|
|
362 // CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>,
|
|
363 // CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>,
|
|
364 // CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>,
|
|
365 // CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>,
|
|
366 // CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>,
|
|
367 // CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>,
|
|
368 // CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>,
|
|
369 // CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32>
|
|
370 // CHECK: return %[[A7]] : vector<2x4xf32>
|
|
371
|
|
372 func @tuple_get_producer_consumer_swizzle(
|
|
373 %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>,
|
|
374 %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>,
|
|
375 %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>,
|
|
376 %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> {
|
|
377 %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7
|
|
378 : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
|
379 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>
|
|
380 // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0]
|
|
381 %1 = vector.insert_slices %0, [2, 4], [1, 1]
|
|
382 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>,
|
|
383 vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
|
384 into vector<4x16xf32>
|
|
385 // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12]
|
|
386 %2 = vector.extract_slices %1, [4, 8], [1, 1]
|
|
387 : vector<4x16xf32> into tuple<vector<4x8xf32>, vector<4x8xf32>>
|
|
388 // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4]
|
|
389 %3= vector.shape_cast %2 : tuple<vector<4x8xf32>, vector<4x8xf32>> to
|
|
390 tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
|
|
391 // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4]
|
|
392
|
|
393 // Extract tuple elements.
|
|
394 %4 = vector.tuple_get %3, 0 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
|
|
395 %5 = vector.tuple_get %3, 1 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>>
|
|
396 // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4]
|
|
397
|
|
398 // Swizzle tuple elements.
|
|
399 %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32>
|
|
400 // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4]
|
|
401 %7 = vector.shape_cast %6 : tuple<vector<1x1x4x8xf32>, vector<1x1x4x8xf32>> to
|
|
402 tuple<vector<4x8xf32>, vector<4x8xf32>>
|
|
403 // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4]
|
|
404 %8 = vector.tuple_get %7, 0 : tuple<vector<4x8xf32>, vector<4x8xf32>>
|
|
405 // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4]
|
|
406 %9 = vector.extract_slices %8, [2, 4], [1, 1]
|
|
407 : vector<4x8xf32> into
|
|
408 tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
|
409 // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0]
|
|
410 %10 = vector.tuple_get %9, 3
|
|
411 : tuple<vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>>
|
|
412 // %arg7 == %10
|
|
413 return %10 : vector<2x4xf32>
|
|
414 }
|
|
415
|
|
416 // CHECK-LABEL: func @cancelling_shape_cast_ops
|
|
417 // CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>
|
|
418 // CHECK: return %[[A0]] : vector<2x4xf32>
|
|
419 func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> {
|
|
420 %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
|
|
421 %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32>
|
|
422 return %1 : vector<2x4xf32>
|
|
423 }
|
|
424
|
|
425 // CHECK-LABEL: func @vector_transfers_vector_element_type
|
|
426 // CHECK: %[[C0:.*]] = constant 0 : index
|
|
427 // CHECK: %[[C1:.*]] = constant 1 : index
|
|
428 // CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
|
|
429 // CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32>
|
|
430 // CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
|
|
431 // CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
|
|
432
|
|
433 func @vector_transfers_vector_element_type() {
|
|
434 %c0 = constant 0 : index
|
|
435 %cf0 = constant 0.000000e+00 : f32
|
|
436 %vf0 = splat %cf0 : vector<2x4xf32>
|
|
437
|
|
438 %0 = alloc() : memref<6x2x1xvector<2x4xf32>>
|
|
439
|
|
440 %1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0
|
|
441 {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
|
|
442 : memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32>
|
|
443
|
|
444 %2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1]
|
|
445 : vector<2x1x2x4xf32> into tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
|
|
446 %3 = vector.tuple_get %2, 0 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
|
|
447 %4 = vector.tuple_get %2, 1 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>>
|
|
448 %5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32>
|
|
449 %6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1]
|
|
450 : tuple<vector<1x1x2x4xf32>, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32>
|
|
451
|
|
452 vector.transfer_write %6, %0[%c0, %c0, %c0]
|
|
453 {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>}
|
|
454 : vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>>
|
|
455
|
|
456 return
|
|
457 }
|
|
458
|
|
459 // Test that ShapeCastOp on tuple of vectors, decomposes to multiple
|
|
460 // ShapeCastOps on vectors.
|
|
461 // CHECK-LABEL: func @shape_cast_decomposition
|
|
462 // CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<5x4x2xf32> to vector<20x2xf32>
|
|
463 // CHECK-NEXT: %[[V1:.*]] = vector.shape_cast %{{.*}} : vector<3x4x2xf32> to vector<12x2xf32>
|
|
464 // CHECK-NEXT: return %[[V0]], %[[V1]] : vector<20x2xf32>, vector<12x2xf32>
|
|
465
|
|
466 func @shape_cast_decomposition(%arg0 : vector<5x4x2xf32>,
|
|
467 %arg1 : vector<3x4x2xf32>)
|
|
468 -> (vector<20x2xf32>, vector<12x2xf32>) {
|
|
469 %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
|
|
470 %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
|
|
471 tuple<vector<20x2xf32>, vector<12x2xf32>>
|
|
472 %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
|
473 %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
|
474 return %2, %3 : vector<20x2xf32>, vector<12x2xf32>
|
|
475 }
|
|
476
|
|
477 // Test that cancelling ShapeCastOps are canonicalized away.
|
|
478 // EX:
|
|
479 //
|
|
480 // The following MLIR with cancelling ShapeCastOps:
|
|
481 //
|
|
482 // %0 = source : vector<5x4x2xf32>
|
|
483 // %1 = shape_cast %0 : vector<5x4x2xf32> to vector<20x2xf32>
|
|
484 // %2 = shape_cast %1 : vector<20x2xf32> to vector<5x4x2xf32>
|
|
485 // %3 = user %2 : vector<5x4x2xf32>
|
|
486 //
|
|
487 // Should canonicalize to the following:
|
|
488 //
|
|
489 //
|
|
490 // %0 = source : vector<5x4x2xf32>
|
|
491 // %1 = user %0 : vector<5x4x2xf32>
|
|
492 //
|
|
493
|
|
494 // ShapeCastOps on vectors.
|
|
495 // CHECK-LABEL: func @shape_cast_fold
|
|
496 // CHECK: return %{{.*}}, %{{.*}} : vector<5x4x2xf32>, vector<3x4x2xf32>
|
|
497
|
|
498 func @shape_cast_fold(%arg0 : vector<5x4x2xf32>, %arg1 : vector<3x4x2xf32>)
|
|
499 -> (vector<5x4x2xf32>, vector<3x4x2xf32>) {
|
|
500 %0 = vector.tuple %arg0, %arg1 : vector<5x4x2xf32>, vector<3x4x2xf32>
|
|
501
|
|
502 %1 = vector.shape_cast %0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>> to
|
|
503 tuple<vector<20x2xf32>, vector<12x2xf32>>
|
|
504
|
|
505 %2 = vector.tuple_get %1, 0 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
|
506 %3 = vector.tuple_get %1, 1 : tuple<vector<20x2xf32>, vector<12x2xf32>>
|
|
507
|
|
508 %4 = vector.tuple %2, %3 : vector<20x2xf32>, vector<12x2xf32>
|
|
509 %5 = vector.shape_cast %4 : tuple<vector<20x2xf32>, vector<12x2xf32>> to
|
|
510 tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
|
|
511
|
|
512 %6 = vector.tuple_get %5, 0 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
|
|
513 %7 = vector.tuple_get %5, 1 : tuple<vector<5x4x2xf32>, vector<3x4x2xf32>>
|
|
514
|
|
515 return %6, %7 : vector<5x4x2xf32>, vector<3x4x2xf32>
|
|
516 }
|