Skip to content

Commit

Permalink
Simpler approach
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Dec 14, 2024
1 parent 560d300 commit c517066
Showing 1 changed file with 140 additions and 61 deletions.
201 changes: 140 additions & 61 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,33 @@ class ElaboratorValue {
uintptr_t storage;
};

template<typename StorageTy>
struct HashedStorage {
HashedStorage(unsigned hashcode = 0, StorageTy *storage = nullptr)
: hashcode(hashcode), storage(storage) {}

unsigned hashcode;
StorageTy *storage;
};

// struct SetStorage;
// struct BagStorage;
// struct SequenceStorage;

} // namespace

namespace llvm {
// llvm::hash_code hash_value(const HashedStorage<SetStorage> &storage) {
// return storage.hashcode;
// }
// llvm::hash_code hash_value(const HashedStorage<BagStorage> &storage) {
// return storage.hashcode;
// }
// llvm::hash_code hash_value(const HashedStorage<SequenceStorage> &storage) {
// return storage.hashcode;
// }


/// Add support for llvm style casts. We provide a cast between To and From if
/// From is mlir::Attribute or derives from it.
template <typename To, typename From>
Expand Down Expand Up @@ -172,23 +196,63 @@ struct DenseMapInfo<ElaboratorValue> {
} // namespace llvm

namespace {
llvm::hash_code hash_value(const ElaboratorValue &val) {
return val.getHashValue();
}

template<typename StorageTy>
struct StorageKeyInfo {
static inline HashedStorage<StorageTy> getEmptyKey() {
return HashedStorage<StorageTy>(0, DenseMapInfo<StorageTy *>::getEmptyKey());
}
static inline HashedStorage<StorageTy> getTombstoneKey() {
return HashedStorage<StorageTy>(0, DenseMapInfo<StorageTy *>::getTombstoneKey());
}

static inline unsigned getHashValue(const HashedStorage<StorageTy> &key) {
return key.hashcode;
}
static inline unsigned getHashValue(const StorageTy &key) {
return key.hashcode;
}

static inline bool isEqual(const HashedStorage<StorageTy> &lhs,
const HashedStorage<StorageTy> &rhs) {
return lhs.storage == rhs.storage;
}
static inline bool isEqual(const StorageTy &lhs, const HashedStorage<StorageTy> &rhs) {
if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
return false;
// Invoke the equality function on the lookup key.
return lhs.isEqual(rhs.storage);
}
};

template<typename StorageTy>
struct StorageInfo : public DenseMapInfo<StorageTy *> {
using Base = DenseMapInfo<StorageTy *>;
static inline unsigned getHashValue(const StorageTy *key) {
return key->hashcode;
}

static inline bool isEqual(const StorageTy *lhs, const StorageTy *rhs) {
if (lhs == rhs)
return true;
if (lhs == Base::getEmptyKey() || lhs == Base::getTombstoneKey() || rhs == Base::getEmptyKey() || rhs == Base::getTombstoneKey())
return false;
return lhs->isEqual(rhs);
}
};

struct SetStorage : public llvm::FoldingSetNode {
struct SetStorage {
SetStorage(SetVector<ElaboratorValue> &&set, Type type)
: set(std::move(set)), type(type) {}

// NOLINTNEXTLINE(readability-identifier-naming)
static void Profile(llvm::FoldingSetNodeID &ID,
const SetVector<ElaboratorValue> &set, Type type) {
for (auto el : set) {
ID.AddInteger(el.getRawStorage());
ID.AddInteger(static_cast<unsigned>(el.getKind()));
}
ID.AddPointer(type.getAsOpaquePointer());
: hashcode(llvm::hash_combine(type, llvm::hash_combine_range(set.begin(), set.end()))), set(std::move(set)), type(type) {}

bool isEqual(const SetStorage *other) const {
return set == other->set && type == other->type;
}

// NOLINTNEXTLINE(readability-identifier-naming)
void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, set, type); }
unsigned hashcode;

// Stores the elaborated values of the set.
SetVector<ElaboratorValue> set;
Expand All @@ -198,24 +262,15 @@ struct SetStorage : public llvm::FoldingSetNode {
Type type;
};

struct BagStorage : public llvm::FoldingSetNode {
struct BagStorage {
BagStorage(MapVector<ElaboratorValue, uint64_t> &&bag, Type type)
: bag(std::move(bag)), type(type) {}

// NOLINTNEXTLINE(readability-identifier-naming)
static void Profile(llvm::FoldingSetNodeID &ID,
const MapVector<ElaboratorValue, uint64_t> &bag,
Type type) {
for (auto el : bag) {
ID.AddInteger(el.first.getRawStorage());
ID.AddInteger(static_cast<unsigned>(el.first.getKind()));
ID.AddInteger(el.second);
}
ID.AddPointer(type.getAsOpaquePointer());
: hashcode(llvm::hash_combine(type, llvm::hash_combine_range(bag.begin(), bag.end()))), bag(std::move(bag)), type(type) {}

bool isEqual(const BagStorage *other) const {
return llvm::equal(bag, other->bag) && type == other->type;
}

// NOLINTNEXTLINE(readability-identifier-naming)
void Profile(llvm::FoldingSetNodeID &ID) const { Profile(ID, bag, type); }
unsigned hashcode;

// Stores the elaborated values of the bag.
MapVector<ElaboratorValue, uint64_t> bag;
Expand All @@ -225,65 +280,86 @@ struct BagStorage : public llvm::FoldingSetNode {
Type type;
};

struct SequenceStorage : public llvm::FoldingSetNode {
struct SequenceStorage {
SequenceStorage(StringRef name, StringAttr familyName,
SmallVector<ElaboratorValue> &&args)
: name(name), familyName(familyName), args(std::move(args)) {}

// NOLINTNEXTLINE(readability-identifier-naming)
static void Profile(llvm::FoldingSetNodeID &ID, StringRef name,
StringAttr familyName, ArrayRef<ElaboratorValue> args) {
ID.AddString(name);
ID.AddPointer(familyName.getAsOpaquePointer());
for (auto el : args) {
ID.AddInteger(el.getRawStorage());
ID.AddInteger(static_cast<unsigned>(el.getKind()));
}
}
: hashcode(llvm::hash_combine(name, familyName, llvm::hash_combine_range(args.begin(), args.end()))), name(name), familyName(familyName), args(std::move(args)) {}

// NOLINTNEXTLINE(readability-identifier-naming)
void Profile(llvm::FoldingSetNodeID &ID) const {
Profile(ID, name, familyName, args);
bool isEqual(const SequenceStorage *other) const {
return name == other->name && familyName == other->familyName && args == other->args;
}

unsigned hashcode;
StringRef name;
StringAttr familyName;
SmallVector<ElaboratorValue> args;
};

// struct LookupKey {
// unsigned hascode;
// function_ref<bool(const BaseStorage *)> isEqual;
// };

class Internalizer {
public:
// template <typename StorageTy, typename... Args>
// StorageTy *internalize(Args &&...args) {
// StorageTy storage(std::forward<Args>(args)...);

// auto existing = getInternSet<StorageTy>().insert_as(HashedStorage<StorageTy>(storage.hashcode), storage);
// StorageTy *&storagePtr = existing.first->storage;
// if (existing.second)
// storagePtr = new (allocator.Allocate<StorageTy>()) StorageTy(std::move(storage));
// return storagePtr;
// }

// template <typename StorageTy>
// DenseSet<HashedStorage<StorageTy>, StorageKeyInfo<StorageTy>> &getInternSet() {
// assert(false && "no generic internalization set");
// }

// template <>
// DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> &getInternSet() {
// return internedSets;
// }

// template <>
// DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> &getInternSet() {
// return internedBags;
// }

// template <>
// DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>> &getInternSet() {
// return internedSequences;
// }

template <typename StorageTy, typename... Args>
StorageTy *internalize(Args &&...args) {
llvm::FoldingSetNodeID profile;
StorageTy::Profile(profile, args...);
void *insertPos = nullptr;
if (auto *storage =
getInternSet<StorageTy>().FindNodeOrInsertPos(profile, insertPos))
return static_cast<StorageTy *>(storage);
auto *storagePtr = new (allocator.Allocate<StorageTy>())
StorageTy(std::forward<Args>(args)...);
getInternSet<StorageTy>().InsertNode(storagePtr, insertPos);
return storagePtr;
auto *storagePtr = new (allocator.Allocate<StorageTy>()) StorageTy(std::forward<Args>(args)...);
auto existing = getInternSet<StorageTy>().insert(storagePtr);
if (!existing.second)
allocator.Deallocate(storagePtr);

return *existing.first;
}

template <typename StorageTy>
llvm::FoldingSet<StorageTy> &getInternSet() {
DenseSet<StorageTy*, StorageInfo<StorageTy>> &getInternSet() {
assert(false && "no generic internalization set");
}

template <>
llvm::FoldingSet<SetStorage> &getInternSet() {
DenseSet<SetStorage*, StorageInfo<SetStorage>> &getInternSet() {
return internedSets;
}

template <>
llvm::FoldingSet<BagStorage> &getInternSet() {
DenseSet<BagStorage*, StorageInfo<BagStorage>> &getInternSet() {
return internedBags;
}

template <>
llvm::FoldingSet<SequenceStorage> &getInternSet() {
DenseSet<SequenceStorage*, StorageInfo<SequenceStorage>> &getInternSet() {
return internedSequences;
}

Expand All @@ -308,9 +384,12 @@ class Internalizer {
// inserting an object of a derived class of ElaboratorValue.
// The custom MapInfo makes sure that we do a value comparison instead of
// comparing the pointers.
llvm::FoldingSet<SetStorage> internedSets;
llvm::FoldingSet<BagStorage> internedBags;
llvm::FoldingSet<SequenceStorage> internedSequences;
// DenseSet<HashedStorage<SetStorage>, StorageKeyInfo<SetStorage>> internedSets;
// DenseSet<HashedStorage<BagStorage>, StorageKeyInfo<BagStorage>> internedBags;
// DenseSet<HashedStorage<SequenceStorage>, StorageKeyInfo<SequenceStorage>> internedSequences;
DenseSet<SetStorage*, StorageInfo<SetStorage>> internedSets;
DenseSet<BagStorage*, StorageInfo<BagStorage>> internedBags;
DenseSet<SequenceStorage*, StorageInfo<SequenceStorage>> internedSequences;
};

/// Holds any typed attribute. Wrapping around an MLIR `Attribute` allows us to
Expand Down

0 comments on commit c517066

Please sign in to comment.