diff options
Diffstat (limited to 'mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 202 |
1 files changed, 103 insertions, 99 deletions
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 1d92b5d..9b61540 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -94,13 +94,13 @@ static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter) { if (isa<VectorType>(srcType)) { - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, + return LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast<ShapedType>(srcType), minusOneIntegerAttribute(srcType, rewriter))); } - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); + return LLVM::ConstantOp::create(rewriter, loc, dstType, + minusOneIntegerAttribute(srcType, rewriter)); } /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value. @@ -108,14 +108,14 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType, PatternRewriter &rewriter, double value) { if (auto vecType = dyn_cast<VectorType>(srcType)) { auto floatType = cast<FloatType>(vecType.getElementType()); - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, + return LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(vecType, rewriter.getFloatAttr(floatType, value))); } auto floatType = cast<FloatType>(srcType); - return rewriter.create<LLVM::ConstantOp>( - loc, dstType, rewriter.getFloatAttr(floatType, value)); + return LLVM::ConstantOp::create(rewriter, loc, dstType, + rewriter.getFloatAttr(floatType, value)); } /// Utility function for bitfield ops: @@ -134,13 +134,13 @@ static Value optionallyTruncateOrExtend(Location loc, Value value, : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) - return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value); + return LLVM::ZExtOp::create(rewriter, loc, llvmType, value); // If the bit widths of `Count` and `Offset` are greater than the bit width // of the target type, they are truncated. Truncation is safe since `Count` // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, // both values can be expressed in 8 bits. if (valueBitWidth > targetBitWidth) - return rewriter.create<LLVM::TruncOp>(loc, llvmType, value); + return LLVM::TruncOp::create(rewriter, loc, llvmType, value); return value; } @@ -151,12 +151,12 @@ static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); - Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType); + Value broadcasted = LLVM::PoisonOp::create(rewriter, loc, llvmVectorType); for (unsigned i = 0; i < numElements; ++i) { - auto index = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); - broadcasted = rewriter.create<LLVM::InsertElementOp>( - loc, llvmVectorType, broadcasted, toBroadcast, index); + auto index = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, + rewriter.getI32IntegerAttr(i)); + broadcasted = LLVM::InsertElementOp::create( + rewriter, loc, llvmVectorType, broadcasted, toBroadcast, index); } return broadcasted; } @@ -217,8 +217,8 @@ static Type convertStructTypePacked(spirv::StructType type, /// Creates LLVM dialect constant with the given value. static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter, unsigned value) { - return rewriter.create<LLVM::ConstantOp>( - loc, IntegerType::get(rewriter.getContext(), 32), + return LLVM::ConstantOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), 32), rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } @@ -322,8 +322,9 @@ public: auto llvmIndexType = getTypeConverter()->convertType(indexType); if (!llvmIndexType) return rewriter.notifyMatchFailure(op, "type conversion failed"); - Value zero = rewriter.create<LLVM::ConstantOp>( - op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); + Value zero = + LLVM::ConstantOp::create(rewriter, op.getLoc(), llvmIndexType, + rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); auto elementType = getTypeConverter()->convertType( @@ -375,20 +376,20 @@ public: // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); - Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType, - maskShiftedByCount, minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value negated = LLVM::XOrOp::create(rewriter, loc, dstType, + maskShiftedByCount, minusOne); Value maskShiftedByCountAndOffset = - rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset); - Value mask = rewriter.create<LLVM::XOrOp>( - loc, dstType, maskShiftedByCountAndOffset, minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, negated, offset); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, + maskShiftedByCountAndOffset, minusOne); // Extract unchanged bits from the `Base` that are outside of // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. Value baseAndMask = - rewriter.create<LLVM::AndOp>(loc, dstType, op.getBase(), mask); + LLVM::AndOp::create(rewriter, loc, dstType, op.getBase(), mask); Value insertShiftedByOffset = - rewriter.create<LLVM::ShlOp>(loc, dstType, op.getInsert(), offset); + LLVM::ShlOp::create(rewriter, loc, dstType, op.getInsert(), offset); rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask, insertShiftedByOffset); return success(); @@ -470,23 +471,23 @@ public: auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType)); Value size = isa<VectorType>(srcType) - ? rewriter.create<LLVM::ConstantOp>( - loc, dstType, + ? LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast<ShapedType>(srcType), baseSize)) - : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize); + : LLVM::ConstantOp::create(rewriter, loc, dstType, baseSize); // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit // at Offset + Count - 1 is the most significant bit now. Value countPlusOffset = - rewriter.create<LLVM::AddOp>(loc, dstType, count, offset); + LLVM::AddOp::create(rewriter, loc, dstType, count, offset); Value amountToShiftLeft = - rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset); - Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>( - loc, dstType, op.getBase(), amountToShiftLeft); + LLVM::SubOp::create(rewriter, loc, dstType, size, countPlusOffset); + Value baseShiftedLeft = LLVM::ShlOp::create( + rewriter, loc, dstType, op.getBase(), amountToShiftLeft); // Shift the result right, filling the bits with the sign bit. Value amountToShiftRight = - rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft); + LLVM::AddOp::create(rewriter, loc, dstType, offset, amountToShiftLeft); rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft, amountToShiftRight); return success(); @@ -516,13 +517,13 @@ public: // Create a mask with bits set at [0, Count - 1]. Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); Value maskShiftedByCount = - rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count); - Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount, - minusOne); + LLVM::ShlOp::create(rewriter, loc, dstType, minusOne, count); + Value mask = LLVM::XOrOp::create(rewriter, loc, dstType, maskShiftedByCount, + minusOne); // Shift `Base` by `Offset` and apply the mask on it. Value shiftedBase = - rewriter.create<LLVM::LShrOp>(loc, dstType, op.getBase(), offset); + LLVM::LShrOp::create(rewriter, loc, dstType, op.getBase(), offset); rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask); return success(); } @@ -694,8 +695,8 @@ public: auto structType = LLVM::LLVMStructType::getLiteral(context, fields); // Create `llvm.mlir.global` with initializer region containing one block. - auto global = rewriter.create<LLVM::GlobalOp>( - UnknownLoc::get(context), structType, /*isConstant=*/true, + auto global = LLVM::GlobalOp::create( + rewriter, UnknownLoc::get(context), structType, /*isConstant=*/true, LLVM::Linkage::External, executionModeInfoName, Attribute(), /*alignment=*/0); Location loc = global.getLoc(); @@ -704,22 +705,23 @@ public: // Initialize the struct and set the execution mode value. rewriter.setInsertionPointToStart(block); - Value structValue = rewriter.create<LLVM::PoisonOp>(loc, structType); - Value executionMode = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, + Value structValue = LLVM::PoisonOp::create(rewriter, loc, structType); + Value executionMode = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, rewriter.getI32IntegerAttr( static_cast<uint32_t>(executionModeAttr.getValue()))); - structValue = rewriter.create<LLVM::InsertValueOp>(loc, structValue, - executionMode, 0); + SmallVector<int64_t> position{0}; + structValue = LLVM::InsertValueOp::create(rewriter, loc, structValue, + executionMode, position); // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { auto attr = values.getValue()[i]; - Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr); - structValue = rewriter.create<LLVM::InsertValueOp>( - loc, structValue, entry, ArrayRef<int64_t>({1, i})); + Value entry = LLVM::ConstantOp::create(rewriter, loc, llvmI32Type, attr); + structValue = LLVM::InsertValueOp::create( + rewriter, loc, structValue, entry, ArrayRef<int64_t>({1, i})); } - rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue})); + LLVM::ReturnOp::create(rewriter, loc, ArrayRef<Value>({structValue})); rewriter.eraseOp(op); return success(); } @@ -913,7 +915,7 @@ public: Location loc = op.getLoc(); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); - Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.getOperand()); + Value sqrt = LLVM::SqrtOp::create(rewriter, loc, dstType, op.getOperand()); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt); return success(); } @@ -973,10 +975,10 @@ public: IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); auto mask = isa<VectorType>(srcType) - ? rewriter.create<LLVM::ConstantOp>( - loc, dstType, + ? LLVM::ConstantOp::create( + rewriter, loc, dstType, SplatElementsAttr::get(cast<VectorType>(srcType), minusOne)) - : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne); + : LLVM::ConstantOp::create(rewriter, loc, dstType, minusOne); rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType, notOp.getOperand(), mask); return success(); @@ -1034,8 +1036,8 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, return func; OpBuilder b(symbolTable->getRegion(0)); - func = b.create<LLVM::LLVMFuncOp>( - symbolTable->getLoc(), name, + func = LLVM::LLVMFuncOp::create( + b, symbolTable->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes)); func.setCConv(LLVM::cconv::CConv::SPIR_FUNC); func.setConvergent(convergent); @@ -1047,7 +1049,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable, static LLVM::CallOp createSPIRVBuiltinCall(Location loc, OpBuilder &builder, LLVM::LLVMFuncOp func, ValueRange args) { - auto call = builder.create<LLVM::CallOp>(loc, func, args); + auto call = LLVM::CallOp::create(builder, loc, func, args); call.setCConv(func.getCConv()); call.setConvergentAttr(func.getConvergentAttr()); call.setNoUnwindAttr(func.getNoUnwindAttr()); @@ -1078,12 +1080,12 @@ public: lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy); Location loc = controlBarrierOp->getLoc(); - Value execution = rewriter.create<LLVM::ConstantOp>( - loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); - Value memory = rewriter.create<LLVM::ConstantOp>( - loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); - Value semantics = rewriter.create<LLVM::ConstantOp>( - loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); + Value execution = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast<int32_t>(adaptor.getExecutionScope())); + Value memory = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemoryScope())); + Value semantics = LLVM::ConstantOp::create( + rewriter, loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics())); auto call = createSPIRVBuiltinCall(loc, rewriter, func, {execution, memory, semantics}); @@ -1255,10 +1257,12 @@ public: lookupOrCreateSPIRVFn(symbolTable, funcName, paramTypes, retTy); Location loc = op.getLoc(); - Value scope = rewriter.create<LLVM::ConstantOp>( - loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope())); - Value groupOp = rewriter.create<LLVM::ConstantOp>( - loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation())); + Value scope = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, + static_cast<int32_t>(adaptor.getExecutionScope())); + Value groupOp = LLVM::ConstantOp::create( + rewriter, loc, i32Ty, + static_cast<int32_t>(adaptor.getGroupOperation())); SmallVector<Value> operands{scope, groupOp}; operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); @@ -1368,7 +1372,7 @@ public: return failure(); Block *headerBlock = loopOp.getHeaderBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock); + LLVM::BrOp::create(rewriter, loc, brOp.getBlockArguments(), headerBlock); rewriter.eraseBlock(entryBlock); // Branch from merge block to end block. @@ -1376,7 +1380,7 @@ public: Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, endBlock); rewriter.inlineRegionBefore(loopOp.getBody(), endBlock); rewriter.replaceOp(loopOp, endBlock->getArguments()); @@ -1434,16 +1438,15 @@ public: Operation *terminator = mergeBlock->getTerminator(); ValueRange terminatorOperands = terminator->getOperands(); rewriter.setInsertionPointToEnd(mergeBlock); - rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock); + LLVM::BrOp::create(rewriter, loc, terminatorOperands, continueBlock); // Link current block to `true` and `false` blocks within the selection. Block *trueBlock = condBrOp.getTrueBlock(); Block *falseBlock = condBrOp.getFalseBlock(); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<LLVM::CondBrOp>(loc, condBrOp.getCondition(), trueBlock, - condBrOp.getTrueTargetOperands(), - falseBlock, - condBrOp.getFalseTargetOperands()); + LLVM::CondBrOp::create(rewriter, loc, condBrOp.getCondition(), trueBlock, + condBrOp.getTrueTargetOperands(), falseBlock, + condBrOp.getFalseTargetOperands()); rewriter.eraseBlock(headerBlock); rewriter.inlineRegionBefore(op.getBody(), continueBlock); @@ -1490,11 +1493,11 @@ public: Value extended; if (op2TypeWidth < dstTypeWidth) { if (isUnsignedIntegerOrVector(op2Type)) { - extended = rewriter.template create<LLVM::ZExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } else { - extended = rewriter.template create<LLVM::SExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } } else if (op2TypeWidth == dstTypeWidth) { extended = adaptor.getOperand2(); @@ -1502,8 +1505,8 @@ public: return failure(); } - Value result = rewriter.template create<LLVMOp>( - loc, dstType, adaptor.getOperand1(), extended); + Value result = + LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(op, result); return success(); } @@ -1521,8 +1524,8 @@ public: return rewriter.notifyMatchFailure(tanOp, "type conversion failed"); Location loc = tanOp.getLoc(); - Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.getOperand()); - Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.getOperand()); + Value sin = LLVM::SinOp::create(rewriter, loc, dstType, tanOp.getOperand()); + Value cos = LLVM::CosOp::create(rewriter, loc, dstType, tanOp.getOperand()); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos); return success(); } @@ -1549,13 +1552,13 @@ public: Location loc = tanhOp.getLoc(); Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0); Value multiplied = - rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.getOperand()); - Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied); + LLVM::FMulOp::create(rewriter, loc, dstType, two, tanhOp.getOperand()); + Value exponential = LLVM::ExpOp::create(rewriter, loc, dstType, multiplied); Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0); Value numerator = - rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one); + LLVM::FSubOp::create(rewriter, loc, dstType, exponential, one); Value denominator = - rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one); + LLVM::FAddOp::create(rewriter, loc, dstType, exponential, one); rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator, denominator); return success(); @@ -1594,8 +1597,8 @@ public: if (!elementType) return rewriter.notifyMatchFailure(varOp, "type conversion failed"); Value allocated = - rewriter.create<LLVM::AllocaOp>(loc, dstType, elementType, size); - rewriter.create<LLVM::StoreOp>(loc, adaptor.getInitializer(), allocated); + LLVM::AllocaOp::create(rewriter, loc, dstType, elementType, size); + LLVM::StoreOp::create(rewriter, loc, adaptor.getInitializer(), allocated); rewriter.replaceOp(varOp, allocated); return success(); } @@ -1656,7 +1659,7 @@ public: // Create a new `LLVMFuncOp` Location loc = funcOp.getLoc(); StringRef name = funcOp.getName(); - auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType); + auto newFuncOp = LLVM::LLVMFuncOp::create(rewriter, loc, name, llvmType); // Convert SPIR-V Function Control to equivalent LLVM function attribute MLIRContext *context = funcOp.getContext(); @@ -1710,7 +1713,7 @@ public: ConversionPatternRewriter &rewriter) const override { auto newModuleOp = - rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName()); + ModuleOp::create(rewriter, spvModuleOp.getLoc(), spvModuleOp.getName()); rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody()); // Remove the terminator block that was automatically added by builder @@ -1751,7 +1754,7 @@ public: auto componentsArray = components.getValue(); auto *context = rewriter.getContext(); auto llvmI32Type = IntegerType::get(context, 32); - Value targetOp = rewriter.create<LLVM::PoisonOp>(loc, dstType); + Value targetOp = LLVM::PoisonOp::create(rewriter, loc, dstType); for (unsigned i = 0; i < componentsArray.size(); i++) { if (!isa<IntegerAttr>(componentsArray[i])) return op.emitError("unable to support non-constant component"); @@ -1767,16 +1770,17 @@ public: baseVector = vector2; } - Value dstIndex = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); - Value index = rewriter.create<LLVM::ConstantOp>( - loc, llvmI32Type, + Value dstIndex = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, + rewriter.getIntegerAttr(rewriter.getI32Type(), i)); + Value index = LLVM::ConstantOp::create( + rewriter, loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); - auto extractOp = rewriter.create<LLVM::ExtractElementOp>( - loc, scalarType, baseVector, index); - targetOp = rewriter.create<LLVM::InsertElementOp>(loc, dstType, targetOp, - extractOp, dstIndex); + auto extractOp = LLVM::ExtractElementOp::create(rewriter, loc, scalarType, + baseVector, index); + targetOp = LLVM::InsertElementOp::create(rewriter, loc, dstType, targetOp, + extractOp, dstIndex); } rewriter.replaceOp(op, targetOp); return success(); |