diff options
author | Christian Sigg <csigg@google.com> | 2020-11-11 09:42:23 +0100 |
---|---|---|
committer | Christian Sigg <csigg@google.com> | 2020-11-11 10:27:20 +0100 |
commit | 5bdb21df21c6c78554a99754ef14da06a85f9910 (patch) | |
tree | 092a39beb550b4da5b02b4aef32751fa5477fa9d | |
parent | a20b04c984f142d07098ccbc6ea7f73776b20fc6 (diff) | |
download | llvm-5bdb21df21c6c78554a99754ef14da06a85f9910.zip llvm-5bdb21df21c6c78554a99754ef14da06a85f9910.tar.gz llvm-5bdb21df21c6c78554a99754ef14da06a85f9910.tar.bz2 |
[mlir] Use assemblyFormat in AllocLikeOp.
Split operands into dynamicSizes and symbolOperands.
Reviewed By: nicolasvasilache
Differential Revision: https://reviews.llvm.org/D90589
-rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/IR/Ops.h | 10 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/StandardOps/IR/Ops.td | 54 | ||||
-rw-r--r-- | mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 27 | ||||
-rw-r--r-- | mlir/lib/Dialect/StandardOps/IR/Ops.cpp | 115 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/Utils.cpp | 6 | ||||
-rw-r--r-- | mlir/test/IR/invalid-ops.mlir | 6 |
6 files changed, 81 insertions, 137 deletions
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h index d878692..902220b 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -313,16 +313,6 @@ llvm::Optional<SmallVector<bool, 4>> computeRankReductionMask(ArrayRef<int64_t> originalShape, ArrayRef<int64_t> reducedShape); -/// Prints dimension and symbol list. -void printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, unsigned numDims, - OpAsmPrinter &p); - -/// Parses dimension and symbol list and returns true if parsing failed. -ParseResult parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl<Value> &operands, - unsigned &numDims); - /// Determines whether MemRefCastOp casts to a more dynamic version of the /// source memref. This is useful to to fold a memref_cast into a consuming op /// and implement canonicalization patterns for ops in different dialects that diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index d15f06b..652efa7 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -165,20 +165,38 @@ class ComplexFloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> : class AllocLikeOp<string mnemonic, Resource resource, list<OpTrait> traits = []> : - Std_Op<mnemonic, !listconcat([MemoryEffects<[MemAlloc<resource>]>], traits)> { - - let arguments = (ins Variadic<Index>:$value, + Std_Op<mnemonic, + !listconcat([ + MemoryEffects<[MemAlloc<resource>]>, + AttrSizedOperandSegments + ], traits)> { + + let arguments = (ins Variadic<Index>:$dynamicSizes, + // The symbolic operands (the ones in square brackets) bind + // to the symbols of the memref's layout map. + Variadic<Index>:$symbolOperands, Confined<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$alignment); - let results = (outs Res<AnyMemRef, "", [MemAlloc<resource>]>); + let results = (outs Res<AnyMemRef, "", [MemAlloc<resource>]>:$memref); let builders = [ - OpBuilderDAG<(ins "MemRefType":$memrefType), [{ - $_state.types.push_back(memrefType); + OpBuilderDAG<(ins "MemRefType":$memrefType, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, {}, alignment); }]>, - OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$operands, - CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ - $_state.addOperands(operands); + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, dynamicSizes, {}, alignment); + }]>, + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + "ValueRange":$symbolOperands, + CArg<"IntegerAttr", "{}">:$alignment), [{ $_state.types.push_back(memrefType); + $_state.addOperands(dynamicSizes); + $_state.addOperands(symbolOperands); + $_state.addAttribute(getOperandSegmentSizeAttr(), + $_builder.getI32VectorAttr({ + static_cast<int32_t>(dynamicSizes.size()), + static_cast<int32_t>(symbolOperands.size())})); if (alignment) $_state.addAttribute(getAlignmentAttrName(), alignment); }]>]; @@ -188,23 +206,13 @@ class AllocLikeOp<string mnemonic, MemRefType getType() { return getResult().getType().cast<MemRefType>(); } - /// Returns the number of symbolic operands (the ones in square brackets), - /// which bind to the symbols of the memref's layout map. - unsigned getNumSymbolicOperands() { - return getNumOperands() - getType().getNumDynamicDims(); - } - - /// Returns the symbolic operands (the ones in square brackets), which bind - /// to the symbols of the memref's layout map. - operand_range getSymbolicOperands() { - return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; - } - /// Returns the dynamic sizes for this alloc operation if specified. - operand_range getDynamicSizes() { return getOperands(); } + operand_range getDynamicSizes() { return dynamicSizes(); } }]; - let parser = [{ return ::parseAllocLikeOp(parser, result); }]; + let assemblyFormat = [{ + `(`$dynamicSizes`)` (`` `[` $symbolOperands^ `]`)? attr-dict `:` type($memref) + }]; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 4cb6821..4c2196f 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -297,6 +297,33 @@ static bool isValidAffineIndexOperand(Value value, Region *region) { return isValidDim(value, region) || isValidSymbol(value, region); } +/// Prints dimension and symbol list. +static void printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, + unsigned numDims, OpAsmPrinter &printer) { + OperandRange operands(begin, end); + printer << '(' << operands.take_front(numDims) << ')'; + if (operands.size() > numDims) + printer << '[' << operands.drop_front(numDims) << ']'; +} + +/// Parses dimension and symbol list and returns true if parsing failed. +static ParseResult parseDimAndSymbolList(OpAsmParser &parser, + SmallVectorImpl<Value> &operands, + unsigned &numDims) { + SmallVector<OpAsmParser::OperandType, 8> opInfos; + if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) + return failure(); + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto indexTy = parser.getBuilder().getIndexType(); + return failure(parser.parseOperandList( + opInfos, OpAsmParser::Delimiter::OptionalSquare) || + parser.resolveOperands(opInfos, indexTy, operands)); +} + /// Utility function to verify that a set of operands are valid dimension and /// symbol identifiers. The operands should be laid out such that the dimension /// operands are before the symbol operands. This function returns failure if diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index d333ddc..bc584bd 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -169,36 +169,6 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, return builder.create<ConstantOp>(loc, type, value); } -void mlir::printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, - unsigned numDims, OpAsmPrinter &p) { - Operation::operand_range operands(begin, end); - p << '(' << operands.take_front(numDims) << ')'; - if (operands.size() != numDims) - p << '[' << operands.drop_front(numDims) << ']'; -} - -// Parses dimension and symbol list, and sets 'numDims' to the number of -// dimension operands parsed. -// Returns 'false' on success and 'true' on error. -ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl<Value> &operands, - unsigned &numDims) { - SmallVector<OpAsmParser::OperandType, 8> opInfos; - if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) - return failure(); - // Store number of dimensions for validation by caller. - numDims = opInfos.size(); - - // Parse the optional symbol operands. - auto indexTy = parser.getBuilder().getIndexType(); - if (parser.parseOperandList(opInfos, - OpAsmParser::Delimiter::OptionalSquare) || - parser.resolveOperands(opInfos, indexTy, operands)) - return failure(); - return success(); -} - /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses m_Constant /// and checks the operation for an index type. @@ -404,90 +374,37 @@ static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { //===----------------------------------------------------------------------===// template <typename AllocLikeOp> -static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) { - static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, - "applies to only alloc or alloca"); - p << name; - - // Print dynamic dimension operands. - MemRefType type = op.getType(); - printDimAndSymbolList(op.operand_begin(), op.operand_end(), - type.getNumDynamicDims(), p); - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); - p << " : " << type; -} - -static void print(OpAsmPrinter &p, AllocOp op) { - printAllocLikeOp(p, op, "alloc"); -} - -static void print(OpAsmPrinter &p, AllocaOp op) { - printAllocLikeOp(p, op, "alloca"); -} - -static ParseResult parseAllocLikeOp(OpAsmParser &parser, - OperationState &result) { - MemRefType type; - - // Parse the dimension operands and optional symbol operands, followed by a - // memref type. - unsigned numDimOperands; - if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return failure(); - - // Check numDynamicDims against number of question marks in memref type. - // Note: this check remains here (instead of in verify()), because the - // partition between dim operands and symbol operands is lost after parsing. - // Verification still checks that the total number of operands matches - // the number of symbols in the affine map, plus the number of dynamic - // dimensions in the memref. - if (numDimOperands != type.getNumDynamicDims()) - return parser.emitError(parser.getNameLoc()) - << "dimension operand count does not equal memref dynamic dimension " - "count"; - result.types.push_back(type); - return success(); -} - -template <typename AllocLikeOp> -static LogicalResult verify(AllocLikeOp op) { +static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value, "applies to only alloc or alloca"); auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>(); if (!memRefType) return op.emitOpError("result must be a memref"); - unsigned numSymbols = 0; - if (!memRefType.getAffineMaps().empty()) { - // Store number of symbols used in affine map (used in subsequent check). - AffineMap affineMap = memRefType.getAffineMaps()[0]; - numSymbols = affineMap.getNumSymbols(); - } + if (static_cast<int64_t>(op.dynamicSizes().size()) != + memRefType.getNumDynamicDims()) + return op.emitOpError("dimension operand count does not equal memref " + "dynamic dimension count"); - // Check that the total number of operands matches the number of symbols in - // the affine map, plus the number of dynamic dimensions specified in the - // memref type. - unsigned numDynamicDims = memRefType.getNumDynamicDims(); - if (op.getNumOperands() != numDynamicDims + numSymbols) + unsigned numSymbols = 0; + if (!memRefType.getAffineMaps().empty()) + numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); + if (op.symbolOperands().size() != numSymbols) return op.emitOpError( - "operand count does not equal dimension plus symbol operand count"); + "symbol operand count does not equal memref symbol count"); - // Verify that all operands are of type Index. - for (auto operandType : op.getOperandTypes()) - if (!operandType.isIndex()) - return op.emitOpError("requires operands to be of type Index"); + return success(); +} - if (std::is_same<AllocLikeOp, AllocOp>::value) - return success(); +static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } +static LogicalResult verify(AllocaOp op) { // An alloca op needs to have an ancestor with an allocation scope trait. - if (!op.template getParentWithTrait<OpTrait::AutomaticAllocationScope>()) + if (!op.getParentWithTrait<OpTrait::AutomaticAllocationScope>()) return op.emitOpError( "requires an ancestor op with AutomaticAllocationScope trait"); - return success(); + return verifyAllocLikeOp(op); } namespace { diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 516f8c0..cef0a827 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -401,7 +401,7 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { // Fetch a new memref type after normalizing the old memref to have an // identity map layout. MemRefType newMemRefType = - normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands()); + normalizeMemRefType(memrefType, b, allocOp.symbolOperands().size()); if (newMemRefType == memrefType) // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. @@ -409,9 +409,9 @@ LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { Value oldMemRef = allocOp.getResult(); - SmallVector<Value, 4> symbolOperands(allocOp.getSymbolicOperands()); + SmallVector<Value, 4> symbolOperands(allocOp.symbolOperands()); AllocOp newAlloc = b.create<AllocOp>(allocOp.getLoc(), newMemRefType, - llvm::None, allocOp.alignmentAttr()); + allocOp.alignmentAttr()); AffineMap layoutMap = memrefType.getAffineMaps().front(); // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 76aff5c..eb24774 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -87,7 +87,8 @@ func @bad_alloc_wrong_dynamic_dim_count() { ^bb0: %0 = constant 7 : index // Test alloc with wrong number of dynamic dimensions. - %1 = alloc(%0)[%1] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{op 'std.alloc' dimension operand count does not equal memref dynamic dimension count}} + // expected-error@+1 {{dimension operand count does not equal memref dynamic dimension count}} + %1 = alloc(%0)[%0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> return } @@ -97,7 +98,8 @@ func @bad_alloc_wrong_symbol_count() { ^bb0: %0 = constant 7 : index // Test alloc with wrong number of symbols - %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{operand count does not equal dimension plus symbol operand count}} + // expected-error@+1 {{symbol operand count does not equal memref symbol count}} + %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> return } |