diff options
Diffstat (limited to 'mlir/lib')
25 files changed, 242 insertions, 450 deletions
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp index 51fa773..fb5649e 100644 --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #define DEBUG_TYPE "constant-propagation" @@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const { LogicalResult SparseConstantPropagation::visitOperation( Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, ArrayRef<Lattice<ConstantValue> *> results) { - LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); + LDBG() << "SCP: Visiting operation: " << *op; // Don't try to simulate the results of a region operation as we can't // guarantee that folding will be out-of-place. We don't allow in-place @@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation( // Merge in the result of the fold, either a constant or a value. OpFoldResult foldResult = std::get<1>(it); if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) { - LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); + LDBG() << "Folded to constant: " << attr; propagateIfChanged(lattice, lattice->join(ConstantValue(attr, op->getDialect()))); } else { - LLVM_DEBUG(llvm::dbgs() - << "Folded to value: " << cast<Value>(foldResult) << "\n"); + LDBG() << "Folded to value: " << cast<Value>(foldResult); AbstractSparseForwardDataFlowAnalysis::join( lattice, *getLatticeElement(cast<Value>(foldResult))); } diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 197f97f..509f520 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -294,7 +294,7 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { solver.load<LivenessAnalysis>(symbolTable); LDBG() << "Initializing and running solver"; (void)solver.initializeAndRun(op); - LDBG() << "Dumping liveness state for op"; + LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName(); } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 176d53e..16f7033 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -14,7 +14,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/Config/abi-breaking.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "dataflow" @@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent, (void)inserted; DATAFLOW_DEBUG({ if (inserted) { - llvm::dbgs() << "Creating dependency between " << debugName << " of " - << anchor << "\nand " << debugName << " on " << dependent - << "\n"; + LDBG() << "Creating dependency between " << debugName << " of " << anchor + << "\nand " << debugName << " on " << dependent; } }); } @@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { // Initialize the analyses. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { - DATAFLOW_DEBUG(llvm::dbgs() - << "Priming analysis: " << analysis.debugName << "\n"); + DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName); if (failed(analysis.initialize(top))) return failure(); } @@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { auto [point, analysis] = worklist.front(); worklist.pop(); - DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName - << "' on: " << point << "\n"); + DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName + << "' on: " << point); if (failed(analysis->visit(point))) return failure(); } @@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state, assert(isRunning && "DataFlowSolver is not running, should not use propagateIfChanged"); if (changed == ChangeResult::Change) { - DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName - << " of " << state->anchor << "\n" - << "Value: " << *state << "\n"); + DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName + << " of " << state->anchor << "\n" + << "Value: " << *state); state->onUpdate(this); } } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 4307bc6..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1070,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1206,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index b1af5f0..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index 3c00b32..6265f46 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -15,13 +15,13 @@ #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::affine; #define DEBUG_TYPE "decompose-affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") /// Count the number of loops surrounding `operand` such that operand could be /// hoisted above. @@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, return rewriter.notifyMatchFailure( op, "only add or mul binary expr can be reassociated"); - LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n"); + LDBG() << "Start decomposeIntoFinerGrainedOps: " << op; // 2. Iteratively extract the RHS subexpressions while the top-level binary // expr kind remains the same. @@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp); if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { subExpressions.push_back(remainingExp); - LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n"); + LDBG() << "--terminal: " << subExpressions.back(); break; } subExpressions.push_back(currentBinExpr.getRHS()); - LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n"); + LDBG() << "--subExpr: " << subExpressions.back(); remainingExp = currentBinExpr.getLHS(); } @@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) { return getMaxSymbol(e1) < getMaxSymbol(e2); }); - LLVM_DEBUG( - llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: "); - llvm::dbgs() << "\n"); + LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions); // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. auto s0 = getAffineSymbolExpr(0, ctx); @@ -162,7 +160,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, Value tmp = createSubApply(rewriter, op, subExpressions[i]); current = AffineApplyOp::create(rewriter, op.getLoc(), binMap, ValueRange{current, tmp}); - LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n"); + LDBG() << "--reassociate into: " << current; } // 5. Replace original op. diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp index 8493b60..2521512 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp @@ -19,11 +19,10 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/IntEqClasses.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "affine-min-max" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::affine; @@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { ValueRange operands = affineOp.getOperands(); static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>; - LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; }); + LDBG() << "analyzing value: `" << affineOp; // Create a `Variable` list with values corresponding to each of the results // in the affine affineMap. @@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { [&](unsigned i) { return Variable(affineMap.getSliceMap(i, 1), operands); }); - LLVM_DEBUG({ - DBGS() << "- constructed variables are: " - << llvm::interleaved_array(llvm::map_range( - variables, [](const Variable &v) { return v.getMap(); })) - << "`\n"; - }); + LDBG() << "- constructed variables are: " + << llvm::interleaved_array(llvm::map_range( + variables, [](const Variable &v) { return v.getMap(); })); // Get the comparison operation. ComparisonOperator cmpOp = @@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Initialize the bound. Variable *bound = &v; - LLVM_DEBUG({ - DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() - << "`\n"; - }); + LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() + << "`\n"; // Check against the other variables. for (size_t j = i + 1; j < variables.size(); ++j) { @@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Get the bound of the equivalence class or itself. Variable *nv = bounds.lookup_or(jEqClass, &variables[j]); - LLVM_DEBUG({ - DBGS() << "- comparing with variable: #" << jEqClass - << ", with map: " << nv->getMap() << "\n"; - }); + LDBG() << "- comparing with variable: #" << jEqClass + << ", with map: " << nv->getMap(); // Compare the variables. FailureOr<bool> cmpResult = @@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // The variables cannot be compared. if (failed(cmpResult)) { - LLVM_DEBUG({ - DBGS() << "-- classes: #" << i << ", #" << jEqClass - << " cannot be merged\n"; - }); + LDBG() << "-- classes: #" << i << ", #" << jEqClass + << " cannot be merged"; continue; } // Join the equivalent classes and update the bound if necessary. - LLVM_DEBUG({ - DBGS() << "-- merging classes: #" << i << ", #" << jEqClass - << ", is cmp(lhs, rhs): " << *cmpResult << "`\n"; - }); + LDBG() << "-- merging classes: #" << i << ", #" << jEqClass + << ", is cmp(lhs, rhs): " << *cmpResult << "`"; if (*cmpResult) { boundedClasses.join(eqClass, jEqClass); } else { @@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Return if there's no simplification. if (bounds.size() >= affineMap.getNumResults()) { - LLVM_DEBUG( - { DBGS() << "- the affine operation couldn't get simplified\n"; }); + LDBG() << "- the affine operation couldn't get simplified"; return false; } @@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { for (auto [k, bound] : bounds) results.push_back(bound->getMap().getResult(0)); - LLVM_DEBUG({ - DBGS() << "- starting from map: " << affineMap << "\n"; - DBGS() << "- creating new map with: \n"; - DBGS() << "--- dims: " << affineMap.getNumDims() << "\n"; - DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n"; - DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n"; - }); + LDBG() << "- starting from map: " << affineMap; + LDBG() << "- creating new map with:"; + LDBG() << "--- dims: " << affineMap.getNumDims(); + LDBG() << "--- syms: " << affineMap.getNumSymbols(); + LDBG() << "--- res: " << llvm::interleaved_array(results); affineMap = AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(), @@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Update the affine op. rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); }); - LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; }); + LDBG() << "- simplified affine op: `" << affineOp << "`"; return true; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 27b6617..b56a212 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos, }); } -/// Returns true if the dimension of `sourceShape` is smaller than the dimension -/// of the `limitShape`. -static bool areAllInBound(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> limitShape) { - assert( - sourceShape.size() == limitShape.size() && - "expected source shape rank, and limit of the shape to have same rank"); - return llvm::all_of( - llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) { - int64_t sourceExtent = std::get<0>(it); - int64_t limit = std::get<1>(it); - return ShapedType::isDynamic(sourceExtent) || - ShapedType::isDynamic(limit) || sourceExtent <= limit; - }); -} - template <typename OpTy> static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, @@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // represents full tiles. RankedTensorType expectedPackedType = PackOp::inferPackedType( unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); - if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { - return op->emitError("the shape of output is not large enough to hold the " - "packed data. Expected at least ") - << expectedPackedType << ", got " << packedType; - } if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), mixedTiles), @@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); } + if (failed(verifyCompatibleShape(expectedPackedType.getShape(), + packedType.getShape()))) { + return op->emitError("expected ") + << expectedPackedType << " for the packed domain value, got " + << packedType; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp index c926dfb..5c8c2de 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" @@ -21,7 +22,6 @@ using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") static Attribute linearId0(MLIRContext *ctx) { return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0); @@ -81,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, this->threadMapping = llvm::to_vector(ArrayRef(allThreadMappings) .take_back(this->smallestBoundingTileSizes.size())); - LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n"); + LDBG() << *this; } int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 2fe72a3..d4a3e5f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -15,14 +15,13 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InterleavedRange.h" using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") //===----------------------------------------------------------------------===// // StructuredMatchOp @@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( return emitSilenceableError() << "expected a Linalg op"; } // If errors are suppressed, succeed and set all results to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op"); + LDBG() << "optional nested matcher expected a Linalg op"; results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); return DiagnosedSilenceableFailure::success(); } @@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( // When they are defined in this block, we additionally check if we have // already applied the operation that defines them. If not, the // corresponding results will be set to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() - << "\n"); + LDBG() << "optional nested matcher failed: " << diag.getMessage(); (void)diag.silence(); SmallVector<OpOperand *> undefinedOperands; for (OpOperand &terminatorOperand : diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 277e50b..9d7f4e0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index dad3526..57b610b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -932,20 +932,6 @@ struct PackOpTiling continue; } - // If the dimension needs padding, it is not supported because there are - // iterations that only write padding values to the whole tile. The - // consumer fusion is driven by the source, so it is not possible to map - // an empty slice to the tile. - bool needExtraPadding = - ShapedType::isDynamic(destDimSize) || !cstInnerSize || - destDimSize * cstInnerSize.value() != srcDimSize; - // Prioritize the case that the op already says that it does not need - // padding. - if (!packOp.getPaddingValue()) - needExtraPadding = false; - if (needExtraPadding) - return failure(); - // Currently fusing `packOp` as consumer only expects perfect tiling // scenario because even if without padding semantic, the `packOp` may // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 0170837..793eec7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1913,14 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), sourceShape.end()); - ReifiedRankedShapedTypeDims reifiedRetShapes; - LogicalResult status = - cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) - .reifyResultShapes(rewriter, reifiedRetShapes); - if (status.failed()) { - LDBG() << "Unable to reify result shapes of " << unpackOp; - return failure(); - } Location loc = unpackOp->getLoc(); auto padValue = arith::ConstantOp::create( diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index 106c3b4..cce80db 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { for (auto &&[opOffset, sourceOffset, sourceStride, opSize] : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), sourceOp.getMixedStrides(), op.getMixedSizes())) { - // We only support static sizes. - if (isa<Value>(opSize)) { - return failure(); - } sizes.push_back(opSize); Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset), sourceOffsetAttr = diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index ecd93ff..3cafb19 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { return std::nullopt; } +static void printInitializationList(OpAsmPrinter &parser, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + parser << prefix << '('; + llvm::interleaveComma( + llvm::zip(blocksArgs, initializers), parser, + [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); + parser << ")"; +} + // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. @@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); - auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; - // Create a i1 tensor type for the boolean condition. - Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); - if (parser.parseOperand(cond) || - parser.resolveOperand(cond, i1Type, result.operands)) + + if (parser.parseOperand(cond)) return failure(); - // Parse optional results type list. - if (parser.parseOptionalArrowTypeList(result.types)) + + SmallVector<OpAsmParser::Argument, 4> regionArgs; + SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; + + // Parse the optional block arguments + OptionalParseResult listResult = + parser.parseOptionalAssignmentList(regionArgs, operands); + if (listResult.has_value() && failed(listResult.value())) return failure(); + + // Parse a colon. + if (failed(parser.parseColon())) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Parse the type of the condition operand + Type condType; + if (failed(parser.parseType(condType))) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Resolve operand with provided type + if (failed(parser.resolveOperand(cond, condType, result.operands))) + return failure(); + + // Parse optional block arg types + if (listResult.has_value()) { + FunctionType functionType; + + if (failed(parser.parseType(functionType))) + return parser.emitError(parser.getCurrentLocation()) + << "expected list of types for block arguments " + << "followed by arrow type and list of return types"; + + result.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) { + return parser.emitError(parser.getCurrentLocation()) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + } + + // Resolve input operands. + if (failed(parser.resolveOperands(operands, functionType.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + } else { + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + } + // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); @@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { } void IfOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; - p << " " << getCondition(); - if (!getResults().empty()) { - p << " -> (" << getResultTypes() << ")"; - // Print yield explicitly if the op defines values. - printBlockTerminators = true; + + printInitializationList(p, getThenGraph().front().getArguments(), + getInputList(), " "); + p << " : "; + p << getCondition().getType(); + + if (!getInputList().empty()) { + p << " ("; + llvm::interleaveComma(getInputList().getTypes(), p); + p << ")"; } - p << ' '; - p.printRegion(getThenGraph(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printArrowTypeList(getResultTypes()); + p << " "; + + p.printRegion(getThenGraph()); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; - p.printRegion(elseRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printRegion(elseRegion); } p.printOptionalAttrDict((*this)->getAttrs()); @@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { parser.parseOptionalAttrDictWithKeyword(result.attributes)); } -static void printInitializationList(OpAsmPrinter &parser, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - parser << prefix << '('; - llvm::interleaveComma( - llvm::zip(blocksArgs, initializers), parser, - [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); - parser << ")"; -} - void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCondGraph().front().getArguments(), getInputList(), " "); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 32b5fb6..8ec7765 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) { // }) // // Simplified: - // %0 = tosa.cond_if %arg2 { - // tosa.yield %arg0 + // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) { + // ^bb0(%arg3, %arg4): + // tosa.yield %arg3 // } else { - // tosa.yield %arg1 + // ^bb0(%arg3, %arg4): + // tosa.yield %arg4 // } - // - // Unfortunately, the simplified syntax does not encapsulate values - // used in then/else regions (see 'simplified' example above), so it - // must be rewritten to use the generic syntax in order to be conformant - // to the specification. + return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) || failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); } diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index c0d20d4..14a4fdf 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -21,17 +21,8 @@ #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "transform-dialect" -#define DEBUG_TYPE_FULL "transform-dialect-full" #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") -#ifndef NDEBUG -#define FULL_LDBG(X) \ - DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL) -#else -#define FULL_LDBG(X) \ - for (bool _c = false; _c; _c = false) \ - ::llvm::nulls() -#endif +#define FULL_LDBG() LDBG(4) using namespace mlir; @@ -818,16 +809,14 @@ void transform::TransformState::compactOpHandles() { DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { - LLVM_DEBUG({ - DBGS() << "applying: "; - transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"; - }); + LDBG() << "applying: " + << OpWithFlags(transform, OpPrintingFlags().skipRegions()); FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel(); auto printOnFailureRAII = llvm::make_scope_exit([this] { (void)this; - LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( - llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); + LDBG() << "Failing Top-level payload:\n" + << OpWithFlags(getTopLevel(), + OpPrintingFlags().printGenericOpForm()); }); // Set current transform op. @@ -995,8 +984,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { printOnFailureRAII.release(); DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { - DBGS() << "Top-level payload:\n"; - getTopLevel()->print(llvm::dbgs()); + LDBG() << "Top-level payload:\n" << *getTopLevel(); }); return result; } @@ -1273,7 +1261,7 @@ void transform::TrackingListener::notifyMatchFailure( LLVM_DEBUG({ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); - DBGS() << "Match Failure : " << diag.str(); + LDBG() << "Match Failure : " << diag.str(); }); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bce358d..8789f55 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, CanonicalizeContractAdd<arith::AddFOp>>(context); } -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source) { - result.addOperands({source}); - result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType()); -} - -LogicalResult vector::ExtractElementOp::verify() { - VectorType vectorType = getSourceVectorType(); - if (vectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (vectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here now. - if (!adaptor.getPosition()) - return {}; - - // Fold extractelement (splat X) -> X. - if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) - return splat.getInput(); - - // Fold extractelement(broadcast(X)) -> X. - if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>()) - if (!llvm::isa<VectorType>(broadcast.getSource().getType())) - return broadcast.getSource(); - - auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!pos || !src) - return {}; - - auto srcElements = src.getValues<Attribute>(); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= srcElements.size()) - return {}; - - return srcElements[posIdx]; -} - // Returns `true` if `index` is either within [0, maxIndex) or equal to // `poisonValue`. static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, @@ -3184,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// InsertElementOp -//===----------------------------------------------------------------------===// - -void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); -} - -void InsertElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value dest) { - build(builder, result, source, dest, {}); -} - -LogicalResult InsertElementOp::verify() { - auto dstVectorType = getDestVectorType(); - if (dstVectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (dstVectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here. - if (!adaptor.getPosition()) - return {}; - - auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource()); - auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!src || !dst || !pos) - return {}; - - if (src.getType() != getDestVectorType().getElementType()) - return {}; - - auto dstElements = dst.getValues<Attribute>(); - - SmallVector<Attribute> results(dstElements); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= results.size()) - return {}; - results[posIdx] = src; - - return DenseElementsAttr::get(getDestVectorType(), results); -} - -//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5c98417..9332f55 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -156,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present<Listener>(listener); + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { op->erase(); @@ -320,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. assert(source->empty() && "expected 'source' to be empty"); eraseBlock(source); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 0db9808..7094c8e 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) { if (failed(initialize(context, impl->initializationGeneration + 1))) return failure(); initializationKey = newInitKey; - pipelineKey = pipelineInitializationKey; + pipelineInitializationKey = pipelineKey; } // Construct a top level analysis manager for the pipeline. diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp index 01ad910..304253c 100644 --- a/mlir/lib/Support/TypeID.cpp +++ b/mlir/lib/Support/TypeID.cpp @@ -27,9 +27,6 @@ namespace { struct ImplicitTypeIDRegistry { /// Lookup or insert a TypeID for the given type name. TypeID lookupOrInsert(StringRef typeName) { - LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert(" - << typeName << ")\n"); - // Perform a heuristic check to see if this type is in an anonymous // namespace. String equality is not valid for anonymous types, so we try to // abort whenever we see them. diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 58e5353..a8a2b2e 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -446,6 +446,19 @@ LogicalResult Serializer::processType(Location loc, Type type, LogicalResult Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, SetVector<StringRef> &serializationCtx) { + + // Map unsigned integer types to singless integer types. + // This is needed otherwise the generated spirv assembly will contain + // twice a type declaration (like OpTypeInt 32 0) which is no permitted and + // such module fails validation. Indeed at MLIR level the two types are + // different and lookup in the cache below misses. + // Note: This conversion needs to happen here before the type is looked up in + // the cache. + if (type.isUnsignedInteger()) { + type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(), + IntegerType::SignednessSemantics::Signless); + } + typeID = getTypeID(type); if (typeID) return success(); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index df255cf..08803e0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -1700,6 +1701,7 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( }); assert(!wasOpReplaced(newParentOp) && "attempting to insert into a region within a replaced/erased op"); + (void)newParentOp; patternInsertedBlocks.insert(block); @@ -1758,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> { return v ? SmallVector<Value>{v} : SmallVector<Value>(); @@ -1773,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1781,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1887,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); } @@ -2216,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); }); - (void)rewriterImpl; + + // Clear pattern state, so that the next pattern application starts with a + // clean slate. (The op/block sets are populated by listener notifications.) + auto cleanup = llvm::make_scope_exit([&]() { + rewriterImpl.patternNewOps.clear(); + rewriterImpl.patternModifiedOps.clear(); + rewriterImpl.patternInsertedBlocks.clear(); + }); + + // Upon failure, undo all changes made by the folder. + RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. StringRef opName = op->getName().getStringRef(); SmallVector<Value, 2> replacementValues; SmallVector<Operation *, 2> newOps; rewriter.setInsertionPoint(op); + rewriter.startOpModification(op); if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); + rewriter.cancelOpModification(op); return failure(); } + rewriter.finalizeOpModification(op); // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { if (failed(legalize(newOp, rewriter))) { @@ -2245,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op, "op '" + opName + "' folder rollback of IR modifications requested"); } - // Legalization failed: erase all materialized constants. - for (Operation *op : newOps) - rewriter.eraseOp(op); + rewriterImpl.resetState( + curState, std::string(op->getName().getStringRef()) + " folder"); return failure(); } } - // Insert a replacement for 'op' with the folded replacement values. - rewriter.replaceOp(op, replacementValues); - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); return success(); } diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index b639e87f..26c965c 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -21,7 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "inlining" @@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, // InlinerInterfaceImpl //===----------------------------------------------------------------------===// -#ifndef NDEBUG static std::string getNodeName(CallOpInterface op) { if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) return debugString(op); return "_unnamed_callee_"; } -#endif /// Return true if the specified `inlineHistoryID` indicates an inline history /// that already includes `node`. @@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); LLVM_DEBUG({ - llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + LDBG() << "* Inliner: Initial calls in SCC are: {"; for (unsigned i = 0, e = calls.size(); i < e; ++i) - llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; - llvm::dbgs() << "}\n"; + LDBG() << " " << i << ". " << calls[i].call << ","; + LDBG() << "}"; }); // Try to inline each of the call operations. Don't cache the end iterator @@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Inlining call: " << i << ". " << call; else - llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Not inlining call: " << i << ". " << call; }); if (!doInline) continue; @@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, cast<CallableOpInterface>(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { - LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); + LDBG() << "** Failed to inline"; continue; } inlinedAnyCalls = true; @@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, auto historyToString = [](InlineHistoryT h) { return h.has_value() ? std::to_string(*h) : "root"; }; - (void)historyToString; - LLVM_DEBUG(llvm::dbgs() - << "* new inlineHistory entry: " << newInlineHistoryID << ". [" - << getNodeName(call) << ", " << historyToString(inlineHistoryID) - << "]\n"); + LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << historyToString(inlineHistoryID) + << "]"; for (unsigned k = prevSize; k != calls.size(); ++k) { callHistory.push_back(newInlineHistoryID); - LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call - << "}\n with historyID = " << newInlineHistoryID - << ", added due to inlining of\n call {" << call - << "}\n with historyID = " - << historyToString(inlineHistoryID) << "\n"); + LDBG() << "* new call " << k << " {" << calls[k].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << historyToString(inlineHistoryID); } // If the inlining was successful, Merge the new uses into the source node. |