镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
stage academic conversation
这个提交包含在:
30
config.py
30
config.py
@@ -43,7 +43,7 @@ AVAIL_LLM_MODELS = ["qwen-max", "o1-mini", "o1-mini-2024-09-12", "o1", "o1-2024-
|
|||||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
|
||||||
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
|
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
|
||||||
"gemini-1.5-pro", "chatglm3", "chatglm4",
|
"gemini-1.5-pro", "chatglm3", "chatglm4",
|
||||||
"deepseek-chat", "deepseek-coder", "deepseek-reasoner",
|
"deepseek-chat", "deepseek-coder", "deepseek-reasoner",
|
||||||
"volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226",
|
"volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226",
|
||||||
"dashscope-deepseek-r1", "dashscope-deepseek-v3",
|
"dashscope-deepseek-r1", "dashscope-deepseek-v3",
|
||||||
"dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b",
|
"dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b",
|
||||||
@@ -94,19 +94,19 @@ AVAIL_THEMES = ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast", "Gsta
|
|||||||
|
|
||||||
FONT = "Theme-Default-Font"
|
FONT = "Theme-Default-Font"
|
||||||
AVAIL_FONTS = [
|
AVAIL_FONTS = [
|
||||||
"默认值(Theme-Default-Font)",
|
"默认值(Theme-Default-Font)",
|
||||||
"宋体(SimSun)",
|
"宋体(SimSun)",
|
||||||
"黑体(SimHei)",
|
"黑体(SimHei)",
|
||||||
"楷体(KaiTi)",
|
"楷体(KaiTi)",
|
||||||
"仿宋(FangSong)",
|
"仿宋(FangSong)",
|
||||||
"华文细黑(STHeiti Light)",
|
"华文细黑(STHeiti Light)",
|
||||||
"华文楷体(STKaiti)",
|
"华文楷体(STKaiti)",
|
||||||
"华文仿宋(STFangsong)",
|
"华文仿宋(STFangsong)",
|
||||||
"华文宋体(STSong)",
|
"华文宋体(STSong)",
|
||||||
"华文中宋(STZhongsong)",
|
"华文中宋(STZhongsong)",
|
||||||
"华文新魏(STXinwei)",
|
"华文新魏(STXinwei)",
|
||||||
"华文隶书(STLiti)",
|
"华文隶书(STLiti)",
|
||||||
# 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)"
|
# 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)"
|
||||||
"思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)",
|
"思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)",
|
||||||
"月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)",
|
"月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)",
|
||||||
"珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)",
|
"珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)",
|
||||||
@@ -355,6 +355,10 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran
|
|||||||
JINA_API_KEY = ""
|
JINA_API_KEY = ""
|
||||||
|
|
||||||
|
|
||||||
|
# SEMANTIC SCHOLAR API KEY
|
||||||
|
SEMANTIC_SCHOLAR_KEY = ""
|
||||||
|
|
||||||
|
|
||||||
# 是否自动裁剪上下文长度(是否启动,默认不启动)
|
# 是否自动裁剪上下文长度(是否启动,默认不启动)
|
||||||
AUTO_CONTEXT_CLIP_ENABLE = False
|
AUTO_CONTEXT_CLIP_ENABLE = False
|
||||||
# 目标裁剪上下文的token长度(如果超过这个长度,则会自动裁剪)
|
# 目标裁剪上下文的token长度(如果超过这个长度,则会自动裁剪)
|
||||||
|
|||||||
@@ -0,0 +1,290 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from textwrap import dedent
|
||||||
|
from toolbox import CatchException, get_conf, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
|
||||||
|
from toolbox import update_ui, CatchException, report_exception, write_history_to_file
|
||||||
|
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
|
||||||
|
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
|
||||||
|
from crazy_functions.review_fns.query_analyzer import QueryAnalyzer
|
||||||
|
from crazy_functions.review_fns.handlers.review_handler import 文献综述功能
|
||||||
|
from crazy_functions.review_fns.handlers.recommend_handler import 论文推荐功能
|
||||||
|
from crazy_functions.review_fns.handlers.qa_handler import 学术问答功能
|
||||||
|
from crazy_functions.review_fns.handlers.paper_handler import 单篇论文分析功能
|
||||||
|
from crazy_functions.Conversation_To_File import write_chat_to_file
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
|
from crazy_functions.review_fns.handlers.latest_handler import Arxiv最新论文推荐功能
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
@CatchException
|
||||||
|
def 学术对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
|
||||||
|
history: List, system_prompt: str, user_request: str):
|
||||||
|
"""主函数"""
|
||||||
|
|
||||||
|
# 初始化数据源
|
||||||
|
arxiv_source = ArxivSource()
|
||||||
|
semantic_source = SemanticScholarSource(
|
||||||
|
api_key=get_conf("SEMANTIC_SCHOLAR_KEY")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化处理器
|
||||||
|
handlers = {
|
||||||
|
"review": 文献综述功能(arxiv_source, semantic_source, llm_kwargs),
|
||||||
|
"recommend": 论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
|
||||||
|
"qa": 学术问答功能(arxiv_source, semantic_source, llm_kwargs),
|
||||||
|
"paper": 单篇论文分析功能(arxiv_source, semantic_source, llm_kwargs),
|
||||||
|
"latest": Arxiv最新论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 分析查询意图
|
||||||
|
chatbot.append([None, "正在分析研究主题和查询要求..."])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
query_analyzer = QueryAnalyzer()
|
||||||
|
search_criteria = yield from query_analyzer.analyze_query(txt, chatbot, llm_kwargs)
|
||||||
|
handler = handlers.get(search_criteria.query_type)
|
||||||
|
if not handler:
|
||||||
|
handler = handlers["qa"] # 默认使用QA处理器
|
||||||
|
|
||||||
|
# 处理查询
|
||||||
|
chatbot.append([None, f"使用{handler.__class__.__name__}处理...,可能需要您耐心等待3~5分钟..."])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
final_prompt = asyncio.run(handler.handle(
|
||||||
|
criteria=search_criteria,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=history,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
plugin_kwargs=plugin_kwargs
|
||||||
|
))
|
||||||
|
|
||||||
|
if final_prompt:
|
||||||
|
# 检查是否是道歉提示
|
||||||
|
if "很抱歉,我们未能找到" in final_prompt:
|
||||||
|
chatbot.append([txt, final_prompt])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
return
|
||||||
|
# 在 final_prompt 末尾添加用户原始查询要求
|
||||||
|
final_prompt += dedent(f"""
|
||||||
|
Original user query: "{txt}"
|
||||||
|
|
||||||
|
IMPORTANT NOTE :
|
||||||
|
- Your response must directly address the user's original user query above
|
||||||
|
- While following the previous guidelines, prioritize answering what the user specifically asked
|
||||||
|
- Make sure your response format and content align with the user's expectations
|
||||||
|
- Do not translate paper titles, keep them in their original language
|
||||||
|
- Do not generate a reference list in your response - references will be handled separately
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 使用最终的prompt生成回答
|
||||||
|
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
|
||||||
|
inputs=final_prompt,
|
||||||
|
inputs_show_user=txt,
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history=[],
|
||||||
|
sys_prompt=f"You are a helpful academic assistant. Response in Chinese by default unless specified language is required in the user's query."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 获取文献列表
|
||||||
|
papers_list = handler.ranked_papers # 直接使用原始论文数据
|
||||||
|
|
||||||
|
# 在新的对话中添加格式化的参考文献列表
|
||||||
|
if papers_list:
|
||||||
|
references = ""
|
||||||
|
for idx, paper in enumerate(papers_list, 1):
|
||||||
|
# 构建作者列表
|
||||||
|
authors = paper.authors[:3]
|
||||||
|
if len(paper.authors) > 3:
|
||||||
|
authors.append("et al.")
|
||||||
|
authors_str = ", ".join(authors)
|
||||||
|
|
||||||
|
# 构建期刊指标信息
|
||||||
|
metrics = []
|
||||||
|
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||||
|
metrics.append(f"IF: {paper.if_factor}")
|
||||||
|
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||||
|
metrics.append(f"JCR: {paper.jcr_division}")
|
||||||
|
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||||
|
metrics.append(f"中科院分区: {paper.cas_division}")
|
||||||
|
metrics_str = f" [{', '.join(metrics)}]" if metrics else ""
|
||||||
|
|
||||||
|
# 构建DOI链接
|
||||||
|
doi_link = ""
|
||||||
|
if paper.doi:
|
||||||
|
if "arxiv.org" in str(paper.doi):
|
||||||
|
doi_url = paper.doi
|
||||||
|
else:
|
||||||
|
doi_url = f"https://doi.org/{paper.doi}"
|
||||||
|
doi_link = f" <a href='{doi_url}' target='_blank'>DOI: {paper.doi}</a>"
|
||||||
|
|
||||||
|
# 构建完整的引用
|
||||||
|
reference = f"[{idx}] {authors_str}. *{paper.title}*"
|
||||||
|
if paper.venue_name:
|
||||||
|
reference += f". {paper.venue_name}"
|
||||||
|
if paper.year:
|
||||||
|
reference += f", {paper.year}"
|
||||||
|
reference += metrics_str
|
||||||
|
if doi_link:
|
||||||
|
reference += f".{doi_link}"
|
||||||
|
reference += " \n"
|
||||||
|
|
||||||
|
references += reference
|
||||||
|
|
||||||
|
# 添加新的对话显示参考文献
|
||||||
|
chatbot.append(["参考文献如下:", references])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
|
|
||||||
|
# 2. 保存为不同格式
|
||||||
|
from .review_fns.conversation_doc.word_doc import WordFormatter
|
||||||
|
from .review_fns.conversation_doc.word2pdf import WordToPdfConverter
|
||||||
|
from .review_fns.conversation_doc.markdown_doc import MarkdownFormatter
|
||||||
|
from .review_fns.conversation_doc.html_doc import HtmlFormatter
|
||||||
|
|
||||||
|
# 创建保存目录
|
||||||
|
save_dir = get_log_folder(get_user(chatbot), plugin_name='chatscholar')
|
||||||
|
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
|
||||||
|
# 生成文件名
|
||||||
|
def get_safe_filename(txt, max_length=10):
|
||||||
|
# 获取文本前max_length个字符作为文件名
|
||||||
|
filename = txt[:max_length].strip()
|
||||||
|
# 移除不安全的文件名字符
|
||||||
|
filename = re.sub(r'[\\/:*?"<>|]', '', filename)
|
||||||
|
# 如果文件名为空,使用时间戳
|
||||||
|
if not filename:
|
||||||
|
filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
|
return filename
|
||||||
|
|
||||||
|
base_filename = get_safe_filename(txt)
|
||||||
|
|
||||||
|
result_files = [] # 收集所有生成的文件
|
||||||
|
pdf_path = None # 用于跟踪PDF是否成功生成
|
||||||
|
|
||||||
|
# 保存为Markdown
|
||||||
|
try:
|
||||||
|
md_formatter = MarkdownFormatter()
|
||||||
|
md_content = md_formatter.create_document(txt, response, papers_list)
|
||||||
|
result_file_md = write_history_to_file(
|
||||||
|
history=[md_content],
|
||||||
|
file_basename=f"markdown_{base_filename}.md"
|
||||||
|
)
|
||||||
|
result_files.append(result_file_md)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Markdown保存失败: {str(e)}")
|
||||||
|
|
||||||
|
# 保存为HTML
|
||||||
|
try:
|
||||||
|
html_formatter = HtmlFormatter()
|
||||||
|
html_content = html_formatter.create_document(txt, response, papers_list)
|
||||||
|
result_file_html = write_history_to_file(
|
||||||
|
history=[html_content],
|
||||||
|
file_basename=f"html_{base_filename}.html"
|
||||||
|
)
|
||||||
|
result_files.append(result_file_html)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"HTML保存失败: {str(e)}")
|
||||||
|
|
||||||
|
# 保存为Word
|
||||||
|
try:
|
||||||
|
word_formatter = WordFormatter()
|
||||||
|
try:
|
||||||
|
doc = word_formatter.create_document(txt, response, papers_list)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Word文档内容生成失败: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
try:
|
||||||
|
result_file_docx = os.path.join(
|
||||||
|
os.path.dirname(result_file_md) if result_file_md else save_dir,
|
||||||
|
f"docx_{base_filename}.docx"
|
||||||
|
)
|
||||||
|
doc.save(result_file_docx)
|
||||||
|
result_files.append(result_file_docx)
|
||||||
|
print(f"Word文档已保存到: {result_file_docx}")
|
||||||
|
|
||||||
|
# 转换为PDF
|
||||||
|
try:
|
||||||
|
pdf_path = WordToPdfConverter.convert_to_pdf(result_file_docx)
|
||||||
|
if pdf_path:
|
||||||
|
result_files.append(pdf_path)
|
||||||
|
print(f"PDF文档已生成: {pdf_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"PDF转换失败: {str(e)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Word文档保存失败: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Word格式化失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(f"详细错误信息: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
# 保存为BibTeX格式
|
||||||
|
try:
|
||||||
|
from .review_fns.conversation_doc.reference_formatter import ReferenceFormatter
|
||||||
|
ref_formatter = ReferenceFormatter()
|
||||||
|
bibtex_content = ref_formatter.create_document(papers_list)
|
||||||
|
|
||||||
|
# 在与其他文件相同目录下创建BibTeX文件
|
||||||
|
result_file_bib = os.path.join(
|
||||||
|
os.path.dirname(result_file_md) if result_file_md else save_dir,
|
||||||
|
f"references_{base_filename}.bib"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 直接写入文件
|
||||||
|
with open(result_file_bib, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(bibtex_content)
|
||||||
|
|
||||||
|
result_files.append(result_file_bib)
|
||||||
|
print(f"BibTeX文件已保存到: {result_file_bib}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"BibTeX格式保存失败: {str(e)}")
|
||||||
|
|
||||||
|
# 保存为EndNote格式
|
||||||
|
try:
|
||||||
|
from .review_fns.conversation_doc.endnote_doc import EndNoteFormatter
|
||||||
|
endnote_formatter = EndNoteFormatter()
|
||||||
|
endnote_content = endnote_formatter.create_document(papers_list)
|
||||||
|
|
||||||
|
# 在与其他文件相同目录下创建EndNote文件
|
||||||
|
result_file_enw = os.path.join(
|
||||||
|
os.path.dirname(result_file_md) if result_file_md else save_dir,
|
||||||
|
f"references_{base_filename}.enw"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 直接写入文件
|
||||||
|
with open(result_file_enw, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(endnote_content)
|
||||||
|
|
||||||
|
result_files.append(result_file_enw)
|
||||||
|
print(f"EndNote文件已保存到: {result_file_enw}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"EndNote格式保存失败: {str(e)}")
|
||||||
|
|
||||||
|
# 添加所有文件到下载区
|
||||||
|
success_files = []
|
||||||
|
for file in result_files:
|
||||||
|
try:
|
||||||
|
promote_file_to_downloadzone(file, chatbot=chatbot)
|
||||||
|
success_files.append(os.path.basename(file))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"文件添加到下载区失败: {str(e)}")
|
||||||
|
|
||||||
|
# 更新成功提示消息
|
||||||
|
if success_files:
|
||||||
|
chatbot.append(["保存对话记录成功,bib和enw文件支持导入到EndNote、Zotero、JabRef、Mendeley等文献管理软件,HTML文件支持在浏览器中打开,里面包含详细论文源信息", "对话已保存并添加到下载区,可以在下载区找到相关文件"])
|
||||||
|
else:
|
||||||
|
chatbot.append(["保存对话记录", "所有格式的保存都失败了,请检查错误日志。"])
|
||||||
|
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
else:
|
||||||
|
report_exception(chatbot, history, a=f"处理失败", b=f"请尝试其他查询")
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
@@ -0,0 +1,275 @@
|
|||||||
|
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
|
||||||
|
from loguru import logger
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
|
||||||
|
def force_breakdown(txt, limit, get_token_fn):
|
||||||
|
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
|
||||||
|
"""
|
||||||
|
for i in reversed(range(len(txt))):
|
||||||
|
if get_token_fn(txt[:i]) < limit:
|
||||||
|
return txt[:i], txt[i:]
|
||||||
|
return "Tiktoken未知错误", "Tiktoken未知错误"
|
||||||
|
|
||||||
|
|
||||||
|
def maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage):
|
||||||
|
""" 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
|
||||||
|
当 remain_txt_to_cut < `_min` 时,我们再把 remain_txt_to_cut_storage 中的部分文字取出
|
||||||
|
"""
|
||||||
|
_min = int(5e4)
|
||||||
|
_max = int(1e5)
|
||||||
|
# print(len(remain_txt_to_cut), len(remain_txt_to_cut_storage))
|
||||||
|
if len(remain_txt_to_cut) < _min and len(remain_txt_to_cut_storage) > 0:
|
||||||
|
remain_txt_to_cut = remain_txt_to_cut + remain_txt_to_cut_storage
|
||||||
|
remain_txt_to_cut_storage = ""
|
||||||
|
if len(remain_txt_to_cut) > _max:
|
||||||
|
remain_txt_to_cut_storage = remain_txt_to_cut[_max:] + remain_txt_to_cut_storage
|
||||||
|
remain_txt_to_cut = remain_txt_to_cut[:_max]
|
||||||
|
return remain_txt_to_cut, remain_txt_to_cut_storage
|
||||||
|
|
||||||
|
|
||||||
|
def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=False):
|
||||||
|
""" 文本切分
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
total_len = len(txt_tocut)
|
||||||
|
fin_len = 0
|
||||||
|
remain_txt_to_cut = txt_tocut
|
||||||
|
remain_txt_to_cut_storage = ""
|
||||||
|
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
|
||||||
|
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if get_token_fn(remain_txt_to_cut) <= limit:
|
||||||
|
# 如果剩余文本的token数小于限制,那么就不用切了
|
||||||
|
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果剩余文本的token数大于限制,那么就切
|
||||||
|
lines = remain_txt_to_cut.split('\n')
|
||||||
|
|
||||||
|
# 估计一个切分点
|
||||||
|
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
|
||||||
|
estimated_line_cut = int(estimated_line_cut)
|
||||||
|
|
||||||
|
# 开始查找合适切分点的偏移(cnt)
|
||||||
|
cnt = 0
|
||||||
|
for cnt in reversed(range(estimated_line_cut)):
|
||||||
|
if must_break_at_empty_line:
|
||||||
|
# 首先尝试用双空行(\n\n)作为切分点
|
||||||
|
if lines[cnt] != "":
|
||||||
|
continue
|
||||||
|
prev = "\n".join(lines[:cnt])
|
||||||
|
post = "\n".join(lines[cnt:])
|
||||||
|
if get_token_fn(prev) < limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
if cnt == 0:
|
||||||
|
# 如果没有找到合适的切分点
|
||||||
|
if break_anyway:
|
||||||
|
# 是否允许暴力切分
|
||||||
|
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
|
||||||
|
else:
|
||||||
|
# 不允许直接报错
|
||||||
|
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
|
||||||
|
|
||||||
|
# 追加列表
|
||||||
|
res.append(prev); fin_len+=len(prev)
|
||||||
|
# 准备下一次迭代
|
||||||
|
remain_txt_to_cut = post
|
||||||
|
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||||
|
process = fin_len/total_len
|
||||||
|
logger.info(f'正在文本切分 {int(process*100)}%')
|
||||||
|
if len(remain_txt_to_cut.strip()) == 0:
|
||||||
|
break
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def breakdown_text_to_satisfy_token_limit_(txt, limit, llm_model="gpt-3.5-turbo"):
|
||||||
|
""" 使用多种方式尝试切分文本,以满足 token 限制
|
||||||
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
enc = model_info[llm_model]['tokenizer']
|
||||||
|
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||||
|
try:
|
||||||
|
# 第1次尝试,将双空行(\n\n)作为切分点
|
||||||
|
return cut(limit, get_token_fn, txt, must_break_at_empty_line=True)
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
# 第2次尝试,将单空行(\n)作为切分点
|
||||||
|
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False)
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
# 第3次尝试,将英文句号(.)作为切分点
|
||||||
|
res = cut(limit, get_token_fn, txt.replace('.', '。\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
|
||||||
|
return [r.replace('。\n', '.') for r in res]
|
||||||
|
except RuntimeError as e:
|
||||||
|
try:
|
||||||
|
# 第4次尝试,将中文句号(。)作为切分点
|
||||||
|
res = cut(limit, get_token_fn, txt.replace('。', '。。\n'), must_break_at_empty_line=False)
|
||||||
|
return [r.replace('。。\n', '。') for r in res]
|
||||||
|
except RuntimeError as e:
|
||||||
|
# 第5次尝试,没办法了,随便切一下吧
|
||||||
|
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
|
||||||
|
|
||||||
|
breakdown_text_to_satisfy_token_limit = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_, timeout=60)
|
||||||
|
|
||||||
|
def cut_new(limit, get_token_fn, txt_tocut, must_break_at_empty_line, must_break_at_one_empty_line=False, break_anyway=False):
|
||||||
|
""" 文本切分
|
||||||
|
"""
|
||||||
|
res = []
|
||||||
|
res_empty_line = []
|
||||||
|
total_len = len(txt_tocut)
|
||||||
|
fin_len = 0
|
||||||
|
remain_txt_to_cut = txt_tocut
|
||||||
|
remain_txt_to_cut_storage = ""
|
||||||
|
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
|
||||||
|
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||||
|
empty=0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if get_token_fn(remain_txt_to_cut) <= limit:
|
||||||
|
# 如果剩余文本的token数小于限制,那么就不用切了
|
||||||
|
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
|
||||||
|
res_empty_line.append(empty)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果剩余文本的token数大于限制,那么就切
|
||||||
|
lines = remain_txt_to_cut.split('\n')
|
||||||
|
|
||||||
|
# 估计一个切分点
|
||||||
|
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
|
||||||
|
estimated_line_cut = int(estimated_line_cut)
|
||||||
|
|
||||||
|
# 开始查找合适切分点的偏移(cnt)
|
||||||
|
cnt = 0
|
||||||
|
for cnt in reversed(range(estimated_line_cut)):
|
||||||
|
if must_break_at_empty_line:
|
||||||
|
# 首先尝试用双空行(\n\n)作为切分点
|
||||||
|
if lines[cnt] != "":
|
||||||
|
continue
|
||||||
|
if must_break_at_empty_line or must_break_at_one_empty_line:
|
||||||
|
empty=1
|
||||||
|
prev = "\n".join(lines[:cnt])
|
||||||
|
post = "\n".join(lines[cnt:])
|
||||||
|
if get_token_fn(prev) < limit :
|
||||||
|
break
|
||||||
|
# empty=0
|
||||||
|
if get_token_fn(prev)>limit:
|
||||||
|
if '.' not in prev or '。' not in prev:
|
||||||
|
# empty = 0
|
||||||
|
break
|
||||||
|
|
||||||
|
# if cnt
|
||||||
|
if cnt == 0:
|
||||||
|
# 如果没有找到合适的切分点
|
||||||
|
if break_anyway:
|
||||||
|
# 是否允许暴力切分
|
||||||
|
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
|
||||||
|
empty =0
|
||||||
|
else:
|
||||||
|
# 不允许直接报错
|
||||||
|
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
|
||||||
|
|
||||||
|
# 追加列表
|
||||||
|
res.append(prev); fin_len+=len(prev)
|
||||||
|
res_empty_line.append(empty)
|
||||||
|
# 准备下一次迭代
|
||||||
|
remain_txt_to_cut = post
|
||||||
|
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
|
||||||
|
process = fin_len/total_len
|
||||||
|
logger.info(f'正在文本切分 {int(process*100)}%')
|
||||||
|
if len(remain_txt_to_cut.strip()) == 0:
|
||||||
|
break
|
||||||
|
return res,res_empty_line
|
||||||
|
|
||||||
|
|
||||||
|
def breakdown_text_to_satisfy_token_limit_new_(txt, limit, llm_model="gpt-3.5-turbo"):
|
||||||
|
""" 使用多种方式尝试切分文本,以满足 token 限制
|
||||||
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
enc = model_info[llm_model]['tokenizer']
|
||||||
|
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||||
|
try:
|
||||||
|
# 第1次尝试,将双空行(\n\n)作为切分点
|
||||||
|
res, empty_line =cut_new(limit, get_token_fn, txt, must_break_at_empty_line=True)
|
||||||
|
return res,empty_line
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
# 第2次尝试,将单空行(\n)作为切分点
|
||||||
|
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False,must_break_at_one_empty_line=True)
|
||||||
|
return res, _
|
||||||
|
except RuntimeError:
|
||||||
|
try:
|
||||||
|
# 第3次尝试,将英文句号(.)作为切分点
|
||||||
|
res, _ = cut_new(limit, get_token_fn, txt.replace('.', '。\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
|
||||||
|
return [r.replace('。\n', '.') for r in res],_
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
try:
|
||||||
|
# 第4次尝试,将中文句号(。)作为切分点
|
||||||
|
res,_ = cut_new(limit, get_token_fn, txt.replace('。', '。。\n'), must_break_at_empty_line=False)
|
||||||
|
return [r.replace('。。\n', '。') for r in res], _
|
||||||
|
except RuntimeError as e:
|
||||||
|
# 第5次尝试,没办法了,随便切一下吧
|
||||||
|
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
|
||||||
|
return res,_
|
||||||
|
breakdown_text_to_satisfy_token_limit_new = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_new_, timeout=60)
|
||||||
|
|
||||||
|
def cut_from_end_to_satisfy_token_limit_(txt, limit, reserve_token=500, llm_model="gpt-3.5-turbo"):
|
||||||
|
"""从后往前裁剪文本,以论文为单位进行裁剪
|
||||||
|
|
||||||
|
参数:
|
||||||
|
txt: 要处理的文本(格式化后的论文列表字符串)
|
||||||
|
limit: token数量上限
|
||||||
|
reserve_token: 需要预留的token数量,默认500
|
||||||
|
llm_model: 使用的模型名称
|
||||||
|
返回:
|
||||||
|
裁剪后的文本
|
||||||
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
enc = model_info[llm_model]['tokenizer']
|
||||||
|
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
|
||||||
|
|
||||||
|
# 计算当前文本的token数
|
||||||
|
current_tokens = get_token_fn(txt)
|
||||||
|
target_limit = limit - reserve_token
|
||||||
|
|
||||||
|
# 如果当前token数已经在限制范围内,直接返回
|
||||||
|
if current_tokens <= target_limit:
|
||||||
|
return txt
|
||||||
|
|
||||||
|
# 按论文编号分割文本
|
||||||
|
papers = re.split(r'\n(?=\d+\. \*\*)', txt)
|
||||||
|
if not papers:
|
||||||
|
return txt
|
||||||
|
|
||||||
|
# 从前往后累加论文,直到达到token限制
|
||||||
|
result = papers[0] # 保留第一篇
|
||||||
|
current_tokens = get_token_fn(result)
|
||||||
|
|
||||||
|
for paper in papers[1:]:
|
||||||
|
paper_tokens = get_token_fn(paper)
|
||||||
|
if current_tokens + paper_tokens <= target_limit:
|
||||||
|
result += "\n" + paper
|
||||||
|
current_tokens += paper_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 添加超时保护
|
||||||
|
cut_from_end_to_satisfy_token_limit = run_in_subprocess_with_timeout(cut_from_end_to_satisfy_token_limit_, timeout=20)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from crazy_functions.crazy_utils import read_and_clean_pdf_text
|
||||||
|
file_content, page_one = read_and_clean_pdf_text("build/assets/at.pdf")
|
||||||
|
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
for i in range(5):
|
||||||
|
file_content += file_content
|
||||||
|
|
||||||
|
logger.info(len(file_content))
|
||||||
|
TOKEN_LIMIT_PER_FRAGMENT = 2500
|
||||||
|
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)
|
||||||
|
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
from typing import List
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||||
|
|
||||||
|
class EndNoteFormatter:
|
||||||
|
"""EndNote参考文献格式生成器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_document(self, papers: List[PaperMetadata]) -> str:
|
||||||
|
"""生成EndNote格式的参考文献文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
papers: 论文列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: EndNote格式的参考文献文本
|
||||||
|
"""
|
||||||
|
endnote_text = ""
|
||||||
|
|
||||||
|
for paper in papers:
|
||||||
|
# 开始一个新条目
|
||||||
|
endnote_text += "%0 Journal Article\n" # 默认类型为期刊文章
|
||||||
|
|
||||||
|
# 根据venue_type调整条目类型
|
||||||
|
if hasattr(paper, 'venue_type') and paper.venue_type:
|
||||||
|
if paper.venue_type.lower() == 'conference':
|
||||||
|
endnote_text = endnote_text.replace("Journal Article", "Conference Paper")
|
||||||
|
elif paper.venue_type.lower() == 'preprint':
|
||||||
|
endnote_text = endnote_text.replace("Journal Article", "Electronic Article")
|
||||||
|
|
||||||
|
# 添加标题
|
||||||
|
endnote_text += f"%T {paper.title}\n"
|
||||||
|
|
||||||
|
# 添加作者
|
||||||
|
for author in paper.authors:
|
||||||
|
endnote_text += f"%A {author}\n"
|
||||||
|
|
||||||
|
# 添加年份
|
||||||
|
if paper.year:
|
||||||
|
endnote_text += f"%D {paper.year}\n"
|
||||||
|
|
||||||
|
# 添加期刊/会议名称
|
||||||
|
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||||
|
endnote_text += f"%J {paper.venue_name}\n"
|
||||||
|
elif paper.venue:
|
||||||
|
endnote_text += f"%J {paper.venue}\n"
|
||||||
|
|
||||||
|
# 添加DOI
|
||||||
|
if paper.doi:
|
||||||
|
endnote_text += f"%R {paper.doi}\n"
|
||||||
|
endnote_text += f"%U https://doi.org/{paper.doi}\n"
|
||||||
|
elif paper.url:
|
||||||
|
endnote_text += f"%U {paper.url}\n"
|
||||||
|
|
||||||
|
# 添加摘要
|
||||||
|
if paper.abstract:
|
||||||
|
endnote_text += f"%X {paper.abstract}\n"
|
||||||
|
|
||||||
|
# 添加机构
|
||||||
|
if hasattr(paper, 'institutions'):
|
||||||
|
for institution in paper.institutions:
|
||||||
|
endnote_text += f"%I {institution}\n"
|
||||||
|
|
||||||
|
# 条目之间添加空行
|
||||||
|
endnote_text += "\n"
|
||||||
|
|
||||||
|
return endnote_text
|
||||||
@@ -0,0 +1,211 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelTableFormatter:
|
||||||
|
"""聊天记录中Markdown表格转Excel生成器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化Excel文档对象"""
|
||||||
|
from openpyxl import Workbook
|
||||||
|
self.workbook = Workbook()
|
||||||
|
self._table_count = 0
|
||||||
|
self._current_sheet = None
|
||||||
|
|
||||||
|
def _normalize_table_row(self, row):
|
||||||
|
"""标准化表格行,处理不同的分隔符情况"""
|
||||||
|
row = row.strip()
|
||||||
|
if row.startswith('|'):
|
||||||
|
row = row[1:]
|
||||||
|
if row.endswith('|'):
|
||||||
|
row = row[:-1]
|
||||||
|
return [cell.strip() for cell in row.split('|')]
|
||||||
|
|
||||||
|
def _is_separator_row(self, row):
|
||||||
|
"""检查是否是分隔行(由 - 或 : 组成)"""
|
||||||
|
clean_row = re.sub(r'[\s|]', '', row)
|
||||||
|
return bool(re.match(r'^[-:]+$', clean_row))
|
||||||
|
|
||||||
|
def _extract_tables_from_text(self, text):
|
||||||
|
"""从文本中提取所有表格内容"""
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return []
|
||||||
|
|
||||||
|
tables = []
|
||||||
|
current_table = []
|
||||||
|
is_in_table = False
|
||||||
|
|
||||||
|
for line in text.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
if is_in_table and current_table:
|
||||||
|
if len(current_table) >= 2:
|
||||||
|
tables.append(current_table)
|
||||||
|
current_table = []
|
||||||
|
is_in_table = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if '|' in line:
|
||||||
|
if not is_in_table:
|
||||||
|
is_in_table = True
|
||||||
|
current_table.append(line)
|
||||||
|
else:
|
||||||
|
if is_in_table and current_table:
|
||||||
|
if len(current_table) >= 2:
|
||||||
|
tables.append(current_table)
|
||||||
|
current_table = []
|
||||||
|
is_in_table = False
|
||||||
|
|
||||||
|
if is_in_table and current_table and len(current_table) >= 2:
|
||||||
|
tables.append(current_table)
|
||||||
|
|
||||||
|
return tables
|
||||||
|
|
||||||
|
def _parse_table(self, table_lines):
|
||||||
|
"""解析表格内容为结构化数据"""
|
||||||
|
try:
|
||||||
|
headers = self._normalize_table_row(table_lines[0])
|
||||||
|
|
||||||
|
separator_index = next(
|
||||||
|
(i for i, line in enumerate(table_lines) if self._is_separator_row(line)),
|
||||||
|
1
|
||||||
|
)
|
||||||
|
|
||||||
|
data_rows = []
|
||||||
|
for line in table_lines[separator_index + 1:]:
|
||||||
|
cells = self._normalize_table_row(line)
|
||||||
|
# 确保单元格数量与表头一致
|
||||||
|
while len(cells) < len(headers):
|
||||||
|
cells.append('')
|
||||||
|
cells = cells[:len(headers)]
|
||||||
|
data_rows.append(cells)
|
||||||
|
|
||||||
|
if headers and data_rows:
|
||||||
|
return {
|
||||||
|
'headers': headers,
|
||||||
|
'data': data_rows
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析表格时发生错误: {str(e)}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_sheet(self, question_num, table_num):
|
||||||
|
"""创建新的工作表"""
|
||||||
|
sheet_name = f'Q{question_num}_T{table_num}'
|
||||||
|
if len(sheet_name) > 31:
|
||||||
|
sheet_name = f'Table{self._table_count}'
|
||||||
|
|
||||||
|
if sheet_name in self.workbook.sheetnames:
|
||||||
|
sheet_name = f'{sheet_name}_{datetime.now().strftime("%H%M%S")}'
|
||||||
|
|
||||||
|
return self.workbook.create_sheet(title=sheet_name)
|
||||||
|
|
||||||
|
def create_document(self, history):
|
||||||
|
"""
|
||||||
|
处理聊天历史中的所有表格并创建Excel文档
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: 聊天历史列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Workbook: 处理完成的Excel工作簿对象,如果没有表格则返回None
|
||||||
|
"""
|
||||||
|
has_tables = False
|
||||||
|
|
||||||
|
# 删除默认创建的工作表
|
||||||
|
default_sheet = self.workbook['Sheet']
|
||||||
|
self.workbook.remove(default_sheet)
|
||||||
|
|
||||||
|
# 遍历所有回答
|
||||||
|
for i in range(1, len(history), 2):
|
||||||
|
answer = history[i]
|
||||||
|
tables = self._extract_tables_from_text(answer)
|
||||||
|
|
||||||
|
for table_lines in tables:
|
||||||
|
parsed_table = self._parse_table(table_lines)
|
||||||
|
if parsed_table:
|
||||||
|
self._table_count += 1
|
||||||
|
sheet = self._create_sheet(i // 2 + 1, self._table_count)
|
||||||
|
|
||||||
|
# 写入表头
|
||||||
|
for col, header in enumerate(parsed_table['headers'], 1):
|
||||||
|
sheet.cell(row=1, column=col, value=header)
|
||||||
|
|
||||||
|
# 写入数据
|
||||||
|
for row_idx, row_data in enumerate(parsed_table['data'], 2):
|
||||||
|
for col_idx, value in enumerate(row_data, 1):
|
||||||
|
sheet.cell(row=row_idx, column=col_idx, value=value)
|
||||||
|
|
||||||
|
has_tables = True
|
||||||
|
|
||||||
|
return self.workbook if has_tables else None
|
||||||
|
|
||||||
|
|
||||||
|
def save_chat_tables(history, save_dir, base_name):
|
||||||
|
"""
|
||||||
|
保存聊天历史中的表格到Excel文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: 聊天历史列表
|
||||||
|
save_dir: 保存目录
|
||||||
|
base_name: 基础文件名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 保存的文件路径列表
|
||||||
|
"""
|
||||||
|
result_files = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建Excel格式
|
||||||
|
excel_formatter = ExcelTableFormatter()
|
||||||
|
workbook = excel_formatter.create_document(history)
|
||||||
|
|
||||||
|
if workbook is not None:
|
||||||
|
# 确保保存目录存在
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成Excel文件路径
|
||||||
|
excel_file = os.path.join(save_dir, base_name + '.xlsx')
|
||||||
|
|
||||||
|
# 保存Excel文件
|
||||||
|
workbook.save(excel_file)
|
||||||
|
result_files.append(excel_file)
|
||||||
|
print(f"已保存表格到Excel文件: {excel_file}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存Excel格式失败: {str(e)}")
|
||||||
|
|
||||||
|
return result_files
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 示例聊天历史
|
||||||
|
history = [
|
||||||
|
"问题1",
|
||||||
|
"""这是第一个表格:
|
||||||
|
| A | B | C |
|
||||||
|
|---|---|---|
|
||||||
|
| 1 | 2 | 3 |""",
|
||||||
|
|
||||||
|
"问题2",
|
||||||
|
"这是没有表格的回答",
|
||||||
|
|
||||||
|
"问题3",
|
||||||
|
"""回答包含多个表格:
|
||||||
|
| Name | Age |
|
||||||
|
|------|-----|
|
||||||
|
| Tom | 20 |
|
||||||
|
|
||||||
|
第二个表格:
|
||||||
|
| X | Y |
|
||||||
|
|---|---|
|
||||||
|
| 1 | 2 |"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# 保存表格
|
||||||
|
save_dir = "output"
|
||||||
|
base_name = "chat_tables"
|
||||||
|
saved_files = save_chat_tables(history, save_dir, base_name)
|
||||||
@@ -0,0 +1,472 @@
|
|||||||
|
class HtmlFormatter:
|
||||||
|
"""聊天记录HTML格式生成器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.css_styles = """
|
||||||
|
:root {
|
||||||
|
--primary-color: #2563eb;
|
||||||
|
--primary-light: #eff6ff;
|
||||||
|
--secondary-color: #1e293b;
|
||||||
|
--background-color: #f8fafc;
|
||||||
|
--text-color: #334155;
|
||||||
|
--border-color: #e2e8f0;
|
||||||
|
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||||
|
line-height: 1.8;
|
||||||
|
margin: 0;
|
||||||
|
padding: 2rem;
|
||||||
|
color: var(--text-color);
|
||||||
|
background-color: var(--background-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
max-width: 1200px;
|
||||||
|
margin: 0 auto;
|
||||||
|
background: white;
|
||||||
|
padding: 2rem;
|
||||||
|
border-radius: 16px;
|
||||||
|
box-shadow: var(--card-shadow);
|
||||||
|
}
|
||||||
|
::selection {
|
||||||
|
background: var(--primary-light);
|
||||||
|
color: var(--primary-color);
|
||||||
|
}
|
||||||
|
@keyframes fadeIn {
|
||||||
|
from { opacity: 0; transform: translateY(20px); }
|
||||||
|
to { opacity: 1; transform: translateY(0); }
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes slideIn {
|
||||||
|
from { transform: translateX(-20px); opacity: 0; }
|
||||||
|
to { transform: translateX(0); opacity: 1; }
|
||||||
|
}
|
||||||
|
|
||||||
|
.container {
|
||||||
|
animation: fadeIn 0.6s ease-out;
|
||||||
|
}
|
||||||
|
|
||||||
|
.QaBox {
|
||||||
|
animation: slideIn 0.5s ease-out;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.QaBox:hover {
|
||||||
|
transform: translateX(5px);
|
||||||
|
}
|
||||||
|
.Question, .Answer, .historyBox {
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
.chat-title {
|
||||||
|
color: var(--primary-color);
|
||||||
|
font-size: 2em;
|
||||||
|
text-align: center;
|
||||||
|
margin: 1rem 0 2rem;
|
||||||
|
padding-bottom: 1rem;
|
||||||
|
border-bottom: 2px solid var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-body {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1.5rem;
|
||||||
|
margin: 2rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.QaBox {
|
||||||
|
background: white;
|
||||||
|
padding: 1.5rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
border-left: 4px solid var(--primary-color);
|
||||||
|
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.Question {
|
||||||
|
color: var(--secondary-color);
|
||||||
|
font-weight: 500;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.Answer {
|
||||||
|
color: var(--text-color);
|
||||||
|
background: var(--primary-light);
|
||||||
|
padding: 1rem;
|
||||||
|
border-radius: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.history-section {
|
||||||
|
margin-top: 3rem;
|
||||||
|
padding-top: 2rem;
|
||||||
|
border-top: 2px solid var(--border-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.history-title {
|
||||||
|
color: var(--secondary-color);
|
||||||
|
font-size: 1.5em;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.historyBox {
|
||||||
|
background: white;
|
||||||
|
padding: 1rem;
|
||||||
|
margin: 0.5rem 0;
|
||||||
|
border-radius: 6px;
|
||||||
|
border: 1px solid var(--border-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
:root {
|
||||||
|
--background-color: #0f172a;
|
||||||
|
--text-color: #e2e8f0;
|
||||||
|
--border-color: #1e293b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.container, .QaBox {
|
||||||
|
background: #1e293b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
|
||||||
|
"""生成完整的HTML文档
|
||||||
|
Args:
|
||||||
|
question: str, 用户问题
|
||||||
|
answer: str, AI回答
|
||||||
|
ranked_papers: list, 排序后的论文列表
|
||||||
|
Returns:
|
||||||
|
str: 完整的HTML文档字符串
|
||||||
|
"""
|
||||||
|
chat_content = f'''
|
||||||
|
<div class="QaBox">
|
||||||
|
<div class="Question">{question}</div>
|
||||||
|
<div class="Answer markdown-body" id="answer-content">{answer}</div>
|
||||||
|
</div>
|
||||||
|
'''
|
||||||
|
|
||||||
|
references_content = ""
|
||||||
|
if ranked_papers:
|
||||||
|
references_content = '<div class="history-section"><h2 class="history-title">参考文献</h2>'
|
||||||
|
for idx, paper in enumerate(ranked_papers, 1):
|
||||||
|
authors = ', '.join(paper.authors)
|
||||||
|
|
||||||
|
# 构建引用信息
|
||||||
|
citations_info = f"被引用次数:{paper.citations}" if paper.citations is not None else "引用信息未知"
|
||||||
|
|
||||||
|
# 构建下载链接
|
||||||
|
download_links = []
|
||||||
|
if paper.doi:
|
||||||
|
# 检查是否是arXiv链接
|
||||||
|
if 'arxiv.org' in paper.doi:
|
||||||
|
# 如果DOI中包含完整的arXiv URL,直接使用
|
||||||
|
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
|
||||||
|
download_links.append(f'<a href="{arxiv_url}">arXiv链接</a>')
|
||||||
|
# 提取arXiv ID并添加PDF链接
|
||||||
|
arxiv_id = arxiv_url.split('abs/')[-1].split('v')[0]
|
||||||
|
download_links.append(f'<a href="https://arxiv.org/pdf/{arxiv_id}.pdf">PDF下载</a>')
|
||||||
|
else:
|
||||||
|
# 非arXiv的DOI使用标准格式
|
||||||
|
download_links.append(f'<a href="https://doi.org/{paper.doi}">DOI: {paper.doi}</a>')
|
||||||
|
|
||||||
|
if hasattr(paper, 'url') and paper.url and 'arxiv.org' not in str(paper.url):
|
||||||
|
# 只有当URL不是arXiv链接时才添加
|
||||||
|
download_links.append(f'<a href="{paper.url}">原文链接</a>')
|
||||||
|
download_section = ' | '.join(download_links) if download_links else "无直接下载链接"
|
||||||
|
|
||||||
|
# 构建来源信息
|
||||||
|
source_info = []
|
||||||
|
if paper.venue_type:
|
||||||
|
source_info.append(f"类型:{paper.venue_type}")
|
||||||
|
if paper.venue_name:
|
||||||
|
source_info.append(f"来源:{paper.venue_name}")
|
||||||
|
|
||||||
|
# 添加期刊指标信息
|
||||||
|
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||||
|
source_info.append(f"<span class='journal-metric'>IF: {paper.if_factor}</span>")
|
||||||
|
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||||
|
source_info.append(f"<span class='journal-metric'>JCR分区: {paper.jcr_division}</span>")
|
||||||
|
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||||
|
source_info.append(f"<span class='journal-metric'>中科院分区: {paper.cas_division}</span>")
|
||||||
|
|
||||||
|
if hasattr(paper, 'venue_info') and paper.venue_info:
|
||||||
|
if paper.venue_info.get('journal_ref'):
|
||||||
|
source_info.append(f"期刊参考:{paper.venue_info['journal_ref']}")
|
||||||
|
if paper.venue_info.get('publisher'):
|
||||||
|
source_info.append(f"出版商:{paper.venue_info['publisher']}")
|
||||||
|
source_section = ' | '.join(source_info) if source_info else ""
|
||||||
|
|
||||||
|
# 构建标准引用格式
|
||||||
|
standard_citation = f"[{idx}] "
|
||||||
|
# 添加作者(最多3个,超过则添加et al.)
|
||||||
|
author_list = paper.authors[:3]
|
||||||
|
if len(paper.authors) > 3:
|
||||||
|
author_list.append("et al.")
|
||||||
|
standard_citation += ", ".join(author_list) + ". "
|
||||||
|
# 添加标题
|
||||||
|
standard_citation += f"<i>{paper.title}</i>"
|
||||||
|
# 添加期刊/会议名称
|
||||||
|
if paper.venue_name:
|
||||||
|
standard_citation += f". {paper.venue_name}"
|
||||||
|
# 添加年份
|
||||||
|
if paper.year:
|
||||||
|
standard_citation += f", {paper.year}"
|
||||||
|
# 添加DOI
|
||||||
|
if paper.doi:
|
||||||
|
if 'arxiv.org' in paper.doi:
|
||||||
|
# 如果是arXiv链接,直接使用arXiv URL
|
||||||
|
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
|
||||||
|
standard_citation += f". {arxiv_url}"
|
||||||
|
else:
|
||||||
|
# 非arXiv的DOI使用标准格式
|
||||||
|
standard_citation += f". DOI: {paper.doi}"
|
||||||
|
standard_citation += "."
|
||||||
|
|
||||||
|
references_content += f'''
|
||||||
|
<div class="historyBox">
|
||||||
|
<div class="entry">
|
||||||
|
<p class="paper-title"><b>[{idx}]</b> <i>{paper.title}</i></p>
|
||||||
|
<p class="paper-authors">作者:{authors}</p>
|
||||||
|
<p class="paper-year">发表年份:{paper.year if paper.year else "未知"}</p>
|
||||||
|
<p class="paper-citations">{citations_info}</p>
|
||||||
|
{f'<p class="paper-source">{source_section}</p>' if source_section else ""}
|
||||||
|
<p class="paper-abstract">摘要:{paper.abstract if paper.abstract else "无摘要"}</p>
|
||||||
|
<p class="paper-links">链接:{download_section}</p>
|
||||||
|
<div class="standard-citation">
|
||||||
|
<p class="citation-title">标准引用格式:</p>
|
||||||
|
<p class="citation-text">{standard_citation}</p>
|
||||||
|
<button class="copy-btn" onclick="copyToClipboard(this.previousElementSibling)">复制引用格式</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
'''
|
||||||
|
references_content += '</div>'
|
||||||
|
|
||||||
|
# 添加新的CSS样式
|
||||||
|
css_additions = """
|
||||||
|
.paper-title {
|
||||||
|
font-size: 1.1em;
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
}
|
||||||
|
.paper-authors {
|
||||||
|
color: var(--secondary-color);
|
||||||
|
margin: 0.3em 0;
|
||||||
|
}
|
||||||
|
.paper-year, .paper-citations {
|
||||||
|
color: var(--text-color);
|
||||||
|
margin: 0.3em 0;
|
||||||
|
}
|
||||||
|
.paper-source {
|
||||||
|
color: var(--text-color);
|
||||||
|
font-style: italic;
|
||||||
|
margin: 0.3em 0;
|
||||||
|
}
|
||||||
|
.paper-abstract {
|
||||||
|
margin: 0.8em 0;
|
||||||
|
padding: 0.8em;
|
||||||
|
background: var(--primary-light);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.paper-links {
|
||||||
|
margin-top: 0.5em;
|
||||||
|
}
|
||||||
|
.paper-links a {
|
||||||
|
color: var(--primary-color);
|
||||||
|
text-decoration: none;
|
||||||
|
margin-right: 1em;
|
||||||
|
}
|
||||||
|
.paper-links a:hover {
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
.standard-citation {
|
||||||
|
margin-top: 1em;
|
||||||
|
padding: 1em;
|
||||||
|
background: #f8fafc;
|
||||||
|
border-radius: 4px;
|
||||||
|
border: 1px solid var(--border-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.citation-title {
|
||||||
|
font-weight: bold;
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
color: var(--secondary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.citation-text {
|
||||||
|
font-family: 'Times New Roman', Times, serif;
|
||||||
|
line-height: 1.6;
|
||||||
|
margin-bottom: 0.5em;
|
||||||
|
padding: 0.5em;
|
||||||
|
background: white;
|
||||||
|
border-radius: 4px;
|
||||||
|
border: 1px solid var(--border-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-btn {
|
||||||
|
background: var(--primary-color);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 0.5em 1em;
|
||||||
|
border-radius: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.9em;
|
||||||
|
transition: background-color 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.copy-btn:hover {
|
||||||
|
background: #1e40af;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
.standard-citation {
|
||||||
|
background: #1e293b;
|
||||||
|
}
|
||||||
|
.citation-text {
|
||||||
|
background: #0f172a;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 添加期刊指标样式 */
|
||||||
|
.journal-metric {
|
||||||
|
display: inline-block;
|
||||||
|
padding: 0.2em 0.6em;
|
||||||
|
margin: 0 0.3em;
|
||||||
|
background: var(--primary-light);
|
||||||
|
border-radius: 4px;
|
||||||
|
font-weight: 500;
|
||||||
|
color: var(--primary-color);
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: dark) {
|
||||||
|
.journal-metric {
|
||||||
|
background: #1e293b;
|
||||||
|
color: #60a5fa;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 修改 js_code 部分,添加 markdown 解析功能
|
||||||
|
js_code = """
|
||||||
|
<script>
|
||||||
|
// 复制功能
|
||||||
|
function copyToClipboard(element) {
|
||||||
|
const text = element.innerText;
|
||||||
|
navigator.clipboard.writeText(text).then(function() {
|
||||||
|
const btn = element.nextElementSibling;
|
||||||
|
const originalText = btn.innerText;
|
||||||
|
btn.innerText = '已复制!';
|
||||||
|
setTimeout(() => {
|
||||||
|
btn.innerText = originalText;
|
||||||
|
}, 2000);
|
||||||
|
}).catch(function(err) {
|
||||||
|
console.error('复制失败:', err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Markdown解析
|
||||||
|
document.addEventListener('DOMContentLoaded', function() {
|
||||||
|
const answerContent = document.getElementById('answer-content');
|
||||||
|
if (answerContent) {
|
||||||
|
const markdown = answerContent.textContent;
|
||||||
|
answerContent.innerHTML = marked.parse(markdown);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 将新的CSS样式添加到现有样式中
|
||||||
|
self.css_styles += css_additions
|
||||||
|
|
||||||
|
return f"""
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||||
|
<title>学术对话存档</title>
|
||||||
|
<!-- 添加 marked.js -->
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||||
|
<!-- 添加 GitHub Markdown CSS -->
|
||||||
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/gh/sindresorhus/github-markdown-css@4.0.0/github-markdown.min.css">
|
||||||
|
<style>
|
||||||
|
{self.css_styles}
|
||||||
|
/* 添加 Markdown 相关样式 */
|
||||||
|
.markdown-body {{
|
||||||
|
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||||
|
padding: 1rem;
|
||||||
|
background: var(--primary-light);
|
||||||
|
border-radius: 6px;
|
||||||
|
}}
|
||||||
|
.markdown-body pre {{
|
||||||
|
background-color: #f6f8fa;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 16px;
|
||||||
|
overflow: auto;
|
||||||
|
}}
|
||||||
|
.markdown-body code {{
|
||||||
|
background-color: rgba(175,184,193,0.2);
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 0.2em 0.4em;
|
||||||
|
font-size: 85%;
|
||||||
|
}}
|
||||||
|
.markdown-body pre code {{
|
||||||
|
background-color: transparent;
|
||||||
|
padding: 0;
|
||||||
|
}}
|
||||||
|
.markdown-body blockquote {{
|
||||||
|
border-left: 0.25em solid #d0d7de;
|
||||||
|
padding: 0 1em;
|
||||||
|
color: #656d76;
|
||||||
|
}}
|
||||||
|
.markdown-body table {{
|
||||||
|
border-collapse: collapse;
|
||||||
|
width: 100%;
|
||||||
|
margin: 1em 0;
|
||||||
|
}}
|
||||||
|
.markdown-body table th,
|
||||||
|
.markdown-body table td {{
|
||||||
|
border: 1px solid #d0d7de;
|
||||||
|
padding: 6px 13px;
|
||||||
|
}}
|
||||||
|
.markdown-body table tr:nth-child(2n) {{
|
||||||
|
background-color: #f6f8fa;
|
||||||
|
}}
|
||||||
|
@media (prefers-color-scheme: dark) {{
|
||||||
|
.markdown-body {{
|
||||||
|
background: #1e293b;
|
||||||
|
color: #e2e8f0;
|
||||||
|
}}
|
||||||
|
.markdown-body pre {{
|
||||||
|
background-color: #0f172a;
|
||||||
|
}}
|
||||||
|
.markdown-body code {{
|
||||||
|
background-color: rgba(99,110,123,0.4);
|
||||||
|
}}
|
||||||
|
.markdown-body blockquote {{
|
||||||
|
border-left-color: #30363d;
|
||||||
|
color: #8b949e;
|
||||||
|
}}
|
||||||
|
.markdown-body table th,
|
||||||
|
.markdown-body table td {{
|
||||||
|
border-color: #30363d;
|
||||||
|
}}
|
||||||
|
.markdown-body table tr:nth-child(2n) {{
|
||||||
|
background-color: #0f172a;
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1 class="chat-title">学术对话存档</h1>
|
||||||
|
<div class="chat-body">
|
||||||
|
{chat_content}
|
||||||
|
{references_content}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{js_code}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
class MarkdownFormatter:
|
||||||
|
"""Markdown格式文档生成器 - 用于生成对话记录的markdown文档"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.content = []
|
||||||
|
|
||||||
|
def _add_content(self, text: str):
|
||||||
|
"""添加正文内容"""
|
||||||
|
if text:
|
||||||
|
self.content.append(f"\n{text}\n")
|
||||||
|
|
||||||
|
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
|
||||||
|
"""创建完整的Markdown文档
|
||||||
|
Args:
|
||||||
|
question: str, 用户问题
|
||||||
|
answer: str, AI回答
|
||||||
|
ranked_papers: list, 排序后的论文列表
|
||||||
|
Returns:
|
||||||
|
str: 生成的Markdown文本
|
||||||
|
"""
|
||||||
|
content = []
|
||||||
|
|
||||||
|
# 添加问答部分
|
||||||
|
content.append("## 问题")
|
||||||
|
content.append(question)
|
||||||
|
content.append("\n## 回答")
|
||||||
|
content.append(answer)
|
||||||
|
|
||||||
|
# 添加参考文献
|
||||||
|
if ranked_papers:
|
||||||
|
content.append("\n## 参考文献")
|
||||||
|
for idx, paper in enumerate(ranked_papers, 1):
|
||||||
|
authors = ', '.join(paper.authors[:3])
|
||||||
|
if len(paper.authors) > 3:
|
||||||
|
authors += ' et al.'
|
||||||
|
|
||||||
|
ref = f"[{idx}] {authors}. *{paper.title}*"
|
||||||
|
if paper.venue_name:
|
||||||
|
ref += f". {paper.venue_name}"
|
||||||
|
if paper.year:
|
||||||
|
ref += f", {paper.year}"
|
||||||
|
if paper.doi:
|
||||||
|
ref += f". [DOI: {paper.doi}](https://doi.org/{paper.doi})"
|
||||||
|
|
||||||
|
content.append(ref)
|
||||||
|
|
||||||
|
return "\n\n".join(content)
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
from typing import List
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||||
|
import re
|
||||||
|
|
||||||
|
class ReferenceFormatter:
|
||||||
|
"""通用参考文献格式生成器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _sanitize_bibtex(self, text: str) -> str:
|
||||||
|
"""清理BibTeX字符串,处理特殊字符"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 替换特殊字符
|
||||||
|
replacements = {
|
||||||
|
'&': '\\&',
|
||||||
|
'%': '\\%',
|
||||||
|
'$': '\\$',
|
||||||
|
'#': '\\#',
|
||||||
|
'_': '\\_',
|
||||||
|
'{': '\\{',
|
||||||
|
'}': '\\}',
|
||||||
|
'~': '\\textasciitilde{}',
|
||||||
|
'^': '\\textasciicircum{}',
|
||||||
|
'\\': '\\textbackslash{}',
|
||||||
|
'<': '\\textless{}',
|
||||||
|
'>': '\\textgreater{}',
|
||||||
|
'"': '``',
|
||||||
|
"'": "'",
|
||||||
|
'-': '--',
|
||||||
|
'—': '---',
|
||||||
|
}
|
||||||
|
|
||||||
|
for char, replacement in replacements.items():
|
||||||
|
text = text.replace(char, replacement)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _generate_cite_key(self, paper: PaperMetadata) -> str:
|
||||||
|
"""生成引用键
|
||||||
|
格式: 第一作者姓氏_年份_第一个实词
|
||||||
|
"""
|
||||||
|
# 获取第一作者姓氏
|
||||||
|
first_author = ""
|
||||||
|
if paper.authors and len(paper.authors) > 0:
|
||||||
|
first_author = paper.authors[0].split()[-1].lower()
|
||||||
|
|
||||||
|
# 获取年份
|
||||||
|
year = str(paper.year) if paper.year else "0000"
|
||||||
|
|
||||||
|
# 从标题中获取第一个实词
|
||||||
|
title_word = ""
|
||||||
|
if paper.title:
|
||||||
|
# 移除特殊字符,分割成单词
|
||||||
|
words = re.findall(r'\w+', paper.title.lower())
|
||||||
|
# 过滤掉常见的停用词
|
||||||
|
stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
|
||||||
|
for word in words:
|
||||||
|
if word not in stop_words and len(word) > 2:
|
||||||
|
title_word = word
|
||||||
|
break
|
||||||
|
|
||||||
|
# 组合cite key
|
||||||
|
cite_key = f"{first_author}{year}{title_word}"
|
||||||
|
|
||||||
|
# 确保cite key只包含合法字符
|
||||||
|
cite_key = re.sub(r'[^a-z0-9]', '', cite_key.lower())
|
||||||
|
|
||||||
|
return cite_key
|
||||||
|
|
||||||
|
def _get_entry_type(self, paper: PaperMetadata) -> str:
|
||||||
|
"""确定BibTeX条目类型"""
|
||||||
|
if hasattr(paper, 'venue_type') and paper.venue_type:
|
||||||
|
venue_type = paper.venue_type.lower()
|
||||||
|
if venue_type == 'conference':
|
||||||
|
return 'inproceedings'
|
||||||
|
elif venue_type == 'preprint':
|
||||||
|
return 'unpublished'
|
||||||
|
elif venue_type == 'journal':
|
||||||
|
return 'article'
|
||||||
|
elif venue_type == 'book':
|
||||||
|
return 'book'
|
||||||
|
elif venue_type == 'thesis':
|
||||||
|
return 'phdthesis'
|
||||||
|
return 'article' # 默认为期刊文章
|
||||||
|
|
||||||
|
|
||||||
|
def create_document(self, papers: List[PaperMetadata]) -> str:
|
||||||
|
"""生成BibTeX格式的参考文献文本"""
|
||||||
|
bibtex_text = "% This file was automatically generated by GPT-Academic\n"
|
||||||
|
bibtex_text += "% Compatible with: EndNote, Zotero, JabRef, and LaTeX\n\n"
|
||||||
|
|
||||||
|
for paper in papers:
|
||||||
|
entry_type = self._get_entry_type(paper)
|
||||||
|
cite_key = self._generate_cite_key(paper)
|
||||||
|
|
||||||
|
bibtex_text += f"@{entry_type}{{{cite_key},\n"
|
||||||
|
|
||||||
|
# 添加标题
|
||||||
|
if paper.title:
|
||||||
|
bibtex_text += f" title = {{{self._sanitize_bibtex(paper.title)}}},\n"
|
||||||
|
|
||||||
|
# 添加作者
|
||||||
|
if paper.authors:
|
||||||
|
# 确保每个作者的姓和名正确分隔
|
||||||
|
processed_authors = []
|
||||||
|
for author in paper.authors:
|
||||||
|
names = author.split()
|
||||||
|
if len(names) > 1:
|
||||||
|
# 假设最后一个词是姓,其他的是名
|
||||||
|
surname = names[-1]
|
||||||
|
given_names = ' '.join(names[:-1])
|
||||||
|
processed_authors.append(f"{surname}, {given_names}")
|
||||||
|
else:
|
||||||
|
processed_authors.append(author)
|
||||||
|
|
||||||
|
authors = " and ".join([self._sanitize_bibtex(author) for author in processed_authors])
|
||||||
|
bibtex_text += f" author = {{{authors}}},\n"
|
||||||
|
|
||||||
|
# 添加年份
|
||||||
|
if paper.year:
|
||||||
|
bibtex_text += f" year = {{{paper.year}}},\n"
|
||||||
|
|
||||||
|
# 添加期刊/会议名称
|
||||||
|
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||||
|
if entry_type == 'inproceedings':
|
||||||
|
bibtex_text += f" booktitle = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
|
||||||
|
elif entry_type == 'article':
|
||||||
|
bibtex_text += f" journal = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
|
||||||
|
# 添加期刊相关信息
|
||||||
|
if hasattr(paper, 'venue_info'):
|
||||||
|
if 'volume' in paper.venue_info:
|
||||||
|
bibtex_text += f" volume = {{{paper.venue_info['volume']}}},\n"
|
||||||
|
if 'number' in paper.venue_info:
|
||||||
|
bibtex_text += f" number = {{{paper.venue_info['number']}}},\n"
|
||||||
|
if 'pages' in paper.venue_info:
|
||||||
|
bibtex_text += f" pages = {{{paper.venue_info['pages']}}},\n"
|
||||||
|
elif paper.venue:
|
||||||
|
venue_field = "booktitle" if entry_type == "inproceedings" else "journal"
|
||||||
|
bibtex_text += f" {venue_field} = {{{self._sanitize_bibtex(paper.venue)}}},\n"
|
||||||
|
|
||||||
|
# 添加DOI
|
||||||
|
if paper.doi:
|
||||||
|
bibtex_text += f" doi = {{{paper.doi}}},\n"
|
||||||
|
|
||||||
|
# 添加URL
|
||||||
|
if paper.url:
|
||||||
|
bibtex_text += f" url = {{{paper.url}}},\n"
|
||||||
|
elif paper.doi:
|
||||||
|
bibtex_text += f" url = {{https://doi.org/{paper.doi}}},\n"
|
||||||
|
|
||||||
|
# 添加摘要
|
||||||
|
if paper.abstract:
|
||||||
|
bibtex_text += f" abstract = {{{self._sanitize_bibtex(paper.abstract)}}},\n"
|
||||||
|
|
||||||
|
# 添加机构
|
||||||
|
if hasattr(paper, 'institutions') and paper.institutions:
|
||||||
|
institutions = " and ".join([self._sanitize_bibtex(inst) for inst in paper.institutions])
|
||||||
|
bibtex_text += f" institution = {{{institutions}}},\n"
|
||||||
|
|
||||||
|
# 添加月份
|
||||||
|
if hasattr(paper, 'month'):
|
||||||
|
bibtex_text += f" month = {{{paper.month}}},\n"
|
||||||
|
|
||||||
|
# 添加注释字段
|
||||||
|
if hasattr(paper, 'note'):
|
||||||
|
bibtex_text += f" note = {{{self._sanitize_bibtex(paper.note)}}},\n"
|
||||||
|
|
||||||
|
# 移除最后一个逗号并关闭条目
|
||||||
|
bibtex_text = bibtex_text.rstrip(',\n') + "\n}\n\n"
|
||||||
|
|
||||||
|
return bibtex_text
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
from docx2pdf import convert
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from typing import Union
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class WordToPdfConverter:
|
||||||
|
"""Word文档转PDF转换器"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _replace_docx_in_filename(filename: Union[str, Path]) -> Path:
|
||||||
|
"""
|
||||||
|
将文件名中的'docx'替换为'pdf'
|
||||||
|
例如: 'docx_test.pdf' -> 'pdf_test.pdf'
|
||||||
|
"""
|
||||||
|
path = Path(filename)
|
||||||
|
new_name = path.stem.replace('docx', 'pdf')
|
||||||
|
return path.parent / f"{new_name}{path.suffix}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
|
||||||
|
"""
|
||||||
|
将Word文档转换为PDF
|
||||||
|
|
||||||
|
参数:
|
||||||
|
word_path: Word文档的路径
|
||||||
|
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
|
||||||
|
|
||||||
|
返回:
|
||||||
|
生成的PDF文件路径
|
||||||
|
|
||||||
|
异常:
|
||||||
|
如果转换失败,将抛出相应异常
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
word_path = Path(word_path)
|
||||||
|
|
||||||
|
if pdf_path is None:
|
||||||
|
# 创建新的pdf路径,同时替换文件名中的docx
|
||||||
|
pdf_path = WordToPdfConverter._replace_docx_in_filename(word_path).with_suffix('.pdf')
|
||||||
|
else:
|
||||||
|
pdf_path = WordToPdfConverter._replace_docx_in_filename(Path(pdf_path))
|
||||||
|
|
||||||
|
# 检查操作系统
|
||||||
|
if platform.system() == 'Linux':
|
||||||
|
# Linux系统需要安装libreoffice
|
||||||
|
if not os.system('which libreoffice') == 0:
|
||||||
|
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
|
||||||
|
|
||||||
|
# 使用libreoffice进行转换
|
||||||
|
os.system(f'libreoffice --headless --convert-to pdf "{word_path}" --outdir "{pdf_path.parent}"')
|
||||||
|
|
||||||
|
# 如果输出路径与默认生成的不同,则重命名
|
||||||
|
default_pdf = word_path.with_suffix('.pdf')
|
||||||
|
if default_pdf != pdf_path:
|
||||||
|
os.rename(default_pdf, pdf_path)
|
||||||
|
else:
|
||||||
|
# Windows和MacOS使用 docx2pdf
|
||||||
|
convert(word_path, pdf_path)
|
||||||
|
|
||||||
|
return str(pdf_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"转换PDF失败: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
|
||||||
|
"""
|
||||||
|
批量转换目录下的所有Word文档
|
||||||
|
|
||||||
|
参数:
|
||||||
|
word_dir: 包含Word文档的目录路径
|
||||||
|
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
|
||||||
|
|
||||||
|
返回:
|
||||||
|
生成的PDF文件路径列表
|
||||||
|
"""
|
||||||
|
word_dir = Path(word_dir)
|
||||||
|
if pdf_dir:
|
||||||
|
pdf_dir = Path(pdf_dir)
|
||||||
|
pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
converted_files = []
|
||||||
|
|
||||||
|
for word_file in word_dir.glob("*.docx"):
|
||||||
|
try:
|
||||||
|
if pdf_dir:
|
||||||
|
pdf_path = pdf_dir / WordToPdfConverter._replace_docx_in_filename(
|
||||||
|
word_file.with_suffix('.pdf')
|
||||||
|
).name
|
||||||
|
else:
|
||||||
|
pdf_path = WordToPdfConverter._replace_docx_in_filename(
|
||||||
|
word_file.with_suffix('.pdf')
|
||||||
|
)
|
||||||
|
|
||||||
|
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
|
||||||
|
converted_files.append(pdf_file)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"转换 {word_file} 失败: {str(e)}")
|
||||||
|
|
||||||
|
return converted_files
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
|
||||||
|
"""
|
||||||
|
将docx对象直接转换为PDF
|
||||||
|
|
||||||
|
参数:
|
||||||
|
doc: python-docx的Document对象
|
||||||
|
output_dir: 可选,输出目录。如果未指定,将使用当前目录
|
||||||
|
|
||||||
|
返回:
|
||||||
|
生成的PDF文件路径
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 设置临时文件路径和输出路径
|
||||||
|
output_dir = Path(output_dir) if output_dir else Path.cwd()
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成临时word文件
|
||||||
|
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
|
||||||
|
doc.save(temp_docx)
|
||||||
|
|
||||||
|
# 转换为PDF
|
||||||
|
pdf_path = temp_docx.with_suffix('.pdf')
|
||||||
|
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
|
||||||
|
|
||||||
|
# 删除临时word文件
|
||||||
|
temp_docx.unlink()
|
||||||
|
|
||||||
|
return str(pdf_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if temp_docx.exists():
|
||||||
|
temp_docx.unlink()
|
||||||
|
raise Exception(f"转换PDF失败: {str(e)}")
|
||||||
@@ -0,0 +1,246 @@
|
|||||||
|
import re
|
||||||
|
from docx import Document
|
||||||
|
from docx.shared import Cm, Pt
|
||||||
|
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
|
||||||
|
from docx.enum.style import WD_STYLE_TYPE
|
||||||
|
from docx.oxml.ns import qn
|
||||||
|
from datetime import datetime
|
||||||
|
import docx
|
||||||
|
from docx.oxml import shared
|
||||||
|
from crazy_functions.doc_fns.conversation_doc.word_doc import convert_markdown_to_word
|
||||||
|
|
||||||
|
|
||||||
|
class WordFormatter:
|
||||||
|
"""聊天记录Word文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012)"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.doc = Document()
|
||||||
|
self._setup_document()
|
||||||
|
self._create_styles()
|
||||||
|
|
||||||
|
def _setup_document(self):
|
||||||
|
"""设置文档基本格式,包括页面设置和页眉"""
|
||||||
|
sections = self.doc.sections
|
||||||
|
for section in sections:
|
||||||
|
# 设置页面大小为A4
|
||||||
|
section.page_width = Cm(21)
|
||||||
|
section.page_height = Cm(29.7)
|
||||||
|
# 设置页边距
|
||||||
|
section.top_margin = Cm(3.7) # 上边距37mm
|
||||||
|
section.bottom_margin = Cm(3.5) # 下边距35mm
|
||||||
|
section.left_margin = Cm(2.8) # 左边距28mm
|
||||||
|
section.right_margin = Cm(2.6) # 右边距26mm
|
||||||
|
# 设置页眉页脚距离
|
||||||
|
section.header_distance = Cm(2.0)
|
||||||
|
section.footer_distance = Cm(2.0)
|
||||||
|
|
||||||
|
# 修改页眉
|
||||||
|
header = section.header
|
||||||
|
header_para = header.paragraphs[0]
|
||||||
|
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||||
|
header_run = header_para.add_run("GPT-Academic学术对话 (体验地址:https://auth.gpt-academic.top/)")
|
||||||
|
header_run.font.name = '仿宋'
|
||||||
|
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
header_run.font.size = Pt(9)
|
||||||
|
|
||||||
|
def _create_styles(self):
|
||||||
|
"""创建文档样式"""
|
||||||
|
# 创建正文样式
|
||||||
|
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
style.font.name = '仿宋'
|
||||||
|
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
style.font.size = Pt(12)
|
||||||
|
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
style.paragraph_format.space_after = Pt(0)
|
||||||
|
|
||||||
|
# 创建问题样式
|
||||||
|
question_style = self.doc.styles.add_style('Question_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
question_style.font.name = '黑体'
|
||||||
|
question_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||||
|
question_style.font.size = Pt(14) # 调整为14磅
|
||||||
|
question_style.font.bold = True
|
||||||
|
question_style.paragraph_format.space_before = Pt(12) # 减小段前距
|
||||||
|
question_style.paragraph_format.space_after = Pt(6)
|
||||||
|
question_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
question_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
|
||||||
|
|
||||||
|
# 创建回答样式
|
||||||
|
answer_style = self.doc.styles.add_style('Answer_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
answer_style.font.name = '仿宋'
|
||||||
|
answer_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
answer_style.font.size = Pt(12) # 调整为12磅
|
||||||
|
answer_style.paragraph_format.space_before = Pt(6)
|
||||||
|
answer_style.paragraph_format.space_after = Pt(12)
|
||||||
|
answer_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
answer_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
|
||||||
|
|
||||||
|
# 创建标题样式
|
||||||
|
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
title_style.font.name = '黑体' # 改用黑体
|
||||||
|
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||||
|
title_style.font.size = Pt(22) # 调整为22磅
|
||||||
|
title_style.font.bold = True
|
||||||
|
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||||
|
title_style.paragraph_format.space_before = Pt(0)
|
||||||
|
title_style.paragraph_format.space_after = Pt(24)
|
||||||
|
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
|
||||||
|
# 添加参考文献样式
|
||||||
|
ref_style = self.doc.styles.add_style('Reference_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
ref_style.font.name = '宋体'
|
||||||
|
ref_style._element.rPr.rFonts.set(qn('w:eastAsia'), '宋体')
|
||||||
|
ref_style.font.size = Pt(10.5) # 参考文献使用小号字体
|
||||||
|
ref_style.paragraph_format.space_before = Pt(3)
|
||||||
|
ref_style.paragraph_format.space_after = Pt(3)
|
||||||
|
ref_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
|
||||||
|
ref_style.paragraph_format.left_indent = Pt(21)
|
||||||
|
ref_style.paragraph_format.first_line_indent = Pt(-21)
|
||||||
|
|
||||||
|
# 添加参考文献标题样式
|
||||||
|
ref_title_style = self.doc.styles.add_style('Reference_Title_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||||
|
ref_title_style.font.name = '黑体'
|
||||||
|
ref_title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||||
|
ref_title_style.font.size = Pt(16) # 参考文献标题与问题同样大小
|
||||||
|
ref_title_style.font.bold = True
|
||||||
|
ref_title_style.paragraph_format.space_before = Pt(24) # 增加段前距
|
||||||
|
ref_title_style.paragraph_format.space_after = Pt(12)
|
||||||
|
ref_title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||||
|
|
||||||
|
def create_document(self, question: str, answer: str, ranked_papers: list = None):
|
||||||
|
"""写入聊天历史
|
||||||
|
Args:
|
||||||
|
question: str, 用户问题
|
||||||
|
answer: str, AI回答
|
||||||
|
ranked_papers: list, 排序后的论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 添加标题
|
||||||
|
title_para = self.doc.add_paragraph(style='Title_Custom')
|
||||||
|
title_run = title_para.add_run('GPT-Academic 对话记录')
|
||||||
|
|
||||||
|
# 添加日期
|
||||||
|
try:
|
||||||
|
date_para = self.doc.add_paragraph()
|
||||||
|
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||||
|
date_run = date_para.add_run(datetime.now().strftime('%Y年%m月%d日'))
|
||||||
|
date_run.font.name = '仿宋'
|
||||||
|
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||||
|
date_run.font.size = Pt(16)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"添加日期失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
self.doc.add_paragraph() # 添加空行
|
||||||
|
|
||||||
|
# 添加问答对话
|
||||||
|
try:
|
||||||
|
q_para = self.doc.add_paragraph(style='Question_Style')
|
||||||
|
q_para.add_run('问题:').bold = True
|
||||||
|
q_para.add_run(str(question))
|
||||||
|
|
||||||
|
a_para = self.doc.add_paragraph(style='Answer_Style')
|
||||||
|
a_para.add_run('回答:').bold = True
|
||||||
|
a_para.add_run(convert_markdown_to_word(str(answer)))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"添加问答对话失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 添加参考文献部分
|
||||||
|
if ranked_papers:
|
||||||
|
try:
|
||||||
|
ref_title = self.doc.add_paragraph(style='Reference_Title_Style')
|
||||||
|
ref_title.add_run("参考文献")
|
||||||
|
|
||||||
|
for idx, paper in enumerate(ranked_papers, 1):
|
||||||
|
try:
|
||||||
|
ref_para = self.doc.add_paragraph(style='Reference_Style')
|
||||||
|
ref_para.add_run(f'[{idx}] ').bold = True
|
||||||
|
|
||||||
|
# 添加作者
|
||||||
|
authors = ', '.join(paper.authors[:3])
|
||||||
|
if len(paper.authors) > 3:
|
||||||
|
authors += ' et al.'
|
||||||
|
ref_para.add_run(f'{authors}. ')
|
||||||
|
|
||||||
|
# 添加标题
|
||||||
|
title_run = ref_para.add_run(paper.title)
|
||||||
|
title_run.italic = True
|
||||||
|
if hasattr(paper, 'url') and paper.url:
|
||||||
|
try:
|
||||||
|
title_run._element.rPr.rStyle = self._create_hyperlink_style()
|
||||||
|
self._add_hyperlink(ref_para, paper.title, paper.url)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"添加超链接失败: {str(e)}")
|
||||||
|
|
||||||
|
# 添加期刊/会议信息
|
||||||
|
if paper.venue_name:
|
||||||
|
ref_para.add_run(f'. {paper.venue_name}')
|
||||||
|
|
||||||
|
# 添加年份
|
||||||
|
if paper.year:
|
||||||
|
ref_para.add_run(f', {paper.year}')
|
||||||
|
|
||||||
|
# 添加DOI
|
||||||
|
if paper.doi:
|
||||||
|
ref_para.add_run('. ')
|
||||||
|
if "arxiv" in paper.url:
|
||||||
|
doi_url = paper.doi
|
||||||
|
else:
|
||||||
|
doi_url = f'https://doi.org/{paper.doi}'
|
||||||
|
self._add_hyperlink(ref_para, f'DOI: {paper.doi}', doi_url)
|
||||||
|
|
||||||
|
ref_para.add_run('.')
|
||||||
|
except Exception as e:
|
||||||
|
print(f"添加第 {idx} 篇参考文献失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"添加参考文献部分失败: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
return self.doc
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Word文档创建失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(f"详细错误信息: {traceback.format_exc()}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _create_hyperlink_style(self):
|
||||||
|
"""创建超链接样式"""
|
||||||
|
styles = self.doc.styles
|
||||||
|
if 'Hyperlink' not in styles:
|
||||||
|
hyperlink_style = styles.add_style('Hyperlink', WD_STYLE_TYPE.CHARACTER)
|
||||||
|
# 使用科技蓝 (#0066CC)
|
||||||
|
hyperlink_style.font.color.rgb = 0x0066CC # 科技蓝
|
||||||
|
hyperlink_style.font.underline = True
|
||||||
|
return styles['Hyperlink']
|
||||||
|
|
||||||
|
def _add_hyperlink(self, paragraph, text, url):
|
||||||
|
"""添加超链接到段落"""
|
||||||
|
# 这个是在XML级别添加超链接
|
||||||
|
part = paragraph.part
|
||||||
|
r_id = part.relate_to(url, docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK, is_external=True)
|
||||||
|
|
||||||
|
# 创建超链接XML元素
|
||||||
|
hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
|
||||||
|
hyperlink.set(docx.oxml.shared.qn('r:id'), r_id)
|
||||||
|
|
||||||
|
# 创建文本运行
|
||||||
|
new_run = docx.oxml.shared.OxmlElement('w:r')
|
||||||
|
rPr = docx.oxml.shared.OxmlElement('w:rPr')
|
||||||
|
|
||||||
|
# 应用超链接样式
|
||||||
|
rStyle = docx.oxml.shared.OxmlElement('w:rStyle')
|
||||||
|
rStyle.set(docx.oxml.shared.qn('w:val'), 'Hyperlink')
|
||||||
|
rPr.append(rStyle)
|
||||||
|
|
||||||
|
# 添加文本
|
||||||
|
t = docx.oxml.shared.OxmlElement('w:t')
|
||||||
|
t.text = text
|
||||||
|
new_run.append(rPr)
|
||||||
|
new_run.append(t)
|
||||||
|
hyperlink.append(new_run)
|
||||||
|
|
||||||
|
# 将超链接添加到段落
|
||||||
|
paragraph._p.append(hyperlink)
|
||||||
|
|
||||||
@@ -0,0 +1,279 @@
|
|||||||
|
from typing import List, Optional, Dict, Union
|
||||||
|
from datetime import datetime
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import random
|
||||||
|
|
||||||
|
class AdsabsSource(DataSource):
|
||||||
|
"""ADS (Astrophysics Data System) API实现"""
|
||||||
|
|
||||||
|
# 定义API密钥列表
|
||||||
|
API_KEYS = [
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
"""初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: ADS API密钥,如果不提供则从预定义列表中随机选择
|
||||||
|
"""
|
||||||
|
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化基础URL和请求头"""
|
||||||
|
self.base_url = "https://api.adsabs.harvard.edu/v1"
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _make_request(self, url: str, method: str = "GET", data: dict = None) -> Optional[dict]:
|
||||||
|
"""发送HTTP请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 请求URL
|
||||||
|
method: HTTP方法
|
||||||
|
data: POST请求数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
响应内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
if method == "GET":
|
||||||
|
async with session.get(url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
elif method == "POST":
|
||||||
|
async with session.post(url, json=data) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_paper(self, doc: dict) -> PaperMetadata:
|
||||||
|
"""解析ADS文献数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc: ADS文献数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的论文数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return PaperMetadata(
|
||||||
|
title=doc.get('title', [''])[0] if doc.get('title') else '',
|
||||||
|
authors=doc.get('author', []),
|
||||||
|
abstract=doc.get('abstract', ''),
|
||||||
|
year=doc.get('year'),
|
||||||
|
doi=doc.get('doi', [''])[0] if doc.get('doi') else None,
|
||||||
|
url=f"https://ui.adsabs.harvard.edu/abs/{doc.get('bibcode')}/abstract" if doc.get('bibcode') else None,
|
||||||
|
citations=doc.get('citation_count'),
|
||||||
|
venue=doc.get('pub', ''),
|
||||||
|
institutions=doc.get('aff', []),
|
||||||
|
venue_type="journal",
|
||||||
|
venue_name=doc.get('pub', ''),
|
||||||
|
venue_info={
|
||||||
|
'volume': doc.get('volume'),
|
||||||
|
'issue': doc.get('issue'),
|
||||||
|
'pub_date': doc.get('pubdate', '')
|
||||||
|
},
|
||||||
|
source='adsabs'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析文章时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = "relevance",
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""搜索论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||||
|
start_year: 起始年份
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建查询
|
||||||
|
if start_year:
|
||||||
|
query = f"{query} year:{start_year}-"
|
||||||
|
|
||||||
|
# 设置排序
|
||||||
|
sort_mapping = {
|
||||||
|
'relevance': 'score desc',
|
||||||
|
'date': 'date desc',
|
||||||
|
'citations': 'citation_count desc'
|
||||||
|
}
|
||||||
|
sort = sort_mapping.get(sort_by, 'score desc')
|
||||||
|
|
||||||
|
# 构建搜索请求
|
||||||
|
search_url = f"{self.base_url}/search/query"
|
||||||
|
params = {
|
||||||
|
"q": query,
|
||||||
|
"rows": limit,
|
||||||
|
"sort": sort,
|
||||||
|
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||||
|
if not response or 'response' not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
papers = []
|
||||||
|
for doc in response['response']['docs']:
|
||||||
|
paper = self._parse_paper(doc)
|
||||||
|
if paper:
|
||||||
|
papers.append(paper)
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索论文时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _build_query_string(self, params: dict) -> str:
|
||||||
|
"""构建查询字符串"""
|
||||||
|
return "&".join([f"{k}={v}" for k, v in params.items()])
|
||||||
|
|
||||||
|
async def get_paper_details(self, bibcode: str) -> Optional[PaperMetadata]:
|
||||||
|
"""获取指定bibcode的论文详情"""
|
||||||
|
search_url = f"{self.base_url}/search/query"
|
||||||
|
params = {
|
||||||
|
"q": f"identifier:{bibcode}",
|
||||||
|
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||||
|
if response and 'response' in response and response['response']['docs']:
|
||||||
|
return self._parse_paper(response['response']['docs'][0])
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_related_papers(self, bibcode: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
"""获取相关论文"""
|
||||||
|
url = f"{self.base_url}/search/query"
|
||||||
|
params = {
|
||||||
|
"q": f"citations(identifier:{bibcode}) OR references(identifier:{bibcode})",
|
||||||
|
"rows": limit,
|
||||||
|
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||||
|
if not response or 'response' not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
papers = []
|
||||||
|
for doc in response['response']['docs']:
|
||||||
|
paper = self._parse_paper(doc)
|
||||||
|
if paper:
|
||||||
|
papers.append(paper)
|
||||||
|
return papers
|
||||||
|
|
||||||
|
async def search_by_author(
|
||||||
|
self,
|
||||||
|
author: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按作者搜索论文"""
|
||||||
|
query = f"author:\"{author}\""
|
||||||
|
return await self.search(query, limit=limit, start_year=start_year)
|
||||||
|
|
||||||
|
async def search_by_journal(
|
||||||
|
self,
|
||||||
|
journal: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按期刊搜索论文"""
|
||||||
|
query = f"pub:\"{journal}\""
|
||||||
|
return await self.search(query, limit=limit, start_year=start_year)
|
||||||
|
|
||||||
|
async def get_latest_papers(
|
||||||
|
self,
|
||||||
|
days: int = 7,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""获取最新论文"""
|
||||||
|
query = f"entdate:[NOW-{days}DAYS TO NOW]"
|
||||||
|
return await self.search(query, limit=limit, sort_by="date")
|
||||||
|
|
||||||
|
async def get_citations(self, bibcode: str) -> List[PaperMetadata]:
|
||||||
|
"""获取引用该论文的文献"""
|
||||||
|
url = f"{self.base_url}/search/query"
|
||||||
|
params = {
|
||||||
|
"q": f"citations(identifier:{bibcode})",
|
||||||
|
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||||
|
if not response or 'response' not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
papers = []
|
||||||
|
for doc in response['response']['docs']:
|
||||||
|
paper = self._parse_paper(doc)
|
||||||
|
if paper:
|
||||||
|
papers.append(paper)
|
||||||
|
return papers
|
||||||
|
|
||||||
|
async def get_references(self, bibcode: str) -> List[PaperMetadata]:
|
||||||
|
"""获取该论文引用的文献"""
|
||||||
|
url = f"{self.base_url}/search/query"
|
||||||
|
params = {
|
||||||
|
"q": f"references(identifier:{bibcode})",
|
||||||
|
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||||
|
if not response or 'response' not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
papers = []
|
||||||
|
for doc in response['response']['docs']:
|
||||||
|
paper = self._parse_paper(doc)
|
||||||
|
if paper:
|
||||||
|
papers.append(paper)
|
||||||
|
return papers
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""AdsabsSource使用示例"""
|
||||||
|
ads = AdsabsSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 示例1:基本搜索
|
||||||
|
print("\n=== 示例1:搜索黑洞相关论文 ===")
|
||||||
|
papers = await ads.search("black hole", limit=3)
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
|
||||||
|
# 其他示例...
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# python -m crazy_functions.review_fns.data_sources.adsabs_source
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,636 @@
|
|||||||
|
import arxiv
|
||||||
|
from typing import List, Optional, Union, Literal, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
from .base_source import DataSource, PaperMetadata
|
||||||
|
import os
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
import feedparser
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class ArxivSource(DataSource):
|
||||||
|
"""arXiv API实现"""
|
||||||
|
|
||||||
|
CATEGORIES = {
|
||||||
|
# 物理学
|
||||||
|
"Physics": {
|
||||||
|
"astro-ph": "天体物理学",
|
||||||
|
"cond-mat": "凝聚态物理",
|
||||||
|
"gr-qc": "广义相对论与量子宇宙学",
|
||||||
|
"hep-ex": "高能物理实验",
|
||||||
|
"hep-lat": "格点场论",
|
||||||
|
"hep-ph": "高能物理理论",
|
||||||
|
"hep-th": "高能物理理论",
|
||||||
|
"math-ph": "数学物理",
|
||||||
|
"nlin": "非线性科学",
|
||||||
|
"nucl-ex": "核实验",
|
||||||
|
"nucl-th": "核理论",
|
||||||
|
"physics": "物理学",
|
||||||
|
"quant-ph": "量子物理",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 数学
|
||||||
|
"Mathematics": {
|
||||||
|
"math.AG": "代数几何",
|
||||||
|
"math.AT": "代数拓扑",
|
||||||
|
"math.AP": "分析与偏微分方程",
|
||||||
|
"math.CT": "范畴论",
|
||||||
|
"math.CA": "复分析",
|
||||||
|
"math.CO": "组合数学",
|
||||||
|
"math.AC": "交换代数",
|
||||||
|
"math.CV": "复变函数",
|
||||||
|
"math.DG": "微分几何",
|
||||||
|
"math.DS": "动力系统",
|
||||||
|
"math.FA": "泛函分析",
|
||||||
|
"math.GM": "一般数学",
|
||||||
|
"math.GN": "一般拓扑",
|
||||||
|
"math.GT": "几何拓扑",
|
||||||
|
"math.GR": "群论",
|
||||||
|
"math.HO": "数学史与数学概述",
|
||||||
|
"math.IT": "信息论",
|
||||||
|
"math.KT": "K理论与同调",
|
||||||
|
"math.LO": "逻辑",
|
||||||
|
"math.MP": "数学物理",
|
||||||
|
"math.MG": "度量几何",
|
||||||
|
"math.NT": "数论",
|
||||||
|
"math.NA": "数值分析",
|
||||||
|
"math.OA": "算子代数",
|
||||||
|
"math.OC": "最优化与控制",
|
||||||
|
"math.PR": "概率论",
|
||||||
|
"math.QA": "量子代数",
|
||||||
|
"math.RT": "表示论",
|
||||||
|
"math.RA": "环与代数",
|
||||||
|
"math.SP": "谱理论",
|
||||||
|
"math.ST": "统计理论",
|
||||||
|
"math.SG": "辛几何",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 计算机科学
|
||||||
|
"Computer Science": {
|
||||||
|
"cs.AI": "人工智能",
|
||||||
|
"cs.CL": "计算语言学",
|
||||||
|
"cs.CC": "计算复杂性",
|
||||||
|
"cs.CE": "计算工程",
|
||||||
|
"cs.CG": "计算几何",
|
||||||
|
"cs.GT": "计算机博弈论",
|
||||||
|
"cs.CV": "计算机视觉",
|
||||||
|
"cs.CY": "计算机与社会",
|
||||||
|
"cs.CR": "密码学与安全",
|
||||||
|
"cs.DS": "数据结构与算法",
|
||||||
|
"cs.DB": "数据库",
|
||||||
|
"cs.DL": "数字图书馆",
|
||||||
|
"cs.DM": "离散数学",
|
||||||
|
"cs.DC": "分布式计算",
|
||||||
|
"cs.ET": "新兴技术",
|
||||||
|
"cs.FL": "形式语言与自动机理论",
|
||||||
|
"cs.GL": "一般文献",
|
||||||
|
"cs.GR": "图形学",
|
||||||
|
"cs.AR": "硬件架构",
|
||||||
|
"cs.HC": "人机交互",
|
||||||
|
"cs.IR": "信息检索",
|
||||||
|
"cs.IT": "信息论",
|
||||||
|
"cs.LG": "机器学习",
|
||||||
|
"cs.LO": "逻辑与计算机",
|
||||||
|
"cs.MS": "数学软件",
|
||||||
|
"cs.MA": "多智能体系统",
|
||||||
|
"cs.MM": "多媒体",
|
||||||
|
"cs.NI": "网络与互联网架构",
|
||||||
|
"cs.NE": "神经与进化计算",
|
||||||
|
"cs.NA": "数值分析",
|
||||||
|
"cs.OS": "操作系统",
|
||||||
|
"cs.OH": "其他计算机科学",
|
||||||
|
"cs.PF": "性能评估",
|
||||||
|
"cs.PL": "编程语言",
|
||||||
|
"cs.RO": "机器人学",
|
||||||
|
"cs.SI": "社会与信息网络",
|
||||||
|
"cs.SE": "软件工程",
|
||||||
|
"cs.SD": "声音",
|
||||||
|
"cs.SC": "符号计算",
|
||||||
|
"cs.SY": "系统与控制",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 定量生物学
|
||||||
|
"Quantitative Biology": {
|
||||||
|
"q-bio.BM": "生物分子",
|
||||||
|
"q-bio.CB": "细胞行为",
|
||||||
|
"q-bio.GN": "基因组学",
|
||||||
|
"q-bio.MN": "分子网络",
|
||||||
|
"q-bio.NC": "神经计算",
|
||||||
|
"q-bio.OT": "其他",
|
||||||
|
"q-bio.PE": "群体与进化",
|
||||||
|
"q-bio.QM": "定量方法",
|
||||||
|
"q-bio.SC": "亚细胞过程",
|
||||||
|
"q-bio.TO": "组织与器官",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 定量金融
|
||||||
|
"Quantitative Finance": {
|
||||||
|
"q-fin.CP": "计算金融",
|
||||||
|
"q-fin.EC": "经济学",
|
||||||
|
"q-fin.GN": "一般金融",
|
||||||
|
"q-fin.MF": "数学金融",
|
||||||
|
"q-fin.PM": "投资组合管理",
|
||||||
|
"q-fin.PR": "定价理论",
|
||||||
|
"q-fin.RM": "风险管理",
|
||||||
|
"q-fin.ST": "统计金融",
|
||||||
|
"q-fin.TR": "交易与市场微观结构",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 统计学
|
||||||
|
"Statistics": {
|
||||||
|
"stat.AP": "应用统计",
|
||||||
|
"stat.CO": "计算统计",
|
||||||
|
"stat.ML": "机器学习",
|
||||||
|
"stat.ME": "方法论",
|
||||||
|
"stat.OT": "其他统计",
|
||||||
|
"stat.TH": "统计理论",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 电气工程与系统科学
|
||||||
|
"Electrical Engineering and Systems Science": {
|
||||||
|
"eess.AS": "音频与语音处理",
|
||||||
|
"eess.IV": "图像与视频处理",
|
||||||
|
"eess.SP": "信号处理",
|
||||||
|
"eess.SY": "系统与控制",
|
||||||
|
},
|
||||||
|
|
||||||
|
# 经济学
|
||||||
|
"Economics": {
|
||||||
|
"econ.EM": "计量经济学",
|
||||||
|
"econ.GN": "一般经济学",
|
||||||
|
"econ.TH": "理论经济学",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化"""
|
||||||
|
self._initialize() # 调用初始化方法
|
||||||
|
# 修改排序选项映射
|
||||||
|
self.sort_options = {
|
||||||
|
'relevance': arxiv.SortCriterion.Relevance, # arXiv的相关性排序
|
||||||
|
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
|
||||||
|
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
|
||||||
|
}
|
||||||
|
|
||||||
|
self.sort_order_options = {
|
||||||
|
'ascending': arxiv.SortOrder.Ascending,
|
||||||
|
'descending': arxiv.SortOrder.Descending
|
||||||
|
}
|
||||||
|
|
||||||
|
self.default_sort = 'lastUpdatedDate'
|
||||||
|
self.default_order = 'descending'
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化客户端,设置默认参数"""
|
||||||
|
self.client = arxiv.Client()
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
sort_by: str = None,
|
||||||
|
sort_order: str = None,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""搜索论文"""
|
||||||
|
try:
|
||||||
|
# 使用默认排序如果提供的排序选项无效
|
||||||
|
if not sort_by or sort_by not in self.sort_options:
|
||||||
|
sort_by = self.default_sort
|
||||||
|
|
||||||
|
# 使用默认排序顺序如果提供的顺序无效
|
||||||
|
if not sort_order or sort_order not in self.sort_order_options:
|
||||||
|
sort_order = self.default_order
|
||||||
|
|
||||||
|
# 如果指定了起始年份,添加到查询中
|
||||||
|
if start_year:
|
||||||
|
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||||
|
|
||||||
|
search = arxiv.Search(
|
||||||
|
query=query,
|
||||||
|
max_results=limit,
|
||||||
|
sort_by=self.sort_options[sort_by],
|
||||||
|
sort_order=self.sort_order_options[sort_order]
|
||||||
|
)
|
||||||
|
|
||||||
|
results = list(self.client.results(search))
|
||||||
|
return [self._parse_paper_data(result) for result in results]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索论文时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
|
||||||
|
"""按ID搜索论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: 单个arXiv ID或ID列表,例如:'2005.14165' 或 ['2005.14165', '2103.14030']
|
||||||
|
"""
|
||||||
|
if isinstance(paper_id, str):
|
||||||
|
paper_id = [paper_id]
|
||||||
|
|
||||||
|
search = arxiv.Search(
|
||||||
|
id_list=paper_id,
|
||||||
|
max_results=len(paper_id)
|
||||||
|
)
|
||||||
|
results = list(self.client.results(search))
|
||||||
|
return [self._parse_paper_data(result) for result in results]
|
||||||
|
|
||||||
|
async def search_by_category(
|
||||||
|
self,
|
||||||
|
category: str,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = 'relevance',
|
||||||
|
sort_order: str = 'descending',
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按类别搜索论文"""
|
||||||
|
query = f"cat:{category}"
|
||||||
|
|
||||||
|
# 如果指定了起始年份,添加到查询中
|
||||||
|
if start_year:
|
||||||
|
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||||
|
|
||||||
|
return await self.search(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search_by_authors(
|
||||||
|
self,
|
||||||
|
authors: List[str],
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = 'relevance',
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按作者搜索论文"""
|
||||||
|
query = " AND ".join([f"au:\"{author}\"" for author in authors])
|
||||||
|
|
||||||
|
# 如果指定了起始年份,添加到查询中
|
||||||
|
if start_year:
|
||||||
|
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||||
|
|
||||||
|
return await self.search(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
sort_by=sort_by
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search_by_date_range(
|
||||||
|
self,
|
||||||
|
start_date: datetime,
|
||||||
|
end_date: datetime,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
|
||||||
|
sort_order: Literal['ascending', 'descending'] = 'descending'
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按日期范围搜索论文"""
|
||||||
|
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
|
||||||
|
return await self.search(
|
||||||
|
query,
|
||||||
|
limit=limit,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order
|
||||||
|
)
|
||||||
|
|
||||||
|
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||||
|
"""下载论文PDF
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: arXiv ID
|
||||||
|
dirpath: 保存目录
|
||||||
|
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
保存的文件路径
|
||||||
|
"""
|
||||||
|
papers = await self.search_by_id(paper_id)
|
||||||
|
if not papers:
|
||||||
|
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||||
|
paper = papers[0]
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
# 清理标题中的非法字符
|
||||||
|
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||||
|
filename = f"{paper_id}_{safe_title}.pdf"
|
||||||
|
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
urlretrieve(paper.url, filepath)
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||||
|
"""下载论文源文件(通常是LaTeX源码)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: arXiv ID
|
||||||
|
dirpath: 保存目录
|
||||||
|
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
保存的文件路径
|
||||||
|
"""
|
||||||
|
papers = await self.search_by_id(paper_id)
|
||||||
|
if not papers:
|
||||||
|
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||||
|
paper = papers[0]
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||||
|
filename = f"{paper_id}_{safe_title}.tar.gz"
|
||||||
|
|
||||||
|
filepath = os.path.join(dirpath, filename)
|
||||||
|
source_url = paper.url.replace("/pdf/", "/src/")
|
||||||
|
urlretrieve(source_url, filepath)
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
# arXiv API不直接提供引用信息
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
# arXiv API不直接提供引用信息
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||||
|
"""获取论文详情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: arXiv ID 或 DOI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
论文详细信息,如果未找到返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 如果是完整的 arXiv URL,提取 ID
|
||||||
|
if "arxiv.org" in paper_id:
|
||||||
|
paper_id = paper_id.split("/")[-1]
|
||||||
|
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
|
||||||
|
elif paper_id.startswith("10.48550/arXiv."):
|
||||||
|
paper_id = paper_id.split(".")[-1]
|
||||||
|
|
||||||
|
papers = await self.search_by_id(paper_id)
|
||||||
|
return papers[0] if papers else None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文详情时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_paper_data(self, result: arxiv.Result) -> PaperMetadata:
|
||||||
|
"""解析arXiv API返回的数据"""
|
||||||
|
# 解析主要类别和次要类别
|
||||||
|
primary_category = result.primary_category
|
||||||
|
categories = result.categories
|
||||||
|
|
||||||
|
# 构建venue信息
|
||||||
|
venue_info = {
|
||||||
|
'primary_category': primary_category,
|
||||||
|
'categories': categories,
|
||||||
|
'comments': getattr(result, 'comment', None),
|
||||||
|
'journal_ref': getattr(result, 'journal_ref', None)
|
||||||
|
}
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=result.title,
|
||||||
|
authors=[author.name for author in result.authors],
|
||||||
|
abstract=result.summary,
|
||||||
|
year=result.published.year,
|
||||||
|
doi=result.entry_id,
|
||||||
|
url=result.pdf_url,
|
||||||
|
citations=None,
|
||||||
|
venue=f"arXiv:{primary_category}",
|
||||||
|
institutions=[],
|
||||||
|
venue_type='preprint', # arXiv论文都是预印本
|
||||||
|
venue_name='arXiv',
|
||||||
|
venue_info=venue_info,
|
||||||
|
source='arxiv' # 添加来源标记
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_latest_papers(
|
||||||
|
self,
|
||||||
|
category: str,
|
||||||
|
debug: bool = False,
|
||||||
|
batch_size: int = 50
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""获取指定类别的最新论文
|
||||||
|
|
||||||
|
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: arXiv类别,例如:
|
||||||
|
- 整个领域: 'cs'
|
||||||
|
- 具体方向: 'cs.AI'
|
||||||
|
- 多个类别: 'cs.AI+q-bio.NC'
|
||||||
|
debug: 是否为调试模式,如果为True则只返回5篇最新论文
|
||||||
|
batch_size: 批量获取论文的数量,默认50
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
论文列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果类别无效
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 处理类别格式
|
||||||
|
# 1. 转换为小写
|
||||||
|
# 2. 确保多个类别之间使用+连接
|
||||||
|
category = category.lower().replace(' ', '+')
|
||||||
|
|
||||||
|
# 构建RSS feed URL
|
||||||
|
feed_url = f"https://rss.arxiv.org/rss/{category}"
|
||||||
|
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
|
||||||
|
|
||||||
|
feed = feedparser.parse(feed_url)
|
||||||
|
|
||||||
|
# 检查feed是否有效
|
||||||
|
if hasattr(feed, 'status') and feed.status != 200:
|
||||||
|
raise ValueError(f"获取RSS feed失败,状态码: {feed.status}")
|
||||||
|
|
||||||
|
if not feed.entries:
|
||||||
|
print(f"警告:未在feed中找到任何条目") # 添加调试信息
|
||||||
|
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
|
||||||
|
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
# 调试模式:只获取5篇最新论文
|
||||||
|
search = arxiv.Search(
|
||||||
|
query=f'cat:{category}',
|
||||||
|
sort_by=arxiv.SortCriterion.SubmittedDate,
|
||||||
|
sort_order=arxiv.SortOrder.Descending,
|
||||||
|
max_results=5
|
||||||
|
)
|
||||||
|
results = list(self.client.results(search))
|
||||||
|
return [self._parse_paper_data(result) for result in results]
|
||||||
|
|
||||||
|
# 正常模式:获取所有新论文
|
||||||
|
# 从RSS条目中提取arXiv ID
|
||||||
|
paper_ids = []
|
||||||
|
for entry in feed.entries:
|
||||||
|
try:
|
||||||
|
# RSS链接格式可能是以下几种:
|
||||||
|
# - http://arxiv.org/abs/2403.xxxxx
|
||||||
|
# - http://arxiv.org/pdf/2403.xxxxx
|
||||||
|
# - https://arxiv.org/abs/2403.xxxxx
|
||||||
|
link = entry.link or entry.id
|
||||||
|
arxiv_id = link.split('/')[-1].replace('.pdf', '')
|
||||||
|
if arxiv_id:
|
||||||
|
paper_ids.append(arxiv_id)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not paper_ids:
|
||||||
|
print("未能从feed中提取到任何论文ID") # 添加调试信息
|
||||||
|
return []
|
||||||
|
|
||||||
|
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
|
||||||
|
|
||||||
|
# 批量获取论文详情
|
||||||
|
papers = []
|
||||||
|
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
|
||||||
|
for i in range(0, len(paper_ids), batch_size):
|
||||||
|
batch_ids = paper_ids[i:i + batch_size]
|
||||||
|
search = arxiv.Search(
|
||||||
|
id_list=batch_ids,
|
||||||
|
max_results=len(batch_ids)
|
||||||
|
)
|
||||||
|
batch_results = list(self.client.results(search))
|
||||||
|
papers.extend([self._parse_paper_data(result) for result in batch_results])
|
||||||
|
pbar.update(len(batch_results))
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取最新论文时发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc()) # 添加完整的错误追踪
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""ArxivSource使用示例"""
|
||||||
|
arxiv_source = ArxivSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 示例1:基本搜索,使用不同的排序方式
|
||||||
|
# print("\n=== 示例1:搜索最新的机器学习论文(按提交时间排序)===")
|
||||||
|
# papers = await arxiv_source.search(
|
||||||
|
# "ti:\"machine learning\"",
|
||||||
|
# limit=3,
|
||||||
|
# sort_by='submitted',
|
||||||
|
# sort_order='descending'
|
||||||
|
# )
|
||||||
|
# print(f"找到 {len(papers)} 篇论文")
|
||||||
|
|
||||||
|
# for i, paper in enumerate(papers, 1):
|
||||||
|
# print(f"\n--- 论文 {i} ---")
|
||||||
|
# print(f"标题: {paper.title}")
|
||||||
|
# print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
# print(f"发表年份: {paper.year}")
|
||||||
|
# print(f"arXiv ID: {paper.doi}")
|
||||||
|
# print(f"PDF URL: {paper.url}")
|
||||||
|
# if paper.abstract:
|
||||||
|
# print(f"\n摘要:")
|
||||||
|
# print(paper.abstract)
|
||||||
|
# print(f"发表venue: {paper.venue}")
|
||||||
|
|
||||||
|
# # 示例2:按ID搜索
|
||||||
|
# print("\n=== 示例2:按ID搜索论文 ===")
|
||||||
|
# paper_id = "2005.14165" # GPT-3论文
|
||||||
|
# papers = await arxiv_source.search_by_id(paper_id)
|
||||||
|
# if papers:
|
||||||
|
# paper = papers[0]
|
||||||
|
# print(f"标题: {paper.title}")
|
||||||
|
# print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
# print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# # 示例3:按类别搜索
|
||||||
|
# print("\n=== 示例3:搜索人工智能领域最新论文 ===")
|
||||||
|
# ai_papers = await arxiv_source.search_by_category(
|
||||||
|
# "cs.AI",
|
||||||
|
# limit=2,
|
||||||
|
# sort_by='updated',
|
||||||
|
# sort_order='descending'
|
||||||
|
# )
|
||||||
|
# for i, paper in enumerate(ai_papers, 1):
|
||||||
|
# print(f"\n--- AI论文 {i} ---")
|
||||||
|
# print(f"标题: {paper.title}")
|
||||||
|
# print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
# print(f"发表venue: {paper.venue}")
|
||||||
|
|
||||||
|
# # 示例4:按作者搜索
|
||||||
|
# print("\n=== 示例4:搜索特定作者的论文 ===")
|
||||||
|
# author_papers = await arxiv_source.search_by_authors(
|
||||||
|
# ["Bengio"],
|
||||||
|
# limit=2,
|
||||||
|
# sort_by='relevance'
|
||||||
|
# )
|
||||||
|
# for i, paper in enumerate(author_papers, 1):
|
||||||
|
# print(f"\n--- Bengio的论文 {i} ---")
|
||||||
|
# print(f"标题: {paper.title}")
|
||||||
|
# print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
# print(f"发表venue: {paper.venue}")
|
||||||
|
|
||||||
|
# # 示例5:按日期范围搜索
|
||||||
|
# print("\n=== 示例5:搜索特定日期范围的论文 ===")
|
||||||
|
# from datetime import datetime, timedelta
|
||||||
|
# end_date = datetime.now()
|
||||||
|
# start_date = end_date - timedelta(days=7) # 最近一周
|
||||||
|
# recent_papers = await arxiv_source.search_by_date_range(
|
||||||
|
# start_date,
|
||||||
|
# end_date,
|
||||||
|
# limit=2
|
||||||
|
# )
|
||||||
|
# for i, paper in enumerate(recent_papers, 1):
|
||||||
|
# print(f"\n--- 最近论文 {i} ---")
|
||||||
|
# print(f"标题: {paper.title}")
|
||||||
|
# print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
# print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# # 示例6:下载PDF
|
||||||
|
# print("\n=== 示例6:下载论文PDF ===")
|
||||||
|
# if papers: # 使用之前搜索到的GPT-3论文
|
||||||
|
# pdf_path = await arxiv_source.download_pdf(paper_id)
|
||||||
|
# print(f"PDF已下载到: {pdf_path}")
|
||||||
|
|
||||||
|
# # 示例7:下载源文件
|
||||||
|
# print("\n=== 示例7:下载论文源文件 ===")
|
||||||
|
# if papers:
|
||||||
|
# source_path = await arxiv_source.download_source(paper_id)
|
||||||
|
# print(f"源文件已下载到: {source_path}")
|
||||||
|
|
||||||
|
# 示例6:获取最新论文
|
||||||
|
print("\n=== 示例8:获取最新论文 ===")
|
||||||
|
|
||||||
|
# 获取CS.AI领域的最新论文
|
||||||
|
print("\n--- 获取AI领域最新论文 ---")
|
||||||
|
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
|
||||||
|
for i, paper in enumerate(ai_latest, 1):
|
||||||
|
print(f"\n论文 {i}:")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# 获取整个计算机科学领域的最新论文
|
||||||
|
print("\n--- 获取整个CS领域最新论文 ---")
|
||||||
|
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
|
||||||
|
for i, paper in enumerate(cs_latest, 1):
|
||||||
|
print(f"\n论文 {i}:")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# 获取多个类别的最新论文
|
||||||
|
print("\n--- 获取AI和机器学习领域最新论文 ---")
|
||||||
|
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
|
||||||
|
for i, paper in enumerate(multi_latest, 1):
|
||||||
|
print(f"\n论文 {i}:")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
class PaperMetadata:
|
||||||
|
"""论文元数据"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
title: str,
|
||||||
|
authors: List[str],
|
||||||
|
abstract: str,
|
||||||
|
year: int,
|
||||||
|
doi: str = None,
|
||||||
|
url: str = None,
|
||||||
|
citations: int = None,
|
||||||
|
venue: str = None,
|
||||||
|
institutions: List[str] = None,
|
||||||
|
venue_type: str = None, # 来源类型(journal/conference/preprint等)
|
||||||
|
venue_name: str = None, # 具体的期刊/会议名称
|
||||||
|
venue_info: Dict = None, # 更多来源详细信息(如影响因子、分区等)
|
||||||
|
source: str = None # 新增: 论文来源标记
|
||||||
|
):
|
||||||
|
self.title = title
|
||||||
|
self.authors = authors
|
||||||
|
self.abstract = abstract
|
||||||
|
self.year = year
|
||||||
|
self.doi = doi
|
||||||
|
self.url = url
|
||||||
|
self.citations = citations
|
||||||
|
self.venue = venue
|
||||||
|
self.institutions = institutions or []
|
||||||
|
self.venue_type = venue_type # 新增
|
||||||
|
self.venue_name = venue_name # 新增
|
||||||
|
self.venue_info = venue_info or {} # 新增
|
||||||
|
self.source = source # 新增: 存储论文来源
|
||||||
|
|
||||||
|
# 新增:影响因子和分区信息,初始化为None
|
||||||
|
self._if_factor = None
|
||||||
|
self._cas_division = None
|
||||||
|
self._jcr_division = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def if_factor(self) -> Optional[float]:
|
||||||
|
"""获取影响因子"""
|
||||||
|
return self._if_factor
|
||||||
|
|
||||||
|
@if_factor.setter
|
||||||
|
def if_factor(self, value: float):
|
||||||
|
"""设置影响因子"""
|
||||||
|
self._if_factor = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cas_division(self) -> Optional[str]:
|
||||||
|
"""获取中科院分区"""
|
||||||
|
return self._cas_division
|
||||||
|
|
||||||
|
@cas_division.setter
|
||||||
|
def cas_division(self, value: str):
|
||||||
|
"""设置中科院分区"""
|
||||||
|
self._cas_division = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def jcr_division(self) -> Optional[str]:
|
||||||
|
"""获取JCR分区"""
|
||||||
|
return self._jcr_division
|
||||||
|
|
||||||
|
@jcr_division.setter
|
||||||
|
def jcr_division(self, value: str):
|
||||||
|
"""设置JCR分区"""
|
||||||
|
self._jcr_division = value
|
||||||
|
|
||||||
|
class DataSource(ABC):
|
||||||
|
"""数据源基类"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: Optional[str] = None):
|
||||||
|
self.api_key = api_key
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化数据源"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
"""搜索论文"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_paper_details(self, paper_id: str) -> PaperMetadata:
|
||||||
|
"""获取论文详细信息"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
"""获取引用该论文的文献"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
"""获取该论文引用的文献"""
|
||||||
|
pass
|
||||||
文件差异因一行或多行过长而隐藏
@@ -0,0 +1,400 @@
|
|||||||
|
import aiohttp
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||||
|
import random
|
||||||
|
|
||||||
|
class CrossrefSource(DataSource):
|
||||||
|
"""Crossref API实现"""
|
||||||
|
|
||||||
|
CONTACT_EMAILS = [
|
||||||
|
"gpt_abc_academic@163.com",
|
||||||
|
"gpt_abc_newapi@163.com",
|
||||||
|
"gpt_abc_academic_pwd@163.com"
|
||||||
|
]
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化客户端,设置默认参数"""
|
||||||
|
self.base_url = "https://api.crossref.org"
|
||||||
|
# 随机选择一个邮箱
|
||||||
|
contact_email = random.choice(self.CONTACT_EMAILS)
|
||||||
|
self.headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
|
||||||
|
}
|
||||||
|
if self.api_key:
|
||||||
|
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = None,
|
||||||
|
sort_order: str = None,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""搜索论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
sort_by: 排序字段
|
||||||
|
sort_order: 排序顺序
|
||||||
|
start_year: 起始年份
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
# 请求更多的结果以补偿可能被过滤掉的文章
|
||||||
|
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"rows": adjusted_limit,
|
||||||
|
"select": (
|
||||||
|
"DOI,title,author,published-print,abstract,reference,"
|
||||||
|
"container-title,is-referenced-by-count,type,"
|
||||||
|
"publisher,ISSN,ISBN,issue,volume,page"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加年份过滤
|
||||||
|
if start_year:
|
||||||
|
params["filter"] = f"from-pub-date:{start_year}"
|
||||||
|
|
||||||
|
# 添加排序
|
||||||
|
if sort_by:
|
||||||
|
params["sort"] = sort_by
|
||||||
|
if sort_order:
|
||||||
|
params["order"] = sort_order
|
||||||
|
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works",
|
||||||
|
params=params
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
print(f"API请求失败: HTTP {response.status}")
|
||||||
|
print(f"响应内容: {await response.text()}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
data = await response.json()
|
||||||
|
items = data.get("message", {}).get("items", [])
|
||||||
|
if not items:
|
||||||
|
print(f"未找到相关论文")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 过滤掉没有摘要的文章
|
||||||
|
papers = []
|
||||||
|
filtered_count = 0
|
||||||
|
for work in items:
|
||||||
|
paper = self._parse_work(work)
|
||||||
|
if paper.abstract and paper.abstract.strip():
|
||||||
|
papers.append(paper)
|
||||||
|
if len(papers) >= limit: # 达到原始请求的限制后停止
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
filtered_count += 1
|
||||||
|
|
||||||
|
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
|
||||||
|
print(f"返回 {len(papers)} 篇包含摘要的论文")
|
||||||
|
return papers
|
||||||
|
|
||||||
|
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||||
|
"""获取指定DOI的论文详情"""
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works/{doi}",
|
||||||
|
params={
|
||||||
|
"select": (
|
||||||
|
"DOI,title,author,published-print,abstract,reference,"
|
||||||
|
"container-title,is-referenced-by-count,type,"
|
||||||
|
"publisher,ISSN,ISBN,issue,volume,page"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
print(f"获取论文详情失败: HTTP {response.status}")
|
||||||
|
print(f"响应内容: {await response.text()}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await response.json()
|
||||||
|
return self._parse_work(data.get("message", {}))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析论文详情时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||||
|
"""获取指定DOI论文的参考文献列表"""
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works/{doi}",
|
||||||
|
params={"select": "reference"}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
print(f"获取参考文献失败: HTTP {response.status}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await response.json()
|
||||||
|
# 确保我们正确处理返回的数据结构
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
print(f"API返回了意外的数据格式: {type(data)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
references = data.get("message", {}).get("reference", [])
|
||||||
|
if not references:
|
||||||
|
print(f"未找到参考文献")
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
PaperMetadata(
|
||||||
|
title=ref.get("article-title", ""),
|
||||||
|
authors=[ref.get("author", "")],
|
||||||
|
year=ref.get("year"),
|
||||||
|
doi=ref.get("DOI"),
|
||||||
|
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
|
||||||
|
abstract="",
|
||||||
|
citations=None,
|
||||||
|
venue=ref.get("journal-title", ""),
|
||||||
|
institutions=[]
|
||||||
|
)
|
||||||
|
for ref in references
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析参考文献数据时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||||
|
"""获取引用指定DOI论文的文献列表"""
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works",
|
||||||
|
params={
|
||||||
|
"filter": f"reference.DOI:{doi}",
|
||||||
|
"select": "DOI,title,author,published-print,abstract"
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
print(f"获取引用信息失败: HTTP {response.status}")
|
||||||
|
print(f"响应内容: {await response.text()}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await response.json()
|
||||||
|
# 检查返回的数据结构
|
||||||
|
if isinstance(data, dict):
|
||||||
|
items = data.get("message", {}).get("items", [])
|
||||||
|
return [self._parse_work(work) for work in items]
|
||||||
|
else:
|
||||||
|
print(f"API返回了意外的数据格式: {type(data)}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析引用数据时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||||
|
"""解析Crossref返回的数据"""
|
||||||
|
# 获取摘要 - 处理可能的不同格式
|
||||||
|
abstract = ""
|
||||||
|
if isinstance(work.get("abstract"), str):
|
||||||
|
abstract = work.get("abstract", "")
|
||||||
|
elif isinstance(work.get("abstract"), dict):
|
||||||
|
abstract = work.get("abstract", {}).get("value", "")
|
||||||
|
|
||||||
|
if not abstract:
|
||||||
|
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
|
||||||
|
|
||||||
|
# 获取机构信息
|
||||||
|
institutions = []
|
||||||
|
for author in work.get("author", []):
|
||||||
|
if "affiliation" in author:
|
||||||
|
for affiliation in author["affiliation"]:
|
||||||
|
if "name" in affiliation and affiliation["name"] not in institutions:
|
||||||
|
institutions.append(affiliation["name"])
|
||||||
|
|
||||||
|
# 获取venue信息
|
||||||
|
venue_name = work.get("container-title", [None])[0]
|
||||||
|
venue_type = work.get("type", "unknown") # 文献类型
|
||||||
|
venue_info = {
|
||||||
|
"publisher": work.get("publisher"),
|
||||||
|
"issn": work.get("ISSN", []),
|
||||||
|
"isbn": work.get("ISBN", []),
|
||||||
|
"issue": work.get("issue"),
|
||||||
|
"volume": work.get("volume"),
|
||||||
|
"page": work.get("page")
|
||||||
|
}
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=work.get("title", [None])[0] or "",
|
||||||
|
authors=[
|
||||||
|
author.get("given", "") + " " + author.get("family", "")
|
||||||
|
for author in work.get("author", [])
|
||||||
|
],
|
||||||
|
institutions=institutions, # 添加机构信息
|
||||||
|
abstract=abstract,
|
||||||
|
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
|
||||||
|
doi=work.get("DOI"),
|
||||||
|
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
|
||||||
|
citations=work.get("is-referenced-by-count"),
|
||||||
|
venue=venue_name,
|
||||||
|
venue_type=venue_type, # 添加venue类型
|
||||||
|
venue_name=venue_name, # 添加venue名称
|
||||||
|
venue_info=venue_info, # 添加venue详细信息
|
||||||
|
source='crossref' # 添加来源标记
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search_by_authors(
|
||||||
|
self,
|
||||||
|
authors: List[str],
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = None,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按作者搜索论文"""
|
||||||
|
query = " ".join([f"author:\"{author}\"" for author in authors])
|
||||||
|
return await self.search(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
sort_by=sort_by,
|
||||||
|
start_year=start_year
|
||||||
|
)
|
||||||
|
|
||||||
|
async def search_by_date_range(
|
||||||
|
self,
|
||||||
|
start_date: datetime,
|
||||||
|
end_date: datetime,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = None,
|
||||||
|
sort_order: str = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按日期范围搜索论文"""
|
||||||
|
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
|
||||||
|
return await self.search(
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order
|
||||||
|
)
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""CrossrefSource使用示例"""
|
||||||
|
crossref = CrossrefSource(api_key=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 示例1:基本搜索,使用不同的排序方式
|
||||||
|
print("\n=== 示例1:搜索最新的机器学习论文 ===")
|
||||||
|
papers = await crossref.search(
|
||||||
|
query="machine learning",
|
||||||
|
limit=3,
|
||||||
|
sort_by="published",
|
||||||
|
sort_order="desc",
|
||||||
|
start_year=2023
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"URL: {paper.url}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"摘要: {paper.abstract[:200]}...")
|
||||||
|
if paper.institutions:
|
||||||
|
print(f"机构: {', '.join(paper.institutions)}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"发表venue: {paper.venue}")
|
||||||
|
print(f"venue类型: {paper.venue_type}")
|
||||||
|
if paper.venue_info:
|
||||||
|
print("Venue详细信息:")
|
||||||
|
for key, value in paper.venue_info.items():
|
||||||
|
if value:
|
||||||
|
print(f" - {key}: {value}")
|
||||||
|
|
||||||
|
# 示例2:按DOI获取论文详情
|
||||||
|
print("\n=== 示例2:获取特定论文详情 ===")
|
||||||
|
# 使用BERT论文的DOI
|
||||||
|
doi = "10.18653/v1/N19-1423"
|
||||||
|
paper = await crossref.get_paper_details(doi)
|
||||||
|
if paper:
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"摘要: {paper.abstract[:200]}...")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
|
||||||
|
# 示例3:按作者搜索
|
||||||
|
print("\n=== 示例3:搜索特定作者的论文 ===")
|
||||||
|
author_papers = await crossref.search_by_authors(
|
||||||
|
authors=["Yoshua Bengio"],
|
||||||
|
limit=3,
|
||||||
|
sort_by="published",
|
||||||
|
start_year=2020
|
||||||
|
)
|
||||||
|
for i, paper in enumerate(author_papers, 1):
|
||||||
|
print(f"\n--- {i}. {paper.title} ---")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
|
||||||
|
# 示例4:按日期范围搜索
|
||||||
|
print("\n=== 示例4:搜索特定日期范围的论文 ===")
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
end_date = datetime.now()
|
||||||
|
start_date = end_date - timedelta(days=30) # 最近一个月
|
||||||
|
recent_papers = await crossref.search_by_date_range(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=3,
|
||||||
|
sort_by="published",
|
||||||
|
sort_order="desc"
|
||||||
|
)
|
||||||
|
for i, paper in enumerate(recent_papers, 1):
|
||||||
|
print(f"\n--- 最近发表的论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
|
||||||
|
# 示例5:获取论文引用信息
|
||||||
|
print("\n=== 示例5:获取论文引用信息 ===")
|
||||||
|
if paper: # 使用之前获取的BERT论文
|
||||||
|
print("\n获取引用该论文的文献:")
|
||||||
|
citations = await crossref.get_citations(paper.doi)
|
||||||
|
for i, citing_paper in enumerate(citations[:3], 1):
|
||||||
|
print(f"\n--- 引用论文 {i} ---")
|
||||||
|
print(f"标题: {citing_paper.title}")
|
||||||
|
print(f"作者: {', '.join(citing_paper.authors)}")
|
||||||
|
print(f"发表年份: {citing_paper.year}")
|
||||||
|
|
||||||
|
print("\n获取该论文引用的参考文献:")
|
||||||
|
references = await crossref.get_references(paper.doi)
|
||||||
|
for i, ref_paper in enumerate(references[:3], 1):
|
||||||
|
print(f"\n--- 参考文献 {i} ---")
|
||||||
|
print(f"标题: {ref_paper.title}")
|
||||||
|
print(f"作者: {', '.join(ref_paper.authors)}")
|
||||||
|
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
|
||||||
|
|
||||||
|
# 示例6:展示venue信息的使用
|
||||||
|
print("\n=== 示例6:展示期刊/会议详细信息 ===")
|
||||||
|
if papers:
|
||||||
|
paper = papers[0]
|
||||||
|
print(f"文献类型: {paper.venue_type}")
|
||||||
|
print(f"发表venue: {paper.venue_name}")
|
||||||
|
if paper.venue_info:
|
||||||
|
print("Venue详细信息:")
|
||||||
|
for key, value in paper.venue_info.items():
|
||||||
|
if value:
|
||||||
|
print(f" - {key}: {value}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 运行示例
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,449 @@
|
|||||||
|
from typing import List, Optional, Dict, Union
|
||||||
|
from datetime import datetime
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import random
|
||||||
|
|
||||||
|
class ElsevierSource(DataSource):
|
||||||
|
"""Elsevier (Scopus) API实现"""
|
||||||
|
|
||||||
|
# 定义API密钥列表
|
||||||
|
API_KEYS = [
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
"""初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Elsevier API密钥,如果不提供则从预定义列表中随机选择
|
||||||
|
"""
|
||||||
|
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化基础URL和请求头"""
|
||||||
|
self.base_url = "https://api.elsevier.com/content"
|
||||||
|
self.headers = {
|
||||||
|
"X-ELS-APIKey": self.api_key,
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
# 添加更多必要的头部信息
|
||||||
|
"X-ELS-Insttoken": "", # 如果有机构令牌
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||||
|
"""发送HTTP请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 请求URL
|
||||||
|
params: 查询参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.get(url, params=params) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
else:
|
||||||
|
# 添加更详细的错误信息
|
||||||
|
error_text = await response.text()
|
||||||
|
print(f"请求失败: {response.status}")
|
||||||
|
print(f"错误详情: {error_text}")
|
||||||
|
if response.status == 401:
|
||||||
|
print(f"使用的API密钥: {self.api_key}")
|
||||||
|
# 尝试切换到另一个API密钥
|
||||||
|
new_key = random.choice([k for k in self.API_KEYS if k != self.api_key])
|
||||||
|
print(f"尝试切换到新的API密钥: {new_key}")
|
||||||
|
self.api_key = new_key
|
||||||
|
self.headers["X-ELS-APIKey"] = new_key
|
||||||
|
# 重试请求
|
||||||
|
return await self._make_request(url, params)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = "relevance",
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""搜索论文"""
|
||||||
|
try:
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"count": min(limit, 100),
|
||||||
|
"view": "STANDARD",
|
||||||
|
# 移除dc:description字段,因为它在STANDARD视图中不可用
|
||||||
|
"field": "dc:title,dc:creator,prism:doi,prism:coverDate,citedby-count,prism:publicationName"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加年份过滤
|
||||||
|
if start_year:
|
||||||
|
params["date"] = f"{start_year}-present"
|
||||||
|
|
||||||
|
# 添加排序
|
||||||
|
if sort_by == "date":
|
||||||
|
params["sort"] = "-coverDate"
|
||||||
|
elif sort_by == "cited":
|
||||||
|
params["sort"] = "-citedby-count"
|
||||||
|
|
||||||
|
# 发送搜索请求
|
||||||
|
response = await self._make_request(
|
||||||
|
f"{self.base_url}/search/scopus",
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response or "search-results" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析搜索结果
|
||||||
|
entries = response["search-results"].get("entry", [])
|
||||||
|
papers = [paper for paper in (self._parse_entry(entry) for entry in entries) if paper is not None]
|
||||||
|
|
||||||
|
# 尝试为每篇论文获取摘要
|
||||||
|
for paper in papers:
|
||||||
|
if paper.doi:
|
||||||
|
paper.abstract = await self.fetch_abstract(paper.doi) or ""
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索论文时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_entry(self, entry: Dict) -> Optional[PaperMetadata]:
|
||||||
|
"""解析Scopus API返回的条目"""
|
||||||
|
try:
|
||||||
|
# 获取作者列表
|
||||||
|
authors = []
|
||||||
|
creator = entry.get("dc:creator")
|
||||||
|
if creator:
|
||||||
|
authors = [creator]
|
||||||
|
|
||||||
|
# 获取发表年份
|
||||||
|
year = None
|
||||||
|
if "prism:coverDate" in entry:
|
||||||
|
try:
|
||||||
|
year = int(entry["prism:coverDate"][:4])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 简化venue信息
|
||||||
|
venue_info = {
|
||||||
|
'source_id': entry.get("source-id"),
|
||||||
|
'issn': entry.get("prism:issn")
|
||||||
|
}
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=entry.get("dc:title", ""),
|
||||||
|
authors=authors,
|
||||||
|
abstract=entry.get("dc:description", ""), # 从响应中获取摘要
|
||||||
|
year=year,
|
||||||
|
doi=entry.get("prism:doi"),
|
||||||
|
url=entry.get("prism:url"),
|
||||||
|
citations=int(entry.get("citedby-count", 0)),
|
||||||
|
venue=entry.get("prism:publicationName"),
|
||||||
|
institutions=[], # 移除机构信息
|
||||||
|
venue_type="",
|
||||||
|
venue_name=entry.get("prism:publicationName"),
|
||||||
|
venue_info=venue_info
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析条目时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_citations(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
"""获取引用该论文的文献"""
|
||||||
|
try:
|
||||||
|
params = {
|
||||||
|
"query": f"REF({doi})",
|
||||||
|
"count": min(limit, 100),
|
||||||
|
"view": "STANDARD"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await self._make_request(
|
||||||
|
f"{self.base_url}/search/scopus",
|
||||||
|
params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response or "search-results" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entries = response["search-results"].get("entry", [])
|
||||||
|
return [self._parse_entry(entry) for entry in entries]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取引用文献时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||||
|
"""获取该论文引用的文献"""
|
||||||
|
try:
|
||||||
|
response = await self._make_request(
|
||||||
|
f"{self.base_url}/abstract/doi/{doi}/references",
|
||||||
|
params={"view": "STANDARD"}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response or "references" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
references = response["references"].get("reference", [])
|
||||||
|
papers = [paper for paper in (self._parse_reference(ref) for ref in references) if paper is not None]
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取参考文献时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_reference(self, ref: Dict) -> Optional[PaperMetadata]:
|
||||||
|
"""解析参考文献数据"""
|
||||||
|
try:
|
||||||
|
authors = []
|
||||||
|
if "author-list" in ref:
|
||||||
|
author_list = ref["author-list"].get("author", [])
|
||||||
|
if isinstance(author_list, list):
|
||||||
|
authors = [f"{author.get('ce:given-name', '')} {author.get('ce:surname', '')}"
|
||||||
|
for author in author_list]
|
||||||
|
else:
|
||||||
|
authors = [f"{author_list.get('ce:given-name', '')} {author_list.get('ce:surname', '')}"]
|
||||||
|
|
||||||
|
year = None
|
||||||
|
if "prism:coverDate" in ref:
|
||||||
|
try:
|
||||||
|
year = int(ref["prism:coverDate"][:4])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=ref.get("ce:title", ""),
|
||||||
|
authors=authors,
|
||||||
|
abstract="", # 参考文献通常不包含摘要
|
||||||
|
year=year,
|
||||||
|
doi=ref.get("prism:doi"),
|
||||||
|
url=None,
|
||||||
|
citations=None,
|
||||||
|
venue=ref.get("prism:publicationName"),
|
||||||
|
institutions=[],
|
||||||
|
venue_type="unknown",
|
||||||
|
venue_name=ref.get("prism:publicationName"),
|
||||||
|
venue_info={}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析参考文献时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def search_by_author(
|
||||||
|
self,
|
||||||
|
author: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按作者搜索论文"""
|
||||||
|
query = f"AUTHOR-NAME({author})"
|
||||||
|
return await self.search(query, limit=limit, start_year=start_year)
|
||||||
|
|
||||||
|
async def search_by_affiliation(
|
||||||
|
self,
|
||||||
|
affiliation: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按机构搜索论文"""
|
||||||
|
query = f"AF-ID({affiliation})"
|
||||||
|
return await self.search(query, limit=limit, start_year=start_year)
|
||||||
|
|
||||||
|
async def search_by_venue(
|
||||||
|
self,
|
||||||
|
venue: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按期刊/会议搜索论文"""
|
||||||
|
query = f"SRCTITLE({venue})"
|
||||||
|
return await self.search(query, limit=limit, start_year=start_year)
|
||||||
|
|
||||||
|
async def test_api_access(self):
|
||||||
|
"""测试API访问权限"""
|
||||||
|
print(f"\n测试API密钥: {self.api_key}")
|
||||||
|
|
||||||
|
# 测试1: 基础搜索
|
||||||
|
basic_params = {
|
||||||
|
"query": "test",
|
||||||
|
"count": 1,
|
||||||
|
"view": "STANDARD"
|
||||||
|
}
|
||||||
|
print("\n1. 测试基础搜索...")
|
||||||
|
response = await self._make_request(
|
||||||
|
f"{self.base_url}/search/scopus",
|
||||||
|
params=basic_params
|
||||||
|
)
|
||||||
|
if response:
|
||||||
|
print("基础搜索成功")
|
||||||
|
print("可用字段:", list(response.get("search-results", {}).get("entry", [{}])[0].keys()))
|
||||||
|
|
||||||
|
# 测试2: 测试单篇文章访问
|
||||||
|
print("\n2. 测试文章详情访问...")
|
||||||
|
test_doi = "10.1016/j.artint.2021.103535" # 一个示例DOI
|
||||||
|
response = await self._make_request(
|
||||||
|
f"{self.base_url}/abstract/doi/{test_doi}",
|
||||||
|
params={"view": "STANDARD"} # 改为STANDARD视图
|
||||||
|
)
|
||||||
|
if response:
|
||||||
|
print("文章详情访问成功")
|
||||||
|
else:
|
||||||
|
print("文章详情访问失败")
|
||||||
|
|
||||||
|
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||||
|
"""获取论文详细信息
|
||||||
|
|
||||||
|
注意:当前API权限不支持获取详细信息,返回None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: 论文ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None,因为当前API权限不支持此功能
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def fetch_abstract(self, doi: str) -> Optional[str]:
|
||||||
|
"""获取论文摘要
|
||||||
|
|
||||||
|
使用Scopus Abstract API获取论文摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doi: 论文的DOI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
摘要文本,如果获取失败则返回None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用Abstract API而不是Search API
|
||||||
|
response = await self._make_request(
|
||||||
|
f"{self.base_url}/abstract/doi/{doi}",
|
||||||
|
params={
|
||||||
|
"view": "FULL" # 使用FULL视图
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if response and "abstracts-retrieval-response" in response:
|
||||||
|
# 从coredata中获取摘要
|
||||||
|
coredata = response["abstracts-retrieval-response"].get("coredata", {})
|
||||||
|
return coredata.get("dc:description", "")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取摘要时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""ElsevierSource使用示例"""
|
||||||
|
elsevier = ElsevierSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 首先测试API访问权限
|
||||||
|
print("\n=== 测试API访问权限 ===")
|
||||||
|
await elsevier.test_api_access()
|
||||||
|
|
||||||
|
# 示例1:基本搜索
|
||||||
|
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||||
|
papers = await elsevier.search("machine learning", limit=3)
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"URL: {paper.url}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"期刊/会议: {paper.venue}")
|
||||||
|
print("期刊信息:")
|
||||||
|
for key, value in paper.venue_info.items():
|
||||||
|
if value: # 只打印非空值
|
||||||
|
print(f" - {key}: {value}")
|
||||||
|
|
||||||
|
# 示例2:获取引用信息
|
||||||
|
if papers and papers[0].doi:
|
||||||
|
print("\n=== 示例2:获取引用该论文的文献 ===")
|
||||||
|
citations = await elsevier.get_citations(papers[0].doi, limit=3)
|
||||||
|
for i, paper in enumerate(citations, 1):
|
||||||
|
print(f"\n--- 引用论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"期刊/会议: {paper.venue}")
|
||||||
|
|
||||||
|
# 示例3:获取参考文献
|
||||||
|
if papers and papers[0].doi:
|
||||||
|
print("\n=== 示例3:获取论文的参考文献 ===")
|
||||||
|
references = await elsevier.get_references(papers[0].doi)
|
||||||
|
for i, paper in enumerate(references[:3], 1):
|
||||||
|
print(f"\n--- 参考文献 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"期刊/会议: {paper.venue}")
|
||||||
|
|
||||||
|
# 示例4:按作者搜索
|
||||||
|
print("\n=== 示例4:按作者搜索 ===")
|
||||||
|
author_papers = await elsevier.search_by_author("Hinton G", limit=3)
|
||||||
|
for i, paper in enumerate(author_papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"期刊/会议: {paper.venue}")
|
||||||
|
|
||||||
|
# 示例5:按机构搜索
|
||||||
|
print("\n=== 示例5:按机构搜索 ===")
|
||||||
|
affiliation_papers = await elsevier.search_by_affiliation("60027950", limit=3) # MIT的机构ID
|
||||||
|
for i, paper in enumerate(affiliation_papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"期刊/会议: {paper.venue}")
|
||||||
|
|
||||||
|
# 示例6:获取论文摘要
|
||||||
|
print("\n=== 示例6:获取论文摘要 ===")
|
||||||
|
test_doi = "10.1016/j.artint.2021.103535"
|
||||||
|
abstract = await elsevier.fetch_abstract(test_doi)
|
||||||
|
if abstract:
|
||||||
|
print(f"摘要: {abstract[:200]}...") # 只显示前200个字符
|
||||||
|
else:
|
||||||
|
print("无法获取摘要")
|
||||||
|
|
||||||
|
# 在搜索结果中显示摘要
|
||||||
|
print("\n=== 示例7:搜索结果中的摘要 ===")
|
||||||
|
papers = await elsevier.search("machine learning", limit=1)
|
||||||
|
for paper in papers:
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"摘要: {paper.abstract[:200]}..." if paper.abstract else "摘要: 无")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,698 @@
|
|||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Optional, Union, Any
|
||||||
|
|
||||||
|
class GitHubSource:
|
||||||
|
"""GitHub API实现"""
|
||||||
|
|
||||||
|
# 默认API密钥列表 - 可以放置多个GitHub令牌
|
||||||
|
API_KEYS = [
|
||||||
|
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
|
||||||
|
"""初始化GitHub API客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: GitHub个人访问令牌或令牌列表
|
||||||
|
"""
|
||||||
|
if api_key is None:
|
||||||
|
self.api_keys = self.API_KEYS
|
||||||
|
elif isinstance(api_key, str):
|
||||||
|
self.api_keys = [api_key]
|
||||||
|
else:
|
||||||
|
self.api_keys = api_key
|
||||||
|
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化客户端,设置默认参数"""
|
||||||
|
self.base_url = "https://api.github.com"
|
||||||
|
self.headers = {
|
||||||
|
"Accept": "application/vnd.github+json",
|
||||||
|
"X-GitHub-Api-Version": "2022-11-28",
|
||||||
|
"User-Agent": "GitHub-API-Python-Client"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有可用的API密钥,随机选择一个
|
||||||
|
if self.api_keys:
|
||||||
|
selected_key = random.choice(self.api_keys)
|
||||||
|
self.headers["Authorization"] = f"Bearer {selected_key}"
|
||||||
|
print(f"已随机选择API密钥进行认证")
|
||||||
|
else:
|
||||||
|
print("警告: 未提供API密钥,将受到GitHub API请求限制")
|
||||||
|
|
||||||
|
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
|
||||||
|
"""发送API请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP方法 (GET, POST, PUT, DELETE等)
|
||||||
|
endpoint: API端点
|
||||||
|
params: URL参数
|
||||||
|
data: 请求体数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的响应JSON
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
|
||||||
|
# 为调试目的打印请求信息
|
||||||
|
print(f"请求: {method} {url}")
|
||||||
|
if params:
|
||||||
|
print(f"参数: {params}")
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
request_kwargs = {}
|
||||||
|
if params:
|
||||||
|
request_kwargs["params"] = params
|
||||||
|
if data:
|
||||||
|
request_kwargs["json"] = data
|
||||||
|
|
||||||
|
async with session.request(method, url, **request_kwargs) as response:
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
# 检查HTTP状态码
|
||||||
|
if response.status >= 400:
|
||||||
|
print(f"API请求失败: HTTP {response.status}")
|
||||||
|
print(f"响应内容: {response_text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 解析JSON响应
|
||||||
|
try:
|
||||||
|
return json.loads(response_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"JSON解析错误: {response_text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ===== 用户相关方法 =====
|
||||||
|
|
||||||
|
async def get_user(self, username: Optional[str] = None) -> Dict:
|
||||||
|
"""获取用户信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: 指定用户名,不指定则获取当前授权用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户信息字典
|
||||||
|
"""
|
||||||
|
endpoint = "/user" if username is None else f"/users/{username}"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
|
||||||
|
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取用户的仓库列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: 指定用户名,不指定则获取当前授权用户
|
||||||
|
sort: 排序方式 (created, updated, pushed, full_name)
|
||||||
|
direction: 排序方向 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
仓库列表
|
||||||
|
"""
|
||||||
|
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
|
||||||
|
params = {
|
||||||
|
"sort": sort,
|
||||||
|
"direction": direction,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def get_user_starred(self, username: Optional[str] = None,
|
||||||
|
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取用户星标的仓库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: 指定用户名,不指定则获取当前授权用户
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
星标仓库列表
|
||||||
|
"""
|
||||||
|
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
|
||||||
|
params = {
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
# ===== 仓库相关方法 =====
|
||||||
|
|
||||||
|
async def get_repo(self, owner: str, repo: str) -> Dict:
|
||||||
|
"""获取仓库信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
仓库信息
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取仓库的分支列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分支列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/branches"
|
||||||
|
params = {
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
|
||||||
|
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取仓库的提交历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
sha: 特定提交SHA或分支名
|
||||||
|
path: 文件路径筛选
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提交列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/commits"
|
||||||
|
params = {
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
if sha:
|
||||||
|
params["sha"] = sha
|
||||||
|
if path:
|
||||||
|
params["path"] = path
|
||||||
|
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
|
||||||
|
"""获取特定提交的详情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
commit_sha: 提交SHA
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提交详情
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
# ===== 内容相关方法 =====
|
||||||
|
|
||||||
|
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
|
||||||
|
"""获取文件内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
path: 文件路径
|
||||||
|
ref: 分支名、标签名或提交SHA
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文件内容信息
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||||
|
params = {}
|
||||||
|
if ref:
|
||||||
|
params["ref"] = ref
|
||||||
|
|
||||||
|
response = await self._request("GET", endpoint, params=params)
|
||||||
|
if response and isinstance(response, dict) and "content" in response:
|
||||||
|
try:
|
||||||
|
# 解码Base64编码的文件内容
|
||||||
|
content = base64.b64decode(response["content"].encode()).decode()
|
||||||
|
response["decoded_content"] = content
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解码文件内容时出错: {str(e)}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
|
||||||
|
"""获取目录内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
path: 目录路径
|
||||||
|
ref: 分支名、标签名或提交SHA
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
目录内容列表
|
||||||
|
"""
|
||||||
|
# 注意:此方法与get_file_content使用相同的端点,但对于目录会返回列表
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||||
|
params = {}
|
||||||
|
if ref:
|
||||||
|
params["ref"] = ref
|
||||||
|
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
# ===== Issues相关方法 =====
|
||||||
|
|
||||||
|
async def get_issues(self, owner: str, repo: str, state: str = "open",
|
||||||
|
sort: str = "created", direction: str = "desc",
|
||||||
|
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取仓库的Issues列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
state: Issue状态 (open, closed, all)
|
||||||
|
sort: 排序方式 (created, updated, comments)
|
||||||
|
direction: 排序方向 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Issues列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/issues"
|
||||||
|
params = {
|
||||||
|
"state": state,
|
||||||
|
"sort": sort,
|
||||||
|
"direction": direction,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
|
||||||
|
"""获取特定Issue的详情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
issue_number: Issue编号
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Issue详情
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
|
||||||
|
"""获取Issue的评论
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
issue_number: Issue编号
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
评论列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
# ===== Pull Requests相关方法 =====
|
||||||
|
|
||||||
|
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
|
||||||
|
sort: str = "created", direction: str = "desc",
|
||||||
|
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取仓库的Pull Request列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
state: PR状态 (open, closed, all)
|
||||||
|
sort: 排序方式 (created, updated, popularity, long-running)
|
||||||
|
direction: 排序方向 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pull Request列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/pulls"
|
||||||
|
params = {
|
||||||
|
"state": state,
|
||||||
|
"sort": sort,
|
||||||
|
"direction": direction,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
|
||||||
|
"""获取特定Pull Request的详情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
pr_number: Pull Request编号
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pull Request详情
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
|
||||||
|
"""获取Pull Request中修改的文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
pr_number: Pull Request编号
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
修改文件列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
# ===== 搜索相关方法 =====
|
||||||
|
|
||||||
|
async def search_repositories(self, query: str, sort: str = "stars",
|
||||||
|
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||||
|
"""搜索仓库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
sort: 排序方式 (stars, forks, updated)
|
||||||
|
order: 排序顺序 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果
|
||||||
|
"""
|
||||||
|
endpoint = "/search/repositories"
|
||||||
|
params = {
|
||||||
|
"q": query,
|
||||||
|
"sort": sort,
|
||||||
|
"order": order,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def search_code(self, query: str, sort: str = "indexed",
|
||||||
|
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||||
|
"""搜索代码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
sort: 排序方式 (indexed)
|
||||||
|
order: 排序顺序 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果
|
||||||
|
"""
|
||||||
|
endpoint = "/search/code"
|
||||||
|
params = {
|
||||||
|
"q": query,
|
||||||
|
"sort": sort,
|
||||||
|
"order": order,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def search_issues(self, query: str, sort: str = "created",
|
||||||
|
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||||
|
"""搜索Issues和Pull Requests
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
sort: 排序方式 (created, updated, comments)
|
||||||
|
order: 排序顺序 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果
|
||||||
|
"""
|
||||||
|
endpoint = "/search/issues"
|
||||||
|
params = {
|
||||||
|
"q": query,
|
||||||
|
"sort": sort,
|
||||||
|
"order": order,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def search_users(self, query: str, sort: str = "followers",
|
||||||
|
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||||
|
"""搜索用户
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
sort: 排序方式 (followers, repositories, joined)
|
||||||
|
order: 排序顺序 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果
|
||||||
|
"""
|
||||||
|
endpoint = "/search/users"
|
||||||
|
params = {
|
||||||
|
"q": query,
|
||||||
|
"sort": sort,
|
||||||
|
"order": order,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
# ===== 组织相关方法 =====
|
||||||
|
|
||||||
|
async def get_organization(self, org: str) -> Dict:
|
||||||
|
"""获取组织信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org: 组织名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
组织信息
|
||||||
|
"""
|
||||||
|
endpoint = f"/orgs/{org}"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def get_organization_repos(self, org: str, type: str = "all",
|
||||||
|
sort: str = "created", direction: str = "desc",
|
||||||
|
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取组织的仓库列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org: 组织名称
|
||||||
|
type: 仓库类型 (all, public, private, forks, sources, member, internal)
|
||||||
|
sort: 排序方式 (created, updated, pushed, full_name)
|
||||||
|
direction: 排序方向 (asc, desc)
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
仓库列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/orgs/{org}/repos"
|
||||||
|
params = {
|
||||||
|
"type": type,
|
||||||
|
"sort": sort,
|
||||||
|
"direction": direction,
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||||
|
"""获取组织成员列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
org: 组织名称
|
||||||
|
per_page: 每页结果数量
|
||||||
|
page: 页码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
成员列表
|
||||||
|
"""
|
||||||
|
endpoint = f"/orgs/{org}/members"
|
||||||
|
params = {
|
||||||
|
"per_page": per_page,
|
||||||
|
"page": page
|
||||||
|
}
|
||||||
|
return await self._request("GET", endpoint, params=params)
|
||||||
|
|
||||||
|
# ===== 更复杂的操作 =====
|
||||||
|
|
||||||
|
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
|
||||||
|
"""获取仓库使用的编程语言及其比例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
语言使用情况
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/languages"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
|
||||||
|
"""获取仓库的贡献者统计
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
贡献者统计信息
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
|
||||||
|
"""获取仓库的提交活动
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: 仓库所有者
|
||||||
|
repo: 仓库名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提交活动统计
|
||||||
|
"""
|
||||||
|
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
|
||||||
|
return await self._request("GET", endpoint)
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""GitHubSource使用示例"""
|
||||||
|
# 创建客户端实例(可选传入API令牌)
|
||||||
|
# github = GitHubSource(api_key="your_github_token")
|
||||||
|
github = GitHubSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 示例1:搜索热门Python仓库
|
||||||
|
print("\n=== 示例1:搜索热门Python仓库 ===")
|
||||||
|
repos = await github.search_repositories(
|
||||||
|
query="language:python stars:>1000",
|
||||||
|
sort="stars",
|
||||||
|
order="desc",
|
||||||
|
per_page=5
|
||||||
|
)
|
||||||
|
|
||||||
|
if repos and "items" in repos:
|
||||||
|
for i, repo in enumerate(repos["items"], 1):
|
||||||
|
print(f"\n--- 仓库 {i} ---")
|
||||||
|
print(f"名称: {repo['full_name']}")
|
||||||
|
print(f"描述: {repo['description']}")
|
||||||
|
print(f"星标数: {repo['stargazers_count']}")
|
||||||
|
print(f"Fork数: {repo['forks_count']}")
|
||||||
|
print(f"最近更新: {repo['updated_at']}")
|
||||||
|
print(f"URL: {repo['html_url']}")
|
||||||
|
|
||||||
|
# 示例2:获取特定仓库的详情
|
||||||
|
print("\n=== 示例2:获取特定仓库的详情 ===")
|
||||||
|
repo_details = await github.get_repo("microsoft", "vscode")
|
||||||
|
if repo_details:
|
||||||
|
print(f"名称: {repo_details['full_name']}")
|
||||||
|
print(f"描述: {repo_details['description']}")
|
||||||
|
print(f"星标数: {repo_details['stargazers_count']}")
|
||||||
|
print(f"Fork数: {repo_details['forks_count']}")
|
||||||
|
print(f"默认分支: {repo_details['default_branch']}")
|
||||||
|
print(f"开源许可: {repo_details.get('license', {}).get('name', '无')}")
|
||||||
|
print(f"语言: {repo_details['language']}")
|
||||||
|
print(f"Open Issues数: {repo_details['open_issues_count']}")
|
||||||
|
|
||||||
|
# 示例3:获取仓库的提交历史
|
||||||
|
print("\n=== 示例3:获取仓库的最近提交 ===")
|
||||||
|
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
|
||||||
|
if commits:
|
||||||
|
for i, commit in enumerate(commits, 1):
|
||||||
|
print(f"\n--- 提交 {i} ---")
|
||||||
|
print(f"SHA: {commit['sha'][:7]}")
|
||||||
|
print(f"作者: {commit['commit']['author']['name']}")
|
||||||
|
print(f"日期: {commit['commit']['author']['date']}")
|
||||||
|
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
|
||||||
|
|
||||||
|
# 示例4:搜索代码
|
||||||
|
print("\n=== 示例4:搜索代码 ===")
|
||||||
|
code_results = await github.search_code(
|
||||||
|
query="filename:README.md language:markdown pytorch in:file",
|
||||||
|
per_page=3
|
||||||
|
)
|
||||||
|
if code_results and "items" in code_results:
|
||||||
|
print(f"共找到: {code_results['total_count']} 个结果")
|
||||||
|
for i, item in enumerate(code_results["items"], 1):
|
||||||
|
print(f"\n--- 代码 {i} ---")
|
||||||
|
print(f"仓库: {item['repository']['full_name']}")
|
||||||
|
print(f"文件: {item['path']}")
|
||||||
|
print(f"URL: {item['html_url']}")
|
||||||
|
|
||||||
|
# 示例5:获取文件内容
|
||||||
|
print("\n=== 示例5:获取文件内容 ===")
|
||||||
|
file_content = await github.get_file_content("python", "cpython", "README.rst")
|
||||||
|
if file_content and "decoded_content" in file_content:
|
||||||
|
content = file_content["decoded_content"]
|
||||||
|
print(f"文件名: {file_content['name']}")
|
||||||
|
print(f"大小: {file_content['size']} 字节")
|
||||||
|
print(f"内容预览: {content[:200]}...")
|
||||||
|
|
||||||
|
# 示例6:获取仓库使用的编程语言
|
||||||
|
print("\n=== 示例6:获取仓库使用的编程语言 ===")
|
||||||
|
languages = await github.get_repository_languages("facebook", "react")
|
||||||
|
if languages:
|
||||||
|
print(f"React仓库使用的编程语言:")
|
||||||
|
for lang, bytes_of_code in languages.items():
|
||||||
|
print(f"- {lang}: {bytes_of_code} 字节")
|
||||||
|
|
||||||
|
# 示例7:获取组织信息
|
||||||
|
print("\n=== 示例7:获取组织信息 ===")
|
||||||
|
org_info = await github.get_organization("google")
|
||||||
|
if org_info:
|
||||||
|
print(f"名称: {org_info['name']}")
|
||||||
|
print(f"描述: {org_info.get('description', '无')}")
|
||||||
|
print(f"位置: {org_info.get('location', '未指定')}")
|
||||||
|
print(f"公共仓库数: {org_info['public_repos']}")
|
||||||
|
print(f"成员数: {org_info.get('public_members', 0)}")
|
||||||
|
print(f"URL: {org_info['html_url']}")
|
||||||
|
|
||||||
|
# 示例8:获取用户信息
|
||||||
|
print("\n=== 示例8:获取用户信息 ===")
|
||||||
|
user_info = await github.get_user("torvalds")
|
||||||
|
if user_info:
|
||||||
|
print(f"名称: {user_info['name']}")
|
||||||
|
print(f"公司: {user_info.get('company', '无')}")
|
||||||
|
print(f"博客: {user_info.get('blog', '无')}")
|
||||||
|
print(f"位置: {user_info.get('location', '未指定')}")
|
||||||
|
print(f"公共仓库数: {user_info['public_repos']}")
|
||||||
|
print(f"关注者数: {user_info['followers']}")
|
||||||
|
print(f"URL: {user_info['html_url']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 运行示例
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
class JournalMetrics:
|
||||||
|
"""期刊指标管理类"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.journal_data: Dict = {} # 期刊名称到指标的映射
|
||||||
|
self.issn_map: Dict = {} # ISSN到指标的映射
|
||||||
|
self.name_map: Dict = {} # 标准化名称到指标的映射
|
||||||
|
self._load_journal_data()
|
||||||
|
|
||||||
|
def _normalize_journal_name(self, name: str) -> str:
|
||||||
|
"""标准化期刊名称
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: 原始期刊名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化后的期刊名称
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 转换为小写
|
||||||
|
name = name.lower()
|
||||||
|
|
||||||
|
# 移除常见的前缀和后缀
|
||||||
|
prefixes = ['the ', 'proceedings of ', 'journal of ']
|
||||||
|
suffixes = [' journal', ' proceedings', ' magazine', ' review', ' letters']
|
||||||
|
|
||||||
|
for prefix in prefixes:
|
||||||
|
if name.startswith(prefix):
|
||||||
|
name = name[len(prefix):]
|
||||||
|
|
||||||
|
for suffix in suffixes:
|
||||||
|
if name.endswith(suffix):
|
||||||
|
name = name[:-len(suffix)]
|
||||||
|
|
||||||
|
# 移除特殊字符,保留字母、数字和空格
|
||||||
|
name = ''.join(c for c in name if c.isalnum() or c.isspace())
|
||||||
|
|
||||||
|
# 移除多余的空格
|
||||||
|
name = ' '.join(name.split())
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
def _convert_if_value(self, if_str: str) -> Optional[float]:
|
||||||
|
"""转换IF值为float,处理特殊情况"""
|
||||||
|
try:
|
||||||
|
if if_str.startswith('<'):
|
||||||
|
# 对于<0.1这样的值,返回0.1
|
||||||
|
return float(if_str.strip('<'))
|
||||||
|
return float(if_str)
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_journal_data(self):
|
||||||
|
"""加载期刊数据"""
|
||||||
|
try:
|
||||||
|
file_path = os.path.join(os.path.dirname(__file__), 'cas_if.json')
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
# 建立期刊名称到指标的映射
|
||||||
|
for journal in data:
|
||||||
|
# 准备指标数据
|
||||||
|
metrics = {
|
||||||
|
'if_factor': self._convert_if_value(journal.get('IF')),
|
||||||
|
'jcr_division': journal.get('Q'),
|
||||||
|
'cas_division': journal.get('B')
|
||||||
|
}
|
||||||
|
|
||||||
|
# 存储期刊名称映射(使用标准化名称)
|
||||||
|
if journal.get('journal'):
|
||||||
|
normalized_name = self._normalize_journal_name(journal['journal'])
|
||||||
|
self.journal_data[normalized_name] = metrics
|
||||||
|
self.name_map[normalized_name] = metrics
|
||||||
|
|
||||||
|
# 存储期刊缩写映射
|
||||||
|
if journal.get('jabb'):
|
||||||
|
normalized_abbr = self._normalize_journal_name(journal['jabb'])
|
||||||
|
self.journal_data[normalized_abbr] = metrics
|
||||||
|
self.name_map[normalized_abbr] = metrics
|
||||||
|
|
||||||
|
# 存储ISSN映射
|
||||||
|
if journal.get('issn'):
|
||||||
|
self.issn_map[journal['issn']] = metrics
|
||||||
|
if journal.get('eissn'):
|
||||||
|
self.issn_map[journal['eissn']] = metrics
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"加载期刊数据时出错: {str(e)}")
|
||||||
|
self.journal_data = {}
|
||||||
|
self.issn_map = {}
|
||||||
|
self.name_map = {}
|
||||||
|
|
||||||
|
def get_journal_metrics(self, venue_name: str, venue_info: dict) -> dict:
|
||||||
|
"""获取期刊指标
|
||||||
|
|
||||||
|
Args:
|
||||||
|
venue_name: 期刊名称
|
||||||
|
venue_info: 期刊详细信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含期刊指标的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
# 1. 首先尝试通过ISSN匹配
|
||||||
|
if venue_info and 'issn' in venue_info:
|
||||||
|
issn_value = venue_info['issn']
|
||||||
|
# 处理ISSN可能是列表的情况
|
||||||
|
if isinstance(issn_value, list):
|
||||||
|
# 尝试每个ISSN
|
||||||
|
for issn in issn_value:
|
||||||
|
metrics = self.issn_map.get(issn, {})
|
||||||
|
if metrics: # 如果找到匹配的指标,就停止搜索
|
||||||
|
break
|
||||||
|
else: # ISSN是字符串的情况
|
||||||
|
metrics = self.issn_map.get(issn_value, {})
|
||||||
|
|
||||||
|
# 2. 如果ISSN匹配失败,尝试通过期刊名称匹配
|
||||||
|
if not metrics and venue_name:
|
||||||
|
# 标准化期刊名称
|
||||||
|
normalized_name = self._normalize_journal_name(venue_name)
|
||||||
|
metrics = self.name_map.get(normalized_name, {})
|
||||||
|
|
||||||
|
# 如果完全匹配失败,尝试部分匹配
|
||||||
|
# if not metrics:
|
||||||
|
# for db_name, db_metrics in self.name_map.items():
|
||||||
|
# if normalized_name in db_name:
|
||||||
|
# metrics = db_metrics
|
||||||
|
# break
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取期刊指标时出错: {str(e)}")
|
||||||
|
return {}
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
import aiohttp
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from .base_source import DataSource, PaperMetadata
|
||||||
|
import os
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
class OpenAlexSource(DataSource):
|
||||||
|
"""OpenAlex API实现"""
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
self.base_url = "https://api.openalex.org"
|
||||||
|
self.mailto = "xxxxxxxxxxxxxxxxxxxxxxxx@163.com" # 直接写入邮件地址
|
||||||
|
|
||||||
|
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
params = {"mailto": self.mailto} if self.mailto else {}
|
||||||
|
params.update({
|
||||||
|
"filter": f"title.search:{query}",
|
||||||
|
"per-page": limit
|
||||||
|
})
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works",
|
||||||
|
params=params
|
||||||
|
) as response:
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
data = await response.json()
|
||||||
|
results = data.get("results", [])
|
||||||
|
return [self._parse_work(work) for work in results]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||||
|
"""解析OpenAlex返回的数据"""
|
||||||
|
# 获取作者信息
|
||||||
|
raw_author_names = [
|
||||||
|
authorship.get("raw_author_name", "")
|
||||||
|
for authorship in work.get("authorships", [])
|
||||||
|
if authorship
|
||||||
|
]
|
||||||
|
# 处理作者名字格式
|
||||||
|
authors = [
|
||||||
|
self._reformat_name(author)
|
||||||
|
for author in raw_author_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# 获取机构信息
|
||||||
|
institutions = [
|
||||||
|
inst.get("display_name", "")
|
||||||
|
for authorship in work.get("authorships", [])
|
||||||
|
for inst in authorship.get("institutions", [])
|
||||||
|
if inst
|
||||||
|
]
|
||||||
|
|
||||||
|
# 获取主要发表位置信息
|
||||||
|
primary_location = work.get("primary_location") or {}
|
||||||
|
source = primary_location.get("source") or {}
|
||||||
|
venue = source.get("display_name")
|
||||||
|
|
||||||
|
# 获取发表日期
|
||||||
|
year = work.get("publication_year")
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=work.get("title", ""),
|
||||||
|
authors=authors,
|
||||||
|
institutions=institutions,
|
||||||
|
abstract=work.get("abstract", ""),
|
||||||
|
year=year,
|
||||||
|
doi=work.get("doi"),
|
||||||
|
url=work.get("doi"), # OpenAlex 使用 DOI 作为 URL
|
||||||
|
citations=work.get("cited_by_count"),
|
||||||
|
venue=venue
|
||||||
|
)
|
||||||
|
|
||||||
|
def _reformat_name(self, name: str) -> str:
|
||||||
|
"""重新格式化作者名字"""
|
||||||
|
if "," not in name:
|
||||||
|
return name
|
||||||
|
family, given_names = (x.strip() for x in name.split(",", maxsplit=1))
|
||||||
|
return f"{given_names} {family}"
|
||||||
|
|
||||||
|
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||||
|
"""获取指定DOI的论文详情"""
|
||||||
|
params = {"mailto": self.mailto} if self.mailto else {}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}",
|
||||||
|
params=params
|
||||||
|
) as response:
|
||||||
|
data = await response.json()
|
||||||
|
return self._parse_work(data)
|
||||||
|
|
||||||
|
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||||
|
"""获取指定DOI论文的参考文献列表"""
|
||||||
|
params = {"mailto": self.mailto} if self.mailto else {}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}/references",
|
||||||
|
params=params
|
||||||
|
) as response:
|
||||||
|
data = await response.json()
|
||||||
|
return [self._parse_work(work) for work in data.get("results", [])]
|
||||||
|
|
||||||
|
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||||
|
"""获取引用指定DOI论文的文献列表"""
|
||||||
|
params = {"mailto": self.mailto} if self.mailto else {}
|
||||||
|
params.update({
|
||||||
|
"filter": f"cites:doi:{doi}",
|
||||||
|
"per-page": 100
|
||||||
|
})
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/works",
|
||||||
|
params=params
|
||||||
|
) as response:
|
||||||
|
data = await response.json()
|
||||||
|
return [self._parse_work(work) for work in data.get("results", [])]
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""OpenAlexSource使用示例"""
|
||||||
|
# 初始化OpenAlexSource
|
||||||
|
openalex = OpenAlexSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("正在搜索论文...")
|
||||||
|
# 搜索与"artificial intelligence"相关的论文,限制返回5篇
|
||||||
|
papers = await openalex.search(query="artificial intelligence", limit=5)
|
||||||
|
|
||||||
|
if not papers:
|
||||||
|
print("未获取到任何论文信息")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"找到 {len(papers)} 篇论文")
|
||||||
|
|
||||||
|
# 打印搜索结果
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors) if paper.authors else '未知'}")
|
||||||
|
if paper.institutions:
|
||||||
|
print(f"机构: {', '.join(paper.institutions)}")
|
||||||
|
print(f"发表年份: {paper.year if paper.year else '未知'}")
|
||||||
|
print(f"DOI: {paper.doi if paper.doi else '未知'}")
|
||||||
|
print(f"URL: {paper.url if paper.url else '未知'}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"摘要: {paper.abstract[:200]}...")
|
||||||
|
print(f"引用次数: {paper.citations if paper.citations is not None else '未知'}")
|
||||||
|
print(f"发表venue: {paper.venue if paper.venue else '未知'}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
# 如果直接运行此文件,执行示例代码
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# 运行示例
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,458 @@
|
|||||||
|
from typing import List, Optional, Dict, Union
|
||||||
|
from datetime import datetime
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
from urllib.parse import quote
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import random
|
||||||
|
|
||||||
|
class PubMedSource(DataSource):
|
||||||
|
"""PubMed API实现"""
|
||||||
|
|
||||||
|
# 定义API密钥列表
|
||||||
|
API_KEYS = [
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
"""初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: PubMed API密钥,如果不提供则从预定义列表中随机选择
|
||||||
|
"""
|
||||||
|
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化基础URL和请求头"""
|
||||||
|
self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
||||||
|
self.headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 PubMedDataSource/1.0",
|
||||||
|
"Accept": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _make_request(self, url: str) -> Optional[str]:
|
||||||
|
"""发送HTTP请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 请求URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
响应内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.text()
|
||||||
|
else:
|
||||||
|
print(f"请求失败: {response.status}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = "relevance",
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""搜索论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||||
|
start_year: 起始年份
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 添加年份过滤
|
||||||
|
if start_year:
|
||||||
|
query = f"{query} AND {start_year}:3000[dp]"
|
||||||
|
|
||||||
|
# 构建搜索URL
|
||||||
|
search_url = (
|
||||||
|
f"{self.base_url}/esearch.fcgi?"
|
||||||
|
f"db=pubmed&term={quote(query)}&retmax={limit}"
|
||||||
|
f"&usehistory=y&api_key={self.api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if sort_by == "date":
|
||||||
|
search_url += "&sort=date"
|
||||||
|
|
||||||
|
# 获取搜索结果
|
||||||
|
response = await self._make_request(search_url)
|
||||||
|
if not response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析XML响应
|
||||||
|
root = ET.fromstring(response)
|
||||||
|
id_list = root.findall(".//Id")
|
||||||
|
pmids = [id_elem.text for id_elem in id_list]
|
||||||
|
|
||||||
|
if not pmids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 批量获取论文详情
|
||||||
|
papers = []
|
||||||
|
batch_size = 50
|
||||||
|
for i in range(0, len(pmids), batch_size):
|
||||||
|
batch = pmids[i:i + batch_size]
|
||||||
|
batch_papers = await self._fetch_papers_batch(batch)
|
||||||
|
papers.extend(batch_papers)
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索论文时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _fetch_papers_batch(self, pmids: List[str]) -> List[PaperMetadata]:
|
||||||
|
"""批量获取论文详情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pmids: PubMed ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
论文详情列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建批量获取URL
|
||||||
|
fetch_url = (
|
||||||
|
f"{self.base_url}/efetch.fcgi?"
|
||||||
|
f"db=pubmed&id={','.join(pmids)}"
|
||||||
|
f"&retmode=xml&api_key={self.api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self._make_request(fetch_url)
|
||||||
|
if not response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析XML响应
|
||||||
|
root = ET.fromstring(response)
|
||||||
|
articles = root.findall(".//PubmedArticle")
|
||||||
|
|
||||||
|
return [self._parse_article(article) for article in articles]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文批次时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_article(self, article: ET.Element) -> PaperMetadata:
|
||||||
|
"""解析PubMed文章XML
|
||||||
|
|
||||||
|
Args:
|
||||||
|
article: XML元素
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的论文数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 提取基本信息
|
||||||
|
pmid = article.find(".//PMID").text
|
||||||
|
article_meta = article.find(".//Article")
|
||||||
|
|
||||||
|
# 获取标题
|
||||||
|
title = article_meta.find(".//ArticleTitle")
|
||||||
|
title = title.text if title is not None else ""
|
||||||
|
|
||||||
|
# 获取作者列表
|
||||||
|
authors = []
|
||||||
|
author_list = article_meta.findall(".//Author")
|
||||||
|
for author in author_list:
|
||||||
|
last_name = author.find("LastName")
|
||||||
|
fore_name = author.find("ForeName")
|
||||||
|
if last_name is not None and fore_name is not None:
|
||||||
|
authors.append(f"{fore_name.text} {last_name.text}")
|
||||||
|
elif last_name is not None:
|
||||||
|
authors.append(last_name.text)
|
||||||
|
|
||||||
|
# 获取摘要
|
||||||
|
abstract = article_meta.find(".//Abstract/AbstractText")
|
||||||
|
abstract = abstract.text if abstract is not None else ""
|
||||||
|
|
||||||
|
# 获取发表年份
|
||||||
|
pub_date = article_meta.find(".//PubDate/Year")
|
||||||
|
year = int(pub_date.text) if pub_date is not None else None
|
||||||
|
|
||||||
|
# 获取DOI
|
||||||
|
doi = article.find(".//ELocationID[@EIdType='doi']")
|
||||||
|
doi = doi.text if doi is not None else None
|
||||||
|
|
||||||
|
# 获取期刊信息
|
||||||
|
journal = article_meta.find(".//Journal")
|
||||||
|
if journal is not None:
|
||||||
|
journal_title = journal.find(".//Title")
|
||||||
|
venue = journal_title.text if journal_title is not None else None
|
||||||
|
|
||||||
|
# 获取期刊详细信息
|
||||||
|
venue_info = {
|
||||||
|
'issn': journal.findtext(".//ISSN"),
|
||||||
|
'volume': journal.findtext(".//Volume"),
|
||||||
|
'issue': journal.findtext(".//Issue"),
|
||||||
|
'pub_date': journal.findtext(".//PubDate/MedlineDate") or
|
||||||
|
f"{journal.findtext('.//PubDate/Year', '')}-{journal.findtext('.//PubDate/Month', '')}"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
venue = None
|
||||||
|
venue_info = {}
|
||||||
|
|
||||||
|
# 获取机构信息
|
||||||
|
institutions = []
|
||||||
|
affiliations = article_meta.findall(".//Affiliation")
|
||||||
|
for affiliation in affiliations:
|
||||||
|
if affiliation is not None and affiliation.text:
|
||||||
|
institutions.append(affiliation.text)
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=title,
|
||||||
|
authors=authors,
|
||||||
|
abstract=abstract,
|
||||||
|
year=year,
|
||||||
|
doi=doi,
|
||||||
|
url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" if pmid else None,
|
||||||
|
citations=None, # PubMed API不直接提供引用数据
|
||||||
|
venue=venue,
|
||||||
|
institutions=institutions,
|
||||||
|
venue_type="journal",
|
||||||
|
venue_name=venue,
|
||||||
|
venue_info=venue_info,
|
||||||
|
source='pubmed' # 添加来源标记
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析文章时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_paper_details(self, pmid: str) -> Optional[PaperMetadata]:
|
||||||
|
"""获取指定PMID的论文详情"""
|
||||||
|
papers = await self._fetch_papers_batch([pmid])
|
||||||
|
return papers[0] if papers else None
|
||||||
|
|
||||||
|
async def get_related_papers(self, pmid: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
"""获取相关论文
|
||||||
|
|
||||||
|
使用PubMed的相关文章功能
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pmid: PubMed ID
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建相关文章URL
|
||||||
|
link_url = (
|
||||||
|
f"{self.base_url}/elink.fcgi?"
|
||||||
|
f"db=pubmed&id={pmid}&cmd=neighbor&api_key={self.api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self._make_request(link_url)
|
||||||
|
if not response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析XML响应
|
||||||
|
root = ET.fromstring(response)
|
||||||
|
related_ids = root.findall(".//Link/Id")
|
||||||
|
pmids = [id_elem.text for id_elem in related_ids][:limit]
|
||||||
|
|
||||||
|
if not pmids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取相关论文详情
|
||||||
|
return await self._fetch_papers_batch(pmids)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取相关论文时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def search_by_author(
|
||||||
|
self,
|
||||||
|
author: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按作者搜索论文"""
|
||||||
|
query = f"{author}[Author]"
|
||||||
|
if start_year:
|
||||||
|
query += f" AND {start_year}:3000[dp]"
|
||||||
|
return await self.search(query, limit=limit)
|
||||||
|
|
||||||
|
async def search_by_journal(
|
||||||
|
self,
|
||||||
|
journal: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按期刊搜索论文"""
|
||||||
|
query = f"{journal}[Journal]"
|
||||||
|
if start_year:
|
||||||
|
query += f" AND {start_year}:3000[dp]"
|
||||||
|
return await self.search(query, limit=limit)
|
||||||
|
|
||||||
|
async def get_latest_papers(
|
||||||
|
self,
|
||||||
|
days: int = 7,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""获取最新论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
days: 最近几天的论文
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
最新论文列表
|
||||||
|
"""
|
||||||
|
query = f"last {days} days[dp]"
|
||||||
|
return await self.search(query, limit=limit, sort_by="date")
|
||||||
|
|
||||||
|
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
"""获取引用该论文的文献
|
||||||
|
|
||||||
|
注意:PubMed API本身不提供引用数据,此方法将返回空列表
|
||||||
|
未来可以考虑集成其他数据源(如CrossRef)来获取引用信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: PubMed ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
空列表,因为PubMed不提供引用数据
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
"""获取该论文引用的文献
|
||||||
|
|
||||||
|
从PubMed文章的参考文献列表获取引用的文献
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: PubMed ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
引用的文献列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建获取参考文献的URL
|
||||||
|
refs_url = (
|
||||||
|
f"{self.base_url}/elink.fcgi?"
|
||||||
|
f"dbfrom=pubmed&db=pubmed&id={paper_id}"
|
||||||
|
f"&cmd=neighbor_history&linkname=pubmed_pubmed_refs"
|
||||||
|
f"&api_key={self.api_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self._make_request(refs_url)
|
||||||
|
if not response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析XML响应
|
||||||
|
root = ET.fromstring(response)
|
||||||
|
ref_ids = root.findall(".//Link/Id")
|
||||||
|
pmids = [id_elem.text for id_elem in ref_ids]
|
||||||
|
|
||||||
|
if not pmids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取参考文献详情
|
||||||
|
return await self._fetch_papers_batch(pmids)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取参考文献时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""PubMedSource使用示例"""
|
||||||
|
pubmed = PubMedSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 示例1:基本搜索
|
||||||
|
print("\n=== 示例1:搜索COVID-19相关论文 ===")
|
||||||
|
papers = await pubmed.search("COVID-19", limit=3)
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"摘要: {paper.abstract[:200]}...")
|
||||||
|
|
||||||
|
# 示例2:获取论文详情
|
||||||
|
if papers:
|
||||||
|
print("\n=== 示例2:获取论文详情 ===")
|
||||||
|
paper_id = papers[0].url.split("/")[-2]
|
||||||
|
paper = await pubmed.get_paper_details(paper_id)
|
||||||
|
if paper:
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"期刊: {paper.venue}")
|
||||||
|
print(f"机构: {', '.join(paper.institutions)}")
|
||||||
|
|
||||||
|
# 示例3:获取相关论文
|
||||||
|
if papers:
|
||||||
|
print("\n=== 示例3:获取相关论文 ===")
|
||||||
|
related = await pubmed.get_related_papers(paper_id, limit=3)
|
||||||
|
for i, paper in enumerate(related, 1):
|
||||||
|
print(f"\n--- 相关论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
|
||||||
|
# 示例4:按作者搜索
|
||||||
|
print("\n=== 示例4:按作者搜索 ===")
|
||||||
|
author_papers = await pubmed.search_by_author("Fauci AS", limit=3)
|
||||||
|
for i, paper in enumerate(author_papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# 示例5:按期刊搜索
|
||||||
|
print("\n=== 示例5:按期刊搜索 ===")
|
||||||
|
journal_papers = await pubmed.search_by_journal("Nature", limit=3)
|
||||||
|
for i, paper in enumerate(journal_papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# 示例6:获取最新论文
|
||||||
|
print("\n=== 示例6:获取最新论文 ===")
|
||||||
|
latest = await pubmed.get_latest_papers(days=7, limit=3)
|
||||||
|
for i, paper in enumerate(latest, 1):
|
||||||
|
print(f"\n--- 最新论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"发表日期: {paper.venue_info.get('pub_date')}")
|
||||||
|
|
||||||
|
# 示例7:获取论文的参考文献
|
||||||
|
if papers:
|
||||||
|
print("\n=== 示例7:获取论文的参考文献 ===")
|
||||||
|
paper_id = papers[0].url.split("/")[-2]
|
||||||
|
references = await pubmed.get_references(paper_id)
|
||||||
|
for i, paper in enumerate(references[:3], 1):
|
||||||
|
print(f"\n--- 参考文献 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# 示例8:尝试获取引用信息(将返回空列表)
|
||||||
|
if papers:
|
||||||
|
print("\n=== 示例8:获取论文的引用信息 ===")
|
||||||
|
paper_id = papers[0].url.split("/")[-2]
|
||||||
|
citations = await pubmed.get_citations(paper_id)
|
||||||
|
print(f"引用数据:{len(citations)} (PubMed API不提供引用信息)")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,326 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
import time
|
||||||
|
from loguru import logger
|
||||||
|
import PyPDF2
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
class SciHub:
|
||||||
|
# 更新的镜像列表,包含更多可用的镜像
|
||||||
|
MIRRORS = [
|
||||||
|
'https://sci-hub.se/',
|
||||||
|
'https://sci-hub.st/',
|
||||||
|
'https://sci-hub.ru/',
|
||||||
|
'https://sci-hub.wf/',
|
||||||
|
'https://sci-hub.ee/',
|
||||||
|
'https://sci-hub.ren/',
|
||||||
|
'https://sci-hub.tf/',
|
||||||
|
'https://sci-hub.si/',
|
||||||
|
'https://sci-hub.do/',
|
||||||
|
'https://sci-hub.hkvisa.net/',
|
||||||
|
'https://sci-hub.mksa.top/',
|
||||||
|
'https://sci-hub.shop/',
|
||||||
|
'https://sci-hub.yncjkj.com/',
|
||||||
|
'https://sci-hub.41610.org/',
|
||||||
|
'https://sci-hub.automic.us/',
|
||||||
|
'https://sci-hub.et-fine.com/',
|
||||||
|
'https://sci-hub.pooh.mu/',
|
||||||
|
'https://sci-hub.bban.top/',
|
||||||
|
'https://sci-hub.usualwant.com/',
|
||||||
|
'https://sci-hub.unblockit.kim/'
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, doi: str, path: Path, url=None, timeout=60, use_proxy=True):
|
||||||
|
self.timeout = timeout
|
||||||
|
self.path = path
|
||||||
|
self.doi = str(doi)
|
||||||
|
self.use_proxy = use_proxy
|
||||||
|
self.headers = {
|
||||||
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||||
|
}
|
||||||
|
self.payload = {
|
||||||
|
'sci-hub-plugin-check': '',
|
||||||
|
'request': self.doi
|
||||||
|
}
|
||||||
|
self.url = url if url else self.MIRRORS[0]
|
||||||
|
self.proxies = {
|
||||||
|
"http": "socks5h://localhost:10880",
|
||||||
|
"https": "socks5h://localhost:10880",
|
||||||
|
} if use_proxy else None
|
||||||
|
|
||||||
|
def _test_proxy_connection(self):
|
||||||
|
"""测试代理连接是否可用"""
|
||||||
|
if not self.use_proxy:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 测试代理连接
|
||||||
|
test_response = requests.get(
|
||||||
|
'https://httpbin.org/ip',
|
||||||
|
proxies=self.proxies,
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
if test_response.status_code == 200:
|
||||||
|
logger.info("代理连接测试成功")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"代理连接测试失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_pdf_validity(self, content):
|
||||||
|
"""检查PDF文件是否有效"""
|
||||||
|
try:
|
||||||
|
# 使用PyPDF2检查PDF是否可以正常打开和读取
|
||||||
|
pdf = PyPDF2.PdfReader(io.BytesIO(content))
|
||||||
|
if len(pdf.pages) > 0:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"PDF文件无效: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _send_request(self):
|
||||||
|
"""发送请求到Sci-Hub镜像站点"""
|
||||||
|
# 首先测试代理连接
|
||||||
|
if self.use_proxy and not self._test_proxy_connection():
|
||||||
|
logger.warning("代理连接不可用,切换到直连模式")
|
||||||
|
self.use_proxy = False
|
||||||
|
self.proxies = None
|
||||||
|
|
||||||
|
last_exception = None
|
||||||
|
working_mirrors = []
|
||||||
|
|
||||||
|
# 先测试哪些镜像可用
|
||||||
|
logger.info("正在测试镜像站点可用性...")
|
||||||
|
for mirror in self.MIRRORS:
|
||||||
|
try:
|
||||||
|
test_response = requests.get(
|
||||||
|
mirror,
|
||||||
|
headers=self.headers,
|
||||||
|
proxies=self.proxies,
|
||||||
|
timeout=10
|
||||||
|
)
|
||||||
|
if test_response.status_code == 200:
|
||||||
|
working_mirrors.append(mirror)
|
||||||
|
logger.info(f"镜像 {mirror} 可用")
|
||||||
|
if len(working_mirrors) >= 5: # 找到5个可用镜像就够了
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"镜像 {mirror} 不可用: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not working_mirrors:
|
||||||
|
raise Exception("没有找到可用的镜像站点")
|
||||||
|
|
||||||
|
logger.info(f"找到 {len(working_mirrors)} 个可用镜像,开始尝试下载...")
|
||||||
|
|
||||||
|
# 使用可用的镜像进行下载
|
||||||
|
for mirror in working_mirrors:
|
||||||
|
try:
|
||||||
|
res = requests.post(
|
||||||
|
mirror,
|
||||||
|
headers=self.headers,
|
||||||
|
data=self.payload,
|
||||||
|
proxies=self.proxies,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
if res.ok:
|
||||||
|
logger.info(f"成功使用镜像站点: {mirror}")
|
||||||
|
self.url = mirror # 更新当前使用的镜像
|
||||||
|
time.sleep(1) # 降低等待时间以提高效率
|
||||||
|
return res
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"尝试镜像 {mirror} 失败: {str(e)}")
|
||||||
|
last_exception = e
|
||||||
|
continue
|
||||||
|
|
||||||
|
if last_exception:
|
||||||
|
raise last_exception
|
||||||
|
raise Exception("所有可用镜像站点均无法完成下载")
|
||||||
|
|
||||||
|
def _extract_url(self, response):
|
||||||
|
"""从响应中提取PDF下载链接"""
|
||||||
|
soup = BeautifulSoup(response.content, 'html.parser')
|
||||||
|
try:
|
||||||
|
# 尝试多种方式提取PDF链接
|
||||||
|
pdf_element = soup.find(id='pdf')
|
||||||
|
if pdf_element:
|
||||||
|
content_url = pdf_element.get('src')
|
||||||
|
else:
|
||||||
|
# 尝试其他可能的选择器
|
||||||
|
pdf_element = soup.find('iframe')
|
||||||
|
if pdf_element:
|
||||||
|
content_url = pdf_element.get('src')
|
||||||
|
else:
|
||||||
|
# 查找直接的PDF链接
|
||||||
|
pdf_links = soup.find_all('a', href=lambda x: x and '.pdf' in x)
|
||||||
|
if pdf_links:
|
||||||
|
content_url = pdf_links[0].get('href')
|
||||||
|
else:
|
||||||
|
raise AttributeError("未找到PDF链接")
|
||||||
|
|
||||||
|
if content_url:
|
||||||
|
content_url = content_url.replace('#navpanes=0&view=FitH', '').replace('//', '/')
|
||||||
|
if not content_url.endswith('.pdf') and 'pdf' not in content_url.lower():
|
||||||
|
raise AttributeError("找到的链接不是PDF文件")
|
||||||
|
except AttributeError:
|
||||||
|
logger.error(f"未找到论文 {self.doi}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
current_mirror = self.url.rstrip('/')
|
||||||
|
if content_url.startswith('/'):
|
||||||
|
return current_mirror + content_url
|
||||||
|
elif content_url.startswith('http'):
|
||||||
|
return content_url
|
||||||
|
else:
|
||||||
|
return 'https:/' + content_url
|
||||||
|
|
||||||
|
def _download_pdf(self, pdf_url):
|
||||||
|
"""下载PDF文件并验证其完整性"""
|
||||||
|
try:
|
||||||
|
# 尝试不同的下载方式
|
||||||
|
download_methods = [
|
||||||
|
# 方法1:直接下载
|
||||||
|
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout),
|
||||||
|
# 方法2:添加 Referer 头
|
||||||
|
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||||
|
headers={**self.headers, 'Referer': self.url}),
|
||||||
|
# 方法3:使用原始域名作为 Referer
|
||||||
|
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||||
|
headers={**self.headers, 'Referer': pdf_url.split('/downloads')[0] if '/downloads' in pdf_url else self.url})
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, download_method in enumerate(download_methods):
|
||||||
|
try:
|
||||||
|
logger.info(f"尝试下载方式 {i+1}/3...")
|
||||||
|
response = download_method()
|
||||||
|
if response.status_code == 200:
|
||||||
|
content = response.content
|
||||||
|
if len(content) > 1000 and self._check_pdf_validity(content): # 确保文件不是太小
|
||||||
|
logger.info(f"PDF下载成功,文件大小: {len(content)} bytes")
|
||||||
|
return content
|
||||||
|
else:
|
||||||
|
logger.warning("下载的文件可能不是有效的PDF")
|
||||||
|
elif response.status_code == 403:
|
||||||
|
logger.warning(f"访问被拒绝 (403 Forbidden),尝试其他下载方式")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.warning(f"下载失败,状态码: {response.status_code}")
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"下载方式 {i+1} 失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果所有方法都失败,尝试构造替代URL
|
||||||
|
try:
|
||||||
|
logger.info("尝试使用替代镜像下载...")
|
||||||
|
# 从原始URL提取关键信息
|
||||||
|
if '/downloads/' in pdf_url:
|
||||||
|
file_part = pdf_url.split('/downloads/')[-1]
|
||||||
|
alternative_mirrors = [
|
||||||
|
f"https://sci-hub.se/downloads/{file_part}",
|
||||||
|
f"https://sci-hub.st/downloads/{file_part}",
|
||||||
|
f"https://sci-hub.ru/downloads/{file_part}",
|
||||||
|
f"https://sci-hub.wf/downloads/{file_part}",
|
||||||
|
f"https://sci-hub.ee/downloads/{file_part}",
|
||||||
|
f"https://sci-hub.ren/downloads/{file_part}",
|
||||||
|
f"https://sci-hub.tf/downloads/{file_part}"
|
||||||
|
]
|
||||||
|
|
||||||
|
for alt_url in alternative_mirrors:
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
alt_url,
|
||||||
|
proxies=self.proxies,
|
||||||
|
timeout=self.timeout,
|
||||||
|
headers={**self.headers, 'Referer': alt_url.split('/downloads')[0]}
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
content = response.content
|
||||||
|
if len(content) > 1000 and self._check_pdf_validity(content):
|
||||||
|
logger.info(f"使用替代镜像成功下载: {alt_url}")
|
||||||
|
return content
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"替代镜像 {alt_url} 下载失败: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"所有下载方式都失败: {str(e)}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"下载PDF文件失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def fetch(self):
|
||||||
|
"""获取论文PDF,包含重试和验证机制"""
|
||||||
|
for attempt in range(2): # 最多重试3次
|
||||||
|
try:
|
||||||
|
logger.info(f"开始第 {attempt + 1} 次尝试下载论文: {self.doi}")
|
||||||
|
|
||||||
|
# 获取PDF下载链接
|
||||||
|
response = self._send_request()
|
||||||
|
pdf_url = self._extract_url(response)
|
||||||
|
if pdf_url is None:
|
||||||
|
logger.warning(f"第 {attempt + 1} 次尝试:未找到PDF下载链接")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"找到PDF下载链接: {pdf_url}")
|
||||||
|
|
||||||
|
# 下载并验证PDF
|
||||||
|
pdf_content = self._download_pdf(pdf_url)
|
||||||
|
if pdf_content is None:
|
||||||
|
logger.warning(f"第 {attempt + 1} 次尝试:PDF下载失败")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 保存PDF文件
|
||||||
|
pdf_name = f"{self.doi.replace('/', '_').replace(':', '_')}.pdf"
|
||||||
|
pdf_path = self.path.joinpath(pdf_name)
|
||||||
|
pdf_path.write_bytes(pdf_content)
|
||||||
|
|
||||||
|
logger.info(f"成功下载论文: {pdf_name},文件大小: {len(pdf_content)} bytes")
|
||||||
|
return str(pdf_path)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"第 {attempt + 1} 次尝试失败: {str(e)}")
|
||||||
|
if attempt < 2: # 不是最后一次尝试
|
||||||
|
wait_time = (attempt + 1) * 3 # 递增等待时间
|
||||||
|
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise Exception(f"无法下载论文 {self.doi},所有重试都失败了")
|
||||||
|
|
||||||
|
# Usage Example
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 创建一个用于保存PDF的目录
|
||||||
|
save_path = Path('./downloaded_papers')
|
||||||
|
save_path.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# DOI示例
|
||||||
|
sample_doi = '10.3897/rio.7.e67379' # 这是一篇Nature的论文DOI
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 初始化SciHub下载器,先尝试使用代理
|
||||||
|
logger.info("尝试使用代理模式...")
|
||||||
|
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=True)
|
||||||
|
|
||||||
|
# 开始下载
|
||||||
|
result = downloader.fetch()
|
||||||
|
print(f"论文已保存到: {result}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"使用代理模式失败: {str(e)}")
|
||||||
|
try:
|
||||||
|
# 如果代理模式失败,尝试直连模式
|
||||||
|
logger.info("尝试直连模式...")
|
||||||
|
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=False)
|
||||||
|
result = downloader.fetch()
|
||||||
|
print(f"论文已保存到: {result}")
|
||||||
|
except Exception as e2:
|
||||||
|
print(f"直连模式也失败: {str(e2)}")
|
||||||
|
print("建议检查网络连接或尝试其他DOI")
|
||||||
@@ -0,0 +1,400 @@
|
|||||||
|
from typing import List, Optional, Dict, Union
|
||||||
|
from datetime import datetime
|
||||||
|
import aiohttp
|
||||||
|
import random
|
||||||
|
from .base_source import DataSource, PaperMetadata
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
class ScopusSource(DataSource):
|
||||||
|
"""Scopus API实现"""
|
||||||
|
|
||||||
|
# 定义API密钥列表
|
||||||
|
API_KEYS = [
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
"""初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Scopus API密钥,如果不提供则从预定义列表中随机选择
|
||||||
|
"""
|
||||||
|
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||||
|
self._initialize()
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化基础URL和请求头"""
|
||||||
|
self.base_url = "https://api.elsevier.com/content"
|
||||||
|
self.headers = {
|
||||||
|
"X-ELS-APIKey": self.api_key,
|
||||||
|
"Accept": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||||
|
"""发送HTTP请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 请求URL
|
||||||
|
params: 查询参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
响应JSON数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||||
|
async with session.get(url, params=params) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return await response.json()
|
||||||
|
else:
|
||||||
|
print(f"请求失败: {response.status}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"请求发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_paper_data(self, data: Dict) -> PaperMetadata:
|
||||||
|
"""解析Scopus API返回的数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Scopus API返回的论文数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
解析后的论文元数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 提取基本信息
|
||||||
|
title = data.get("dc:title", "")
|
||||||
|
|
||||||
|
# 提取作者信息
|
||||||
|
authors = []
|
||||||
|
if "author" in data:
|
||||||
|
if isinstance(data["author"], list):
|
||||||
|
for author in data["author"]:
|
||||||
|
if "given-name" in author and "surname" in author:
|
||||||
|
authors.append(f"{author['given-name']} {author['surname']}")
|
||||||
|
elif "indexed-name" in author:
|
||||||
|
authors.append(author["indexed-name"])
|
||||||
|
elif isinstance(data["author"], dict):
|
||||||
|
if "given-name" in data["author"] and "surname" in data["author"]:
|
||||||
|
authors.append(f"{data['author']['given-name']} {data['author']['surname']}")
|
||||||
|
elif "indexed-name" in data["author"]:
|
||||||
|
authors.append(data["author"]["indexed-name"])
|
||||||
|
|
||||||
|
# 提取摘要
|
||||||
|
abstract = data.get("dc:description", "")
|
||||||
|
|
||||||
|
# 提取年份
|
||||||
|
year = None
|
||||||
|
if "prism:coverDate" in data:
|
||||||
|
try:
|
||||||
|
year = int(data["prism:coverDate"][:4])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 提取DOI
|
||||||
|
doi = data.get("prism:doi")
|
||||||
|
|
||||||
|
# 提取引用次数
|
||||||
|
citations = data.get("citedby-count")
|
||||||
|
if citations:
|
||||||
|
try:
|
||||||
|
citations = int(citations)
|
||||||
|
except:
|
||||||
|
citations = None
|
||||||
|
|
||||||
|
# 提取期刊信息
|
||||||
|
venue = data.get("prism:publicationName")
|
||||||
|
|
||||||
|
# 提取机构信息
|
||||||
|
institutions = []
|
||||||
|
if "affiliation" in data:
|
||||||
|
if isinstance(data["affiliation"], list):
|
||||||
|
for aff in data["affiliation"]:
|
||||||
|
if "affilname" in aff:
|
||||||
|
institutions.append(aff["affilname"])
|
||||||
|
elif isinstance(data["affiliation"], dict):
|
||||||
|
if "affilname" in data["affiliation"]:
|
||||||
|
institutions.append(data["affiliation"]["affilname"])
|
||||||
|
|
||||||
|
# 构建venue信息
|
||||||
|
venue_info = {
|
||||||
|
"issn": data.get("prism:issn"),
|
||||||
|
"eissn": data.get("prism:eIssn"),
|
||||||
|
"volume": data.get("prism:volume"),
|
||||||
|
"issue": data.get("prism:issueIdentifier"),
|
||||||
|
"page_range": data.get("prism:pageRange"),
|
||||||
|
"article_number": data.get("article-number"),
|
||||||
|
"publication_date": data.get("prism:coverDate")
|
||||||
|
}
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=title,
|
||||||
|
authors=authors,
|
||||||
|
abstract=abstract,
|
||||||
|
year=year,
|
||||||
|
doi=doi,
|
||||||
|
url=data.get("link", [{}])[0].get("@href"),
|
||||||
|
citations=citations,
|
||||||
|
venue=venue,
|
||||||
|
institutions=institutions,
|
||||||
|
venue_type="journal",
|
||||||
|
venue_name=venue,
|
||||||
|
venue_info=venue_info
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"解析论文数据时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 100,
|
||||||
|
sort_by: str = None,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""搜索论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 搜索关键词
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||||
|
start_year: 起始年份
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建查询参数
|
||||||
|
params = {
|
||||||
|
"query": query,
|
||||||
|
"count": min(limit, 100), # Scopus API单次请求限制
|
||||||
|
"start": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加年份过滤
|
||||||
|
if start_year:
|
||||||
|
params["date"] = f"{start_year}-present"
|
||||||
|
|
||||||
|
# 添加排序
|
||||||
|
if sort_by:
|
||||||
|
sort_map = {
|
||||||
|
"relevance": "-score",
|
||||||
|
"date": "-coverDate",
|
||||||
|
"citations": "-citedby-count"
|
||||||
|
}
|
||||||
|
if sort_by in sort_map:
|
||||||
|
params["sort"] = sort_map[sort_by]
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
url = f"{self.base_url}/search/scopus"
|
||||||
|
response = await self._make_request(url, params)
|
||||||
|
|
||||||
|
if not response or "search-results" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 解析结果
|
||||||
|
results = response["search-results"].get("entry", [])
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
paper = self._parse_paper_data(result)
|
||||||
|
if paper:
|
||||||
|
papers.append(paper)
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索论文时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||||
|
"""获取论文详情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: Scopus ID或DOI
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
论文详情
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 判断是否为DOI
|
||||||
|
if "/" in paper_id:
|
||||||
|
url = f"{self.base_url}/article/doi/{paper_id}"
|
||||||
|
else:
|
||||||
|
url = f"{self.base_url}/abstract/scopus_id/{paper_id}"
|
||||||
|
|
||||||
|
response = await self._make_request(url)
|
||||||
|
|
||||||
|
if not response or "abstracts-retrieval-response" not in response:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = response["abstracts-retrieval-response"]
|
||||||
|
return self._parse_paper_data(data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文详情时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
"""获取引用该论文的文献
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: Scopus ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
引用论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = f"{self.base_url}/abstract/citations/{paper_id}"
|
||||||
|
response = await self._make_request(url)
|
||||||
|
|
||||||
|
if not response or "citing-papers" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = response["citing-papers"].get("papers", [])
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
paper = self._parse_paper_data(result)
|
||||||
|
if paper:
|
||||||
|
papers.append(paper)
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取引用信息时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||||
|
"""获取该论文引用的文献
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper_id: Scopus ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
参考文献列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = f"{self.base_url}/abstract/references/{paper_id}"
|
||||||
|
response = await self._make_request(url)
|
||||||
|
|
||||||
|
if not response or "references" not in response:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results = response["references"].get("reference", [])
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
paper = self._parse_paper_data(result)
|
||||||
|
if paper:
|
||||||
|
papers.append(paper)
|
||||||
|
|
||||||
|
return papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取参考文献时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def search_by_author(
|
||||||
|
self,
|
||||||
|
author: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按作者搜索论文"""
|
||||||
|
query = f"AUTHOR-NAME({author})"
|
||||||
|
if start_year:
|
||||||
|
query += f" AND PUBYEAR > {start_year}"
|
||||||
|
return await self.search(query, limit=limit)
|
||||||
|
|
||||||
|
async def search_by_journal(
|
||||||
|
self,
|
||||||
|
journal: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""按期刊搜索论文"""
|
||||||
|
query = f"SRCTITLE({journal})"
|
||||||
|
if start_year:
|
||||||
|
query += f" AND PUBYEAR > {start_year}"
|
||||||
|
return await self.search(query, limit=limit)
|
||||||
|
|
||||||
|
async def get_latest_papers(
|
||||||
|
self,
|
||||||
|
days: int = 7,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""获取最新论文"""
|
||||||
|
query = f"LOAD-DATE > NOW() - {days}d"
|
||||||
|
return await self.search(query, limit=limit, sort_by="date")
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""ScopusSource使用示例"""
|
||||||
|
scopus = ScopusSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 示例1:基本搜索
|
||||||
|
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||||
|
papers = await scopus.search("machine learning", limit=3)
|
||||||
|
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n论文 {i}:")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"发表期刊: {paper.venue}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"摘要:\n{paper.abstract}")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# 示例2:按作者搜索
|
||||||
|
print("\n=== 示例2:搜索特定作者的论文 ===")
|
||||||
|
author_papers = await scopus.search_by_author("Hinton G.", limit=3)
|
||||||
|
print(f"\n找到 {len(author_papers)} 篇 Hinton 的论文:")
|
||||||
|
for i, paper in enumerate(author_papers, 1):
|
||||||
|
print(f"\n论文 {i}:")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"发表期刊: {paper.venue}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"摘要:\n{paper.abstract}")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
# 示例3:根据关键词搜索相关论文
|
||||||
|
print("\n=== 示例3:搜索人工智能相关论文 ===")
|
||||||
|
keywords = "artificial intelligence AND deep learning"
|
||||||
|
papers = await scopus.search(
|
||||||
|
query=keywords,
|
||||||
|
limit=5,
|
||||||
|
sort_by="citations", # 按引用次数排序
|
||||||
|
start_year=2020 # 只搜索2020年之后的论文
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n论文 {i}:")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"发表期刊: {paper.venue}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"摘要:\n{paper.abstract}")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,480 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||||
|
import random
|
||||||
|
|
||||||
|
class SemanticScholarSource(DataSource):
|
||||||
|
"""Semantic Scholar API实现,使用官方Python包"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None):
|
||||||
|
"""初始化
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: Semantic Scholar API密钥(可选)
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self._initialize() # 调用初始化方法
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
"""初始化API客户端"""
|
||||||
|
if not self.api_key:
|
||||||
|
# 默认API密钥列表
|
||||||
|
default_api_keys = [
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||||
|
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||||
|
]
|
||||||
|
self.api_key = random.choice(default_api_keys)
|
||||||
|
|
||||||
|
self.client = None # 延迟初始化
|
||||||
|
self.fields = [
|
||||||
|
"title",
|
||||||
|
"authors",
|
||||||
|
"abstract",
|
||||||
|
"year",
|
||||||
|
"externalIds",
|
||||||
|
"citationCount",
|
||||||
|
"venue",
|
||||||
|
"openAccessPdf",
|
||||||
|
"publicationVenue"
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _ensure_client(self):
|
||||||
|
"""确保客户端已初始化"""
|
||||||
|
if self.client is None:
|
||||||
|
from semanticscholar import AsyncSemanticScholar
|
||||||
|
self.client = AsyncSemanticScholar(api_key=self.api_key)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""搜索论文"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
|
||||||
|
# 如果指定了起始年份,添加到查询中
|
||||||
|
if start_year:
|
||||||
|
query = f"{query} year>={start_year}"
|
||||||
|
|
||||||
|
# 直接使用 search_paper 的结果
|
||||||
|
response = await self.client._requester.get_data_async(
|
||||||
|
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
|
||||||
|
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
|
||||||
|
self.client.auth_header
|
||||||
|
)
|
||||||
|
papers = response.get('data', [])
|
||||||
|
return [self._parse_paper_data(paper) for paper in papers]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索论文时发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
|
||||||
|
"""获取指定DOI的论文详情"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
|
||||||
|
return self._parse_paper_data(paper)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文详情时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_citations(
|
||||||
|
self,
|
||||||
|
doi: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""获取引用指定DOI论文的文献列表"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
# 构建查询参数
|
||||||
|
fields_param = f"fields={','.join(self.fields)}"
|
||||||
|
limit_param = f"limit={limit}"
|
||||||
|
year_param = f"year>={start_year}" if start_year else ""
|
||||||
|
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||||
|
|
||||||
|
response = await self.client._requester.get_data_async(
|
||||||
|
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
|
||||||
|
params,
|
||||||
|
self.client.auth_header
|
||||||
|
)
|
||||||
|
citations = response.get('data', [])
|
||||||
|
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取引用列表时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_references(
|
||||||
|
self,
|
||||||
|
doi: str,
|
||||||
|
limit: int = 100,
|
||||||
|
start_year: int = None
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""获取指定DOI论文的参考文献列表"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
# 构建查询参数
|
||||||
|
fields_param = f"fields={','.join(self.fields)}"
|
||||||
|
limit_param = f"limit={limit}"
|
||||||
|
year_param = f"year>={start_year}" if start_year else ""
|
||||||
|
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||||
|
|
||||||
|
response = await self.client._requester.get_data_async(
|
||||||
|
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
|
||||||
|
params,
|
||||||
|
self.client.auth_header
|
||||||
|
)
|
||||||
|
references = response.get('data', [])
|
||||||
|
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取参考文献列表时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
"""获取论文推荐
|
||||||
|
|
||||||
|
根据一篇论文获取相关的推荐论文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doi: 论文的DOI
|
||||||
|
limit: 返回结果数量限制,最大500
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
papers = await self.client.get_recommended_papers(
|
||||||
|
f"DOI:{doi}",
|
||||||
|
fields=self.fields,
|
||||||
|
limit=min(limit, 500)
|
||||||
|
)
|
||||||
|
return [self._parse_paper_data(paper) for paper in papers]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文推荐时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_recommended_papers_from_lists(
|
||||||
|
self,
|
||||||
|
positive_dois: List[str],
|
||||||
|
negative_dois: List[str] = None,
|
||||||
|
limit: int = 100
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""基于正负例论文列表获取推荐
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positive_dois: 正例论文DOI列表(想要获取类似的论文)
|
||||||
|
negative_dois: 负例论文DOI列表(不想要类似的论文)
|
||||||
|
limit: 返回结果数量限制,最大500
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐论文列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
|
||||||
|
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
|
||||||
|
|
||||||
|
papers = await self.client.get_recommended_papers_from_lists(
|
||||||
|
positive_paper_ids=positive_ids,
|
||||||
|
negative_paper_ids=negative_ids,
|
||||||
|
fields=self.fields,
|
||||||
|
limit=min(limit, 500)
|
||||||
|
)
|
||||||
|
return [self._parse_paper_data(paper) for paper in papers]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文推荐列表时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
|
||||||
|
"""搜索作者"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
# 直接使用 API 请求而不是 search_author 方法
|
||||||
|
response = await self.client._requester.get_data_async(
|
||||||
|
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
|
||||||
|
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
|
||||||
|
self.client.auth_header
|
||||||
|
)
|
||||||
|
authors = response.get('data', [])
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'author_id': author.get('authorId'),
|
||||||
|
'name': author.get('name'),
|
||||||
|
'paper_count': author.get('paperCount'),
|
||||||
|
'citation_count': author.get('citationCount'),
|
||||||
|
}
|
||||||
|
for author in authors
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索作者时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_author_details(self, author_id: str) -> Optional[dict]:
|
||||||
|
"""获取作者详细信息"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
# 直接使用 API 请求
|
||||||
|
response = await self.client._requester.get_data_async(
|
||||||
|
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
|
||||||
|
"fields=name,paperCount,citationCount,hIndex",
|
||||||
|
self.client.auth_header
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'author_id': response.get('authorId'),
|
||||||
|
'name': response.get('name'),
|
||||||
|
'paper_count': response.get('paperCount'),
|
||||||
|
'citation_count': response.get('citationCount'),
|
||||||
|
'h_index': response.get('hIndex'),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取作者详情时发生错误: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
"""获取作者的论文列表"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
# 直接使用 API 请求
|
||||||
|
response = await self.client._requester.get_data_async(
|
||||||
|
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
|
||||||
|
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
|
||||||
|
self.client.auth_header
|
||||||
|
)
|
||||||
|
papers = response.get('data', [])
|
||||||
|
return [self._parse_paper_data(paper) for paper in papers]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取作者论文列表时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_paper_autocomplete(self, query: str) -> List[dict]:
|
||||||
|
"""论文标题自动补全"""
|
||||||
|
try:
|
||||||
|
await self._ensure_client()
|
||||||
|
# 直接使用 API 请求
|
||||||
|
response = await self.client._requester.get_data_async(
|
||||||
|
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
|
||||||
|
f"query={query}",
|
||||||
|
self.client.auth_header
|
||||||
|
)
|
||||||
|
suggestions = response.get('matches', [])
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'title': suggestion.get('title'),
|
||||||
|
'paper_id': suggestion.get('paperId'),
|
||||||
|
'year': suggestion.get('year'),
|
||||||
|
'venue': suggestion.get('venue'),
|
||||||
|
}
|
||||||
|
for suggestion in suggestions
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取标题自动补全时发生错误: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _parse_paper_data(self, paper) -> PaperMetadata:
|
||||||
|
"""解析论文数据"""
|
||||||
|
# 获取DOI
|
||||||
|
doi = None
|
||||||
|
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
|
||||||
|
if external_ids:
|
||||||
|
if isinstance(external_ids, dict):
|
||||||
|
doi = external_ids.get('DOI')
|
||||||
|
if not doi and 'ArXiv' in external_ids:
|
||||||
|
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
|
||||||
|
else:
|
||||||
|
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
|
||||||
|
if not doi and hasattr(external_ids, 'ArXiv'):
|
||||||
|
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
|
||||||
|
|
||||||
|
# 获取PDF URL
|
||||||
|
pdf_url = None
|
||||||
|
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
|
||||||
|
if pdf_info:
|
||||||
|
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
|
||||||
|
|
||||||
|
# 获取发表场所详细信息
|
||||||
|
venue_type = None
|
||||||
|
venue_name = None
|
||||||
|
venue_info = {}
|
||||||
|
|
||||||
|
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
|
||||||
|
if venue:
|
||||||
|
if isinstance(venue, dict):
|
||||||
|
venue_name = venue.get('name')
|
||||||
|
venue_type = venue.get('type')
|
||||||
|
# 提取更多venue信息
|
||||||
|
venue_info = {
|
||||||
|
'issn': venue.get('issn'),
|
||||||
|
'publisher': venue.get('publisher'),
|
||||||
|
'url': venue.get('url'),
|
||||||
|
'alternate_names': venue.get('alternate_names', [])
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
venue_name = venue.name if hasattr(venue, 'name') else None
|
||||||
|
venue_type = venue.type if hasattr(venue, 'type') else None
|
||||||
|
venue_info = {
|
||||||
|
'issn': getattr(venue, 'issn', None),
|
||||||
|
'publisher': getattr(venue, 'publisher', None),
|
||||||
|
'url': getattr(venue, 'url', None),
|
||||||
|
'alternate_names': getattr(venue, 'alternate_names', [])
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取标题
|
||||||
|
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
|
||||||
|
|
||||||
|
# 获取作者
|
||||||
|
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
|
||||||
|
author_names = []
|
||||||
|
for author in authors:
|
||||||
|
if isinstance(author, dict):
|
||||||
|
author_names.append(author.get('name', ''))
|
||||||
|
else:
|
||||||
|
author_names.append(author.name if hasattr(author, 'name') else str(author))
|
||||||
|
|
||||||
|
# 获取摘要
|
||||||
|
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
|
||||||
|
|
||||||
|
# 获取年份
|
||||||
|
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
|
||||||
|
|
||||||
|
# 获取引用次数
|
||||||
|
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
|
||||||
|
|
||||||
|
return PaperMetadata(
|
||||||
|
title=title,
|
||||||
|
authors=author_names,
|
||||||
|
abstract=abstract,
|
||||||
|
year=year,
|
||||||
|
doi=doi,
|
||||||
|
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
|
||||||
|
citations=citations,
|
||||||
|
venue=venue_name,
|
||||||
|
institutions=[],
|
||||||
|
venue_type=venue_type,
|
||||||
|
venue_name=venue_name,
|
||||||
|
venue_info=venue_info,
|
||||||
|
source='semantic' # 添加来源标记
|
||||||
|
)
|
||||||
|
|
||||||
|
async def example_usage():
|
||||||
|
"""SemanticScholarSource使用示例"""
|
||||||
|
semantic = SemanticScholarSource()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 示例1:使用DOI直接获取论文
|
||||||
|
print("\n=== 示例1:通过DOI获取论文 ===")
|
||||||
|
doi = "10.18653/v1/N19-1423" # BERT论文
|
||||||
|
print(f"获取DOI为 {doi} 的论文信息...")
|
||||||
|
|
||||||
|
paper = await semantic.get_paper_details(doi)
|
||||||
|
if paper:
|
||||||
|
print("\n--- 论文信息 ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"DOI: {paper.doi}")
|
||||||
|
print(f"URL: {paper.url}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"\n摘要:")
|
||||||
|
print(paper.abstract)
|
||||||
|
print(f"\n引用次数: {paper.citations}")
|
||||||
|
print(f"发表venue: {paper.venue}")
|
||||||
|
|
||||||
|
# 示例2:搜索论文
|
||||||
|
print("\n=== 示例2:搜索论文 ===")
|
||||||
|
query = "BERT pre-training"
|
||||||
|
print(f"搜索关键词 '{query}' 相关的论文...")
|
||||||
|
papers = await semantic.search(query=query, limit=3)
|
||||||
|
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
print(f"\n--- 搜索结果 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
if paper.abstract:
|
||||||
|
print(f"\n摘要:")
|
||||||
|
print(paper.abstract)
|
||||||
|
print(f"\nDOI: {paper.doi}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
|
||||||
|
# 示例3:获取论文推荐
|
||||||
|
print("\n=== 示例3:获取论文推荐 ===")
|
||||||
|
print(f"获取与论文 {doi} 相关的推荐论文...")
|
||||||
|
recommendations = await semantic.get_recommended_papers(doi, limit=3)
|
||||||
|
for i, paper in enumerate(recommendations, 1):
|
||||||
|
print(f"\n--- 推荐论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
|
||||||
|
# 示例4:基于多篇论文的推荐
|
||||||
|
print("\n=== 示例4:基于多篇论文的推荐 ===")
|
||||||
|
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
|
||||||
|
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
|
||||||
|
multi_recommendations = await semantic.get_recommended_papers_from_lists(
|
||||||
|
positive_dois=positive_dois,
|
||||||
|
limit=3
|
||||||
|
)
|
||||||
|
for i, paper in enumerate(multi_recommendations, 1):
|
||||||
|
print(f"\n--- 推荐论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"作者: {', '.join(paper.authors)}")
|
||||||
|
|
||||||
|
# 示例5:搜索作者
|
||||||
|
print("\n=== 示例5:搜索作者 ===")
|
||||||
|
author_query = "Yann LeCun"
|
||||||
|
print(f"搜索作者: '{author_query}'")
|
||||||
|
authors = await semantic.search_author(author_query, limit=3)
|
||||||
|
for i, author in enumerate(authors, 1):
|
||||||
|
print(f"\n--- 作者 {i} ---")
|
||||||
|
print(f"姓名: {author['name']}")
|
||||||
|
print(f"论文数量: {author['paper_count']}")
|
||||||
|
print(f"总引用次数: {author['citation_count']}")
|
||||||
|
|
||||||
|
# 示例6:获取作者详情
|
||||||
|
print("\n=== 示例6:获取作者详情 ===")
|
||||||
|
if authors: # 使用第一个搜索结果的作者ID
|
||||||
|
author_id = authors[0]['author_id']
|
||||||
|
print(f"获取作者ID {author_id} 的详细信息...")
|
||||||
|
author_details = await semantic.get_author_details(author_id)
|
||||||
|
if author_details:
|
||||||
|
print(f"姓名: {author_details['name']}")
|
||||||
|
print(f"H指数: {author_details['h_index']}")
|
||||||
|
print(f"总引用次数: {author_details['citation_count']}")
|
||||||
|
print(f"发表论文数: {author_details['paper_count']}")
|
||||||
|
|
||||||
|
# 示例7:获取作者论文
|
||||||
|
print("\n=== 示例7:获取作者论文 ===")
|
||||||
|
if authors: # 使用第一个搜索结果的作者ID
|
||||||
|
author_id = authors[0]['author_id']
|
||||||
|
print(f"获取作者 {authors[0]['name']} 的论文列表...")
|
||||||
|
author_papers = await semantic.get_author_papers(author_id, limit=3)
|
||||||
|
for i, paper in enumerate(author_papers, 1):
|
||||||
|
print(f"\n--- 论文 {i} ---")
|
||||||
|
print(f"标题: {paper.title}")
|
||||||
|
print(f"发表年份: {paper.year}")
|
||||||
|
print(f"引用次数: {paper.citations}")
|
||||||
|
|
||||||
|
# 示例8:论文标题自动补全
|
||||||
|
print("\n=== 示例8:论文标题自动补全 ===")
|
||||||
|
title_query = "Attention is all"
|
||||||
|
print(f"搜索标题: '{title_query}'")
|
||||||
|
suggestions = await semantic.get_paper_autocomplete(title_query)
|
||||||
|
for i, suggestion in enumerate(suggestions[:3], 1):
|
||||||
|
print(f"\n--- 建议 {i} ---")
|
||||||
|
print(f"标题: {suggestion['title']}")
|
||||||
|
print(f"发表年份: {suggestion['year']}")
|
||||||
|
print(f"发表venue: {suggestion['venue']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"发生错误: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(example_usage())
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
import aiohttp
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from .base_source import DataSource, PaperMetadata
|
||||||
|
|
||||||
|
class UnpaywallSource(DataSource):
|
||||||
|
"""Unpaywall API实现"""
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
self.base_url = "https://api.unpaywall.org/v2"
|
||||||
|
self.email = self.api_key # Unpaywall使用email作为API key
|
||||||
|
|
||||||
|
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
f"{self.base_url}/search",
|
||||||
|
params={
|
||||||
|
"query": query,
|
||||||
|
"email": self.email,
|
||||||
|
"limit": limit
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
data = await response.json()
|
||||||
|
return [self._parse_response(item.response)
|
||||||
|
for item in data.get("results", [])]
|
||||||
|
|
||||||
|
def _parse_response(self, data: Dict) -> PaperMetadata:
|
||||||
|
"""解析Unpaywall返回的数据"""
|
||||||
|
return PaperMetadata(
|
||||||
|
title=data.get("title", ""),
|
||||||
|
authors=[
|
||||||
|
f"{author.get('given', '')} {author.get('family', '')}"
|
||||||
|
for author in data.get("z_authors", [])
|
||||||
|
],
|
||||||
|
institutions=[
|
||||||
|
aff.get("name", "")
|
||||||
|
for author in data.get("z_authors", [])
|
||||||
|
for aff in author.get("affiliation", [])
|
||||||
|
],
|
||||||
|
abstract="", # Unpaywall不提供摘要
|
||||||
|
year=data.get("year"),
|
||||||
|
doi=data.get("doi"),
|
||||||
|
url=data.get("doi_url"),
|
||||||
|
citations=None, # Unpaywall不提供引用计数
|
||||||
|
venue=data.get("journal_name")
|
||||||
|
)
|
||||||
@@ -0,0 +1,412 @@
|
|||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||||
|
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
|
||||||
|
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
|
||||||
|
from crazy_functions.review_fns.data_sources.pubmed_source import PubMedSource
|
||||||
|
from crazy_functions.review_fns.paper_processor.paper_llm_ranker import PaperLLMRanker
|
||||||
|
from crazy_functions.pdf_fns.breakdown_pdf_txt import cut_from_end_to_satisfy_token_limit
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
from crazy_functions.review_fns.data_sources.crossref_source import CrossrefSource
|
||||||
|
from crazy_functions.review_fns.data_sources.adsabs_source import AdsabsSource
|
||||||
|
from toolbox import get_conf
|
||||||
|
|
||||||
|
|
||||||
|
class BaseHandler(ABC):
|
||||||
|
"""处理器基类"""
|
||||||
|
|
||||||
|
def __init__(self, arxiv: ArxivSource, semantic: SemanticScholarSource, llm_kwargs: Dict = None):
|
||||||
|
self.arxiv = arxiv
|
||||||
|
self.semantic = semantic
|
||||||
|
self.pubmed = PubMedSource()
|
||||||
|
self.crossref = CrossrefSource() # 添加 Crossref 实例
|
||||||
|
self.adsabs = AdsabsSource() # 添加 ADS 实例
|
||||||
|
self.paper_ranker = PaperLLMRanker(llm_kwargs=llm_kwargs)
|
||||||
|
self.ranked_papers = [] # 存储排序后的论文列表
|
||||||
|
self.llm_kwargs = llm_kwargs or {} # 保存llm_kwargs
|
||||||
|
|
||||||
|
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
|
||||||
|
"""获取搜索参数"""
|
||||||
|
return {
|
||||||
|
'max_papers': plugin_kwargs.get('max_papers', 100), # 最大论文数量
|
||||||
|
'min_year': plugin_kwargs.get('min_year', 2015), # 最早年份
|
||||||
|
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
|
||||||
|
}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
criteria: SearchCriteria,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any],
|
||||||
|
plugin_kwargs: Dict[str, Any],
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""处理查询"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _search_arxiv(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||||
|
"""使用arXiv专用参数搜索"""
|
||||||
|
try:
|
||||||
|
original_limit = params.get("limit", 20)
|
||||||
|
params["limit"] = original_limit * limit_multiplier
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
# 首先尝试基础搜索
|
||||||
|
query = params.get("query", "")
|
||||||
|
if query:
|
||||||
|
papers = await self.arxiv.search(
|
||||||
|
query,
|
||||||
|
limit=params["limit"],
|
||||||
|
sort_by=params.get("sort_by", "relevance"),
|
||||||
|
sort_order=params.get("sort_order", "descending"),
|
||||||
|
start_year=min_year
|
||||||
|
)
|
||||||
|
|
||||||
|
# 如果基础搜索没有结果,尝试分类搜索
|
||||||
|
if not papers:
|
||||||
|
categories = params.get("categories", [])
|
||||||
|
for category in categories:
|
||||||
|
category_papers = await self.arxiv.search_by_category(
|
||||||
|
category,
|
||||||
|
limit=params["limit"],
|
||||||
|
sort_by=params.get("sort_by", "relevance"),
|
||||||
|
sort_order=params.get("sort_order", "descending"),
|
||||||
|
)
|
||||||
|
if category_papers:
|
||||||
|
papers.extend(category_papers)
|
||||||
|
|
||||||
|
return papers or []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"arXiv搜索出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _search_semantic(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||||
|
"""使用Semantic Scholar专用参数搜索"""
|
||||||
|
try:
|
||||||
|
original_limit = params.get("limit", 20)
|
||||||
|
params["limit"] = original_limit * limit_multiplier
|
||||||
|
|
||||||
|
# 只使用基本的搜索参数
|
||||||
|
papers = await self.semantic.search(
|
||||||
|
query=params.get("query", ""),
|
||||||
|
limit=params["limit"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在内存中进行过滤
|
||||||
|
if papers and min_year:
|
||||||
|
papers = [p for p in papers if getattr(p, 'year', 0) and p.year >= min_year]
|
||||||
|
|
||||||
|
return papers or []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Semantic Scholar搜索出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _search_pubmed(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||||
|
"""使用PubMed专用参数搜索"""
|
||||||
|
try:
|
||||||
|
# 如果不需要PubMed搜索,直接返回空列表
|
||||||
|
if params.get("search_type") == "none":
|
||||||
|
return []
|
||||||
|
|
||||||
|
original_limit = params.get("limit", 20)
|
||||||
|
params["limit"] = original_limit * limit_multiplier
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
# 根据搜索类型选择搜索方法
|
||||||
|
if params.get("search_type") == "basic":
|
||||||
|
papers = await self.pubmed.search(
|
||||||
|
query=params.get("query", ""),
|
||||||
|
limit=params["limit"],
|
||||||
|
start_year=min_year
|
||||||
|
)
|
||||||
|
elif params.get("search_type") == "author":
|
||||||
|
papers = await self.pubmed.search_by_author(
|
||||||
|
author=params.get("query", ""),
|
||||||
|
limit=params["limit"],
|
||||||
|
start_year=min_year
|
||||||
|
)
|
||||||
|
elif params.get("search_type") == "journal":
|
||||||
|
papers = await self.pubmed.search_by_journal(
|
||||||
|
journal=params.get("query", ""),
|
||||||
|
limit=params["limit"],
|
||||||
|
start_year=min_year
|
||||||
|
)
|
||||||
|
|
||||||
|
return papers or []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"PubMed搜索出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _search_crossref(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||||
|
"""使用Crossref专用参数搜索"""
|
||||||
|
try:
|
||||||
|
original_limit = params.get("limit", 20)
|
||||||
|
params["limit"] = original_limit * limit_multiplier
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
# 根据搜索类型选择搜索方法
|
||||||
|
if params.get("search_type") == "basic":
|
||||||
|
papers = await self.crossref.search(
|
||||||
|
query=params.get("query", ""),
|
||||||
|
limit=params["limit"],
|
||||||
|
start_year=min_year
|
||||||
|
)
|
||||||
|
elif params.get("search_type") == "author":
|
||||||
|
papers = await self.crossref.search_by_authors(
|
||||||
|
authors=[params.get("query", "")],
|
||||||
|
limit=params["limit"],
|
||||||
|
start_year=min_year
|
||||||
|
)
|
||||||
|
elif params.get("search_type") == "journal":
|
||||||
|
# 实现期刊搜索逻辑
|
||||||
|
pass
|
||||||
|
|
||||||
|
return papers or []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Crossref搜索出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _search_adsabs(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||||
|
"""使用ADS专用参数搜索"""
|
||||||
|
try:
|
||||||
|
original_limit = params.get("limit", 20)
|
||||||
|
params["limit"] = original_limit * limit_multiplier
|
||||||
|
papers = []
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
|
if params.get("search_type") == "basic":
|
||||||
|
papers = await self.adsabs.search(
|
||||||
|
query=params.get("query", ""),
|
||||||
|
limit=params["limit"],
|
||||||
|
start_year=min_year
|
||||||
|
)
|
||||||
|
|
||||||
|
return papers or []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"ADS搜索出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _search_all_sources(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||||
|
"""从所有数据源搜索论文"""
|
||||||
|
search_tasks = []
|
||||||
|
|
||||||
|
# # 检查是否需要执行PubMed搜索
|
||||||
|
# is_using_pubmed = criteria.pubmed_params.get("search_type") != "none" and criteria.pubmed_params.get("query") != "none"
|
||||||
|
is_using_pubmed = False # 开源版本不再搜索pubmed
|
||||||
|
|
||||||
|
# 如果使用PubMed,则只执行PubMed和Semantic Scholar搜索
|
||||||
|
if is_using_pubmed:
|
||||||
|
search_tasks.append(
|
||||||
|
self._search_pubmed(
|
||||||
|
criteria.pubmed_params,
|
||||||
|
limit_multiplier=search_params['search_multiplier'],
|
||||||
|
min_year=criteria.start_year
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Semantic Scholar总是执行搜索
|
||||||
|
search_tasks.append(
|
||||||
|
self._search_semantic(
|
||||||
|
criteria.semantic_params,
|
||||||
|
limit_multiplier=search_params['search_multiplier'],
|
||||||
|
min_year=criteria.start_year
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# 如果不使用ADS,则执行Crossref搜索
|
||||||
|
if criteria.crossref_params.get("search_type") != "none" and criteria.crossref_params.get("query") != "none":
|
||||||
|
search_tasks.append(
|
||||||
|
self._search_crossref(
|
||||||
|
criteria.crossref_params,
|
||||||
|
limit_multiplier=search_params['search_multiplier'],
|
||||||
|
min_year=criteria.start_year
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
search_tasks.append(
|
||||||
|
self._search_arxiv(
|
||||||
|
criteria.arxiv_params,
|
||||||
|
limit_multiplier=search_params['search_multiplier'],
|
||||||
|
min_year=criteria.start_year
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if get_conf("SEMANTIC_SCHOLAR_KEY"):
|
||||||
|
search_tasks.append(
|
||||||
|
self._search_semantic(
|
||||||
|
criteria.semantic_params,
|
||||||
|
limit_multiplier=search_params['search_multiplier'],
|
||||||
|
min_year=criteria.start_year
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行所有需要的搜索任务
|
||||||
|
papers = await asyncio.gather(*search_tasks)
|
||||||
|
|
||||||
|
# 合并所有来源的论文并统计各来源的数量
|
||||||
|
all_papers = []
|
||||||
|
source_counts = {
|
||||||
|
'arxiv': 0,
|
||||||
|
'semantic': 0,
|
||||||
|
'pubmed': 0,
|
||||||
|
'crossref': 0,
|
||||||
|
'adsabs': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
for source_papers in papers:
|
||||||
|
if source_papers:
|
||||||
|
for paper in source_papers:
|
||||||
|
source = getattr(paper, 'source', 'unknown')
|
||||||
|
if source in source_counts:
|
||||||
|
source_counts[source] += 1
|
||||||
|
all_papers.extend(source_papers)
|
||||||
|
|
||||||
|
# 打印各来源的论文数量
|
||||||
|
print("\n=== 各数据源找到的论文数量 ===")
|
||||||
|
for source, count in source_counts.items():
|
||||||
|
if count > 0: # 只打印有论文的来源
|
||||||
|
print(f"{source.capitalize()}: {count} 篇")
|
||||||
|
print(f"总计: {len(all_papers)} 篇")
|
||||||
|
print("===========================\n")
|
||||||
|
|
||||||
|
return all_papers
|
||||||
|
|
||||||
|
def _format_paper_time(self, paper) -> str:
|
||||||
|
"""格式化论文时间信息"""
|
||||||
|
year = getattr(paper, 'year', None)
|
||||||
|
if not year:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 如果有具体的发表日期,使用具体日期
|
||||||
|
if hasattr(paper, 'published') and paper.published:
|
||||||
|
return f"(发表于 {paper.published.strftime('%Y-%m')})"
|
||||||
|
# 如果只有年份,只显示年份
|
||||||
|
return f"({year})"
|
||||||
|
|
||||||
|
def _format_papers(self, papers: List) -> str:
|
||||||
|
"""格式化论文列表,使用token限制控制长度"""
|
||||||
|
formatted = []
|
||||||
|
|
||||||
|
for i, paper in enumerate(papers, 1):
|
||||||
|
# 只保留前三个作者
|
||||||
|
authors = paper.authors[:3]
|
||||||
|
if len(paper.authors) > 3:
|
||||||
|
authors.append("et al.")
|
||||||
|
|
||||||
|
# 构建所有可能的下载链接
|
||||||
|
download_links = []
|
||||||
|
|
||||||
|
# 添加arXiv链接
|
||||||
|
if hasattr(paper, 'doi') and paper.doi:
|
||||||
|
if paper.doi.startswith("10.48550/arXiv."):
|
||||||
|
# 从DOI中提取完整的arXiv ID
|
||||||
|
arxiv_id = paper.doi.split("arXiv.")[-1]
|
||||||
|
# 移除多余的点号并确保格式正确
|
||||||
|
arxiv_id = arxiv_id.replace("..", ".") # 移除重复的点号
|
||||||
|
if arxiv_id.startswith("."): # 移除开头的点号
|
||||||
|
arxiv_id = arxiv_id[1:]
|
||||||
|
if arxiv_id.endswith("."): # 移除结尾的点号
|
||||||
|
arxiv_id = arxiv_id[:-1]
|
||||||
|
|
||||||
|
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||||||
|
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||||||
|
elif "arxiv.org/abs/" in paper.doi:
|
||||||
|
# 直接从URL中提取arXiv ID
|
||||||
|
arxiv_id = paper.doi.split("arxiv.org/abs/")[-1]
|
||||||
|
if "v" in arxiv_id: # 移除版本号
|
||||||
|
arxiv_id = arxiv_id.split("v")[0]
|
||||||
|
|
||||||
|
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||||||
|
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||||||
|
else:
|
||||||
|
download_links.append(f"[DOI](https://doi.org/{paper.doi})")
|
||||||
|
|
||||||
|
# 添加直接URL链接(如果存在且不同于前面的链接)
|
||||||
|
if hasattr(paper, 'url') and paper.url:
|
||||||
|
if not any(paper.url in link for link in download_links):
|
||||||
|
download_links.append(f"[Source]({paper.url})")
|
||||||
|
|
||||||
|
# 构建下载链接字符串
|
||||||
|
download_section = " | ".join(download_links) if download_links else "No direct download link available"
|
||||||
|
|
||||||
|
# 构建来源信息
|
||||||
|
source_info = []
|
||||||
|
if hasattr(paper, 'venue_type') and paper.venue_type and paper.venue_type != 'preprint':
|
||||||
|
source_info.append(f"Type: {paper.venue_type}")
|
||||||
|
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||||
|
source_info.append(f"Venue: {paper.venue_name}")
|
||||||
|
|
||||||
|
# 添加IF指数和分区信息
|
||||||
|
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||||
|
source_info.append(f"IF: {paper.if_factor}")
|
||||||
|
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||||
|
source_info.append(f"中科院分区: {paper.cas_division}")
|
||||||
|
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||||
|
source_info.append(f"JCR分区: {paper.jcr_division}")
|
||||||
|
|
||||||
|
if hasattr(paper, 'venue_info') and paper.venue_info:
|
||||||
|
if paper.venue_info.get('journal_ref'):
|
||||||
|
source_info.append(f"Journal Reference: {paper.venue_info['journal_ref']}")
|
||||||
|
if paper.venue_info.get('publisher'):
|
||||||
|
source_info.append(f"Publisher: {paper.venue_info['publisher']}")
|
||||||
|
|
||||||
|
# 构建当前论文的格式化文本
|
||||||
|
paper_text = (
|
||||||
|
f"{i}. **{paper.title}**\n" +
|
||||||
|
f" Authors: {', '.join(authors)}\n" +
|
||||||
|
f" Year: {paper.year}\n" +
|
||||||
|
f" Citations: {paper.citations if paper.citations else 'N/A'}\n" +
|
||||||
|
(f" Source: {'; '.join(source_info)}\n" if source_info else "") +
|
||||||
|
# 添加PubMed特有信息
|
||||||
|
(f" MeSH Terms: {'; '.join(paper.mesh_terms)}\n" if hasattr(paper,
|
||||||
|
'mesh_terms') and paper.mesh_terms else "") +
|
||||||
|
f" 📥 PDF Downloads: {download_section}\n" +
|
||||||
|
f" Abstract: {paper.abstract}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted.append(paper_text)
|
||||||
|
|
||||||
|
full_text = "\n".join(formatted)
|
||||||
|
|
||||||
|
# 根据不同模型设置不同的token限制
|
||||||
|
model_name = getattr(self, 'llm_kwargs', {}).get('llm_model', 'gpt-3.5-turbo')
|
||||||
|
|
||||||
|
token_limit = model_info[model_name]['max_token'] * 3 // 4
|
||||||
|
# 使用token限制控制长度
|
||||||
|
return cut_from_end_to_satisfy_token_limit(full_text, limit=token_limit, reserve_token=0, llm_model=model_name)
|
||||||
|
|
||||||
|
def _get_current_time(self) -> str:
|
||||||
|
"""获取当前时间信息"""
|
||||||
|
now = datetime.now()
|
||||||
|
return now.strftime("%Y年%m月%d日")
|
||||||
|
|
||||||
|
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
|
||||||
|
"""生成道歉提示"""
|
||||||
|
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的有效文献。
|
||||||
|
|
||||||
|
可能的原因:
|
||||||
|
1. 搜索词过于具体或专业
|
||||||
|
2. 时间范围限制过严
|
||||||
|
|
||||||
|
建议解决方案:
|
||||||
|
1. 尝试使用更通用的关键词
|
||||||
|
2. 扩大搜索时间范围
|
||||||
|
3. 使用同义词或相关术语
|
||||||
|
请根据以上建议调整后重试。"""
|
||||||
|
|
||||||
|
def get_ranked_papers(self) -> str:
|
||||||
|
"""获取排序后的论文列表的格式化字符串"""
|
||||||
|
return self._format_papers(self.ranked_papers) if self.ranked_papers else ""
|
||||||
|
|
||||||
|
def _is_pubmed_paper(self, paper) -> bool:
|
||||||
|
"""判断是否为PubMed论文"""
|
||||||
|
return (paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url)
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
from .base_handler import BaseHandler
|
||||||
|
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
class Arxiv最新论文推荐功能(BaseHandler):
|
||||||
|
"""最新论文推荐处理器"""
|
||||||
|
|
||||||
|
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||||
|
super().__init__(arxiv, semantic, llm_kwargs)
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
criteria: SearchCriteria,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any],
|
||||||
|
plugin_kwargs: Dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""处理最新论文推荐请求"""
|
||||||
|
|
||||||
|
# 获取搜索参数
|
||||||
|
search_params = self._get_search_params(plugin_kwargs)
|
||||||
|
|
||||||
|
# 获取最新论文
|
||||||
|
papers = []
|
||||||
|
for category in criteria.arxiv_params["categories"]:
|
||||||
|
latest_papers = await self.arxiv.get_latest_papers(
|
||||||
|
category=category,
|
||||||
|
debug=False,
|
||||||
|
batch_size=50
|
||||||
|
)
|
||||||
|
papers.extend(latest_papers)
|
||||||
|
|
||||||
|
if not papers:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
# 使用embedding模型对论文进行排序
|
||||||
|
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||||
|
query=criteria.original_query,
|
||||||
|
papers=papers,
|
||||||
|
search_criteria=criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建最终的prompt
|
||||||
|
current_time = self._get_current_time()
|
||||||
|
final_prompt = f"""Current time: {current_time}
|
||||||
|
|
||||||
|
Based on your interest in {criteria.main_topic}, here are the latest papers from arXiv in relevant categories:
|
||||||
|
{', '.join(criteria.arxiv_params["categories"])}
|
||||||
|
|
||||||
|
Latest papers available:
|
||||||
|
{self._format_papers(self.ranked_papers)}
|
||||||
|
|
||||||
|
Please provide:
|
||||||
|
1. A clear list of latext papers, organized by themes or approaches
|
||||||
|
|
||||||
|
|
||||||
|
2. Group papers by sub-topics or themes if applicable
|
||||||
|
|
||||||
|
3. For each paper:
|
||||||
|
- Publication time
|
||||||
|
- The key contributions and main findings
|
||||||
|
- Why it's relevant to the user's interests
|
||||||
|
- How it relates to other latest papers
|
||||||
|
- The paper's citation count and citation impact
|
||||||
|
- The paper's download link
|
||||||
|
|
||||||
|
4. A suggested reading order based on:
|
||||||
|
- Paper relationships and dependencies
|
||||||
|
- Difficulty level
|
||||||
|
- Significance
|
||||||
|
|
||||||
|
5. Future Directions
|
||||||
|
- Emerging venues and research streams
|
||||||
|
- Novel methodological approaches
|
||||||
|
- Cross-disciplinary opportunities
|
||||||
|
- Research gaps by publication type
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- Focus on explaining why each paper is interesting
|
||||||
|
- Highlight the novelty and potential impact
|
||||||
|
- Consider the credibility and stage of each publication
|
||||||
|
- Use the provided paper titles with their links when referring to specific papers
|
||||||
|
- Base recommendations ONLY on the explicitly provided paper information
|
||||||
|
- Do not make ANY assumptions about papers beyond the given data
|
||||||
|
- When information is missing or unclear, acknowledge the limitation
|
||||||
|
- Never speculate about:
|
||||||
|
* Paper quality or rigor not evidenced in the data
|
||||||
|
* Research impact beyond citation counts
|
||||||
|
* Implementation details not mentioned
|
||||||
|
* Author expertise or background
|
||||||
|
* Future research directions not stated
|
||||||
|
- For each paper, cite only verifiable information
|
||||||
|
- Clearly distinguish between facts and potential implications
|
||||||
|
- Each paper includes download links in its 📥 PDF Downloads section
|
||||||
|
|
||||||
|
Format your response in markdown with clear sections.
|
||||||
|
|
||||||
|
Language requirement:
|
||||||
|
- If the query explicitly specifies a language, use that language
|
||||||
|
- Otherwise, match the language of the original user query
|
||||||
|
"""
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
@@ -0,0 +1,344 @@
|
|||||||
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
|
from .base_handler import BaseHandler
|
||||||
|
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||||
|
import asyncio
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||||
|
|
||||||
|
class 单篇论文分析功能(BaseHandler):
|
||||||
|
"""论文分析处理器"""
|
||||||
|
|
||||||
|
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||||
|
super().__init__(arxiv, semantic, llm_kwargs)
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
criteria: SearchCriteria,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any],
|
||||||
|
plugin_kwargs: Dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""处理论文分析请求,返回最终的prompt"""
|
||||||
|
|
||||||
|
# 1. 获取论文详情
|
||||||
|
paper = await self._get_paper_details(criteria)
|
||||||
|
if not paper:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
# 保存为ranked_papers以便统一接口
|
||||||
|
self.ranked_papers = [paper]
|
||||||
|
|
||||||
|
# 2. 构建最终的prompt
|
||||||
|
current_time = self._get_current_time()
|
||||||
|
|
||||||
|
# 获取论文信息
|
||||||
|
title = getattr(paper, "title", "Unknown Title")
|
||||||
|
authors = getattr(paper, "authors", [])
|
||||||
|
year = getattr(paper, "year", "Unknown Year")
|
||||||
|
abstract = getattr(paper, "abstract", "No abstract available")
|
||||||
|
citations = getattr(paper, "citations", "N/A")
|
||||||
|
|
||||||
|
# 添加论文ID信息
|
||||||
|
paper_id = ""
|
||||||
|
if criteria.paper_source == "arxiv":
|
||||||
|
paper_id = f"arXiv ID: {criteria.paper_id}\n"
|
||||||
|
elif criteria.paper_source == "doi":
|
||||||
|
paper_id = f"DOI: {criteria.paper_id}\n"
|
||||||
|
|
||||||
|
# 格式化作者列表
|
||||||
|
authors_str = ', '.join(authors) if isinstance(authors, list) else authors
|
||||||
|
|
||||||
|
final_prompt = f"""Current time: {current_time}
|
||||||
|
|
||||||
|
Please provide a comprehensive analysis of the following paper:
|
||||||
|
|
||||||
|
{paper_id}Title: {title}
|
||||||
|
Authors: {authors_str}
|
||||||
|
Year: {year}
|
||||||
|
Citations: {citations}
|
||||||
|
Publication Venue: {paper.venue_name} ({paper.venue_type})
|
||||||
|
{f"Publisher: {paper.venue_info.get('publisher')}" if paper.venue_info.get('publisher') else ""}
|
||||||
|
{f"Journal Reference: {paper.venue_info.get('journal_ref')}" if paper.venue_info.get('journal_ref') else ""}
|
||||||
|
Abstract: {abstract}
|
||||||
|
|
||||||
|
Please provide:
|
||||||
|
1. Publication Context
|
||||||
|
- Publication venue analysis and impact factor (if available)
|
||||||
|
- Paper type (journal article, conference paper, preprint)
|
||||||
|
- Publication timeline and peer review status
|
||||||
|
- Publisher reputation and venue prestige
|
||||||
|
|
||||||
|
2. Research Context
|
||||||
|
- Field positioning and significance
|
||||||
|
- Historical context and prior work
|
||||||
|
- Related research streams
|
||||||
|
- Cross-venue impact analysis
|
||||||
|
|
||||||
|
3. Technical Analysis
|
||||||
|
- Detailed methodology review
|
||||||
|
- Implementation details
|
||||||
|
- Experimental setup and results
|
||||||
|
- Technical innovations
|
||||||
|
|
||||||
|
4. Impact Analysis
|
||||||
|
- Citation patterns and influence
|
||||||
|
- Cross-venue recognition
|
||||||
|
- Industry vs. academic impact
|
||||||
|
- Practical applications
|
||||||
|
|
||||||
|
5. Critical Review
|
||||||
|
- Methodological rigor assessment
|
||||||
|
- Result reliability and reproducibility
|
||||||
|
- Venue-appropriate evaluation standards
|
||||||
|
- Limitations and potential improvements
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- Strictly use ONLY the information provided above about the paper
|
||||||
|
- Do not make ANY assumptions or inferences beyond the given data
|
||||||
|
- If certain information is not provided, explicitly state that it is unknown
|
||||||
|
- For any unclear or missing details, acknowledge the limitation rather than speculating
|
||||||
|
- When discussing methodology or results, only describe what is explicitly stated in the abstract
|
||||||
|
- Never fabricate or assume any details about:
|
||||||
|
* Publication venues or status
|
||||||
|
* Implementation details not mentioned
|
||||||
|
* Results or findings not stated
|
||||||
|
* Impact or influence not supported by the citation count
|
||||||
|
* Authors' affiliations or backgrounds
|
||||||
|
* Future work or implications not mentioned
|
||||||
|
- You can find the paper's download options in the 📥 PDF Downloads section
|
||||||
|
- Available download formats include arXiv PDF, DOI links, and source URLs
|
||||||
|
|
||||||
|
Format your response in markdown with clear sections.
|
||||||
|
|
||||||
|
Language requirement:
|
||||||
|
- If the query explicitly specifies a language, use that language
|
||||||
|
- Otherwise, match the language of the original user query
|
||||||
|
"""
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
async def _get_paper_details(self, criteria: SearchCriteria):
|
||||||
|
"""获取论文详情"""
|
||||||
|
try:
|
||||||
|
if criteria.paper_source == "arxiv":
|
||||||
|
# 使用 arxiv ID 搜索
|
||||||
|
papers = await self.arxiv.search_by_id(criteria.paper_id)
|
||||||
|
return papers[0] if papers else None
|
||||||
|
|
||||||
|
elif criteria.paper_source == "doi":
|
||||||
|
# 尝试从所有来源获取
|
||||||
|
paper = await self.semantic.get_paper_by_doi(criteria.paper_id)
|
||||||
|
if not paper:
|
||||||
|
# 如果Semantic Scholar没有找到,尝试PubMed
|
||||||
|
papers = await self.pubmed.search(
|
||||||
|
f"{criteria.paper_id}[doi]",
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
if papers:
|
||||||
|
return papers[0]
|
||||||
|
return paper
|
||||||
|
|
||||||
|
elif criteria.paper_source == "title":
|
||||||
|
# 使用_search_all_sources搜索
|
||||||
|
search_params = {
|
||||||
|
'max_papers': 1,
|
||||||
|
'min_year': 1900, # 不限制年份
|
||||||
|
'search_multiplier': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 设置搜索参数
|
||||||
|
criteria.arxiv_params = {
|
||||||
|
"search_type": "basic",
|
||||||
|
"query": f'ti:"{criteria.paper_title}"',
|
||||||
|
"limit": 1
|
||||||
|
}
|
||||||
|
criteria.semantic_params = {
|
||||||
|
"query": criteria.paper_title,
|
||||||
|
"limit": 1
|
||||||
|
}
|
||||||
|
criteria.pubmed_params = {
|
||||||
|
"search_type": "basic",
|
||||||
|
"query": f'"{criteria.paper_title}"[Title]',
|
||||||
|
"limit": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
papers = await self._search_all_sources(criteria, search_params)
|
||||||
|
return papers[0] if papers else None
|
||||||
|
|
||||||
|
# 如果都没有找到,尝试使用 main_topic 作为标题搜索
|
||||||
|
if not criteria.paper_title and not criteria.paper_id:
|
||||||
|
search_params = {
|
||||||
|
'max_papers': 1,
|
||||||
|
'min_year': 1900,
|
||||||
|
'search_multiplier': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 设置搜索参数
|
||||||
|
criteria.arxiv_params = {
|
||||||
|
"search_type": "basic",
|
||||||
|
"query": f'ti:"{criteria.main_topic}"',
|
||||||
|
"limit": 1
|
||||||
|
}
|
||||||
|
criteria.semantic_params = {
|
||||||
|
"query": criteria.main_topic,
|
||||||
|
"limit": 1
|
||||||
|
}
|
||||||
|
criteria.pubmed_params = {
|
||||||
|
"search_type": "basic",
|
||||||
|
"query": f'"{criteria.main_topic}"[Title]',
|
||||||
|
"limit": 1
|
||||||
|
}
|
||||||
|
|
||||||
|
papers = await self._search_all_sources(criteria, search_params)
|
||||||
|
return papers[0] if papers else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文详情时出错: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_citation_context(self, paper: Dict, plugin_kwargs: Dict) -> Tuple[List, List]:
|
||||||
|
"""获取引用上下文"""
|
||||||
|
search_params = self._get_search_params(plugin_kwargs)
|
||||||
|
|
||||||
|
# 使用论文标题构建搜索参数
|
||||||
|
title_query = f'ti:"{getattr(paper, "title", "")}"'
|
||||||
|
arxiv_params = {
|
||||||
|
"query": title_query,
|
||||||
|
"limit": search_params['max_papers'],
|
||||||
|
"search_type": "basic",
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"sort_order": "descending"
|
||||||
|
}
|
||||||
|
semantic_params = {
|
||||||
|
"query": getattr(paper, "title", ""),
|
||||||
|
"limit": search_params['max_papers']
|
||||||
|
}
|
||||||
|
|
||||||
|
citations, references = await asyncio.gather(
|
||||||
|
self._search_semantic(
|
||||||
|
semantic_params,
|
||||||
|
limit_multiplier=search_params['search_multiplier'],
|
||||||
|
min_year=search_params['min_year']
|
||||||
|
),
|
||||||
|
self._search_arxiv(
|
||||||
|
arxiv_params,
|
||||||
|
limit_multiplier=search_params['search_multiplier'],
|
||||||
|
min_year=search_params['min_year']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return citations, references
|
||||||
|
|
||||||
|
async def _generate_analysis(
|
||||||
|
self,
|
||||||
|
paper: Dict,
|
||||||
|
citations: List,
|
||||||
|
references: List,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any]
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""生成论文分析"""
|
||||||
|
|
||||||
|
# 构建提示
|
||||||
|
analysis_prompt = f"""Please provide a comprehensive analysis of the following paper:
|
||||||
|
|
||||||
|
Paper details:
|
||||||
|
{self._format_paper(paper)}
|
||||||
|
|
||||||
|
Key references (papers cited by this paper):
|
||||||
|
{self._format_papers(references)}
|
||||||
|
|
||||||
|
Important citations (papers that cite this paper):
|
||||||
|
{self._format_papers(citations)}
|
||||||
|
|
||||||
|
Please provide:
|
||||||
|
1. Paper Overview
|
||||||
|
- Main research question/objective
|
||||||
|
- Key methodology/approach
|
||||||
|
- Main findings/contributions
|
||||||
|
|
||||||
|
2. Technical Analysis
|
||||||
|
- Detailed methodology review
|
||||||
|
- Technical innovations
|
||||||
|
- Implementation details
|
||||||
|
- Experimental setup and results
|
||||||
|
|
||||||
|
3. Impact Analysis
|
||||||
|
- Significance in the field
|
||||||
|
- Influence on subsequent research (based on citing papers)
|
||||||
|
- Relationship to prior work (based on cited papers)
|
||||||
|
- Practical applications
|
||||||
|
|
||||||
|
4. Critical Review
|
||||||
|
- Strengths and limitations
|
||||||
|
- Potential improvements
|
||||||
|
- Open questions and future directions
|
||||||
|
- Alternative approaches
|
||||||
|
|
||||||
|
5. Related Research Context
|
||||||
|
- How it builds on previous work
|
||||||
|
- How it has influenced subsequent research
|
||||||
|
- Comparison with alternative approaches
|
||||||
|
|
||||||
|
Format your response in markdown with clear sections."""
|
||||||
|
|
||||||
|
# 并行生成概述和技术分析
|
||||||
|
for response_chunk in request_gpt(
|
||||||
|
inputs_array=[
|
||||||
|
analysis_prompt,
|
||||||
|
self._get_technical_prompt(paper)
|
||||||
|
],
|
||||||
|
inputs_show_user_array=[
|
||||||
|
"Generating paper analysis...",
|
||||||
|
"Analyzing technical details..."
|
||||||
|
],
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history_array=[history, []],
|
||||||
|
sys_prompt_array=[
|
||||||
|
system_prompt,
|
||||||
|
"You are an expert at analyzing technical details in research papers."
|
||||||
|
]
|
||||||
|
):
|
||||||
|
pass # 等待生成完成
|
||||||
|
|
||||||
|
# 获取最后的两个回答
|
||||||
|
if chatbot and len(chatbot[-2:]) == 2:
|
||||||
|
analysis = chatbot[-2][1]
|
||||||
|
technical = chatbot[-1][1]
|
||||||
|
full_analysis = f"""# Paper Analysis: {paper.title}
|
||||||
|
|
||||||
|
## General Analysis
|
||||||
|
{analysis}
|
||||||
|
|
||||||
|
## Technical Deep Dive
|
||||||
|
{technical}
|
||||||
|
"""
|
||||||
|
chatbot.append(["Here is the paper analysis:", full_analysis])
|
||||||
|
else:
|
||||||
|
chatbot.append(["Here is the paper analysis:", "Failed to generate analysis."])
|
||||||
|
|
||||||
|
return chatbot
|
||||||
|
|
||||||
|
def _get_technical_prompt(self, paper: Dict) -> str:
|
||||||
|
"""生成技术分析提示"""
|
||||||
|
return f"""Please provide a detailed technical analysis of the following paper:
|
||||||
|
|
||||||
|
{self._format_paper(paper)}
|
||||||
|
|
||||||
|
Focus on:
|
||||||
|
1. Mathematical formulations and their implications
|
||||||
|
2. Algorithm design and complexity analysis
|
||||||
|
3. Architecture details and design choices
|
||||||
|
4. Implementation challenges and solutions
|
||||||
|
5. Performance analysis and bottlenecks
|
||||||
|
6. Technical limitations and potential improvements
|
||||||
|
|
||||||
|
Format your response in markdown, focusing purely on technical aspects."""
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
from .base_handler import BaseHandler
|
||||||
|
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||||
|
from textwrap import dedent
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||||
|
|
||||||
|
class 学术问答功能(BaseHandler):
|
||||||
|
"""学术问答处理器"""
|
||||||
|
|
||||||
|
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||||
|
super().__init__(arxiv, semantic, llm_kwargs)
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
criteria: SearchCriteria,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any],
|
||||||
|
plugin_kwargs: Dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""处理学术问答请求,返回最终的prompt"""
|
||||||
|
|
||||||
|
# 1. 获取搜索参数
|
||||||
|
search_params = self._get_search_params(plugin_kwargs)
|
||||||
|
|
||||||
|
# 2. 搜索相关论文
|
||||||
|
papers = await self._search_relevant_papers(criteria, search_params)
|
||||||
|
if not papers:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
# 构建最终的prompt
|
||||||
|
current_time = self._get_current_time()
|
||||||
|
final_prompt = dedent(f"""Current time: {current_time}
|
||||||
|
|
||||||
|
Based on the following paper abstracts, please answer this academic question: {criteria.original_query}
|
||||||
|
|
||||||
|
Available papers for reference:
|
||||||
|
{self._format_papers(self.ranked_papers)}
|
||||||
|
|
||||||
|
Please structure your response in the following format:
|
||||||
|
|
||||||
|
1. Core Answer (2-3 paragraphs)
|
||||||
|
- Provide a clear, direct answer synthesizing key findings
|
||||||
|
- Support main points with citations [1,2,etc.]
|
||||||
|
- Focus on consensus and differences across papers
|
||||||
|
|
||||||
|
2. Key Evidence (2-3 paragraphs)
|
||||||
|
- Present supporting evidence from abstracts
|
||||||
|
- Compare methodologies and results
|
||||||
|
- Highlight significant findings with citations
|
||||||
|
|
||||||
|
3. Research Context (1-2 paragraphs)
|
||||||
|
- Discuss current trends and developments
|
||||||
|
- Identify research gaps or limitations
|
||||||
|
- Suggest potential future directions
|
||||||
|
|
||||||
|
Guidelines:
|
||||||
|
- Base your answer ONLY on the provided abstracts
|
||||||
|
- Use numbered citations [1], [2,3], etc. for every claim
|
||||||
|
- Maintain academic tone and objectivity
|
||||||
|
- Synthesize findings across multiple papers
|
||||||
|
- Focus on the most relevant information to the question
|
||||||
|
|
||||||
|
Constraints:
|
||||||
|
- Do not include information beyond the provided abstracts
|
||||||
|
- Avoid speculation or personal opinions
|
||||||
|
- Do not elaborate on technical details unless directly relevant
|
||||||
|
- Keep citations concise and focused
|
||||||
|
- Use [N] citations for every major claim or finding
|
||||||
|
- Cite multiple papers [1,2,3] when showing consensus
|
||||||
|
- Place citations immediately after the relevant statements
|
||||||
|
|
||||||
|
Note: Provide citations for every major claim to ensure traceability to source papers.
|
||||||
|
Language requirement:
|
||||||
|
- If the query explicitly specifies a language, use that language. Use Chinese to answer if no language is specified.
|
||||||
|
- Otherwise, match the language of the original user query
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
async def _search_relevant_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||||
|
"""搜索相关论文"""
|
||||||
|
# 使用_search_all_sources替代原来的并行搜索
|
||||||
|
all_papers = await self._search_all_sources(criteria, search_params)
|
||||||
|
|
||||||
|
if not all_papers:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 使用BGE重排序
|
||||||
|
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||||
|
query=criteria.main_topic,
|
||||||
|
papers=all_papers,
|
||||||
|
search_criteria=criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.ranked_papers or []
|
||||||
|
|
||||||
|
async def _generate_answer(
|
||||||
|
self,
|
||||||
|
criteria: SearchCriteria,
|
||||||
|
papers: List,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any]
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""生成答案"""
|
||||||
|
|
||||||
|
# 构建提示
|
||||||
|
qa_prompt = dedent(f"""Please answer the following academic question based on recent research papers.
|
||||||
|
|
||||||
|
Question: {criteria.main_topic}
|
||||||
|
|
||||||
|
Relevant papers:
|
||||||
|
{self._format_papers(papers)}
|
||||||
|
|
||||||
|
Please provide:
|
||||||
|
1. A direct answer to the question
|
||||||
|
2. Supporting evidence from the papers
|
||||||
|
3. Different perspectives or approaches if applicable
|
||||||
|
4. Current limitations and open questions
|
||||||
|
5. References to specific papers
|
||||||
|
|
||||||
|
Format your response in markdown with clear sections."""
|
||||||
|
)
|
||||||
|
# 调用LLM生成答案
|
||||||
|
for response_chunk in request_gpt(
|
||||||
|
inputs_array=[qa_prompt],
|
||||||
|
inputs_show_user_array=["Generating answer..."],
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history_array=[history],
|
||||||
|
sys_prompt_array=[system_prompt]
|
||||||
|
):
|
||||||
|
pass # 等待生成完成
|
||||||
|
|
||||||
|
# 获取最后的回答
|
||||||
|
if chatbot and len(chatbot[-1]) >= 2:
|
||||||
|
answer = chatbot[-1][1]
|
||||||
|
chatbot.append(["Here is the answer:", answer])
|
||||||
|
else:
|
||||||
|
chatbot.append(["Here is the answer:", "Failed to generate answer."])
|
||||||
|
|
||||||
|
return chatbot
|
||||||
|
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
from .base_handler import BaseHandler
|
||||||
|
from textwrap import dedent
|
||||||
|
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||||
|
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||||
|
|
||||||
|
class 论文推荐功能(BaseHandler):
|
||||||
|
"""论文推荐处理器"""
|
||||||
|
|
||||||
|
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||||
|
super().__init__(arxiv, semantic, llm_kwargs)
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
criteria: SearchCriteria,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any],
|
||||||
|
plugin_kwargs: Dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""处理论文推荐请求,返回最终的prompt"""
|
||||||
|
|
||||||
|
search_params = self._get_search_params(plugin_kwargs)
|
||||||
|
|
||||||
|
# 1. 先搜索种子论文
|
||||||
|
seed_papers = await self._search_seed_papers(criteria, search_params)
|
||||||
|
if not seed_papers:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
# 使用BGE重排序
|
||||||
|
all_papers = seed_papers
|
||||||
|
|
||||||
|
if not all_papers:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||||
|
query=criteria.original_query,
|
||||||
|
papers=all_papers,
|
||||||
|
search_criteria=criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.ranked_papers:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
# 构建最终的prompt
|
||||||
|
current_time = self._get_current_time()
|
||||||
|
final_prompt = dedent(f"""Current time: {current_time}
|
||||||
|
|
||||||
|
Based on the user's interest in {criteria.main_topic}, here are relevant papers.
|
||||||
|
|
||||||
|
Available papers for recommendation:
|
||||||
|
{self._format_papers(self.ranked_papers)}
|
||||||
|
|
||||||
|
Please provide:
|
||||||
|
1. Group papers by sub-topics or themes if applicable
|
||||||
|
|
||||||
|
2. For each paper:
|
||||||
|
- Publication time and venue (when available)
|
||||||
|
- Journal metrics (when available):
|
||||||
|
* Impact Factor (IF)
|
||||||
|
* JCR Quartile
|
||||||
|
* Chinese Academy of Sciences (CAS) Division
|
||||||
|
- The key contributions and main findings
|
||||||
|
- Why it's relevant to the user's interests
|
||||||
|
- How it relates to other recommended papers
|
||||||
|
- The paper's citation count and citation impact
|
||||||
|
- The paper's download link
|
||||||
|
|
||||||
|
3. A suggested reading order based on:
|
||||||
|
- Journal impact and quality metrics
|
||||||
|
- Chronological development of ideas
|
||||||
|
- Paper relationships and dependencies
|
||||||
|
- Difficulty level
|
||||||
|
- Impact and significance
|
||||||
|
|
||||||
|
4. Future Directions
|
||||||
|
- Emerging venues and research streams
|
||||||
|
- Novel methodological approaches
|
||||||
|
- Cross-disciplinary opportunities
|
||||||
|
- Research gaps by publication type
|
||||||
|
|
||||||
|
|
||||||
|
IMPORTANT:
|
||||||
|
- Focus on explaining why each paper is valuable
|
||||||
|
- Highlight connections between papers
|
||||||
|
- Consider both citation counts AND journal metrics when discussing impact
|
||||||
|
- When available, use IF, JCR quartile, and CAS division to assess paper quality
|
||||||
|
- Mention publication timing when discussing paper relationships
|
||||||
|
- When referring to papers, use HTML links in this format:
|
||||||
|
* For DOIs: <a href='https://doi.org/DOI_HERE' target='_blank'>DOI: DOI_HERE</a>
|
||||||
|
* For titles: <a href='PAPER_URL' target='_blank'>PAPER_TITLE</a>
|
||||||
|
- Present papers in a way that shows the evolution of ideas over time
|
||||||
|
- Base recommendations ONLY on the explicitly provided paper information
|
||||||
|
- Do not make ANY assumptions about papers beyond the given data
|
||||||
|
- When information is missing or unclear, acknowledge the limitation
|
||||||
|
- Never speculate about:
|
||||||
|
* Paper quality or rigor not evidenced in the data
|
||||||
|
* Research impact beyond citation counts and journal metrics
|
||||||
|
* Implementation details not mentioned
|
||||||
|
* Author expertise or background
|
||||||
|
* Future research directions not stated
|
||||||
|
- For each recommendation, cite only verifiable information
|
||||||
|
- Clearly distinguish between facts and potential implications
|
||||||
|
|
||||||
|
Format your response in markdown with clear sections.
|
||||||
|
Language requirement:
|
||||||
|
- If the query explicitly specifies a language, use that language
|
||||||
|
- Otherwise, match the language of the original user query
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
async def _search_seed_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||||
|
"""搜索种子论文"""
|
||||||
|
try:
|
||||||
|
# 使用_search_all_sources替代原来的并行搜索
|
||||||
|
all_papers = await self._search_all_sources(criteria, search_params)
|
||||||
|
|
||||||
|
if not all_papers:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return all_papers
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"搜索种子论文时出错: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _get_recommendations(self, seed_papers: List, multiplier: int = 1) -> List:
|
||||||
|
"""获取推荐论文"""
|
||||||
|
recommendations = []
|
||||||
|
base_limit = 3 * multiplier
|
||||||
|
|
||||||
|
# 将种子论文添加到推荐列表中
|
||||||
|
recommendations.extend(seed_papers)
|
||||||
|
|
||||||
|
# 只使用前5篇论文作为种子
|
||||||
|
seed_papers = seed_papers[:5]
|
||||||
|
|
||||||
|
for paper in seed_papers:
|
||||||
|
try:
|
||||||
|
if paper.doi and paper.doi.startswith("10.48550/arXiv."):
|
||||||
|
# arXiv论文
|
||||||
|
arxiv_id = paper.doi.split(".")[-1]
|
||||||
|
paper_details = await self.arxiv.get_paper_details(arxiv_id)
|
||||||
|
if paper_details and hasattr(paper_details, 'venue'):
|
||||||
|
category = paper_details.venue.split(":")[-1]
|
||||||
|
similar_papers = await self.arxiv.search_by_category(
|
||||||
|
category,
|
||||||
|
limit=base_limit,
|
||||||
|
sort_by='relevance'
|
||||||
|
)
|
||||||
|
recommendations.extend(similar_papers)
|
||||||
|
elif paper.doi: # 只对有DOI的论文获取推荐
|
||||||
|
# Semantic Scholar论文
|
||||||
|
similar_papers = await self.semantic.get_recommended_papers(
|
||||||
|
paper.doi,
|
||||||
|
limit=base_limit
|
||||||
|
)
|
||||||
|
if similar_papers: # 只添加成功获取的推荐
|
||||||
|
recommendations.extend(similar_papers)
|
||||||
|
else:
|
||||||
|
# 对于没有DOI的论文,使用标题进行相关搜索
|
||||||
|
if paper.title:
|
||||||
|
similar_papers = await self.semantic.search(
|
||||||
|
query=paper.title,
|
||||||
|
limit=base_limit
|
||||||
|
)
|
||||||
|
recommendations.extend(similar_papers)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取论文 '{paper.title}' 的推荐时发生错误: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 去重处理
|
||||||
|
seen_dois = set()
|
||||||
|
unique_recommendations = []
|
||||||
|
for paper in recommendations:
|
||||||
|
if paper.doi and paper.doi not in seen_dois:
|
||||||
|
seen_dois.add(paper.doi)
|
||||||
|
unique_recommendations.append(paper)
|
||||||
|
elif not paper.doi and paper not in unique_recommendations:
|
||||||
|
unique_recommendations.append(paper)
|
||||||
|
|
||||||
|
return unique_recommendations
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
from typing import List, Dict, Any, Tuple
|
||||||
|
from .base_handler import BaseHandler
|
||||||
|
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
class 文献综述功能(BaseHandler):
|
||||||
|
"""文献综述处理器"""
|
||||||
|
|
||||||
|
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||||
|
super().__init__(arxiv, semantic, llm_kwargs)
|
||||||
|
|
||||||
|
async def handle(
|
||||||
|
self,
|
||||||
|
criteria: SearchCriteria,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any],
|
||||||
|
plugin_kwargs: Dict[str, Any],
|
||||||
|
) -> str:
|
||||||
|
"""处理文献综述请求,返回最终的prompt"""
|
||||||
|
|
||||||
|
# 获取搜索参数
|
||||||
|
search_params = self._get_search_params(plugin_kwargs)
|
||||||
|
|
||||||
|
# 使用_search_all_sources替代原来的并行搜索
|
||||||
|
all_papers = await self._search_all_sources(criteria, search_params)
|
||||||
|
|
||||||
|
if not all_papers:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||||
|
query=criteria.original_query,
|
||||||
|
papers=all_papers,
|
||||||
|
search_criteria=criteria
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查排序后的论文数量
|
||||||
|
if not self.ranked_papers:
|
||||||
|
return self._generate_apology_prompt(criteria)
|
||||||
|
|
||||||
|
# 检查是否包含PubMed论文
|
||||||
|
has_pubmed_papers = any(paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url
|
||||||
|
for paper in self.ranked_papers)
|
||||||
|
|
||||||
|
if has_pubmed_papers:
|
||||||
|
return self._generate_medical_review_prompt(criteria)
|
||||||
|
else:
|
||||||
|
return self._generate_general_review_prompt(criteria)
|
||||||
|
|
||||||
|
def _generate_medical_review_prompt(self, criteria: SearchCriteria) -> str:
|
||||||
|
"""生成医学文献综述prompt"""
|
||||||
|
return f"""Current time: {self._get_current_time()}
|
||||||
|
|
||||||
|
Conduct a systematic medical literature review on {criteria.main_topic} based STRICTLY on the provided articles.
|
||||||
|
|
||||||
|
Available literature for review:
|
||||||
|
{self._format_papers(self.ranked_papers)}
|
||||||
|
|
||||||
|
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
|
||||||
|
|
||||||
|
Please structure your medical review following these guidelines:
|
||||||
|
|
||||||
|
1. Research Overview
|
||||||
|
- Main research questions and objectives from the studies
|
||||||
|
- Types of studies included (clinical trials, observational studies, etc.)
|
||||||
|
- Study populations and settings
|
||||||
|
- Time period of the research
|
||||||
|
|
||||||
|
2. Key Findings
|
||||||
|
- Main outcomes and results reported in abstracts
|
||||||
|
- Primary endpoints and their measurements
|
||||||
|
- Statistical significance when reported
|
||||||
|
- Observed trends across studies
|
||||||
|
|
||||||
|
3. Methods Summary
|
||||||
|
- Study designs used
|
||||||
|
- Major interventions or treatments studied
|
||||||
|
- Key outcome measures
|
||||||
|
- Patient populations studied
|
||||||
|
|
||||||
|
4. Clinical Relevance
|
||||||
|
- Reported clinical implications
|
||||||
|
- Main conclusions from authors
|
||||||
|
- Reported benefits and risks
|
||||||
|
- Treatment responses when available
|
||||||
|
|
||||||
|
5. Research Status
|
||||||
|
- Current research focus areas
|
||||||
|
- Reported limitations
|
||||||
|
- Gaps identified in abstracts
|
||||||
|
- Authors' suggested future directions
|
||||||
|
|
||||||
|
CRITICAL REQUIREMENTS:
|
||||||
|
|
||||||
|
Citation Rules (MANDATORY):
|
||||||
|
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
|
||||||
|
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
|
||||||
|
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
|
||||||
|
- Use ONLY the papers provided in the available literature list above
|
||||||
|
- Citations must appear immediately after each statement
|
||||||
|
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
|
||||||
|
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
|
||||||
|
|
||||||
|
Content Guidelines:
|
||||||
|
- Present only information available in the provided papers
|
||||||
|
- If certain information is not available, simply omit that aspect rather than explicitly stating its absence
|
||||||
|
- Focus on synthesizing and presenting available findings
|
||||||
|
- Maintain professional medical writing style
|
||||||
|
- Present limitations and gaps as research opportunities rather than missing information
|
||||||
|
|
||||||
|
Writing Style:
|
||||||
|
- Use precise medical terminology
|
||||||
|
- Maintain objective reporting
|
||||||
|
- Use consistent terminology throughout
|
||||||
|
- Present a cohesive narrative without referencing data limitations
|
||||||
|
|
||||||
|
Language requirement:
|
||||||
|
- If the query explicitly specifies a language, use that language
|
||||||
|
- Otherwise, match the language of the original user query
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _generate_general_review_prompt(self, criteria: SearchCriteria) -> str:
|
||||||
|
"""生成通用文献综述prompt"""
|
||||||
|
current_time = self._get_current_time()
|
||||||
|
final_prompt = f"""Current time: {current_time}
|
||||||
|
|
||||||
|
Conduct a comprehensive literature review on {criteria.main_topic} focusing on the following aspects:
|
||||||
|
{', '.join(criteria.sub_topics)}
|
||||||
|
|
||||||
|
Available literature for review:
|
||||||
|
{self._format_papers(self.ranked_papers)}
|
||||||
|
|
||||||
|
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
|
||||||
|
|
||||||
|
Please structure your review following these guidelines:
|
||||||
|
|
||||||
|
1. Introduction and Research Background
|
||||||
|
- Current state and significance of the research field
|
||||||
|
- Key research problems and challenges
|
||||||
|
- Research development timeline and evolution
|
||||||
|
|
||||||
|
2. Research Directions and Classifications
|
||||||
|
- Major research directions and their relationships
|
||||||
|
- Different technical approaches and their characteristics
|
||||||
|
- Comparative analysis of various solutions
|
||||||
|
|
||||||
|
3. Core Technologies and Methods
|
||||||
|
- Key technological breakthroughs
|
||||||
|
- Advantages and limitations of different methods
|
||||||
|
- Technical challenges and solutions
|
||||||
|
|
||||||
|
4. Applications and Impact
|
||||||
|
- Real-world applications and use cases
|
||||||
|
- Industry influence and practical value
|
||||||
|
- Implementation challenges and solutions
|
||||||
|
|
||||||
|
5. Future Trends and Prospects
|
||||||
|
- Emerging research directions
|
||||||
|
- Unsolved problems and challenges
|
||||||
|
- Potential breakthrough points
|
||||||
|
|
||||||
|
CRITICAL REQUIREMENTS:
|
||||||
|
|
||||||
|
Citation Rules (MANDATORY):
|
||||||
|
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
|
||||||
|
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
|
||||||
|
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
|
||||||
|
- Use ONLY the papers provided in the available literature list above
|
||||||
|
- Citations must appear immediately after each statement
|
||||||
|
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
|
||||||
|
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
|
||||||
|
|
||||||
|
Writing Style:
|
||||||
|
- Maintain academic and professional tone
|
||||||
|
- Focus on objective analysis with proper citations
|
||||||
|
- Ensure logical flow and clear structure
|
||||||
|
|
||||||
|
Content Requirements:
|
||||||
|
- Base ALL analysis STRICTLY on the provided papers with explicit citations
|
||||||
|
- When introducing any concept, method, or finding, immediately follow with [N]
|
||||||
|
- For each research direction or approach, cite the specific papers [N] that proposed or developed it
|
||||||
|
- When discussing limitations or challenges, cite the papers [N] that identified them
|
||||||
|
- DO NOT include information from sources outside the provided paper list
|
||||||
|
- DO NOT make unsupported claims or statements
|
||||||
|
|
||||||
|
Language requirement:
|
||||||
|
- If the query explicitly specifies a language, use that language
|
||||||
|
- Otherwise, match the language of the original user query
|
||||||
|
"""
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
@@ -0,0 +1,452 @@
|
|||||||
|
from typing import List, Dict
|
||||||
|
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||||
|
from request_llms.embed_models.bge_llm import BGELLMRanker
|
||||||
|
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||||
|
import random
|
||||||
|
from crazy_functions.review_fns.data_sources.journal_metrics import JournalMetrics
|
||||||
|
|
||||||
|
class PaperLLMRanker:
|
||||||
|
"""使用LLM进行论文重排序"""
|
||||||
|
|
||||||
|
def __init__(self, llm_kwargs: Dict = None):
|
||||||
|
self.ranker = BGELLMRanker(llm_kwargs=llm_kwargs)
|
||||||
|
self.journal_metrics = JournalMetrics()
|
||||||
|
|
||||||
|
def _update_paper_metrics(self, papers: List[PaperMetadata]) -> None:
|
||||||
|
"""更新论文的期刊指标"""
|
||||||
|
for paper in papers:
|
||||||
|
# 跳过arXiv来源的论文
|
||||||
|
if getattr(paper, 'source', '') == 'arxiv':
|
||||||
|
continue
|
||||||
|
|
||||||
|
if hasattr(paper, 'venue_name') or hasattr(paper, 'venue_info'):
|
||||||
|
# 获取venue_name和venue_info
|
||||||
|
venue_name = getattr(paper, 'venue_name', '')
|
||||||
|
venue_info = getattr(paper, 'venue_info', {})
|
||||||
|
|
||||||
|
# 使用改进的匹配逻辑获取指标
|
||||||
|
metrics = self.journal_metrics.get_journal_metrics(venue_name, venue_info)
|
||||||
|
|
||||||
|
# 更新论文的指标
|
||||||
|
paper.if_factor = metrics.get('if_factor')
|
||||||
|
paper.jcr_division = metrics.get('jcr_division')
|
||||||
|
paper.cas_division = metrics.get('cas_division')
|
||||||
|
|
||||||
|
def _get_year_as_int(self, paper) -> int:
|
||||||
|
"""统一获取论文年份为整数格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paper: 论文对象或直接是年份值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
整数格式的年份,如果无法转换则返回0
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 如果输入直接是年份而不是论文对象
|
||||||
|
if isinstance(paper, int):
|
||||||
|
return paper
|
||||||
|
elif isinstance(paper, str):
|
||||||
|
try:
|
||||||
|
return int(paper)
|
||||||
|
except ValueError:
|
||||||
|
import re
|
||||||
|
year_match = re.search(r'\d{4}', paper)
|
||||||
|
if year_match:
|
||||||
|
return int(year_match.group())
|
||||||
|
return 0
|
||||||
|
elif isinstance(paper, float):
|
||||||
|
return int(paper)
|
||||||
|
|
||||||
|
# 处理论文对象
|
||||||
|
year = getattr(paper, 'year', None)
|
||||||
|
if year is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# 如果是字符串,尝试转换为整数
|
||||||
|
if isinstance(year, str):
|
||||||
|
# 首先尝试直接转换整个字符串
|
||||||
|
try:
|
||||||
|
return int(year)
|
||||||
|
except ValueError:
|
||||||
|
# 如果直接转换失败,尝试提取第一个数字序列
|
||||||
|
import re
|
||||||
|
year_match = re.search(r'\d{4}', year)
|
||||||
|
if year_match:
|
||||||
|
return int(year_match.group())
|
||||||
|
return 0
|
||||||
|
# 如果是浮点数,转换为整数
|
||||||
|
elif isinstance(year, float):
|
||||||
|
return int(year)
|
||||||
|
# 如果已经是整数,直接返回
|
||||||
|
elif isinstance(year, int):
|
||||||
|
return year
|
||||||
|
return 0
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def rank_papers(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
papers: List[PaperMetadata],
|
||||||
|
search_criteria: SearchCriteria = None,
|
||||||
|
top_k: int = 40,
|
||||||
|
use_rerank: bool = False,
|
||||||
|
pre_filter_ratio: float = 0.5,
|
||||||
|
max_papers: int = 150
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""对论文进行重排序"""
|
||||||
|
initial_count = len(papers) if papers else 0
|
||||||
|
stats = {'initial': initial_count}
|
||||||
|
|
||||||
|
if not papers or not query:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 更新论文的期刊指标
|
||||||
|
self._update_paper_metrics(papers)
|
||||||
|
|
||||||
|
# 构建增强查询
|
||||||
|
# enhanced_query = self._build_enhanced_query(query, search_criteria) if search_criteria else query
|
||||||
|
enhanced_query = query
|
||||||
|
# 首先过滤不满足年份要求的论文
|
||||||
|
if search_criteria and search_criteria.start_year and search_criteria.end_year:
|
||||||
|
before_year_filter = len(papers)
|
||||||
|
filtered_papers = []
|
||||||
|
start_year = int(search_criteria.start_year)
|
||||||
|
end_year = int(search_criteria.end_year)
|
||||||
|
|
||||||
|
for paper in papers:
|
||||||
|
paper_year = self._get_year_as_int(paper)
|
||||||
|
if paper_year == 0 or start_year <= paper_year <= end_year:
|
||||||
|
filtered_papers.append(paper)
|
||||||
|
|
||||||
|
papers = filtered_papers
|
||||||
|
stats['after_year_filter'] = len(papers)
|
||||||
|
|
||||||
|
if not papers: # 如果过滤后没有论文,直接返回空列表
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 新增:对少量论文的快速处理
|
||||||
|
SMALL_PAPER_THRESHOLD = 10 # 定义"少量"论文的阈值
|
||||||
|
if len(papers) <= SMALL_PAPER_THRESHOLD:
|
||||||
|
# 对于少量论文,直接根据查询类型进行简单排序
|
||||||
|
if search_criteria:
|
||||||
|
if search_criteria.query_type == "latest":
|
||||||
|
papers.sort(key=lambda x: getattr(x, 'year', 0) or 0, reverse=True)
|
||||||
|
elif search_criteria.query_type == "recommend":
|
||||||
|
papers.sort(key=lambda x: getattr(x, 'citations', 0) or 0, reverse=True)
|
||||||
|
elif search_criteria.query_type == "review":
|
||||||
|
papers.sort(key=lambda x:
|
||||||
|
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
|
||||||
|
keyword in (getattr(x, 'abstract', '') or '').lower()
|
||||||
|
for keyword in ['review', 'survey', 'overview'])
|
||||||
|
else 0,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
return papers[:top_k]
|
||||||
|
|
||||||
|
# 1. 优先处理最新的论文
|
||||||
|
if search_criteria and search_criteria.query_type == "latest":
|
||||||
|
papers = sorted(papers, key=lambda x: self._get_year_as_int(x), reverse=True)
|
||||||
|
|
||||||
|
# 2. 如果是综述类查询,优先处理可能的综述论文
|
||||||
|
if search_criteria and search_criteria.query_type == "review":
|
||||||
|
papers = sorted(papers, key=lambda x:
|
||||||
|
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
|
||||||
|
keyword in (getattr(x, 'abstract', '') or '').lower()
|
||||||
|
for keyword in ['review', 'survey', 'overview'])
|
||||||
|
else 0,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 如果论文数量超过限制,采用分层采样而不是完全随机
|
||||||
|
if len(papers) > max_papers:
|
||||||
|
before_max_limit = len(papers)
|
||||||
|
papers = self._select_papers_strategically(papers, search_criteria, max_papers)
|
||||||
|
stats['after_max_limit'] = len(papers)
|
||||||
|
|
||||||
|
try:
|
||||||
|
paper_texts = []
|
||||||
|
valid_papers = [] # 4. 跟踪有效论文
|
||||||
|
|
||||||
|
for paper in papers:
|
||||||
|
if paper is None:
|
||||||
|
continue
|
||||||
|
# 5. 预先过滤明显不相关的论文
|
||||||
|
if search_criteria and search_criteria.start_year:
|
||||||
|
if getattr(paper, 'year', 0) and self._get_year_as_int(paper.year) < search_criteria.start_year:
|
||||||
|
continue
|
||||||
|
|
||||||
|
doc = self._build_enhanced_document(paper, search_criteria)
|
||||||
|
paper_texts.append(doc)
|
||||||
|
valid_papers.append(paper) # 记录对应的论文
|
||||||
|
|
||||||
|
stats['after_valid_check'] = len(valid_papers)
|
||||||
|
|
||||||
|
if not paper_texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 使用LLM判断相关性
|
||||||
|
relevance_results = self.ranker.batch_check_relevance(
|
||||||
|
query=enhanced_query, # 使用增强的查询
|
||||||
|
paper_texts=paper_texts,
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. 优化相关论文的选择策略
|
||||||
|
relevant_papers = []
|
||||||
|
for paper, is_relevant in zip(valid_papers, relevance_results):
|
||||||
|
if is_relevant:
|
||||||
|
relevant_papers.append(paper)
|
||||||
|
|
||||||
|
stats['after_llm_filter'] = len(relevant_papers)
|
||||||
|
|
||||||
|
# 打印统计信息
|
||||||
|
print(f"论文筛选统计: 初始数量={stats['initial']}, " +
|
||||||
|
f"年份过滤后={stats.get('after_year_filter', stats['initial'])}, " +
|
||||||
|
f"数量限制后={stats.get('after_max_limit', stats.get('after_year_filter', stats['initial']))}, " +
|
||||||
|
f"有效性检查后={stats['after_valid_check']}, " +
|
||||||
|
f"LLM筛选后={stats['after_llm_filter']}")
|
||||||
|
|
||||||
|
# 7. 改进回退策略
|
||||||
|
if len(relevant_papers) < min(5, len(papers)):
|
||||||
|
# 如果相关论文太少,返回按引用量排序的论文
|
||||||
|
return sorted(
|
||||||
|
papers[:top_k],
|
||||||
|
key=lambda x: getattr(x, 'citations', 0) or 0,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. 对最终结果进行排序
|
||||||
|
if search_criteria:
|
||||||
|
if search_criteria.query_type == "latest":
|
||||||
|
# 最新论文优先,但同年份按IF排序
|
||||||
|
relevant_papers.sort(key=lambda x: (
|
||||||
|
self._get_year_as_int(x),
|
||||||
|
getattr(x, 'if_factor', 0) or 0
|
||||||
|
), reverse=True)
|
||||||
|
elif search_criteria.query_type == "recommend":
|
||||||
|
# IF指数优先,其次是引用量
|
||||||
|
relevant_papers.sort(key=lambda x: (
|
||||||
|
getattr(x, 'if_factor', 0) or 0,
|
||||||
|
getattr(x, 'citations', 0) or 0
|
||||||
|
), reverse=True)
|
||||||
|
else:
|
||||||
|
# 默认按IF指数排序
|
||||||
|
relevant_papers.sort(key=lambda x: getattr(x, 'if_factor', 0) or 0, reverse=True)
|
||||||
|
|
||||||
|
return relevant_papers[:top_k]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"论文排序时出错: {str(e)}")
|
||||||
|
# 9. 改进错误处理的回退策略
|
||||||
|
try:
|
||||||
|
return sorted(
|
||||||
|
papers[:top_k],
|
||||||
|
key=lambda x: getattr(x, 'citations', 0) or 0,
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
return papers[:top_k] if papers else []
|
||||||
|
|
||||||
|
def _build_enhanced_query(self, query: str, criteria: SearchCriteria) -> str:
|
||||||
|
"""构建增强的查询文本"""
|
||||||
|
components = []
|
||||||
|
|
||||||
|
# 强调这是用户的原始查询,是最重要的匹配依据
|
||||||
|
components.append(f"Original user query that must be primarily matched: {query}")
|
||||||
|
|
||||||
|
if criteria:
|
||||||
|
# 添加主题(如果与原始查询不同)
|
||||||
|
if criteria.main_topic and criteria.main_topic != query:
|
||||||
|
components.append(f"Additional context - The main topic is about: {criteria.main_topic}")
|
||||||
|
|
||||||
|
# 添加子主题
|
||||||
|
if criteria.sub_topics:
|
||||||
|
components.append(f"Secondary aspects to consider: {', '.join(criteria.sub_topics)}")
|
||||||
|
|
||||||
|
# 添加查询类型相关信息
|
||||||
|
if criteria.query_type == "review":
|
||||||
|
components.append("Paper type preference: Looking for comprehensive review papers, survey papers, or overview papers")
|
||||||
|
elif criteria.query_type == "latest":
|
||||||
|
components.append("Temporal preference: Focus on the most recent developments and latest papers")
|
||||||
|
elif criteria.query_type == "recommend":
|
||||||
|
components.append("Impact preference: Consider influential and fundamental papers")
|
||||||
|
|
||||||
|
# 直接连接所有组件,保持语序
|
||||||
|
enhanced_query = ' '.join(components)
|
||||||
|
|
||||||
|
# 限制长度但不打乱顺序
|
||||||
|
if len(enhanced_query) > 1000:
|
||||||
|
enhanced_query = enhanced_query[:997] + "..."
|
||||||
|
|
||||||
|
return enhanced_query
|
||||||
|
|
||||||
|
def _build_enhanced_document(self, paper: PaperMetadata, criteria: SearchCriteria) -> str:
|
||||||
|
"""构建增强的文档表示"""
|
||||||
|
components = []
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
title = getattr(paper, 'title', '')
|
||||||
|
authors = ', '.join(getattr(paper, 'authors', []))
|
||||||
|
abstract = getattr(paper, 'abstract', '')
|
||||||
|
year = getattr(paper, 'year', '')
|
||||||
|
venue = getattr(paper, 'venue', '')
|
||||||
|
|
||||||
|
components.extend([
|
||||||
|
f"Title: {title}",
|
||||||
|
f"Authors: {authors}",
|
||||||
|
f"Year: {year}",
|
||||||
|
f"Venue: {venue}",
|
||||||
|
f"Abstract: {abstract}"
|
||||||
|
])
|
||||||
|
|
||||||
|
# 根据查询类型添加额外信息
|
||||||
|
if criteria:
|
||||||
|
if criteria.query_type == "review":
|
||||||
|
# 对于综述类查询,强调论文的综述性质
|
||||||
|
title_lower = (title or '').lower()
|
||||||
|
abstract_lower = (abstract or '').lower()
|
||||||
|
if any(keyword in title_lower or keyword in abstract_lower
|
||||||
|
for keyword in ['review', 'survey', 'overview']):
|
||||||
|
components.append("This is a review/survey paper")
|
||||||
|
|
||||||
|
elif criteria.query_type == "latest":
|
||||||
|
# 对于最新论文查询,强调时间信息
|
||||||
|
if year and int(year) >= criteria.start_year:
|
||||||
|
components.append(f"This is a recent paper from {year}")
|
||||||
|
|
||||||
|
elif criteria.query_type == "recommend":
|
||||||
|
# 对于推荐类查询,添加主题相关性信息
|
||||||
|
if criteria.main_topic:
|
||||||
|
title_lower = (title or '').lower()
|
||||||
|
abstract_lower = (abstract or '').lower()
|
||||||
|
topic_relevance = any(topic.lower() in title_lower or topic.lower() in abstract_lower
|
||||||
|
for topic in [criteria.main_topic] + (criteria.sub_topics or []))
|
||||||
|
if topic_relevance:
|
||||||
|
components.append(f"This paper is directly related to {criteria.main_topic}")
|
||||||
|
|
||||||
|
return '\n'.join(components)
|
||||||
|
|
||||||
|
def _select_papers_strategically(
|
||||||
|
self,
|
||||||
|
papers: List[PaperMetadata],
|
||||||
|
search_criteria: SearchCriteria,
|
||||||
|
max_papers: int = 150
|
||||||
|
) -> List[PaperMetadata]:
|
||||||
|
"""战略性地选择论文子集,优先选择非Crossref来源的论文,
|
||||||
|
当ADS论文充足时排除arXiv论文"""
|
||||||
|
if len(papers) <= max_papers:
|
||||||
|
return papers
|
||||||
|
|
||||||
|
# 1. 首先按来源分组
|
||||||
|
papers_by_source = {
|
||||||
|
'crossref': [],
|
||||||
|
'adsabs': [],
|
||||||
|
'arxiv': [],
|
||||||
|
'others': [] # semantic, pubmed等其他来源
|
||||||
|
}
|
||||||
|
|
||||||
|
for paper in papers:
|
||||||
|
source = getattr(paper, 'source', '')
|
||||||
|
if source == 'crossref':
|
||||||
|
papers_by_source['crossref'].append(paper)
|
||||||
|
elif source == 'adsabs':
|
||||||
|
papers_by_source['adsabs'].append(paper)
|
||||||
|
elif source == 'arxiv':
|
||||||
|
papers_by_source['arxiv'].append(paper)
|
||||||
|
else:
|
||||||
|
papers_by_source['others'].append(paper)
|
||||||
|
|
||||||
|
# 2. 计算分数的通用函数
|
||||||
|
def calculate_paper_score(paper):
|
||||||
|
score = 0
|
||||||
|
title = (getattr(paper, 'title', '') or '').lower()
|
||||||
|
abstract = (getattr(paper, 'abstract', '') or '').lower()
|
||||||
|
year = self._get_year_as_int(paper)
|
||||||
|
citations = getattr(paper, 'citations', 0) or 0
|
||||||
|
|
||||||
|
# 安全地获取搜索条件
|
||||||
|
main_topic = (getattr(search_criteria, 'main_topic', '') or '').lower()
|
||||||
|
sub_topics = getattr(search_criteria, 'sub_topics', []) or []
|
||||||
|
query_type = getattr(search_criteria, 'query_type', '')
|
||||||
|
start_year = getattr(search_criteria, 'start_year', 0) or 0
|
||||||
|
|
||||||
|
# 主题相关性得分
|
||||||
|
if main_topic and main_topic in title:
|
||||||
|
score += 10
|
||||||
|
if main_topic and main_topic in abstract:
|
||||||
|
score += 5
|
||||||
|
|
||||||
|
# 子主题相关性得分
|
||||||
|
for sub_topic in sub_topics:
|
||||||
|
if sub_topic and sub_topic.lower() in title:
|
||||||
|
score += 5
|
||||||
|
if sub_topic and sub_topic.lower() in abstract:
|
||||||
|
score += 2.5
|
||||||
|
|
||||||
|
# 根据查询类型调整分数
|
||||||
|
if query_type == "review":
|
||||||
|
review_keywords = ['review', 'survey', 'overview']
|
||||||
|
if any(keyword in title for keyword in review_keywords):
|
||||||
|
score *= 1.5
|
||||||
|
if any(keyword in abstract for keyword in review_keywords):
|
||||||
|
score *= 1.2
|
||||||
|
elif query_type == "latest":
|
||||||
|
if year and start_year:
|
||||||
|
year_int = year if isinstance(year, int) else self._get_year_as_int(paper)
|
||||||
|
start_year_int = start_year if isinstance(start_year, int) else int(start_year)
|
||||||
|
if year_int >= start_year_int:
|
||||||
|
recency_bonus = min(5, (year_int - start_year_int))
|
||||||
|
score += recency_bonus * 2
|
||||||
|
elif query_type == "recommend":
|
||||||
|
citation_score = min(10, citations / 100)
|
||||||
|
score += citation_score
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
result = []
|
||||||
|
|
||||||
|
# 3. 处理ADS和arXiv论文
|
||||||
|
non_crossref_papers = papers_by_source['others'] # 首先添加其他来源的论文
|
||||||
|
|
||||||
|
# 添加ADS论文
|
||||||
|
if papers_by_source['adsabs']:
|
||||||
|
non_crossref_papers.extend(papers_by_source['adsabs'])
|
||||||
|
|
||||||
|
# 只有当ADS论文不足20篇时,才添加arXiv论文
|
||||||
|
if len(papers_by_source['adsabs']) <= 20:
|
||||||
|
non_crossref_papers.extend(papers_by_source['arxiv'])
|
||||||
|
elif not papers_by_source['adsabs'] and papers_by_source['arxiv']:
|
||||||
|
# 如果没有ADS论文但有arXiv论文,也使用arXiv论文
|
||||||
|
non_crossref_papers.extend(papers_by_source['arxiv'])
|
||||||
|
|
||||||
|
# 4. 对非Crossref论文评分和排序
|
||||||
|
scored_non_crossref = [(p, calculate_paper_score(p)) for p in non_crossref_papers]
|
||||||
|
scored_non_crossref.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 5. 先添加高分的非Crossref论文
|
||||||
|
non_crossref_limit = max_papers * 0.9 # 90%的配额给非Crossref论文
|
||||||
|
if len(scored_non_crossref) >= non_crossref_limit:
|
||||||
|
result.extend([p[0] for p in scored_non_crossref[:int(non_crossref_limit)]])
|
||||||
|
else:
|
||||||
|
result.extend([p[0] for p in scored_non_crossref])
|
||||||
|
|
||||||
|
# 6. 如果还有剩余空间,考虑添加Crossref论文
|
||||||
|
remaining_slots = max_papers - len(result)
|
||||||
|
if remaining_slots > 0 and papers_by_source['crossref']:
|
||||||
|
# 计算Crossref论文的最大数量(不超过总数的10%)
|
||||||
|
max_crossref = min(remaining_slots, max_papers * 0.1)
|
||||||
|
|
||||||
|
# 对Crossref论文评分和排序
|
||||||
|
scored_crossref = [(p, calculate_paper_score(p)) for p in papers_by_source['crossref']]
|
||||||
|
scored_crossref.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 添加最高分的Crossref论文
|
||||||
|
result.extend([p[0] for p in scored_crossref[:int(max_crossref)]])
|
||||||
|
|
||||||
|
# 7. 如果使用了Crossref论文后还有空位,继续使用非Crossref论文填充
|
||||||
|
if len(result) < max_papers and len(scored_non_crossref) > len(result):
|
||||||
|
remaining_non_crossref = [p[0] for p in scored_non_crossref[len(result):]]
|
||||||
|
result.extend(remaining_non_crossref[:max_papers - len(result)])
|
||||||
|
|
||||||
|
return result
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
# ADS query optimization prompt
|
||||||
|
ADSABS_QUERY_PROMPT = """Analyze and optimize the following query for NASA ADS search.
|
||||||
|
If the query is not related to astronomy, astrophysics, or physics, return <query>none</query>.
|
||||||
|
If the query contains non-English terms, translate them to English first.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Transform the natural language query into an optimized ADS search query.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
|
||||||
|
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||||
|
or output format requirements. Focus only on the core research topic for the search query.
|
||||||
|
|
||||||
|
Relevant research areas for ADS:
|
||||||
|
- Astronomy and astrophysics
|
||||||
|
- Physics (theoretical and experimental)
|
||||||
|
- Space science and exploration
|
||||||
|
- Planetary science
|
||||||
|
- Cosmology
|
||||||
|
- Astrobiology
|
||||||
|
- Related instrumentation and methods
|
||||||
|
|
||||||
|
Available search fields and filters:
|
||||||
|
1. Basic fields:
|
||||||
|
- title: Search in title (title:"term")
|
||||||
|
- abstract: Search in abstract (abstract:"term")
|
||||||
|
- author: Search for author names (author:"lastname, firstname")
|
||||||
|
- year: Filter by year (year:2020-2023)
|
||||||
|
- bibstem: Search by journal abbreviation (bibstem:ApJ)
|
||||||
|
|
||||||
|
2. Boolean operators:
|
||||||
|
- AND
|
||||||
|
- OR
|
||||||
|
- NOT
|
||||||
|
- (): Group terms
|
||||||
|
- "": Exact phrase match
|
||||||
|
|
||||||
|
3. Special filters:
|
||||||
|
- citations(identifier:paper): Papers citing a specific paper
|
||||||
|
- references(identifier:paper): References of a specific paper
|
||||||
|
- citation_count: Filter by citation count
|
||||||
|
- database: Filter by database (database:astronomy)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Black holes in galaxy centers after 2020"
|
||||||
|
<query>title:"black hole" AND abstract:"galaxy center" AND year:2020-</query>
|
||||||
|
|
||||||
|
2. Query: "Papers by Neil deGrasse Tyson about exoplanets"
|
||||||
|
<query>author:"Tyson, Neil deGrasse" AND title:exoplanet</query>
|
||||||
|
|
||||||
|
3. Query: "Most cited papers about dark matter in ApJ"
|
||||||
|
<query>title:"dark matter" AND bibstem:ApJ AND citation_count:[100 TO *]</query>
|
||||||
|
|
||||||
|
4. Query: "Latest research on diabetes treatment"
|
||||||
|
<query>none</query>
|
||||||
|
|
||||||
|
5. Query: "Machine learning for galaxy classification"
|
||||||
|
<query>title:("machine learning" OR "deep learning") AND (title:galaxy OR abstract:galaxy) AND abstract:classification</query>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<query>Provide the optimized ADS search query using appropriate fields and operators, or "none" if not relevant</query>"""
|
||||||
|
|
||||||
|
# System prompt
|
||||||
|
ADSABS_QUERY_SYSTEM_PROMPT = """You are an expert at crafting NASA ADS search queries.
|
||||||
|
Your task is to:
|
||||||
|
1. First determine if the query is relevant to astronomy, astrophysics, or physics research
|
||||||
|
2. If relevant, optimize the natural language query for the ADS API
|
||||||
|
3. If not relevant, return "none" to indicate the query should be handled by other databases
|
||||||
|
|
||||||
|
Focus on creating precise queries that will return relevant astronomical and physics literature.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
Consider using field-specific search terms and appropriate filters to improve search accuracy.
|
||||||
|
|
||||||
|
Remember: ADS is specifically for astronomy, astrophysics, and physics research.
|
||||||
|
Medical, biological, or general research queries should return "none"."""
|
||||||
@@ -0,0 +1,341 @@
|
|||||||
|
# Basic type analysis prompt
|
||||||
|
ARXIV_TYPE_PROMPT = """Analyze the research query and determine if arXiv search is needed and its type.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task 1: Determine if this query requires arXiv search
|
||||||
|
- arXiv is suitable for:
|
||||||
|
* Computer science and AI/ML
|
||||||
|
* Physics and mathematics
|
||||||
|
* Quantitative biology and finance
|
||||||
|
* Electrical engineering
|
||||||
|
* Recent preprints in these fields
|
||||||
|
- arXiv is NOT needed for:
|
||||||
|
* Medical research (unless ML/AI applications)
|
||||||
|
* Social sciences
|
||||||
|
* Business studies
|
||||||
|
* Humanities
|
||||||
|
* Industry reports
|
||||||
|
|
||||||
|
Task 2: If arXiv search is needed, determine the most appropriate search type
|
||||||
|
Available types:
|
||||||
|
1. basic: Keyword-based search across all fields
|
||||||
|
- For specific technical queries
|
||||||
|
- When looking for particular methods or applications
|
||||||
|
2. category: Category-based search within specific fields
|
||||||
|
- For broad topic exploration
|
||||||
|
- When surveying a research area
|
||||||
|
3. none: arXiv search not needed for this query
|
||||||
|
- When topic is outside arXiv's scope
|
||||||
|
- For non-technical or clinical research
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "BERT transformer architecture"
|
||||||
|
<search_type>basic</search_type>
|
||||||
|
|
||||||
|
2. Query: "latest developments in machine learning"
|
||||||
|
<search_type>category</search_type>
|
||||||
|
|
||||||
|
3. Query: "COVID-19 clinical trials"
|
||||||
|
<search_type>none</search_type>
|
||||||
|
|
||||||
|
4. Query: "psychological effects of social media"
|
||||||
|
<search_type>none</search_type>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<search_type>Choose either 'basic', 'category', or 'none'</search_type>"""
|
||||||
|
|
||||||
|
# Query optimization prompt
|
||||||
|
ARXIV_QUERY_PROMPT = """Optimize the following query for arXiv search.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Transform the natural language query into an optimized arXiv search query using boolean operators and field tags.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
|
||||||
|
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||||
|
or output format requirements. Focus only on the core research topic for the search query.
|
||||||
|
|
||||||
|
Available field tags:
|
||||||
|
- ti: Search in title
|
||||||
|
- abs: Search in abstract
|
||||||
|
- au: Search for author
|
||||||
|
- all: Search in all fields (default)
|
||||||
|
|
||||||
|
Boolean operators:
|
||||||
|
- AND: Both terms must appear
|
||||||
|
- OR: Either term can appear
|
||||||
|
- NOT: Exclude terms
|
||||||
|
- (): Group terms
|
||||||
|
- "": Exact phrase match
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Natural query: "Recent papers about transformer models by Vaswani"
|
||||||
|
<query>ti:"transformer model" AND au:Vaswani AND year:[2017 TO 2024]</query>
|
||||||
|
|
||||||
|
2. Natural query: "Deep learning for computer vision, excluding surveys"
|
||||||
|
<query>ti:(deep learning AND "computer vision") NOT (ti:survey OR ti:review)</query>
|
||||||
|
|
||||||
|
3. Natural query: "Attention mechanism in language models"
|
||||||
|
<query>ti:(attention OR "attention mechanism") AND abs:"language model"</query>
|
||||||
|
|
||||||
|
4. Natural query: "GANs or generative adversarial networks for image generation"
|
||||||
|
<query>(ti:GAN OR ti:"generative adversarial network") AND abs:"image generation"</query>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<query>Provide the optimized search query using appropriate operators and tags</query>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Use quotes for exact phrases
|
||||||
|
- Combine multiple conditions with boolean operators
|
||||||
|
- Consider both title and abstract for important concepts
|
||||||
|
- Include author names when relevant
|
||||||
|
- Use parentheses for complex logical groupings"""
|
||||||
|
|
||||||
|
# Sort parameters prompt
|
||||||
|
ARXIV_SORT_PROMPT = """Determine optimal sorting parameters for the research query.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Select the most appropriate sorting parameters to help users find the most relevant papers.
|
||||||
|
|
||||||
|
Available sorting options:
|
||||||
|
|
||||||
|
1. Sort by:
|
||||||
|
- relevance: Best match to query terms (default)
|
||||||
|
- lastUpdatedDate: Most recently updated papers
|
||||||
|
- submittedDate: Most recently submitted papers
|
||||||
|
|
||||||
|
2. Sort order:
|
||||||
|
- descending: Newest/Most relevant first (default)
|
||||||
|
- ascending: Oldest/Least relevant first
|
||||||
|
|
||||||
|
3. Result limit:
|
||||||
|
- Minimum: 10 papers
|
||||||
|
- Maximum: 50 papers
|
||||||
|
- Recommended: 20-30 papers for most queries
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Latest developments in transformer models"
|
||||||
|
<sort_by>submittedDate</sort_by>
|
||||||
|
<sort_order>descending</sort_order>
|
||||||
|
<limit>30</limit>
|
||||||
|
|
||||||
|
2. Query: "Foundational papers about neural networks"
|
||||||
|
<sort_by>relevance</sort_by>
|
||||||
|
<sort_order>descending</sort_order>
|
||||||
|
<limit>20</limit>
|
||||||
|
|
||||||
|
3. Query: "Evolution of deep learning since 2012"
|
||||||
|
<sort_by>submittedDate</sort_by>
|
||||||
|
<sort_order>ascending</sort_order>
|
||||||
|
<limit>50</limit>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<sort_by>Choose: relevance, lastUpdatedDate, or submittedDate</sort_by>
|
||||||
|
<sort_order>Choose: ascending or descending</sort_order>
|
||||||
|
<limit>Suggest number between 10-50</limit>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Choose relevance for specific technical queries
|
||||||
|
- Use lastUpdatedDate for tracking paper revisions
|
||||||
|
- Use submittedDate for following recent developments
|
||||||
|
- Consider query context when setting the limit"""
|
||||||
|
|
||||||
|
# System prompts for each task
|
||||||
|
ARXIV_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||||
|
Your task is to determine whether the query is better suited for keyword search or category-based search.
|
||||||
|
Consider the query's specificity, scope, and intended search area when making your decision.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
ARXIV_QUERY_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
|
||||||
|
Your task is to optimize natural language queries using boolean operators and field tags.
|
||||||
|
Focus on creating precise, targeted queries that will return the most relevant results.
|
||||||
|
Always generate English search terms regardless of the input language."""
|
||||||
|
|
||||||
|
ARXIV_CATEGORIES_SYSTEM_PROMPT = """You are an expert at arXiv category classification.
|
||||||
|
Your task is to select the most relevant categories for the given research query.
|
||||||
|
Consider both primary and related interdisciplinary categories, while maintaining focus on the main research area.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
ARXIV_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
|
||||||
|
Your task is to determine the best sorting parameters based on the query context.
|
||||||
|
Consider the user's likely intent and temporal aspects of the research topic.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
# 添加新的搜索提示词
|
||||||
|
ARXIV_SEARCH_PROMPT = """Analyze and optimize the research query for arXiv search.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Transform the natural language query into an optimized arXiv search query.
|
||||||
|
|
||||||
|
Available search options:
|
||||||
|
1. Basic search with field tags:
|
||||||
|
- ti: Search in title
|
||||||
|
- abs: Search in abstract
|
||||||
|
- au: Search for author
|
||||||
|
Example: "ti:transformer AND abs:attention"
|
||||||
|
|
||||||
|
2. Category-based search:
|
||||||
|
- Use specific arXiv categories
|
||||||
|
Example: "cat:cs.AI AND neural networks"
|
||||||
|
|
||||||
|
3. Date range:
|
||||||
|
- Specify date range using submittedDate
|
||||||
|
Example: "deep learning AND submittedDate:[20200101 TO 20231231]"
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Recent papers about transformer models by Vaswani"
|
||||||
|
<search_criteria>
|
||||||
|
<query>ti:"transformer model" AND au:Vaswani AND submittedDate:[20170101 TO 99991231]</query>
|
||||||
|
<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||||
|
<sort_by>submittedDate</sort_by>
|
||||||
|
<sort_order>descending</sort_order>
|
||||||
|
<limit>30</limit>
|
||||||
|
</search_criteria>
|
||||||
|
|
||||||
|
2. Query: "Latest developments in computer vision"
|
||||||
|
<search_criteria>
|
||||||
|
<query>cat:cs.CV AND submittedDate:[20220101 TO 99991231]</query>
|
||||||
|
<categories>cs.CV, cs.AI, cs.LG</categories>
|
||||||
|
<sort_by>submittedDate</sort_by>
|
||||||
|
<sort_order>descending</sort_order>
|
||||||
|
<limit>25</limit>
|
||||||
|
</search_criteria>
|
||||||
|
|
||||||
|
Please analyze the query and respond with XML tags containing search criteria."""
|
||||||
|
|
||||||
|
ARXIV_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
|
||||||
|
Your task is to analyze research queries and transform them into optimized arXiv search criteria.
|
||||||
|
Consider query intent, relevant categories, and temporal aspects when creating the search parameters.
|
||||||
|
Always generate English search terms and respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
# Categories selection prompt
|
||||||
|
ARXIV_CATEGORIES_PROMPT = """Select the most relevant arXiv categories for the research query.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Choose 2-4 most relevant categories that best match the research topic.
|
||||||
|
|
||||||
|
Available Categories:
|
||||||
|
|
||||||
|
Computer Science (cs):
|
||||||
|
- cs.AI: Artificial Intelligence (neural networks, machine learning, NLP)
|
||||||
|
- cs.CL: Computation and Language (NLP, machine translation)
|
||||||
|
- cs.CV: Computer Vision and Pattern Recognition
|
||||||
|
- cs.LG: Machine Learning (deep learning, reinforcement learning)
|
||||||
|
- cs.NE: Neural and Evolutionary Computing
|
||||||
|
- cs.RO: Robotics
|
||||||
|
- cs.IR: Information Retrieval
|
||||||
|
- cs.SE: Software Engineering
|
||||||
|
- cs.DB: Databases
|
||||||
|
- cs.DC: Distributed Computing
|
||||||
|
- cs.CY: Computers and Society
|
||||||
|
- cs.HC: Human-Computer Interaction
|
||||||
|
|
||||||
|
Mathematics (math):
|
||||||
|
- math.OC: Optimization and Control
|
||||||
|
- math.PR: Probability
|
||||||
|
- math.ST: Statistics
|
||||||
|
- math.NA: Numerical Analysis
|
||||||
|
- math.DS: Dynamical Systems
|
||||||
|
|
||||||
|
Statistics (stat):
|
||||||
|
- stat.ML: Machine Learning
|
||||||
|
- stat.ME: Methodology
|
||||||
|
- stat.TH: Theory
|
||||||
|
- stat.AP: Applications
|
||||||
|
|
||||||
|
Physics (physics):
|
||||||
|
- physics.comp-ph: Computational Physics
|
||||||
|
- physics.data-an: Data Analysis
|
||||||
|
- physics.soc-ph: Physics and Society
|
||||||
|
|
||||||
|
Electrical Engineering (eess):
|
||||||
|
- eess.SP: Signal Processing
|
||||||
|
- eess.AS: Audio and Speech Processing
|
||||||
|
- eess.IV: Image and Video Processing
|
||||||
|
- eess.SY: Systems and Control
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Deep learning for computer vision"
|
||||||
|
<categories>cs.CV, cs.LG, stat.ML</categories>
|
||||||
|
|
||||||
|
2. Query: "Natural language processing with transformers"
|
||||||
|
<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||||
|
|
||||||
|
3. Query: "Reinforcement learning for robotics"
|
||||||
|
<categories>cs.RO, cs.AI, cs.LG</categories>
|
||||||
|
|
||||||
|
4. Query: "Statistical methods in machine learning"
|
||||||
|
<categories>stat.ML, cs.LG, math.ST</categories>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<categories>List 2-4 most relevant categories, comma-separated</categories>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Choose primary categories first, then add related ones
|
||||||
|
- Limit to 2-4 most relevant categories
|
||||||
|
- Order by relevance (most relevant first)
|
||||||
|
- Use comma and space between categories (e.g., "cs.AI, cs.LG")"""
|
||||||
|
|
||||||
|
# 在文件末尾添加新的 prompt
|
||||||
|
ARXIV_LATEST_PROMPT = """Determine if the query is requesting latest papers from arXiv.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Analyze if the query is specifically asking for recent/latest papers from arXiv.
|
||||||
|
|
||||||
|
IMPORTANT RULE:
|
||||||
|
- The query MUST explicitly mention "arXiv" or "arxiv" to be considered a latest arXiv papers request
|
||||||
|
- Queries only asking for recent/latest papers WITHOUT mentioning arXiv should return false
|
||||||
|
|
||||||
|
Indicators for latest papers request:
|
||||||
|
1. MUST HAVE keywords about arXiv:
|
||||||
|
- "arxiv"
|
||||||
|
- "arXiv"
|
||||||
|
AND
|
||||||
|
|
||||||
|
2. Keywords about recency:
|
||||||
|
- "latest"
|
||||||
|
- "recent"
|
||||||
|
- "new"
|
||||||
|
- "newest"
|
||||||
|
- "just published"
|
||||||
|
- "this week/month"
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Latest papers request (Valid):
|
||||||
|
Query: "Show me the latest AI papers on arXiv"
|
||||||
|
<is_latest_request>true</is_latest_request>
|
||||||
|
|
||||||
|
2. Latest papers request (Valid):
|
||||||
|
Query: "What are the recent papers about transformers on arxiv"
|
||||||
|
<is_latest_request>true</is_latest_request>
|
||||||
|
|
||||||
|
3. Not a latest papers request (Invalid - no mention of arXiv):
|
||||||
|
Query: "Show me the latest papers about BERT"
|
||||||
|
<is_latest_request>false</is_latest_request>
|
||||||
|
|
||||||
|
4. Not a latest papers request (Invalid - no recency):
|
||||||
|
Query: "Find papers on arxiv about transformers"
|
||||||
|
<is_latest_request>false</is_latest_request>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<is_latest_request>true/false</is_latest_request>
|
||||||
|
|
||||||
|
Note: The response should be true ONLY if both conditions are met:
|
||||||
|
1. Query explicitly mentions arXiv/arxiv
|
||||||
|
2. Query asks for recent/latest papers"""
|
||||||
|
|
||||||
|
ARXIV_LATEST_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||||
|
Your task is to determine if the query is specifically requesting latest/recent papers from arXiv.
|
||||||
|
Remember: The query MUST explicitly mention arXiv to be considered valid, even if it asks for recent papers.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
# Crossref query optimization prompt
|
||||||
|
CROSSREF_QUERY_PROMPT = """Analyze and optimize the query for Crossref search.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Transform the natural language query into an optimized Crossref search query.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
|
||||||
|
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||||
|
or output format requirements. Focus only on the core research topic for the search query.
|
||||||
|
|
||||||
|
Available search fields and filters:
|
||||||
|
1. Basic fields:
|
||||||
|
- title: Search in title
|
||||||
|
- abstract: Search in abstract
|
||||||
|
- author: Search for author names
|
||||||
|
- container-title: Search in journal/conference name
|
||||||
|
- publisher: Search by publisher name
|
||||||
|
- type: Filter by work type (journal-article, book-chapter, etc.)
|
||||||
|
- year: Filter by publication year
|
||||||
|
|
||||||
|
2. Boolean operators:
|
||||||
|
- AND: Both terms must appear
|
||||||
|
- OR: Either term can appear
|
||||||
|
- NOT: Exclude terms
|
||||||
|
- "": Exact phrase match
|
||||||
|
|
||||||
|
3. Special filters:
|
||||||
|
- is-referenced-by-count: Filter by citation count
|
||||||
|
- from-pub-date: Filter by publication date
|
||||||
|
- has-abstract: Filter papers with abstracts
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Machine learning in healthcare after 2020"
|
||||||
|
<query>title:"machine learning" AND title:healthcare AND from-pub-date:2020</query>
|
||||||
|
|
||||||
|
2. Query: "Papers by Geoffrey Hinton about deep learning"
|
||||||
|
<query>author:"Hinton, Geoffrey" AND (title:"deep learning" OR abstract:"deep learning")</query>
|
||||||
|
|
||||||
|
3. Query: "Most cited papers about transformers in Nature"
|
||||||
|
<query>title:transformer AND container-title:Nature AND is-referenced-by-count:[100 TO *]</query>
|
||||||
|
|
||||||
|
4. Query: "Recent BERT applications in medical domain"
|
||||||
|
<query>title:BERT AND abstract:medical AND from-pub-date:2020 AND type:journal-article</query>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<query>Provide the optimized Crossref search query using appropriate fields and operators</query>"""
|
||||||
|
|
||||||
|
# System prompt
|
||||||
|
CROSSREF_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Crossref search queries.
|
||||||
|
Your task is to optimize natural language queries for Crossref's API.
|
||||||
|
Focus on creating precise queries that will return relevant results.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
Consider using field-specific search terms and appropriate filters to improve search accuracy."""
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
# 新建文件,添加论文识别提示
|
||||||
|
PAPER_IDENTIFY_PROMPT = """Analyze the query to identify paper details.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Extract paper identification information from the query.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
|
||||||
|
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||||
|
or output format requirements. Focus only on identifying paper details.
|
||||||
|
|
||||||
|
Possible paper identifiers:
|
||||||
|
1. arXiv ID (e.g., 2103.14030, arXiv:2103.14030)
|
||||||
|
2. DOI (e.g., 10.1234/xxx.xxx)
|
||||||
|
3. Paper title (e.g., "Attention is All You Need")
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query with arXiv ID:
|
||||||
|
Query: "Analyze paper 2103.14030"
|
||||||
|
<paper_info>
|
||||||
|
<paper_source>arxiv</paper_source>
|
||||||
|
<paper_id>2103.14030</paper_id>
|
||||||
|
<paper_title></paper_title>
|
||||||
|
</paper_info>
|
||||||
|
|
||||||
|
2. Query with DOI:
|
||||||
|
Query: "Review the paper with DOI 10.1234/xxx.xxx"
|
||||||
|
<paper_info>
|
||||||
|
<paper_source>doi</paper_source>
|
||||||
|
<paper_id>10.1234/xxx.xxx</paper_id>
|
||||||
|
<paper_title></paper_title>
|
||||||
|
</paper_info>
|
||||||
|
|
||||||
|
3. Query with paper title:
|
||||||
|
Query: "Analyze 'Attention is All You Need' paper"
|
||||||
|
<paper_info>
|
||||||
|
<paper_source>title</paper_source>
|
||||||
|
<paper_id></paper_id>
|
||||||
|
<paper_title>Attention is All You Need</paper_title>
|
||||||
|
</paper_info>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags containing paper information."""
|
||||||
|
|
||||||
|
PAPER_IDENTIFY_SYSTEM_PROMPT = """You are an expert at identifying academic paper references.
|
||||||
|
Your task is to extract paper identification information from queries.
|
||||||
|
Look for arXiv IDs, DOIs, and paper titles."""
|
||||||
@@ -0,0 +1,108 @@
|
|||||||
|
# PubMed search type prompt
|
||||||
|
PUBMED_TYPE_PROMPT = """Analyze the research query and determine the appropriate PubMed search type.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Available search types:
|
||||||
|
1. basic: General keyword search for medical/biomedical topics
|
||||||
|
2. author: Search by author name
|
||||||
|
3. journal: Search within specific journals
|
||||||
|
4. none: Query not related to medical/biomedical research
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "COVID-19 treatment outcomes"
|
||||||
|
<search_type>basic</search_type>
|
||||||
|
|
||||||
|
2. Query: "Papers by Anthony Fauci"
|
||||||
|
<search_type>author</search_type>
|
||||||
|
|
||||||
|
3. Query: "Recent papers in Nature about CRISPR"
|
||||||
|
<search_type>journal</search_type>
|
||||||
|
|
||||||
|
4. Query: "Deep learning for computer vision"
|
||||||
|
<search_type>none</search_type>
|
||||||
|
|
||||||
|
5. Query: "Transformer architecture for NLP"
|
||||||
|
<search_type>none</search_type>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<search_type>Choose: basic, author, journal, or none</search_type>"""
|
||||||
|
|
||||||
|
# PubMed query optimization prompt
|
||||||
|
PUBMED_QUERY_PROMPT = """Optimize the following query for PubMed search.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Transform the natural language query into an optimized PubMed search query.
|
||||||
|
Requirements:
|
||||||
|
- Always generate English search terms regardless of input language
|
||||||
|
- Translate any non-English terms to English before creating the query
|
||||||
|
- Never include non-English characters in the final query
|
||||||
|
|
||||||
|
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||||
|
or output format requirements. Focus only on the core medical/biomedical topic for the search query.
|
||||||
|
|
||||||
|
Available field tags:
|
||||||
|
- [Title] - Search in title
|
||||||
|
- [Author] - Search for author
|
||||||
|
- [Journal] - Search in journal name
|
||||||
|
- [MeSH Terms] - Search using MeSH terms
|
||||||
|
|
||||||
|
Boolean operators:
|
||||||
|
- AND
|
||||||
|
- OR
|
||||||
|
- NOT
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "COVID-19 treatment in elderly patients"
|
||||||
|
<query>COVID-19[Title] AND treatment[Title/Abstract] AND elderly[Title/Abstract]</query>
|
||||||
|
|
||||||
|
2. Query: "Cancer immunotherapy review articles"
|
||||||
|
<query>cancer immunotherapy[Title/Abstract] AND review[Publication Type]</query>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<query>Provide the optimized PubMed search query</query>"""
|
||||||
|
|
||||||
|
# PubMed sort parameters prompt
|
||||||
|
PUBMED_SORT_PROMPT = """Determine optimal sorting parameters for PubMed results.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Select the most appropriate sorting method and result limit.
|
||||||
|
|
||||||
|
Available sort options:
|
||||||
|
- relevance: Best match to query
|
||||||
|
- date: Most recent first
|
||||||
|
- journal: Sort by journal name
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Latest developments in gene therapy"
|
||||||
|
<sort_by>date</sort_by>
|
||||||
|
<limit>30</limit>
|
||||||
|
|
||||||
|
2. Query: "Classic papers about DNA structure"
|
||||||
|
<sort_by>relevance</sort_by>
|
||||||
|
<limit>20</limit>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<sort_by>Choose: relevance, date, or journal</sort_by>
|
||||||
|
<limit>Suggest number between 10-50</limit>"""
|
||||||
|
|
||||||
|
# System prompts
|
||||||
|
PUBMED_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing medical and scientific queries.
|
||||||
|
Your task is to determine the most appropriate PubMed search type.
|
||||||
|
Consider the query's focus and intended search scope.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
PUBMED_QUERY_SYSTEM_PROMPT = """You are an expert at crafting PubMed search queries.
|
||||||
|
Your task is to optimize natural language queries using PubMed's search syntax.
|
||||||
|
Focus on creating precise, targeted queries that will return relevant medical literature.
|
||||||
|
Always generate English search terms regardless of the input language."""
|
||||||
|
|
||||||
|
PUBMED_SORT_SYSTEM_PROMPT = """You are an expert at optimizing PubMed search results.
|
||||||
|
Your task is to determine the best sorting parameters based on the query context.
|
||||||
|
Consider the balance between relevance and recency.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
@@ -0,0 +1,276 @@
|
|||||||
|
# Search type prompt
|
||||||
|
SEMANTIC_TYPE_PROMPT = """Determine the most appropriate search type for Semantic Scholar.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Analyze the research query and select the most appropriate search type for Semantic Scholar API.
|
||||||
|
|
||||||
|
Available search types:
|
||||||
|
|
||||||
|
1. paper: General paper search
|
||||||
|
- Use for broad topic searches
|
||||||
|
- Looking for specific papers
|
||||||
|
- Keyword-based searches
|
||||||
|
Example: "transformer models in NLP"
|
||||||
|
|
||||||
|
2. author: Author-based search
|
||||||
|
- Finding works by specific researchers
|
||||||
|
- Author profile analysis
|
||||||
|
Example: "papers by Yoshua Bengio"
|
||||||
|
|
||||||
|
3. paper_details: Specific paper lookup
|
||||||
|
- Getting details about a known paper
|
||||||
|
- Finding specific versions or citations
|
||||||
|
Example: "Attention is All You Need paper details"
|
||||||
|
|
||||||
|
4. citations: Citation analysis
|
||||||
|
- Finding papers that cite a specific work
|
||||||
|
- Impact analysis
|
||||||
|
Example: "papers citing BERT"
|
||||||
|
|
||||||
|
5. references: Reference analysis
|
||||||
|
- Finding papers cited by a specific work
|
||||||
|
- Background research
|
||||||
|
Example: "references in GPT-3 paper"
|
||||||
|
|
||||||
|
6. recommendations: Paper recommendations
|
||||||
|
- Finding similar papers
|
||||||
|
- Research direction exploration
|
||||||
|
Example: "papers similar to Transformer"
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Latest papers about deep learning"
|
||||||
|
<search_type>paper</search_type>
|
||||||
|
|
||||||
|
2. Query: "Works by Geoffrey Hinton since 2020"
|
||||||
|
<search_type>author</search_type>
|
||||||
|
|
||||||
|
3. Query: "Papers citing the original Transformer paper"
|
||||||
|
<search_type>citations</search_type>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<search_type>Choose the most appropriate search type from the list above</search_type>"""
|
||||||
|
|
||||||
|
# Query optimization prompt
|
||||||
|
SEMANTIC_QUERY_PROMPT = """Optimize the following query for Semantic Scholar search.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Transform the natural language query into an optimized search query for maximum relevance.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
|
||||||
|
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||||
|
or output format requirements. Focus only on the core research topic for the search query.
|
||||||
|
|
||||||
|
Query optimization guidelines:
|
||||||
|
|
||||||
|
1. Use quotes for exact phrases
|
||||||
|
- Ensures exact matching
|
||||||
|
- Reduces irrelevant results
|
||||||
|
Example: "\"attention mechanism\"" vs attention mechanism
|
||||||
|
|
||||||
|
2. Include key technical terms
|
||||||
|
- Use specific technical terminology
|
||||||
|
- Include common variations
|
||||||
|
Example: "transformer architecture" neural networks
|
||||||
|
|
||||||
|
3. Author names (if relevant)
|
||||||
|
- Include full names when known
|
||||||
|
- Consider common name variations
|
||||||
|
Example: "Geoffrey Hinton" OR "G. E. Hinton"
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Natural query: "Recent advances in transformer models"
|
||||||
|
<query>"transformer model" "neural architecture" deep learning</query>
|
||||||
|
|
||||||
|
2. Natural query: "BERT applications in text classification"
|
||||||
|
<query>"BERT" "text classification" "language model" application</query>
|
||||||
|
|
||||||
|
3. Natural query: "Deep learning for computer vision by Kaiming He"
|
||||||
|
<query>"deep learning" "computer vision" author:"Kaiming He"</query>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<query>Provide the optimized search query</query>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Balance between specificity and coverage
|
||||||
|
- Include important technical terms
|
||||||
|
- Use quotes for key phrases
|
||||||
|
- Consider synonyms and related terms"""
|
||||||
|
|
||||||
|
# Fields selection prompt
|
||||||
|
SEMANTIC_FIELDS_PROMPT = """Select relevant fields to retrieve from Semantic Scholar.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Determine which paper fields should be retrieved based on the research needs.
|
||||||
|
|
||||||
|
Available fields:
|
||||||
|
|
||||||
|
Core fields:
|
||||||
|
- title: Paper title (always included)
|
||||||
|
- abstract: Full paper abstract
|
||||||
|
- authors: Author information
|
||||||
|
- year: Publication year
|
||||||
|
- venue: Publication venue
|
||||||
|
|
||||||
|
Citation fields:
|
||||||
|
- citations: Papers citing this work
|
||||||
|
- references: Papers cited by this work
|
||||||
|
|
||||||
|
Additional fields:
|
||||||
|
- embedding: Paper vector embedding
|
||||||
|
- tldr: AI-generated summary
|
||||||
|
- venue: Publication venue/journal
|
||||||
|
- url: Paper URL
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Latest developments in NLP"
|
||||||
|
<fields>title, abstract, authors, year, venue, citations</fields>
|
||||||
|
|
||||||
|
2. Query: "Most influential papers in deep learning"
|
||||||
|
<fields>title, abstract, authors, year, citations, references</fields>
|
||||||
|
|
||||||
|
3. Query: "Survey of transformer architectures"
|
||||||
|
<fields>title, abstract, authors, year, tldr, references</fields>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<fields>List relevant fields, comma-separated</fields>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Choose fields based on the query's purpose
|
||||||
|
- Include citation data for impact analysis
|
||||||
|
- Consider tldr for quick paper screening
|
||||||
|
- Balance completeness with API efficiency"""
|
||||||
|
|
||||||
|
# Sort parameters prompt
|
||||||
|
SEMANTIC_SORT_PROMPT = """Determine optimal sorting parameters for the query.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Select the most appropriate sorting method and result limit for the search.
|
||||||
|
Always generate English search terms regardless of the input language.
|
||||||
|
|
||||||
|
Sorting options:
|
||||||
|
|
||||||
|
1. relevance (default)
|
||||||
|
- Best match to query terms
|
||||||
|
- Recommended for specific technical searches
|
||||||
|
Example: "specific algorithm implementations"
|
||||||
|
|
||||||
|
2. citations
|
||||||
|
- Sort by citation count
|
||||||
|
- Best for finding influential papers
|
||||||
|
Example: "most important papers in deep learning"
|
||||||
|
|
||||||
|
3. year
|
||||||
|
- Sort by publication date
|
||||||
|
- Best for following recent developments
|
||||||
|
Example: "latest advances in NLP"
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Recent breakthroughs in AI"
|
||||||
|
<sort_by>year</sort_by>
|
||||||
|
<limit>30</limit>
|
||||||
|
|
||||||
|
2. Query: "Most influential papers about GANs"
|
||||||
|
<sort_by>citations</sort_by>
|
||||||
|
<limit>20</limit>
|
||||||
|
|
||||||
|
3. Query: "Specific papers about BERT fine-tuning"
|
||||||
|
<sort_by>relevance</sort_by>
|
||||||
|
<limit>25</limit>
|
||||||
|
|
||||||
|
Please analyze the query and respond ONLY with XML tags:
|
||||||
|
<sort_by>Choose: relevance, citations, or year</sort_by>
|
||||||
|
<limit>Suggest number between 10-50</limit>
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Consider the query's temporal aspects
|
||||||
|
- Balance between comprehensive coverage and information overload
|
||||||
|
- Use citation sorting for impact analysis
|
||||||
|
- Use year sorting for tracking developments"""
|
||||||
|
|
||||||
|
# System prompts for each task
|
||||||
|
SEMANTIC_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||||
|
Your task is to determine the most appropriate type of search on Semantic Scholar.
|
||||||
|
Consider the query's intent, scope, and specific research needs.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
SEMANTIC_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
|
||||||
|
Your task is to optimize natural language queries for maximum relevance.
|
||||||
|
Focus on creating precise queries that leverage the platform's search capabilities.
|
||||||
|
Always generate English search terms regardless of the input language."""
|
||||||
|
|
||||||
|
SEMANTIC_FIELDS_SYSTEM_PROMPT = """You are an expert at Semantic Scholar data fields.
|
||||||
|
Your task is to select the most relevant fields based on the research context.
|
||||||
|
Consider both essential and supplementary information needs.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
SEMANTIC_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
|
||||||
|
Your task is to determine the best sorting parameters based on the query context.
|
||||||
|
Consider the balance between relevance, impact, and recency.
|
||||||
|
Always respond in English regardless of the input language."""
|
||||||
|
|
||||||
|
# 添加新的综合搜索提示词
|
||||||
|
SEMANTIC_SEARCH_PROMPT = """Analyze and optimize the research query for Semantic Scholar search.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Task: Transform the natural language query into optimized search criteria for Semantic Scholar.
|
||||||
|
|
||||||
|
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||||
|
or output format requirements when generating the search terms. These requirements
|
||||||
|
should be considered only for post-search filtering, not as part of the core query.
|
||||||
|
|
||||||
|
Available search options:
|
||||||
|
1. Paper search:
|
||||||
|
- Title and abstract search
|
||||||
|
- Author search
|
||||||
|
- Field-specific search
|
||||||
|
Example: "transformer architecture neural networks"
|
||||||
|
|
||||||
|
2. Field tags:
|
||||||
|
- title: Search in title
|
||||||
|
- abstract: Search in abstract
|
||||||
|
- authors: Search by author names
|
||||||
|
- venue: Search by publication venue
|
||||||
|
Example: "title:transformer authors:\"Vaswani\""
|
||||||
|
|
||||||
|
3. Advanced options:
|
||||||
|
- Year range filtering
|
||||||
|
- Citation count filtering
|
||||||
|
- Venue filtering
|
||||||
|
Example: "deep learning year>2020 venue:\"NeurIPS\""
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
1. Query: "Recent transformer papers by Vaswani with high impact"
|
||||||
|
<search_criteria>
|
||||||
|
<query>title:transformer authors:"Vaswani" year>2017</query>
|
||||||
|
<search_type>paper</search_type>
|
||||||
|
<fields>title,abstract,authors,year,citations,venue</fields>
|
||||||
|
<sort_by>citations</sort_by>
|
||||||
|
<limit>30</limit>
|
||||||
|
</search_criteria>
|
||||||
|
|
||||||
|
2. Query: "Most cited papers about BERT in top conferences"
|
||||||
|
<search_criteria>
|
||||||
|
<query>title:BERT venue:"ACL|EMNLP|NAACL"</query>
|
||||||
|
<search_type>paper</search_type>
|
||||||
|
<fields>title,abstract,authors,year,citations,venue,references</fields>
|
||||||
|
<sort_by>citations</sort_by>
|
||||||
|
<limit>25</limit>
|
||||||
|
</search_criteria>
|
||||||
|
|
||||||
|
Please analyze the query and respond with XML tags containing complete search criteria."""
|
||||||
|
|
||||||
|
SEMANTIC_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
|
||||||
|
Your task is to analyze research queries and transform them into optimized search criteria.
|
||||||
|
Consider query intent, field relevance, and citation impact when creating the search parameters.
|
||||||
|
Focus on producing precise and comprehensive search criteria that will yield the most relevant results.
|
||||||
|
Always generate English search terms and respond in English regardless of the input language."""
|
||||||
@@ -0,0 +1,493 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from textwrap import dedent
|
||||||
|
from datetime import datetime
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchCriteria:
|
||||||
|
"""搜索条件"""
|
||||||
|
query_type: str # 查询类型: review/recommend/qa/paper
|
||||||
|
main_topic: str # 主题
|
||||||
|
sub_topics: List[str] # 子主题列表
|
||||||
|
start_year: int # 起始年份
|
||||||
|
end_year: int # 结束年份
|
||||||
|
arxiv_params: Dict # arXiv搜索参数
|
||||||
|
semantic_params: Dict # Semantic Scholar搜索参数
|
||||||
|
pubmed_params: Dict # 新增: PubMed搜索参数
|
||||||
|
crossref_params: Dict # 添加 Crossref 参数
|
||||||
|
adsabs_params: Dict # 添加 ADS 参数
|
||||||
|
paper_id: str = "" # 论文ID (arxiv ID 或 DOI)
|
||||||
|
paper_title: str = "" # 论文标题
|
||||||
|
paper_source: str = "" # 论文来源 (arxiv/doi/title)
|
||||||
|
original_query: str = "" # 新增: 原始查询字符串
|
||||||
|
|
||||||
|
|
||||||
|
class QueryAnalyzer:
|
||||||
|
"""查询分析器"""
|
||||||
|
|
||||||
|
# 响应索引常量
|
||||||
|
BASIC_QUERY_INDEX = 0
|
||||||
|
PAPER_IDENTIFY_INDEX = 1
|
||||||
|
ARXIV_QUERY_INDEX = 2
|
||||||
|
ARXIV_CATEGORIES_INDEX = 3
|
||||||
|
ARXIV_LATEST_INDEX = 4
|
||||||
|
ARXIV_SORT_INDEX = 5
|
||||||
|
SEMANTIC_QUERY_INDEX = 6
|
||||||
|
SEMANTIC_FIELDS_INDEX = 7
|
||||||
|
PUBMED_TYPE_INDEX = 8
|
||||||
|
PUBMED_QUERY_INDEX = 9
|
||||||
|
CROSSREF_QUERY_INDEX = 10
|
||||||
|
ADSABS_QUERY_INDEX = 11
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.current_year = datetime.now().year
|
||||||
|
self.valid_types = {
|
||||||
|
"review": ["review", "literature review", "survey"],
|
||||||
|
"recommend": ["recommend", "recommendation", "suggest", "similar"],
|
||||||
|
"qa": ["qa", "question", "answer", "explain", "what", "how", "why"],
|
||||||
|
"paper": ["paper", "analyze", "analysis"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def analyze_query(self, query: str, chatbot: List, llm_kwargs: Dict):
|
||||||
|
"""分析查询意图"""
|
||||||
|
from crazy_functions.crazy_utils import \
|
||||||
|
request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||||
|
from crazy_functions.review_fns.prompts.arxiv_prompts import (
|
||||||
|
ARXIV_QUERY_PROMPT, ARXIV_CATEGORIES_PROMPT, ARXIV_LATEST_PROMPT,
|
||||||
|
ARXIV_SORT_PROMPT, ARXIV_QUERY_SYSTEM_PROMPT, ARXIV_CATEGORIES_SYSTEM_PROMPT, ARXIV_SORT_SYSTEM_PROMPT,
|
||||||
|
ARXIV_LATEST_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
from crazy_functions.review_fns.prompts.semantic_prompts import (
|
||||||
|
SEMANTIC_QUERY_PROMPT, SEMANTIC_FIELDS_PROMPT,
|
||||||
|
SEMANTIC_QUERY_SYSTEM_PROMPT, SEMANTIC_FIELDS_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
from .prompts.paper_prompts import PAPER_IDENTIFY_PROMPT, PAPER_IDENTIFY_SYSTEM_PROMPT
|
||||||
|
from .prompts.pubmed_prompts import (
|
||||||
|
PUBMED_TYPE_PROMPT, PUBMED_QUERY_PROMPT, PUBMED_SORT_PROMPT,
|
||||||
|
PUBMED_TYPE_SYSTEM_PROMPT, PUBMED_QUERY_SYSTEM_PROMPT, PUBMED_SORT_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
from .prompts.crossref_prompts import (
|
||||||
|
CROSSREF_QUERY_PROMPT,
|
||||||
|
CROSSREF_QUERY_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
from .prompts.adsabs_prompts import ADSABS_QUERY_PROMPT, ADSABS_QUERY_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
# 1. 基本查询分析
|
||||||
|
type_prompt = dedent(f"""Please analyze this academic query and respond STRICTLY in the following XML format:
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
1. Your response must use XML tags exactly as shown below
|
||||||
|
2. Do not add any text outside the tags
|
||||||
|
3. Choose query type from: review/recommend/qa/paper
|
||||||
|
- review: for literature review or survey requests
|
||||||
|
- recommend: for paper recommendation requests
|
||||||
|
- qa: for general questions about research topics
|
||||||
|
- paper: ONLY for queries about a SPECIFIC paper (with paper ID, DOI, or exact title)
|
||||||
|
4. Identify main topic and subtopics
|
||||||
|
5. Specify year range if mentioned
|
||||||
|
|
||||||
|
Required format:
|
||||||
|
<query_type>ANSWER HERE</query_type>
|
||||||
|
<main_topic>ANSWER HERE</main_topic>
|
||||||
|
<sub_topics>SUBTOPIC1, SUBTOPIC2, ...</sub_topics>
|
||||||
|
<year_range>START_YEAR-END_YEAR</year_range>
|
||||||
|
|
||||||
|
Example responses:
|
||||||
|
|
||||||
|
1. Literature Review Request:
|
||||||
|
Query: "Review recent developments in transformer models for NLP from 2020 to 2023"
|
||||||
|
<query_type>review</query_type>
|
||||||
|
<main_topic>transformer models in natural language processing</main_topic>
|
||||||
|
<sub_topics>architecture improvements, pre-training methods, fine-tuning techniques</sub_topics>
|
||||||
|
<year_range>2020-2023</year_range>
|
||||||
|
|
||||||
|
2. Paper Recommendation Request:
|
||||||
|
Query: "Suggest papers about reinforcement learning in robotics since 2018"
|
||||||
|
<query_type>recommend</query_type>
|
||||||
|
<main_topic>reinforcement learning in robotics</main_topic>
|
||||||
|
<sub_topics>robot control, policy learning, sim-to-real transfer</sub_topics>
|
||||||
|
<year_range>2018-2023</year_range>"""
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建提示数组
|
||||||
|
prompts = [
|
||||||
|
type_prompt,
|
||||||
|
PAPER_IDENTIFY_PROMPT.format(query=query),
|
||||||
|
ARXIV_QUERY_PROMPT.format(query=query),
|
||||||
|
ARXIV_CATEGORIES_PROMPT.format(query=query),
|
||||||
|
ARXIV_LATEST_PROMPT.format(query=query),
|
||||||
|
ARXIV_SORT_PROMPT.format(query=query),
|
||||||
|
SEMANTIC_QUERY_PROMPT.format(query=query),
|
||||||
|
SEMANTIC_FIELDS_PROMPT.format(query=query),
|
||||||
|
PUBMED_TYPE_PROMPT.format(query=query),
|
||||||
|
PUBMED_QUERY_PROMPT.format(query=query),
|
||||||
|
CROSSREF_QUERY_PROMPT.format(query=query),
|
||||||
|
ADSABS_QUERY_PROMPT.format(query=query)
|
||||||
|
]
|
||||||
|
|
||||||
|
show_messages = [
|
||||||
|
"Analyzing query type...",
|
||||||
|
"Identifying paper details...",
|
||||||
|
"Determining arXiv search type...",
|
||||||
|
"Selecting arXiv categories...",
|
||||||
|
"Checking if latest papers requested...",
|
||||||
|
"Determining arXiv sort parameters...",
|
||||||
|
"Optimizing Semantic Scholar query...",
|
||||||
|
"Selecting Semantic Scholar fields...",
|
||||||
|
"Determining PubMed search type...",
|
||||||
|
"Optimizing PubMed query...",
|
||||||
|
"Optimizing Crossref query...",
|
||||||
|
"Optimizing ADS query..."
|
||||||
|
]
|
||||||
|
|
||||||
|
sys_prompts = [
|
||||||
|
"You are an expert at analyzing academic queries.",
|
||||||
|
PAPER_IDENTIFY_SYSTEM_PROMPT,
|
||||||
|
ARXIV_QUERY_SYSTEM_PROMPT,
|
||||||
|
ARXIV_CATEGORIES_SYSTEM_PROMPT,
|
||||||
|
ARXIV_LATEST_SYSTEM_PROMPT,
|
||||||
|
ARXIV_SORT_SYSTEM_PROMPT,
|
||||||
|
SEMANTIC_QUERY_SYSTEM_PROMPT,
|
||||||
|
SEMANTIC_FIELDS_SYSTEM_PROMPT,
|
||||||
|
PUBMED_TYPE_SYSTEM_PROMPT,
|
||||||
|
PUBMED_QUERY_SYSTEM_PROMPT,
|
||||||
|
CROSSREF_QUERY_SYSTEM_PROMPT,
|
||||||
|
ADSABS_QUERY_SYSTEM_PROMPT
|
||||||
|
]
|
||||||
|
new_llm_kwargs = llm_kwargs.copy()
|
||||||
|
# new_llm_kwargs['llm_model'] = 'deepseek-chat' # deepseek-ai/DeepSeek-V2.5
|
||||||
|
|
||||||
|
# 使用同步方式调用LLM
|
||||||
|
responses = yield from request_gpt(
|
||||||
|
inputs_array=prompts,
|
||||||
|
inputs_show_user_array=show_messages,
|
||||||
|
llm_kwargs=new_llm_kwargs,
|
||||||
|
chatbot=chatbot,
|
||||||
|
history_array=[[] for _ in prompts],
|
||||||
|
sys_prompt_array=sys_prompts,
|
||||||
|
max_workers=5
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从收集的响应中提取我们需要的内容
|
||||||
|
extracted_responses = []
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
if (i * 2 + 1) < len(responses):
|
||||||
|
response = responses[i * 2 + 1]
|
||||||
|
if response is None:
|
||||||
|
raise Exception(f"Response {i} is None")
|
||||||
|
if not isinstance(response, str):
|
||||||
|
try:
|
||||||
|
response = str(response)
|
||||||
|
except:
|
||||||
|
raise Exception(f"Cannot convert response {i} to string")
|
||||||
|
extracted_responses.append(response)
|
||||||
|
else:
|
||||||
|
raise Exception(f"未收到第 {i + 1} 个响应")
|
||||||
|
|
||||||
|
# 解析基本信息
|
||||||
|
query_type = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "query_type")
|
||||||
|
if not query_type:
|
||||||
|
print(
|
||||||
|
f"Debug - Failed to extract query_type. Response was: {extracted_responses[self.BASIC_QUERY_INDEX]}")
|
||||||
|
raise Exception("无法提取query_type标签内容")
|
||||||
|
query_type = query_type.lower()
|
||||||
|
|
||||||
|
main_topic = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "main_topic")
|
||||||
|
if not main_topic:
|
||||||
|
print(f"Debug - Failed to extract main_topic. Using query as fallback.")
|
||||||
|
main_topic = query
|
||||||
|
|
||||||
|
query_type = self._normalize_query_type(query_type, query)
|
||||||
|
|
||||||
|
# 解析arXiv参数
|
||||||
|
try:
|
||||||
|
arxiv_params = {
|
||||||
|
"query": self._extract_tag(extracted_responses[self.ARXIV_QUERY_INDEX], "query"),
|
||||||
|
"categories": [cat.strip() for cat in
|
||||||
|
self._extract_tag(extracted_responses[self.ARXIV_CATEGORIES_INDEX],
|
||||||
|
"categories").split(",")],
|
||||||
|
"sort_by": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_by"),
|
||||||
|
"sort_order": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_order"),
|
||||||
|
"limit": 20
|
||||||
|
}
|
||||||
|
|
||||||
|
# 安全地解析limit值
|
||||||
|
limit_str = self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "limit")
|
||||||
|
if limit_str and limit_str.isdigit():
|
||||||
|
arxiv_params["limit"] = int(limit_str)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Error parsing arXiv parameters: {str(e)}")
|
||||||
|
arxiv_params = {
|
||||||
|
"query": "",
|
||||||
|
"categories": [],
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"sort_order": "descending",
|
||||||
|
"limit": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 解析Semantic Scholar参数
|
||||||
|
try:
|
||||||
|
semantic_params = {
|
||||||
|
"query": self._extract_tag(extracted_responses[self.SEMANTIC_QUERY_INDEX], "query"),
|
||||||
|
"fields": [field.strip() for field in
|
||||||
|
self._extract_tag(extracted_responses[self.SEMANTIC_FIELDS_INDEX], "fields").split(",")],
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 20
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Error parsing Semantic Scholar parameters: {str(e)}")
|
||||||
|
semantic_params = {
|
||||||
|
"query": query,
|
||||||
|
"fields": ["title", "abstract", "authors", "year"],
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 20
|
||||||
|
}
|
||||||
|
|
||||||
|
# 解析PubMed参数
|
||||||
|
try:
|
||||||
|
# 首先检查是否需要PubMed搜索
|
||||||
|
pubmed_search_type = self._extract_tag(extracted_responses[self.PUBMED_TYPE_INDEX], "search_type")
|
||||||
|
|
||||||
|
if pubmed_search_type == "none":
|
||||||
|
# 不需要PubMed搜索,使用空参数
|
||||||
|
pubmed_params = {
|
||||||
|
"search_type": "none",
|
||||||
|
"query": "",
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 0
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 需要PubMed搜索,解析完整参数
|
||||||
|
pubmed_params = {
|
||||||
|
"search_type": pubmed_search_type,
|
||||||
|
"query": self._extract_tag(extracted_responses[self.PUBMED_QUERY_INDEX], "query"),
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 200
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Error parsing PubMed parameters: {str(e)}")
|
||||||
|
pubmed_params = {
|
||||||
|
"search_type": "none",
|
||||||
|
"query": "",
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 解析Crossref参数
|
||||||
|
try:
|
||||||
|
crossref_query = self._extract_tag(extracted_responses[self.CROSSREF_QUERY_INDEX], "query")
|
||||||
|
|
||||||
|
if not crossref_query:
|
||||||
|
crossref_params = {
|
||||||
|
"search_type": "none",
|
||||||
|
"query": "",
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 0
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
crossref_params = {
|
||||||
|
"search_type": "basic",
|
||||||
|
"query": crossref_query,
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 20
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Error parsing Crossref parameters: {str(e)}")
|
||||||
|
crossref_params = {
|
||||||
|
"search_type": "none",
|
||||||
|
"query": "",
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# 解析ADS参数
|
||||||
|
try:
|
||||||
|
adsabs_query = self._extract_tag(extracted_responses[self.ADSABS_QUERY_INDEX], "query")
|
||||||
|
|
||||||
|
if not adsabs_query:
|
||||||
|
adsabs_params = {
|
||||||
|
"search_type": "none",
|
||||||
|
"query": "",
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 0
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
adsabs_params = {
|
||||||
|
"search_type": "basic",
|
||||||
|
"query": adsabs_query,
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 20
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Error parsing ADS parameters: {str(e)}")
|
||||||
|
adsabs_params = {
|
||||||
|
"search_type": "none",
|
||||||
|
"query": "",
|
||||||
|
"sort_by": "relevance",
|
||||||
|
"limit": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"Debug - Extracted information:")
|
||||||
|
print(f"Query type: {query_type}")
|
||||||
|
print(f"Main topic: {main_topic}")
|
||||||
|
print(f"arXiv params: {arxiv_params}")
|
||||||
|
print(f"Semantic params: {semantic_params}")
|
||||||
|
print(f"PubMed params: {pubmed_params}")
|
||||||
|
print(f"Crossref params: {crossref_params}")
|
||||||
|
print(f"ADS params: {adsabs_params}")
|
||||||
|
|
||||||
|
# 提取子主题
|
||||||
|
sub_topics = []
|
||||||
|
if "sub_topics" in query.lower():
|
||||||
|
sub_topics_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "sub_topics")
|
||||||
|
if sub_topics_text:
|
||||||
|
sub_topics = [topic.strip() for topic in sub_topics_text.split(",")]
|
||||||
|
|
||||||
|
# 提取年份范围
|
||||||
|
start_year = self.current_year - 5 # 默认最近5年
|
||||||
|
end_year = self.current_year
|
||||||
|
year_range = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "year_range")
|
||||||
|
if year_range:
|
||||||
|
try:
|
||||||
|
years = year_range.split("-")
|
||||||
|
if len(years) == 2:
|
||||||
|
start_year = int(years[0].strip())
|
||||||
|
end_year = int(years[1].strip())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 提取 latest request 判断
|
||||||
|
is_latest_request = self._extract_tag(extracted_responses[self.ARXIV_LATEST_INDEX],
|
||||||
|
"is_latest_request").lower() == "true"
|
||||||
|
|
||||||
|
# 如果是最新论文请求,将查询类型改为 "latest"
|
||||||
|
if is_latest_request:
|
||||||
|
query_type = "latest"
|
||||||
|
|
||||||
|
# 提取论文标识信息
|
||||||
|
paper_source = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_source")
|
||||||
|
paper_id = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_id")
|
||||||
|
paper_title = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_title")
|
||||||
|
if start_year > end_year:
|
||||||
|
start_year, end_year = end_year, start_year
|
||||||
|
# 更新返回的 SearchCriteria
|
||||||
|
return SearchCriteria(
|
||||||
|
query_type=query_type,
|
||||||
|
main_topic=main_topic,
|
||||||
|
sub_topics=sub_topics,
|
||||||
|
start_year=start_year,
|
||||||
|
end_year=end_year,
|
||||||
|
arxiv_params=arxiv_params,
|
||||||
|
semantic_params=semantic_params,
|
||||||
|
pubmed_params=pubmed_params,
|
||||||
|
crossref_params=crossref_params,
|
||||||
|
paper_id=paper_id,
|
||||||
|
paper_title=paper_title,
|
||||||
|
paper_source=paper_source,
|
||||||
|
original_query=query,
|
||||||
|
adsabs_params=adsabs_params
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Failed to analyze query: {str(e)}")
|
||||||
|
|
||||||
|
def _normalize_query_type(self, query_type: str, query: str) -> str:
|
||||||
|
"""规范化查询类型"""
|
||||||
|
if query_type in ["review", "recommend", "qa", "paper"]:
|
||||||
|
return query_type
|
||||||
|
|
||||||
|
query_lower = query.lower()
|
||||||
|
for type_name, keywords in self.valid_types.items():
|
||||||
|
for keyword in keywords:
|
||||||
|
if keyword in query_lower:
|
||||||
|
return type_name
|
||||||
|
|
||||||
|
query_type_lower = query_type.lower()
|
||||||
|
for type_name, keywords in self.valid_types.items():
|
||||||
|
for keyword in keywords:
|
||||||
|
if keyword in query_type_lower:
|
||||||
|
return type_name
|
||||||
|
|
||||||
|
return "qa" # 默认返回qa类型
|
||||||
|
|
||||||
|
def _extract_tag(self, text: str, tag: str) -> str:
|
||||||
|
"""提取标记内容"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 1. 标准XML格式(处理多行和特殊字符)
|
||||||
|
pattern = f"<{tag}>(.*?)</{tag}>"
|
||||||
|
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
content = match.group(1).strip()
|
||||||
|
if content:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# 2. 处理特定标签的复杂内容
|
||||||
|
if tag == "categories":
|
||||||
|
# 处理arXiv类别
|
||||||
|
patterns = [
|
||||||
|
# 标准格式:<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||||
|
r"<categories>\s*((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)\s*</categories>",
|
||||||
|
# 简单列表格式:cs.CL, cs.AI, cs.LG
|
||||||
|
r"(?:^|\s)((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)(?:\s|$)",
|
||||||
|
# 单个类别格式:cs.AI
|
||||||
|
r"(?:^|\s)((?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+)(?:\s|$)"
|
||||||
|
]
|
||||||
|
|
||||||
|
elif tag == "query":
|
||||||
|
# 处理搜索查询
|
||||||
|
patterns = [
|
||||||
|
# 完整的查询格式:<query>complex query</query>
|
||||||
|
r"<query>\s*((?:(?:ti|abs|au|cat):[^\n]*?|(?:AND|OR|NOT|\(|\)|\d{4}|year:\d{4}|[\"'][^\"']*[\"']|\s+))+)\s*</query>",
|
||||||
|
# 简单的关键词列表:keyword1, keyword2
|
||||||
|
r"(?:^|\s)((?:\"[^\"]*\"|'[^']*'|[^\s,]+)(?:\s*,\s*(?:\"[^\"]*\"|'[^']*'|[^\s,]+))*)",
|
||||||
|
# 字段搜索格式:field:value
|
||||||
|
r"((?:ti|abs|au|cat):\s*(?:\"[^\"]*\"|'[^']*'|[^\s]+))"
|
||||||
|
]
|
||||||
|
|
||||||
|
elif tag == "fields":
|
||||||
|
# 处理字段列表
|
||||||
|
patterns = [
|
||||||
|
# 标准格式:<fields>field1, field2</fields>
|
||||||
|
r"<fields>\s*([\w\s,]+)\s*</fields>",
|
||||||
|
# 简单列表格式:field1, field2
|
||||||
|
r"(?:^|\s)([\w]+(?:\s*,\s*[\w]+)*)",
|
||||||
|
]
|
||||||
|
|
||||||
|
elif tag == "sort_by":
|
||||||
|
# 处理排序字段
|
||||||
|
patterns = [
|
||||||
|
# 标准格式:<sort_by>value</sort_by>
|
||||||
|
r"<sort_by>\s*(relevance|date|citations|submittedDate|year)\s*</sort_by>",
|
||||||
|
# 简单值格式:relevance
|
||||||
|
r"(?:^|\s)(relevance|date|citations|submittedDate|year)(?:\s|$)"
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 通用模式
|
||||||
|
patterns = [
|
||||||
|
f"<{tag}>\s*([\s\S]*?)\s*</{tag}>", # 标准XML格式
|
||||||
|
f"<{tag}>([\s\S]*?)(?:</{tag}>|$)", # 未闭合的标签
|
||||||
|
f"[{tag}]([\s\S]*?)[/{tag}]", # 方括号格式
|
||||||
|
f"{tag}:\s*(.*?)(?=\n\w|$)", # 冒号格式
|
||||||
|
f"<{tag}>\s*(.*?)(?=<|$)" # 部分闭合
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. 尝试所有模式
|
||||||
|
for pattern in patterns:
|
||||||
|
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
|
||||||
|
if match:
|
||||||
|
content = match.group(1).strip()
|
||||||
|
if content: # 确保提取的内容不为空
|
||||||
|
return content
|
||||||
|
|
||||||
|
# 4. 如果所有模式都失败,返回空字符串
|
||||||
|
return ""
|
||||||
|
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
from .query_analyzer import QueryAnalyzer, SearchCriteria
|
||||||
|
from .data_sources.arxiv_source import ArxivSource
|
||||||
|
from .data_sources.semantic_source import SemanticScholarSource
|
||||||
|
from .handlers.review_handler import 文献综述功能
|
||||||
|
from .handlers.recommend_handler import 论文推荐功能
|
||||||
|
from .handlers.qa_handler import 学术问答功能
|
||||||
|
from .handlers.paper_handler import 单篇论文分析功能
|
||||||
|
|
||||||
|
class QueryProcessor:
|
||||||
|
"""查询处理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.analyzer = QueryAnalyzer()
|
||||||
|
self.arxiv = ArxivSource()
|
||||||
|
self.semantic = SemanticScholarSource()
|
||||||
|
|
||||||
|
# 初始化各种处理器
|
||||||
|
self.handlers = {
|
||||||
|
"review": 文献综述功能(self.arxiv, self.semantic),
|
||||||
|
"recommend": 论文推荐功能(self.arxiv, self.semantic),
|
||||||
|
"qa": 学术问答功能(self.arxiv, self.semantic),
|
||||||
|
"paper": 单篇论文分析功能(self.arxiv, self.semantic)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def process_query(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
chatbot: List[List[str]],
|
||||||
|
history: List[List[str]],
|
||||||
|
system_prompt: str,
|
||||||
|
llm_kwargs: Dict[str, Any],
|
||||||
|
plugin_kwargs: Dict[str, Any],
|
||||||
|
) -> List[List[str]]:
|
||||||
|
"""处理用户查询"""
|
||||||
|
|
||||||
|
# 设置默认的插件参数
|
||||||
|
default_plugin_kwargs = {
|
||||||
|
'max_papers': 20, # 最大论文数量
|
||||||
|
'min_year': 2015, # 最早年份
|
||||||
|
'search_multiplier': 3, # 检索倍数
|
||||||
|
}
|
||||||
|
# 更新插件参数
|
||||||
|
plugin_kwargs.update({k: v for k, v in default_plugin_kwargs.items() if k not in plugin_kwargs})
|
||||||
|
|
||||||
|
# 1. 分析查询意图
|
||||||
|
criteria = self.analyzer.analyze_query(query, chatbot, llm_kwargs)
|
||||||
|
|
||||||
|
# 2. 根据查询类型选择处理器
|
||||||
|
handler = self.handlers.get(criteria.query_type)
|
||||||
|
if not handler:
|
||||||
|
handler = self.handlers["qa"] # 默认使用QA处理器
|
||||||
|
|
||||||
|
# 3. 处理查询
|
||||||
|
response = await handler.handle(
|
||||||
|
criteria,
|
||||||
|
chatbot,
|
||||||
|
history,
|
||||||
|
system_prompt,
|
||||||
|
llm_kwargs,
|
||||||
|
plugin_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
109
request_llms/embed_models/bge_llm.py
普通文件
109
request_llms/embed_models/bge_llm.py
普通文件
@@ -0,0 +1,109 @@
|
|||||||
|
import re
|
||||||
|
import requests
|
||||||
|
from loguru import logger
|
||||||
|
from typing import List, Dict
|
||||||
|
from urllib3.util import Retry
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from textwrap import dedent
|
||||||
|
from request_llms.bridge_all import predict_no_ui_long_connection
|
||||||
|
|
||||||
|
class BGELLMRanker:
|
||||||
|
"""使用LLM进行论文相关性判断的类"""
|
||||||
|
def __init__(self, llm_kwargs):
|
||||||
|
self.llm_kwargs = llm_kwargs
|
||||||
|
|
||||||
|
def is_paper_relevant(self, query: str, paper_text: str) -> bool:
|
||||||
|
"""判断论文是否与查询相关"""
|
||||||
|
prompt = dedent(f"""
|
||||||
|
Evaluate if this academic paper contains information that directly addresses the user's query.
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
Paper Content:
|
||||||
|
{paper_text}
|
||||||
|
|
||||||
|
Evaluation Criteria:
|
||||||
|
1. The paper must contain core information that directly answers the query
|
||||||
|
2. The paper's main research focus must be highly relevant to the query
|
||||||
|
3. Papers that only mention query-related content in abstract should be excluded
|
||||||
|
4. Papers with superficial or general discussions should be excluded
|
||||||
|
5. For queries about "recent" or "latest" advances, paper should be from last 3 years
|
||||||
|
|
||||||
|
Instructions:
|
||||||
|
- Carefully evaluate against ALL criteria above
|
||||||
|
- Return true ONLY if paper meets ALL criteria
|
||||||
|
- If any criteria is not met or unclear, return false
|
||||||
|
- Be strict but not overly restrictive
|
||||||
|
|
||||||
|
Output Rules:
|
||||||
|
- Must ONLY respond with <decision>true</decision> or <decision>false</decision>
|
||||||
|
- true = paper contains relevant information to answer the query
|
||||||
|
- false = paper does not contain sufficient relevant information
|
||||||
|
|
||||||
|
Do not include any explanation or additional text."""
|
||||||
|
)
|
||||||
|
response = predict_no_ui_long_connection(
|
||||||
|
inputs=prompt,
|
||||||
|
history=[],
|
||||||
|
llm_kwargs=self.llm_kwargs,
|
||||||
|
sys_prompt="You are an expert at determining paper relevance to queries. Respond only with <decision>true</decision> or <decision>false</decision>."
|
||||||
|
)
|
||||||
|
# 提取decision标签中的内容
|
||||||
|
match = re.search(r'<decision>(.*?)</decision>', response, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
decision = match.group(1).lower()
|
||||||
|
return decision == "true"
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def batch_check_relevance(self, query: str, paper_texts: List[str], show_progress: bool = True) -> List[bool]:
|
||||||
|
"""批量检查论文相关性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询
|
||||||
|
paper_texts: 论文文本列表
|
||||||
|
show_progress: 是否显示进度条
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[bool]: 相关性判断结果列表
|
||||||
|
"""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
results = [False] * len(paper_texts)
|
||||||
|
|
||||||
|
# 减少并发线程数以避免连接池耗尽
|
||||||
|
max_workers = min(20, len(paper_texts)) # 限制最大线程数
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
future_to_idx = {
|
||||||
|
executor.submit(self.is_paper_relevant, query, text): i
|
||||||
|
for i, text in enumerate(paper_texts)
|
||||||
|
}
|
||||||
|
iterator = as_completed(future_to_idx)
|
||||||
|
if show_progress:
|
||||||
|
iterator = tqdm(iterator, total=len(paper_texts), desc="判断论文相关性")
|
||||||
|
for future in iterator:
|
||||||
|
idx = future_to_idx[future]
|
||||||
|
try:
|
||||||
|
results[idx] = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"处理论文 {idx} 时出错: {str(e)}")
|
||||||
|
results[idx] = False
|
||||||
|
return results
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 测试代码
|
||||||
|
ranker = BGELLMRanker()
|
||||||
|
|
||||||
|
query = "Recent advances in transformer models"
|
||||||
|
paper_text = """
|
||||||
|
Title: Attention Is All You Need
|
||||||
|
Abstract: The dominant sequence transduction models are based on complex recurrent or convolutional neural networks that include an encoder and a decoder. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely...
|
||||||
|
"""
|
||||||
|
|
||||||
|
is_relevant = ranker.is_paper_relevant(query, paper_text)
|
||||||
|
print(f"Paper relevant: {is_relevant}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -8,13 +8,10 @@ API_URL_REDIRECT, AZURE_ENDPOINT, AZURE_ENGINE = get_conf("API_URL_REDIRECT", "A
|
|||||||
openai_endpoint = "https://api.openai.com/v1/chat/completions"
|
openai_endpoint = "https://api.openai.com/v1/chat/completions"
|
||||||
if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/'
|
if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/'
|
||||||
azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15'
|
azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15'
|
||||||
|
|
||||||
|
|
||||||
if openai_endpoint in API_URL_REDIRECT: openai_endpoint = API_URL_REDIRECT[openai_endpoint]
|
if openai_endpoint in API_URL_REDIRECT: openai_endpoint = API_URL_REDIRECT[openai_endpoint]
|
||||||
|
|
||||||
openai_embed_endpoint = openai_endpoint.replace("chat/completions", "embeddings")
|
openai_embed_endpoint = openai_endpoint.replace("chat/completions", "embeddings")
|
||||||
|
|
||||||
from .openai_embed import OpenAiEmbeddingModel
|
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
|
||||||
|
|
||||||
embed_model_info = {
|
embed_model_info = {
|
||||||
# text-embedding-3-small Increased performance over 2nd generation ada embedding model | 1,536
|
# text-embedding-3-small Increased performance over 2nd generation ada embedding model | 1,536
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ mdtex2html
|
|||||||
dashscope
|
dashscope
|
||||||
pyautogen
|
pyautogen
|
||||||
colorama
|
colorama
|
||||||
|
docx2pdf
|
||||||
Markdown
|
Markdown
|
||||||
pygments
|
pygments
|
||||||
edge-tts>=7.0.0
|
edge-tts>=7.0.0
|
||||||
|
|||||||
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
对项目中的各个插件进行测试。运行方法:直接运行 python tests/test_plugins.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import init_test
|
||||||
|
import os, sys
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from test_utils import plugin_test
|
||||||
|
plugin_test(plugin='crazy_functions.Academic_Conversation->学术对话', main_input="搜索最新使用GAIA Benchmark的论文")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.Internet_GPT->连接网络回答问题', main_input="谁是应急食品?")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.函数动态生成->函数动态生成', main_input='交换图像的蓝色通道和红色通道', advanced_arg={"file_path_arg": "./build/ants.jpg"})
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2307.07522")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.PDF_Translate->批量翻译PDF文档', main_input='build/pdf/t1.pdf')
|
||||||
|
|
||||||
|
# plugin_test(
|
||||||
|
# plugin="crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF",
|
||||||
|
# main_input="G:/SEAFILE_LOCAL/50503047/我的资料库/学位/paperlatex/aaai/Fu_8368_with_appendix",
|
||||||
|
# )
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.虚空终端->虚空终端', main_input='修改api-key为sk-jhoejriotherjep')
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.批量翻译PDF文档_NOUGAT->批量翻译PDF文档', main_input='crazy_functions/test_project/pdf_and_word/aaai.pdf')
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.虚空终端->虚空终端', main_input='调用插件,对C:/Users/fuqingxu/Desktop/旧文件/gpt/chatgpt_academic/crazy_functions/latex_fns中的python文件进行解析')
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.命令行助手->命令行助手', main_input='查看当前的docker容器列表')
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.SourceCode_Analyse->解析一个Python项目', main_input="crazy_functions/test_project/python/dqn")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.SourceCode_Analyse->解析一个C项目', main_input="crazy_functions/test_project/cpp/cppipc")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.Latex_Project_Polish->Latex英文润色', main_input="crazy_functions/test_project/latex/attention")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.Markdown_Translate->Markdown中译英', main_input="README.md")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.PDF_Translate->批量翻译PDF文档', main_input='crazy_functions/test_project/pdf_and_word/aaai.pdf')
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.谷歌检索小助手->谷歌检索小助手', main_input="https://scholar.google.com/scholar?hl=en&as_sdt=0%2C5&q=auto+reinforcement+learning&btnG=")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.总结word文档->总结word文档', main_input="crazy_functions/test_project/pdf_and_word")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.下载arxiv论文翻译摘要->下载arxiv论文并翻译摘要', main_input="1812.10695")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.解析JupyterNotebook->解析ipynb文件', main_input="crazy_functions/test_samples")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.数学动画生成manim->动画生成', main_input="A ball split into 2, and then split into 4, and finally split into 8.")
|
||||||
|
|
||||||
|
# for lang in ["English", "French", "Japanese", "Korean", "Russian", "Italian", "German", "Portuguese", "Arabic"]:
|
||||||
|
# plugin_test(plugin='crazy_functions.Markdown_Translate->Markdown翻译指定语言', main_input="README.md", advanced_arg={"advanced_arg": lang})
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.知识库文件注入->知识库文件注入', main_input="./")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.知识库文件注入->读取知识库作答', main_input="What is the installation method?")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.知识库文件注入->读取知识库作答', main_input="远程云服务器部署?")
|
||||||
|
|
||||||
|
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2210.03629")
|
||||||
|
|
||||||
在新工单中引用
屏蔽一个用户