summaryrefslogtreecommitdiffstats
path: root/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
blob: 477755b66c020c72aad61254b6a3a5ed17b7847b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s

func.func @contiguous_inner_most_view(%in: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>) -> vector<1x8x1xf32>{
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.0 : f32
  %0 = vector.transfer_read %in[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>, vector<1x8x1xf32>
  return %0 : vector<1x8x1xf32>
}
//      CHECK: func @contiguous_inner_most_view(%[[SRC:.+]]: memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>>
//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME:    memref<1x1x8x1xf32, strided<[3072, 8, 1, 1], offset: ?>> to memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>
//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[SRC_0]]
// CHECK-SAME:    memref<1x1x8xf32, strided<[3072, 8, 1], offset: ?>>, vector<1x8xf32>
//      CHECK:   %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
//      CHECK:   return %[[RESULT]]

// -----

func.func @contiguous_outer_dyn_inner_most_view(%a: index, %b: index, %memref: memref<?x?x8x1xf32>) -> vector<8x1xf32> {
  %c0 = arith.constant 0 : index
  %pad = arith.constant 0.0 : f32
  %v = vector.transfer_read %memref[%a, %b, %c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?x8x1xf32>, vector<8x1xf32>
  return %v : vector<8x1xf32>
}
// CHECK: func.func @contiguous_outer_dyn_inner_most_view(
// CHECK-SAME:   %[[IDX0:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[IDX1:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[SRC:[a-zA-Z0-9]+]]
// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG:    %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG:    %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
// CHECK:        %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
// CHECK:        %[[D1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK:        %[[VIEW:.+]] = memref.subview %[[SRC]][0, 0, 0, 0] [%[[D0]], %[[D1]], 8, 1] [1, 1, 1, 1]
// CHECK-SAME:     memref<?x?x8x1xf32> to memref<?x?x8xf32, strided<[?, 8, 1], offset: ?>>
// CHECK:        %[[VEC:.+]] = vector.transfer_read %[[VIEW]]
// CHECK-SAME:     memref<?x?x8xf32, strided<[?, 8, 1], offset: ?>>, vector<8xf32>
// CHECK:        %[[RESULT:.+]] = vector.shape_cast %[[VEC]]
// CHECK:        return %[[RESULT]]

// -----

func.func @contiguous_inner_most_dim(%A: memref<16x1xf32>, %i:index, %j:index) -> (vector<8x1xf32>) {
  %c0 = arith.constant 0 : index
  %f0 = arith.constant 0.0 : f32
  %1 = vector.transfer_read %A[%i, %j], %f0 : memref<16x1xf32>, vector<8x1xf32>
  return %1 : vector<8x1xf32>
}
//      CHECK: func @contiguous_inner_most_dim(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index, %[[J:.+]]: index) -> vector<8x1xf32>
//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME:     memref<16x1xf32> to memref<16xf32, strided<[1]>>
//      CHECK:   %[[V:.+]] = vector.transfer_read %[[SRC_0]]
//      CHECK:   %[[RESULT]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32>
//      CHECK:   return %[[RESULT]]

// -----

func.func @contiguous_inner_most_dim_bounds(%A: memref<1000x1xf32>, %i:index, %ii:index) -> (vector<4x1xf32>) {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.0 : f32
  %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, strided<[1, 1], offset: ?>>
  %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, strided<[1, 1], offset: ?>>, vector<4x1xf32>
  return %1 : vector<4x1xf32>
}
//      CHECK: func @contiguous_inner_most_dim_bounds(%[[SRC:.+]]: memref<1000x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1xf32>
//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
//      CHECK:   %[[SRC_1:.+]] = memref.subview %[[SRC_0]]
//      CHECK:   %[[V:.+]] = vector.transfer_read %[[SRC_1]]
// CHECK-SAME:       {in_bounds = [true]}
// CHECK-SAME:       vector<4xf32>

// -----

func.func @contiguous_inner_most_dim_bounds_2d(%A: memref<1000x1x1xf32>, %i:index, %ii:index) -> (vector<4x1x1xf32>) {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.0 : f32
  %0 = memref.subview %A[%i, 0, 0] [40, 1, 1] [1, 1, 1] : memref<1000x1x1xf32> to memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>>
  %1 = vector.transfer_read %0[%ii, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<40x1x1xf32, strided<[1, 1, 1], offset: ?>>, vector<4x1x1xf32>
  return %1 : vector<4x1x1xf32>
}
//      CHECK: func @contiguous_inner_most_dim_bounds_2d(%[[SRC:.+]]: memref<1000x1x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1x1xf32>
//      CHECK:   %[[SRC_0:.+]] = memref.subview %[[SRC]]
//      CHECK:   %[[SRC_1:.+]] = memref.subview %[[SRC_0]]
//      CHECK:   %[[V:.+]] = vector.transfer_read %[[SRC_1]]
// CHECK-SAME:       {in_bounds = [true]}
// CHECK-SAME:       vector<4xf32>

// -----

func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) -> vector<4x8xf32> {
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %0 = vector.transfer_read %arg0[%c0, %c0], %cst : memref<1x1xf32>, vector<4x8xf32>
  return %0 : vector<4x8xf32>
}
// The inner most unit dim can not be dropped. In this context, we do not
// generate rank-reduced memref.subview ops.
//      CHECK: func.func @contiguous_inner_most_dim_out_of_bounds_2d
// CHECK-SAME:   %[[SRC:[a-zA-Z0-9]+]]
//  CHECK-NOT:   memref.subview
//      CHECK:   %[[READ:.+]] = vector.transfer_read %[[SRC]]
//      CHECK:   return %[[READ]] : vector<4x8xf32>

// -----

func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
  %c0 = arith.constant 0 : index
  vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
    {in_bounds = [true, true, true, true, true]}
    : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
  return
}
// CHECK:      func.func @drop_two_inner_most_dim_for_transfer_write
// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
// CHECK-SAME:     memref<1x512x16x1x1xf32> to memref<1x512x16xf32, strided<[8192, 16, 1]>>
// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1x1xf32> to vector<1x16x16xf32>
// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
// CHECK-SAME:     [%[[C0]], %[[IDX]], %[[C0]]]

// -----

func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
  %c0 = arith.constant 0 : index
  vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
    {in_bounds = [true, true, true, true]}
    : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
  return
}
// CHECK:      func.func @drop_inner_most_dim_for_transfer_write
// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
// CHECK-SAME:     memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
// CHECK-SAME:     [%[[C0]], %[[IDX]], %[[C0]]]

// -----

func.func @outer_dyn_drop_inner_most_dim_for_transfer_write(%arg0: memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
  %c0 = arith.constant 0 : index
  vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0, %c0]
    {in_bounds = [true, true, true, true]}
    : vector<1x16x16x1xf32>, memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
  return
}
// CHECK:      func.func @outer_dyn_drop_inner_most_dim_for_transfer_write
// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
//  CHECK-DAG:   %[[D0:.+]] = memref.dim %[[SRC]], %[[C0]]
// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]][0, 0, 0, 0] [%[[D0]], 512, 16, 1]
// CHECK-SAME:     memref<?x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<?x512x16xf32, strided<[8192, 16, 1], offset: ?>>
// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
// CHECK-SAME:     [%[[IDX]], %[[C0]], %[[C0]]]

// -----

func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
  %c0 = arith.constant 0 : index
  vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector<16x16x1xf32>, memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>
  return
}
// The inner most unit dims can not be dropped if the strides are not ones.
// CHECK:     func.func @non_unit_strides
// CHECK-NOT:   memref.subview