diff options
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r-- | mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 177 | ||||
-rw-r--r-- | mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 85 |
2 files changed, 256 insertions, 6 deletions
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 1f239aa..519d9c8 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/LogicalResult.h" namespace mlir { #define GEN_PASS_DEF_SCFTOEMITC @@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables, emitc::AssignOp::create(rewriter, loc, var, value); } -SmallVector<Value> loadValues(const SmallVector<Value> &variables, +SmallVector<Value> loadValues(ArrayRef<Value> variables, PatternRewriter &rewriter, Location loc) { return llvm::map_to_vector<>(variables, [&](Value var) { Type type = cast<emitc::LValueType>(var.getType()).getValueType(); @@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables, static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, ConversionPatternRewriter &rewriter, - scf::YieldOp yield) { + scf::YieldOp yield, bool createYield = true) { Location loc = yield.getLoc(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); SmallVector<Value> yieldOperands; - if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); - } assignValues(yieldOperands, resultVariables, rewriter, loc); @@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite( return success(); } +// Lower scf::while to emitc::do using mutable variables to maintain loop state +// across iterations. The do-while structure ensures the condition is evaluated +// after each iteration, matching SCF while semantics. +struct WhileLowering : public OpConversionPattern<WhileOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = whileOp.getLoc(); + MLIRContext *context = loc.getContext(); + + // Create an emitc::variable op for each result. These variables will be + // assigned to by emitc::assign ops within the loop body. + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(whileOp, + "Failed to create result variables"); + + // Create variable storage for loop-carried values to enable imperative + // updates while maintaining SSA semantics at conversion boundaries. + SmallVector<Value> loopVariables; + if (failed(createVariablesForLoopCarriedValues( + whileOp, rewriter, loopVariables, loc, context))) + return failure(); + + if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context, + rewriter, loc))) + return failure(); + + rewriter.setInsertionPointAfter(whileOp); + + // Load the final result values from result variables. + SmallVector<Value> finalResults = + loadValues(resultVariables, rewriter, loc); + rewriter.replaceOp(whileOp, finalResults); + + return success(); + } + +private: + // Initialize variables for loop-carried values to enable state updates + // across iterations without SSA argument passing. + LogicalResult createVariablesForLoopCarriedValues( + WhileOp whileOp, ConversionPatternRewriter &rewriter, + SmallVectorImpl<Value> &loopVars, Location loc, + MLIRContext *context) const { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(whileOp); + + emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); + + for (Value init : whileOp.getInits()) { + Type convertedType = getTypeConverter()->convertType(init.getType()); + if (!convertedType) + return rewriter.notifyMatchFailure(whileOp, "type conversion failed"); + + emitc::VariableOp var = rewriter.create<emitc::VariableOp>( + loc, emitc::LValueType::get(convertedType), noInit); + rewriter.create<emitc::AssignOp>(loc, var.getResult(), init); + loopVars.push_back(var); + } + + return success(); + } + + // Lower scf.while to emitc.do. + LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars, + ArrayRef<Value> resultVars, MLIRContext *context, + ConversionPatternRewriter &rewriter, + Location loc) const { + // Create a global boolean variable to store the loop condition state. + Type i1Type = IntegerType::get(context, 1); + auto globalCondition = + rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type), + emitc::OpaqueAttr::get(context, "")); + Value conditionVal = globalCondition.getResult(); + + auto loweredDo = rewriter.create<emitc::DoOp>(loc); + + // Convert region types to match the target dialect type system. + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), + *getTypeConverter(), nullptr)) || + failed(rewriter.convertRegionTypes(&whileOp.getAfter(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(whileOp, + "region types conversion failed"); + } + + // Prepare the before region (condition evaluation) for merging. + Block *beforeBlock = &whileOp.getBefore().front(); + Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion()); + rewriter.setInsertionPointToStart(bodyBlock); + + // Load current variable values to use as initial arguments for the + // condition block. + SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc); + rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues); + + Operation *condTerminator = + loweredDo.getBodyRegion().back().getTerminator(); + scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator); + rewriter.setInsertionPoint(condOp); + + // Update result variables with values from scf::condition. + SmallVector<Value> conditionArgs; + for (Value arg : condOp.getArgs()) { + conditionArgs.push_back(rewriter.getRemappedValue(arg)); + } + assignValues(conditionArgs, resultVars, rewriter, loc); + + // Convert scf.condition to condition variable assignment. + Value condition = rewriter.getRemappedValue(condOp.getCondition()); + rewriter.create<emitc::AssignOp>(loc, conditionVal, condition); + + // Wrap body region in conditional to preserve scf semantics. Only create + // ifOp if after-region is non-empty. + if (whileOp.getAfterBody()->getOperations().size() > 1) { + auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false); + + // Prepare the after region (loop body) for merging. + Block *afterBlock = &whileOp.getAfter().front(); + Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion()); + + // Replacement values for after block using condition op arguments. + SmallVector<Value> afterReplacingValues; + for (Value arg : condOp.getArgs()) + afterReplacingValues.push_back(rewriter.getRemappedValue(arg)); + + rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues); + + if (failed(lowerYield(whileOp, loopVars, rewriter, + cast<scf::YieldOp>(ifBodyBlock->getTerminator())))) + return failure(); + } + + rewriter.eraseOp(condOp); + + // Create condition region that loads from the flag variable. + Region &condRegion = loweredDo.getConditionRegion(); + Block *condBlock = rewriter.createBlock(&condRegion); + rewriter.setInsertionPointToStart(condBlock); + + auto exprOp = rewriter.create<emitc::ExpressionOp>( + loc, i1Type, conditionVal, /*do_not_inline=*/false); + Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion()); + + // Set up the expression block to load the condition variable. + exprBlock->addArgument(conditionVal.getType(), loc); + rewriter.setInsertionPointToStart(exprBlock); + + // Load the condition value and yield it as the expression result. + Value cond = + rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0)); + rewriter.create<emitc::YieldOp>(loc, cond); + + // Yield the expression as the condition region result. + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create<emitc::YieldOp>(loc, exprOp); + + return success(); + } +}; + void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add<ForLowering>(typeConverter, patterns.getContext()); patterns.add<IfLowering>(typeConverter, patterns.getContext()); patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext()); + patterns.add<WhileLowering>(typeConverter, patterns.getContext()); } void SCFToEmitCPass::runOnOperation() { @@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() { // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); - target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(); + target + .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 57877b8..f449d90 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) { return op.getCacheControl(); } +static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) { + return op.getCacheControl(); +} + static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) { return op.getCacheControl(); } @@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) { return op.getCacheControl(); } +static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) { + return op.getCacheControl(); +} + static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) { if (op->hasAttr("cache_control")) { auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control"); @@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> || std::is_same_v<OpType, BlockPrefetch2dOp> || std::is_same_v<OpType, LLVM::LoadOp> || + std::is_same_v<OpType, BlockLoadOp> || std::is_same_v<OpType, PrefetchOp>; const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; SmallVector<int32_t, decorationCacheControlArity> decorationsL1{ @@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { return success(); } }; + +template <typename OpType> +class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>; + // Get OpenCL function name + // https://registry.khronos.org/OpenCL/extensions/ + // intel/cl_intel_subgroup_local_block_io.html + std::string funcName{"intel_sub_group_block_"}; + // Value or Result type can be vector or scalar + Type valOrResTy; + if constexpr (isStore) { + funcName += "write_u"; + valOrResTy = op.getVal().getType(); + } else { + funcName += "read_u"; + valOrResTy = op.getType(); + } + // Get element type of the vector/scalar + VectorType vecTy = dyn_cast<VectorType>(valOrResTy); + Type elemType = vecTy ? vecTy.getElementType() : valOrResTy; + funcName += getTypeMangling(elemType); + if (vecTy) + funcName += std::to_string(vecTy.getNumElements()); + SmallVector<Type, 2> argTypes{}; + // XeVM BlockLoad/StoreOp always use signless integer types + // but OpenCL builtins expect unsigned types + // use unsigned types for mangling + SmallVector<bool, 2> isUnsigned{}; + // arg0: pointer to the src/dst address + // arg1 - only if store : vector to store + // Prepare arguments + SmallVector<Value, 2> args{}; + args.push_back(op.getPtr()); + argTypes.push_back(op.getPtr().getType()); + isUnsigned.push_back(true); + Type retType; + if constexpr (isStore) { + args.push_back(op.getVal()); + argTypes.push_back(op.getVal().getType()); + isUnsigned.push_back(true); + retType = LLVM::LLVMVoidType::get(rewriter.getContext()); + } else { + retType = valOrResTy; + } + funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName + + "PU3AS" + + std::to_string(op.getPtr().getType().getAddressSpace()); + funcName += getTypeMangling(elemType, /*isUnsigned=*/true); + if constexpr (isStore) + funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true); + LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs}; + + LLVM::CallOp call = + createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args, + {}, funcAttr, op.getOperation()); + if (std::optional<ArrayAttr> optCacheControls = + getCacheControlMetadata(rewriter, op)) { + call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + } + if constexpr (isStore) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, call->getResult(0)); + return success(); + } +}; + template <typename OpType> class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { using OpConversionPattern<OpType>::OpConversionPattern; @@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern, LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, - LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext()); + LLVMLoadStoreToOCLPattern<LLVM::StoreOp>, + BlockLoadStore1DToOCLPattern<BlockLoadOp>, + BlockLoadStore1DToOCLPattern<BlockStoreOp>>( + patterns.getContext()); } void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { |