Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W ReaRev GNN-RAG #9857

Open
wants to merge 41 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d6c115a
Create rearev_data_loader.py
natalieshell22 Dec 13, 2024
f216279
Create rearev_data_loader_test.py
natalieshell22 Dec 13, 2024
7188791
Create rearev.py
natalieshell22 Dec 13, 2024
92a097d
Create trainer_kbqa.py
natalieshell22 Dec 13, 2024
218abf2
Create graph_utils.py
natalieshell22 Dec 13, 2024
6f7676e
Create reason.py
natalieshell22 Dec 13, 2024
019dd36
Update CHANGELOG.md
natalieshell22 Dec 13, 2024
03508a0
Update rearev_data_loader_test.py
natalieshell22 Dec 13, 2024
04fe125
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
7edb8ac
Update rearev_data_loader.py
natalieshell22 Dec 13, 2024
a126372
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
0c30f2e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
c788edb
Update rearev.py
natalieshell22 Dec 13, 2024
e411b63
Update reason.py
natalieshell22 Dec 13, 2024
21e3139
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
f59bc67
Update graph_utils.py
natalieshell22 Dec 13, 2024
c7fc6cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
397291f
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
291d461
Update rearev.py
natalieshell22 Dec 13, 2024
8ed7c76
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
90961e0
Update graph_utils.py
natalieshell22 Dec 13, 2024
f06cab1
Update reason.py
natalieshell22 Dec 13, 2024
e85f635
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
5a090cd
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
70db241
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
834565a
Update trainer_kbqa.py
natalieshell22 Dec 13, 2024
0a0ace9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
119fd07
Update reason.py
natalieshell22 Dec 13, 2024
7f4e99a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
5f24243
Update graph_utils.py
natalieshell22 Dec 13, 2024
280fae9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
9f49b2f
Update rearev.py
natalieshell22 Dec 13, 2024
fbec144
Update graph_utils.py
natalieshell22 Dec 13, 2024
5952480
Update reason.py
natalieshell22 Dec 13, 2024
17295ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
33c9758
Update graph_utils.py
natalieshell22 Dec 13, 2024
3f43890
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
3a78d97
Update reason.py
natalieshell22 Dec 13, 2024
8485eaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
320f77e
Update graph_utils.py
natalieshell22 Dec 13, 2024
5ce693e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## \[2.7.0\] - 2024-MM-DD
## \[2.8.0\] - 2024-MM-DD

## \[2.7.0\] - 2024-12-13

- Created a GNN-RAG functionality for a pyg function which uses graph neural networks to enhance LLM performance for question-answering.
- Added `rearev_data_loader` to take in and process data for ReaRev
- Added `rearev_data_loader_test` as a unit test for `rearev_data_loader`
- Added `rearev.py` and `trainer_kbqa.py` to train our model's knowledge graphs.
- Added `graph_utils.py` and `reason.py`

### Added

Expand Down
207 changes: 207 additions & 0 deletions torch_geometric/loader/rearev_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import json

import numpy as np # Added to fix undefined name 'np'
import torch
from transformers import AutoTokenizer

from torch_geometric.data import Data, DataLoader


class BasicDataLoader:
def __init__(self, config, word2id, relation2id, entity2id, tokenize,
data_type="train"):
self.batch_size = config['batch_size']
print("Init called!")
self.tokenize = tokenize
self._initialize(config, word2id, relation2id, entity2id)
print("Loading file")
self._load_file(config, data_type)
print("Preparing data")
self._prepare_data()
self.build_rel_words(self.tokenize)

def _initialize(self, config, word2id, relation2id, entity2id):
print("Initializing")
self.config = config
self.word2id = word2id
self.relation2id, self.entity2id = relation2id, entity2id
self.id2entity = {v: k for k, v in entity2id.items()}
self.num_relations = len(relation2id)
if config.get('use_inverse_relation', False):
self.num_relations *= 2
if config.get('use_self_loop', False):
self.num_relations += 1

def _load_file(self, config, data_type):
self.data = []
file_path = f"{config['data_folder']}{data_type}.json"
print(f"Loading data from {file_path}...")
with open(file_path) as f:
lines = len(f.readlines())
with open(file_path) as f:
print("Number of lines: ", lines)
iter = 0
for line in f:
line = json.loads(line)
if (iter % 1000 == 0):
print(f"Line {iter} out of {lines}")
if 'entities' in line:
self.data.append(line)
iter += 1
print(f"Loaded {len(self.data)} samples.")

def _prepare_data(self):
self.graphs = [self._create_graph(sample) for sample in self.data]
self.data_loader = DataLoader(self.graphs, batch_size=self.batch_size,
shuffle=True)

def _create_graph(self, sample):
entity_map = {
ent: idx
for idx, ent in enumerate(sample.get('entities', []))
}
x = torch.zeros(len(entity_map), self.num_relations)
edges, edge_attrs = [], []

for head, rel, tail in sample['subgraph']['tuples']:
if head in entity_map and tail in entity_map:
h = entity_map[head]
t = entity_map[tail]
r = self.relation2id.get(rel, len(self.relation2id))
edges.append([h, t])
edge_attrs.append(r)

edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
edge_attr = torch.tensor(edge_attrs, dtype=torch.long)

return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

def reset_batches(self):
self.data_loader = DataLoader(self.graphs, batch_size=self.batch_size,
shuffle=True)

def build_rel_words(self, tokenize):
# Tokenizes relation surface forms.
max_rel_words = 0
rel_words = []
if 'metaqa' in getattr(self, 'data_file', ''):
for rel in self.relation2id:
words = rel.split('_')
max_rel_words = max(len(words), max_rel_words)
rel_words.append(words)
# print(rel_words)
else:
for rel in self.relation2id:
rel = rel.strip()
fields = rel.split('.')
try:
words = fields[-2].split('_') + fields[-1].split('_')
max_rel_words = max(len(words), max_rel_words)
rel_words.append(words)
# print(rel, words)
except Exception: # Changed bare except to except Exception
words = ['UNK']
rel_words.append(words)
# words = fields[-2].split('_') + fields[-1].split('_')

self.max_rel_words = max_rel_words
if tokenize == 'lstm':
self.rel_texts = np.full(
(self.num_kb_relation + 1, self.max_rel_words),
len(self.word2id), dtype=int)
self.rel_texts_inv = np.full(
(self.num_kb_relation + 1, self.max_rel_words),
len(self.word2id), dtype=int)
for rel_id, tokens in enumerate(rel_words):
for j, word in enumerate(tokens):
if j < self.max_rel_words:
if word in self.word2id:
self.rel_texts[rel_id, j] = self.word2id[word]
self.rel_texts_inv[rel_id, j] = self.word2id[word]
else:
self.rel_texts[rel_id, j] = len(self.word2id)
self.rel_texts_inv[rel_id, j] = len(self.word2id)
else:
if tokenize == 'bert':
tokenizer_name = 'bert-base-uncased'
elif tokenize == 'roberta':
tokenizer_name = 'roberta-base'
elif tokenize == 'sbert':
tokenizer_name = 'sentence-transformers/all-MiniLM-L6-v2'
elif tokenize == 'sbert2':
tokenizer_name = 'sentence-transformers/all-mpnet-base-v2'
elif tokenize == 'simcse':
tokenizer_name = 'princeton-nlp/sup-simcse-bert-base-uncased'
elif tokenize == 't5':
tokenizer_name = 't5-small'
elif tokenize == 'relbert':
tokenizer_name = 'pretrained_lms/sr-simbert/'

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
pad_val = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
self.rel_texts = np.full(
(self.num_kb_relation + 1, self.max_rel_words), pad_val,
dtype=int)
self.rel_texts_inv = np.full(
(self.num_kb_relation + 1, self.max_rel_words), pad_val,
dtype=int)

for rel_id, words in enumerate(rel_words):
tokens = tokenizer.encode_plus(text=' '.join(words),
max_length=self.max_rel_words,
pad_to_max_length=True,
return_attention_mask=False,
truncation=True)
tokens_inv = tokenizer.encode_plus(
text=' '.join(words[::-1]), max_length=self.max_rel_words,
pad_to_max_length=True, return_attention_mask=False,
truncation=True)
self.rel_texts[rel_id] = np.array(tokens['input_ids'])
self.rel_texts_inv[rel_id] = np.array(tokens_inv['input_ids'])


class SingleDataLoader(BasicDataLoader):
def get_batch(self):
return next(iter(self.data_loader))


def load_dict(file_path):
with open(file_path, encoding='utf-8') as f:
return {line.strip(): idx for idx, line in enumerate(f)}


def load_data(config, tokenize):
print("Load data called...")
entity2id = load_dict(f"{config['data_folder']}{config['entity2id']}")
word2id = load_dict(f"{config['data_folder']}{config['word2id']}")
relation2id = load_dict(f"{config['data_folder']}{config['relation2id']}")

print("Dictionaries loaded!")

loaders = {
data_type:
SingleDataLoader(config, word2id, relation2id, entity2id, tokenize,
data_type)
for data_type in ['train', 'dev', 'test']
}

return {
**loaders, "entity2id": entity2id,
"relation2id": relation2id,
"word2id": word2id,
"num_word": AutoTokenizer.from_pretrained(tokenize)
}


if __name__ == "__main__":
print("data loading! Main function")
# Define args to avoid undefined name error.
args = {
'batch_size': 32,
'data_folder': './',
'entity2id': 'entity2id.txt',
'word2id': 'word2id.txt',
'relation2id': 'relation2id.txt'
}
# Replace `args` with your configuration dictionary as needed.
dataset = load_data(args, tokenize=lambda x: x)
Empty file.
166 changes: 166 additions & 0 deletions torch_geometric/nn/models/trainer_kbqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import os
import time

import numpy as np
import torch
from dataset_load import load_data
from evaluate import Evaluator
from rearev import ReaRev
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm


class TrainerKBQA:
"""Trainer for Knowledge-Based Question Answering
(KBQA) using PyTorch and PyG.

Handles data loading, model training, evaluation,
and checkpoint management.
"""
def __init__(self, args, model_name, logger=None):
"""Initialize Trainer with configuration, model, and logger.

Args:
args (dict): Training configurations and hyperparameters.
model_name (str): Name of the model to use.
logger (logging.Logger, optional): Logger for training logs.
"""
self.args = args
self.logger = logger
self.device = torch.device("cuda" if args["use_cuda"] else "cpu")

# Hyperparameters and training settings
self.learning_rate = args["lr"]
self.decay_rate = args.get("decay_rate", 0.98)
self.test_batch_size = args["test_batch_size"]
self.warmup_epoch = args["warmup_epoch"]

# Data loading
self.dataset = load_data(args, args["lm"])
self.train_data = self.dataset["train"]
self.valid_data = self.dataset["valid"]
self.test_data = self.dataset["test"]
self.entity2id = self.dataset["entity2id"]
self.relation2id = self.dataset["relation2id"]
self.word2id = self.dataset["word2id"]
self.rel_texts = self.dataset.get("rel_texts")
self.rel_texts_inv = self.dataset.get("rel_texts_inv")
self.num_entity = len(self.entity2id)
self.num_relation = len(self.relation2id)
self.num_word = len(self.word2id)

# Model initialization
if model_name == "ReaRev":
self.model = ReaRev(
args,
num_entity=self.num_entity,
num_relation=self.num_relation,
num_word=self.num_word,
)
if args.get("relation_word_emb"):
self.model.encode_rel_texts(self.rel_texts, self.rel_texts_inv)

self.model.to(self.device)
self.optimizer = Adam(self.model.parameters(), lr=self.learning_rate)
self.scheduler = ExponentialLR(self.optimizer, gamma=self.decay_rate)
self.evaluator = Evaluator(
args=args,
model=self.model,
entity2id=self.entity2id,
relation2id=self.relation2id,
device=self.device,
)

# Load pretrained weights if specified
if args.get("load_experiment"):
self.load_checkpoint(args["load_experiment"])

def train(self, start_epoch, end_epoch):
"""Train the model over a range of epochs.

Args:
start_epoch (int): Starting epoch.
end_epoch (int): Ending epoch.
"""
eval_every = self.args["eval_every"]
for epoch in range(start_epoch, end_epoch + 1):
time.time()

loss, h1_list, f1_list = self._train_epoch()
if self.decay_rate > 0:
self.scheduler.step()

avg_h1, avg_f1 = np.mean(h1_list), np.mean(f1_list)
self.logger.info(f"Training H1: {avg_h1:.4f}, F1: {avg_f1:.4f}")

# Evaluation
if (epoch + 1) % eval_every == 0:
self.evaluate(self.valid_data)

def _train_epoch(self):
"""Train the model for one epoch.

Returns:
tuple: Average loss, H1 scores, and F1 scores.
"""
self.model.train()
losses, h1_list, f1_list = [], [], []
data_loader = self.train_data.data_loader(
batch_size=self.args["batch_size"], shuffle=True)

for batch in tqdm(data_loader, desc="Training"):
self.optimizer.zero_grad()
loss, _, _, tp_list = self.model(batch, training=True)
h1_scores, f1_scores = tp_list

loss.backward()
clip_grad_norm_(self.model.parameters(),
self.args["gradient_clip"])
self.optimizer.step()

losses.append(loss.item())
h1_list.extend(h1_scores)
f1_list.extend(f1_scores)

return np.mean(losses), h1_list, f1_list

def evaluate(self, data):
"""Evaluate the model on a dataset.

Args:
data (Dataset): Dataset to evaluate.

Returns:
dict: Evaluation metrics (F1, H1, EM).
"""
self.model.eval()
with torch.no_grad():
return self.evaluator.evaluate(data,
batch_size=self.test_batch_size)

def save_checkpoint(self, name):
"""Save model checkpoint.

Args:
name (str): Name of the checkpoint.
"""
checkpoint_path = os.path.join(self.args["checkpoint_dir"],
f"{name}.ckpt")
torch.save({"model_state_dict": self.model.state_dict()},
checkpoint_path)
self.logger.info(f"Checkpoint saved: {checkpoint_path}")

def load_checkpoint(self, name):
"""Load model checkpoint.

Args:
name (str): Name of the checkpoint to load.
"""
checkpoint_path = os.path.join(self.args["checkpoint_dir"],
f"{name}.ckpt")
checkpoint = torch.load(checkpoint_path)
self.model.load_state_dict(checkpoint["model_state_dict"],
strict=False)
self.logger.info(f"Checkpoint loaded: {checkpoint_path}")
Loading
Loading