add configuable funcname

This commit is contained in:
yichuan520030910320
2025-07-01 05:02:01 +00:00
parent b81b5d0f86
commit 371e3de04e
4 changed files with 973 additions and 1131 deletions

2067
demo.ipynb
View File

File diff suppressed because it is too large Load Diff

View File

@@ -69,7 +69,7 @@ async def main():
query = "Based on the paper, what are the main techniques LEANN explores to reduce the storage overhead?"
print(f"You: {query}")
chat_response = chat.ask(query, recompute_beighbor_embeddings=True)
chat_response = chat.ask(query, top_k=10, recompute_beighbor_embeddings=True)
print(f"Leann: {chat_response}")
if __name__ == "__main__":

View File

@@ -246,7 +246,7 @@ class DiskannSearcher(LeannBackendSearcherInterface):
raise
def search(self, query: np.ndarray, top_k: int, **kwargs) -> Dict[str, any]:
complexity = kwargs.get("complexity", 100)
complexity = kwargs.get("complexity", 32)
beam_width = kwargs.get("beam_width", 4)
USE_DEFERRED_FETCH = kwargs.get("USE_DEFERRED_FETCH", False)

View File

@@ -139,9 +139,36 @@ class LeannChat:
self.openai_client = openai.OpenAI(api_key=api_key)
return self.openai_client
def ask(self, question: str, **kwargs):
# 1. 检索
results = self.searcher.search(question, top_k=5, **kwargs)
def ask(self, question: str, top_k=5, **kwargs):
"""
Additional keyword arguments (kwargs) for advanced search customization. Example usage:
chat.ask(
"What is ANN?",
top_k=10,
complexity=64,
beam_width=8,
USE_DEFERRED_FETCH=True,
skip_search_reorder=True,
recompute_beighbor_embeddings=True,
dedup_node_dis=True,
prune_ratio=0.1,
batch_recompute=True,
global_pruning=True
)
Supported kwargs:
- complexity (int): Search complexity parameter (default: 32)
- beam_width (int): Beam width for search (default: 4)
- USE_DEFERRED_FETCH (bool): Enable deferred fetch mode (default: False)
- skip_search_reorder (bool): Skip search reorder step (default: False)
- recompute_beighbor_embeddings (bool): Enable ZMQ embedding server for neighbor recomputation (default: False)
- dedup_node_dis (bool): Deduplicate nodes by distance (default: False)
- prune_ratio (float): Pruning ratio for search (default: 0.0)
- batch_recompute (bool): Enable batch recomputation (default: False)
- global_pruning (bool): Enable global pruning (default: False)
"""
results = self.searcher.search(question, top_k=top_k, **kwargs)
context = "\n\n".join([r['text'] for r in results])
# 2. 构建 Prompt