Skip to content

Commit

Permalink
增加spo检索
Browse files Browse the repository at this point in the history
  • Loading branch information
royzhao committed Dec 27, 2024
1 parent 3b7dd83 commit e7f0a1a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion kag/solver/execute/default_lf_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _execute_chunk_answer(self, req_id: str, query: str, lf: LFPlan, process_inf
# force chunk retriever, so we clear kg solved answer
process_info['kg_solved_answer'] = []
# chunk retriever
all_related_entities = kg_graph.get_all_entity()
all_related_entities = kg_graph.get_all_spo()
all_related_entities = list(set(all_related_entities))
sub_query = self._generate_sub_query_with_history_qa(history, lf.query)
doc_retrieved = self.chunk_retriever.recall_docs(queries=[query, sub_query],
Expand Down
7 changes: 6 additions & 1 deletion kag/solver/logic/core_modules/common/one_hop_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,12 @@ def get_all_entity(self):
for d in self.entity_map[k]:
all_entity.append(d)
return list(set(all_entity))

def get_all_spo(self):
all_spo = []
for k in self.edge_map.keys():
for d in self.edge_map[k]:
all_spo.append(d)
return all_spo
def _graph_to_json(self):
total_entity_map = {}
edge_dict = {}
Expand Down
10 changes: 4 additions & 6 deletions kag/solver/retriever/impl/default_chunk_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,10 @@ def _add_extra_entity_from_spo(self, matched_entities: Dict, retrieved_spo: List
all_related_entities = []
if retrieved_spo:
for spo in retrieved_spo:
if spo.type not in ['Text', 'attribute']:
all_related_entities.append(spo)
# if spo.from_entity.type not in ['Text', 'attribute']:
# all_related_entities.append(spo.from_entity)
# if spo.end_entity.type not in ['Text', 'attribute']:
# all_related_entities.append(spo.end_entity)
if spo.from_entity.type not in ['Text', 'attribute']:
all_related_entities.append(spo.from_entity)
if spo.end_entity.type not in ['Text', 'attribute']:
all_related_entities.append(spo.end_entity)
all_related_entities = list(set(all_related_entities))

if len(all_related_entities) == 0:
Expand Down

0 comments on commit e7f0a1a

Please sign in to comment.