diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index 274aaca..de582f0 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -66,6 +66,7 @@ endif () # Directories with MLIR dialects add_subdirectory(include/graphalg) +add_subdirectory(include/garel) add_subdirectory(src) add_subdirectory(test) diff --git a/compiler/include/garel/CMakeLists.txt b/compiler/include/garel/CMakeLists.txt new file mode 100644 index 0000000..f221bc4 --- /dev/null +++ b/compiler/include/garel/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories(SYSTEM ${MLIR_INCLUDE_DIRS}) +include_directories(SYSTEM ${PROJECT_BINARY_DIR}/include) + +set(LLVM_TARGET_DEFINITIONS GARelAttr.td) +mlir_tablegen(GARelEnumAttr.h.inc -gen-enum-decls) +mlir_tablegen(GARelEnumAttr.cpp.inc -gen-enum-defs) +mlir_tablegen(GARelAttr.h.inc --gen-attrdef-decls) +mlir_tablegen(GARelAttr.cpp.inc -gen-attrdef-defs) + +set(LLVM_TARGET_DEFINITIONS GARelOps.td) +add_mlir_dialect(GARelOps garel) + +set(LLVM_TARGET_DEFINITIONS GARelPasses.td) +mlir_tablegen(GARelPasses.h.inc --gen-pass-decls) +add_public_tablegen_target(MLIRGARelPassesIncGen) diff --git a/compiler/include/garel/GARelAttr.h b/compiler/include/garel/GARelAttr.h new file mode 100644 index 0000000..7a5f505 --- /dev/null +++ b/compiler/include/garel/GARelAttr.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include + +#include "garel/GARelEnumAttr.h.inc" + +namespace garel { + +/** Reference to a column inside of \c RelationType or \c TupleType. */ +using ColumnIdx = std::int32_t; + +} // namespace garel + +#define GET_ATTRDEF_CLASSES +#include "garel/GARelAttr.h.inc" diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td new file mode 100644 index 0000000..c09797c --- /dev/null +++ b/compiler/include/garel/GARelAttr.td @@ -0,0 +1,72 @@ +#ifndef GAREL_ATTR +#define GAREL_ATTR + +include "mlir/IR/BuiltinAttributeInterfaces.td" + +include "GARelDialect.td" + +class GARel_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +def JoinPredicate : GARel_Attr<"JoinPredicate", "join_pred"> { + let summary = "A binary equality join predicate"; + + let parameters = (ins + "std::int32_t":$lhsRelIdx, + "ColumnIdx":$lhsColIdx, + "std::int32_t":$rhsRelIdx, + "ColumnIdx":$rhsColIdx); + + let assemblyFormat = [{ + `<` $lhsRelIdx `[` $lhsColIdx `]` `=` $rhsRelIdx `[` $rhsColIdx `]` `>` + }]; +} + +def JoinPredicates : ArrayOfAttr< + GARel_Dialect, + "JoinPredicates", + "join_preds", + "JoinPredicateAttr">; + +def AggregateFunc : I64EnumAttr< + "AggregateFunc", "", + [ + I64EnumAttrCase<"SUM", 0>, + I64EnumAttrCase<"MIN", 1>, + I64EnumAttrCase<"MAX", 2>, + I64EnumAttrCase<"LOR", 3>, /* Logical OR (over i1) */ + I64EnumAttrCase<"ARGMIN", 4>, + ] +> { + let cppNamespace = "::garel"; +} + +// NOTE: assumes an aggregator produces exactly one output column. +def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { + let summary = "Aggregate function with bound input columns"; + + let parameters = (ins + "AggregateFunc":$func, + ArrayRefParameter<"ColumnIdx">:$inputs); + + let assemblyFormat = [{ + `<` $func $inputs `>` + }]; + + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + /** Type for values in the output column. */ + mlir::Type getResultType(::mlir::Type relType); + }]; +} + +def Aggregators : ArrayOfAttr< + GARel_Dialect, + "Aggregators", + "aggregators", + "AggregatorAttr">; + +#endif // GAREL_ATTR diff --git a/compiler/include/garel/GARelDialect.h b/compiler/include/garel/GARelDialect.h new file mode 100644 index 0000000..0149f92 --- /dev/null +++ b/compiler/include/garel/GARelDialect.h @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +#include "garel/GARelOpsDialect.h.inc" diff --git a/compiler/include/garel/GARelDialect.td b/compiler/include/garel/GARelDialect.td new file mode 100644 index 0000000..725499d --- /dev/null +++ b/compiler/include/garel/GARelDialect.td @@ -0,0 +1,27 @@ +#ifndef GAREL_DIALECT +#define GAREL_DIALECT + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" + +def GARel_Dialect : Dialect { + let name = "garel"; + let cppNamespace = "::garel"; + + let extraClassDeclaration = [{ + private: + void registerAttributes(); + void registerTypes(); + }]; + + let usePropertiesForAttributes = 1; + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; +} + +// NOTE: GARel ops are always 'Pure' +class GARel_Op traits = []> : + Op; + +#endif // GAREL_DIALECT diff --git a/compiler/include/garel/GARelOps.h b/compiler/include/garel/GARelOps.h new file mode 100644 index 0000000..5c158fe --- /dev/null +++ b/compiler/include/garel/GARelOps.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "garel/GARelTypes.h" + +#define GET_OP_CLASSES +#include "garel/GARelOps.h.inc" diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td new file mode 100644 index 0000000..f8ccaa3 --- /dev/null +++ b/compiler/include/garel/GARelOps.td @@ -0,0 +1,224 @@ +#ifndef GAREL_OPS +#define GAREL_OPS + +/** + * Relation-level ops in the GARel dialect. + * + * See ./GARelTupleOps.td for the tuple-level ops. + */ + +include "mlir/Interfaces/InferTypeOpInterface.td" + +include "GARelDialect.td" +include "GARelTypes.td" + +// Per-tuple ops kept in a separate file. +include "GARelTupleOps.td" + +def ProjectOp : GARel_Op<"project", [IsolatedFromAbove]> { + let summary = "Remaps, reorders, drops and computes columns"; + + let arguments = (ins Relation:$input); + + let regions = (region SizedRegion<1>:$projections); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) `->` type($result) $projections attr-dict + }]; + + let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + mlir::Block& createProjectionsBlock(); + ProjectReturnOp getTerminator(); + }]; +} + +def ProjectReturnOp : GARel_Op<"project.return", [ + Terminator, + HasParent<"ProjectOp">]> { + let summary = "The output projections"; + + let arguments = (ins Variadic:$projections); + + let assemblyFormat = [{ + $projections `:` type($projections) attr-dict + }]; + + // NOTE: verification performed by ProjectOp +} + +def SelectOp : GARel_Op<"select", [ + SameOperandsAndResultType, + IsolatedFromAbove]> { + let summary = "Removes tuples that fail (one of) the predicates"; + + let arguments = (ins Relation:$input); + let regions = (region SizedRegion<1>:$predicates); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) $predicates attr-dict + }]; + + let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + mlir::Block& createPredicatesBlock(); + SelectReturnOp getTerminator(); + }]; +} + +def SelectReturnOp : GARel_Op<"select.return", [ + Terminator, + HasParent<"SelectOp">]> { + let summary = "Return the select predicates"; + + let arguments = (ins Variadic:$predicates); + + let assemblyFormat = [{ + $predicates attr-dict + }]; +} + +def JoinOp : GARel_Op<"join", [InferTypeOpAdaptor]> { + let summary = "Natural (equi)join of relations"; + + let arguments = (ins + // NOTE: All inputs must have distinct columns + Variadic:$inputs, + // NOTE: Equality predicates only + JoinPredicates:$predicates); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $inputs `:` type($inputs) + $predicates + attr-dict + }]; + + let hasVerifier = 1; + let hasFolder = 1; +} + +def UnionOp : GARel_Op<"union", [SameOperandsAndResultType]> { + let summary = "Union of relations"; + + let arguments = (ins Variadic:$inputs); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $inputs `:` type($inputs) + attr-dict + }]; + + let hasFolder = 1; +} + +def AggregateOp : GARel_Op<"aggregate", [InferTypeOpAdaptor]> { + let summary = "Groups tuples by key columns, aggregating values of other columns"; + + let arguments = (ins + Relation:$input, + DenseI32ArrayAttr:$groupBy, + Aggregators:$aggregators); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) + `group_by` `` `=` `` $groupBy + `aggregators` `` `=` `` $aggregators + attr-dict + }]; +} + +def ForOp : GARel_Op<"for", [InferTypeOpAdaptor]> { + let summary = "Bounded iteration"; + + let arguments = (ins + Variadic:$init, + I64Attr:$iters, + I64Attr:$resultIdx); + + let regions = (region + SizedRegion<1>:$body, + MaxSizedRegion<1>:$until); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $init `:` type($init) + `iters` `` `=` `` $iters + `result_idx` `` `=` `` $resultIdx + $body + (`until` $until^)? + attr-dict + }]; + + let hasVerifier = 1; + let hasRegionVerifier = 1; +} + +def ForYieldOp : GARel_Op<"for.yield", [ + Terminator, + HasParent<"ForOp">]> { + let summary = "Produces the iter args for the next iteration"; + + let arguments = (ins Variadic:$inputs); + + let assemblyFormat = [{ + $inputs `:` type($inputs) + attr-dict + }]; + + // Note: verification performed by parent ForOp. +} + +def RangeOp : GARel_Op<"range", [InferTypeOpAdaptor]> { + let summary = "generates a range of `index` values from `[0, size)`"; + + let arguments = (ins I64Attr:$size); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $size attr-dict + }]; +} + +def RemapOp : GARel_Op<"remap", [InferTypeOpAdaptor]> { + let summary = "reorders or drop columns"; + + let arguments = (ins Relation:$input, DenseI32ArrayAttr:$remap); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) + $remap + attr-dict + }]; + + let hasVerifier = 1; + let hasFolder = 1; +} + +def ConstantOp : GARel_Op<"const", [InferTypeOpAdaptor]> { + let summary = "A relation with one constant-value tuple"; + + let arguments = (ins TypedAttrInterface:$value); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $value attr-dict + }]; +} + +#endif // GAREL_OPS diff --git a/compiler/include/garel/GARelPasses.h b/compiler/include/garel/GARelPasses.h new file mode 100644 index 0000000..e432976 --- /dev/null +++ b/compiler/include/garel/GARelPasses.h @@ -0,0 +1,11 @@ +#pragma once + +namespace garel { + +#define GEN_PASS_DECL +#include "garel/GARelPasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "garel/GARelPasses.h.inc" + +} // namespace garel diff --git a/compiler/include/garel/GARelPasses.td b/compiler/include/garel/GARelPasses.td new file mode 100644 index 0000000..e378828 --- /dev/null +++ b/compiler/include/garel/GARelPasses.td @@ -0,0 +1,15 @@ +#ifndef GAREL_PASSES +#define GAREL_PASSES + +include "mlir/Pass/PassBase.td" + +def GraphAlgToRel : Pass<"graphalg-to-rel", "::mlir::ModuleOp"> { + let summary = "Convert GraphAlg Core IR to relational ops"; + + let dependentDialects = [ + "garel::GARelDialect", + "graphalg::GraphAlgDialect", + ]; +} + +#endif // GAREL_PASSES diff --git a/compiler/include/garel/GARelTupleOps.td b/compiler/include/garel/GARelTupleOps.td new file mode 100644 index 0000000..0244442 --- /dev/null +++ b/compiler/include/garel/GARelTupleOps.td @@ -0,0 +1,28 @@ +#ifndef GAREL_TUPLE_OPS +#define GAREL_TUPLE_OPS +/** + * Tuple-level ops in the GARel dialect. + * + * See ./GARelOps.td for the relation-level ops. + */ + +include "mlir/Interfaces/InferTypeOpInterface.td" + +include "GARelDialect.td" +include "GARelTypes.td" + +def ExtractOp : GARel_Op<"extract", [InferTypeOpAdaptor]> { + let summary = "Extract the value of one column from a tuple"; + + let arguments = (ins I32Attr:$column, Tuple:$tuple); + + let results = (outs ColumnType:$result); + + let assemblyFormat = [{ + $column $tuple `:` type($tuple) attr-dict + }]; + + let hasVerifier = 1; +} + +#endif // GAREL_TUPLE_OPS diff --git a/compiler/include/garel/GARelTypes.h b/compiler/include/garel/GARelTypes.h new file mode 100644 index 0000000..e6a4c90 --- /dev/null +++ b/compiler/include/garel/GARelTypes.h @@ -0,0 +1,14 @@ +#pragma once + +#include + +#include "garel/GARelAttr.h" + +#define GET_TYPEDEF_CLASSES +#include "garel/GARelOpsTypes.h.inc" + +namespace garel { + +bool isColumnType(mlir::Type t); + +} // namespace garel diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td new file mode 100644 index 0000000..5d848b5 --- /dev/null +++ b/compiler/include/garel/GARelTypes.td @@ -0,0 +1,37 @@ + +#ifndef GAREL_TYPES +#define GAREL_TYPES + +include "mlir/IR/AttrTypeBase.td" + +include "GARelDialect.td" +include "GARelAttr.td" + +class GARel_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def Relation : GARel_Type<"Relation", "rel"> { + let summary = "A set of tuples"; + + let parameters = (ins OptionalArrayRefParameter<"mlir::Type">:$columns); + + let assemblyFormat = [{ + `<` $columns `>` + }]; +} + +def Tuple : GARel_Type<"Tuple", "tuple"> { + let summary = "A single tuple"; + + let parameters = (ins OptionalArrayRefParameter<"mlir::Type">:$columns); + + let assemblyFormat = [{ + `<` $columns `>` + }]; +} + +def ColumnType : Type, "column type">; + +#endif // GAREL_TYPES diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index de9518a..8217f22 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(graphalg) +add_subdirectory(garel) diff --git a/compiler/src/garel/CMakeLists.txt b/compiler/src/garel/CMakeLists.txt new file mode 100644 index 0000000..6c8622e --- /dev/null +++ b/compiler/src/garel/CMakeLists.txt @@ -0,0 +1,37 @@ +add_library(GARelIR + GARelAttr.cpp + GARelDialect.cpp + GARelOps.cpp + GARelTupleOps.cpp + GARelTypes.cpp +) +target_include_directories(GARelIR PUBLIC ../../include) +target_include_directories(GARelIR SYSTEM PUBLIC ${PROJECT_BINARY_DIR}/include) +add_dependencies(GARelIR + MLIRGARelOpsIncGen + MLIRGARelPassesIncGen +) +# Suppress -Wdangling-assignment-gsl for generated code from MLIR tablegen +# The warning is triggered by template instantiations in GARelOps.cpp.inc +target_compile_options(GARelIR PRIVATE -Wno-dangling-assignment-gsl) +if(NOT GRAPHALG_ENABLE_RTTI) + target_compile_options(GARelIR PUBLIC -fno-rtti) +endif() +target_link_libraries( + GARelIR + PRIVATE + MLIRInferTypeOpInterface + MLIRIR + MLIRSupport +) + +add_library(GraphAlgToRel + GraphAlgToRel.cpp +) +target_link_libraries( + GraphAlgToRel + PRIVATE + GraphAlgIR + GARelIR + MLIRPass +) diff --git a/compiler/src/garel/GARelAttr.cpp b/compiler/src/garel/GARelAttr.cpp new file mode 100644 index 0000000..abb69cc --- /dev/null +++ b/compiler/src/garel/GARelAttr.cpp @@ -0,0 +1,58 @@ +#include +#include +#include +#include +#include +#include + +#include "garel/GARelAttr.h" +#include "garel/GARelDialect.h" + +#include "garel/GARelEnumAttr.cpp.inc" +#include "garel/GARelTypes.h" +#define GET_ATTRDEF_CLASSES +#include "garel/GARelAttr.cpp.inc" + +namespace garel { + +mlir::Type AggregatorAttr::getResultType(mlir::Type inputRel) { + switch (getFunc()) { + case AggregateFunc::SUM: + case AggregateFunc::MIN: + case AggregateFunc::MAX: + case AggregateFunc::LOR: + case AggregateFunc::ARGMIN: + // NOTE: argmin(arg, val) also uses first input column as output type. + return llvm::cast(inputRel).getColumns()[getInputs()[0]]; + } +} + +mlir::LogicalResult +AggregatorAttr::verify(llvm::function_ref emitError, + AggregateFunc func, llvm::ArrayRef inputs) { + if (func == AggregateFunc::ARGMIN) { + if (inputs.size() != 2) { + return emitError() << stringifyAggregateFunc(func) + << " expects exactly two inputs (arg, val), got " + << inputs.size(); + } + } else { + if (inputs.size() != 1) { + return emitError() << stringifyAggregateFunc(func) + << " expects exactly one input, got " << inputs.size(); + } + } + + return mlir::success(); +} + +// Need to define this here to avoid depending on GARelAttr in +// GARelDialect and creating a cycle. +void GARelDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "garel/GARelAttr.cpp.inc" + >(); +} + +} // namespace garel diff --git a/compiler/src/garel/GARelDialect.cpp b/compiler/src/garel/GARelDialect.cpp new file mode 100644 index 0000000..7d2fae3 --- /dev/null +++ b/compiler/src/garel/GARelDialect.cpp @@ -0,0 +1,17 @@ +#include "garel/GARelDialect.h" +#include "garel/GARelOps.h" + +#include "garel/GARelOpsDialect.cpp.inc" + +namespace garel { + +void GARelDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "garel/GARelOps.cpp.inc" + >(); + registerAttributes(); + registerTypes(); +} + +} // namespace garel diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp new file mode 100644 index 0000000..b70abcd --- /dev/null +++ b/compiler/src/garel/GARelOps.cpp @@ -0,0 +1,345 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "garel/GARelAttr.h" +#include "garel/GARelDialect.h" +#include "garel/GARelOps.h" +#include "garel/GARelTypes.h" +#include "llvm/ADT/ArrayRef.h" + +#define GET_OP_CLASSES +#include "garel/GARelOps.cpp.inc" + +namespace garel { + +// === ProjectOp === +mlir::LogicalResult ProjectOp::verifyRegions() { + if (getProjections().getNumArguments() != 1) { + return emitOpError("projections block should have exactly one argument"); + } + + auto blockArg = getProjections().getArgument(0); + auto blockType = llvm::dyn_cast(blockArg.getType()); + if (!blockType) { + return emitOpError("projections block arg must be of type tuple"); + } + + if (getInput().getType().getColumns() != blockType.getColumns()) { + return emitOpError("projections block columns do not match input columns"); + } + + auto terminator = getProjections().front().getTerminator(); + if (!terminator) { + return emitOpError("missing return from projections block"); + } + + auto returnOp = llvm::dyn_cast(terminator); + if (!returnOp) { + return emitOpError("projections block not terminated by project.return"); + } + + if (returnOp.getProjections().size() != getType().getColumns().size()) { + return emitOpError("projections block returns a different number of " + "values than specified in the projection return type"); + } + + for (const auto &[val, col] : + llvm::zip_equal(returnOp.getProjections(), getType().getColumns())) { + if (val.getType() != col) { + return emitOpError("projections block return types do not match the " + "projection output column types"); + } + } + + return mlir::success(); +} + +mlir::Block &ProjectOp::createProjectionsBlock() { + assert(getProjections().empty() && "Already have a projections block"); + auto &block = getProjections().emplaceBlock(); + // Same columns as the input, but as a tuple. + block.addArgument( + TupleType::get(getContext(), getInput().getType().getColumns()), + getInput().getLoc()); + return block; +} + +ProjectReturnOp ProjectOp::getTerminator() { + return llvm::cast(getProjections().front().getTerminator()); +} + +// === SelectOp === +mlir::LogicalResult SelectOp::verifyRegions() { + if (getPredicates().getNumArguments() != 1) { + return emitOpError("predicates block should have exactly one argument"); + } + + auto blockArg = getPredicates().getArgument(0); + auto blockType = llvm::dyn_cast(blockArg.getType()); + if (!blockType) { + return emitOpError("predicates block arg must be of type tuple"); + } + + if (getInput().getType().getColumns() != blockType.getColumns()) { + return emitOpError("predicates block slots do not match child slots"); + } + + auto terminator = getPredicates().front().getTerminator(); + if (!terminator || !llvm::isa(terminator)) { + return emitOpError("predicates block not terminated with select.return"); + } + + return mlir::success(); +} + +mlir::Block &SelectOp::createPredicatesBlock() { + assert(getPredicates().empty() && "Already have a predicates block"); + auto &block = getPredicates().emplaceBlock(); + // Same columns as the input, but as a tuple. + block.addArgument( + TupleType::get(getContext(), getInput().getType().getColumns()), + getInput().getLoc()); + return block; +} + +SelectReturnOp SelectOp::getTerminator() { + return llvm::cast(getPredicates().front().getTerminator()); +} + +// === JoinOp === +mlir::OpFoldResult JoinOp::fold(FoldAdaptor adaptor) { + if (getInputs().size() == 1) { + // no-op if we only have one input. + assert(getInputs()[0].getType() == getType()); + return getInputs()[0]; + } + + return nullptr; +} + +mlir::LogicalResult JoinOp::verify() { + for (auto pred : getPredicates()) { + // Valid input relation + if (pred.getLhsRelIdx() >= getInputs().size()) { + return emitOpError("predicate refers to input relation ") + << pred.getLhsRelIdx() << ", but there are only " + << getInputs().size() << " input relations: " << pred; + } else if (pred.getRhsRelIdx() >= getInputs().size()) { + return emitOpError("predicate refers to input relation ") + << pred.getRhsRelIdx() << ", but there are only " + << getInputs().size() << " input relations: " << pred; + } + + if (pred.getLhsRelIdx() == pred.getRhsRelIdx()) { + return emitOpError("predicate between columns of the same relation: ") + << pred; + } + + // Valid column on LHS relation. + auto lhsInputType = + llvm::cast(getInputs()[pred.getLhsRelIdx()].getType()); + if (pred.getLhsColIdx() >= lhsInputType.getColumns().size()) { + auto diag = emitOpError("predicate refers to column ") + << pred.getLhsColIdx() << ", but there are only " + << lhsInputType.getColumns().size() + << " input columns: " << pred; + diag.attachNote(getInputs()[pred.getLhsRelIdx()].getLoc()) + << "input relation defined here"; + return diag; + } + + // Valid column on RHS relation. + auto rhsInputType = + llvm::cast(getInputs()[pred.getRhsRelIdx()].getType()); + if (pred.getRhsColIdx() >= rhsInputType.getColumns().size()) { + auto diag = emitOpError("predicate refers to column ") + << pred.getRhsColIdx() << ", but there are only " + << rhsInputType.getColumns().size() + << " input columns: " << pred; + diag.attachNote(getInputs()[pred.getRhsRelIdx()].getLoc()) + << "input relation defined here"; + return diag; + } + } + + return mlir::success(); +} + +mlir::LogicalResult JoinOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns; + for (auto input : adaptor.getInputs()) { + auto inputColumns = llvm::cast(input.getType()).getColumns(); + outputColumns.append(inputColumns.begin(), inputColumns.end()); + } + + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); + return mlir::success(); +} + +// === UnionOp === +mlir::OpFoldResult UnionOp::fold(FoldAdaptor adaptor) { + if (getInputs().size() == 1) { + // no-op if we only have one input. + assert(getInputs()[0].getType() == getType()); + return getInputs()[0]; + } + + return nullptr; +} + +// === AggregateOp === +mlir::LogicalResult AggregateOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns; + + auto inputType = llvm::cast(adaptor.getInput().getType()); + auto inputColumns = inputType.getColumns(); + + // Key columns + for (auto key : adaptor.getGroupBy()) { + outputColumns.push_back(inputColumns[key]); + } + + // Aggregator outputs + for (auto agg : adaptor.getAggregators()) { + outputColumns.push_back(agg.getResultType(inputType)); + } + + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); + return mlir::success(); +} + +// === ForOp === +static mlir::LogicalResult +verifyResultIdx(llvm::function_ref emitError, + mlir::ValueRange initArgs, std::uint64_t resultIdx) { + // resultIdx is within bounds of init args. + if (initArgs.size() <= resultIdx) { + return emitError() << "has result_idx=" << resultIdx + << ", but there are only " << initArgs.size() + << " init args"; + } + + return mlir::success(); +} + +mlir::LogicalResult ForOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + auto loc = location ? *location : mlir::UnknownLoc::get(ctx); + if (mlir::failed(verifyResultIdx( + [&]() { + return mlir::emitError(loc) + << ForOp::getOperationName() << " to build with init args " + << adaptor.getInit() << " "; + }, + adaptor.getInit(), adaptor.getResultIdx()))) { + return mlir::failure(); + } + + auto resultType = adaptor.getInit()[adaptor.getResultIdx()].getType(); + inferredReturnTypes.emplace_back(resultType); + return mlir::success(); +} + +mlir::LogicalResult ForOp::verify() { + return verifyResultIdx([this]() { return emitOpError(); }, getInit(), + getResultIdx()); +} + +mlir::LogicalResult ForOp::verifyRegions() { + auto initTypes = getInit().getTypes(); + + // Body arg types match init args + auto argTypes = getBody().front().getArgumentTypes(); + if (initTypes != argTypes) { + return emitOpError("body arg types do not match the initial value types"); + } + + // Body result types match init args + auto yieldOp = llvm::cast(getBody().front().getTerminator()); + auto resTypes = yieldOp.getInputs().getTypes(); + if (initTypes != resTypes) { + auto diag = + emitOpError("body result types do not match the initial value types"); + diag.attachNote(yieldOp.getLoc()) << "body result is here"; + return diag; + } + + return mlir::success(); +} + +// === RangeOp === +mlir::LogicalResult RangeOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back( + RelationType::get(ctx, {mlir::IndexType::get(ctx)})); + return mlir::success(); +} + +// === RemapOp === +mlir::LogicalResult RemapOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns; + auto inputType = llvm::cast(adaptor.getInput().getType()); + for (auto inputCol : adaptor.getRemap()) { + outputColumns.push_back(inputType.getColumns()[inputCol]); + } + + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); + return mlir::success(); +} + +mlir::LogicalResult RemapOp::verify() { + auto inputColumns = getInput().getType().getColumns(); + for (auto inputCol : getRemap()) { + if (inputCol >= inputColumns.size()) { + return emitOpError("remap refers to input column ") + << inputCol << ", but input only has " << inputColumns.size() + << " columns"; + } + } + + return mlir::success(); +} + +// Checks for mapping [0, 1, 2, ...] +static bool isIdentityRemap(llvm::ArrayRef indexes) { + for (auto [i, idx] : llvm::enumerate(indexes)) { + if (i != idx) { + return false; + } + } + + return true; +} + +mlir::OpFoldResult RemapOp::fold(FoldAdaptor adaptor) { + if (isIdentityRemap(getRemap())) { + assert(getInput().getType() == getType()); + return getInput(); + } + + return nullptr; +} + +// === ConstantOp === +mlir::LogicalResult ConstantOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns{adaptor.getValue().getType()}; + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); + return mlir::success(); +} + +} // namespace garel diff --git a/compiler/src/garel/GARelTupleOps.cpp b/compiler/src/garel/GARelTupleOps.cpp new file mode 100644 index 0000000..a78f269 --- /dev/null +++ b/compiler/src/garel/GARelTupleOps.cpp @@ -0,0 +1,28 @@ +#include + +#include "garel/GARelOps.h" +#include "garel/GARelTypes.h" + +namespace garel { + +mlir::LogicalResult ExtractOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + // TODO: Do we need this cast? + auto tupleType = llvm::cast(adaptor.getTuple().getType()); + auto columnTypes = tupleType.getColumns(); + inferredReturnTypes.push_back(columnTypes[adaptor.getColumn()]); + return mlir::success(); +} + +mlir::LogicalResult ExtractOp::verify() { + auto columns = getTuple().getType().getColumns(); + if (getColumn() >= getTuple().getType().getColumns().size()) { + return emitOpError("column ") + << getColumn() << " not included in tuple " << getTuple().getType(); + } + + return mlir::success(); +} + +} // namespace garel diff --git a/compiler/src/garel/GARelTypes.cpp b/compiler/src/garel/GARelTypes.cpp new file mode 100644 index 0000000..618642d --- /dev/null +++ b/compiler/src/garel/GARelTypes.cpp @@ -0,0 +1,30 @@ +#include +#include +#include +#include +#include + +#include "garel/GARelDialect.h" +#include "garel/GARelTypes.h" + +#define GET_TYPEDEF_CLASSES +#include "garel/GARelOpsTypes.cpp.inc" + +namespace garel { + +bool isColumnType(mlir::Type t) { + // Allow i1, i64, f64, index + return t.isSignlessInteger(1) || t.isSignlessInteger(64) || t.isF64() || + t.isIndex(); +} + +// Need to define this here to avoid depending on IPRTypes in +// IPRDialect and creating a cycle. +void GARelDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "garel/GARelOpsTypes.cpp.inc" + >(); +} + +} // namespace garel diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp new file mode 100644 index 0000000..ee60821 --- /dev/null +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -0,0 +1,1325 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "garel/GARelAttr.h" +#include "garel/GARelDialect.h" +#include "garel/GARelOps.h" +#include "garel/GARelTypes.h" +#include "graphalg/GraphAlgAttr.h" +#include "graphalg/GraphAlgCast.h" +#include "graphalg/GraphAlgDialect.h" +#include "graphalg/GraphAlgOps.h" +#include "graphalg/GraphAlgTypes.h" +#include "graphalg/SemiringTypes.h" + +namespace garel { + +#define GEN_PASS_DEF_GRAPHALGTOREL +#include "garel/GARelPasses.h.inc" + +namespace { + +/** + * Converts all GraphAlg IR ops into relation ops from the GARel dialect. + * + * Matrices, vectors and scalars are converted into relations: + * - Matrices => (row, column, value) tuples + * - Vectors => (row, value) tuples + * - Scalars => a single (value) tuple. For consistency, they are still + * relations. + * + * Top-level operations are converted into relational ops such as \c ProjectOp, + * \c JoinOp and \c AggregateOp. + * + * Scalar operations inside of \c ApplyOp are converted into ops from the arith + * dialect. + */ +class GraphAlgToRel : public impl::GraphAlgToRelBase { +public: + using impl::GraphAlgToRelBase::GraphAlgToRelBase; + + void runOnOperation() final; +}; + +/** Converts semiring types into their relational equivalents. */ +class SemiringTypeConverter : public mlir::TypeConverter { +private: + static mlir::Type convertSemiringType(graphalg::SemiringTypeInterface type); + +public: + SemiringTypeConverter(); +}; + +/** Converts matrix types into relations. */ +class MatrixTypeConverter : public mlir::TypeConverter { +private: + SemiringTypeConverter _semiringConverter; + + mlir::FunctionType convertFunctionType(mlir::FunctionType type) const; + RelationType convertMatrixType(graphalg::MatrixType type) const; + +public: + MatrixTypeConverter(mlir::MLIRContext *ctx, + const SemiringTypeConverter &semiringConverter); +}; + +/** + * Convenient wrapper around a matrix value and its relation equivalent + * after type conversion. + * + * This class is particularly useful for retrieving the relation column for + * the rows, columns or values of the matrix. + */ +class MatrixAdaptor { +private: + mlir::TypedValue _matrix; + + RelationType _relType; + // May be null for outputs, in which case only the relation type is available. + mlir::TypedValue _relation; + +public: + // For output matrices, where we only have the desired output type. + MatrixAdaptor(mlir::Value matrix, mlir::Type relType) + : _matrix(llvm::cast>(matrix)), + _relType(llvm::cast(relType)) {} + + // For input matrices, where the OpAdaptor provides the relation value. + MatrixAdaptor(mlir::Value matrix, mlir::Value relation) + : MatrixAdaptor(matrix, relation.getType()) { + this->_relation = llvm::cast>(relation); + } + + graphalg::MatrixType matrixType() { return _matrix.getType(); } + + RelationType relType() { return _relType; } + + mlir::TypedValue relation() { + assert(!!_relation && "No relation value (only type)"); + return _relation; + } + + auto columns() const { return _relType.getColumns(); } + + bool isScalar() const { return _matrix.getType().isScalar(); } + + bool hasRowColumn() const { return !_matrix.getType().getRows().isOne(); } + + bool hasColColumn() const { return !_matrix.getType().getCols().isOne(); } + + ColumnIdx rowColumn() const { + assert(hasRowColumn()); + return 0; + } + + ColumnIdx colColumn() const { + assert(hasColColumn()); + // Follow row column, if there is one. + return hasRowColumn() ? 1 : 0; + } + + ColumnIdx valColumn() const { + // Last column in the relation. + return columns().size() - 1; + } + + graphalg::SemiringTypeInterface semiring() { + return llvm::cast( + _matrix.getType().getSemiring()); + } +}; + +template class OpConversion : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(T op, + typename mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; +}; + +class ApplyOpConversion : public mlir::OpConversionPattern { +private: + const SemiringTypeConverter &_bodyArgConverter; + + mlir::LogicalResult + matchAndRewrite(graphalg::ApplyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; + +public: + ApplyOpConversion(const SemiringTypeConverter &bodyArgConverter, + const MatrixTypeConverter &typeConverter, + mlir::MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx), + _bodyArgConverter(bodyArgConverter) {} +}; + +struct InputColumnRef { + std::size_t relIdx; + ColumnIdx colIdx; + ColumnIdx outIdx; +}; + +} // namespace + +// ============================================================================= +// =============================== Class Methods =============================== +// ============================================================================= + +mlir::Type +SemiringTypeConverter::convertSemiringType(graphalg::SemiringTypeInterface t) { + auto *ctx = t.getContext(); + // To i1 + if (t == graphalg::SemiringTypes::forBool(ctx)) { + return mlir::IntegerType::get(ctx, 1); + } + + // To i64 + if (t == graphalg::SemiringTypes::forInt(ctx) || + t == graphalg::SemiringTypes::forTropInt(ctx) || + t == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + return mlir::IntegerType::get(ctx, 64); + } + + // To f64 + if (t == graphalg::SemiringTypes::forReal(ctx) || + t == graphalg::SemiringTypes::forTropReal(ctx)) { + return mlir::Float64Type::get(ctx); + } + + return nullptr; +} + +SemiringTypeConverter::SemiringTypeConverter() { + addConversion(convertSemiringType); +} + +mlir::FunctionType +MatrixTypeConverter::convertFunctionType(mlir::FunctionType type) const { + llvm::SmallVector inputs; + if (mlir::failed(convertTypes(type.getInputs(), inputs))) { + return {}; + } + + llvm::SmallVector results; + if (mlir::failed(convertTypes(type.getResults(), results))) { + return {}; + } + + return mlir::FunctionType::get(type.getContext(), inputs, results); +} + +RelationType +MatrixTypeConverter::convertMatrixType(graphalg::MatrixType type) const { + llvm::SmallVector columns; + auto *ctx = type.getContext(); + if (!type.getRows().isOne()) { + columns.push_back(mlir::IndexType::get(ctx)); + } + + if (!type.getCols().isOne()) { + columns.push_back(mlir::IndexType::get(ctx)); + } + + auto valueType = _semiringConverter.convertType(type.getSemiring()); + if (!valueType) { + return {}; + } + + columns.push_back(valueType); + return RelationType::get(ctx, columns); +} + +MatrixTypeConverter::MatrixTypeConverter( + mlir::MLIRContext *ctx, const SemiringTypeConverter &semiringConverter) + : _semiringConverter(semiringConverter) { + addConversion( + [this](mlir::FunctionType t) { return convertFunctionType(t); }); + + addConversion( + [this](graphalg::MatrixType t) { return convertMatrixType(t); }); +} + +// ============================================================================= +// ============================== Helper Methods =============================== +// ============================================================================= + +/** + * Create a relation with all indices for a matrix dimension. + * + * Used to broadcast scalar values to a larger matrix. + */ +static RangeOp createDimRead(mlir::Location loc, graphalg::DimAttr dim, + mlir::OpBuilder &builder) { + return builder.create(loc, dim.getConcreteDim()); +} + +static void +buildApplyJoinPredicates(mlir::MLIRContext *ctx, + llvm::SmallVectorImpl &predicates, + llvm::ArrayRef columnsToJoin) { + if (columnsToJoin.size() < 2) { + return; + } + + auto first = columnsToJoin.front(); + for (auto other : columnsToJoin.drop_front()) { + predicates.push_back(JoinPredicateAttr::get(ctx, first.relIdx, first.colIdx, + other.relIdx, other.colIdx)); + } +} + +static mlir::FailureOr convertConstant(mlir::Operation *op, + mlir::TypedAttr attr) { + auto *ctx = attr.getContext(); + auto type = attr.getType(); + if (type == graphalg::SemiringTypes::forBool(ctx)) { + return attr; + } else if (type == graphalg::SemiringTypes::forInt(ctx)) { + return attr; + } else if (type == graphalg::SemiringTypes::forReal(ctx)) { + return attr; + } else if (type == graphalg::SemiringTypes::forTropInt(ctx)) { + std::int64_t value; + if (llvm::isa(attr)) { + // Positive infinity, kind of. + value = std::numeric_limits::max(); + } else { + auto intAttr = llvm::cast(attr); + value = intAttr.getValue().getValue().getSExtValue(); + } + + return mlir::TypedAttr( + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), value)); + } else if (type == graphalg::SemiringTypes::forTropReal(ctx)) { + double value; + if (llvm::isa(attr)) { + // Has a proper positive infinity value + value = std::numeric_limits::infinity(); + } else { + auto floatAttr = llvm::cast(attr); + value = floatAttr.getValue().getValueAsDouble(); + } + + return mlir::TypedAttr( + mlir::FloatAttr::get(mlir::Float64Type::get(ctx), value)); + } else if (type == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + std::int64_t value; + if (llvm::isa(attr)) { + // Negative infinity, kind of. + value = std::numeric_limits::min(); + } else { + auto intAttr = llvm::cast(attr); + value = intAttr.getValue().getValue().getSExtValue(); + } + + return mlir::TypedAttr( + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), value)); + } + + return op->emitOpError("cannot convert constant ") << attr; +} + +static bool isTropicalnessCast(graphalg::SemiringTypeInterface inRing, + graphalg::SemiringTypeInterface outRing) { + assert(inRing != outRing && "No-op cast"); + // If the relational types match, it is purely a 'tropicalness' cast such as + // i64 -> !graphalg.trop_i64. + SemiringTypeConverter conv; + return conv.convertType(inRing) == conv.convertType(outRing); +} + +static mlir::Value preserveAdditiveIdentity(graphalg::CastScalarOp op, + mlir::Value input, + mlir::Value defaultOutput, + mlir::OpBuilder &builder) { + // Return defaultOutput, except when input is the additive identity, which + // we need to remap to the additive identity of the target type. + auto inRing = + llvm::cast(op.getInput().getType()); + auto outRing = llvm::cast(op.getType()); + + auto inIdent = convertConstant(op, inRing.addIdentity()); + assert(mlir::succeeded(inIdent)); + auto outIdent = convertConstant(op, outRing.addIdentity()); + assert(mlir::succeeded(outIdent)); + + auto inIdentOp = + builder.create(input.getLoc(), *inIdent); + auto outIdentOp = + builder.create(input.getLoc(), *outIdent); + + // Compare input == inIdent + mlir::Value identCompare; + if (input.getType().isF64()) { + assert(inIdentOp.getType().isF64()); + identCompare = builder.create( + op.getLoc(), mlir::arith::CmpFPredicate::OEQ, input, inIdentOp); + } else { + assert(input.getType().isSignlessInteger(64)); + assert(inIdentOp.getType().isSignlessInteger(64)); + identCompare = builder.create( + op.getLoc(), mlir::arith::CmpIPredicate::eq, input, inIdentOp); + } + + return builder.create(op.getLoc(), identCompare, + outIdentOp, defaultOutput); +} + +static mlir::FailureOr +createAggregator(mlir::Operation *op, graphalg::SemiringTypeInterface sring, + ColumnIdx input, mlir::OpBuilder &builder) { + auto *ctx = builder.getContext(); + AggregateFunc func; + if (sring == graphalg::SemiringTypes::forBool(ctx)) { + func = AggregateFunc::LOR; + } else if (sring.isIntOrFloat()) { + func = AggregateFunc::SUM; + } else if (llvm::isa(sring)) { + func = AggregateFunc::MIN; + } else if (llvm::isa(sring)) { + func = AggregateFunc::MAX; + } else { + return op->emitOpError("aggregation with semiring ") + << sring << " is not supported"; + } + + std::array inputs{input}; + return AggregatorAttr::get(ctx, func, inputs); +} + +static mlir::IntegerAttr tryGetConstantInt(mlir::Value v) { + mlir::Attribute attr; + if (!mlir::matchPattern(v, mlir::m_Constant(&attr))) { + return nullptr; + } + + return llvm::cast(attr); +} + +static mlir::FailureOr +createMul(mlir::Operation *op, graphalg::SemiringTypeInterface sring, + mlir::Value lhs, mlir::Value rhs, mlir::OpBuilder &builder) { + auto *ctx = builder.getContext(); + if (sring == graphalg::SemiringTypes::forBool(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forInt(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forReal(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forTropInt(ctx) || + sring == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forTropReal(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } + + return op->emitOpError("multiplication with semiring ") + << sring << " is not supported"; +} + +// ============================================================================= +// =============================== Op Conversion =============================== +// ============================================================================= + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + mlir::func::FuncOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto funcType = llvm::cast_if_present( + typeConverter->convertType(op.getFunctionType())); + if (!funcType) { + return op->emitOpError("function type ") + << op.getFunctionType() << " cannot be converted"; + } + + rewriter.modifyOpInPlace(op, [&]() { + // Update function type. + op.setFunctionType(funcType); + }); + + // Convert block args. + mlir::TypeConverter::SignatureConversion newSig(funcType.getNumInputs()); + if (mlir::failed( + rewriter.convertRegionTypes(&op.getFunctionBody(), *typeConverter))) { + return mlir::failure(); + } + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + mlir::func::ReturnOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::TransposeOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + auto projectOp = rewriter.replaceOpWithNewOp(op, output.relType(), + input.relation()); + + auto &body = projectOp.createProjectionsBlock(); + rewriter.setInsertionPointToStart(&body); + + llvm::SmallVector columns(input.columns().size()); + std::iota(columns.begin(), columns.end(), 0); + assert(columns.size() <= 3); + // Transpose is a no-op if there are fewer than 3 columns. + if (columns.size() == 3) { + // Swap row and column + std::swap(columns[0], columns[1]); + } + + // Return the input columns (after row and column have been swapped) + llvm::SmallVector results; + for (auto col : columns) { + results.push_back( + rewriter.create(op.getLoc(), col, body.getArgument(0))); + } + + rewriter.create(op.getLoc(), results); + return mlir::success(); +} + +static constexpr llvm::StringLiteral APPLY_ROW_IDX_ATTR_KEY = + "garel.apply.row_idx"; +static constexpr llvm::StringLiteral APPLY_COL_IDX_ATTR_KEY = + "garel.apply.col_idx"; + +mlir::LogicalResult ApplyOpConversion::matchAndRewrite( + graphalg::ApplyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm::SmallVector inputs; + for (auto [matrix, relation] : + llvm::zip_equal(op.getInputs(), adaptor.getInputs())) { + auto &input = inputs.emplace_back(matrix, relation); + } + + llvm::SmallVector joinChildren; + llvm::SmallVector rowColumns; + llvm::SmallVector colColumns; + llvm::SmallVector valColumns; + ColumnIdx nextColumnIdx = 0; + for (const auto &[idx, input] : llvm::enumerate(inputs)) { + joinChildren.push_back(input.relation()); + + if (input.hasRowColumn()) { + rowColumns.push_back(InputColumnRef{ + .relIdx = idx, + .colIdx = input.rowColumn(), + .outIdx = nextColumnIdx + input.rowColumn(), + }); + } + + if (input.hasColColumn()) { + colColumns.push_back(InputColumnRef{ + .relIdx = idx, + .colIdx = input.colColumn(), + .outIdx = nextColumnIdx + input.colColumn(), + }); + } + + valColumns.push_back(nextColumnIdx + input.valColumn()); + nextColumnIdx += input.columns().size(); + } + + auto outputType = typeConverter->convertType(op.getType()); + MatrixAdaptor output(op.getResult(), outputType); + if (rowColumns.empty() && output.hasRowColumn()) { + // None of the inputs have a row column, but we need it in the output. + // Broadcast to all rows. + auto rowsOp = + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + joinChildren.push_back(rowsOp); + rowColumns.push_back(InputColumnRef{ + .relIdx = joinChildren.size() - 1, + .colIdx = 0, + .outIdx = nextColumnIdx++, + }); + } + + if (colColumns.empty() && output.hasColColumn()) { + // None of the inputs have a col column, but we need it in the output. + // Broadcast to all columns. + auto colsOp = + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + joinChildren.push_back(colsOp); + colColumns.push_back(InputColumnRef{ + .relIdx = joinChildren.size() - 1, + .colIdx = 0, + .outIdx = nextColumnIdx++, + }); + } + + mlir::Value joined; + if (joinChildren.size() == 1) { + joined = joinChildren.front(); + } else { + llvm::SmallVector predicates; + buildApplyJoinPredicates(rewriter.getContext(), predicates, rowColumns); + buildApplyJoinPredicates(rewriter.getContext(), predicates, colColumns); + joined = rewriter.create(op.getLoc(), joinChildren, predicates); + } + + auto projectOp = rewriter.create(op->getLoc(), outputType, joined); + + // Convert old body + if (mlir::failed( + rewriter.convertRegionTypes(&op.getBody(), _bodyArgConverter))) { + return op->emitOpError("failed to convert body argument types"); + } + + // Read value columns, to be used as arg replacements for the old body. + mlir::OpBuilder::InsertionGuard guard(rewriter); + auto &body = projectOp.createProjectionsBlock(); + rewriter.setInsertionPointToStart(&body); + + llvm::SmallVector columnReads; + for (auto col : valColumns) { + columnReads.push_back( + rewriter.create(op->getLoc(), col, body.getArgument(0))); + } + + // Inline into new body + rewriter.inlineBlockBefore(&op.getBody().front(), &body, body.end(), + columnReads); + + rewriter.replaceOp(op, projectOp); + + // Attach the row and column indexes to the return op. + auto returnOp = llvm::cast(body.getTerminator()); + if (!rowColumns.empty()) { + returnOp->setAttr(APPLY_ROW_IDX_ATTR_KEY, + rewriter.getI32IntegerAttr(rowColumns[0].outIdx)); + } + + if (!colColumns.empty()) { + returnOp->setAttr(APPLY_COL_IDX_ATTR_KEY, + rewriter.getI32IntegerAttr(colColumns[0].outIdx)); + } + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ApplyReturnOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm::SmallVector results; + + // Note: conversion is done top-down, so the ApplyOp is converted to + // ProjectOp before we reach this op in its body. + auto inputTuple = op->getBlock()->getArgument(0); + + if (auto idx = op->getAttrOfType(APPLY_ROW_IDX_ATTR_KEY)) { + results.push_back( + rewriter.create(op->getLoc(), idx, inputTuple)); + } + + if (auto idx = op->getAttrOfType(APPLY_COL_IDX_ATTR_KEY)) { + results.push_back( + rewriter.create(op->getLoc(), idx, inputTuple)); + } + + // The value column + results.push_back(adaptor.getValue()); + + rewriter.replaceOpWithNewOp(op, results); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::BroadcastOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + llvm::SmallVector joinChildren; + ColumnIdx currentColIdx = 0; + + std::optional rowColumnIdx; + std::optional colColumnIdx; + + if (input.hasRowColumn()) { + // Already have a row column. + rowColumnIdx = input.rowColumn(); + } else if (output.hasRowColumn()) { + // Broadcast over all rows. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + rowColumnIdx = currentColIdx++; + } + + if (input.hasColColumn()) { + // Already have a col column. + colColumnIdx = input.colColumn(); + } else if (output.hasColColumn()) { + // Broadcast over all columns. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + colColumnIdx = currentColIdx++; + } + + joinChildren.push_back(input.relation()); + auto valColumnIdx = currentColIdx + input.valColumn(); + + auto joinOp = + rewriter.create(op.getLoc(), joinChildren, + // on join predicates (cartesian product) + llvm::ArrayRef{}); + + // Remap to correctly order as (row, col, val). + llvm::SmallVector outputColumns; + if (rowColumnIdx) { + outputColumns.push_back(*rowColumnIdx); + } + + if (colColumnIdx) { + outputColumns.push_back(*colColumnIdx); + } + + outputColumns.push_back(valColumnIdx); + + // NOTE: folds if the remapping is unncessary. + auto remapped = + rewriter.createOrFold(op.getLoc(), joinOp, outputColumns); + rewriter.replaceOp(op, remapped); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ConstantMatrixOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + auto constantValue = convertConstant(op, op.getValue()); + if (mlir::failed(constantValue)) { + return mlir::failure(); + } + + auto constantOp = rewriter.create(op.getLoc(), *constantValue); + + // Broadcast to rows/columns if needed. + llvm::SmallVector joinChildren; + if (!output.matrixType().getRows().isOne()) { + // Broadcast over all rows. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + } + + if (!output.matrixType().getCols().isOne()) { + // Broadcast over all columns. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + } + + joinChildren.push_back(constantOp); + auto joinOp = rewriter.createOrFold( + op.getLoc(), joinChildren, llvm::ArrayRef{}); + + rewriter.replaceOp(op, joinOp); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::DeferredReduceOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInputs()[0], adaptor.getInputs()[0]); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + // Group by keys + llvm::SmallVector groupBy; + if (output.hasRowColumn()) { + groupBy.push_back(input.rowColumn()); + } + + if (output.hasColColumn()) { + groupBy.push_back(input.colColumn()); + } + + // Aggregators + auto aggregator = + createAggregator(op, input.semiring(), input.valColumn(), rewriter); + if (mlir::failed(aggregator)) { + return mlir::failure(); + } + + std::array aggregators{*aggregator}; + + // union the inputs and then aggregate. + auto unionOp = + rewriter.createOrFold(op.getLoc(), adaptor.getInputs()); + rewriter.replaceOpWithNewOp(op, unionOp, groupBy, aggregators); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::DiagOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + std::array mapping{0, 0, 1}; + rewriter.replaceOpWithNewOp(op, adaptor.getInput(), mapping); + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ForConstOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto rangeBegin = tryGetConstantInt(op.getRangeBegin()); + auto rangeEnd = tryGetConstantInt(op.getRangeEnd()); + if (!rangeBegin || !rangeEnd) { + return op->emitOpError("iter range is not constant"); + } + + auto iters = rangeEnd.getInt() - rangeBegin.getInt(); + + llvm::SmallVector initArgs{adaptor.getRangeBegin()}; + initArgs.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); + + auto blockSignature = + typeConverter->convertBlockSignature(&op.getBody().front()); + if (!blockSignature) { + return op->emitOpError("Failed to convert iter args"); + } + + // The relational version of this op can only have a single output value. + // For loops with multiple results, duplicate. + llvm::SmallVector resultValues; + for (auto i : llvm::seq(op->getNumResults())) { + auto result = op->getResult(i); + if (result.use_empty()) { + // Not used. Take init arg as a dummy value. + resultValues.push_back(adaptor.getInitArgs()[i]); + continue; + } + + // We are adding the iteration count variable as a first argument, so offset + // the result index accordingly. + std::int64_t resultIdx = i + 1; + auto resultType = adaptor.getInitArgs()[i].getType(); + auto forOp = rewriter.create(op.getLoc(), resultType, initArgs, + iters, resultIdx); + // body block + rewriter.cloneRegionBefore(op.getBody(), forOp.getBody(), + forOp.getBody().begin()); + rewriter.applySignatureConversion(&forOp.getBody().front(), + *blockSignature); + + // until block + if (!op.getUntil().empty()) { + rewriter.cloneRegionBefore(op.getUntil(), forOp.getUntil(), + forOp.getUntil().begin()); + rewriter.applySignatureConversion(&forOp.getUntil().front(), + *blockSignature); + } + + resultValues.push_back(forOp); + } + + rewriter.replaceOp(op, resultValues); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::YieldOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm::SmallVector inputs; + + auto *block = op->getBlock(); + auto forOp = llvm::cast(block->getParentOp()); + if (block == &forOp.getBody().front()) { + // Main body + auto iterVar = op->getBlock()->getArgument(0); + // Increment the iteration counter using a garel.project op. + auto loc = forOp.getLoc(); + auto projOp = rewriter.create(loc, iterVar.getType(), iterVar); + auto &projBlock = projOp.createProjectionsBlock(); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&projBlock); + auto iterOp = rewriter.create(loc, 0, projBlock.getArgument(0)); + auto oneOp = rewriter.create( + loc, 1, rewriter.getI64Type()); + auto addOp = rewriter.create(loc, iterOp, oneOp); + rewriter.create(loc, mlir::ValueRange{addOp}); + + inputs.push_back(projOp); + } else { + // No changes needed for 'until' block. + } + + inputs.append(adaptor.getInputs().begin(), adaptor.getInputs().end()); + + rewriter.replaceOpWithNewOp(op, inputs); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::MatMulJoinOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor lhs(op.getLhs(), adaptor.getLhs()); + MatrixAdaptor rhs(op.getRhs(), adaptor.getRhs()); + MatrixAdaptor result(op, typeConverter->convertType(op.getType())); + + // Join matrices. + llvm::SmallVector predicates; + if (lhs.hasColColumn() && rhs.hasRowColumn()) { + predicates.push_back(rewriter.getAttr( + /*lhsRelIdx=*/0, lhs.colColumn(), /*rhsRelIdx=*/1, rhs.rowColumn())); + } + + auto joinOp = rewriter.create( + op->getLoc(), mlir::ValueRange{lhs.relation(), rhs.relation()}, + predicates); + + // Project the multiplied values. + auto projOp = + rewriter.create(op.getLoc(), result.relType(), joinOp); + { + auto &body = projOp.createProjectionsBlock(); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&body); + + llvm::SmallVector projections; + + if (lhs.hasRowColumn()) { + projections.push_back(rewriter.create( + op.getLoc(), lhs.rowColumn(), body.getArgument(0))); + } + + if (rhs.hasColColumn()) { + // In the join output, rhs columns come after lhs columns. + auto colIdx = lhs.columns().size() + rhs.colColumn(); + projections.push_back( + rewriter.create(op.getLoc(), colIdx, body.getArgument(0))); + } + + // Get the value columns. + auto lhsVal = rewriter.create(op.getLoc(), lhs.valColumn(), + body.getArgument(0)); + auto rhsVal = rewriter.create( + op.getLoc(), lhs.columns().size() + rhs.valColumn(), + body.getArgument(0)); + + // Perform the multiplication + auto mulOp = createMul( + op, + llvm::cast(op.getType().getSemiring()), + lhsVal, rhsVal, rewriter); + if (mlir::failed(mulOp)) { + return mlir::failure(); + } + + projections.push_back(*mulOp); + rewriter.create(op.getLoc(), projections); + } + + rewriter.replaceOp(op, projOp); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::PickAnyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + auto *ctx = rewriter.getContext(); + + // Remove rows where value is the additive identity. + auto selectOp = rewriter.create(op.getLoc(), input.relation()); + { + mlir::OpBuilder::InsertionGuard guard(rewriter); + auto &body = selectOp.createPredicatesBlock(); + rewriter.setInsertionPointToStart(&body); + + auto valOp = rewriter.create(op.getLoc(), input.valColumn(), + body.getArgument(0)); + auto addIdent = convertConstant(op, input.semiring().addIdentity()); + if (mlir::failed(addIdent)) { + return mlir::failure(); + } + + auto addIdentOp = + rewriter.create(op.getLoc(), *addIdent); + mlir::Value cmpOp; + if (addIdentOp.getType().isF64()) { + cmpOp = rewriter.create( + op.getLoc(), mlir::arith::CmpFPredicate::ONE, valOp, addIdentOp); + } else { + cmpOp = rewriter.create( + op.getLoc(), mlir::arith::CmpIPredicate::ne, valOp, addIdentOp); + } + + rewriter.create(op.getLoc(), mlir::ValueRange{cmpOp}); + } + + llvm::SmallVector groupBy; + if (input.hasRowColumn()) { + groupBy.push_back(input.rowColumn()); + } + + assert(input.hasColColumn()); + std::array aggregators{ + // Minimum column + rewriter.getAttr( + AggregateFunc::MIN, std::array{input.colColumn()}), + // Value for minimum column + rewriter.getAttr( + AggregateFunc::ARGMIN, + std::array{input.valColumn(), input.colColumn()}), + }; + rewriter.replaceOpWithNewOp(op, selectOp, groupBy, aggregators); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::TrilOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + + if (!input.hasRowColumn() || !input.hasColColumn()) { + return op->emitOpError( + "only works on full matrices (not scalars or vector)"); + } + + auto selectOp = rewriter.replaceOpWithNewOp(op, input.relation()); + + auto &body = selectOp.createPredicatesBlock(); + rewriter.setInsertionPointToStart(&body); + + auto row = rewriter.create(op.getLoc(), input.rowColumn(), + body.getArgument(0)); + auto col = rewriter.create(op.getLoc(), input.colColumn(), + body.getArgument(0)); + // col < row + auto cmpOp = rewriter.create( + op.getLoc(), mlir::arith::CmpIPredicate::ult, col, row); + rewriter.create(op.getLoc(), mlir::ValueRange{cmpOp}); + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::UnionOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto targetType = + llvm::cast(typeConverter->convertType(op.getType())); + MatrixAdaptor target(op, targetType); + + llvm::SmallVector inputs; + for (auto [matrix, rel] : + llvm::zip_equal(op.getInputs(), adaptor.getInputs())) { + MatrixAdaptor input(matrix, rel); + + // Drop columns we don't want in the output. + llvm::SmallVector remap; + if (target.hasRowColumn()) { + remap.push_back(input.rowColumn()); + } + + if (target.hasColColumn()) { + remap.push_back(input.colColumn()); + } + + remap.push_back(input.valColumn()); + inputs.push_back( + rewriter.createOrFold(op.getLoc(), input.relation(), remap)); + } + + auto newOp = rewriter.createOrFold(op.getLoc(), targetType, inputs); + rewriter.replaceOp(op, newOp); + + return mlir::success(); +} + +// ============================================================================= +// ============================ Tuple Op Conversion ============================ +// ============================================================================= + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ConstantOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto value = convertConstant(op, op.getValue()); + if (mlir::failed(value)) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp(op, *value); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::AddOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto sring = op.getType(); + auto *ctx = rewriter.getContext(); + if (sring == graphalg::SemiringTypes::forBool(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forInt(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forReal(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forTropInt(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forTropReal(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else { + return op->emitOpError("conversion not supported for semiring ") << sring; + } + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::CastScalarOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto *ctx = rewriter.getContext(); + auto inRing = + llvm::cast(op.getInput().getType()); + auto outRing = llvm::cast(op.getType()); + assert(inRing != outRing && "Identity cast not removed by fold()"); + + if (outRing == graphalg::SemiringTypes::forBool(ctx)) { + // Rewrite to: input != zero(inRing) + auto addIdent = convertConstant(op, inRing.addIdentity()); + if (mlir::failed(addIdent)) { + return mlir::failure(); + } + + auto addIdentOp = + rewriter.create(op.getLoc(), *addIdent); + + if (addIdentOp.getType().isF64()) { + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpFPredicate::ONE, adaptor.getInput(), addIdentOp); + } else { + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpIPredicate::ne, adaptor.getInput(), addIdentOp); + } + + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forBool(ctx)) { + // Mapping: + // true -> multiplicative identity + // false -> additive identity + auto trueValue = convertConstant(op, outRing.mulIdentity()); + if (mlir::failed(trueValue)) { + return mlir::failure(); + } + + auto falseValue = convertConstant(op, outRing.addIdentity()); + if (mlir::failed(falseValue)) { + return mlir::failure(); + } + + auto trueOp = + rewriter.create(op.getLoc(), *trueValue); + auto falseOp = + rewriter.create(op.getLoc(), *falseValue); + rewriter.replaceOpWithNewOp(op, adaptor.getInput(), + trueOp, falseOp); + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forInt(ctx) && + outRing == graphalg::SemiringTypes::forReal(ctx)) { + // Promote to i64 -> f64 + rewriter.replaceOpWithNewOp(op, outRing, + adaptor.getInput()); + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forReal(ctx) && + outRing == graphalg::SemiringTypes::forInt(ctx)) { + // Truncate to int + rewriter.replaceOpWithNewOp(op, outRing, + adaptor.getInput()); + return mlir::success(); + } else if (isTropicalnessCast(inRing, outRing)) { + // Only cast the 'tropicalness' of the type. The underlying relational type + // does not change. Preserve the value unless it is the additive identity, + // in which case we remap it to the additive identity of the output ring. + auto selectOp = preserveAdditiveIdentity(op, adaptor.getInput(), + adaptor.getInput(), rewriter); + rewriter.replaceOp(op, selectOp); + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forTropInt(ctx) && + outRing == graphalg::SemiringTypes::forTropReal(ctx)) { + // trop_i64 to trop_f64 + // Cast the underlying relational type, but preserve the additive identity. + auto castOp = rewriter.create( + op.getLoc(), rewriter.getF64Type(), adaptor.getInput()); + auto selectOp = + preserveAdditiveIdentity(op, adaptor.getInput(), castOp, rewriter); + rewriter.replaceOp(op, selectOp); + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forTropReal(ctx) && + outRing == graphalg::SemiringTypes::forTropInt(ctx)) { + // trop_f64 to trop_i64 + // Cast the underlying relational type, but preserve the additive identity. + auto castOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), adaptor.getInput()); + auto selectOp = + preserveAdditiveIdentity(op, adaptor.getInput(), castOp, rewriter); + rewriter.replaceOp(op, selectOp); + return mlir::success(); + } + + return op->emitOpError("cast from ") + << op.getInput().getType() << " to " << op.getType() + << " not yet supported in " << GARelDialect::getDialectNamespace() + << " dialect"; +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::EqOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + if (lhs.getType().isF64()) { + assert(rhs.getType().isF64()); + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpFPredicate::OEQ, lhs, rhs); + } else { + assert(lhs.getType().isSignlessInteger()); + assert(rhs.getType().isSignlessInteger()); + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpIPredicate::eq, lhs, rhs); + } + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::MulOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto sring = op.getType(); + auto *ctx = rewriter.getContext(); + auto mulOp = createMul(op, llvm::cast(sring), + adaptor.getLhs(), adaptor.getRhs(), rewriter); + if (mlir::failed(mulOp)) { + return mlir::failure(); + } + + rewriter.replaceOp(op, *mulOp); + return mlir::success(); +} + +static bool hasRelationSignature(mlir::func::FuncOp op) { + // All inputs should be relations + auto funcType = op.getFunctionType(); + for (auto input : funcType.getInputs()) { + if (!llvm::isa(input)) { + return false; + } + } + + // There should be exactly one relation result + return funcType.getNumResults() == 1 && + llvm::isa(funcType.getResult(0)); +} + +static bool hasRelationOperands(mlir::Operation *op) { + return llvm::all_of(op->getOperandTypes(), + [](auto t) { return llvm::isa(t); }); +} + +void GraphAlgToRel::runOnOperation() { + mlir::ConversionTarget target(getContext()); + // Eliminate all graphalg ops + target.addIllegalDialect(); + // Turn them into relational ops. + target.addLegalDialect(); + // and arith ops for the scalar operations. + target.addLegalDialect(); + // Keep container module. + target.addLegalOp(); + // Keep functions, but change their signature. + target.addDynamicallyLegalOp(hasRelationSignature); + target.addDynamicallyLegalOp(hasRelationOperands); + + SemiringTypeConverter semiringTypeConverter; + MatrixTypeConverter matrixTypeConverter(&getContext(), semiringTypeConverter); + + mlir::RewritePatternSet patterns(&getContext()); + patterns.add< + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion, + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion, OpConversion>( + matrixTypeConverter, &getContext()); + patterns.add(semiringTypeConverter, matrixTypeConverter, + &getContext()); + + // Scalar patterns. + patterns + .add, + OpConversion, OpConversion, + OpConversion, OpConversion, + OpConversion>(semiringTypeConverter, &getContext()); + + if (mlir::failed(mlir::applyFullConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace garel diff --git a/compiler/test/graphalg-to-rel/add.mlir b/compiler/test/graphalg-to-rel/add.mlir new file mode 100644 index 0000000..6069c74 --- /dev/null +++ b/compiler/test/graphalg-to-rel/add.mlir @@ -0,0 +1,109 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @AddBool +func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { +// CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg2 : i1, %arg3: i1): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.ori %[[#LHS]], %[[#RHS]] + %1 = graphalg.add %arg2, %arg3 : i1 + + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %1 : i1 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @AddInt +func.func @AddInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg2 : i64, %arg3: i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addi %[[#LHS]], %[[#RHS]] + %1 = graphalg.add %arg2, %arg3 : i64 + + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @AddReal +func.func @AddReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addf %[[#LHS]], %[[#RHS]] + %1 = graphalg.add %arg2, %arg3 : f64 + + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %1 : f64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @AddTropInt +func.func @AddTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg2 : !graphalg.trop_i64, %arg3: !graphalg.trop_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.minsi %[[#LHS]], %[[#RHS]] + %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_i64 + + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @AddTropReal +func.func @AddTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_f64>, !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_f64> { + ^bb0(%arg2 : !graphalg.trop_f64, %arg3: !graphalg.trop_f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.minimumf %[[#LHS]], %[[#RHS]] + %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_f64 + + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %1 : !graphalg.trop_f64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @AddTropMaxInt +func.func @AddTropMaxInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x !graphalg.trop_max_i64> { + ^bb0(%arg2 : !graphalg.trop_max_i64, %arg3: !graphalg.trop_max_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.maxsi %[[#LHS]], %[[#RHS]] + %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_max_i64 + + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %1 : !graphalg.trop_max_i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} diff --git a/compiler/test/graphalg-to-rel/apply.mlir b/compiler/test/graphalg-to-rel/apply.mlir new file mode 100644 index 0000000..827db04 --- /dev/null +++ b/compiler/test/graphalg-to-rel/apply.mlir @@ -0,0 +1,177 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === Arity + +// CHECK-LABEL: @ApplyUnary +func.func @ApplyUnary(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 2 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] + %2 = graphalg.add %1, %arg1 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @ApplyBinary +func.func @ApplyBinary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>, <0[1] = 1[1]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg2: i64, %arg3: i64): + // CHECK: %[[#VAL1:]] = garel.extract 2 + // CHECK: %[[#VAL2:]] = garel.extract 5 + // CHECK: %[[#ADD:]] = arith.addi %[[#VAL1]], %[[#VAL2]] + %1 = graphalg.add %arg2, %arg3 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @ApplyTernary +func.func @ApplyTernary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>, %arg2: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1, %arg2 : !garel.rel, !garel.rel, !garel.rel [<0[0] = 1[0]>, <0[0] = 2[0]>, <0[1] = 1[1]>, <0[1] = 2[1]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + %0 = graphalg.apply %arg0, %arg1, %arg2 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg3: i64, %arg4: i64, %arg5: i64): + // CHECK: %[[#VAL1:]] = garel.extract 2 + // CHECK: %[[#VAL2:]] = garel.extract 5 + // CHECK: %[[#VAL3:]] = garel.extract 8 + // CHECK: %[[#ADD1:]] = arith.addi %[[#VAL1]], %[[#VAL2]] + %1 = graphalg.add %arg3, %arg4 : i64 + // CHECK: %[[#ADD2:]] = arith.addi %[[#ADD1]], %[[#VAL3]] + %2 = graphalg.add %1, %arg5 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD2]] + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// === Shape + +// CHECK-LABEL: @ApplyMat +func.func @ApplyMat(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 + %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 2 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] + %2 = graphalg.add %1, %arg1 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @ApplyRowVec +func.func @ApplyRowVec(%arg0: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 42 x i64> -> <1 x 42 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] + %2 = graphalg.add %1, %arg1 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: garel.project.return %[[#ROW]], %[[#ADD]] + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @ApplyColVec +func.func @ApplyColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 + %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> { + ^bb0(%arg42: i64): + %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] + %2 = graphalg.add %1, %arg42 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: garel.project.return %[[#ROW]], %[[#ADD]] + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ApplyScalar +func.func @ApplyScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 0 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] + %2 = graphalg.add %1, %arg1 : i64 + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ApplyBroadcastScalar +func.func @ApplyBroadcastScalar(%arg0 : !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 43 x i64> { + // CHECK: %[[#ROWS:]] = garel.range 42 + // CHECK: %[[#COLS:]] = garel.range 43 + // CHECK: %[[#JOIN:]] = garel.join %arg0, %[[#ROWS]], %[[#COLS]] : !garel.rel, !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <42 x 43 x i64> { + ^bb0(%arg1: i64): + // CHECK: %[[#VAL:]] = garel.extract 0 + // CHECK: %[[#ROW:]] = garel.extract 1 + // CHECK: %[[#COL:]] = garel.extract 2 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#VAL]] + graphalg.apply.return %arg1 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<42 x 43 x i64> +} + +// CHECK-LABEL: @ApplyBroadcastOne +func.func @ApplyBroadcastOne(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 1 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg2: i64, %arg3: i64): + // CHECK: %[[#VAL1:]] = garel.extract 1 + // CHECK: %[[#VAL2:]] = garel.extract 4 + // CHECK: %[[#ADD:]] = arith.addi %[[#VAL1]], %[[#VAL2]] + %1 = graphalg.add %arg2, %arg3 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 3 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT:]] + return %0 : !graphalg.mat<42 x 42 x i64> +} diff --git a/compiler/test/graphalg-to-rel/broadcast.mlir b/compiler/test/graphalg-to-rel/broadcast.mlir new file mode 100644 index 0000000..918fbc2 --- /dev/null +++ b/compiler/test/graphalg-to-rel/broadcast.mlir @@ -0,0 +1,32 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @BroadcastMat +func.func @BroadcastMat(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#RNG0:]] = garel.range 42 + // CHECK: %[[#RNG1:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#RNG0]], %[[#RNG1]], %arg1 + %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <42 x 42 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @BroadcastRowVec +func.func @BroadcastRowVec(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#RNG:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#RNG]], %arg1 + %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <1 x 42 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @BroadcastColVec +func.func @BroadcastColVec(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#RNG:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#RNG]], %arg1 + %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <42 x 1 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<42 x 1 x i64> +} diff --git a/compiler/test/graphalg-to-rel/cast.mlir b/compiler/test/graphalg-to-rel/cast.mlir new file mode 100644 index 0000000..e37b00b --- /dev/null +++ b/compiler/test/graphalg-to-rel/cast.mlir @@ -0,0 +1,173 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @CastBoolInt +func.func @CastBoolInt(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i64> { + ^bb0(%arg1 : i1): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[C1:.+]] = arith.constant 1 : i64 + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[C1]], %[[C0]] + %1 = graphalg.cast_scalar %arg1 : i1 -> i64 + + // CHECK: garel.project.return %[[#SELECT]] + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @CastIntReal +func.func @CastIntReal(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x f64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x f64> { + ^bb0(%arg1 : i64): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.sitofp %[[#EXTRACT]] + %1 = graphalg.cast_scalar %arg1 : i64 -> f64 + + // CHECK: garel.project.return %[[#CAST]] + graphalg.apply.return %1 : f64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @CastRealInt +func.func @CastRealInt(%arg0: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x f64> -> <1 x 1 x i64> { + ^bb0(%arg1 : f64): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.fptosi %[[#EXTRACT]] + %1 = graphalg.cast_scalar %arg1 : f64 -> i64 + + // CHECK: garel.project.return %[[#CAST]] + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @CastBoolTrop +func.func @CastBoolTrop(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg1 : i1): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[CMAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[C0]], %[[CMAX]] + %1 = graphalg.cast_scalar %arg1 : i1 -> !graphalg.trop_i64 + + // CHECK: garel.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @CastTropBool +func.func @CastTropBool(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[CMAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#EXTRACT]], %[[CMAX]] + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> i1 + + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %1 : i1 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @CastTropIntTropReal +func.func @CastTropIntTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_f64> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.sitofp %[[#EXTRACT]] + // CHECK: %[[MAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[INF:.+]] = arith.constant 0x7FF0000000000000 : f64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[MAX]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[INF]], %[[#CAST]] + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> !graphalg.trop_f64 + + // CHECK: garel.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_f64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @CastTropRealTropInt +func.func @CastTropRealTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.fptosi %[[#EXTRACT]] + // CHECK: %[[INF:.+]] = arith.constant 0x7FF0000000000000 : f64 + // CHECK: %[[MAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[#EXTRACT]], %[[INF]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[MAX]], %[[#CAST]] + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_f64 -> !graphalg.trop_i64 + + // CHECK: garel.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @CastIntToTropMaxInt +func.func @CastIntToTropMaxInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x !graphalg.trop_max_i64> { + ^bb0(%arg1 : i64): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[MIN:.+]] = arith.constant -9223372036854775808 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[C0]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[MIN]], %[[#EXTRACT]] + %1 = graphalg.cast_scalar %arg1 : i64 -> !graphalg.trop_max_i64 + + // CHECK: garel.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_max_i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} + +// CHECK-LABEL: @CastTropMaxIntToInt +func.func @CastTropMaxIntToInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i64> { + ^bb0(%arg1 : !graphalg.trop_max_i64): + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[MIN:.+]] = arith.constant -9223372036854775808 : i64 + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[MIN]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[C0]], %[[#EXTRACT]] + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_max_i64 -> i64 + + // CHECK: garel.project.return %[[#SELECT]] + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} diff --git a/compiler/test/graphalg-to-rel/const-mat.mlir b/compiler/test/graphalg-to-rel/const-mat.mlir new file mode 100644 index 0000000..fe8fa40 --- /dev/null +++ b/compiler/test/graphalg-to-rel/const-mat.mlir @@ -0,0 +1,113 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === Shapes + +// CHECK-LABEL: @ConstMat +func.func @ConstMat() -> !graphalg.mat<42 x 43 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + // CHECK: %[[#ROWS:]] = garel.range 42 + // CHECK: %[[#COLS:]] = garel.range 43 + // CHECK: %[[#JOIN:]] = garel.join %[[#ROWS]], %[[#COLS]], %[[#CONST]] + %0 = graphalg.const_mat 1 : i64 -> <42 x 43 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<42 x 43 x i64> +} + +// CHECK-LABEL: @ConstRowVec +func.func @ConstRowVec() -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + // CHECK: %[[#COLS:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#COLS]], %[[#CONST]] + %0 = graphalg.const_mat 1 : i64 -> <1 x 42 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @ConstColVec +func.func @ConstColVec() -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + // CHECK: %[[#ROWS:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#ROWS]], %[[#CONST]] + %0 = graphalg.const_mat 1 : i64 -> <42 x 1 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ConstScalar +func.func @ConstScalar() -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + %0 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> + + // CHECK: return %[[#CONST]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// === Semirings + +// CHECK-LABEL: @ConstBool +func.func @ConstBool() -> !graphalg.mat<1 x 1 x i1> { + // CHECK: garel.const true + %0 = graphalg.const_mat true -> <1 x 1 x i1> + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ConstInt +func.func @ConstInt() -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.const 1 : i64 + %0 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ConstReal +func.func @ConstReal() -> !graphalg.mat<1 x 1 x f64> { + // CHECK: garel.const 1.000000e+00 : f64 + %0 = graphalg.const_mat 1.000000e+00 : f64 -> <1 x 1 x f64> + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: ConstTropInt +func.func @ConstTropInt() -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: garel.const 1 : i64 + %0 = graphalg.const_mat #graphalg.trop_int<1 : i64> : !graphalg.trop_i64 -> <1 x 1 x !graphalg.trop_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: ConstTropReal +func.func @ConstTropReal() -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: garel.const 1.000000e+00 : f64 + %0 = graphalg.const_mat #graphalg.trop_float<1.000000e+00 : f64> : !graphalg.trop_f64 -> <1 x 1 x !graphalg.trop_f64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: ConstTropMaxInt +func.func @ConstTropMaxInt() -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: garel.const 1 : i64 + %0 = graphalg.const_mat #graphalg.trop_int<1 : i64> : !graphalg.trop_max_i64 -> <1 x 1 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} + +// === Infinity values + +// CHECK-LABEL: ConstTropIntInf +func.func @ConstTropIntInf() -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: garel.const 9223372036854775807 : i64 + %0 = graphalg.const_mat #graphalg.trop_inf : !graphalg.trop_i64 -> <1 x 1 x !graphalg.trop_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: ConstTropRealInf +func.func @ConstTropRealInf() -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: garel.const 0x7FF0000000000000 : f64 + %0 = graphalg.const_mat #graphalg.trop_inf : !graphalg.trop_f64 -> <1 x 1 x !graphalg.trop_f64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: ConstTropMaxIntInf +func.func @ConstTropMaxIntInf() -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: garel.const -9223372036854775808 : i64 + %0 = graphalg.const_mat #graphalg.trop_inf : !graphalg.trop_max_i64 -> <1 x 1 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} diff --git a/compiler/test/graphalg-to-rel/const.mlir b/compiler/test/graphalg-to-rel/const.mlir new file mode 100644 index 0000000..0947fce --- /dev/null +++ b/compiler/test/graphalg-to-rel/const.mlir @@ -0,0 +1,125 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// Note: The inputs for these tests have to be fairly complex because: +// - Scalar ops must be inside of an ApplyOp +// - An ApplyOp whose body is reducible to a constant value folds into a +// ConstantMatrixOp. + +// CHECK-LABEL: @ConstBool +func.func @ConstBool(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg1 : i1): + // CHECK: arith.constant false + %1 = graphalg.const false + %2 = graphalg.eq %arg1, %1 : i1 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ConstInt +func.func @ConstInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : i64): + // CHECK: arith.constant 42 : i64 + %1 = graphalg.const 42 : i64 + %2 = graphalg.eq %arg1, %1 : i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ConstReal +func.func @ConstReal(%arg0: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : f64): + // CHECK: arith.constant 4.200000e+01 : f64 + %1 = graphalg.const 42.0 : f64 + %2 = graphalg.eq %arg1, %1 : f64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropInt +func.func @ConstTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: arith.constant 42 : i64 + %1 = graphalg.const #graphalg.trop_int<42 : i64> : !graphalg.trop_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropReal +func.func @ConstTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: arith.constant 4.200000e+01 : f64 + %1 = graphalg.const #graphalg.trop_float<42.0 : f64> : !graphalg.trop_f64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_f64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropMaxInt +func.func @ConstTropMaxInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_max_i64): + // CHECK: arith.constant 42 : i64 + %1 = graphalg.const #graphalg.trop_int<42 : i64> : !graphalg.trop_max_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_max_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// === Infinity values + +// CHECK-LABEL: ConstTropIntInf +func.func @ConstTropIntInf(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: arith.constant 9223372036854775807 : i64 + %1 = graphalg.const #graphalg.trop_inf : !graphalg.trop_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropRealInf +func.func @ConstTropRealInf(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: arith.constant 0x7FF0000000000000 : f64 + %1 = graphalg.const #graphalg.trop_inf : !graphalg.trop_f64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_f64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropMaxIntInf +func.func @ConstTropMaxIntInf(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_max_i64): + // CHECK: arith.constant -9223372036854775808 : i64 + %1 = graphalg.const #graphalg.trop_inf : !graphalg.trop_max_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_max_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} diff --git a/compiler/test/graphalg-to-rel/deferred-reduce.mlir b/compiler/test/graphalg-to-rel/deferred-reduce.mlir new file mode 100644 index 0000000..e6c27c9 --- /dev/null +++ b/compiler/test/graphalg-to-rel/deferred-reduce.mlir @@ -0,0 +1,121 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === Input/output shapes + +// CHECK-LABEL: @ReduceMatScalar +func.func @ReduceMatScalar(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceMatRowVec +func.func @ReduceMatRowVec(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 43 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[1] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 43 x i64> + return %0 : !graphalg.mat<1 x 43 x i64> +} + +// CHECK-LABEL: @ReduceMatColVec +func.func @ReduceMatColVec(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <42 x 1 x i64> + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ReduceMatMat +func.func @ReduceMatMat(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<42 x 43 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0, 1] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <42 x 43 x i64> + return %0 : !graphalg.mat<42 x 43 x i64> +} + +// CHECK-LABEL: @ReduceRowVecScalar +func.func @ReduceRowVecScalar(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<1 x 43 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceRowVecRowVec +func.func @ReduceRowVecRowVec(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<1 x 43 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<1 x 43 x i64> -> <1 x 43 x i64> + return %0 : !graphalg.mat<1 x 43 x i64> +} + +// CHECK-LABEL: @ReduceColVecScalar +func.func @ReduceColVecScalar(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 1 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceColVecColVec +func.func @ReduceColVecColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ReduceScalarScalar +func.func @ReduceScalarScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceMultiple +func.func @ReduceMultiple( + %arg0 : !graphalg.mat<1 x 43 x i64>, + %arg1 : !graphalg.mat<42 x 1 x i64>) + -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 : !garel.rel, !garel.rel + // CHECK: %[[#AGG:]] = garel.aggregate %0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0, %arg1 : !graphalg.mat<1 x 43 x i64>, !graphalg.mat<42 x 1 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// === Semirings + +// CHECK-LABEL: @ReduceBool +func.func @ReduceBool(%arg0: !graphalg.mat<42 x 43 x i1>) -> !graphalg.mat<1 x 1 x i1> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i1> -> <1 x 1 x i1> + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ReduceInt +func.func @ReduceInt(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceReal +func.func @ReduceReal(%arg0: !graphalg.mat<42 x 43 x f64>) -> !graphalg.mat<1 x 1 x f64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x f64> -> <1 x 1 x f64> + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @ReduceTropInt +func.func @ReduceTropInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @ReduceTropReal +func.func @ReduceTropReal(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_f64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @ReduceTropMaxInt +func.func @ReduceTropMaxInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x !graphalg.trop_max_i64> -> <1 x 1 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} diff --git a/compiler/test/graphalg-to-rel/diag.mlir b/compiler/test/graphalg-to-rel/diag.mlir new file mode 100644 index 0000000..846dd8f --- /dev/null +++ b/compiler/test/graphalg-to-rel/diag.mlir @@ -0,0 +1,15 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @DiagCol +func.func @DiagCol(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: garel.remap %arg0 : [0, 0, 1] + %0 = graphalg.diag %arg0 : !graphalg.mat<42 x 1 x i64> + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @DiagRow +func.func @DiagRow(%arg0: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: garel.remap %arg0 : [0, 0, 1] + %0 = graphalg.diag %arg0 : !graphalg.mat<1 x 42 x i64> + return %0 : !graphalg.mat<42 x 42 x i64> +} diff --git a/compiler/test/graphalg-to-rel/div.mlir b/compiler/test/graphalg-to-rel/div.mlir new file mode 100644 index 0000000..66f8122 --- /dev/null +++ b/compiler/test/graphalg-to-rel/div.mlir @@ -0,0 +1,16 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @DivReal +func.func @DivReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#DIV:]] = arith.divf %2, %3 : f64 + %1 = arith.divf %arg2, %arg3 : f64 + // CHECK: garel.project.return %[[#DIV]] + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} diff --git a/compiler/test/graphalg-to-rel/eq.mlir b/compiler/test/graphalg-to-rel/eq.mlir new file mode 100644 index 0000000..7434046 --- /dev/null +++ b/compiler/test/graphalg-to-rel/eq.mlir @@ -0,0 +1,81 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @EqBool +func.func @EqBool(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg1 : i1): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant false + %1 = graphalg.const false + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[LHS]], %[[RHS]] : i1 + %2 = graphalg.eq %arg1, %1 : i1 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqInt +func.func @EqInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : i64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0 + %1 = graphalg.const 0 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[LHS]], %[[RHS]] : i64 + %2 = graphalg.eq %arg1, %1 : i64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqReal +func.func @EqReal(%arg0: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : f64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0.000000e+00 + %1 = graphalg.const 0.000000e+00 : f64 + // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[LHS]], %[[RHS]] : f64 + %2 = graphalg.eq %arg1, %1 : f64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqTropInt +func.func @EqTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0 + %1 = graphalg.const #graphalg.trop_int<0 : i64> : !graphalg.trop_i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[LHS]], %[[RHS]] : i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_i64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqTropReal +func.func @EqTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0.000000e+00 + %1 = graphalg.const #graphalg.trop_float<0.0 : f64> : !graphalg.trop_f64 + // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[LHS]], %[[RHS]] : f64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_f64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} diff --git a/compiler/test/graphalg-to-rel/for-const.mlir b/compiler/test/graphalg-to-rel/for-const.mlir new file mode 100644 index 0000000..2d418d8 --- /dev/null +++ b/compiler/test/graphalg-to-rel/for-const.mlir @@ -0,0 +1,66 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @ForConst +func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> + %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> + + // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 : !garel.rel, !garel.rel iters=10 result_idx=1 { + %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): + // CHECK: %[[#PROJ:]] = garel.project %arg1 + // CHECK: %[[#EXT:]] = garel.extract 0 + // CHECK: %[[#ADD:]] = arith.addi %[[#EXT]], %c1_i64 + // CHECK: garel.project.return %[[#ADD]] + // CHECK: garel.for.yield %[[#PROJ]], %arg2 + graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> + } until { + } + + // CHECK: return %[[#FOR]] + return %2 : !graphalg.mat<1 x 1 x i64> + +} + +// CHECK-LABEL: @ForResultUnused +func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> + %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> + + // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0, %arg1 + %2:2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0, %arg1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> body { + ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>, %arg4: !graphalg.mat<1 x 1 x f64>): + // CHECK: %[[#PROJ:]] = garel.project %arg2 + // CHECK: garel.for.yield %[[#PROJ]], %arg3, %arg4 + graphalg.yield %arg3, %arg4 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> + } until { + } + + // CHECK: return %[[#FOR]] + return %2#1 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @Until +func.func @Until(%arg0: !graphalg.mat<42 x 42 x i1>) -> !graphalg.mat<42 x 42 x i1> { + %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> + %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> + + // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 + %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<42 x 42 x i1> -> !graphalg.mat<42 x 42 x i1> body { + ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): + // CHECK: %[[#PROJ:]] = garel.project %arg1 + // CHECK: garel.for.yield %[[#PROJ]], %arg2 + graphalg.yield %arg2 : !graphalg.mat<42 x 42 x i1> + // CHECK: } until { + } until { + ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): + // CHECK: %[[#AGG:]] = garel.aggregate %arg2 + %3 = graphalg.deferred_reduce %arg2 : !graphalg.mat<42 x 42 x i1> -> <1 x 1 x i1> + // CHECK: garel.for.yield %[[#AGG]] + graphalg.yield %3 : !graphalg.mat<1 x 1 x i1> + } + return %2 : !graphalg.mat<42 x 42 x i1> +} diff --git a/compiler/test/graphalg-to-rel/mat-mul-join.mlir b/compiler/test/graphalg-to-rel/mat-mul-join.mlir new file mode 100644 index 0000000..bae4363 --- /dev/null +++ b/compiler/test/graphalg-to-rel/mat-mul-join.mlir @@ -0,0 +1,175 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === input/output shapes + +// (a,b) * (b,c) +// CHECK-LABEL: @MatMulABC +func.func @MatMulABC(%arg0: !graphalg.mat<42 x 43 x i64>, %arg1: !graphalg.mat<43 x 44 x i64>) -> !graphalg.mat<42 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[1] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 4 + // CHECK: %[[#LHS:]] = garel.extract 2 + // CHECK: %[[#RHS:]] = garel.extract 5 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i64>, <43 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 44 x i64> +} + +// (1,b) * (b,c) +// CHECK-LABEL: @MatMulBC +func.func @MatMulBC(%arg0: !graphalg.mat<1 x 43 x i64>, %arg1: !graphalg.mat<43 x 44 x i64>) -> !graphalg.mat<1 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 3 + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 4 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 43 x i64>, <43 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 44 x i64> +} + +// (a,1) * (1,c) +// CHECK-LABEL: @MatMulAC +func.func @MatMulAC(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<1 x 44 x i64>) -> !graphalg.mat<42 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 2 + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 3 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 1 x i64>, <1 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 44 x i64> +} + +// (a,b) * (b,1) +// CHECK-LABEL: @MatMulAB +func.func @MatMulAB(%arg0: !graphalg.mat<42 x 43 x i64>, %arg1: !graphalg.mat<43 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[1] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#LHS:]] = garel.extract 2 + // CHECK: %[[#RHS:]] = garel.extract 4 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i64>, <43 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// (a,1) * (1,1) +// CHECK-LABEL: @MatMulA +func.func @MatMulA(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 2 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 1 x i64>, <1 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// (1, b) * (b, 1) +// CHECK-LABEL: @MatMulB +func.func @MatMulB(%arg0: !graphalg.mat<1 x 43 x i64>, %arg1: !graphalg.mat<43 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 3 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 43 x i64>, <43 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// (1,1) * (1,c) +// CHECK-LABEL: @MatMulC +func.func @MatMulC(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 44 x i64>) -> !graphalg.mat<1 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 2 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#COL]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 1 x i64>, <1 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 44 x i64> +} + +// (1,1) * (1,1) +// CHECK-LABEL: @MatMulScalar +func.func @MatMulScalar(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 1 x i64>, <1 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// === Semirings + +// CHECK-LABEL: @MatMulBool +func.func @MatMulBool(%arg0: !graphalg.mat<42 x 43 x i1>, %arg1: !graphalg.mat<43 x 44 x i1>) -> !graphalg.mat<42 x 44 x i1> { + // CHECK: arith.andi + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i1>, <43 x 44 x i1> + return %0 : !graphalg.mat<42 x 44 x i1> +} + +// CHECK-LABEL: @MatMulInt +func.func @MatMulInt(%arg0: !graphalg.mat<42 x 43 x i64>, %arg1: !graphalg.mat<43 x 44 x i64>) -> !graphalg.mat<42 x 44 x i64> { + // CHECK: arith.muli + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i64>, <43 x 44 x i64> + return %0 : !graphalg.mat<42 x 44 x i64> +} + +// CHECK-LABEL: @MatMulReal +func.func @MatMulReal(%arg0: !graphalg.mat<42 x 43 x f64>, %arg1: !graphalg.mat<43 x 44 x f64>) -> !graphalg.mat<42 x 44 x f64> { + // CHECK: arith.mulf + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x f64>, <43 x 44 x f64> + return %0 : !graphalg.mat<42 x 44 x f64> +} + +// CHECK-LABEL: @MatMulTropInt +func.func @MatMulTropInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_i64>, %arg1: !graphalg.mat<43 x 44 x !graphalg.trop_i64>) -> !graphalg.mat<42 x 44 x !graphalg.trop_i64> { + // CHECK: arith.addi + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x !graphalg.trop_i64>, <43 x 44 x !graphalg.trop_i64> + return %0 : !graphalg.mat<42 x 44 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @MatMulTropReal +func.func @MatMulTropReal(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_f64>, %arg1: !graphalg.mat<43 x 44 x !graphalg.trop_f64>) -> !graphalg.mat<42 x 44 x !graphalg.trop_f64> { + // CHECK: arith.addf + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x !graphalg.trop_f64>, <43 x 44 x !graphalg.trop_f64> + return %0 : !graphalg.mat<42 x 44 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @MatMulTropMaxInt +func.func @MatMulTropMaxInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_max_i64>, %arg1: !graphalg.mat<43 x 44 x !graphalg.trop_max_i64>) -> !graphalg.mat<42 x 44 x !graphalg.trop_max_i64> { + // CHECK: arith.addi + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x !graphalg.trop_max_i64>, <43 x 44 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<42 x 44 x !graphalg.trop_max_i64> +} diff --git a/compiler/test/graphalg-to-rel/mul.mlir b/compiler/test/graphalg-to-rel/mul.mlir new file mode 100644 index 0000000..8d35cf9 --- /dev/null +++ b/compiler/test/graphalg-to-rel/mul.mlir @@ -0,0 +1,91 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @MulBool +func.func @MulBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg2 : i1, %arg3: i1): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.andi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : i1 + graphalg.apply.return %1 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @MulInt +func.func @MulInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg2 : i64, %arg3: i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : i64 + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @MulReal +func.func @MulReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.mulf %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : f64 + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @MulTropInt +func.func @MulTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg2 : !graphalg.trop_i64, %arg3: !graphalg.trop_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.addi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : !graphalg.trop_i64 + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @MulTropReal +func.func @MulTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_f64>, !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_f64> { + ^bb0(%arg2 : !graphalg.trop_f64, %arg3: !graphalg.trop_f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.addf %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : !graphalg.trop_f64 + graphalg.apply.return %1 : !graphalg.trop_f64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @MulTropMaxInt +func.func @MulTropMaxInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x !graphalg.trop_max_i64> { + ^bb0(%arg2 : !graphalg.trop_max_i64, %arg3: !graphalg.trop_max_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.addi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : !graphalg.trop_max_i64 + graphalg.apply.return %1 : !graphalg.trop_max_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} diff --git a/compiler/test/graphalg-to-rel/pick-any.mlir b/compiler/test/graphalg-to-rel/pick-any.mlir new file mode 100644 index 0000000..adefc89 --- /dev/null +++ b/compiler/test/graphalg-to-rel/pick-any.mlir @@ -0,0 +1,27 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +func.func @PickAnyMat(%arg0: !graphalg.mat<42 x 43 x i1>) -> !graphalg.mat<42 x 43 x i1> { + // CHECK: %[[#SELECT:]] = garel.select %arg0 + // CHECK: %[[#VAL:]] = garel.extract 2 + // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#VAL]], %false + // CHECK: garel.select.return %[[#CMP]] + // CHECK: %[[#AGG:]] = garel.aggregate %[[#SELECT]] : group_by=[0] aggregators=[, ] + %0 = graphalg.pick_any %arg0 : <42 x 43 x i1> + + // CHECK: return %[[#AGG]] + return %0 : !graphalg.mat<42 x 43 x i1> +} + +func.func @PickAnyRowVec(%arg0: !graphalg.mat<1 x 42 x i1>) -> !graphalg.mat<1 x 42 x i1> { + // CHECK: %[[#SELECT:]] = garel.select %arg0 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#VAL:]], %false + // CHECK: garel.select.return %[[#CMP]] + // CHECK: %[[#AGG:]] = garel.aggregate %[[#SELECT]] : group_by=[] aggregators=[, ] + %0 = graphalg.pick_any %arg0 : <1 x 42 x i1> + + // CHECK: return %[[#AGG]] + return %0 : !graphalg.mat<1 x 42 x i1> +} + +// Note: scalar and column vector cases are folded away diff --git a/compiler/test/graphalg-to-rel/sub.mlir b/compiler/test/graphalg-to-rel/sub.mlir new file mode 100644 index 0000000..7429978 --- /dev/null +++ b/compiler/test/graphalg-to-rel/sub.mlir @@ -0,0 +1,33 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @SubInt +func.func @SubInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg2 : i64, %arg3: i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#SUB:]] = arith.subi %[[#LHS]], %[[#RHS]] + %1 = arith.subi %arg2, %arg3 : i64 + + // CHECK: garel.project.return %[[#SUB]] + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @SubReal +func.func @SubReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#SUB:]] = arith.subf %[[#LHS]], %[[#RHS]] + %1 = arith.subf %arg2, %arg3 : f64 + + // CHECK: garel.project.return %[[#SUB]] + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} diff --git a/compiler/test/graphalg-to-rel/transpose.mlir b/compiler/test/graphalg-to-rel/transpose.mlir new file mode 100644 index 0000000..4aaaab0 --- /dev/null +++ b/compiler/test/graphalg-to-rel/transpose.mlir @@ -0,0 +1,46 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @TransposeMatrix +func.func @TransposeMatrix(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<43 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#COL_SLOT:]] = garel.extract 1 + // CHECK: %[[#ROW_SLOT:]] = garel.extract 0 + // CHECK: %[[#VAL_SLOT:]] = garel.extract 2 + // CHECK: garel.project.return %[[#COL_SLOT]], %[[#ROW_SLOT]], %[[#VAL_SLOT]] + %0 = graphalg.transpose %arg0 : <42 x 43 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<43 x 42 x i64> +} + +// CHECK-LABEL: @TransposeColVec +func.func @TransposeColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] : index, i64 + %0 = graphalg.transpose %arg0 : <42 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @TransposeRowVec +func.func @TransposeRowVec(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<43 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#COL:]] = garel.extract 0 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#COL]], %[[#VAL]] : index, i64 + %0 = graphalg.transpose %arg0 : <1 x 43 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<43 x 1 x i64> +} + +func.func @TransposeScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // NOTE: folded away + %0 = graphalg.transpose %arg0 : <1 x 1 x i64> + + // CHECK: return %arg0 + return %0 : !graphalg.mat<1 x 1 x i64> +} diff --git a/compiler/test/graphalg-to-rel/tril.mlir b/compiler/test/graphalg-to-rel/tril.mlir new file mode 100644 index 0000000..de25b1d --- /dev/null +++ b/compiler/test/graphalg-to-rel/tril.mlir @@ -0,0 +1,13 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +func.func @Tril(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#SELECT:]] = garel.select %arg0 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: %[[#CMP:]] = arith.cmpi ult, %[[#COL]], %[[#ROW]] + // CHECK: garel.select.return %[[#CMP]] + %0 = graphalg.tril %arg0 : <42 x 42 x i64> + + // return %[[#SELECT]] + return %0 : !graphalg.mat<42 x 42 x i64> +} diff --git a/compiler/test/graphalg-to-rel/union.mlir b/compiler/test/graphalg-to-rel/union.mlir new file mode 100644 index 0000000..7aa3154 --- /dev/null +++ b/compiler/test/graphalg-to-rel/union.mlir @@ -0,0 +1,61 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @UnionMat +func.func @UnionMat(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @UnionRowVec +func.func @UnionRowVec(%arg0: !graphalg.mat<1 x 42 x i64>, %arg1: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<1 x 42 x i64>, !graphalg.mat<1 x 42 x i64> -> <1 x 42 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @UnionColVec +func.func @UnionColVec(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<42 x 1 x i64>, !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @UnionScalar +func.func @UnionScalar(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +func.func @UnionFlattenRow(%arg0 : !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 43 x i64> { + // CHECK: %[[#REMAP:]] = garel.remap %arg0 : [1, 2] + %0 = graphalg.union %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 43 x i64> + + // CHECK: return %[[#REMAP]] + return %0 : !graphalg.mat<1 x 43 x i64> +} + +func.func @UnionFlattenCol(%arg0 : !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#REMAP:]] = garel.remap %arg0 : [0, 2] + %0 = graphalg.union %arg0 : !graphalg.mat<42 x 43 x i64> -> <42 x 1 x i64> + + // CHECK: return %[[#REMAP]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +func.func @UnionFlattenAll(%arg0 : !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#REMAP:]] = garel.remap %arg0 : [2] + %0 = graphalg.union %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 1 x i64> + + // CHECK: return %[[#REMAP]] + return %0 : !graphalg.mat<1 x 1 x i64> +} diff --git a/compiler/tools/CMakeLists.txt b/compiler/tools/CMakeLists.txt index a869516..6efaafd 100644 --- a/compiler/tools/CMakeLists.txt +++ b/compiler/tools/CMakeLists.txt @@ -11,6 +11,8 @@ target_link_libraries(graphalg-opt PRIVATE ${llvm_libs} GraphAlgIR GraphAlgPasses + GARelIR + GraphAlgToRel MLIROptLib ) @@ -18,6 +20,7 @@ add_executable(graphalg-lsp-server graphalg-lsp-server.cpp) target_link_libraries(graphalg-lsp-server PRIVATE ${llvm_libs} GraphAlgIR + GARelIR MLIRLspServerLib ) diff --git a/compiler/tools/graphalg-lsp-server.cpp b/compiler/tools/graphalg-lsp-server.cpp index 4bd4378..eedaad1 100644 --- a/compiler/tools/graphalg-lsp-server.cpp +++ b/compiler/tools/graphalg-lsp-server.cpp @@ -4,6 +4,7 @@ #include #include +#include "garel/GARelDialect.h" #include "graphalg/GraphAlgDialect.h" using namespace mlir; @@ -11,6 +12,7 @@ using namespace mlir; int main(int argc, char **argv) { DialectRegistry registry; registry.insert(); + registry.insert(); registry.insert(); registry.insert(); diff --git a/compiler/tools/graphalg-opt.cpp b/compiler/tools/graphalg-opt.cpp index c25da8b..5bc7f2e 100644 --- a/compiler/tools/graphalg-opt.cpp +++ b/compiler/tools/graphalg-opt.cpp @@ -6,16 +6,20 @@ #include #include +#include "garel/GARelDialect.h" +#include "garel/GARelPasses.h" #include "graphalg/GraphAlgDialect.h" #include "graphalg/GraphAlgPasses.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); + registry.insert(); registry.insert(); graphalg::registerPasses(); graphalg::registerGraphAlgToCorePipeline(); + garel::registerPasses(); mlir::registerCanonicalizerPass(); mlir::registerInlinerPass(); mlir::registerCSEPass();