Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG][Elaboration] Add support for arith constants #7890

Draft
wants to merge 1 commit into
base: maerhart-rtg-elaboration-debug-printing
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions include/circt/Dialect/RTG/IR/ArithVisitors.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
//===- ArithVisitors.h - Arith Dialect Visitors -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines visitors that make it easier to work with Arith Ops.
//
//===----------------------------------------------------------------------===//

#ifndef CIRCT_DIALECT_RTG_IR_ARITHVISITORS_H
#define CIRCT_DIALECT_RTG_IR_ARITHVISITORS_H

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "llvm/ADT/TypeSwitch.h"

namespace mlir {
namespace arith {

/// This helps visit TypeOp nodes.
template <typename ConcreteType, typename ResultType = void,
typename... ExtraArgs>
class ArithOpVisitor {
public:
ResultType dispatchOpVisitor(Operation *op, ExtraArgs... args) {
auto *thisCast = static_cast<ConcreteType *>(this);
return TypeSwitch<Operation *, ResultType>(op)
.template Case<ConstantOp, AddIOp, AddUIExtendedOp, SubIOp, MulIOp,
MulSIExtendedOp, MulUIExtendedOp, DivUIOp, DivSIOp,
CeilDivUIOp, CeilDivSIOp, FloorDivSIOp, RemUIOp, RemSIOp,
AndIOp, OrIOp, XOrIOp, ShLIOp, ShRUIOp, ShRSIOp, NegFOp,
AddFOp, SubFOp, MaximumFOp, MaxNumFOp, MaxSIOp, MaxUIOp,
MinimumFOp, MinNumFOp, MinSIOp, MinUIOp, MulFOp, DivFOp,
RemFOp, ExtUIOp, ExtSIOp, ExtFOp, TruncIOp, TruncFOp,
UIToFPOp, SIToFPOp, FPToUIOp, FPToSIOp, IndexCastOp,
IndexCastUIOp, BitcastOp, CmpIOp, CmpFOp, SelectOp>(
[&](auto expr) -> ResultType {
return thisCast->visitOp(expr, args...);
})
.Default([&](auto expr) -> ResultType {
if (op->getDialect() ==
op->getContext()->getLoadedDialect<ArithDialect>()) {
return visitInvalidTypeOp(op, args...);
}
return thisCast->visitExternalOp(op, args...);
});
}

/// This callback is invoked on any RTG operations not handled properly by the
/// TypeSwitch.
ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args) {
op->emitOpError("Unknown Arith operation: ") << op->getName();
abort();
}

/// This callback is invoked on any operations that are not
/// handled by the concrete visitor.
ResultType visitUnhandledOp(Operation *op, ExtraArgs... args);

ResultType visitExternalOp(Operation *op, ExtraArgs... args) {
return ResultType();
}

#define HANDLE(OPTYPE, OPKIND) \
ResultType visitOp(OPTYPE op, ExtraArgs... args) { \
return static_cast<ConcreteType *>(this)->visit##OPKIND##Op(op, args...); \
}

HANDLE(ConstantOp, Unhandled);
HANDLE(AddIOp, Unhandled);
HANDLE(AddUIExtendedOp, Unhandled);
HANDLE(SubIOp, Unhandled);
HANDLE(MulIOp, Unhandled);
HANDLE(MulSIExtendedOp, Unhandled);
HANDLE(MulUIExtendedOp, Unhandled);
HANDLE(DivUIOp, Unhandled);
HANDLE(DivSIOp, Unhandled);
HANDLE(CeilDivUIOp, Unhandled);
HANDLE(CeilDivSIOp, Unhandled);
HANDLE(FloorDivSIOp, Unhandled);
HANDLE(RemUIOp, Unhandled);
HANDLE(RemSIOp, Unhandled);
HANDLE(AndIOp, Unhandled);
HANDLE(OrIOp, Unhandled);
HANDLE(XOrIOp, Unhandled);
HANDLE(ShLIOp, Unhandled);
HANDLE(ShRUIOp, Unhandled);
HANDLE(ShRSIOp, Unhandled);
HANDLE(NegFOp, Unhandled);
HANDLE(AddFOp, Unhandled);
HANDLE(SubFOp, Unhandled);
HANDLE(MaximumFOp, Unhandled);
HANDLE(MaxNumFOp, Unhandled);
HANDLE(MaxSIOp, Unhandled);
HANDLE(MaxUIOp, Unhandled);
HANDLE(MinimumFOp, Unhandled);
HANDLE(MinNumFOp, Unhandled);
HANDLE(MinSIOp, Unhandled);
HANDLE(MinUIOp, Unhandled);
HANDLE(MulFOp, Unhandled);
HANDLE(DivFOp, Unhandled);
HANDLE(RemFOp, Unhandled);
HANDLE(ExtUIOp, Unhandled);
HANDLE(ExtSIOp, Unhandled);
HANDLE(ExtFOp, Unhandled);
HANDLE(TruncIOp, Unhandled);
HANDLE(TruncFOp, Unhandled);
HANDLE(UIToFPOp, Unhandled);
HANDLE(SIToFPOp, Unhandled);
HANDLE(FPToUIOp, Unhandled);
HANDLE(FPToSIOp, Unhandled);
HANDLE(IndexCastOp, Unhandled);
HANDLE(IndexCastUIOp, Unhandled);
HANDLE(BitcastOp, Unhandled);
HANDLE(CmpIOp, Unhandled);
HANDLE(CmpFOp, Unhandled);
HANDLE(SelectOp, Unhandled);
#undef HANDLE
};

} // namespace arith
} // namespace mlir

#endif // CIRCT_DIALECT_RTG_IR_ARITHVISITORS_H
1 change: 1 addition & 0 deletions lib/Dialect/RTG/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_circt_dialect_library(CIRCTRTGTransforms

LINK_LIBS PRIVATE
CIRCTRTGDialect
MLIRArithDialect
MLIRIR
MLIRPass
)
Expand Down
87 changes: 83 additions & 4 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//
//===----------------------------------------------------------------------===//

#include "circt/Dialect/RTG/IR/ArithVisitors.h"
#include "circt/Dialect/RTG/IR/RTGOps.h"
#include "circt/Dialect/RTG/IR/RTGVisitors.h"
#include "circt/Dialect/RTG/Transforms/RTGPasses.h"
Expand Down Expand Up @@ -228,6 +229,53 @@ class SequenceClosureValue : public ElaboratorValue {
SmallVector<ElaboratorValue *> args;
};

/// Holds an evaluated value of an `IndexType` or `IntegerType`'d value.
/// TODO: support integers with more than 64 bits
class IntegerValue : public ElaboratorValue {
public:
IntegerValue(Value value, uint64_t integer)
: ElaboratorValue(value, false), integer(integer) {
assert((isa<IntegerType>(value.getType()) &&
value.getType().getIntOrFloatBitWidth() <= 64) ||
isa<IndexType>(value.getType()));
}

// Implement LLVMs RTTI
static bool classof(const ElaboratorValue *val) {
return !val->isOpaqueValue() &&
(IndexType::classof(val->getType()) ||
(IntegerType::classof(val->getType()) &&
val->getType().getIntOrFloatBitWidth() <= 64));
}

bool containsOpaqueValue() const override { return false; }

llvm::hash_code getHashValue() const override {
return llvm::hash_combine(integer, getType());
}

bool isEqual(const ElaboratorValue &other) const override {
auto *intVal = dyn_cast<IntegerValue>(&other);
if (!intVal)
return false;

return integer == intVal->integer && getType() == intVal->getType();
}

std::string toString() const override {
std::string out;
llvm::raw_string_ostream stream(out);
stream << "<const-integer " << integer << " of type " << getType() << " at "
<< this << ">";
return out;
}

uint64_t getInt() const { return integer; }

private:
uint64_t integer;
};

//===----------------------------------------------------------------------===//
// Hash Map Helpers
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -269,11 +317,19 @@ struct InternMapInfo : public DenseMapInfo<ElaboratorValue *> {
enum class DeletionKind { Keep, Delete };

/// Interprets the IR to perform and lower the represented randomizations.
class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
function_ref<void(Operation *)>> {
class Elaborator
: public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
function_ref<void(Operation *)>>,
public mlir::arith::ArithOpVisitor<Elaborator, FailureOr<DeletionKind>,
function_ref<void(Operation *)>> {
public:
using RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
function_ref<void(Operation *)>>::visitOp;
using RTGBase = RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
function_ref<void(Operation *)>>;
using ArithBase = ArithOpVisitor<Elaborator, FailureOr<DeletionKind>,
function_ref<void(Operation *)>>;

using ArithBase::visitOp;
using RTGBase::visitOp;

Elaborator(SymbolTable &table, const ElaborationOptions &options)
: options(options), symTable(table) {
Expand Down Expand Up @@ -313,6 +369,20 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
return DeletionKind::Keep;
}

FailureOr<DeletionKind>
visitOp(arith::ConstantOp op, function_ref<void(Operation *)> addToWorklist) {
if (auto val = dyn_cast<IntegerAttr>(op.getValue())) {
if (val.getValue().getBitWidth() <= 64 &&
!val.getType().isSignedInteger()) {
internalizeResult<IntegerValue>(op.getResult(),
val.getValue().getZExtValue());
return DeletionKind::Delete;
}
}

return visitExternalOp(op, addToWorklist);
}

FailureOr<DeletionKind>
visitOp(SequenceClosureOp op, function_ref<void(Operation *)> addToWorklist) {
SmallVector<ElaboratorValue *> args;
Expand Down Expand Up @@ -431,6 +501,15 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>,
return DeletionKind::Delete;
}

FailureOr<DeletionKind>
dispatchOpVisitor(Operation *op,
function_ref<void(Operation *)> addToWorklist) {
if (op->getDialect() == op->getContext()->getLoadedDialect<RTGDialect>())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sigh. I don't suppose we can get around this.

return RTGBase::dispatchOpVisitor(op, addToWorklist);

return ArithBase::dispatchOpVisitor(op, addToWorklist);
}

LogicalResult elaborate(TestOp testOp) {
LLVM_DEBUG(llvm::dbgs()
<< "\n=== Elaborating Test @" << testOp.getSymName() << "\n\n");
Expand Down
37 changes: 22 additions & 15 deletions test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// CHECK-LABEL: rtg.sequence @seq0
rtg.sequence @seq0 {
%2 = arith.constant 2 : i32
%2 = hw.constant 2 : i32
}

// CHECK-LABEL: rtg.sequence @seq2
Expand All @@ -19,8 +19,8 @@ rtg.sequence @seq2 {
// Test the set operations and passing a sequence ot another one via argument
// CHECK-LABEL: rtg.test @setOperations
rtg.test @setOperations : !rtg.dict<> {
// CHECK-NEXT: arith.constant 2 : i32
// CHECK-NEXT: arith.constant 2 : i32
// CHECK-NEXT: hw.constant 2 : i32
// CHECK-NEXT: hw.constant 2 : i32
// CHECK-NEXT: }
%0 = rtg.sequence_closure @seq0
%1 = rtg.sequence_closure @seq2(%0 : !rtg.sequence)
Expand All @@ -41,8 +41,8 @@ rtg.sequence @seq3 {

// CHECK-LABEL: rtg.test @setArguments
rtg.test @setArguments : !rtg.dict<> {
// CHECK-NEXT: arith.constant 2 : i32
// CHECK-NEXT: arith.constant 2 : i32
// CHECK-NEXT: hw.constant 2 : i32
// CHECK-NEXT: hw.constant 2 : i32
// CHECK-NEXT: }
%0 = rtg.sequence_closure @seq0
%1 = rtg.sequence_closure @seq2(%0 : !rtg.sequence)
Expand Down Expand Up @@ -70,45 +70,52 @@ rtg.test @noNullOperands : !rtg.dict<> {
}

rtg.target @target0 : !rtg.dict<num_cpus: i32> {
%0 = arith.constant 0 : i32
%0 = hw.constant 0 : i32
rtg.yield %0 : i32
}

rtg.target @target1 : !rtg.dict<num_cpus: i32> {
%0 = arith.constant 1 : i32
%0 = hw.constant 1 : i32
rtg.yield %0 : i32
}

// CHECK-LABEL: @targetTest_target0
// CHECK: [[V0:%.+]] = arith.constant 0
// CHECK: arith.addi [[V0]], [[V0]]
// CHECK: [[V0:%.+]] = hw.constant 0
// CHECK: comb.add [[V0]], [[V0]]

// CHECK-LABEL: @targetTest_target1
// CHECK: [[V0:%.+]] = arith.constant 1
// CHECK: arith.addi [[V0]], [[V0]]
// CHECK: [[V0:%.+]] = hw.constant 1
// CHECK: comb.add [[V0]], [[V0]]
rtg.test @targetTest : !rtg.dict<num_cpus: i32> {
^bb0(%arg0: i32):
arith.addi %arg0, %arg0 : i32
comb.add %arg0, %arg0 : i32
}

// CHECK-NOT: @unmatchedTest
rtg.test @unmatchedTest : !rtg.dict<num_cpus: i64> {
^bb0(%arg0: i64):
arith.addi %arg0, %arg0 : i64
comb.add %arg0, %arg0 : i64
}

// CHECK-LABEL: rtg.test @arithConstant
rtg.test @arithConstant : !rtg.dict<> {
%0 = arith.constant 2 : index
%1 = arith.constant 2 : i32
// CHECK-NEXT: }
}

// -----

rtg.test @opaqueValuesAndSets : !rtg.dict<> {
%0 = arith.constant 2 : i32
%0 = hw.constant 2 : i32
// expected-error @below {{cannot create a set of opaque values because they cannot be reliably uniqued}}
%1 = rtg.set_create %0 : i32
}

// -----

rtg.sequence @seq0 {
%2 = arith.constant 2 : i32
%2 = hw.constant 2 : i32
}

// Test that the elaborator value interning works as intended and exercise 'set_select_random' error messages.
Expand Down
Loading