diff options
Diffstat (limited to 'mlir/lib')
48 files changed, 1035 insertions, 327 deletions
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d43e681..265293b 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector<Attribute, 8> elements; if (isa<FloatType>(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa<IntegerType>(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa<FloatType>(srcType)) { auto srcAttr = cast<FloatAttr>(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); if (!dstAttr) return failure(); @@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc29..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4..56b6181 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f65..c0439a4 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 855c582..cde2340 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -22,7 +22,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS @@ -32,7 +32,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-funcs" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace { // Pattern to convert vector operations to scalar operations. @@ -653,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op, /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { if (!isa<IntegerType>(elementType)) { - LLVM_DEBUG({ - DBGS() << "non-integer element type for CtlzFunc; type was: "; - elementType.print(llvm::dbgs()); - }); + LDBG() << "non-integer element type for CtlzFunc; type was: " + << elementType; llvm_unreachable("non-integer element type"); } int64_t bitWidth = elementType.getIntOrFloatBitWidth(); diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 93d8b49..df219f3 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,7 +22,6 @@ #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOROCDL @@ -31,7 +31,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-rocdl" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") template <typename OpTy> static void populateOpPatterns(const LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 6ba5bfe4..dc2035b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -24,11 +24,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" + #include <optional> #define DEBUG_TYPE "memref-to-llvm" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS @@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " - "from fmax to fmaximum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw maximumf changed " + "from fmax to fmaximum, expect more NaNs"; return LLVM::AtomicBinOp::fmaximum; case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; @@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " - "from fmin to fminimum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw minimum changed " + "from fmin to fminimum, expect more NaNs"; return LLVM::AtomicBinOp::fminimum; case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 5d13353..2549a9c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -26,13 +26,12 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <optional> #define DEBUG_TYPE "nvgpu-to-nvvm" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGSE() (llvm::dbgs()) namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS @@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) - << ")\n start_addr : " << baseAddr << "\n"); + LDBG() << "Generating warpgroup.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); return success(); @@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering } else { llvm_unreachable("msg: not supported K shape"); } - LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM - << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); + LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]"; } /// Generates WGMMATypesAttr from MLIR Type @@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering int tileShapeA = matrixTypeA.getDimSize(1); int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k - << "] [wgmma descriptors] Descriptor A + " - << incrementVal << " | \t "); + LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal + << " | \t "; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering int byte = elemB.getIntOrFloatBitWidth() / 8; int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + LDBG() << "Descriptor B + " << incrementVal; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" - << (iterationM * wgmmaM) + wgmmaM << "][" - << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" - << wgmmaN << "])\n"); + LDBG() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" + << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM + << "][" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN + << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN - << "] += A[" << totalM << "][" << totalK << "] * B[" - << totalK << "][" << totalN << "] ---===\n"); + LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A[" + << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN + << "] ---==="; // Find the shape for one wgmma instruction findWgmmaShape( diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 662ee9e..91788f9 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -25,11 +25,10 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS @@ -52,17 +51,17 @@ struct PtxLowering LogicalResult matchAndRewrite(BasicPtxBuilderInterface op, PatternRewriter &rewriter) const override { if (op.hasIntrinsic()) { - LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n"); + LDBG() << "Ptx Builder does not lower \n\t" << op; return failure(); } SmallVector<std::pair<Value, PTXRegisterMod>> asmValues; - LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); + LDBG() << op.getPtx(); PtxBuilder generator(op, rewriter); op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { - LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier); + LDBG() << asmValue << "\t Modifier : " << &modifier; generator.insertValue(asmValue, modifier); } diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index fd40e7c..fa9e544 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -36,7 +36,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "shard-to-mpi" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { #define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386e..8cd650e 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index a425eff..1d1904f 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -31,10 +31,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-to-gpu" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU @@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op, // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { if (!supportsMMaMatrixType(op, useNvGpu)) { - LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); + LDBG() << "cannot convert op: " << *op; return true; } return false; @@ -548,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } @@ -583,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; - LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); + LDBG() << "transfer read to: " << load; return success(); } @@ -597,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no mapping\n"); + LDBG() << "no mapping"; return rewriter.notifyMatchFailure(op, "no mapping"); } @@ -613,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; - LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); + LDBG() << "transfer write to: " << store; - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -641,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); if (!dense) { - LLVM_DEBUG(DBGS() << "not a splat\n"); + LDBG() << "not a splat"; return rewriter.notifyMatchFailure(op, "not a splat"); } @@ -677,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { mlir::AffineMap map = op.getPermutationMap(); if (map.getNumResults() != 2) { - LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " - "is not a 2d operand\n"); + LDBG() << "Failed because the result of `vector.transfer_read` " + "is not a 2d operand"; return failure(); } @@ -691,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { auto exprN = dyn_cast<AffineDimExpr>(dN); if (!exprM || !exprN) { - LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " - "expressions, then transpose cannot be determined.\n"); + LDBG() << "Failed because expressions are not affine dim " + "expressions, then transpose cannot be determined."; return failure(); } @@ -709,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } FailureOr<bool> transpose = isTransposed(op); if (failed(transpose)) { - LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); + LDBG() << "failed to determine the transpose"; return rewriter.notifyMatchFailure( op, "Op should likely not be converted to a nvgpu.ldmatrix call."); } @@ -731,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); if (failed(params)) { - LLVM_DEBUG( - DBGS() - << "failed to convert vector.transfer_read to ldmatrix. " - << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); + LDBG() << "failed to convert vector.transfer_read to ldmatrix. " + << "Op should likely not be converted to a nvgpu.ldmatrix call."; return rewriter.notifyMatchFailure( op, "failed to convert vector.transfer_read to ldmatrix; this op " "likely should not be converted to a nvgpu.ldmatrix call."); @@ -745,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<AffineMap> offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { - LLVM_DEBUG(DBGS() << "no offsets\n"); + LDBG() << "no offsets"; return rewriter.notifyMatchFailure(op, "no offsets"); } @@ -934,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, loop.getNumResults()))) rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); - LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); - LLVM_DEBUG(DBGS() << "erase: " << loop); + LDBG() << "newLoop now: " << newLoop; + LDBG() << "stripped scf.for: " << loop; + LDBG() << "erase: " << loop; rewriter.eraseOp(loop); return newLoop; @@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, for (const auto &operand : llvm::enumerate(op.getInitArgs())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); + LDBG() << "no value mapping for: " << operand.value(); continue; } argMapping.push_back(std::make_pair( @@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } - LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); + LDBG() << "scf.for to: " << newForOp; return success(); } @@ -1191,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, } scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, auto globalRes = LogicalResult::success(); for (Operation *op : ops) { - LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); + LDBG() << "Process op: " << *op; // Apparently callers do not want to early exit on failure here. auto res = LogicalResult::success(); if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 8d7053c..22608a1 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -26,7 +26,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include <numeric> @@ -40,7 +40,6 @@ using llvm::divideFloorSigned; using llvm::mod; #define DEBUG_TYPE "affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc" @@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineMap *map, ValueRange dims, ValueRange syms) { + LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`"; AffineMap affineMinMap = minOp.getAffineMap(); - LLVM_DEBUG({ - DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n"; - }); - // Check the value is positive. for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) { // Compare each expression in the minimum against 0. diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index cffe310..52cd0ce 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -30,6 +30,7 @@ #include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 935aa3c..b951df8 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -22,6 +22,8 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + #define DEBUG_TYPE "llvm-inliner" using namespace mlir; @@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { bool wouldBeCloned) const final { auto callOp = dyn_cast<LLVM::CallOp>(call); if (!callOp) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '" - << LLVM::CallOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: call is not an '" + << LLVM::CallOp::getOperationName() << "' op"; return false; } if (callOp.getNoInline()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n"); + LDBG() << "Cannot inline: call is marked no_inline"; return false; } auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable); if (!funcOp) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: callable is not an '" - << LLVM::LLVMFuncOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: callable is not an '" + << LLVM::LLVMFuncOp::getOperationName() << "' op"; return false; } if (funcOp.isNoInline()) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: function is marked no_inline\n"); + LDBG() << "Cannot inline: function is marked no_inline"; return false; } if (funcOp.isVarArg()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n"); + LDBG() << "Cannot inline: callable is variadic"; return false; } // TODO: Generate aliasing metadata from noalias result attributes. if (auto attrs = funcOp.getArgAttrs()) { for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) { if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": inalloca arguments not supported\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": inalloca arguments not supported"; return false; } } } // TODO: Handle exceptions. if (funcOp.getPersonality()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": unhandled function personality\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": unhandled function personality"; return false; } if (funcOp.getPassthrough()) { @@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { if (!stringAttr) return false; if (disallowedFunctionAttrs.contains(stringAttr)) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline " << funcOp.getSymName() - << ": found disallowed function attribute " - << stringAttr << "\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": found disallowed function attribute " << stringAttr; return true; } return false; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 7f9ba1b..bf66ed0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -637,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { } ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape(); + ArrayRef<int64_t> resultShape = padOp.getResultType().getShape(); int64_t padRank = sourceShape.size(); auto isStaticZero = [](OpFoldResult f) { @@ -647,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { allowedUnitDims.end()); llvm::SmallDenseSet<unsigned> unitDims; SmallVector<int64_t> newShape; + SmallVector<int64_t> newResultShape; SmallVector<OpFoldResult> newLowPad; SmallVector<OpFoldResult> newHighPad; - for (const auto [dim, size, low, high] : - zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, - padOp.getMixedLowPad(), padOp.getMixedHighPad())) { + for (const auto [dim, size, outSize, low, high] : zip_equal( + llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, + resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) { if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && isStaticZero(high)) { unitDims.insert(dim); } else { newShape.push_back(size); + newResultShape.push_back(outSize); newLowPad.push_back(low); newHighPad.push_back(high); } @@ -686,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, reassociationMap, options.rankReductionStrategy); - auto newPadOp = tensor::PadOp::create( - rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, + auto newResultType = RankedTensorType::get( + newResultShape, padOp.getResultType().getElementType()); + auto newPadOp = rewriter.create<tensor::PadOp>( + padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 2c62cb6..2e62523 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, return paddingSizes; } +/// Extracts the constant multiplier from an affine expression of the form +/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an +/// AffineConstantExpr. Returns 1 if the expression is not a simple +/// multiplication of a dimension and a constant. +static int64_t extractConstantMultiplier(AffineExpr expr) { + if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) { + if (binOp.getKind() == AffineExprKind::Mul) { + auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS()); + auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS()); + if (lhsD && rhsC) { + return rhsC.getValue(); + } + auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS()); + auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS()); + if (lhsC && rhsD) { + return lhsC.getValue(); + } + } + } + return 1; +} + /// Compute the padded shape of the given value `v` of `RankedTensorType` given /// - `indexingSizes` a list of OpFoldResult. /// - an `indexingMap` that encodes how the shape of varies with increases @@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps. /// The implementaiton below iteratively combines increases from contributing /// dimensions using affine.apply operations. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permutation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> linalg::computePaddedShape( @@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( /*compressDims=*/true); // If we are padding to the next multiple of, compose with ceil(sz) * sz. + OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; bindDims(rewriter.getContext(), d0); bindSymbols(rewriter.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); - terms.push_back(paddingDimOfr); } else { // Otherwise just set to paddingSize. - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, projectedMap, paddingSize); - terms.push_back(paddingDimOfr); } + // Adjust for the maximum accessed index, which is (paddingSize - 1) * + // multiplier. + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); + AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); + OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( + rewriter, loc, subtractMap, {paddingDimOfr}); + terms.push_back(maxAccessIdx); + LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); } @@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; - OpFoldResult paddedDimOfr = - affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms); + // Add 1 to the maximum accessed index and get the final padded size. + OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( + rewriter, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 793eec7..ea68b1a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1946,12 +1946,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( rewriter, loc, vecCollapsedType, transposeOp->getResult(0)); - // writeVectorSizes had to match the shapecast shape for dynamic sizes, - // otherwise the validator complains that the mask size is invalid. - SmallVector<int64_t> writeVectorSizes( - unpackOp.getDestType().hasStaticShape() - ? vectorSizes - : shapeCastOp.getResultVectorType().getShape()); Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), /*writeIndices=*/{}, useInBoundsInsteadOfMasking); diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index e73bdd3..9d5dfc1 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() { getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static); } +acc::LoopParMode +acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) { + if (hasSeq(deviceType)) + return LoopParMode::loop_seq; + if (hasAuto(deviceType)) + return LoopParMode::loop_auto; + if (hasIndependent(deviceType)) + return LoopParMode::loop_independent; + if (hasSeq()) + return LoopParMode::loop_seq; + if (hasAuto()) + return LoopParMode::loop_auto; + assert(hasIndependent() && + "loop must have default auto, seq, or independent"); + return LoopParMode::loop_independent; +} + void acc::LoopOp::addGangOperands( MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes, llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) { diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 759e58b..0262a1b 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, if (parser.parseOptionalArrowTypeList(result.types)) return failure(); + if (succeeded(parser.parseOptionalKeyword("no_inline"))) + result.addAttribute("no_inline", parser.getBuilder().getUnitAttr()); + // Introduce the body region and parse it. Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || @@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, void ExecuteRegionOp::print(OpAsmPrinter &p) { p.printOptionalArrowTypeList(getResultTypes()); - p << ' '; + if (getNoInline()) + p << "no_inline "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); @@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override { - if (!op.getRegion().hasOneBlock()) + if (!op.getRegion().hasOneBlock() || op.getNoInline()) return failure(); replaceOpWithRegion(rewriter, op, op.getRegion()); return success(); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 9bee200..fcf1526 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations( // `!spirv.struct<` (id `,`)? // `(` // (spirv-type (`[` struct-member-decoration `]`)?)* -// `)>` +// `)` +// (`,` struct-decoration)? +// `>` static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser) { // TODO: This function is quite lengthy. Break it down into smaller chunks. @@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect, return Type(); } - if (failed(parser.parseRParen()) || failed(parser.parseGreater())) + if (failed(parser.parseRParen())) + return Type(); + + SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo; + + auto parseStructDecoration = [&]() { + std::optional<spirv::Decoration> decoration = + parseAndVerify<spirv::Decoration>(dialect, parser); + if (!decoration) + return failure(); + + // Parse decoration value if it exists. + if (succeeded(parser.parseOptionalEqual())) { + Attribute decorationValue; + if (failed(parser.parseAttribute(decorationValue))) + return failure(); + + structDecorationInfo.emplace_back(decoration.value(), decorationValue); + } else { + structDecorationInfo.emplace_back(decoration.value(), + UnitAttr::get(dialect.getContext())); + } + return success(); + }; + + while (succeeded(parser.parseOptionalComma())) + if (failed(parseStructDecoration())) + return Type(); + + if (failed(parser.parseGreater())) return Type(); if (!identifier.empty()) { if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, - memberDecorationInfo))) + memberDecorationInfo, + structDecorationInfo))) return Type(); return idStructTy; } - return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo, + structDecorationInfo); } // spirv-type ::= array-type @@ -893,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) { }; llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, printMember); - os << ")>"; + os << ")"; + + SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations; + type.getStructDecorations(decorations); + if (!decorations.empty()) { + os << ", "; + auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) { + os << stringifyDecoration(decoration.decoration); + if (decoration.hasValue()) { + os << "="; + os.printAttributeWithoutType(decoration.decorationValue); + } + }; + llvm::interleaveComma(decorations, os, eachFn); + } + + os << ">"; } static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 46739bc..ddb3426 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -835,12 +835,14 @@ void SampledImageType::getCapabilities( /// - for literal structs: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. /// /// Identified structures only have a mutable component consisting of: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. struct spirv::detail::StructTypeStorage : public TypeStorage { /// Construct a storage object for an identified struct type. A struct type /// associated with such storage must call StructType::trySetBody(...) later @@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage(StringRef identifier) : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), + numStructDecorations(0), structDecorationsInfo(nullptr), identifier(identifier) {} /// Construct a storage object for a literal struct type. A struct type @@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, - StructType::MemberDecorationInfo const *memberDecorationsInfo) + StructType::MemberDecorationInfo const *memberDecorationsInfo, + unsigned numStructDecorations, + StructType::StructDecorationInfo const *structDecorationsInfo) : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), numMembers(numMembers), numMemberDecorations(numMemberDecorations), - memberDecorationsInfo(memberDecorationsInfo) {} + memberDecorationsInfo(memberDecorationsInfo), + numStructDecorations(numStructDecorations), + structDecorationsInfo(structDecorationsInfo) {} /// A storage key is divided into 2 parts: /// - for identified structs: @@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - an ArrayRef<Type> for member types; /// - an ArrayRef<StructType::OffsetInfo> for member offset info; /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration + /// info; + /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration /// info. /// /// An identified struct type is uniqued only by the first part (field 0) /// of the key. /// - /// A literal struct type is uniqued only by the second part (fields 1, 2, and - /// 3) of the key. The identifier field (field 0) must be empty. + /// A literal struct type is uniqued only by the second part (fields 1, 2, 3 + /// and 4) of the key. The identifier field (field 0) must be empty. using KeyTy = std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>, - ArrayRef<StructType::MemberDecorationInfo>>; + ArrayRef<StructType::MemberDecorationInfo>, + ArrayRef<StructType::StructDecorationInfo>>; /// For identified structs, return true if the given key contains the same /// identifier. @@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { } return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), - getMemberDecorationsInfo()); + getMemberDecorationsInfo(), getStructDecorationsInfo()); } /// If the given key contains a non-empty identifier, this method constructs @@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } - return new (allocator.allocate<StructTypeStorage>()) - StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, - numMemberDecorations, memberDecorationList); + const StructType::StructDecorationInfo *structDecorationList = nullptr; + unsigned numStructDecorations = 0; + if (!std::get<4>(key).empty()) { + auto keyStructDecorations = std::get<4>(key); + numStructDecorations = keyStructDecorations.size(); + structDecorationList = allocator.copyInto(keyStructDecorations).data(); + } + + return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage( + keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, + memberDecorationList, numStructDecorations, structDecorationList); } ArrayRef<Type> getMemberTypes() const { @@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { return {}; } + ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const { + if (structDecorationsInfo) + return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo, + numStructDecorations); + return {}; + } + StringRef getIdentifier() const { return identifier; } bool isIdentified() const { return !identifier.empty(); } @@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - If called for an identified struct whose body was set before (through a /// call to this method) but with different contents from the passed /// arguments. - LogicalResult mutate( - TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, - ArrayRef<StructType::OffsetInfo> structOffsetInfo, - ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) { + LogicalResult + mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, + ArrayRef<StructType::OffsetInfo> structOffsetInfo, + ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo, + ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) { if (!isIdentified()) return failure(); if (memberTypesAndIsBodySet.getInt() && (getMemberTypes() != structMemberTypes || getOffsetInfo() != structOffsetInfo || - getMemberDecorationsInfo() != structMemberDecorationInfo)) + getMemberDecorationsInfo() != structMemberDecorationInfo || + getStructDecorationsInfo() != structDecorationInfo)) return failure(); memberTypesAndIsBodySet.setInt(true); @@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { allocator.copyInto(structMemberDecorationInfo).data(); } + if (!structDecorationInfo.empty()) { + numStructDecorations = structDecorationInfo.size(); + structDecorationsInfo = allocator.copyInto(structDecorationInfo).data(); + } + return success(); } @@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; + unsigned numStructDecorations; + StructType::StructDecorationInfo const *structDecorationsInfo; StringRef identifier; }; StructType StructType::get(ArrayRef<Type> memberTypes, ArrayRef<StructType::OffsetInfo> offsetInfo, - ArrayRef<StructType::MemberDecorationInfo> memberDecorations) { + ArrayRef<StructType::MemberDecorationInfo> memberDecorations, + ArrayRef<StructType::StructDecorationInfo> structDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. - SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations( + SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations( memberDecorations); - llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); + llvm::array_pod_sort(sortedMemberDecorations.begin(), + sortedMemberDecorations.end()); + SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations( + structDecorations); + llvm::array_pod_sort(sortedStructDecorations.begin(), + sortedStructDecorations.end()); + return Base::get(memberTypes.vec().front().getContext(), /*identifier=*/StringRef(), memberTypes, offsetInfo, - sortedDecorations); + sortedMemberDecorations, sortedStructDecorations); } StructType StructType::getIdentified(MLIRContext *context, @@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context, return Base::get(context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); } StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { StructType newStructType = Base::get( context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); // Set an empty body in case this is a identified struct. if (newStructType.isIdentified() && failed(newStructType.trySetBody( ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()))) + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()))) return StructType(); return newStructType; @@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const { bool StructType::hasOffset() const { return getImpl()->offsetInfo; } +bool StructType::hasDecoration(spirv::Decoration decoration) const { + for (StructType::StructDecorationInfo info : + getImpl()->getStructDecorationsInfo()) + if (info.decoration == decoration) + return true; + + return false; +} + uint64_t StructType::getMemberOffset(unsigned index) const { assert(getNumElements() > index && "member index out of range"); return getImpl()->offsetInfo[index]; @@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations( } } +void StructType::getStructDecorations( + SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations) + const { + structDecorations.clear(); + auto implDecorations = getImpl()->getStructDecorationsInfo(); + structDecorations.append(implDecorations.begin(), implDecorations.end()); +} + LogicalResult StructType::trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo, - ArrayRef<MemberDecorationInfo> memberDecorations) { - return Base::mutate(memberTypes, offsetInfo, memberDecorations); + ArrayRef<MemberDecorationInfo> memberDecorations, + ArrayRef<StructDecorationInfo> structDecorations) { + return Base::mutate(memberTypes, offsetInfo, memberDecorations, + structDecorations); } void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, @@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value( memberDecorationInfo.decoration); } +llvm::hash_code spirv::hash_value( + const StructType::StructDecorationInfo &structDecorationInfo) { + return llvm::hash_value(structDecorationInfo.decoration); +} + //===----------------------------------------------------------------------===// // MatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 35ec019..8f4c4cc 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } + // Handle 8-bit floats. + if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) { + auto bitWidth = type.getIntOrFloatBitWidth(); + if (bitWidth == 8) + return bitWidth / 8; + return std::nullopt; + } + if (auto complexType = dyn_cast<ComplexType>(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) @@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, type.getSignedness()); } +/// Converts 8-bit float types to integer types with the same bit width. +/// Returns a nullptr for unsupported 8-bit float types. +static Type convert8BitFloatType(const SPIRVConversionOptions &options, + FloatType type) { + if (!options.emulateUnsupportedFloatTypes) + return nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(type)) + return IntegerType::get(type.getContext(), type.getWidth()); + LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n"); + return nullptr; +} + +/// Returns a type with the same shape but with any 8-bit float element type +/// converted to the same bit width integer type. This is a noop when the +/// element type is not the 8-bit float type or emulation flag is set to false. +static ShapedType +convertShaped8BitFloatType(ShapedType type, + const SPIRVConversionOptions &options) { + if (!options.emulateUnsupportedFloatTypes) + return type; + Type srcElementType = type.getElementType(); + Type convertedElementType = nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(srcElementType)) + convertedElementType = IntegerType::get( + type.getContext(), srcElementType.getIntOrFloatBitWidth()); + + if (!convertedElementType) + return type; + + return type.clone(convertedElementType); +} + /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. @@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional<spirv::StorageClass> storageClass = {}) { type = cast<VectorType>(convertIndexElementType(type, options)); + type = cast<VectorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { // If this is not a spec allowed scalar type, try to handle sub-byte integer @@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, } type = cast<TensorType>(convertIndexElementType(type, options)); + type = cast<TensorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() @@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } else if (auto indexType = dyn_cast<IndexType>(elementType)) { type = cast<MemRefType>(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); + } else if (auto floatType = dyn_cast<FloatType>(elementType)) { + // Hnadle 8 bit float types. + type = cast<MemRefType>(convertShaped8BitFloatType(type, options)); + arrayElemType = type.getElementType(); } else { LLVM_DEBUG( llvm::dbgs() @@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](FloatType floatType) -> std::optional<Type> { if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); + if (floatType.getWidth() == 8) + return convert8BitFloatType(this->options, floatType); return Type(); }); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6a9b951..a53d0a7 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() { if (walkResult.wasInterrupted()) return signalPassFailure(); + // Update min version requirement for capabilities after deducing them. + for (spirv::Capability cap : deducedCapabilities) { + if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) { + deducedVersion = std::max(deducedVersion, *minVersion); + if (deducedVersion > allowedVersion) { + module.emitError("Capability '") + << spirv::stringifyCapability(cap) << "' requires min version " + << spirv::stringifyVersion(deducedVersion) + << " but target environment allows up to " + << spirv::stringifyVersion(allowedVersion); + return signalPassFailure(); + } + } + } + // TODO: verify that the deduced version is consistent with // SPIR-V ops' maximal version requirements. diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index e5a3b5d..08fccfa 100644 --- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -38,7 +38,6 @@ #include <utility> #define DEBUG_TYPE "shard-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::shard; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 88b0f36..9543fa1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { CheckCondition condition = CheckCondition::invalid; const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + if (failed(maybeProfDef) && failed(maybeExtDef)) + return success(); - if (!failed(maybeProfDef) && !failed(maybeExtDef) && - !maybeProfDef.value().size() && !maybeExtDef.value().size()) { + const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->empty()); + if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); os << "illegal: operation operand/result data types did not align with any " diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8789f55..86fbb76 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6316,6 +6316,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() { return llvm::to_vector<4>(getResultVectorType().getShape()); } +void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. @@ -7198,6 +7203,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, } //===----------------------------------------------------------------------===// +// StepOp +//===----------------------------------------------------------------------===// + +void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + auto resultType = cast<VectorType>(getType()); + if (resultType.isScalable()) { + return; + } + unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType); + APInt zero(bitwidth, 0); + APInt high(bitwidth, resultType.getDimSize(0) - 1); + ConstantIntRanges result = {zero, high, zero, high}; + setResultRanges(getResult(), result); +} + +//===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index cb8e566..dedc3b3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -28,7 +28,10 @@ using namespace mlir; using namespace mlir::vector; namespace { -/// Progressive lowering of BroadcastOp. + +/// Convert a vector.broadcast with a vector operand to a lower rank +/// vector.broadcast. vector.broadcast with a scalar operand is expected to be +/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: using OpRewritePattern::OpRewritePattern; @@ -40,20 +43,23 @@ public: VectorType srcType = dyn_cast<VectorType>(op.getSourceType()); Type eltType = dstType.getElementType(); - // Scalar to any vector can use splat. - if (!srcType) { - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); - return success(); - } + // A broadcast from a scalar is considered to be in the lowered form. + if (!srcType) + return rewriter.notifyMatchFailure( + op, "broadcast from scalar already in lowered form"); // Determine rank of source and destination. int64_t srcRank = srcType.getRank(); int64_t dstRank = dstType.getRank(); - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + // Here we are broadcasting to a rank-1 vector. Ensure that the source is a + // scalar. if (srcRank <= 1 && dstRank == 1) { - Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); + SmallVector<int64_t> fullRankPosition(srcRank, 0); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), + fullRankPosition); + assert(!isa<VectorType>(ext.getType()) && "expected scalar"); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 4baeb11..2cf8f0b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering read, "vector type is not rank 1, can't create masked load, needs " "VectorToSCF"); - Value fill = vector::SplatOp::create( + Value fill = vector::BroadcastOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding()); res = vector::MaskedLoadOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 72352d7..cbb9d4b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -303,7 +303,7 @@ public: // Extract/insert on a lower ranked extract strided slice op. Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, dstType, zero); + Value res = BroadcastOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 48d680c..c707f38 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -25,12 +25,10 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-transfer-opt" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") - using namespace mlir; /// Return the ancestor op in the region or nullptr if the region is not @@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) { /// transfer_write is dead if all reads that can be reached from the potentially /// dead transfer_write are dominated by the overwriting transfer_write. void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { - LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() - << "\n"); + LDBG() << "Candidate for dead store: " << *write.getOperation(); llvm::SmallVector<Operation *, 8> blockingAccesses; Operation *firstOverwriteCandidate = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase())); @@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { !isReachable(writeAncestor, accessAncestor)) continue; if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { - LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " - << *accessAncestor << "\n"); + LDBG() << "Store may not be dead due to op: " << *accessAncestor; return; } } - LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() - << " overwritten by: " << *firstOverwriteCandidate << "\n"); + LDBG() << "Found dead store: " << *write.getOperation() + << " overwritten by: " << *firstOverwriteCandidate; opToErase.push_back(write.getOperation()); } @@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (read.hasOutOfBoundsDim()) return; - LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() - << "\n"); + LDBG() << "Candidate for Forwarding: " << *read.getOperation(); SmallVector<Operation *, 8> blockingWrites; vector::TransferWriteOp lastwrite = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase())); @@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) continue; if (!postDominators.postDominates(lastwrite, write)) { - LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " - << *write << "\n"); + LDBG() << "Fail to do write to read forwarding due to op: " << *write; return; } } - LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() - << " to: " << *read.getOperation() << "\n"); + LDBG() << "Forward value from " << *lastwrite.getOperation() + << " to: " << *read.getOperation(); read.replaceAllUsesWith(lastwrite.getVector()); opToErase.push_back(read.getOperation()); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 8de87fe..2269a40 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -939,7 +939,7 @@ public: Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, castDstType, zero); + Value res = BroadcastOp::create(rewriter, loc, castDstType, zero); SmallVector<int64_t> sliceShape = {castDstLastDim}; SmallVector<int64_t> strides = {1}; @@ -965,6 +965,45 @@ private: std::function<bool(BitCastOp)> controlFn; }; +static bool haveSameShapeAndScaling(Type t, Type u) { + auto tVec = dyn_cast<VectorType>(t); + auto uVec = dyn_cast<VectorType>(u); + if (!tVec) { + return !uVec; + } + if (!uVec) { + return false; + } + return tVec.getShape() == uVec.getShape() && + tVec.getScalableDims() == uVec.getScalableDims(); +} + +/// If `type` is shaped, clone it with `newElementType`. Otherwise, +/// return `newElementType`. +static Type cloneOrReplace(Type type, Type newElementType) { + if (auto shapedType = dyn_cast<ShapedType>(type)) { + return shapedType.clone(newElementType); + } + return newElementType; +} + +/// If `value` is the result of a splat or broadcast operation, return the input +/// of the splat/broadcast operation. +static Value getBroadcastLikeSource(Value value) { + + Operation *op = value.getDefiningOp(); + if (!op) + return {}; + + if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) + return broadcast.getSource(); + + if (auto splat = dyn_cast<vector::SplatOp>(op)) + return splat.getInput(); + + return {}; +} + /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// /// Example: @@ -988,16 +1027,14 @@ struct ReorderElementwiseOpsOnBroadcast final PatternRewriter &rewriter) const override { if (op->getNumResults() != 1) return failure(); - if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) + auto resultType = dyn_cast<VectorType>(op->getResult(0).getType()); + if (!resultType) return failure(); if (!OpTrait::hasElementwiseMappableTraits(op)) return rewriter.notifyMatchFailure( op, "Op doesn't have ElementwiseMappableTraits"); if (op->getNumOperands() == 0) return failure(); - if (op->getResults()[0].getType() != op->getOperand(0).getType()) - return rewriter.notifyMatchFailure(op, - "result and operand type mismatch"); if (isa<vector::FMAOp>(op)) { return rewriter.notifyMatchFailure( op, @@ -1005,45 +1042,71 @@ struct ReorderElementwiseOpsOnBroadcast final "might be a scalar"); } - // Get the type of the lhs operand - auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); - if (!lhsBcastOrSplat || - !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) + Type resultElemType = resultType.getElementType(); + + // Get the type of the first non-constant operand + Value splatSource; + for (Value operand : op->getOperands()) { + Operation *definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + if (definingOp->hasTrait<OpTrait::ConstantLike>()) + continue; + splatSource = getBroadcastLikeSource(operand); + break; + } + if (!splatSource) return failure(); - auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); + Type unbroadcastResultType = + cloneOrReplace(splatSource.getType(), resultElemType); - // Make sure that all operands are broadcast from identical types: + // Make sure that all operands are broadcast from identically-shaped types: // * scalar (`vector.broadcast` + `vector.splat`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { - auto bcast = val.getDefiningOp<vector::BroadcastOp>(); - if (bcast) - return (bcast.getOperand().getType() == lhsBcastOrSplatType); - auto splat = val.getDefiningOp<vector::SplatOp>(); - if (splat) - return (splat.getOperand().getType() == lhsBcastOrSplatType); - return false; + if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (auto source = getBroadcastLikeSource(val)) + return haveSameShapeAndScaling(source.getType(), + splatSource.getType()); + SplatElementsAttr splatConst; + return matchPattern(val, m_Constant(&splatConst)); })) { - return failure(); + return rewriter.notifyMatchFailure( + op, + "not all operands are constants or broadcasts from the same type"); } // Collect the source values before broadcasting SmallVector<Value> srcValues; srcValues.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + SplatElementsAttr splatConst; + if (matchPattern(operand, m_Constant(&splatConst))) { + Attribute newConst; + Type elementType = getElementTypeOrSelf(operand.getType()); + Type newType = cloneOrReplace(unbroadcastResultType, elementType); + if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) { + newConst = splatConst.resizeSplat(newTypeShaped); + } else { + newConst = splatConst.getSplatValue<Attribute>(); + } + Operation *newConstOp = + operand.getDefiningOp()->getDialect()->materializeConstant( + rewriter, newConst, newType, operand.getLoc()); + srcValues.push_back(newConstOp->getResult(0)); + } else { + srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + } } // Create the "elementwise" Op Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, - lhsBcastOrSplatType, op->getAttrs()); + unbroadcastResultType, op->getAttrs()); // Replace the original Op with the elementwise Op - auto vectorType = op->getResultTypes()[0]; rewriter.replaceOpWithNewOp<vector::BroadcastOp>( - op, vectorType, elementwiseOp->getResults()); + op, resultType, elementwiseOp->getResults()); return success(); } @@ -1239,15 +1302,17 @@ public: return rewriter.notifyMatchFailure( op, "only 1-element vectors are supported"); - Operation *splat = op.getValueToStore().getDefiningOp(); - if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat)) - return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); + Value toStore = op.getValueToStore(); + Value source = getBroadcastLikeSource(toStore); + if (!source) + return rewriter.notifyMatchFailure( + op, "value to store is not from a broadcast"); // Checking for single use so we can remove splat. + Operation *splat = toStore.getDefiningOp(); if (!splat->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); - Value source = splat->getOperand(0); Value base = op.getBase(); ValueRange indices = op.getIndices(); @@ -1297,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, // Add in an offset if requested. if (off) { Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o); + Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o); indices = arith::AddIOp::create(rewriter, loc, ov, indices); } // Construct the vector comparison. Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = - vector::SplatOp::create(rewriter, loc, indices.getType(), bound); + vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound); return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, indices, bounds); } diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 704deea..33450f3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, return success(); } +static LogicalResult +isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, + int64_t chunkSize, + function_ref<InFlightDiagnostic()> emitError) { + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + + // a valid shape for SIMT case + if (valueTy.getRank() == 1) { + if (valueTy.getNumElements() != chunkSize) + return emitError() << "value elements must match chunk size " << chunkSize + << " for SIMT code."; + return success(); + } + + llvm::SmallVector<int64_t> expectedMaskShape(valueShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) + return emitError() << "Mask should match value except the chunk size dim."; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (!tdescTy.isScattered()) + + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() { return success(); } +void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// @@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); + + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + auto srcTy = getSourceType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(srcTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, valueType, source, Value(), mask, IntegerAttr(), + l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// @@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getDest()) > 1) + return emitOpError( + "Expecting the dest is a 1D memref or pointer (uint64_t)."); + if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + + auto destTy = getDestType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(destTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, + l2_hint, l3_hint); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index ec8fad4..c793b71 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f95ad29..de52fbd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -40,7 +40,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" @@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, return failure(); }); if (failed(verify(op))) { - LLVM_DEBUG(llvm::dbgs() - << DEBUG_TYPE << ": '" << op->getName() - << "' failed to verify and will be printed in generic form\n"); + LDBG() << op->getName() + << "' failed to verify and will be printed in generic form"; printerFlags.printGenericOpForm(); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e9b5e92..310680b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -17,14 +17,32 @@ using namespace mlir; +static std::pair<int64_t, int64_t> +getLineAndColStart(const llvm::SourceMgr &sourceMgr) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + if (lastFileID == 1) + return {0, 0}; + + auto bufferID = sourceMgr.getMainFileID(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + // Exclude same start. + if (main->getBufferStart() < last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + return sourceMgr.getLineAndColumn( + llvm::SMLoc::getFromPointer(last->getBufferStart()), bufferID); + } + return {0, 0}; +} + LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) { const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(*sourceBuf, block, config); @@ -37,9 +55,9 @@ mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, const auto *sourceBuf = sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(*sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(sourceMgr, block, config); diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index af22a7f..9ea5c683 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRROCDLToLLVMIRTranslation MLIRSPIRVToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation + MLIRXeVMToLLVMIRTranslation ) add_mlir_translation_library(MLIRTargetLLVMIRImport diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index f030fa7..86c731a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -10,3 +10,4 @@ add_subdirectory(OpenMP) add_subdirectory(ROCDL) add_subdirectory(SPIRV) add_subdirectory(VCIX) +add_subdirectory(XeVM) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index ff34a08..0f675a0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands, return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); } -static LogicalResult -convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray, - ArrayAttr resAttrsArray, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - if (argAttrsArray) { - for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { - if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr); - !argAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, argAttrs); - if (failed(attrBuilder)) - return failure(); - call->addParamAttrs(argIdx, *attrBuilder); - } - } - } - - if (resAttrsArray && resAttrsArray.size() > 0) { - if (resAttrsArray.size() != 1) - return mlir::emitError(loc, "llvm.func cannot have multiple results"); - if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); - !resAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, resAttrs); - if (failed(attrBuilder)) - return failure(); - call->addRetAttrs(*attrBuilder); - } - } - return success(); -} - -static LogicalResult -convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - return convertParameterAndResultAttrs( - callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call, - moduleTranslation); -} - /// Builder for LLVM_CallIntrinsicOp static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, @@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(), moduleTranslation)); - if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(), - op.getResAttrsAttr(), inst, - moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst))) return failure(); if (op.getNumResults() == 1) @@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getInlineHintAttr()) call->addFnAttr(llvm::Attribute::InlineHint); - if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call))) return failure(); if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { @@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, operandsRef.drop_front(), opBundles); } result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); - if (failed( - convertParameterAndResultAttrs(invOp, result, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result))) return failure(); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp index 1c9e226..55e73e8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Target/LLVMIR/ModuleImport.h" +#include "llvm/IR/ConstantRange.h" using namespace mlir; using namespace mlir::NVVM; diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt new file mode 100644 index 0000000..6308d7e --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_OPTIONAL_SOURCES + XeVMToLLVMIRTranslation.cpp +) + +add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation + XeVMToLLVMIRTranslation.cpp + + DEPENDS + MLIRXeVMConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRLLVMDialect + MLIRXeVMDialect + MLIRSupport + MLIRTargetLLVMIRExport +) diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp new file mode 100644 index 0000000..73b166d --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp @@ -0,0 +1,103 @@ +//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR XeVM dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" + +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the XeVM dialect to LLVM IR. +class XeVMDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + StringRef attrName = attribute.getName().getValue(); + if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) { + auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue()); + if (cacheControlsArray.size() != 2) { + return op->emitOpError( + "Expected both L1 and L3 cache control attributes!"); + } + if (instructions.size() != 1) { + return op->emitOpError("Expecting a single instruction"); + } + return handleDecorationCacheControl(instructions.front(), + cacheControlsArray.getValue()); + } + auto func = dyn_cast<LLVM::LLVMFuncOp>(op); + if (!func) + return failure(); + + return success(); + } + +private: + static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst, + ArrayRef<Attribute> attrs) { + SmallVector<llvm::Metadata *> decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); + llvm::transform( + attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue(); + std::array<llvm::Metadata *, 4> metadata; + llvm::transform( + valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( + i32Ty, cast<IntegerAttr>(valueAttr).getValue())); + }); + return llvm::MDNode::get(ctx, metadata); + }); + constexpr llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } +}; +} // namespace + +void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry ®istry) { + registry.insert<xevm::XeVMDialect>(); + registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) { + dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>(); + }); +} + +void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) { + DialectRegistry registry; + registerXeVMDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp index 580afdd..cb1f234 100644 --- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp @@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( SmallVector<Value> mlirOperands; SmallVector<NamedAttribute> mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs))) + llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false, + /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands, + mlirAttrs))) return failure(); Type resultType = moduleImport.convertType(inst->getType()); @@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( ValueRange{mlirOperands}, FastmathFlagsAttr{}); moduleImport.setFastmathFlagsAttr(inst, op); - - ArrayAttr argsAttr, resAttr; - moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder); - op.setArgAttrsAttr(argsAttr); - op.setResAttrsAttr(resAttr); + moduleImport.convertArgAndResultAttrs(inst, op); // Update importer tracking of results. unsigned numRes = op.getNumResults(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 58e3c44..6325480 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Comdat.h" #include "llvm/IR/Constants.h" @@ -1063,6 +1064,18 @@ void ModuleImport::convertTargetTriple() { builder.getStringAttr(llvmModule->getTargetTriple().str())); } +void ModuleImport::convertModuleLevelAsm() { + llvm::StringRef asmStr = llvmModule->getModuleInlineAsm(); + llvm::SmallVector<mlir::Attribute> asmArrayAttr; + + for (llvm::StringRef line : llvm::split(asmStr, '\n')) + if (!line.empty()) + asmArrayAttr.push_back(builder.getStringAttr(line)); + + mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(), + builder.getArrayAttr(asmArrayAttr)); +} + LogicalResult ModuleImport::convertFunctions() { for (llvm::Function &func : llvmModule->functions()) if (failed(processFunction(&func))) @@ -2267,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // call. if (!isIncompatibleCall) - convertParameterAttributes(callInst, callOp, builder); + convertArgAndResultAttrs(callInst, callOp); return callOp.getOperation(); }(); @@ -2364,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // invoke. if (!isIncompatibleInvoke) - convertParameterAttributes(invokeInst, invokeOp, builder); + convertArgAndResultAttrs(invokeInst, invokeOp); if (!invokeInst->getType()->isVoidTy()) mapValue(inst, invokeOp.getResults().front()); @@ -2730,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, } DictionaryAttr -ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, - OpBuilder &builder) { +ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) { SmallVector<NamedAttribute> paramAttrs; for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) { - auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind); + auto llvmAttr = llvmAttrSet.getAttribute(llvmKind); // Skip attributes that are not attached. if (!llvmAttr.isValid()) continue; @@ -2769,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, return builder.getDictionaryAttr(paramAttrs); } -void ModuleImport::convertParameterAttributes(llvm::Function *func, - LLVMFuncOp funcOp, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs(llvm::Function *func, + LLVMFuncOp funcOp) { auto llvmAttrs = func->getAttributes(); for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) { llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i); - funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder)); + funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs)); } // Convert the result attributes and attach them wrapped in an ArrayAttribute // to the funcOp. @@ -2783,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, if (!llvmResAttr.hasAttributes()) return; funcOp.setResAttrsAttr( - builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); + builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)})); } -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - ArrayAttr &argsAttr, - ArrayAttr &resAttr, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs( + llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp, + ArrayRef<unsigned> immArgPositions) { + // Compute the set of immediate argument positions. + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + // Convert the argument attributes and filter out immediate arguments. llvm::AttributeList llvmAttrs = call->getAttributes(); SmallVector<llvm::AttributeSet> llvmArgAttrsSet; bool anyArgAttrs = false; for (size_t i = 0, e = call->arg_size(); i < e; ++i) { + // Skip immediate arguments. + if (immArgPositionsSet.contains(i)) + continue; llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i)); if (llvmArgAttrsSet.back().hasAttributes()) anyArgAttrs = true; @@ -2807,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, if (anyArgAttrs) { SmallVector<DictionaryAttr> argAttrs; for (auto &llvmArgAttrs : llvmArgAttrsSet) - argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); - argsAttr = getArrayAttr(argAttrs); + argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs)); + attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs)); } + // Convert the result attributes. llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); if (!llvmResAttr.hasAttributes()) return; - DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder); - resAttr = getArrayAttr({resAttrs}); -} - -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - CallOpInterface callOp, - OpBuilder &builder) { - ArrayAttr argsAttr, resAttr; - convertParameterAttributes(call, argsAttr, resAttr, builder); - callOp.setArgAttrsAttr(argsAttr); - callOp.setResAttrsAttr(resAttr); + DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr); + attrsOp.setResAttrsAttr(getArrayAttr({resAttrs})); } template <typename Op> @@ -2892,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { builder, loc, func->getName(), functionType, convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv); - convertParameterAttributes(func, funcOp, builder); + convertArgAndResultAttrs(func, funcOp); if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func)) funcOp.setPersonalityAttr(personality); @@ -3199,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule( if (failed(moduleImport.convertIFuncs())) return {}; moduleImport.convertTargetTriple(); + moduleImport.convertModuleLevelAsm(); return module; } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index b997e55..b3a06e2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, return attrBuilder; } +LogicalResult ModuleTranslation::convertArgAndResultAttrs( + ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call, + ArrayRef<unsigned> immArgPositions) { + // Convert the argument attributes. + if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) { + unsigned argAttrIdx = 0; + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) { + if (argAttrIdx >= argAttrsArray.size()) + break; + // Skip immediate arguments (they have no entries in argAttrsArray). + if (immArgPositionsSet.contains(argIdx)) + continue; + // Skip empty argument attributes. + auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]); + if (argAttrs.empty()) + continue; + // Convert and add attributes to the call instruction. + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + + // Convert the result attributes. + if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) { + if (!resAttrsArray.empty()) { + auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + } + + return success(); +} + FailureOr<llvm::AttrBuilder> ModuleTranslation::convertParameterAttrs(Location loc, DictionaryAttr paramAttrs) { @@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, llvmModule->setTargetTriple( llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue())); + if (auto asmAttr = m->getDiscardableAttr( + LLVM::LLVMDialect::getModuleLevelAsmAttrName())) { + auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr); + if (!asmArrayAttr) { + m->emitError("expected an array attribute for a module level asm"); + return nullptr; + } + + for (Attribute elt : asmArrayAttr) { + auto asmStrAttr = dyn_cast<StringAttr>(elt); + if (!asmStrAttr) { + m->emitError( + "expected a string attribute for each entry of a module level asm"); + return nullptr; + } + llvmModule->appendModuleInlineAsm(asmStrAttr.getValue()); + } + } + return llvmModule; } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index e5934bb..88931b5 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target <id>"; } - // Block decoration does not affect spirv.struct type, but is still stored - // for verification. - // TODO: Update StructType to contain this information since - // it is needed for many validation rules. decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); break; case spirv::Decoration::Location: @@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { if (failed(structType.trySetBody( deferredStructIt->memberTypes, deferredStructIt->offsetInfo, - deferredStructIt->memberDecorationsInfo))) + deferredStructIt->memberDecorationsInfo, + deferredStructIt->structDecorationsInfo))) return failure(); deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); @@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { } } + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; + if (decorations.count(operands[0])) { + NamedAttrList &allDecorations = decorations[operands[0]]; + for (NamedAttribute &decorationAttr : allDecorations) { + std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration( + llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true)); + assert(decoration.has_value()); + structDecorationsInfo.emplace_back(decoration.value(), + decorationAttr.getValue()); + } + } + uint32_t structID = operands[0]; std::string structIdentifier = nameMap.lookup(structID).str(); if (structIdentifier.empty()) { assert(unresolvedMemberTypes.empty() && "didn't expect unresolved member types"); - typeMap[structID] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + typeMap[structID] = spirv::StructType::get( + memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo); } else { auto structTy = spirv::StructType::getIdentified(context, structIdentifier); typeMap[structID] = structTy; if (!unresolvedMemberTypes.empty()) - deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, - memberTypes, offsetInfo, - memberDecorationsInfo}); + deferredStructTypesInfos.push_back( + {structTy, unresolvedMemberTypes, memberTypes, offsetInfo, + memberDecorationsInfo, structDecorationsInfo}); else if (failed(structTy.trySetBody(memberTypes, offsetInfo, - memberDecorationsInfo))) + memberDecorationsInfo, + structDecorationsInfo))) return failure(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 20482bd..db1cc3f 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -95,6 +95,7 @@ struct DeferredStructTypeInfo { SmallVector<Type, 4> memberTypes; SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; }; /// A struct that collects the info needed to materialize/emit a diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index a8a2b2e..737f296 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -318,6 +318,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::RestrictPointer: case spirv::Decoration::NoContraction: case spirv::Decoration::Constant: + case spirv::Decoration::Block: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa<UnitAttr, DecorationAttr>(attr)) @@ -630,11 +631,16 @@ LogicalResult Serializer::prepareBasicType( operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); + // TODO: Now struct decorations are supported this code may not be + // necessary. However, it is left to support backwards compatibility. + // Ideally, Block decorations should be inserted when converting to SPIR-V. if (isInterfaceStructPtrType(ptrType)) { - if (failed(emitDecoration(getTypeID(pointeeStruct), - spirv::Decoration::Block))) - return emitError(loc, "cannot decorate ") - << pointeeStruct << " with Block decoration"; + auto structType = cast<spirv::StructType>(ptrType.getPointeeType()); + if (!structType.hasDecoration(spirv::Decoration::Block)) + if (failed(emitDecoration(getTypeID(pointeeStruct), + spirv::Decoration::Block))) + return emitError(loc, "cannot decorate ") + << pointeeStruct << " with Block decoration"; } return success(); @@ -704,6 +710,20 @@ LogicalResult Serializer::prepareBasicType( } } + SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations; + structType.getStructDecorations(structDecorations); + + for (spirv::StructType::StructDecorationInfo &structDecoration : + structDecorations) { + if (failed(processDecorationAttr(loc, resultID, + structDecoration.decoration, + structDecoration.decorationValue))) { + return emitError(loc, "cannot decorate struct ") + << structType << " with " + << stringifyDecoration(structDecoration.decoration); + } + } + typeEnum = spirv::Opcode::OpTypeStruct; if (structType.isIdentified()) @@ -938,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, } else { return 0; } + } else if (isa<spirv::TensorArmType>(constType)) { + numberOfConstituents = shapedType.getNumElements(); + operands.reserve(numberOfConstituents + 2); + for (int i = 0; i < numberOfConstituents; ++i) { + uint32_t elementID = 0; + if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { + elementID = + elementType.isInteger(1) + ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i]) + : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]); + } + if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { + elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]); + } + if (!elementID) { + return 0; + } + operands.push_back(elementID); + } } else { operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { |