aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h6
-rw-r--r--mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h2
-rw-r--r--mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td27
-rw-r--r--mlir/include/mlir/Dialect/Affine/IR/AffineOps.td6
-rw-r--r--mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td115
-rw-r--r--mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td50
-rw-r--r--mlir/include/mlir/Dialect/SCF/IR/SCFOps.td9
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformOps.td6
-rw-r--r--mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td2
-rw-r--r--mlir/include/mlir/IR/Diagnostics.h2
-rw-r--r--mlir/include/mlir/IR/Operation.h1
-rw-r--r--mlir/include/mlir/IR/Region.h2
-rw-r--r--mlir/include/mlir/Interfaces/ControlFlowInterfaces.h104
-rw-r--r--mlir/include/mlir/Interfaces/ControlFlowInterfaces.td108
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp325
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp9
-rw-r--r--mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp4
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp6
-rw-r--r--mlir/lib/Analysis/SliceWalk.cpp2
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp175
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp54
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp10
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp50
-rw-r--r--mlir/lib/Dialect/Async/IR/Async.cpp11
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp11
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp8
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp3
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp23
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp52
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp1
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp1
-rw-r--r--mlir/lib/Dialect/Shape/IR/Shape.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp15
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h6
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp2
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp37
-rw-r--r--mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp5
-rw-r--r--mlir/lib/IR/Diagnostics.cpp4
-rw-r--r--mlir/lib/IR/Region.cpp15
-rw-r--r--mlir/lib/Interfaces/ControlFlowInterfaces.cpp305
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp25
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir89
-rw-r--r--mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir40
-rw-r--r--mlir/test/Dialect/AMDGPU/invalid.mlir60
-rw-r--r--mlir/test/Dialect/AMDGPU/ops.mlir35
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir21
-rw-r--r--mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir20
-rw-r--r--mlir/test/Dialect/LLVMIR/rocdl.mlir30
-rw-r--r--mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir18
-rw-r--r--mlir/test/Dialect/OpenACC/ops.mlir73
-rw-r--r--mlir/test/Dialect/SCF/invalid.mlir8
-rw-r--r--mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir32
-rw-r--r--mlir/test/Target/LLVMIR/ptr.mlir28
-rw-r--r--mlir/test/Target/LLVMIR/rocdl.mlir30
-rw-r--r--mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir42
-rw-r--r--mlir/test/Target/SPIRV/decorations.mlir77
-rw-r--r--mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp4
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp26
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td2
-rw-r--r--mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp38
63 files changed, 1718 insertions, 564 deletions
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 8bcfe51..3c87c45 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -397,7 +397,7 @@ protected:
/// itself.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
}
@@ -526,7 +526,7 @@ public:
/// and "to" regions.
virtual void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const LatticeT &after, LatticeT *before) {
+ RegionSuccessor regionTo, const LatticeT &after, LatticeT *before) {
AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, after, before);
}
@@ -571,7 +571,7 @@ protected:
}
void visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionForm,
- RegionBranchPoint regionTo, const AbstractDenseLattice &after,
+ RegionSuccessor regionTo, const AbstractDenseLattice &after,
AbstractDenseLattice *before) final {
visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo,
static_cast<const LatticeT &>(after),
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1a33ecf..9855734 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -286,7 +286,7 @@ private:
/// and propagating therefrom.
virtual void
visitRegionSuccessors(ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor,
+ RegionSuccessor successor,
ArrayRef<AbstractSparseLattice *> lattices);
};
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 37db096..45cb67f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
- VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
- VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
+ VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+ VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
+ VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -992,7 +993,7 @@ def AMDGPU_WMMAOp :
Arguments<(ins
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
- ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
WMMAInTypes:$sourceA,
WMMAInTypes:$sourceB,
WMMAOutTypes:$destC,
@@ -1005,8 +1006,14 @@ def AMDGPU_WMMAOp :
let description = [{
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
instructions in the AMDGPU architecture, which perform matrix multiplication.
- Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
- dimensions.
+
+ On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
+
+ On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
+ all element types, and K=32 for i4 sources.
+
+ On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
+ depending on the element types.
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
(or 16xbf16) vector containing only 8 valid values:
@@ -1022,7 +1029,13 @@ def AMDGPU_WMMAOp :
Example:
```mlir
- %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+ %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
+
+ %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
+
+ %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+ %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
```
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 12a7935..409bd05 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -714,7 +714,7 @@ def AffineParallelOp : Affine_Op<"parallel",
operand_range getUpperBoundsOperands();
AffineValueMap getUpperBoundsValueMap();
- /// Sets elements fo the loop upper bound.
+ /// Sets elements of the loop upper bound.
void setUpperBounds(ValueRange operands, AffineMap map);
void setSteps(ArrayRef<int64_t> newSteps);
@@ -999,7 +999,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
elemental type, supplied as its second operand.
The index for each memref dimension is an affine expression of loop
induction variables and symbols. These indices determine the start position
- of the write within the memref. The shape of th input vector determines the
+ of the write within the memref. The shape of the input vector determines the
shape of the slice written to the memref. This slice is contiguous along the
respective dimensions of the shape. Strided vector stores will be supported
in the future.
@@ -1188,7 +1188,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If all `N` basis elements are provided, the linearize_index operation is said to
"have an outer bound".
- As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+ As a convenience, and for symmetry with `getPaddedBasis()`, if the first
element of a set of `OpFoldResult`s passed to the builders of this operation is
`nullptr`, that element is ignored.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index d2df244..5241f9a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -147,6 +147,35 @@ class ROCDL_DimGetterFunctionOp<string mnemonic, string device_function,
}
//===----------------------------------------------------------------------===//
+// ROCDL vector types definitions
+//===----------------------------------------------------------------------===//
+
+class ROCDL_ConcreteVector<Type elem, int length> :
+ FixedVectorOfLengthAndType<[length], [elem]>,
+ BuildableType<
+ "::mlir::VectorType::get({" # length # "} ,"
+ # elem.builderCall # ")">;
+
+def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
+def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
+def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
+def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
+def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
+def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
+def ROCDL_V4I32Type : ROCDL_ConcreteVector<I32, 4>;
+def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
+def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
+def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
+def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
+def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
+def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
+def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
+def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
+def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
+def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
+def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
+
+//===----------------------------------------------------------------------===//
// Wave-level primitives
//===----------------------------------------------------------------------===//
@@ -664,6 +693,68 @@ def ROCDL_GlobalLoadLDSOp :
}
//===---------------------------------------------------------------------===//
+// Tensor load/store intrinsics (available in GFX1250)
+//===---------------------------------------------------------------------===//
+
+// Base class for tensor load/store operations with 4 descriptor groups.
+class ROCDL_TensorLDSIntrOp<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [4], ["cachePolicy"]> {
+ dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
+ ROCDL_V4I32Type:$dgroup2, ROCDL_V4I32Type:$dgroup3,
+ I32Attr:$cachePolicy);
+ let arguments = !con(args, baseArgs);
+ let summary = "Base class for ROCDL tensor load/store to/from LDS.";
+ let description = [{
+ Moves tiles of tensor data between global memory and LDS. The tile is
+ described by the $dgroup descriptors. 4 $dgroup descriptors allows for
+ movement of up to 5D tensors. $cachePolicy describes the memory scope and an
+ indicator of expected data re-use.
+
+ This op is for gfx1250+ architectures.
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
+ }];
+ let extraClassDefinition = [{
+ SmallVector<Value> $cppClass::getAccessedOperands() {
+ return {getDgroup0(), getDgroup1(), getDgroup2(), getDgroup3()};
+ }
+ }];
+}
+
+// Base class for tensor load/store operations with 2 descriptor groups
+// (D2 variant).
+class ROCDL_TensorLDSIntrD2Op<string mnemonic> :
+ ROCDL_IntrOp<mnemonic, [], [], [], 0, 0, 1, 0, [2], ["cachePolicy"]> {
+ dag args = (ins ROCDL_V4I32Type:$dgroup0, ROCDL_V8I32Type:$dgroup1,
+ I32Attr:$cachePolicy);
+ let arguments = !con(args, baseArgs);
+ let summary = "Base class for ROCDL tensor load/store to/from LDS (D2 variant).";
+ let description = [{
+ Moves tiles of tensor data between global memory and LDS. The tile is
+ described by the $dgroup descriptors. 2 $dgroup descriptors allows for
+ movement of up to 2D tensors. $cachePolicy describes the memory scope and an
+ indicator of expected data re-use.
+
+ This op is for gfx1250+ architectures.
+ }];
+ let assemblyFormat = [{
+ attr-dict operands `cachepolicy` $cachePolicy `:` type($dgroup0) `,` type($dgroup1)
+ }];
+ let extraClassDefinition = [{
+ SmallVector<Value> $cppClass::getAccessedOperands() {
+ return {getDgroup0(), getDgroup1()};
+ }
+ }];
+}
+
+// Tensor load and store operations
+def ROCDL_TensorLoadToLDSOp : ROCDL_TensorLDSIntrOp<"tensor.load.to.lds">;
+def ROCDL_TensorStoreFromLDSOp : ROCDL_TensorLDSIntrOp<"tensor.store.from.lds">;
+def ROCDL_TensorLoadToLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.load.to.lds.d2">;
+def ROCDL_TensorStoreFromLDSD2Op : ROCDL_TensorLDSIntrD2Op<"tensor.store.from.lds.d2">;
+
+//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
// raw buffer mode).
//===---------------------------------------------------------------------===//
@@ -932,30 +1023,6 @@ def ROCDL_Permlane32SwapOp : ROCDL_IntrOp<"permlane32.swap", [], [],
}];
}
-class ROCDL_ConcreteVector<Type elem, int length> :
- FixedVectorOfLengthAndType<[length], [elem]>,
- BuildableType<
- "::mlir::VectorType::get({" # length # "} ,"
- # elem.builderCall # ")">;
-
-def ROCDL_V2I16Type : ROCDL_ConcreteVector<I16, 2>;
-def ROCDL_V2F16Type : ROCDL_ConcreteVector<F16, 2>;
-def ROCDL_V2I32Type : ROCDL_ConcreteVector<I32, 2>;
-def ROCDL_V2BF16Type : ROCDL_ConcreteVector<BF16, 2>;
-def ROCDL_V2F32Type : ROCDL_ConcreteVector<F32, 2>;
-def ROCDL_V3I32Type : ROCDL_ConcreteVector<I32, 3>;
-def ROCDL_V6I32Type : ROCDL_ConcreteVector<I32, 6>;
-def ROCDL_V8I32Type : ROCDL_ConcreteVector<I32, 8>;
-def ROCDL_V8BF16Type : ROCDL_ConcreteVector<BF16, 8>;
-def ROCDL_V8F16Type : ROCDL_ConcreteVector<F16, 8>;
-def ROCDL_V8F32Type : ROCDL_ConcreteVector<F32, 8>;
-def ROCDL_V16BF16Type : ROCDL_ConcreteVector<BF16, 16>;
-def ROCDL_V16F16Type : ROCDL_ConcreteVector<F16, 16>;
-def ROCDL_V16F32Type : ROCDL_ConcreteVector<F32, 16>;
-def ROCDL_V32F16Type : ROCDL_ConcreteVector<F16, 32>;
-def ROCDL_V32BF16Type : ROCDL_ConcreteVector<BF16, 32>;
-def ROCDL_V32F32Type : ROCDL_ConcreteVector<F32, 32>;
-
//===---------------------------------------------------------------------===//
// 16-bit float intrinsics
//===---------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 2f87975..a18c18a 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -2117,6 +2117,56 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
}
//===----------------------------------------------------------------------===//
+// acc.kernel_environment
+//===----------------------------------------------------------------------===//
+
+def OpenACC_KernelEnvironmentOp : OpenACC_Op<"kernel_environment",
+ [AttrSizedOperandSegments, RecursiveMemoryEffects, SingleBlock,
+ NoTerminator,
+ MemoryEffects<[MemWrite<OpenACC_ConstructResource>,
+ MemRead<OpenACC_CurrentDeviceIdResource>]>]> {
+ let summary = "Decomposition of compute constructs to capture data mapping "
+ "and asynchronous behavior information";
+ let description = [{
+ The `acc.kernel_environment` operation represents a decomposition of
+ any OpenACC compute construct (acc.kernels, acc.parallel, or
+ acc.serial) that captures data mapping and asynchronous behavior:
+ - data clause operands
+ - async clause operands
+ - wait clause operands
+
+ This allows kernel execution parallelism and privatization to be
+ handled separately, facilitating eventual lowering to GPU dialect where
+ kernel launching and compute offloading are handled separately.
+ }];
+
+ let arguments = (ins
+ Variadic<AnyType>:$dataClauseOperands,
+ Variadic<IntOrIndex>:$asyncOperands,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOperandsDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$asyncOnly,
+ Variadic<IntOrIndex>:$waitOperands,
+ OptionalAttr<DenseI32ArrayAttr>:$waitOperandsSegments,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOperandsDeviceType,
+ OptionalAttr<BoolArrayAttr>:$hasWaitDevnum,
+ OptionalAttr<DeviceTypeArrayAttr>:$waitOnly);
+
+ let regions = (region SizedRegion<1>:$region);
+
+ let assemblyFormat = [{
+ oilist(
+ `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+ | `async` `` custom<DeviceTypeOperandsWithKeywordOnly>($asyncOperands,
+ type($asyncOperands), $asyncOperandsDeviceType, $asyncOnly)
+ | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
+ $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
+ $waitOnly)
+ )
+ $region attr-dict
+ }];
+}
+
+//===----------------------------------------------------------------------===//
// 2.6.5 data Construct
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index fadd3fc..cd033c1 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -77,7 +77,7 @@ def ConditionOp : SCF_Op<"condition", [
//===----------------------------------------------------------------------===//
def ExecuteRegionOp : SCF_Op<"execute_region", [
- DeclareOpInterfaceMethods<RegionBranchOpInterface>]> {
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>, RecursiveMemoryEffects]> {
let summary = "operation that executes its region exactly once";
let description = [{
The `scf.execute_region` operation is used to allow multiple blocks within SCF
@@ -644,6 +644,13 @@ def ForallOp : SCF_Op<"forall", [
/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();
+
+ /// RegionBranchOpInterface
+
+ OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
+ return getInits();
+ }
+
}];
}
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 62e66b3..ed69287 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -25,7 +25,7 @@ include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
def AlternativesOp : TransformDialectOp<"alternatives",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -624,7 +624,7 @@ def ForeachOp : TransformDialectOp<"foreach",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface, [
- "getSuccessorRegions", "getEntrySuccessorOperands"]>,
+ "getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">
]> {
let summary = "Executes the body for each element of the payload";
@@ -1237,7 +1237,7 @@ def SelectOp : TransformDialectOp<"select",
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
MatchOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>,
diff --git a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
index d095659..4079848 100644
--- a/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td
@@ -63,7 +63,7 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorRegions",
+ ["getEntrySuccessorOperands",
"getRegionInvocationBounds"]>,
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 7ff718a..a0a99f4 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -29,6 +29,7 @@ class MLIRContext;
class Operation;
class OperationName;
class OpPrintingFlags;
+class OpWithFlags;
class Type;
class Value;
@@ -199,6 +200,7 @@ public:
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
+ Diagnostic &operator<<(OpWithFlags op);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 5569392c..b201957 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -1114,6 +1114,7 @@ public:
: op(op), theFlags(flags) {}
OpPrintingFlags &flags() { return theFlags; }
const OpPrintingFlags &flags() const { return theFlags; }
+ Operation *getOperation() const { return op; }
private:
Operation *op;
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 1fcb316..53d461d 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -379,6 +379,8 @@ private:
friend RangeBaseT;
};
+llvm::raw_ostream &operator<<(llvm::raw_ostream &os, Region &region);
+
} // namespace mlir
#endif // MLIR_IR_REGION_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index d63800c..47afd25 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -15,10 +15,16 @@
#define MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
class BranchOpInterface;
class RegionBranchOpInterface;
+class RegionBranchTerminatorOpInterface;
/// This class models how operands are forwarded to block arguments in control
/// flow. It consists of a number, denoting how many of the successors block
@@ -186,27 +192,40 @@ class RegionSuccessor {
public:
/// Initialize a successor that branches to another region of the parent
/// operation.
+ /// TODO: the default value for the regionInputs is somehow broken.
+ /// A region successor should have its input correctly set.
RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {})
- : region(region), inputs(regionInputs) {}
+ : successor(region), inputs(regionInputs) {
+ assert(region && "Region must not be null");
+ }
/// Initialize a successor that branches back to/out of the parent operation.
- RegionSuccessor(Operation::result_range results)
- : inputs(ValueRange(results)) {}
- /// Constructor with no arguments.
- RegionSuccessor() : inputs(ValueRange()) {}
+ /// The target must be one of the recursive parent operations.
+ RegionSuccessor(Operation *successorOp, Operation::result_range results)
+ : successor(successorOp), inputs(ValueRange(results)) {
+ assert(successorOp && "Successor op must not be null");
+ }
/// Return the given region successor. Returns nullptr if the successor is the
/// parent operation.
- Region *getSuccessor() const { return region; }
+ Region *getSuccessor() const { return dyn_cast<Region *>(successor); }
/// Return true if the successor is the parent operation.
- bool isParent() const { return region == nullptr; }
+ bool isParent() const { return isa<Operation *>(successor); }
/// Return the inputs to the successor that are remapped by the exit values of
/// the current region.
ValueRange getSuccessorInputs() const { return inputs; }
+ bool operator==(RegionSuccessor rhs) const {
+ return successor == rhs.successor && inputs == rhs.inputs;
+ }
+
+ friend bool operator!=(RegionSuccessor lhs, RegionSuccessor rhs) {
+ return !(lhs == rhs);
+ }
+
private:
- Region *region{nullptr};
+ llvm::PointerUnion<Region *, Operation *> successor{nullptr};
ValueRange inputs;
};
@@ -214,64 +233,67 @@ private:
/// `RegionBranchOpInterface`.
/// One can branch from one of two kinds of places:
/// * The parent operation (aka the `RegionBranchOpInterface` implementation)
-/// * A region within the parent operation.
+/// * A RegionBranchTerminatorOpInterface inside a region within the parent
+// operation.
class RegionBranchPoint {
public:
/// Returns an instance of `RegionBranchPoint` representing the parent
/// operation.
static constexpr RegionBranchPoint parent() { return RegionBranchPoint(); }
- /// Creates a `RegionBranchPoint` that branches from the given region.
- /// The pointer must not be null.
- RegionBranchPoint(Region *region) : maybeRegion(region) {
- assert(region && "Region must not be null");
- }
-
- RegionBranchPoint(Region &region) : RegionBranchPoint(&region) {}
+ /// Creates a `RegionBranchPoint` that branches from the given terminator.
+ inline RegionBranchPoint(RegionBranchTerminatorOpInterface predecessor);
/// Explicitly stops users from constructing with `nullptr`.
RegionBranchPoint(std::nullptr_t) = delete;
- /// Constructs a `RegionBranchPoint` from the the target of a
- /// `RegionSuccessor` instance.
- RegionBranchPoint(RegionSuccessor successor) {
- if (successor.isParent())
- maybeRegion = nullptr;
- else
- maybeRegion = successor.getSuccessor();
- }
-
- /// Assigns a region being branched from.
- RegionBranchPoint &operator=(Region &region) {
- maybeRegion = &region;
- return *this;
- }
-
/// Returns true if branching from the parent op.
- bool isParent() const { return maybeRegion == nullptr; }
+ bool isParent() const { return predecessor == nullptr; }
- /// Returns the region if branching from a region.
+ /// Returns the terminator if branching from a region.
/// A null pointer otherwise.
- Region *getRegionOrNull() const { return maybeRegion; }
+ Operation *getTerminatorPredecessorOrNull() const { return predecessor; }
/// Returns true if the two branch points are equal.
friend bool operator==(RegionBranchPoint lhs, RegionBranchPoint rhs) {
- return lhs.maybeRegion == rhs.maybeRegion;
+ return lhs.predecessor == rhs.predecessor;
}
private:
// Private constructor to encourage the use of `RegionBranchPoint::parent`.
- constexpr RegionBranchPoint() : maybeRegion(nullptr) {}
+ constexpr RegionBranchPoint() = default;
/// Internal encoding. Uses nullptr for representing branching from the parent
- /// op and the region being branched from otherwise.
- Region *maybeRegion;
+ /// op and the region terminator being branched from otherwise.
+ Operation *predecessor = nullptr;
};
inline bool operator!=(RegionBranchPoint lhs, RegionBranchPoint rhs) {
return !(lhs == rhs);
}
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionBranchPoint point) {
+ if (point.isParent())
+ return os << "<from parent>";
+ return os << "<region #"
+ << point.getTerminatorPredecessorOrNull()
+ ->getParentRegion()
+ ->getRegionNumber()
+ << ", terminator "
+ << OpWithFlags(point.getTerminatorPredecessorOrNull(),
+ OpPrintingFlags().skipRegions())
+ << ">";
+}
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
+ RegionSuccessor successor) {
+ if (successor.isParent())
+ return os << "<to parent>";
+ return os << "<to region #" << successor.getSuccessor()->getRegionNumber()
+ << " with " << successor.getSuccessorInputs().size() << " inputs>";
+}
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
@@ -348,4 +370,10 @@ struct ReturnLike : public TraitBase<ConcreteType, ReturnLike> {
/// Include the generated interface declarations.
#include "mlir/Interfaces/ControlFlowInterfaces.h.inc"
+namespace mlir {
+inline RegionBranchPoint::RegionBranchPoint(
+ RegionBranchTerminatorOpInterface predecessor)
+ : predecessor(predecessor.getOperation()) {}
+} // namespace mlir
+
#endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index b8d08cc..94242e3 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -117,7 +117,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let description = [{
- This interface provides information for region operations that exhibit
+ This interface provides information for region-holding operations that exhibit
branching behavior between held regions. I.e., this interface allows for
expressing control flow information for region holding operations.
@@ -126,12 +126,12 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
be side-effect free.
A "region branch point" indicates a point from which a branch originates. It
- can indicate either a region of this op or `RegionBranchPoint::parent()`. In
- the latter case, the branch originates from outside of the op, i.e., when
- first executing this op.
+ can indicate either a terminator in any of the immediately nested region of
+ this op or `RegionBranchPoint::parent()`. In the latter case, the branch
+ originates from outside of the op, i.e., when first executing this op.
A "region successor" indicates the target of a branch. It can indicate
- either a region of this op or this op. In the former case, the region
+ either a region of this op or this op itself. In the former case, the region
successor is a region pointer and a range of block arguments to which the
"successor operands" are forwarded to. In the latter case, the control flow
leaves this op and the region successor is a range of results of this op to
@@ -151,10 +151,10 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}
```
- `scf.for` has one region. The region has two region successors: the region
- itself and the `scf.for` op. %b is an entry successor operand. %c is a
- successor operand. %a is a successor block argument. %r is a successor
- result.
+ `scf.for` has one region. The `scf.yield` has two region successors: the
+ region body itself and the `scf.for` op. `%b` is an entry successor
+ operand. `%c` is a successor operand. `%a` is a successor block argument.
+ `%r` is a successor result.
}];
let cppNamespace = "::mlir";
@@ -162,16 +162,16 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
InterfaceMethod<[{
Returns the operands of this operation that are forwarded to the region
successor's block arguments or this operation's results when branching
- to `point`. `point` is guaranteed to be among the successors that are
+ to `successor`. `successor` is guaranteed to be among the successors that are
returned by `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
- `scf.for` op, regardless of the value of `point`. I.e., this op always
+ `scf.for` op, regardless of the value of `successor`. I.e., this op always
forwards the same operands, regardless of whether the loop has 0 or more
iterations.
}],
"::mlir::OperandRange", "getEntrySuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point), [{}],
+ (ins "::mlir::RegionSuccessor":$successor), [{}],
/*defaultImplementation=*/[{
auto operandEnd = this->getOperation()->operand_end();
return ::mlir::OperandRange(operandEnd, operandEnd);
@@ -225,6 +225,80 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
"::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions)
>,
InterfaceMethod<[{
+ Returns the potential region successors when branching from any
+ terminator in `region`.
+ These are the regions that may be selected during the flow of control.
+ }],
+ "void", "getSuccessorRegions",
+ (ins "::mlir::Region&":$region,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}],
+ /*defaultImplementation=*/[{
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ regions);
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential branching point (predecessors) for a given successor.
+ }],
+ "void", "getPredecessors",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "::llvm::SmallVectorImpl<::mlir::RegionBranchPoint> &":$predecessors),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint::parent(),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(RegionBranchPoint::parent());
+ for (Region &region : $_op->getRegions()) {
+ for (::mlir::Block &block : region) {
+ if (block.empty())
+ continue;
+ if (auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back())) {
+ ::llvm::SmallVector<::mlir::RegionSuccessor> successors;
+ $_op.getSuccessorRegions(RegionBranchPoint(terminator),
+ successors);
+ if (llvm::any_of(successors, [&] (const RegionSuccessor & succ) {
+ return succ.getSuccessor() == successor.getSuccessor() ||
+ (succ.isParent() && successor.isParent());
+ }))
+ predecessors.push_back(terminator);
+ }
+ }
+ }
+ }]>,
+ InterfaceMethod<[{
+ Returns the potential values across all (predecessors) for a given successor
+ input, modeled by its index (its position in the list of values).
+ }],
+ "void", "getPredecessorValues",
+ (ins "::mlir::RegionSuccessor":$successor,
+ "int":$index,
+ "::llvm::SmallVectorImpl<::mlir::Value> &":$predecessorValues),
+ [{}],
+ /*defaultImplementation=*/[{
+ ::llvm::SmallVector<::mlir::RegionBranchPoint> predecessors;
+ $_op.getPredecessors(successor, predecessors);
+ for (auto predecessor : predecessors) {
+ if (predecessor.isParent()) {
+ predecessorValues.push_back($_op.getEntrySuccessorOperands(successor)[index]);
+ continue;
+ }
+ auto terminator = cast<RegionBranchTerminatorOpInterface>(predecessor.getTerminatorPredecessorOrNull());
+ predecessorValues.push_back(terminator.getSuccessorOperands(successor)[index]);
+ }
+ }]>,
+ InterfaceMethod<[{
Populates `invocationBounds` with the minimum and maximum number of
times this operation will invoke the attached regions (assuming the
regions yield normally, i.e. do not abort or invoke an infinite loop).
@@ -298,7 +372,7 @@ def RegionBranchTerminatorOpInterface :
passing them to the region successor indicated by `point`.
}],
"::mlir::MutableOperandRange", "getMutableSuccessorOperands",
- (ins "::mlir::RegionBranchPoint":$point)
+ (ins "::mlir::RegionSuccessor":$point)
>,
InterfaceMethod<[{
Returns the potential region successors that are branched to after this
@@ -317,7 +391,7 @@ def RegionBranchTerminatorOpInterface :
/*defaultImplementation=*/[{
::mlir::Operation *op = $_op;
::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(op->getParentRegion(), regions);
+ .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
}]
>,
];
@@ -337,8 +411,8 @@ def RegionBranchTerminatorOpInterface :
// them to the region successor given by `index`. If `index` is None, this
// function returns the operands that are passed as a result to the parent
// operation.
- ::mlir::OperandRange getSuccessorOperands(::mlir::RegionBranchPoint point) {
- return getMutableSuccessorOperands(point);
+ ::mlir::OperandRange getSuccessorOperands(::mlir::RegionSuccessor successor) {
+ return getMutableSuccessorOperands(successor);
}
}];
}
@@ -504,7 +578,7 @@ def ReturnLike : TraitList<[
/*extraOpDeclaration=*/"",
/*extraOpDefinition=*/[{
::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands(
- ::mlir::RegionBranchPoint point) {
+ ::mlir::RegionSuccessor successor) {
return ::mlir::MutableOperandRange(*this);
}
}]
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index a84d10d..24cb123 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -16,19 +16,21 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
#include <utility>
using namespace mlir;
+#define DEBUG_TYPE "local-alias-analysis"
+
//===----------------------------------------------------------------------===//
// Underlying Address Computation
//===----------------------------------------------------------------------===//
@@ -42,81 +44,47 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output);
-/// Given a successor (`region`) of a RegionBranchOpInterface, collect all of
-/// the underlying values being addressed by one of the successor inputs. If the
-/// provided `region` is null, as per `RegionBranchOpInterface` this represents
-/// the parent operation.
-static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
- Region *region, Value inputValue,
- unsigned inputIndex,
- unsigned maxDepth,
- DenseSet<Value> &visited,
- SmallVectorImpl<Value> &output) {
- // Given the index of a region of the branch (`predIndex`), or std::nullopt to
- // represent the parent operation, try to return the index into the outputs of
- // this region predecessor that correspond to the input values of `region`. If
- // an index could not be found, std::nullopt is returned instead.
- auto getOperandIndexIfPred =
- [&](RegionBranchPoint pred) -> std::optional<unsigned> {
- SmallVector<RegionSuccessor, 2> successors;
- branch.getSuccessorRegions(pred, successors);
- for (RegionSuccessor &successor : successors) {
- if (successor.getSuccessor() != region)
- continue;
- // Check that the successor inputs map to the given input value.
- ValueRange inputs = successor.getSuccessorInputs();
- if (inputs.empty()) {
- output.push_back(inputValue);
- break;
- }
- unsigned firstInputIndex, lastInputIndex;
- if (region) {
- firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
- lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
- } else {
- firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
- lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
- }
- if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
- output.push_back(inputValue);
- break;
- }
- return inputIndex - firstInputIndex;
- }
- return std::nullopt;
- };
-
- // Check branches from the parent operation.
- auto branchPoint = RegionBranchPoint::parent();
- if (region)
- branchPoint = region;
-
- if (std::optional<unsigned> operandIndex =
- getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
- collectUnderlyingAddressValues(
- branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
- visited, output);
+/// Given a RegionBranchOpInterface operation (`branch`), a Value`inputValue`
+/// which is an input for the provided successor (`initialSuccessor`), try to
+/// find the possible sources for the value along the control flow edges.
+static void collectUnderlyingAddressValues2(
+ RegionBranchOpInterface branch, RegionSuccessor initialSuccessor,
+ Value inputValue, unsigned inputIndex, unsigned maxDepth,
+ DenseSet<Value> &visited, SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues2: "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " with initialSuccessor " << initialSuccessor;
+ LDBG() << " inputValue: " << inputValue;
+ LDBG() << " inputIndex: " << inputIndex;
+ LDBG() << " maxDepth: " << maxDepth;
+ ValueRange inputs = initialSuccessor.getSuccessorInputs();
+ if (inputs.empty()) {
+ LDBG() << " input is empty, enqueue value";
+ output.push_back(inputValue);
+ return;
}
- // Check branches from each child region.
- Operation *op = branch.getOperation();
- for (Region &region : op->getRegions()) {
- if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
- for (Block &block : region) {
- // Try to determine possible region-branch successor operands for the
- // current region.
- if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
- block.getTerminator())) {
- collectUnderlyingAddressValues(
- term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
- visited, output);
- } else if (block.getNumSuccessors()) {
- // Otherwise, if this terminator may exit the region we can't make
- // any assumptions about which values get passed.
- output.push_back(inputValue);
- return;
- }
- }
- }
+ unsigned firstInputIndex, lastInputIndex;
+ if (isa<BlockArgument>(inputs[0])) {
+ firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
+ lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
+ } else {
+ firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
+ lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
+ }
+ if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
+ LDBG() << " !! Input index " << inputIndex << " out of range "
+ << firstInputIndex << " to " << lastInputIndex
+ << ", adding input value to output";
+ output.push_back(inputValue);
+ return;
+ }
+ SmallVector<Value> predecessorValues;
+ branch.getPredecessorValues(initialSuccessor, inputIndex - firstInputIndex,
+ predecessorValues);
+ LDBG() << " Found " << predecessorValues.size() << " predecessor values";
+ for (Value predecessorValue : predecessorValues) {
+ LDBG() << " Processing predecessor value: " << predecessorValue;
+ collectUnderlyingAddressValues(predecessorValue, maxDepth, visited, output);
}
}
@@ -124,22 +92,28 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues (OpResult): " << result;
+ LDBG() << " maxDepth: " << maxDepth;
+
Operation *op = result.getOwner();
// If this is a view, unwrap to the source.
if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) {
if (result == view.getViewDest()) {
+ LDBG() << " Unwrapping view to source: " << view.getViewSource();
return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
visited, output);
}
}
// Check to see if we can reason about the control flow of this op.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result,
- result.getResultNumber(), maxDepth,
- visited, output);
+ LDBG() << " Processing region branch operation";
+ return collectUnderlyingAddressValues2(
+ branch, RegionSuccessor(op, op->getResults()), result,
+ result.getResultNumber(), maxDepth, visited, output);
}
+ LDBG() << " Adding result to output: " << result;
output.push_back(result);
}
@@ -148,14 +122,23 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues (BlockArgument): " << arg;
+ LDBG() << " maxDepth: " << maxDepth;
+ LDBG() << " argNumber: " << arg.getArgNumber();
+ LDBG() << " isEntryBlock: " << arg.getOwner()->isEntryBlock();
+
Block *block = arg.getOwner();
unsigned argNumber = arg.getArgNumber();
// Handle the case of a non-entry block.
if (!block->isEntryBlock()) {
+ LDBG() << " Processing non-entry block with "
+ << std::distance(block->pred_begin(), block->pred_end())
+ << " predecessors";
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
if (!branch) {
+ LDBG() << " Cannot analyze control flow, adding argument to output";
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
@@ -165,10 +148,12 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
unsigned index = it.getSuccessorIndex();
Value operand = branch.getSuccessorOperands(index)[argNumber];
if (!operand) {
+ LDBG() << " No operand found for argument, adding to output";
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
}
+ LDBG() << " Processing operand from predecessor: " << operand;
collectUnderlyingAddressValues(operand, maxDepth, visited, output);
}
return;
@@ -178,10 +163,35 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
Region *region = block->getParent();
Operation *op = region->getParentOp();
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- return collectUnderlyingAddressValues(branch, region, arg, argNumber,
- maxDepth, visited, output);
+ LDBG() << " Processing region branch operation for entry block";
+ // We have to find the successor matching the region, so that the input
+ // arguments are correctly set.
+ // TODO: this isn't comprehensive: the successor may not be reachable from
+ // the entry block.
+ SmallVector<RegionSuccessor> successors;
+ branch.getSuccessorRegions(RegionBranchPoint::parent(), successors);
+ RegionSuccessor regionSuccessor(region);
+ bool found = false;
+ for (RegionSuccessor &successor : successors) {
+ if (successor.getSuccessor() == region) {
+ LDBG() << " Found matching region successor: " << successor;
+ found = true;
+ regionSuccessor = successor;
+ break;
+ }
+ }
+ if (!found) {
+ LDBG()
+ << " No matching region successor found, adding argument to output";
+ output.push_back(arg);
+ return;
+ }
+ return collectUnderlyingAddressValues2(
+ branch, regionSuccessor, arg, argNumber, maxDepth, visited, output);
}
+ LDBG()
+ << " Cannot reason about underlying address, adding argument to output";
// We can't reason about the underlying address of this argument.
output.push_back(arg);
}
@@ -190,17 +200,26 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues: " << value;
+ LDBG() << " maxDepth: " << maxDepth;
+
// Check that we don't infinitely recurse.
- if (!visited.insert(value).second)
+ if (!visited.insert(value).second) {
+ LDBG() << " Value already visited, skipping";
return;
+ }
if (maxDepth == 0) {
+ LDBG() << " Max depth reached, adding value to output";
output.push_back(value);
return;
}
--maxDepth;
- if (BlockArgument arg = dyn_cast<BlockArgument>(value))
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
+ LDBG() << " Processing as BlockArgument";
return collectUnderlyingAddressValues(arg, maxDepth, visited, output);
+ }
+ LDBG() << " Processing as OpResult";
collectUnderlyingAddressValues(cast<OpResult>(value), maxDepth, visited,
output);
}
@@ -208,9 +227,11 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
/// Given a value, collect all of the underlying values being addressed.
static void collectUnderlyingAddressValues(Value value,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues: " << value;
DenseSet<Value> visited;
collectUnderlyingAddressValues(value, maxUnderlyingValueSearchDepth, visited,
output);
+ LDBG() << " Collected " << output.size() << " underlying values";
}
//===----------------------------------------------------------------------===//
@@ -227,19 +248,33 @@ static LogicalResult
getAllocEffectFor(Value value,
std::optional<MemoryEffects::EffectInstance> &effect,
Operation *&allocScopeOp) {
+ LDBG() << "getAllocEffectFor: " << value;
+
// Try to get a memory effect interface for the parent operation.
Operation *op;
- if (BlockArgument arg = dyn_cast<BlockArgument>(value))
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
op = arg.getOwner()->getParentOp();
- else
+ LDBG() << " BlockArgument, parent op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ } else {
op = cast<OpResult>(value).getOwner();
+ LDBG() << " OpResult, owner op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ }
+
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!interface)
+ if (!interface) {
+ LDBG() << " No memory effect interface found";
return failure();
+ }
// Try to find an allocation effect on the resource.
- if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value)))
+ if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value))) {
+ LDBG() << " No allocation effect found on value";
return failure();
+ }
+
+ LDBG() << " Found allocation effect";
// If we found an allocation effect, try to find a scope for the allocation.
// If the resource of this allocation is automatically scoped, find the parent
@@ -247,6 +282,12 @@ getAllocEffectFor(Value value,
if (llvm::isa<SideEffects::AutomaticAllocationScopeResource>(
effect->getResource())) {
allocScopeOp = op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
+ if (allocScopeOp) {
+ LDBG() << " Automatic allocation scope found: "
+ << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions());
+ } else {
+ LDBG() << " Automatic allocation scope found: null";
+ }
return success();
}
@@ -255,6 +296,12 @@ getAllocEffectFor(Value value,
// For now assume allocation scope to the function scope (we don't care if
// pointer escape outside function).
allocScopeOp = op->getParentOfType<FunctionOpInterface>();
+ if (allocScopeOp) {
+ LDBG() << " Function scope found: "
+ << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions());
+ } else {
+ LDBG() << " Function scope found: null";
+ }
return success();
}
@@ -293,33 +340,44 @@ static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) {
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
- if (lhs == rhs)
+ LDBG() << "aliasImpl: " << lhs << " vs " << rhs;
+
+ if (lhs == rhs) {
+ LDBG() << " Same value, must alias";
return AliasResult::MustAlias;
+ }
+
Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr;
std::optional<MemoryEffects::EffectInstance> lhsAlloc, rhsAlloc;
// Handle the case where lhs is a constant.
Attribute lhsAttr, rhsAttr;
if (matchPattern(lhs, m_Constant(&lhsAttr))) {
+ LDBG() << " lhs is constant";
// TODO: This is overly conservative. Two matching constants don't
// necessarily map to the same address. For example, if the two values
// correspond to different symbols that both represent a definition.
- if (matchPattern(rhs, m_Constant(&rhsAttr)))
+ if (matchPattern(rhs, m_Constant(&rhsAttr))) {
+ LDBG() << " rhs is also constant, may alias";
return AliasResult::MayAlias;
+ }
// Try to find an alloc effect on rhs. If an effect was found we can't
// alias, otherwise we might.
- return succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope))
- ? AliasResult::NoAlias
- : AliasResult::MayAlias;
+ bool rhsHasAlloc =
+ succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope));
+ LDBG() << " rhs has alloc effect: " << rhsHasAlloc;
+ return rhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
// Handle the case where rhs is a constant.
if (matchPattern(rhs, m_Constant(&rhsAttr))) {
+ LDBG() << " rhs is constant";
// Try to find an alloc effect on lhs. If an effect was found we can't
// alias, otherwise we might.
- return succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope))
- ? AliasResult::NoAlias
- : AliasResult::MayAlias;
+ bool lhsHasAlloc =
+ succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
+ LDBG() << " lhs has alloc effect: " << lhsHasAlloc;
+ return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs))
@@ -329,9 +387,14 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// an allocation effect.
bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
bool rhsHasAlloc = succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope));
+ LDBG() << " lhs has alloc effect: " << lhsHasAlloc;
+ LDBG() << " rhs has alloc effect: " << rhsHasAlloc;
+
if (lhsHasAlloc == rhsHasAlloc) {
// If both values have an allocation effect we know they don't alias, and if
// neither have an effect we can't make an assumptions.
+ LDBG() << " Both have same alloc status: "
+ << (lhsHasAlloc ? "NoAlias" : "MayAlias");
return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
@@ -339,6 +402,7 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// and one without. Move the one with the effect to the lhs to make the next
// checks simpler.
if (rhsHasAlloc) {
+ LDBG() << " Swapping lhs and rhs to put alloc effect on lhs";
std::swap(lhs, rhs);
lhsAlloc = rhsAlloc;
lhsAllocScope = rhsAllocScope;
@@ -347,49 +411,74 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// If the effect has a scoped allocation region, check to see if the
// non-effect value is defined above that scope.
if (lhsAllocScope) {
+ LDBG() << " Checking allocation scope: "
+ << OpWithFlags(lhsAllocScope, OpPrintingFlags().skipRegions());
// If the parent operation of rhs is an ancestor of the allocation scope, or
// if rhs is an entry block argument of the allocation scope we know the two
// values can't alias.
Operation *rhsParentOp = rhs.getParentRegion()->getParentOp();
- if (rhsParentOp->isProperAncestor(lhsAllocScope))
+ if (rhsParentOp->isProperAncestor(lhsAllocScope)) {
+ LDBG() << " rhs parent is ancestor of alloc scope, no alias";
return AliasResult::NoAlias;
+ }
if (rhsParentOp == lhsAllocScope) {
BlockArgument rhsArg = dyn_cast<BlockArgument>(rhs);
- if (rhsArg && rhs.getParentBlock()->isEntryBlock())
+ if (rhsArg && rhs.getParentBlock()->isEntryBlock()) {
+ LDBG() << " rhs is entry block arg of alloc scope, no alias";
return AliasResult::NoAlias;
+ }
}
}
// If we couldn't reason about the relationship between the two values,
// conservatively assume they might alias.
+ LDBG() << " Cannot reason about relationship, may alias";
return AliasResult::MayAlias;
}
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
- if (lhs == rhs)
+ LDBG() << "alias: " << lhs << " vs " << rhs;
+
+ if (lhs == rhs) {
+ LDBG() << " Same value, must alias";
return AliasResult::MustAlias;
+ }
// Get the underlying values being addressed.
SmallVector<Value, 8> lhsValues, rhsValues;
collectUnderlyingAddressValues(lhs, lhsValues);
collectUnderlyingAddressValues(rhs, rhsValues);
+ LDBG() << " lhs underlying values: " << lhsValues.size();
+ LDBG() << " rhs underlying values: " << rhsValues.size();
+
// If we failed to collect for either of the values somehow, conservatively
// assume they may alias.
- if (lhsValues.empty() || rhsValues.empty())
+ if (lhsValues.empty() || rhsValues.empty()) {
+ LDBG() << " Failed to collect underlying values, may alias";
return AliasResult::MayAlias;
+ }
// Check the alias results against each of the underlying values.
std::optional<AliasResult> result;
for (Value lhsVal : lhsValues) {
for (Value rhsVal : rhsValues) {
+ LDBG() << " Checking underlying values: " << lhsVal << " vs " << rhsVal;
AliasResult nextResult = aliasImpl(lhsVal, rhsVal);
+ LDBG() << " Result: "
+ << (nextResult == AliasResult::MustAlias ? "MustAlias"
+ : nextResult == AliasResult::NoAlias ? "NoAlias"
+ : "MayAlias");
result = result ? result->merge(nextResult) : nextResult;
}
}
// We should always have a valid result here.
+ LDBG() << " Final result: "
+ << (result->isMust() ? "MustAlias"
+ : result->isNo() ? "NoAlias"
+ : "MayAlias");
return *result;
}
@@ -398,8 +487,12 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
//===----------------------------------------------------------------------===//
ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
+ LDBG() << "getModRef: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " on location " << location;
+
// Check to see if this operation relies on nested side effects.
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
+ LDBG() << " Operation has recursive memory effects, returning ModAndRef";
// TODO: To check recursive operations we need to check all of the nested
// operations, which can result in a quadratic number of queries. We should
// introduce some caching of some kind to help alleviate this, especially as
@@ -410,38 +503,64 @@ ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
// Otherwise, check to see if this operation has a memory effect interface.
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!interface)
+ if (!interface) {
+ LDBG() << " No memory effect interface, returning ModAndRef";
return ModRefResult::getModAndRef();
+ }
// Build a ModRefResult by merging the behavior of the effects of this
// operation.
SmallVector<MemoryEffects::EffectInstance> effects;
interface.getEffects(effects);
+ LDBG() << " Found " << effects.size() << " memory effects";
ModRefResult result = ModRefResult::getNoModRef();
for (const MemoryEffects::EffectInstance &effect : effects) {
- if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
+ if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect())) {
+ LDBG() << " Skipping alloc/free effect";
continue;
+ }
// Check for an alias between the effect and our memory location.
// TODO: Add support for checking an alias with a symbol reference.
AliasResult aliasResult = AliasResult::MayAlias;
- if (Value effectValue = effect.getValue())
+ if (Value effectValue = effect.getValue()) {
+ LDBG() << " Checking alias between effect value " << effectValue
+ << " and location " << location;
aliasResult = alias(effectValue, location);
+ LDBG() << " Alias result: "
+ << (aliasResult.isMust() ? "MustAlias"
+ : aliasResult.isNo() ? "NoAlias"
+ : "MayAlias");
+ } else {
+ LDBG() << " No effect value, assuming MayAlias";
+ }
// If we don't alias, ignore this effect.
- if (aliasResult.isNo())
+ if (aliasResult.isNo()) {
+ LDBG() << " No alias, ignoring effect";
continue;
+ }
// Merge in the corresponding mod or ref for this effect.
if (isa<MemoryEffects::Read>(effect.getEffect())) {
+ LDBG() << " Adding Ref to result";
result = result.merge(ModRefResult::getRef());
} else {
assert(isa<MemoryEffects::Write>(effect.getEffect()));
+ LDBG() << " Adding Mod to result";
result = result.merge(ModRefResult::getMod());
}
- if (result.isModAndRef())
+ if (result.isModAndRef()) {
+ LDBG() << " Result is now ModAndRef, breaking";
break;
+ }
}
+
+ LDBG() << " Final ModRef result: "
+ << (result.isModAndRef() ? "ModAndRef"
+ : result.isMod() ? "Mod"
+ : result.isRef() ? "Ref"
+ : "NoModRef");
return result;
}
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 377f7eb..0fc5b44 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -501,11 +501,10 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
return;
SmallVector<RegionSuccessor> successors;
- if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op))
- terminator.getSuccessorRegions(*operands, successors);
- else
- branch.getSuccessorRegions(op->getParentRegion(), successors);
-
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op);
+ if (!terminator)
+ return;
+ terminator.getSuccessorRegions(*operands, successors);
visitRegionBranchEdges(branch, op, successors);
}
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index daa3db5..0682e5f 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -588,7 +588,9 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
LDBG() << " Exit block of region branch operation";
- visitRegionBranchOperation(point, branch, block->getParent(), before);
+ auto terminator =
+ cast<RegionBranchTerminatorOpInterface>(block->getTerminator());
+ visitRegionBranchOperation(point, branch, terminator, before);
return;
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0d2e2ed..8e63ae8 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -130,7 +130,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
visitRegionSuccessors(getProgramPointAfter(branch), branch,
- /*successor=*/RegionBranchPoint::parent(),
+ /*successor=*/{branch, branch->getResults()},
resultLattices);
return success();
}
@@ -279,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
+ RegionSuccessor successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@@ -314,7 +314,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(
- branch->getResults().slice(firstIndex, inputs.size())),
+ branch, branch->getResults().slice(firstIndex, inputs.size())),
lattices, firstIndex);
} else {
if (!inputs.empty())
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
index 817d71a..863f260 100644
--- a/mlir/lib/Analysis/SliceWalk.cpp
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -114,7 +114,7 @@ mlir::getControlFlowPredecessors(Value value) {
if (!regionOp)
return std::nullopt;
// Add the control flow predecessor operands to the work list.
- RegionSuccessor region(regionOp->getResults());
+ RegionSuccessor region(regionOp, regionOp->getResults());
SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
regionOp, region, opResult.getResultNumber());
return predecessorOperands;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aa..1eca43d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
smfma.getN(), smfma.getK(), 1u, chipset);
}
-/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
-/// if one exists. This includes checking to ensure the intrinsic is supported
-/// on the architecture you are compiling for.
-static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
- Chipset chipset) {
- auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
- Type elemSourceType = sourceVectorType.getElementType();
- Type elemBSourceType = sourceBVectorType.getElementType();
- Type elemDestType = destVectorType.getElementType();
-
- const uint32_t k = wmma.getK();
-
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for RDNA3/4 architectures.
+static std::optional<StringRef>
+wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
+ Type elemDestType, uint32_t k, bool isRDNA3) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ // Handle k == 16 for RDNA3/4.
if (k == 16) {
+ // Common patterns for RDNA3 and RDNA4.
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
+
+ // RDNA3 specific patterns.
+ if (isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ return std::nullopt;
}
- }
- if (chipset.majorVersion < 12)
- return std::nullopt;
- // gfx12+
- if (k == 16) {
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ // RDNA4 specific patterns (fp8/bf8).
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
- if (k == 32) {
+
+ // Handle k == 32 for RDNA4.
+ if (k == 32 && !isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ }
+
+ llvm_unreachable("Unsupported k value");
+}
+
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for the gfx1250 architecture.
+static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
+ Type elemBSourceType,
+ Type elemDestType,
+ uint32_t k) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ if (k == 4) {
+ if (elemSourceType.isF32() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
+
return std::nullopt;
}
+ if (k == 32) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 64) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+ }
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 128) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+ }
+
+ return std::nullopt;
+ }
+
+ llvm_unreachable("Unsupported k value");
+}
+
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+ Chipset chipset) {
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+ const bool isRDNA3 = chipset.majorVersion == 11;
+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+
+ // Handle RDNA3 and RDNA4.
+ if (isRDNA3 || isRDNA4)
+ return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
+ k, isRDNA3);
+
+ // Handle gfx1250.
+ if (chipset == Chipset{12, 5, 0})
+ return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
+ elemDestType, k);
+
llvm_unreachable("unhandled WMMA case");
}
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0fe7239..9e46b7d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -313,25 +313,53 @@ private:
struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
+ // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
+ // Handle special cases as StableHLO implementation does:
+ // 1. When b == 0, set imag(exp(z)) = 0
+ // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
LogicalResult
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
- auto elementType = cast<FloatType>(type.getElementType());
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
-
- Value real =
- complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
- Value imag =
- complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
- Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
- Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
+ auto ET = cast<FloatType>(type.getElementType());
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+ const auto &floatSemantics = ET.getFloatSemantics();
+ ImplicitLocOpBuilder b(loc, rewriter);
+
+ Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
+ Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
+ Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
+ Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
+ Value inf = arith::ConstantOp::create(
+ b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
+
+ Value exp = math::ExpOp::create(b, x, fmf);
+ Value xHalf = arith::MulFOp::create(b, x, half, fmf);
+ Value expHalf = math::ExpOp::create(b, xHalf, fmf);
+ Value cos = math::CosOp::create(b, y, fmf);
+ Value sin = math::SinOp::create(b, y, fmf);
+
+ Value expIsInf =
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
+ Value yIsZero =
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
+
+ // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
+ Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
+ Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
+ Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
Value resultReal =
- arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
- Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
- Value resultImag =
- arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
+ arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
+
+ // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
+ // exp(x/2)*sin(y)*exp(x/2)
+ Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
+ Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
+ Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
+ Value imagNonZero =
+ arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
+ Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 585b6da..df955fc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
- "source element types much match (except for fp8) but have ")
+ "source element types must match (except for fp8/bf8) but have ")
<< sourceAType << " and " << sourceBType;
}
- if (!sourceAElemType.isInteger(4) && getK() != 16) {
- return emitOpError("K dimension must be 16 for source element type ")
- << sourceAElemType;
+ if (isSrcFloat) {
+ if (getClamp())
+ return emitOpError("clamp flag is not supported for float types");
+ if (getUnsignedA() || getUnsignedB())
+ return emitOpError("unsigned flags are not supported for float types");
}
return success();
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e0a53cd..0c35921 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
return success(folded);
}
-OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getRegion()) && "invalid region point");
+OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert((successor.isParent() || successor.getSuccessor() == &getRegion()) &&
+ "invalid region point");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
void AffineForOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
- assert((point.isParent() || point == getRegion()) && "expected loop region");
+ assert((point.isParent() ||
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getRegion()) &&
+ "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
- if (point.isParent() && tripCount.has_value()) {
- if (tripCount.value() > 0) {
- regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- return;
- }
- if (tripCount.value() == 0) {
- regions.push_back(RegionSuccessor(getResults()));
- return;
+ if (tripCount.has_value()) {
+ if (!point.isParent()) {
+ // From the loop body, if the trip count is one, we can only branch back
+ // to the parent.
+ if (tripCount == 1) {
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ return;
+ }
+ if (tripCount == 0)
+ return;
+ } else {
+ if (tripCount.value() > 0) {
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ return;
+ }
+ if (tripCount.value() == 0) {
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ return;
+ }
}
}
- // From the loop body, if the trip count is one, we can only branch back to
- // the parent.
- if (!point.isParent() && tripCount == 1) {
- regions.push_back(RegionSuccessor(getResults()));
- return;
- }
-
// In all other cases, the loop may branch back to itself or the parent
// operation.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
AffineBound AffineForOp::getLowerBound() {
@@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
// If the "else" region is empty, branch bach into parent.
if (getElseRegion().empty()) {
- regions.push_back(getResults());
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
} else {
regions.push_back(
RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
@@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions(
// If the predecessor is the `else`/`then` region, then branching into parent
// op is valid.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
LogicalResult AffineIfOp::verify() {
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index dc7b07d..8e4a49d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -36,8 +36,9 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBodyRegion() && "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBodyRegion() &&
+ "invalid region index");
return getBodyOperands();
}
@@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `body` region branch back to the parent operation.
- if (point == getBodyRegion()) {
- regions.push_back(RegionSuccessor(getBodyResults()));
+ if (!point.isParent() &&
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBodyRegion()) {
+ regions.push_back(RegionSuccessor(getOperation(), getBodyResults()));
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index b593cca..36a759c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -562,8 +562,11 @@ LogicalResult
BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) {
SmallVector<TypeRange> returnOperandTypes(llvm::map_range(
op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(),
- [](RegionBranchTerminatorOpInterface op) {
- return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes();
+ [&](RegionBranchTerminatorOpInterface branchOp) {
+ return branchOp
+ .getSuccessorOperands(RegionSuccessor(
+ op.getOperation(), op.getOperation()->getResults()))
+ .getTypes();
}));
if (!llvm::all_equal(returnOperandTypes))
return op->emitError(
@@ -942,8 +945,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
// about, but we would need to check how many successors there are and under
// which condition they are taken, etc.
- MutableOperandRange operands =
- op.getMutableSuccessorOperands(RegionBranchPoint::parent());
+ MutableOperandRange operands = op.getMutableSuccessorOperands(
+ RegionSuccessor(op.getOperation(), op.getOperation()->getResults()));
SmallVector<Value> updatedOwnerships;
auto result = deallocation_impl::insertDeallocOpForReturnLike(
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4754f0b..0992ce14 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -845,7 +845,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
return;
}
@@ -854,7 +855,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -871,7 +873,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b5f8dda..6c6d8d2 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
void WarpExecuteOnLane0Op::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eb2d825..bd25e94 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -495,13 +495,14 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
if (failed(maybePackedDimForEachOperand))
return failure();
packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
- listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
LDBG() << "maps: " << llvm::interleaved(indexingMaps);
LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
LDBG() << "packedDimForEachOperand: "
<< llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
+
+ listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
}
// Step 2. Propagate packing to all LinalgOp operands.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c551fba..1c21a2f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
void AllocaScopeOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 6fa8ce4..69afbca 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -98,6 +98,27 @@ struct RankOpInterface
}
};
+struct CollapseShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+ memref::CollapseShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto collapseOp = cast<memref::CollapseShapeOp>(op);
+ assert(value == collapseOp.getResult() && "invalid value");
+
+ // Multiply the expressions for the dimensions in the reassociation group.
+ const ReassociationIndices reassocIndices =
+ collapseOp.getReassociationIndices()[dim];
+ AffineExpr productExpr =
+ cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+ for (size_t i = 1; i < reassocIndices.size(); ++i) {
+ productExpr =
+ productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+ }
+ cstr.bound(value)[dim] == productExpr;
+ }
+};
+
struct SubViewOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
+ *ctx);
memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
*ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1ab01d8..2946b53 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions(
}
// Otherwise, the region branches back to the parent operation.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
//===----------------------------------------------------------------------===//
@@ -405,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getParentOp().getAfter()) &&
- "condition op can only exit the loop or branch to the after"
- "region");
+ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
+ assert(
+ (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
+ "condition op can only exit the loop or branch to the after"
+ "region");
// Pass all operands except the condition to the successor region.
return getArgsMutable();
}
@@ -426,7 +427,7 @@ void ConditionOp::getSuccessorRegions(
regions.emplace_back(&whileOp.getAfter(),
whileOp.getAfter().getArguments());
if (!boolAttr || !boolAttr.getValue())
- regions.emplace_back(whileOp.getResults());
+ regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
}
//===----------------------------------------------------------------------===//
@@ -749,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
return dyn_cast_or_null<ForOp>(containingOp);
}
-OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
return getInitArgs();
}
@@ -759,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
@@ -2053,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point,
// parallel by multiple threads. We should not expect to branch back into
// the forall body after the region's execution is complete.
if (point.isParent())
- regions.push_back(RegionSuccessor(&getRegion()));
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
else
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
}
//===----------------------------------------------------------------------===//
@@ -2333,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) {
void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- // The `then` and the `else` region branch back to the parent operation.
+ // The `then` and the `else` region branch back to the parent operation or one
+ // of the recursive parent operations (early exit case).
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -2344,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -2361,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
}
}
@@ -3385,7 +3389,8 @@ void ParallelOp::getSuccessorRegions(
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion()));
- regions.push_back(RegionSuccessor());
+ regions.push_back(RegionSuccessor(
+ getOperation(), ResultRange{getResults().end(), getResults().end()}));
}
//===----------------------------------------------------------------------===//
@@ -3431,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() {
}
MutableOperandRange
-ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
// No operands are forwarded to the next iteration.
return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
@@ -3514,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
return getBeforeArguments();
}
-OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBefore() &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}
@@ -3528,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
return;
}
- assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
+ assert(llvm::is_contained(
+ {&getAfter(), &getBefore()},
+ point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
"there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (point == getAfter()) {
+ if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
regions.emplace_back(&getAfter(), getAfter().getArguments());
}
@@ -4445,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
if (!point.isParent()) {
- successors.emplace_back(getResults());
+ successors.emplace_back(getOperation(), getResults());
return;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ae52af5..ddcbda8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -23,7 +23,6 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
-using namespace llvm;
using namespace mlir;
using scf::ForOp;
using scf::WhileOp;
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1..00bef70 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -21,7 +21,6 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
-using namespace llvm;
using namespace mlir;
using scf::LoopNest;
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5ba8289..f0f22e5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions(
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1a9d9e1..3962e3e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2597,7 +2597,7 @@ std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
-OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
return getInitArgs();
}
@@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
// or back into the operation itself.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
// It is possible for loop not to enter the body.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index f53d272..ffa8b40 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -152,19 +152,20 @@ IterationGraphSorter IterationGraphSorter::fromGenericOp(
}
IterationGraphSorter::IterationGraphSorter(
- SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
- AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
+ SmallVector<Value> &&insArg, SmallVector<AffineMap> &&loop2InsLvlArg,
+ Value out, AffineMap loop2OutLvl,
+ SmallVector<utils::IteratorType> &&iterTypesArg,
sparse_tensor::LoopOrderingStrategy strategy)
- : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
- loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
+ : ins(std::move(insArg)), loop2InsLvl(std::move(loop2InsLvlArg)), out(out),
+ loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypesArg)),
strategy(strategy) {
// One map per tensor.
- assert(this->loop2InsLvl.size() == this->ins.size());
+ assert(loop2InsLvl.size() == ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
- this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
- assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
+ assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
auto [m, v] = mvPair;
// For ranked types the rank must match.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index b2a16e9..35e58ed 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -59,10 +59,10 @@ public:
private:
// Private constructor.
- IterationGraphSorter(SmallVector<Value> &&ins,
- SmallVector<AffineMap> &&loop2InsLvl, Value out,
+ IterationGraphSorter(SmallVector<Value> &&insArg,
+ SmallVector<AffineMap> &&loop2InsLvlArg, Value out,
AffineMap loop2OutLvl,
- SmallVector<utils::IteratorType> &&iterTypes,
+ SmallVector<utils::IteratorType> &&iterTypesArg,
sparse_tensor::LoopOrderingStrategy strategy =
sparse_tensor::LoopOrderingStrategy::kDefault);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 1e3b377..549ac7a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -77,7 +77,7 @@ FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
if (!consumerOp)
return failure();
- for (auto opOperand : consumerOperands.drop_front()) {
+ for (auto *opOperand : consumerOperands.drop_front()) {
if (opOperand->getOwner() != consumerOp) {
LLVM_DEBUG({
llvm::dbgs()
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 365afab..062606e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange
-transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- if (!point.isParent() && getOperation()->getNumOperands() == 1)
+OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
+ RegionSuccessor successor) {
+ if (!successor.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
@@ -107,15 +107,18 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
void transform::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
for (Region &alternative : llvm::drop_begin(
- getAlternatives(),
- point.isParent() ? 0
- : point.getRegionOrNull()->getRegionNumber() + 1)) {
+ getAlternatives(), point.isParent()
+ ? 0
+ : point.getTerminatorPredecessorOrNull()
+ ->getParentRegion()
+ ->getRegionNumber() +
+ 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (!point.isParent())
- regions.emplace_back(getOperation()->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::AlternativesOp::getRegionInvocationBounds(
@@ -1740,16 +1743,18 @@ void transform::ForeachOp::getSuccessorRegions(
}
// Branch back to the region or the parent.
- assert(point == getBody() && "unexpected region index");
+ assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBody() &&
+ "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
// Each block argument handle is mapped to a subset (one op to be precise)
// of the payload of the corresponding `targets` operand of ForeachOp.
- assert(point == getBody() && "unexpected region index");
+ assert(successor.getSuccessor() == &getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@@ -2948,8 +2953,8 @@ void transform::SequenceOp::getEffects(
}
OperandRange
-transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody() && "unexpected region index");
+transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@@ -2966,8 +2971,10 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
- assert(point == getBody() && "unexpected region index");
- regions.emplace_back(getOperation()->getResults());
+ assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBody() &&
+ "unexpected region index");
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::SequenceOp::getRegionInvocationBounds(
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index c627158..f727118 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
@@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
}
OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
- RegionBranchPoint point) {
+ RegionSuccessor successor) {
// No operands will be forwarded to the region(s).
return getOperands().slice(0, 0);
}
@@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions(
for (Region &alternative : getAlternatives())
regions.emplace_back(&alternative, Block::BlockArgListType());
else
- regions.emplace_back(getOperation()->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::tune::AlternativesOp::getRegionInvocationBounds(
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 776b5c6..f4c9242 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -138,6 +138,10 @@ Diagnostic &Diagnostic::operator<<(Operation &op) {
return appendOp(op, OpPrintingFlags());
}
+Diagnostic &Diagnostic::operator<<(OpWithFlags op) {
+ return appendOp(*op.getOperation(), op.flags());
+}
+
Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
std::string str;
llvm::raw_string_ostream os(str);
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 46b6298..15a941f 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -253,6 +253,21 @@ void Region::OpIterator::skipOverBlocksWithNoOps() {
operation = block->begin();
}
+llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, Region &region) {
+ if (!region.getParentOp()) {
+ os << "Region has no parent op";
+ } else {
+ os << "Region #" << region.getRegionNumber() << " in operation "
+ << region.getParentOp()->getName();
+ }
+ for (auto it : llvm::enumerate(region.getBlocks())) {
+ os << "\n Block #" << it.index() << ":";
+ for (Operation &op : it.value().getOperations())
+ os << "\n " << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ }
+ return os;
+}
+
//===----------------------------------------------------------------------===//
// RegionRange
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ca3f766..1e56810 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,7 +9,9 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -38,20 +40,31 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
std::optional<BlockArgument>
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
+ LDBG() << "Getting branch successor argument for operand index "
+ << operandIndex << " in successor block";
+
OperandRange forwardedOperands = operands.getForwardedOperands();
// Check that the operands are valid.
- if (forwardedOperands.empty())
+ if (forwardedOperands.empty()) {
+ LDBG() << "No forwarded operands, returning nullopt";
return std::nullopt;
+ }
// Check to ensure that this operand is within the range.
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
- operandIndex >= (operandsStart + forwardedOperands.size()))
+ operandIndex >= (operandsStart + forwardedOperands.size())) {
+ LDBG() << "Operand index " << operandIndex << " out of range ["
+ << operandsStart << ", "
+ << (operandsStart + forwardedOperands.size())
+ << "), returning nullopt";
return std::nullopt;
+ }
// Index the successor.
unsigned argIndex =
operands.getProducedOperandCount() + operandIndex - operandsStart;
+ LDBG() << "Computed argument index " << argIndex << " for successor block";
return successor->getArgument(argIndex);
}
@@ -59,9 +72,15 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
+ LDBG() << "Verifying branch successor operands for successor #" << succNo
+ << " in operation " << op->getName();
+
// Check the count.
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
+ LDBG() << "Branch has " << operandCount << " operands, target block has "
+ << destBB->getNumArguments() << " arguments";
+
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
@@ -69,13 +88,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
+ LDBG() << "Checking type compatibility for "
+ << (operandCount - operands.getProducedOperandCount())
+ << " forwarded operands";
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
- if (!cast<BranchOpInterface>(op).areTypesCompatible(
- operands[i].getType(), destBB->getArgument(i).getType()))
+ Type operandType = operands[i].getType();
+ Type argType = destBB->getArgument(i).getType();
+ LDBG() << "Checking type compatibility: operand type " << operandType
+ << " vs argument type " << argType;
+
+ if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
+
+ LDBG() << "Branch successor operand verification successful";
return success();
}
@@ -126,15 +154,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
RegionBranchPoint sourceNo,
- RegionBranchPoint succRegionNo) {
+ RegionSuccessor succRegionNo) {
diag << "from ";
- if (Region *region = sourceNo.getRegionOrNull())
- diag << "Region #" << region->getRegionNumber();
+ if (Operation *op = sourceNo.getTerminatorPredecessorOrNull())
+ diag << "Operation " << op->getName();
else
diag << "parent operands";
diag << " to ";
- if (Region *region = succRegionNo.getRegionOrNull())
+ if (Region *region = succRegionNo.getSuccessor())
diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
@@ -145,13 +173,12 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// types of the inputs that flow to a successor region.
static LogicalResult
-verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
- function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
+verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
+ RegionBranchPoint sourcePoint,
+ function_ref<FailureOr<TypeRange>(RegionSuccessor)>
getInputsTypesForRegion) {
- auto regionInterface = cast<RegionBranchOpInterface>(op);
-
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(sourcePoint, successors);
+ branchOp.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
@@ -160,10 +187,14 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
- InFlightDiagnostic diag = op->emitOpError("region control flow edge ");
+ InFlightDiagnostic diag =
+ branchOp->emitOpError("region control flow edge ");
+ std::string succStr;
+ llvm::raw_string_ostream os(succStr);
+ os << succ;
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
- << " operands, but target successor needs "
+ << " operands, but target successor " << os.str() << " needs "
<< succInputsTypes.size();
}
@@ -171,8 +202,10 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
Type sourceType = std::get<0>(typesIdx.value());
Type inputType = std::get<1>(typesIdx.value());
- if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
- InFlightDiagnostic diag = op->emitOpError("along control flow edge ");
+
+ if (!branchOp.areTypesCompatible(sourceType, inputType)) {
+ InFlightDiagnostic diag =
+ branchOp->emitOpError("along control flow edge ");
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
@@ -180,6 +213,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
}
}
}
+
return success();
}
@@ -187,34 +221,18 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
- auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
- return regionInterface.getEntrySuccessorOperands(point).getTypes();
+ auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange {
+ return regionInterface.getEntrySuccessorOperands(successor).getTypes();
};
// Verify types along control flow edges originating from the parent.
- if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
- inputTypesFromParent)))
+ if (failed(verifyTypesAlongAllEdges(
+ regionInterface, RegionBranchPoint::parent(), inputTypesFromParent)))
return failure();
- auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
- if (lhs.size() != rhs.size())
- return false;
- for (auto types : llvm::zip(lhs, rhs)) {
- if (!regionInterface.areTypesCompatible(std::get<0>(types),
- std::get<1>(types))) {
- return false;
- }
- }
- return true;
- };
-
// Verify types along control flow edges originating from each region.
for (Region &region : op->getRegions()) {
-
- // Since there can be multiple terminators implementing the
- // `RegionBranchTerminatorOpInterface`, all should have the same operand
- // types when passing them to the same region.
-
+ // Collect all return-like terminators in the region.
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
for (Block &block : region)
if (!block.empty())
@@ -227,33 +245,20 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (regionReturnOps.empty())
continue;
- auto inputTypesForRegion =
- [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
- std::optional<OperandRange> regionReturnOperands;
- for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
- auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
-
- if (!regionReturnOperands) {
- regionReturnOperands = terminatorOperands;
- continue;
- }
-
- // Found more than one ReturnLike terminator. Make sure the operand
- // types match with the first one.
- if (!areTypesCompatible(regionReturnOperands->getTypes(),
- terminatorOperands.getTypes())) {
- InFlightDiagnostic diag = op->emitOpError("along control flow edge");
- return printRegionEdgeName(diag, region, point)
- << " operands mismatch between return-like terminators";
- }
- }
-
- // All successors get the same set of operand types.
- return TypeRange(regionReturnOperands->getTypes());
- };
-
- if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
- return failure();
+ // Verify types along control flow edges originating from each return-like
+ // terminator.
+ for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
+
+ auto inputTypesForRegion =
+ [&](RegionSuccessor successor) -> FailureOr<TypeRange> {
+ OperandRange terminatorOperands =
+ regionReturnOp.getSuccessorOperands(successor);
+ return TypeRange(terminatorOperands.getTypes());
+ };
+ if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp,
+ inputTypesForRegion)))
+ return failure();
+ }
}
return success();
@@ -272,31 +277,74 @@ using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
static bool traverseRegionGraph(Region *begin,
StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
+ LDBG() << "Starting region graph traversal from region #"
+ << begin->getRegionNumber() << " in operation " << op->getName();
+
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
+ LDBG() << "Initialized visited array with " << op->getNumRegions()
+ << " regions";
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<Region *> worklist;
auto enqueueAllSuccessors = [&](Region *region) {
- SmallVector<RegionSuccessor> successors;
- op.getSuccessorRegions(region, successors);
- for (RegionSuccessor successor : successors)
- if (!successor.isParent())
- worklist.push_back(successor.getSuccessor());
+ LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
+ SmallVector<Attribute> operandAttributes(op->getNumOperands());
+ for (Block &block : *region) {
+ if (block.empty())
+ continue;
+ auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
+ if (!terminator)
+ continue;
+ SmallVector<RegionSuccessor> successors;
+ operandAttributes.resize(terminator->getNumOperands());
+ terminator.getSuccessorRegions(operandAttributes, successors);
+ LDBG() << "Found " << successors.size()
+ << " successors from terminator in block";
+ for (RegionSuccessor successor : successors) {
+ if (!successor.isParent()) {
+ worklist.push_back(successor.getSuccessor());
+ LDBG() << "Added region #"
+ << successor.getSuccessor()->getRegionNumber()
+ << " to worklist";
+ } else {
+ LDBG() << "Skipping parent successor";
+ }
+ }
+ }
};
enqueueAllSuccessors(begin);
+ LDBG() << "Initial worklist size: " << worklist.size();
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
- if (stopConditionFn(nextRegion, visited))
+ LDBG() << "Processing region #" << nextRegion->getRegionNumber()
+ << " from worklist (remaining: " << worklist.size() << ")";
+
+ if (stopConditionFn(nextRegion, visited)) {
+ LDBG() << "Stop condition met for region #"
+ << nextRegion->getRegionNumber() << ", returning true";
return true;
- if (visited[nextRegion->getRegionNumber()])
+ }
+ llvm::dbgs() << "Region: " << nextRegion << "\n";
+ if (!nextRegion->getParentOp()) {
+ llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
+ return false;
+ }
+ if (visited[nextRegion->getRegionNumber()]) {
+ LDBG() << "Region #" << nextRegion->getRegionNumber()
+ << " already visited, skipping";
continue;
+ }
visited[nextRegion->getRegionNumber()] = true;
+ LDBG() << "Marking region #" << nextRegion->getRegionNumber()
+ << " as visited";
enqueueAllSuccessors(nextRegion);
}
+ LDBG() << "Traversal completed, returning false";
return false;
}
@@ -322,18 +370,26 @@ static bool isRegionReachable(Region *begin, Region *r) {
/// mutually exclusive if they are not reachable from each other as per
/// RegionBranchOpInterface::getSuccessorRegions.
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
+ LDBG() << "Checking if operations are in mutually exclusive regions: "
+ << a->getName() << " and " << b->getName();
+
assert(a && "expected non-empty operation");
assert(b && "expected non-empty operation");
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
while (branchOp) {
+ LDBG() << "Checking branch operation " << branchOp->getName();
+
// Check if b is inside branchOp. (We already know that a is.)
if (!branchOp->isProperAncestor(b)) {
+ LDBG() << "Operation b is not inside branchOp, checking next ancestor";
// Check next enclosing RegionBranchOpInterface.
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
continue;
}
+ LDBG() << "Both operations are inside branchOp, finding their regions";
+
// b is contained in branchOp. Retrieve the regions in which `a` and `b`
// are contained.
Region *regionA = nullptr, *regionB = nullptr;
@@ -341,63 +397,136 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
if (r.findAncestorOpInRegion(*a)) {
assert(!regionA && "already found a region for a");
regionA = &r;
+ LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
}
if (r.findAncestorOpInRegion(*b)) {
assert(!regionB && "already found a region for b");
regionB = &r;
+ LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
}
}
assert(regionA && regionB && "could not find region of op");
+ LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
+ << regionB->getRegionNumber();
+
// `a` and `b` are in mutually exclusive regions if both regions are
// distinct and neither region is reachable from the other region.
- return regionA != regionB && !isRegionReachable(regionA, regionB) &&
- !isRegionReachable(regionB, regionA);
+ bool regionsAreDistinct = (regionA != regionB);
+ bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
+ bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
+
+ LDBG() << "Regions distinct: " << regionsAreDistinct
+ << ", A not reachable from B: " << aNotReachableFromB
+ << ", B not reachable from A: " << bNotReachableFromA;
+
+ bool mutuallyExclusive =
+ regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
+ LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
+
+ return mutuallyExclusive;
}
// Could not find a common RegionBranchOpInterface among a's and b's
// ancestors.
+ LDBG() << "No common RegionBranchOpInterface found, operations are not "
+ "mutually exclusive";
return false;
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
+ LDBG() << "Checking if region #" << index << " is repetitive in operation "
+ << getOperation()->getName();
+
Region *region = &getOperation()->getRegion(index);
- return isRegionReachable(region, region);
+ bool isRepetitive = isRegionReachable(region, region);
+
+ LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
+ return isRepetitive;
}
bool RegionBranchOpInterface::hasLoop() {
+ LDBG() << "Checking if operation " << getOperation()->getName()
+ << " has loops";
+
SmallVector<RegionSuccessor> entryRegions;
getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
- for (RegionSuccessor successor : entryRegions)
- if (!successor.isParent() &&
- traverseRegionGraph(successor.getSuccessor(),
- [](Region *nextRegion, ArrayRef<bool> visited) {
- // Interrupt traversal if the region was already
- // visited.
- return visited[nextRegion->getRegionNumber()];
- }))
- return true;
+ LDBG() << "Found " << entryRegions.size() << " entry regions";
+
+ for (RegionSuccessor successor : entryRegions) {
+ if (!successor.isParent()) {
+ LDBG() << "Checking entry region #"
+ << successor.getSuccessor()->getRegionNumber() << " for loops";
+
+ bool hasLoop =
+ traverseRegionGraph(successor.getSuccessor(),
+ [](Region *nextRegion, ArrayRef<bool> visited) {
+ // Interrupt traversal if the region was already
+ // visited.
+ return visited[nextRegion->getRegionNumber()];
+ });
+
+ if (hasLoop) {
+ LDBG() << "Found loop in entry region #"
+ << successor.getSuccessor()->getRegionNumber();
+ return true;
+ }
+ } else {
+ LDBG() << "Skipping parent successor";
+ }
+ }
+
+ LDBG() << "No loops found in operation";
return false;
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
+ LDBG() << "Finding enclosing repetitive region for operation "
+ << op->getName();
+
while (Region *region = op->getParentRegion()) {
+ LDBG() << "Checking region #" << region->getRegionNumber()
+ << " in operation " << region->getParentOp()->getName();
+
op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
- if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG()
+ << "Found RegionBranchOpInterface, checking if region is repetitive";
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
+ LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
+ }
+ } else {
+ LDBG() << "Parent operation does not implement RegionBranchOpInterface";
+ }
}
+
+ LDBG() << "No enclosing repetitive region found";
return nullptr;
}
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
+ LDBG() << "Finding enclosing repetitive region for value";
+
Region *region = value.getParentRegion();
while (region) {
+ LDBG() << "Checking region #" << region->getRegionNumber()
+ << " in operation " << region->getParentOp()->getName();
+
Operation *op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
- if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG()
+ << "Found RegionBranchOpInterface, checking if region is repetitive";
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
+ LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
+ }
+ } else {
+ LDBG() << "Parent operation does not implement RegionBranchOpInterface";
+ }
region = op->getParentRegion();
}
+
+ LDBG() << "No enclosing repetitive region found for value";
return nullptr;
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0..41f3f9d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -432,8 +432,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
- auto getSuccessors = [&](Region *region = nullptr) {
- auto point = region ? region : RegionBranchPoint::parent();
+ auto getSuccessors = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
return successors;
@@ -456,7 +455,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// `nonForwardedOperands`.
auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors()) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint::parent())) {
for (OpOperand *opOperand : getForwardedOpOperands(successor))
nonForwardedOperands.reset(opOperand->getOperandNumber());
}
@@ -469,10 +469,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
+ // TODO: this isn't correct in face of multiple terminators.
Operation *terminator = region.front().getTerminator();
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors(&region)) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator)))) {
for (OpOperand *opOperand :
getForwardedOpOperands(successor, terminator))
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
@@ -489,8 +492,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
+ RegionBranchPoint point =
+ terminator
+ ? RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator))
+ : RegionBranchPoint::parent();
- for (const RegionSuccessor &successor : getSuccessors(region)) {
+ for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
@@ -517,7 +525,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
resultsOrArgsToKeepChanged = false;
// Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor : getSuccessors()) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint::parent())) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor),
@@ -551,7 +560,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor : getSuccessors(&region)) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator)))) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
new file mode 100644
index 0000000..bcbdef0
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @wmma_k4
+func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
+ // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
+ amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+ func.return
+}
+
+// CHECK-LABEL: @wmma_k32
+func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
+ %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
+ // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
+ amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+ amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_k64
+func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
+ %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
+ // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+ amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_k128
+func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
+ // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+ func.return
+}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index dec62f9..7a82236 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -211,11 +211,25 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
}
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
+// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32
-// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] : f32
+// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] : f32
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32
-// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
+// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] : f32
+// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32
+// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
+// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] : f32
+// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] : f32
+// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
@@ -832,11 +846,25 @@ func.func @complex_exp_with_fmf(%arg: complex<f32>) -> complex<f32> {
}
// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
-// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[HALF:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[INF:.*]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[RESULT_REAL:.]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[REAL_HALF:.*]] = arith.mulf %[[REAL]], %[[HALF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_HALF:.*]] = math.exp %[[REAL_HALF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] fastmath<nnan,contract> : f32
// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] fastmath<nnan,contract> : f32
-// CHECK-DAG: %[[RESULT_IMAG:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IS_INF:.*]] = arith.cmpf oeq, %[[EXP_REAL]], %[[INF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IS_IMAG_ZERO:.*]] = arith.cmpf oeq, %[[IMAG]], %[[ZERO]] : f32
+// CHECK-DAG: %[[REAL_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_HALF_COS:.*]] = arith.mulf %[[EXP_HALF]], %[[COS_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[REAL_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_COS]], %[[EXP_HALF]] fastmath<nnan,contract> : f32
+// CHECK: %[[RESULT_REAL:.*]] = arith.select %[[IS_INF]], %[[REAL_OVERFLOW]], %[[REAL_NORMAL]] : f32
+// CHECK-DAG: %[[IMAG_NORMAL:.*]] = arith.mulf %[[EXP_REAL]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[EXP_HALF_SIN:.*]] = arith.mulf %[[EXP_HALF]], %[[SIN_IMAG]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IMAG_OVERFLOW:.*]] = arith.mulf %[[EXP_HALF_SIN]], %[[EXP_HALF]] fastmath<nnan,contract> : f32
+// CHECK-DAG: %[[IMAG_NONZERO:.*]] = arith.select %[[IS_INF]], %[[IMAG_OVERFLOW]], %[[IMAG_NORMAL]] : f32
+// CHECK: %[[RESULT_IMAG:.*]] = arith.select %[[IS_IMAG_ZERO]], %[[ZERO]], %[[IMAG_NONZERO]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 5784764..4c6f62a 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -156,14 +156,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
// -----
-func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
- // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
- %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
- func.return %0 : vector<8xi32>
-}
-
-// -----
-
func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
// expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
%0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
@@ -173,14 +165,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
// -----
func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
- // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
%0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
func.return %0 : vector<8xi32>
}
// -----
-// Missinng `resetOffset`
+func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op source vectors have different lengths}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op clamp flag is not supported for float types}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+// Missing `resetOffset`
func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
// expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
%ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index a330967..09134cb 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -586,6 +586,41 @@ func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) ->
func.return %0 : vector<8xi32>
}
+// CHECK-LABEL: func @wmma_f32_16x16x4_f32
+func.func @wmma_f32_16x16x4_f32(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // CHECK: amdgpu.wmma 16x16x4
+ %0 = amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f32_16x16x64_f8
+func.func @wmma_f32_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // CHECK: amdgpu.wmma 16x16x64
+ %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f32_16x16x64_bf8
+func.func @wmma_f32_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // CHECK: amdgpu.wmma 16x16x64
+ %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @wmma_f16_16x16x64_bf8
+func.func @wmma_f16_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+ // CHECK: amdgpu.wmma 16x16x64
+ %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+ func.return %0 : vector<8xf16>
+}
+
+// CHECK-LABEL: func @wmma_f16_16x16x64_f8
+func.func @wmma_f16_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf16>) -> vector<8xf16> {
+ // CHECK: amdgpu.wmma 16x16x64
+ %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+ func.return %0 : vector<8xf16>
+}
+
// CHECK-LABEL: func @swizzle_bitmode
func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
// CHECK: amdgpu.swizzle_bitmode
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
index 40a57b9..e8bb0c0 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-memoryeffect-interface.mlir
@@ -156,3 +156,24 @@ func.func @manual_deallocation(%c: i1, %f: f32, %idx: index) -> f32 {
// CHECK: cf.assert %[[true]], "expected that the block does not have ownership"
// CHECK: memref.dealloc %[[manual_alloc]]
// CHECK: bufferization.dealloc (%[[managed_alloc]] : memref<5xf32>) if (%[[true]])
+
+// -----
+
+// CHECK-LABEL: func.func private @properly_creates_deallocations_in_execute_region(
+// CHECK: %[[true:.*]] = arith.constant true
+// CHECK: scf.execute_region no_inline {
+// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x63x378x16xui8>
+// CHECK: bufferization.dealloc (%[[alloc]] : memref<1x63x378x16xui8>) if (%[[true]])
+
+func.func private @properly_creates_deallocations_in_execute_region(%arg1: memref<1x16x252x380xui8> ) -> (memref<1x250x378x16xui8> ) {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x250x378x16xui8>
+ scf.execute_region no_inline {
+ %subview = memref.subview %arg1[0, 0, 0, 0] [1, 16, 65, 380] [1, 1, 1, 1] : memref<1x16x252x380xui8> to memref<1x16x65x380xui8, strided<[1532160, 95760, 380, 1]>>
+ %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<1x63x378x16xui8>
+ test.buffer_based in(%subview: memref<1x16x65x380xui8, strided<[1532160, 95760, 380, 1]>>) out(%alloc_3: memref<1x63x378x16xui8>)
+ %subview_7 = memref.subview %alloc[0, 0, 0, 0] [1, 63, 378, 16] [1, 1, 1, 1] : memref<1x250x378x16xui8> to memref<1x63x378x16xui8, strided<[1512000, 6048, 16, 1]>>
+ test.copy(%alloc_3, %subview_7) : (memref<1x63x378x16xui8>, memref<1x63x378x16xui8, strided<[1512000, 6048, 16, 1]>>)
+ scf.yield
+ }
+ return %alloc : memref<1x250x378x16xui8>
+}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d5f834b..8db1ebb 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -381,15 +381,19 @@ func.func private @execute_region_test(%t1 : tensor<?xf32>)
// -----
// CHECK-LABEL: func @no_inline_execute_region_not_canonicalized
-func.func @no_inline_execute_region_not_canonicalized() {
- %c = arith.constant 42 : i32
- // CHECK: scf.execute_region
- // CHECK-SAME: no_inline
- %v = scf.execute_region -> i32 no_inline {
- scf.yield %c : i32
+module {
+ func.func private @foo()->()
+ func.func @no_inline_execute_region_not_canonicalized() {
+ %c = arith.constant 42 : i32
+ // CHECK: scf.execute_region
+ // CHECK-SAME: no_inline
+ %v = scf.execute_region -> i32 no_inline {
+ func.call @foo():()->()
+ scf.yield %c : i32
+ }
+ // CHECK: return
+ return
}
- // CHECK: return
- return
}
// -----
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index d270ee8..e703600 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -664,6 +664,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL @rocdl.tensor.load.to.lds
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: rocdl.tensor.load.to.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.store.from.lds
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: rocdl.tensor.store.from.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.load.to.lds.d2
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: rocdl.tensor.load.to.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL @rocdl.tensor.store.from.lds.d2
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: rocdl.tensor.store.from.lds.d2 %{{.*}}, %{{.*}} cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i64,
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index f9b81df..d0aec68 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -77,6 +77,24 @@ func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
// -----
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK-LABEL: func @memref_collapse(
+// CHECK-SAME: %[[sz0:.*]]: index
+// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index
+// CHECK: %[[dim:.*]] = memref.dim %{{.*}}, %[[c2]] : memref<3x4x?x2xf32>
+// CHECK: %[[mul:.*]] = affine.apply #[[$MAP]]()[%[[dim]]]
+// CHECK: return %[[c12]], %[[mul]]
+func.func @memref_collapse(%sz0: index) -> (index, index) {
+ %0 = memref.alloc(%sz0) : memref<3x4x?x2xf32>
+ %1 = memref.collapse_shape %0 [[0, 1], [2, 3]] : memref<3x4x?x2xf32> into memref<12x?xf32>
+ %2 = "test.reify_bound"(%1) {dim = 0} : (memref<12x?xf32>) -> (index)
+ %3 = "test.reify_bound"(%1) {dim = 1} : (memref<12x?xf32>) -> (index)
+ return %2, %3 : index, index
+}
+
+// -----
+
// CHECK-LABEL: func @memref_get_global(
// CHECK: %[[c4:.*]] = arith.constant 4 : index
// CHECK: return %[[c4]]
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 77d18da..042ee25 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -2243,3 +2243,76 @@ func.func @test_firstprivate_map(%arg0: memref<10xf32>) {
// CHECK-NEXT: acc.yield
// CHECK-NEXT: }
// CHECK-NEXT: return
+
+// -----
+
+func.func @test_kernel_environment(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+ %c1 = arith.constant 1 : index
+ %c1024 = arith.constant 1024 : index
+
+ // Create data clause operands for the kernel environment
+ %copyin = acc.copyin varPtr(%arg0 : memref<1024xf32>) -> memref<1024xf32>
+ %create = acc.create varPtr(%arg1 : memref<1024xf32>) -> memref<1024xf32>
+
+ // Kernel environment wraps gpu.launch and captures data mapping
+ acc.kernel_environment dataOperands(%copyin, %create : memref<1024xf32>, memref<1024xf32>) {
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c1, %block_z = %c1) {
+ // Kernel body uses the mapped data
+ %val = memref.load %copyin[%tx] : memref<1024xf32>
+ %result = arith.mulf %val, %val : f32
+ memref.store %result, %create[%tx] : memref<1024xf32>
+ gpu.terminator
+ }
+ }
+
+ // Copy results back to host and deallocate device memory
+ acc.copyout accPtr(%create : memref<1024xf32>) to varPtr(%arg1 : memref<1024xf32>)
+ acc.delete accPtr(%copyin : memref<1024xf32>)
+
+ return
+}
+
+// CHECK-LABEL: func @test_kernel_environment
+// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32>
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) -> memref<1024xf32>
+// CHECK: acc.kernel_environment dataOperands(%[[COPYIN]], %[[CREATE]] : memref<1024xf32>, memref<1024xf32>) {
+// CHECK: gpu.launch
+// CHECK: memref.load %[[COPYIN]]
+// CHECK: memref.store %{{.*}}, %[[CREATE]]
+// CHECK: }
+// CHECK: }
+// CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) to varPtr(%{{.*}} : memref<1024xf32>)
+// CHECK: acc.delete accPtr(%[[COPYIN]] : memref<1024xf32>)
+
+// -----
+
+func.func @test_kernel_environment_with_async(%arg0: memref<1024xf32>) {
+ %c1 = arith.constant 1 : index
+ %c1024 = arith.constant 1024 : index
+ %async_val = arith.constant 1 : i32
+
+ %create = acc.create varPtr(%arg0 : memref<1024xf32>) async(%async_val : i32) -> memref<1024xf32>
+
+ // Kernel environment with async clause
+ acc.kernel_environment dataOperands(%create : memref<1024xf32>) async(%async_val : i32) {
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c1024, %block_y = %c1, %block_z = %c1) {
+ %f0 = arith.constant 0.0 : f32
+ memref.store %f0, %create[%tx] : memref<1024xf32>
+ gpu.terminator
+ }
+ }
+
+ acc.copyout accPtr(%create : memref<1024xf32>) async(%async_val : i32) to varPtr(%arg0 : memref<1024xf32>)
+
+ return
+}
+
+// CHECK-LABEL: func @test_kernel_environment_with_async
+// CHECK: %[[ASYNC:.*]] = arith.constant 1 : i32
+// CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : memref<1024xf32>) async(%[[ASYNC]] : i32) -> memref<1024xf32>
+// CHECK: acc.kernel_environment dataOperands(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32)
+// CHECK: gpu.launch
+// CHECK: memref.store %{{.*}}, %[[CREATE]]
+// CHECK: acc.copyout accPtr(%[[CREATE]] : memref<1024xf32>) async(%[[ASYNC]] : i32) to varPtr(%{{.*}} : memref<1024xf32>)
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 37fc86b..3f481ad 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -373,7 +373,7 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) {
func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
{
- // expected-error@+1 {{region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 2}}
+ // expected-error@+1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor <to parent> needs 2}}
%x, %y = scf.if %arg0 -> (f32, f32) {
%0 = arith.addf %arg1, %arg1 : f32
scf.yield %0 : f32
@@ -544,7 +544,7 @@ func.func @while_invalid_terminator() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
- // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to Region #1: source has 0 operands, but target successor needs 1}}
+ // expected-error@+1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor <to region #1 with 1 inputs> needs 1}}
scf.while : () -> () {
scf.condition(%true)
} do {
@@ -557,7 +557,7 @@ func.func @while_cross_region_type_mismatch() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
- // expected-error@+1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}}
+ // expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}}
%0 = scf.while : () -> (i1) {
scf.condition(%true) %true : i1
} do {
@@ -570,7 +570,7 @@ func.func @while_cross_region_type_mismatch() {
func.func @while_result_type_mismatch() {
%true = arith.constant true
- // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}}
+ // expected-error@+1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor <to parent> needs 0}}
scf.while : () -> () {
scf.condition(%true) %true : i1
} do {
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 1bcef0a..ea587e9 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -49,6 +49,11 @@ func.func @conj(%arg: complex<f32>) -> complex<f32> {
func.return %conj : complex<f32>
}
+func.func @exp(%arg: complex<f32>) -> complex<f32> {
+ %exp = complex.exp %arg : complex<f32>
+ func.return %exp : complex<f32>
+}
+
// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
func.func @test_binary(%input: tensor<?xcomplex<f32>>,
%func: (complex<f32>, complex<f32>) -> complex<f32>) {
@@ -353,5 +358,32 @@ func.func @entry() {
call @test_element_f64(%abs_test_cast, %abs_func)
: (tensor<?xcomplex<f64>>, (complex<f64>) -> f64) -> ()
+ // complex.exp test
+ %exp_test = arith.constant dense<[
+ (1.0, 2.0),
+ // CHECK: -1.1312
+ // CHECK-NEXT: 2.4717
+
+ // The first case to consider is overflow of exp(real_part). If computed
+ // directly, this yields inf * 0 = NaN, which is incorrect.
+ (500.0, 0.0),
+ // CHECK-NEXT: inf
+ // CHECK-NOT: nan
+ // CHECK-NEXT: 0
+
+ // In this case, the overflow of exp(real_part) is compensated when
+ // sin(imag_part) is close to zero, yielding a finite imaginary part.
+ (90.0238094, 5.900613e-39)
+ // CHECK-NEXT: inf
+ // CHECK-NOT: inf
+ // CHECK-NEXT: 7.3746
+ ]> : tensor<3xcomplex<f32>>
+ %exp_test_cast = tensor.cast %exp_test
+ : tensor<3xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+ %exp_func = func.constant @exp : (complex<f32>) -> complex<f32>
+ call @test_unary(%exp_test_cast, %exp_func)
+ : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
+
func.return
}
diff --git a/mlir/test/Target/LLVMIR/ptr.mlir b/mlir/test/Target/LLVMIR/ptr.mlir
index 473ac05..94b6628 100644
--- a/mlir/test/Target/LLVMIR/ptr.mlir
+++ b/mlir/test/Target/LLVMIR/ptr.mlir
@@ -284,8 +284,8 @@ llvm.func @ptr_add_cst() -> !ptr.ptr<#llvm.address_space<0>> {
// CHECK-LABEL: define i64 @ptr_diff_scalar
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
-// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
-// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
@@ -296,8 +296,8 @@ llvm.func @ptr_diff_scalar(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !ptr.
// CHECK-LABEL: define i32 @ptr_diff_scalar_i32
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
-// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
-// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: %[[TRUNC:.*]] = trunc i64 %[[DIFF]] to i32
// CHECK-NEXT: ret i32 %[[TRUNC]]
@@ -309,8 +309,8 @@ llvm.func @ptr_diff_scalar_i32(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr2: !
// CHECK-LABEL: define <4 x i64> @ptr_diff_vector
// CHECK-SAME: (<4 x ptr> %[[PTRS1:.*]], <4 x ptr> %[[PTRS2:.*]]) {
-// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <4 x ptr> %[[PTRS1]] to <4 x i64>
-// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <4 x ptr> %[[PTRS2]] to <4 x i64>
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr <4 x ptr> %[[PTRS1]] to <4 x i64>
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr <4 x ptr> %[[PTRS2]] to <4 x i64>
// CHECK-NEXT: %[[DIFF:.*]] = sub <4 x i64> %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret <4 x i64> %[[DIFF]]
// CHECK-NEXT: }
@@ -321,8 +321,8 @@ llvm.func @ptr_diff_vector(%ptrs1: vector<4x!ptr.ptr<#llvm.address_space<0>>>, %
// CHECK-LABEL: define <8 x i32> @ptr_diff_vector_i32
// CHECK-SAME: (<8 x ptr> %[[PTRS1:.*]], <8 x ptr> %[[PTRS2:.*]]) {
-// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint <8 x ptr> %[[PTRS1]] to <8 x i64>
-// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint <8 x ptr> %[[PTRS2]] to <8 x i64>
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr <8 x ptr> %[[PTRS1]] to <8 x i64>
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr <8 x ptr> %[[PTRS2]] to <8 x i64>
// CHECK-NEXT: %[[DIFF:.*]] = sub <8 x i64> %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: %[[TRUNC:.*]] = trunc <8 x i64> %[[DIFF]] to <8 x i32>
// CHECK-NEXT: ret <8 x i32> %[[TRUNC]]
@@ -344,8 +344,8 @@ llvm.func @ptr_diff_with_constants() -> i64 {
// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
-// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
-// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub nsw i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
@@ -356,8 +356,8 @@ llvm.func @ptr_diff_with_flags_nsw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr
// CHECK-LABEL: define i64 @ptr_diff_with_flags_nuw
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
-// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
-// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub nuw i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
@@ -368,8 +368,8 @@ llvm.func @ptr_diff_with_flags_nuw(%ptr1: !ptr.ptr<#llvm.address_space<0>>, %ptr
// CHECK-LABEL: define i64 @ptr_diff_with_flags_nsw_nuw
// CHECK-SAME: (ptr %[[PTR1:.*]], ptr %[[PTR2:.*]]) {
-// CHECK-NEXT: %[[P1INT:.*]] = ptrtoint ptr %[[PTR1]] to i64
-// CHECK-NEXT: %[[P2INT:.*]] = ptrtoint ptr %[[PTR2]] to i64
+// CHECK-NEXT: %[[P1INT:.*]] = ptrtoaddr ptr %[[PTR1]] to i64
+// CHECK-NEXT: %[[P2INT:.*]] = ptrtoaddr ptr %[[PTR2]] to i64
// CHECK-NEXT: %[[DIFF:.*]] = sub nuw nsw i64 %[[P1INT]], %[[P2INT]]
// CHECK-NEXT: ret i64 %[[DIFF]]
// CHECK-NEXT: }
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 30126f6..8a848221 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -1040,6 +1040,36 @@ llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: rocdl.tensor.load.to.lds
+llvm.func @rocdl.tensor.load.to.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.load.to.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.store.from.lds
+llvm.func @rocdl.tensor.store.from.lds(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>,
+ %dgroup2 : vector<4xi32>, %dgroup3 : vector<4xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.store.from.lds %dgroup0, %dgroup1, %dgroup2, %dgroup3 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.load.to.lds.d2
+llvm.func @rocdl.tensor.load.to.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.load.to.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.load.to.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
+// CHECK-LABEL: rocdl.tensor.store.from.lds.d2
+llvm.func @rocdl.tensor.store.from.lds.d2(%dgroup0 : vector<4xi32>, %dgroup1 : vector<8xi32>) {
+ // CHECK: call void @llvm.amdgcn.tensor.store.from.lds.d2(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i32 0)
+ rocdl.tensor.store.from.lds.d2 %dgroup0, %dgroup1 cachepolicy 0 : vector<4xi32>, vector<8xi32>
+ llvm.return
+}
+
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
%stride : i16,
%numRecords : i64,
diff --git a/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir b/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir
new file mode 100644
index 0000000..62d15de
--- /dev/null
+++ b/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip --verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls() "None" {
+ // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}}
+ %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
+ %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
+ %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 90ba690e..712fd17 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -1,27 +1,32 @@
-// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: location = 0 : i32
spirv.GlobalVariable @var {location = 0 : i32} : !spirv.ptr<vector<4xf32>, Input>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: no_perspective
spirv.GlobalVariable @var {no_perspective} : !spirv.ptr<vector<4xf32>, Input>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: flat
spirv.GlobalVariable @var {flat} : !spirv.ptr<si32, Input>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: aliased
// CHECK: aliased
spirv.GlobalVariable @var1 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
@@ -30,28 +35,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: non_readable
spirv.GlobalVariable @var bind(0, 0) {non_readable} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: non_writable
spirv.GlobalVariable @var bind(0, 0) {non_writable} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: restrict
spirv.GlobalVariable @var bind(0, 0) {restrict} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: relaxed_precision
spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
}
@@ -84,7 +89,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage], [SPV_KHR_no_integer_wrap_decoration]> {
spirv.func @iadd_decorations(%arg: i32) -> i32 "None" {
// CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
%0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
@@ -94,7 +99,7 @@ spirv.func @iadd_decorations(%arg: i32) -> i32 "None" {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage], []> {
spirv.func @fadd_decorations(%arg: f32) -> f32 "None" {
// CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
%0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
@@ -104,7 +109,7 @@ spirv.func @fadd_decorations(%arg: f32) -> f32 "None" {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
// CHECK: spirv.FMul %{{.*}}, %{{.*}} {no_contraction}
%0 = spirv.FMul %arg, %arg {no_contraction} : f32
@@ -114,7 +119,7 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Float16], []> {
spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
%0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
@@ -124,51 +129,7 @@ spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
// -----
-// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls() "None" {
- // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
- %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
- // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
- %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls_invalid_type() "None" {
- // expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}}
- %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls_invalid_type() "None" {
- // expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
- %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls_invalid_type() "None" {
- // expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
- %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: spirv.func @relaxed_precision_arg({{%.*}}: !spirv.ptr<f32, Function> {spirv.decoration = #spirv.decoration<RelaxedPrecision>}) "None" attributes {relaxed_precision} {
spirv.func @relaxed_precision_arg(%arg0: !spirv.ptr<f32, Function> {spirv.decoration = #spirv.decoration<RelaxedPrecision>}) -> () "None" attributes {relaxed_precision} {
spirv.Return
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index eb0d980..7a7a583 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -66,7 +66,7 @@ public:
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo,
+ RegionSuccessor regionTo,
const NextAccess &after,
NextAccess *before) override;
@@ -240,7 +240,7 @@ void NextAccessAnalysis::visitCallControlFlowTransfer(
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
- RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
+ RegionSuccessor regionTo, const NextAccess &after, NextAccess *before) {
LDBG() << "visitRegionBranchControlFlowTransfer: "
<< OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
LDBG() << " regionFrom: " << (regionFrom.isParent() ? "parent" : "region");
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b211e24..4d4ec02 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()},
+ successor.getSuccessor()) &&
"invalid region index");
return getOperands();
}
@@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// We always branch to the join region.
if (!point.isParent()) {
- if (point != getJoinRegion())
+ if (point.getTerminatorPredecessorOrNull()->getParentRegion() !=
+ &getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
if (point.isParent())
regions.emplace_back(&getRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
}
void AnyCondOp::getRegionInvocationBounds(
@@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions(
if (point.isParent())
return;
- regions.emplace_back((*this)->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody());
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBody());
return MutableOperandRange(getInitMutable());
}
@@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
//===----------------------------------------------------------------------===//
MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- if (point.isParent())
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) {
+ if (successor.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
@@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions(
if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
else
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
@@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions(
// enter the body.
regions.emplace_back(
RegionSuccessor(&getBody(), getBody().front().getArguments()));
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 05a33cf..a3430ba 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
NoTerminator,
- DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>
]> {
let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
diff --git a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
index f1aae15..2e6950f 100644
--- a/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
+++ b/mlir/unittests/Interfaces/ControlFlowInterfacesTest.cpp
@@ -13,17 +13,24 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Parser/Parser.h"
+#include "llvm/Support/DebugLog.h"
#include <gtest/gtest.h>
using namespace mlir;
/// A dummy op that is also a terminator.
-struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator> {
+struct DummyOp : public Op<DummyOp, OpTrait::IsTerminator, OpTrait::ZeroResults,
+ OpTrait::ZeroSuccessors,
+ RegionBranchTerminatorOpInterface::Trait> {
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
static StringRef getOperationName() { return "cftest.dummy_op"; }
+
+ MutableOperandRange getMutableSuccessorOperands(RegionSuccessor point) {
+ return MutableOperandRange(getOperation(), 0, 0);
+ }
};
/// All regions of this op are mutually exclusive.
@@ -39,6 +46,8 @@ struct MutuallyExclusiveRegionsOp
// Regions have no successors.
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {}
+ using RegionBranchOpInterface::Trait<
+ MutuallyExclusiveRegionsOp>::getSuccessorRegions;
};
/// All regions of this op call each other in a large circle.
@@ -53,13 +62,18 @@ struct LoopRegionsOp
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- if (Region *region = point.getRegionOrNull()) {
- if (point == (*this)->getRegion(1))
+ if (point.getTerminatorPredecessorOrNull()) {
+ Region *region =
+ point.getTerminatorPredecessorOrNull()->getParentRegion();
+ if (region == &(*this)->getRegion(1))
// This region also branches back to the parent.
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation()->getParentOp(),
+ getOperation()->getParentOp()->getResults()));
regions.push_back(RegionSuccessor(region));
}
}
+ using RegionBranchOpInterface::Trait<LoopRegionsOp>::getSuccessorRegions;
};
/// Each region branches back it itself or the parent.
@@ -75,11 +89,17 @@ struct DoubleLoopRegionsOp
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- if (Region *region = point.getRegionOrNull()) {
- regions.push_back(RegionSuccessor());
+ if (point.getTerminatorPredecessorOrNull()) {
+ Region *region =
+ point.getTerminatorPredecessorOrNull()->getParentRegion();
+ regions.push_back(
+ RegionSuccessor(getOperation()->getParentOp(),
+ getOperation()->getParentOp()->getResults()));
regions.push_back(RegionSuccessor(region));
}
}
+ using RegionBranchOpInterface::Trait<
+ DoubleLoopRegionsOp>::getSuccessorRegions;
};
/// Regions are executed sequentially.
@@ -93,11 +113,15 @@ struct SequentialRegionsOp
// Region 0 has Region 1 as a successor.
void getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- if (point == (*this)->getRegion(0)) {
+ if (point.getTerminatorPredecessorOrNull() &&
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &(*this)->getRegion(0)) {
Operation *thisOp = this->getOperation();
regions.push_back(RegionSuccessor(&thisOp->getRegion(1)));
}
}
+ using RegionBranchOpInterface::Trait<
+ SequentialRegionsOp>::getSuccessorRegions;
};
/// A dialect putting all the above together.