diff options
Diffstat (limited to 'mlir/lib')
47 files changed, 1475 insertions, 386 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 85f0fd1d..9b15435 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1927,16 +1927,16 @@ struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> { else llvm_unreachable("unsupported row length"); - const Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); - const Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1}); + Value vdst0 = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); + Value vdst1 = LLVM::ExtractValueOp::create(rewriter, loc, res, {1}); - const Value isEqual = - rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, vdst0, v); + Value isEqual = LLVM::ICmpOp::create(rewriter, loc, + LLVM::ICmpPredicate::eq, vdst0, v); // Per `permlane(16|32)` semantics: if the first extracted element equals // 'v', the result is the second element; otherwise it is the first. Value vdstNew = - rewriter.create<LLVM::SelectOp>(loc, isEqual, vdst1, vdst0); + LLVM::SelectOp::create(rewriter, loc, isEqual, vdst1, vdst0); permuted.emplace_back(vdstNew); } diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 42099aa..12adfe1 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -93,11 +93,11 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> { Location loc = op.getLoc(); Value exponentReal = - rewriter.create<arith::SIToFPOp>(loc, exponentFloatType, op.getRhs()); - Value zeroImag = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(exponentFloatType)); - Value exponent = rewriter.create<complex::CreateOp>( - loc, op.getLhs().getType(), exponentReal, zeroImag); + arith::SIToFPOp::create(rewriter, loc, exponentFloatType, op.getRhs()); + Value zeroImag = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(exponentFloatType)); + Value exponent = complex::CreateOp::create( + rewriter, loc, op.getLhs().getType(), exponentReal, zeroImag); rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(), exponent, op.getFastmathAttr()); diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 5613e02..0fe7239 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -937,14 +937,14 @@ struct PowiOpConversion : public OpConversionPattern<complex::PowiOp> { auto elementType = cast<FloatType>(type.getElementType()); Value floatExponent = - builder.create<arith::SIToFPOp>(elementType, adaptor.getRhs()); + arith::SIToFPOp::create(builder, elementType, adaptor.getRhs()); Value zero = arith::ConstantOp::create( builder, elementType, builder.getFloatAttr(elementType, 0.0)); Value complexExponent = complex::CreateOp::create(builder, type, floatExponent, zero); - auto pow = builder.create<complex::PowOp>( - type, adaptor.getLhs(), complexExponent, op.getFastmathAttr()); + auto pow = complex::PowOp::create(builder, type, adaptor.getLhs(), + complexExponent, op.getFastmathAttr()); rewriter.replaceOp(op, pow.getResult()); return success(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 2285d26..eb662a1 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType}, /*isVarArg=*/true); LLVM::LLVMFuncOp printfDecl = - getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); + getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType); + printfDecl.setCConv(callingConvention); // Create the global op or find an existing one. LLVM::GlobalOp global = getOrCreateStringConstant( @@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); + auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); + call.setCConv(callingConvention); rewriter.eraseOp(gpuPrintfOp); return success(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h index 66d3bb4..ec74787 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -10,6 +10,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" namespace mlir { @@ -142,13 +143,23 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> { /// This pass will add a declaration of printf() to the GPUModule if needed /// and separate out the format strings into global constants. For some /// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler -/// will lower printf calls to appropriate device-side code +/// will lower printf calls to appropriate device-side code. +/// However not all backends use the same calling convention and function +/// naming. +/// For example, the LLVM SPIRV backend requires calling convention +/// LLVM::cconv::CConv::SPIR_FUNC and function name needs to be +/// mangled as "_Z6printfPU3AS2Kcz". +/// Default callingConvention is LLVM::cconv::CConv::C and +/// funcName is "printf" but they can be customized as needed. struct GPUPrintfOpToLLVMCallLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> { - GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter, - int addressSpace = 0) + GPUPrintfOpToLLVMCallLowering( + const LLVMTypeConverter &converter, int addressSpace = 0, + LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C, + StringRef funcName = "printf") : ConvertOpToLLVMPattern<gpu::PrintfOp>(converter), - addressSpace(addressSpace) {} + addressSpace(addressSpace), callingConvention(callingConvention), + funcName(funcName) {} LogicalResult matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, @@ -156,6 +167,8 @@ struct GPUPrintfOpToLLVMCallLowering private: int addressSpace; + LLVM::cconv::CConv callingConvention; + StringRef funcName; }; /// Lowering of gpu.printf to a vprintf standard library. diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp index c2363a1..25f1e1b 100644 --- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp +++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp @@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp, gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp, gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp, - gpu::ThreadIdOp>(); + gpu::ThreadIdOp, gpu::PrintfOp>(); populateGpuToLLVMSPVConversionPatterns(converter, patterns); populateGpuMemorySpaceAttributeConversions(converter); + patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2, + LLVM::cconv::CConv::SPIR_FUNC, + "_Z6printfPU3AS2Kcz"); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 852c50c..d64c4d6 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -500,19 +500,19 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> { op->getParentWithTrait<mlir::OpTrait::AutomaticAllocationScope>(); assert(scope && "Expected op to be inside automatic allocation scope"); rewriter.setInsertionPointToStart(&scope->getRegion(0).front()); - auto one = rewriter.create<LLVM::ConstantOp>( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); + auto one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(1)); sinPtr = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0); + LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0); cosPtr = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, computeType, one, 0); + LLVM::AllocaOp::create(rewriter, loc, ptrType, computeType, one, 0); } createSincosCall(rewriter, loc, sincosFunc, convertedInput, sinPtr, cosPtr, op); - auto sinResult = rewriter.create<LLVM::LoadOp>(loc, computeType, sinPtr); - auto cosResult = rewriter.create<LLVM::LoadOp>(loc, computeType, cosPtr); + auto sinResult = LLVM::LoadOp::create(rewriter, loc, computeType, sinPtr); + auto cosResult = LLVM::LoadOp::create(rewriter, loc, computeType, cosPtr); rewriter.replaceOp(op, {maybeTrunc(sinResult, inputType, rewriter), maybeTrunc(cosResult, inputType, rewriter)}); @@ -522,14 +522,15 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> { private: Value maybeExt(Value operand, PatternRewriter &rewriter) const { if (isa<Float16Type, BFloat16Type>(operand.getType())) - return rewriter.create<LLVM::FPExtOp>( - operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + return LLVM::FPExtOp::create(rewriter, operand.getLoc(), + Float32Type::get(rewriter.getContext()), + operand); return operand; } Value maybeTrunc(Value operand, Type type, PatternRewriter &rewriter) const { if (operand.getType() != type) - return rewriter.create<LLVM::FPTruncOp>(operand.getLoc(), type, operand); + return LLVM::FPTruncOp::create(rewriter, operand.getLoc(), type, operand); return operand; } @@ -556,7 +557,7 @@ private: } SmallVector<Value> callOperands = {input, sinPtr, cosPtr}; - rewriter.create<LLVM::CallOp>(loc, funcOp, callOperands); + LLVM::CallOp::create(rewriter, loc, funcOp, callOperands); } }; diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 229e40e..7cce324 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -142,8 +142,8 @@ struct SincosOpLowering : public ConvertOpToLLVMPattern<math::SincosOp> { auto structType = LLVM::LLVMStructType::getLiteral( rewriter.getContext(), {llvmOperandType, llvmOperandType}); - auto sincosOp = rewriter.create<LLVM::SincosOp>( - loc, structType, adaptor.getOperand(), attrs.getAttrs()); + auto sincosOp = LLVM::SincosOp::create( + rewriter, loc, structType, adaptor.getOperand(), attrs.getAttrs()); auto sinValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 0); auto cosValue = LLVM::ExtractValueOp::create(rewriter, loc, sincosOp, 1); diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 519d9c8..71e3f88 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -394,9 +394,9 @@ private: if (!convertedType) return rewriter.notifyMatchFailure(whileOp, "type conversion failed"); - emitc::VariableOp var = rewriter.create<emitc::VariableOp>( - loc, emitc::LValueType::get(convertedType), noInit); - rewriter.create<emitc::AssignOp>(loc, var.getResult(), init); + auto var = emitc::VariableOp::create( + rewriter, loc, emitc::LValueType::get(convertedType), noInit); + emitc::AssignOp::create(rewriter, loc, var.getResult(), init); loopVars.push_back(var); } @@ -411,11 +411,11 @@ private: // Create a global boolean variable to store the loop condition state. Type i1Type = IntegerType::get(context, 1); auto globalCondition = - rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type), - emitc::OpaqueAttr::get(context, "")); + emitc::VariableOp::create(rewriter, loc, emitc::LValueType::get(i1Type), + emitc::OpaqueAttr::get(context, "")); Value conditionVal = globalCondition.getResult(); - auto loweredDo = rewriter.create<emitc::DoOp>(loc); + auto loweredDo = emitc::DoOp::create(rewriter, loc); // Convert region types to match the target dialect type system. if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), @@ -450,12 +450,12 @@ private: // Convert scf.condition to condition variable assignment. Value condition = rewriter.getRemappedValue(condOp.getCondition()); - rewriter.create<emitc::AssignOp>(loc, conditionVal, condition); + emitc::AssignOp::create(rewriter, loc, conditionVal, condition); // Wrap body region in conditional to preserve scf semantics. Only create // ifOp if after-region is non-empty. if (whileOp.getAfterBody()->getOperations().size() > 1) { - auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false); + auto ifOp = emitc::IfOp::create(rewriter, loc, condition, false, false); // Prepare the after region (loop body) for merging. Block *afterBlock = &whileOp.getAfter().front(); @@ -480,8 +480,8 @@ private: Block *condBlock = rewriter.createBlock(&condRegion); rewriter.setInsertionPointToStart(condBlock); - auto exprOp = rewriter.create<emitc::ExpressionOp>( - loc, i1Type, conditionVal, /*do_not_inline=*/false); + auto exprOp = emitc::ExpressionOp::create( + rewriter, loc, i1Type, conditionVal, /*do_not_inline=*/false); Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion()); // Set up the expression block to load the condition variable. @@ -490,12 +490,12 @@ private: // Load the condition value and yield it as the expression result. Value cond = - rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0)); - rewriter.create<emitc::YieldOp>(loc, cond); + emitc::LoadOp::create(rewriter, loc, i1Type, exprBlock->getArgument(0)); + emitc::YieldOp::create(rewriter, loc, cond); // Yield the expression as the condition region result. rewriter.setInsertionPointToEnd(condBlock); - rewriter.create<emitc::YieldOp>(loc, exprOp); + emitc::YieldOp::create(rewriter, loc, exprOp); return success(); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 00df14b1..29afdc2 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -232,16 +232,16 @@ static Value createLinalgBodyCalculationForElementwiseOp( } intermediateType = rewriter.getIntegerType(intermediateBitWidth); - zpAddValue = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + zpAddValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); } else { intermediateType = rewriter.getIntegerType(intermediateBitWidth); auto arg1 = - rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]); + arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]); auto arg2 = - rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]); + arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]); zpAddValue = - rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2); + arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2); } // The negation can be applied by doing: @@ -1402,8 +1402,8 @@ static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input, auto elemType = inputType.getElementType(); auto collapsedType = RankedTensorType::get({}, elemType); // Emit the collapse op - return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input, - reassociation); + return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, + reassociation); } static llvm::SmallVector<int8_t> @@ -1443,7 +1443,7 @@ static void setupLinalgGenericOpInputAndIndexingMap( IntegerAttr intAttr = isShift ? rewriter.getI8IntegerAttr(values.front()) : rewriter.getI32IntegerAttr(values.front()); - constant = rewriter.create<arith::ConstantOp>(loc, intAttr); + constant = arith::ConstantOp::create(rewriter, loc, intAttr); } else { auto elementType = isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type(); @@ -1511,14 +1511,14 @@ static Value getExtendZp(OpBuilder &builder, Type valueTy, .getResult(0); } if (zpTy.isUnsignedInteger()) { - return builder.create<arith::ExtUIOp>(loc, extendType, result); + return arith::ExtUIOp::create(builder, loc, extendType, result); } else { - return builder.create<arith::ExtSIOp>(loc, extendType, result); + return arith::ExtSIOp::create(builder, loc, extendType, result); } } } else { - return builder.create<arith::ConstantOp>( - loc, IntegerAttr::get(extendType, *maybeZp)); + return arith::ConstantOp::create(builder, loc, + IntegerAttr::get(extendType, *maybeZp)); } return result; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp index 316721b..60ae78b 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp @@ -45,18 +45,15 @@ struct LoopUnroll : public affine::impl::AffineLoopUnrollBase<LoopUnroll> { const std::function<unsigned(AffineForOp)> getUnrollFactor; LoopUnroll() : getUnrollFactor(nullptr) {} - LoopUnroll(const LoopUnroll &other) - - = default; + LoopUnroll(const LoopUnroll &other) = default; explicit LoopUnroll( std::optional<unsigned> unrollFactor = std::nullopt, - bool unrollUpToFactor = false, bool unrollFull = false, + bool unrollUpToFactor = false, const std::function<unsigned(AffineForOp)> &getUnrollFactor = nullptr) : getUnrollFactor(getUnrollFactor) { if (unrollFactor) this->unrollFactor = *unrollFactor; this->unrollUpToFactor = unrollUpToFactor; - this->unrollFull = unrollFull; } void runOnOperation() override; @@ -85,11 +82,17 @@ static void gatherInnermostLoops(FunctionOpInterface f, } void LoopUnroll::runOnOperation() { + if (!(unrollFactor.getValue() > 0 || unrollFactor.getValue() == -1)) { + emitError(UnknownLoc::get(&getContext()), + "Invalid option: 'unroll-factor' should be greater than 0 or " + "equal to -1"); + return signalPassFailure(); + } FunctionOpInterface func = getOperation(); if (func.isExternal()) return; - if (unrollFull && unrollFullThreshold.hasValue()) { + if (unrollFactor.getValue() == -1 && unrollFullThreshold.hasValue()) { // Store short loops as we walk. SmallVector<AffineForOp, 4> loops; @@ -130,7 +133,7 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { return loopUnrollByFactor(forOp, getUnrollFactor(forOp), /*annotateFn=*/nullptr, cleanUpUnroll); // Unroll completely if full loop unroll was specified. - if (unrollFull) + if (unrollFactor.getValue() == -1) return loopUnrollFull(forOp); // Otherwise, unroll by the given unroll factor. if (unrollUpToFactor) @@ -141,9 +144,9 @@ LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) { std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::affine::createLoopUnrollPass( - int unrollFactor, bool unrollUpToFactor, bool unrollFull, + int unrollFactor, bool unrollUpToFactor, const std::function<unsigned(AffineForOp)> &getUnrollFactor) { return std::make_unique<LoopUnroll>( unrollFactor == -1 ? std::nullopt : std::optional<unsigned>(unrollFactor), - unrollUpToFactor, unrollFull, getUnrollFactor); + unrollUpToFactor, getUnrollFactor); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index a6159ee..f0ddb50 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -14,13 +14,6 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" -namespace mlir { -namespace bufferization { -#define GEN_PASS_DEF_TENSORCOPYINSERTION -#include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" -} // namespace bufferization -} // namespace mlir - using namespace mlir; using namespace mlir::bufferization; diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 2a8c330..f0de4db 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -320,6 +320,51 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() { return success(); } +LogicalResult ConvertF8x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E4M3FNType::get(ctx) << " and " + << mlir::Float8E5M2Type::get(ctx) + << " types are supported for conversions from f8x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF8x2ToBF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + if (!llvm::isa<Float8E8M0FNUType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float8E8M0FNUType::get(ctx) + << " type is supported for conversions from f8x2 to bf16x2."; + + return success(); +} + +LogicalResult ConvertF6x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float6E2M3FNType::get(ctx) << " and " + << mlir::Float6E3M2FNType::get(ctx) + << " types are supported for conversions from f6x2 to f16x2."; + + return success(); +} + +LogicalResult ConvertF4x2ToF16x2Op::verify() { + mlir::MLIRContext *ctx = getContext(); + + if (!llvm::isa<Float4E2M1FNType>(getSrcType())) + return emitOpError("Only ") + << mlir::Float4E2M1FNType::get(ctx) + << " type is supported for conversions from f4x2 to f16x2."; + + return success(); +} + LogicalResult BulkStoreOp::verify() { if (getInitVal() != 0) return emitOpError("only 0 is supported for initVal, got ") << getInitVal(); @@ -2187,6 +2232,98 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn; + }) + .Case<Float8E5M2Type>([&](Float8E5M2Type type) { + return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op); + + llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2; + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn; + }) + .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *packedI16 = + builder.CreateBitCast(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {packedI16}}; +} + +NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op); + + bool hasRelu = curOp.getRelu(); + + llvm::Intrinsic::ID intId = + llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType()) + .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) { + return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu + : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn; + }) + .Default([](mlir::Type type) { + llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op"); + return llvm::Intrinsic::not_intrinsic; + }); + + llvm::Value *extendedI16 = + builder.CreateZExt(mt.lookupValue(curOp.getSrc()), + llvm::Type::getInt16Ty(builder.getContext())); + + return {intId, {extendedI16}}; +} + llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt index d4ff095..37a45d4 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -18,4 +18,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms MLIRPass MLIRTransforms MLIRNVVMDialect + MLIROpenMPDialect ) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 9a8a63e..794dda9 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -437,13 +437,15 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, for (auto [pos, dim] : llvm::enumerate(type.getShape())) { if (!ShapedType::isDynamic(dim)) continue; - Value cst = rewriter.create<arith::ConstantIndexOp>(tensor.getLoc(), pos); - auto dimOp = rewriter.create<tensor::DimOp>(tensor.getLoc(), tensor, cst); + Value cst = + arith::ConstantIndexOp::create(rewriter, tensor.getLoc(), pos); + auto dimOp = + tensor::DimOp::create(rewriter, tensor.getLoc(), tensor, cst); preservedOps.insert(dimOp); dynamicDims.push_back(dimOp); } - auto allocation = rewriter.create<bufferization::AllocTensorOp>( - tensor.getLoc(), type, dynamicDims); + auto allocation = bufferization::AllocTensorOp::create( + rewriter, tensor.getLoc(), type, dynamicDims); // Set memory space if provided. if (getMemorySpaceAttr()) allocation.setMemorySpaceAttr(getMemorySpaceAttr()); @@ -452,8 +454,8 @@ transform::PromoteTensorOp::apply(transform::TransformRewriter &rewriter, // Only insert a materialization (typically bufferizes to a copy) when the // value may be read from. if (needsMaterialization) { - auto copy = rewriter.create<bufferization::MaterializeInDestinationOp>( - tensor.getLoc(), tensor, allocated); + auto copy = bufferization::MaterializeInDestinationOp::create( + rewriter, tensor.getLoc(), tensor, allocated); preservedOps.insert(copy); promoted.push_back(copy.getResult()); } else { diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index 15eb51a..5e10ba3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" @@ -43,6 +44,33 @@ struct StructuredOpInterface auto zero = arith::ConstantIndexOp::create(builder, loc, 0); auto one = arith::ConstantIndexOp::create(builder, loc, 1); + Value iterationDomainIsNonDegenerate; + for (auto [start, end] : llvm::zip(starts, ends)) { + auto startValue = getValueOrCreateConstantIndexOp(builder, loc, start); + auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); + + // Loop Trip count > 0 iff start < end + Value dimensionHasNonZeroTripCount = index::CmpOp::create( + builder, loc, index::IndexCmpPredicate::SLT, startValue, endValue); + + if (!iterationDomainIsNonDegenerate) { + iterationDomainIsNonDegenerate = dimensionHasNonZeroTripCount; + } else { + // Iteration domain is non-degenerate iff all dimensions have loop trip + // count > 0 + iterationDomainIsNonDegenerate = + arith::AndIOp::create(builder, loc, iterationDomainIsNonDegenerate, + dimensionHasNonZeroTripCount); + } + } + + if (!iterationDomainIsNonDegenerate) + return; + + auto ifOp = scf::IfOp::create(builder, loc, iterationDomainIsNonDegenerate, + /*withElseRegion=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Subtract one from the loop ends before composing with the indexing map transform(ends, ends.begin(), [&](OpFoldResult end) { auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); @@ -110,6 +138,7 @@ struct StructuredOpInterface builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg); } } + builder.setInsertionPointAfter(ifOp); } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 94947b7..c551fba 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1437,6 +1437,13 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, atLeastOneReplacement |= replaceConstantUsesOf( builder, getLoc(), getStrides(), getConstifiedMixedStrides()); + // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x). + if (auto prev = getSource().getDefiningOp<CastOp>()) + if (isa<MemRefType>(prev.getSource().getType())) { + getSourceMutable().assign(prev.getSource()); + atLeastOneReplacement = true; + } + return success(atLeastOneReplacement); } @@ -1744,11 +1751,11 @@ OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) { } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getSource()); + return getSource(); } TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() { - return cast<TypedValue<PtrLikeTypeInterface>>(getDest()); + return getDest(); } bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt, diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index 11400de..a15bf89 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -59,6 +59,17 @@ struct DimOpInterface } }; +struct ExpandShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<ExpandShapeOpInterface, + memref::ExpandShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto expandOp = cast<memref::ExpandShapeOp>(op); + assert(value == expandOp.getResult() && "invalid value"); + cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + } +}; + struct GetGlobalOpInterface : public ValueBoundsOpInterface::ExternalModel<GetGlobalOpInterface, GetGlobalOp> { @@ -123,6 +134,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( + *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); memref::RankOp::attachInterface<memref::RankOpInterface>(*ctx); memref::SubViewOp::attachInterface<memref::SubViewOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index d35566a..bd02516 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -1033,91 +1033,6 @@ class ExtractStridedMetadataOpReinterpretCastFolder } }; -/// Replace `base, offset, sizes, strides = -/// extract_strided_metadata( -/// cast(src) to dstTy)` -/// With -/// ``` -/// base, ... = extract_strided_metadata(src) -/// offset = !dstTy.srcOffset.isDynamic() -/// ? dstTy.srcOffset -/// : extract_strided_metadata(src).offset -/// sizes = for each srcSize in dstTy.srcSizes: -/// !srcSize.isDynamic() -/// ? srcSize -// : extract_strided_metadata(src).sizes[i] -/// strides = for each srcStride in dstTy.srcStrides: -/// !srcStrides.isDynamic() -/// ? srcStrides -/// : extract_strided_metadata(src).strides[i] -/// ``` -/// -/// In other words, consume the `cast` and apply its effects -/// on the offset, sizes, and strides or compute them directly from `src`. -class ExtractStridedMetadataOpCastFolder - : public OpRewritePattern<memref::ExtractStridedMetadataOp> { - using OpRewritePattern::OpRewritePattern; - - LogicalResult - matchAndRewrite(memref::ExtractStridedMetadataOp extractStridedMetadataOp, - PatternRewriter &rewriter) const override { - Value source = extractStridedMetadataOp.getSource(); - auto castOp = source.getDefiningOp<memref::CastOp>(); - if (!castOp) - return failure(); - - Location loc = extractStridedMetadataOp.getLoc(); - // Check if the source is suitable for extract_strided_metadata. - SmallVector<Type> inferredReturnTypes; - if (failed(extractStridedMetadataOp.inferReturnTypes( - rewriter.getContext(), loc, {castOp.getSource()}, - /*attributes=*/{}, /*properties=*/nullptr, /*regions=*/{}, - inferredReturnTypes))) - return rewriter.notifyMatchFailure(castOp, - "cast source's type is incompatible"); - - auto memrefType = cast<MemRefType>(source.getType()); - unsigned rank = memrefType.getRank(); - SmallVector<OpFoldResult> results; - results.resize_for_overwrite(rank * 2 + 2); - - auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( - rewriter, loc, castOp.getSource()); - - // Register the base_buffer. - results[0] = newExtractStridedMetadata.getBaseBuffer(); - - auto getConstantOrValue = [&rewriter](int64_t constant, - OpFoldResult ofr) -> OpFoldResult { - return ShapedType::isStatic(constant) - ? OpFoldResult(rewriter.getIndexAttr(constant)) - : ofr; - }; - - auto [sourceStrides, sourceOffset] = memrefType.getStridesAndOffset(); - assert(sourceStrides.size() == rank && "unexpected number of strides"); - - // Register the new offset. - results[1] = - getConstantOrValue(sourceOffset, newExtractStridedMetadata.getOffset()); - - const unsigned sizeStartIdx = 2; - const unsigned strideStartIdx = sizeStartIdx + rank; - ArrayRef<int64_t> sourceSizes = memrefType.getShape(); - - SmallVector<OpFoldResult> sizes = newExtractStridedMetadata.getSizes(); - SmallVector<OpFoldResult> strides = newExtractStridedMetadata.getStrides(); - for (unsigned i = 0; i < rank; ++i) { - results[sizeStartIdx + i] = getConstantOrValue(sourceSizes[i], sizes[i]); - results[strideStartIdx + i] = - getConstantOrValue(sourceStrides[i], strides[i]); - } - rewriter.replaceOp(extractStridedMetadataOp, - getValueOrCreateConstantIndexOp(rewriter, loc, results)); - return success(); - } -}; - /// Replace `base, offset, sizes, strides = extract_strided_metadata( /// memory_space_cast(src) to dstTy)` /// with @@ -1209,7 +1124,6 @@ void memref::populateExpandStridedMetadataPatterns( RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, ExtractStridedMetadataOpSubviewFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( @@ -1226,7 +1140,6 @@ void memref::populateResolveExtractStridedMetadataPatterns( ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp, ExtractStridedMetadataOpReinterpretCastFolder, - ExtractStridedMetadataOpCastFolder, ExtractStridedMetadataOpMemorySpaceCastFolder, ExtractStridedMetadataOpAssumeAlignmentFolder, ExtractStridedMetadataOpExtractStridedMetadataFolder>( diff --git a/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt new file mode 100644 index 0000000..f305068 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIROpenACCAnalysis + OpenACCSupport.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenACC + + LINK_LIBS PUBLIC + MLIRIR + MLIROpenACCDialect + MLIROpenACCUtils + MLIRSupport +) + diff --git a/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp new file mode 100644 index 0000000..f6b4534 --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Analysis/OpenACCSupport.cpp @@ -0,0 +1,26 @@ +//===- OpenACCSupport.cpp - OpenACCSupport Implementation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the OpenACCSupport analysis interface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" +#include "mlir/Dialect/OpenACC/OpenACCUtils.h" + +namespace mlir { +namespace acc { + +std::string OpenACCSupport::getVariableName(Value v) { + if (impl) + return impl->getVariableName(v); + return acc::getVariableName(v); +} + +} // namespace acc +} // namespace mlir diff --git a/mlir/lib/Dialect/OpenACC/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/CMakeLists.txt index 7117520..e8a916e 100644 --- a/mlir/lib/Dialect/OpenACC/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Analysis) add_subdirectory(IR) add_subdirectory(Utils) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 5ca0100..ca46629 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -610,6 +610,20 @@ LogicalResult acc::FirstprivateOp::verify() { } //===----------------------------------------------------------------------===// +// FirstprivateMapInitialOp +//===----------------------------------------------------------------------===// +LogicalResult acc::FirstprivateMapInitialOp::verify() { + if (getDataClause() != acc::DataClause::acc_firstprivate) + return emitError("data clause associated with firstprivate operation must " + "match its intent"); + if (failed(checkVarAndVarType(*this))) + return failure(); + if (failed(checkNoModifier(*this))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// LogicalResult acc::ReductionOp::verify() { diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp index 1223325..89adda82 100644 --- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp +++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtils.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/OpenACC/OpenACCUtils.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/TypeSwitch.h" mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region ®ion) { @@ -78,3 +79,30 @@ mlir::acc::VariableTypeCategory mlir::acc::getTypeCategory(mlir::Value var) { pointerLikeTy.getElementType()); return typeCategory; } + +std::string mlir::acc::getVariableName(mlir::Value v) { + Value current = v; + + // Walk through view operations until a name is found or can't go further + while (Operation *definingOp = current.getDefiningOp()) { + // Check for `acc.var_name` attribute + if (auto varNameAttr = + definingOp->getAttrOfType<VarNameAttr>(getVarNameAttrName())) + return varNameAttr.getName().str(); + + // If it is a data entry operation, get name via getVarName + if (isa<ACC_DATA_ENTRY_OPS>(definingOp)) + if (auto name = acc::getVarName(definingOp)) + return name->str(); + + // If it's a view operation, continue to the source + if (auto viewOp = dyn_cast<ViewLikeOpInterface>(definingOp)) { + current = viewOp.getViewSource(); + continue; + } + + break; + } + + return ""; +} diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt index 57a6d34..f3c02da 100644 --- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Transforms) + add_mlir_dialect_library(MLIROpenMPDialect IR/OpenMPDialect.cpp diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index fd4cabbad..1b069c6 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -32,7 +32,6 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" -#include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Support/InterleavedRange.h" #include <cstddef> #include <iterator> @@ -1737,10 +1736,10 @@ static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { // Parser, printer and verifier for Target //===----------------------------------------------------------------------===// -// Helper function to get bitwise AND of `value` and 'flag' -static uint64_t mapTypeToBitFlag(uint64_t value, - llvm::omp::OpenMPOffloadMappingFlags flag) { - return value & llvm::to_underlying(flag); +// Helper function to get bitwise AND of `value` and 'flag' then return it as a +// boolean +static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) { + return (value & flag) == flag; } /// Parses a map_entries map type from a string format back into its numeric @@ -1748,10 +1747,9 @@ static uint64_t mapTypeToBitFlag(uint64_t value, /// /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `? /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` ) -static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { - llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; - +static ParseResult parseMapClause(OpAsmParser &parser, + ClauseMapFlagsAttr &mapType) { + ClauseMapFlags mapTypeBits = ClauseMapFlags::none; // This simply verifies the correct keyword is read in, the // keyword itself is stored inside of the operation auto parseTypeAndMod = [&]() -> ParseResult { @@ -1760,35 +1758,64 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { return failure(); if (mapTypeMod == "always") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + mapTypeBits |= ClauseMapFlags::always; if (mapTypeMod == "implicit") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + mapTypeBits |= ClauseMapFlags::implicit; if (mapTypeMod == "ompx_hold") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; + mapTypeBits |= ClauseMapFlags::ompx_hold; if (mapTypeMod == "close") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + mapTypeBits |= ClauseMapFlags::close; if (mapTypeMod == "present") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + mapTypeBits |= ClauseMapFlags::present; if (mapTypeMod == "to") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + mapTypeBits |= ClauseMapFlags::to; if (mapTypeMod == "from") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= ClauseMapFlags::from; if (mapTypeMod == "tofrom") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from; if (mapTypeMod == "delete") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + mapTypeBits |= ClauseMapFlags::del; + + if (mapTypeMod == "storage") + mapTypeBits |= ClauseMapFlags::storage; if (mapTypeMod == "return_param") - mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + mapTypeBits |= ClauseMapFlags::return_param; + + if (mapTypeMod == "private") + mapTypeBits |= ClauseMapFlags::priv; + + if (mapTypeMod == "literal") + mapTypeBits |= ClauseMapFlags::literal; + + if (mapTypeMod == "attach") + mapTypeBits |= ClauseMapFlags::attach; + + if (mapTypeMod == "attach_always") + mapTypeBits |= ClauseMapFlags::attach_always; + + if (mapTypeMod == "attach_none") + mapTypeBits |= ClauseMapFlags::attach_none; + + if (mapTypeMod == "attach_auto") + mapTypeBits |= ClauseMapFlags::attach_auto; + + if (mapTypeMod == "ref_ptr") + mapTypeBits |= ClauseMapFlags::ref_ptr; + + if (mapTypeMod == "ref_ptee") + mapTypeBits |= ClauseMapFlags::ref_ptee; + + if (mapTypeMod == "ref_ptr_ptee") + mapTypeBits |= ClauseMapFlags::ref_ptr_ptee; return success(); }; @@ -1796,9 +1823,8 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { if (parser.parseCommaSeparatedList(parseTypeAndMod)) return failure(); - mapType = parser.getBuilder().getIntegerAttr( - parser.getBuilder().getIntegerType(64, /*isSigned=*/false), - llvm::to_underlying(mapTypeBits)); + mapType = + parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits); return success(); } @@ -1806,60 +1832,62 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { /// Prints a map_entries map type from its numeric value out into its string /// format. static void printMapClause(OpAsmPrinter &p, Operation *op, - IntegerAttr mapType) { - uint64_t mapTypeBits = mapType.getUInt(); - - bool emitAllocRelease = true; + ClauseMapFlagsAttr mapType) { llvm::SmallVector<std::string, 4> mapTypeStrs; + ClauseMapFlags mapFlags = mapType.getValue(); // handling of always, close, present placed at the beginning of the string // to aid readability - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::always)) mapTypeStrs.push_back("always"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit)) mapTypeStrs.push_back("implicit"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold)) mapTypeStrs.push_back("ompx_hold"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::close)) mapTypeStrs.push_back("close"); - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) + if (mapTypeToBool(mapFlags, ClauseMapFlags::present)) mapTypeStrs.push_back("present"); // special handling of to/from/tofrom/delete and release/alloc, release + // alloc are the abscense of one of the other flags, whereas tofrom requires // both the to and from flag to be set. - bool to = mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - if (to && from) { - emitAllocRelease = false; + bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to); + bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from); + + if (to && from) mapTypeStrs.push_back("tofrom"); - } else if (from) { - emitAllocRelease = false; + else if (from) mapTypeStrs.push_back("from"); - } else if (to) { - emitAllocRelease = false; + else if (to) mapTypeStrs.push_back("to"); - } - if (mapTypeToBitFlag(mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) { - emitAllocRelease = false; + + if (mapTypeToBool(mapFlags, ClauseMapFlags::del)) mapTypeStrs.push_back("delete"); - } - if (mapTypeToBitFlag( - mapTypeBits, - llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) { - emitAllocRelease = false; + if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param)) mapTypeStrs.push_back("return_param"); - } - if (emitAllocRelease) - mapTypeStrs.push_back("exit_release_or_enter_alloc"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::storage)) + mapTypeStrs.push_back("storage"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::priv)) + mapTypeStrs.push_back("private"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::literal)) + mapTypeStrs.push_back("literal"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach)) + mapTypeStrs.push_back("attach"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always)) + mapTypeStrs.push_back("attach_always"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_none)) + mapTypeStrs.push_back("attach_none"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto)) + mapTypeStrs.push_back("attach_auto"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr)) + mapTypeStrs.push_back("ref_ptr"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee)) + mapTypeStrs.push_back("ref_ptee"); + if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee)) + mapTypeStrs.push_back("ref_ptr_ptee"); + if (mapFlags == ClauseMapFlags::none) + mapTypeStrs.push_back("none"); for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) { p << mapTypeStrs[i]; @@ -1963,21 +1991,15 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { return emitError(op->getLoc(), "missing map operation"); if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) { - uint64_t mapTypeBits = mapInfoOp.getMapType(); - - bool to = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); - bool from = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); - bool del = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); - - bool always = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); - bool close = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); - bool implicit = mapTypeToBitFlag( - mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); + mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType(); + + bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to); + bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from); + bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del); + + bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always); + bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close); + bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit); if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del) return emitError(op->getLoc(), diff --git a/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt new file mode 100644 index 0000000..b9b8eda --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIROpenMPTransforms + OpenMPOffloadPrivatizationPrepare.cpp + + DEPENDS + MLIROpenMPPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRFuncDialect + MLIRLLVMDialect + MLIROpenMPDialect + MLIRPass + MLIRTransforms + ) diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp new file mode 100644 index 0000000..a9125ec --- /dev/null +++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp @@ -0,0 +1,445 @@ +//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include <cstdint> +#include <iterator> +#include <utility> + +//===----------------------------------------------------------------------===// +// A pass that prepares OpenMP code for translation of delayed privatization +// in the context of deferred target tasks. Deferred target tasks are created +// when the nowait clause is used on the target directive. +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "omp-prepare-for-offload-privatization" + +namespace mlir { +namespace omp { + +#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS +#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc" + +} // namespace omp +} // namespace mlir + +using namespace mlir; +namespace { + +//===----------------------------------------------------------------------===// +// PrepareForOMPOffloadPrivatizationPass +//===----------------------------------------------------------------------===// + +class PrepareForOMPOffloadPrivatizationPass + : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase< + PrepareForOMPOffloadPrivatizationPass> { + + void runOnOperation() override { + ModuleOp mod = getOperation(); + + // In this pass, we make host-allocated privatized variables persist for + // deferred target tasks by copying them to the heap. Once the target task + // is done, this heap memory is freed. Since all of this happens on the host + // we can skip device modules. + auto offloadModuleInterface = + dyn_cast<omp::OffloadModuleInterface>(mod.getOperation()); + if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice()) + return; + + getOperation()->walk([&](omp::TargetOp targetOp) { + if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp)) + return; + IRRewriter rewriter(&getContext()); + OperandRange privateVars = targetOp.getPrivateVars(); + SmallVector<mlir::Value> newPrivVars; + Value fakeDependVar; + omp::TaskOp cleanupTaskOp; + + newPrivVars.reserve(privateVars.size()); + std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms(); + for (auto [privVarIdx, privVarSymPair] : + llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) { + Value privVar = std::get<0>(privVarSymPair); + Attribute privSym = std::get<1>(privVarSymPair); + + omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym); + if (!privatizer.needsMap()) { + newPrivVars.push_back(privVar); + continue; + } + bool isFirstPrivate = privatizer.getDataSharingType() == + omp::DataSharingClauseType::FirstPrivate; + + Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx); + auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp()); + + if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) { + newPrivVars.push_back(privVar); + continue; + } + + // For deferred target tasks (!$omp target nowait), we need to keep + // a copy of the original, i.e. host variable being privatized so + // that it is available when the target task is eventually executed. + // We do this by first allocating as much heap memory as is needed by + // the original variable. Then, we use the init and copy regions of the + // privatizer, an instance of omp::PrivateClauseOp to set up the heap- + // allocated copy. + // After the target task is done, we need to use the dealloc region + // of the privatizer to clean up everything. We also need to free + // the heap memory we allocated. But due to the deferred nature + // of the target task, we cannot simply deallocate right after the + // omp.target operation else we may end up freeing memory before + // its eventual use by the target task. So, we create a dummy + // dependence between the target task and new omp.task. In the omp.task, + // we do all the cleanup. So, we end up with the following structure + // + // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) { + // ... + // omp.terminator + // } + // omp.task depend(in: fakeDependVar) { + // /*cleanup_code*/ + // omp.terminator + // } + // fakeDependVar is the address of the first heap-allocated copy of the + // host variable being privatized. + + bool needsCleanupTask = !privatizer.getDeallocRegion().empty(); + + // Allocate heap memory that corresponds to the type of memory + // pointed to by varPtr + // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols + // should have mapped the pointer to the boxchar so use that as varPtr. + Value varPtr = mapInfoOp.getVarPtr(); + Type varType = mapInfoOp.getVarType(); + bool isPrivatizedByValue = + !isa<LLVM::LLVMPointerType>(privVar.getType()); + + assert(isa<LLVM::LLVMPointerType>(varPtr.getType())); + Value heapMem = + allocateHeapMem(targetOp, varPtr, varType, mod, rewriter); + if (!heapMem) + targetOp.emitError( + "Unable to allocate heap memory when trying to move " + "a private variable out of the stack and into the " + "heap for use by a deferred target task"); + + if (needsCleanupTask && !fakeDependVar) + fakeDependVar = heapMem; + + // The types of private vars should match before and after the + // transformation. In particular, if the type is a pointer, + // simply record the newly allocated malloc location as the + // new private variable. If, however, the type is not a pointer + // then, we need to load the value from the newly allocated + // location. We'll insert that load later after we have updated + // the malloc'd location with the contents of the original + // variable. + if (!isPrivatizedByValue) + newPrivVars.push_back(heapMem); + + // We now need to copy the original private variable into the newly + // allocated location in the heap. + // Find the earliest insertion point for the copy. This will be before + // the first in the list of omp::MapInfoOp instances that use varPtr. + // After the copy these omp::MapInfoOp instances will refer to heapMem + // instead. + Operation *varPtrDefiningOp = varPtr.getDefiningOp(); + DenseSet<Operation *> users; + if (varPtrDefiningOp) { + users.insert(varPtrDefiningOp->user_begin(), + varPtrDefiningOp->user_end()); + } else { + auto blockArg = cast<BlockArgument>(varPtr); + users.insert(blockArg.user_begin(), blockArg.user_end()); + } + auto usesVarPtr = [&users](Operation *op) -> bool { + return users.count(op); + }; + + SmallVector<Operation *> chainOfOps; + chainOfOps.push_back(mapInfoOp); + for (auto member : mapInfoOp.getMembers()) { + omp::MapInfoOp memberMap = + cast<omp::MapInfoOp>(member.getDefiningOp()); + if (usesVarPtr(memberMap)) + chainOfOps.push_back(memberMap); + if (memberMap.getVarPtrPtr()) { + Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp(); + if (defOp && usesVarPtr(defOp)) + chainOfOps.push_back(defOp); + } + } + + DominanceInfo dom; + llvm::sort(chainOfOps, [&](Operation *l, Operation *r) { + return dom.dominates(l, r); + }); + + rewriter.setInsertionPoint(chainOfOps.front()); + + Operation *firstOp = chainOfOps.front(); + Location loc = firstOp->getLoc(); + + // Create a llvm.func for 'region' that is marked always_inline and call + // it. + auto createAlwaysInlineFuncAndCallIt = + [&](Region ®ion, llvm::StringRef funcName, + llvm::ArrayRef<Value> args, bool returnsValue) -> Value { + assert(!region.empty() && "region cannot be empty"); + LLVM::LLVMFuncOp func = createFuncOpForRegion( + loc, mod, region, funcName, rewriter, returnsValue); + auto call = LLVM::CallOp::create(rewriter, loc, func, args); + return call.getResult(); + }; + + Value moldArg, newArg; + if (isPrivatizedByValue) { + moldArg = LLVM::LoadOp::create(rewriter, loc, varType, varPtr); + newArg = LLVM::LoadOp::create(rewriter, loc, varType, heapMem); + } else { + moldArg = varPtr; + newArg = heapMem; + } + + Value initializedVal; + if (!privatizer.getInitRegion().empty()) + initializedVal = createAlwaysInlineFuncAndCallIt( + privatizer.getInitRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(), + {moldArg, newArg}, /*returnsValue=*/true); + else + initializedVal = newArg; + + if (isFirstPrivate && !privatizer.getCopyRegion().empty()) + initializedVal = createAlwaysInlineFuncAndCallIt( + privatizer.getCopyRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(), + {moldArg, initializedVal}, /*returnsValue=*/true); + + if (isPrivatizedByValue) + (void)LLVM::StoreOp::create(rewriter, loc, initializedVal, heapMem); + + // clone origOp, replace all uses of varPtr with heapMem and + // erase origOp. + auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * { + Operation *clonedOp = rewriter.clone(*origOp); + rewriter.replaceAllOpUsesWith(origOp, clonedOp); + rewriter.modifyOpInPlace(clonedOp, [&]() { + clonedOp->replaceUsesOfWith(varPtr, heapMem); + }); + rewriter.eraseOp(origOp); + return clonedOp; + }; + + // Now that we have set up the heap-allocated copy of the private + // variable, rewrite all the uses of the original variable with + // the heap-allocated variable. + rewriter.setInsertionPoint(targetOp); + mapInfoOp = cast<omp::MapInfoOp>(cloneModifyAndErase(mapInfoOp)); + rewriter.setInsertionPoint(mapInfoOp); + + // Fix any members that may use varPtr to now use heapMem + for (auto member : mapInfoOp.getMembers()) { + auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp()); + if (!usesVarPtr(memberMapInfoOp)) + continue; + memberMapInfoOp = + cast<omp::MapInfoOp>(cloneModifyAndErase(memberMapInfoOp)); + rewriter.setInsertionPoint(memberMapInfoOp); + + if (memberMapInfoOp.getVarPtrPtr()) { + Operation *varPtrPtrdefOp = + memberMapInfoOp.getVarPtrPtr().getDefiningOp(); + rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp)); + } + } + + // If the type of the private variable is not a pointer, + // which is typically the case with !fir.boxchar types, then + // we need to ensure that the new private variable is also + // not a pointer. Insert a load from heapMem right before + // targetOp. + if (isPrivatizedByValue) { + rewriter.setInsertionPoint(targetOp); + auto newPrivVar = LLVM::LoadOp::create(rewriter, mapInfoOp.getLoc(), + varType, heapMem); + newPrivVars.push_back(newPrivVar); + } + + // Deallocate + if (needsCleanupTask) { + if (!cleanupTaskOp) { + assert(fakeDependVar && + "Need a valid value to set up a dependency"); + rewriter.setInsertionPointAfter(targetOp); + omp::TaskOperands taskOperands; + auto inDepend = omp::ClauseTaskDependAttr::get( + rewriter.getContext(), omp::ClauseTaskDepend::taskdependin); + taskOperands.dependKinds.push_back(inDepend); + taskOperands.dependVars.push_back(fakeDependVar); + cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands); + Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion()); + rewriter.setInsertionPointToEnd(taskBlock); + omp::TerminatorOp::create(rewriter, cleanupTaskOp.getLoc()); + } + rewriter.setInsertionPointToStart( + &*cleanupTaskOp.getRegion().getBlocks().begin()); + (void)createAlwaysInlineFuncAndCallIt( + privatizer.getDeallocRegion(), + llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc") + .str(), + {initializedVal}, /*returnsValue=*/false); + llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc = + LLVM::lookupOrCreateFreeFn(rewriter, mod); + assert(llvm::succeeded(freeFunc) && + "Could not find free in the module"); + (void)LLVM::CallOp::create(rewriter, loc, freeFunc.value(), + ValueRange{heapMem}); + } + } + assert(newPrivVars.size() == privateVars.size() && + "The number of private variables must match before and after " + "transformation"); + if (fakeDependVar) { + omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get( + rewriter.getContext(), omp::ClauseTaskDepend::taskdependout); + SmallVector<Attribute> newDependKinds; + if (!targetOp.getDependVars().empty()) { + std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds(); + assert(dependKinds && "bad depend clause in omp::TargetOp"); + llvm::copy(*dependKinds, std::back_inserter(newDependKinds)); + } + newDependKinds.push_back(outDepend); + ArrayAttr newDependKindsAttr = + ArrayAttr::get(rewriter.getContext(), newDependKinds); + targetOp.getDependVarsMutable().append(fakeDependVar); + targetOp.setDependKindsAttr(newDependKindsAttr); + } + rewriter.setInsertionPoint(targetOp); + targetOp.getPrivateVarsMutable().clear(); + targetOp.getPrivateVarsMutable().assign(newPrivVars); + }); + } + +private: + bool hasPrivateVars(omp::TargetOp targetOp) const { + return !targetOp.getPrivateVars().empty(); + } + + bool isTargetTaskDeferred(omp::TargetOp targetOp) const { + return targetOp.getNowait(); + } + + template <typename OpTy> + omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const { + SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym); + omp::PrivateClauseOp privatizer = + SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>( + op, privatizerName); + return privatizer; + } + + // Get the (compile-time constant) size of varType as per the + // given DataLayout dl. + std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const { + llvm::TypeSize size = dl.getTypeSize(varType); + unsigned short alignment = dl.getTypeABIAlignment(varType); + return llvm::alignTo(size, alignment); + } + + LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const { + llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall = + LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type()); + assert(llvm::succeeded(mallocCall) && + "Could not find malloc in the module"); + return mallocCall.value(); + } + + Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType, + ModuleOp mod, IRRewriter &rewriter) const { + OpBuilder::InsertionGuard guard(rewriter); + Value varPtr = privVar; + Operation *definingOp = varPtr.getDefiningOp(); + BlockArgument blockArg; + if (!definingOp) { + blockArg = mlir::dyn_cast<BlockArgument>(varPtr); + rewriter.setInsertionPointToStart(blockArg.getParentBlock()); + } else { + rewriter.setInsertionPoint(definingOp); + } + Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc(); + LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter); + + assert(mod.getDataLayoutSpec() && + "MLIR module with no datalayout spec not handled yet"); + + const DataLayout &dl = DataLayout(mod); + std::int64_t distance = getSizeInBytes(dl, varType); + + Value sizeBytes = LLVM::ConstantOp::create( + rewriter, loc, mallocFn.getFunctionType().getParamType(0), distance); + + auto mallocCallOp = + LLVM::CallOp::create(rewriter, loc, mallocFn, ValueRange{sizeBytes}); + return mallocCallOp.getResult(); + } + + // Create a function for srcRegion and attribute it to be always_inline. + // The big assumption here is that srcRegion is one of init, copy or dealloc + // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed + // to either be the same as the types of the two arguments of the region (for + // init and copy regions) or void as would be the case for dealloc regions. + LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod, + Region &srcRegion, + llvm::StringRef funcName, + IRRewriter &rewriter, + bool returnsValue = false) { + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end()); + Region clonedRegion; + IRMapping mapper; + srcRegion.cloneInto(&clonedRegion, mapper); + + SmallVector<Type> paramTypes; + llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes)); + Type resultType = returnsValue + ? srcRegion.getArgument(0).getType() + : LLVM::LLVMVoidType::get(rewriter.getContext()); + LLVM::LLVMFunctionType funcType = + LLVM::LLVMFunctionType::get(resultType, paramTypes); + + LLVM::LLVMFuncOp func = + LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType); + func.setAlwaysInline(true); + rewriter.inlineRegionBefore(clonedRegion, func.getRegion(), + func.getRegion().end()); + for (auto &block : func.getRegion().getBlocks()) { + if (isa<omp::YieldOp>(block.getTerminator())) { + omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator()); + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(), + yieldOp.getOperands()); + } + } + return func; + } +}; +} // namespace diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index a9da6c2..744a595 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -27,6 +27,7 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" #include "llvm/Support/DebugLog.h" @@ -291,9 +292,102 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { } }; +// Pattern to eliminate ExecuteRegionOp results which forward external +// values from the region. In case there are multiple yield operations, +// all of them must have the same operands in order for the pattern to be +// applicable. +struct ExecuteRegionForwardingEliminator + : public OpRewritePattern<ExecuteRegionOp> { + using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(ExecuteRegionOp op, + PatternRewriter &rewriter) const override { + if (op.getNumResults() == 0) + return failure(); + + SmallVector<Operation *> yieldOps; + for (Block &block : op.getRegion()) { + if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) + yieldOps.push_back(yield.getOperation()); + } + + if (yieldOps.empty()) + return failure(); + + // Check if all yield operations have the same operands. + auto yieldOpsOperands = yieldOps[0]->getOperands(); + for (auto *yieldOp : yieldOps) { + if (yieldOp->getOperands() != yieldOpsOperands) + return failure(); + } + + SmallVector<Value> externalValues; + SmallVector<Value> internalValues; + SmallVector<Value> opResultsToReplaceWithExternalValues; + SmallVector<Value> opResultsToKeep; + for (auto [index, yieldedValue] : llvm::enumerate(yieldOpsOperands)) { + if (isValueFromInsideRegion(yieldedValue, op)) { + internalValues.push_back(yieldedValue); + opResultsToKeep.push_back(op.getResult(index)); + } else { + externalValues.push_back(yieldedValue); + opResultsToReplaceWithExternalValues.push_back(op.getResult(index)); + } + } + // No yielded external values - nothing to do. + if (externalValues.empty()) + return failure(); + + // There are yielded external values - create a new execute_region returning + // just the internal values. + SmallVector<Type> resultTypes; + for (Value value : internalValues) + resultTypes.push_back(value.getType()); + auto newOp = + ExecuteRegionOp::create(rewriter, op.getLoc(), TypeRange(resultTypes)); + newOp->setAttrs(op->getAttrs()); + + // Move old op's region to the new operation. + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); + + // Replace all yield operations with a new yield operation with updated + // results. scf.execute_region must have at least one yield operation. + for (auto *yieldOp : yieldOps) { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, + ValueRange(internalValues)); + } + + // Replace the old operation with the external values directly. + rewriter.replaceAllUsesWith(opResultsToReplaceWithExternalValues, + externalValues); + // Replace the old operation's remaining results with the new operation's + // results. + rewriter.replaceAllUsesWith(opResultsToKeep, newOp.getResults()); + rewriter.eraseOp(op); + return success(); + } + +private: + bool isValueFromInsideRegion(Value value, + ExecuteRegionOp executeRegionOp) const { + // Check if the value is defined within the execute_region + if (Operation *defOp = value.getDefiningOp()) + return &executeRegionOp.getRegion() == defOp->getParentRegion(); + + // If it's a block argument, check if it's from within the region + if (BlockArgument blockArg = dyn_cast<BlockArgument>(value)) + return &executeRegionOp.getRegion() == blockArg.getParentRegion(); + + return false; // Value is from outside the region + } +}; + void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context); + results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner, + ExecuteRegionForwardingEliminator>(context); } void ExecuteRegionOp::getSuccessorRegions( @@ -2490,8 +2584,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantTrue) - constantTrue = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); + constantTrue = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantTrue); }); @@ -2500,8 +2594,8 @@ struct ConditionPropagation : public OpRewritePattern<IfOp> { changed = true; if (!constantFalse) - constantFalse = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); + constantFalse = arith::ConstantOp::create( + rewriter, op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); rewriter.modifyOpInPlace(use.getOwner(), [&]() { use.set(constantFalse); }); diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 5dc61a2..335ca1a 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -69,10 +69,10 @@ splitLastAxisInResharding(ImplicitLocOpBuilder &builder, Sharding sourceSharding, TypedValue<ShapedType> sourceShard, GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( + TypedValue<ShapedType> targetShard = AllSliceOp::create(builder, sourceShard, grid, ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis) - .getResult()); + .getResult(); Sharding targetSharding = targetShardingInSplitLastAxis( builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); return {targetShard, targetSharding}; @@ -204,9 +204,8 @@ static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding( APInt(64, splitTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allGatherResult) - .getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allGatherResult).getResult(); return {targetShard, targetSharding}; } @@ -336,8 +335,8 @@ moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = shardShapedType(sourceUnshardedShape, grid, targetSharding); - TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); + TypedValue<ShapedType> targetShard = + tensor::CastOp::create(builder, targetShape, allToAllResult).getResult(); return {targetShard, targetSharding}; } @@ -510,8 +509,7 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source, auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding, - cast<TypedValue<ShapedType>>(source.getSrc()), - sourceShardValue); + source.getSrc(), sourceShardValue); } TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index 1cba1bb..32eb286 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -12,6 +12,96 @@ namespace mlir { namespace tosa { +llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { + return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); +} + +TosaSpecificationVersion getMinVersion(const Profile &profile) { + switch (profile) { + case Profile::pro_int: + case Profile::pro_fp: + return TosaSpecificationVersion(1, 0); + case Profile::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA profile"); +} + +TosaSpecificationVersion getMinVersion(const Extension &extension) { + switch (extension) { + case Extension::int16: + case Extension::int4: + case Extension::bf16: + case Extension::fp8e4m3: + case Extension::fp8e5m2: + case Extension::fft: + case Extension::variable: + case Extension::controlflow: + case Extension::doubleround: + case Extension::inexactround: + case Extension::dynamic: + return TosaSpecificationVersion(1, 0); + case Extension::mxfp: + return TosaSpecificationVersion(1, 1); + case Extension::none: + return TosaSpecificationVersion(0, 0); + } + llvm_unreachable("Unknown TOSA extension"); +} + +TosaSpecificationVersion getMinVersion(const Level &level) { + switch (level) { + case Level::eightK: + case Level::none: + return TosaSpecificationVersion(1, 0); + } + llvm_unreachable("Unknown TOSA level"); +} + +FailureOr<TargetEnv> +TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr, + Location targetEnvAttrLoc) { + if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc))) + return failure(); + + return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(), + targetAttr.getProfiles(), targetAttr.getExtensions()); +} + +LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr, + Location targetAttrLoc) { + TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion()); + + const auto isCompatibleWithTargetVersion = + [&](const auto &targetEnum, Location targetAttrLoc, + StringRef enumName) -> LogicalResult { + const TosaSpecificationVersion minRequiredVersion = + getMinVersion(targetEnum); + if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion)) + return emitError(targetAttrLoc, enumName) + << " '" << stringifyEnum(targetEnum) + << "' is not compatible with the target version " + << stringifyVersion(targetVersion) + << ", minimum required version is " + << stringifyVersion(minRequiredVersion); + return success(); + }; + + for (const auto &profile : targetAttr.getProfiles()) + if (failed( + isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile"))) + return failure(); + for (const auto &extension : targetAttr.getExtensions()) + if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc, + "extension"))) + return failure(); + if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc, + "level"))) + return failure(); + + return success(); +} + TargetEnvAttr lookupTargetEnv(Operation *op) { while (op) { op = SymbolTable::getNearestSymbolTable(op); @@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) { return getDefaultTargetEnv(op->getContext()); } -llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) { - return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor()); -} - } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index caf8016..a85ff10a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -76,28 +76,6 @@ template <typename OpTy> struct PoolPadFoldAdaptor; template <> -struct PoolPadFoldAdaptor<tosa::AvgPool2dOp> { - using OpTy = tosa::AvgPool2dOp; - static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { - const llvm::ArrayRef<int64_t> kernel = op.getKernel(); - if (newPad[2] >= kernel[1] || newPad[3] >= kernel[1] || - newPad[0] >= kernel[0] || newPad[1] >= kernel[0]) - return false; - return true; - } - static bool checkPadConstCompliance(OpTy op, Value padConst) { - return checkMatchingPadConstAndZp(padConst, op.getInputZp()); - } - static void replaceOpWithNewPad(PatternRewriter &rewriter, OpTy op, - Value padInput, ArrayRef<int64_t> newPad) { - rewriter.replaceOpWithNewOp<tosa::AvgPool2dOp>( - op, op.getType(), padInput, op.getInputZp(), op.getOutputZp(), - op.getKernel(), op.getStride(), rewriter.getDenseI64ArrayAttr(newPad), - op.getAccType()); - } -}; - -template <> struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> { using OpTy = tosa::MaxPool2dOp; static bool checkKernelCompliance(OpTy op, const ArrayRef<int64_t> newPad) { @@ -245,13 +223,6 @@ struct FoldPadToTensorOp : public OpRewritePattern<OpTy> { }; } // namespace -void AvgPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add<FoldPadToTensorOp<tosa::AvgPool2dOp, - PoolPadFoldAdaptor<tosa::AvgPool2dOp>>>( - context); -} - void Conv2DOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add< @@ -1001,8 +972,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { !outputTy.hasStaticShape()) return {}; - if (inputTy.getDimSize(getAxis()) == 1) - return DenseElementsAttr::get(outputTy, 0); + const Type outputElementTy = getElementTypeOrSelf(outputTy); + if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) { + const auto outputElemIntTy = cast<IntegerType>(outputElementTy); + const APInt zero = APInt::getZero(outputElemIntTy.getWidth()); + return DenseElementsAttr::get(outputTy, zero); + } return {}; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 00f84bc..6cd0eae 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, } } + // special handling: block_size accepts a *bare* BlockSizeMode enum + if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) { + if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeBlockSize(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid block_size value: " << kw; + auto attr = BlockSizeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; @@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { } else if (auto nanPropagationModeAttr = dyn_cast<tosa::NanPropagationModeAttr>(attr)) { parser << nanPropagationModeAttr.getValue(); + } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) { + parser << blockSizeAttr.getValue(); } else { parser.printAttribute(attr); } @@ -508,6 +523,15 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } +ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -933,32 +957,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { // verify that inType and outType have same element types template <typename T> -static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { - auto inputType = llvm::dyn_cast<TensorType>(inType); - auto outputType = llvm::dyn_cast<TensorType>(outType); - if (!inputType) { - op.emitOpError("expect shaped tensor for input, got ") << inType; +static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, + StringRef aName = "input", + StringRef bName = "output") { + auto aTType = llvm::dyn_cast<TensorType>(aType); + auto bTType = llvm::dyn_cast<TensorType>(bType); + if (!aTType) { + op.emitOpError("expect shaped tensor for") << aName << ", got " << aType; return failure(); } - if (!outputType) { - op.emitOpError("expect shaped tensor for output, got ") << outType; + if (!bTType) { + op.emitOpError("expect shaped tensor for") << bName << ", got" << bType; return failure(); } - auto inputElementType = inputType.getElementType(); - auto outputElementType = outputType.getElementType(); - auto inputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType); - auto outputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType); - if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && - (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && - inputElementType != outputElementType) { + auto aElementType = aTType.getElementType(); + auto bElementType = bTType.getElementType(); + auto aQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType); + auto bQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType); + if ((aElementType.isIntOrIndexOrFloat() || aQuantType) && + (bElementType.isIntOrIndexOrFloat() || bQuantType) && + aElementType != bElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir - op.emitOpError("expect input and output to have same element type, got ") - << inputElementType << " and " << outputElementType; + op.emitOpError("expect ") + << aName << " and " << bName << " to have same element type, got " + << aElementType << " and " << bElementType; return failure(); } return success(); @@ -1846,6 +1873,161 @@ LogicalResult MatMulOp::verify() { return success(); } +LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + MatmulTBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic); + + const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType()); + if (aDataShape.hasRank()) { + outShape[0] = aDataShape.getDimSize(0); + outShape[1] = aDataShape.getDimSize(1); + } + + const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType()); + if (aScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0) + : outShape[0]; + outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1) + : outShape[1]; + } + + // If B batch size is 1, it is broadcast across A's batch size + const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType()); + if (bDataShape.hasRank()) { + const int64_t bDataBatchSize = bDataShape.getDimSize(0); + if (bDataBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0]; + outShape[2] = bDataShape.getDimSize(1); + } + + const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType()); + if (bScaleShape.hasRank()) { + const int64_t bScaleBatchSize = bScaleShape.getDimSize(0); + if (bScaleBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0]; + outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1) + : outShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult MatmulTBlockScaledOp::verify() { + // Verify same input data types + const Type aDataType = getAData().getType(); + const Type bDataType = getBData().getType(); + if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data", + "B_data"))) + return failure(); + + auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, + const StringRef operandName, + const StringRef dimName) -> LogicalResult { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); + }; + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t D = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + int64_t multiplesOfC = ShapedType::kDynamic; + + const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType); + if (aDataShape.hasRank()) { + N = aDataShape.getDimSize(0); + H = aDataShape.getDimSize(1); + C = aDataShape.getDimSize(2); + } + + const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); + if (aScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", + "batch")) || + failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", + "height"))) + return failure(); + multiplesOfC = aScaleShape.getDimSize(2); + } + + const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); + if (bDataShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", + "batch")) || + failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", + "channels"))) + return failure(); + W = bDataShape.getDimSize(1); + } + + const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); + if (bScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", + "batch")) || + failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", + "width")) || + failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), + "b_scale", "C/block_size"))) + return failure(); + } + + // Verify batch size is broadcast compatible + if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1) + return emitOpError("expect B matrix batch size to be broadcast compatible " + "with A, got D=") + << D << " vs N=" << N; + + // Verify C is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(C) && C % blockSize != 0) + return emitOpError("expect C to be a multiple of block size, got C=") + << C << ", block_size=" << blockSize; + + // Verify multiplesOfC is C / block size + if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) && + multiplesOfC != C / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal C/block_size (") + << C << "/" << blockSize << ")" + << ", got " << multiplesOfC; + + // Verify output shape + N = ShapedType::isDynamic(N) ? D : N; + const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W}; + const auto outputType = cast<ShapedType>(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, PadOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index f072e3e..e965ae0 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -25,6 +25,12 @@ TosaProfileCompliance::TosaProfileCompliance() { const TypeInfo fp8e4m3T = {mlir::Float8E4M3FNType::getTypeID(), 8}; const TypeInfo fp8e5m2T = {mlir::Float8E5M2Type::getTypeID(), 8}; + // micro-scaling formats + const TypeInfo fp6e2m3T = {mlir::Float6E2M3FNType::getTypeID(), 6}; + const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6}; + const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4}; + const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8}; + // The profile-based compliance content below is auto-generated by a script // in https://git.mlplatform.org/tosa/specification.git #include "mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc" @@ -269,6 +275,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. + POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) POPULATE_PROFILE_INFO_COMMON(Cast) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) @@ -623,6 +630,14 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) { return {"fp8e4m3"}; } else if (typeInfo.typeID == mlir::Float8E5M2Type::getTypeID()) { return {"fp8e5m2"}; + } else if (typeInfo.typeID == mlir::Float6E2M3FNType::getTypeID()) { + return {"fp6e2m3"}; + } else if (typeInfo.typeID == mlir::Float6E3M2FNType::getTypeID()) { + return {"fp6e3m2"}; + } else if (typeInfo.typeID == mlir::Float4E2M1FNType::getTypeID()) { + return {"fp4e2m1"}; + } else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) { + return {"fp8e8m0"}; } llvm_unreachable("unknown type"); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 82f2f7e..3f874d9 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -657,6 +657,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_SIZES(TransposeConv2D); CHECK_SIZES(FFT2d); CHECK_SIZES(MatMul); + CHECK_SIZES(MatmulTBlockScaled); CHECK_SIZES(MaxPool2d); CHECK_SIZES(RFFT2d); // Scatter/Gather Operators @@ -1192,9 +1193,9 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa<FloatType>(type)) { return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, - Float8E5M2Type>(type); - } - if (auto intTy = dyn_cast<IntegerType>(type)) { + Float8E5M2Type, Float4E2M1FNType, Float6E2M3FNType, + Float6E3M2FNType, Float8E8M0FNUType>(type); + } else if (auto intTy = dyn_cast<IntegerType>(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: @@ -1220,13 +1221,19 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { } void TosaValidation::runOnOperation() { + ModuleOp modOp = getOperation(); + const TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(modOp); + const auto maybeTargetEnv = + tosa::TargetEnv::createTargetEnvFromAttr(targetEnvAttr, modOp.getLoc()); + if (failed(maybeTargetEnv)) + return signalPassFailure(); + targetEnv = *maybeTargetEnv; + TosaDialect *tosaDialect = getContext().getLoadedDialect<TosaDialect>(); if (!tosaDialect) return; - targetEnv = tosa::TargetEnv(lookupTargetEnvOrDefault(getOperation())); - - getOperation().walk([&](Operation *op) { + modOp.walk([&](Operation *op) { if (op->getDialect() != tosaDialect) return; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 8f46ad6..ef49c86 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -74,9 +74,9 @@ struct MixedSizeInputShuffleOpRewrite final for (int64_t i = 0; i < origNumElems; ++i) promoteMask[i] = i; - Value promotedInput = rewriter.create<vector::ShuffleOp>( - shuffleOp.getLoc(), promotedType, inputToPromote, inputToPromote, - promoteMask); + Value promotedInput = + vector::ShuffleOp::create(rewriter, shuffleOp.getLoc(), promotedType, + inputToPromote, inputToPromote, promoteMask); // Create the final shuffle with the promoted inputs. Value promotedV1 = promoteV1 ? promotedInput : shuffleOp.getV1(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7c019e7..8b5e950 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -341,13 +341,18 @@ private: /// Return the distributed vector type based on the original type and the /// distribution map. The map is expected to have a dimension equal to the /// original type rank and should be a projection where the results are the -/// distributed dimensions. The number of results should be equal to the number +/// distributed dimensions. If the number of results is zero there is no +/// distribution (i.e. original type is returned). +/// Otherwise, The number of results should be equal to the number /// of warp sizes which is currently limited to 1. /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) /// and a warp size of 16 would distribute the second dimension (associated to /// d1) and return vector<16x2x64> static VectorType getDistributedType(VectorType originalType, AffineMap map, int64_t warpSize) { + // If the map has zero results, return the original type. + if (map.getNumResults() == 0) + return originalType; SmallVector<int64_t> targetShape(originalType.getShape()); for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { unsigned position = map.getDimPosition(i); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index 1599ae9..24e9095 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -736,7 +736,7 @@ OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder) { auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); - return builder.create<ArithOp>(loc, aVal, bVal).getResult(); + return ArithOp::create(builder, loc, aVal, bVal).getResult(); } // a helper utility to perform division operation on OpFoldResult and int64_t. diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 26770b3..d09dc19 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -1505,14 +1505,19 @@ void XeGPUSubgroupDistributePass::runOnOperation() { return AffineMap::get(val.getContext()); // Get the layout of the vector type. xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(val); - // If no layout is specified, assume the inner most dimension is distributed - // for now. + // If no layout is specified, that means no distribution. if (!layout) - return AffineMap::getMultiDimMapWithTargets( - vecRank, {static_cast<unsigned int>(vecRank - 1)}, val.getContext()); + return AffineMap::getMultiDimMapWithTargets(vecRank, {}, + val.getContext()); + // Expecting vector and layout rank to match. + assert(layout.getRank() == vecRank && + "Expecting vector and layout rank to match"); + // A dimension is distributed only if layout suggests there are + // multiple lanes assigned for this dimension and the shape can be evenly + // distributed to those lanes. SmallVector<unsigned int> distributedDims; for (auto [i, v] : llvm::enumerate(layout.getEffectiveLaneLayoutAsInt())) { - if (v > 1) + if (v > 1 && vecType.getShape()[i] % v == 0) distributedDims.push_back(i); } return AffineMap::getMultiDimMapWithTargets(vecRank, distributedDims, @@ -1525,15 +1530,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() { auto warpReduction = [](Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size) { // First reduce on a single thread to get per lane reduction value. - Value laneVal = builder.create<vector::ReductionOp>(loc, kind, input); + Value laneVal = vector::ReductionOp::create(builder, loc, kind, input); // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < size; i <<= 1) { - Value shuffled = - builder - .create<gpu::ShuffleOp>(loc, laneVal, i, - /*width=*/size, - /*mode=*/gpu::ShuffleMode::XOR) - .getShuffleResult(); + Value shuffled = gpu::ShuffleOp::create(builder, loc, laneVal, i, + /*width=*/size, + /*mode=*/gpu::ShuffleMode::XOR) + .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } return laneVal; diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 31a967d..9fc5ad9 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -825,7 +825,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType), baseTileValues); - auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr); + auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr); // Get subgroup id Value sgId = @@ -837,25 +837,26 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { SmallVector<Value, 2> strideConsts; strideConsts.push_back( - rewriter.create<arith::ConstantIndexOp>(loc, colStride)); + arith::ConstantIndexOp::create(rewriter, loc, colStride)); if (rows > 1) strideConsts.insert( strideConsts.begin(), - rewriter.create<arith::ConstantIndexOp>(loc, rowStride)); + arith::ConstantIndexOp::create(rewriter, loc, rowStride)); SmallVector<Value> newConstOps; for (auto offsets : *sgOffsets) { // Multiply offset with stride, broadcast it and add to baseConstVec - Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value mulOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); for (size_t i = 0; i < strideConsts.size(); ++i) { - Value mul = rewriter.create<arith::MulIOp>( - loc, rewriter.getIndexType(), offsets[i], strideConsts[i]); - mulOffset = rewriter.create<arith::AddIOp>( - loc, rewriter.getIndexType(), mulOffset, mul); + Value mul = + arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(), + offsets[i], strideConsts[i]); + mulOffset = arith::AddIOp::create( + rewriter, loc, rewriter.getIndexType(), mulOffset, mul); } // Broadcast to baseConstVec size - auto bcastOffset = rewriter.create<vector::BroadcastOp>( - loc, baseConstVec.getType(), mulOffset); + auto bcastOffset = vector::BroadcastOp::create( + rewriter, loc, baseConstVec.getType(), mulOffset); auto finalConst = arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset); setLayoutIfNeeded(baseConstVec); @@ -1138,8 +1139,8 @@ struct WgToSgVectorShapeCastOp SmallVector<Value> newShapeCastOps; for (auto src : adaptor.getSource()) { - auto newShapeCast = - rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src); + auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(), + newResultType, src); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), @@ -1201,9 +1202,9 @@ struct WgToSgMultiDimReductionOp SmallVector<Value> newReductions; for (auto sgSrc : adaptor.getSource()) { - auto newOp = rewriter.create<vector::MultiDimReductionOp>( - op.getLoc(), newDstType, op.getKind(), sgSrc, adaptor.getAcc()[0], - op.getReductionDims()); + auto newOp = vector::MultiDimReductionOp::create( + rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc, + adaptor.getAcc()[0], op.getReductionDims()); if (!layout.getEffectiveLaneLayoutAsInt().empty() || !layout.getEffectiveInstDataAsInt().empty()) xegpu::setDistributeLayoutAttr(newOp->getResult(0), diff --git a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt index 6ef1529..c712c64b 100644 --- a/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt +++ b/mlir/lib/ExecutionEngine/SparseTensor/CMakeLists.txt @@ -21,6 +21,6 @@ set_property(TARGET MLIRSparseTensorRuntime PROPERTY CXX_STANDARD 17) check_cxx_compiler_flag(-Wweak-vtables COMPILER_SUPPORTS_WARNING_WEAK_VTABLES) if(COMPILER_SUPPORTS_WARNING_WEAK_VTABLES) - target_compile_options(MLIRSparseTensorRuntime PUBLIC + target_compile_options(MLIRSparseTensorRuntime PRIVATE "-Wweak-vtables") endif() diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 4d81918..776b5c6 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -378,10 +378,8 @@ struct SourceMgrDiagnosticHandlerImpl { } // Otherwise, try to load the source file. - auto bufferOrErr = llvm::MemoryBuffer::getFile(filename); - if (!bufferOrErr) - return 0; - unsigned id = mgr.AddNewSourceBuffer(std::move(*bufferOrErr), SMLoc()); + std::string ignored; + unsigned id = mgr.AddIncludeFile(std::string(filename), SMLoc(), ignored); filenameToBufId[filename] = id; return id; } diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp index dd413d2de..d7e321a 100644 --- a/mlir/lib/RegisterAllPasses.cpp +++ b/mlir/lib/RegisterAllPasses.cpp @@ -33,6 +33,7 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/NVGPU/Transforms/Passes.h" #include "mlir/Dialect/OpenACC/Transforms/Passes.h" +#include "mlir/Dialect/OpenMP/Transforms/Passes.h" #include "mlir/Dialect/Quant/Transforms/Passes.h" #include "mlir/Dialect/SCF/Transforms/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" @@ -80,6 +81,7 @@ void mlir::registerAllPasses() { memref::registerMemRefPasses(); shard::registerShardPasses(); ml_program::registerMLProgramPasses(); + omp::registerOpenMPPasses(); quant::registerQuantPasses(); registerSCFPasses(); registerShapePasses(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 8de49dd..f284540 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -357,14 +357,8 @@ static LogicalResult checkImplementationStatus(Operation &op) { result = todo("priority"); }; auto checkPrivate = [&todo](auto op, LogicalResult &result) { - if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) { - // Privatization is supported only for included target tasks. - if (!op.getPrivateVars().empty() && op.getNowait()) - result = todo("privatization for deferred target tasks"); - } else { - if (!op.getPrivateVars().empty() || op.getPrivateSyms()) - result = todo("privatization"); - } + if (!op.getPrivateVars().empty() || op.getPrivateSyms()) + result = todo("privatization"); }; auto checkReduction = [&todo](auto op, LogicalResult &result) { if (isa<omp::TeamsOp>(op)) @@ -451,7 +445,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { checkDevice(op, result); checkInReduction(op, result); checkIsDevicePtr(op, result); - checkPrivate(op, result); }) .Default([](Operation &) { // Assume all clauses for an operation can be translated unless they are @@ -3833,6 +3826,58 @@ static llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, return builder.getInt64(dl.getTypeSizeInBits(type) / 8); } +// Convert the MLIR map flag set to the runtime map flag set for embedding +// in LLVM-IR. This is important as the two bit-flag lists do not correspond +// 1-to-1 as there's flags the runtime doesn't care about and vice versa. +// Certain flags are discarded here such as RefPtee and co. +static llvm::omp::OpenMPOffloadMappingFlags +convertClauseMapFlags(omp::ClauseMapFlags mlirFlags) { + auto mapTypeToBool = [&mlirFlags](omp::ClauseMapFlags flag) { + return (mlirFlags & flag) == flag; + }; + + llvm::omp::OpenMPOffloadMappingFlags mapType = + llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; + + if (mapTypeToBool(omp::ClauseMapFlags::to)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; + + if (mapTypeToBool(omp::ClauseMapFlags::from)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; + + if (mapTypeToBool(omp::ClauseMapFlags::always)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; + + if (mapTypeToBool(omp::ClauseMapFlags::del)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; + + if (mapTypeToBool(omp::ClauseMapFlags::return_param)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; + + if (mapTypeToBool(omp::ClauseMapFlags::priv)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRIVATE; + + if (mapTypeToBool(omp::ClauseMapFlags::literal)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL; + + if (mapTypeToBool(omp::ClauseMapFlags::implicit)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; + + if (mapTypeToBool(omp::ClauseMapFlags::close)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; + + if (mapTypeToBool(omp::ClauseMapFlags::present)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; + + if (mapTypeToBool(omp::ClauseMapFlags::ompx_hold)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; + + if (mapTypeToBool(omp::ClauseMapFlags::attach)) + mapType |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ATTACH; + + return mapType; +} + static void collectMapDataFromMapOperands( MapInfoData &mapData, SmallVectorImpl<Value> &mapVars, LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl, @@ -3880,8 +3925,7 @@ static void collectMapDataFromMapOperands( getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(), mapData.BaseType.back(), builder, moduleTranslation)); mapData.MapClause.push_back(mapOp.getOperation()); - mapData.Types.push_back( - llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType())); + mapData.Types.push_back(convertClauseMapFlags(mapOp.getMapType())); mapData.Names.push_back(LLVM::createMappingInformation( mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder())); mapData.DevicePointers.push_back(llvm::OpenMPIRBuilder::DeviceInfoTy::None); @@ -3950,8 +3994,7 @@ static void collectMapDataFromMapOperands( Value offloadPtr = mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr(); llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr); - auto mapType = - static_cast<llvm::omp::OpenMPOffloadMappingFlags>(mapOp.getMapType()); + auto mapType = convertClauseMapFlags(mapOp.getMapType()); auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; mapData.OriginalValue.push_back(origValue); @@ -4299,8 +4342,7 @@ static void processMapMembersWithParent( // in part as we currently have substantially less information on the data // being mapped at this stage. if (checkIfPointerMap(memberClause)) { - auto mapFlag = - llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType()); + auto mapFlag = convertClauseMapFlags(memberClause.getMapType()); mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF; ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); @@ -4319,8 +4361,7 @@ static void processMapMembersWithParent( // Same MemberOfFlag to indicate its link with parent and other members // of. - auto mapFlag = - llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType()); + auto mapFlag = convertClauseMapFlags(memberClause.getMapType()); mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM; mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF; ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index d9ad8fb..6492708 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -702,8 +702,8 @@ spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) { // RAII guard to reset the insertion point to previous value when done. OpBuilder::InsertionGuard insertionGuard(opBuilder); opBuilder.setInsertionPoint(graphARM); - opBuilder.create<spirv::GraphEntryPointARMOp>( - unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name), + spirv::GraphEntryPointARMOp::create( + opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name), opBuilder.getArrayAttr(interface)); return success(); @@ -736,7 +736,7 @@ spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) { std::string graphName = getGraphSymbol(graphID); auto graphOp = - opBuilder.create<spirv::GraphARMOp>(unknownLoc, graphName, graphType); + spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType); curGraph = graphMap[graphID] = graphOp; Block *entryBlock = graphOp.addEntryBlock(); LLVM_DEBUG({ @@ -844,7 +844,7 @@ spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) { LogicalResult spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) { // Create GraphOutputsARM instruction. - opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs); + spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs); // Process OpGraphEndARM. if (!operands.empty()) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index b56e778..b88fbaa 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -260,9 +260,9 @@ static std::string getDecorationName(StringRef attrName) { } template <typename AttrTy, typename EmitF> -LogicalResult processDecorationList(Location loc, Decoration decoration, - Attribute attrList, StringRef attrName, - EmitF emitter) { +static LogicalResult processDecorationList(Location loc, Decoration decoration, + Attribute attrList, + StringRef attrName, EmitF emitter) { auto arrayAttr = dyn_cast<ArrayAttr>(attrList); if (!arrayAttr) { return emitError(loc, "expecting array attribute of ") diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp index 366ba8f..048e964 100644 --- a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -406,7 +406,7 @@ private: auto returnOperands = popOperands(resTypes); if (failed(returnOperands)) return failure(); - builder.create<BlockReturnOp>(opLoc, *returnOperands); + BlockReturnOp::create(builder, opLoc, *returnOperands); LDBG() << "end of parsing of a block"; return bodyParsingRes->endingByte; } @@ -1000,7 +1000,7 @@ parsed_inst_t ExpressionParser::parseBlockLikeOp(OpBuilder &builder) { builder.createBlock(curRegion, curRegion->end(), resTypes, locations); builder.setInsertionPointToEnd(curBlock); auto blockOp = - builder.create<OpToCreate>(*currentOpLoc, *inputOps, successor); + OpToCreate::create(builder, *currentOpLoc, *inputOps, successor); auto *blockBody = blockOp.createBlock(); if (failed(parseBlockContent(builder, blockBody, resTypes, *opLoc, blockOp))) return failure(); @@ -1047,8 +1047,8 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< auto *successor = builder.createBlock(curRegion, curRegion->end(), resTypes, locations); builder.setInsertionPointToEnd(curBlock); - auto ifOp = builder.create<IfOp>(*currentOpLoc, conditionValue->front(), - *inputOps, successor); + auto ifOp = IfOp::create(builder, *currentOpLoc, conditionValue->front(), + *inputOps, successor); auto *ifEntryBlock = ifOp.createIfBlock(); constexpr auto ifElseFilter = ByteSequence<WasmBinaryEncoding::endByte, @@ -1091,9 +1091,9 @@ inline parsed_inst_t ExpressionParser::parseSpecificInstruction< auto branchArgs = popOperands(inputTypes); if (failed(branchArgs)) return failure(); - builder.create<BranchIfOp>(*currentOpLoc, condition->front(), - builder.getUI32IntegerAttr(*level), *branchArgs, - elseBlock); + BranchIfOp::create(builder, *currentOpLoc, condition->front(), + builder.getUI32IntegerAttr(*level), *branchArgs, + elseBlock); builder.setInsertionPointToStart(elseBlock); return {*branchArgs}; } @@ -1115,7 +1115,7 @@ ExpressionParser::parseSpecificInstruction<WasmBinaryEncoding::OpCode::call>( if (failed(inOperands)) return failure(); auto callOp = - builder.create<FuncCallOp>(loc, resTypes, callee.symbol, *inOperands); + FuncCallOp::create(builder, loc, resTypes, callee.symbol, *inOperands); return {callOp.getResults()}; } @@ -1391,8 +1391,8 @@ inline parsed_inst_t ExpressionParser::buildConvertOp(OpBuilder &builder, auto operand = popOperands(intype); if (failed(operand)) return failure(); - auto op = builder.create<opType>(*currentOpLoc, outType, operand->front(), - extraArgs...); + auto op = opType::create(builder, *currentOpLoc, outType, operand->front(), + extraArgs...); LDBG() << "Built operation: " << op; return {{op.getResult()}}; } |
