summaryrefslogtreecommitdiffstats
path: root/mlir/test/Dialect/Affine/loop-fusion-transformation.mlir
blob: c8e0918e6bfd6f7b3a6162caab428c88b6e692c8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// RUN: mlir-opt %s -allow-unregistered-dialect -test-loop-fusion -test-loop-fusion-transformation -split-input-file -canonicalize | FileCheck %s

// CHECK-LABEL: func @slice_depth1_loop_nest() {
func.func @slice_depth1_loop_nest() {
  %0 = memref.alloc() : memref<100xf32>
  %cst = arith.constant 7.000000e+00 : f32
  affine.for %i0 = 0 to 16 {
    affine.store %cst, %0[%i0] : memref<100xf32>
  }
  affine.for %i1 = 0 to 5 {
    %1 = affine.load %0[%i1] : memref<100xf32>
    "prevent.dce"(%1) : (f32) -> ()
  }
  // CHECK:      affine.for %[[IV0:.*]] = 0 to 5 {
  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%[[IV0]]] : memref<100xf32>
  // CHECK-NEXT:   affine.load %{{.*}}[%[[IV0]]] : memref<100xf32>
  // CHECK-NEXT:   "prevent.dce"(%{{.*}}) : (f32) -> ()
  // CHECK-NEXT: }
  // CHECK-NEXT: return
  return
}

// -----

// CHECK-LABEL: func @should_fuse_reduction_to_pointwise() {
func.func @should_fuse_reduction_to_pointwise() {
  %a = memref.alloc() : memref<10x10xf32>
  %b = memref.alloc() : memref<10xf32>
  %c = memref.alloc() : memref<10xf32>

  %cf7 = arith.constant 7.0 : f32

  affine.for %i0 = 0 to 10 {
    affine.for %i1 = 0 to 10 {
      %v0 = affine.load %b[%i0] : memref<10xf32>
      %v1 = affine.load %a[%i0, %i1] : memref<10x10xf32>
      %v3 = arith.addf %v0, %v1 : f32
      affine.store %v3, %b[%i0] : memref<10xf32>
    }
  }
  affine.for %i2 = 0 to 10 {
    %v4 = affine.load %b[%i2] : memref<10xf32>
    affine.store %v4, %c[%i2] : memref<10xf32>
  }

  // Match on the fused loop nest.
  // Should fuse in entire inner loop on %i1 from source loop nest, as %i1
  // is not used in the access function of the store/load on %b.
  // CHECK:       affine.for %{{.*}} = 0 to 10 {
  // CHECK-NEXT:    affine.for %{{.*}} = 0 to 10 {
  // CHECK-NEXT:      affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:      affine.load %{{.*}}[%{{.*}}, %{{.*}}] : memref<10x10xf32>
  // CHECK-NEXT:      arith.addf %{{.*}}, %{{.*}} : f32
  // CHECK-NEXT:      affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:    }
  // CHECK-NEXT:    affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:    affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:  }
  // CHECK-NEXT:  return
  return
}

// -----

// CHECK-LABEL: func @should_fuse_avoiding_dependence_cycle() {
func.func @should_fuse_avoiding_dependence_cycle() {
  %a = memref.alloc() : memref<10xf32>
  %b = memref.alloc() : memref<10xf32>
  %c = memref.alloc() : memref<10xf32>

  %cf7 = arith.constant 7.0 : f32

  // Set up the following dependences:
  // 1) loop0 -> loop1 on memref '%{{.*}}'
  // 2) loop0 -> loop2 on memref '%{{.*}}'
  // 3) loop1 -> loop2 on memref '%{{.*}}'
  affine.for %i0 = 0 to 10 {
    %v0 = affine.load %a[%i0] : memref<10xf32>
    affine.store %v0, %b[%i0] : memref<10xf32>
  }
  affine.for %i1 = 0 to 10 {
    affine.store %cf7, %a[%i1] : memref<10xf32>
    %v1 = affine.load %c[%i1] : memref<10xf32>
    "prevent.dce"(%v1) : (f32) -> ()
  }
  affine.for %i2 = 0 to 10 {
    %v2 = affine.load %b[%i2] : memref<10xf32>
    affine.store %v2, %c[%i2] : memref<10xf32>
  }
  // Fusing loop first loop into last would create a cycle:
  //   {1} <--> {0, 2}
  // However, we can avoid the dependence cycle if we first fuse loop0 into
  // loop1:
  //   {0, 1) --> {2}
  // Then fuse this loop nest with loop2:
  //   {0, 1, 2}
  //
  // CHECK:      affine.for %{{.*}} = 0 to 10 {
  // CHECK-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:   "prevent.dce"
  // CHECK-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT:   affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
  // CHECK-NEXT: }
  // CHECK-NEXT: return
  return
}