diff options
Diffstat (limited to 'mlir')
28 files changed, 773 insertions, 120 deletions
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 9753dca..b67e4cb 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -973,6 +973,7 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", custom<ShuffleType>(ref(type($v1)), type($res), ref($mask)) }]; + let hasFolder = 1; let hasVerifier = 1; string llvmInstName = "ShuffleVector"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index ce9ff7e..d959464 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -263,6 +263,7 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda; + let hasVerifier = 1; // Backwards-compatibility builder for an unspecified range. let builders = [ @@ -279,6 +280,11 @@ class NVVM_PureSpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = SetIntRangeFn setResultRanges) { nvvmInferResultRanges(getOperation(), getResult(), argRanges, setResultRanges); } + + // Verify the range attribute satisfies LLVM ConstantRange constructor requirements. + ::llvm::LogicalResult $cppClass::verify() { + return verifyConstantRangeAttr(getOperation(), getRange()); + } }]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 6925cec..68f31e6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -412,6 +412,32 @@ def ROCDL_WaitExpcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.expcnt", [], 0, [0], let assemblyFormat = "$count attr-dict"; } +def ROCDL_WaitAsynccntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.asynccnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until ASYNCCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + +def ROCDL_WaitTensorcntOp: ROCDL_ConcreteNonMemIntrOp<"s.wait.tensorcnt", [], 0, [0], ["count"]>, + Arguments<(ins I16Attr:$count)> { + let summary = "Wait until TENSORCNT is less than or equal to `count`"; + let description = [{ + Wait for the counter specified to be less-than or equal-to the `count` + before continuing. + + Available on gfx1250+. + }]; + let results = (outs); + let assemblyFormat = "$count attr-dict"; +} + def ROCDL_SetPrioOp : ROCDL_ConcreteNonMemIntrOp<"s.setprio", [], 0, [0], ["priority"]>, Arguments<(ins I16Attr:$priority)> { let assemblyFormat = "$priority attr-dict"; diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6e79085..6e15b1e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [ }]; let results = (outs VectorOfRankAndType<[1], [Index]>:$result); let assemblyFormat = "attr-dict `:` type($result)"; + let hasCanonicalizer = 1; } def Vector_YieldOp : Vector_Op<"yield", [ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 5355909..41d8d53 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering return success(); } - // For 1-d vector, we additionally do a `vectorshuffle`. auto v = LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType, poison, adaptor.getSource(), zero); + // For 1-d vector, we additionally do a `shufflevector`. int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0); SmallVector<int32_t> zeroValues(width, 0); // Shuffle the value across the desired number of elements. - rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison, - zeroValues); + auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>( + broadcast.getLoc(), v, poison, zeroValues); + rewriter.replaceOp(broadcast, shuffle); return success(); } }; diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 71687b1..ddcbc44 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" @@ -390,7 +391,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { // Load result or Store valye Type can be vector or scalar. Type valOrResTy; if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) - valOrResTy = op.getResult().getType(); + valOrResTy = + this->getTypeConverter()->convertType(op.getResult().getType()); else valOrResTy = adaptor.getValue().getType(); VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); @@ -878,10 +880,30 @@ struct ConvertXeGPUToXeVMPass } return {}; }; - typeConverter.addSourceMaterialization(memrefMaterializationCast); - typeConverter.addSourceMaterialization(ui64MaterializationCast); - typeConverter.addSourceMaterialization(ui32MaterializationCast); - typeConverter.addSourceMaterialization(vectorMaterializationCast); + + // If result type of original op is single element vector and lowered type + // is scalar. This materialization cast creates a single element vector by + // broadcasting the scalar value. + auto singleElementVectorMaterializationCast = + [](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType().isIntOrIndexOrFloat()) { + // If the input is a scalar, and the target type is a vector of single + // element, create a single element vector by broadcasting. + if (auto vecTy = dyn_cast<VectorType>(type)) { + if (vecTy.getNumElements() == 1) { + return vector::BroadcastOp::create(builder, loc, vecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization( + singleElementVectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 7ca09d9..3eae67f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2826,6 +2826,20 @@ LogicalResult ShuffleVectorOp::verify() { return success(); } +// Folding for shufflevector op when v1 is single element 1D vector +// and the mask is a single zero. OpFoldResult will be v1 in this case. +OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) { + // Check if operand 0 is a single element vector. + auto vecType = llvm::dyn_cast<VectorType>(getV1().getType()); + if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1) + return {}; + // Check if the mask is a single zero. + // Note: The mask is guaranteed to be non-empty. + if (getMask().size() != 1 || getMask()[0] != 0) + return {}; + return getV1(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index ab54183..2a8c330 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -798,6 +798,26 @@ LogicalResult MmaOp::verify() { " attribute"); } + // Validate layout combinations. According to the operation description, most + // MMA operations require layoutA=row and layoutB=col. Only m8n8k4 with f16 + // can use other layout combinations. + bool isM8N8K4_F16 = + (mmaShape[0] == 8 && mmaShape[1] == 8 && mmaShape[2] == 4 && + getMultiplicandAPtxType() == MMATypes::f16); + + if (!isM8N8K4_F16) { + // For all other shapes/types, layoutA must be row and layoutB must be col + if (getLayoutA() != MMALayout::row || getLayoutB() != MMALayout::col) { + return emitOpError("requires layoutA = #nvvm.mma_layout<row> and " + "layoutB = #nvvm.mma_layout<col> for shape <") + << mmaShape[0] << ", " << mmaShape[1] << ", " << mmaShape[2] + << "> with element types " + << stringifyEnum(*getMultiplicandAPtxType()) << " and " + << stringifyEnum(*getMultiplicandBPtxType()) + << ". Only m8n8k4 with f16 supports other layouts."; + } + } + return success(); } @@ -2334,6 +2354,32 @@ static void nvvmInferResultRanges(Operation *op, Value result, } } +/// Verify the range attribute satisfies LLVM ConstantRange constructor +/// requirements for NVVM SpecialRangeableRegisterOp. +static LogicalResult +verifyConstantRangeAttr(Operation *op, + std::optional<LLVM::ConstantRangeAttr> rangeAttr) { + if (!rangeAttr) + return success(); + + const llvm::APInt &lower = rangeAttr->getLower(); + const llvm::APInt &upper = rangeAttr->getUpper(); + + // Check LLVM ConstantRange constructor condition + if (lower == upper && !lower.isMaxValue() && !lower.isMinValue()) { + unsigned bitWidth = lower.getBitWidth(); + llvm::APInt minVal = llvm::APInt::getMinValue(bitWidth); + llvm::APInt maxVal = llvm::APInt::getMaxValue(bitWidth); + return op->emitOpError( + "invalid range attribute: Lower == Upper, but they aren't min (") + << llvm::toString(minVal, 10, false) << ") or max (" + << llvm::toString(maxVal, 10, false) + << ") value! This is an invalid constant range."; + } + + return success(); +} + static llvm::Value *getAsPackedI32(llvm::Value *arg, llvm::IRBuilderBase &builder) { return builder.CreateBitCast(arg, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 58256b0..45c54c7 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), result); } +namespace { + +/// Fold `vector.step -> arith.cmpi` when the step value is compared to a +/// constant large enough such that the result is the same at all indices. +/// +/// For example, rewrite the 'greater than' comparison below, +/// +/// ```mlir +/// %cst = arith.constant dense<7> : vector<3xindex> +/// %stp = vector.step : vector<3xindex> +/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex> +/// ``` +/// +/// as, +/// +/// ```mlir +/// %out = arith.constant dense<false> : vector<3xi1>. +/// ``` +/// +/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result +/// is false at ALL indices we fold. If the constant was 1, then +/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold, +/// conservatively preferring the 'compact' vector.step representation. +/// +/// Note: this folder only works for the case where the constant (`%cst` above) +/// is the second operand of the comparison. The arith.cmpi canonicalizer will +/// ensure that constants are always second (on the right). +struct StepCompareFolder : public OpRewritePattern<StepOp> { + using Base::Base; + + LogicalResult matchAndRewrite(StepOp stepOp, + PatternRewriter &rewriter) const override { + const int64_t stepSize = stepOp.getResult().getType().getNumElements(); + + for (OpOperand &use : stepOp.getResult().getUses()) { + auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner()); + if (!cmpiOp) + continue; + + // arith.cmpi canonicalizer makes constants final operands. + const unsigned stepOperandNumber = use.getOperandNumber(); + if (stepOperandNumber != 0) + continue; + + // Check that operand 1 is a constant. + unsigned constOperandNumber = 1; + Value otherOperand = cmpiOp.getOperand(constOperandNumber); + std::optional<int64_t> maybeConstValue = + getConstantIntValue(otherOperand); + if (!maybeConstValue.has_value()) + continue; + + int64_t constValue = maybeConstValue.value(); + arith::CmpIPredicate pred = cmpiOp.getPredicate(); + + auto maybeSplat = [&]() -> std::optional<bool> { + // Handle ult (unsigned less than) and uge (unsigned greater equal). + if ((pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::uge) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ult; + + // Handle ule and ugt. + if ((pred == arith::CmpIPredicate::ule || + pred == arith::CmpIPredicate::ugt) && + stepSize - 1 <= constValue) { + return pred == arith::CmpIPredicate::ule; + } + + // Handle eq and ne. + if ((pred == arith::CmpIPredicate::eq || + pred == arith::CmpIPredicate::ne) && + stepSize <= constValue) + return pred == arith::CmpIPredicate::ne; + + return std::nullopt; + }(); + + if (!maybeSplat.has_value()) + continue; + + rewriter.setInsertionPointAfter(cmpiOp); + + auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType()); + if (!type) + continue; + + auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value()); + Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(), + type, boolAttr); + + rewriter.replaceOp(cmpiOp, splat); + return success(); + } + + return failure(); + } +}; +} // namespace + +void StepOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<StepCompareFolder>(context); +} + //===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 14639c5..fbae098 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -465,26 +465,33 @@ struct UnrollElementwisePattern : public RewritePattern { auto targetShape = getTargetShape(options, op); if (!targetShape) return failure(); + int64_t targetShapeRank = targetShape->size(); auto dstVecType = cast<VectorType>(op->getResult(0).getType()); SmallVector<int64_t> originalSize = *cast<VectorUnrollOpInterface>(op).getShapeForUnroll(); - // Bail-out if rank(source) != rank(target). The main limitation here is the - // fact that `ExtractStridedSlice` requires the rank for the input and - // output to match. If needed, we can relax this later. - if (originalSize.size() != targetShape->size()) - return rewriter.notifyMatchFailure( - op, "expected input vector rank to match target shape rank"); + int64_t originalShapeRank = originalSize.size(); + Location loc = op->getLoc(); + + // Handle rank mismatch by adding leading unit dimensions to targetShape + SmallVector<int64_t> adjustedTargetShape(originalShapeRank); + int64_t rankDiff = originalShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), + adjustedTargetShape.begin() + rankDiff, 1); + std::copy(targetShape->begin(), targetShape->end(), + adjustedTargetShape.begin() + rankDiff); + + int64_t adjustedTargetShapeRank = adjustedTargetShape.size(); // Prepare the result vector. Value result = arith::ConstantOp::create(rewriter, loc, dstVecType, rewriter.getZeroAttr(dstVecType)); - SmallVector<int64_t> strides(targetShape->size(), 1); - VectorType newVecType = + SmallVector<int64_t> strides(adjustedTargetShapeRank, 1); + VectorType unrolledVecType = VectorType::get(*targetShape, dstVecType.getElementType()); // Create the unrolled computation. for (SmallVector<int64_t> offsets : - StaticTileOffsetRange(originalSize, *targetShape)) { + StaticTileOffsetRange(originalSize, adjustedTargetShape)) { SmallVector<Value> extractOperands; for (OpOperand &operand : op->getOpOperands()) { auto vecType = dyn_cast<VectorType>(operand.get().getType()); @@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern { extractOperands.push_back(operand.get()); continue; } - extractOperands.push_back( - rewriter.createOrFold<vector::ExtractStridedSliceOp>( - loc, operand.get(), offsets, *targetShape, strides)); + Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>( + loc, operand.get(), offsets, adjustedTargetShape, strides); + + // Reshape to remove leading unit dims if needed + if (adjustedTargetShapeRank > targetShapeRank) { + extracted = rewriter.createOrFold<vector::ShapeCastOp>( + loc, VectorType::get(*targetShape, vecType.getElementType()), + extracted); + } + extractOperands.push_back(extracted); } + Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, op, extractOperands, newVecType); + rewriter, loc, op, extractOperands, unrolledVecType); + + Value computeResult = newOp->getResult(0); + + // Use strides sized to targetShape for proper insertion + SmallVector<int64_t> insertStrides = + (adjustedTargetShapeRank > targetShapeRank) + ? SmallVector<int64_t>(targetShapeRank, 1) + : strides; + result = rewriter.createOrFold<vector::InsertStridedSliceOp>( - loc, newOp->getResult(0), result, offsets, strides); + loc, computeResult, result, offsets, insertStrides); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 776b5c6..4d81918 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -378,8 +378,10 @@ struct SourceMgrDiagnosticHandlerImpl { } // Otherwise, try to load the source file. - std::string ignored; - unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored); + auto bufferOrErr = llvm::MemoryBuffer::getFile(filename); + if (!bufferOrErr) + return 0; + unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc()); filenameToBufId[filename] = id; return id; } diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index c883baa..3236b4f 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -27,6 +27,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Parser.h" #include <optional> @@ -828,6 +829,7 @@ LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, llvm::SourceMgr tdSrcMgr; tdSrcMgr.AddNewSourceBuffer(std::move(*includeBuffer), SMLoc()); tdSrcMgr.setIncludeDirs(parserSrcMgr.getIncludeDirs()); + tdSrcMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); // This class provides a context argument for the llvm::SourceMgr diagnostic // handler. diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 60b9567..1dbe7eca 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -31,6 +31,7 @@ #include "llvm/Support/FileSystem.h" #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include <optional> using namespace mlir; @@ -402,6 +403,7 @@ PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, llvm::append_range(includeDirs, extraDirs); sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); astContext.getDiagEngine().setHandlerFn([&](const ast::Diagnostic &diag) { diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 3080b78..2d817be 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/LSP/Logging.h" #include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" +#include "llvm/Support/VirtualFileSystem.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" #include <optional> @@ -448,6 +449,7 @@ void TableGenTextFile::initialize( return; } sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); // This class provides a context argument for the SourceMgr diagnostic diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 2d33888..d669a3b 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -76,6 +76,18 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> { // ----- +func.func @broadcast_single_elem_vec1d_from_f32(%arg0: f32) -> vector<1xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1xf32> + return %0 : vector<1xf32> +} +// CHECK-LABEL: @broadcast_single_elem_vec1d_from_f32 +// CHECK-SAME: %[[A:.*]]: f32) +// CHECK: %[[T0:.*]] = llvm.insertelement %[[A]] +// CHECK-NOT: llvm.shufflevector +// CHECK: return %[[T0]] : vector<1xf32> + +// ----- + func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> { %0 = vector.broadcast %arg0 : f32 to vector<[2]xf32> return %0 : vector<[2]xf32> diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir index 0b150e9..9c552d8 100644 --- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir +++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir @@ -14,19 +14,36 @@ gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64 // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1> // CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) { - // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16> - // CHECK: %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16> - // CHECK: scf.yield %[[VAR8]] : f16 - // CHECK: } else { - // CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16> - // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16> + // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> f16 // CHECK: scf.yield %[[VAR7]] : f16 + // CHECK: } else { + // CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f16 + // CHECK: scf.yield %[[CST_0]] : f16 // CHECK: } %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> gpu.return } } + +// ----- +gpu.module @test { +// CHECK-LABEL: @source_materialize_single_elem_vec +// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: memref<1xf16> +gpu.func @source_materialize_single_elem_vec(%src: i64, %offset: vector<1xindex>, %dst: memref<1xf16>) { + %1 = arith.constant dense<1>: vector<1xi1> + %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> + : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16> + // CHECK: %[[VAR_IF:.*]] = scf.if + // CHECK: %[[VAR_RET:.*]] = vector.broadcast %[[VAR_IF]] : f16 to vector<1xf16> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: vector.store %[[VAR_RET]], %[[ARG2]][%[[C0]]] : memref<1xf16>, vector<1xf16> + %c0 = arith.constant 0 : index + vector.store %3, %dst[%c0] : memref<1xf16>, vector<1xf16> + gpu.return +} +} + // ----- gpu.module @test { diff --git a/mlir/test/Dialect/LLVMIR/canonicalize.mlir b/mlir/test/Dialect/LLVMIR/canonicalize.mlir index 8accf6e..755e3a3 100644 --- a/mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ b/mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -235,6 +235,17 @@ llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr { // ----- +// CHECK-LABEL: fold_shufflevector +// CHECK-SAME: %[[ARG1:[[:alnum:]]+]]: vector<1xf32>, %[[ARG2:[[:alnum:]]+]]: vector<1xf32> +llvm.func @fold_shufflevector(%v1 : vector<1xf32>, %v2 : vector<1xf32>) -> vector<1xf32> { + // CHECK-NOT: llvm.shufflevector + %c = llvm.shufflevector %v1, %v2 [0] : vector<1xf32> + // CHECK: llvm.return %[[ARG1]] + llvm.return %c : vector<1xf32> +} + +// ----- + // Check that LLVM constants participate in cross-dialect constant folding. The // resulting constant is created in the arith dialect because the last folded // operation belongs to it. diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir index 358bd33..242c04f 100644 --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -1035,6 +1035,20 @@ llvm.func @rocdl.s.wait.expcnt() { llvm.return } +llvm.func @rocdl.s.wait.asynccnt() { + // CHECK-LABEL: rocdl.s.wait.asynccnt + // CHECK: rocdl.s.wait.asynccnt 0 + rocdl.s.wait.asynccnt 0 + llvm.return +} + +llvm.func @rocdl.s.wait.tensorcnt() { + // CHECK-LABEL: rocdl.s.wait.tensorcnt + // CHECK: rocdl.s.wait.tensorcnt 0 + rocdl.s.wait.tensorcnt 0 + llvm.return +} + // ----- llvm.func @rocdl.readfirstlane(%src : f32) -> f32 { diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir new file mode 100644 index 0000000..023a0e5 --- /dev/null +++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir @@ -0,0 +1,311 @@ +// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s + +///===----------------------------------------------===// +/// Tests of `StepCompareFolder` +///===----------------------------------------------===// + + +///===------------------------------------===// +/// Tests of `ugt` (unsigned greater than) +///===------------------------------------===// + +// CHECK-LABEL: @ugt_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 3 > [0, 1, 2] => [true, true, true] => true for all indices => fold + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ugt_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ugt_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 > [0, 1, 2] => [true, true, false] => not same for all indices => don't fold + %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ugt_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 3 => [false, false, false] => false for all indices => fold + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ugt_constant_max_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_max_rhs() -> vector<3xi1> { + // The largest i64 possible: + %cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + + +// ----- + +// CHECK-LABEL: @ugt_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ugt_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 2 => [false, false, false] => false for all indices => fold + %1 = arith.cmpi ugt, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ugt_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ugt_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] > 1 => [false, false, true] => not same for all indices => don't fold + %1 = arith.cmpi ugt, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `uge` (unsigned greater than or equal) +///===------------------------------------===// + + +// CHECK-LABEL: @uge_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @uge_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 >= [0, 1, 2] => [true, true, true] => true for all indices => fold + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_uge_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_uge_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 >= [0, 1, 2] => [true, false, false] => not same for all indices => don't fold + %1 = arith.cmpi uge, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @uge_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @uge_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 3 => [false, false, false] => false for all indices => fold + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_uge_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_uge_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] >= 2 => [false, false, true] => not same for all indices => don't fold + %1 = arith.cmpi uge, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + + +///===------------------------------------===// +/// Tests of `ult` (unsigned less than) +///===------------------------------------===// + + +// CHECK-LABEL: @ult_constant_2_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ult_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 2 < [0, 1, 2] => [false, false, false] => false for all indices => fold + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ult_constant_1_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ult_constant_1_lhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // 1 < [0, 1, 2] => [false, false, true] => not same for all indices => don't fold + %1 = arith.cmpi ult, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ult_constant_3_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ult_constant_3_rhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] < 3 => [true, true, true] => true for all indices => fold + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ult_constant_2_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ult_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + // [0, 1, 2] < 2 => [true, true, false] => not same for all indices => don't fold + %1 = arith.cmpi ult, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `ule` (unsigned less than or equal) +///===------------------------------------===// + +// CHECK-LABEL: @ule_constant_3_lhs +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ule_constant_3_lhs() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ule_constant_2_lhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ule_constant_2_lhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %cst, %0 : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @ule_constant_2_rhs +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ule_constant_2_rhs() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst : vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ule_constant_1_rhs +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ule_constant_1_rhs() -> vector<3xi1> { + %cst = arith.constant dense<1> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ule, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `eq` (equal) +///===------------------------------------===// + +// CHECK-LABEL: @eq_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @eq_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_eq_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_eq_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi eq, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +///===------------------------------------===// +/// Tests of `ne` (not equal) +///===------------------------------------===// + +// CHECK-LABEL: @ne_constant_3 +// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1> +// CHECK: return %[[CST]] : vector<3xi1> +func.func @ne_constant_3() -> vector<3xi1> { + %cst = arith.constant dense<3> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + +// ----- + +// CHECK-LABEL: @negative_ne_constant_2 +// CHECK: %[[CMP:.*]] = arith.cmpi +// CHECK: return %[[CMP]] +func.func @negative_ne_constant_2() -> vector<3xi1> { + %cst = arith.constant dense<2> : vector<3xindex> + %0 = vector.step : vector<3xindex> + %1 = arith.cmpi ne, %0, %cst: vector<3xindex> + return %1 : vector<3xi1> +} + diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 35db14e..e5a98b5 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -188,15 +188,38 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf // CHECK-LABEL: func @vector_fma // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32> -// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern. -func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{ +func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{ %0 = vector.fma %a, %a, %a : vector<3x2x2xf32> return %0 : vector<3x2x2xf32> } -// CHECK-LABEL: func @negative_vector_fma_3d -// CHECK-NOT: vector.extract_strided_slice -// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32> -// CHECK: return +// CHECK-LABEL: func @vector_fma_3d +// CHECK-SAME: (%[[SRC:.*]]: vector<3x2x2xf32>) -> vector<3x2x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32> +// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_OUT_0:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_OUT_0:.*]] = vector.shape_cast %[[E_OUT_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[FMA0:.*]] = vector.fma %[[S_LHS_0]], %[[S_RHS_0]], %[[S_OUT_0]] : vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[FMA0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32> +// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_OUT_1:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_OUT_1:.*]] = vector.shape_cast %[[E_OUT_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[FMA1:.*]] = vector.fma %[[S_LHS_1]], %[[S_RHS_1]], %[[S_OUT_1]] : vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[FMA1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32> +// CHECK: %[[E_LHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_2:.*]] = vector.shape_cast %[[E_LHS_2]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_2:.*]] = vector.shape_cast %[[E_RHS_2]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_OUT_2:.*]] = vector.extract_strided_slice %[[SRC]] {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_OUT_2:.*]] = vector.shape_cast %[[E_OUT_2]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[FMA2:.*]] = vector.fma %[[S_LHS_2]], %[[S_RHS_2]], %[[S_OUT_2]] : vector<2x2xf32> +// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[FMA2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<3x2x2xf32> +// CHECK: return %[[I2]] : vector<3x2x2xf32> func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> { %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32> @@ -440,3 +463,36 @@ func.func @vector_step() -> vector<32xindex> { // CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex> // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex> // CHECK: return %[[INS3]] : vector<32xindex> + + +func.func @elementwise_3D_to_2D(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> { + %0 = arith.addf %v1, %v2 : vector<2x2x2xf32> + return %0 : vector<2x2x2xf32> +} +// CHECK-LABEL: func @elementwise_3D_to_2D +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x2x2xf32>, %[[ARG1:.*]]: vector<2x2x2xf32>) -> vector<2x2x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32> +// CHECK: %[[E_LHS_0:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_0:.*]] = vector.shape_cast %[[E_LHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_0:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_0:.*]] = vector.shape_cast %[[E_RHS_0]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[ADD0:.*]] = arith.addf %[[S_LHS_0]], %[[S_RHS_0]] : vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[ADD0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32> +// CHECK: %[[E_LHS_1:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_LHS_1:.*]] = vector.shape_cast %[[E_LHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[E_RHS_1:.*]] = vector.extract_strided_slice %[[ARG1]] {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32> +// CHECK: %[[S_RHS_1:.*]] = vector.shape_cast %[[E_RHS_1]] : vector<1x2x2xf32> to vector<2x2xf32> +// CHECK: %[[ADD1:.*]] = arith.addf %[[S_LHS_1]], %[[S_RHS_1]] : vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[ADD1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<2x2x2xf32> +// CHECK: return %[[I1]] : vector<2x2x2xf32> + + +func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf32>) -> vector<2x2x2x2xf32> { + %0 = arith.addf %v1, %v2 : vector<2x2x2x2xf32> + return %0 : vector<2x2x2x2xf32> +} + +// CHECK-LABEL: func @elementwise_4D_to_2D +// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> +// CHECK-NOT: arith.addf +// CHECK: return diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 78e1e659..6cccfe4 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -567,3 +567,25 @@ llvm.func @clusterlaunchcontrol_query_cancel_get_first_cta_id_invalid_return_typ %res = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %try_cancel_response : i1 llvm.return } + +// ----- + +// Test that ensures invalid row/col layouts for matrices A and B are not accepted +llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { + // expected-error@+1 {{Only m8n8k4 with f16 supports other layouts.}} + %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] + {layoutA = #nvvm.mma_layout<col>, layoutB = #nvvm.mma_layout<col>, + multiplicandAPtxType = #nvvm.mma_type<s4>, multiplicandBPtxType = #nvvm.mma_type<s4>, + intOverflowBehavior=#nvvm.mma_int_overflow<satfinite>, + shape = #nvvm.shape<m = 16, n = 8, k = 32>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)> + llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> +} + +// ----- + +// Test for range validation - invalid range where lower == upper but not at extremes +func.func @invalid_range_equal_bounds() { + // expected-error @below {{invalid range attribute: Lower == Upper, but they aren't min (0) or max (4294967295) value! This is an invalid constant range.}} + %0 = nvvm.read.ptx.sreg.warpsize range <i32, 32, 32> : i32 + return +} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 00a479d..594ae48 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -152,6 +152,10 @@ llvm.func @nvvm_special_regs() -> i32 { %74 = nvvm.read.ptx.sreg.lanemask.ge : i32 //CHECK: call i32 @llvm.nvvm.read.ptx.sreg.lanemask.gt %75 = nvvm.read.ptx.sreg.lanemask.gt : i32 + // CHECK: %76 = call range(i32 0, 0) i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %76 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 0> : i32 + // CHECK: %77 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %77 = nvvm.read.ptx.sreg.tid.x range <i32, 4294967295, 4294967295> : i32 llvm.return %1 : i32 } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index fdd2c91..6536fac 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -276,6 +276,20 @@ llvm.func @rocdl.s.wait.expcnt() { llvm.return } +llvm.func @rocdl.s.wait.asynccnt() { + // CHECK-LABEL: rocdl.s.wait.asynccnt + // CHECK-NEXT: call void @llvm.amdgcn.s.wait.asynccnt(i16 0) + rocdl.s.wait.asynccnt 0 + llvm.return +} + +llvm.func @rocdl.s.wait.tensorcnt() { + // CHECK-LABEL: rocdl.s.wait.tensorcnt + // CHECK-NEXT: call void @llvm.amdgcn.s.wait.tensorcnt(i16 0) + rocdl.s.wait.tensorcnt 0 + llvm.return +} + llvm.func @rocdl.setprio() { // CHECK: call void @llvm.amdgcn.s.setprio(i16 0) rocdl.s.setprio 0 diff --git a/mlir/test/mlir-tblgen/cpp-class-comments.td b/mlir/test/mlir-tblgen/cpp-class-comments.td index a896888..9dcf975 100644 --- a/mlir/test/mlir-tblgen/cpp-class-comments.td +++ b/mlir/test/mlir-tblgen/cpp-class-comments.td @@ -96,17 +96,14 @@ def EncodingTrait : AttrInterface<"EncodingTrait"> { }]; let methods = [ ]; -// ATTR-INTERFACE: namespace mlir -// ATTR-INTERFACE-NEXT: namespace a -// ATTR-INTERFACE-NEXT: namespace traits +// ATTR-INTERFACE: namespace mlir::a::traits { // ATTR-INTERFACE-NEXT: /// Common trait for all layouts. // ATTR-INTERFACE-NEXT: class EncodingTrait; } def SimpleEncodingTrait : AttrInterface<"SimpleEncodingTrait"> { let cppNamespace = "a::traits"; -// ATTR-INTERFACE: namespace a { -// ATTR-INTERFACE-NEXT: namespace traits { +// ATTR-INTERFACE: namespace a::traits { // ATTR-INTERFACE-NEXT: class SimpleEncodingTrait; } @@ -116,8 +113,7 @@ def SimpleOpInterface : OpInterface<"SimpleOpInterface"> { Simple Op Interface description }]; -// OP-INTERFACE: namespace a { -// OP-INTERFACE-NEXT: namespace traits { +// OP-INTERFACE: namespace a::traits { // OP-INTERFACE-NEXT: /// Simple Op Interface description // OP-INTERFACE-NEXT: class SimpleOpInterface; } diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp index f99dcdb..76122a0 100644 --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -19,6 +19,7 @@ #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/VirtualFileSystem.h" #include <set> using namespace mlir; @@ -41,6 +42,7 @@ processBuffer(raw_ostream &os, std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, bool dumpODS, std::set<std::string> *includedFiles) { llvm::SourceMgr sourceMgr; sourceMgr.setIncludeDirs(includeDirs); + sourceMgr.setVirtualFileSystem(llvm::vfs::getRealFileSystem()); sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc()); // If we are dumping ODS information, also enable documentation to ensure the diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index d55ad482..11bf9ce 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -701,11 +702,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef underlyingToSymFnName = enumInfo.getUnderlyingToSymbolFnName(); auto enumerants = enumInfo.getAllCases(); - SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, cppNamespace); // Emit the enum class definition emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os); @@ -766,8 +763,7 @@ public: os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName); } - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; + ns.close(); // Generate a generic parser and printer for the enum. std::string qualName = @@ -790,13 +786,8 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); - StringRef cppNamespace = enumInfo.getCppNamespace(); - SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); if (enumInfo.isBitEnum()) { emitSymToStrFnForBitEnum(enumDef, os); @@ -810,10 +801,6 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { if (enumInfo.genSpecializedAttr()) emitSpecializedAttrDef(enumDef, os); - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; - os << "\n"; } static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 730b5b2..ab8d534 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -342,11 +343,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { } void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; - + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); for (auto &method : interface.getMethods()) { os << "template<typename " << valueTemplate << ">\n"; emitCPPType(method.getReturnType(), os); @@ -442,18 +439,11 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { method.isStatic() ? &ctx : &nonStaticMethodFmt); os << "\n}\n"; } - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; - - os << "namespace detail {\n"; + auto cppNamespace = (interface.getCppNamespace() + "::detail").str(); + llvm::NamespaceEmitter ns(os, cppNamespace); StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); @@ -504,10 +494,6 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; - os << "}// namespace detail\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } static void emitInterfaceDeclMethods(const Interface &interface, @@ -533,10 +519,7 @@ static void emitInterfaceDeclMethods(const Interface &interface, } void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); // Emit a forward declaration of the interface class so that it becomes usable // in the signature of its methods. @@ -545,16 +528,10 @@ void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { StringRef interfaceName = interface.getName(); os << "class " << interfaceName << ";\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, interface.getCppNamespace()); StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); @@ -631,9 +608,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { } os << "};\n"; - - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } bool InterfaceGenerator::emitInterfaceDecls() { diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 3ead2f0..ca291b5 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -259,8 +259,8 @@ static void emitInterfaceDecl(const Availability &availability, std::string interfaceTraitsName = std::string(formatv("{0}Traits", interfaceName)); - StringRef cppNamespace = availability.getInterfaceClassNamespace(); - llvm::NamespaceEmitter nsEmitter(os, cppNamespace); + llvm::NamespaceEmitter nsEmitter(os, + availability.getInterfaceClassNamespace()); os << "class " << interfaceName << ";\n\n"; // Emit the traits struct containing the concept and model declarations. @@ -418,15 +418,9 @@ static void emitAvailabilityQueryForBitEnum(const Record &enumDef, static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); StringRef enumName = enumInfo.getEnumClassName(); - StringRef cppNamespace = enumInfo.getCppNamespace(); auto enumerants = enumInfo.getAllCases(); - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; - + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); llvm::StringSet<> handledClasses; // Place all availability specifications to their corresponding @@ -441,9 +435,6 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { enumName); handledClasses.insert(className); } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { @@ -459,31 +450,19 @@ static bool emitEnumDecls(const RecordKeeper &records, raw_ostream &os) { static void emitEnumDef(const Record &enumDef, raw_ostream &os) { EnumInfo enumInfo(enumDef); - StringRef cppNamespace = enumInfo.getCppNamespace(); - - llvm::SmallVector<StringRef, 2> namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + llvm::NamespaceEmitter ns(os, enumInfo.getCppNamespace()); - if (enumInfo.isBitEnum()) { + if (enumInfo.isBitEnum()) emitAvailabilityQueryForBitEnum(enumDef, os); - } else { + else emitAvailabilityQueryForIntEnum(enumDef, os); - } - - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; - os << "\n"; } static bool emitEnumDefs(const RecordKeeper &records, raw_ostream &os) { llvm::emitSourceFileHeader("SPIR-V Enum Availability Definitions", os, records); - auto defs = records.getAllDerivedDefinitions("EnumInfo"); - for (const auto *def : defs) + for (const Record *def : records.getAllDerivedDefinitions("EnumInfo")) emitEnumDef(*def, os); return false; |