stage academic conversation

这个提交包含在:
binary-husky
2025-06-22 18:31:41 +08:00
父节点 8c21432291
当前提交 73f573092b
共有 45 个文件被更改,包括 9992 次插入17 次删除

查看文件

@@ -43,7 +43,7 @@ AVAIL_LLM_MODELS = ["qwen-max", "o1-mini", "o1-mini-2024-09-12", "o1", "o1-2024-
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo", "gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
"gemini-1.5-pro", "chatglm3", "chatglm4", "gemini-1.5-pro", "chatglm3", "chatglm4",
"deepseek-chat", "deepseek-coder", "deepseek-reasoner", "deepseek-chat", "deepseek-coder", "deepseek-reasoner",
"volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226", "volcengine-deepseek-r1-250120", "volcengine-deepseek-v3-241226",
"dashscope-deepseek-r1", "dashscope-deepseek-v3", "dashscope-deepseek-r1", "dashscope-deepseek-v3",
"dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b", "dashscope-qwen3-14b", "dashscope-qwen3-235b-a22b", "dashscope-qwen3-32b",
@@ -94,19 +94,19 @@ AVAIL_THEMES = ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast", "Gsta
FONT = "Theme-Default-Font" FONT = "Theme-Default-Font"
AVAIL_FONTS = [ AVAIL_FONTS = [
"默认值(Theme-Default-Font)", "默认值(Theme-Default-Font)",
"宋体(SimSun)", "宋体(SimSun)",
"黑体(SimHei)", "黑体(SimHei)",
"楷体(KaiTi)", "楷体(KaiTi)",
"仿宋(FangSong)", "仿宋(FangSong)",
"华文细黑(STHeiti Light)", "华文细黑(STHeiti Light)",
"华文楷体(STKaiti)", "华文楷体(STKaiti)",
"华文仿宋(STFangsong)", "华文仿宋(STFangsong)",
"华文宋体(STSong)", "华文宋体(STSong)",
"华文中宋(STZhongsong)", "华文中宋(STZhongsong)",
"华文新魏(STXinwei)", "华文新魏(STXinwei)",
"华文隶书(STLiti)", "华文隶书(STLiti)",
# 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)" # 备注:以下字体需要网络支持,您可以自定义任意您喜欢的字体,如下所示,需要满足的格式为 "字体昵称(字体英文真名@字体css下载链接)"
"思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)", "思源宋体(Source Han Serif CN VF@https://chinese-fonts-cdn.deno.dev/packages/syst/dist/SourceHanSerifCN/result.css)",
"月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)", "月星楷(Moon Stars Kai HW@https://chinese-fonts-cdn.deno.dev/packages/moon-stars-kai/dist/MoonStarsKaiHW-Regular/result.css)",
"珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)", "珠圆体(MaokenZhuyuanTi@https://chinese-fonts-cdn.deno.dev/packages/mkzyt/dist/猫啃珠圆体/result.css)",
@@ -355,6 +355,10 @@ DAAS_SERVER_URLS = [ f"https://niuziniu-biligpt{i}.hf.space/stream" for i in ran
JINA_API_KEY = "" JINA_API_KEY = ""
# SEMANTIC SCHOLAR API KEY
SEMANTIC_SCHOLAR_KEY = ""
# 是否自动裁剪上下文长度(是否启动,默认不启动) # 是否自动裁剪上下文长度(是否启动,默认不启动)
AUTO_CONTEXT_CLIP_ENABLE = False AUTO_CONTEXT_CLIP_ENABLE = False
# 目标裁剪上下文的token长度如果超过这个长度,则会自动裁剪 # 目标裁剪上下文的token长度如果超过这个长度,则会自动裁剪

查看文件

@@ -0,0 +1,290 @@
import re
import os
import asyncio
from typing import List, Dict, Tuple
from dataclasses import dataclass
from textwrap import dedent
from toolbox import CatchException, get_conf, update_ui, promote_file_to_downloadzone, get_log_folder, get_user
from toolbox import update_ui, CatchException, report_exception, write_history_to_file
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
from crazy_functions.review_fns.query_analyzer import QueryAnalyzer
from crazy_functions.review_fns.handlers.review_handler import 文献综述功能
from crazy_functions.review_fns.handlers.recommend_handler import 论文推荐功能
from crazy_functions.review_fns.handlers.qa_handler import 学术问答功能
from crazy_functions.review_fns.handlers.paper_handler import 单篇论文分析功能
from crazy_functions.Conversation_To_File import write_chat_to_file
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.review_fns.handlers.latest_handler import Arxiv最新论文推荐功能
from datetime import datetime
@CatchException
def 学术对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: List,
history: List, system_prompt: str, user_request: str):
"""主函数"""
# 初始化数据源
arxiv_source = ArxivSource()
semantic_source = SemanticScholarSource(
api_key=get_conf("SEMANTIC_SCHOLAR_KEY")
)
# 初始化处理器
handlers = {
"review": 文献综述功能(arxiv_source, semantic_source, llm_kwargs),
"recommend": 论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
"qa": 学术问答功能(arxiv_source, semantic_source, llm_kwargs),
"paper": 单篇论文分析功能(arxiv_source, semantic_source, llm_kwargs),
"latest": Arxiv最新论文推荐功能(arxiv_source, semantic_source, llm_kwargs),
}
# 分析查询意图
chatbot.append([None, "正在分析研究主题和查询要求..."])
yield from update_ui(chatbot=chatbot, history=history)
query_analyzer = QueryAnalyzer()
search_criteria = yield from query_analyzer.analyze_query(txt, chatbot, llm_kwargs)
handler = handlers.get(search_criteria.query_type)
if not handler:
handler = handlers["qa"] # 默认使用QA处理器
# 处理查询
chatbot.append([None, f"使用{handler.__class__.__name__}处理...,可能需要您耐心等待35分钟..."])
yield from update_ui(chatbot=chatbot, history=history)
final_prompt = asyncio.run(handler.handle(
criteria=search_criteria,
chatbot=chatbot,
history=history,
system_prompt=system_prompt,
llm_kwargs=llm_kwargs,
plugin_kwargs=plugin_kwargs
))
if final_prompt:
# 检查是否是道歉提示
if "很抱歉,我们未能找到" in final_prompt:
chatbot.append([txt, final_prompt])
yield from update_ui(chatbot=chatbot, history=history)
return
# 在 final_prompt 末尾添加用户原始查询要求
final_prompt += dedent(f"""
Original user query: "{txt}"
IMPORTANT NOTE :
- Your response must directly address the user's original user query above
- While following the previous guidelines, prioritize answering what the user specifically asked
- Make sure your response format and content align with the user's expectations
- Do not translate paper titles, keep them in their original language
- Do not generate a reference list in your response - references will be handled separately
""")
# 使用最终的prompt生成回答
response = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=final_prompt,
inputs_show_user=txt,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history=[],
sys_prompt=f"You are a helpful academic assistant. Response in Chinese by default unless specified language is required in the user's query."
)
# 1. 获取文献列表
papers_list = handler.ranked_papers # 直接使用原始论文数据
# 在新的对话中添加格式化的参考文献列表
if papers_list:
references = ""
for idx, paper in enumerate(papers_list, 1):
# 构建作者列表
authors = paper.authors[:3]
if len(paper.authors) > 3:
authors.append("et al.")
authors_str = ", ".join(authors)
# 构建期刊指标信息
metrics = []
if hasattr(paper, 'if_factor') and paper.if_factor:
metrics.append(f"IF: {paper.if_factor}")
if hasattr(paper, 'jcr_division') and paper.jcr_division:
metrics.append(f"JCR: {paper.jcr_division}")
if hasattr(paper, 'cas_division') and paper.cas_division:
metrics.append(f"中科院分区: {paper.cas_division}")
metrics_str = f" [{', '.join(metrics)}]" if metrics else ""
# 构建DOI链接
doi_link = ""
if paper.doi:
if "arxiv.org" in str(paper.doi):
doi_url = paper.doi
else:
doi_url = f"https://doi.org/{paper.doi}"
doi_link = f" <a href='{doi_url}' target='_blank'>DOI: {paper.doi}</a>"
# 构建完整的引用
reference = f"[{idx}] {authors_str}. *{paper.title}*"
if paper.venue_name:
reference += f". {paper.venue_name}"
if paper.year:
reference += f", {paper.year}"
reference += metrics_str
if doi_link:
reference += f".{doi_link}"
reference += " \n"
references += reference
# 添加新的对话显示参考文献
chatbot.append(["参考文献如下:", references])
yield from update_ui(chatbot=chatbot, history=history)
# 2. 保存为不同格式
from .review_fns.conversation_doc.word_doc import WordFormatter
from .review_fns.conversation_doc.word2pdf import WordToPdfConverter
from .review_fns.conversation_doc.markdown_doc import MarkdownFormatter
from .review_fns.conversation_doc.html_doc import HtmlFormatter
# 创建保存目录
save_dir = get_log_folder(get_user(chatbot), plugin_name='chatscholar')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 生成文件名
def get_safe_filename(txt, max_length=10):
# 获取文本前max_length个字符作为文件名
filename = txt[:max_length].strip()
# 移除不安全的文件名字符
filename = re.sub(r'[\\/:*?"<>|]', '', filename)
# 如果文件名为空,使用时间戳
if not filename:
filename = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
return filename
base_filename = get_safe_filename(txt)
result_files = [] # 收集所有生成的文件
pdf_path = None # 用于跟踪PDF是否成功生成
# 保存为Markdown
try:
md_formatter = MarkdownFormatter()
md_content = md_formatter.create_document(txt, response, papers_list)
result_file_md = write_history_to_file(
history=[md_content],
file_basename=f"markdown_{base_filename}.md"
)
result_files.append(result_file_md)
except Exception as e:
print(f"Markdown保存失败: {str(e)}")
# 保存为HTML
try:
html_formatter = HtmlFormatter()
html_content = html_formatter.create_document(txt, response, papers_list)
result_file_html = write_history_to_file(
history=[html_content],
file_basename=f"html_{base_filename}.html"
)
result_files.append(result_file_html)
except Exception as e:
print(f"HTML保存失败: {str(e)}")
# 保存为Word
try:
word_formatter = WordFormatter()
try:
doc = word_formatter.create_document(txt, response, papers_list)
except Exception as e:
print(f"Word文档内容生成失败: {str(e)}")
raise e
try:
result_file_docx = os.path.join(
os.path.dirname(result_file_md) if result_file_md else save_dir,
f"docx_{base_filename}.docx"
)
doc.save(result_file_docx)
result_files.append(result_file_docx)
print(f"Word文档已保存到: {result_file_docx}")
# 转换为PDF
try:
pdf_path = WordToPdfConverter.convert_to_pdf(result_file_docx)
if pdf_path:
result_files.append(pdf_path)
print(f"PDF文档已生成: {pdf_path}")
except Exception as e:
print(f"PDF转换失败: {str(e)}")
except Exception as e:
print(f"Word文档保存失败: {str(e)}")
raise e
except Exception as e:
print(f"Word格式化失败: {str(e)}")
import traceback
print(f"详细错误信息: {traceback.format_exc()}")
# 保存为BibTeX格式
try:
from .review_fns.conversation_doc.reference_formatter import ReferenceFormatter
ref_formatter = ReferenceFormatter()
bibtex_content = ref_formatter.create_document(papers_list)
# 在与其他文件相同目录下创建BibTeX文件
result_file_bib = os.path.join(
os.path.dirname(result_file_md) if result_file_md else save_dir,
f"references_{base_filename}.bib"
)
# 直接写入文件
with open(result_file_bib, 'w', encoding='utf-8') as f:
f.write(bibtex_content)
result_files.append(result_file_bib)
print(f"BibTeX文件已保存到: {result_file_bib}")
except Exception as e:
print(f"BibTeX格式保存失败: {str(e)}")
# 保存为EndNote格式
try:
from .review_fns.conversation_doc.endnote_doc import EndNoteFormatter
endnote_formatter = EndNoteFormatter()
endnote_content = endnote_formatter.create_document(papers_list)
# 在与其他文件相同目录下创建EndNote文件
result_file_enw = os.path.join(
os.path.dirname(result_file_md) if result_file_md else save_dir,
f"references_{base_filename}.enw"
)
# 直接写入文件
with open(result_file_enw, 'w', encoding='utf-8') as f:
f.write(endnote_content)
result_files.append(result_file_enw)
print(f"EndNote文件已保存到: {result_file_enw}")
except Exception as e:
print(f"EndNote格式保存失败: {str(e)}")
# 添加所有文件到下载区
success_files = []
for file in result_files:
try:
promote_file_to_downloadzone(file, chatbot=chatbot)
success_files.append(os.path.basename(file))
except Exception as e:
print(f"文件添加到下载区失败: {str(e)}")
# 更新成功提示消息
if success_files:
chatbot.append(["保存对话记录成功,bib和enw文件支持导入到EndNote、Zotero、JabRef、Mendeley等文献管理软件,HTML文件支持在浏览器中打开,里面包含详细论文源信息", "对话已保存并添加到下载区,可以在下载区找到相关文件"])
else:
chatbot.append(["保存对话记录", "所有格式的保存都失败了,请检查错误日志。"])
yield from update_ui(chatbot=chatbot, history=history)
else:
report_exception(chatbot, history, a=f"处理失败", b=f"请尝试其他查询")
yield from update_ui(chatbot=chatbot, history=history)

查看文件

@@ -0,0 +1,275 @@
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
from loguru import logger
import time
import re
def force_breakdown(txt, limit, get_token_fn):
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
"""
for i in reversed(range(len(txt))):
if get_token_fn(txt[:i]) < limit:
return txt[:i], txt[i:]
return "Tiktoken未知错误", "Tiktoken未知错误"
def maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage):
""" 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
当 remain_txt_to_cut < `_min` 时,我们再把 remain_txt_to_cut_storage 中的部分文字取出
"""
_min = int(5e4)
_max = int(1e5)
# print(len(remain_txt_to_cut), len(remain_txt_to_cut_storage))
if len(remain_txt_to_cut) < _min and len(remain_txt_to_cut_storage) > 0:
remain_txt_to_cut = remain_txt_to_cut + remain_txt_to_cut_storage
remain_txt_to_cut_storage = ""
if len(remain_txt_to_cut) > _max:
remain_txt_to_cut_storage = remain_txt_to_cut[_max:] + remain_txt_to_cut_storage
remain_txt_to_cut = remain_txt_to_cut[:_max]
return remain_txt_to_cut, remain_txt_to_cut_storage
def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=False):
""" 文本切分
"""
res = []
total_len = len(txt_tocut)
fin_len = 0
remain_txt_to_cut = txt_tocut
remain_txt_to_cut_storage = ""
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
while True:
if get_token_fn(remain_txt_to_cut) <= limit:
# 如果剩余文本的token数小于限制,那么就不用切了
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
break
else:
# 如果剩余文本的token数大于限制,那么就切
lines = remain_txt_to_cut.split('\n')
# 估计一个切分点
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
estimated_line_cut = int(estimated_line_cut)
# 开始查找合适切分点的偏移cnt
cnt = 0
for cnt in reversed(range(estimated_line_cut)):
if must_break_at_empty_line:
# 首先尝试用双空行(\n\n作为切分点
if lines[cnt] != "":
continue
prev = "\n".join(lines[:cnt])
post = "\n".join(lines[cnt:])
if get_token_fn(prev) < limit:
break
if cnt == 0:
# 如果没有找到合适的切分点
if break_anyway:
# 是否允许暴力切分
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
else:
# 不允许直接报错
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
# 追加列表
res.append(prev); fin_len+=len(prev)
# 准备下一次迭代
remain_txt_to_cut = post
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
process = fin_len/total_len
logger.info(f'正在文本切分 {int(process*100)}%')
if len(remain_txt_to_cut.strip()) == 0:
break
return res
def breakdown_text_to_satisfy_token_limit_(txt, limit, llm_model="gpt-3.5-turbo"):
""" 使用多种方式尝试切分文本,以满足 token 限制
"""
from request_llms.bridge_all import model_info
enc = model_info[llm_model]['tokenizer']
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
try:
# 第1次尝试,将双空行\n\n作为切分点
return cut(limit, get_token_fn, txt, must_break_at_empty_line=True)
except RuntimeError:
try:
# 第2次尝试,将单空行\n作为切分点
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False)
except RuntimeError:
try:
# 第3次尝试,将英文句号.)作为切分点
res = cut(limit, get_token_fn, txt.replace('.', '\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
return [r.replace('\n', '.') for r in res]
except RuntimeError as e:
try:
# 第4次尝试,将中文句号作为切分点
res = cut(limit, get_token_fn, txt.replace('', '。。\n'), must_break_at_empty_line=False)
return [r.replace('。。\n', '') for r in res]
except RuntimeError as e:
# 第5次尝试,没办法了,随便切一下吧
return cut(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
breakdown_text_to_satisfy_token_limit = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_, timeout=60)
def cut_new(limit, get_token_fn, txt_tocut, must_break_at_empty_line, must_break_at_one_empty_line=False, break_anyway=False):
""" 文本切分
"""
res = []
res_empty_line = []
total_len = len(txt_tocut)
fin_len = 0
remain_txt_to_cut = txt_tocut
remain_txt_to_cut_storage = ""
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
empty=0
while True:
if get_token_fn(remain_txt_to_cut) <= limit:
# 如果剩余文本的token数小于限制,那么就不用切了
res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut)
res_empty_line.append(empty)
break
else:
# 如果剩余文本的token数大于限制,那么就切
lines = remain_txt_to_cut.split('\n')
# 估计一个切分点
estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines)
estimated_line_cut = int(estimated_line_cut)
# 开始查找合适切分点的偏移cnt
cnt = 0
for cnt in reversed(range(estimated_line_cut)):
if must_break_at_empty_line:
# 首先尝试用双空行(\n\n作为切分点
if lines[cnt] != "":
continue
if must_break_at_empty_line or must_break_at_one_empty_line:
empty=1
prev = "\n".join(lines[:cnt])
post = "\n".join(lines[cnt:])
if get_token_fn(prev) < limit :
break
# empty=0
if get_token_fn(prev)>limit:
if '.' not in prev or '' not in prev:
# empty = 0
break
# if cnt
if cnt == 0:
# 如果没有找到合适的切分点
if break_anyway:
# 是否允许暴力切分
prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn)
empty =0
else:
# 不允许直接报错
raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}")
# 追加列表
res.append(prev); fin_len+=len(prev)
res_empty_line.append(empty)
# 准备下一次迭代
remain_txt_to_cut = post
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
process = fin_len/total_len
logger.info(f'正在文本切分 {int(process*100)}%')
if len(remain_txt_to_cut.strip()) == 0:
break
return res,res_empty_line
def breakdown_text_to_satisfy_token_limit_new_(txt, limit, llm_model="gpt-3.5-turbo"):
""" 使用多种方式尝试切分文本,以满足 token 限制
"""
from request_llms.bridge_all import model_info
enc = model_info[llm_model]['tokenizer']
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
try:
# 第1次尝试,将双空行\n\n作为切分点
res, empty_line =cut_new(limit, get_token_fn, txt, must_break_at_empty_line=True)
return res,empty_line
except RuntimeError:
try:
# 第2次尝试,将单空行\n作为切分点
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False,must_break_at_one_empty_line=True)
return res, _
except RuntimeError:
try:
# 第3次尝试,将英文句号.)作为切分点
res, _ = cut_new(limit, get_token_fn, txt.replace('.', '\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在
return [r.replace('\n', '.') for r in res],_
except RuntimeError as e:
try:
# 第4次尝试,将中文句号作为切分点
res,_ = cut_new(limit, get_token_fn, txt.replace('', '。。\n'), must_break_at_empty_line=False)
return [r.replace('。。\n', '') for r in res], _
except RuntimeError as e:
# 第5次尝试,没办法了,随便切一下吧
res, _ = cut_new(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True)
return res,_
breakdown_text_to_satisfy_token_limit_new = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_new_, timeout=60)
def cut_from_end_to_satisfy_token_limit_(txt, limit, reserve_token=500, llm_model="gpt-3.5-turbo"):
"""从后往前裁剪文本,以论文为单位进行裁剪
参数:
txt: 要处理的文本(格式化后的论文列表字符串)
limit: token数量上限
reserve_token: 需要预留的token数量,默认500
llm_model: 使用的模型名称
返回:
裁剪后的文本
"""
from request_llms.bridge_all import model_info
enc = model_info[llm_model]['tokenizer']
def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=()))
# 计算当前文本的token数
current_tokens = get_token_fn(txt)
target_limit = limit - reserve_token
# 如果当前token数已经在限制范围内,直接返回
if current_tokens <= target_limit:
return txt
# 按论文编号分割文本
papers = re.split(r'\n(?=\d+\. \*\*)', txt)
if not papers:
return txt
# 从前往后累加论文,直到达到token限制
result = papers[0] # 保留第一篇
current_tokens = get_token_fn(result)
for paper in papers[1:]:
paper_tokens = get_token_fn(paper)
if current_tokens + paper_tokens <= target_limit:
result += "\n" + paper
current_tokens += paper_tokens
else:
break
return result
# 添加超时保护
cut_from_end_to_satisfy_token_limit = run_in_subprocess_with_timeout(cut_from_end_to_satisfy_token_limit_, timeout=20)
if __name__ == '__main__':
from crazy_functions.crazy_utils import read_and_clean_pdf_text
file_content, page_one = read_and_clean_pdf_text("build/assets/at.pdf")
from request_llms.bridge_all import model_info
for i in range(5):
file_content += file_content
logger.info(len(file_content))
TOKEN_LIMIT_PER_FRAGMENT = 2500
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)

查看文件

查看文件

@@ -0,0 +1,68 @@
from typing import List
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
class EndNoteFormatter:
"""EndNote参考文献格式生成器"""
def __init__(self):
pass
def create_document(self, papers: List[PaperMetadata]) -> str:
"""生成EndNote格式的参考文献文本
Args:
papers: 论文列表
Returns:
str: EndNote格式的参考文献文本
"""
endnote_text = ""
for paper in papers:
# 开始一个新条目
endnote_text += "%0 Journal Article\n" # 默认类型为期刊文章
# 根据venue_type调整条目类型
if hasattr(paper, 'venue_type') and paper.venue_type:
if paper.venue_type.lower() == 'conference':
endnote_text = endnote_text.replace("Journal Article", "Conference Paper")
elif paper.venue_type.lower() == 'preprint':
endnote_text = endnote_text.replace("Journal Article", "Electronic Article")
# 添加标题
endnote_text += f"%T {paper.title}\n"
# 添加作者
for author in paper.authors:
endnote_text += f"%A {author}\n"
# 添加年份
if paper.year:
endnote_text += f"%D {paper.year}\n"
# 添加期刊/会议名称
if hasattr(paper, 'venue_name') and paper.venue_name:
endnote_text += f"%J {paper.venue_name}\n"
elif paper.venue:
endnote_text += f"%J {paper.venue}\n"
# 添加DOI
if paper.doi:
endnote_text += f"%R {paper.doi}\n"
endnote_text += f"%U https://doi.org/{paper.doi}\n"
elif paper.url:
endnote_text += f"%U {paper.url}\n"
# 添加摘要
if paper.abstract:
endnote_text += f"%X {paper.abstract}\n"
# 添加机构
if hasattr(paper, 'institutions'):
for institution in paper.institutions:
endnote_text += f"%I {institution}\n"
# 条目之间添加空行
endnote_text += "\n"
return endnote_text

查看文件

@@ -0,0 +1,211 @@
import re
import os
import pandas as pd
from datetime import datetime
class ExcelTableFormatter:
"""聊天记录中Markdown表格转Excel生成器"""
def __init__(self):
"""初始化Excel文档对象"""
from openpyxl import Workbook
self.workbook = Workbook()
self._table_count = 0
self._current_sheet = None
def _normalize_table_row(self, row):
"""标准化表格行,处理不同的分隔符情况"""
row = row.strip()
if row.startswith('|'):
row = row[1:]
if row.endswith('|'):
row = row[:-1]
return [cell.strip() for cell in row.split('|')]
def _is_separator_row(self, row):
"""检查是否是分隔行(由 - 或 : 组成)"""
clean_row = re.sub(r'[\s|]', '', row)
return bool(re.match(r'^[-:]+$', clean_row))
def _extract_tables_from_text(self, text):
"""从文本中提取所有表格内容"""
if not isinstance(text, str):
return []
tables = []
current_table = []
is_in_table = False
for line in text.split('\n'):
line = line.strip()
if not line:
if is_in_table and current_table:
if len(current_table) >= 2:
tables.append(current_table)
current_table = []
is_in_table = False
continue
if '|' in line:
if not is_in_table:
is_in_table = True
current_table.append(line)
else:
if is_in_table and current_table:
if len(current_table) >= 2:
tables.append(current_table)
current_table = []
is_in_table = False
if is_in_table and current_table and len(current_table) >= 2:
tables.append(current_table)
return tables
def _parse_table(self, table_lines):
"""解析表格内容为结构化数据"""
try:
headers = self._normalize_table_row(table_lines[0])
separator_index = next(
(i for i, line in enumerate(table_lines) if self._is_separator_row(line)),
1
)
data_rows = []
for line in table_lines[separator_index + 1:]:
cells = self._normalize_table_row(line)
# 确保单元格数量与表头一致
while len(cells) < len(headers):
cells.append('')
cells = cells[:len(headers)]
data_rows.append(cells)
if headers and data_rows:
return {
'headers': headers,
'data': data_rows
}
except Exception as e:
print(f"解析表格时发生错误: {str(e)}")
return None
def _create_sheet(self, question_num, table_num):
"""创建新的工作表"""
sheet_name = f'Q{question_num}_T{table_num}'
if len(sheet_name) > 31:
sheet_name = f'Table{self._table_count}'
if sheet_name in self.workbook.sheetnames:
sheet_name = f'{sheet_name}_{datetime.now().strftime("%H%M%S")}'
return self.workbook.create_sheet(title=sheet_name)
def create_document(self, history):
"""
处理聊天历史中的所有表格并创建Excel文档
Args:
history: 聊天历史列表
Returns:
Workbook: 处理完成的Excel工作簿对象,如果没有表格则返回None
"""
has_tables = False
# 删除默认创建的工作表
default_sheet = self.workbook['Sheet']
self.workbook.remove(default_sheet)
# 遍历所有回答
for i in range(1, len(history), 2):
answer = history[i]
tables = self._extract_tables_from_text(answer)
for table_lines in tables:
parsed_table = self._parse_table(table_lines)
if parsed_table:
self._table_count += 1
sheet = self._create_sheet(i // 2 + 1, self._table_count)
# 写入表头
for col, header in enumerate(parsed_table['headers'], 1):
sheet.cell(row=1, column=col, value=header)
# 写入数据
for row_idx, row_data in enumerate(parsed_table['data'], 2):
for col_idx, value in enumerate(row_data, 1):
sheet.cell(row=row_idx, column=col_idx, value=value)
has_tables = True
return self.workbook if has_tables else None
def save_chat_tables(history, save_dir, base_name):
"""
保存聊天历史中的表格到Excel文件
Args:
history: 聊天历史列表
save_dir: 保存目录
base_name: 基础文件名
Returns:
list: 保存的文件路径列表
"""
result_files = []
try:
# 创建Excel格式
excel_formatter = ExcelTableFormatter()
workbook = excel_formatter.create_document(history)
if workbook is not None:
# 确保保存目录存在
os.makedirs(save_dir, exist_ok=True)
# 生成Excel文件路径
excel_file = os.path.join(save_dir, base_name + '.xlsx')
# 保存Excel文件
workbook.save(excel_file)
result_files.append(excel_file)
print(f"已保存表格到Excel文件: {excel_file}")
except Exception as e:
print(f"保存Excel格式失败: {str(e)}")
return result_files
# 使用示例
if __name__ == "__main__":
# 示例聊天历史
history = [
"问题1",
"""这是第一个表格:
| A | B | C |
|---|---|---|
| 1 | 2 | 3 |""",
"问题2",
"这是没有表格的回答",
"问题3",
"""回答包含多个表格:
| Name | Age |
|------|-----|
| Tom | 20 |
第二个表格:
| X | Y |
|---|---|
| 1 | 2 |"""
]
# 保存表格
save_dir = "output"
base_name = "chat_tables"
saved_files = save_chat_tables(history, save_dir, base_name)

查看文件

@@ -0,0 +1,472 @@
class HtmlFormatter:
"""聊天记录HTML格式生成器"""
def __init__(self):
self.css_styles = """
:root {
--primary-color: #2563eb;
--primary-light: #eff6ff;
--secondary-color: #1e293b;
--background-color: #f8fafc;
--text-color: #334155;
--border-color: #e2e8f0;
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
}
body {
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
line-height: 1.8;
margin: 0;
padding: 2rem;
color: var(--text-color);
background-color: var(--background-color);
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
padding: 2rem;
border-radius: 16px;
box-shadow: var(--card-shadow);
}
::selection {
background: var(--primary-light);
color: var(--primary-color);
}
@keyframes fadeIn {
from { opacity: 0; transform: translateY(20px); }
to { opacity: 1; transform: translateY(0); }
}
@keyframes slideIn {
from { transform: translateX(-20px); opacity: 0; }
to { transform: translateX(0); opacity: 1; }
}
.container {
animation: fadeIn 0.6s ease-out;
}
.QaBox {
animation: slideIn 0.5s ease-out;
transition: all 0.3s ease;
}
.QaBox:hover {
transform: translateX(5px);
}
.Question, .Answer, .historyBox {
transition: all 0.3s ease;
}
.chat-title {
color: var(--primary-color);
font-size: 2em;
text-align: center;
margin: 1rem 0 2rem;
padding-bottom: 1rem;
border-bottom: 2px solid var(--primary-color);
}
.chat-body {
display: flex;
flex-direction: column;
gap: 1.5rem;
margin: 2rem 0;
}
.QaBox {
background: white;
padding: 1.5rem;
border-radius: 8px;
border-left: 4px solid var(--primary-color);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
margin-bottom: 1.5rem;
}
.Question {
color: var(--secondary-color);
font-weight: 500;
margin-bottom: 1rem;
}
.Answer {
color: var(--text-color);
background: var(--primary-light);
padding: 1rem;
border-radius: 6px;
}
.history-section {
margin-top: 3rem;
padding-top: 2rem;
border-top: 2px solid var(--border-color);
}
.history-title {
color: var(--secondary-color);
font-size: 1.5em;
margin-bottom: 1.5rem;
text-align: center;
}
.historyBox {
background: white;
padding: 1rem;
margin: 0.5rem 0;
border-radius: 6px;
border: 1px solid var(--border-color);
}
@media (prefers-color-scheme: dark) {
:root {
--background-color: #0f172a;
--text-color: #e2e8f0;
--border-color: #1e293b;
}
.container, .QaBox {
background: #1e293b;
}
}
"""
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
"""生成完整的HTML文档
Args:
question: str, 用户问题
answer: str, AI回答
ranked_papers: list, 排序后的论文列表
Returns:
str: 完整的HTML文档字符串
"""
chat_content = f'''
<div class="QaBox">
<div class="Question">{question}</div>
<div class="Answer markdown-body" id="answer-content">{answer}</div>
</div>
'''
references_content = ""
if ranked_papers:
references_content = '<div class="history-section"><h2 class="history-title">参考文献</h2>'
for idx, paper in enumerate(ranked_papers, 1):
authors = ', '.join(paper.authors)
# 构建引用信息
citations_info = f"被引用次数:{paper.citations}" if paper.citations is not None else "引用信息未知"
# 构建下载链接
download_links = []
if paper.doi:
# 检查是否是arXiv链接
if 'arxiv.org' in paper.doi:
# 如果DOI中包含完整的arXiv URL,直接使用
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
download_links.append(f'<a href="{arxiv_url}">arXiv链接</a>')
# 提取arXiv ID并添加PDF链接
arxiv_id = arxiv_url.split('abs/')[-1].split('v')[0]
download_links.append(f'<a href="https://arxiv.org/pdf/{arxiv_id}.pdf">PDF下载</a>')
else:
# 非arXiv的DOI使用标准格式
download_links.append(f'<a href="https://doi.org/{paper.doi}">DOI: {paper.doi}</a>')
if hasattr(paper, 'url') and paper.url and 'arxiv.org' not in str(paper.url):
# 只有当URL不是arXiv链接时才添加
download_links.append(f'<a href="{paper.url}">原文链接</a>')
download_section = ' | '.join(download_links) if download_links else "无直接下载链接"
# 构建来源信息
source_info = []
if paper.venue_type:
source_info.append(f"类型:{paper.venue_type}")
if paper.venue_name:
source_info.append(f"来源:{paper.venue_name}")
# 添加期刊指标信息
if hasattr(paper, 'if_factor') and paper.if_factor:
source_info.append(f"<span class='journal-metric'>IF: {paper.if_factor}</span>")
if hasattr(paper, 'jcr_division') and paper.jcr_division:
source_info.append(f"<span class='journal-metric'>JCR分区: {paper.jcr_division}</span>")
if hasattr(paper, 'cas_division') and paper.cas_division:
source_info.append(f"<span class='journal-metric'>中科院分区: {paper.cas_division}</span>")
if hasattr(paper, 'venue_info') and paper.venue_info:
if paper.venue_info.get('journal_ref'):
source_info.append(f"期刊参考:{paper.venue_info['journal_ref']}")
if paper.venue_info.get('publisher'):
source_info.append(f"出版商:{paper.venue_info['publisher']}")
source_section = ' | '.join(source_info) if source_info else ""
# 构建标准引用格式
standard_citation = f"[{idx}] "
# 添加作者最多3个,超过则添加et al.
author_list = paper.authors[:3]
if len(paper.authors) > 3:
author_list.append("et al.")
standard_citation += ", ".join(author_list) + ". "
# 添加标题
standard_citation += f"<i>{paper.title}</i>"
# 添加期刊/会议名称
if paper.venue_name:
standard_citation += f". {paper.venue_name}"
# 添加年份
if paper.year:
standard_citation += f", {paper.year}"
# 添加DOI
if paper.doi:
if 'arxiv.org' in paper.doi:
# 如果是arXiv链接,直接使用arXiv URL
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
standard_citation += f". {arxiv_url}"
else:
# 非arXiv的DOI使用标准格式
standard_citation += f". DOI: {paper.doi}"
standard_citation += "."
references_content += f'''
<div class="historyBox">
<div class="entry">
<p class="paper-title"><b>[{idx}]</b> <i>{paper.title}</i></p>
<p class="paper-authors">作者:{authors}</p>
<p class="paper-year">发表年份:{paper.year if paper.year else "未知"}</p>
<p class="paper-citations">{citations_info}</p>
{f'<p class="paper-source">{source_section}</p>' if source_section else ""}
<p class="paper-abstract">摘要:{paper.abstract if paper.abstract else "无摘要"}</p>
<p class="paper-links">链接:{download_section}</p>
<div class="standard-citation">
<p class="citation-title">标准引用格式:</p>
<p class="citation-text">{standard_citation}</p>
<button class="copy-btn" onclick="copyToClipboard(this.previousElementSibling)">复制引用格式</button>
</div>
</div>
</div>
'''
references_content += '</div>'
# 添加新的CSS样式
css_additions = """
.paper-title {
font-size: 1.1em;
margin-bottom: 0.5em;
}
.paper-authors {
color: var(--secondary-color);
margin: 0.3em 0;
}
.paper-year, .paper-citations {
color: var(--text-color);
margin: 0.3em 0;
}
.paper-source {
color: var(--text-color);
font-style: italic;
margin: 0.3em 0;
}
.paper-abstract {
margin: 0.8em 0;
padding: 0.8em;
background: var(--primary-light);
border-radius: 4px;
}
.paper-links {
margin-top: 0.5em;
}
.paper-links a {
color: var(--primary-color);
text-decoration: none;
margin-right: 1em;
}
.paper-links a:hover {
text-decoration: underline;
}
.standard-citation {
margin-top: 1em;
padding: 1em;
background: #f8fafc;
border-radius: 4px;
border: 1px solid var(--border-color);
}
.citation-title {
font-weight: bold;
margin-bottom: 0.5em;
color: var(--secondary-color);
}
.citation-text {
font-family: 'Times New Roman', Times, serif;
line-height: 1.6;
margin-bottom: 0.5em;
padding: 0.5em;
background: white;
border-radius: 4px;
border: 1px solid var(--border-color);
}
.copy-btn {
background: var(--primary-color);
color: white;
border: none;
padding: 0.5em 1em;
border-radius: 4px;
cursor: pointer;
font-size: 0.9em;
transition: background-color 0.2s;
}
.copy-btn:hover {
background: #1e40af;
}
@media (prefers-color-scheme: dark) {
.standard-citation {
background: #1e293b;
}
.citation-text {
background: #0f172a;
}
}
/* 添加期刊指标样式 */
.journal-metric {
display: inline-block;
padding: 0.2em 0.6em;
margin: 0 0.3em;
background: var(--primary-light);
border-radius: 4px;
font-weight: 500;
color: var(--primary-color);
}
@media (prefers-color-scheme: dark) {
.journal-metric {
background: #1e293b;
color: #60a5fa;
}
}
"""
# 修改 js_code 部分,添加 markdown 解析功能
js_code = """
<script>
// 复制功能
function copyToClipboard(element) {
const text = element.innerText;
navigator.clipboard.writeText(text).then(function() {
const btn = element.nextElementSibling;
const originalText = btn.innerText;
btn.innerText = '已复制!';
setTimeout(() => {
btn.innerText = originalText;
}, 2000);
}).catch(function(err) {
console.error('复制失败:', err);
});
}
// Markdown解析
document.addEventListener('DOMContentLoaded', function() {
const answerContent = document.getElementById('answer-content');
if (answerContent) {
const markdown = answerContent.textContent;
answerContent.innerHTML = marked.parse(markdown);
}
});
</script>
"""
# 将新的CSS样式添加到现有样式中
self.css_styles += css_additions
return f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>学术对话存档</title>
<!-- 添加 marked.js -->
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<!-- 添加 GitHub Markdown CSS -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/gh/sindresorhus/github-markdown-css@4.0.0/github-markdown.min.css">
<style>
{self.css_styles}
/* 添加 Markdown 相关样式 */
.markdown-body {{
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
padding: 1rem;
background: var(--primary-light);
border-radius: 6px;
}}
.markdown-body pre {{
background-color: #f6f8fa;
border-radius: 6px;
padding: 16px;
overflow: auto;
}}
.markdown-body code {{
background-color: rgba(175,184,193,0.2);
border-radius: 6px;
padding: 0.2em 0.4em;
font-size: 85%;
}}
.markdown-body pre code {{
background-color: transparent;
padding: 0;
}}
.markdown-body blockquote {{
border-left: 0.25em solid #d0d7de;
padding: 0 1em;
color: #656d76;
}}
.markdown-body table {{
border-collapse: collapse;
width: 100%;
margin: 1em 0;
}}
.markdown-body table th,
.markdown-body table td {{
border: 1px solid #d0d7de;
padding: 6px 13px;
}}
.markdown-body table tr:nth-child(2n) {{
background-color: #f6f8fa;
}}
@media (prefers-color-scheme: dark) {{
.markdown-body {{
background: #1e293b;
color: #e2e8f0;
}}
.markdown-body pre {{
background-color: #0f172a;
}}
.markdown-body code {{
background-color: rgba(99,110,123,0.4);
}}
.markdown-body blockquote {{
border-left-color: #30363d;
color: #8b949e;
}}
.markdown-body table th,
.markdown-body table td {{
border-color: #30363d;
}}
.markdown-body table tr:nth-child(2n) {{
background-color: #0f172a;
}}
}}
</style>
</head>
<body>
<div class="container">
<h1 class="chat-title">学术对话存档</h1>
<div class="chat-body">
{chat_content}
{references_content}
</div>
</div>
{js_code}
</body>
</html>
"""

查看文件

@@ -0,0 +1,47 @@
class MarkdownFormatter:
"""Markdown格式文档生成器 - 用于生成对话记录的markdown文档"""
def __init__(self):
self.content = []
def _add_content(self, text: str):
"""添加正文内容"""
if text:
self.content.append(f"\n{text}\n")
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
"""创建完整的Markdown文档
Args:
question: str, 用户问题
answer: str, AI回答
ranked_papers: list, 排序后的论文列表
Returns:
str: 生成的Markdown文本
"""
content = []
# 添加问答部分
content.append("## 问题")
content.append(question)
content.append("\n## 回答")
content.append(answer)
# 添加参考文献
if ranked_papers:
content.append("\n## 参考文献")
for idx, paper in enumerate(ranked_papers, 1):
authors = ', '.join(paper.authors[:3])
if len(paper.authors) > 3:
authors += ' et al.'
ref = f"[{idx}] {authors}. *{paper.title}*"
if paper.venue_name:
ref += f". {paper.venue_name}"
if paper.year:
ref += f", {paper.year}"
if paper.doi:
ref += f". [DOI: {paper.doi}](https://doi.org/{paper.doi})"
content.append(ref)
return "\n\n".join(content)

查看文件

@@ -0,0 +1,174 @@
from typing import List
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
import re
class ReferenceFormatter:
"""通用参考文献格式生成器"""
def __init__(self):
pass
def _sanitize_bibtex(self, text: str) -> str:
"""清理BibTeX字符串,处理特殊字符"""
if not text:
return ""
# 替换特殊字符
replacements = {
'&': '\\&',
'%': '\\%',
'$': '\\$',
'#': '\\#',
'_': '\\_',
'{': '\\{',
'}': '\\}',
'~': '\\textasciitilde{}',
'^': '\\textasciicircum{}',
'\\': '\\textbackslash{}',
'<': '\\textless{}',
'>': '\\textgreater{}',
'"': '``',
"'": "'",
'-': '--',
'': '---',
}
for char, replacement in replacements.items():
text = text.replace(char, replacement)
return text
def _generate_cite_key(self, paper: PaperMetadata) -> str:
"""生成引用键
格式: 第一作者姓氏_年份_第一个实词
"""
# 获取第一作者姓氏
first_author = ""
if paper.authors and len(paper.authors) > 0:
first_author = paper.authors[0].split()[-1].lower()
# 获取年份
year = str(paper.year) if paper.year else "0000"
# 从标题中获取第一个实词
title_word = ""
if paper.title:
# 移除特殊字符,分割成单词
words = re.findall(r'\w+', paper.title.lower())
# 过滤掉常见的停用词
stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
for word in words:
if word not in stop_words and len(word) > 2:
title_word = word
break
# 组合cite key
cite_key = f"{first_author}{year}{title_word}"
# 确保cite key只包含合法字符
cite_key = re.sub(r'[^a-z0-9]', '', cite_key.lower())
return cite_key
def _get_entry_type(self, paper: PaperMetadata) -> str:
"""确定BibTeX条目类型"""
if hasattr(paper, 'venue_type') and paper.venue_type:
venue_type = paper.venue_type.lower()
if venue_type == 'conference':
return 'inproceedings'
elif venue_type == 'preprint':
return 'unpublished'
elif venue_type == 'journal':
return 'article'
elif venue_type == 'book':
return 'book'
elif venue_type == 'thesis':
return 'phdthesis'
return 'article' # 默认为期刊文章
def create_document(self, papers: List[PaperMetadata]) -> str:
"""生成BibTeX格式的参考文献文本"""
bibtex_text = "% This file was automatically generated by GPT-Academic\n"
bibtex_text += "% Compatible with: EndNote, Zotero, JabRef, and LaTeX\n\n"
for paper in papers:
entry_type = self._get_entry_type(paper)
cite_key = self._generate_cite_key(paper)
bibtex_text += f"@{entry_type}{{{cite_key},\n"
# 添加标题
if paper.title:
bibtex_text += f" title = {{{self._sanitize_bibtex(paper.title)}}},\n"
# 添加作者
if paper.authors:
# 确保每个作者的姓和名正确分隔
processed_authors = []
for author in paper.authors:
names = author.split()
if len(names) > 1:
# 假设最后一个词是姓,其他的是名
surname = names[-1]
given_names = ' '.join(names[:-1])
processed_authors.append(f"{surname}, {given_names}")
else:
processed_authors.append(author)
authors = " and ".join([self._sanitize_bibtex(author) for author in processed_authors])
bibtex_text += f" author = {{{authors}}},\n"
# 添加年份
if paper.year:
bibtex_text += f" year = {{{paper.year}}},\n"
# 添加期刊/会议名称
if hasattr(paper, 'venue_name') and paper.venue_name:
if entry_type == 'inproceedings':
bibtex_text += f" booktitle = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
elif entry_type == 'article':
bibtex_text += f" journal = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
# 添加期刊相关信息
if hasattr(paper, 'venue_info'):
if 'volume' in paper.venue_info:
bibtex_text += f" volume = {{{paper.venue_info['volume']}}},\n"
if 'number' in paper.venue_info:
bibtex_text += f" number = {{{paper.venue_info['number']}}},\n"
if 'pages' in paper.venue_info:
bibtex_text += f" pages = {{{paper.venue_info['pages']}}},\n"
elif paper.venue:
venue_field = "booktitle" if entry_type == "inproceedings" else "journal"
bibtex_text += f" {venue_field} = {{{self._sanitize_bibtex(paper.venue)}}},\n"
# 添加DOI
if paper.doi:
bibtex_text += f" doi = {{{paper.doi}}},\n"
# 添加URL
if paper.url:
bibtex_text += f" url = {{{paper.url}}},\n"
elif paper.doi:
bibtex_text += f" url = {{https://doi.org/{paper.doi}}},\n"
# 添加摘要
if paper.abstract:
bibtex_text += f" abstract = {{{self._sanitize_bibtex(paper.abstract)}}},\n"
# 添加机构
if hasattr(paper, 'institutions') and paper.institutions:
institutions = " and ".join([self._sanitize_bibtex(inst) for inst in paper.institutions])
bibtex_text += f" institution = {{{institutions}}},\n"
# 添加月份
if hasattr(paper, 'month'):
bibtex_text += f" month = {{{paper.month}}},\n"
# 添加注释字段
if hasattr(paper, 'note'):
bibtex_text += f" note = {{{self._sanitize_bibtex(paper.note)}}},\n"
# 移除最后一个逗号并关闭条目
bibtex_text = bibtex_text.rstrip(',\n') + "\n}\n\n"
return bibtex_text

查看文件

@@ -0,0 +1,138 @@
from docx2pdf import convert
import os
import platform
from typing import Union
from pathlib import Path
from datetime import datetime
class WordToPdfConverter:
"""Word文档转PDF转换器"""
@staticmethod
def _replace_docx_in_filename(filename: Union[str, Path]) -> Path:
"""
将文件名中的'docx'替换为'pdf'
例如: 'docx_test.pdf' -> 'pdf_test.pdf'
"""
path = Path(filename)
new_name = path.stem.replace('docx', 'pdf')
return path.parent / f"{new_name}{path.suffix}"
@staticmethod
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
"""
将Word文档转换为PDF
参数:
word_path: Word文档的路径
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
返回:
生成的PDF文件路径
异常:
如果转换失败,将抛出相应异常
"""
try:
word_path = Path(word_path)
if pdf_path is None:
# 创建新的pdf路径,同时替换文件名中的docx
pdf_path = WordToPdfConverter._replace_docx_in_filename(word_path).with_suffix('.pdf')
else:
pdf_path = WordToPdfConverter._replace_docx_in_filename(Path(pdf_path))
# 检查操作系统
if platform.system() == 'Linux':
# Linux系统需要安装libreoffice
if not os.system('which libreoffice') == 0:
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
# 使用libreoffice进行转换
os.system(f'libreoffice --headless --convert-to pdf "{word_path}" --outdir "{pdf_path.parent}"')
# 如果输出路径与默认生成的不同,则重命名
default_pdf = word_path.with_suffix('.pdf')
if default_pdf != pdf_path:
os.rename(default_pdf, pdf_path)
else:
# Windows和MacOS使用 docx2pdf
convert(word_path, pdf_path)
return str(pdf_path)
except Exception as e:
raise Exception(f"转换PDF失败: {str(e)}")
@staticmethod
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
"""
批量转换目录下的所有Word文档
参数:
word_dir: 包含Word文档的目录路径
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
返回:
生成的PDF文件路径列表
"""
word_dir = Path(word_dir)
if pdf_dir:
pdf_dir = Path(pdf_dir)
pdf_dir.mkdir(parents=True, exist_ok=True)
converted_files = []
for word_file in word_dir.glob("*.docx"):
try:
if pdf_dir:
pdf_path = pdf_dir / WordToPdfConverter._replace_docx_in_filename(
word_file.with_suffix('.pdf')
).name
else:
pdf_path = WordToPdfConverter._replace_docx_in_filename(
word_file.with_suffix('.pdf')
)
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
converted_files.append(pdf_file)
except Exception as e:
print(f"转换 {word_file} 失败: {str(e)}")
return converted_files
@staticmethod
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
"""
将docx对象直接转换为PDF
参数:
doc: python-docx的Document对象
output_dir: 可选,输出目录。如果未指定,将使用当前目录
返回:
生成的PDF文件路径
"""
try:
# 设置临时文件路径和输出路径
output_dir = Path(output_dir) if output_dir else Path.cwd()
output_dir.mkdir(parents=True, exist_ok=True)
# 生成临时word文件
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
doc.save(temp_docx)
# 转换为PDF
pdf_path = temp_docx.with_suffix('.pdf')
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
# 删除临时word文件
temp_docx.unlink()
return str(pdf_path)
except Exception as e:
if temp_docx.exists():
temp_docx.unlink()
raise Exception(f"转换PDF失败: {str(e)}")

查看文件

@@ -0,0 +1,246 @@
import re
from docx import Document
from docx.shared import Cm, Pt
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.enum.style import WD_STYLE_TYPE
from docx.oxml.ns import qn
from datetime import datetime
import docx
from docx.oxml import shared
from crazy_functions.doc_fns.conversation_doc.word_doc import convert_markdown_to_word
class WordFormatter:
"""聊天记录Word文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012)"""
def __init__(self):
self.doc = Document()
self._setup_document()
self._create_styles()
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 修改页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
header_run = header_para.add_run("GPT-Academic学术对话 (体验地址https://auth.gpt-academic.top/)")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(12)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
# 创建问题样式
question_style = self.doc.styles.add_style('Question_Style', WD_STYLE_TYPE.PARAGRAPH)
question_style.font.name = '黑体'
question_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
question_style.font.size = Pt(14) # 调整为14磅
question_style.font.bold = True
question_style.paragraph_format.space_before = Pt(12) # 减小段前距
question_style.paragraph_format.space_after = Pt(6)
question_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
question_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
# 创建回答样式
answer_style = self.doc.styles.add_style('Answer_Style', WD_STYLE_TYPE.PARAGRAPH)
answer_style.font.name = '仿宋'
answer_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
answer_style.font.size = Pt(12) # 调整为12磅
answer_style.paragraph_format.space_before = Pt(6)
answer_style.paragraph_format.space_after = Pt(12)
answer_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
answer_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
# 创建标题样式
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
title_style.font.name = '黑体' # 改用黑体
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
title_style.font.size = Pt(22) # 调整为22磅
title_style.font.bold = True
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
title_style.paragraph_format.space_before = Pt(0)
title_style.paragraph_format.space_after = Pt(24)
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
# 添加参考文献样式
ref_style = self.doc.styles.add_style('Reference_Style', WD_STYLE_TYPE.PARAGRAPH)
ref_style.font.name = '宋体'
ref_style._element.rPr.rFonts.set(qn('w:eastAsia'), '宋体')
ref_style.font.size = Pt(10.5) # 参考文献使用小号字体
ref_style.paragraph_format.space_before = Pt(3)
ref_style.paragraph_format.space_after = Pt(3)
ref_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
ref_style.paragraph_format.left_indent = Pt(21)
ref_style.paragraph_format.first_line_indent = Pt(-21)
# 添加参考文献标题样式
ref_title_style = self.doc.styles.add_style('Reference_Title_Style', WD_STYLE_TYPE.PARAGRAPH)
ref_title_style.font.name = '黑体'
ref_title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
ref_title_style.font.size = Pt(16) # 参考文献标题与问题同样大小
ref_title_style.font.bold = True
ref_title_style.paragraph_format.space_before = Pt(24) # 增加段前距
ref_title_style.paragraph_format.space_after = Pt(12)
ref_title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
def create_document(self, question: str, answer: str, ranked_papers: list = None):
"""写入聊天历史
Args:
question: str, 用户问题
answer: str, AI回答
ranked_papers: list, 排序后的论文列表
"""
try:
# 添加标题
title_para = self.doc.add_paragraph(style='Title_Custom')
title_run = title_para.add_run('GPT-Academic 对话记录')
# 添加日期
try:
date_para = self.doc.add_paragraph()
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_para.add_run(datetime.now().strftime('%Y年%m月%d'))
date_run.font.name = '仿宋'
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
date_run.font.size = Pt(16)
except Exception as e:
print(f"添加日期失败: {str(e)}")
raise
self.doc.add_paragraph() # 添加空行
# 添加问答对话
try:
q_para = self.doc.add_paragraph(style='Question_Style')
q_para.add_run('问题:').bold = True
q_para.add_run(str(question))
a_para = self.doc.add_paragraph(style='Answer_Style')
a_para.add_run('回答:').bold = True
a_para.add_run(convert_markdown_to_word(str(answer)))
except Exception as e:
print(f"添加问答对话失败: {str(e)}")
raise
# 添加参考文献部分
if ranked_papers:
try:
ref_title = self.doc.add_paragraph(style='Reference_Title_Style')
ref_title.add_run("参考文献")
for idx, paper in enumerate(ranked_papers, 1):
try:
ref_para = self.doc.add_paragraph(style='Reference_Style')
ref_para.add_run(f'[{idx}] ').bold = True
# 添加作者
authors = ', '.join(paper.authors[:3])
if len(paper.authors) > 3:
authors += ' et al.'
ref_para.add_run(f'{authors}. ')
# 添加标题
title_run = ref_para.add_run(paper.title)
title_run.italic = True
if hasattr(paper, 'url') and paper.url:
try:
title_run._element.rPr.rStyle = self._create_hyperlink_style()
self._add_hyperlink(ref_para, paper.title, paper.url)
except Exception as e:
print(f"添加超链接失败: {str(e)}")
# 添加期刊/会议信息
if paper.venue_name:
ref_para.add_run(f'. {paper.venue_name}')
# 添加年份
if paper.year:
ref_para.add_run(f', {paper.year}')
# 添加DOI
if paper.doi:
ref_para.add_run('. ')
if "arxiv" in paper.url:
doi_url = paper.doi
else:
doi_url = f'https://doi.org/{paper.doi}'
self._add_hyperlink(ref_para, f'DOI: {paper.doi}', doi_url)
ref_para.add_run('.')
except Exception as e:
print(f"添加第 {idx} 篇参考文献失败: {str(e)}")
continue
except Exception as e:
print(f"添加参考文献部分失败: {str(e)}")
raise
return self.doc
except Exception as e:
print(f"Word文档创建失败: {str(e)}")
import traceback
print(f"详细错误信息: {traceback.format_exc()}")
raise
def _create_hyperlink_style(self):
"""创建超链接样式"""
styles = self.doc.styles
if 'Hyperlink' not in styles:
hyperlink_style = styles.add_style('Hyperlink', WD_STYLE_TYPE.CHARACTER)
# 使用科技蓝 (#0066CC)
hyperlink_style.font.color.rgb = 0x0066CC # 科技蓝
hyperlink_style.font.underline = True
return styles['Hyperlink']
def _add_hyperlink(self, paragraph, text, url):
"""添加超链接到段落"""
# 这个是在XML级别添加超链接
part = paragraph.part
r_id = part.relate_to(url, docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK, is_external=True)
# 创建超链接XML元素
hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
hyperlink.set(docx.oxml.shared.qn('r:id'), r_id)
# 创建文本运行
new_run = docx.oxml.shared.OxmlElement('w:r')
rPr = docx.oxml.shared.OxmlElement('w:rPr')
# 应用超链接样式
rStyle = docx.oxml.shared.OxmlElement('w:rStyle')
rStyle.set(docx.oxml.shared.qn('w:val'), 'Hyperlink')
rPr.append(rStyle)
# 添加文本
t = docx.oxml.shared.OxmlElement('w:t')
t.text = text
new_run.append(rPr)
new_run.append(t)
hyperlink.append(new_run)
# 将超链接添加到段落
paragraph._p.append(hyperlink)

查看文件

@@ -0,0 +1,279 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import json
from tqdm import tqdm
import random
class AdsabsSource(DataSource):
"""ADS (Astrophysics Data System) API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: ADS API密钥,如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.adsabs.harvard.edu/v1"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
async def _make_request(self, url: str, method: str = "GET", data: dict = None) -> Optional[dict]:
"""发送HTTP请求
Args:
url: 请求URL
method: HTTP方法
data: POST请求数据
Returns:
响应内容
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
if method == "GET":
async with session.get(url) as response:
if response.status == 200:
return await response.json()
elif method == "POST":
async with session.post(url, json=data) as response:
if response.status == 200:
return await response.json()
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
def _parse_paper(self, doc: dict) -> PaperMetadata:
"""解析ADS文献数据
Args:
doc: ADS文献数据
Returns:
解析后的论文数据
"""
try:
return PaperMetadata(
title=doc.get('title', [''])[0] if doc.get('title') else '',
authors=doc.get('author', []),
abstract=doc.get('abstract', ''),
year=doc.get('year'),
doi=doc.get('doi', [''])[0] if doc.get('doi') else None,
url=f"https://ui.adsabs.harvard.edu/abs/{doc.get('bibcode')}/abstract" if doc.get('bibcode') else None,
citations=doc.get('citation_count'),
venue=doc.get('pub', ''),
institutions=doc.get('aff', []),
venue_type="journal",
venue_name=doc.get('pub', ''),
venue_info={
'volume': doc.get('volume'),
'issue': doc.get('issue'),
'pub_date': doc.get('pubdate', '')
},
source='adsabs'
)
except Exception as e:
print(f"解析文章时发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序方式 ('relevance', 'date', 'citations')
start_year: 起始年份
Returns:
论文列表
"""
try:
# 构建查询
if start_year:
query = f"{query} year:{start_year}-"
# 设置排序
sort_mapping = {
'relevance': 'score desc',
'date': 'date desc',
'citations': 'citation_count desc'
}
sort = sort_mapping.get(sort_by, 'score desc')
# 构建搜索请求
search_url = f"{self.base_url}/search/query"
params = {
"q": query,
"rows": limit,
"sort": sort,
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
# 解析结果
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
def _build_query_string(self, params: dict) -> str:
"""构建查询字符串"""
return "&".join([f"{k}={v}" for k, v in params.items()])
async def get_paper_details(self, bibcode: str) -> Optional[PaperMetadata]:
"""获取指定bibcode的论文详情"""
search_url = f"{self.base_url}/search/query"
params = {
"q": f"identifier:{bibcode}",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
if response and 'response' in response and response['response']['docs']:
return self._parse_paper(response['response']['docs'][0])
return None
async def get_related_papers(self, bibcode: str, limit: int = 100) -> List[PaperMetadata]:
"""获取相关论文"""
url = f"{self.base_url}/search/query"
params = {
"q": f"citations(identifier:{bibcode}) OR references(identifier:{bibcode})",
"rows": limit,
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"author:\"{author}\""
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_journal(
self,
journal: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊搜索论文"""
query = f"pub:\"{journal}\""
return await self.search(query, limit=limit, start_year=start_year)
async def get_latest_papers(
self,
days: int = 7,
limit: int = 100
) -> List[PaperMetadata]:
"""获取最新论文"""
query = f"entdate:[NOW-{days}DAYS TO NOW]"
return await self.search(query, limit=limit, sort_by="date")
async def get_citations(self, bibcode: str) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
url = f"{self.base_url}/search/query"
params = {
"q": f"citations(identifier:{bibcode})",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def get_references(self, bibcode: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
url = f"{self.base_url}/search/query"
params = {
"q": f"references(identifier:{bibcode})",
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
}
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
if not response or 'response' not in response:
return []
papers = []
for doc in response['response']['docs']:
paper = self._parse_paper(doc)
if paper:
papers.append(paper)
return papers
async def example_usage():
"""AdsabsSource使用示例"""
ads = AdsabsSource()
try:
# 示例1基本搜索
print("\n=== 示例1搜索黑洞相关论文 ===")
papers = await ads.search("black hole", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
# 其他示例...
except Exception as e:
print(f"发生错误: {str(e)}")
if __name__ == "__main__":
# python -m crazy_functions.review_fns.data_sources.adsabs_source
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,636 @@
import arxiv
from typing import List, Optional, Union, Literal, Dict
from datetime import datetime
from .base_source import DataSource, PaperMetadata
import os
from urllib.request import urlretrieve
import feedparser
from tqdm import tqdm
class ArxivSource(DataSource):
"""arXiv API实现"""
CATEGORIES = {
# 物理学
"Physics": {
"astro-ph": "天体物理学",
"cond-mat": "凝聚态物理",
"gr-qc": "广义相对论与量子宇宙学",
"hep-ex": "高能物理实验",
"hep-lat": "格点场论",
"hep-ph": "高能物理理论",
"hep-th": "高能物理理论",
"math-ph": "数学物理",
"nlin": "非线性科学",
"nucl-ex": "核实验",
"nucl-th": "核理论",
"physics": "物理学",
"quant-ph": "量子物理",
},
# 数学
"Mathematics": {
"math.AG": "代数几何",
"math.AT": "代数拓扑",
"math.AP": "分析与偏微分方程",
"math.CT": "范畴论",
"math.CA": "复分析",
"math.CO": "组合数学",
"math.AC": "交换代数",
"math.CV": "复变函数",
"math.DG": "微分几何",
"math.DS": "动力系统",
"math.FA": "泛函分析",
"math.GM": "一般数学",
"math.GN": "一般拓扑",
"math.GT": "几何拓扑",
"math.GR": "群论",
"math.HO": "数学史与数学概述",
"math.IT": "信息论",
"math.KT": "K理论与同调",
"math.LO": "逻辑",
"math.MP": "数学物理",
"math.MG": "度量几何",
"math.NT": "数论",
"math.NA": "数值分析",
"math.OA": "算子代数",
"math.OC": "最优化与控制",
"math.PR": "概率论",
"math.QA": "量子代数",
"math.RT": "表示论",
"math.RA": "环与代数",
"math.SP": "谱理论",
"math.ST": "统计理论",
"math.SG": "辛几何",
},
# 计算机科学
"Computer Science": {
"cs.AI": "人工智能",
"cs.CL": "计算语言学",
"cs.CC": "计算复杂性",
"cs.CE": "计算工程",
"cs.CG": "计算几何",
"cs.GT": "计算机博弈论",
"cs.CV": "计算机视觉",
"cs.CY": "计算机与社会",
"cs.CR": "密码学与安全",
"cs.DS": "数据结构与算法",
"cs.DB": "数据库",
"cs.DL": "数字图书馆",
"cs.DM": "离散数学",
"cs.DC": "分布式计算",
"cs.ET": "新兴技术",
"cs.FL": "形式语言与自动机理论",
"cs.GL": "一般文献",
"cs.GR": "图形学",
"cs.AR": "硬件架构",
"cs.HC": "人机交互",
"cs.IR": "信息检索",
"cs.IT": "信息论",
"cs.LG": "机器学习",
"cs.LO": "逻辑与计算机",
"cs.MS": "数学软件",
"cs.MA": "多智能体系统",
"cs.MM": "多媒体",
"cs.NI": "网络与互联网架构",
"cs.NE": "神经与进化计算",
"cs.NA": "数值分析",
"cs.OS": "操作系统",
"cs.OH": "其他计算机科学",
"cs.PF": "性能评估",
"cs.PL": "编程语言",
"cs.RO": "机器人学",
"cs.SI": "社会与信息网络",
"cs.SE": "软件工程",
"cs.SD": "声音",
"cs.SC": "符号计算",
"cs.SY": "系统与控制",
},
# 定量生物学
"Quantitative Biology": {
"q-bio.BM": "生物分子",
"q-bio.CB": "细胞行为",
"q-bio.GN": "基因组学",
"q-bio.MN": "分子网络",
"q-bio.NC": "神经计算",
"q-bio.OT": "其他",
"q-bio.PE": "群体与进化",
"q-bio.QM": "定量方法",
"q-bio.SC": "亚细胞过程",
"q-bio.TO": "组织与器官",
},
# 定量金融
"Quantitative Finance": {
"q-fin.CP": "计算金融",
"q-fin.EC": "经济学",
"q-fin.GN": "一般金融",
"q-fin.MF": "数学金融",
"q-fin.PM": "投资组合管理",
"q-fin.PR": "定价理论",
"q-fin.RM": "风险管理",
"q-fin.ST": "统计金融",
"q-fin.TR": "交易与市场微观结构",
},
# 统计学
"Statistics": {
"stat.AP": "应用统计",
"stat.CO": "计算统计",
"stat.ML": "机器学习",
"stat.ME": "方法论",
"stat.OT": "其他统计",
"stat.TH": "统计理论",
},
# 电气工程与系统科学
"Electrical Engineering and Systems Science": {
"eess.AS": "音频与语音处理",
"eess.IV": "图像与视频处理",
"eess.SP": "信号处理",
"eess.SY": "系统与控制",
},
# 经济学
"Economics": {
"econ.EM": "计量经济学",
"econ.GN": "一般经济学",
"econ.TH": "理论经济学",
}
}
def __init__(self):
"""初始化"""
self._initialize() # 调用初始化方法
# 修改排序选项映射
self.sort_options = {
'relevance': arxiv.SortCriterion.Relevance, # arXiv的相关性排序
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
}
self.sort_order_options = {
'ascending': arxiv.SortOrder.Ascending,
'descending': arxiv.SortOrder.Descending
}
self.default_sort = 'lastUpdatedDate'
self.default_order = 'descending'
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.client = arxiv.Client()
async def search(
self,
query: str,
limit: int = 10,
sort_by: str = None,
sort_order: str = None,
start_year: int = None
) -> List[Dict]:
"""搜索论文"""
try:
# 使用默认排序如果提供的排序选项无效
if not sort_by or sort_by not in self.sort_options:
sort_by = self.default_sort
# 使用默认排序顺序如果提供的顺序无效
if not sort_order or sort_order not in self.sort_order_options:
sort_order = self.default_order
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
search = arxiv.Search(
query=query,
max_results=limit,
sort_by=self.sort_options[sort_by],
sort_order=self.sort_order_options[sort_order]
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
"""按ID搜索论文
Args:
paper_id: 单个arXiv ID或ID列表,例如'2005.14165' 或 ['2005.14165', '2103.14030']
"""
if isinstance(paper_id, str):
paper_id = [paper_id]
search = arxiv.Search(
id_list=paper_id,
max_results=len(paper_id)
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
async def search_by_category(
self,
category: str,
limit: int = 100,
sort_by: str = 'relevance',
sort_order: str = 'descending',
start_year: int = None
) -> List[PaperMetadata]:
"""按类别搜索论文"""
query = f"cat:{category}"
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def search_by_authors(
self,
authors: List[str],
limit: int = 100,
sort_by: str = 'relevance',
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = " AND ".join([f"au:\"{author}\"" for author in authors])
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by
)
async def search_by_date_range(
self,
start_date: datetime,
end_date: datetime,
limit: int = 100,
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
sort_order: Literal['ascending', 'descending'] = 'descending'
) -> List[PaperMetadata]:
"""按日期范围搜索论文"""
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
return await self.search(
query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
"""下载论文PDF
Args:
paper_id: arXiv ID
dirpath: 保存目录
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
Returns:
保存的文件路径
"""
papers = await self.search_by_id(paper_id)
if not papers:
raise ValueError(f"未找到ID为 {paper_id} 的论文")
paper = papers[0]
if not filename:
# 清理标题中的非法字符
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
filename = f"{paper_id}_{safe_title}.pdf"
filepath = os.path.join(dirpath, filename)
urlretrieve(paper.url, filepath)
return filepath
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
"""下载论文源文件通常是LaTeX源码
Args:
paper_id: arXiv ID
dirpath: 保存目录
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
Returns:
保存的文件路径
"""
papers = await self.search_by_id(paper_id)
if not papers:
raise ValueError(f"未找到ID为 {paper_id} 的论文")
paper = papers[0]
if not filename:
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
filename = f"{paper_id}_{safe_title}.tar.gz"
filepath = os.path.join(dirpath, filename)
source_url = paper.url.replace("/pdf/", "/src/")
urlretrieve(source_url, filepath)
return filepath
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
# arXiv API不直接提供引用信息
return []
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
# arXiv API不直接提供引用信息
return []
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详情
Args:
paper_id: arXiv ID 或 DOI
Returns:
论文详细信息,如果未找到返回 None
"""
try:
# 如果是完整的 arXiv URL,提取 ID
if "arxiv.org" in paper_id:
paper_id = paper_id.split("/")[-1]
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
elif paper_id.startswith("10.48550/arXiv."):
paper_id = paper_id.split(".")[-1]
papers = await self.search_by_id(paper_id)
return papers[0] if papers else None
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
def _parse_paper_data(self, result: arxiv.Result) -> PaperMetadata:
"""解析arXiv API返回的数据"""
# 解析主要类别和次要类别
primary_category = result.primary_category
categories = result.categories
# 构建venue信息
venue_info = {
'primary_category': primary_category,
'categories': categories,
'comments': getattr(result, 'comment', None),
'journal_ref': getattr(result, 'journal_ref', None)
}
return PaperMetadata(
title=result.title,
authors=[author.name for author in result.authors],
abstract=result.summary,
year=result.published.year,
doi=result.entry_id,
url=result.pdf_url,
citations=None,
venue=f"arXiv:{primary_category}",
institutions=[],
venue_type='preprint', # arXiv论文都是预印本
venue_name='arXiv',
venue_info=venue_info,
source='arxiv' # 添加来源标记
)
async def get_latest_papers(
self,
category: str,
debug: bool = False,
batch_size: int = 50
) -> List[PaperMetadata]:
"""获取指定类别的最新论文
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
Args:
category: arXiv类别,例如
- 整个领域: 'cs'
- 具体方向: 'cs.AI'
- 多个类别: 'cs.AI+q-bio.NC'
debug: 是否为调试模式,如果为True则只返回5篇最新论文
batch_size: 批量获取论文的数量,默认50
Returns:
论文列表
Raises:
ValueError: 如果类别无效
"""
try:
# 处理类别格式
# 1. 转换为小写
# 2. 确保多个类别之间使用+连接
category = category.lower().replace(' ', '+')
# 构建RSS feed URL
feed_url = f"https://rss.arxiv.org/rss/{category}"
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
feed = feedparser.parse(feed_url)
# 检查feed是否有效
if hasattr(feed, 'status') and feed.status != 200:
raise ValueError(f"获取RSS feed失败,状态码: {feed.status}")
if not feed.entries:
print(f"警告未在feed中找到任何条目") # 添加调试信息
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
if debug:
# 调试模式只获取5篇最新论文
search = arxiv.Search(
query=f'cat:{category}',
sort_by=arxiv.SortCriterion.SubmittedDate,
sort_order=arxiv.SortOrder.Descending,
max_results=5
)
results = list(self.client.results(search))
return [self._parse_paper_data(result) for result in results]
# 正常模式:获取所有新论文
# 从RSS条目中提取arXiv ID
paper_ids = []
for entry in feed.entries:
try:
# RSS链接格式可能是以下几种
# - http://arxiv.org/abs/2403.xxxxx
# - http://arxiv.org/pdf/2403.xxxxx
# - https://arxiv.org/abs/2403.xxxxx
link = entry.link or entry.id
arxiv_id = link.split('/')[-1].replace('.pdf', '')
if arxiv_id:
paper_ids.append(arxiv_id)
except Exception as e:
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
continue
if not paper_ids:
print("未能从feed中提取到任何论文ID") # 添加调试信息
return []
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
# 批量获取论文详情
papers = []
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
for i in range(0, len(paper_ids), batch_size):
batch_ids = paper_ids[i:i + batch_size]
search = arxiv.Search(
id_list=batch_ids,
max_results=len(batch_ids)
)
batch_results = list(self.client.results(search))
papers.extend([self._parse_paper_data(result) for result in batch_results])
pbar.update(len(batch_results))
return papers
except Exception as e:
print(f"获取最新论文时发生错误: {str(e)}")
import traceback
print(traceback.format_exc()) # 添加完整的错误追踪
return []
async def example_usage():
"""ArxivSource使用示例"""
arxiv_source = ArxivSource()
try:
# 示例1基本搜索,使用不同的排序方式
# print("\n=== 示例1搜索最新的机器学习论文按提交时间排序===")
# papers = await arxiv_source.search(
# "ti:\"machine learning\"",
# limit=3,
# sort_by='submitted',
# sort_order='descending'
# )
# print(f"找到 {len(papers)} 篇论文")
# for i, paper in enumerate(papers, 1):
# print(f"\n--- 论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# print(f"arXiv ID: {paper.doi}")
# print(f"PDF URL: {paper.url}")
# if paper.abstract:
# print(f"\n摘要:")
# print(paper.abstract)
# print(f"发表venue: {paper.venue}")
# # 示例2按ID搜索
# print("\n=== 示例2按ID搜索论文 ===")
# paper_id = "2005.14165" # GPT-3论文
# papers = await arxiv_source.search_by_id(paper_id)
# if papers:
# paper = papers[0]
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# # 示例3按类别搜索
# print("\n=== 示例3搜索人工智能领域最新论文 ===")
# ai_papers = await arxiv_source.search_by_category(
# "cs.AI",
# limit=2,
# sort_by='updated',
# sort_order='descending'
# )
# for i, paper in enumerate(ai_papers, 1):
# print(f"\n--- AI论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表venue: {paper.venue}")
# # 示例4按作者搜索
# print("\n=== 示例4搜索特定作者的论文 ===")
# author_papers = await arxiv_source.search_by_authors(
# ["Bengio"],
# limit=2,
# sort_by='relevance'
# )
# for i, paper in enumerate(author_papers, 1):
# print(f"\n--- Bengio的论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表venue: {paper.venue}")
# # 示例5按日期范围搜索
# print("\n=== 示例5搜索特定日期范围的论文 ===")
# from datetime import datetime, timedelta
# end_date = datetime.now()
# start_date = end_date - timedelta(days=7) # 最近一周
# recent_papers = await arxiv_source.search_by_date_range(
# start_date,
# end_date,
# limit=2
# )
# for i, paper in enumerate(recent_papers, 1):
# print(f"\n--- 最近论文 {i} ---")
# print(f"标题: {paper.title}")
# print(f"作者: {', '.join(paper.authors)}")
# print(f"发表年份: {paper.year}")
# # 示例6下载PDF
# print("\n=== 示例6下载论文PDF ===")
# if papers: # 使用之前搜索到的GPT-3论文
# pdf_path = await arxiv_source.download_pdf(paper_id)
# print(f"PDF已下载到: {pdf_path}")
# # 示例7下载源文件
# print("\n=== 示例7下载论文源文件 ===")
# if papers:
# source_path = await arxiv_source.download_source(paper_id)
# print(f"源文件已下载到: {source_path}")
# 示例6获取最新论文
print("\n=== 示例8获取最新论文 ===")
# 获取CS.AI领域的最新论文
print("\n--- 获取AI领域最新论文 ---")
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
for i, paper in enumerate(ai_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 获取整个计算机科学领域的最新论文
print("\n--- 获取整个CS领域最新论文 ---")
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
for i, paper in enumerate(cs_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 获取多个类别的最新论文
print("\n--- 获取AI和机器学习领域最新论文 ---")
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
for i, paper in enumerate(multi_latest, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,102 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Optional
from dataclasses import dataclass
class PaperMetadata:
"""论文元数据"""
def __init__(
self,
title: str,
authors: List[str],
abstract: str,
year: int,
doi: str = None,
url: str = None,
citations: int = None,
venue: str = None,
institutions: List[str] = None,
venue_type: str = None, # 来源类型(journal/conference/preprint等)
venue_name: str = None, # 具体的期刊/会议名称
venue_info: Dict = None, # 更多来源详细信息(如影响因子、分区等)
source: str = None # 新增: 论文来源标记
):
self.title = title
self.authors = authors
self.abstract = abstract
self.year = year
self.doi = doi
self.url = url
self.citations = citations
self.venue = venue
self.institutions = institutions or []
self.venue_type = venue_type # 新增
self.venue_name = venue_name # 新增
self.venue_info = venue_info or {} # 新增
self.source = source # 新增: 存储论文来源
# 新增影响因子和分区信息,初始化为None
self._if_factor = None
self._cas_division = None
self._jcr_division = None
@property
def if_factor(self) -> Optional[float]:
"""获取影响因子"""
return self._if_factor
@if_factor.setter
def if_factor(self, value: float):
"""设置影响因子"""
self._if_factor = value
@property
def cas_division(self) -> Optional[str]:
"""获取中科院分区"""
return self._cas_division
@cas_division.setter
def cas_division(self, value: str):
"""设置中科院分区"""
self._cas_division = value
@property
def jcr_division(self) -> Optional[str]:
"""获取JCR分区"""
return self._jcr_division
@jcr_division.setter
def jcr_division(self, value: str):
"""设置JCR分区"""
self._jcr_division = value
class DataSource(ABC):
"""数据源基类"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key
self._initialize()
@abstractmethod
def _initialize(self) -> None:
"""初始化数据源"""
pass
@abstractmethod
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
"""搜索论文"""
pass
@abstractmethod
async def get_paper_details(self, paper_id: str) -> PaperMetadata:
"""获取论文详细信息"""
pass
@abstractmethod
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
pass
@abstractmethod
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
pass

文件差异因一行或多行过长而隐藏

查看文件

@@ -0,0 +1,400 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import random
class CrossrefSource(DataSource):
"""Crossref API实现"""
CONTACT_EMAILS = [
"gpt_abc_academic@163.com",
"gpt_abc_newapi@163.com",
"gpt_abc_academic_pwd@163.com"
]
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.crossref.org"
# 随机选择一个邮箱
contact_email = random.choice(self.CONTACT_EMAILS)
self.headers = {
"Accept": "application/json",
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
}
if self.api_key:
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = None,
sort_order: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序字段
sort_order: 排序顺序
start_year: 起始年份
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
# 请求更多的结果以补偿可能被过滤掉的文章
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
params = {
"query": query,
"rows": adjusted_limit,
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
# 添加年份过滤
if start_year:
params["filter"] = f"from-pub-date:{start_year}"
# 添加排序
if sort_by:
params["sort"] = sort_by
if sort_order:
params["order"] = sort_order
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
if response.status != 200:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
data = await response.json()
items = data.get("message", {}).get("items", [])
if not items:
print(f"未找到相关论文")
return []
# 过滤掉没有摘要的文章
papers = []
filtered_count = 0
for work in items:
paper = self._parse_work(work)
if paper.abstract and paper.abstract.strip():
papers.append(paper)
if len(papers) >= limit: # 达到原始请求的限制后停止
break
else:
filtered_count += 1
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
print(f"返回 {len(papers)} 篇包含摘要的论文")
return papers
async def get_paper_details(self, doi: str) -> PaperMetadata:
"""获取指定DOI的论文详情"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={
"select": (
"DOI,title,author,published-print,abstract,reference,"
"container-title,is-referenced-by-count,type,"
"publisher,ISSN,ISBN,issue,volume,page"
)
}
) as response:
if response.status != 200:
print(f"获取论文详情失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return None
try:
data = await response.json()
return self._parse_work(data.get("message", {}))
except Exception as e:
print(f"解析论文详情时发生错误: {str(e)}")
return None
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works/{doi}",
params={"select": "reference"}
) as response:
if response.status != 200:
print(f"获取参考文献失败: HTTP {response.status}")
return []
try:
data = await response.json()
# 确保我们正确处理返回的数据结构
if not isinstance(data, dict):
print(f"API返回了意外的数据格式: {type(data)}")
return []
references = data.get("message", {}).get("reference", [])
if not references:
print(f"未找到参考文献")
return []
return [
PaperMetadata(
title=ref.get("article-title", ""),
authors=[ref.get("author", "")],
year=ref.get("year"),
doi=ref.get("DOI"),
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
abstract="",
citations=None,
venue=ref.get("journal-title", ""),
institutions=[]
)
for ref in references
]
except Exception as e:
print(f"解析参考文献数据时发生错误: {str(e)}")
return []
async def get_citations(self, doi: str) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(
f"{self.base_url}/works",
params={
"filter": f"reference.DOI:{doi}",
"select": "DOI,title,author,published-print,abstract"
}
) as response:
if response.status != 200:
print(f"获取引用信息失败: HTTP {response.status}")
print(f"响应内容: {await response.text()}")
return []
try:
data = await response.json()
# 检查返回的数据结构
if isinstance(data, dict):
items = data.get("message", {}).get("items", [])
return [self._parse_work(work) for work in items]
else:
print(f"API返回了意外的数据格式: {type(data)}")
return []
except Exception as e:
print(f"解析引用数据时发生错误: {str(e)}")
return []
def _parse_work(self, work: Dict) -> PaperMetadata:
"""解析Crossref返回的数据"""
# 获取摘要 - 处理可能的不同格式
abstract = ""
if isinstance(work.get("abstract"), str):
abstract = work.get("abstract", "")
elif isinstance(work.get("abstract"), dict):
abstract = work.get("abstract", {}).get("value", "")
if not abstract:
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
# 获取机构信息
institutions = []
for author in work.get("author", []):
if "affiliation" in author:
for affiliation in author["affiliation"]:
if "name" in affiliation and affiliation["name"] not in institutions:
institutions.append(affiliation["name"])
# 获取venue信息
venue_name = work.get("container-title", [None])[0]
venue_type = work.get("type", "unknown") # 文献类型
venue_info = {
"publisher": work.get("publisher"),
"issn": work.get("ISSN", []),
"isbn": work.get("ISBN", []),
"issue": work.get("issue"),
"volume": work.get("volume"),
"page": work.get("page")
}
return PaperMetadata(
title=work.get("title", [None])[0] or "",
authors=[
author.get("given", "") + " " + author.get("family", "")
for author in work.get("author", [])
],
institutions=institutions, # 添加机构信息
abstract=abstract,
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
doi=work.get("DOI"),
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
citations=work.get("is-referenced-by-count"),
venue=venue_name,
venue_type=venue_type, # 添加venue类型
venue_name=venue_name, # 添加venue名称
venue_info=venue_info, # 添加venue详细信息
source='crossref' # 添加来源标记
)
async def search_by_authors(
self,
authors: List[str],
limit: int = 100,
sort_by: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = " ".join([f"author:\"{author}\"" for author in authors])
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
start_year=start_year
)
async def search_by_date_range(
self,
start_date: datetime,
end_date: datetime,
limit: int = 100,
sort_by: str = None,
sort_order: str = None
) -> List[PaperMetadata]:
"""按日期范围搜索论文"""
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
return await self.search(
query=query,
limit=limit,
sort_by=sort_by,
sort_order=sort_order
)
async def example_usage():
"""CrossrefSource使用示例"""
crossref = CrossrefSource(api_key=None)
try:
# 示例1基本搜索,使用不同的排序方式
print("\n=== 示例1搜索最新的机器学习论文 ===")
papers = await crossref.search(
query="machine learning",
limit=3,
sort_by="published",
sort_order="desc",
start_year=2023
)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
if paper.institutions:
print(f"机构: {', '.join(paper.institutions)}")
print(f"引用次数: {paper.citations}")
print(f"发表venue: {paper.venue}")
print(f"venue类型: {paper.venue_type}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
# 示例2按DOI获取论文详情
print("\n=== 示例2获取特定论文详情 ===")
# 使用BERT论文的DOI
doi = "10.18653/v1/N19-1423"
paper = await crossref.get_paper_details(doi)
if paper:
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
print(f"引用次数: {paper.citations}")
# 示例3按作者搜索
print("\n=== 示例3搜索特定作者的论文 ===")
author_papers = await crossref.search_by_authors(
authors=["Yoshua Bengio"],
limit=3,
sort_by="published",
start_year=2020
)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- {i}. {paper.title} ---")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
# 示例4按日期范围搜索
print("\n=== 示例4搜索特定日期范围的论文 ===")
from datetime import datetime, timedelta
end_date = datetime.now()
start_date = end_date - timedelta(days=30) # 最近一个月
recent_papers = await crossref.search_by_date_range(
start_date=start_date,
end_date=end_date,
limit=3,
sort_by="published",
sort_order="desc"
)
for i, paper in enumerate(recent_papers, 1):
print(f"\n--- 最近发表的论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
# 示例5获取论文引用信息
print("\n=== 示例5获取论文引用信息 ===")
if paper: # 使用之前获取的BERT论文
print("\n获取引用该论文的文献:")
citations = await crossref.get_citations(paper.doi)
for i, citing_paper in enumerate(citations[:3], 1):
print(f"\n--- 引用论文 {i} ---")
print(f"标题: {citing_paper.title}")
print(f"作者: {', '.join(citing_paper.authors)}")
print(f"发表年份: {citing_paper.year}")
print("\n获取该论文引用的参考文献:")
references = await crossref.get_references(paper.doi)
for i, ref_paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {ref_paper.title}")
print(f"作者: {', '.join(ref_paper.authors)}")
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
# 示例6展示venue信息的使用
print("\n=== 示例6展示期刊/会议详细信息 ===")
if papers:
paper = papers[0]
print(f"文献类型: {paper.venue_type}")
print(f"发表venue: {paper.venue_name}")
if paper.venue_info:
print("Venue详细信息:")
for key, value in paper.venue_info.items():
if value:
print(f" - {key}: {value}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,449 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import json
from tqdm import tqdm
import random
class ElsevierSource(DataSource):
"""Elsevier (Scopus) API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Elsevier API密钥,如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS)
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.elsevier.com/content"
self.headers = {
"X-ELS-APIKey": self.api_key,
"Accept": "application/json",
"Content-Type": "application/json",
# 添加更多必要的头部信息
"X-ELS-Insttoken": "", # 如果有机构令牌
}
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
"""发送HTTP请求
Args:
url: 请求URL
params: 查询参数
Returns:
JSON响应
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url, params=params) as response:
if response.status == 200:
return await response.json()
else:
# 添加更详细的错误信息
error_text = await response.text()
print(f"请求失败: {response.status}")
print(f"错误详情: {error_text}")
if response.status == 401:
print(f"使用的API密钥: {self.api_key}")
# 尝试切换到另一个API密钥
new_key = random.choice([k for k in self.API_KEYS if k != self.api_key])
print(f"尝试切换到新的API密钥: {new_key}")
self.api_key = new_key
self.headers["X-ELS-APIKey"] = new_key
# 重试请求
return await self._make_request(url, params)
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文"""
try:
params = {
"query": query,
"count": min(limit, 100),
"view": "STANDARD",
# 移除dc:description字段,因为它在STANDARD视图中不可用
"field": "dc:title,dc:creator,prism:doi,prism:coverDate,citedby-count,prism:publicationName"
}
# 添加年份过滤
if start_year:
params["date"] = f"{start_year}-present"
# 添加排序
if sort_by == "date":
params["sort"] = "-coverDate"
elif sort_by == "cited":
params["sort"] = "-citedby-count"
# 发送搜索请求
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=params
)
if not response or "search-results" not in response:
return []
# 解析搜索结果
entries = response["search-results"].get("entry", [])
papers = [paper for paper in (self._parse_entry(entry) for entry in entries) if paper is not None]
# 尝试为每篇论文获取摘要
for paper in papers:
if paper.doi:
paper.abstract = await self.fetch_abstract(paper.doi) or ""
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
def _parse_entry(self, entry: Dict) -> Optional[PaperMetadata]:
"""解析Scopus API返回的条目"""
try:
# 获取作者列表
authors = []
creator = entry.get("dc:creator")
if creator:
authors = [creator]
# 获取发表年份
year = None
if "prism:coverDate" in entry:
try:
year = int(entry["prism:coverDate"][:4])
except:
pass
# 简化venue信息
venue_info = {
'source_id': entry.get("source-id"),
'issn': entry.get("prism:issn")
}
return PaperMetadata(
title=entry.get("dc:title", ""),
authors=authors,
abstract=entry.get("dc:description", ""), # 从响应中获取摘要
year=year,
doi=entry.get("prism:doi"),
url=entry.get("prism:url"),
citations=int(entry.get("citedby-count", 0)),
venue=entry.get("prism:publicationName"),
institutions=[], # 移除机构信息
venue_type="",
venue_name=entry.get("prism:publicationName"),
venue_info=venue_info
)
except Exception as e:
print(f"解析条目时发生错误: {str(e)}")
return None
async def get_citations(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
"""获取引用该论文的文献"""
try:
params = {
"query": f"REF({doi})",
"count": min(limit, 100),
"view": "STANDARD"
}
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=params
)
if not response or "search-results" not in response:
return []
entries = response["search-results"].get("entry", [])
return [self._parse_entry(entry) for entry in entries]
except Exception as e:
print(f"获取引用文献时发生错误: {str(e)}")
return []
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取该论文引用的文献"""
try:
response = await self._make_request(
f"{self.base_url}/abstract/doi/{doi}/references",
params={"view": "STANDARD"}
)
if not response or "references" not in response:
return []
references = response["references"].get("reference", [])
papers = [paper for paper in (self._parse_reference(ref) for ref in references) if paper is not None]
return papers
except Exception as e:
print(f"获取参考文献时发生错误: {str(e)}")
return []
def _parse_reference(self, ref: Dict) -> Optional[PaperMetadata]:
"""解析参考文献数据"""
try:
authors = []
if "author-list" in ref:
author_list = ref["author-list"].get("author", [])
if isinstance(author_list, list):
authors = [f"{author.get('ce:given-name', '')} {author.get('ce:surname', '')}"
for author in author_list]
else:
authors = [f"{author_list.get('ce:given-name', '')} {author_list.get('ce:surname', '')}"]
year = None
if "prism:coverDate" in ref:
try:
year = int(ref["prism:coverDate"][:4])
except:
pass
return PaperMetadata(
title=ref.get("ce:title", ""),
authors=authors,
abstract="", # 参考文献通常不包含摘要
year=year,
doi=ref.get("prism:doi"),
url=None,
citations=None,
venue=ref.get("prism:publicationName"),
institutions=[],
venue_type="unknown",
venue_name=ref.get("prism:publicationName"),
venue_info={}
)
except Exception as e:
print(f"解析参考文献时发生错误: {str(e)}")
return None
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"AUTHOR-NAME({author})"
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_affiliation(
self,
affiliation: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按机构搜索论文"""
query = f"AF-ID({affiliation})"
return await self.search(query, limit=limit, start_year=start_year)
async def search_by_venue(
self,
venue: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊/会议搜索论文"""
query = f"SRCTITLE({venue})"
return await self.search(query, limit=limit, start_year=start_year)
async def test_api_access(self):
"""测试API访问权限"""
print(f"\n测试API密钥: {self.api_key}")
# 测试1: 基础搜索
basic_params = {
"query": "test",
"count": 1,
"view": "STANDARD"
}
print("\n1. 测试基础搜索...")
response = await self._make_request(
f"{self.base_url}/search/scopus",
params=basic_params
)
if response:
print("基础搜索成功")
print("可用字段:", list(response.get("search-results", {}).get("entry", [{}])[0].keys()))
# 测试2: 测试单篇文章访问
print("\n2. 测试文章详情访问...")
test_doi = "10.1016/j.artint.2021.103535" # 一个示例DOI
response = await self._make_request(
f"{self.base_url}/abstract/doi/{test_doi}",
params={"view": "STANDARD"} # 改为STANDARD视图
)
if response:
print("文章详情访问成功")
else:
print("文章详情访问失败")
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详细信息
注意当前API权限不支持获取详细信息,返回None
Args:
paper_id: 论文ID
Returns:
None,因为当前API权限不支持此功能
"""
return None
async def fetch_abstract(self, doi: str) -> Optional[str]:
"""获取论文摘要
使用Scopus Abstract API获取论文摘要
Args:
doi: 论文的DOI
Returns:
摘要文本,如果获取失败则返回None
"""
try:
# 使用Abstract API而不是Search API
response = await self._make_request(
f"{self.base_url}/abstract/doi/{doi}",
params={
"view": "FULL" # 使用FULL视图
}
)
if response and "abstracts-retrieval-response" in response:
# 从coredata中获取摘要
coredata = response["abstracts-retrieval-response"].get("coredata", {})
return coredata.get("dc:description", "")
return None
except Exception as e:
print(f"获取摘要时发生错误: {str(e)}")
return None
async def example_usage():
"""ElsevierSource使用示例"""
elsevier = ElsevierSource()
try:
# 首先测试API访问权限
print("\n=== 测试API访问权限 ===")
await elsevier.test_api_access()
# 示例1基本搜索
print("\n=== 示例1搜索机器学习相关论文 ===")
papers = await elsevier.search("machine learning", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
print("期刊信息:")
for key, value in paper.venue_info.items():
if value: # 只打印非空值
print(f" - {key}: {value}")
# 示例2获取引用信息
if papers and papers[0].doi:
print("\n=== 示例2获取引用该论文的文献 ===")
citations = await elsevier.get_citations(papers[0].doi, limit=3)
for i, paper in enumerate(citations, 1):
print(f"\n--- 引用论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例3获取参考文献
if papers and papers[0].doi:
print("\n=== 示例3获取论文的参考文献 ===")
references = await elsevier.get_references(papers[0].doi)
for i, paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"期刊/会议: {paper.venue}")
# 示例4按作者搜索
print("\n=== 示例4按作者搜索 ===")
author_papers = await elsevier.search_by_author("Hinton G", limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例5按机构搜索
print("\n=== 示例5按机构搜索 ===")
affiliation_papers = await elsevier.search_by_affiliation("60027950", limit=3) # MIT的机构ID
for i, paper in enumerate(affiliation_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
print(f"期刊/会议: {paper.venue}")
# 示例6获取论文摘要
print("\n=== 示例6获取论文摘要 ===")
test_doi = "10.1016/j.artint.2021.103535"
abstract = await elsevier.fetch_abstract(test_doi)
if abstract:
print(f"摘要: {abstract[:200]}...") # 只显示前200个字符
else:
print("无法获取摘要")
# 在搜索结果中显示摘要
print("\n=== 示例7搜索结果中的摘要 ===")
papers = await elsevier.search("machine learning", limit=1)
for paper in papers:
print(f"标题: {paper.title}")
print(f"摘要: {paper.abstract[:200]}..." if paper.abstract else "摘要: 无")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,698 @@
import aiohttp
import asyncio
import base64
import json
import random
from datetime import datetime
from typing import List, Dict, Optional, Union, Any
class GitHubSource:
"""GitHub API实现"""
# 默认API密钥列表 - 可以放置多个GitHub令牌
API_KEYS = [
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
]
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
"""初始化GitHub API客户端
Args:
api_key: GitHub个人访问令牌或令牌列表
"""
if api_key is None:
self.api_keys = self.API_KEYS
elif isinstance(api_key, str):
self.api_keys = [api_key]
else:
self.api_keys = api_key
self._initialize()
def _initialize(self) -> None:
"""初始化客户端,设置默认参数"""
self.base_url = "https://api.github.com"
self.headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
"User-Agent": "GitHub-API-Python-Client"
}
# 如果有可用的API密钥,随机选择一个
if self.api_keys:
selected_key = random.choice(self.api_keys)
self.headers["Authorization"] = f"Bearer {selected_key}"
print(f"已随机选择API密钥进行认证")
else:
print("警告: 未提供API密钥,将受到GitHub API请求限制")
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
"""发送API请求
Args:
method: HTTP方法 (GET, POST, PUT, DELETE等)
endpoint: API端点
params: URL参数
data: 请求体数据
Returns:
解析后的响应JSON
"""
async with aiohttp.ClientSession(headers=self.headers) as session:
url = f"{self.base_url}{endpoint}"
# 为调试目的打印请求信息
print(f"请求: {method} {url}")
if params:
print(f"参数: {params}")
# 发送请求
request_kwargs = {}
if params:
request_kwargs["params"] = params
if data:
request_kwargs["json"] = data
async with session.request(method, url, **request_kwargs) as response:
response_text = await response.text()
# 检查HTTP状态码
if response.status >= 400:
print(f"API请求失败: HTTP {response.status}")
print(f"响应内容: {response_text}")
return None
# 解析JSON响应
try:
return json.loads(response_text)
except json.JSONDecodeError:
print(f"JSON解析错误: {response_text}")
return None
# ===== 用户相关方法 =====
async def get_user(self, username: Optional[str] = None) -> Dict:
"""获取用户信息
Args:
username: 指定用户名,不指定则获取当前授权用户
Returns:
用户信息字典
"""
endpoint = "/user" if username is None else f"/users/{username}"
return await self._request("GET", endpoint)
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户的仓库列表
Args:
username: 指定用户名,不指定则获取当前授权用户
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
params = {
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_user_starred(self, username: Optional[str] = None,
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取用户星标的仓库
Args:
username: 指定用户名,不指定则获取当前授权用户
per_page: 每页结果数量
page: 页码
Returns:
星标仓库列表
"""
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 仓库相关方法 =====
async def get_repo(self, owner: str, repo: str) -> Dict:
"""获取仓库信息
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
仓库信息
"""
endpoint = f"/repos/{owner}/{repo}"
return await self._request("GET", endpoint)
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的分支列表
Args:
owner: 仓库所有者
repo: 仓库名
per_page: 每页结果数量
page: 页码
Returns:
分支列表
"""
endpoint = f"/repos/{owner}/{repo}/branches"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的提交历史
Args:
owner: 仓库所有者
repo: 仓库名
sha: 特定提交SHA或分支名
path: 文件路径筛选
per_page: 每页结果数量
page: 页码
Returns:
提交列表
"""
endpoint = f"/repos/{owner}/{repo}/commits"
params = {
"per_page": per_page,
"page": page
}
if sha:
params["sha"] = sha
if path:
params["path"] = path
return await self._request("GET", endpoint, params=params)
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
"""获取特定提交的详情
Args:
owner: 仓库所有者
repo: 仓库名
commit_sha: 提交SHA
Returns:
提交详情
"""
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
return await self._request("GET", endpoint)
# ===== 内容相关方法 =====
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
"""获取文件内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 文件路径
ref: 分支名、标签名或提交SHA
Returns:
文件内容信息
"""
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
response = await self._request("GET", endpoint, params=params)
if response and isinstance(response, dict) and "content" in response:
try:
# 解码Base64编码的文件内容
content = base64.b64decode(response["content"].encode()).decode()
response["decoded_content"] = content
except Exception as e:
print(f"解码文件内容时出错: {str(e)}")
return response
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
"""获取目录内容
Args:
owner: 仓库所有者
repo: 仓库名
path: 目录路径
ref: 分支名、标签名或提交SHA
Returns:
目录内容列表
"""
# 注意此方法与get_file_content使用相同的端点,但对于目录会返回列表
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
params = {}
if ref:
params["ref"] = ref
return await self._request("GET", endpoint, params=params)
# ===== Issues相关方法 =====
async def get_issues(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Issues列表
Args:
owner: 仓库所有者
repo: 仓库名
state: Issue状态 (open, closed, all)
sort: 排序方式 (created, updated, comments)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Issues列表
"""
endpoint = f"/repos/{owner}/{repo}/issues"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
"""获取特定Issue的详情
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
Issue详情
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
return await self._request("GET", endpoint)
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
"""获取Issue的评论
Args:
owner: 仓库所有者
repo: 仓库名
issue_number: Issue编号
Returns:
评论列表
"""
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
return await self._request("GET", endpoint)
# ===== Pull Requests相关方法 =====
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取仓库的Pull Request列表
Args:
owner: 仓库所有者
repo: 仓库名
state: PR状态 (open, closed, all)
sort: 排序方式 (created, updated, popularity, long-running)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
Pull Request列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls"
params = {
"state": state,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
"""获取特定Pull Request的详情
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
Pull Request详情
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
return await self._request("GET", endpoint)
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
"""获取Pull Request中修改的文件
Args:
owner: 仓库所有者
repo: 仓库名
pr_number: Pull Request编号
Returns:
修改文件列表
"""
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
return await self._request("GET", endpoint)
# ===== 搜索相关方法 =====
async def search_repositories(self, query: str, sort: str = "stars",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索仓库
Args:
query: 搜索关键词
sort: 排序方式 (stars, forks, updated)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/repositories"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_code(self, query: str, sort: str = "indexed",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索代码
Args:
query: 搜索关键词
sort: 排序方式 (indexed)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/code"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_issues(self, query: str, sort: str = "created",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索Issues和Pull Requests
Args:
query: 搜索关键词
sort: 排序方式 (created, updated, comments)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/issues"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def search_users(self, query: str, sort: str = "followers",
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
"""搜索用户
Args:
query: 搜索关键词
sort: 排序方式 (followers, repositories, joined)
order: 排序顺序 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
搜索结果
"""
endpoint = "/search/users"
params = {
"q": query,
"sort": sort,
"order": order,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 组织相关方法 =====
async def get_organization(self, org: str) -> Dict:
"""获取组织信息
Args:
org: 组织名称
Returns:
组织信息
"""
endpoint = f"/orgs/{org}"
return await self._request("GET", endpoint)
async def get_organization_repos(self, org: str, type: str = "all",
sort: str = "created", direction: str = "desc",
per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织的仓库列表
Args:
org: 组织名称
type: 仓库类型 (all, public, private, forks, sources, member, internal)
sort: 排序方式 (created, updated, pushed, full_name)
direction: 排序方向 (asc, desc)
per_page: 每页结果数量
page: 页码
Returns:
仓库列表
"""
endpoint = f"/orgs/{org}/repos"
params = {
"type": type,
"sort": sort,
"direction": direction,
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
"""获取组织成员列表
Args:
org: 组织名称
per_page: 每页结果数量
page: 页码
Returns:
成员列表
"""
endpoint = f"/orgs/{org}/members"
params = {
"per_page": per_page,
"page": page
}
return await self._request("GET", endpoint, params=params)
# ===== 更复杂的操作 =====
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
"""获取仓库使用的编程语言及其比例
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
语言使用情况
"""
endpoint = f"/repos/{owner}/{repo}/languages"
return await self._request("GET", endpoint)
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的贡献者统计
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
贡献者统计信息
"""
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
return await self._request("GET", endpoint)
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
"""获取仓库的提交活动
Args:
owner: 仓库所有者
repo: 仓库名
Returns:
提交活动统计
"""
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
return await self._request("GET", endpoint)
async def example_usage():
"""GitHubSource使用示例"""
# 创建客户端实例可选传入API令牌
# github = GitHubSource(api_key="your_github_token")
github = GitHubSource()
try:
# 示例1搜索热门Python仓库
print("\n=== 示例1搜索热门Python仓库 ===")
repos = await github.search_repositories(
query="language:python stars:>1000",
sort="stars",
order="desc",
per_page=5
)
if repos and "items" in repos:
for i, repo in enumerate(repos["items"], 1):
print(f"\n--- 仓库 {i} ---")
print(f"名称: {repo['full_name']}")
print(f"描述: {repo['description']}")
print(f"星标数: {repo['stargazers_count']}")
print(f"Fork数: {repo['forks_count']}")
print(f"最近更新: {repo['updated_at']}")
print(f"URL: {repo['html_url']}")
# 示例2获取特定仓库的详情
print("\n=== 示例2获取特定仓库的详情 ===")
repo_details = await github.get_repo("microsoft", "vscode")
if repo_details:
print(f"名称: {repo_details['full_name']}")
print(f"描述: {repo_details['description']}")
print(f"星标数: {repo_details['stargazers_count']}")
print(f"Fork数: {repo_details['forks_count']}")
print(f"默认分支: {repo_details['default_branch']}")
print(f"开源许可: {repo_details.get('license', {}).get('name', '')}")
print(f"语言: {repo_details['language']}")
print(f"Open Issues数: {repo_details['open_issues_count']}")
# 示例3获取仓库的提交历史
print("\n=== 示例3获取仓库的最近提交 ===")
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
if commits:
for i, commit in enumerate(commits, 1):
print(f"\n--- 提交 {i} ---")
print(f"SHA: {commit['sha'][:7]}")
print(f"作者: {commit['commit']['author']['name']}")
print(f"日期: {commit['commit']['author']['date']}")
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
# 示例4搜索代码
print("\n=== 示例4搜索代码 ===")
code_results = await github.search_code(
query="filename:README.md language:markdown pytorch in:file",
per_page=3
)
if code_results and "items" in code_results:
print(f"共找到: {code_results['total_count']} 个结果")
for i, item in enumerate(code_results["items"], 1):
print(f"\n--- 代码 {i} ---")
print(f"仓库: {item['repository']['full_name']}")
print(f"文件: {item['path']}")
print(f"URL: {item['html_url']}")
# 示例5获取文件内容
print("\n=== 示例5获取文件内容 ===")
file_content = await github.get_file_content("python", "cpython", "README.rst")
if file_content and "decoded_content" in file_content:
content = file_content["decoded_content"]
print(f"文件名: {file_content['name']}")
print(f"大小: {file_content['size']} 字节")
print(f"内容预览: {content[:200]}...")
# 示例6获取仓库使用的编程语言
print("\n=== 示例6获取仓库使用的编程语言 ===")
languages = await github.get_repository_languages("facebook", "react")
if languages:
print(f"React仓库使用的编程语言:")
for lang, bytes_of_code in languages.items():
print(f"- {lang}: {bytes_of_code} 字节")
# 示例7获取组织信息
print("\n=== 示例7获取组织信息 ===")
org_info = await github.get_organization("google")
if org_info:
print(f"名称: {org_info['name']}")
print(f"描述: {org_info.get('description', '')}")
print(f"位置: {org_info.get('location', '未指定')}")
print(f"公共仓库数: {org_info['public_repos']}")
print(f"成员数: {org_info.get('public_members', 0)}")
print(f"URL: {org_info['html_url']}")
# 示例8获取用户信息
print("\n=== 示例8获取用户信息 ===")
user_info = await github.get_user("torvalds")
if user_info:
print(f"名称: {user_info['name']}")
print(f"公司: {user_info.get('company', '')}")
print(f"博客: {user_info.get('blog', '')}")
print(f"位置: {user_info.get('location', '未指定')}")
print(f"公共仓库数: {user_info['public_repos']}")
print(f"关注者数: {user_info['followers']}")
print(f"URL: {user_info['html_url']}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,142 @@
import json
import os
from typing import Dict, Optional
class JournalMetrics:
"""期刊指标管理类"""
def __init__(self):
self.journal_data: Dict = {} # 期刊名称到指标的映射
self.issn_map: Dict = {} # ISSN到指标的映射
self.name_map: Dict = {} # 标准化名称到指标的映射
self._load_journal_data()
def _normalize_journal_name(self, name: str) -> str:
"""标准化期刊名称
Args:
name: 原始期刊名称
Returns:
标准化后的期刊名称
"""
if not name:
return ""
# 转换为小写
name = name.lower()
# 移除常见的前缀和后缀
prefixes = ['the ', 'proceedings of ', 'journal of ']
suffixes = [' journal', ' proceedings', ' magazine', ' review', ' letters']
for prefix in prefixes:
if name.startswith(prefix):
name = name[len(prefix):]
for suffix in suffixes:
if name.endswith(suffix):
name = name[:-len(suffix)]
# 移除特殊字符,保留字母、数字和空格
name = ''.join(c for c in name if c.isalnum() or c.isspace())
# 移除多余的空格
name = ' '.join(name.split())
return name
def _convert_if_value(self, if_str: str) -> Optional[float]:
"""转换IF值为float,处理特殊情况"""
try:
if if_str.startswith('<'):
# 对于<0.1这样的值,返回0.1
return float(if_str.strip('<'))
return float(if_str)
except (ValueError, AttributeError):
return None
def _load_journal_data(self):
"""加载期刊数据"""
try:
file_path = os.path.join(os.path.dirname(__file__), 'cas_if.json')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 建立期刊名称到指标的映射
for journal in data:
# 准备指标数据
metrics = {
'if_factor': self._convert_if_value(journal.get('IF')),
'jcr_division': journal.get('Q'),
'cas_division': journal.get('B')
}
# 存储期刊名称映射(使用标准化名称)
if journal.get('journal'):
normalized_name = self._normalize_journal_name(journal['journal'])
self.journal_data[normalized_name] = metrics
self.name_map[normalized_name] = metrics
# 存储期刊缩写映射
if journal.get('jabb'):
normalized_abbr = self._normalize_journal_name(journal['jabb'])
self.journal_data[normalized_abbr] = metrics
self.name_map[normalized_abbr] = metrics
# 存储ISSN映射
if journal.get('issn'):
self.issn_map[journal['issn']] = metrics
if journal.get('eissn'):
self.issn_map[journal['eissn']] = metrics
except Exception as e:
print(f"加载期刊数据时出错: {str(e)}")
self.journal_data = {}
self.issn_map = {}
self.name_map = {}
def get_journal_metrics(self, venue_name: str, venue_info: dict) -> dict:
"""获取期刊指标
Args:
venue_name: 期刊名称
venue_info: 期刊详细信息
Returns:
包含期刊指标的字典
"""
try:
metrics = {}
# 1. 首先尝试通过ISSN匹配
if venue_info and 'issn' in venue_info:
issn_value = venue_info['issn']
# 处理ISSN可能是列表的情况
if isinstance(issn_value, list):
# 尝试每个ISSN
for issn in issn_value:
metrics = self.issn_map.get(issn, {})
if metrics: # 如果找到匹配的指标,就停止搜索
break
else: # ISSN是字符串的情况
metrics = self.issn_map.get(issn_value, {})
# 2. 如果ISSN匹配失败,尝试通过期刊名称匹配
if not metrics and venue_name:
# 标准化期刊名称
normalized_name = self._normalize_journal_name(venue_name)
metrics = self.name_map.get(normalized_name, {})
# 如果完全匹配失败,尝试部分匹配
# if not metrics:
# for db_name, db_metrics in self.name_map.items():
# if normalized_name in db_name:
# metrics = db_metrics
# break
return metrics
except Exception as e:
print(f"获取期刊指标时出错: {str(e)}")
return {}

查看文件

@@ -0,0 +1,163 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from .base_source import DataSource, PaperMetadata
import os
from urllib.parse import quote
class OpenAlexSource(DataSource):
"""OpenAlex API实现"""
def _initialize(self) -> None:
self.base_url = "https://api.openalex.org"
self.mailto = "xxxxxxxxxxxxxxxxxxxxxxxx@163.com" # 直接写入邮件地址
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
params = {"mailto": self.mailto} if self.mailto else {}
params.update({
"filter": f"title.search:{query}",
"per-page": limit
})
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
try:
response.raise_for_status()
data = await response.json()
results = data.get("results", [])
return [self._parse_work(work) for work in results]
except Exception as e:
print(f"搜索出错: {str(e)}")
return []
def _parse_work(self, work: Dict) -> PaperMetadata:
"""解析OpenAlex返回的数据"""
# 获取作者信息
raw_author_names = [
authorship.get("raw_author_name", "")
for authorship in work.get("authorships", [])
if authorship
]
# 处理作者名字格式
authors = [
self._reformat_name(author)
for author in raw_author_names
]
# 获取机构信息
institutions = [
inst.get("display_name", "")
for authorship in work.get("authorships", [])
for inst in authorship.get("institutions", [])
if inst
]
# 获取主要发表位置信息
primary_location = work.get("primary_location") or {}
source = primary_location.get("source") or {}
venue = source.get("display_name")
# 获取发表日期
year = work.get("publication_year")
return PaperMetadata(
title=work.get("title", ""),
authors=authors,
institutions=institutions,
abstract=work.get("abstract", ""),
year=year,
doi=work.get("doi"),
url=work.get("doi"), # OpenAlex 使用 DOI 作为 URL
citations=work.get("cited_by_count"),
venue=venue
)
def _reformat_name(self, name: str) -> str:
"""重新格式化作者名字"""
if "," not in name:
return name
family, given_names = (x.strip() for x in name.split(",", maxsplit=1))
return f"{given_names} {family}"
async def get_paper_details(self, doi: str) -> PaperMetadata:
"""获取指定DOI的论文详情"""
params = {"mailto": self.mailto} if self.mailto else {}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}",
params=params
) as response:
data = await response.json()
return self._parse_work(data)
async def get_references(self, doi: str) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
params = {"mailto": self.mailto} if self.mailto else {}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}/references",
params=params
) as response:
data = await response.json()
return [self._parse_work(work) for work in data.get("results", [])]
async def get_citations(self, doi: str) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
params = {"mailto": self.mailto} if self.mailto else {}
params.update({
"filter": f"cites:doi:{doi}",
"per-page": 100
})
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/works",
params=params
) as response:
data = await response.json()
return [self._parse_work(work) for work in data.get("results", [])]
async def example_usage():
"""OpenAlexSource使用示例"""
# 初始化OpenAlexSource
openalex = OpenAlexSource()
try:
print("正在搜索论文...")
# 搜索与"artificial intelligence"相关的论文,限制返回5篇
papers = await openalex.search(query="artificial intelligence", limit=5)
if not papers:
print("未获取到任何论文信息")
return
print(f"找到 {len(papers)} 篇论文")
# 打印搜索结果
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors) if paper.authors else '未知'}")
if paper.institutions:
print(f"机构: {', '.join(paper.institutions)}")
print(f"发表年份: {paper.year if paper.year else '未知'}")
print(f"DOI: {paper.doi if paper.doi else '未知'}")
print(f"URL: {paper.url if paper.url else '未知'}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
print(f"引用次数: {paper.citations if paper.citations is not None else '未知'}")
print(f"发表venue: {paper.venue if paper.venue else '未知'}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
# 如果直接运行此文件,执行示例代码
if __name__ == "__main__":
import asyncio
# 运行示例
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,458 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import asyncio
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import xml.etree.ElementTree as ET
from urllib.parse import quote
import json
from tqdm import tqdm
import random
class PubMedSource(DataSource):
"""PubMed API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: PubMed API密钥,如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
self.headers = {
"User-Agent": "Mozilla/5.0 PubMedDataSource/1.0",
"Accept": "application/json"
}
async def _make_request(self, url: str) -> Optional[str]:
"""发送HTTP请求
Args:
url: 请求URL
Returns:
响应内容
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url) as response:
if response.status == 200:
return await response.text()
else:
print(f"请求失败: {response.status}")
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = "relevance",
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序方式 ('relevance', 'date', 'citations')
start_year: 起始年份
Returns:
论文列表
"""
try:
# 添加年份过滤
if start_year:
query = f"{query} AND {start_year}:3000[dp]"
# 构建搜索URL
search_url = (
f"{self.base_url}/esearch.fcgi?"
f"db=pubmed&term={quote(query)}&retmax={limit}"
f"&usehistory=y&api_key={self.api_key}"
)
if sort_by == "date":
search_url += "&sort=date"
# 获取搜索结果
response = await self._make_request(search_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
id_list = root.findall(".//Id")
pmids = [id_elem.text for id_elem in id_list]
if not pmids:
return []
# 批量获取论文详情
papers = []
batch_size = 50
for i in range(0, len(pmids), batch_size):
batch = pmids[i:i + batch_size]
batch_papers = await self._fetch_papers_batch(batch)
papers.extend(batch_papers)
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
async def _fetch_papers_batch(self, pmids: List[str]) -> List[PaperMetadata]:
"""批量获取论文详情
Args:
pmids: PubMed ID列表
Returns:
论文详情列表
"""
try:
# 构建批量获取URL
fetch_url = (
f"{self.base_url}/efetch.fcgi?"
f"db=pubmed&id={','.join(pmids)}"
f"&retmode=xml&api_key={self.api_key}"
)
response = await self._make_request(fetch_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
articles = root.findall(".//PubmedArticle")
return [self._parse_article(article) for article in articles]
except Exception as e:
print(f"获取论文批次时发生错误: {str(e)}")
return []
def _parse_article(self, article: ET.Element) -> PaperMetadata:
"""解析PubMed文章XML
Args:
article: XML元素
Returns:
解析后的论文数据
"""
try:
# 提取基本信息
pmid = article.find(".//PMID").text
article_meta = article.find(".//Article")
# 获取标题
title = article_meta.find(".//ArticleTitle")
title = title.text if title is not None else ""
# 获取作者列表
authors = []
author_list = article_meta.findall(".//Author")
for author in author_list:
last_name = author.find("LastName")
fore_name = author.find("ForeName")
if last_name is not None and fore_name is not None:
authors.append(f"{fore_name.text} {last_name.text}")
elif last_name is not None:
authors.append(last_name.text)
# 获取摘要
abstract = article_meta.find(".//Abstract/AbstractText")
abstract = abstract.text if abstract is not None else ""
# 获取发表年份
pub_date = article_meta.find(".//PubDate/Year")
year = int(pub_date.text) if pub_date is not None else None
# 获取DOI
doi = article.find(".//ELocationID[@EIdType='doi']")
doi = doi.text if doi is not None else None
# 获取期刊信息
journal = article_meta.find(".//Journal")
if journal is not None:
journal_title = journal.find(".//Title")
venue = journal_title.text if journal_title is not None else None
# 获取期刊详细信息
venue_info = {
'issn': journal.findtext(".//ISSN"),
'volume': journal.findtext(".//Volume"),
'issue': journal.findtext(".//Issue"),
'pub_date': journal.findtext(".//PubDate/MedlineDate") or
f"{journal.findtext('.//PubDate/Year', '')}-{journal.findtext('.//PubDate/Month', '')}"
}
else:
venue = None
venue_info = {}
# 获取机构信息
institutions = []
affiliations = article_meta.findall(".//Affiliation")
for affiliation in affiliations:
if affiliation is not None and affiliation.text:
institutions.append(affiliation.text)
return PaperMetadata(
title=title,
authors=authors,
abstract=abstract,
year=year,
doi=doi,
url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" if pmid else None,
citations=None, # PubMed API不直接提供引用数据
venue=venue,
institutions=institutions,
venue_type="journal",
venue_name=venue,
venue_info=venue_info,
source='pubmed' # 添加来源标记
)
except Exception as e:
print(f"解析文章时发生错误: {str(e)}")
return None
async def get_paper_details(self, pmid: str) -> Optional[PaperMetadata]:
"""获取指定PMID的论文详情"""
papers = await self._fetch_papers_batch([pmid])
return papers[0] if papers else None
async def get_related_papers(self, pmid: str, limit: int = 100) -> List[PaperMetadata]:
"""获取相关论文
使用PubMed的相关文章功能
Args:
pmid: PubMed ID
limit: 返回结果数量限制
Returns:
相关论文列表
"""
try:
# 构建相关文章URL
link_url = (
f"{self.base_url}/elink.fcgi?"
f"db=pubmed&id={pmid}&cmd=neighbor&api_key={self.api_key}"
)
response = await self._make_request(link_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
related_ids = root.findall(".//Link/Id")
pmids = [id_elem.text for id_elem in related_ids][:limit]
if not pmids:
return []
# 获取相关论文详情
return await self._fetch_papers_batch(pmids)
except Exception as e:
print(f"获取相关论文时发生错误: {str(e)}")
return []
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"{author}[Author]"
if start_year:
query += f" AND {start_year}:3000[dp]"
return await self.search(query, limit=limit)
async def search_by_journal(
self,
journal: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊搜索论文"""
query = f"{journal}[Journal]"
if start_year:
query += f" AND {start_year}:3000[dp]"
return await self.search(query, limit=limit)
async def get_latest_papers(
self,
days: int = 7,
limit: int = 100
) -> List[PaperMetadata]:
"""获取最新论文
Args:
days: 最近几天的论文
limit: 返回结果数量限制
Returns:
最新论文列表
"""
query = f"last {days} days[dp]"
return await self.search(query, limit=limit, sort_by="date")
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
"""获取引用该论文的文献
注意PubMed API本身不提供引用数据,此方法将返回空列表
未来可以考虑集成其他数据源(如CrossRef)来获取引用信息
Args:
paper_id: PubMed ID
Returns:
空列表,因为PubMed不提供引用数据
"""
return []
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
"""获取该论文引用的文献
从PubMed文章的参考文献列表获取引用的文献
Args:
paper_id: PubMed ID
Returns:
引用的文献列表
"""
try:
# 构建获取参考文献的URL
refs_url = (
f"{self.base_url}/elink.fcgi?"
f"dbfrom=pubmed&db=pubmed&id={paper_id}"
f"&cmd=neighbor_history&linkname=pubmed_pubmed_refs"
f"&api_key={self.api_key}"
)
response = await self._make_request(refs_url)
if not response:
return []
# 解析XML响应
root = ET.fromstring(response)
ref_ids = root.findall(".//Link/Id")
pmids = [id_elem.text for id_elem in ref_ids]
if not pmids:
return []
# 获取参考文献详情
return await self._fetch_papers_batch(pmids)
except Exception as e:
print(f"获取参考文献时发生错误: {str(e)}")
return []
async def example_usage():
"""PubMedSource使用示例"""
pubmed = PubMedSource()
try:
# 示例1基本搜索
print("\n=== 示例1搜索COVID-19相关论文 ===")
papers = await pubmed.search("COVID-19", limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要: {paper.abstract[:200]}...")
# 示例2获取论文详情
if papers:
print("\n=== 示例2获取论文详情 ===")
paper_id = papers[0].url.split("/")[-2]
paper = await pubmed.get_paper_details(paper_id)
if paper:
print(f"标题: {paper.title}")
print(f"期刊: {paper.venue}")
print(f"机构: {', '.join(paper.institutions)}")
# 示例3获取相关论文
if papers:
print("\n=== 示例3获取相关论文 ===")
related = await pubmed.get_related_papers(paper_id, limit=3)
for i, paper in enumerate(related, 1):
print(f"\n--- 相关论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
# 示例4按作者搜索
print("\n=== 示例4按作者搜索 ===")
author_papers = await pubmed.search_by_author("Fauci AS", limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表年份: {paper.year}")
# 示例5按期刊搜索
print("\n=== 示例5按期刊搜索 ===")
journal_papers = await pubmed.search_by_journal("Nature", limit=3)
for i, paper in enumerate(journal_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表年份: {paper.year}")
# 示例6获取最新论文
print("\n=== 示例6获取最新论文 ===")
latest = await pubmed.get_latest_papers(days=7, limit=3)
for i, paper in enumerate(latest, 1):
print(f"\n--- 最新论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表日期: {paper.venue_info.get('pub_date')}")
# 示例7获取论文的参考文献
if papers:
print("\n=== 示例7获取论文的参考文献 ===")
paper_id = papers[0].url.split("/")[-2]
references = await pubmed.get_references(paper_id)
for i, paper in enumerate(references[:3], 1):
print(f"\n--- 参考文献 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 示例8尝试获取引用信息将返回空列表
if papers:
print("\n=== 示例8获取论文的引用信息 ===")
paper_id = papers[0].url.split("/")[-2]
citations = await pubmed.get_citations(paper_id)
print(f"引用数据:{len(citations)} (PubMed API不提供引用信息)")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,326 @@
from pathlib import Path
import requests
from bs4 import BeautifulSoup
import time
from loguru import logger
import PyPDF2
import io
class SciHub:
# 更新的镜像列表,包含更多可用的镜像
MIRRORS = [
'https://sci-hub.se/',
'https://sci-hub.st/',
'https://sci-hub.ru/',
'https://sci-hub.wf/',
'https://sci-hub.ee/',
'https://sci-hub.ren/',
'https://sci-hub.tf/',
'https://sci-hub.si/',
'https://sci-hub.do/',
'https://sci-hub.hkvisa.net/',
'https://sci-hub.mksa.top/',
'https://sci-hub.shop/',
'https://sci-hub.yncjkj.com/',
'https://sci-hub.41610.org/',
'https://sci-hub.automic.us/',
'https://sci-hub.et-fine.com/',
'https://sci-hub.pooh.mu/',
'https://sci-hub.bban.top/',
'https://sci-hub.usualwant.com/',
'https://sci-hub.unblockit.kim/'
]
def __init__(self, doi: str, path: Path, url=None, timeout=60, use_proxy=True):
self.timeout = timeout
self.path = path
self.doi = str(doi)
self.use_proxy = use_proxy
self.headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
}
self.payload = {
'sci-hub-plugin-check': '',
'request': self.doi
}
self.url = url if url else self.MIRRORS[0]
self.proxies = {
"http": "socks5h://localhost:10880",
"https": "socks5h://localhost:10880",
} if use_proxy else None
def _test_proxy_connection(self):
"""测试代理连接是否可用"""
if not self.use_proxy:
return True
try:
# 测试代理连接
test_response = requests.get(
'https://httpbin.org/ip',
proxies=self.proxies,
timeout=10
)
if test_response.status_code == 200:
logger.info("代理连接测试成功")
return True
except Exception as e:
logger.warning(f"代理连接测试失败: {str(e)}")
return False
return False
def _check_pdf_validity(self, content):
"""检查PDF文件是否有效"""
try:
# 使用PyPDF2检查PDF是否可以正常打开和读取
pdf = PyPDF2.PdfReader(io.BytesIO(content))
if len(pdf.pages) > 0:
return True
return False
except Exception as e:
logger.error(f"PDF文件无效: {str(e)}")
return False
def _send_request(self):
"""发送请求到Sci-Hub镜像站点"""
# 首先测试代理连接
if self.use_proxy and not self._test_proxy_connection():
logger.warning("代理连接不可用,切换到直连模式")
self.use_proxy = False
self.proxies = None
last_exception = None
working_mirrors = []
# 先测试哪些镜像可用
logger.info("正在测试镜像站点可用性...")
for mirror in self.MIRRORS:
try:
test_response = requests.get(
mirror,
headers=self.headers,
proxies=self.proxies,
timeout=10
)
if test_response.status_code == 200:
working_mirrors.append(mirror)
logger.info(f"镜像 {mirror} 可用")
if len(working_mirrors) >= 5: # 找到5个可用镜像就够了
break
except Exception as e:
logger.debug(f"镜像 {mirror} 不可用: {str(e)}")
continue
if not working_mirrors:
raise Exception("没有找到可用的镜像站点")
logger.info(f"找到 {len(working_mirrors)} 个可用镜像,开始尝试下载...")
# 使用可用的镜像进行下载
for mirror in working_mirrors:
try:
res = requests.post(
mirror,
headers=self.headers,
data=self.payload,
proxies=self.proxies,
timeout=self.timeout
)
if res.ok:
logger.info(f"成功使用镜像站点: {mirror}")
self.url = mirror # 更新当前使用的镜像
time.sleep(1) # 降低等待时间以提高效率
return res
except Exception as e:
logger.error(f"尝试镜像 {mirror} 失败: {str(e)}")
last_exception = e
continue
if last_exception:
raise last_exception
raise Exception("所有可用镜像站点均无法完成下载")
def _extract_url(self, response):
"""从响应中提取PDF下载链接"""
soup = BeautifulSoup(response.content, 'html.parser')
try:
# 尝试多种方式提取PDF链接
pdf_element = soup.find(id='pdf')
if pdf_element:
content_url = pdf_element.get('src')
else:
# 尝试其他可能的选择器
pdf_element = soup.find('iframe')
if pdf_element:
content_url = pdf_element.get('src')
else:
# 查找直接的PDF链接
pdf_links = soup.find_all('a', href=lambda x: x and '.pdf' in x)
if pdf_links:
content_url = pdf_links[0].get('href')
else:
raise AttributeError("未找到PDF链接")
if content_url:
content_url = content_url.replace('#navpanes=0&view=FitH', '').replace('//', '/')
if not content_url.endswith('.pdf') and 'pdf' not in content_url.lower():
raise AttributeError("找到的链接不是PDF文件")
except AttributeError:
logger.error(f"未找到论文 {self.doi}")
return None
current_mirror = self.url.rstrip('/')
if content_url.startswith('/'):
return current_mirror + content_url
elif content_url.startswith('http'):
return content_url
else:
return 'https:/' + content_url
def _download_pdf(self, pdf_url):
"""下载PDF文件并验证其完整性"""
try:
# 尝试不同的下载方式
download_methods = [
# 方法1直接下载
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout),
# 方法2添加 Referer 头
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
headers={**self.headers, 'Referer': self.url}),
# 方法3使用原始域名作为 Referer
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
headers={**self.headers, 'Referer': pdf_url.split('/downloads')[0] if '/downloads' in pdf_url else self.url})
]
for i, download_method in enumerate(download_methods):
try:
logger.info(f"尝试下载方式 {i+1}/3...")
response = download_method()
if response.status_code == 200:
content = response.content
if len(content) > 1000 and self._check_pdf_validity(content): # 确保文件不是太小
logger.info(f"PDF下载成功,文件大小: {len(content)} bytes")
return content
else:
logger.warning("下载的文件可能不是有效的PDF")
elif response.status_code == 403:
logger.warning(f"访问被拒绝 (403 Forbidden),尝试其他下载方式")
continue
else:
logger.warning(f"下载失败,状态码: {response.status_code}")
continue
except Exception as e:
logger.warning(f"下载方式 {i+1} 失败: {str(e)}")
continue
# 如果所有方法都失败,尝试构造替代URL
try:
logger.info("尝试使用替代镜像下载...")
# 从原始URL提取关键信息
if '/downloads/' in pdf_url:
file_part = pdf_url.split('/downloads/')[-1]
alternative_mirrors = [
f"https://sci-hub.se/downloads/{file_part}",
f"https://sci-hub.st/downloads/{file_part}",
f"https://sci-hub.ru/downloads/{file_part}",
f"https://sci-hub.wf/downloads/{file_part}",
f"https://sci-hub.ee/downloads/{file_part}",
f"https://sci-hub.ren/downloads/{file_part}",
f"https://sci-hub.tf/downloads/{file_part}"
]
for alt_url in alternative_mirrors:
try:
response = requests.get(
alt_url,
proxies=self.proxies,
timeout=self.timeout,
headers={**self.headers, 'Referer': alt_url.split('/downloads')[0]}
)
if response.status_code == 200:
content = response.content
if len(content) > 1000 and self._check_pdf_validity(content):
logger.info(f"使用替代镜像成功下载: {alt_url}")
return content
except Exception as e:
logger.debug(f"替代镜像 {alt_url} 下载失败: {str(e)}")
continue
except Exception as e:
logger.error(f"所有下载方式都失败: {str(e)}")
return None
except Exception as e:
logger.error(f"下载PDF文件失败: {str(e)}")
return None
def fetch(self):
"""获取论文PDF,包含重试和验证机制"""
for attempt in range(2): # 最多重试3次
try:
logger.info(f"开始第 {attempt + 1} 次尝试下载论文: {self.doi}")
# 获取PDF下载链接
response = self._send_request()
pdf_url = self._extract_url(response)
if pdf_url is None:
logger.warning(f"{attempt + 1} 次尝试未找到PDF下载链接")
continue
logger.info(f"找到PDF下载链接: {pdf_url}")
# 下载并验证PDF
pdf_content = self._download_pdf(pdf_url)
if pdf_content is None:
logger.warning(f"{attempt + 1} 次尝试PDF下载失败")
continue
# 保存PDF文件
pdf_name = f"{self.doi.replace('/', '_').replace(':', '_')}.pdf"
pdf_path = self.path.joinpath(pdf_name)
pdf_path.write_bytes(pdf_content)
logger.info(f"成功下载论文: {pdf_name},文件大小: {len(pdf_content)} bytes")
return str(pdf_path)
except Exception as e:
logger.error(f"{attempt + 1} 次尝试失败: {str(e)}")
if attempt < 2: # 不是最后一次尝试
wait_time = (attempt + 1) * 3 # 递增等待时间
logger.info(f"等待 {wait_time} 秒后重试...")
time.sleep(wait_time)
continue
raise Exception(f"无法下载论文 {self.doi},所有重试都失败了")
# Usage Example
if __name__ == '__main__':
# 创建一个用于保存PDF的目录
save_path = Path('./downloaded_papers')
save_path.mkdir(exist_ok=True)
# DOI示例
sample_doi = '10.3897/rio.7.e67379' # 这是一篇Nature的论文DOI
try:
# 初始化SciHub下载器,先尝试使用代理
logger.info("尝试使用代理模式...")
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=True)
# 开始下载
result = downloader.fetch()
print(f"论文已保存到: {result}")
except Exception as e:
print(f"使用代理模式失败: {str(e)}")
try:
# 如果代理模式失败,尝试直连模式
logger.info("尝试直连模式...")
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=False)
result = downloader.fetch()
print(f"论文已保存到: {result}")
except Exception as e2:
print(f"直连模式也失败: {str(e2)}")
print("建议检查网络连接或尝试其他DOI")

查看文件

@@ -0,0 +1,400 @@
from typing import List, Optional, Dict, Union
from datetime import datetime
import aiohttp
import random
from .base_source import DataSource, PaperMetadata
from tqdm import tqdm
class ScopusSource(DataSource):
"""Scopus API实现"""
# 定义API密钥列表
API_KEYS = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Scopus API密钥,如果不提供则从预定义列表中随机选择
"""
self.api_key = api_key or random.choice(self.API_KEYS)
self._initialize()
def _initialize(self) -> None:
"""初始化基础URL和请求头"""
self.base_url = "https://api.elsevier.com/content"
self.headers = {
"X-ELS-APIKey": self.api_key,
"Accept": "application/json"
}
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
"""发送HTTP请求
Args:
url: 请求URL
params: 查询参数
Returns:
响应JSON数据
"""
try:
async with aiohttp.ClientSession(headers=self.headers) as session:
async with session.get(url, params=params) as response:
if response.status == 200:
return await response.json()
else:
print(f"请求失败: {response.status}")
return None
except Exception as e:
print(f"请求发生错误: {str(e)}")
return None
def _parse_paper_data(self, data: Dict) -> PaperMetadata:
"""解析Scopus API返回的数据
Args:
data: Scopus API返回的论文数据
Returns:
解析后的论文元数据
"""
try:
# 提取基本信息
title = data.get("dc:title", "")
# 提取作者信息
authors = []
if "author" in data:
if isinstance(data["author"], list):
for author in data["author"]:
if "given-name" in author and "surname" in author:
authors.append(f"{author['given-name']} {author['surname']}")
elif "indexed-name" in author:
authors.append(author["indexed-name"])
elif isinstance(data["author"], dict):
if "given-name" in data["author"] and "surname" in data["author"]:
authors.append(f"{data['author']['given-name']} {data['author']['surname']}")
elif "indexed-name" in data["author"]:
authors.append(data["author"]["indexed-name"])
# 提取摘要
abstract = data.get("dc:description", "")
# 提取年份
year = None
if "prism:coverDate" in data:
try:
year = int(data["prism:coverDate"][:4])
except:
pass
# 提取DOI
doi = data.get("prism:doi")
# 提取引用次数
citations = data.get("citedby-count")
if citations:
try:
citations = int(citations)
except:
citations = None
# 提取期刊信息
venue = data.get("prism:publicationName")
# 提取机构信息
institutions = []
if "affiliation" in data:
if isinstance(data["affiliation"], list):
for aff in data["affiliation"]:
if "affilname" in aff:
institutions.append(aff["affilname"])
elif isinstance(data["affiliation"], dict):
if "affilname" in data["affiliation"]:
institutions.append(data["affiliation"]["affilname"])
# 构建venue信息
venue_info = {
"issn": data.get("prism:issn"),
"eissn": data.get("prism:eIssn"),
"volume": data.get("prism:volume"),
"issue": data.get("prism:issueIdentifier"),
"page_range": data.get("prism:pageRange"),
"article_number": data.get("article-number"),
"publication_date": data.get("prism:coverDate")
}
return PaperMetadata(
title=title,
authors=authors,
abstract=abstract,
year=year,
doi=doi,
url=data.get("link", [{}])[0].get("@href"),
citations=citations,
venue=venue,
institutions=institutions,
venue_type="journal",
venue_name=venue,
venue_info=venue_info
)
except Exception as e:
print(f"解析论文数据时发生错误: {str(e)}")
return None
async def search(
self,
query: str,
limit: int = 100,
sort_by: str = None,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文
Args:
query: 搜索关键词
limit: 返回结果数量限制
sort_by: 排序方式 ('relevance', 'date', 'citations')
start_year: 起始年份
Returns:
论文列表
"""
try:
# 构建查询参数
params = {
"query": query,
"count": min(limit, 100), # Scopus API单次请求限制
"start": 0
}
# 添加年份过滤
if start_year:
params["date"] = f"{start_year}-present"
# 添加排序
if sort_by:
sort_map = {
"relevance": "-score",
"date": "-coverDate",
"citations": "-citedby-count"
}
if sort_by in sort_map:
params["sort"] = sort_map[sort_by]
# 发送请求
url = f"{self.base_url}/search/scopus"
response = await self._make_request(url, params)
if not response or "search-results" not in response:
return []
# 解析结果
results = response["search-results"].get("entry", [])
papers = []
for result in results:
paper = self._parse_paper_data(result)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
return []
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
"""获取论文详情
Args:
paper_id: Scopus ID或DOI
Returns:
论文详情
"""
try:
# 判断是否为DOI
if "/" in paper_id:
url = f"{self.base_url}/article/doi/{paper_id}"
else:
url = f"{self.base_url}/abstract/scopus_id/{paper_id}"
response = await self._make_request(url)
if not response or "abstracts-retrieval-response" not in response:
return None
data = response["abstracts-retrieval-response"]
return self._parse_paper_data(data)
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
"""获取引用该论文的文献
Args:
paper_id: Scopus ID
Returns:
引用论文列表
"""
try:
url = f"{self.base_url}/abstract/citations/{paper_id}"
response = await self._make_request(url)
if not response or "citing-papers" not in response:
return []
results = response["citing-papers"].get("papers", [])
papers = []
for result in results:
paper = self._parse_paper_data(result)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"获取引用信息时发生错误: {str(e)}")
return []
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
"""获取该论文引用的文献
Args:
paper_id: Scopus ID
Returns:
参考文献列表
"""
try:
url = f"{self.base_url}/abstract/references/{paper_id}"
response = await self._make_request(url)
if not response or "references" not in response:
return []
results = response["references"].get("reference", [])
papers = []
for result in results:
paper = self._parse_paper_data(result)
if paper:
papers.append(paper)
return papers
except Exception as e:
print(f"获取参考文献时发生错误: {str(e)}")
return []
async def search_by_author(
self,
author: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按作者搜索论文"""
query = f"AUTHOR-NAME({author})"
if start_year:
query += f" AND PUBYEAR > {start_year}"
return await self.search(query, limit=limit)
async def search_by_journal(
self,
journal: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""按期刊搜索论文"""
query = f"SRCTITLE({journal})"
if start_year:
query += f" AND PUBYEAR > {start_year}"
return await self.search(query, limit=limit)
async def get_latest_papers(
self,
days: int = 7,
limit: int = 100
) -> List[PaperMetadata]:
"""获取最新论文"""
query = f"LOAD-DATE > NOW() - {days}d"
return await self.search(query, limit=limit, sort_by="date")
async def example_usage():
"""ScopusSource使用示例"""
scopus = ScopusSource()
try:
# 示例1基本搜索
print("\n=== 示例1搜索机器学习相关论文 ===")
papers = await scopus.search("machine learning", limit=3)
print(f"\n找到 {len(papers)} 篇相关论文:")
for i, paper in enumerate(papers, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"发表期刊: {paper.venue}")
print(f"引用次数: {paper.citations}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要:\n{paper.abstract}")
print("-" * 80)
# 示例2按作者搜索
print("\n=== 示例2搜索特定作者的论文 ===")
author_papers = await scopus.search_by_author("Hinton G.", limit=3)
print(f"\n找到 {len(author_papers)} 篇 Hinton 的论文:")
for i, paper in enumerate(author_papers, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"发表期刊: {paper.venue}")
print(f"引用次数: {paper.citations}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要:\n{paper.abstract}")
print("-" * 80)
# 示例3根据关键词搜索相关论文
print("\n=== 示例3搜索人工智能相关论文 ===")
keywords = "artificial intelligence AND deep learning"
papers = await scopus.search(
query=keywords,
limit=5,
sort_by="citations", # 按引用次数排序
start_year=2020 # 只搜索2020年之后的论文
)
print(f"\n找到 {len(papers)} 篇相关论文:")
for i, paper in enumerate(papers, 1):
print(f"\n论文 {i}:")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"发表期刊: {paper.venue}")
print(f"引用次数: {paper.citations}")
print(f"DOI: {paper.doi}")
if paper.abstract:
print(f"摘要:\n{paper.abstract}")
print("-" * 80)
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,480 @@
from typing import List, Optional
from datetime import datetime
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
import random
class SemanticScholarSource(DataSource):
"""Semantic Scholar API实现,使用官方Python包"""
def __init__(self, api_key: str = None):
"""初始化
Args:
api_key: Semantic Scholar API密钥(可选)
"""
self.api_key = api_key
self._initialize() # 调用初始化方法
def _initialize(self) -> None:
"""初始化API客户端"""
if not self.api_key:
# 默认API密钥列表
default_api_keys = [
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
]
self.api_key = random.choice(default_api_keys)
self.client = None # 延迟初始化
self.fields = [
"title",
"authors",
"abstract",
"year",
"externalIds",
"citationCount",
"venue",
"openAccessPdf",
"publicationVenue"
]
async def _ensure_client(self):
"""确保客户端已初始化"""
if self.client is None:
from semanticscholar import AsyncSemanticScholar
self.client = AsyncSemanticScholar(api_key=self.api_key)
async def search(
self,
query: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""搜索论文"""
try:
await self._ensure_client()
# 如果指定了起始年份,添加到查询中
if start_year:
query = f"{query} year>={start_year}"
# 直接使用 search_paper 的结果
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
self.client.auth_header
)
papers = response.get('data', [])
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"搜索论文时发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
return []
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
"""获取指定DOI的论文详情"""
try:
await self._ensure_client()
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
return self._parse_paper_data(paper)
except Exception as e:
print(f"获取论文详情时发生错误: {str(e)}")
return None
async def get_citations(
self,
doi: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""获取引用指定DOI论文的文献列表"""
try:
await self._ensure_client()
# 构建查询参数
fields_param = f"fields={','.join(self.fields)}"
limit_param = f"limit={limit}"
year_param = f"year>={start_year}" if start_year else ""
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
params,
self.client.auth_header
)
citations = response.get('data', [])
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
except Exception as e:
print(f"获取引用列表时发生错误: {str(e)}")
return []
async def get_references(
self,
doi: str,
limit: int = 100,
start_year: int = None
) -> List[PaperMetadata]:
"""获取指定DOI论文的参考文献列表"""
try:
await self._ensure_client()
# 构建查询参数
fields_param = f"fields={','.join(self.fields)}"
limit_param = f"limit={limit}"
year_param = f"year>={start_year}" if start_year else ""
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
params,
self.client.auth_header
)
references = response.get('data', [])
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
except Exception as e:
print(f"获取参考文献列表时发生错误: {str(e)}")
return []
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
"""获取论文推荐
根据一篇论文获取相关的推荐论文
Args:
doi: 论文的DOI
limit: 返回结果数量限制,最大500
Returns:
推荐论文列表
"""
try:
await self._ensure_client()
papers = await self.client.get_recommended_papers(
f"DOI:{doi}",
fields=self.fields,
limit=min(limit, 500)
)
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取论文推荐时发生错误: {str(e)}")
return []
async def get_recommended_papers_from_lists(
self,
positive_dois: List[str],
negative_dois: List[str] = None,
limit: int = 100
) -> List[PaperMetadata]:
"""基于正负例论文列表获取推荐
Args:
positive_dois: 正例论文DOI列表想要获取类似的论文
negative_dois: 负例论文DOI列表不想要类似的论文
limit: 返回结果数量限制,最大500
Returns:
推荐论文列表
"""
try:
await self._ensure_client()
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
papers = await self.client.get_recommended_papers_from_lists(
positive_paper_ids=positive_ids,
negative_paper_ids=negative_ids,
fields=self.fields,
limit=min(limit, 500)
)
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取论文推荐列表时发生错误: {str(e)}")
return []
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
"""搜索作者"""
try:
await self._ensure_client()
# 直接使用 API 请求而不是 search_author 方法
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
self.client.auth_header
)
authors = response.get('data', [])
return [
{
'author_id': author.get('authorId'),
'name': author.get('name'),
'paper_count': author.get('paperCount'),
'citation_count': author.get('citationCount'),
}
for author in authors
]
except Exception as e:
print(f"搜索作者时发生错误: {str(e)}")
return []
async def get_author_details(self, author_id: str) -> Optional[dict]:
"""获取作者详细信息"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
"fields=name,paperCount,citationCount,hIndex",
self.client.auth_header
)
return {
'author_id': response.get('authorId'),
'name': response.get('name'),
'paper_count': response.get('paperCount'),
'citation_count': response.get('citationCount'),
'h_index': response.get('hIndex'),
}
except Exception as e:
print(f"获取作者详情时发生错误: {str(e)}")
return None
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
"""获取作者的论文列表"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
self.client.auth_header
)
papers = response.get('data', [])
return [self._parse_paper_data(paper) for paper in papers]
except Exception as e:
print(f"获取作者论文列表时发生错误: {str(e)}")
return []
async def get_paper_autocomplete(self, query: str) -> List[dict]:
"""论文标题自动补全"""
try:
await self._ensure_client()
# 直接使用 API 请求
response = await self.client._requester.get_data_async(
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
f"query={query}",
self.client.auth_header
)
suggestions = response.get('matches', [])
return [
{
'title': suggestion.get('title'),
'paper_id': suggestion.get('paperId'),
'year': suggestion.get('year'),
'venue': suggestion.get('venue'),
}
for suggestion in suggestions
]
except Exception as e:
print(f"获取标题自动补全时发生错误: {str(e)}")
return []
def _parse_paper_data(self, paper) -> PaperMetadata:
"""解析论文数据"""
# 获取DOI
doi = None
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
if external_ids:
if isinstance(external_ids, dict):
doi = external_ids.get('DOI')
if not doi and 'ArXiv' in external_ids:
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
else:
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
if not doi and hasattr(external_ids, 'ArXiv'):
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
# 获取PDF URL
pdf_url = None
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
if pdf_info:
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
# 获取发表场所详细信息
venue_type = None
venue_name = None
venue_info = {}
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
if venue:
if isinstance(venue, dict):
venue_name = venue.get('name')
venue_type = venue.get('type')
# 提取更多venue信息
venue_info = {
'issn': venue.get('issn'),
'publisher': venue.get('publisher'),
'url': venue.get('url'),
'alternate_names': venue.get('alternate_names', [])
}
else:
venue_name = venue.name if hasattr(venue, 'name') else None
venue_type = venue.type if hasattr(venue, 'type') else None
venue_info = {
'issn': getattr(venue, 'issn', None),
'publisher': getattr(venue, 'publisher', None),
'url': getattr(venue, 'url', None),
'alternate_names': getattr(venue, 'alternate_names', [])
}
# 获取标题
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
# 获取作者
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
author_names = []
for author in authors:
if isinstance(author, dict):
author_names.append(author.get('name', ''))
else:
author_names.append(author.name if hasattr(author, 'name') else str(author))
# 获取摘要
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
# 获取年份
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
# 获取引用次数
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
return PaperMetadata(
title=title,
authors=author_names,
abstract=abstract,
year=year,
doi=doi,
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
citations=citations,
venue=venue_name,
institutions=[],
venue_type=venue_type,
venue_name=venue_name,
venue_info=venue_info,
source='semantic' # 添加来源标记
)
async def example_usage():
"""SemanticScholarSource使用示例"""
semantic = SemanticScholarSource()
try:
# 示例1使用DOI直接获取论文
print("\n=== 示例1通过DOI获取论文 ===")
doi = "10.18653/v1/N19-1423" # BERT论文
print(f"获取DOI为 {doi} 的论文信息...")
paper = await semantic.get_paper_details(doi)
if paper:
print("\n--- 论文信息 ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
print(f"DOI: {paper.doi}")
print(f"URL: {paper.url}")
if paper.abstract:
print(f"\n摘要:")
print(paper.abstract)
print(f"\n引用次数: {paper.citations}")
print(f"发表venue: {paper.venue}")
# 示例2搜索论文
print("\n=== 示例2搜索论文 ===")
query = "BERT pre-training"
print(f"搜索关键词 '{query}' 相关的论文...")
papers = await semantic.search(query=query, limit=3)
for i, paper in enumerate(papers, 1):
print(f"\n--- 搜索结果 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
if paper.abstract:
print(f"\n摘要:")
print(paper.abstract)
print(f"\nDOI: {paper.doi}")
print(f"引用次数: {paper.citations}")
# 示例3获取论文推荐
print("\n=== 示例3获取论文推荐 ===")
print(f"获取与论文 {doi} 相关的推荐论文...")
recommendations = await semantic.get_recommended_papers(doi, limit=3)
for i, paper in enumerate(recommendations, 1):
print(f"\n--- 推荐论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
print(f"发表年份: {paper.year}")
# 示例4基于多篇论文的推荐
print("\n=== 示例4基于多篇论文的推荐 ===")
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
multi_recommendations = await semantic.get_recommended_papers_from_lists(
positive_dois=positive_dois,
limit=3
)
for i, paper in enumerate(multi_recommendations, 1):
print(f"\n--- 推荐论文 {i} ---")
print(f"标题: {paper.title}")
print(f"作者: {', '.join(paper.authors)}")
# 示例5搜索作者
print("\n=== 示例5搜索作者 ===")
author_query = "Yann LeCun"
print(f"搜索作者: '{author_query}'")
authors = await semantic.search_author(author_query, limit=3)
for i, author in enumerate(authors, 1):
print(f"\n--- 作者 {i} ---")
print(f"姓名: {author['name']}")
print(f"论文数量: {author['paper_count']}")
print(f"总引用次数: {author['citation_count']}")
# 示例6获取作者详情
print("\n=== 示例6获取作者详情 ===")
if authors: # 使用第一个搜索结果的作者ID
author_id = authors[0]['author_id']
print(f"获取作者ID {author_id} 的详细信息...")
author_details = await semantic.get_author_details(author_id)
if author_details:
print(f"姓名: {author_details['name']}")
print(f"H指数: {author_details['h_index']}")
print(f"总引用次数: {author_details['citation_count']}")
print(f"发表论文数: {author_details['paper_count']}")
# 示例7获取作者论文
print("\n=== 示例7获取作者论文 ===")
if authors: # 使用第一个搜索结果的作者ID
author_id = authors[0]['author_id']
print(f"获取作者 {authors[0]['name']} 的论文列表...")
author_papers = await semantic.get_author_papers(author_id, limit=3)
for i, paper in enumerate(author_papers, 1):
print(f"\n--- 论文 {i} ---")
print(f"标题: {paper.title}")
print(f"发表年份: {paper.year}")
print(f"引用次数: {paper.citations}")
# 示例8论文标题自动补全
print("\n=== 示例8论文标题自动补全 ===")
title_query = "Attention is all"
print(f"搜索标题: '{title_query}'")
suggestions = await semantic.get_paper_autocomplete(title_query)
for i, suggestion in enumerate(suggestions[:3], 1):
print(f"\n--- 建议 {i} ---")
print(f"标题: {suggestion['title']}")
print(f"发表年份: {suggestion['year']}")
print(f"发表venue: {suggestion['venue']}")
except Exception as e:
print(f"发生错误: {str(e)}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
import asyncio
asyncio.run(example_usage())

查看文件

@@ -0,0 +1,46 @@
import aiohttp
from typing import List, Dict, Optional
from datetime import datetime
from .base_source import DataSource, PaperMetadata
class UnpaywallSource(DataSource):
"""Unpaywall API实现"""
def _initialize(self) -> None:
self.base_url = "https://api.unpaywall.org/v2"
self.email = self.api_key # Unpaywall使用email作为API key
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.base_url}/search",
params={
"query": query,
"email": self.email,
"limit": limit
}
) as response:
data = await response.json()
return [self._parse_response(item.response)
for item in data.get("results", [])]
def _parse_response(self, data: Dict) -> PaperMetadata:
"""解析Unpaywall返回的数据"""
return PaperMetadata(
title=data.get("title", ""),
authors=[
f"{author.get('given', '')} {author.get('family', '')}"
for author in data.get("z_authors", [])
],
institutions=[
aff.get("name", "")
for author in data.get("z_authors", [])
for aff in author.get("affiliation", [])
],
abstract="", # Unpaywall不提供摘要
year=data.get("year"),
doi=data.get("doi"),
url=data.get("doi_url"),
citations=None, # Unpaywall不提供引用计数
venue=data.get("journal_name")
)

查看文件

@@ -0,0 +1,412 @@
import asyncio
from datetime import datetime
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from crazy_functions.review_fns.query_analyzer import SearchCriteria
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
from crazy_functions.review_fns.data_sources.pubmed_source import PubMedSource
from crazy_functions.review_fns.paper_processor.paper_llm_ranker import PaperLLMRanker
from crazy_functions.pdf_fns.breakdown_pdf_txt import cut_from_end_to_satisfy_token_limit
from request_llms.bridge_all import model_info
from crazy_functions.review_fns.data_sources.crossref_source import CrossrefSource
from crazy_functions.review_fns.data_sources.adsabs_source import AdsabsSource
from toolbox import get_conf
class BaseHandler(ABC):
"""处理器基类"""
def __init__(self, arxiv: ArxivSource, semantic: SemanticScholarSource, llm_kwargs: Dict = None):
self.arxiv = arxiv
self.semantic = semantic
self.pubmed = PubMedSource()
self.crossref = CrossrefSource() # 添加 Crossref 实例
self.adsabs = AdsabsSource() # 添加 ADS 实例
self.paper_ranker = PaperLLMRanker(llm_kwargs=llm_kwargs)
self.ranked_papers = [] # 存储排序后的论文列表
self.llm_kwargs = llm_kwargs or {} # 保存llm_kwargs
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
"""获取搜索参数"""
return {
'max_papers': plugin_kwargs.get('max_papers', 100), # 最大论文数量
'min_year': plugin_kwargs.get('min_year', 2015), # 最早年份
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
}
@abstractmethod
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> List[List[str]]:
"""处理查询"""
pass
async def _search_arxiv(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用arXiv专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 首先尝试基础搜索
query = params.get("query", "")
if query:
papers = await self.arxiv.search(
query,
limit=params["limit"],
sort_by=params.get("sort_by", "relevance"),
sort_order=params.get("sort_order", "descending"),
start_year=min_year
)
# 如果基础搜索没有结果,尝试分类搜索
if not papers:
categories = params.get("categories", [])
for category in categories:
category_papers = await self.arxiv.search_by_category(
category,
limit=params["limit"],
sort_by=params.get("sort_by", "relevance"),
sort_order=params.get("sort_order", "descending"),
)
if category_papers:
papers.extend(category_papers)
return papers or []
except Exception as e:
print(f"arXiv搜索出错: {str(e)}")
return []
async def _search_semantic(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用Semantic Scholar专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
# 只使用基本的搜索参数
papers = await self.semantic.search(
query=params.get("query", ""),
limit=params["limit"]
)
# 在内存中进行过滤
if papers and min_year:
papers = [p for p in papers if getattr(p, 'year', 0) and p.year >= min_year]
return papers or []
except Exception as e:
print(f"Semantic Scholar搜索出错: {str(e)}")
return []
async def _search_pubmed(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用PubMed专用参数搜索"""
try:
# 如果不需要PubMed搜索,直接返回空列表
if params.get("search_type") == "none":
return []
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 根据搜索类型选择搜索方法
if params.get("search_type") == "basic":
papers = await self.pubmed.search(
query=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "author":
papers = await self.pubmed.search_by_author(
author=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "journal":
papers = await self.pubmed.search_by_journal(
journal=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
return papers or []
except Exception as e:
print(f"PubMed搜索出错: {str(e)}")
return []
async def _search_crossref(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用Crossref专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 根据搜索类型选择搜索方法
if params.get("search_type") == "basic":
papers = await self.crossref.search(
query=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "author":
papers = await self.crossref.search_by_authors(
authors=[params.get("query", "")],
limit=params["limit"],
start_year=min_year
)
elif params.get("search_type") == "journal":
# 实现期刊搜索逻辑
pass
return papers or []
except Exception as e:
print(f"Crossref搜索出错: {str(e)}")
return []
async def _search_adsabs(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
"""使用ADS专用参数搜索"""
try:
original_limit = params.get("limit", 20)
params["limit"] = original_limit * limit_multiplier
papers = []
# 执行搜索
if params.get("search_type") == "basic":
papers = await self.adsabs.search(
query=params.get("query", ""),
limit=params["limit"],
start_year=min_year
)
return papers or []
except Exception as e:
print(f"ADS搜索出错: {str(e)}")
return []
async def _search_all_sources(self, criteria: SearchCriteria, search_params: Dict) -> List:
"""从所有数据源搜索论文"""
search_tasks = []
# # 检查是否需要执行PubMed搜索
# is_using_pubmed = criteria.pubmed_params.get("search_type") != "none" and criteria.pubmed_params.get("query") != "none"
is_using_pubmed = False # 开源版本不再搜索pubmed
# 如果使用PubMed,则只执行PubMed和Semantic Scholar搜索
if is_using_pubmed:
search_tasks.append(
self._search_pubmed(
criteria.pubmed_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
# Semantic Scholar总是执行搜索
search_tasks.append(
self._search_semantic(
criteria.semantic_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
else:
# 如果不使用ADS,则执行Crossref搜索
if criteria.crossref_params.get("search_type") != "none" and criteria.crossref_params.get("query") != "none":
search_tasks.append(
self._search_crossref(
criteria.crossref_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
search_tasks.append(
self._search_arxiv(
criteria.arxiv_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
if get_conf("SEMANTIC_SCHOLAR_KEY"):
search_tasks.append(
self._search_semantic(
criteria.semantic_params,
limit_multiplier=search_params['search_multiplier'],
min_year=criteria.start_year
)
)
# 执行所有需要的搜索任务
papers = await asyncio.gather(*search_tasks)
# 合并所有来源的论文并统计各来源的数量
all_papers = []
source_counts = {
'arxiv': 0,
'semantic': 0,
'pubmed': 0,
'crossref': 0,
'adsabs': 0
}
for source_papers in papers:
if source_papers:
for paper in source_papers:
source = getattr(paper, 'source', 'unknown')
if source in source_counts:
source_counts[source] += 1
all_papers.extend(source_papers)
# 打印各来源的论文数量
print("\n=== 各数据源找到的论文数量 ===")
for source, count in source_counts.items():
if count > 0: # 只打印有论文的来源
print(f"{source.capitalize()}: {count}")
print(f"总计: {len(all_papers)}")
print("===========================\n")
return all_papers
def _format_paper_time(self, paper) -> str:
"""格式化论文时间信息"""
year = getattr(paper, 'year', None)
if not year:
return ""
# 如果有具体的发表日期,使用具体日期
if hasattr(paper, 'published') and paper.published:
return f"(发表于 {paper.published.strftime('%Y-%m')})"
# 如果只有年份,只显示年份
return f"({year})"
def _format_papers(self, papers: List) -> str:
"""格式化论文列表,使用token限制控制长度"""
formatted = []
for i, paper in enumerate(papers, 1):
# 只保留前三个作者
authors = paper.authors[:3]
if len(paper.authors) > 3:
authors.append("et al.")
# 构建所有可能的下载链接
download_links = []
# 添加arXiv链接
if hasattr(paper, 'doi') and paper.doi:
if paper.doi.startswith("10.48550/arXiv."):
# 从DOI中提取完整的arXiv ID
arxiv_id = paper.doi.split("arXiv.")[-1]
# 移除多余的点号并确保格式正确
arxiv_id = arxiv_id.replace("..", ".") # 移除重复的点号
if arxiv_id.startswith("."): # 移除开头的点号
arxiv_id = arxiv_id[1:]
if arxiv_id.endswith("."): # 移除结尾的点号
arxiv_id = arxiv_id[:-1]
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
elif "arxiv.org/abs/" in paper.doi:
# 直接从URL中提取arXiv ID
arxiv_id = paper.doi.split("arxiv.org/abs/")[-1]
if "v" in arxiv_id: # 移除版本号
arxiv_id = arxiv_id.split("v")[0]
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
else:
download_links.append(f"[DOI](https://doi.org/{paper.doi})")
# 添加直接URL链接如果存在且不同于前面的链接
if hasattr(paper, 'url') and paper.url:
if not any(paper.url in link for link in download_links):
download_links.append(f"[Source]({paper.url})")
# 构建下载链接字符串
download_section = " | ".join(download_links) if download_links else "No direct download link available"
# 构建来源信息
source_info = []
if hasattr(paper, 'venue_type') and paper.venue_type and paper.venue_type != 'preprint':
source_info.append(f"Type: {paper.venue_type}")
if hasattr(paper, 'venue_name') and paper.venue_name:
source_info.append(f"Venue: {paper.venue_name}")
# 添加IF指数和分区信息
if hasattr(paper, 'if_factor') and paper.if_factor:
source_info.append(f"IF: {paper.if_factor}")
if hasattr(paper, 'cas_division') and paper.cas_division:
source_info.append(f"中科院分区: {paper.cas_division}")
if hasattr(paper, 'jcr_division') and paper.jcr_division:
source_info.append(f"JCR分区: {paper.jcr_division}")
if hasattr(paper, 'venue_info') and paper.venue_info:
if paper.venue_info.get('journal_ref'):
source_info.append(f"Journal Reference: {paper.venue_info['journal_ref']}")
if paper.venue_info.get('publisher'):
source_info.append(f"Publisher: {paper.venue_info['publisher']}")
# 构建当前论文的格式化文本
paper_text = (
f"{i}. **{paper.title}**\n" +
f" Authors: {', '.join(authors)}\n" +
f" Year: {paper.year}\n" +
f" Citations: {paper.citations if paper.citations else 'N/A'}\n" +
(f" Source: {'; '.join(source_info)}\n" if source_info else "") +
# 添加PubMed特有信息
(f" MeSH Terms: {'; '.join(paper.mesh_terms)}\n" if hasattr(paper,
'mesh_terms') and paper.mesh_terms else "") +
f" 📥 PDF Downloads: {download_section}\n" +
f" Abstract: {paper.abstract}\n"
)
formatted.append(paper_text)
full_text = "\n".join(formatted)
# 根据不同模型设置不同的token限制
model_name = getattr(self, 'llm_kwargs', {}).get('llm_model', 'gpt-3.5-turbo')
token_limit = model_info[model_name]['max_token'] * 3 // 4
# 使用token限制控制长度
return cut_from_end_to_satisfy_token_limit(full_text, limit=token_limit, reserve_token=0, llm_model=model_name)
def _get_current_time(self) -> str:
"""获取当前时间信息"""
now = datetime.now()
return now.strftime("%Y年%m月%d")
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
"""生成道歉提示"""
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的有效文献。
可能的原因:
1. 搜索词过于具体或专业
2. 时间范围限制过严
建议解决方案:
1. 尝试使用更通用的关键词
2. 扩大搜索时间范围
3. 使用同义词或相关术语
请根据以上建议调整后重试。"""
def get_ranked_papers(self) -> str:
"""获取排序后的论文列表的格式化字符串"""
return self._format_papers(self.ranked_papers) if self.ranked_papers else ""
def _is_pubmed_paper(self, paper) -> bool:
"""判断是否为PubMed论文"""
return (paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url)

查看文件

@@ -0,0 +1,106 @@
from typing import List, Dict, Any
from .base_handler import BaseHandler
from crazy_functions.review_fns.query_analyzer import SearchCriteria
import asyncio
class Arxiv最新论文推荐功能(BaseHandler):
"""最新论文推荐处理器"""
def __init__(self, arxiv, semantic, llm_kwargs=None):
super().__init__(arxiv, semantic, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理最新论文推荐请求"""
# 获取搜索参数
search_params = self._get_search_params(plugin_kwargs)
# 获取最新论文
papers = []
for category in criteria.arxiv_params["categories"]:
latest_papers = await self.arxiv.get_latest_papers(
category=category,
debug=False,
batch_size=50
)
papers.extend(latest_papers)
if not papers:
return self._generate_apology_prompt(criteria)
# 使用embedding模型对论文进行排序
self.ranked_papers = self.paper_ranker.rank_papers(
query=criteria.original_query,
papers=papers,
search_criteria=criteria
)
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = f"""Current time: {current_time}
Based on your interest in {criteria.main_topic}, here are the latest papers from arXiv in relevant categories:
{', '.join(criteria.arxiv_params["categories"])}
Latest papers available:
{self._format_papers(self.ranked_papers)}
Please provide:
1. A clear list of latext papers, organized by themes or approaches
2. Group papers by sub-topics or themes if applicable
3. For each paper:
- Publication time
- The key contributions and main findings
- Why it's relevant to the user's interests
- How it relates to other latest papers
- The paper's citation count and citation impact
- The paper's download link
4. A suggested reading order based on:
- Paper relationships and dependencies
- Difficulty level
- Significance
5. Future Directions
- Emerging venues and research streams
- Novel methodological approaches
- Cross-disciplinary opportunities
- Research gaps by publication type
IMPORTANT:
- Focus on explaining why each paper is interesting
- Highlight the novelty and potential impact
- Consider the credibility and stage of each publication
- Use the provided paper titles with their links when referring to specific papers
- Base recommendations ONLY on the explicitly provided paper information
- Do not make ANY assumptions about papers beyond the given data
- When information is missing or unclear, acknowledge the limitation
- Never speculate about:
* Paper quality or rigor not evidenced in the data
* Research impact beyond citation counts
* Implementation details not mentioned
* Author expertise or background
* Future research directions not stated
- For each paper, cite only verifiable information
- Clearly distinguish between facts and potential implications
- Each paper includes download links in its 📥 PDF Downloads section
Format your response in markdown with clear sections.
Language requirement:
- If the query explicitly specifies a language, use that language
- Otherwise, match the language of the original user query
"""
return final_prompt

查看文件

@@ -0,0 +1,344 @@
from typing import List, Dict, Any, Optional, Tuple
from .base_handler import BaseHandler
from crazy_functions.review_fns.query_analyzer import SearchCriteria
import asyncio
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
class 单篇论文分析功能(BaseHandler):
"""论文分析处理器"""
def __init__(self, arxiv, semantic, llm_kwargs=None):
super().__init__(arxiv, semantic, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理论文分析请求,返回最终的prompt"""
# 1. 获取论文详情
paper = await self._get_paper_details(criteria)
if not paper:
return self._generate_apology_prompt(criteria)
# 保存为ranked_papers以便统一接口
self.ranked_papers = [paper]
# 2. 构建最终的prompt
current_time = self._get_current_time()
# 获取论文信息
title = getattr(paper, "title", "Unknown Title")
authors = getattr(paper, "authors", [])
year = getattr(paper, "year", "Unknown Year")
abstract = getattr(paper, "abstract", "No abstract available")
citations = getattr(paper, "citations", "N/A")
# 添加论文ID信息
paper_id = ""
if criteria.paper_source == "arxiv":
paper_id = f"arXiv ID: {criteria.paper_id}\n"
elif criteria.paper_source == "doi":
paper_id = f"DOI: {criteria.paper_id}\n"
# 格式化作者列表
authors_str = ', '.join(authors) if isinstance(authors, list) else authors
final_prompt = f"""Current time: {current_time}
Please provide a comprehensive analysis of the following paper:
{paper_id}Title: {title}
Authors: {authors_str}
Year: {year}
Citations: {citations}
Publication Venue: {paper.venue_name} ({paper.venue_type})
{f"Publisher: {paper.venue_info.get('publisher')}" if paper.venue_info.get('publisher') else ""}
{f"Journal Reference: {paper.venue_info.get('journal_ref')}" if paper.venue_info.get('journal_ref') else ""}
Abstract: {abstract}
Please provide:
1. Publication Context
- Publication venue analysis and impact factor (if available)
- Paper type (journal article, conference paper, preprint)
- Publication timeline and peer review status
- Publisher reputation and venue prestige
2. Research Context
- Field positioning and significance
- Historical context and prior work
- Related research streams
- Cross-venue impact analysis
3. Technical Analysis
- Detailed methodology review
- Implementation details
- Experimental setup and results
- Technical innovations
4. Impact Analysis
- Citation patterns and influence
- Cross-venue recognition
- Industry vs. academic impact
- Practical applications
5. Critical Review
- Methodological rigor assessment
- Result reliability and reproducibility
- Venue-appropriate evaluation standards
- Limitations and potential improvements
IMPORTANT:
- Strictly use ONLY the information provided above about the paper
- Do not make ANY assumptions or inferences beyond the given data
- If certain information is not provided, explicitly state that it is unknown
- For any unclear or missing details, acknowledge the limitation rather than speculating
- When discussing methodology or results, only describe what is explicitly stated in the abstract
- Never fabricate or assume any details about:
* Publication venues or status
* Implementation details not mentioned
* Results or findings not stated
* Impact or influence not supported by the citation count
* Authors' affiliations or backgrounds
* Future work or implications not mentioned
- You can find the paper's download options in the 📥 PDF Downloads section
- Available download formats include arXiv PDF, DOI links, and source URLs
Format your response in markdown with clear sections.
Language requirement:
- If the query explicitly specifies a language, use that language
- Otherwise, match the language of the original user query
"""
return final_prompt
async def _get_paper_details(self, criteria: SearchCriteria):
"""获取论文详情"""
try:
if criteria.paper_source == "arxiv":
# 使用 arxiv ID 搜索
papers = await self.arxiv.search_by_id(criteria.paper_id)
return papers[0] if papers else None
elif criteria.paper_source == "doi":
# 尝试从所有来源获取
paper = await self.semantic.get_paper_by_doi(criteria.paper_id)
if not paper:
# 如果Semantic Scholar没有找到,尝试PubMed
papers = await self.pubmed.search(
f"{criteria.paper_id}[doi]",
limit=1
)
if papers:
return papers[0]
return paper
elif criteria.paper_source == "title":
# 使用_search_all_sources搜索
search_params = {
'max_papers': 1,
'min_year': 1900, # 不限制年份
'search_multiplier': 1
}
# 设置搜索参数
criteria.arxiv_params = {
"search_type": "basic",
"query": f'ti:"{criteria.paper_title}"',
"limit": 1
}
criteria.semantic_params = {
"query": criteria.paper_title,
"limit": 1
}
criteria.pubmed_params = {
"search_type": "basic",
"query": f'"{criteria.paper_title}"[Title]',
"limit": 1
}
papers = await self._search_all_sources(criteria, search_params)
return papers[0] if papers else None
# 如果都没有找到,尝试使用 main_topic 作为标题搜索
if not criteria.paper_title and not criteria.paper_id:
search_params = {
'max_papers': 1,
'min_year': 1900,
'search_multiplier': 1
}
# 设置搜索参数
criteria.arxiv_params = {
"search_type": "basic",
"query": f'ti:"{criteria.main_topic}"',
"limit": 1
}
criteria.semantic_params = {
"query": criteria.main_topic,
"limit": 1
}
criteria.pubmed_params = {
"search_type": "basic",
"query": f'"{criteria.main_topic}"[Title]',
"limit": 1
}
papers = await self._search_all_sources(criteria, search_params)
return papers[0] if papers else None
return None
except Exception as e:
print(f"获取论文详情时出错: {str(e)}")
return None
async def _get_citation_context(self, paper: Dict, plugin_kwargs: Dict) -> Tuple[List, List]:
"""获取引用上下文"""
search_params = self._get_search_params(plugin_kwargs)
# 使用论文标题构建搜索参数
title_query = f'ti:"{getattr(paper, "title", "")}"'
arxiv_params = {
"query": title_query,
"limit": search_params['max_papers'],
"search_type": "basic",
"sort_by": "relevance",
"sort_order": "descending"
}
semantic_params = {
"query": getattr(paper, "title", ""),
"limit": search_params['max_papers']
}
citations, references = await asyncio.gather(
self._search_semantic(
semantic_params,
limit_multiplier=search_params['search_multiplier'],
min_year=search_params['min_year']
),
self._search_arxiv(
arxiv_params,
limit_multiplier=search_params['search_multiplier'],
min_year=search_params['min_year']
)
)
return citations, references
async def _generate_analysis(
self,
paper: Dict,
citations: List,
references: List,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any]
) -> List[List[str]]:
"""生成论文分析"""
# 构建提示
analysis_prompt = f"""Please provide a comprehensive analysis of the following paper:
Paper details:
{self._format_paper(paper)}
Key references (papers cited by this paper):
{self._format_papers(references)}
Important citations (papers that cite this paper):
{self._format_papers(citations)}
Please provide:
1. Paper Overview
- Main research question/objective
- Key methodology/approach
- Main findings/contributions
2. Technical Analysis
- Detailed methodology review
- Technical innovations
- Implementation details
- Experimental setup and results
3. Impact Analysis
- Significance in the field
- Influence on subsequent research (based on citing papers)
- Relationship to prior work (based on cited papers)
- Practical applications
4. Critical Review
- Strengths and limitations
- Potential improvements
- Open questions and future directions
- Alternative approaches
5. Related Research Context
- How it builds on previous work
- How it has influenced subsequent research
- Comparison with alternative approaches
Format your response in markdown with clear sections."""
# 并行生成概述和技术分析
for response_chunk in request_gpt(
inputs_array=[
analysis_prompt,
self._get_technical_prompt(paper)
],
inputs_show_user_array=[
"Generating paper analysis...",
"Analyzing technical details..."
],
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history_array=[history, []],
sys_prompt_array=[
system_prompt,
"You are an expert at analyzing technical details in research papers."
]
):
pass # 等待生成完成
# 获取最后的两个回答
if chatbot and len(chatbot[-2:]) == 2:
analysis = chatbot[-2][1]
technical = chatbot[-1][1]
full_analysis = f"""# Paper Analysis: {paper.title}
## General Analysis
{analysis}
## Technical Deep Dive
{technical}
"""
chatbot.append(["Here is the paper analysis:", full_analysis])
else:
chatbot.append(["Here is the paper analysis:", "Failed to generate analysis."])
return chatbot
def _get_technical_prompt(self, paper: Dict) -> str:
"""生成技术分析提示"""
return f"""Please provide a detailed technical analysis of the following paper:
{self._format_paper(paper)}
Focus on:
1. Mathematical formulations and their implications
2. Algorithm design and complexity analysis
3. Architecture details and design choices
4. Implementation challenges and solutions
5. Performance analysis and bottlenecks
6. Technical limitations and potential improvements
Format your response in markdown, focusing purely on technical aspects."""

查看文件

@@ -0,0 +1,147 @@
from typing import List, Dict, Any
from .base_handler import BaseHandler
from crazy_functions.review_fns.query_analyzer import SearchCriteria
from textwrap import dedent
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
class 学术问答功能(BaseHandler):
"""学术问答处理器"""
def __init__(self, arxiv, semantic, llm_kwargs=None):
super().__init__(arxiv, semantic, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理学术问答请求,返回最终的prompt"""
# 1. 获取搜索参数
search_params = self._get_search_params(plugin_kwargs)
# 2. 搜索相关论文
papers = await self._search_relevant_papers(criteria, search_params)
if not papers:
return self._generate_apology_prompt(criteria)
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = dedent(f"""Current time: {current_time}
Based on the following paper abstracts, please answer this academic question: {criteria.original_query}
Available papers for reference:
{self._format_papers(self.ranked_papers)}
Please structure your response in the following format:
1. Core Answer (2-3 paragraphs)
- Provide a clear, direct answer synthesizing key findings
- Support main points with citations [1,2,etc.]
- Focus on consensus and differences across papers
2. Key Evidence (2-3 paragraphs)
- Present supporting evidence from abstracts
- Compare methodologies and results
- Highlight significant findings with citations
3. Research Context (1-2 paragraphs)
- Discuss current trends and developments
- Identify research gaps or limitations
- Suggest potential future directions
Guidelines:
- Base your answer ONLY on the provided abstracts
- Use numbered citations [1], [2,3], etc. for every claim
- Maintain academic tone and objectivity
- Synthesize findings across multiple papers
- Focus on the most relevant information to the question
Constraints:
- Do not include information beyond the provided abstracts
- Avoid speculation or personal opinions
- Do not elaborate on technical details unless directly relevant
- Keep citations concise and focused
- Use [N] citations for every major claim or finding
- Cite multiple papers [1,2,3] when showing consensus
- Place citations immediately after the relevant statements
Note: Provide citations for every major claim to ensure traceability to source papers.
Language requirement:
- If the query explicitly specifies a language, use that language. Use Chinese to answer if no language is specified.
- Otherwise, match the language of the original user query
"""
)
return final_prompt
async def _search_relevant_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
"""搜索相关论文"""
# 使用_search_all_sources替代原来的并行搜索
all_papers = await self._search_all_sources(criteria, search_params)
if not all_papers:
return []
# 使用BGE重排序
self.ranked_papers = self.paper_ranker.rank_papers(
query=criteria.main_topic,
papers=all_papers,
search_criteria=criteria
)
return self.ranked_papers or []
async def _generate_answer(
self,
criteria: SearchCriteria,
papers: List,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any]
) -> List[List[str]]:
"""生成答案"""
# 构建提示
qa_prompt = dedent(f"""Please answer the following academic question based on recent research papers.
Question: {criteria.main_topic}
Relevant papers:
{self._format_papers(papers)}
Please provide:
1. A direct answer to the question
2. Supporting evidence from the papers
3. Different perspectives or approaches if applicable
4. Current limitations and open questions
5. References to specific papers
Format your response in markdown with clear sections."""
)
# 调用LLM生成答案
for response_chunk in request_gpt(
inputs_array=[qa_prompt],
inputs_show_user_array=["Generating answer..."],
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history_array=[history],
sys_prompt_array=[system_prompt]
):
pass # 等待生成完成
# 获取最后的回答
if chatbot and len(chatbot[-1]) >= 2:
answer = chatbot[-1][1]
chatbot.append(["Here is the answer:", answer])
else:
chatbot.append(["Here is the answer:", "Failed to generate answer."])
return chatbot

查看文件

@@ -0,0 +1,185 @@
from typing import List, Dict, Any
from .base_handler import BaseHandler
from textwrap import dedent
from crazy_functions.review_fns.query_analyzer import SearchCriteria
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
class 论文推荐功能(BaseHandler):
"""论文推荐处理器"""
def __init__(self, arxiv, semantic, llm_kwargs=None):
super().__init__(arxiv, semantic, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理论文推荐请求,返回最终的prompt"""
search_params = self._get_search_params(plugin_kwargs)
# 1. 先搜索种子论文
seed_papers = await self._search_seed_papers(criteria, search_params)
if not seed_papers:
return self._generate_apology_prompt(criteria)
# 使用BGE重排序
all_papers = seed_papers
if not all_papers:
return self._generate_apology_prompt(criteria)
self.ranked_papers = self.paper_ranker.rank_papers(
query=criteria.original_query,
papers=all_papers,
search_criteria=criteria
)
if not self.ranked_papers:
return self._generate_apology_prompt(criteria)
# 构建最终的prompt
current_time = self._get_current_time()
final_prompt = dedent(f"""Current time: {current_time}
Based on the user's interest in {criteria.main_topic}, here are relevant papers.
Available papers for recommendation:
{self._format_papers(self.ranked_papers)}
Please provide:
1. Group papers by sub-topics or themes if applicable
2. For each paper:
- Publication time and venue (when available)
- Journal metrics (when available):
* Impact Factor (IF)
* JCR Quartile
* Chinese Academy of Sciences (CAS) Division
- The key contributions and main findings
- Why it's relevant to the user's interests
- How it relates to other recommended papers
- The paper's citation count and citation impact
- The paper's download link
3. A suggested reading order based on:
- Journal impact and quality metrics
- Chronological development of ideas
- Paper relationships and dependencies
- Difficulty level
- Impact and significance
4. Future Directions
- Emerging venues and research streams
- Novel methodological approaches
- Cross-disciplinary opportunities
- Research gaps by publication type
IMPORTANT:
- Focus on explaining why each paper is valuable
- Highlight connections between papers
- Consider both citation counts AND journal metrics when discussing impact
- When available, use IF, JCR quartile, and CAS division to assess paper quality
- Mention publication timing when discussing paper relationships
- When referring to papers, use HTML links in this format:
* For DOIs: <a href='https://doi.org/DOI_HERE' target='_blank'>DOI: DOI_HERE</a>
* For titles: <a href='PAPER_URL' target='_blank'>PAPER_TITLE</a>
- Present papers in a way that shows the evolution of ideas over time
- Base recommendations ONLY on the explicitly provided paper information
- Do not make ANY assumptions about papers beyond the given data
- When information is missing or unclear, acknowledge the limitation
- Never speculate about:
* Paper quality or rigor not evidenced in the data
* Research impact beyond citation counts and journal metrics
* Implementation details not mentioned
* Author expertise or background
* Future research directions not stated
- For each recommendation, cite only verifiable information
- Clearly distinguish between facts and potential implications
Format your response in markdown with clear sections.
Language requirement:
- If the query explicitly specifies a language, use that language
- Otherwise, match the language of the original user query
"""
)
return final_prompt
async def _search_seed_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
"""搜索种子论文"""
try:
# 使用_search_all_sources替代原来的并行搜索
all_papers = await self._search_all_sources(criteria, search_params)
if not all_papers:
return []
return all_papers
except Exception as e:
print(f"搜索种子论文时出错: {str(e)}")
return []
async def _get_recommendations(self, seed_papers: List, multiplier: int = 1) -> List:
"""获取推荐论文"""
recommendations = []
base_limit = 3 * multiplier
# 将种子论文添加到推荐列表中
recommendations.extend(seed_papers)
# 只使用前5篇论文作为种子
seed_papers = seed_papers[:5]
for paper in seed_papers:
try:
if paper.doi and paper.doi.startswith("10.48550/arXiv."):
# arXiv论文
arxiv_id = paper.doi.split(".")[-1]
paper_details = await self.arxiv.get_paper_details(arxiv_id)
if paper_details and hasattr(paper_details, 'venue'):
category = paper_details.venue.split(":")[-1]
similar_papers = await self.arxiv.search_by_category(
category,
limit=base_limit,
sort_by='relevance'
)
recommendations.extend(similar_papers)
elif paper.doi: # 只对有DOI的论文获取推荐
# Semantic Scholar论文
similar_papers = await self.semantic.get_recommended_papers(
paper.doi,
limit=base_limit
)
if similar_papers: # 只添加成功获取的推荐
recommendations.extend(similar_papers)
else:
# 对于没有DOI的论文,使用标题进行相关搜索
if paper.title:
similar_papers = await self.semantic.search(
query=paper.title,
limit=base_limit
)
recommendations.extend(similar_papers)
except Exception as e:
print(f"获取论文 '{paper.title}' 的推荐时发生错误: {str(e)}")
continue
# 去重处理
seen_dois = set()
unique_recommendations = []
for paper in recommendations:
if paper.doi and paper.doi not in seen_dois:
seen_dois.add(paper.doi)
unique_recommendations.append(paper)
elif not paper.doi and paper not in unique_recommendations:
unique_recommendations.append(paper)
return unique_recommendations

查看文件

@@ -0,0 +1,193 @@
from typing import List, Dict, Any, Tuple
from .base_handler import BaseHandler
from crazy_functions.review_fns.query_analyzer import SearchCriteria
import asyncio
class 文献综述功能(BaseHandler):
"""文献综述处理器"""
def __init__(self, arxiv, semantic, llm_kwargs=None):
super().__init__(arxiv, semantic, llm_kwargs)
async def handle(
self,
criteria: SearchCriteria,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> str:
"""处理文献综述请求,返回最终的prompt"""
# 获取搜索参数
search_params = self._get_search_params(plugin_kwargs)
# 使用_search_all_sources替代原来的并行搜索
all_papers = await self._search_all_sources(criteria, search_params)
if not all_papers:
return self._generate_apology_prompt(criteria)
self.ranked_papers = self.paper_ranker.rank_papers(
query=criteria.original_query,
papers=all_papers,
search_criteria=criteria
)
# 检查排序后的论文数量
if not self.ranked_papers:
return self._generate_apology_prompt(criteria)
# 检查是否包含PubMed论文
has_pubmed_papers = any(paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url
for paper in self.ranked_papers)
if has_pubmed_papers:
return self._generate_medical_review_prompt(criteria)
else:
return self._generate_general_review_prompt(criteria)
def _generate_medical_review_prompt(self, criteria: SearchCriteria) -> str:
"""生成医学文献综述prompt"""
return f"""Current time: {self._get_current_time()}
Conduct a systematic medical literature review on {criteria.main_topic} based STRICTLY on the provided articles.
Available literature for review:
{self._format_papers(self.ranked_papers)}
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
Please structure your medical review following these guidelines:
1. Research Overview
- Main research questions and objectives from the studies
- Types of studies included (clinical trials, observational studies, etc.)
- Study populations and settings
- Time period of the research
2. Key Findings
- Main outcomes and results reported in abstracts
- Primary endpoints and their measurements
- Statistical significance when reported
- Observed trends across studies
3. Methods Summary
- Study designs used
- Major interventions or treatments studied
- Key outcome measures
- Patient populations studied
4. Clinical Relevance
- Reported clinical implications
- Main conclusions from authors
- Reported benefits and risks
- Treatment responses when available
5. Research Status
- Current research focus areas
- Reported limitations
- Gaps identified in abstracts
- Authors' suggested future directions
CRITICAL REQUIREMENTS:
Citation Rules (MANDATORY):
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
- Use ONLY the papers provided in the available literature list above
- Citations must appear immediately after each statement
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
Content Guidelines:
- Present only information available in the provided papers
- If certain information is not available, simply omit that aspect rather than explicitly stating its absence
- Focus on synthesizing and presenting available findings
- Maintain professional medical writing style
- Present limitations and gaps as research opportunities rather than missing information
Writing Style:
- Use precise medical terminology
- Maintain objective reporting
- Use consistent terminology throughout
- Present a cohesive narrative without referencing data limitations
Language requirement:
- If the query explicitly specifies a language, use that language
- Otherwise, match the language of the original user query
"""
def _generate_general_review_prompt(self, criteria: SearchCriteria) -> str:
"""生成通用文献综述prompt"""
current_time = self._get_current_time()
final_prompt = f"""Current time: {current_time}
Conduct a comprehensive literature review on {criteria.main_topic} focusing on the following aspects:
{', '.join(criteria.sub_topics)}
Available literature for review:
{self._format_papers(self.ranked_papers)}
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
Please structure your review following these guidelines:
1. Introduction and Research Background
- Current state and significance of the research field
- Key research problems and challenges
- Research development timeline and evolution
2. Research Directions and Classifications
- Major research directions and their relationships
- Different technical approaches and their characteristics
- Comparative analysis of various solutions
3. Core Technologies and Methods
- Key technological breakthroughs
- Advantages and limitations of different methods
- Technical challenges and solutions
4. Applications and Impact
- Real-world applications and use cases
- Industry influence and practical value
- Implementation challenges and solutions
5. Future Trends and Prospects
- Emerging research directions
- Unsolved problems and challenges
- Potential breakthrough points
CRITICAL REQUIREMENTS:
Citation Rules (MANDATORY):
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
- Use ONLY the papers provided in the available literature list above
- Citations must appear immediately after each statement
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
Writing Style:
- Maintain academic and professional tone
- Focus on objective analysis with proper citations
- Ensure logical flow and clear structure
Content Requirements:
- Base ALL analysis STRICTLY on the provided papers with explicit citations
- When introducing any concept, method, or finding, immediately follow with [N]
- For each research direction or approach, cite the specific papers [N] that proposed or developed it
- When discussing limitations or challenges, cite the papers [N] that identified them
- DO NOT include information from sources outside the provided paper list
- DO NOT make unsupported claims or statements
Language requirement:
- If the query explicitly specifies a language, use that language
- Otherwise, match the language of the original user query
"""
return final_prompt

查看文件

@@ -0,0 +1,452 @@
from typing import List, Dict
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
from request_llms.embed_models.bge_llm import BGELLMRanker
from crazy_functions.review_fns.query_analyzer import SearchCriteria
import random
from crazy_functions.review_fns.data_sources.journal_metrics import JournalMetrics
class PaperLLMRanker:
"""使用LLM进行论文重排序"""
def __init__(self, llm_kwargs: Dict = None):
self.ranker = BGELLMRanker(llm_kwargs=llm_kwargs)
self.journal_metrics = JournalMetrics()
def _update_paper_metrics(self, papers: List[PaperMetadata]) -> None:
"""更新论文的期刊指标"""
for paper in papers:
# 跳过arXiv来源的论文
if getattr(paper, 'source', '') == 'arxiv':
continue
if hasattr(paper, 'venue_name') or hasattr(paper, 'venue_info'):
# 获取venue_name和venue_info
venue_name = getattr(paper, 'venue_name', '')
venue_info = getattr(paper, 'venue_info', {})
# 使用改进的匹配逻辑获取指标
metrics = self.journal_metrics.get_journal_metrics(venue_name, venue_info)
# 更新论文的指标
paper.if_factor = metrics.get('if_factor')
paper.jcr_division = metrics.get('jcr_division')
paper.cas_division = metrics.get('cas_division')
def _get_year_as_int(self, paper) -> int:
"""统一获取论文年份为整数格式
Args:
paper: 论文对象或直接是年份值
Returns:
整数格式的年份,如果无法转换则返回0
"""
try:
# 如果输入直接是年份而不是论文对象
if isinstance(paper, int):
return paper
elif isinstance(paper, str):
try:
return int(paper)
except ValueError:
import re
year_match = re.search(r'\d{4}', paper)
if year_match:
return int(year_match.group())
return 0
elif isinstance(paper, float):
return int(paper)
# 处理论文对象
year = getattr(paper, 'year', None)
if year is None:
return 0
# 如果是字符串,尝试转换为整数
if isinstance(year, str):
# 首先尝试直接转换整个字符串
try:
return int(year)
except ValueError:
# 如果直接转换失败,尝试提取第一个数字序列
import re
year_match = re.search(r'\d{4}', year)
if year_match:
return int(year_match.group())
return 0
# 如果是浮点数,转换为整数
elif isinstance(year, float):
return int(year)
# 如果已经是整数,直接返回
elif isinstance(year, int):
return year
return 0
except (ValueError, TypeError):
return 0
def rank_papers(
self,
query: str,
papers: List[PaperMetadata],
search_criteria: SearchCriteria = None,
top_k: int = 40,
use_rerank: bool = False,
pre_filter_ratio: float = 0.5,
max_papers: int = 150
) -> List[PaperMetadata]:
"""对论文进行重排序"""
initial_count = len(papers) if papers else 0
stats = {'initial': initial_count}
if not papers or not query:
return []
# 更新论文的期刊指标
self._update_paper_metrics(papers)
# 构建增强查询
# enhanced_query = self._build_enhanced_query(query, search_criteria) if search_criteria else query
enhanced_query = query
# 首先过滤不满足年份要求的论文
if search_criteria and search_criteria.start_year and search_criteria.end_year:
before_year_filter = len(papers)
filtered_papers = []
start_year = int(search_criteria.start_year)
end_year = int(search_criteria.end_year)
for paper in papers:
paper_year = self._get_year_as_int(paper)
if paper_year == 0 or start_year <= paper_year <= end_year:
filtered_papers.append(paper)
papers = filtered_papers
stats['after_year_filter'] = len(papers)
if not papers: # 如果过滤后没有论文,直接返回空列表
return []
# 新增:对少量论文的快速处理
SMALL_PAPER_THRESHOLD = 10 # 定义"少量"论文的阈值
if len(papers) <= SMALL_PAPER_THRESHOLD:
# 对于少量论文,直接根据查询类型进行简单排序
if search_criteria:
if search_criteria.query_type == "latest":
papers.sort(key=lambda x: getattr(x, 'year', 0) or 0, reverse=True)
elif search_criteria.query_type == "recommend":
papers.sort(key=lambda x: getattr(x, 'citations', 0) or 0, reverse=True)
elif search_criteria.query_type == "review":
papers.sort(key=lambda x:
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
keyword in (getattr(x, 'abstract', '') or '').lower()
for keyword in ['review', 'survey', 'overview'])
else 0,
reverse=True
)
return papers[:top_k]
# 1. 优先处理最新的论文
if search_criteria and search_criteria.query_type == "latest":
papers = sorted(papers, key=lambda x: self._get_year_as_int(x), reverse=True)
# 2. 如果是综述类查询,优先处理可能的综述论文
if search_criteria and search_criteria.query_type == "review":
papers = sorted(papers, key=lambda x:
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
keyword in (getattr(x, 'abstract', '') or '').lower()
for keyword in ['review', 'survey', 'overview'])
else 0,
reverse=True
)
# 3. 如果论文数量超过限制,采用分层采样而不是完全随机
if len(papers) > max_papers:
before_max_limit = len(papers)
papers = self._select_papers_strategically(papers, search_criteria, max_papers)
stats['after_max_limit'] = len(papers)
try:
paper_texts = []
valid_papers = [] # 4. 跟踪有效论文
for paper in papers:
if paper is None:
continue
# 5. 预先过滤明显不相关的论文
if search_criteria and search_criteria.start_year:
if getattr(paper, 'year', 0) and self._get_year_as_int(paper.year) < search_criteria.start_year:
continue
doc = self._build_enhanced_document(paper, search_criteria)
paper_texts.append(doc)
valid_papers.append(paper) # 记录对应的论文
stats['after_valid_check'] = len(valid_papers)
if not paper_texts:
return []
# 使用LLM判断相关性
relevance_results = self.ranker.batch_check_relevance(
query=enhanced_query, # 使用增强的查询
paper_texts=paper_texts,
show_progress=True
)
# 6. 优化相关论文的选择策略
relevant_papers = []
for paper, is_relevant in zip(valid_papers, relevance_results):
if is_relevant:
relevant_papers.append(paper)
stats['after_llm_filter'] = len(relevant_papers)
# 打印统计信息
print(f"论文筛选统计: 初始数量={stats['initial']}, " +
f"年份过滤后={stats.get('after_year_filter', stats['initial'])}, " +
f"数量限制后={stats.get('after_max_limit', stats.get('after_year_filter', stats['initial']))}, " +
f"有效性检查后={stats['after_valid_check']}, " +
f"LLM筛选后={stats['after_llm_filter']}")
# 7. 改进回退策略
if len(relevant_papers) < min(5, len(papers)):
# 如果相关论文太少,返回按引用量排序的论文
return sorted(
papers[:top_k],
key=lambda x: getattr(x, 'citations', 0) or 0,
reverse=True
)
# 8. 对最终结果进行排序
if search_criteria:
if search_criteria.query_type == "latest":
# 最新论文优先,但同年份按IF排序
relevant_papers.sort(key=lambda x: (
self._get_year_as_int(x),
getattr(x, 'if_factor', 0) or 0
), reverse=True)
elif search_criteria.query_type == "recommend":
# IF指数优先,其次是引用量
relevant_papers.sort(key=lambda x: (
getattr(x, 'if_factor', 0) or 0,
getattr(x, 'citations', 0) or 0
), reverse=True)
else:
# 默认按IF指数排序
relevant_papers.sort(key=lambda x: getattr(x, 'if_factor', 0) or 0, reverse=True)
return relevant_papers[:top_k]
except Exception as e:
print(f"论文排序时出错: {str(e)}")
# 9. 改进错误处理的回退策略
try:
return sorted(
papers[:top_k],
key=lambda x: getattr(x, 'citations', 0) or 0,
reverse=True
)
except:
return papers[:top_k] if papers else []
def _build_enhanced_query(self, query: str, criteria: SearchCriteria) -> str:
"""构建增强的查询文本"""
components = []
# 强调这是用户的原始查询,是最重要的匹配依据
components.append(f"Original user query that must be primarily matched: {query}")
if criteria:
# 添加主题(如果与原始查询不同)
if criteria.main_topic and criteria.main_topic != query:
components.append(f"Additional context - The main topic is about: {criteria.main_topic}")
# 添加子主题
if criteria.sub_topics:
components.append(f"Secondary aspects to consider: {', '.join(criteria.sub_topics)}")
# 添加查询类型相关信息
if criteria.query_type == "review":
components.append("Paper type preference: Looking for comprehensive review papers, survey papers, or overview papers")
elif criteria.query_type == "latest":
components.append("Temporal preference: Focus on the most recent developments and latest papers")
elif criteria.query_type == "recommend":
components.append("Impact preference: Consider influential and fundamental papers")
# 直接连接所有组件,保持语序
enhanced_query = ' '.join(components)
# 限制长度但不打乱顺序
if len(enhanced_query) > 1000:
enhanced_query = enhanced_query[:997] + "..."
return enhanced_query
def _build_enhanced_document(self, paper: PaperMetadata, criteria: SearchCriteria) -> str:
"""构建增强的文档表示"""
components = []
# 基本信息
title = getattr(paper, 'title', '')
authors = ', '.join(getattr(paper, 'authors', []))
abstract = getattr(paper, 'abstract', '')
year = getattr(paper, 'year', '')
venue = getattr(paper, 'venue', '')
components.extend([
f"Title: {title}",
f"Authors: {authors}",
f"Year: {year}",
f"Venue: {venue}",
f"Abstract: {abstract}"
])
# 根据查询类型添加额外信息
if criteria:
if criteria.query_type == "review":
# 对于综述类查询,强调论文的综述性质
title_lower = (title or '').lower()
abstract_lower = (abstract or '').lower()
if any(keyword in title_lower or keyword in abstract_lower
for keyword in ['review', 'survey', 'overview']):
components.append("This is a review/survey paper")
elif criteria.query_type == "latest":
# 对于最新论文查询,强调时间信息
if year and int(year) >= criteria.start_year:
components.append(f"This is a recent paper from {year}")
elif criteria.query_type == "recommend":
# 对于推荐类查询,添加主题相关性信息
if criteria.main_topic:
title_lower = (title or '').lower()
abstract_lower = (abstract or '').lower()
topic_relevance = any(topic.lower() in title_lower or topic.lower() in abstract_lower
for topic in [criteria.main_topic] + (criteria.sub_topics or []))
if topic_relevance:
components.append(f"This paper is directly related to {criteria.main_topic}")
return '\n'.join(components)
def _select_papers_strategically(
self,
papers: List[PaperMetadata],
search_criteria: SearchCriteria,
max_papers: int = 150
) -> List[PaperMetadata]:
"""战略性地选择论文子集,优先选择非Crossref来源的论文,
当ADS论文充足时排除arXiv论文"""
if len(papers) <= max_papers:
return papers
# 1. 首先按来源分组
papers_by_source = {
'crossref': [],
'adsabs': [],
'arxiv': [],
'others': [] # semantic, pubmed等其他来源
}
for paper in papers:
source = getattr(paper, 'source', '')
if source == 'crossref':
papers_by_source['crossref'].append(paper)
elif source == 'adsabs':
papers_by_source['adsabs'].append(paper)
elif source == 'arxiv':
papers_by_source['arxiv'].append(paper)
else:
papers_by_source['others'].append(paper)
# 2. 计算分数的通用函数
def calculate_paper_score(paper):
score = 0
title = (getattr(paper, 'title', '') or '').lower()
abstract = (getattr(paper, 'abstract', '') or '').lower()
year = self._get_year_as_int(paper)
citations = getattr(paper, 'citations', 0) or 0
# 安全地获取搜索条件
main_topic = (getattr(search_criteria, 'main_topic', '') or '').lower()
sub_topics = getattr(search_criteria, 'sub_topics', []) or []
query_type = getattr(search_criteria, 'query_type', '')
start_year = getattr(search_criteria, 'start_year', 0) or 0
# 主题相关性得分
if main_topic and main_topic in title:
score += 10
if main_topic and main_topic in abstract:
score += 5
# 子主题相关性得分
for sub_topic in sub_topics:
if sub_topic and sub_topic.lower() in title:
score += 5
if sub_topic and sub_topic.lower() in abstract:
score += 2.5
# 根据查询类型调整分数
if query_type == "review":
review_keywords = ['review', 'survey', 'overview']
if any(keyword in title for keyword in review_keywords):
score *= 1.5
if any(keyword in abstract for keyword in review_keywords):
score *= 1.2
elif query_type == "latest":
if year and start_year:
year_int = year if isinstance(year, int) else self._get_year_as_int(paper)
start_year_int = start_year if isinstance(start_year, int) else int(start_year)
if year_int >= start_year_int:
recency_bonus = min(5, (year_int - start_year_int))
score += recency_bonus * 2
elif query_type == "recommend":
citation_score = min(10, citations / 100)
score += citation_score
return score
result = []
# 3. 处理ADS和arXiv论文
non_crossref_papers = papers_by_source['others'] # 首先添加其他来源的论文
# 添加ADS论文
if papers_by_source['adsabs']:
non_crossref_papers.extend(papers_by_source['adsabs'])
# 只有当ADS论文不足20篇时,才添加arXiv论文
if len(papers_by_source['adsabs']) <= 20:
non_crossref_papers.extend(papers_by_source['arxiv'])
elif not papers_by_source['adsabs'] and papers_by_source['arxiv']:
# 如果没有ADS论文但有arXiv论文,也使用arXiv论文
non_crossref_papers.extend(papers_by_source['arxiv'])
# 4. 对非Crossref论文评分和排序
scored_non_crossref = [(p, calculate_paper_score(p)) for p in non_crossref_papers]
scored_non_crossref.sort(key=lambda x: x[1], reverse=True)
# 5. 先添加高分的非Crossref论文
non_crossref_limit = max_papers * 0.9 # 90%的配额给非Crossref论文
if len(scored_non_crossref) >= non_crossref_limit:
result.extend([p[0] for p in scored_non_crossref[:int(non_crossref_limit)]])
else:
result.extend([p[0] for p in scored_non_crossref])
# 6. 如果还有剩余空间,考虑添加Crossref论文
remaining_slots = max_papers - len(result)
if remaining_slots > 0 and papers_by_source['crossref']:
# 计算Crossref论文的最大数量不超过总数的10%
max_crossref = min(remaining_slots, max_papers * 0.1)
# 对Crossref论文评分和排序
scored_crossref = [(p, calculate_paper_score(p)) for p in papers_by_source['crossref']]
scored_crossref.sort(key=lambda x: x[1], reverse=True)
# 添加最高分的Crossref论文
result.extend([p[0] for p in scored_crossref[:int(max_crossref)]])
# 7. 如果使用了Crossref论文后还有空位,继续使用非Crossref论文填充
if len(result) < max_papers and len(scored_non_crossref) > len(result):
remaining_non_crossref = [p[0] for p in scored_non_crossref[len(result):]]
result.extend(remaining_non_crossref[:max_papers - len(result)])
return result

查看文件

@@ -0,0 +1,76 @@
# ADS query optimization prompt
ADSABS_QUERY_PROMPT = """Analyze and optimize the following query for NASA ADS search.
If the query is not related to astronomy, astrophysics, or physics, return <query>none</query>.
If the query contains non-English terms, translate them to English first.
Query: {query}
Task: Transform the natural language query into an optimized ADS search query.
Always generate English search terms regardless of the input language.
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
or output format requirements. Focus only on the core research topic for the search query.
Relevant research areas for ADS:
- Astronomy and astrophysics
- Physics (theoretical and experimental)
- Space science and exploration
- Planetary science
- Cosmology
- Astrobiology
- Related instrumentation and methods
Available search fields and filters:
1. Basic fields:
- title: Search in title (title:"term")
- abstract: Search in abstract (abstract:"term")
- author: Search for author names (author:"lastname, firstname")
- year: Filter by year (year:2020-2023)
- bibstem: Search by journal abbreviation (bibstem:ApJ)
2. Boolean operators:
- AND
- OR
- NOT
- (): Group terms
- "": Exact phrase match
3. Special filters:
- citations(identifier:paper): Papers citing a specific paper
- references(identifier:paper): References of a specific paper
- citation_count: Filter by citation count
- database: Filter by database (database:astronomy)
Examples:
1. Query: "Black holes in galaxy centers after 2020"
<query>title:"black hole" AND abstract:"galaxy center" AND year:2020-</query>
2. Query: "Papers by Neil deGrasse Tyson about exoplanets"
<query>author:"Tyson, Neil deGrasse" AND title:exoplanet</query>
3. Query: "Most cited papers about dark matter in ApJ"
<query>title:"dark matter" AND bibstem:ApJ AND citation_count:[100 TO *]</query>
4. Query: "Latest research on diabetes treatment"
<query>none</query>
5. Query: "Machine learning for galaxy classification"
<query>title:("machine learning" OR "deep learning") AND (title:galaxy OR abstract:galaxy) AND abstract:classification</query>
Please analyze the query and respond ONLY with XML tags:
<query>Provide the optimized ADS search query using appropriate fields and operators, or "none" if not relevant</query>"""
# System prompt
ADSABS_QUERY_SYSTEM_PROMPT = """You are an expert at crafting NASA ADS search queries.
Your task is to:
1. First determine if the query is relevant to astronomy, astrophysics, or physics research
2. If relevant, optimize the natural language query for the ADS API
3. If not relevant, return "none" to indicate the query should be handled by other databases
Focus on creating precise queries that will return relevant astronomical and physics literature.
Always generate English search terms regardless of the input language.
Consider using field-specific search terms and appropriate filters to improve search accuracy.
Remember: ADS is specifically for astronomy, astrophysics, and physics research.
Medical, biological, or general research queries should return "none"."""

查看文件

@@ -0,0 +1,341 @@
# Basic type analysis prompt
ARXIV_TYPE_PROMPT = """Analyze the research query and determine if arXiv search is needed and its type.
Query: {query}
Task 1: Determine if this query requires arXiv search
- arXiv is suitable for:
* Computer science and AI/ML
* Physics and mathematics
* Quantitative biology and finance
* Electrical engineering
* Recent preprints in these fields
- arXiv is NOT needed for:
* Medical research (unless ML/AI applications)
* Social sciences
* Business studies
* Humanities
* Industry reports
Task 2: If arXiv search is needed, determine the most appropriate search type
Available types:
1. basic: Keyword-based search across all fields
- For specific technical queries
- When looking for particular methods or applications
2. category: Category-based search within specific fields
- For broad topic exploration
- When surveying a research area
3. none: arXiv search not needed for this query
- When topic is outside arXiv's scope
- For non-technical or clinical research
Examples:
1. Query: "BERT transformer architecture"
<search_type>basic</search_type>
2. Query: "latest developments in machine learning"
<search_type>category</search_type>
3. Query: "COVID-19 clinical trials"
<search_type>none</search_type>
4. Query: "psychological effects of social media"
<search_type>none</search_type>
Please analyze the query and respond ONLY with XML tags:
<search_type>Choose either 'basic', 'category', or 'none'</search_type>"""
# Query optimization prompt
ARXIV_QUERY_PROMPT = """Optimize the following query for arXiv search.
Query: {query}
Task: Transform the natural language query into an optimized arXiv search query using boolean operators and field tags.
Always generate English search terms regardless of the input language.
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
or output format requirements. Focus only on the core research topic for the search query.
Available field tags:
- ti: Search in title
- abs: Search in abstract
- au: Search for author
- all: Search in all fields (default)
Boolean operators:
- AND: Both terms must appear
- OR: Either term can appear
- NOT: Exclude terms
- (): Group terms
- "": Exact phrase match
Examples:
1. Natural query: "Recent papers about transformer models by Vaswani"
<query>ti:"transformer model" AND au:Vaswani AND year:[2017 TO 2024]</query>
2. Natural query: "Deep learning for computer vision, excluding surveys"
<query>ti:(deep learning AND "computer vision") NOT (ti:survey OR ti:review)</query>
3. Natural query: "Attention mechanism in language models"
<query>ti:(attention OR "attention mechanism") AND abs:"language model"</query>
4. Natural query: "GANs or generative adversarial networks for image generation"
<query>(ti:GAN OR ti:"generative adversarial network") AND abs:"image generation"</query>
Please analyze the query and respond ONLY with XML tags:
<query>Provide the optimized search query using appropriate operators and tags</query>
Note:
- Use quotes for exact phrases
- Combine multiple conditions with boolean operators
- Consider both title and abstract for important concepts
- Include author names when relevant
- Use parentheses for complex logical groupings"""
# Sort parameters prompt
ARXIV_SORT_PROMPT = """Determine optimal sorting parameters for the research query.
Query: {query}
Task: Select the most appropriate sorting parameters to help users find the most relevant papers.
Available sorting options:
1. Sort by:
- relevance: Best match to query terms (default)
- lastUpdatedDate: Most recently updated papers
- submittedDate: Most recently submitted papers
2. Sort order:
- descending: Newest/Most relevant first (default)
- ascending: Oldest/Least relevant first
3. Result limit:
- Minimum: 10 papers
- Maximum: 50 papers
- Recommended: 20-30 papers for most queries
Examples:
1. Query: "Latest developments in transformer models"
<sort_by>submittedDate</sort_by>
<sort_order>descending</sort_order>
<limit>30</limit>
2. Query: "Foundational papers about neural networks"
<sort_by>relevance</sort_by>
<sort_order>descending</sort_order>
<limit>20</limit>
3. Query: "Evolution of deep learning since 2012"
<sort_by>submittedDate</sort_by>
<sort_order>ascending</sort_order>
<limit>50</limit>
Please analyze the query and respond ONLY with XML tags:
<sort_by>Choose: relevance, lastUpdatedDate, or submittedDate</sort_by>
<sort_order>Choose: ascending or descending</sort_order>
<limit>Suggest number between 10-50</limit>
Note:
- Choose relevance for specific technical queries
- Use lastUpdatedDate for tracking paper revisions
- Use submittedDate for following recent developments
- Consider query context when setting the limit"""
# System prompts for each task
ARXIV_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
Your task is to determine whether the query is better suited for keyword search or category-based search.
Consider the query's specificity, scope, and intended search area when making your decision.
Always respond in English regardless of the input language."""
ARXIV_QUERY_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
Your task is to optimize natural language queries using boolean operators and field tags.
Focus on creating precise, targeted queries that will return the most relevant results.
Always generate English search terms regardless of the input language."""
ARXIV_CATEGORIES_SYSTEM_PROMPT = """You are an expert at arXiv category classification.
Your task is to select the most relevant categories for the given research query.
Consider both primary and related interdisciplinary categories, while maintaining focus on the main research area.
Always respond in English regardless of the input language."""
ARXIV_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
Your task is to determine the best sorting parameters based on the query context.
Consider the user's likely intent and temporal aspects of the research topic.
Always respond in English regardless of the input language."""
# 添加新的搜索提示词
ARXIV_SEARCH_PROMPT = """Analyze and optimize the research query for arXiv search.
Query: {query}
Task: Transform the natural language query into an optimized arXiv search query.
Available search options:
1. Basic search with field tags:
- ti: Search in title
- abs: Search in abstract
- au: Search for author
Example: "ti:transformer AND abs:attention"
2. Category-based search:
- Use specific arXiv categories
Example: "cat:cs.AI AND neural networks"
3. Date range:
- Specify date range using submittedDate
Example: "deep learning AND submittedDate:[20200101 TO 20231231]"
Examples:
1. Query: "Recent papers about transformer models by Vaswani"
<search_criteria>
<query>ti:"transformer model" AND au:Vaswani AND submittedDate:[20170101 TO 99991231]</query>
<categories>cs.CL, cs.AI, cs.LG</categories>
<sort_by>submittedDate</sort_by>
<sort_order>descending</sort_order>
<limit>30</limit>
</search_criteria>
2. Query: "Latest developments in computer vision"
<search_criteria>
<query>cat:cs.CV AND submittedDate:[20220101 TO 99991231]</query>
<categories>cs.CV, cs.AI, cs.LG</categories>
<sort_by>submittedDate</sort_by>
<sort_order>descending</sort_order>
<limit>25</limit>
</search_criteria>
Please analyze the query and respond with XML tags containing search criteria."""
ARXIV_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
Your task is to analyze research queries and transform them into optimized arXiv search criteria.
Consider query intent, relevant categories, and temporal aspects when creating the search parameters.
Always generate English search terms and respond in English regardless of the input language."""
# Categories selection prompt
ARXIV_CATEGORIES_PROMPT = """Select the most relevant arXiv categories for the research query.
Query: {query}
Task: Choose 2-4 most relevant categories that best match the research topic.
Available Categories:
Computer Science (cs):
- cs.AI: Artificial Intelligence (neural networks, machine learning, NLP)
- cs.CL: Computation and Language (NLP, machine translation)
- cs.CV: Computer Vision and Pattern Recognition
- cs.LG: Machine Learning (deep learning, reinforcement learning)
- cs.NE: Neural and Evolutionary Computing
- cs.RO: Robotics
- cs.IR: Information Retrieval
- cs.SE: Software Engineering
- cs.DB: Databases
- cs.DC: Distributed Computing
- cs.CY: Computers and Society
- cs.HC: Human-Computer Interaction
Mathematics (math):
- math.OC: Optimization and Control
- math.PR: Probability
- math.ST: Statistics
- math.NA: Numerical Analysis
- math.DS: Dynamical Systems
Statistics (stat):
- stat.ML: Machine Learning
- stat.ME: Methodology
- stat.TH: Theory
- stat.AP: Applications
Physics (physics):
- physics.comp-ph: Computational Physics
- physics.data-an: Data Analysis
- physics.soc-ph: Physics and Society
Electrical Engineering (eess):
- eess.SP: Signal Processing
- eess.AS: Audio and Speech Processing
- eess.IV: Image and Video Processing
- eess.SY: Systems and Control
Examples:
1. Query: "Deep learning for computer vision"
<categories>cs.CV, cs.LG, stat.ML</categories>
2. Query: "Natural language processing with transformers"
<categories>cs.CL, cs.AI, cs.LG</categories>
3. Query: "Reinforcement learning for robotics"
<categories>cs.RO, cs.AI, cs.LG</categories>
4. Query: "Statistical methods in machine learning"
<categories>stat.ML, cs.LG, math.ST</categories>
Please analyze the query and respond ONLY with XML tags:
<categories>List 2-4 most relevant categories, comma-separated</categories>
Note:
- Choose primary categories first, then add related ones
- Limit to 2-4 most relevant categories
- Order by relevance (most relevant first)
- Use comma and space between categories (e.g., "cs.AI, cs.LG")"""
# 在文件末尾添加新的 prompt
ARXIV_LATEST_PROMPT = """Determine if the query is requesting latest papers from arXiv.
Query: {query}
Task: Analyze if the query is specifically asking for recent/latest papers from arXiv.
IMPORTANT RULE:
- The query MUST explicitly mention "arXiv" or "arxiv" to be considered a latest arXiv papers request
- Queries only asking for recent/latest papers WITHOUT mentioning arXiv should return false
Indicators for latest papers request:
1. MUST HAVE keywords about arXiv:
- "arxiv"
- "arXiv"
AND
2. Keywords about recency:
- "latest"
- "recent"
- "new"
- "newest"
- "just published"
- "this week/month"
Examples:
1. Latest papers request (Valid):
Query: "Show me the latest AI papers on arXiv"
<is_latest_request>true</is_latest_request>
2. Latest papers request (Valid):
Query: "What are the recent papers about transformers on arxiv"
<is_latest_request>true</is_latest_request>
3. Not a latest papers request (Invalid - no mention of arXiv):
Query: "Show me the latest papers about BERT"
<is_latest_request>false</is_latest_request>
4. Not a latest papers request (Invalid - no recency):
Query: "Find papers on arxiv about transformers"
<is_latest_request>false</is_latest_request>
Please analyze the query and respond ONLY with XML tags:
<is_latest_request>true/false</is_latest_request>
Note: The response should be true ONLY if both conditions are met:
1. Query explicitly mentions arXiv/arxiv
2. Query asks for recent/latest papers"""
ARXIV_LATEST_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
Your task is to determine if the query is specifically requesting latest/recent papers from arXiv.
Remember: The query MUST explicitly mention arXiv to be considered valid, even if it asks for recent papers.
Always respond in English regardless of the input language."""

查看文件

@@ -0,0 +1,55 @@
# Crossref query optimization prompt
CROSSREF_QUERY_PROMPT = """Analyze and optimize the query for Crossref search.
Query: {query}
Task: Transform the natural language query into an optimized Crossref search query.
Always generate English search terms regardless of the input language.
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
or output format requirements. Focus only on the core research topic for the search query.
Available search fields and filters:
1. Basic fields:
- title: Search in title
- abstract: Search in abstract
- author: Search for author names
- container-title: Search in journal/conference name
- publisher: Search by publisher name
- type: Filter by work type (journal-article, book-chapter, etc.)
- year: Filter by publication year
2. Boolean operators:
- AND: Both terms must appear
- OR: Either term can appear
- NOT: Exclude terms
- "": Exact phrase match
3. Special filters:
- is-referenced-by-count: Filter by citation count
- from-pub-date: Filter by publication date
- has-abstract: Filter papers with abstracts
Examples:
1. Query: "Machine learning in healthcare after 2020"
<query>title:"machine learning" AND title:healthcare AND from-pub-date:2020</query>
2. Query: "Papers by Geoffrey Hinton about deep learning"
<query>author:"Hinton, Geoffrey" AND (title:"deep learning" OR abstract:"deep learning")</query>
3. Query: "Most cited papers about transformers in Nature"
<query>title:transformer AND container-title:Nature AND is-referenced-by-count:[100 TO *]</query>
4. Query: "Recent BERT applications in medical domain"
<query>title:BERT AND abstract:medical AND from-pub-date:2020 AND type:journal-article</query>
Please analyze the query and respond ONLY with XML tags:
<query>Provide the optimized Crossref search query using appropriate fields and operators</query>"""
# System prompt
CROSSREF_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Crossref search queries.
Your task is to optimize natural language queries for Crossref's API.
Focus on creating precise queries that will return relevant results.
Always generate English search terms regardless of the input language.
Consider using field-specific search terms and appropriate filters to improve search accuracy."""

查看文件

@@ -0,0 +1,47 @@
# 新建文件,添加论文识别提示
PAPER_IDENTIFY_PROMPT = """Analyze the query to identify paper details.
Query: {query}
Task: Extract paper identification information from the query.
Always generate English search terms regardless of the input language.
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
or output format requirements. Focus only on identifying paper details.
Possible paper identifiers:
1. arXiv ID (e.g., 2103.14030, arXiv:2103.14030)
2. DOI (e.g., 10.1234/xxx.xxx)
3. Paper title (e.g., "Attention is All You Need")
Examples:
1. Query with arXiv ID:
Query: "Analyze paper 2103.14030"
<paper_info>
<paper_source>arxiv</paper_source>
<paper_id>2103.14030</paper_id>
<paper_title></paper_title>
</paper_info>
2. Query with DOI:
Query: "Review the paper with DOI 10.1234/xxx.xxx"
<paper_info>
<paper_source>doi</paper_source>
<paper_id>10.1234/xxx.xxx</paper_id>
<paper_title></paper_title>
</paper_info>
3. Query with paper title:
Query: "Analyze 'Attention is All You Need' paper"
<paper_info>
<paper_source>title</paper_source>
<paper_id></paper_id>
<paper_title>Attention is All You Need</paper_title>
</paper_info>
Please analyze the query and respond ONLY with XML tags containing paper information."""
PAPER_IDENTIFY_SYSTEM_PROMPT = """You are an expert at identifying academic paper references.
Your task is to extract paper identification information from queries.
Look for arXiv IDs, DOIs, and paper titles."""

查看文件

@@ -0,0 +1,108 @@
# PubMed search type prompt
PUBMED_TYPE_PROMPT = """Analyze the research query and determine the appropriate PubMed search type.
Query: {query}
Available search types:
1. basic: General keyword search for medical/biomedical topics
2. author: Search by author name
3. journal: Search within specific journals
4. none: Query not related to medical/biomedical research
Examples:
1. Query: "COVID-19 treatment outcomes"
<search_type>basic</search_type>
2. Query: "Papers by Anthony Fauci"
<search_type>author</search_type>
3. Query: "Recent papers in Nature about CRISPR"
<search_type>journal</search_type>
4. Query: "Deep learning for computer vision"
<search_type>none</search_type>
5. Query: "Transformer architecture for NLP"
<search_type>none</search_type>
Please analyze the query and respond ONLY with XML tags:
<search_type>Choose: basic, author, journal, or none</search_type>"""
# PubMed query optimization prompt
PUBMED_QUERY_PROMPT = """Optimize the following query for PubMed search.
Query: {query}
Task: Transform the natural language query into an optimized PubMed search query.
Requirements:
- Always generate English search terms regardless of input language
- Translate any non-English terms to English before creating the query
- Never include non-English characters in the final query
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
or output format requirements. Focus only on the core medical/biomedical topic for the search query.
Available field tags:
- [Title] - Search in title
- [Author] - Search for author
- [Journal] - Search in journal name
- [MeSH Terms] - Search using MeSH terms
Boolean operators:
- AND
- OR
- NOT
Examples:
1. Query: "COVID-19 treatment in elderly patients"
<query>COVID-19[Title] AND treatment[Title/Abstract] AND elderly[Title/Abstract]</query>
2. Query: "Cancer immunotherapy review articles"
<query>cancer immunotherapy[Title/Abstract] AND review[Publication Type]</query>
Please analyze the query and respond ONLY with XML tags:
<query>Provide the optimized PubMed search query</query>"""
# PubMed sort parameters prompt
PUBMED_SORT_PROMPT = """Determine optimal sorting parameters for PubMed results.
Query: {query}
Task: Select the most appropriate sorting method and result limit.
Available sort options:
- relevance: Best match to query
- date: Most recent first
- journal: Sort by journal name
Examples:
1. Query: "Latest developments in gene therapy"
<sort_by>date</sort_by>
<limit>30</limit>
2. Query: "Classic papers about DNA structure"
<sort_by>relevance</sort_by>
<limit>20</limit>
Please analyze the query and respond ONLY with XML tags:
<sort_by>Choose: relevance, date, or journal</sort_by>
<limit>Suggest number between 10-50</limit>"""
# System prompts
PUBMED_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing medical and scientific queries.
Your task is to determine the most appropriate PubMed search type.
Consider the query's focus and intended search scope.
Always respond in English regardless of the input language."""
PUBMED_QUERY_SYSTEM_PROMPT = """You are an expert at crafting PubMed search queries.
Your task is to optimize natural language queries using PubMed's search syntax.
Focus on creating precise, targeted queries that will return relevant medical literature.
Always generate English search terms regardless of the input language."""
PUBMED_SORT_SYSTEM_PROMPT = """You are an expert at optimizing PubMed search results.
Your task is to determine the best sorting parameters based on the query context.
Consider the balance between relevance and recency.
Always respond in English regardless of the input language."""

查看文件

@@ -0,0 +1,276 @@
# Search type prompt
SEMANTIC_TYPE_PROMPT = """Determine the most appropriate search type for Semantic Scholar.
Query: {query}
Task: Analyze the research query and select the most appropriate search type for Semantic Scholar API.
Available search types:
1. paper: General paper search
- Use for broad topic searches
- Looking for specific papers
- Keyword-based searches
Example: "transformer models in NLP"
2. author: Author-based search
- Finding works by specific researchers
- Author profile analysis
Example: "papers by Yoshua Bengio"
3. paper_details: Specific paper lookup
- Getting details about a known paper
- Finding specific versions or citations
Example: "Attention is All You Need paper details"
4. citations: Citation analysis
- Finding papers that cite a specific work
- Impact analysis
Example: "papers citing BERT"
5. references: Reference analysis
- Finding papers cited by a specific work
- Background research
Example: "references in GPT-3 paper"
6. recommendations: Paper recommendations
- Finding similar papers
- Research direction exploration
Example: "papers similar to Transformer"
Examples:
1. Query: "Latest papers about deep learning"
<search_type>paper</search_type>
2. Query: "Works by Geoffrey Hinton since 2020"
<search_type>author</search_type>
3. Query: "Papers citing the original Transformer paper"
<search_type>citations</search_type>
Please analyze the query and respond ONLY with XML tags:
<search_type>Choose the most appropriate search type from the list above</search_type>"""
# Query optimization prompt
SEMANTIC_QUERY_PROMPT = """Optimize the following query for Semantic Scholar search.
Query: {query}
Task: Transform the natural language query into an optimized search query for maximum relevance.
Always generate English search terms regardless of the input language.
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
or output format requirements. Focus only on the core research topic for the search query.
Query optimization guidelines:
1. Use quotes for exact phrases
- Ensures exact matching
- Reduces irrelevant results
Example: "\"attention mechanism\"" vs attention mechanism
2. Include key technical terms
- Use specific technical terminology
- Include common variations
Example: "transformer architecture" neural networks
3. Author names (if relevant)
- Include full names when known
- Consider common name variations
Example: "Geoffrey Hinton" OR "G. E. Hinton"
Examples:
1. Natural query: "Recent advances in transformer models"
<query>"transformer model" "neural architecture" deep learning</query>
2. Natural query: "BERT applications in text classification"
<query>"BERT" "text classification" "language model" application</query>
3. Natural query: "Deep learning for computer vision by Kaiming He"
<query>"deep learning" "computer vision" author:"Kaiming He"</query>
Please analyze the query and respond ONLY with XML tags:
<query>Provide the optimized search query</query>
Note:
- Balance between specificity and coverage
- Include important technical terms
- Use quotes for key phrases
- Consider synonyms and related terms"""
# Fields selection prompt
SEMANTIC_FIELDS_PROMPT = """Select relevant fields to retrieve from Semantic Scholar.
Query: {query}
Task: Determine which paper fields should be retrieved based on the research needs.
Available fields:
Core fields:
- title: Paper title (always included)
- abstract: Full paper abstract
- authors: Author information
- year: Publication year
- venue: Publication venue
Citation fields:
- citations: Papers citing this work
- references: Papers cited by this work
Additional fields:
- embedding: Paper vector embedding
- tldr: AI-generated summary
- venue: Publication venue/journal
- url: Paper URL
Examples:
1. Query: "Latest developments in NLP"
<fields>title, abstract, authors, year, venue, citations</fields>
2. Query: "Most influential papers in deep learning"
<fields>title, abstract, authors, year, citations, references</fields>
3. Query: "Survey of transformer architectures"
<fields>title, abstract, authors, year, tldr, references</fields>
Please analyze the query and respond ONLY with XML tags:
<fields>List relevant fields, comma-separated</fields>
Note:
- Choose fields based on the query's purpose
- Include citation data for impact analysis
- Consider tldr for quick paper screening
- Balance completeness with API efficiency"""
# Sort parameters prompt
SEMANTIC_SORT_PROMPT = """Determine optimal sorting parameters for the query.
Query: {query}
Task: Select the most appropriate sorting method and result limit for the search.
Always generate English search terms regardless of the input language.
Sorting options:
1. relevance (default)
- Best match to query terms
- Recommended for specific technical searches
Example: "specific algorithm implementations"
2. citations
- Sort by citation count
- Best for finding influential papers
Example: "most important papers in deep learning"
3. year
- Sort by publication date
- Best for following recent developments
Example: "latest advances in NLP"
Examples:
1. Query: "Recent breakthroughs in AI"
<sort_by>year</sort_by>
<limit>30</limit>
2. Query: "Most influential papers about GANs"
<sort_by>citations</sort_by>
<limit>20</limit>
3. Query: "Specific papers about BERT fine-tuning"
<sort_by>relevance</sort_by>
<limit>25</limit>
Please analyze the query and respond ONLY with XML tags:
<sort_by>Choose: relevance, citations, or year</sort_by>
<limit>Suggest number between 10-50</limit>
Note:
- Consider the query's temporal aspects
- Balance between comprehensive coverage and information overload
- Use citation sorting for impact analysis
- Use year sorting for tracking developments"""
# System prompts for each task
SEMANTIC_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
Your task is to determine the most appropriate type of search on Semantic Scholar.
Consider the query's intent, scope, and specific research needs.
Always respond in English regardless of the input language."""
SEMANTIC_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
Your task is to optimize natural language queries for maximum relevance.
Focus on creating precise queries that leverage the platform's search capabilities.
Always generate English search terms regardless of the input language."""
SEMANTIC_FIELDS_SYSTEM_PROMPT = """You are an expert at Semantic Scholar data fields.
Your task is to select the most relevant fields based on the research context.
Consider both essential and supplementary information needs.
Always respond in English regardless of the input language."""
SEMANTIC_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
Your task is to determine the best sorting parameters based on the query context.
Consider the balance between relevance, impact, and recency.
Always respond in English regardless of the input language."""
# 添加新的综合搜索提示词
SEMANTIC_SEARCH_PROMPT = """Analyze and optimize the research query for Semantic Scholar search.
Query: {query}
Task: Transform the natural language query into optimized search criteria for Semantic Scholar.
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
or output format requirements when generating the search terms. These requirements
should be considered only for post-search filtering, not as part of the core query.
Available search options:
1. Paper search:
- Title and abstract search
- Author search
- Field-specific search
Example: "transformer architecture neural networks"
2. Field tags:
- title: Search in title
- abstract: Search in abstract
- authors: Search by author names
- venue: Search by publication venue
Example: "title:transformer authors:\"Vaswani\""
3. Advanced options:
- Year range filtering
- Citation count filtering
- Venue filtering
Example: "deep learning year>2020 venue:\"NeurIPS\""
Examples:
1. Query: "Recent transformer papers by Vaswani with high impact"
<search_criteria>
<query>title:transformer authors:"Vaswani" year>2017</query>
<search_type>paper</search_type>
<fields>title,abstract,authors,year,citations,venue</fields>
<sort_by>citations</sort_by>
<limit>30</limit>
</search_criteria>
2. Query: "Most cited papers about BERT in top conferences"
<search_criteria>
<query>title:BERT venue:"ACL|EMNLP|NAACL"</query>
<search_type>paper</search_type>
<fields>title,abstract,authors,year,citations,venue,references</fields>
<sort_by>citations</sort_by>
<limit>25</limit>
</search_criteria>
Please analyze the query and respond with XML tags containing complete search criteria."""
SEMANTIC_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
Your task is to analyze research queries and transform them into optimized search criteria.
Consider query intent, field relevance, and citation impact when creating the search parameters.
Focus on producing precise and comprehensive search criteria that will yield the most relevant results.
Always generate English search terms and respond in English regardless of the input language."""

查看文件

@@ -0,0 +1,493 @@
from typing import Dict, List
from dataclasses import dataclass
from textwrap import dedent
from datetime import datetime
import re
@dataclass
class SearchCriteria:
"""搜索条件"""
query_type: str # 查询类型: review/recommend/qa/paper
main_topic: str # 主题
sub_topics: List[str] # 子主题列表
start_year: int # 起始年份
end_year: int # 结束年份
arxiv_params: Dict # arXiv搜索参数
semantic_params: Dict # Semantic Scholar搜索参数
pubmed_params: Dict # 新增: PubMed搜索参数
crossref_params: Dict # 添加 Crossref 参数
adsabs_params: Dict # 添加 ADS 参数
paper_id: str = "" # 论文ID (arxiv ID 或 DOI)
paper_title: str = "" # 论文标题
paper_source: str = "" # 论文来源 (arxiv/doi/title)
original_query: str = "" # 新增: 原始查询字符串
class QueryAnalyzer:
"""查询分析器"""
# 响应索引常量
BASIC_QUERY_INDEX = 0
PAPER_IDENTIFY_INDEX = 1
ARXIV_QUERY_INDEX = 2
ARXIV_CATEGORIES_INDEX = 3
ARXIV_LATEST_INDEX = 4
ARXIV_SORT_INDEX = 5
SEMANTIC_QUERY_INDEX = 6
SEMANTIC_FIELDS_INDEX = 7
PUBMED_TYPE_INDEX = 8
PUBMED_QUERY_INDEX = 9
CROSSREF_QUERY_INDEX = 10
ADSABS_QUERY_INDEX = 11
def __init__(self):
self.current_year = datetime.now().year
self.valid_types = {
"review": ["review", "literature review", "survey"],
"recommend": ["recommend", "recommendation", "suggest", "similar"],
"qa": ["qa", "question", "answer", "explain", "what", "how", "why"],
"paper": ["paper", "analyze", "analysis"]
}
def analyze_query(self, query: str, chatbot: List, llm_kwargs: Dict):
"""分析查询意图"""
from crazy_functions.crazy_utils import \
request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
from crazy_functions.review_fns.prompts.arxiv_prompts import (
ARXIV_QUERY_PROMPT, ARXIV_CATEGORIES_PROMPT, ARXIV_LATEST_PROMPT,
ARXIV_SORT_PROMPT, ARXIV_QUERY_SYSTEM_PROMPT, ARXIV_CATEGORIES_SYSTEM_PROMPT, ARXIV_SORT_SYSTEM_PROMPT,
ARXIV_LATEST_SYSTEM_PROMPT
)
from crazy_functions.review_fns.prompts.semantic_prompts import (
SEMANTIC_QUERY_PROMPT, SEMANTIC_FIELDS_PROMPT,
SEMANTIC_QUERY_SYSTEM_PROMPT, SEMANTIC_FIELDS_SYSTEM_PROMPT
)
from .prompts.paper_prompts import PAPER_IDENTIFY_PROMPT, PAPER_IDENTIFY_SYSTEM_PROMPT
from .prompts.pubmed_prompts import (
PUBMED_TYPE_PROMPT, PUBMED_QUERY_PROMPT, PUBMED_SORT_PROMPT,
PUBMED_TYPE_SYSTEM_PROMPT, PUBMED_QUERY_SYSTEM_PROMPT, PUBMED_SORT_SYSTEM_PROMPT
)
from .prompts.crossref_prompts import (
CROSSREF_QUERY_PROMPT,
CROSSREF_QUERY_SYSTEM_PROMPT
)
from .prompts.adsabs_prompts import ADSABS_QUERY_PROMPT, ADSABS_QUERY_SYSTEM_PROMPT
# 1. 基本查询分析
type_prompt = dedent(f"""Please analyze this academic query and respond STRICTLY in the following XML format:
Query: {query}
Instructions:
1. Your response must use XML tags exactly as shown below
2. Do not add any text outside the tags
3. Choose query type from: review/recommend/qa/paper
- review: for literature review or survey requests
- recommend: for paper recommendation requests
- qa: for general questions about research topics
- paper: ONLY for queries about a SPECIFIC paper (with paper ID, DOI, or exact title)
4. Identify main topic and subtopics
5. Specify year range if mentioned
Required format:
<query_type>ANSWER HERE</query_type>
<main_topic>ANSWER HERE</main_topic>
<sub_topics>SUBTOPIC1, SUBTOPIC2, ...</sub_topics>
<year_range>START_YEAR-END_YEAR</year_range>
Example responses:
1. Literature Review Request:
Query: "Review recent developments in transformer models for NLP from 2020 to 2023"
<query_type>review</query_type>
<main_topic>transformer models in natural language processing</main_topic>
<sub_topics>architecture improvements, pre-training methods, fine-tuning techniques</sub_topics>
<year_range>2020-2023</year_range>
2. Paper Recommendation Request:
Query: "Suggest papers about reinforcement learning in robotics since 2018"
<query_type>recommend</query_type>
<main_topic>reinforcement learning in robotics</main_topic>
<sub_topics>robot control, policy learning, sim-to-real transfer</sub_topics>
<year_range>2018-2023</year_range>"""
)
try:
# 构建提示数组
prompts = [
type_prompt,
PAPER_IDENTIFY_PROMPT.format(query=query),
ARXIV_QUERY_PROMPT.format(query=query),
ARXIV_CATEGORIES_PROMPT.format(query=query),
ARXIV_LATEST_PROMPT.format(query=query),
ARXIV_SORT_PROMPT.format(query=query),
SEMANTIC_QUERY_PROMPT.format(query=query),
SEMANTIC_FIELDS_PROMPT.format(query=query),
PUBMED_TYPE_PROMPT.format(query=query),
PUBMED_QUERY_PROMPT.format(query=query),
CROSSREF_QUERY_PROMPT.format(query=query),
ADSABS_QUERY_PROMPT.format(query=query)
]
show_messages = [
"Analyzing query type...",
"Identifying paper details...",
"Determining arXiv search type...",
"Selecting arXiv categories...",
"Checking if latest papers requested...",
"Determining arXiv sort parameters...",
"Optimizing Semantic Scholar query...",
"Selecting Semantic Scholar fields...",
"Determining PubMed search type...",
"Optimizing PubMed query...",
"Optimizing Crossref query...",
"Optimizing ADS query..."
]
sys_prompts = [
"You are an expert at analyzing academic queries.",
PAPER_IDENTIFY_SYSTEM_PROMPT,
ARXIV_QUERY_SYSTEM_PROMPT,
ARXIV_CATEGORIES_SYSTEM_PROMPT,
ARXIV_LATEST_SYSTEM_PROMPT,
ARXIV_SORT_SYSTEM_PROMPT,
SEMANTIC_QUERY_SYSTEM_PROMPT,
SEMANTIC_FIELDS_SYSTEM_PROMPT,
PUBMED_TYPE_SYSTEM_PROMPT,
PUBMED_QUERY_SYSTEM_PROMPT,
CROSSREF_QUERY_SYSTEM_PROMPT,
ADSABS_QUERY_SYSTEM_PROMPT
]
new_llm_kwargs = llm_kwargs.copy()
# new_llm_kwargs['llm_model'] = 'deepseek-chat' # deepseek-ai/DeepSeek-V2.5
# 使用同步方式调用LLM
responses = yield from request_gpt(
inputs_array=prompts,
inputs_show_user_array=show_messages,
llm_kwargs=new_llm_kwargs,
chatbot=chatbot,
history_array=[[] for _ in prompts],
sys_prompt_array=sys_prompts,
max_workers=5
)
# 从收集的响应中提取我们需要的内容
extracted_responses = []
for i in range(len(prompts)):
if (i * 2 + 1) < len(responses):
response = responses[i * 2 + 1]
if response is None:
raise Exception(f"Response {i} is None")
if not isinstance(response, str):
try:
response = str(response)
except:
raise Exception(f"Cannot convert response {i} to string")
extracted_responses.append(response)
else:
raise Exception(f"未收到第 {i + 1} 个响应")
# 解析基本信息
query_type = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "query_type")
if not query_type:
print(
f"Debug - Failed to extract query_type. Response was: {extracted_responses[self.BASIC_QUERY_INDEX]}")
raise Exception("无法提取query_type标签内容")
query_type = query_type.lower()
main_topic = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "main_topic")
if not main_topic:
print(f"Debug - Failed to extract main_topic. Using query as fallback.")
main_topic = query
query_type = self._normalize_query_type(query_type, query)
# 解析arXiv参数
try:
arxiv_params = {
"query": self._extract_tag(extracted_responses[self.ARXIV_QUERY_INDEX], "query"),
"categories": [cat.strip() for cat in
self._extract_tag(extracted_responses[self.ARXIV_CATEGORIES_INDEX],
"categories").split(",")],
"sort_by": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_by"),
"sort_order": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_order"),
"limit": 20
}
# 安全地解析limit值
limit_str = self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "limit")
if limit_str and limit_str.isdigit():
arxiv_params["limit"] = int(limit_str)
except Exception as e:
print(f"Warning: Error parsing arXiv parameters: {str(e)}")
arxiv_params = {
"query": "",
"categories": [],
"sort_by": "relevance",
"sort_order": "descending",
"limit": 0
}
# 解析Semantic Scholar参数
try:
semantic_params = {
"query": self._extract_tag(extracted_responses[self.SEMANTIC_QUERY_INDEX], "query"),
"fields": [field.strip() for field in
self._extract_tag(extracted_responses[self.SEMANTIC_FIELDS_INDEX], "fields").split(",")],
"sort_by": "relevance",
"limit": 20
}
except Exception as e:
print(f"Warning: Error parsing Semantic Scholar parameters: {str(e)}")
semantic_params = {
"query": query,
"fields": ["title", "abstract", "authors", "year"],
"sort_by": "relevance",
"limit": 20
}
# 解析PubMed参数
try:
# 首先检查是否需要PubMed搜索
pubmed_search_type = self._extract_tag(extracted_responses[self.PUBMED_TYPE_INDEX], "search_type")
if pubmed_search_type == "none":
# 不需要PubMed搜索,使用空参数
pubmed_params = {
"search_type": "none",
"query": "",
"sort_by": "relevance",
"limit": 0
}
else:
# 需要PubMed搜索,解析完整参数
pubmed_params = {
"search_type": pubmed_search_type,
"query": self._extract_tag(extracted_responses[self.PUBMED_QUERY_INDEX], "query"),
"sort_by": "relevance",
"limit": 200
}
except Exception as e:
print(f"Warning: Error parsing PubMed parameters: {str(e)}")
pubmed_params = {
"search_type": "none",
"query": "",
"sort_by": "relevance",
"limit": 0
}
# 解析Crossref参数
try:
crossref_query = self._extract_tag(extracted_responses[self.CROSSREF_QUERY_INDEX], "query")
if not crossref_query:
crossref_params = {
"search_type": "none",
"query": "",
"sort_by": "relevance",
"limit": 0
}
else:
crossref_params = {
"search_type": "basic",
"query": crossref_query,
"sort_by": "relevance",
"limit": 20
}
except Exception as e:
print(f"Warning: Error parsing Crossref parameters: {str(e)}")
crossref_params = {
"search_type": "none",
"query": "",
"sort_by": "relevance",
"limit": 0
}
# 解析ADS参数
try:
adsabs_query = self._extract_tag(extracted_responses[self.ADSABS_QUERY_INDEX], "query")
if not adsabs_query:
adsabs_params = {
"search_type": "none",
"query": "",
"sort_by": "relevance",
"limit": 0
}
else:
adsabs_params = {
"search_type": "basic",
"query": adsabs_query,
"sort_by": "relevance",
"limit": 20
}
except Exception as e:
print(f"Warning: Error parsing ADS parameters: {str(e)}")
adsabs_params = {
"search_type": "none",
"query": "",
"sort_by": "relevance",
"limit": 0
}
print(f"Debug - Extracted information:")
print(f"Query type: {query_type}")
print(f"Main topic: {main_topic}")
print(f"arXiv params: {arxiv_params}")
print(f"Semantic params: {semantic_params}")
print(f"PubMed params: {pubmed_params}")
print(f"Crossref params: {crossref_params}")
print(f"ADS params: {adsabs_params}")
# 提取子主题
sub_topics = []
if "sub_topics" in query.lower():
sub_topics_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "sub_topics")
if sub_topics_text:
sub_topics = [topic.strip() for topic in sub_topics_text.split(",")]
# 提取年份范围
start_year = self.current_year - 5 # 默认最近5年
end_year = self.current_year
year_range = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "year_range")
if year_range:
try:
years = year_range.split("-")
if len(years) == 2:
start_year = int(years[0].strip())
end_year = int(years[1].strip())
except:
pass
# 提取 latest request 判断
is_latest_request = self._extract_tag(extracted_responses[self.ARXIV_LATEST_INDEX],
"is_latest_request").lower() == "true"
# 如果是最新论文请求,将查询类型改为 "latest"
if is_latest_request:
query_type = "latest"
# 提取论文标识信息
paper_source = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_source")
paper_id = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_id")
paper_title = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_title")
if start_year > end_year:
start_year, end_year = end_year, start_year
# 更新返回的 SearchCriteria
return SearchCriteria(
query_type=query_type,
main_topic=main_topic,
sub_topics=sub_topics,
start_year=start_year,
end_year=end_year,
arxiv_params=arxiv_params,
semantic_params=semantic_params,
pubmed_params=pubmed_params,
crossref_params=crossref_params,
paper_id=paper_id,
paper_title=paper_title,
paper_source=paper_source,
original_query=query,
adsabs_params=adsabs_params
)
except Exception as e:
raise Exception(f"Failed to analyze query: {str(e)}")
def _normalize_query_type(self, query_type: str, query: str) -> str:
"""规范化查询类型"""
if query_type in ["review", "recommend", "qa", "paper"]:
return query_type
query_lower = query.lower()
for type_name, keywords in self.valid_types.items():
for keyword in keywords:
if keyword in query_lower:
return type_name
query_type_lower = query_type.lower()
for type_name, keywords in self.valid_types.items():
for keyword in keywords:
if keyword in query_type_lower:
return type_name
return "qa" # 默认返回qa类型
def _extract_tag(self, text: str, tag: str) -> str:
"""提取标记内容"""
if not text:
return ""
# 1. 标准XML格式处理多行和特殊字符
pattern = f"<{tag}>(.*?)</{tag}>"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
if match:
content = match.group(1).strip()
if content:
return content
# 2. 处理特定标签的复杂内容
if tag == "categories":
# 处理arXiv类别
patterns = [
# 标准格式:<categories>cs.CL, cs.AI, cs.LG</categories>
r"<categories>\s*((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)\s*</categories>",
# 简单列表格式cs.CL, cs.AI, cs.LG
r"(?:^|\s)((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)(?:\s|$)",
# 单个类别格式cs.AI
r"(?:^|\s)((?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+)(?:\s|$)"
]
elif tag == "query":
# 处理搜索查询
patterns = [
# 完整的查询格式:<query>complex query</query>
r"<query>\s*((?:(?:ti|abs|au|cat):[^\n]*?|(?:AND|OR|NOT|\(|\)|\d{4}|year:\d{4}|[\"'][^\"']*[\"']|\s+))+)\s*</query>",
# 简单的关键词列表keyword1, keyword2
r"(?:^|\s)((?:\"[^\"]*\"|'[^']*'|[^\s,]+)(?:\s*,\s*(?:\"[^\"]*\"|'[^']*'|[^\s,]+))*)",
# 字段搜索格式field:value
r"((?:ti|abs|au|cat):\s*(?:\"[^\"]*\"|'[^']*'|[^\s]+))"
]
elif tag == "fields":
# 处理字段列表
patterns = [
# 标准格式:<fields>field1, field2</fields>
r"<fields>\s*([\w\s,]+)\s*</fields>",
# 简单列表格式field1, field2
r"(?:^|\s)([\w]+(?:\s*,\s*[\w]+)*)",
]
elif tag == "sort_by":
# 处理排序字段
patterns = [
# 标准格式:<sort_by>value</sort_by>
r"<sort_by>\s*(relevance|date|citations|submittedDate|year)\s*</sort_by>",
# 简单值格式relevance
r"(?:^|\s)(relevance|date|citations|submittedDate|year)(?:\s|$)"
]
else:
# 通用模式
patterns = [
f"<{tag}>\s*([\s\S]*?)\s*</{tag}>", # 标准XML格式
f"<{tag}>([\s\S]*?)(?:</{tag}>|$)", # 未闭合的标签
f"[{tag}]([\s\S]*?)[/{tag}]", # 方括号格式
f"{tag}:\s*(.*?)(?=\n\w|$)", # 冒号格式
f"<{tag}>\s*(.*?)(?=<|$)" # 部分闭合
]
# 3. 尝试所有模式
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
if match:
content = match.group(1).strip()
if content: # 确保提取的内容不为空
return content
# 4. 如果所有模式都失败,返回空字符串
return ""

查看文件

@@ -0,0 +1,64 @@
from typing import List, Dict, Any
from .query_analyzer import QueryAnalyzer, SearchCriteria
from .data_sources.arxiv_source import ArxivSource
from .data_sources.semantic_source import SemanticScholarSource
from .handlers.review_handler import 文献综述功能
from .handlers.recommend_handler import 论文推荐功能
from .handlers.qa_handler import 学术问答功能
from .handlers.paper_handler import 单篇论文分析功能
class QueryProcessor:
"""查询处理器"""
def __init__(self):
self.analyzer = QueryAnalyzer()
self.arxiv = ArxivSource()
self.semantic = SemanticScholarSource()
# 初始化各种处理器
self.handlers = {
"review": 文献综述功能(self.arxiv, self.semantic),
"recommend": 论文推荐功能(self.arxiv, self.semantic),
"qa": 学术问答功能(self.arxiv, self.semantic),
"paper": 单篇论文分析功能(self.arxiv, self.semantic)
}
async def process_query(
self,
query: str,
chatbot: List[List[str]],
history: List[List[str]],
system_prompt: str,
llm_kwargs: Dict[str, Any],
plugin_kwargs: Dict[str, Any],
) -> List[List[str]]:
"""处理用户查询"""
# 设置默认的插件参数
default_plugin_kwargs = {
'max_papers': 20, # 最大论文数量
'min_year': 2015, # 最早年份
'search_multiplier': 3, # 检索倍数
}
# 更新插件参数
plugin_kwargs.update({k: v for k, v in default_plugin_kwargs.items() if k not in plugin_kwargs})
# 1. 分析查询意图
criteria = self.analyzer.analyze_query(query, chatbot, llm_kwargs)
# 2. 根据查询类型选择处理器
handler = self.handlers.get(criteria.query_type)
if not handler:
handler = self.handlers["qa"] # 默认使用QA处理器
# 3. 处理查询
response = await handler.handle(
criteria,
chatbot,
history,
system_prompt,
llm_kwargs,
plugin_kwargs
)
return response

查看文件

@@ -0,0 +1,109 @@
import re
import requests
from loguru import logger
from typing import List, Dict
from urllib3.util import Retry
from requests.adapters import HTTPAdapter
from textwrap import dedent
from request_llms.bridge_all import predict_no_ui_long_connection
class BGELLMRanker:
"""使用LLM进行论文相关性判断的类"""
def __init__(self, llm_kwargs):
self.llm_kwargs = llm_kwargs
def is_paper_relevant(self, query: str, paper_text: str) -> bool:
"""判断论文是否与查询相关"""
prompt = dedent(f"""
Evaluate if this academic paper contains information that directly addresses the user's query.
Query: {query}
Paper Content:
{paper_text}
Evaluation Criteria:
1. The paper must contain core information that directly answers the query
2. The paper's main research focus must be highly relevant to the query
3. Papers that only mention query-related content in abstract should be excluded
4. Papers with superficial or general discussions should be excluded
5. For queries about "recent" or "latest" advances, paper should be from last 3 years
Instructions:
- Carefully evaluate against ALL criteria above
- Return true ONLY if paper meets ALL criteria
- If any criteria is not met or unclear, return false
- Be strict but not overly restrictive
Output Rules:
- Must ONLY respond with <decision>true</decision> or <decision>false</decision>
- true = paper contains relevant information to answer the query
- false = paper does not contain sufficient relevant information
Do not include any explanation or additional text."""
)
response = predict_no_ui_long_connection(
inputs=prompt,
history=[],
llm_kwargs=self.llm_kwargs,
sys_prompt="You are an expert at determining paper relevance to queries. Respond only with <decision>true</decision> or <decision>false</decision>."
)
# 提取decision标签中的内容
match = re.search(r'<decision>(.*?)</decision>', response, re.IGNORECASE)
if match:
decision = match.group(1).lower()
return decision == "true"
else:
return False
def batch_check_relevance(self, query: str, paper_texts: List[str], show_progress: bool = True) -> List[bool]:
"""批量检查论文相关性
Args:
query: 用户查询
paper_texts: 论文文本列表
show_progress: 是否显示进度条
Returns:
List[bool]: 相关性判断结果列表
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
results = [False] * len(paper_texts)
# 减少并发线程数以避免连接池耗尽
max_workers = min(20, len(paper_texts)) # 限制最大线程数
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_idx = {
executor.submit(self.is_paper_relevant, query, text): i
for i, text in enumerate(paper_texts)
}
iterator = as_completed(future_to_idx)
if show_progress:
iterator = tqdm(iterator, total=len(paper_texts), desc="判断论文相关性")
for future in iterator:
idx = future_to_idx[future]
try:
results[idx] = future.result()
except Exception as e:
logger.exception(f"处理论文 {idx} 时出错: {str(e)}")
results[idx] = False
return results
def main():
# 测试代码
ranker = BGELLMRanker()
query = "Recent advances in transformer models"
paper_text = """
Title: Attention Is All You Need
Abstract: The dominant sequence transduction models are based on complex recurrent or convolutional neural networks that include an encoder and a decoder. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely...
"""
is_relevant = ranker.is_paper_relevant(query, paper_text)
print(f"Paper relevant: {is_relevant}")
if __name__ == "__main__":
main()

查看文件

@@ -8,13 +8,10 @@ API_URL_REDIRECT, AZURE_ENDPOINT, AZURE_ENGINE = get_conf("API_URL_REDIRECT", "A
openai_endpoint = "https://api.openai.com/v1/chat/completions" openai_endpoint = "https://api.openai.com/v1/chat/completions"
if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/' if not AZURE_ENDPOINT.endswith('/'): AZURE_ENDPOINT += '/'
azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15' azure_endpoint = AZURE_ENDPOINT + f'openai/deployments/{AZURE_ENGINE}/chat/completions?api-version=2023-05-15'
if openai_endpoint in API_URL_REDIRECT: openai_endpoint = API_URL_REDIRECT[openai_endpoint] if openai_endpoint in API_URL_REDIRECT: openai_endpoint = API_URL_REDIRECT[openai_endpoint]
openai_embed_endpoint = openai_endpoint.replace("chat/completions", "embeddings") openai_embed_endpoint = openai_endpoint.replace("chat/completions", "embeddings")
from .openai_embed import OpenAiEmbeddingModel from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
embed_model_info = { embed_model_info = {
# text-embedding-3-small Increased performance over 2nd generation ada embedding model | 1,536 # text-embedding-3-small Increased performance over 2nd generation ada embedding model | 1,536

查看文件

@@ -23,6 +23,7 @@ mdtex2html
dashscope dashscope
pyautogen pyautogen
colorama colorama
docx2pdf
Markdown Markdown
pygments pygments
edge-tts>=7.0.0 edge-tts>=7.0.0

查看文件

@@ -0,0 +1,64 @@
"""
对项目中的各个插件进行测试。运行方法:直接运行 python tests/test_plugins.py
"""
import init_test
import os, sys
if __name__ == "__main__":
from test_utils import plugin_test
plugin_test(plugin='crazy_functions.Academic_Conversation->学术对话', main_input="搜索最新使用GAIA Benchmark的论文")
# plugin_test(plugin='crazy_functions.Internet_GPT->连接网络回答问题', main_input="谁是应急食品?")
# plugin_test(plugin='crazy_functions.函数动态生成->函数动态生成', main_input='交换图像的蓝色通道和红色通道', advanced_arg={"file_path_arg": "./build/ants.jpg"})
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2307.07522")
# plugin_test(plugin='crazy_functions.PDF_Translate->批量翻译PDF文档', main_input='build/pdf/t1.pdf')
# plugin_test(
# plugin="crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF",
# main_input="G:/SEAFILE_LOCAL/50503047/我的资料库/学位/paperlatex/aaai/Fu_8368_with_appendix",
# )
# plugin_test(plugin='crazy_functions.虚空终端->虚空终端', main_input='修改api-key为sk-jhoejriotherjep')
# plugin_test(plugin='crazy_functions.批量翻译PDF文档_NOUGAT->批量翻译PDF文档', main_input='crazy_functions/test_project/pdf_and_word/aaai.pdf')
# plugin_test(plugin='crazy_functions.虚空终端->虚空终端', main_input='调用插件,对C:/Users/fuqingxu/Desktop/旧文件/gpt/chatgpt_academic/crazy_functions/latex_fns中的python文件进行解析')
# plugin_test(plugin='crazy_functions.命令行助手->命令行助手', main_input='查看当前的docker容器列表')
# plugin_test(plugin='crazy_functions.SourceCode_Analyse->解析一个Python项目', main_input="crazy_functions/test_project/python/dqn")
# plugin_test(plugin='crazy_functions.SourceCode_Analyse->解析一个C项目', main_input="crazy_functions/test_project/cpp/cppipc")
# plugin_test(plugin='crazy_functions.Latex_Project_Polish->Latex英文润色', main_input="crazy_functions/test_project/latex/attention")
# plugin_test(plugin='crazy_functions.Markdown_Translate->Markdown中译英', main_input="README.md")
# plugin_test(plugin='crazy_functions.PDF_Translate->批量翻译PDF文档', main_input='crazy_functions/test_project/pdf_and_word/aaai.pdf')
# plugin_test(plugin='crazy_functions.谷歌检索小助手->谷歌检索小助手', main_input="https://scholar.google.com/scholar?hl=en&as_sdt=0%2C5&q=auto+reinforcement+learning&btnG=")
# plugin_test(plugin='crazy_functions.总结word文档->总结word文档', main_input="crazy_functions/test_project/pdf_and_word")
# plugin_test(plugin='crazy_functions.下载arxiv论文翻译摘要->下载arxiv论文并翻译摘要', main_input="1812.10695")
# plugin_test(plugin='crazy_functions.解析JupyterNotebook->解析ipynb文件', main_input="crazy_functions/test_samples")
# plugin_test(plugin='crazy_functions.数学动画生成manim->动画生成', main_input="A ball split into 2, and then split into 4, and finally split into 8.")
# for lang in ["English", "French", "Japanese", "Korean", "Russian", "Italian", "German", "Portuguese", "Arabic"]:
# plugin_test(plugin='crazy_functions.Markdown_Translate->Markdown翻译指定语言', main_input="README.md", advanced_arg={"advanced_arg": lang})
# plugin_test(plugin='crazy_functions.知识库文件注入->知识库文件注入', main_input="./")
# plugin_test(plugin='crazy_functions.知识库文件注入->读取知识库作答', main_input="What is the installation method?")
# plugin_test(plugin='crazy_functions.知识库文件注入->读取知识库作答', main_input="远程云服务器部署?")
# plugin_test(plugin='crazy_functions.Latex_Function->Latex翻译中文并重新编译PDF', main_input="2210.03629")