From 7190a64c29a007089dca4427f395044f1e15c738 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Mon, 19 Jan 2026 15:25:57 +0000 Subject: [PATCH 01/32] Initial garel dialect definition. --- compiler/CMakeLists.txt | 1 + compiler/include/garel/CMakeLists.txt | 9 +++++ compiler/include/garel/GARelAttr.h | 4 +++ compiler/include/garel/GARelAttr.td | 38 +++++++++++++++++++++ compiler/include/garel/GARelDialect.h | 6 ++++ compiler/include/garel/GARelDialect.td | 25 ++++++++++++++ compiler/include/garel/GARelOps.h | 14 ++++++++ compiler/include/garel/GARelOps.td | 22 ++++++++++++ compiler/include/garel/GARelTypes.h | 6 ++++ compiler/include/garel/GARelTypes.td | 25 ++++++++++++++ compiler/src/CMakeLists.txt | 1 + compiler/src/garel/CMakeLists.txt | 21 ++++++++++++ compiler/src/garel/GARelAttr.cpp | 23 +++++++++++++ compiler/src/garel/GARelDialect.cpp | 46 ++++++++++++++++++++++++++ compiler/src/garel/GARelOps.cpp | 7 ++++ compiler/src/garel/GARelTypes.cpp | 23 +++++++++++++ compiler/tools/CMakeLists.txt | 2 ++ compiler/tools/graphalg-lsp-server.cpp | 2 ++ compiler/tools/graphalg-opt.cpp | 2 ++ 19 files changed, 277 insertions(+) create mode 100644 compiler/include/garel/CMakeLists.txt create mode 100644 compiler/include/garel/GARelAttr.h create mode 100644 compiler/include/garel/GARelAttr.td create mode 100644 compiler/include/garel/GARelDialect.h create mode 100644 compiler/include/garel/GARelDialect.td create mode 100644 compiler/include/garel/GARelOps.h create mode 100644 compiler/include/garel/GARelOps.td create mode 100644 compiler/include/garel/GARelTypes.h create mode 100644 compiler/include/garel/GARelTypes.td create mode 100644 compiler/src/garel/CMakeLists.txt create mode 100644 compiler/src/garel/GARelAttr.cpp create mode 100644 compiler/src/garel/GARelDialect.cpp create mode 100644 compiler/src/garel/GARelOps.cpp create mode 100644 compiler/src/garel/GARelTypes.cpp diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index 274aaca..de582f0 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -66,6 +66,7 @@ endif () # Directories with MLIR dialects add_subdirectory(include/graphalg) +add_subdirectory(include/garel) add_subdirectory(src) add_subdirectory(test) diff --git a/compiler/include/garel/CMakeLists.txt b/compiler/include/garel/CMakeLists.txt new file mode 100644 index 0000000..883398b --- /dev/null +++ b/compiler/include/garel/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(SYSTEM ${MLIR_INCLUDE_DIRS}) +include_directories(SYSTEM ${PROJECT_BINARY_DIR}/include) + +set(LLVM_TARGET_DEFINITIONS GARelAttr.td) +mlir_tablegen(GARelAttr.h.inc --gen-attrdef-decls) +mlir_tablegen(GARelAttr.cpp.inc -gen-attrdef-defs) + +set(LLVM_TARGET_DEFINITIONS GARelOps.td) +add_mlir_dialect(GARelOps garel) diff --git a/compiler/include/garel/GARelAttr.h b/compiler/include/garel/GARelAttr.h new file mode 100644 index 0000000..cc1db7d --- /dev/null +++ b/compiler/include/garel/GARelAttr.h @@ -0,0 +1,4 @@ +#pragma once + +#define GET_ATTRDEF_CLASSES +#include "garel/GARelAttr.h.inc" diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td new file mode 100644 index 0000000..11250b3 --- /dev/null +++ b/compiler/include/garel/GARelAttr.td @@ -0,0 +1,38 @@ +#ifndef GAREL_ATTR +#define GAREL_ATTR + +include "mlir/IR/BuiltinAttributeInterfaces.td" + +include "GARelDialect.td" + +class GARel_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +// Also often referred to as 'attribute' in the literature, but that conflicts +// with MLIRs use of the term 'attribute', so name this Column instead. +def ColumnAttr : GARel_Attr<"Column", "column"> { + let summary = "A column (also referred to as 'attribute') of a relation"; + + let parameters = (ins + "::mlir::DistinctAttr":$id, + TypeParameter<"::mlir::Type", "Type of this column">:$type); + + let assemblyFormat = [{ + `<` $id `` `:` `` $type `>` + }]; +} + +def ColumnSetAttr : GARel_Attr<"ColumnSet", "column_set"> { + let summary = "The set of columns in a relation"; + + let parameters = (ins + OptionalArrayRefParameter<"ColumnAttr">:$columns); + + let assemblyFormat = [{ + `<` (`>`) : (`` $columns^ `>`)? + }]; +} + +#endif // GAREL_ATTR diff --git a/compiler/include/garel/GARelDialect.h b/compiler/include/garel/GARelDialect.h new file mode 100644 index 0000000..0149f92 --- /dev/null +++ b/compiler/include/garel/GARelDialect.h @@ -0,0 +1,6 @@ +#pragma once + +#include +#include + +#include "garel/GARelOpsDialect.h.inc" diff --git a/compiler/include/garel/GARelDialect.td b/compiler/include/garel/GARelDialect.td new file mode 100644 index 0000000..f745dbf --- /dev/null +++ b/compiler/include/garel/GARelDialect.td @@ -0,0 +1,25 @@ +#ifndef GAREL_DIALECT +#define GAREL_DIALECT + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" + +def GARel_Dialect : Dialect { + let name = "garel"; + let cppNamespace = "::garel"; + + let extraClassDeclaration = [{ + private: + void registerAttributes(); + void registerTypes(); + }]; + + let usePropertiesForAttributes = 1; + let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; +} + +class GARel_Op traits = []> : + Op; + +#endif // GAREL_DIALECT diff --git a/compiler/include/garel/GARelOps.h b/compiler/include/garel/GARelOps.h new file mode 100644 index 0000000..c05bcd0 --- /dev/null +++ b/compiler/include/garel/GARelOps.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "garel/GARelTypes.h" + +#define GET_OP_CLASSES +#include "garel/GARelOps.h.inc" diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td new file mode 100644 index 0000000..ff9fe06 --- /dev/null +++ b/compiler/include/garel/GARelOps.td @@ -0,0 +1,22 @@ +#ifndef GAREL_OPS +#define GAREL_OPS + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +include "GARelDialect.td" +include "GARelTypes.td" + +def ProjectOp : GARel_Op<"project", [Pure]> { + let summary = "Remaps, reorders, drops and computes columns"; + + let arguments = (ins Relation:$input); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) `->` type($result) attr-dict + }]; +} + +#endif // GAREL_OPS diff --git a/compiler/include/garel/GARelTypes.h b/compiler/include/garel/GARelTypes.h new file mode 100644 index 0000000..3a5e853 --- /dev/null +++ b/compiler/include/garel/GARelTypes.h @@ -0,0 +1,6 @@ +#pragma once + +#include "garel/GARelAttr.h" + +#define GET_TYPEDEF_CLASSES +#include "garel/GARelOpsTypes.h.inc" diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td new file mode 100644 index 0000000..7f7a2e3 --- /dev/null +++ b/compiler/include/garel/GARelTypes.td @@ -0,0 +1,25 @@ + +#ifndef GAREL_TYPES +#define GAREL_TYPES + +include "mlir/IR/AttrTypeBase.td" + +include "GARelDialect.td" +include "GARelAttr.td" + +class GARel_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def Relation : GARel_Type<"Relation", "rel"> { + let summary = "A set of tuples"; + + let parameters = (ins ColumnSetAttr:$columns); + + let assemblyFormat = [{ + `<` $columns `>` + }]; +} + +#endif // GAREL_TYPES diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index de9518a..8217f22 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(graphalg) +add_subdirectory(garel) diff --git a/compiler/src/garel/CMakeLists.txt b/compiler/src/garel/CMakeLists.txt new file mode 100644 index 0000000..8b3781c --- /dev/null +++ b/compiler/src/garel/CMakeLists.txt @@ -0,0 +1,21 @@ +add_library(GARelIR + GARelAttr.cpp + GARelDialect.cpp + GARelOps.cpp + GARelTypes.cpp +) +target_include_directories(GARelIR PUBLIC ../../include) +target_include_directories(GARelIR SYSTEM PUBLIC ${PROJECT_BINARY_DIR}/include) +add_dependencies(GARelIR + MLIRGARelOpsIncGen +) +if(NOT GRAPHALG_ENABLE_RTTI) + target_compile_options(GARelIR PUBLIC -fno-rtti) +endif() +target_link_libraries( + GARelIR + PRIVATE + MLIRInferTypeOpInterface + MLIRIR + MLIRSupport +) diff --git a/compiler/src/garel/GARelAttr.cpp b/compiler/src/garel/GARelAttr.cpp new file mode 100644 index 0000000..7218da4 --- /dev/null +++ b/compiler/src/garel/GARelAttr.cpp @@ -0,0 +1,23 @@ +#include +#include +#include +#include + +#include "garel/GARelAttr.h" +#include "garel/GARelDialect.h" + +#define GET_ATTRDEF_CLASSES +#include "garel/GARelAttr.cpp.inc" + +namespace garel { + +// Need to define this here to avoid depending on GraphAlgAttr in +// GARelDialect and creating a cycle. +void GARelDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "garel/GARelAttr.cpp.inc" + >(); +} + +} // namespace garel diff --git a/compiler/src/garel/GARelDialect.cpp b/compiler/src/garel/GARelDialect.cpp new file mode 100644 index 0000000..b008305 --- /dev/null +++ b/compiler/src/garel/GARelDialect.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "garel/GARelDialect.h" +#include "garel/GARelOps.h" + +#include "garel/GARelOpsDialect.cpp.inc" + +namespace garel { + +namespace { + +class GARelOpAsmDialectInterface : public mlir::OpAsmDialectInterface { +public: + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(mlir::Attribute attr, + mlir::raw_ostream &os) const override { + // Assign aliases to columns. + if (auto colAttr = llvm::dyn_cast(attr)) { + os << "col"; + return AliasResult::FinalAlias; + } + + return AliasResult::NoAlias; + } +}; + +} // namespace + +void GARelDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "garel/GARelOps.cpp.inc" + >(); + registerAttributes(); + registerTypes(); + addInterface(); +} + +} // namespace garel diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp new file mode 100644 index 0000000..1615a9e --- /dev/null +++ b/compiler/src/garel/GARelOps.cpp @@ -0,0 +1,7 @@ +#include "garel/GARelOps.h" +#include "garel/GARelDialect.h" + +#define GET_OP_CLASSES +#include "garel/GARelOps.cpp.inc" + +namespace garel {} // namespace garel diff --git a/compiler/src/garel/GARelTypes.cpp b/compiler/src/garel/GARelTypes.cpp new file mode 100644 index 0000000..ae693a9 --- /dev/null +++ b/compiler/src/garel/GARelTypes.cpp @@ -0,0 +1,23 @@ +#include +#include +#include +#include + +#include "garel/GARelDialect.h" +#include "garel/GARelTypes.h" + +#define GET_TYPEDEF_CLASSES +#include "garel/GARelOpsTypes.cpp.inc" + +namespace garel { + +// Need to define this here to avoid depending on IPRTypes in +// IPRDialect and creating a cycle. +void GARelDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "garel/GARelOpsTypes.cpp.inc" + >(); +} + +} // namespace garel diff --git a/compiler/tools/CMakeLists.txt b/compiler/tools/CMakeLists.txt index a869516..8f81c9c 100644 --- a/compiler/tools/CMakeLists.txt +++ b/compiler/tools/CMakeLists.txt @@ -11,6 +11,7 @@ target_link_libraries(graphalg-opt PRIVATE ${llvm_libs} GraphAlgIR GraphAlgPasses + GARelIR MLIROptLib ) @@ -18,6 +19,7 @@ add_executable(graphalg-lsp-server graphalg-lsp-server.cpp) target_link_libraries(graphalg-lsp-server PRIVATE ${llvm_libs} GraphAlgIR + GARelIR MLIRLspServerLib ) diff --git a/compiler/tools/graphalg-lsp-server.cpp b/compiler/tools/graphalg-lsp-server.cpp index 4bd4378..eedaad1 100644 --- a/compiler/tools/graphalg-lsp-server.cpp +++ b/compiler/tools/graphalg-lsp-server.cpp @@ -4,6 +4,7 @@ #include #include +#include "garel/GARelDialect.h" #include "graphalg/GraphAlgDialect.h" using namespace mlir; @@ -11,6 +12,7 @@ using namespace mlir; int main(int argc, char **argv) { DialectRegistry registry; registry.insert(); + registry.insert(); registry.insert(); registry.insert(); diff --git a/compiler/tools/graphalg-opt.cpp b/compiler/tools/graphalg-opt.cpp index c25da8b..f011237 100644 --- a/compiler/tools/graphalg-opt.cpp +++ b/compiler/tools/graphalg-opt.cpp @@ -6,12 +6,14 @@ #include #include +#include "garel/GARelDialect.h" #include "graphalg/GraphAlgDialect.h" #include "graphalg/GraphAlgPasses.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); + registry.insert(); registry.insert(); graphalg::registerPasses(); From bde16924552a2ab372f991af9eb9613805e51e2a Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Mon, 19 Jan 2026 16:29:45 +0000 Subject: [PATCH 02/32] Verifiers. --- compiler/include/garel/GARelAttr.td | 14 +++++++++++ compiler/include/garel/GARelOps.h | 1 + compiler/include/garel/GARelOps.td | 35 +++++++++++++++++++++++++++- compiler/include/garel/GARelTypes.h | 8 +++++++ compiler/include/garel/GARelTypes.td | 18 +++++++++++++- compiler/src/garel/CMakeLists.txt | 3 +++ compiler/src/garel/GARelOps.cpp | 25 ++++++++++++++++++-- compiler/src/garel/GARelTypes.cpp | 35 ++++++++++++++++++++++++++++ 8 files changed, 135 insertions(+), 4 deletions(-) diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index 11250b3..89041e4 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -24,15 +24,29 @@ def ColumnAttr : GARel_Attr<"Column", "column"> { }]; } +def ColumnList : OptionalArrayRefParameter<"ColumnAttr">; + +/* def ColumnSetAttr : GARel_Attr<"ColumnSet", "column_set"> { let summary = "The set of columns in a relation"; + let description = [{ + An ordered set of columns. + + Order is significant: `` is a different set than ``. + + A column may only appear once in the set: `` is invalid. + }]; + let parameters = (ins OptionalArrayRefParameter<"ColumnAttr">:$columns); let assemblyFormat = [{ `<` (`>`) : (`` $columns^ `>`)? }]; + + let genVerifyDecl = 1; } +*/ #endif // GAREL_ATTR diff --git a/compiler/include/garel/GARelOps.h b/compiler/include/garel/GARelOps.h index c05bcd0..5c158fe 100644 --- a/compiler/include/garel/GARelOps.h +++ b/compiler/include/garel/GARelOps.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index ff9fe06..5a0efc3 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -12,11 +12,44 @@ def ProjectOp : GARel_Op<"project", [Pure]> { let arguments = (ins Relation:$input); + let regions = (region SizedRegion<1>:$projections); + let results = (outs Relation:$result); let assemblyFormat = [{ - $input `:` type($input) `->` type($result) attr-dict + $input `:` type($input) `->` type($result) $projections attr-dict }]; + + // let hasRegionVerifier = 1; +} + +def ProjectReturnOp : GARel_Op<"project.return", [ + Pure, + Terminator, + HasParent<"ProjectOp">]> { + let summary = "The output projections"; + + let arguments = (ins Variadic:$projections); + + let assemblyFormat = [{ + $projections `:` type($projections) attr-dict + }]; + + // NOTE: verification performed by ProjectOp +} + +def ExtractOp : GARel_Op<"extract", [Pure, InferTypeOpAdaptor]> { + let summary = "Extract the value of one column from a tuple"; + + let arguments = (ins ColumnAttr:$column, Tuple:$tuple); + + let results = (outs ColumnType:$result); + + let assemblyFormat = [{ + $column $tuple `:` type($tuple) attr-dict + }]; + + let hasVerifier = 1; } #endif // GAREL_OPS diff --git a/compiler/include/garel/GARelTypes.h b/compiler/include/garel/GARelTypes.h index 3a5e853..e6a4c90 100644 --- a/compiler/include/garel/GARelTypes.h +++ b/compiler/include/garel/GARelTypes.h @@ -1,6 +1,14 @@ #pragma once +#include + #include "garel/GARelAttr.h" #define GET_TYPEDEF_CLASSES #include "garel/GARelOpsTypes.h.inc" + +namespace garel { + +bool isColumnType(mlir::Type t); + +} // namespace garel diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td index 7f7a2e3..43e84b3 100644 --- a/compiler/include/garel/GARelTypes.td +++ b/compiler/include/garel/GARelTypes.td @@ -15,11 +15,27 @@ class GARel_Type traits = []> def Relation : GARel_Type<"Relation", "rel"> { let summary = "A set of tuples"; - let parameters = (ins ColumnSetAttr:$columns); + let parameters = (ins ColumnList:$columns); let assemblyFormat = [{ `<` $columns `>` }]; + + let genVerifyDecl = 1; +} + +def Tuple : GARel_Type<"Tuple", "tuple"> { + let summary = "A single tuple"; + + let parameters = (ins ColumnList:$columns); + + let assemblyFormat = [{ + `<` $columns `>` + }]; + + let genVerifyDecl = 1; } +def ColumnType : Type, "column type">; + #endif // GAREL_TYPES diff --git a/compiler/src/garel/CMakeLists.txt b/compiler/src/garel/CMakeLists.txt index 8b3781c..4f11aec 100644 --- a/compiler/src/garel/CMakeLists.txt +++ b/compiler/src/garel/CMakeLists.txt @@ -9,6 +9,9 @@ target_include_directories(GARelIR SYSTEM PUBLIC ${PROJECT_BINARY_DIR}/include) add_dependencies(GARelIR MLIRGARelOpsIncGen ) +# Suppress -Wdangling-assignment-gsl for generated code from MLIR tablegen +# The warning is triggered by template instantiations in GARelOps.cpp.inc +target_compile_options(GARelIR PRIVATE -Wno-dangling-assignment-gsl) if(NOT GRAPHALG_ENABLE_RTTI) target_compile_options(GARelIR PUBLIC -fno-rtti) endif() diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index 1615a9e..d3dddf5 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -1,7 +1,28 @@ -#include "garel/GARelOps.h" +#include + #include "garel/GARelDialect.h" +#include "garel/GARelOps.h" #define GET_OP_CLASSES #include "garel/GARelOps.cpp.inc" -namespace garel {} // namespace garel +namespace garel { + +mlir::LogicalResult ExtractOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(adaptor.getColumn().getType()); + return mlir::success(); +} + +mlir::LogicalResult ExtractOp::verify() { + auto columns = getTuple().getType().getColumns(); + if (!llvm::is_contained(columns, getColumn())) { + return emitOpError("column ") + << getColumn() << " not included in tuple " << getTuple().getType(); + } + + return mlir::success(); +} + +} // namespace garel diff --git a/compiler/src/garel/GARelTypes.cpp b/compiler/src/garel/GARelTypes.cpp index ae693a9..5e628f5 100644 --- a/compiler/src/garel/GARelTypes.cpp +++ b/compiler/src/garel/GARelTypes.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -11,6 +12,40 @@ namespace garel { +static mlir::LogicalResult +verifyColumnsUnique(llvm::function_ref emitError, + llvm::ArrayRef columns) { + // Columns must be unique + llvm::SmallDenseSet columnSet; + for (auto c : columns) { + auto [_, newlyAdded] = columnSet.insert(c); + if (!newlyAdded) { + return emitError() << "column " << c + << " specified multiple times in the same column set"; + } + } + + return mlir::success(); +} + +mlir::LogicalResult +RelationType::verify(llvm::function_ref emitError, + llvm::ArrayRef columns) { + return verifyColumnsUnique(emitError, columns); +} + +mlir::LogicalResult +TupleType::verify(llvm::function_ref emitError, + llvm::ArrayRef columns) { + return verifyColumnsUnique(emitError, columns); +} + +bool isColumnType(mlir::Type t) { + // Allow i1, si64, f64, index + return t.isSignlessInteger(1) || t.isSignedInteger(64) || t.isF64() || + t.isIndex(); +} + // Need to define this here to avoid depending on IPRTypes in // IPRDialect and creating a cycle. void GARelDialect::registerTypes() { From 087ef659a8bb1fd35834f55ed8002429840d7a03 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Mon, 19 Jan 2026 17:36:27 +0000 Subject: [PATCH 03/32] verify ProjectOp. --- compiler/include/garel/GARelOps.td | 2 +- compiler/src/garel/GARelOps.cpp | 41 ++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 5a0efc3..150515b 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -20,7 +20,7 @@ def ProjectOp : GARel_Op<"project", [Pure]> { $input `:` type($input) `->` type($result) $projections attr-dict }]; - // let hasRegionVerifier = 1; + let hasRegionVerifier = 1; } def ProjectReturnOp : GARel_Op<"project.return", [ diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index d3dddf5..98d43a3 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -8,6 +8,47 @@ namespace garel { +mlir::LogicalResult ProjectOp::verifyRegions() { + if (getProjections().getNumArguments() != 1) { + return emitOpError("projections block should have exactly one argument"); + } + + auto blockArg = getProjections().getArgument(0); + auto blockType = llvm::dyn_cast(blockArg.getType()); + if (!blockType) { + return emitOpError("projections block arg must be of type tuple"); + } + + if (getInput().getType().getColumns() != blockType.getColumns()) { + return emitOpError("projections block columns do not match input columns"); + } + + auto terminator = getProjections().front().getTerminator(); + if (!terminator) { + return emitOpError("missing return from projections block"); + } + + auto returnOp = llvm::dyn_cast(terminator); + if (!returnOp) { + return emitOpError("projections block not terminated by project.return"); + } + + if (returnOp.getProjections().size() != getType().getColumns().size()) { + return emitOpError("projections block returns a different number of " + "values than specified in the projection return type"); + } + + for (const auto &[val, col] : + llvm::zip_equal(returnOp.getProjections(), getType().getColumns())) { + if (val.getType() != col.getType()) { + return emitOpError("projections block return types do not match the " + "projection output column types"); + } + } + + return mlir::success(); +} + mlir::LogicalResult ExtractOp::inferReturnTypes( mlir::MLIRContext *ctx, std::optional location, Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { From d0393a9d2171fd30f5c3bd86e61cce36d5bd8e38 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 20 Jan 2026 16:59:05 +0000 Subject: [PATCH 04/32] Add SelectOp and JoinOp. --- compiler/include/garel/GARelAttr.h | 3 + compiler/include/garel/GARelAttr.td | 16 ++++++ compiler/include/garel/GARelDialect.td | 6 +- compiler/include/garel/GARelOps.td | 75 ++++++++++++++++++++++--- compiler/include/garel/GARelTupleOps.td | 28 +++++++++ compiler/src/garel/CMakeLists.txt | 1 + compiler/src/garel/GARelOps.cpp | 70 ++++++++++++++++++++--- compiler/src/garel/GARelTupleOps.cpp | 24 ++++++++ 8 files changed, 204 insertions(+), 19 deletions(-) create mode 100644 compiler/include/garel/GARelTupleOps.td create mode 100644 compiler/src/garel/GARelTupleOps.cpp diff --git a/compiler/include/garel/GARelAttr.h b/compiler/include/garel/GARelAttr.h index cc1db7d..d0de207 100644 --- a/compiler/include/garel/GARelAttr.h +++ b/compiler/include/garel/GARelAttr.h @@ -1,4 +1,7 @@ #pragma once +#include +#include + #define GET_ATTRDEF_CLASSES #include "garel/GARelAttr.h.inc" diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index 89041e4..c0aad08 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -26,6 +26,22 @@ def ColumnAttr : GARel_Attr<"Column", "column"> { def ColumnList : OptionalArrayRefParameter<"ColumnAttr">; +def JoinPredicate : GARel_Attr<"JoinPredicate", "join_pred"> { + let summary = "A binary equality join predicate"; + + let parameters = (ins "ColumnAttr":$lhs, "ColumnAttr":$rhs); + + let assemblyFormat = [{ + `<` $lhs `=` $rhs `>` + }]; +} + +def JoinPredicates : ArrayOfAttr< + GARel_Dialect, + "JoinPredicates", + "join_preds", + "JoinPredicateAttr">; + /* def ColumnSetAttr : GARel_Attr<"ColumnSet", "column_set"> { let summary = "The set of columns in a relation"; diff --git a/compiler/include/garel/GARelDialect.td b/compiler/include/garel/GARelDialect.td index f745dbf..725499d 100644 --- a/compiler/include/garel/GARelDialect.td +++ b/compiler/include/garel/GARelDialect.td @@ -1,8 +1,9 @@ #ifndef GAREL_DIALECT #define GAREL_DIALECT -include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" def GARel_Dialect : Dialect { let name = "garel"; @@ -19,7 +20,8 @@ def GARel_Dialect : Dialect { let useDefaultTypePrinterParser = 1; } +// NOTE: GARel ops are always 'Pure' class GARel_Op traits = []> : - Op; + Op; #endif // GAREL_DIALECT diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 150515b..3594eb1 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -1,13 +1,21 @@ #ifndef GAREL_OPS #define GAREL_OPS +/** + * Relation-level ops in the GARel dialect. + * + * See ./GARelTupleOps.td for the tuple-level ops. + */ + include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" include "GARelDialect.td" include "GARelTypes.td" -def ProjectOp : GARel_Op<"project", [Pure]> { +// Per-tuple ops kept in a separate file. +include "GARelTupleOps.td" + +def ProjectOp : GARel_Op<"project", [IsolatedFromAbove]> { let summary = "Remaps, reorders, drops and computes columns"; let arguments = (ins Relation:$input); @@ -21,10 +29,13 @@ def ProjectOp : GARel_Op<"project", [Pure]> { }]; let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + ProjectReturnOp getTerminator(); + }]; } def ProjectReturnOp : GARel_Op<"project.return", [ - Pure, Terminator, HasParent<"ProjectOp">]> { let summary = "The output projections"; @@ -38,18 +49,66 @@ def ProjectReturnOp : GARel_Op<"project.return", [ // NOTE: verification performed by ProjectOp } -def ExtractOp : GARel_Op<"extract", [Pure, InferTypeOpAdaptor]> { - let summary = "Extract the value of one column from a tuple"; +def SelectOp : GARel_Op<"select", [ + SameOperandsAndResultType, + IsolatedFromAbove]> { + let summary = "AlgebraSelection"; + + let arguments = (ins Relation:$input); + let regions = (region SizedRegion<1>:$predicates); + + let builders = [OpBuilder<(ins "mlir::Value":$child)>]; + let skipDefaultBuilders = 1; + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) $predicates attr-dict + }]; + + let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + SelectReturnOp getTerminator(); + }]; +} + +def SelectReturnOp : GARel_Op<"select.return", [ + Terminator, + HasParent<"SelectOp">]> { + let summary = "Return the select predicates"; - let arguments = (ins ColumnAttr:$column, Tuple:$tuple); + let arguments = (ins Variadic:$predicates); - let results = (outs ColumnType:$result); + let assemblyFormat = [{ + $predicates attr-dict + }]; +} + +def JoinOp : GARel_Op<"join", [ + Pure, + InferTypeOpAdaptor]> { + let summary = "AlgebraJoin"; + + let arguments = (ins + // NOTE: All inputs must have distinct columns + Variadic:$inputs, + // NOTE: Equality predicates only + JoinPredicates:$predicates); + + let results = (outs Relation:$result); let assemblyFormat = [{ - $column $tuple `:` type($tuple) attr-dict + $inputs `:` type($inputs) + $predicates + attr-dict }]; let hasVerifier = 1; } +// TODO: Aggregate (with sum, min, max, or, and argmin) + +// TODO: Loop Operator + #endif // GAREL_OPS diff --git a/compiler/include/garel/GARelTupleOps.td b/compiler/include/garel/GARelTupleOps.td new file mode 100644 index 0000000..409e702 --- /dev/null +++ b/compiler/include/garel/GARelTupleOps.td @@ -0,0 +1,28 @@ +#ifndef GAREL_TUPLE_OPS +#define GAREL_TUPLE_OPS +/** + * Tuple-level ops in the GARel dialect. + * + * See ./GARelOps.td for the relation-level ops. + */ + +include "mlir/Interfaces/InferTypeOpInterface.td" + +include "GARelDialect.td" +include "GARelTypes.td" + +def ExtractOp : GARel_Op<"extract", [InferTypeOpAdaptor]> { + let summary = "Extract the value of one column from a tuple"; + + let arguments = (ins ColumnAttr:$column, Tuple:$tuple); + + let results = (outs ColumnType:$result); + + let assemblyFormat = [{ + $column $tuple `:` type($tuple) attr-dict + }]; + + let hasVerifier = 1; +} + +#endif // GAREL_TUPLE_OPS diff --git a/compiler/src/garel/CMakeLists.txt b/compiler/src/garel/CMakeLists.txt index 4f11aec..f8a4570 100644 --- a/compiler/src/garel/CMakeLists.txt +++ b/compiler/src/garel/CMakeLists.txt @@ -2,6 +2,7 @@ add_library(GARelIR GARelAttr.cpp GARelDialect.cpp GARelOps.cpp + GARelTupleOps.cpp GARelTypes.cpp ) target_include_directories(GARelIR PUBLIC ../../include) diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index 98d43a3..7afdd0d 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -1,13 +1,18 @@ #include +#include +#include +#include "garel/GARelAttr.h" #include "garel/GARelDialect.h" #include "garel/GARelOps.h" +#include "garel/GARelTypes.h" #define GET_OP_CLASSES #include "garel/GARelOps.cpp.inc" namespace garel { +// === ProjectOp === mlir::LogicalResult ProjectOp::verifyRegions() { if (getProjections().getNumArguments() != 1) { return emitOpError("projections block should have exactly one argument"); @@ -49,20 +54,67 @@ mlir::LogicalResult ProjectOp::verifyRegions() { return mlir::success(); } -mlir::LogicalResult ExtractOp::inferReturnTypes( - mlir::MLIRContext *ctx, std::optional location, - Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.push_back(adaptor.getColumn().getType()); +ProjectReturnOp ProjectOp::getTerminator() { + return llvm::cast(getProjections().front().getTerminator()); +} + +// === SelectOp === +void SelectOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value input) { + auto region = state.addRegion(); + auto &block = region->emplaceBlock(); + auto inputType = llvm::cast(input.getType()); + block.addArgument(builder.getType(inputType.getColumns()), + builder.getUnknownLoc()); + state.addTypes(input.getType()); + state.addOperands(input); +} + +mlir::LogicalResult SelectOp::verifyRegions() { + if (getPredicates().getNumArguments() != 1) { + return emitOpError("predicates block should have exactly one argument"); + } + + auto blockArg = getPredicates().getArgument(0); + auto blockType = llvm::dyn_cast(blockArg.getType()); + if (!blockType) { + return emitOpError("predicates block arg must be of type tuple"); + } + + if (getInput().getType().getColumns() != blockType.getColumns()) { + return emitOpError("predicates block slots do not match child slots"); + } + + auto terminator = getPredicates().front().getTerminator(); + if (!terminator || !llvm::isa(terminator)) { + return emitOpError("predicates block not terminated with select.return"); + } + return mlir::success(); } -mlir::LogicalResult ExtractOp::verify() { - auto columns = getTuple().getType().getColumns(); - if (!llvm::is_contained(columns, getColumn())) { - return emitOpError("column ") - << getColumn() << " not included in tuple " << getTuple().getType(); +SelectReturnOp SelectOp::getTerminator() { + return llvm::cast(getPredicates().front().getTerminator()); +} + +// === JoinOp === +mlir::LogicalResult JoinOp::verify() { + // TODO: Inputs must use distinct columns. + // TODO: Predicates must refer to columns in distinct inputs (and to columns + // present in the input). + return emitOpError("TODO: verify JoinOp"); +} + +mlir::LogicalResult JoinOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns; + for (auto input : adaptor.getInputs()) { + auto inputColumns = llvm::cast(input.getType()).getColumns(); + outputColumns.append(inputColumns.begin(), inputColumns.end()); } + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); return mlir::success(); } diff --git a/compiler/src/garel/GARelTupleOps.cpp b/compiler/src/garel/GARelTupleOps.cpp new file mode 100644 index 0000000..198cc14 --- /dev/null +++ b/compiler/src/garel/GARelTupleOps.cpp @@ -0,0 +1,24 @@ +#include + +#include "garel/GARelOps.h" + +namespace garel { + +mlir::LogicalResult ExtractOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(adaptor.getColumn().getType()); + return mlir::success(); +} + +mlir::LogicalResult ExtractOp::verify() { + auto columns = getTuple().getType().getColumns(); + if (!llvm::is_contained(columns, getColumn())) { + return emitOpError("column ") + << getColumn() << " not included in tuple " << getTuple().getType(); + } + + return mlir::success(); +} + +} // namespace garel From 96cd4e3c83a5efe2d6dd7fc0f43fbd19f659e82e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 Jan 2026 10:08:30 +0000 Subject: [PATCH 05/32] All relational ops. --- compiler/include/garel/CMakeLists.txt | 2 + compiler/include/garel/GARelAttr.h | 2 + compiler/include/garel/GARelAttr.td | 55 ++++++++++++++---- compiler/include/garel/GARelOps.td | 77 ++++++++++++++++++++++--- compiler/include/garel/GARelTypes.td | 4 +- compiler/src/garel/GARelAttr.cpp | 30 +++++++++- compiler/src/garel/GARelOps.cpp | 82 ++++++++++++++++++++++++++- 7 files changed, 229 insertions(+), 23 deletions(-) diff --git a/compiler/include/garel/CMakeLists.txt b/compiler/include/garel/CMakeLists.txt index 883398b..b6a97a1 100644 --- a/compiler/include/garel/CMakeLists.txt +++ b/compiler/include/garel/CMakeLists.txt @@ -2,6 +2,8 @@ include_directories(SYSTEM ${MLIR_INCLUDE_DIRS}) include_directories(SYSTEM ${PROJECT_BINARY_DIR}/include) set(LLVM_TARGET_DEFINITIONS GARelAttr.td) +mlir_tablegen(GARelEnumAttr.h.inc -gen-enum-decls) +mlir_tablegen(GARelEnumAttr.cpp.inc -gen-enum-defs) mlir_tablegen(GARelAttr.h.inc --gen-attrdef-decls) mlir_tablegen(GARelAttr.cpp.inc -gen-attrdef-defs) diff --git a/compiler/include/garel/GARelAttr.h b/compiler/include/garel/GARelAttr.h index d0de207..4319db6 100644 --- a/compiler/include/garel/GARelAttr.h +++ b/compiler/include/garel/GARelAttr.h @@ -3,5 +3,7 @@ #include #include +#include "garel/GARelEnumAttr.h.inc" + #define GET_ATTRDEF_CLASSES #include "garel/GARelAttr.h.inc" diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index c0aad08..7bb0a87 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -22,9 +22,13 @@ def ColumnAttr : GARel_Attr<"Column", "column"> { let assemblyFormat = [{ `<` $id `` `:` `` $type `>` }]; + + let extraClassDeclaration = [{ + static ColumnAttr newOfType(mlir::Type type); + }]; } -def ColumnList : OptionalArrayRefParameter<"ColumnAttr">; +def ColumnListParameter : OptionalArrayRefParameter<"ColumnAttr">; def JoinPredicate : GARel_Attr<"JoinPredicate", "join_pred"> { let summary = "A binary equality join predicate"; @@ -42,27 +46,54 @@ def JoinPredicates : ArrayOfAttr< "join_preds", "JoinPredicateAttr">; -/* -def ColumnSetAttr : GARel_Attr<"ColumnSet", "column_set"> { - let summary = "The set of columns in a relation"; +def ColumnList : GARel_Attr<"ColumnList", "columns"> { + let summary = "Array of columns"; - let description = [{ - An ordered set of columns. - - Order is significant: `` is a different set than ``. + let parameters = (ins + ColumnListParameter:$columns); - A column may only appear once in the set: `` is invalid. + let assemblyFormat = [{ + `<` (`>`) : (`` $columns^ `>`)? }]; +} + +def AggregateFunc : I64EnumAttr< + "AggregateFunc", "", + [ + I64EnumAttrCase<"SUM", 0>, + I64EnumAttrCase<"MIN", 1>, + I64EnumAttrCase<"MAX", 2>, + I64EnumAttrCase<"LOR", 3>, /* Logical OR (over i1) */ + I64EnumAttrCase<"ARGMIN", 4>, + ] +> { + let cppNamespace = "::garel"; +} + +// NOTE: assumes an aggregator produces exactly one output column. +def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { + let summary = "Aggregate function with bound input columns"; let parameters = (ins - OptionalArrayRefParameter<"ColumnAttr">:$columns); + "AggregateFunc":$func, + ColumnListParameter:$inputs); let assemblyFormat = [{ - `<` (`>`) : (`` $columns^ `>`)? + `<` $func $inputs `>` }]; let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + /** Type for values in the output column. */ + mlir::Type getResultType(); + }]; } -*/ + +def Aggregators : ArrayOfAttr< + GARel_Dialect, + "Aggregators", + "aggregators", + "AggregatorAttr">; #endif // GAREL_ATTR diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 3594eb1..33d01ea 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -52,7 +52,7 @@ def ProjectReturnOp : GARel_Op<"project.return", [ def SelectOp : GARel_Op<"select", [ SameOperandsAndResultType, IsolatedFromAbove]> { - let summary = "AlgebraSelection"; + let summary = "Removes tuples that fail (one of) the predicates"; let arguments = (ins Relation:$input); let regions = (region SizedRegion<1>:$predicates); @@ -85,10 +85,8 @@ def SelectReturnOp : GARel_Op<"select.return", [ }]; } -def JoinOp : GARel_Op<"join", [ - Pure, - InferTypeOpAdaptor]> { - let summary = "AlgebraJoin"; +def JoinOp : GARel_Op<"join", [InferTypeOpAdaptor]> { + let summary = "Natural (equi)join of relations"; let arguments = (ins // NOTE: All inputs must have distinct columns @@ -107,8 +105,73 @@ def JoinOp : GARel_Op<"join", [ let hasVerifier = 1; } -// TODO: Aggregate (with sum, min, max, or, and argmin) +def UnionOp : GARel_Op<"union", [SameOperandsAndResultType]> { + let summary = "Union of relations"; -// TODO: Loop Operator + let arguments = (ins Variadic:$inputs); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $inputs `:` type($inputs) + attr-dict + }]; +} + +def AggregateOp : GARel_Op<"aggregate", [InferTypeOpAdaptor]> { + let summary = "Groups tuples by key columns, aggregating values of other columns"; + + let arguments = (ins + Relation:$input, + ColumnList:$groupBy, + Aggregators:$aggregators); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) + `group_by` `` `=` `` $groupBy + `aggregators` `` `=` `` $aggregators + attr-dict + }]; +} + +def ForOp : GARel_Op<"for", [InferTypeOpAdaptor]> { + let summary = "Bounded iteration"; + + let arguments = (ins + Variadic:$init, + I64Attr:$iters, + I64Attr:$resultIdx); + + let regions = (region SizedRegion<1>:$body); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $init `:` type($init) + `iters` `` `=` `` $iters + `result_idx` `` `=` `` $resultIdx + $body + attr-dict + }]; + + let hasVerifier = 1; + let hasRegionVerifier = 1; +} + +def ForYieldOp : GARel_Op<"for.yield", [ + Terminator, + HasParent<"ForOp">]> { + let summary = "Produces the iter args for the next iteration"; + + let arguments = (ins Variadic:$inputs); + + let assemblyFormat = [{ + $inputs `:` type($inputs) attr-dict + }]; + + // Note: verification performed by parent ForOp. +} #endif // GAREL_OPS diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td index 43e84b3..595359d 100644 --- a/compiler/include/garel/GARelTypes.td +++ b/compiler/include/garel/GARelTypes.td @@ -15,7 +15,7 @@ class GARel_Type traits = []> def Relation : GARel_Type<"Relation", "rel"> { let summary = "A set of tuples"; - let parameters = (ins ColumnList:$columns); + let parameters = (ins ColumnListParameter:$columns); let assemblyFormat = [{ `<` $columns `>` @@ -27,7 +27,7 @@ def Relation : GARel_Type<"Relation", "rel"> { def Tuple : GARel_Type<"Tuple", "tuple"> { let summary = "A single tuple"; - let parameters = (ins ColumnList:$columns); + let parameters = (ins ColumnListParameter:$columns); let assemblyFormat = [{ `<` $columns `>` diff --git a/compiler/src/garel/GARelAttr.cpp b/compiler/src/garel/GARelAttr.cpp index 7218da4..b8c162d 100644 --- a/compiler/src/garel/GARelAttr.cpp +++ b/compiler/src/garel/GARelAttr.cpp @@ -1,17 +1,45 @@ +#include #include #include +#include #include #include #include "garel/GARelAttr.h" #include "garel/GARelDialect.h" +#include "garel/GARelEnumAttr.cpp.inc" #define GET_ATTRDEF_CLASSES #include "garel/GARelAttr.cpp.inc" namespace garel { -// Need to define this here to avoid depending on GraphAlgAttr in +ColumnAttr ColumnAttr::newOfType(mlir::Type type) { + auto *ctx = type.getContext(); + auto colId = mlir::DistinctAttr::create(mlir::UnitAttr::get(ctx)); + return ColumnAttr::get(ctx, colId, type); +} + +mlir::Type AggregatorAttr::getResultType() { + switch (getFunc()) { + case AggregateFunc::SUM: + case AggregateFunc::MIN: + case AggregateFunc::MAX: + case AggregateFunc::LOR: + case AggregateFunc::ARGMIN: + // NOTE: argmin(arg, val) also uses first input as output type. + return getInputs()[0].getType(); + } +} + +mlir::LogicalResult +AggregatorAttr::verify(llvm::function_ref emitError, + AggregateFunc func, llvm::ArrayRef inputs) { + // TODO: Verify input column count and type(s) + return mlir::success(); +} + +// Need to define this here to avoid depending on GARelAttr in // GARelDialect and creating a cycle. void GARelDialect::registerAttributes() { addAttributes< diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index 7afdd0d..ef81b50 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include "garel/GARelAttr.h" #include "garel/GARelDialect.h" @@ -102,7 +103,7 @@ mlir::LogicalResult JoinOp::verify() { // TODO: Inputs must use distinct columns. // TODO: Predicates must refer to columns in distinct inputs (and to columns // present in the input). - return emitOpError("TODO: verify JoinOp"); + return mlir::success(); } mlir::LogicalResult JoinOp::inferReturnTypes( @@ -118,4 +119,83 @@ mlir::LogicalResult JoinOp::inferReturnTypes( return mlir::success(); } +// === AggregateOp === +mlir::LogicalResult AggregateOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns; + + // Key columns + auto keyColumns = adaptor.getGroupBy().getColumns(); + outputColumns.append(keyColumns.begin(), keyColumns.end()); + + // Aggregator outputs + for (auto agg : adaptor.getAggregators()) { + outputColumns.push_back(ColumnAttr::newOfType(agg.getResultType())); + } + + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); + return mlir::success(); +} + +// === ForOp === +static mlir::LogicalResult +verifyResultIdx(llvm::function_ref emitError, + mlir::ValueRange initArgs, std::uint64_t resultIdx) { + // resultIdx is within bounds of init args. + if (initArgs.size() <= resultIdx) { + return emitError() << "has result_idx=" << resultIdx + << ", but there are only " << initArgs.size() + << " init args"; + } + + return mlir::success(); +} + +mlir::LogicalResult ForOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + auto loc = location ? *location : mlir::UnknownLoc::get(ctx); + if (mlir::failed(verifyResultIdx( + [&]() { + return mlir::emitError(loc) + << ForOp::getOperationName() << " to build with init args " + << adaptor.getInit() << " "; + }, + adaptor.getInit(), adaptor.getResultIdx()))) { + return mlir::failure(); + } + + auto resultType = adaptor.getInit()[adaptor.getResultIdx()].getType(); + inferredReturnTypes.emplace_back(resultType); + return mlir::success(); +} + +mlir::LogicalResult ForOp::verify() { + return verifyResultIdx([this]() { return emitOpError(); }, getInit(), + getResultIdx()); +} + +mlir::LogicalResult ForOp::verifyRegions() { + auto initTypes = getInit().getTypes(); + + // Body arg types match init args + auto argTypes = getBody().front().getArgumentTypes(); + if (initTypes != argTypes) { + return emitOpError("body arg types do not match the initial value types"); + } + + // Body result types match init args + auto yieldOp = llvm::cast(getBody().front().getTerminator()); + auto resTypes = yieldOp.getInputs().getTypes(); + if (initTypes != resTypes) { + auto diag = + emitOpError("body result types do not match the initial value types"); + diag.attachNote(yieldOp.getLoc()) << "body result is here"; + return diag; + } + + return mlir::success(); +} + } // namespace garel From 01bc2f24e1ba2fe50b464cd37fe59ae5cb4c723c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 Jan 2026 11:18:17 +0000 Subject: [PATCH 06/32] Start conversion pass. --- compiler/include/garel/CMakeLists.txt | 4 + compiler/include/garel/GARelOps.td | 1 + compiler/include/garel/GARelPasses.h | 11 + compiler/include/garel/GARelPasses.td | 15 + compiler/src/garel/CMakeLists.txt | 12 + compiler/src/garel/GARelOps.cpp | 11 + compiler/src/garel/GraphAlgToRel.cpp | 338 +++++++++++++++++++ compiler/test/graphalg-to-rel/transpose.mlir | 34 ++ compiler/tools/CMakeLists.txt | 1 + compiler/tools/graphalg-opt.cpp | 2 + 10 files changed, 429 insertions(+) create mode 100644 compiler/include/garel/GARelPasses.h create mode 100644 compiler/include/garel/GARelPasses.td create mode 100644 compiler/src/garel/GraphAlgToRel.cpp create mode 100644 compiler/test/graphalg-to-rel/transpose.mlir diff --git a/compiler/include/garel/CMakeLists.txt b/compiler/include/garel/CMakeLists.txt index b6a97a1..f221bc4 100644 --- a/compiler/include/garel/CMakeLists.txt +++ b/compiler/include/garel/CMakeLists.txt @@ -9,3 +9,7 @@ mlir_tablegen(GARelAttr.cpp.inc -gen-attrdef-defs) set(LLVM_TARGET_DEFINITIONS GARelOps.td) add_mlir_dialect(GARelOps garel) + +set(LLVM_TARGET_DEFINITIONS GARelPasses.td) +mlir_tablegen(GARelPasses.h.inc --gen-pass-decls) +add_public_tablegen_target(MLIRGARelPassesIncGen) diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 33d01ea..bc8ab2b 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -31,6 +31,7 @@ def ProjectOp : GARel_Op<"project", [IsolatedFromAbove]> { let hasRegionVerifier = 1; let extraClassDeclaration = [{ + mlir::Block& createProjectionsBlock(); ProjectReturnOp getTerminator(); }]; } diff --git a/compiler/include/garel/GARelPasses.h b/compiler/include/garel/GARelPasses.h new file mode 100644 index 0000000..e432976 --- /dev/null +++ b/compiler/include/garel/GARelPasses.h @@ -0,0 +1,11 @@ +#pragma once + +namespace garel { + +#define GEN_PASS_DECL +#include "garel/GARelPasses.h.inc" + +#define GEN_PASS_REGISTRATION +#include "garel/GARelPasses.h.inc" + +} // namespace garel diff --git a/compiler/include/garel/GARelPasses.td b/compiler/include/garel/GARelPasses.td new file mode 100644 index 0000000..e378828 --- /dev/null +++ b/compiler/include/garel/GARelPasses.td @@ -0,0 +1,15 @@ +#ifndef GAREL_PASSES +#define GAREL_PASSES + +include "mlir/Pass/PassBase.td" + +def GraphAlgToRel : Pass<"graphalg-to-rel", "::mlir::ModuleOp"> { + let summary = "Convert GraphAlg Core IR to relational ops"; + + let dependentDialects = [ + "garel::GARelDialect", + "graphalg::GraphAlgDialect", + ]; +} + +#endif // GAREL_PASSES diff --git a/compiler/src/garel/CMakeLists.txt b/compiler/src/garel/CMakeLists.txt index f8a4570..6c8622e 100644 --- a/compiler/src/garel/CMakeLists.txt +++ b/compiler/src/garel/CMakeLists.txt @@ -9,6 +9,7 @@ target_include_directories(GARelIR PUBLIC ../../include) target_include_directories(GARelIR SYSTEM PUBLIC ${PROJECT_BINARY_DIR}/include) add_dependencies(GARelIR MLIRGARelOpsIncGen + MLIRGARelPassesIncGen ) # Suppress -Wdangling-assignment-gsl for generated code from MLIR tablegen # The warning is triggered by template instantiations in GARelOps.cpp.inc @@ -23,3 +24,14 @@ target_link_libraries( MLIRIR MLIRSupport ) + +add_library(GraphAlgToRel + GraphAlgToRel.cpp +) +target_link_libraries( + GraphAlgToRel + PRIVATE + GraphAlgIR + GARelIR + MLIRPass +) diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index ef81b50..7868e8a 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include "garel/GARelAttr.h" @@ -55,6 +56,16 @@ mlir::LogicalResult ProjectOp::verifyRegions() { return mlir::success(); } +mlir::Block &ProjectOp::createProjectionsBlock() { + assert(getProjections().empty() && "Already have a projections block"); + auto &block = getProjections().emplaceBlock(); + // Same columns as the input, but as a tuple. + block.addArgument( + TupleType::get(getContext(), getInput().getType().getColumns()), + getInput().getLoc()); + return block; +} + ProjectReturnOp ProjectOp::getTerminator() { return llvm::cast(getProjections().front().getTerminator()); } diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp new file mode 100644 index 0000000..caa1147 --- /dev/null +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -0,0 +1,338 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "garel/GARelAttr.h" +#include "garel/GARelDialect.h" +#include "garel/GARelOps.h" +#include "garel/GARelTypes.h" +#include "graphalg/GraphAlgDialect.h" +#include "graphalg/GraphAlgOps.h" +#include "graphalg/GraphAlgTypes.h" +#include "graphalg/SemiringTypes.h" + +namespace garel { + +#define GEN_PASS_DEF_GRAPHALGTOREL +#include "garel/GARelPasses.h.inc" + +namespace { + +class GraphAlgToRel : public impl::GraphAlgToRelBase { +public: + using impl::GraphAlgToRelBase::GraphAlgToRelBase; + + void runOnOperation() final; +}; + +/** Converts semiring types into their relational equivalents. */ +class SemiringTypeConverter : public mlir::TypeConverter { +private: + static mlir::Type convertSemiringType(graphalg::SemiringTypeInterface type); + +public: + SemiringTypeConverter(); +}; + +/** Converts matrix types into relations. */ +class MatrixTypeConverter : public mlir::TypeConverter { +private: + SemiringTypeConverter _semiringConverter; + + mlir::FunctionType convertFunctionType(mlir::FunctionType type) const; + RelationType convertMatrixType(graphalg::MatrixType type) const; + +public: + MatrixTypeConverter(mlir::MLIRContext *ctx, + const SemiringTypeConverter &semiringConverter); +}; + +/** + * Convenient wrapper around a matrix value and its relation equivalent + * after type conversion. + * + * This class is particularly useful for retrieving the relation column for + * the rows, columns or values of the matrix. + */ +class MatrixAdaptor { +private: + mlir::TypedValue _matrix; + + RelationType _relType; + // May be null for outputs, in which case only the relation type is available. + mlir::TypedValue _relation; + +public: + // For output matrices, where we only have the desired output type. + MatrixAdaptor(mlir::Value matrix, mlir::Type streamType) + : _matrix(llvm::cast>(matrix)), + _relType(llvm::cast(streamType)) {} + + // For input matrices, where the OpAdaptor provides the relation value. + MatrixAdaptor(mlir::Value matrix, mlir::Value relation) + : MatrixAdaptor(matrix, relation.getType()) { + this->_relation = llvm::cast>(relation); + } + + graphalg::MatrixType matrixType() { return _matrix.getType(); } + + RelationType relType() { return _relType; } + + mlir::TypedValue relation() { + assert(!!_relation && "No relation value (only type)"); + return _relation; + } + + llvm::ArrayRef columns() { return _relType.getColumns(); } + + ColumnAttr row() { + if (_matrix.getType().getRows().isOne()) { + return {}; + } + + return columns().front(); + } + + ColumnAttr col() { + if (_matrix.getType().getCols().isOne()) { + return {}; + } + + return columns().drop_back().back(); + } + + /** + * The row or column column, depending on which one is present. + * + * NOTE: Vector relations have a row or a column slot, but not both. Row and + * column vectors have identical relational representation of (idx, val). + */ + ColumnAttr vectorIdxColumn() { + assert(columns().size() == 2 && "Not a vector"); + return columns().front(); + } + + ColumnAttr valSlot() { return columns().back(); } + + bool isScalar() { return _matrix.getType().isScalar(); } + + graphalg::SemiringTypeInterface semiring() { + return llvm::cast( + _matrix.getType().getSemiring()); + } +}; + +template class OpConversion : public mlir::OpConversionPattern { + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(T op, + typename mlir::OpConversionPattern::OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace + +mlir::Type +SemiringTypeConverter::convertSemiringType(graphalg::SemiringTypeInterface t) { + auto *ctx = t.getContext(); + // To i1 + if (t == graphalg::SemiringTypes::forBool(ctx)) { + return mlir::IntegerType::get(ctx, 1); + } + + // To i64 + if (t == graphalg::SemiringTypes::forInt(ctx) || + t == graphalg::SemiringTypes::forTropInt(ctx) || + t == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + return mlir::IntegerType::get(ctx, 64, mlir::IntegerType::Signed); + } + + // To f64 + if (t == graphalg::SemiringTypes::forReal(ctx) || + t == graphalg::SemiringTypes::forTropReal(ctx)) { + return mlir::Float64Type::get(ctx); + } + + return nullptr; +} + +SemiringTypeConverter::SemiringTypeConverter() { + addConversion(convertSemiringType); +} + +mlir::FunctionType +MatrixTypeConverter::convertFunctionType(mlir::FunctionType type) const { + llvm::SmallVector inputs; + if (mlir::failed(convertTypes(type.getInputs(), inputs))) { + return {}; + } + + llvm::SmallVector results; + if (mlir::failed(convertTypes(type.getResults(), results))) { + return {}; + } + + return mlir::FunctionType::get(type.getContext(), inputs, results); +} + +RelationType +MatrixTypeConverter::convertMatrixType(graphalg::MatrixType type) const { + llvm::SmallVector columns; + auto *ctx = type.getContext(); + auto indexType = mlir::IndexType::get(ctx); + if (!type.getRows().isOne()) { + columns.push_back(ColumnAttr::newOfType(indexType)); + } + + if (!type.getCols().isOne()) { + columns.push_back(ColumnAttr::newOfType(indexType)); + } + + auto valueType = _semiringConverter.convertType(type.getSemiring()); + if (!valueType) { + return {}; + } + + columns.push_back(ColumnAttr::newOfType(valueType)); + return RelationType::get(ctx, columns); +} + +MatrixTypeConverter::MatrixTypeConverter( + mlir::MLIRContext *ctx, const SemiringTypeConverter &semiringConverter) + : _semiringConverter(semiringConverter) { + addConversion( + [this](mlir::FunctionType t) { return convertFunctionType(t); }); + + addConversion( + [this](graphalg::MatrixType t) { return convertMatrixType(t); }); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + mlir::func::FuncOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // TODO: type caching means that if two function arguments have the same + // matrix type, they will also be assigned the same relation type, and + // therefore have the same set of slots. Consider doing our own conversion + // here to avoid that. + auto funcType = llvm::cast_if_present( + typeConverter->convertType(op.getFunctionType())); + if (!funcType) { + return op->emitOpError("function type ") + << op.getFunctionType() << " cannot be converted"; + } + + auto result = mlir::success(); + rewriter.modifyOpInPlace(op, [&]() { + // Update function type. + op.setFunctionType(funcType); + }); + + // Convert block args. + mlir::TypeConverter::SignatureConversion newSig(funcType.getNumInputs()); + if (mlir::failed( + rewriter.convertRegionTypes(&op.getFunctionBody(), *typeConverter))) { + return mlir::failure(); + } + + return result; +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + mlir::func::ReturnOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + rewriter.modifyOpInPlace(op, + [&]() { op->setOperands(adaptor.getOperands()); }); + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::TransposeOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + auto projectOp = rewriter.replaceOpWithNewOp(op, output.relType(), + input.relation()); + + auto &body = projectOp.createProjectionsBlock(); + rewriter.setInsertionPointToStart(&body); + + llvm::SmallVector columns(input.columns()); + assert(columns.size() <= 3); + // Transpose is a no-op if there are fewer than 3 columns. + if (columns.size() == 3) { + // Swap row and column + std::swap(columns[0], columns[1]); + } + + // Return the input slots (after row and column have been swapped) + llvm::SmallVector results; + for (auto col : columns) { + results.emplace_back( + rewriter.create(op.getLoc(), col, body.getArgument(0))); + } + + rewriter.create(op.getLoc(), results); + + return mlir::success(); +} + +static bool hasRelationSignature(mlir::func::FuncOp op) { + // All inputs should be relations + auto funcType = op.getFunctionType(); + for (auto input : funcType.getInputs()) { + if (!llvm::isa(input)) { + return false; + } + } + + // There should be exactly one relation result + return funcType.getNumResults() == 1 && + llvm::isa(funcType.getResult(0)); +} + +static bool hasRelationOperands(mlir::Operation *op) { + return llvm::all_of(op->getOperandTypes(), + [](auto t) { return llvm::isa(t); }); +} + +void GraphAlgToRel::runOnOperation() { + mlir::ConversionTarget target(getContext()); + // Eliminate all graphalg ops (and the few ops we use from arith) ... + target.addIllegalDialect(); + target.addIllegalDialect(); + // And turn them into relational ops. + target.addLegalDialect(); + // Keep container module. + target.addLegalOp(); + // Keep functions, but change their signature. + target.addDynamicallyLegalOp(hasRelationSignature); + target.addDynamicallyLegalOp(hasRelationOperands); + + SemiringTypeConverter semiringTypeConverter; + MatrixTypeConverter matrixTypeConverter(&getContext(), semiringTypeConverter); + + mlir::RewritePatternSet patterns(&getContext()); + patterns + .add, OpConversion, + OpConversion>(matrixTypeConverter, + &getContext()); + + // Scalar patterns. + + if (mlir::failed(mlir::applyFullConversion(getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace garel diff --git a/compiler/test/graphalg-to-rel/transpose.mlir b/compiler/test/graphalg-to-rel/transpose.mlir new file mode 100644 index 0000000..f989af6 --- /dev/null +++ b/compiler/test/graphalg-to-rel/transpose.mlir @@ -0,0 +1,34 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @TransposeMatrix +func.func @TransposeMatrix(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<43 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : <[[ROW:#col[0-9]*]], #col1, #col2> -> <#col3, #col4, #col5> { + // CHECK: %[[#COL_SLOT:]] = garel.extract #col1 + // CHECK: %[[#ROW_SLOT:]] = garel.extract [[ROW]] + // CHECK: %[[#VAL_SLOT:]] = garel.extract #col2 + // CHECK: garel.project.return %[[#COL_SLOT]], %[[#ROW_SLOT]], %[[#VAL_SLOT]] + %0 = graphalg.transpose %arg0 : <42 x 43 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<43 x 42 x i64> +} + +// CHECK-LABEL: @TransposeColVec +func.func @TransposeColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 42 x i64> { + %0 = graphalg.transpose %arg0 : <42 x 1 x i64> + + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @TransposeRowVec +func.func @TransposeRowVec(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<43 x 1 x i64> { + %0 = graphalg.transpose %arg0 : <1 x 43 x i64> + + return %0 : !graphalg.mat<43 x 1 x i64> +} + +func.func @TransposeScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.transpose %arg0 : <1 x 1 x i64> + + return %0 : !graphalg.mat<1 x 1 x i64> +} diff --git a/compiler/tools/CMakeLists.txt b/compiler/tools/CMakeLists.txt index 8f81c9c..6efaafd 100644 --- a/compiler/tools/CMakeLists.txt +++ b/compiler/tools/CMakeLists.txt @@ -12,6 +12,7 @@ target_link_libraries(graphalg-opt PRIVATE GraphAlgIR GraphAlgPasses GARelIR + GraphAlgToRel MLIROptLib ) diff --git a/compiler/tools/graphalg-opt.cpp b/compiler/tools/graphalg-opt.cpp index f011237..5bc7f2e 100644 --- a/compiler/tools/graphalg-opt.cpp +++ b/compiler/tools/graphalg-opt.cpp @@ -7,6 +7,7 @@ #include #include "garel/GARelDialect.h" +#include "garel/GARelPasses.h" #include "graphalg/GraphAlgDialect.h" #include "graphalg/GraphAlgPasses.h" @@ -18,6 +19,7 @@ int main(int argc, char **argv) { graphalg::registerPasses(); graphalg::registerGraphAlgToCorePipeline(); + garel::registerPasses(); mlir::registerCanonicalizerPass(); mlir::registerInlinerPass(); mlir::registerCSEPass(); From 5f8e0a3b8715256ebb68acf6e5e3f135096e80e8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 Jan 2026 11:53:01 +0000 Subject: [PATCH 07/32] Use local ids instead of column attributes. --- compiler/include/garel/GARelAttr.td | 43 ++++------------------ compiler/include/garel/GARelOps.td | 2 +- compiler/include/garel/GARelTupleOps.td | 2 +- compiler/include/garel/GARelTypes.td | 8 +--- compiler/src/garel/GARelAttr.cpp | 29 +++++++++------ compiler/src/garel/GARelDialect.cpp | 29 --------------- compiler/src/garel/GARelOps.cpp | 16 +++++--- compiler/src/garel/GARelTupleOps.cpp | 8 +++- compiler/src/garel/GARelTypes.cpp | 28 -------------- compiler/src/garel/GraphAlgToRel.cpp | 49 +++++-------------------- 10 files changed, 55 insertions(+), 159 deletions(-) diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index 7bb0a87..33f268b 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -10,33 +10,17 @@ class GARel_Attr traits = []> let mnemonic = attrMnemonic; } -// Also often referred to as 'attribute' in the literature, but that conflicts -// with MLIRs use of the term 'attribute', so name this Column instead. -def ColumnAttr : GARel_Attr<"Column", "column"> { - let summary = "A column (also referred to as 'attribute') of a relation"; - - let parameters = (ins - "::mlir::DistinctAttr":$id, - TypeParameter<"::mlir::Type", "Type of this column">:$type); - - let assemblyFormat = [{ - `<` $id `` `:` `` $type `>` - }]; - - let extraClassDeclaration = [{ - static ColumnAttr newOfType(mlir::Type type); - }]; -} - -def ColumnListParameter : OptionalArrayRefParameter<"ColumnAttr">; - def JoinPredicate : GARel_Attr<"JoinPredicate", "join_pred"> { let summary = "A binary equality join predicate"; - let parameters = (ins "ColumnAttr":$lhs, "ColumnAttr":$rhs); + let parameters = (ins + "unsigned":$lhsRelIdx, + "unsigned":$lhsColIdx, + "unsigned":$rhsRelIdx, + "unsigned":$rhsColIdx); let assemblyFormat = [{ - `<` $lhs `=` $rhs `>` + `<` $lhsRelIdx `[` $lhsColIdx `]` `=` $rhsRelIdx `[` $rhsColIdx `]` `>` }]; } @@ -46,17 +30,6 @@ def JoinPredicates : ArrayOfAttr< "join_preds", "JoinPredicateAttr">; -def ColumnList : GARel_Attr<"ColumnList", "columns"> { - let summary = "Array of columns"; - - let parameters = (ins - ColumnListParameter:$columns); - - let assemblyFormat = [{ - `<` (`>`) : (`` $columns^ `>`)? - }]; -} - def AggregateFunc : I64EnumAttr< "AggregateFunc", "", [ @@ -76,7 +49,7 @@ def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { let parameters = (ins "AggregateFunc":$func, - ColumnListParameter:$inputs); + ArrayRefParameter<"unsigned">:$inputs); let assemblyFormat = [{ `<` $func $inputs `>` @@ -86,7 +59,7 @@ def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { let extraClassDeclaration = [{ /** Type for values in the output column. */ - mlir::Type getResultType(); + mlir::Type getResultType(::mlir::Type relType); }]; } diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index bc8ab2b..b05ab79 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -124,7 +124,7 @@ def AggregateOp : GARel_Op<"aggregate", [InferTypeOpAdaptor]> { let arguments = (ins Relation:$input, - ColumnList:$groupBy, + DenseI32ArrayAttr:$groupBy, Aggregators:$aggregators); let results = (outs Relation:$result); diff --git a/compiler/include/garel/GARelTupleOps.td b/compiler/include/garel/GARelTupleOps.td index 409e702..0244442 100644 --- a/compiler/include/garel/GARelTupleOps.td +++ b/compiler/include/garel/GARelTupleOps.td @@ -14,7 +14,7 @@ include "GARelTypes.td" def ExtractOp : GARel_Op<"extract", [InferTypeOpAdaptor]> { let summary = "Extract the value of one column from a tuple"; - let arguments = (ins ColumnAttr:$column, Tuple:$tuple); + let arguments = (ins I32Attr:$column, Tuple:$tuple); let results = (outs ColumnType:$result); diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td index 595359d..5d848b5 100644 --- a/compiler/include/garel/GARelTypes.td +++ b/compiler/include/garel/GARelTypes.td @@ -15,25 +15,21 @@ class GARel_Type traits = []> def Relation : GARel_Type<"Relation", "rel"> { let summary = "A set of tuples"; - let parameters = (ins ColumnListParameter:$columns); + let parameters = (ins OptionalArrayRefParameter<"mlir::Type">:$columns); let assemblyFormat = [{ `<` $columns `>` }]; - - let genVerifyDecl = 1; } def Tuple : GARel_Type<"Tuple", "tuple"> { let summary = "A single tuple"; - let parameters = (ins ColumnListParameter:$columns); + let parameters = (ins OptionalArrayRefParameter<"mlir::Type">:$columns); let assemblyFormat = [{ `<` $columns `>` }]; - - let genVerifyDecl = 1; } def ColumnType : Type, "column type">; diff --git a/compiler/src/garel/GARelAttr.cpp b/compiler/src/garel/GARelAttr.cpp index b8c162d..487551d 100644 --- a/compiler/src/garel/GARelAttr.cpp +++ b/compiler/src/garel/GARelAttr.cpp @@ -9,33 +9,40 @@ #include "garel/GARelDialect.h" #include "garel/GARelEnumAttr.cpp.inc" +#include "garel/GARelTypes.h" #define GET_ATTRDEF_CLASSES #include "garel/GARelAttr.cpp.inc" namespace garel { -ColumnAttr ColumnAttr::newOfType(mlir::Type type) { - auto *ctx = type.getContext(); - auto colId = mlir::DistinctAttr::create(mlir::UnitAttr::get(ctx)); - return ColumnAttr::get(ctx, colId, type); -} - -mlir::Type AggregatorAttr::getResultType() { +mlir::Type AggregatorAttr::getResultType(mlir::Type inputRel) { switch (getFunc()) { case AggregateFunc::SUM: case AggregateFunc::MIN: case AggregateFunc::MAX: case AggregateFunc::LOR: case AggregateFunc::ARGMIN: - // NOTE: argmin(arg, val) also uses first input as output type. - return getInputs()[0].getType(); + // NOTE: argmin(arg, val) also uses first input column as output type. + return llvm::cast(inputRel).getColumns()[getInputs()[0]]; } } mlir::LogicalResult AggregatorAttr::verify(llvm::function_ref emitError, - AggregateFunc func, llvm::ArrayRef inputs) { - // TODO: Verify input column count and type(s) + AggregateFunc func, llvm::ArrayRef inputs) { + if (func == AggregateFunc::ARGMIN) { + if (inputs.size() != 2) { + return emitError() << stringifyAggregateFunc(func) + << " expects exactly two inputs (arg, val), got " + << inputs.size(); + } + } else { + if (inputs.size() != 1) { + return emitError() << stringifyAggregateFunc(func) + << " expects exactly one input, got " << inputs.size(); + } + } + return mlir::success(); } diff --git a/compiler/src/garel/GARelDialect.cpp b/compiler/src/garel/GARelDialect.cpp index b008305..7d2fae3 100644 --- a/compiler/src/garel/GARelDialect.cpp +++ b/compiler/src/garel/GARelDialect.cpp @@ -1,11 +1,3 @@ -#include -#include -#include -#include -#include -#include -#include - #include "garel/GARelDialect.h" #include "garel/GARelOps.h" @@ -13,26 +5,6 @@ namespace garel { -namespace { - -class GARelOpAsmDialectInterface : public mlir::OpAsmDialectInterface { -public: - using OpAsmDialectInterface::OpAsmDialectInterface; - - AliasResult getAlias(mlir::Attribute attr, - mlir::raw_ostream &os) const override { - // Assign aliases to columns. - if (auto colAttr = llvm::dyn_cast(attr)) { - os << "col"; - return AliasResult::FinalAlias; - } - - return AliasResult::NoAlias; - } -}; - -} // namespace - void GARelDialect::initialize() { addOperations< #define GET_OP_LIST @@ -40,7 +12,6 @@ void GARelDialect::initialize() { >(); registerAttributes(); registerTypes(); - addInterface(); } } // namespace garel diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index 7868e8a..8f3046d 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -47,7 +47,7 @@ mlir::LogicalResult ProjectOp::verifyRegions() { for (const auto &[val, col] : llvm::zip_equal(returnOp.getProjections(), getType().getColumns())) { - if (val.getType() != col.getType()) { + if (val.getType() != col) { return emitOpError("projections block return types do not match the " "projection output column types"); } @@ -120,7 +120,7 @@ mlir::LogicalResult JoinOp::verify() { mlir::LogicalResult JoinOp::inferReturnTypes( mlir::MLIRContext *ctx, std::optional location, Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { - llvm::SmallVector outputColumns; + llvm::SmallVector outputColumns; for (auto input : adaptor.getInputs()) { auto inputColumns = llvm::cast(input.getType()).getColumns(); outputColumns.append(inputColumns.begin(), inputColumns.end()); @@ -134,15 +134,19 @@ mlir::LogicalResult JoinOp::inferReturnTypes( mlir::LogicalResult AggregateOp::inferReturnTypes( mlir::MLIRContext *ctx, std::optional location, Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { - llvm::SmallVector outputColumns; + llvm::SmallVector outputColumns; + + auto inputType = llvm::cast(adaptor.getInput().getType()); + auto inputColumns = inputType.getColumns(); // Key columns - auto keyColumns = adaptor.getGroupBy().getColumns(); - outputColumns.append(keyColumns.begin(), keyColumns.end()); + for (auto key : adaptor.getGroupBy()) { + outputColumns.push_back(inputColumns[key]); + } // Aggregator outputs for (auto agg : adaptor.getAggregators()) { - outputColumns.push_back(ColumnAttr::newOfType(agg.getResultType())); + outputColumns.push_back(agg.getResultType(inputType)); } inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); diff --git a/compiler/src/garel/GARelTupleOps.cpp b/compiler/src/garel/GARelTupleOps.cpp index 198cc14..a78f269 100644 --- a/compiler/src/garel/GARelTupleOps.cpp +++ b/compiler/src/garel/GARelTupleOps.cpp @@ -1,19 +1,23 @@ #include #include "garel/GARelOps.h" +#include "garel/GARelTypes.h" namespace garel { mlir::LogicalResult ExtractOp::inferReturnTypes( mlir::MLIRContext *ctx, std::optional location, Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { - inferredReturnTypes.push_back(adaptor.getColumn().getType()); + // TODO: Do we need this cast? + auto tupleType = llvm::cast(adaptor.getTuple().getType()); + auto columnTypes = tupleType.getColumns(); + inferredReturnTypes.push_back(columnTypes[adaptor.getColumn()]); return mlir::success(); } mlir::LogicalResult ExtractOp::verify() { auto columns = getTuple().getType().getColumns(); - if (!llvm::is_contained(columns, getColumn())) { + if (getColumn() >= getTuple().getType().getColumns().size()) { return emitOpError("column ") << getColumn() << " not included in tuple " << getTuple().getType(); } diff --git a/compiler/src/garel/GARelTypes.cpp b/compiler/src/garel/GARelTypes.cpp index 5e628f5..a0fa81c 100644 --- a/compiler/src/garel/GARelTypes.cpp +++ b/compiler/src/garel/GARelTypes.cpp @@ -12,34 +12,6 @@ namespace garel { -static mlir::LogicalResult -verifyColumnsUnique(llvm::function_ref emitError, - llvm::ArrayRef columns) { - // Columns must be unique - llvm::SmallDenseSet columnSet; - for (auto c : columns) { - auto [_, newlyAdded] = columnSet.insert(c); - if (!newlyAdded) { - return emitError() << "column " << c - << " specified multiple times in the same column set"; - } - } - - return mlir::success(); -} - -mlir::LogicalResult -RelationType::verify(llvm::function_ref emitError, - llvm::ArrayRef columns) { - return verifyColumnsUnique(emitError, columns); -} - -mlir::LogicalResult -TupleType::verify(llvm::function_ref emitError, - llvm::ArrayRef columns) { - return verifyColumnsUnique(emitError, columns); -} - bool isColumnType(mlir::Type t) { // Allow i1, si64, f64, index return t.isSignlessInteger(1) || t.isSignedInteger(64) || t.isF64() || diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index caa1147..4a36651 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -87,36 +89,7 @@ class MatrixAdaptor { return _relation; } - llvm::ArrayRef columns() { return _relType.getColumns(); } - - ColumnAttr row() { - if (_matrix.getType().getRows().isOne()) { - return {}; - } - - return columns().front(); - } - - ColumnAttr col() { - if (_matrix.getType().getCols().isOne()) { - return {}; - } - - return columns().drop_back().back(); - } - - /** - * The row or column column, depending on which one is present. - * - * NOTE: Vector relations have a row or a column slot, but not both. Row and - * column vectors have identical relational representation of (idx, val). - */ - ColumnAttr vectorIdxColumn() { - assert(columns().size() == 2 && "Not a vector"); - return columns().front(); - } - - ColumnAttr valSlot() { return columns().back(); } + auto columns() { return _relType.getColumns(); } bool isScalar() { return _matrix.getType().isScalar(); } @@ -182,15 +155,14 @@ MatrixTypeConverter::convertFunctionType(mlir::FunctionType type) const { RelationType MatrixTypeConverter::convertMatrixType(graphalg::MatrixType type) const { - llvm::SmallVector columns; + llvm::SmallVector columns; auto *ctx = type.getContext(); - auto indexType = mlir::IndexType::get(ctx); if (!type.getRows().isOne()) { - columns.push_back(ColumnAttr::newOfType(indexType)); + columns.push_back(mlir::IndexType::get(ctx)); } if (!type.getCols().isOne()) { - columns.push_back(ColumnAttr::newOfType(indexType)); + columns.push_back(mlir::IndexType::get(ctx)); } auto valueType = _semiringConverter.convertType(type.getSemiring()); @@ -198,7 +170,7 @@ MatrixTypeConverter::convertMatrixType(graphalg::MatrixType type) const { return {}; } - columns.push_back(ColumnAttr::newOfType(valueType)); + columns.push_back(valueType); return RelationType::get(ctx, columns); } @@ -216,10 +188,6 @@ template <> mlir::LogicalResult OpConversion::matchAndRewrite( mlir::func::FuncOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - // TODO: type caching means that if two function arguments have the same - // matrix type, they will also be assigned the same relation type, and - // therefore have the same set of slots. Consider doing our own conversion - // here to avoid that. auto funcType = llvm::cast_if_present( typeConverter->convertType(op.getFunctionType())); if (!funcType) { @@ -266,7 +234,8 @@ mlir::LogicalResult OpConversion::matchAndRewrite( auto &body = projectOp.createProjectionsBlock(); rewriter.setInsertionPointToStart(&body); - llvm::SmallVector columns(input.columns()); + llvm::SmallVector columns(input.columns().size()); + std::iota(columns.begin(), columns.end(), 0); assert(columns.size() <= 3); // Transpose is a no-op if there are fewer than 3 columns. if (columns.size() == 3) { From b95b73a50e68f41a026d9a6f73c212b48887c25e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 Jan 2026 13:07:49 +0000 Subject: [PATCH 08/32] Fix transpose test. --- compiler/include/garel/GARelAttr.h | 7 +++++++ compiler/include/garel/GARelAttr.td | 6 +++--- compiler/src/garel/GARelAttr.cpp | 2 +- compiler/src/garel/GraphAlgToRel.cpp | 2 +- compiler/test/graphalg-to-rel/transpose.mlir | 20 ++++++++++++++++---- 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/compiler/include/garel/GARelAttr.h b/compiler/include/garel/GARelAttr.h index 4319db6..3388dbd 100644 --- a/compiler/include/garel/GARelAttr.h +++ b/compiler/include/garel/GARelAttr.h @@ -5,5 +5,12 @@ #include "garel/GARelEnumAttr.h.inc" +namespace garel { + +/** Reference to a column inside of \c RelationType or \c TupleType. */ +using ColumnIdx = unsigned; + +} // namespace garel + #define GET_ATTRDEF_CLASSES #include "garel/GARelAttr.h.inc" diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index 33f268b..5599528 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -15,9 +15,9 @@ def JoinPredicate : GARel_Attr<"JoinPredicate", "join_pred"> { let parameters = (ins "unsigned":$lhsRelIdx, - "unsigned":$lhsColIdx, + "ColumnIdx":$lhsColIdx, "unsigned":$rhsRelIdx, - "unsigned":$rhsColIdx); + "ColumnIdx":$rhsColIdx); let assemblyFormat = [{ `<` $lhsRelIdx `[` $lhsColIdx `]` `=` $rhsRelIdx `[` $rhsColIdx `]` `>` @@ -49,7 +49,7 @@ def Aggregator : GARel_Attr<"Aggregator", "aggregator"> { let parameters = (ins "AggregateFunc":$func, - ArrayRefParameter<"unsigned">:$inputs); + ArrayRefParameter<"ColumnIdx">:$inputs); let assemblyFormat = [{ `<` $func $inputs `>` diff --git a/compiler/src/garel/GARelAttr.cpp b/compiler/src/garel/GARelAttr.cpp index 487551d..abb69cc 100644 --- a/compiler/src/garel/GARelAttr.cpp +++ b/compiler/src/garel/GARelAttr.cpp @@ -29,7 +29,7 @@ mlir::Type AggregatorAttr::getResultType(mlir::Type inputRel) { mlir::LogicalResult AggregatorAttr::verify(llvm::function_ref emitError, - AggregateFunc func, llvm::ArrayRef inputs) { + AggregateFunc func, llvm::ArrayRef inputs) { if (func == AggregateFunc::ARGMIN) { if (inputs.size() != 2) { return emitError() << stringifyAggregateFunc(func) diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 4a36651..66f5427 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -234,7 +234,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( auto &body = projectOp.createProjectionsBlock(); rewriter.setInsertionPointToStart(&body); - llvm::SmallVector columns(input.columns().size()); + llvm::SmallVector columns(input.columns().size()); std::iota(columns.begin(), columns.end(), 0); assert(columns.size() <= 3); // Transpose is a no-op if there are fewer than 3 columns. diff --git a/compiler/test/graphalg-to-rel/transpose.mlir b/compiler/test/graphalg-to-rel/transpose.mlir index f989af6..ab15c9e 100644 --- a/compiler/test/graphalg-to-rel/transpose.mlir +++ b/compiler/test/graphalg-to-rel/transpose.mlir @@ -2,10 +2,10 @@ // CHECK-LABEL: @TransposeMatrix func.func @TransposeMatrix(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<43 x 42 x i64> { - // CHECK: %[[#PROJECT:]] = garel.project %arg0 : <[[ROW:#col[0-9]*]], #col1, #col2> -> <#col3, #col4, #col5> { - // CHECK: %[[#COL_SLOT:]] = garel.extract #col1 - // CHECK: %[[#ROW_SLOT:]] = garel.extract [[ROW]] - // CHECK: %[[#VAL_SLOT:]] = garel.extract #col2 + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#COL_SLOT:]] = garel.extract 1 + // CHECK: %[[#ROW_SLOT:]] = garel.extract 0 + // CHECK: %[[#VAL_SLOT:]] = garel.extract 2 // CHECK: garel.project.return %[[#COL_SLOT]], %[[#ROW_SLOT]], %[[#VAL_SLOT]] %0 = graphalg.transpose %arg0 : <42 x 43 x i64> @@ -15,20 +15,32 @@ func.func @TransposeMatrix(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat // CHECK-LABEL: @TransposeColVec func.func @TransposeColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] : index, si64 %0 = graphalg.transpose %arg0 : <42 x 1 x i64> + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 42 x i64> } // CHECK-LABEL: @TransposeRowVec func.func @TransposeRowVec(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<43 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#COL:]] = garel.extract 0 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#COL]], %[[#VAL]] : index, si64 %0 = graphalg.transpose %arg0 : <1 x 43 x i64> + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<43 x 1 x i64> } func.func @TransposeScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // NOTE: folded away %0 = graphalg.transpose %arg0 : <1 x 1 x i64> + // CHECK: return %arg0 return %0 : !graphalg.mat<1 x 1 x i64> } From f53660648a400e85ae2447bc8447bd2025e71b56 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 21 Jan 2026 15:44:34 +0000 Subject: [PATCH 09/32] Start with ApplyOp. --- compiler/include/garel/GARelOps.td | 12 ++ compiler/src/garel/GARelOps.cpp | 9 + compiler/src/garel/GraphAlgToRel.cpp | 232 ++++++++++++++++++++++- compiler/test/graphalg-to-rel/apply.mlir | 118 ++++++++++++ 4 files changed, 367 insertions(+), 4 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/apply.mlir diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index b05ab79..62af3f9 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -175,4 +175,16 @@ def ForYieldOp : GARel_Op<"for.yield", [ // Note: verification performed by parent ForOp. } +def RangeOp : GARel_Op<"range", [InferTypeOpAdaptor]> { + let summary = "generates a range of `index` values from `[0, size)`"; + + let arguments = (ins I64Attr:$size); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $size attr-dict + }]; +} + #endif // GAREL_OPS diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index 8f3046d..dc6c89d 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "garel/GARelAttr.h" #include "garel/GARelDialect.h" @@ -213,4 +214,12 @@ mlir::LogicalResult ForOp::verifyRegions() { return mlir::success(); } +// === RangeOp === +mlir::LogicalResult RangeOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back( + RelationType::get(ctx, {mlir::IndexType::get(ctx)})); +} + } // namespace garel diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 66f5427..206884e 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1,10 +1,14 @@ #include +#include +#include #include #include +#include #include #include #include +#include #include #include @@ -12,6 +16,7 @@ #include "garel/GARelDialect.h" #include "garel/GARelOps.h" #include "garel/GARelTypes.h" +#include "graphalg/GraphAlgAttr.h" #include "graphalg/GraphAlgDialect.h" #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" @@ -70,9 +75,9 @@ class MatrixAdaptor { public: // For output matrices, where we only have the desired output type. - MatrixAdaptor(mlir::Value matrix, mlir::Type streamType) + MatrixAdaptor(mlir::Value matrix, mlir::Type relType) : _matrix(llvm::cast>(matrix)), - _relType(llvm::cast(streamType)) {} + _relType(llvm::cast(relType)) {} // For input matrices, where the OpAdaptor provides the relation value. MatrixAdaptor(mlir::Value matrix, mlir::Value relation) @@ -89,9 +94,29 @@ class MatrixAdaptor { return _relation; } - auto columns() { return _relType.getColumns(); } + auto columns() const { return _relType.getColumns(); } - bool isScalar() { return _matrix.getType().isScalar(); } + bool isScalar() const { return _matrix.getType().isScalar(); } + + bool hasRowColumn() const { return !_matrix.getType().getRows().isOne(); } + + bool hasColColumn() const { return !_matrix.getType().getCols().isOne(); } + + ColumnIdx rowColumn() const { + assert(hasRowColumn()); + return 0; + } + + ColumnIdx colColumn() const { + assert(hasColColumn()); + // Follow row column, if there is one. + return hasRowColumn() ? 1 : 0; + } + + ColumnIdx valColumn() const { + // Last column in the relation. + return columns().size() - 1; + } graphalg::SemiringTypeInterface semiring() { return llvm::cast( @@ -108,6 +133,28 @@ template class OpConversion : public mlir::OpConversionPattern { mlir::ConversionPatternRewriter &rewriter) const override; }; +class ApplyOpConversion : public mlir::OpConversionPattern { +private: + const SemiringTypeConverter &_bodyArgConverter; + + mlir::LogicalResult + matchAndRewrite(graphalg::ApplyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override; + +public: + ApplyOpConversion(const SemiringTypeConverter &bodyArgConverter, + const MatrixTypeConverter &typeConverter, + mlir::MLIRContext *ctx) + : mlir::OpConversionPattern(typeConverter, ctx), + _bodyArgConverter(bodyArgConverter) {} +}; + +struct InputColumnRef { + unsigned relIdx; + ColumnIdx colIdx; + ColumnIdx outIdx; +}; + } // namespace mlir::Type @@ -255,6 +302,178 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +/** + * Create a relation with all indices for a matrix dimension. + * + * Used to broadcast scalar values to a larger matrix. + */ +static RangeOp createDimRead(mlir::Location loc, graphalg::DimAttr dim, + mlir::OpBuilder &builder) { + return builder.create(loc, dim.getConcreteDim()); +} + +static void +buildApplyJoinPredicates(mlir::MLIRContext *ctx, + llvm::SmallVectorImpl &predicates, + llvm::ArrayRef columnsToJoin) { + if (columnsToJoin.size() < 2) { + return; + } + + auto first = columnsToJoin.front(); + for (auto other : columnsToJoin.drop_front()) { + predicates.push_back(JoinPredicateAttr::get(ctx, first.relIdx, first.colIdx, + other.relIdx, other.colIdx)); + } +} + +static constexpr llvm::StringLiteral APPLY_ROW_IDX_ATTR_KEY = + "garel.apply.row_idx"; +static constexpr llvm::StringLiteral APPLY_COL_IDX_ATTR_KEY = + "garel.apply.col_idx"; + +mlir::LogicalResult ApplyOpConversion::matchAndRewrite( + graphalg::ApplyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm::SmallVector inputs; + for (auto [matrix, relation] : + llvm::zip_equal(op.getInputs(), adaptor.getInputs())) { + auto &input = inputs.emplace_back(matrix, relation); + } + + llvm::SmallVector joinChildren; + llvm::SmallVector rowColumns; + llvm::SmallVector colColumns; + llvm::SmallVector valColumns; + ColumnIdx nextColumnIdx = 0; + for (const auto &[idx, input] : llvm::enumerate(inputs)) { + joinChildren.emplace_back(input.relation()); + + if (input.hasRowColumn()) { + rowColumns.push_back(InputColumnRef{ + .relIdx = static_cast(idx), + .colIdx = input.rowColumn(), + .outIdx = nextColumnIdx + input.rowColumn(), + }); + } + + if (input.hasColColumn()) { + colColumns.push_back(InputColumnRef{ + .relIdx = static_cast(idx), + .colIdx = input.colColumn(), + .outIdx = nextColumnIdx + input.colColumn(), + }); + } + + valColumns.push_back(nextColumnIdx + input.valColumn()); + nextColumnIdx += input.columns().size(); + } + + auto outputType = typeConverter->convertType(op.getType()); + MatrixAdaptor output(op.getResult(), outputType); + if (rowColumns.empty() && output.hasRowColumn()) { + // None of the inputs have a row column, but we need it in the output. + // Broadcast to all rows. + auto rowsOp = + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); + joinChildren.emplace_back(rowsOp); + rowColumns.push_back(InputColumnRef{ + .relIdx = static_cast(joinChildren.size() - 1), + .colIdx = 0, + .outIdx = nextColumnIdx++, + }); + } + + if (colColumns.empty() && output.hasColColumn()) { + // None of the inputs have a col column, but we need it in the output. + // Broadcast to all columns. + auto colsOp = + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); + joinChildren.emplace_back(colsOp); + colColumns.push_back(InputColumnRef{ + .relIdx = static_cast(joinChildren.size() - 1), + .colIdx = 0, + .outIdx = nextColumnIdx++, + }); + } + + mlir::Value joined; + if (joinChildren.size() == 1) { + joined = joinChildren.front(); + } else { + llvm::SmallVector predicates; + buildApplyJoinPredicates(rewriter.getContext(), predicates, rowColumns); + buildApplyJoinPredicates(rewriter.getContext(), predicates, colColumns); + joined = rewriter.create(op.getLoc(), joinChildren); + } + + auto projectOp = rewriter.create(op->getLoc(), outputType, joined); + + // Convert old body + if (mlir::failed( + rewriter.convertRegionTypes(&op.getBody(), _bodyArgConverter))) { + return op->emitOpError("failed to convert body argument types"); + } + + // Read value columns, to be used as arg replacements for the old body. + mlir::OpBuilder::InsertionGuard guard(rewriter); + auto &body = projectOp.createProjectionsBlock(); + rewriter.setInsertionPointToStart(&body); + + llvm::SmallVector slotReads; + for (auto col : valColumns) { + slotReads.emplace_back( + rewriter.create(op->getLoc(), col, body.getArgument(0))); + } + + // Inline into new body + rewriter.inlineBlockBefore(&op.getBody().front(), &body, body.end(), + slotReads); + + rewriter.replaceOp(op, projectOp); + + // Attach the row and column slot to the return op. + auto returnOp = llvm::cast(body.getTerminator()); + if (!rowColumns.empty()) { + returnOp->setAttr(APPLY_ROW_IDX_ATTR_KEY, + rewriter.getUI32IntegerAttr(rowColumns[0].outIdx)); + } + + if (!colColumns.empty()) { + returnOp->setAttr(APPLY_COL_IDX_ATTR_KEY, + rewriter.getUI32IntegerAttr(colColumns[0].outIdx)); + } + + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ApplyReturnOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + llvm::SmallVector results; + + // Note: conversion is done top-down, so the ApplyOp is converted to + // ProjectOp before we reach this op in its body. + auto inputTuple = op->getBlock()->getArgument(0); + + if (auto idx = op->getAttrOfType(APPLY_ROW_IDX_ATTR_KEY)) { + results.emplace_back( + rewriter.create(op->getLoc(), idx, inputTuple)); + } + + if (auto idx = op->getAttrOfType(APPLY_COL_IDX_ATTR_KEY)) { + results.emplace_back( + rewriter.create(op->getLoc(), idx, inputTuple)); + } + + // The value slot + results.emplace_back(adaptor.getValue()); + + rewriter.replaceOpWithNewOp(op, results); + return mlir::success(); +} + static bool hasRelationSignature(mlir::func::FuncOp op) { // All inputs should be relations auto funcType = op.getFunctionType(); @@ -295,8 +514,13 @@ void GraphAlgToRel::runOnOperation() { .add, OpConversion, OpConversion>(matrixTypeConverter, &getContext()); + patterns.add(semiringTypeConverter, matrixTypeConverter, + &getContext()); // Scalar patterns. + // patterns.add(convertArithConstant); + patterns.add>(semiringTypeConverter, + &getContext()); if (mlir::failed(mlir::applyFullConversion(getOperation(), target, std::move(patterns)))) { diff --git a/compiler/test/graphalg-to-rel/apply.mlir b/compiler/test/graphalg-to-rel/apply.mlir new file mode 100644 index 0000000..35f3484 --- /dev/null +++ b/compiler/test/graphalg-to-rel/apply.mlir @@ -0,0 +1,118 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === Arity + +// CHECK-LABEL: @ApplyUnary +func.func @ApplyUnary(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + %2 = graphalg.add %1, %arg1 : i64 + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @ApplyBinary +func.func @ApplyBinary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg2: i64, %arg3: i64): + %1 = graphalg.add %arg2, %arg3 : i64 + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @ApplyTernary +func.func @ApplyTernary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>, %arg2: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + %0 = graphalg.apply %arg0, %arg1, %arg2 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg3: i64, %arg4: i64, %arg5: i64): + %1 = graphalg.add %arg3, %arg4 : i64 + %2 = graphalg.add %1, %arg5 : i64 + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// === Shape + +// CHECK-LABEL: @ApplyMat +func.func @ApplyMat(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + %2 = graphalg.add %1, %arg1 : i64 + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @ApplyRowVec +func.func @ApplyRowVec(%arg0: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<1 x 42 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 42 x i64> -> <1 x 42 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + %2 = graphalg.add %1, %arg1 : i64 + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @ApplyColVec +func.func @ApplyColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> { + ^bb0(%arg42: i64): + %1 = graphalg.const 1 : i64 + %2 = graphalg.add %1, %arg42 : i64 + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ApplyScalar +func.func @ApplyScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg1: i64): + %1 = graphalg.const 1 : i64 + %2 = graphalg.add %1, %arg1 : i64 + graphalg.apply.return %2 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ApplyBroadcastScalar +func.func @ApplyBroadcastScalar(%arg0 : !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 43 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <42 x 43 x i64> { + ^bb0(%arg1: i64): + graphalg.apply.return %arg1 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 43 x i64> +} + +// CHECK-LABEL: @ApplyBroadcastOne +func.func @ApplyBroadcastOne(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 1 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { + ^bb0(%arg2: i64, %arg3: i64): + %1 = graphalg.add %arg2, %arg3 : i64 + graphalg.apply.return %1 : i64 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 42 x i64> +} From a7535f52ddcdeab5114ee084a56b8dac4630d35f Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 21 Jan 2026 21:19:46 +0000 Subject: [PATCH 10/32] AddOp, finish apply.mlir test. --- compiler/src/garel/GARelOps.cpp | 2 + compiler/src/garel/GARelTypes.cpp | 4 +- compiler/src/garel/GraphAlgToRel.cpp | 184 +++++++++++++++---- compiler/test/graphalg-to-rel/apply.mlir | 54 +++++- compiler/test/graphalg-to-rel/transpose.mlir | 10 +- 5 files changed, 204 insertions(+), 50 deletions(-) diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index dc6c89d..d8cd6d1 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "garel/GARelAttr.h" #include "garel/GARelDialect.h" @@ -220,6 +221,7 @@ mlir::LogicalResult RangeOp::inferReturnTypes( Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back( RelationType::get(ctx, {mlir::IndexType::get(ctx)})); + return mlir::success(); } } // namespace garel diff --git a/compiler/src/garel/GARelTypes.cpp b/compiler/src/garel/GARelTypes.cpp index a0fa81c..618642d 100644 --- a/compiler/src/garel/GARelTypes.cpp +++ b/compiler/src/garel/GARelTypes.cpp @@ -13,8 +13,8 @@ namespace garel { bool isColumnType(mlir::Type t) { - // Allow i1, si64, f64, index - return t.isSignlessInteger(1) || t.isSignedInteger(64) || t.isF64() || + // Allow i1, i64, f64, index + return t.isSignlessInteger(1) || t.isSignlessInteger(64) || t.isF64() || t.isIndex(); } diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 206884e..2d1ad52 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -157,6 +158,10 @@ struct InputColumnRef { } // namespace +// ============================================================================= +// =============================== Class Methods =============================== +// ============================================================================= + mlir::Type SemiringTypeConverter::convertSemiringType(graphalg::SemiringTypeInterface t) { auto *ctx = t.getContext(); @@ -169,7 +174,7 @@ SemiringTypeConverter::convertSemiringType(graphalg::SemiringTypeInterface t) { if (t == graphalg::SemiringTypes::forInt(ctx) || t == graphalg::SemiringTypes::forTropInt(ctx) || t == graphalg::SemiringTypes::forTropMaxInt(ctx)) { - return mlir::IntegerType::get(ctx, 64, mlir::IntegerType::Signed); + return mlir::IntegerType::get(ctx, 64); } // To f64 @@ -231,6 +236,91 @@ MatrixTypeConverter::MatrixTypeConverter( [this](graphalg::MatrixType t) { return convertMatrixType(t); }); } +// ============================================================================= +// ============================== Helper Methods =============================== +// ============================================================================= + +/** + * Create a relation with all indices for a matrix dimension. + * + * Used to broadcast scalar values to a larger matrix. + */ +static RangeOp createDimRead(mlir::Location loc, graphalg::DimAttr dim, + mlir::OpBuilder &builder) { + return builder.create(loc, dim.getConcreteDim()); +} + +static void +buildApplyJoinPredicates(mlir::MLIRContext *ctx, + llvm::SmallVectorImpl &predicates, + llvm::ArrayRef columnsToJoin) { + if (columnsToJoin.size() < 2) { + return; + } + + auto first = columnsToJoin.front(); + for (auto other : columnsToJoin.drop_front()) { + predicates.push_back(JoinPredicateAttr::get(ctx, first.relIdx, first.colIdx, + other.relIdx, other.colIdx)); + } +} + +static mlir::FailureOr convertConstant(mlir::Operation *op, + mlir::TypedAttr attr) { + auto *ctx = attr.getContext(); + auto type = attr.getType(); + if (type == graphalg::SemiringTypes::forBool(ctx)) { + return attr; + } else if (type == graphalg::SemiringTypes::forInt(ctx)) { + // TODO: Need to convert to signed? + return attr; + } else if (type == graphalg::SemiringTypes::forReal(ctx)) { + return attr; + } else if (type == graphalg::SemiringTypes::forTropInt(ctx)) { + std::int64_t value; + if (llvm::isa(attr)) { + // Positive infinity, kind of. + value = std::numeric_limits::max(); + } else { + auto intAttr = llvm::cast(attr); + value = intAttr.getValue().getValue().getSExtValue(); + } + + return mlir::TypedAttr( + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), value)); + } else if (type == graphalg::SemiringTypes::forTropReal(ctx)) { + double value; + if (llvm::isa(attr)) { + // Has a proper positive infinity value + value = std::numeric_limits::infinity(); + } else { + auto floatAttr = llvm::cast(attr); + value = floatAttr.getValue().getValueAsDouble(); + } + + return mlir::TypedAttr( + mlir::FloatAttr::get(mlir::Float64Type::get(ctx), value)); + } else if (type == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + std::int64_t value; + if (llvm::isa(attr)) { + // Negative infinity, kind of. + value = std::numeric_limits::min(); + } else { + auto intAttr = llvm::cast(attr); + value = intAttr.getValue().getValue().getSExtValue(); + } + + return mlir::TypedAttr( + mlir::FloatAttr::get(mlir::Float64Type::get(ctx), value)); + } + + return op->emitOpError("cannot convert constant ") << attr; +} + +// ============================================================================= +// =============================== Op Conversion =============================== +// ============================================================================= + template <> mlir::LogicalResult OpConversion::matchAndRewrite( mlir::func::FuncOp op, OpAdaptor adaptor, @@ -302,31 +392,6 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } -/** - * Create a relation with all indices for a matrix dimension. - * - * Used to broadcast scalar values to a larger matrix. - */ -static RangeOp createDimRead(mlir::Location loc, graphalg::DimAttr dim, - mlir::OpBuilder &builder) { - return builder.create(loc, dim.getConcreteDim()); -} - -static void -buildApplyJoinPredicates(mlir::MLIRContext *ctx, - llvm::SmallVectorImpl &predicates, - llvm::ArrayRef columnsToJoin) { - if (columnsToJoin.size() < 2) { - return; - } - - auto first = columnsToJoin.front(); - for (auto other : columnsToJoin.drop_front()) { - predicates.push_back(JoinPredicateAttr::get(ctx, first.relIdx, first.colIdx, - other.relIdx, other.colIdx)); - } -} - static constexpr llvm::StringLiteral APPLY_ROW_IDX_ATTR_KEY = "garel.apply.row_idx"; static constexpr llvm::StringLiteral APPLY_COL_IDX_ATTR_KEY = @@ -404,7 +469,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( llvm::SmallVector predicates; buildApplyJoinPredicates(rewriter.getContext(), predicates, rowColumns); buildApplyJoinPredicates(rewriter.getContext(), predicates, colColumns); - joined = rewriter.create(op.getLoc(), joinChildren); + joined = rewriter.create(op.getLoc(), joinChildren, predicates); } auto projectOp = rewriter.create(op->getLoc(), outputType, joined); @@ -436,12 +501,12 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( auto returnOp = llvm::cast(body.getTerminator()); if (!rowColumns.empty()) { returnOp->setAttr(APPLY_ROW_IDX_ATTR_KEY, - rewriter.getUI32IntegerAttr(rowColumns[0].outIdx)); + rewriter.getI32IntegerAttr(rowColumns[0].outIdx)); } if (!colColumns.empty()) { returnOp->setAttr(APPLY_COL_IDX_ATTR_KEY, - rewriter.getUI32IntegerAttr(colColumns[0].outIdx)); + rewriter.getI32IntegerAttr(colColumns[0].outIdx)); } return mlir::success(); @@ -474,6 +539,54 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +// ============================================================================= +// ============================ Tuple Op Conversion ============================ +// ============================================================================= + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ConstantOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto value = convertConstant(op, op.getValue()); + if (mlir::failed(value)) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp(op, *value); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::AddOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto sring = op.getType(); + auto *ctx = rewriter.getContext(); + if (sring == graphalg::SemiringTypes::forBool(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forInt(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forReal(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forTropInt(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forTropReal(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else if (sring == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + } else { + return op->emitOpError("conversion not supported for semiring ") << sring; + } + + return mlir::success(); +} + static bool hasRelationSignature(mlir::func::FuncOp op) { // All inputs should be relations auto funcType = op.getFunctionType(); @@ -495,11 +608,12 @@ static bool hasRelationOperands(mlir::Operation *op) { void GraphAlgToRel::runOnOperation() { mlir::ConversionTarget target(getContext()); - // Eliminate all graphalg ops (and the few ops we use from arith) ... + // Eliminate all graphalg ops target.addIllegalDialect(); - target.addIllegalDialect(); - // And turn them into relational ops. + // Turn them into relational ops. target.addLegalDialect(); + // and arith ops for the scalar operations. + target.addLegalDialect(); // Keep container module. target.addLegalOp(); // Keep functions, but change their signature. @@ -519,8 +633,10 @@ void GraphAlgToRel::runOnOperation() { // Scalar patterns. // patterns.add(convertArithConstant); - patterns.add>(semiringTypeConverter, - &getContext()); + patterns + .add, + OpConversion, OpConversion>( + semiringTypeConverter, &getContext()); if (mlir::failed(mlir::applyFullConversion(getOperation(), target, std::move(patterns)))) { diff --git a/compiler/test/graphalg-to-rel/apply.mlir b/compiler/test/graphalg-to-rel/apply.mlir index 35f3484..a552970 100644 --- a/compiler/test/graphalg-to-rel/apply.mlir +++ b/compiler/test/graphalg-to-rel/apply.mlir @@ -7,11 +7,15 @@ func.func @ApplyUnary(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg1: i64): %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 2 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] %2 = graphalg.add %1, %arg1 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] graphalg.apply.return %2 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<42 x 42 x i64> } @@ -19,11 +23,16 @@ func.func @ApplyUnary(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x func.func @ApplyBinary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg2: i64, %arg3: i64): + // CHECK: %[[#VAL1:]] = garel.extract 2 + // CHECK: %[[#VAL2:]] = garel.extract 5 + // CHECK: %[[#ADD:]] = arith.addi %[[#VAL1]], %[[#VAL2]] %1 = graphalg.add %arg2, %arg3 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] graphalg.apply.return %1 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<42 x 42 x i64> } @@ -31,12 +40,19 @@ func.func @ApplyBinary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat func.func @ApplyTernary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>, %arg2: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { %0 = graphalg.apply %arg0, %arg1, %arg2 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg3: i64, %arg4: i64, %arg5: i64): + // CHECK: %[[#VAL1:]] = garel.extract 2 + // CHECK: %[[#VAL2:]] = garel.extract 5 + // CHECK: %[[#VAL3:]] = garel.extract 8 + // CHECK: %[[#ADD1:]] = arith.addi %[[#VAL1]], %[[#VAL2]] %1 = graphalg.add %arg3, %arg4 : i64 + // CHECK: %[[#ADD2:]] = arith.addi %[[#ADD1]], %[[#VAL3]] %2 = graphalg.add %1, %arg5 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD2]] graphalg.apply.return %2 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<42 x 42 x i64> } @@ -47,11 +63,15 @@ func.func @ApplyMat(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 4 %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg1: i64): %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 2 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] %2 = graphalg.add %1, %arg1 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] graphalg.apply.return %2 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<42 x 42 x i64> } @@ -60,11 +80,14 @@ func.func @ApplyRowVec(%arg0: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<1 x %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 42 x i64> -> <1 x 42 x i64> { ^bb0(%arg1: i64): %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] %2 = graphalg.add %1, %arg1 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: garel.project.return %[[#ROW]], %[[#ADD]] graphalg.apply.return %2 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 42 x i64> } @@ -73,11 +96,14 @@ func.func @ApplyColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> { ^bb0(%arg42: i64): %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] %2 = graphalg.add %1, %arg42 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: garel.project.return %[[#ROW]], %[[#ADD]] graphalg.apply.return %2 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<42 x 1 x i64> } @@ -86,11 +112,13 @@ func.func @ApplyScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { ^bb0(%arg1: i64): %1 = graphalg.const 1 : i64 + // CHECK: %[[#VAL:]] = garel.extract 0 + // CHECK: %[[#ADD:]] = arith.addi %c1_i64, %[[#VAL]] %2 = graphalg.add %1, %arg1 : i64 + // CHECK: garel.project.return %[[#ADD]] graphalg.apply.return %2 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x i64> } @@ -98,10 +126,13 @@ func.func @ApplyScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 func.func @ApplyBroadcastScalar(%arg0 : !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 43 x i64> { %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <42 x 43 x i64> { ^bb0(%arg1: i64): + // CHECK: %[[#VAL:]] = garel.extract 0 + // CHECK: %[[#ROW:]] = garel.extract 1 + // CHECK: %[[#COL:]] = garel.extract 2 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#VAL]] graphalg.apply.return %arg1 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<42 x 43 x i64> } @@ -109,10 +140,15 @@ func.func @ApplyBroadcastScalar(%arg0 : !graphalg.mat<1 x 1 x i64>) -> !graphalg func.func @ApplyBroadcastOne(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 1 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg2: i64, %arg3: i64): + // CHECK: %[[#VAL1:]] = garel.extract 1 + // CHECK: %[[#VAL2:]] = garel.extract 4 + // CHECK: %[[#ADD:]] = arith.addi %[[#VAL1]], %[[#VAL2]] %1 = graphalg.add %arg2, %arg3 : i64 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 3 + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#ADD]] graphalg.apply.return %1 : i64 } - // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<42 x 42 x i64> } diff --git a/compiler/test/graphalg-to-rel/transpose.mlir b/compiler/test/graphalg-to-rel/transpose.mlir index ab15c9e..4aaaab0 100644 --- a/compiler/test/graphalg-to-rel/transpose.mlir +++ b/compiler/test/graphalg-to-rel/transpose.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: @TransposeMatrix func.func @TransposeMatrix(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<43 x 42 x i64> { - // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> // CHECK: %[[#COL_SLOT:]] = garel.extract 1 // CHECK: %[[#ROW_SLOT:]] = garel.extract 0 // CHECK: %[[#VAL_SLOT:]] = garel.extract 2 @@ -15,10 +15,10 @@ func.func @TransposeMatrix(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat // CHECK-LABEL: @TransposeColVec func.func @TransposeColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 42 x i64> { - // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> // CHECK: %[[#ROW:]] = garel.extract 0 // CHECK: %[[#VAL:]] = garel.extract 1 - // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] : index, si64 + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] : index, i64 %0 = graphalg.transpose %arg0 : <42 x 1 x i64> // CHECK: return %[[#PROJECT]] @@ -27,10 +27,10 @@ func.func @TransposeColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat< // CHECK-LABEL: @TransposeRowVec func.func @TransposeRowVec(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<43 x 1 x i64> { - // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> // CHECK: %[[#COL:]] = garel.extract 0 // CHECK: %[[#VAL:]] = garel.extract 1 - // CHECK: garel.project.return %[[#COL]], %[[#VAL]] : index, si64 + // CHECK: garel.project.return %[[#COL]], %[[#VAL]] : index, i64 %0 = graphalg.transpose %arg0 : <1 x 43 x i64> // CHECK: return %[[#PROJECT]] From fc3287afd4c170f0c771b3588ec41792c526ea97 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 21 Jan 2026 21:42:19 +0000 Subject: [PATCH 11/32] Fix apply.mlir test to add joins. --- compiler/test/graphalg-to-rel/apply.mlir | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/compiler/test/graphalg-to-rel/apply.mlir b/compiler/test/graphalg-to-rel/apply.mlir index a552970..827db04 100644 --- a/compiler/test/graphalg-to-rel/apply.mlir +++ b/compiler/test/graphalg-to-rel/apply.mlir @@ -16,11 +16,14 @@ func.func @ApplyUnary(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x graphalg.apply.return %2 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<42 x 42 x i64> } // CHECK-LABEL: @ApplyBinary func.func @ApplyBinary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>, <0[1] = 1[1]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg2: i64, %arg3: i64): // CHECK: %[[#VAL1:]] = garel.extract 2 @@ -33,11 +36,14 @@ func.func @ApplyBinary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat graphalg.apply.return %1 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<42 x 42 x i64> } // CHECK-LABEL: @ApplyTernary func.func @ApplyTernary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>, %arg2: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1, %arg2 : !garel.rel, !garel.rel, !garel.rel [<0[0] = 1[0]>, <0[0] = 2[0]>, <0[1] = 1[1]>, <0[1] = 2[1]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] %0 = graphalg.apply %arg0, %arg1, %arg2 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg3: i64, %arg4: i64, %arg5: i64): // CHECK: %[[#VAL1:]] = garel.extract 2 @@ -53,6 +59,7 @@ func.func @ApplyTernary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.ma graphalg.apply.return %2 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<42 x 42 x i64> } @@ -60,6 +67,7 @@ func.func @ApplyTernary(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.ma // CHECK-LABEL: @ApplyMat func.func @ApplyMat(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg1: i64): %1 = graphalg.const 1 : i64 @@ -72,11 +80,13 @@ func.func @ApplyMat(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 4 graphalg.apply.return %2 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<42 x 42 x i64> } // CHECK-LABEL: @ApplyRowVec func.func @ApplyRowVec(%arg0: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 42 x i64> -> <1 x 42 x i64> { ^bb0(%arg1: i64): %1 = graphalg.const 1 : i64 @@ -88,11 +98,13 @@ func.func @ApplyRowVec(%arg0: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<1 x graphalg.apply.return %2 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<1 x 42 x i64> } // CHECK-LABEL: @ApplyColVec func.func @ApplyColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 %0 = graphalg.apply %arg0 : !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> { ^bb0(%arg42: i64): %1 = graphalg.const 1 : i64 @@ -104,11 +116,13 @@ func.func @ApplyColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x graphalg.apply.return %2 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<42 x 1 x i64> } // CHECK-LABEL: @ApplyScalar func.func @ApplyScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { ^bb0(%arg1: i64): %1 = graphalg.const 1 : i64 @@ -119,11 +133,16 @@ func.func @ApplyScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 graphalg.apply.return %2 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<1 x 1 x i64> } // CHECK-LABEL: @ApplyBroadcastScalar func.func @ApplyBroadcastScalar(%arg0 : !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 43 x i64> { + // CHECK: %[[#ROWS:]] = garel.range 42 + // CHECK: %[[#COLS:]] = garel.range 43 + // CHECK: %[[#JOIN:]] = garel.join %arg0, %[[#ROWS]], %[[#COLS]] : !garel.rel, !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <42 x 43 x i64> { ^bb0(%arg1: i64): // CHECK: %[[#VAL:]] = garel.extract 0 @@ -133,11 +152,14 @@ func.func @ApplyBroadcastScalar(%arg0 : !graphalg.mat<1 x 1 x i64>) -> !graphalg graphalg.apply.return %arg1 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<42 x 43 x i64> } // CHECK-LABEL: @ApplyBroadcastOne func.func @ApplyBroadcastOne(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<42 x 1 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> { ^bb0(%arg2: i64, %arg3: i64): // CHECK: %[[#VAL1:]] = garel.extract 1 @@ -150,5 +172,6 @@ func.func @ApplyBroadcastOne(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphal graphalg.apply.return %1 : i64 } + // CHECK: return %[[#PROJECT:]] return %0 : !graphalg.mat<42 x 42 x i64> } From 43530445cd9fb66dc5d747a94efc3d4c920bb121 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 21 Jan 2026 21:52:26 +0000 Subject: [PATCH 12/32] Port over add.mlir. --- compiler/test/graphalg-to-rel/add.mlir | 97 ++++++++++++++++++++++++++ llm/graphalg-to-rel.md | 4 ++ 2 files changed, 101 insertions(+) create mode 100644 compiler/test/graphalg-to-rel/add.mlir create mode 100644 llm/graphalg-to-rel.md diff --git a/compiler/test/graphalg-to-rel/add.mlir b/compiler/test/graphalg-to-rel/add.mlir new file mode 100644 index 0000000..82835d1 --- /dev/null +++ b/compiler/test/graphalg-to-rel/add.mlir @@ -0,0 +1,97 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @AddBool +func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + // CHECK: garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg2 : i1, %arg3: i1): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.ori %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ADD]] + %1 = graphalg.add %arg2, %arg3 : i1 + graphalg.apply.return %1 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @AddInt +func.func @AddInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg2 : i64, %arg3: i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ADD]] + %1 = graphalg.add %arg2, %arg3 : i64 + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @AddReal +func.func @AddReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + // CHECK: garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.addf %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ADD]] + %1 = graphalg.add %arg2, %arg3 : f64 + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @AddTropInt +func.func @AddTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg2 : !graphalg.trop_i64, %arg3: !graphalg.trop_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.minsi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ADD]] + %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_i64 + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @AddTropReal +func.func @AddTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_f64>, !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_f64> { + ^bb0(%arg2 : !graphalg.trop_f64, %arg3: !graphalg.trop_f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.minimumf %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ADD]] + %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_f64 + graphalg.apply.return %1 : !graphalg.trop_f64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @AddTropMaxInt +func.func @AddTropMaxInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x !graphalg.trop_max_i64> { + ^bb0(%arg2 : !graphalg.trop_max_i64, %arg3: !graphalg.trop_max_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.maxsi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ADD]] + %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_max_i64 + graphalg.apply.return %1 : !graphalg.trop_max_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} diff --git a/llm/graphalg-to-rel.md b/llm/graphalg-to-rel.md new file mode 100644 index 0000000..7225cee --- /dev/null +++ b/llm/graphalg-to-rel.md @@ -0,0 +1,4 @@ +Update the CHECK comments in @compiler/test/graphalg-to-rel/add.mlir to match the output of running the command `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/add.mlir`. +Check your work by running `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/add.mlir | FileCheck-20` + +Expect to replace ipr.* ops with either arith.* or garel.*. If you think other changes are needed, check with me first. From 664cb47173be9452c8dcc32b697758364324f992 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Wed, 21 Jan 2026 21:54:14 +0000 Subject: [PATCH 13/32] Notes for LLM. --- llm/graphalg-to-rel.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/llm/graphalg-to-rel.md b/llm/graphalg-to-rel.md index 7225cee..b98be92 100644 --- a/llm/graphalg-to-rel.md +++ b/llm/graphalg-to-rel.md @@ -2,3 +2,5 @@ Update the CHECK comments in @compiler/test/graphalg-to-rel/add.mlir to match th Check your work by running `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/add.mlir | FileCheck-20` Expect to replace ipr.* ops with either arith.* or garel.*. If you think other changes are needed, check with me first. + +Keep CHECK comments where they are in the file, close to the ops that they belong to. Do not move them to the outer scope unless necessary. From 39790e12861148834f6d9dd512356f1139f27c5e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 22 Jan 2026 10:29:09 +0000 Subject: [PATCH 14/32] WIP: broadcast. --- compiler/src/garel/GraphAlgToRel.cpp | 83 ++++++++++++++++++++ compiler/test/graphalg-to-rel/GEMINI.md | 79 +++++++++++++++++++ compiler/test/graphalg-to-rel/add.mlir | 36 ++++++--- compiler/test/graphalg-to-rel/broadcast.mlir | 48 +++++++++++ 4 files changed, 234 insertions(+), 12 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/GEMINI.md create mode 100644 compiler/test/graphalg-to-rel/broadcast.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 2d1ad52..dd4b252 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -539,6 +539,89 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::BroadcastOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + llvm::SmallVector joinChildren; + if (input.hasRowColumn()) { + // Already have a row column. + // TODO: record row column. + } else if (output.hasRowColumn()) { + // Broadcast over all rows. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + // TODO: record row column. + } + + if (input.hasColColumn()) { + // Already have a col column. + // TODO: record col column. + } else if (output.hasColColumn()) { + // Broadcast over all columns. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + // TODO: record col column. + } + + joinChildren.push_back(input.relation()); + // TODO: record val column. + + /* + // Join with a dim read for row/col slots that we want in the output, but do + // not have on the input. + llvm::SmallVector renameSlots; + llvm::SmallVector joinChildren; + if (auto rowSlot = input.rowSlot()) { + // Already have a row slot. + renameSlots.emplace_back(rowSlot.getSlot()); + } else if (auto rowSlot = output.rowSlot()) { + // Broadcast over all rows. + joinChildren.emplace_back( + createDimRead(op.getLoc(), rowSlot, rewriter)); + renameSlots.emplace_back(rowSlot.getSlot()); + } + + if (auto colSlot = input.colSlot()) { + // Already have a col slot. + renameSlots.emplace_back(colSlot.getSlot()); + } else if (auto colSlot = output.colSlot()) { + // Broadcast over all columns. + joinChildren.emplace_back( + createDimRead(op.getLoc(), colSlot, rewriter)); + renameSlots.emplace_back(colSlot.getSlot()); + } + + joinChildren.emplace_back(input.stream()); + renameSlots.emplace_back(input.valSlot().getSlot()); + + auto joinOp = rewriter.create( + op.getLoc(), + joinChildren); + { + mlir::OpBuilder::InsertionGuard guard(rewriter); + auto& body = joinOp.getPredicates().front(); + rewriter.setInsertionPointToStart(&body); + // No predicates + rewriter.create(op.getLoc(), std::nullopt); + } + + // Rename to the desired output slots. This also handles reordering slots. + // We want (row, col, val) order, but the join output could be e.g. + // (col, row, val) if the input does not have a col slot. + rewriter.replaceOpWithNewOp( + op, + output.streamType(), + joinOp, + rewriter.getAttr(renameSlots)); + return mlir::success(); + */ + return mlir::failure(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= diff --git a/compiler/test/graphalg-to-rel/GEMINI.md b/compiler/test/graphalg-to-rel/GEMINI.md new file mode 100644 index 0000000..6de77e5 --- /dev/null +++ b/compiler/test/graphalg-to-rel/GEMINI.md @@ -0,0 +1,79 @@ +# GraphAlg to Relation Algebra tests +These test files verify the `graphalg-to-rel` pass, which converts ops from the `graphalg` dialect into `garel` and `arith` dialect ops. + +## Running Tests +Tests require the `graphalg-opt` binary, which is built by running `cmake --build compiler/build --target graphalg-opt`. +To get the output for a test file, run `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/.mlir`. +Test files contain `CHECK` comments that are verified using LLVM's FileCheck tool, installed as `FileCheck-20`. + +If you make any changes and have verified that the individual tests are correct, run the integration tests as a final check: `cmake --build compiler/build --target check`. + +## Coding style +### Use CHECK-LABEL For independent test functions +```mlir +// CHECK-LABEL: @AddBool +func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + ... +} + +// CHECK-LABEL: @AddInt +func.func @AddInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + ... +} +``` + +### Keep new op CHECKs close to original ops +Keep CHECK comments for output ops directly before and at the same indentation as the original ops they were generated from. + +**GOOD**: +```mlir +// CHECK-LABEL: @AddBool +func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg2 : i1, %arg3: i1): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#ADD:]] = arith.ori %[[#LHS]], %[[#RHS]] + %1 = graphalg.add %arg2, %arg3 : i1 + + // CHECK: garel.project.return %[[#ADD]] + graphalg.apply.return %1 : i1 + } + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i1> +} +``` + +**BAD**: +```mlir +// CHECK-LABEL: @AddBool +// CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> +// CHECK: %[[#LHS:]] = garel.extract 0 +// CHECK: %[[#RHS:]] = garel.extract 1 +// CHECK: %[[#ADD:]] = arith.ori %[[#LHS]], %[[#RHS]] +// CHECK: garel.project.return %[[#ADD]] +// CHECK: return %[[#PROJECT]] +func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg2 : i1, %arg3: i1): + %1 = graphalg.add %arg2, %arg3 : i1 + graphalg.apply.return %1 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} +``` + +## Porting IPR tests +If you are asked to port an IPR testcase, do these things: +1. Change ag-opt in the `RUN` comment to graphalg-opt and the pass from --graphalg-to-ipr to --graphalg-to-rel +2. Run `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/.mlir` to see the expected output. Use that to guide the changes described in (3) and (4). +3. Replace `ipr.tuplestream` types with the corresponding `garel.relation`, and `ipr.tuple` with `garel.tuple` +4. Replace `ipr.*` ops with `garel.*` or `arith.*` ops. +5. Verify your changes with `FileCheck-20` (see guide to running tests above). +6. When you have verified your changes to the file, run the integration tests to double-check. + +Do not make changes to the input IR (the parts not in comments). +If you really think this is necessary, ask first. diff --git a/compiler/test/graphalg-to-rel/add.mlir b/compiler/test/graphalg-to-rel/add.mlir index 82835d1..6069c74 100644 --- a/compiler/test/graphalg-to-rel/add.mlir +++ b/compiler/test/graphalg-to-rel/add.mlir @@ -2,96 +2,108 @@ // CHECK-LABEL: @AddBool func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { - // CHECK: garel.project {{.*}} : -> +// CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { ^bb0(%arg2 : i1, %arg3: i1): // CHECK: %[[#LHS:]] = garel.extract 0 // CHECK: %[[#RHS:]] = garel.extract 1 // CHECK: %[[#ADD:]] = arith.ori %[[#LHS]], %[[#RHS]] - // CHECK: garel.project.return %[[#ADD]] %1 = graphalg.add %arg2, %arg3 : i1 + + // CHECK: garel.project.return %[[#ADD]] graphalg.apply.return %1 : i1 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x i1> } // CHECK-LABEL: @AddInt func.func @AddInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { - // CHECK: garel.project {{.*}} : -> + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { ^bb0(%arg2 : i64, %arg3: i64): // CHECK: %[[#LHS:]] = garel.extract 0 // CHECK: %[[#RHS:]] = garel.extract 1 // CHECK: %[[#ADD:]] = arith.addi %[[#LHS]], %[[#RHS]] - // CHECK: garel.project.return %[[#ADD]] %1 = graphalg.add %arg2, %arg3 : i64 + + // CHECK: garel.project.return %[[#ADD]] graphalg.apply.return %1 : i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x i64> } // CHECK-LABEL: @AddReal func.func @AddReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { - // CHECK: garel.project {{.*}} : -> + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { ^bb0(%arg2 : f64, %arg3: f64): // CHECK: %[[#LHS:]] = garel.extract 0 // CHECK: %[[#RHS:]] = garel.extract 1 // CHECK: %[[#ADD:]] = arith.addf %[[#LHS]], %[[#RHS]] - // CHECK: garel.project.return %[[#ADD]] %1 = graphalg.add %arg2, %arg3 : f64 + + // CHECK: garel.project.return %[[#ADD]] graphalg.apply.return %1 : f64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x f64> } // CHECK-LABEL: @AddTropInt func.func @AddTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { - // CHECK: garel.project {{.*}} : -> + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_i64> { ^bb0(%arg2 : !graphalg.trop_i64, %arg3: !graphalg.trop_i64): // CHECK: %[[#LHS:]] = garel.extract 0 // CHECK: %[[#RHS:]] = garel.extract 1 // CHECK: %[[#ADD:]] = arith.minsi %[[#LHS]], %[[#RHS]] - // CHECK: garel.project.return %[[#ADD]] %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_i64 + + // CHECK: garel.project.return %[[#ADD]] graphalg.apply.return %1 : !graphalg.trop_i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> } // CHECK-LABEL: @AddTropReal func.func @AddTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { - // CHECK: garel.project {{.*}} : -> + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_f64>, !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_f64> { ^bb0(%arg2 : !graphalg.trop_f64, %arg3: !graphalg.trop_f64): // CHECK: %[[#LHS:]] = garel.extract 0 // CHECK: %[[#RHS:]] = garel.extract 1 // CHECK: %[[#ADD:]] = arith.minimumf %[[#LHS]], %[[#RHS]] - // CHECK: garel.project.return %[[#ADD]] %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_f64 + + // CHECK: garel.project.return %[[#ADD]] graphalg.apply.return %1 : !graphalg.trop_f64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> } // CHECK-LABEL: @AddTropMaxInt func.func @AddTropMaxInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { - // CHECK: garel.project {{.*}} : -> + // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x !graphalg.trop_max_i64> { ^bb0(%arg2 : !graphalg.trop_max_i64, %arg3: !graphalg.trop_max_i64): // CHECK: %[[#LHS:]] = garel.extract 0 // CHECK: %[[#RHS:]] = garel.extract 1 // CHECK: %[[#ADD:]] = arith.maxsi %[[#LHS]], %[[#RHS]] - // CHECK: garel.project.return %[[#ADD]] %1 = graphalg.add %arg2, %arg3 : !graphalg.trop_max_i64 + + // CHECK: garel.project.return %[[#ADD]] graphalg.apply.return %1 : !graphalg.trop_max_i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> } diff --git a/compiler/test/graphalg-to-rel/broadcast.mlir b/compiler/test/graphalg-to-rel/broadcast.mlir new file mode 100644 index 0000000..9a380eb --- /dev/null +++ b/compiler/test/graphalg-to-rel/broadcast.mlir @@ -0,0 +1,48 @@ +// RUN: ag-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @BroadcastMat +// CHECK: %arg1: !ipr.tuplestream<[[#V:]]:si64> +func.func @BroadcastMat(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#ROWS:]] = ipr.access_vertices <[[#R:]]:!ipr.opaque_vertex> + // CHECK-SAME: opaque_vertex=[[#R]] properties=[] + // CHECK: %[[#COLS:]] = ipr.access_vertices <[[#C:]]:!ipr.opaque_vertex> + // CHECK-SAME: opaque_vertex=[[#C]] properties=[] + // CHECK: %[[#JOIN:]] = ipr.join %[[#ROWS]], %[[#COLS]], %arg1 + // CHECK: ipr.join.return + // + // CHECK: %[[#RENAME:]] = ipr.rename %[[#JOIN]] {{.*}} [[[#R]], [[#C]], [[#V]]] + %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <42 x 42 x i64> + + // CHECK: return %[[#RENAME]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @BroadcastRowVec +// CHECK: %arg1: !ipr.tuplestream<[[#V:]]:si64> +func.func @BroadcastRowVec(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#COLS:]] = ipr.access_vertices <[[#C:]]:!ipr.opaque_vertex> + // CHECK-SAME: opaque_vertex=[[#C]] properties=[] + // CHECK: %[[#JOIN:]] = ipr.join %[[#COLS]], %arg1 + // CHECK: ipr.join.return + // + // CHECK: %[[#RENAME:]] = ipr.rename %[[#JOIN]] {{.*}} [[[#C]], [[#V]]] + %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <1 x 42 x i64> + + // CHECK: return %[[#RENAME]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @BroadcastColVec +// CHECK: %arg1: !ipr.tuplestream<[[#V:]]:si64> +func.func @BroadcastColVec(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#ROWS:]] = ipr.access_vertices <[[#R:]]:!ipr.opaque_vertex> + // CHECK-SAME: opaque_vertex=[[#R]] properties=[] + // CHECK: %[[#JOIN:]] = ipr.join %[[#ROWS]], %arg1 + // CHECK: ipr.join.return + // + // CHECK: %[[#RENAME:]] = ipr.rename %[[#JOIN]] {{.*}} [[[#R]], [[#V]]] + %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <42 x 1 x i64> + + // CHECK: return %[[#RENAME]] + return %0 : !graphalg.mat<42 x 1 x i64> +} From d6be611c0aa7b82cca63bc7c944254966b56cf1e Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 22 Jan 2026 16:36:45 +0000 Subject: [PATCH 15/32] Broadcast tests. --- compiler/include/garel/GARelAttr.h | 2 +- compiler/include/garel/GARelAttr.td | 4 +- compiler/include/garel/GARelOps.td | 17 +++ compiler/src/garel/GARelOps.cpp | 48 +++++++++ compiler/src/garel/GraphAlgToRel.cpp | 104 ++++++++----------- compiler/test/graphalg-to-rel/GEMINI.md | 11 ++ compiler/test/graphalg-to-rel/broadcast.mlir | 38 ++----- llm/graphalg-to-rel.md | 6 -- 8 files changed, 131 insertions(+), 99 deletions(-) delete mode 100644 llm/graphalg-to-rel.md diff --git a/compiler/include/garel/GARelAttr.h b/compiler/include/garel/GARelAttr.h index 3388dbd..7a5f505 100644 --- a/compiler/include/garel/GARelAttr.h +++ b/compiler/include/garel/GARelAttr.h @@ -8,7 +8,7 @@ namespace garel { /** Reference to a column inside of \c RelationType or \c TupleType. */ -using ColumnIdx = unsigned; +using ColumnIdx = std::int32_t; } // namespace garel diff --git a/compiler/include/garel/GARelAttr.td b/compiler/include/garel/GARelAttr.td index 5599528..c09797c 100644 --- a/compiler/include/garel/GARelAttr.td +++ b/compiler/include/garel/GARelAttr.td @@ -14,9 +14,9 @@ def JoinPredicate : GARel_Attr<"JoinPredicate", "join_pred"> { let summary = "A binary equality join predicate"; let parameters = (ins - "unsigned":$lhsRelIdx, + "std::int32_t":$lhsRelIdx, "ColumnIdx":$lhsColIdx, - "unsigned":$rhsRelIdx, + "std::int32_t":$rhsRelIdx, "ColumnIdx":$rhsColIdx); let assemblyFormat = [{ diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 62af3f9..6fb8bf3 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -187,4 +187,21 @@ def RangeOp : GARel_Op<"range", [InferTypeOpAdaptor]> { }]; } +def RemapOp : GARel_Op<"remap", [InferTypeOpAdaptor]> { + let summary = "reorders or drop columns"; + + let arguments = (ins Relation:$input, DenseI32ArrayAttr:$remap); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $input `:` type($input) + $remap + attr-dict + }]; + + let hasVerifier = 1; + let hasFolder = 1; +} + #endif // GAREL_OPS diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index d8cd6d1..5bf9518 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -10,6 +10,7 @@ #include "garel/GARelDialect.h" #include "garel/GARelOps.h" #include "garel/GARelTypes.h" +#include "llvm/ADT/ArrayRef.h" #define GET_OP_CLASSES #include "garel/GARelOps.cpp.inc" @@ -224,4 +225,51 @@ mlir::LogicalResult RangeOp::inferReturnTypes( return mlir::success(); } +// === RemapOp === +mlir::LogicalResult RemapOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns; + auto inputType = llvm::cast(adaptor.getInput().getType()); + for (auto inputCol : adaptor.getRemap()) { + outputColumns.push_back(inputType.getColumns()[inputCol]); + } + + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); + return mlir::success(); +} + +mlir::LogicalResult RemapOp::verify() { + auto inputColumns = getInput().getType().getColumns(); + for (auto inputCol : getRemap()) { + if (inputCol >= inputColumns.size()) { + return emitOpError("remap refers to input column ") + << inputCol << ", but input only has " << inputColumns.size() + << " columns"; + } + } + + return mlir::success(); +} + +// Checks for mapping [0, 1, 2, ...] +static bool isIdentityRemap(llvm::ArrayRef indexes) { + for (auto [i, idx] : llvm::enumerate(indexes)) { + if (i != idx) { + return false; + } + } + + return true; +} + +mlir::OpFoldResult RemapOp::fold(FoldAdaptor adaptor) { + if (isIdentityRemap(getRemap())) { + assert(getInput().getType() == getType()); + return getInput(); + } + + return nullptr; +} + } // namespace garel diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index dd4b252..2a9213f 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -22,6 +22,7 @@ #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" #include "graphalg/SemiringTypes.h" +#include "mlir/IR/Builders.h" namespace garel { @@ -151,7 +152,7 @@ class ApplyOpConversion : public mlir::OpConversionPattern { }; struct InputColumnRef { - unsigned relIdx; + std::size_t relIdx; ColumnIdx colIdx; ColumnIdx outIdx; }; @@ -416,7 +417,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( if (input.hasRowColumn()) { rowColumns.push_back(InputColumnRef{ - .relIdx = static_cast(idx), + .relIdx = idx, .colIdx = input.rowColumn(), .outIdx = nextColumnIdx + input.rowColumn(), }); @@ -424,7 +425,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( if (input.hasColColumn()) { colColumns.push_back(InputColumnRef{ - .relIdx = static_cast(idx), + .relIdx = idx, .colIdx = input.colColumn(), .outIdx = nextColumnIdx + input.colColumn(), }); @@ -443,7 +444,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); joinChildren.emplace_back(rowsOp); rowColumns.push_back(InputColumnRef{ - .relIdx = static_cast(joinChildren.size() - 1), + .relIdx = joinChildren.size() - 1, .colIdx = 0, .outIdx = nextColumnIdx++, }); @@ -456,7 +457,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); joinChildren.emplace_back(colsOp); colColumns.push_back(InputColumnRef{ - .relIdx = static_cast(joinChildren.size() - 1), + .relIdx = joinChildren.size() - 1, .colIdx = 0, .outIdx = nextColumnIdx++, }); @@ -547,79 +548,56 @@ mlir::LogicalResult OpConversion::matchAndRewrite( MatrixAdaptor output(op, typeConverter->convertType(op.getType())); llvm::SmallVector joinChildren; + ColumnIdx currentColIdx = 0; + + std::optional rowColumnIdx; + std::optional colColumnIdx; + if (input.hasRowColumn()) { // Already have a row column. - // TODO: record row column. + rowColumnIdx = input.rowColumn(); } else if (output.hasRowColumn()) { // Broadcast over all rows. joinChildren.push_back( createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); - // TODO: record row column. + rowColumnIdx = currentColIdx++; } if (input.hasColColumn()) { // Already have a col column. - // TODO: record col column. + colColumnIdx = input.colColumn(); } else if (output.hasColColumn()) { // Broadcast over all columns. joinChildren.push_back( createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); - // TODO: record col column. + colColumnIdx = currentColIdx++; } joinChildren.push_back(input.relation()); - // TODO: record val column. + auto valColumnIdx = currentColIdx + input.valColumn(); - /* - // Join with a dim read for row/col slots that we want in the output, but do - // not have on the input. - llvm::SmallVector renameSlots; - llvm::SmallVector joinChildren; - if (auto rowSlot = input.rowSlot()) { - // Already have a row slot. - renameSlots.emplace_back(rowSlot.getSlot()); - } else if (auto rowSlot = output.rowSlot()) { - // Broadcast over all rows. - joinChildren.emplace_back( - createDimRead(op.getLoc(), rowSlot, rewriter)); - renameSlots.emplace_back(rowSlot.getSlot()); - } - - if (auto colSlot = input.colSlot()) { - // Already have a col slot. - renameSlots.emplace_back(colSlot.getSlot()); - } else if (auto colSlot = output.colSlot()) { - // Broadcast over all columns. - joinChildren.emplace_back( - createDimRead(op.getLoc(), colSlot, rewriter)); - renameSlots.emplace_back(colSlot.getSlot()); - } - - joinChildren.emplace_back(input.stream()); - renameSlots.emplace_back(input.valSlot().getSlot()); - - auto joinOp = rewriter.create( - op.getLoc(), - joinChildren); - { - mlir::OpBuilder::InsertionGuard guard(rewriter); - auto& body = joinOp.getPredicates().front(); - rewriter.setInsertionPointToStart(&body); - // No predicates - rewriter.create(op.getLoc(), std::nullopt); - } - - // Rename to the desired output slots. This also handles reordering slots. - // We want (row, col, val) order, but the join output could be e.g. - // (col, row, val) if the input does not have a col slot. - rewriter.replaceOpWithNewOp( - op, - output.streamType(), - joinOp, - rewriter.getAttr(renameSlots)); + auto joinOp = + rewriter.create(op.getLoc(), joinChildren, + // on join predicates (cartesian product) + llvm::ArrayRef{}); + + // Remap to correctly order as (row, col, val). + // TODO: Skip this if it is unnecessary. + llvm::SmallVector outputColumns; + if (rowColumnIdx) { + outputColumns.push_back(*rowColumnIdx); + } + + if (colColumnIdx) { + outputColumns.push_back(*colColumnIdx); + } + + outputColumns.push_back(valColumnIdx); + + auto remapped = + rewriter.createOrFold(op.getLoc(), joinOp, outputColumns); + rewriter.replaceOp(op, remapped); return mlir::success(); - */ - return mlir::failure(); } // ============================================================================= @@ -707,10 +685,10 @@ void GraphAlgToRel::runOnOperation() { MatrixTypeConverter matrixTypeConverter(&getContext(), semiringTypeConverter); mlir::RewritePatternSet patterns(&getContext()); - patterns - .add, OpConversion, - OpConversion>(matrixTypeConverter, - &getContext()); + patterns.add< + OpConversion, OpConversion, + OpConversion, OpConversion>( + matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/GEMINI.md b/compiler/test/graphalg-to-rel/GEMINI.md index 6de77e5..e0fe9e7 100644 --- a/compiler/test/graphalg-to-rel/GEMINI.md +++ b/compiler/test/graphalg-to-rel/GEMINI.md @@ -66,6 +66,17 @@ func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 } ``` +### Prefer numeric substitution blocks when possible +**GOOD**: +```mlir +// CHECK: %[[#RNG:]] = garel.range 42 +``` + +**BAD**: +```mlir +// CHECK: %[[RNG:.+]] = garel.range 42 +``` + ## Porting IPR tests If you are asked to port an IPR testcase, do these things: 1. Change ag-opt in the `RUN` comment to graphalg-opt and the pass from --graphalg-to-ipr to --graphalg-to-rel diff --git a/compiler/test/graphalg-to-rel/broadcast.mlir b/compiler/test/graphalg-to-rel/broadcast.mlir index 9a380eb..918fbc2 100644 --- a/compiler/test/graphalg-to-rel/broadcast.mlir +++ b/compiler/test/graphalg-to-rel/broadcast.mlir @@ -1,48 +1,32 @@ -// RUN: ag-opt --graphalg-to-rel < %s | FileCheck %s +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s // CHECK-LABEL: @BroadcastMat -// CHECK: %arg1: !ipr.tuplestream<[[#V:]]:si64> func.func @BroadcastMat(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 42 x i64> { - // CHECK: %[[#ROWS:]] = ipr.access_vertices <[[#R:]]:!ipr.opaque_vertex> - // CHECK-SAME: opaque_vertex=[[#R]] properties=[] - // CHECK: %[[#COLS:]] = ipr.access_vertices <[[#C:]]:!ipr.opaque_vertex> - // CHECK-SAME: opaque_vertex=[[#C]] properties=[] - // CHECK: %[[#JOIN:]] = ipr.join %[[#ROWS]], %[[#COLS]], %arg1 - // CHECK: ipr.join.return - // - // CHECK: %[[#RENAME:]] = ipr.rename %[[#JOIN]] {{.*}} [[[#R]], [[#C]], [[#V]]] + // CHECK: %[[#RNG0:]] = garel.range 42 + // CHECK: %[[#RNG1:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#RNG0]], %[[#RNG1]], %arg1 %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <42 x 42 x i64> - // CHECK: return %[[#RENAME]] + // CHECK: return %[[#JOIN]] return %0 : !graphalg.mat<42 x 42 x i64> } // CHECK-LABEL: @BroadcastRowVec -// CHECK: %arg1: !ipr.tuplestream<[[#V:]]:si64> func.func @BroadcastRowVec(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 42 x i64> { - // CHECK: %[[#COLS:]] = ipr.access_vertices <[[#C:]]:!ipr.opaque_vertex> - // CHECK-SAME: opaque_vertex=[[#C]] properties=[] - // CHECK: %[[#JOIN:]] = ipr.join %[[#COLS]], %arg1 - // CHECK: ipr.join.return - // - // CHECK: %[[#RENAME:]] = ipr.rename %[[#JOIN]] {{.*}} [[[#C]], [[#V]]] + // CHECK: %[[#RNG:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#RNG]], %arg1 %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <1 x 42 x i64> - // CHECK: return %[[#RENAME]] + // CHECK: return %[[#JOIN]] return %0 : !graphalg.mat<1 x 42 x i64> } // CHECK-LABEL: @BroadcastColVec -// CHECK: %arg1: !ipr.tuplestream<[[#V:]]:si64> func.func @BroadcastColVec(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { - // CHECK: %[[#ROWS:]] = ipr.access_vertices <[[#R:]]:!ipr.opaque_vertex> - // CHECK-SAME: opaque_vertex=[[#R]] properties=[] - // CHECK: %[[#JOIN:]] = ipr.join %[[#ROWS]], %arg1 - // CHECK: ipr.join.return - // - // CHECK: %[[#RENAME:]] = ipr.rename %[[#JOIN]] {{.*}} [[[#R]], [[#V]]] + // CHECK: %[[#RNG:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#RNG]], %arg1 %0 = graphalg.broadcast %arg1 : <1 x 1 x i64> -> <42 x 1 x i64> - // CHECK: return %[[#RENAME]] + // CHECK: return %[[#JOIN]] return %0 : !graphalg.mat<42 x 1 x i64> } diff --git a/llm/graphalg-to-rel.md b/llm/graphalg-to-rel.md deleted file mode 100644 index b98be92..0000000 --- a/llm/graphalg-to-rel.md +++ /dev/null @@ -1,6 +0,0 @@ -Update the CHECK comments in @compiler/test/graphalg-to-rel/add.mlir to match the output of running the command `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/add.mlir`. -Check your work by running `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/add.mlir | FileCheck-20` - -Expect to replace ipr.* ops with either arith.* or garel.*. If you think other changes are needed, check with me first. - -Keep CHECK comments where they are in the file, close to the ops that they belong to. Do not move them to the outer scope unless necessary. From c7f4219d5c8c8c53754def30f748723ddb677420 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Thu, 22 Jan 2026 16:46:03 +0000 Subject: [PATCH 16/32] WIP: casting. --- compiler/src/garel/GraphAlgToRel.cpp | 136 +++++++++++++++++++ compiler/test/graphalg-to-rel/cast.mlir | 170 ++++++++++++++++++++++++ 2 files changed, 306 insertions(+) create mode 100644 compiler/test/graphalg-to-rel/cast.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 2a9213f..d15de34 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -318,6 +318,15 @@ static mlir::FailureOr convertConstant(mlir::Operation *op, return op->emitOpError("cannot convert constant ") << attr; } +static bool isTropicalnessCast(graphalg::SemiringTypeInterface inRing, + graphalg::SemiringTypeInterface outRing) { + assert(inRing != outRing && "No-op cast"); + // If the relational types match, it is purely a 'tropicalness' cast such as + // i64 -> !graphalg.trop_i64. + SemiringTypeConverter conv; + return conv.convertType(inRing) == conv.convertType(outRing); +} + // ============================================================================= // =============================== Op Conversion =============================== // ============================================================================= @@ -648,6 +657,133 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +/* +static mlir::Value preserveAdditiveIdentity(graphalg::CastScalarOp op, + mlir::Value input, + mlir::Value defaultOutput, + mlir::OpBuilder &builder) { +// Return defaultOutput, except when input is the additive identity, which +// we need to remap to the additive identity of the target type. +auto inRing = llvm::cast(op.getInput().getType()); +auto outRing = llvm::cast(op.getType()); + +auto inIdent = convertConstant(op, inRing.addIdentity()); +assert (mlir::succeeded(inIdent)); +auto outIdent = convertConstant(op, outRing.addIdentity()); +assert (mlir::succeeded(outIdent)); + +auto outIdentOp = builder.create( + input.getLoc(), + builder.getAttr(*outIdent)); + +auto inIdentAttr = builder.getAttr( + *inIdent); +auto keysAttr = builder.getAttr( + llvm::ArrayRef(inIdentAttr)); + +return builder.create( + op.getLoc(), + defaultOutput, + input, + keysAttr, + mlir::ValueRange{ outIdentOp }); +} +*/ + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::CastScalarOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto *ctx = rewriter.getContext(); + auto inRing = + llvm::cast(op.getInput().getType()); + auto outRing = llvm::cast(op.getType()); + assert(inRing != outRing && "Identity cast not removed by fold()"); + + if (outRing == graphalg::SemiringTypes::forBool(ctx)) { + // Rewrite to: input != zero(inRing) + /* + auto addIdent = convertConstant(op, inRing.addIdentity()); + if (mlir::failed(addIdent)) { + return mlir::failure(); + } + + auto addIdentOp = rewriter.create( + op.getLoc(), rewriter.getAttr(*addIdent)); + + rewriter.replaceOpWithNewOp( + op, adaptor.getInput(), ipr::CmpPredicate::NE, addIdentOp); + return mlir::success(); + */ + } else if (inRing == graphalg::SemiringTypes::forBool(ctx)) { + // Mapping: + // true -> multiplicative identity + // false -> additive identity + + /* + auto mulIdent = convertConstant(op, outRing.mulIdentity()); + if (mlir::failed(mulIdent)) { + return mlir::failure(); + } + + auto mulIdentOp = rewriter.create( + op.getLoc(), rewriter.getAttr(*mulIdent)); + + auto selectOp = + preserveAdditiveIdentity(op, adaptor.getInput(), mulIdentOp, rewriter); + rewriter.replaceOp(op, selectOp); + return mlir::success(); + */ + } else if (inRing.isIntOrFloat() && outRing.isIntOrFloat()) { + // No tropical semirings, simple cast. + /* + auto dataType = + convertToDataType(op, typeConverter->convertType(op.getType())); + if (mlir::failed(dataType)) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getInput(), rewriter.getAttr(*dataType)); + return mlir::success(); + */ + } else if (isTropicalnessCast(inRing, outRing)) { + // Only cast the 'tropicalness' of the type. The underlying relational type + // does not change. Preserve the value unless it is the additive identity, + // in which case we remap it to the additive identity of the output ring. + /* + auto selectOp = preserveAdditiveIdentity(op, adaptor.getInput(), + adaptor.getInput(), rewriter); + rewriter.replaceOp(op, selectOp); + return mlir::success(); + */ + } else if (llvm::isa(inRing) && + llvm::isa(outRing)) { + // Cast the underlying relational type, but preserve the additive identity. + /* + auto dataType = + convertToDataType(op, typeConverter->convertType(op.getType())); + if (mlir::failed(dataType)) { + return mlir::failure(); + } + + auto castOp = rewriter.create( + op.getLoc(), adaptor.getInput(), + rewriter.getAttr(*dataType)); + + auto selectOp = + preserveAdditiveIdentity(op, adaptor.getInput(), castOp, rewriter); + rewriter.replaceOp(op, selectOp); + return mlir::success(); + */ + } + + return op->emitOpError("cast from ") + << op.getInput().getType() << " to " << op.getType() + << " not yet supported in " << GARelDialect::getDialectNamespace() + << " dialect"; +} + static bool hasRelationSignature(mlir::func::FuncOp op) { // All inputs should be relations auto funcType = op.getFunctionType(); diff --git a/compiler/test/graphalg-to-rel/cast.mlir b/compiler/test/graphalg-to-rel/cast.mlir new file mode 100644 index 0000000..6f52472 --- /dev/null +++ b/compiler/test/graphalg-to-rel/cast.mlir @@ -0,0 +1,170 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @CastBoolInt +func.func @CastBoolInt(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i64> { + ^bb0(%arg1 : i1): + // CHECK: %[[#SLOT:]] = ipr.slot + // CHECK: %[[#MUL_IDENT:]] = ipr.constant_slot <<"S64"> (1)> + // CHECK: %[[#ADD_IDENT:]] = ipr.constant_slot <<"S64"> (0)> + // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#SLOT]] + // CHECK-SAME: [<<"BOOLEAN"> (false)>] %[[#ADD_IDENT]] + // CHECK-SAME: default = %[[#MUL_IDENT]] + %1 = graphalg.cast_scalar %arg1 : i1 -> i64 + + // CHECK: ipr.project.return %[[#SELECT]] + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @CastIntReal +func.func @CastIntReal(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x f64> { + ^bb0(%arg1 : i64): + // CHECK: %[[#INPUT:]] = ipr.slot + // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : si64 -> <"F64"> + %1 = graphalg.cast_scalar %arg1 : i64 -> f64 + + // CHECK: ipr.project.return %[[#CAST]] + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @CastRealInt +func.func @CastRealInt(%arg0: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x f64> -> <1 x 1 x i64> { + ^bb0(%arg1 : f64): + // CHECK: %[[#INPUT:]] = ipr.slot + // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : f64 -> <"S64"> + %1 = graphalg.cast_scalar %arg1 : f64 -> i64 + + // CHECK: ipr.project.return %[[#CAST]] + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @CastBoolTrop +func.func @CastBoolTrop(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg1 : i1): + // CHECK: %[[#SLOT:]] = ipr.slot + + // CHECK: %[[#MUL_IDENT:]] = ipr.constant_slot <<"S64"> (0)> + // CHECK: %[[#ADD_IDENT:]] = ipr.constant_slot <<"S64"> (9223372036854775807)> + // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#SLOT]] + // CHECK-SAME: [<<"BOOLEAN"> (false)>] %[[#ADD_IDENT]] + // CHECK-SAME: default = %[[#MUL_IDENT]] + %1 = graphalg.cast_scalar %arg1 : i1 -> !graphalg.trop_i64 + + // CHECK: ipr.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @CastTropBool +func.func @CastTropBool(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: %[[#INPUT:]] = ipr.slot + // CHECK: %[[#ZERO:]] = ipr.constant_slot <<"S64"> (9223372036854775807)> + // CHECK: %[[#CMP:]] = ipr.cmp %[[#INPUT]] : si64 NE %[[#ZERO]] : si64 + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> i1 + + // CHECK: ipr.project.return %[[#CMP]] + graphalg.apply.return %1 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @CastTropIntTropReal +func.func @CastTropIntTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_f64> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: %[[#INPUT:]] = ipr.slot + // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : si64 -> <"F64"> + // + // CHECK: %[[#INF:]] = ipr.constant_slot <<"F64"> (INF)> + // + // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] + // CHECK-SAME: [<<"S64"> (9223372036854775807)>] + // CHECK-SAME: %[[#INF]] + // CHECK-SAME: default = %[[#CAST]] + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> !graphalg.trop_f64 + + // CHECK: ipr.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_f64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @CastTropRealTropInt +func.func @CastTropRealTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: %[[#INPUT:]] = ipr.slot + // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : f64 -> <"S64"> + // + // CHECK: %[[#INF:]] = ipr.constant_slot <<"S64"> (9223372036854775807)> + // + // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] + // CHECK-SAME: [<<"F64"> (INF)>] + // CHECK-SAME: %[[#INF]] + // CHECK-SAME: default = %[[#CAST]] + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_f64 -> !graphalg.trop_i64 + + // CHECK: ipr.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @CastIntToTropMaxInt +func.func @CastIntToTropMaxInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x !graphalg.trop_max_i64> { + ^bb0(%arg1 : i64): + // CHECK: %[[#INPUT:]] = ipr.slot + + // CHECK: %[[#ZERO:]] = ipr.constant_slot <<"S64"> (-9223372036854775808)> + // + // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] + // CHECK-SAME: [<<"S64"> (0)>] %[[#ZERO]] + // CHECK-SAME: default = %[[#INPUT]] + %1 = graphalg.cast_scalar %arg1 : i64 -> !graphalg.trop_max_i64 + + // CHECK: ipr.project.return %[[#SELECT]] + graphalg.apply.return %1 : !graphalg.trop_max_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} + +// CHECK-LABEL: @CastTropMaxIntToInt +func.func @CastTropMaxIntToInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i64> { + ^bb0(%arg1 : !graphalg.trop_max_i64): + // CHECK: %[[#INPUT:]] = ipr.slot + + // CHECK: %[[#ZERO:]] = ipr.constant_slot <<"S64"> (0)> + // + // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] + // CHECK-SAME: [<<"S64"> (-9223372036854775808)>] %[[#ZERO]] + // CHECK-SAME: default = %[[#INPUT]] + %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_max_i64 -> i64 + + // CHECK: ipr.project.return %[[#SELECT]] + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} From b0528f81f8a89b549587c78c73abe8382985c937 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 23 Jan 2026 14:20:48 +0000 Subject: [PATCH 17/32] Casting. --- compiler/src/garel/GraphAlgToRel.cpp | 171 ++++++++++++++------------- 1 file changed, 89 insertions(+), 82 deletions(-) diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index d15de34..765268f 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -273,7 +273,6 @@ static mlir::FailureOr convertConstant(mlir::Operation *op, if (type == graphalg::SemiringTypes::forBool(ctx)) { return attr; } else if (type == graphalg::SemiringTypes::forInt(ctx)) { - // TODO: Need to convert to signed? return attr; } else if (type == graphalg::SemiringTypes::forReal(ctx)) { return attr; @@ -312,7 +311,7 @@ static mlir::FailureOr convertConstant(mlir::Operation *op, } return mlir::TypedAttr( - mlir::FloatAttr::get(mlir::Float64Type::get(ctx), value)); + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 64), value)); } return op->emitOpError("cannot convert constant ") << attr; @@ -327,6 +326,43 @@ static bool isTropicalnessCast(graphalg::SemiringTypeInterface inRing, return conv.convertType(inRing) == conv.convertType(outRing); } +static mlir::Value preserveAdditiveIdentity(graphalg::CastScalarOp op, + mlir::Value input, + mlir::Value defaultOutput, + mlir::OpBuilder &builder) { + // Return defaultOutput, except when input is the additive identity, which + // we need to remap to the additive identity of the target type. + auto inRing = + llvm::cast(op.getInput().getType()); + auto outRing = llvm::cast(op.getType()); + + auto inIdent = convertConstant(op, inRing.addIdentity()); + assert(mlir::succeeded(inIdent)); + auto outIdent = convertConstant(op, outRing.addIdentity()); + assert(mlir::succeeded(outIdent)); + + auto inIdentOp = + builder.create(input.getLoc(), *inIdent); + auto outIdentOp = + builder.create(input.getLoc(), *outIdent); + + // Compare input == inIdent + mlir::Value identCompare; + if (input.getType().isF64()) { + assert(inIdentOp.getType().isF64()); + identCompare = builder.create( + op.getLoc(), mlir::arith::CmpFPredicate::OEQ, input, inIdentOp); + } else { + assert(input.getType().isSignlessInteger(64)); + assert(inIdentOp.getType().isSignlessInteger(64)); + identCompare = builder.create( + op.getLoc(), mlir::arith::CmpIPredicate::eq, input, inIdentOp); + } + + return builder.create(op.getLoc(), identCompare, + outIdentOp, defaultOutput); +} + // ============================================================================= // =============================== Op Conversion =============================== // ============================================================================= @@ -657,39 +693,6 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } -/* -static mlir::Value preserveAdditiveIdentity(graphalg::CastScalarOp op, - mlir::Value input, - mlir::Value defaultOutput, - mlir::OpBuilder &builder) { -// Return defaultOutput, except when input is the additive identity, which -// we need to remap to the additive identity of the target type. -auto inRing = llvm::cast(op.getInput().getType()); -auto outRing = llvm::cast(op.getType()); - -auto inIdent = convertConstant(op, inRing.addIdentity()); -assert (mlir::succeeded(inIdent)); -auto outIdent = convertConstant(op, outRing.addIdentity()); -assert (mlir::succeeded(outIdent)); - -auto outIdentOp = builder.create( - input.getLoc(), - builder.getAttr(*outIdent)); - -auto inIdentAttr = builder.getAttr( - *inIdent); -auto keysAttr = builder.getAttr( - llvm::ArrayRef(inIdentAttr)); - -return builder.create( - op.getLoc(), - defaultOutput, - input, - keysAttr, - mlir::ValueRange{ outIdentOp }); -} -*/ - template <> mlir::LogicalResult OpConversion::matchAndRewrite( graphalg::CastScalarOp op, OpAdaptor adaptor, @@ -702,80 +705,84 @@ mlir::LogicalResult OpConversion::matchAndRewrite( if (outRing == graphalg::SemiringTypes::forBool(ctx)) { // Rewrite to: input != zero(inRing) - /* auto addIdent = convertConstant(op, inRing.addIdentity()); if (mlir::failed(addIdent)) { return mlir::failure(); } - auto addIdentOp = rewriter.create( - op.getLoc(), rewriter.getAttr(*addIdent)); + auto addIdentOp = + rewriter.create(op.getLoc(), *addIdent); + + if (addIdentOp.getType().isF64()) { + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpFPredicate::ONE, adaptor.getInput(), addIdentOp); + } else { + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpIPredicate::ne, adaptor.getInput(), addIdentOp); + } - rewriter.replaceOpWithNewOp( - op, adaptor.getInput(), ipr::CmpPredicate::NE, addIdentOp); return mlir::success(); - */ } else if (inRing == graphalg::SemiringTypes::forBool(ctx)) { // Mapping: // true -> multiplicative identity // false -> additive identity - - /* - auto mulIdent = convertConstant(op, outRing.mulIdentity()); - if (mlir::failed(mulIdent)) { + auto trueValue = convertConstant(op, outRing.mulIdentity()); + if (mlir::failed(trueValue)) { return mlir::failure(); } - auto mulIdentOp = rewriter.create( - op.getLoc(), rewriter.getAttr(*mulIdent)); - - auto selectOp = - preserveAdditiveIdentity(op, adaptor.getInput(), mulIdentOp, rewriter); - rewriter.replaceOp(op, selectOp); - return mlir::success(); - */ - } else if (inRing.isIntOrFloat() && outRing.isIntOrFloat()) { - // No tropical semirings, simple cast. - /* - auto dataType = - convertToDataType(op, typeConverter->convertType(op.getType())); - if (mlir::failed(dataType)) { + auto falseValue = convertConstant(op, outRing.addIdentity()); + if (mlir::failed(falseValue)) { return mlir::failure(); } - rewriter.replaceOpWithNewOp( - op, adaptor.getInput(), rewriter.getAttr(*dataType)); + auto trueOp = + rewriter.create(op.getLoc(), *trueValue); + auto falseOp = + rewriter.create(op.getLoc(), *falseValue); + rewriter.replaceOpWithNewOp(op, adaptor.getInput(), + trueOp, falseOp); + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forInt(ctx) && + outRing == graphalg::SemiringTypes::forReal(ctx)) { + // Promote to i64 -> f64 + rewriter.replaceOpWithNewOp(op, outRing, + adaptor.getInput()); + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forReal(ctx) && + outRing == graphalg::SemiringTypes::forInt(ctx)) { + // Truncate to int + rewriter.replaceOpWithNewOp(op, outRing, + adaptor.getInput()); return mlir::success(); - */ } else if (isTropicalnessCast(inRing, outRing)) { // Only cast the 'tropicalness' of the type. The underlying relational type // does not change. Preserve the value unless it is the additive identity, // in which case we remap it to the additive identity of the output ring. - /* auto selectOp = preserveAdditiveIdentity(op, adaptor.getInput(), adaptor.getInput(), rewriter); rewriter.replaceOp(op, selectOp); return mlir::success(); - */ - } else if (llvm::isa(inRing) && - llvm::isa(outRing)) { + } else if (inRing == graphalg::SemiringTypes::forTropInt(ctx) && + outRing == graphalg::SemiringTypes::forTropReal(ctx)) { + // trop_i64 to trop_f64 // Cast the underlying relational type, but preserve the additive identity. - /* - auto dataType = - convertToDataType(op, typeConverter->convertType(op.getType())); - if (mlir::failed(dataType)) { - return mlir::failure(); - } - - auto castOp = rewriter.create( - op.getLoc(), adaptor.getInput(), - rewriter.getAttr(*dataType)); - + auto castOp = rewriter.create( + op.getLoc(), rewriter.getF64Type(), adaptor.getInput()); + auto selectOp = + preserveAdditiveIdentity(op, adaptor.getInput(), castOp, rewriter); + rewriter.replaceOp(op, selectOp); + return mlir::success(); + } else if (inRing == graphalg::SemiringTypes::forTropReal(ctx) && + outRing == graphalg::SemiringTypes::forTropInt(ctx)) { + // trop_f64 to trop_i64 + // Cast the underlying relational type, but preserve the additive identity. + auto castOp = rewriter.create( + op.getLoc(), rewriter.getI64Type(), adaptor.getInput()); auto selectOp = preserveAdditiveIdentity(op, adaptor.getInput(), castOp, rewriter); rewriter.replaceOp(op, selectOp); return mlir::success(); - */ } return op->emitOpError("cast from ") @@ -830,10 +837,10 @@ void GraphAlgToRel::runOnOperation() { // Scalar patterns. // patterns.add(convertArithConstant); - patterns - .add, - OpConversion, OpConversion>( - semiringTypeConverter, &getContext()); + patterns.add< + OpConversion, OpConversion, + OpConversion, OpConversion>( + semiringTypeConverter, &getContext()); if (mlir::failed(mlir::applyFullConversion(getOperation(), target, std::move(patterns)))) { From 16373ee98a9890a4752a387a3338cced16092478 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 23 Jan 2026 15:21:16 +0000 Subject: [PATCH 18/32] ConstantMatrixOp. --- compiler/include/garel/GARelOps.td | 13 ++ compiler/src/garel/GARelOps.cpp | 18 +++ compiler/src/garel/GraphAlgToRel.cpp | 41 +++++- compiler/test/graphalg-to-rel/cast.mlir | 125 ++++++++++--------- compiler/test/graphalg-to-rel/const-mat.mlir | 113 +++++++++++++++++ 5 files changed, 246 insertions(+), 64 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/const-mat.mlir diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 6fb8bf3..e40e673 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -104,6 +104,7 @@ def JoinOp : GARel_Op<"join", [InferTypeOpAdaptor]> { }]; let hasVerifier = 1; + let hasFolder = 1; } def UnionOp : GARel_Op<"union", [SameOperandsAndResultType]> { @@ -204,4 +205,16 @@ def RemapOp : GARel_Op<"remap", [InferTypeOpAdaptor]> { let hasFolder = 1; } +def ConstantOp : GARel_Op<"const", [InferTypeOpAdaptor]> { + let summary = "A relation with one constant-value tuple"; + + let arguments = (ins TypedAttrInterface:$value); + + let results = (outs Relation:$result); + + let assemblyFormat = [{ + $value attr-dict + }]; +} + #endif // GAREL_OPS diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index 5bf9518..9216814 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -113,6 +113,15 @@ SelectReturnOp SelectOp::getTerminator() { } // === JoinOp === +mlir::OpFoldResult JoinOp::fold(FoldAdaptor adaptor) { + if (getInputs().size() == 1) { + assert(getInputs()[0].getType() == getType()); + return getInputs()[0]; + } + + return nullptr; +} + mlir::LogicalResult JoinOp::verify() { // TODO: Inputs must use distinct columns. // TODO: Predicates must refer to columns in distinct inputs (and to columns @@ -272,4 +281,13 @@ mlir::OpFoldResult RemapOp::fold(FoldAdaptor adaptor) { return nullptr; } +// === ConstantOp === +mlir::LogicalResult ConstantOp::inferReturnTypes( + mlir::MLIRContext *ctx, std::optional location, + Adaptor adaptor, llvm::SmallVectorImpl &inferredReturnTypes) { + llvm::SmallVector outputColumns{adaptor.getValue().getType()}; + inferredReturnTypes.push_back(RelationType::get(ctx, outputColumns)); + return mlir::success(); +} + } // namespace garel diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 765268f..3ae148c 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -645,6 +645,41 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ConstantMatrixOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + auto constantValue = convertConstant(op, op.getValue()); + if (mlir::failed(constantValue)) { + return mlir::failure(); + } + + auto constantOp = rewriter.create(op.getLoc(), *constantValue); + + // Broadcast to rows/columns if needed. + llvm::SmallVector joinChildren; + if (!output.matrixType().getRows().isOne()) { + // Broadcast over all rows. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter)); + } + + if (!output.matrixType().getCols().isOne()) { + // Broadcast over all columns. + joinChildren.push_back( + createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter)); + } + + joinChildren.push_back(constantOp); + auto joinOp = rewriter.createOrFold( + op.getLoc(), joinChildren, llvm::ArrayRef{}); + + rewriter.replaceOp(op, joinOp); + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -830,13 +865,13 @@ void GraphAlgToRel::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); patterns.add< OpConversion, OpConversion, - OpConversion, OpConversion>( - matrixTypeConverter, &getContext()); + OpConversion, OpConversion, + OpConversion>(matrixTypeConverter, + &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); // Scalar patterns. - // patterns.add(convertArithConstant); patterns.add< OpConversion, OpConversion, OpConversion, OpConversion>( diff --git a/compiler/test/graphalg-to-rel/cast.mlir b/compiler/test/graphalg-to-rel/cast.mlir index 6f52472..af4a639 100644 --- a/compiler/test/graphalg-to-rel/cast.mlir +++ b/compiler/test/graphalg-to-rel/cast.mlir @@ -2,169 +2,172 @@ // CHECK-LABEL: @CastBoolInt func.func @CastBoolInt(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i64> { ^bb0(%arg1 : i1): - // CHECK: %[[#SLOT:]] = ipr.slot - // CHECK: %[[#MUL_IDENT:]] = ipr.constant_slot <<"S64"> (1)> - // CHECK: %[[#ADD_IDENT:]] = ipr.constant_slot <<"S64"> (0)> - // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#SLOT]] - // CHECK-SAME: [<<"BOOLEAN"> (false)>] %[[#ADD_IDENT]] - // CHECK-SAME: default = %[[#MUL_IDENT]] + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#C1:.+]] = arith.constant 1 : i64 + // CHECK: %[[#C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[#C1]], %[[#C0]] %1 = graphalg.cast_scalar %arg1 : i1 -> i64 - // CHECK: ipr.project.return %[[#SELECT]] + // CHECK: garel.project.return %[[#SELECT]] graphalg.apply.return %1 : i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x i64> } // CHECK-LABEL: @CastIntReal func.func @CastIntReal(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x f64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x f64> { ^bb0(%arg1 : i64): - // CHECK: %[[#INPUT:]] = ipr.slot - // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : si64 -> <"F64"> + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.sitofp %[[#EXTRACT]] %1 = graphalg.cast_scalar %arg1 : i64 -> f64 - // CHECK: ipr.project.return %[[#CAST]] + // CHECK: garel.project.return %[[#CAST]] graphalg.apply.return %1 : f64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x f64> } // CHECK-LABEL: @CastRealInt func.func @CastRealInt(%arg0: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x f64> -> <1 x 1 x i64> { ^bb0(%arg1 : f64): - // CHECK: %[[#INPUT:]] = ipr.slot - // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : f64 -> <"S64"> + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.fptosi %[[#EXTRACT]] %1 = graphalg.cast_scalar %arg1 : f64 -> i64 - // CHECK: ipr.project.return %[[#CAST]] + // CHECK: garel.project.return %[[#CAST]] graphalg.apply.return %1 : i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x i64> } // CHECK-LABEL: @CastBoolTrop func.func @CastBoolTrop(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x !graphalg.trop_i64> { ^bb0(%arg1 : i1): - // CHECK: %[[#SLOT:]] = ipr.slot - - // CHECK: %[[#MUL_IDENT:]] = ipr.constant_slot <<"S64"> (0)> - // CHECK: %[[#ADD_IDENT:]] = ipr.constant_slot <<"S64"> (9223372036854775807)> - // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#SLOT]] - // CHECK-SAME: [<<"BOOLEAN"> (false)>] %[[#ADD_IDENT]] - // CHECK-SAME: default = %[[#MUL_IDENT]] + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[#CMAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[#C0]], %[[#CMAX]] %1 = graphalg.cast_scalar %arg1 : i1 -> !graphalg.trop_i64 - // CHECK: ipr.project.return %[[#SELECT]] + // CHECK: garel.project.return %[[#SELECT]] graphalg.apply.return %1 : !graphalg.trop_i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> } // CHECK-LABEL: @CastTropBool func.func @CastTropBool(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { ^bb0(%arg1 : !graphalg.trop_i64): - // CHECK: %[[#INPUT:]] = ipr.slot - // CHECK: %[[#ZERO:]] = ipr.constant_slot <<"S64"> (9223372036854775807)> - // CHECK: %[[#CMP:]] = ipr.cmp %[[#INPUT]] : si64 NE %[[#ZERO]] : si64 + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CMAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#EXTRACT]], %[[#CMAX]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> i1 - // CHECK: ipr.project.return %[[#CMP]] + // CHECK: garel.project.return %[[#CMP]] graphalg.apply.return %1 : i1 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x i1> } // CHECK-LABEL: @CastTropIntTropReal func.func @CastTropIntTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_f64> { ^bb0(%arg1 : !graphalg.trop_i64): - // CHECK: %[[#INPUT:]] = ipr.slot - // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : si64 -> <"F64"> - // - // CHECK: %[[#INF:]] = ipr.constant_slot <<"F64"> (INF)> - // - // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] - // CHECK-SAME: [<<"S64"> (9223372036854775807)>] - // CHECK-SAME: %[[#INF]] - // CHECK-SAME: default = %[[#CAST]] + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.sitofp %[[#EXTRACT]] + // CHECK: %[[#MAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#INF:.+]] = arith.constant 0x7FF0000000000000 : f64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[#MAX]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#INF]], %[[#CAST]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> !graphalg.trop_f64 - // CHECK: ipr.project.return %[[#SELECT]] + // CHECK: garel.project.return %[[#SELECT]] graphalg.apply.return %1 : !graphalg.trop_f64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> } // CHECK-LABEL: @CastTropRealTropInt func.func @CastTropRealTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_i64> { ^bb0(%arg1 : !graphalg.trop_f64): - // CHECK: %[[#INPUT:]] = ipr.slot - // CHECK: %[[#CAST:]] = ipr.cast %[[#INPUT]] : f64 -> <"S64"> - // - // CHECK: %[[#INF:]] = ipr.constant_slot <<"S64"> (9223372036854775807)> - // - // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] - // CHECK-SAME: [<<"F64"> (INF)>] - // CHECK-SAME: %[[#INF]] - // CHECK-SAME: default = %[[#CAST]] + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#CAST:]] = arith.fptosi %[[#EXTRACT]] + // CHECK: %[[#INF:]] = arith.constant 0x7FF0000000000000 : f64 + // CHECK: %[[#MAX:]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[#EXTRACT]], %[[#INF]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#MAX]], %[[#CAST]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_f64 -> !graphalg.trop_i64 - // CHECK: ipr.project.return %[[#SELECT]] + // CHECK: garel.project.return %[[#SELECT]] graphalg.apply.return %1 : !graphalg.trop_i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> } // CHECK-LABEL: @CastIntToTropMaxInt func.func @CastIntToTropMaxInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x !graphalg.trop_max_i64> { ^bb0(%arg1 : i64): - // CHECK: %[[#INPUT:]] = ipr.slot - - // CHECK: %[[#ZERO:]] = ipr.constant_slot <<"S64"> (-9223372036854775808)> - // - // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] - // CHECK-SAME: [<<"S64"> (0)>] %[[#ZERO]] - // CHECK-SAME: default = %[[#INPUT]] + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#C0:]] = arith.constant 0 : i64 + // CHECK: %[[#MIN:]] = arith.constant -9223372036854775808 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[#C0]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#MIN]], %[[#EXTRACT]] %1 = graphalg.cast_scalar %arg1 : i64 -> !graphalg.trop_max_i64 - // CHECK: ipr.project.return %[[#SELECT]] + // CHECK: garel.project.return %[[#SELECT]] graphalg.apply.return %1 : !graphalg.trop_max_i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> } // CHECK-LABEL: @CastTropMaxIntToInt func.func @CastTropMaxIntToInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#PROJECT:]] = garel.project %arg0 : -> %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i64> { ^bb0(%arg1 : !graphalg.trop_max_i64): - // CHECK: %[[#INPUT:]] = ipr.slot - - // CHECK: %[[#ZERO:]] = ipr.constant_slot <<"S64"> (0)> - // - // CHECK: %[[#SELECT:]] = ipr.select_slot %[[#INPUT]] - // CHECK-SAME: [<<"S64"> (-9223372036854775808)>] %[[#ZERO]] - // CHECK-SAME: default = %[[#INPUT]] + // CHECK: %[[#EXTRACT:]] = garel.extract 0 + // CHECK: %[[#MIN:]] = arith.constant -9223372036854775808 : i64 + // CHECK: %[[#C0:]] = arith.constant 0 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[#MIN]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#C0]], %[[#EXTRACT]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_max_i64 -> i64 - // CHECK: ipr.project.return %[[#SELECT]] + // CHECK: garel.project.return %[[#SELECT]] graphalg.apply.return %1 : i64 } + // CHECK: return %[[#PROJECT]] return %0 : !graphalg.mat<1 x 1 x i64> } diff --git a/compiler/test/graphalg-to-rel/const-mat.mlir b/compiler/test/graphalg-to-rel/const-mat.mlir new file mode 100644 index 0000000..fe8fa40 --- /dev/null +++ b/compiler/test/graphalg-to-rel/const-mat.mlir @@ -0,0 +1,113 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === Shapes + +// CHECK-LABEL: @ConstMat +func.func @ConstMat() -> !graphalg.mat<42 x 43 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + // CHECK: %[[#ROWS:]] = garel.range 42 + // CHECK: %[[#COLS:]] = garel.range 43 + // CHECK: %[[#JOIN:]] = garel.join %[[#ROWS]], %[[#COLS]], %[[#CONST]] + %0 = graphalg.const_mat 1 : i64 -> <42 x 43 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<42 x 43 x i64> +} + +// CHECK-LABEL: @ConstRowVec +func.func @ConstRowVec() -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + // CHECK: %[[#COLS:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#COLS]], %[[#CONST]] + %0 = graphalg.const_mat 1 : i64 -> <1 x 42 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @ConstColVec +func.func @ConstColVec() -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + // CHECK: %[[#ROWS:]] = garel.range 42 + // CHECK: %[[#JOIN:]] = garel.join %[[#ROWS]], %[[#CONST]] + %0 = graphalg.const_mat 1 : i64 -> <42 x 1 x i64> + + // CHECK: return %[[#JOIN]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ConstScalar +func.func @ConstScalar() -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#CONST:]] = garel.const 1 : i64 + %0 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> + + // CHECK: return %[[#CONST]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// === Semirings + +// CHECK-LABEL: @ConstBool +func.func @ConstBool() -> !graphalg.mat<1 x 1 x i1> { + // CHECK: garel.const true + %0 = graphalg.const_mat true -> <1 x 1 x i1> + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ConstInt +func.func @ConstInt() -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.const 1 : i64 + %0 = graphalg.const_mat 1 : i64 -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ConstReal +func.func @ConstReal() -> !graphalg.mat<1 x 1 x f64> { + // CHECK: garel.const 1.000000e+00 : f64 + %0 = graphalg.const_mat 1.000000e+00 : f64 -> <1 x 1 x f64> + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: ConstTropInt +func.func @ConstTropInt() -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: garel.const 1 : i64 + %0 = graphalg.const_mat #graphalg.trop_int<1 : i64> : !graphalg.trop_i64 -> <1 x 1 x !graphalg.trop_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: ConstTropReal +func.func @ConstTropReal() -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: garel.const 1.000000e+00 : f64 + %0 = graphalg.const_mat #graphalg.trop_float<1.000000e+00 : f64> : !graphalg.trop_f64 -> <1 x 1 x !graphalg.trop_f64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: ConstTropMaxInt +func.func @ConstTropMaxInt() -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: garel.const 1 : i64 + %0 = graphalg.const_mat #graphalg.trop_int<1 : i64> : !graphalg.trop_max_i64 -> <1 x 1 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} + +// === Infinity values + +// CHECK-LABEL: ConstTropIntInf +func.func @ConstTropIntInf() -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: garel.const 9223372036854775807 : i64 + %0 = graphalg.const_mat #graphalg.trop_inf : !graphalg.trop_i64 -> <1 x 1 x !graphalg.trop_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: ConstTropRealInf +func.func @ConstTropRealInf() -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: garel.const 0x7FF0000000000000 : f64 + %0 = graphalg.const_mat #graphalg.trop_inf : !graphalg.trop_f64 -> <1 x 1 x !graphalg.trop_f64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: ConstTropMaxIntInf +func.func @ConstTropMaxIntInf() -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: garel.const -9223372036854775808 : i64 + %0 = graphalg.const_mat #graphalg.trop_inf : !graphalg.trop_max_i64 -> <1 x 1 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} From 60616f839965d791bd59e371fb4b1e0a7bbff837 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 23 Jan 2026 15:52:24 +0000 Subject: [PATCH 19/32] constants. --- compiler/src/garel/GraphAlgToRel.cpp | 29 +++++- compiler/test/graphalg-to-rel/cast.mlir | 48 ++++----- compiler/test/graphalg-to-rel/const.mlir | 125 +++++++++++++++++++++++ 3 files changed, 174 insertions(+), 28 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/const.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 3ae148c..d2f5ef2 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -826,6 +826,26 @@ mlir::LogicalResult OpConversion::matchAndRewrite( << " dialect"; } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::EqOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + if (lhs.getType().isF64()) { + assert(rhs.getType().isF64()); + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpFPredicate::OEQ, lhs, rhs); + } else { + assert(lhs.getType().isSignlessInteger()); + assert(rhs.getType().isSignlessInteger()); + rewriter.replaceOpWithNewOp( + op, mlir::arith::CmpIPredicate::eq, lhs, rhs); + } + + return mlir::success(); +} + static bool hasRelationSignature(mlir::func::FuncOp op) { // All inputs should be relations auto funcType = op.getFunctionType(); @@ -872,10 +892,11 @@ void GraphAlgToRel::runOnOperation() { &getContext()); // Scalar patterns. - patterns.add< - OpConversion, OpConversion, - OpConversion, OpConversion>( - semiringTypeConverter, &getContext()); + patterns + .add, + OpConversion, OpConversion, + OpConversion, OpConversion>( + semiringTypeConverter, &getContext()); if (mlir::failed(mlir::applyFullConversion(getOperation(), target, std::move(patterns)))) { diff --git a/compiler/test/graphalg-to-rel/cast.mlir b/compiler/test/graphalg-to-rel/cast.mlir index af4a639..e37b00b 100644 --- a/compiler/test/graphalg-to-rel/cast.mlir +++ b/compiler/test/graphalg-to-rel/cast.mlir @@ -6,9 +6,9 @@ func.func @CastBoolInt(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i64> { ^bb0(%arg1 : i1): // CHECK: %[[#EXTRACT:]] = garel.extract 0 - // CHECK: %[[#C1:.+]] = arith.constant 1 : i64 - // CHECK: %[[#C0:.+]] = arith.constant 0 : i64 - // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[#C1]], %[[#C0]] + // CHECK: %[[C1:.+]] = arith.constant 1 : i64 + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[C1]], %[[C0]] %1 = graphalg.cast_scalar %arg1 : i1 -> i64 // CHECK: garel.project.return %[[#SELECT]] @@ -59,9 +59,9 @@ func.func @CastBoolTrop(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x !graphalg.trop_i64> { ^bb0(%arg1 : i1): // CHECK: %[[#EXTRACT:]] = garel.extract 0 - // CHECK: %[[#C0:.+]] = arith.constant 0 : i64 - // CHECK: %[[#CMAX:.+]] = arith.constant 9223372036854775807 : i64 - // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[#C0]], %[[#CMAX]] + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[CMAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#SELECT:]] = arith.select %[[#EXTRACT]], %[[C0]], %[[CMAX]] %1 = graphalg.cast_scalar %arg1 : i1 -> !graphalg.trop_i64 // CHECK: garel.project.return %[[#SELECT]] @@ -78,8 +78,8 @@ func.func @CastTropBool(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !gr %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { ^bb0(%arg1 : !graphalg.trop_i64): // CHECK: %[[#EXTRACT:]] = garel.extract 0 - // CHECK: %[[#CMAX:.+]] = arith.constant 9223372036854775807 : i64 - // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#EXTRACT]], %[[#CMAX]] + // CHECK: %[[CMAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#EXTRACT]], %[[CMAX]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> i1 // CHECK: garel.project.return %[[#CMP]] @@ -97,10 +97,10 @@ func.func @CastTropIntTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) ^bb0(%arg1 : !graphalg.trop_i64): // CHECK: %[[#EXTRACT:]] = garel.extract 0 // CHECK: %[[#CAST:]] = arith.sitofp %[[#EXTRACT]] - // CHECK: %[[#MAX:.+]] = arith.constant 9223372036854775807 : i64 - // CHECK: %[[#INF:.+]] = arith.constant 0x7FF0000000000000 : f64 - // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[#MAX]] - // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#INF]], %[[#CAST]] + // CHECK: %[[MAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[INF:.+]] = arith.constant 0x7FF0000000000000 : f64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[MAX]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[INF]], %[[#CAST]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_i64 -> !graphalg.trop_f64 // CHECK: garel.project.return %[[#SELECT]] @@ -118,10 +118,10 @@ func.func @CastTropRealTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) ^bb0(%arg1 : !graphalg.trop_f64): // CHECK: %[[#EXTRACT:]] = garel.extract 0 // CHECK: %[[#CAST:]] = arith.fptosi %[[#EXTRACT]] - // CHECK: %[[#INF:]] = arith.constant 0x7FF0000000000000 : f64 - // CHECK: %[[#MAX:]] = arith.constant 9223372036854775807 : i64 - // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[#EXTRACT]], %[[#INF]] - // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#MAX]], %[[#CAST]] + // CHECK: %[[INF:.+]] = arith.constant 0x7FF0000000000000 : f64 + // CHECK: %[[MAX:.+]] = arith.constant 9223372036854775807 : i64 + // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[#EXTRACT]], %[[INF]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[MAX]], %[[#CAST]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_f64 -> !graphalg.trop_i64 // CHECK: garel.project.return %[[#SELECT]] @@ -138,10 +138,10 @@ func.func @CastIntToTropMaxInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.m %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x !graphalg.trop_max_i64> { ^bb0(%arg1 : i64): // CHECK: %[[#EXTRACT:]] = garel.extract 0 - // CHECK: %[[#C0:]] = arith.constant 0 : i64 - // CHECK: %[[#MIN:]] = arith.constant -9223372036854775808 : i64 - // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[#C0]] - // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#MIN]], %[[#EXTRACT]] + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[MIN:.+]] = arith.constant -9223372036854775808 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[C0]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[MIN]], %[[#EXTRACT]] %1 = graphalg.cast_scalar %arg1 : i64 -> !graphalg.trop_max_i64 // CHECK: garel.project.return %[[#SELECT]] @@ -158,10 +158,10 @@ func.func @CastTropMaxIntToInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i64> { ^bb0(%arg1 : !graphalg.trop_max_i64): // CHECK: %[[#EXTRACT:]] = garel.extract 0 - // CHECK: %[[#MIN:]] = arith.constant -9223372036854775808 : i64 - // CHECK: %[[#C0:]] = arith.constant 0 : i64 - // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[#MIN]] - // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[#C0]], %[[#EXTRACT]] + // CHECK: %[[MIN:.+]] = arith.constant -9223372036854775808 : i64 + // CHECK: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[#EXTRACT]], %[[MIN]] + // CHECK: %[[#SELECT:]] = arith.select %[[#CMP]], %[[C0]], %[[#EXTRACT]] %1 = graphalg.cast_scalar %arg1 : !graphalg.trop_max_i64 -> i64 // CHECK: garel.project.return %[[#SELECT]] diff --git a/compiler/test/graphalg-to-rel/const.mlir b/compiler/test/graphalg-to-rel/const.mlir new file mode 100644 index 0000000..0947fce --- /dev/null +++ b/compiler/test/graphalg-to-rel/const.mlir @@ -0,0 +1,125 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// Note: The inputs for these tests have to be fairly complex because: +// - Scalar ops must be inside of an ApplyOp +// - An ApplyOp whose body is reducible to a constant value folds into a +// ConstantMatrixOp. + +// CHECK-LABEL: @ConstBool +func.func @ConstBool(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg1 : i1): + // CHECK: arith.constant false + %1 = graphalg.const false + %2 = graphalg.eq %arg1, %1 : i1 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ConstInt +func.func @ConstInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : i64): + // CHECK: arith.constant 42 : i64 + %1 = graphalg.const 42 : i64 + %2 = graphalg.eq %arg1, %1 : i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ConstReal +func.func @ConstReal(%arg0: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : f64): + // CHECK: arith.constant 4.200000e+01 : f64 + %1 = graphalg.const 42.0 : f64 + %2 = graphalg.eq %arg1, %1 : f64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropInt +func.func @ConstTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: arith.constant 42 : i64 + %1 = graphalg.const #graphalg.trop_int<42 : i64> : !graphalg.trop_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropReal +func.func @ConstTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: arith.constant 4.200000e+01 : f64 + %1 = graphalg.const #graphalg.trop_float<42.0 : f64> : !graphalg.trop_f64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_f64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropMaxInt +func.func @ConstTropMaxInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_max_i64): + // CHECK: arith.constant 42 : i64 + %1 = graphalg.const #graphalg.trop_int<42 : i64> : !graphalg.trop_max_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_max_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// === Infinity values + +// CHECK-LABEL: ConstTropIntInf +func.func @ConstTropIntInf(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: arith.constant 9223372036854775807 : i64 + %1 = graphalg.const #graphalg.trop_inf : !graphalg.trop_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropRealInf +func.func @ConstTropRealInf(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: arith.constant 0x7FF0000000000000 : f64 + %1 = graphalg.const #graphalg.trop_inf : !graphalg.trop_f64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_f64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: ConstTropMaxIntInf +func.func @ConstTropMaxIntInf(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_max_i64): + // CHECK: arith.constant -9223372036854775808 : i64 + %1 = graphalg.const #graphalg.trop_inf : !graphalg.trop_max_i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_max_i64 + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} From e98f561f4147a5a3826cde558d79c4689b47f7a0 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 23 Jan 2026 16:26:33 +0000 Subject: [PATCH 20/32] DeferredReduceOp. --- compiler/include/garel/GARelOps.td | 2 + compiler/src/garel/GARelOps.cpp | 59 ++++++++- compiler/src/garel/GraphAlgToRel.cpp | 58 ++++++++- .../test/graphalg-to-rel/deferred-reduce.mlir | 121 ++++++++++++++++++ 4 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/deferred-reduce.mlir diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index e40e673..262dc71 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -118,6 +118,8 @@ def UnionOp : GARel_Op<"union", [SameOperandsAndResultType]> { $inputs `:` type($inputs) attr-dict }]; + + let hasFolder = 1; } def AggregateOp : GARel_Op<"aggregate", [InferTypeOpAdaptor]> { diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index 9216814..cf38e6b 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -115,6 +115,7 @@ SelectReturnOp SelectOp::getTerminator() { // === JoinOp === mlir::OpFoldResult JoinOp::fold(FoldAdaptor adaptor) { if (getInputs().size() == 1) { + // no-op if we only have one input. assert(getInputs()[0].getType() == getType()); return getInputs()[0]; } @@ -123,9 +124,50 @@ mlir::OpFoldResult JoinOp::fold(FoldAdaptor adaptor) { } mlir::LogicalResult JoinOp::verify() { - // TODO: Inputs must use distinct columns. - // TODO: Predicates must refer to columns in distinct inputs (and to columns - // present in the input). + for (auto pred : getPredicates()) { + // Valid input relation + if (pred.getLhsRelIdx() >= getInputs().size()) { + return emitOpError("predicate refers to input relation ") + << pred.getLhsRelIdx() << ", but there are only " + << getInputs().size() << " input relations: " << pred; + } else if (pred.getRhsRelIdx() >= getInputs().size()) { + return emitOpError("predicate refers to input relation ") + << pred.getRhsRelIdx() << ", but there are only " + << getInputs().size() << " input relations: " << pred; + } + + if (pred.getLhsRelIdx() == pred.getRhsRelIdx()) { + return emitOpError("predicate between columns of the same relation: ") + << pred; + } + + // Valid column on LHS relation. + auto lhsInputType = + llvm::cast(getInputs()[pred.getLhsRelIdx()].getType()); + if (pred.getLhsColIdx() >= lhsInputType.getColumns().size()) { + auto diag = emitOpError("predicate refers to column ") + << pred.getLhsColIdx() << ", but there are only " + << lhsInputType.getColumns().size() + << " input columns: " << pred; + diag.attachNote(getInputs()[pred.getLhsRelIdx()].getLoc()) + << "input relation defined here"; + return diag; + } + + // Valid column on RHS relation. + auto rhsInputType = + llvm::cast(getInputs()[pred.getRhsRelIdx()].getType()); + if (pred.getRhsColIdx() >= rhsInputType.getColumns().size()) { + auto diag = emitOpError("predicate refers to column ") + << pred.getRhsColIdx() << ", but there are only " + << rhsInputType.getColumns().size() + << " input columns: " << pred; + diag.attachNote(getInputs()[pred.getRhsRelIdx()].getLoc()) + << "input relation defined here"; + return diag; + } + } + return mlir::success(); } @@ -142,6 +184,17 @@ mlir::LogicalResult JoinOp::inferReturnTypes( return mlir::success(); } +// === UnionOp === +mlir::OpFoldResult UnionOp::fold(FoldAdaptor adaptor) { + if (getInputs().size() == 1) { + // no-op if we only have one input. + assert(getInputs()[0].getType() == getType()); + return getInputs()[0]; + } + + return nullptr; +} + // === AggregateOp === mlir::LogicalResult AggregateOp::inferReturnTypes( mlir::MLIRContext *ctx, std::optional location, diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index d2f5ef2..61c0ba5 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -363,6 +363,28 @@ static mlir::Value preserveAdditiveIdentity(graphalg::CastScalarOp op, outIdentOp, defaultOutput); } +static mlir::FailureOr +createAggregator(mlir::Operation *op, graphalg::SemiringTypeInterface sring, + ColumnIdx input, mlir::OpBuilder &builder) { + auto *ctx = builder.getContext(); + AggregateFunc func; + if (sring == graphalg::SemiringTypes::forBool(ctx)) { + func = AggregateFunc::LOR; + } else if (sring.isIntOrFloat()) { + func = AggregateFunc::SUM; + } else if (llvm::isa(sring)) { + func = AggregateFunc::MIN; + } else if (llvm::isa(sring)) { + func = AggregateFunc::MAX; + } else { + return op->emitOpError("aggregation with semiring ") + << sring << " is not supported"; + } + + std::array inputs{input}; + return AggregatorAttr::get(ctx, func, inputs); +} + // ============================================================================= // =============================== Op Conversion =============================== // ============================================================================= @@ -680,6 +702,39 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::DeferredReduceOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInputs()[0], adaptor.getInputs()[0]); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + // Group by keys + llvm::SmallVector groupBy; + if (output.hasRowColumn()) { + groupBy.push_back(input.rowColumn()); + } + + if (output.hasColColumn()) { + groupBy.push_back(input.colColumn()); + } + + // Aggregators + auto aggregator = + createAggregator(op, input.semiring(), input.valColumn(), rewriter); + if (mlir::failed(aggregator)) { + return mlir::failure(); + } + + std::array aggregators{*aggregator}; + + // union the inputs and then aggregate. + auto unionOp = + rewriter.createOrFold(op.getLoc(), adaptor.getInputs()); + rewriter.replaceOpWithNewOp(op, unionOp, groupBy, aggregators); + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -886,7 +941,8 @@ void GraphAlgToRel::runOnOperation() { patterns.add< OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion>(matrixTypeConverter, + OpConversion, + OpConversion>(matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/deferred-reduce.mlir b/compiler/test/graphalg-to-rel/deferred-reduce.mlir new file mode 100644 index 0000000..e6c27c9 --- /dev/null +++ b/compiler/test/graphalg-to-rel/deferred-reduce.mlir @@ -0,0 +1,121 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === Input/output shapes + +// CHECK-LABEL: @ReduceMatScalar +func.func @ReduceMatScalar(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceMatRowVec +func.func @ReduceMatRowVec(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 43 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[1] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 43 x i64> + return %0 : !graphalg.mat<1 x 43 x i64> +} + +// CHECK-LABEL: @ReduceMatColVec +func.func @ReduceMatColVec(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <42 x 1 x i64> + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ReduceMatMat +func.func @ReduceMatMat(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<42 x 43 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0, 1] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <42 x 43 x i64> + return %0 : !graphalg.mat<42 x 43 x i64> +} + +// CHECK-LABEL: @ReduceRowVecScalar +func.func @ReduceRowVecScalar(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<1 x 43 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceRowVecRowVec +func.func @ReduceRowVecRowVec(%arg0: !graphalg.mat<1 x 43 x i64>) -> !graphalg.mat<1 x 43 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<1 x 43 x i64> -> <1 x 43 x i64> + return %0 : !graphalg.mat<1 x 43 x i64> +} + +// CHECK-LABEL: @ReduceColVecScalar +func.func @ReduceColVecScalar(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 1 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceColVecColVec +func.func @ReduceColVecColVec(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[0] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @ReduceScalarScalar +func.func @ReduceScalarScalar(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: garel.aggregate %arg0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceMultiple +func.func @ReduceMultiple( + %arg0 : !graphalg.mat<1 x 43 x i64>, + %arg1 : !graphalg.mat<42 x 1 x i64>) + -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 : !garel.rel, !garel.rel + // CHECK: %[[#AGG:]] = garel.aggregate %0 : group_by=[] aggregators=[] + %0 = graphalg.deferred_reduce %arg0, %arg1 : !graphalg.mat<1 x 43 x i64>, !graphalg.mat<42 x 1 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// === Semirings + +// CHECK-LABEL: @ReduceBool +func.func @ReduceBool(%arg0: !graphalg.mat<42 x 43 x i1>) -> !graphalg.mat<1 x 1 x i1> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i1> -> <1 x 1 x i1> + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @ReduceInt +func.func @ReduceInt(%arg0: !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 1 x i64> + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ReduceReal +func.func @ReduceReal(%arg0: !graphalg.mat<42 x 43 x f64>) -> !graphalg.mat<1 x 1 x f64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x f64> -> <1 x 1 x f64> + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @ReduceTropInt +func.func @ReduceTropInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @ReduceTropReal +func.func @ReduceTropReal(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_f64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @ReduceTropMaxInt +func.func @ReduceTropMaxInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + // CHECK: aggregators=[] + %0 = graphalg.deferred_reduce %arg0 : !graphalg.mat<42 x 43 x !graphalg.trop_max_i64> -> <1 x 1 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} From 5a1c87e110b5dfea1bddccd1585d24dc2656550d Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 23 Jan 2026 16:35:53 +0000 Subject: [PATCH 21/32] DiagOp. --- compiler/src/garel/GraphAlgToRel.cpp | 17 +++++++++++++++-- compiler/test/graphalg-to-rel/diag.mlir | 15 +++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/diag.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 61c0ba5..58ab11e 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -735,6 +735,19 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::DiagOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + std::array mapping{0, 0, 1}; + rewriter.replaceOpWithNewOp(op, adaptor.getInput(), mapping); + + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -942,8 +955,8 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion>(matrixTypeConverter, - &getContext()); + OpConversion, OpConversion>( + matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/diag.mlir b/compiler/test/graphalg-to-rel/diag.mlir new file mode 100644 index 0000000..846dd8f --- /dev/null +++ b/compiler/test/graphalg-to-rel/diag.mlir @@ -0,0 +1,15 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @DiagCol +func.func @DiagCol(%arg0: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: garel.remap %arg0 : [0, 0, 1] + %0 = graphalg.diag %arg0 : !graphalg.mat<42 x 1 x i64> + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @DiagRow +func.func @DiagRow(%arg0: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: garel.remap %arg0 : [0, 0, 1] + %0 = graphalg.diag %arg0 : !graphalg.mat<1 x 42 x i64> + return %0 : !graphalg.mat<42 x 42 x i64> +} From d7993ce93586ea18022c26429b33b1ec519f6583 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Fri, 23 Jan 2026 16:47:43 +0000 Subject: [PATCH 22/32] EqOp. --- compiler/test/graphalg-to-rel/div.mlir | 16 +++++ compiler/test/graphalg-to-rel/eq.mlir | 81 ++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 compiler/test/graphalg-to-rel/div.mlir create mode 100644 compiler/test/graphalg-to-rel/eq.mlir diff --git a/compiler/test/graphalg-to-rel/div.mlir b/compiler/test/graphalg-to-rel/div.mlir new file mode 100644 index 0000000..66f8122 --- /dev/null +++ b/compiler/test/graphalg-to-rel/div.mlir @@ -0,0 +1,16 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @DivReal +func.func @DivReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#DIV:]] = arith.divf %2, %3 : f64 + %1 = arith.divf %arg2, %arg3 : f64 + // CHECK: garel.project.return %[[#DIV]] + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} diff --git a/compiler/test/graphalg-to-rel/eq.mlir b/compiler/test/graphalg-to-rel/eq.mlir new file mode 100644 index 0000000..7434046 --- /dev/null +++ b/compiler/test/graphalg-to-rel/eq.mlir @@ -0,0 +1,81 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @EqBool +func.func @EqBool(%arg0: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg1 : i1): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant false + %1 = graphalg.const false + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[LHS]], %[[RHS]] : i1 + %2 = graphalg.eq %arg1, %1 : i1 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqInt +func.func @EqInt(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : i64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0 + %1 = graphalg.const 0 : i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[LHS]], %[[RHS]] : i64 + %2 = graphalg.eq %arg1, %1 : i64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqReal +func.func @EqReal(%arg0: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : f64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0.000000e+00 + %1 = graphalg.const 0.000000e+00 : f64 + // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[LHS]], %[[RHS]] : f64 + %2 = graphalg.eq %arg1, %1 : f64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqTropInt +func.func @EqTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_i64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0 + %1 = graphalg.const #graphalg.trop_int<0 : i64> : !graphalg.trop_i64 + // CHECK: %[[#CMP:]] = arith.cmpi eq, %[[LHS]], %[[RHS]] : i64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_i64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @EqTropReal +func.func @EqTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x i1> { + ^bb0(%arg1 : !graphalg.trop_f64): + // CHECK: %[[LHS:.+]] = garel.extract 0 + // CHECK: %[[RHS:.+]] = arith.constant 0.000000e+00 + %1 = graphalg.const #graphalg.trop_float<0.0 : f64> : !graphalg.trop_f64 + // CHECK: %[[#CMP:]] = arith.cmpf oeq, %[[LHS]], %[[RHS]] : f64 + %2 = graphalg.eq %arg1, %1 : !graphalg.trop_f64 + // CHECK: garel.project.return %[[#CMP]] + graphalg.apply.return %2 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} From 02a74e9a75cca6b4edf184185f3563cfdf80676e Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 16:35:15 +0000 Subject: [PATCH 23/32] Started with ForConstOp. --- compiler/include/garel/GARelOps.td | 8 +- compiler/include/garel/GARelTypes.td | 7 ++ compiler/src/garel/GraphAlgToRel.cpp | 90 +++++++++++++++++++- compiler/test/graphalg-to-rel/for-const.mlir | 51 +++++++++++ 4 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/for-const.mlir diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 262dc71..4d46290 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -148,7 +148,9 @@ def ForOp : GARel_Op<"for", [InferTypeOpAdaptor]> { I64Attr:$iters, I64Attr:$resultIdx); - let regions = (region SizedRegion<1>:$body); + let regions = (region + SizedRegion<1>:$body, + MaxSizedRegion<1>:$until); let results = (outs Relation:$result); @@ -157,6 +159,7 @@ def ForOp : GARel_Op<"for", [InferTypeOpAdaptor]> { `iters` `` `=` `` $iters `result_idx` `` `=` `` $resultIdx $body + (`until` $until^)? attr-dict }]; @@ -172,7 +175,8 @@ def ForYieldOp : GARel_Op<"for.yield", [ let arguments = (ins Variadic:$inputs); let assemblyFormat = [{ - $inputs `:` type($inputs) attr-dict + $inputs `:` type($inputs) + attr-dict }]; // Note: verification performed by parent ForOp. diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td index 5d848b5..fcb4709 100644 --- a/compiler/include/garel/GARelTypes.td +++ b/compiler/include/garel/GARelTypes.td @@ -22,6 +22,13 @@ def Relation : GARel_Type<"Relation", "rel"> { }]; } +def I1Relation + : ConfinedType + , BuildableType<[{ + $_builder.getType( + std::array{$_builder.getI1Type()}) + }]>; + def Tuple : GARel_Type<"Tuple", "tuple"> { let summary = "A single tuple"; diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 58ab11e..f0cee38 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1,15 +1,19 @@ #include #include +#include #include #include #include +#include #include #include #include #include #include #include +#include +#include #include #include @@ -22,7 +26,6 @@ #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" #include "graphalg/SemiringTypes.h" -#include "mlir/IR/Builders.h" namespace garel { @@ -385,6 +388,15 @@ createAggregator(mlir::Operation *op, graphalg::SemiringTypeInterface sring, return AggregatorAttr::get(ctx, func, inputs); } +static mlir::IntegerAttr tryGetConstantInt(mlir::Value v) { + mlir::Attribute attr; + if (!mlir::matchPattern(v, mlir::m_Constant(&attr))) { + return nullptr; + } + + return llvm::cast(attr); +} + // ============================================================================= // =============================== Op Conversion =============================== // ============================================================================= @@ -748,6 +760,79 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::ForConstOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto rangeBegin = tryGetConstantInt(op.getRangeBegin()); + auto rangeEnd = tryGetConstantInt(op.getRangeEnd()); + if (!rangeBegin || !rangeEnd) { + return op->emitOpError("iter range is not constant"); + } + + auto iters = rangeEnd.getInt() - rangeBegin.getInt(); + + llvm::SmallVector initArgs{adaptor.getRangeBegin()}; + initArgs.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); + + auto blockSignature = + typeConverter->convertBlockSignature(&op.getBody().front()); + if (!blockSignature) { + return op->emitOpError("Failed to convert iter args"); + } + + // The relational version of this op can only have a single output value. + // For loops with multiple results, duplicate. + llvm::SmallVector resultValues; + for (auto i : llvm::seq(op->getNumResults())) { + auto result = op->getResult(i); + if (result.use_empty()) { + // Not used. Take init arg as a dummy value. + resultValues.emplace_back(adaptor.getInitArgs()[i]); + continue; + } + + // We are adding the iteration count variable as a first argument, so offset + // the result index accordingly. + std::int64_t resultIdx = i + 1; + auto resultType = adaptor.getInitArgs()[i].getType(); + auto forOp = rewriter.create(op.getLoc(), resultType, initArgs, + iters, resultIdx); + // body block + rewriter.cloneRegionBefore(op.getBody(), forOp.getBody(), + forOp.getBody().begin()); + rewriter.applySignatureConversion(&forOp.getBody().front(), + *blockSignature); + + // until block + if (!op.getUntil().empty()) { + rewriter.cloneRegionBefore(op.getUntil(), forOp.getUntil(), + forOp.getUntil().begin()); + rewriter.applySignatureConversion(&forOp.getUntil().front(), + *blockSignature); + } + + resultValues.emplace_back(forOp); + } + + rewriter.replaceOp(op, resultValues); + return mlir::success(); +} + +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::YieldOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + // TODO: increment + auto iterVar = op->getBlock()->getArgument(0); + + llvm::SmallVector inputs{iterVar}; + inputs.append(adaptor.getInputs().begin(), adaptor.getInputs().end()); + + rewriter.replaceOpWithNewOp(op, inputs); + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -955,7 +1040,8 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion, OpConversion>( + OpConversion, OpConversion, + OpConversion, OpConversion>( matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/for-const.mlir b/compiler/test/graphalg-to-rel/for-const.mlir new file mode 100644 index 0000000..dffb64e --- /dev/null +++ b/compiler/test/graphalg-to-rel/for-const.mlir @@ -0,0 +1,51 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @ForConst +func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> + %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> + + // CHECK: %[[#FOR:]] = ipr.for %arg0 {{.*}} range=[0 : 10) result_idx=0 { + // CHECK: ^bb0(%arg1: !ipr.tuplestream<[[#IT:]]:si64>, %arg2: !ipr.tuplestream<[[#V:]]:si64>): + // CHECK: ipr.for.yield %arg2 + %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { + ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): + graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> + } until { + } + + // CHECK: return %[[#FOR]] + return %2 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @ForResultUnused +func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> + %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> + + // CHECK: %[[#FOR:]] = ipr.for %arg0, %arg1 {{.*}} range=[0 : 10) result_idx=1 { + // CHECK: ipr.for.yield %arg3, %arg4 + %2:2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0, %arg1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> body { + ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>, %arg4: !graphalg.mat<1 x 1 x f64>): + graphalg.yield %arg3, %arg4 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> + } until { + } + + // CHECK: return %[[#FOR]] + return %2#1 : !graphalg.mat<1 x 1 x f64> +} + +func.func @Until(%arg0: !graphalg.mat<42 x 42 x i1>) -> !graphalg.mat<42 x 42 x i1> { + %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> + %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> + + %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<42 x 42 x i1> -> !graphalg.mat<42 x 42 x i1> body { + ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): + graphalg.yield %arg2 : !graphalg.mat<42 x 42 x i1> + } until { + ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): + %3 = graphalg.reduce %arg2 : <42 x 42 x i1> -> <1 x 1 x i1> + graphalg.yield %3 : !graphalg.mat<1 x 1 x i1> + } + return %2 : !graphalg.mat<42 x 42 x i1> +} From 2c917dadf323301eee7a08e73925a3c2e0fdf8e3 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 17:01:59 +0000 Subject: [PATCH 24/32] Finish ForConstOp. --- compiler/src/garel/GraphAlgToRel.cpp | 27 +++++++++++++++++--- compiler/test/graphalg-to-rel/for-const.mlir | 27 +++++++++++++++----- 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index f0cee38..2230f22 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -26,6 +26,7 @@ #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" #include "graphalg/SemiringTypes.h" +#include "mlir/IR/ValueRange.h" namespace garel { @@ -823,10 +824,30 @@ template <> mlir::LogicalResult OpConversion::matchAndRewrite( graphalg::YieldOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { - // TODO: increment - auto iterVar = op->getBlock()->getArgument(0); + llvm::SmallVector inputs; + + auto *block = op->getBlock(); + auto forOp = llvm::cast(block->getParentOp()); + if (block == &forOp.getBody().front()) { + // Main body + auto iterVar = op->getBlock()->getArgument(0); + // Increment the iteration counter using a garel.project op. + auto loc = forOp.getLoc(); + auto projOp = rewriter.create(loc, iterVar.getType(), iterVar); + auto &projBlock = projOp.createProjectionsBlock(); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&projBlock); + auto iterOp = rewriter.create(loc, 0, projBlock.getArgument(0)); + auto oneOp = rewriter.create( + loc, 1, rewriter.getI64Type()); + auto addOp = rewriter.create(loc, iterOp, oneOp); + rewriter.create(loc, mlir::ValueRange{addOp}); + + inputs.push_back(projOp); + } else { + // No changes needed for 'until' block. + } - llvm::SmallVector inputs{iterVar}; inputs.append(adaptor.getInputs().begin(), adaptor.getInputs().end()); rewriter.replaceOpWithNewOp(op, inputs); diff --git a/compiler/test/graphalg-to-rel/for-const.mlir b/compiler/test/graphalg-to-rel/for-const.mlir index dffb64e..2d418d8 100644 --- a/compiler/test/graphalg-to-rel/for-const.mlir +++ b/compiler/test/graphalg-to-rel/for-const.mlir @@ -5,17 +5,22 @@ func.func @ForConst(%arg0: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - // CHECK: %[[#FOR:]] = ipr.for %arg0 {{.*}} range=[0 : 10) result_idx=0 { - // CHECK: ^bb0(%arg1: !ipr.tuplestream<[[#IT:]]:si64>, %arg2: !ipr.tuplestream<[[#V:]]:si64>): - // CHECK: ipr.for.yield %arg2 + // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 : !garel.rel, !garel.rel iters=10 result_idx=1 { %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<1 x 1 x i64> -> !graphalg.mat<1 x 1 x i64> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<1 x 1 x i64>): + // CHECK: %[[#PROJ:]] = garel.project %arg1 + // CHECK: %[[#EXT:]] = garel.extract 0 + // CHECK: %[[#ADD:]] = arith.addi %[[#EXT]], %c1_i64 + // CHECK: garel.project.return %[[#ADD]] + // CHECK: garel.for.yield %[[#PROJ]], %arg2 graphalg.yield %arg2 : !graphalg.mat<1 x 1 x i64> } until { } // CHECK: return %[[#FOR]] return %2 : !graphalg.mat<1 x 1 x i64> + } // CHECK-LABEL: @ForResultUnused @@ -23,10 +28,12 @@ func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.m %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> - // CHECK: %[[#FOR:]] = ipr.for %arg0, %arg1 {{.*}} range=[0 : 10) result_idx=1 { - // CHECK: ipr.for.yield %arg3, %arg4 + // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0, %arg1 %2:2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0, %arg1) : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> -> !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> body { ^bb0(%arg2: !graphalg.mat<1 x 1 x i64>, %arg3: !graphalg.mat<1 x 1 x i64>, %arg4: !graphalg.mat<1 x 1 x f64>): + // CHECK: %[[#PROJ:]] = garel.project %arg2 + // CHECK: garel.for.yield %[[#PROJ]], %arg3, %arg4 graphalg.yield %arg3, %arg4 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x f64> } until { } @@ -35,16 +42,24 @@ func.func @ForResultUnused(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.m return %2#1 : !graphalg.mat<1 x 1 x f64> } +// CHECK-LABEL: @Until func.func @Until(%arg0: !graphalg.mat<42 x 42 x i1>) -> !graphalg.mat<42 x 42 x i1> { %0 = graphalg.const_mat 0 : i64 -> <1 x 1 x i64> %1 = graphalg.const_mat 10 : i64 -> <1 x 1 x i64> + // CHECK: %[[#BEGIN:]] = garel.const 0 : i64 + // CHECK: %[[#FOR:]] = garel.for %[[#BEGIN]], %arg0 %2 = graphalg.for_const range(%0, %1) : <1 x 1 x i64> init(%arg0) : !graphalg.mat<42 x 42 x i1> -> !graphalg.mat<42 x 42 x i1> body { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): + // CHECK: %[[#PROJ:]] = garel.project %arg1 + // CHECK: garel.for.yield %[[#PROJ]], %arg2 graphalg.yield %arg2 : !graphalg.mat<42 x 42 x i1> + // CHECK: } until { } until { ^bb0(%arg1: !graphalg.mat<1 x 1 x i64>, %arg2: !graphalg.mat<42 x 42 x i1>): - %3 = graphalg.reduce %arg2 : <42 x 42 x i1> -> <1 x 1 x i1> + // CHECK: %[[#AGG:]] = garel.aggregate %arg2 + %3 = graphalg.deferred_reduce %arg2 : !graphalg.mat<42 x 42 x i1> -> <1 x 1 x i1> + // CHECK: garel.for.yield %[[#AGG]] graphalg.yield %3 : !graphalg.mat<1 x 1 x i1> } return %2 : !graphalg.mat<42 x 42 x i1> From 1a0d2d64261f52cb76a10e723384348daab1cb3c Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 20:49:30 +0000 Subject: [PATCH 25/32] MatMulJoinOp. --- compiler/src/garel/GraphAlgToRel.cpp | 95 +++++++++- .../test/graphalg-to-rel/mat-mul-join.mlir | 175 ++++++++++++++++++ 2 files changed, 268 insertions(+), 2 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/mat-mul-join.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 2230f22..fa87a65 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -398,6 +398,32 @@ static mlir::IntegerAttr tryGetConstantInt(mlir::Value v) { return llvm::cast(attr); } +static mlir::FailureOr +createMul(mlir::Operation *op, graphalg::SemiringTypeInterface sring, + mlir::Value lhs, mlir::Value rhs, mlir::OpBuilder &builder) { + auto *ctx = builder.getContext(); + if (sring == graphalg::SemiringTypes::forBool(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forInt(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forReal(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forTropInt(ctx) || + sring == graphalg::SemiringTypes::forTropMaxInt(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } else if (sring == graphalg::SemiringTypes::forTropReal(ctx)) { + return mlir::Value( + builder.create(op->getLoc(), lhs, rhs)); + } + + return op->emitOpError("multiplication with semiring ") + << sring << " is not supported"; +} + // ============================================================================= // =============================== Op Conversion =============================== // ============================================================================= @@ -854,6 +880,71 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::MatMulJoinOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor lhs(op.getLhs(), adaptor.getLhs()); + MatrixAdaptor rhs(op.getRhs(), adaptor.getRhs()); + MatrixAdaptor result(op, typeConverter->convertType(op.getType())); + + // Join matrices. + llvm::SmallVector predicates; + if (lhs.hasColColumn() && rhs.hasRowColumn()) { + predicates.push_back(rewriter.getAttr( + /*lhsRelIdx=*/0, lhs.colColumn(), /*rhsRelIdx=*/1, rhs.rowColumn())); + } + + auto joinOp = rewriter.create( + op->getLoc(), mlir::ValueRange{lhs.relation(), rhs.relation()}, + predicates); + + // Project the multiplied values. + auto projOp = + rewriter.create(op.getLoc(), result.relType(), joinOp); + { + auto &body = projOp.createProjectionsBlock(); + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&body); + + llvm::SmallVector projections; + + if (lhs.hasRowColumn()) { + projections.push_back(rewriter.create( + op.getLoc(), lhs.rowColumn(), body.getArgument(0))); + } + + if (rhs.hasColColumn()) { + // In the join output, rhs columns come after lhs columns. + auto colIdx = lhs.columns().size() + rhs.colColumn(); + projections.push_back( + rewriter.create(op.getLoc(), colIdx, body.getArgument(0))); + } + + // Get the value slots. + auto lhsVal = rewriter.create(op.getLoc(), lhs.valColumn(), + body.getArgument(0)); + auto rhsVal = rewriter.create( + op.getLoc(), lhs.columns().size() + rhs.valColumn(), + body.getArgument(0)); + + // Perform the multiplication + auto mulOp = createMul( + op, + llvm::cast(op.getType().getSemiring()), + lhsVal, rhsVal, rewriter); + if (mlir::failed(mulOp)) { + return mlir::failure(); + } + + projections.push_back(*mulOp); + rewriter.create(op.getLoc(), projections); + } + + rewriter.replaceOp(op, projOp); + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -1062,8 +1153,8 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion, OpConversion>( - matrixTypeConverter, &getContext()); + OpConversion, OpConversion, + OpConversion>(matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/mat-mul-join.mlir b/compiler/test/graphalg-to-rel/mat-mul-join.mlir new file mode 100644 index 0000000..bae4363 --- /dev/null +++ b/compiler/test/graphalg-to-rel/mat-mul-join.mlir @@ -0,0 +1,175 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// === input/output shapes + +// (a,b) * (b,c) +// CHECK-LABEL: @MatMulABC +func.func @MatMulABC(%arg0: !graphalg.mat<42 x 43 x i64>, %arg1: !graphalg.mat<43 x 44 x i64>) -> !graphalg.mat<42 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[1] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 4 + // CHECK: %[[#LHS:]] = garel.extract 2 + // CHECK: %[[#RHS:]] = garel.extract 5 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i64>, <43 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 44 x i64> +} + +// (1,b) * (b,c) +// CHECK-LABEL: @MatMulBC +func.func @MatMulBC(%arg0: !graphalg.mat<1 x 43 x i64>, %arg1: !graphalg.mat<43 x 44 x i64>) -> !graphalg.mat<1 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 3 + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 4 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 43 x i64>, <43 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 44 x i64> +} + +// (a,1) * (1,c) +// CHECK-LABEL: @MatMulAC +func.func @MatMulAC(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<1 x 44 x i64>) -> !graphalg.mat<42 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 2 + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 3 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#COL]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 1 x i64>, <1 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 44 x i64> +} + +// (a,b) * (b,1) +// CHECK-LABEL: @MatMulAB +func.func @MatMulAB(%arg0: !graphalg.mat<42 x 43 x i64>, %arg1: !graphalg.mat<43 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[1] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#LHS:]] = garel.extract 2 + // CHECK: %[[#RHS:]] = garel.extract 4 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i64>, <43 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// (a,1) * (1,1) +// CHECK-LABEL: @MatMulA +func.func @MatMulA(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 2 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#ROW]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 1 x i64>, <1 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// (1, b) * (b, 1) +// CHECK-LABEL: @MatMulB +func.func @MatMulB(%arg0: !graphalg.mat<1 x 43 x i64>, %arg1: !graphalg.mat<43 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [<0[0] = 1[0]>] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#LHS:]] = garel.extract 1 + // CHECK: %[[#RHS:]] = garel.extract 3 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 43 x i64>, <43 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// (1,1) * (1,c) +// CHECK-LABEL: @MatMulC +func.func @MatMulC(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 44 x i64>) -> !graphalg.mat<1 x 44 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 2 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#COL]], %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 1 x i64>, <1 x 44 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 44 x i64> +} + +// (1,1) * (1,1) +// CHECK-LABEL: @MatMulScalar +func.func @MatMulScalar(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#JOIN:]] = garel.join %arg0, %arg1 : !garel.rel, !garel.rel [] + // CHECK: %[[#PROJECT:]] = garel.project %[[#JOIN]] + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#VAL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#VAL]] + %0 = graphalg.mxm_join %arg0, %arg1 : <1 x 1 x i64>, <1 x 1 x i64> + + // CHECK: return %[[#PROJECT]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// === Semirings + +// CHECK-LABEL: @MatMulBool +func.func @MatMulBool(%arg0: !graphalg.mat<42 x 43 x i1>, %arg1: !graphalg.mat<43 x 44 x i1>) -> !graphalg.mat<42 x 44 x i1> { + // CHECK: arith.andi + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i1>, <43 x 44 x i1> + return %0 : !graphalg.mat<42 x 44 x i1> +} + +// CHECK-LABEL: @MatMulInt +func.func @MatMulInt(%arg0: !graphalg.mat<42 x 43 x i64>, %arg1: !graphalg.mat<43 x 44 x i64>) -> !graphalg.mat<42 x 44 x i64> { + // CHECK: arith.muli + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x i64>, <43 x 44 x i64> + return %0 : !graphalg.mat<42 x 44 x i64> +} + +// CHECK-LABEL: @MatMulReal +func.func @MatMulReal(%arg0: !graphalg.mat<42 x 43 x f64>, %arg1: !graphalg.mat<43 x 44 x f64>) -> !graphalg.mat<42 x 44 x f64> { + // CHECK: arith.mulf + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x f64>, <43 x 44 x f64> + return %0 : !graphalg.mat<42 x 44 x f64> +} + +// CHECK-LABEL: @MatMulTropInt +func.func @MatMulTropInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_i64>, %arg1: !graphalg.mat<43 x 44 x !graphalg.trop_i64>) -> !graphalg.mat<42 x 44 x !graphalg.trop_i64> { + // CHECK: arith.addi + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x !graphalg.trop_i64>, <43 x 44 x !graphalg.trop_i64> + return %0 : !graphalg.mat<42 x 44 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @MatMulTropReal +func.func @MatMulTropReal(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_f64>, %arg1: !graphalg.mat<43 x 44 x !graphalg.trop_f64>) -> !graphalg.mat<42 x 44 x !graphalg.trop_f64> { + // CHECK: arith.addf + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x !graphalg.trop_f64>, <43 x 44 x !graphalg.trop_f64> + return %0 : !graphalg.mat<42 x 44 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @MatMulTropMaxInt +func.func @MatMulTropMaxInt(%arg0: !graphalg.mat<42 x 43 x !graphalg.trop_max_i64>, %arg1: !graphalg.mat<43 x 44 x !graphalg.trop_max_i64>) -> !graphalg.mat<42 x 44 x !graphalg.trop_max_i64> { + // CHECK: arith.addi + %0 = graphalg.mxm_join %arg0, %arg1 : <42 x 43 x !graphalg.trop_max_i64>, <43 x 44 x !graphalg.trop_max_i64> + return %0 : !graphalg.mat<42 x 44 x !graphalg.trop_max_i64> +} From faefc67d5567dc886e898845d5131aeda963484f Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 21:00:57 +0000 Subject: [PATCH 26/32] MulOp. --- compiler/src/garel/GraphAlgToRel.cpp | 21 +++++- compiler/test/graphalg-to-rel/mul.mlir | 91 ++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 2 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/mul.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index fa87a65..65f9c55 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -22,6 +22,7 @@ #include "garel/GARelOps.h" #include "garel/GARelTypes.h" #include "graphalg/GraphAlgAttr.h" +#include "graphalg/GraphAlgCast.h" #include "graphalg/GraphAlgDialect.h" #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" @@ -1111,6 +1112,22 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::MulOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto sring = op.getType(); + auto *ctx = rewriter.getContext(); + auto mulOp = createMul(op, llvm::cast(sring), + adaptor.getLhs(), adaptor.getRhs(), rewriter); + if (mlir::failed(mulOp)) { + return mlir::failure(); + } + + rewriter.replaceOp(op, *mulOp); + return mlir::success(); +} + static bool hasRelationSignature(mlir::func::FuncOp op) { // All inputs should be relations auto funcType = op.getFunctionType(); @@ -1162,8 +1179,8 @@ void GraphAlgToRel::runOnOperation() { patterns .add, OpConversion, OpConversion, - OpConversion, OpConversion>( - semiringTypeConverter, &getContext()); + OpConversion, OpConversion, + OpConversion>(semiringTypeConverter, &getContext()); if (mlir::failed(mlir::applyFullConversion(getOperation(), target, std::move(patterns)))) { diff --git a/compiler/test/graphalg-to-rel/mul.mlir b/compiler/test/graphalg-to-rel/mul.mlir new file mode 100644 index 0000000..8d35cf9 --- /dev/null +++ b/compiler/test/graphalg-to-rel/mul.mlir @@ -0,0 +1,91 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @MulBool +func.func @MulBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { + ^bb0(%arg2 : i1, %arg3: i1): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.andi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : i1 + graphalg.apply.return %1 : i1 + } + + return %0 : !graphalg.mat<1 x 1 x i1> +} + +// CHECK-LABEL: @MulInt +func.func @MulInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg2 : i64, %arg3: i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.muli %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : i64 + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @MulReal +func.func @MulReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.mulf %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : f64 + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} + +// CHECK-LABEL: @MulTropInt +func.func @MulTropInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_i64> -> <1 x 1 x !graphalg.trop_i64> { + ^bb0(%arg2 : !graphalg.trop_i64, %arg3: !graphalg.trop_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.addi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : !graphalg.trop_i64 + graphalg.apply.return %1 : !graphalg.trop_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_i64> +} + +// CHECK-LABEL: @MulTropReal +func.func @MulTropReal(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_f64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_f64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_f64>, !graphalg.mat<1 x 1 x !graphalg.trop_f64> -> <1 x 1 x !graphalg.trop_f64> { + ^bb0(%arg2 : !graphalg.trop_f64, %arg3: !graphalg.trop_f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.addf %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : !graphalg.trop_f64 + graphalg.apply.return %1 : !graphalg.trop_f64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_f64> +} + +// CHECK-LABEL: @MulTropMaxInt +func.func @MulTropMaxInt(%arg0: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, %arg1: !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>) -> !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64>, !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> -> <1 x 1 x !graphalg.trop_max_i64> { + ^bb0(%arg2 : !graphalg.trop_max_i64, %arg3: !graphalg.trop_max_i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#MUL:]] = arith.addi %[[#LHS]], %[[#RHS]] + // CHECK: garel.project.return %[[#MUL]] + %1 = graphalg.mul %arg2, %arg3 : !graphalg.trop_max_i64 + graphalg.apply.return %1 : !graphalg.trop_max_i64 + } + + return %0 : !graphalg.mat<1 x 1 x !graphalg.trop_max_i64> +} From 3057504f67f5b8d3d30cfb4b754c5b837bac1487 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 21:22:07 +0000 Subject: [PATCH 27/32] PickAnyOp. --- compiler/include/garel/GARelOps.td | 4 +- compiler/src/garel/GARelOps.cpp | 21 ++++---- compiler/src/garel/GraphAlgToRel.cpp | 60 ++++++++++++++++++++- compiler/test/graphalg-to-rel/pick-any.mlir | 27 ++++++++++ 4 files changed, 97 insertions(+), 15 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/pick-any.mlir diff --git a/compiler/include/garel/GARelOps.td b/compiler/include/garel/GARelOps.td index 4d46290..f8ccaa3 100644 --- a/compiler/include/garel/GARelOps.td +++ b/compiler/include/garel/GARelOps.td @@ -58,9 +58,6 @@ def SelectOp : GARel_Op<"select", [ let arguments = (ins Relation:$input); let regions = (region SizedRegion<1>:$predicates); - let builders = [OpBuilder<(ins "mlir::Value":$child)>]; - let skipDefaultBuilders = 1; - let results = (outs Relation:$result); let assemblyFormat = [{ @@ -70,6 +67,7 @@ def SelectOp : GARel_Op<"select", [ let hasRegionVerifier = 1; let extraClassDeclaration = [{ + mlir::Block& createPredicatesBlock(); SelectReturnOp getTerminator(); }]; } diff --git a/compiler/src/garel/GARelOps.cpp b/compiler/src/garel/GARelOps.cpp index cf38e6b..b70abcd 100644 --- a/compiler/src/garel/GARelOps.cpp +++ b/compiler/src/garel/GARelOps.cpp @@ -74,17 +74,6 @@ ProjectReturnOp ProjectOp::getTerminator() { } // === SelectOp === -void SelectOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value input) { - auto region = state.addRegion(); - auto &block = region->emplaceBlock(); - auto inputType = llvm::cast(input.getType()); - block.addArgument(builder.getType(inputType.getColumns()), - builder.getUnknownLoc()); - state.addTypes(input.getType()); - state.addOperands(input); -} - mlir::LogicalResult SelectOp::verifyRegions() { if (getPredicates().getNumArguments() != 1) { return emitOpError("predicates block should have exactly one argument"); @@ -108,6 +97,16 @@ mlir::LogicalResult SelectOp::verifyRegions() { return mlir::success(); } +mlir::Block &SelectOp::createPredicatesBlock() { + assert(getPredicates().empty() && "Already have a predicates block"); + auto &block = getPredicates().emplaceBlock(); + // Same columns as the input, but as a tuple. + block.addArgument( + TupleType::get(getContext(), getInput().getType().getColumns()), + getInput().getLoc()); + return block; +} + SelectReturnOp SelectOp::getTerminator() { return llvm::cast(getPredicates().front().getTerminator()); } diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 65f9c55..8db8679 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -946,6 +947,62 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::PickAnyOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + MatrixAdaptor output(op, typeConverter->convertType(op.getType())); + + auto *ctx = rewriter.getContext(); + + // Remove rows where value is the additive identity. + auto selectOp = rewriter.create(op.getLoc(), input.relation()); + { + mlir::OpBuilder::InsertionGuard guard(rewriter); + auto &body = selectOp.createPredicatesBlock(); + rewriter.setInsertionPointToStart(&body); + + auto valOp = rewriter.create(op.getLoc(), input.valColumn(), + body.getArgument(0)); + auto addIdent = convertConstant(op, input.semiring().addIdentity()); + if (mlir::failed(addIdent)) { + return mlir::failure(); + } + + auto addIdentOp = + rewriter.create(op.getLoc(), *addIdent); + mlir::Value cmpOp; + if (addIdentOp.getType().isF64()) { + cmpOp = rewriter.create( + op.getLoc(), mlir::arith::CmpFPredicate::ONE, valOp, addIdentOp); + } else { + cmpOp = rewriter.create( + op.getLoc(), mlir::arith::CmpIPredicate::ne, valOp, addIdentOp); + } + + rewriter.create(op.getLoc(), mlir::ValueRange{cmpOp}); + } + + llvm::SmallVector groupBy; + if (input.hasRowColumn()) { + groupBy.push_back(input.rowColumn()); + } + + assert(input.hasColColumn()); + std::array aggregators{ + // Minimum column + rewriter.getAttr( + AggregateFunc::MIN, std::array{input.colColumn()}), + // Value for minimum column + rewriter.getAttr( + AggregateFunc::ARGMIN, + std::array{input.valColumn(), input.colColumn()}), + }; + rewriter.replaceOpWithNewOp(op, selectOp, groupBy, aggregators); + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -1171,7 +1228,8 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion>(matrixTypeConverter, &getContext()); + OpConversion, OpConversion>( + matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/pick-any.mlir b/compiler/test/graphalg-to-rel/pick-any.mlir new file mode 100644 index 0000000..adefc89 --- /dev/null +++ b/compiler/test/graphalg-to-rel/pick-any.mlir @@ -0,0 +1,27 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +func.func @PickAnyMat(%arg0: !graphalg.mat<42 x 43 x i1>) -> !graphalg.mat<42 x 43 x i1> { + // CHECK: %[[#SELECT:]] = garel.select %arg0 + // CHECK: %[[#VAL:]] = garel.extract 2 + // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#VAL]], %false + // CHECK: garel.select.return %[[#CMP]] + // CHECK: %[[#AGG:]] = garel.aggregate %[[#SELECT]] : group_by=[0] aggregators=[, ] + %0 = graphalg.pick_any %arg0 : <42 x 43 x i1> + + // CHECK: return %[[#AGG]] + return %0 : !graphalg.mat<42 x 43 x i1> +} + +func.func @PickAnyRowVec(%arg0: !graphalg.mat<1 x 42 x i1>) -> !graphalg.mat<1 x 42 x i1> { + // CHECK: %[[#SELECT:]] = garel.select %arg0 + // CHECK: %[[#VAL:]] = garel.extract 1 + // CHECK: %[[#CMP:]] = arith.cmpi ne, %[[#VAL:]], %false + // CHECK: garel.select.return %[[#CMP]] + // CHECK: %[[#AGG:]] = garel.aggregate %[[#SELECT]] : group_by=[] aggregators=[, ] + %0 = graphalg.pick_any %arg0 : <1 x 42 x i1> + + // CHECK: return %[[#AGG]] + return %0 : !graphalg.mat<1 x 42 x i1> +} + +// Note: scalar and column vector cases are folded away From c1eecc35b48e41c2412571ad9ba4102427dffef7 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 21:24:43 +0000 Subject: [PATCH 28/32] SubOp. --- compiler/test/graphalg-to-rel/sub.mlir | 33 ++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 compiler/test/graphalg-to-rel/sub.mlir diff --git a/compiler/test/graphalg-to-rel/sub.mlir b/compiler/test/graphalg-to-rel/sub.mlir new file mode 100644 index 0000000..ac6d389 --- /dev/null +++ b/compiler/test/graphalg-to-rel/sub.mlir @@ -0,0 +1,33 @@ +// RUN: ag-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @SubInt +func.func @SubInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> { + ^bb0(%arg2 : i64, %arg3: i64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#SUB:]] = arith.subi %[[#LHS]], %[[#RHS]] + %1 = arith.subi %arg2, %arg3 : i64 + + // CHECK: garel.project.return %[[#SUB]] + graphalg.apply.return %1 : i64 + } + + return %0 : !graphalg.mat<1 x 1 x i64> +} + +// CHECK-LABEL: @SubReal +func.func @SubReal(%arg0: !graphalg.mat<1 x 1 x f64>, %arg1: !graphalg.mat<1 x 1 x f64>) -> !graphalg.mat<1 x 1 x f64> { + %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x f64>, !graphalg.mat<1 x 1 x f64> -> <1 x 1 x f64> { + ^bb0(%arg2 : f64, %arg3: f64): + // CHECK: %[[#LHS:]] = garel.extract 0 + // CHECK: %[[#RHS:]] = garel.extract 1 + // CHECK: %[[#SUB:]] = arith.subf %[[#LHS]], %[[#RHS]] + %1 = arith.subf %arg2, %arg3 : f64 + + // CHECK: garel.project.return %[[#SUB]] + graphalg.apply.return %1 : f64 + } + + return %0 : !graphalg.mat<1 x 1 x f64> +} From 1559e17682135235e64512400fb7fc4e3caaa522 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 21:30:23 +0000 Subject: [PATCH 29/32] TrilOp. --- compiler/src/garel/GraphAlgToRel.cpp | 32 +++++++++++++++++++++++-- compiler/test/graphalg-to-rel/tril.mlir | 13 ++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 compiler/test/graphalg-to-rel/tril.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 8db8679..b7c529c 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1003,6 +1003,34 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::TrilOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + MatrixAdaptor input(op.getInput(), adaptor.getInput()); + + if (!input.hasRowColumn() || !input.hasColColumn()) { + return op->emitOpError( + "only works on full matrices (not scalars or vector)"); + } + + auto selectOp = rewriter.replaceOpWithNewOp(op, input.relation()); + + auto &body = selectOp.createPredicatesBlock(); + rewriter.setInsertionPointToStart(&body); + + auto row = rewriter.create(op.getLoc(), input.rowColumn(), + body.getArgument(0)); + auto col = rewriter.create(op.getLoc(), input.colColumn(), + body.getArgument(0)); + // col < row + auto cmpOp = rewriter.create( + op.getLoc(), mlir::arith::CmpIPredicate::ult, col, row); + rewriter.create(op.getLoc(), mlir::ValueRange{cmpOp}); + + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -1228,8 +1256,8 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion, OpConversion>( - matrixTypeConverter, &getContext()); + OpConversion, OpConversion, + OpConversion>(matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/tril.mlir b/compiler/test/graphalg-to-rel/tril.mlir new file mode 100644 index 0000000..540b111 --- /dev/null +++ b/compiler/test/graphalg-to-rel/tril.mlir @@ -0,0 +1,13 @@ +// RUN: graphalg-opt --graphalg-to-ipr < %s | FileCheck %s + +func.func @Tril(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#SELECT:]] = garel.select %arg0 + // CHECK: %[[#ROW:]] = garel.extract 0 + // CHECK: %[[#COL:]] = garel.extract 1 + // CHECK: %[[#CMP:]] = arith.cmpi ult, %[[#COL]], %[[#ROW]] + // CHECK: garel.select.return %[[#CMP]] + %0 = graphalg.tril %arg0 : <42 x 42 x i64> + + // return %[[#SELECT]] + return %0 : !graphalg.mat<42 x 42 x i64> +} From 3422d5c35e098abb7ea143d6c2c2f3aa6422da82 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 21:38:28 +0000 Subject: [PATCH 30/32] UnionOp. --- compiler/src/garel/GraphAlgToRel.cpp | 37 +++++++++++++- compiler/test/graphalg-to-rel/union.mlir | 61 ++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 compiler/test/graphalg-to-rel/union.mlir diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index b7c529c..2c62f2d 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -1031,6 +1031,40 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::success(); } +template <> +mlir::LogicalResult OpConversion::matchAndRewrite( + graphalg::UnionOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + auto targetType = + llvm::cast(typeConverter->convertType(op.getType())); + MatrixAdaptor target(op, targetType); + + llvm::SmallVector inputs; + for (auto [matrix, rel] : + llvm::zip_equal(op.getInputs(), adaptor.getInputs())) { + MatrixAdaptor input(matrix, rel); + + // Drop columns we don't want in the output. + llvm::SmallVector remap; + if (target.hasRowColumn()) { + remap.push_back(input.rowColumn()); + } + + if (target.hasColColumn()) { + remap.push_back(input.colColumn()); + } + + remap.push_back(input.valColumn()); + inputs.push_back( + rewriter.createOrFold(op.getLoc(), input.relation(), remap)); + } + + auto newOp = rewriter.createOrFold(op.getLoc(), targetType, inputs); + rewriter.replaceOp(op, newOp); + + return mlir::success(); +} + // ============================================================================= // ============================ Tuple Op Conversion ============================ // ============================================================================= @@ -1257,7 +1291,8 @@ void GraphAlgToRel::runOnOperation() { OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, OpConversion, - OpConversion>(matrixTypeConverter, &getContext()); + OpConversion, OpConversion>( + matrixTypeConverter, &getContext()); patterns.add(semiringTypeConverter, matrixTypeConverter, &getContext()); diff --git a/compiler/test/graphalg-to-rel/union.mlir b/compiler/test/graphalg-to-rel/union.mlir new file mode 100644 index 0000000..7aa3154 --- /dev/null +++ b/compiler/test/graphalg-to-rel/union.mlir @@ -0,0 +1,61 @@ +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s + +// CHECK-LABEL: @UnionMat +func.func @UnionMat(%arg0: !graphalg.mat<42 x 42 x i64>, %arg1: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<42 x 42 x i64>, !graphalg.mat<42 x 42 x i64> -> <42 x 42 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<42 x 42 x i64> +} + +// CHECK-LABEL: @UnionRowVec +func.func @UnionRowVec(%arg0: !graphalg.mat<1 x 42 x i64>, %arg1: !graphalg.mat<1 x 42 x i64>) -> !graphalg.mat<1 x 42 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<1 x 42 x i64>, !graphalg.mat<1 x 42 x i64> -> <1 x 42 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<1 x 42 x i64> +} + +// CHECK-LABEL: @UnionColVec +func.func @UnionColVec(%arg0: !graphalg.mat<42 x 1 x i64>, %arg1: !graphalg.mat<42 x 1 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<42 x 1 x i64>, !graphalg.mat<42 x 1 x i64> -> <42 x 1 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +// CHECK-LABEL: @UnionScalar +func.func @UnionScalar(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#UNION:]] = garel.union %arg0, %arg1 + %0 = graphalg.union %arg0, %arg1 : !graphalg.mat<1 x 1 x i64>, !graphalg.mat<1 x 1 x i64> -> <1 x 1 x i64> + + // CHECK: return %[[#UNION]] + return %0 : !graphalg.mat<1 x 1 x i64> +} + +func.func @UnionFlattenRow(%arg0 : !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 43 x i64> { + // CHECK: %[[#REMAP:]] = garel.remap %arg0 : [1, 2] + %0 = graphalg.union %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 43 x i64> + + // CHECK: return %[[#REMAP]] + return %0 : !graphalg.mat<1 x 43 x i64> +} + +func.func @UnionFlattenCol(%arg0 : !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<42 x 1 x i64> { + // CHECK: %[[#REMAP:]] = garel.remap %arg0 : [0, 2] + %0 = graphalg.union %arg0 : !graphalg.mat<42 x 43 x i64> -> <42 x 1 x i64> + + // CHECK: return %[[#REMAP]] + return %0 : !graphalg.mat<42 x 1 x i64> +} + +func.func @UnionFlattenAll(%arg0 : !graphalg.mat<42 x 43 x i64>) -> !graphalg.mat<1 x 1 x i64> { + // CHECK: %[[#REMAP:]] = garel.remap %arg0 : [2] + %0 = graphalg.union %arg0 : !graphalg.mat<42 x 43 x i64> -> <1 x 1 x i64> + + // CHECK: return %[[#REMAP]] + return %0 : !graphalg.mat<1 x 1 x i64> +} From e40417cd37afe0dde84455b7441d54c68720912f Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Sat, 24 Jan 2026 21:41:42 +0000 Subject: [PATCH 31/32] Finished porting graphalg-to-rel. --- compiler/src/garel/GraphAlgToRel.cpp | 31 +++++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 2c62f2d..002cb45 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -28,7 +29,6 @@ #include "graphalg/GraphAlgOps.h" #include "graphalg/GraphAlgTypes.h" #include "graphalg/SemiringTypes.h" -#include "mlir/IR/ValueRange.h" namespace garel { @@ -37,6 +37,21 @@ namespace garel { namespace { +/** + * Converts all GraphAlg IR ops into relation ops from the GARel dialect. + * + * Matrices, vectors and scalars are converted into relations: + * - Matrices => (row, column, value) tuples + * - Vectors => (row, value) tuples + * - Scalars => a single (value) tuple. For consistency, they are still + * relations. + * + * Top-level operations are converted into relational ops such as \c ProjectOp, + * \c JoinOp and \c AggregateOp. + * + * Scalar operations inside of \c ApplyOp are converted into ops from the arith + * dialect. + */ class GraphAlgToRel : public impl::GraphAlgToRelBase { public: using impl::GraphAlgToRelBase::GraphAlgToRelBase; @@ -489,7 +504,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( std::swap(columns[0], columns[1]); } - // Return the input slots (after row and column have been swapped) + // Return the input columns (after row and column have been swapped) llvm::SmallVector results; for (auto col : columns) { results.emplace_back( @@ -594,19 +609,19 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( auto &body = projectOp.createProjectionsBlock(); rewriter.setInsertionPointToStart(&body); - llvm::SmallVector slotReads; + llvm::SmallVector columnReads; for (auto col : valColumns) { - slotReads.emplace_back( + columnReads.emplace_back( rewriter.create(op->getLoc(), col, body.getArgument(0))); } // Inline into new body rewriter.inlineBlockBefore(&op.getBody().front(), &body, body.end(), - slotReads); + columnReads); rewriter.replaceOp(op, projectOp); - // Attach the row and column slot to the return op. + // Attach the row and column indexes to the return op. auto returnOp = llvm::cast(body.getTerminator()); if (!rowColumns.empty()) { returnOp->setAttr(APPLY_ROW_IDX_ATTR_KEY, @@ -641,7 +656,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( rewriter.create(op->getLoc(), idx, inputTuple)); } - // The value slot + // The value column results.emplace_back(adaptor.getValue()); rewriter.replaceOpWithNewOp(op, results); @@ -923,7 +938,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( rewriter.create(op.getLoc(), colIdx, body.getArgument(0))); } - // Get the value slots. + // Get the value columns. auto lhsVal = rewriter.create(op.getLoc(), lhs.valColumn(), body.getArgument(0)); auto rhsVal = rewriter.create( From 22b37e997e697ffc7f66605d7d19f9fa31565076 Mon Sep 17 00:00:00 2001 From: Daan de Graaf Date: Mon, 26 Jan 2026 09:24:40 +0000 Subject: [PATCH 32/32] Cleanup. --- compiler/include/garel/GARelTypes.td | 7 -- compiler/src/garel/GraphAlgToRel.cpp | 26 ++++--- compiler/test/graphalg-to-rel/GEMINI.md | 90 ------------------------- compiler/test/graphalg-to-rel/sub.mlir | 2 +- compiler/test/graphalg-to-rel/tril.mlir | 2 +- 5 files changed, 14 insertions(+), 113 deletions(-) delete mode 100644 compiler/test/graphalg-to-rel/GEMINI.md diff --git a/compiler/include/garel/GARelTypes.td b/compiler/include/garel/GARelTypes.td index fcb4709..5d848b5 100644 --- a/compiler/include/garel/GARelTypes.td +++ b/compiler/include/garel/GARelTypes.td @@ -22,13 +22,6 @@ def Relation : GARel_Type<"Relation", "rel"> { }]; } -def I1Relation - : ConfinedType - , BuildableType<[{ - $_builder.getType( - std::array{$_builder.getI1Type()}) - }]>; - def Tuple : GARel_Type<"Tuple", "tuple"> { let summary = "A single tuple"; diff --git a/compiler/src/garel/GraphAlgToRel.cpp b/compiler/src/garel/GraphAlgToRel.cpp index 002cb45..ee60821 100644 --- a/compiler/src/garel/GraphAlgToRel.cpp +++ b/compiler/src/garel/GraphAlgToRel.cpp @@ -456,7 +456,6 @@ mlir::LogicalResult OpConversion::matchAndRewrite( << op.getFunctionType() << " cannot be converted"; } - auto result = mlir::success(); rewriter.modifyOpInPlace(op, [&]() { // Update function type. op.setFunctionType(funcType); @@ -469,7 +468,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( return mlir::failure(); } - return result; + return mlir::success(); } template <> @@ -507,12 +506,11 @@ mlir::LogicalResult OpConversion::matchAndRewrite( // Return the input columns (after row and column have been swapped) llvm::SmallVector results; for (auto col : columns) { - results.emplace_back( + results.push_back( rewriter.create(op.getLoc(), col, body.getArgument(0))); } rewriter.create(op.getLoc(), results); - return mlir::success(); } @@ -536,7 +534,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( llvm::SmallVector valColumns; ColumnIdx nextColumnIdx = 0; for (const auto &[idx, input] : llvm::enumerate(inputs)) { - joinChildren.emplace_back(input.relation()); + joinChildren.push_back(input.relation()); if (input.hasRowColumn()) { rowColumns.push_back(InputColumnRef{ @@ -565,7 +563,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( // Broadcast to all rows. auto rowsOp = createDimRead(op.getLoc(), output.matrixType().getRows(), rewriter); - joinChildren.emplace_back(rowsOp); + joinChildren.push_back(rowsOp); rowColumns.push_back(InputColumnRef{ .relIdx = joinChildren.size() - 1, .colIdx = 0, @@ -578,7 +576,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( // Broadcast to all columns. auto colsOp = createDimRead(op.getLoc(), output.matrixType().getCols(), rewriter); - joinChildren.emplace_back(colsOp); + joinChildren.push_back(colsOp); colColumns.push_back(InputColumnRef{ .relIdx = joinChildren.size() - 1, .colIdx = 0, @@ -611,7 +609,7 @@ mlir::LogicalResult ApplyOpConversion::matchAndRewrite( llvm::SmallVector columnReads; for (auto col : valColumns) { - columnReads.emplace_back( + columnReads.push_back( rewriter.create(op->getLoc(), col, body.getArgument(0))); } @@ -647,17 +645,17 @@ mlir::LogicalResult OpConversion::matchAndRewrite( auto inputTuple = op->getBlock()->getArgument(0); if (auto idx = op->getAttrOfType(APPLY_ROW_IDX_ATTR_KEY)) { - results.emplace_back( + results.push_back( rewriter.create(op->getLoc(), idx, inputTuple)); } if (auto idx = op->getAttrOfType(APPLY_COL_IDX_ATTR_KEY)) { - results.emplace_back( + results.push_back( rewriter.create(op->getLoc(), idx, inputTuple)); } // The value column - results.emplace_back(adaptor.getValue()); + results.push_back(adaptor.getValue()); rewriter.replaceOpWithNewOp(op, results); return mlir::success(); @@ -705,7 +703,6 @@ mlir::LogicalResult OpConversion::matchAndRewrite( llvm::ArrayRef{}); // Remap to correctly order as (row, col, val). - // TODO: Skip this if it is unnecessary. llvm::SmallVector outputColumns; if (rowColumnIdx) { outputColumns.push_back(*rowColumnIdx); @@ -717,6 +714,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( outputColumns.push_back(valColumnIdx); + // NOTE: folds if the remapping is unncessary. auto remapped = rewriter.createOrFold(op.getLoc(), joinOp, outputColumns); rewriter.replaceOp(op, remapped); @@ -832,7 +830,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( auto result = op->getResult(i); if (result.use_empty()) { // Not used. Take init arg as a dummy value. - resultValues.emplace_back(adaptor.getInitArgs()[i]); + resultValues.push_back(adaptor.getInitArgs()[i]); continue; } @@ -856,7 +854,7 @@ mlir::LogicalResult OpConversion::matchAndRewrite( *blockSignature); } - resultValues.emplace_back(forOp); + resultValues.push_back(forOp); } rewriter.replaceOp(op, resultValues); diff --git a/compiler/test/graphalg-to-rel/GEMINI.md b/compiler/test/graphalg-to-rel/GEMINI.md deleted file mode 100644 index e0fe9e7..0000000 --- a/compiler/test/graphalg-to-rel/GEMINI.md +++ /dev/null @@ -1,90 +0,0 @@ -# GraphAlg to Relation Algebra tests -These test files verify the `graphalg-to-rel` pass, which converts ops from the `graphalg` dialect into `garel` and `arith` dialect ops. - -## Running Tests -Tests require the `graphalg-opt` binary, which is built by running `cmake --build compiler/build --target graphalg-opt`. -To get the output for a test file, run `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/.mlir`. -Test files contain `CHECK` comments that are verified using LLVM's FileCheck tool, installed as `FileCheck-20`. - -If you make any changes and have verified that the individual tests are correct, run the integration tests as a final check: `cmake --build compiler/build --target check`. - -## Coding style -### Use CHECK-LABEL For independent test functions -```mlir -// CHECK-LABEL: @AddBool -func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { - ... -} - -// CHECK-LABEL: @AddInt -func.func @AddInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { - ... -} -``` - -### Keep new op CHECKs close to original ops -Keep CHECK comments for output ops directly before and at the same indentation as the original ops they were generated from. - -**GOOD**: -```mlir -// CHECK-LABEL: @AddBool -func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { - // CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> - %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { - ^bb0(%arg2 : i1, %arg3: i1): - // CHECK: %[[#LHS:]] = garel.extract 0 - // CHECK: %[[#RHS:]] = garel.extract 1 - // CHECK: %[[#ADD:]] = arith.ori %[[#LHS]], %[[#RHS]] - %1 = graphalg.add %arg2, %arg3 : i1 - - // CHECK: garel.project.return %[[#ADD]] - graphalg.apply.return %1 : i1 - } - - // CHECK: return %[[#PROJECT]] - return %0 : !graphalg.mat<1 x 1 x i1> -} -``` - -**BAD**: -```mlir -// CHECK-LABEL: @AddBool -// CHECK: %[[#PROJECT:]] = garel.project {{.*}} : -> -// CHECK: %[[#LHS:]] = garel.extract 0 -// CHECK: %[[#RHS:]] = garel.extract 1 -// CHECK: %[[#ADD:]] = arith.ori %[[#LHS]], %[[#RHS]] -// CHECK: garel.project.return %[[#ADD]] -// CHECK: return %[[#PROJECT]] -func.func @AddBool(%arg0: !graphalg.mat<1 x 1 x i1>, %arg1: !graphalg.mat<1 x 1 x i1>) -> !graphalg.mat<1 x 1 x i1> { - %0 = graphalg.apply %arg0, %arg1 : !graphalg.mat<1 x 1 x i1>, !graphalg.mat<1 x 1 x i1> -> <1 x 1 x i1> { - ^bb0(%arg2 : i1, %arg3: i1): - %1 = graphalg.add %arg2, %arg3 : i1 - graphalg.apply.return %1 : i1 - } - - return %0 : !graphalg.mat<1 x 1 x i1> -} -``` - -### Prefer numeric substitution blocks when possible -**GOOD**: -```mlir -// CHECK: %[[#RNG:]] = garel.range 42 -``` - -**BAD**: -```mlir -// CHECK: %[[RNG:.+]] = garel.range 42 -``` - -## Porting IPR tests -If you are asked to port an IPR testcase, do these things: -1. Change ag-opt in the `RUN` comment to graphalg-opt and the pass from --graphalg-to-ipr to --graphalg-to-rel -2. Run `./compiler/build/tools/graphalg-opt --graphalg-to-rel compiler/test/graphalg-to-rel/.mlir` to see the expected output. Use that to guide the changes described in (3) and (4). -3. Replace `ipr.tuplestream` types with the corresponding `garel.relation`, and `ipr.tuple` with `garel.tuple` -4. Replace `ipr.*` ops with `garel.*` or `arith.*` ops. -5. Verify your changes with `FileCheck-20` (see guide to running tests above). -6. When you have verified your changes to the file, run the integration tests to double-check. - -Do not make changes to the input IR (the parts not in comments). -If you really think this is necessary, ask first. diff --git a/compiler/test/graphalg-to-rel/sub.mlir b/compiler/test/graphalg-to-rel/sub.mlir index ac6d389..7429978 100644 --- a/compiler/test/graphalg-to-rel/sub.mlir +++ b/compiler/test/graphalg-to-rel/sub.mlir @@ -1,4 +1,4 @@ -// RUN: ag-opt --graphalg-to-rel < %s | FileCheck %s +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s // CHECK-LABEL: @SubInt func.func @SubInt(%arg0: !graphalg.mat<1 x 1 x i64>, %arg1: !graphalg.mat<1 x 1 x i64>) -> !graphalg.mat<1 x 1 x i64> { diff --git a/compiler/test/graphalg-to-rel/tril.mlir b/compiler/test/graphalg-to-rel/tril.mlir index 540b111..de25b1d 100644 --- a/compiler/test/graphalg-to-rel/tril.mlir +++ b/compiler/test/graphalg-to-rel/tril.mlir @@ -1,4 +1,4 @@ -// RUN: graphalg-opt --graphalg-to-ipr < %s | FileCheck %s +// RUN: graphalg-opt --graphalg-to-rel < %s | FileCheck %s func.func @Tril(%arg0: !graphalg.mat<42 x 42 x i64>) -> !graphalg.mat<42 x 42 x i64> { // CHECK: %[[#SELECT:]] = garel.select %arg0