镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 06:26:47 +00:00
Master 4.0 (#2210)
* stage academic conversation * stage document conversation * fix buggy gradio version * file dynamic load * merge more academic plugins * accelerate nltk * feat: 为predict函数添加文件和URL读取功能 - 添加URL检测和网页内容提取功能,支持自动提取网页文本 - 添加文件路径识别和文件内容读取功能,支持private_upload路径格式 - 集成WebTextExtractor处理网页内容提取 - 集成TextContentLoader处理本地文件读取 - 支持文件路径与问题组合的智能处理 * back * block unstable --------- Co-authored-by: XiaoBoAI <liuboyin2019@ia.ac.cn>
这个提交包含在:
@@ -0,0 +1,68 @@
|
||||
from typing import List
|
||||
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||
|
||||
class EndNoteFormatter:
|
||||
"""EndNote参考文献格式生成器"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def create_document(self, papers: List[PaperMetadata]) -> str:
|
||||
"""生成EndNote格式的参考文献文本
|
||||
|
||||
Args:
|
||||
papers: 论文列表
|
||||
|
||||
Returns:
|
||||
str: EndNote格式的参考文献文本
|
||||
"""
|
||||
endnote_text = ""
|
||||
|
||||
for paper in papers:
|
||||
# 开始一个新条目
|
||||
endnote_text += "%0 Journal Article\n" # 默认类型为期刊文章
|
||||
|
||||
# 根据venue_type调整条目类型
|
||||
if hasattr(paper, 'venue_type') and paper.venue_type:
|
||||
if paper.venue_type.lower() == 'conference':
|
||||
endnote_text = endnote_text.replace("Journal Article", "Conference Paper")
|
||||
elif paper.venue_type.lower() == 'preprint':
|
||||
endnote_text = endnote_text.replace("Journal Article", "Electronic Article")
|
||||
|
||||
# 添加标题
|
||||
endnote_text += f"%T {paper.title}\n"
|
||||
|
||||
# 添加作者
|
||||
for author in paper.authors:
|
||||
endnote_text += f"%A {author}\n"
|
||||
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
endnote_text += f"%D {paper.year}\n"
|
||||
|
||||
# 添加期刊/会议名称
|
||||
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||
endnote_text += f"%J {paper.venue_name}\n"
|
||||
elif paper.venue:
|
||||
endnote_text += f"%J {paper.venue}\n"
|
||||
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
endnote_text += f"%R {paper.doi}\n"
|
||||
endnote_text += f"%U https://doi.org/{paper.doi}\n"
|
||||
elif paper.url:
|
||||
endnote_text += f"%U {paper.url}\n"
|
||||
|
||||
# 添加摘要
|
||||
if paper.abstract:
|
||||
endnote_text += f"%X {paper.abstract}\n"
|
||||
|
||||
# 添加机构
|
||||
if hasattr(paper, 'institutions'):
|
||||
for institution in paper.institutions:
|
||||
endnote_text += f"%I {institution}\n"
|
||||
|
||||
# 条目之间添加空行
|
||||
endnote_text += "\n"
|
||||
|
||||
return endnote_text
|
||||
@@ -0,0 +1,211 @@
|
||||
import re
|
||||
import os
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ExcelTableFormatter:
|
||||
"""聊天记录中Markdown表格转Excel生成器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化Excel文档对象"""
|
||||
from openpyxl import Workbook
|
||||
self.workbook = Workbook()
|
||||
self._table_count = 0
|
||||
self._current_sheet = None
|
||||
|
||||
def _normalize_table_row(self, row):
|
||||
"""标准化表格行,处理不同的分隔符情况"""
|
||||
row = row.strip()
|
||||
if row.startswith('|'):
|
||||
row = row[1:]
|
||||
if row.endswith('|'):
|
||||
row = row[:-1]
|
||||
return [cell.strip() for cell in row.split('|')]
|
||||
|
||||
def _is_separator_row(self, row):
|
||||
"""检查是否是分隔行(由 - 或 : 组成)"""
|
||||
clean_row = re.sub(r'[\s|]', '', row)
|
||||
return bool(re.match(r'^[-:]+$', clean_row))
|
||||
|
||||
def _extract_tables_from_text(self, text):
|
||||
"""从文本中提取所有表格内容"""
|
||||
if not isinstance(text, str):
|
||||
return []
|
||||
|
||||
tables = []
|
||||
current_table = []
|
||||
is_in_table = False
|
||||
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
if is_in_table and current_table:
|
||||
if len(current_table) >= 2:
|
||||
tables.append(current_table)
|
||||
current_table = []
|
||||
is_in_table = False
|
||||
continue
|
||||
|
||||
if '|' in line:
|
||||
if not is_in_table:
|
||||
is_in_table = True
|
||||
current_table.append(line)
|
||||
else:
|
||||
if is_in_table and current_table:
|
||||
if len(current_table) >= 2:
|
||||
tables.append(current_table)
|
||||
current_table = []
|
||||
is_in_table = False
|
||||
|
||||
if is_in_table and current_table and len(current_table) >= 2:
|
||||
tables.append(current_table)
|
||||
|
||||
return tables
|
||||
|
||||
def _parse_table(self, table_lines):
|
||||
"""解析表格内容为结构化数据"""
|
||||
try:
|
||||
headers = self._normalize_table_row(table_lines[0])
|
||||
|
||||
separator_index = next(
|
||||
(i for i, line in enumerate(table_lines) if self._is_separator_row(line)),
|
||||
1
|
||||
)
|
||||
|
||||
data_rows = []
|
||||
for line in table_lines[separator_index + 1:]:
|
||||
cells = self._normalize_table_row(line)
|
||||
# 确保单元格数量与表头一致
|
||||
while len(cells) < len(headers):
|
||||
cells.append('')
|
||||
cells = cells[:len(headers)]
|
||||
data_rows.append(cells)
|
||||
|
||||
if headers and data_rows:
|
||||
return {
|
||||
'headers': headers,
|
||||
'data': data_rows
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"解析表格时发生错误: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
def _create_sheet(self, question_num, table_num):
|
||||
"""创建新的工作表"""
|
||||
sheet_name = f'Q{question_num}_T{table_num}'
|
||||
if len(sheet_name) > 31:
|
||||
sheet_name = f'Table{self._table_count}'
|
||||
|
||||
if sheet_name in self.workbook.sheetnames:
|
||||
sheet_name = f'{sheet_name}_{datetime.now().strftime("%H%M%S")}'
|
||||
|
||||
return self.workbook.create_sheet(title=sheet_name)
|
||||
|
||||
def create_document(self, history):
|
||||
"""
|
||||
处理聊天历史中的所有表格并创建Excel文档
|
||||
|
||||
Args:
|
||||
history: 聊天历史列表
|
||||
|
||||
Returns:
|
||||
Workbook: 处理完成的Excel工作簿对象,如果没有表格则返回None
|
||||
"""
|
||||
has_tables = False
|
||||
|
||||
# 删除默认创建的工作表
|
||||
default_sheet = self.workbook['Sheet']
|
||||
self.workbook.remove(default_sheet)
|
||||
|
||||
# 遍历所有回答
|
||||
for i in range(1, len(history), 2):
|
||||
answer = history[i]
|
||||
tables = self._extract_tables_from_text(answer)
|
||||
|
||||
for table_lines in tables:
|
||||
parsed_table = self._parse_table(table_lines)
|
||||
if parsed_table:
|
||||
self._table_count += 1
|
||||
sheet = self._create_sheet(i // 2 + 1, self._table_count)
|
||||
|
||||
# 写入表头
|
||||
for col, header in enumerate(parsed_table['headers'], 1):
|
||||
sheet.cell(row=1, column=col, value=header)
|
||||
|
||||
# 写入数据
|
||||
for row_idx, row_data in enumerate(parsed_table['data'], 2):
|
||||
for col_idx, value in enumerate(row_data, 1):
|
||||
sheet.cell(row=row_idx, column=col_idx, value=value)
|
||||
|
||||
has_tables = True
|
||||
|
||||
return self.workbook if has_tables else None
|
||||
|
||||
|
||||
def save_chat_tables(history, save_dir, base_name):
|
||||
"""
|
||||
保存聊天历史中的表格到Excel文件
|
||||
|
||||
Args:
|
||||
history: 聊天历史列表
|
||||
save_dir: 保存目录
|
||||
base_name: 基础文件名
|
||||
|
||||
Returns:
|
||||
list: 保存的文件路径列表
|
||||
"""
|
||||
result_files = []
|
||||
|
||||
try:
|
||||
# 创建Excel格式
|
||||
excel_formatter = ExcelTableFormatter()
|
||||
workbook = excel_formatter.create_document(history)
|
||||
|
||||
if workbook is not None:
|
||||
# 确保保存目录存在
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 生成Excel文件路径
|
||||
excel_file = os.path.join(save_dir, base_name + '.xlsx')
|
||||
|
||||
# 保存Excel文件
|
||||
workbook.save(excel_file)
|
||||
result_files.append(excel_file)
|
||||
print(f"已保存表格到Excel文件: {excel_file}")
|
||||
except Exception as e:
|
||||
print(f"保存Excel格式失败: {str(e)}")
|
||||
|
||||
return result_files
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 示例聊天历史
|
||||
history = [
|
||||
"问题1",
|
||||
"""这是第一个表格:
|
||||
| A | B | C |
|
||||
|---|---|---|
|
||||
| 1 | 2 | 3 |""",
|
||||
|
||||
"问题2",
|
||||
"这是没有表格的回答",
|
||||
|
||||
"问题3",
|
||||
"""回答包含多个表格:
|
||||
| Name | Age |
|
||||
|------|-----|
|
||||
| Tom | 20 |
|
||||
|
||||
第二个表格:
|
||||
| X | Y |
|
||||
|---|---|
|
||||
| 1 | 2 |"""
|
||||
]
|
||||
|
||||
# 保存表格
|
||||
save_dir = "output"
|
||||
base_name = "chat_tables"
|
||||
saved_files = save_chat_tables(history, save_dir, base_name)
|
||||
@@ -0,0 +1,472 @@
|
||||
class HtmlFormatter:
|
||||
"""聊天记录HTML格式生成器"""
|
||||
|
||||
def __init__(self):
|
||||
self.css_styles = """
|
||||
:root {
|
||||
--primary-color: #2563eb;
|
||||
--primary-light: #eff6ff;
|
||||
--secondary-color: #1e293b;
|
||||
--background-color: #f8fafc;
|
||||
--text-color: #334155;
|
||||
--border-color: #e2e8f0;
|
||||
--card-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1);
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
line-height: 1.8;
|
||||
margin: 0;
|
||||
padding: 2rem;
|
||||
color: var(--text-color);
|
||||
background-color: var(--background-color);
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
padding: 2rem;
|
||||
border-radius: 16px;
|
||||
box-shadow: var(--card-shadow);
|
||||
}
|
||||
::selection {
|
||||
background: var(--primary-light);
|
||||
color: var(--primary-color);
|
||||
}
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(20px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
@keyframes slideIn {
|
||||
from { transform: translateX(-20px); opacity: 0; }
|
||||
to { transform: translateX(0); opacity: 1; }
|
||||
}
|
||||
|
||||
.container {
|
||||
animation: fadeIn 0.6s ease-out;
|
||||
}
|
||||
|
||||
.QaBox {
|
||||
animation: slideIn 0.5s ease-out;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.QaBox:hover {
|
||||
transform: translateX(5px);
|
||||
}
|
||||
.Question, .Answer, .historyBox {
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.chat-title {
|
||||
color: var(--primary-color);
|
||||
font-size: 2em;
|
||||
text-align: center;
|
||||
margin: 1rem 0 2rem;
|
||||
padding-bottom: 1rem;
|
||||
border-bottom: 2px solid var(--primary-color);
|
||||
}
|
||||
|
||||
.chat-body {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.5rem;
|
||||
margin: 2rem 0;
|
||||
}
|
||||
|
||||
.QaBox {
|
||||
background: white;
|
||||
padding: 1.5rem;
|
||||
border-radius: 8px;
|
||||
border-left: 4px solid var(--primary-color);
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.Question {
|
||||
color: var(--secondary-color);
|
||||
font-weight: 500;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
.Answer {
|
||||
color: var(--text-color);
|
||||
background: var(--primary-light);
|
||||
padding: 1rem;
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
.history-section {
|
||||
margin-top: 3rem;
|
||||
padding-top: 2rem;
|
||||
border-top: 2px solid var(--border-color);
|
||||
}
|
||||
|
||||
.history-title {
|
||||
color: var(--secondary-color);
|
||||
font-size: 1.5em;
|
||||
margin-bottom: 1.5rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.historyBox {
|
||||
background: white;
|
||||
padding: 1rem;
|
||||
margin: 0.5rem 0;
|
||||
border-radius: 6px;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
:root {
|
||||
--background-color: #0f172a;
|
||||
--text-color: #e2e8f0;
|
||||
--border-color: #1e293b;
|
||||
}
|
||||
|
||||
.container, .QaBox {
|
||||
background: #1e293b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
|
||||
"""生成完整的HTML文档
|
||||
Args:
|
||||
question: str, 用户问题
|
||||
answer: str, AI回答
|
||||
ranked_papers: list, 排序后的论文列表
|
||||
Returns:
|
||||
str: 完整的HTML文档字符串
|
||||
"""
|
||||
chat_content = f'''
|
||||
<div class="QaBox">
|
||||
<div class="Question">{question}</div>
|
||||
<div class="Answer markdown-body" id="answer-content">{answer}</div>
|
||||
</div>
|
||||
'''
|
||||
|
||||
references_content = ""
|
||||
if ranked_papers:
|
||||
references_content = '<div class="history-section"><h2 class="history-title">参考文献</h2>'
|
||||
for idx, paper in enumerate(ranked_papers, 1):
|
||||
authors = ', '.join(paper.authors)
|
||||
|
||||
# 构建引用信息
|
||||
citations_info = f"被引用次数:{paper.citations}" if paper.citations is not None else "引用信息未知"
|
||||
|
||||
# 构建下载链接
|
||||
download_links = []
|
||||
if paper.doi:
|
||||
# 检查是否是arXiv链接
|
||||
if 'arxiv.org' in paper.doi:
|
||||
# 如果DOI中包含完整的arXiv URL,直接使用
|
||||
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
|
||||
download_links.append(f'<a href="{arxiv_url}">arXiv链接</a>')
|
||||
# 提取arXiv ID并添加PDF链接
|
||||
arxiv_id = arxiv_url.split('abs/')[-1].split('v')[0]
|
||||
download_links.append(f'<a href="https://arxiv.org/pdf/{arxiv_id}.pdf">PDF下载</a>')
|
||||
else:
|
||||
# 非arXiv的DOI使用标准格式
|
||||
download_links.append(f'<a href="https://doi.org/{paper.doi}">DOI: {paper.doi}</a>')
|
||||
|
||||
if hasattr(paper, 'url') and paper.url and 'arxiv.org' not in str(paper.url):
|
||||
# 只有当URL不是arXiv链接时才添加
|
||||
download_links.append(f'<a href="{paper.url}">原文链接</a>')
|
||||
download_section = ' | '.join(download_links) if download_links else "无直接下载链接"
|
||||
|
||||
# 构建来源信息
|
||||
source_info = []
|
||||
if paper.venue_type:
|
||||
source_info.append(f"类型:{paper.venue_type}")
|
||||
if paper.venue_name:
|
||||
source_info.append(f"来源:{paper.venue_name}")
|
||||
|
||||
# 添加期刊指标信息
|
||||
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||
source_info.append(f"<span class='journal-metric'>IF: {paper.if_factor}</span>")
|
||||
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||
source_info.append(f"<span class='journal-metric'>JCR分区: {paper.jcr_division}</span>")
|
||||
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||
source_info.append(f"<span class='journal-metric'>中科院分区: {paper.cas_division}</span>")
|
||||
|
||||
if hasattr(paper, 'venue_info') and paper.venue_info:
|
||||
if paper.venue_info.get('journal_ref'):
|
||||
source_info.append(f"期刊参考:{paper.venue_info['journal_ref']}")
|
||||
if paper.venue_info.get('publisher'):
|
||||
source_info.append(f"出版商:{paper.venue_info['publisher']}")
|
||||
source_section = ' | '.join(source_info) if source_info else ""
|
||||
|
||||
# 构建标准引用格式
|
||||
standard_citation = f"[{idx}] "
|
||||
# 添加作者(最多3个,超过则添加et al.)
|
||||
author_list = paper.authors[:3]
|
||||
if len(paper.authors) > 3:
|
||||
author_list.append("et al.")
|
||||
standard_citation += ", ".join(author_list) + ". "
|
||||
# 添加标题
|
||||
standard_citation += f"<i>{paper.title}</i>"
|
||||
# 添加期刊/会议名称
|
||||
if paper.venue_name:
|
||||
standard_citation += f". {paper.venue_name}"
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
standard_citation += f", {paper.year}"
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
if 'arxiv.org' in paper.doi:
|
||||
# 如果是arXiv链接,直接使用arXiv URL
|
||||
arxiv_url = paper.doi if paper.doi.startswith('http') else f'http://{paper.doi}'
|
||||
standard_citation += f". {arxiv_url}"
|
||||
else:
|
||||
# 非arXiv的DOI使用标准格式
|
||||
standard_citation += f". DOI: {paper.doi}"
|
||||
standard_citation += "."
|
||||
|
||||
references_content += f'''
|
||||
<div class="historyBox">
|
||||
<div class="entry">
|
||||
<p class="paper-title"><b>[{idx}]</b> <i>{paper.title}</i></p>
|
||||
<p class="paper-authors">作者:{authors}</p>
|
||||
<p class="paper-year">发表年份:{paper.year if paper.year else "未知"}</p>
|
||||
<p class="paper-citations">{citations_info}</p>
|
||||
{f'<p class="paper-source">{source_section}</p>' if source_section else ""}
|
||||
<p class="paper-abstract">摘要:{paper.abstract if paper.abstract else "无摘要"}</p>
|
||||
<p class="paper-links">链接:{download_section}</p>
|
||||
<div class="standard-citation">
|
||||
<p class="citation-title">标准引用格式:</p>
|
||||
<p class="citation-text">{standard_citation}</p>
|
||||
<button class="copy-btn" onclick="copyToClipboard(this.previousElementSibling)">复制引用格式</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
'''
|
||||
references_content += '</div>'
|
||||
|
||||
# 添加新的CSS样式
|
||||
css_additions = """
|
||||
.paper-title {
|
||||
font-size: 1.1em;
|
||||
margin-bottom: 0.5em;
|
||||
}
|
||||
.paper-authors {
|
||||
color: var(--secondary-color);
|
||||
margin: 0.3em 0;
|
||||
}
|
||||
.paper-year, .paper-citations {
|
||||
color: var(--text-color);
|
||||
margin: 0.3em 0;
|
||||
}
|
||||
.paper-source {
|
||||
color: var(--text-color);
|
||||
font-style: italic;
|
||||
margin: 0.3em 0;
|
||||
}
|
||||
.paper-abstract {
|
||||
margin: 0.8em 0;
|
||||
padding: 0.8em;
|
||||
background: var(--primary-light);
|
||||
border-radius: 4px;
|
||||
}
|
||||
.paper-links {
|
||||
margin-top: 0.5em;
|
||||
}
|
||||
.paper-links a {
|
||||
color: var(--primary-color);
|
||||
text-decoration: none;
|
||||
margin-right: 1em;
|
||||
}
|
||||
.paper-links a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
.standard-citation {
|
||||
margin-top: 1em;
|
||||
padding: 1em;
|
||||
background: #f8fafc;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.citation-title {
|
||||
font-weight: bold;
|
||||
margin-bottom: 0.5em;
|
||||
color: var(--secondary-color);
|
||||
}
|
||||
|
||||
.citation-text {
|
||||
font-family: 'Times New Roman', Times, serif;
|
||||
line-height: 1.6;
|
||||
margin-bottom: 0.5em;
|
||||
padding: 0.5em;
|
||||
background: white;
|
||||
border-radius: 4px;
|
||||
border: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.copy-btn {
|
||||
background: var(--primary-color);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.5em 1em;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.copy-btn:hover {
|
||||
background: #1e40af;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.standard-citation {
|
||||
background: #1e293b;
|
||||
}
|
||||
.citation-text {
|
||||
background: #0f172a;
|
||||
}
|
||||
}
|
||||
|
||||
/* 添加期刊指标样式 */
|
||||
.journal-metric {
|
||||
display: inline-block;
|
||||
padding: 0.2em 0.6em;
|
||||
margin: 0 0.3em;
|
||||
background: var(--primary-light);
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
color: var(--primary-color);
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.journal-metric {
|
||||
background: #1e293b;
|
||||
color: #60a5fa;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# 修改 js_code 部分,添加 markdown 解析功能
|
||||
js_code = """
|
||||
<script>
|
||||
// 复制功能
|
||||
function copyToClipboard(element) {
|
||||
const text = element.innerText;
|
||||
navigator.clipboard.writeText(text).then(function() {
|
||||
const btn = element.nextElementSibling;
|
||||
const originalText = btn.innerText;
|
||||
btn.innerText = '已复制!';
|
||||
setTimeout(() => {
|
||||
btn.innerText = originalText;
|
||||
}, 2000);
|
||||
}).catch(function(err) {
|
||||
console.error('复制失败:', err);
|
||||
});
|
||||
}
|
||||
|
||||
// Markdown解析
|
||||
document.addEventListener('DOMContentLoaded', function() {
|
||||
const answerContent = document.getElementById('answer-content');
|
||||
if (answerContent) {
|
||||
const markdown = answerContent.textContent;
|
||||
answerContent.innerHTML = marked.parse(markdown);
|
||||
}
|
||||
});
|
||||
</script>
|
||||
"""
|
||||
|
||||
# 将新的CSS样式添加到现有样式中
|
||||
self.css_styles += css_additions
|
||||
|
||||
return f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>学术对话存档</title>
|
||||
<!-- 添加 marked.js -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||
<!-- 添加 GitHub Markdown CSS -->
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/gh/sindresorhus/github-markdown-css@4.0.0/github-markdown.min.css">
|
||||
<style>
|
||||
{self.css_styles}
|
||||
/* 添加 Markdown 相关样式 */
|
||||
.markdown-body {{
|
||||
font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
padding: 1rem;
|
||||
background: var(--primary-light);
|
||||
border-radius: 6px;
|
||||
}}
|
||||
.markdown-body pre {{
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
}}
|
||||
.markdown-body code {{
|
||||
background-color: rgba(175,184,193,0.2);
|
||||
border-radius: 6px;
|
||||
padding: 0.2em 0.4em;
|
||||
font-size: 85%;
|
||||
}}
|
||||
.markdown-body pre code {{
|
||||
background-color: transparent;
|
||||
padding: 0;
|
||||
}}
|
||||
.markdown-body blockquote {{
|
||||
border-left: 0.25em solid #d0d7de;
|
||||
padding: 0 1em;
|
||||
color: #656d76;
|
||||
}}
|
||||
.markdown-body table {{
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
margin: 1em 0;
|
||||
}}
|
||||
.markdown-body table th,
|
||||
.markdown-body table td {{
|
||||
border: 1px solid #d0d7de;
|
||||
padding: 6px 13px;
|
||||
}}
|
||||
.markdown-body table tr:nth-child(2n) {{
|
||||
background-color: #f6f8fa;
|
||||
}}
|
||||
@media (prefers-color-scheme: dark) {{
|
||||
.markdown-body {{
|
||||
background: #1e293b;
|
||||
color: #e2e8f0;
|
||||
}}
|
||||
.markdown-body pre {{
|
||||
background-color: #0f172a;
|
||||
}}
|
||||
.markdown-body code {{
|
||||
background-color: rgba(99,110,123,0.4);
|
||||
}}
|
||||
.markdown-body blockquote {{
|
||||
border-left-color: #30363d;
|
||||
color: #8b949e;
|
||||
}}
|
||||
.markdown-body table th,
|
||||
.markdown-body table td {{
|
||||
border-color: #30363d;
|
||||
}}
|
||||
.markdown-body table tr:nth-child(2n) {{
|
||||
background-color: #0f172a;
|
||||
}}
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1 class="chat-title">学术对话存档</h1>
|
||||
<div class="chat-body">
|
||||
{chat_content}
|
||||
{references_content}
|
||||
</div>
|
||||
</div>
|
||||
{js_code}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -0,0 +1,47 @@
|
||||
class MarkdownFormatter:
|
||||
"""Markdown格式文档生成器 - 用于生成对话记录的markdown文档"""
|
||||
|
||||
def __init__(self):
|
||||
self.content = []
|
||||
|
||||
def _add_content(self, text: str):
|
||||
"""添加正文内容"""
|
||||
if text:
|
||||
self.content.append(f"\n{text}\n")
|
||||
|
||||
def create_document(self, question: str, answer: str, ranked_papers: list = None) -> str:
|
||||
"""创建完整的Markdown文档
|
||||
Args:
|
||||
question: str, 用户问题
|
||||
answer: str, AI回答
|
||||
ranked_papers: list, 排序后的论文列表
|
||||
Returns:
|
||||
str: 生成的Markdown文本
|
||||
"""
|
||||
content = []
|
||||
|
||||
# 添加问答部分
|
||||
content.append("## 问题")
|
||||
content.append(question)
|
||||
content.append("\n## 回答")
|
||||
content.append(answer)
|
||||
|
||||
# 添加参考文献
|
||||
if ranked_papers:
|
||||
content.append("\n## 参考文献")
|
||||
for idx, paper in enumerate(ranked_papers, 1):
|
||||
authors = ', '.join(paper.authors[:3])
|
||||
if len(paper.authors) > 3:
|
||||
authors += ' et al.'
|
||||
|
||||
ref = f"[{idx}] {authors}. *{paper.title}*"
|
||||
if paper.venue_name:
|
||||
ref += f". {paper.venue_name}"
|
||||
if paper.year:
|
||||
ref += f", {paper.year}"
|
||||
if paper.doi:
|
||||
ref += f". [DOI: {paper.doi}](https://doi.org/{paper.doi})"
|
||||
|
||||
content.append(ref)
|
||||
|
||||
return "\n\n".join(content)
|
||||
@@ -0,0 +1,174 @@
|
||||
from typing import List
|
||||
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||
import re
|
||||
|
||||
class ReferenceFormatter:
|
||||
"""通用参考文献格式生成器"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _sanitize_bibtex(self, text: str) -> str:
|
||||
"""清理BibTeX字符串,处理特殊字符"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 替换特殊字符
|
||||
replacements = {
|
||||
'&': '\\&',
|
||||
'%': '\\%',
|
||||
'$': '\\$',
|
||||
'#': '\\#',
|
||||
'_': '\\_',
|
||||
'{': '\\{',
|
||||
'}': '\\}',
|
||||
'~': '\\textasciitilde{}',
|
||||
'^': '\\textasciicircum{}',
|
||||
'\\': '\\textbackslash{}',
|
||||
'<': '\\textless{}',
|
||||
'>': '\\textgreater{}',
|
||||
'"': '``',
|
||||
"'": "'",
|
||||
'-': '--',
|
||||
'—': '---',
|
||||
}
|
||||
|
||||
for char, replacement in replacements.items():
|
||||
text = text.replace(char, replacement)
|
||||
|
||||
return text
|
||||
|
||||
def _generate_cite_key(self, paper: PaperMetadata) -> str:
|
||||
"""生成引用键
|
||||
格式: 第一作者姓氏_年份_第一个实词
|
||||
"""
|
||||
# 获取第一作者姓氏
|
||||
first_author = ""
|
||||
if paper.authors and len(paper.authors) > 0:
|
||||
first_author = paper.authors[0].split()[-1].lower()
|
||||
|
||||
# 获取年份
|
||||
year = str(paper.year) if paper.year else "0000"
|
||||
|
||||
# 从标题中获取第一个实词
|
||||
title_word = ""
|
||||
if paper.title:
|
||||
# 移除特殊字符,分割成单词
|
||||
words = re.findall(r'\w+', paper.title.lower())
|
||||
# 过滤掉常见的停用词
|
||||
stop_words = {'a', 'an', 'the', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'}
|
||||
for word in words:
|
||||
if word not in stop_words and len(word) > 2:
|
||||
title_word = word
|
||||
break
|
||||
|
||||
# 组合cite key
|
||||
cite_key = f"{first_author}{year}{title_word}"
|
||||
|
||||
# 确保cite key只包含合法字符
|
||||
cite_key = re.sub(r'[^a-z0-9]', '', cite_key.lower())
|
||||
|
||||
return cite_key
|
||||
|
||||
def _get_entry_type(self, paper: PaperMetadata) -> str:
|
||||
"""确定BibTeX条目类型"""
|
||||
if hasattr(paper, 'venue_type') and paper.venue_type:
|
||||
venue_type = paper.venue_type.lower()
|
||||
if venue_type == 'conference':
|
||||
return 'inproceedings'
|
||||
elif venue_type == 'preprint':
|
||||
return 'unpublished'
|
||||
elif venue_type == 'journal':
|
||||
return 'article'
|
||||
elif venue_type == 'book':
|
||||
return 'book'
|
||||
elif venue_type == 'thesis':
|
||||
return 'phdthesis'
|
||||
return 'article' # 默认为期刊文章
|
||||
|
||||
|
||||
def create_document(self, papers: List[PaperMetadata]) -> str:
|
||||
"""生成BibTeX格式的参考文献文本"""
|
||||
bibtex_text = "% This file was automatically generated by GPT-Academic\n"
|
||||
bibtex_text += "% Compatible with: EndNote, Zotero, JabRef, and LaTeX\n\n"
|
||||
|
||||
for paper in papers:
|
||||
entry_type = self._get_entry_type(paper)
|
||||
cite_key = self._generate_cite_key(paper)
|
||||
|
||||
bibtex_text += f"@{entry_type}{{{cite_key},\n"
|
||||
|
||||
# 添加标题
|
||||
if paper.title:
|
||||
bibtex_text += f" title = {{{self._sanitize_bibtex(paper.title)}}},\n"
|
||||
|
||||
# 添加作者
|
||||
if paper.authors:
|
||||
# 确保每个作者的姓和名正确分隔
|
||||
processed_authors = []
|
||||
for author in paper.authors:
|
||||
names = author.split()
|
||||
if len(names) > 1:
|
||||
# 假设最后一个词是姓,其他的是名
|
||||
surname = names[-1]
|
||||
given_names = ' '.join(names[:-1])
|
||||
processed_authors.append(f"{surname}, {given_names}")
|
||||
else:
|
||||
processed_authors.append(author)
|
||||
|
||||
authors = " and ".join([self._sanitize_bibtex(author) for author in processed_authors])
|
||||
bibtex_text += f" author = {{{authors}}},\n"
|
||||
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
bibtex_text += f" year = {{{paper.year}}},\n"
|
||||
|
||||
# 添加期刊/会议名称
|
||||
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||
if entry_type == 'inproceedings':
|
||||
bibtex_text += f" booktitle = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
|
||||
elif entry_type == 'article':
|
||||
bibtex_text += f" journal = {{{self._sanitize_bibtex(paper.venue_name)}}},\n"
|
||||
# 添加期刊相关信息
|
||||
if hasattr(paper, 'venue_info'):
|
||||
if 'volume' in paper.venue_info:
|
||||
bibtex_text += f" volume = {{{paper.venue_info['volume']}}},\n"
|
||||
if 'number' in paper.venue_info:
|
||||
bibtex_text += f" number = {{{paper.venue_info['number']}}},\n"
|
||||
if 'pages' in paper.venue_info:
|
||||
bibtex_text += f" pages = {{{paper.venue_info['pages']}}},\n"
|
||||
elif paper.venue:
|
||||
venue_field = "booktitle" if entry_type == "inproceedings" else "journal"
|
||||
bibtex_text += f" {venue_field} = {{{self._sanitize_bibtex(paper.venue)}}},\n"
|
||||
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
bibtex_text += f" doi = {{{paper.doi}}},\n"
|
||||
|
||||
# 添加URL
|
||||
if paper.url:
|
||||
bibtex_text += f" url = {{{paper.url}}},\n"
|
||||
elif paper.doi:
|
||||
bibtex_text += f" url = {{https://doi.org/{paper.doi}}},\n"
|
||||
|
||||
# 添加摘要
|
||||
if paper.abstract:
|
||||
bibtex_text += f" abstract = {{{self._sanitize_bibtex(paper.abstract)}}},\n"
|
||||
|
||||
# 添加机构
|
||||
if hasattr(paper, 'institutions') and paper.institutions:
|
||||
institutions = " and ".join([self._sanitize_bibtex(inst) for inst in paper.institutions])
|
||||
bibtex_text += f" institution = {{{institutions}}},\n"
|
||||
|
||||
# 添加月份
|
||||
if hasattr(paper, 'month'):
|
||||
bibtex_text += f" month = {{{paper.month}}},\n"
|
||||
|
||||
# 添加注释字段
|
||||
if hasattr(paper, 'note'):
|
||||
bibtex_text += f" note = {{{self._sanitize_bibtex(paper.note)}}},\n"
|
||||
|
||||
# 移除最后一个逗号并关闭条目
|
||||
bibtex_text = bibtex_text.rstrip(',\n') + "\n}\n\n"
|
||||
|
||||
return bibtex_text
|
||||
@@ -0,0 +1,138 @@
|
||||
from docx2pdf import convert
|
||||
import os
|
||||
import platform
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
class WordToPdfConverter:
|
||||
"""Word文档转PDF转换器"""
|
||||
|
||||
@staticmethod
|
||||
def _replace_docx_in_filename(filename: Union[str, Path]) -> Path:
|
||||
"""
|
||||
将文件名中的'docx'替换为'pdf'
|
||||
例如: 'docx_test.pdf' -> 'pdf_test.pdf'
|
||||
"""
|
||||
path = Path(filename)
|
||||
new_name = path.stem.replace('docx', 'pdf')
|
||||
return path.parent / f"{new_name}{path.suffix}"
|
||||
|
||||
@staticmethod
|
||||
def convert_to_pdf(word_path: Union[str, Path], pdf_path: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
将Word文档转换为PDF
|
||||
|
||||
参数:
|
||||
word_path: Word文档的路径
|
||||
pdf_path: 可选,PDF文件的输出路径。如果未指定,将使用与Word文档相同的名称和位置
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径
|
||||
|
||||
异常:
|
||||
如果转换失败,将抛出相应异常
|
||||
"""
|
||||
try:
|
||||
word_path = Path(word_path)
|
||||
|
||||
if pdf_path is None:
|
||||
# 创建新的pdf路径,同时替换文件名中的docx
|
||||
pdf_path = WordToPdfConverter._replace_docx_in_filename(word_path).with_suffix('.pdf')
|
||||
else:
|
||||
pdf_path = WordToPdfConverter._replace_docx_in_filename(Path(pdf_path))
|
||||
|
||||
# 检查操作系统
|
||||
if platform.system() == 'Linux':
|
||||
# Linux系统需要安装libreoffice
|
||||
if not os.system('which libreoffice') == 0:
|
||||
raise RuntimeError("请先安装LibreOffice: sudo apt-get install libreoffice")
|
||||
|
||||
# 使用libreoffice进行转换
|
||||
os.system(f'libreoffice --headless --convert-to pdf "{word_path}" --outdir "{pdf_path.parent}"')
|
||||
|
||||
# 如果输出路径与默认生成的不同,则重命名
|
||||
default_pdf = word_path.with_suffix('.pdf')
|
||||
if default_pdf != pdf_path:
|
||||
os.rename(default_pdf, pdf_path)
|
||||
else:
|
||||
# Windows和MacOS使用 docx2pdf
|
||||
convert(word_path, pdf_path)
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"转换PDF失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def batch_convert(word_dir: Union[str, Path], pdf_dir: Union[str, Path] = None) -> list:
|
||||
"""
|
||||
批量转换目录下的所有Word文档
|
||||
|
||||
参数:
|
||||
word_dir: 包含Word文档的目录路径
|
||||
pdf_dir: 可选,PDF文件的输出目录。如果未指定,将使用与Word文档相同的目录
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径列表
|
||||
"""
|
||||
word_dir = Path(word_dir)
|
||||
if pdf_dir:
|
||||
pdf_dir = Path(pdf_dir)
|
||||
pdf_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
converted_files = []
|
||||
|
||||
for word_file in word_dir.glob("*.docx"):
|
||||
try:
|
||||
if pdf_dir:
|
||||
pdf_path = pdf_dir / WordToPdfConverter._replace_docx_in_filename(
|
||||
word_file.with_suffix('.pdf')
|
||||
).name
|
||||
else:
|
||||
pdf_path = WordToPdfConverter._replace_docx_in_filename(
|
||||
word_file.with_suffix('.pdf')
|
||||
)
|
||||
|
||||
pdf_file = WordToPdfConverter.convert_to_pdf(word_file, pdf_path)
|
||||
converted_files.append(pdf_file)
|
||||
|
||||
except Exception as e:
|
||||
print(f"转换 {word_file} 失败: {str(e)}")
|
||||
|
||||
return converted_files
|
||||
|
||||
@staticmethod
|
||||
def convert_doc_to_pdf(doc, output_dir: Union[str, Path] = None) -> str:
|
||||
"""
|
||||
将docx对象直接转换为PDF
|
||||
|
||||
参数:
|
||||
doc: python-docx的Document对象
|
||||
output_dir: 可选,输出目录。如果未指定,将使用当前目录
|
||||
|
||||
返回:
|
||||
生成的PDF文件路径
|
||||
"""
|
||||
try:
|
||||
# 设置临时文件路径和输出路径
|
||||
output_dir = Path(output_dir) if output_dir else Path.cwd()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成临时word文件
|
||||
temp_docx = output_dir / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}.docx"
|
||||
doc.save(temp_docx)
|
||||
|
||||
# 转换为PDF
|
||||
pdf_path = temp_docx.with_suffix('.pdf')
|
||||
WordToPdfConverter.convert_to_pdf(temp_docx, pdf_path)
|
||||
|
||||
# 删除临时word文件
|
||||
temp_docx.unlink()
|
||||
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
if temp_docx.exists():
|
||||
temp_docx.unlink()
|
||||
raise Exception(f"转换PDF失败: {str(e)}")
|
||||
@@ -0,0 +1,246 @@
|
||||
import re
|
||||
from docx import Document
|
||||
from docx.shared import Cm, Pt
|
||||
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
|
||||
from docx.enum.style import WD_STYLE_TYPE
|
||||
from docx.oxml.ns import qn
|
||||
from datetime import datetime
|
||||
import docx
|
||||
from docx.oxml import shared
|
||||
from crazy_functions.doc_fns.conversation_doc.word_doc import convert_markdown_to_word
|
||||
|
||||
|
||||
class WordFormatter:
|
||||
"""聊天记录Word文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012)"""
|
||||
|
||||
def __init__(self):
|
||||
self.doc = Document()
|
||||
self._setup_document()
|
||||
self._create_styles()
|
||||
|
||||
def _setup_document(self):
|
||||
"""设置文档基本格式,包括页面设置和页眉"""
|
||||
sections = self.doc.sections
|
||||
for section in sections:
|
||||
# 设置页面大小为A4
|
||||
section.page_width = Cm(21)
|
||||
section.page_height = Cm(29.7)
|
||||
# 设置页边距
|
||||
section.top_margin = Cm(3.7) # 上边距37mm
|
||||
section.bottom_margin = Cm(3.5) # 下边距35mm
|
||||
section.left_margin = Cm(2.8) # 左边距28mm
|
||||
section.right_margin = Cm(2.6) # 右边距26mm
|
||||
# 设置页眉页脚距离
|
||||
section.header_distance = Cm(2.0)
|
||||
section.footer_distance = Cm(2.0)
|
||||
|
||||
# 修改页眉
|
||||
header = section.header
|
||||
header_para = header.paragraphs[0]
|
||||
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
header_run = header_para.add_run("GPT-Academic学术对话 (体验地址:https://auth.gpt-academic.top/)")
|
||||
header_run.font.name = '仿宋'
|
||||
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
header_run.font.size = Pt(9)
|
||||
|
||||
def _create_styles(self):
|
||||
"""创建文档样式"""
|
||||
# 创建正文样式
|
||||
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
style.font.name = '仿宋'
|
||||
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
style.font.size = Pt(12)
|
||||
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
style.paragraph_format.space_after = Pt(0)
|
||||
|
||||
# 创建问题样式
|
||||
question_style = self.doc.styles.add_style('Question_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
question_style.font.name = '黑体'
|
||||
question_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
question_style.font.size = Pt(14) # 调整为14磅
|
||||
question_style.font.bold = True
|
||||
question_style.paragraph_format.space_before = Pt(12) # 减小段前距
|
||||
question_style.paragraph_format.space_after = Pt(6)
|
||||
question_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
question_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
|
||||
|
||||
# 创建回答样式
|
||||
answer_style = self.doc.styles.add_style('Answer_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
answer_style.font.name = '仿宋'
|
||||
answer_style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
answer_style.font.size = Pt(12) # 调整为12磅
|
||||
answer_style.paragraph_format.space_before = Pt(6)
|
||||
answer_style.paragraph_format.space_after = Pt(12)
|
||||
answer_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
answer_style.paragraph_format.left_indent = Pt(0) # 移除左缩进
|
||||
|
||||
# 创建标题样式
|
||||
title_style = self.doc.styles.add_style('Title_Custom', WD_STYLE_TYPE.PARAGRAPH)
|
||||
title_style.font.name = '黑体' # 改用黑体
|
||||
title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
title_style.font.size = Pt(22) # 调整为22磅
|
||||
title_style.font.bold = True
|
||||
title_style.paragraph_format.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
title_style.paragraph_format.space_before = Pt(0)
|
||||
title_style.paragraph_format.space_after = Pt(24)
|
||||
title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
|
||||
# 添加参考文献样式
|
||||
ref_style = self.doc.styles.add_style('Reference_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
ref_style.font.name = '宋体'
|
||||
ref_style._element.rPr.rFonts.set(qn('w:eastAsia'), '宋体')
|
||||
ref_style.font.size = Pt(10.5) # 参考文献使用小号字体
|
||||
ref_style.paragraph_format.space_before = Pt(3)
|
||||
ref_style.paragraph_format.space_after = Pt(3)
|
||||
ref_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.SINGLE
|
||||
ref_style.paragraph_format.left_indent = Pt(21)
|
||||
ref_style.paragraph_format.first_line_indent = Pt(-21)
|
||||
|
||||
# 添加参考文献标题样式
|
||||
ref_title_style = self.doc.styles.add_style('Reference_Title_Style', WD_STYLE_TYPE.PARAGRAPH)
|
||||
ref_title_style.font.name = '黑体'
|
||||
ref_title_style._element.rPr.rFonts.set(qn('w:eastAsia'), '黑体')
|
||||
ref_title_style.font.size = Pt(16) # 参考文献标题与问题同样大小
|
||||
ref_title_style.font.bold = True
|
||||
ref_title_style.paragraph_format.space_before = Pt(24) # 增加段前距
|
||||
ref_title_style.paragraph_format.space_after = Pt(12)
|
||||
ref_title_style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
|
||||
|
||||
def create_document(self, question: str, answer: str, ranked_papers: list = None):
|
||||
"""写入聊天历史
|
||||
Args:
|
||||
question: str, 用户问题
|
||||
answer: str, AI回答
|
||||
ranked_papers: list, 排序后的论文列表
|
||||
"""
|
||||
try:
|
||||
# 添加标题
|
||||
title_para = self.doc.add_paragraph(style='Title_Custom')
|
||||
title_run = title_para.add_run('GPT-Academic 对话记录')
|
||||
|
||||
# 添加日期
|
||||
try:
|
||||
date_para = self.doc.add_paragraph()
|
||||
date_para.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
|
||||
date_run = date_para.add_run(datetime.now().strftime('%Y年%m月%d日'))
|
||||
date_run.font.name = '仿宋'
|
||||
date_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
|
||||
date_run.font.size = Pt(16)
|
||||
except Exception as e:
|
||||
print(f"添加日期失败: {str(e)}")
|
||||
raise
|
||||
|
||||
self.doc.add_paragraph() # 添加空行
|
||||
|
||||
# 添加问答对话
|
||||
try:
|
||||
q_para = self.doc.add_paragraph(style='Question_Style')
|
||||
q_para.add_run('问题:').bold = True
|
||||
q_para.add_run(str(question))
|
||||
|
||||
a_para = self.doc.add_paragraph(style='Answer_Style')
|
||||
a_para.add_run('回答:').bold = True
|
||||
a_para.add_run(convert_markdown_to_word(str(answer)))
|
||||
except Exception as e:
|
||||
print(f"添加问答对话失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 添加参考文献部分
|
||||
if ranked_papers:
|
||||
try:
|
||||
ref_title = self.doc.add_paragraph(style='Reference_Title_Style')
|
||||
ref_title.add_run("参考文献")
|
||||
|
||||
for idx, paper in enumerate(ranked_papers, 1):
|
||||
try:
|
||||
ref_para = self.doc.add_paragraph(style='Reference_Style')
|
||||
ref_para.add_run(f'[{idx}] ').bold = True
|
||||
|
||||
# 添加作者
|
||||
authors = ', '.join(paper.authors[:3])
|
||||
if len(paper.authors) > 3:
|
||||
authors += ' et al.'
|
||||
ref_para.add_run(f'{authors}. ')
|
||||
|
||||
# 添加标题
|
||||
title_run = ref_para.add_run(paper.title)
|
||||
title_run.italic = True
|
||||
if hasattr(paper, 'url') and paper.url:
|
||||
try:
|
||||
title_run._element.rPr.rStyle = self._create_hyperlink_style()
|
||||
self._add_hyperlink(ref_para, paper.title, paper.url)
|
||||
except Exception as e:
|
||||
print(f"添加超链接失败: {str(e)}")
|
||||
|
||||
# 添加期刊/会议信息
|
||||
if paper.venue_name:
|
||||
ref_para.add_run(f'. {paper.venue_name}')
|
||||
|
||||
# 添加年份
|
||||
if paper.year:
|
||||
ref_para.add_run(f', {paper.year}')
|
||||
|
||||
# 添加DOI
|
||||
if paper.doi:
|
||||
ref_para.add_run('. ')
|
||||
if "arxiv" in paper.url:
|
||||
doi_url = paper.doi
|
||||
else:
|
||||
doi_url = f'https://doi.org/{paper.doi}'
|
||||
self._add_hyperlink(ref_para, f'DOI: {paper.doi}', doi_url)
|
||||
|
||||
ref_para.add_run('.')
|
||||
except Exception as e:
|
||||
print(f"添加第 {idx} 篇参考文献失败: {str(e)}")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"添加参考文献部分失败: {str(e)}")
|
||||
raise
|
||||
|
||||
return self.doc
|
||||
|
||||
except Exception as e:
|
||||
print(f"Word文档创建失败: {str(e)}")
|
||||
import traceback
|
||||
print(f"详细错误信息: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
def _create_hyperlink_style(self):
|
||||
"""创建超链接样式"""
|
||||
styles = self.doc.styles
|
||||
if 'Hyperlink' not in styles:
|
||||
hyperlink_style = styles.add_style('Hyperlink', WD_STYLE_TYPE.CHARACTER)
|
||||
# 使用科技蓝 (#0066CC)
|
||||
hyperlink_style.font.color.rgb = 0x0066CC # 科技蓝
|
||||
hyperlink_style.font.underline = True
|
||||
return styles['Hyperlink']
|
||||
|
||||
def _add_hyperlink(self, paragraph, text, url):
|
||||
"""添加超链接到段落"""
|
||||
# 这个是在XML级别添加超链接
|
||||
part = paragraph.part
|
||||
r_id = part.relate_to(url, docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK, is_external=True)
|
||||
|
||||
# 创建超链接XML元素
|
||||
hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink')
|
||||
hyperlink.set(docx.oxml.shared.qn('r:id'), r_id)
|
||||
|
||||
# 创建文本运行
|
||||
new_run = docx.oxml.shared.OxmlElement('w:r')
|
||||
rPr = docx.oxml.shared.OxmlElement('w:rPr')
|
||||
|
||||
# 应用超链接样式
|
||||
rStyle = docx.oxml.shared.OxmlElement('w:rStyle')
|
||||
rStyle.set(docx.oxml.shared.qn('w:val'), 'Hyperlink')
|
||||
rPr.append(rStyle)
|
||||
|
||||
# 添加文本
|
||||
t = docx.oxml.shared.OxmlElement('w:t')
|
||||
t.text = text
|
||||
new_run.append(rPr)
|
||||
new_run.append(t)
|
||||
hyperlink.append(new_run)
|
||||
|
||||
# 将超链接添加到段落
|
||||
paragraph._p.append(hyperlink)
|
||||
|
||||
@@ -0,0 +1,279 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class AdsabsSource(DataSource):
|
||||
"""ADS (Astrophysics Data System) API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: ADS API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.adsabs.harvard.edu/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, method: str = "GET", data: dict = None) -> Optional[dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
method: HTTP方法
|
||||
data: POST请求数据
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
if method == "GET":
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif method == "POST":
|
||||
async with session.post(url, json=data) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper(self, doc: dict) -> PaperMetadata:
|
||||
"""解析ADS文献数据
|
||||
|
||||
Args:
|
||||
doc: ADS文献数据
|
||||
|
||||
Returns:
|
||||
解析后的论文数据
|
||||
"""
|
||||
try:
|
||||
return PaperMetadata(
|
||||
title=doc.get('title', [''])[0] if doc.get('title') else '',
|
||||
authors=doc.get('author', []),
|
||||
abstract=doc.get('abstract', ''),
|
||||
year=doc.get('year'),
|
||||
doi=doc.get('doi', [''])[0] if doc.get('doi') else None,
|
||||
url=f"https://ui.adsabs.harvard.edu/abs/{doc.get('bibcode')}/abstract" if doc.get('bibcode') else None,
|
||||
citations=doc.get('citation_count'),
|
||||
venue=doc.get('pub', ''),
|
||||
institutions=doc.get('aff', []),
|
||||
venue_type="journal",
|
||||
venue_name=doc.get('pub', ''),
|
||||
venue_info={
|
||||
'volume': doc.get('volume'),
|
||||
'issue': doc.get('issue'),
|
||||
'pub_date': doc.get('pubdate', '')
|
||||
},
|
||||
source='adsabs'
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"解析文章时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
if start_year:
|
||||
query = f"{query} year:{start_year}-"
|
||||
|
||||
# 设置排序
|
||||
sort_mapping = {
|
||||
'relevance': 'score desc',
|
||||
'date': 'date desc',
|
||||
'citations': 'citation_count desc'
|
||||
}
|
||||
sort = sort_mapping.get(sort_by, 'score desc')
|
||||
|
||||
# 构建搜索请求
|
||||
search_url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": query,
|
||||
"rows": limit,
|
||||
"sort": sort,
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
# 解析结果
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _build_query_string(self, params: dict) -> str:
|
||||
"""构建查询字符串"""
|
||||
return "&".join([f"{k}={v}" for k, v in params.items()])
|
||||
|
||||
async def get_paper_details(self, bibcode: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定bibcode的论文详情"""
|
||||
search_url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"identifier:{bibcode}",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{search_url}?{self._build_query_string(params)}")
|
||||
if response and 'response' in response and response['response']['docs']:
|
||||
return self._parse_paper(response['response']['docs'][0])
|
||||
return None
|
||||
|
||||
async def get_related_papers(self, bibcode: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取相关论文"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"citations(identifier:{bibcode}) OR references(identifier:{bibcode})",
|
||||
"rows": limit,
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"author:\"{author}\""
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"pub:\"{journal}\""
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文"""
|
||||
query = f"entdate:[NOW-{days}DAYS TO NOW]"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def get_citations(self, bibcode: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"citations(identifier:{bibcode})",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def get_references(self, bibcode: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
url = f"{self.base_url}/search/query"
|
||||
params = {
|
||||
"q": f"references(identifier:{bibcode})",
|
||||
"fl": "title,author,abstract,year,doi,bibcode,citation_count,pub,aff,volume,issue,pubdate"
|
||||
}
|
||||
|
||||
response = await self._make_request(f"{url}?{self._build_query_string(params)}")
|
||||
if not response or 'response' not in response:
|
||||
return []
|
||||
|
||||
papers = []
|
||||
for doc in response['response']['docs']:
|
||||
paper = self._parse_paper(doc)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
return papers
|
||||
|
||||
async def example_usage():
|
||||
"""AdsabsSource使用示例"""
|
||||
ads = AdsabsSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索黑洞相关论文 ===")
|
||||
papers = await ads.search("black hole", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
|
||||
# 其他示例...
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# python -m crazy_functions.review_fns.data_sources.adsabs_source
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,636 @@
|
||||
import arxiv
|
||||
from typing import List, Optional, Union, Literal, Dict
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
import os
|
||||
from urllib.request import urlretrieve
|
||||
import feedparser
|
||||
from tqdm import tqdm
|
||||
|
||||
class ArxivSource(DataSource):
|
||||
"""arXiv API实现"""
|
||||
|
||||
CATEGORIES = {
|
||||
# 物理学
|
||||
"Physics": {
|
||||
"astro-ph": "天体物理学",
|
||||
"cond-mat": "凝聚态物理",
|
||||
"gr-qc": "广义相对论与量子宇宙学",
|
||||
"hep-ex": "高能物理实验",
|
||||
"hep-lat": "格点场论",
|
||||
"hep-ph": "高能物理理论",
|
||||
"hep-th": "高能物理理论",
|
||||
"math-ph": "数学物理",
|
||||
"nlin": "非线性科学",
|
||||
"nucl-ex": "核实验",
|
||||
"nucl-th": "核理论",
|
||||
"physics": "物理学",
|
||||
"quant-ph": "量子物理",
|
||||
},
|
||||
|
||||
# 数学
|
||||
"Mathematics": {
|
||||
"math.AG": "代数几何",
|
||||
"math.AT": "代数拓扑",
|
||||
"math.AP": "分析与偏微分方程",
|
||||
"math.CT": "范畴论",
|
||||
"math.CA": "复分析",
|
||||
"math.CO": "组合数学",
|
||||
"math.AC": "交换代数",
|
||||
"math.CV": "复变函数",
|
||||
"math.DG": "微分几何",
|
||||
"math.DS": "动力系统",
|
||||
"math.FA": "泛函分析",
|
||||
"math.GM": "一般数学",
|
||||
"math.GN": "一般拓扑",
|
||||
"math.GT": "几何拓扑",
|
||||
"math.GR": "群论",
|
||||
"math.HO": "数学史与数学概述",
|
||||
"math.IT": "信息论",
|
||||
"math.KT": "K理论与同调",
|
||||
"math.LO": "逻辑",
|
||||
"math.MP": "数学物理",
|
||||
"math.MG": "度量几何",
|
||||
"math.NT": "数论",
|
||||
"math.NA": "数值分析",
|
||||
"math.OA": "算子代数",
|
||||
"math.OC": "最优化与控制",
|
||||
"math.PR": "概率论",
|
||||
"math.QA": "量子代数",
|
||||
"math.RT": "表示论",
|
||||
"math.RA": "环与代数",
|
||||
"math.SP": "谱理论",
|
||||
"math.ST": "统计理论",
|
||||
"math.SG": "辛几何",
|
||||
},
|
||||
|
||||
# 计算机科学
|
||||
"Computer Science": {
|
||||
"cs.AI": "人工智能",
|
||||
"cs.CL": "计算语言学",
|
||||
"cs.CC": "计算复杂性",
|
||||
"cs.CE": "计算工程",
|
||||
"cs.CG": "计算几何",
|
||||
"cs.GT": "计算机博弈论",
|
||||
"cs.CV": "计算机视觉",
|
||||
"cs.CY": "计算机与社会",
|
||||
"cs.CR": "密码学与安全",
|
||||
"cs.DS": "数据结构与算法",
|
||||
"cs.DB": "数据库",
|
||||
"cs.DL": "数字图书馆",
|
||||
"cs.DM": "离散数学",
|
||||
"cs.DC": "分布式计算",
|
||||
"cs.ET": "新兴技术",
|
||||
"cs.FL": "形式语言与自动机理论",
|
||||
"cs.GL": "一般文献",
|
||||
"cs.GR": "图形学",
|
||||
"cs.AR": "硬件架构",
|
||||
"cs.HC": "人机交互",
|
||||
"cs.IR": "信息检索",
|
||||
"cs.IT": "信息论",
|
||||
"cs.LG": "机器学习",
|
||||
"cs.LO": "逻辑与计算机",
|
||||
"cs.MS": "数学软件",
|
||||
"cs.MA": "多智能体系统",
|
||||
"cs.MM": "多媒体",
|
||||
"cs.NI": "网络与互联网架构",
|
||||
"cs.NE": "神经与进化计算",
|
||||
"cs.NA": "数值分析",
|
||||
"cs.OS": "操作系统",
|
||||
"cs.OH": "其他计算机科学",
|
||||
"cs.PF": "性能评估",
|
||||
"cs.PL": "编程语言",
|
||||
"cs.RO": "机器人学",
|
||||
"cs.SI": "社会与信息网络",
|
||||
"cs.SE": "软件工程",
|
||||
"cs.SD": "声音",
|
||||
"cs.SC": "符号计算",
|
||||
"cs.SY": "系统与控制",
|
||||
},
|
||||
|
||||
# 定量生物学
|
||||
"Quantitative Biology": {
|
||||
"q-bio.BM": "生物分子",
|
||||
"q-bio.CB": "细胞行为",
|
||||
"q-bio.GN": "基因组学",
|
||||
"q-bio.MN": "分子网络",
|
||||
"q-bio.NC": "神经计算",
|
||||
"q-bio.OT": "其他",
|
||||
"q-bio.PE": "群体与进化",
|
||||
"q-bio.QM": "定量方法",
|
||||
"q-bio.SC": "亚细胞过程",
|
||||
"q-bio.TO": "组织与器官",
|
||||
},
|
||||
|
||||
# 定量金融
|
||||
"Quantitative Finance": {
|
||||
"q-fin.CP": "计算金融",
|
||||
"q-fin.EC": "经济学",
|
||||
"q-fin.GN": "一般金融",
|
||||
"q-fin.MF": "数学金融",
|
||||
"q-fin.PM": "投资组合管理",
|
||||
"q-fin.PR": "定价理论",
|
||||
"q-fin.RM": "风险管理",
|
||||
"q-fin.ST": "统计金融",
|
||||
"q-fin.TR": "交易与市场微观结构",
|
||||
},
|
||||
|
||||
# 统计学
|
||||
"Statistics": {
|
||||
"stat.AP": "应用统计",
|
||||
"stat.CO": "计算统计",
|
||||
"stat.ML": "机器学习",
|
||||
"stat.ME": "方法论",
|
||||
"stat.OT": "其他统计",
|
||||
"stat.TH": "统计理论",
|
||||
},
|
||||
|
||||
# 电气工程与系统科学
|
||||
"Electrical Engineering and Systems Science": {
|
||||
"eess.AS": "音频与语音处理",
|
||||
"eess.IV": "图像与视频处理",
|
||||
"eess.SP": "信号处理",
|
||||
"eess.SY": "系统与控制",
|
||||
},
|
||||
|
||||
# 经济学
|
||||
"Economics": {
|
||||
"econ.EM": "计量经济学",
|
||||
"econ.GN": "一般经济学",
|
||||
"econ.TH": "理论经济学",
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化"""
|
||||
self._initialize() # 调用初始化方法
|
||||
# 修改排序选项映射
|
||||
self.sort_options = {
|
||||
'relevance': arxiv.SortCriterion.Relevance, # arXiv的相关性排序
|
||||
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
|
||||
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
|
||||
}
|
||||
|
||||
self.sort_order_options = {
|
||||
'ascending': arxiv.SortOrder.Ascending,
|
||||
'descending': arxiv.SortOrder.Descending
|
||||
}
|
||||
|
||||
self.default_sort = 'lastUpdatedDate'
|
||||
self.default_order = 'descending'
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.client = arxiv.Client()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None,
|
||||
start_year: int = None
|
||||
) -> List[Dict]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
# 使用默认排序如果提供的排序选项无效
|
||||
if not sort_by or sort_by not in self.sort_options:
|
||||
sort_by = self.default_sort
|
||||
|
||||
# 使用默认排序顺序如果提供的顺序无效
|
||||
if not sort_order or sort_order not in self.sort_order_options:
|
||||
sort_order = self.default_order
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
search = arxiv.Search(
|
||||
query=query,
|
||||
max_results=limit,
|
||||
sort_by=self.sort_options[sort_by],
|
||||
sort_order=self.sort_order_options[sort_order]
|
||||
)
|
||||
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
|
||||
"""按ID搜索论文
|
||||
|
||||
Args:
|
||||
paper_id: 单个arXiv ID或ID列表,例如:'2005.14165' 或 ['2005.14165', '2103.14030']
|
||||
"""
|
||||
if isinstance(paper_id, str):
|
||||
paper_id = [paper_id]
|
||||
|
||||
search = arxiv.Search(
|
||||
id_list=paper_id,
|
||||
max_results=len(paper_id)
|
||||
)
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
async def search_by_category(
|
||||
self,
|
||||
category: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
sort_order: str = 'descending',
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按类别搜索论文"""
|
||||
query = f"cat:{category}"
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def search_by_authors(
|
||||
self,
|
||||
authors: List[str],
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = " AND ".join([f"au:\"{author}\"" for author in authors])
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by
|
||||
)
|
||||
|
||||
async def search_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
|
||||
sort_order: Literal['ascending', 'descending'] = 'descending'
|
||||
) -> List[PaperMetadata]:
|
||||
"""按日期范围搜索论文"""
|
||||
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
|
||||
return await self.search(
|
||||
query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文PDF
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
papers = await self.search_by_id(paper_id)
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
if not filename:
|
||||
# 清理标题中的非法字符
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.pdf"
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
urlretrieve(paper.url, filepath)
|
||||
return filepath
|
||||
|
||||
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文源文件(通常是LaTeX源码)
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
papers = await self.search_by_id(paper_id)
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
if not filename:
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.tar.gz"
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
source_url = paper.url.replace("/pdf/", "/src/")
|
||||
urlretrieve(source_url, filepath)
|
||||
return filepath
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
# arXiv API不直接提供引用信息
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
# arXiv API不直接提供引用信息
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详情
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID 或 DOI
|
||||
|
||||
Returns:
|
||||
论文详细信息,如果未找到返回 None
|
||||
"""
|
||||
try:
|
||||
# 如果是完整的 arXiv URL,提取 ID
|
||||
if "arxiv.org" in paper_id:
|
||||
paper_id = paper_id.split("/")[-1]
|
||||
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
|
||||
elif paper_id.startswith("10.48550/arXiv."):
|
||||
paper_id = paper_id.split(".")[-1]
|
||||
|
||||
papers = await self.search_by_id(paper_id)
|
||||
return papers[0] if papers else None
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper_data(self, result: arxiv.Result) -> PaperMetadata:
|
||||
"""解析arXiv API返回的数据"""
|
||||
# 解析主要类别和次要类别
|
||||
primary_category = result.primary_category
|
||||
categories = result.categories
|
||||
|
||||
# 构建venue信息
|
||||
venue_info = {
|
||||
'primary_category': primary_category,
|
||||
'categories': categories,
|
||||
'comments': getattr(result, 'comment', None),
|
||||
'journal_ref': getattr(result, 'journal_ref', None)
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=result.title,
|
||||
authors=[author.name for author in result.authors],
|
||||
abstract=result.summary,
|
||||
year=result.published.year,
|
||||
doi=result.entry_id,
|
||||
url=result.pdf_url,
|
||||
citations=None,
|
||||
venue=f"arXiv:{primary_category}",
|
||||
institutions=[],
|
||||
venue_type='preprint', # arXiv论文都是预印本
|
||||
venue_name='arXiv',
|
||||
venue_info=venue_info,
|
||||
source='arxiv' # 添加来源标记
|
||||
)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
category: str,
|
||||
debug: bool = False,
|
||||
batch_size: int = 50
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取指定类别的最新论文
|
||||
|
||||
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
|
||||
|
||||
Args:
|
||||
category: arXiv类别,例如:
|
||||
- 整个领域: 'cs'
|
||||
- 具体方向: 'cs.AI'
|
||||
- 多个类别: 'cs.AI+q-bio.NC'
|
||||
debug: 是否为调试模式,如果为True则只返回5篇最新论文
|
||||
batch_size: 批量获取论文的数量,默认50
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果类别无效
|
||||
"""
|
||||
try:
|
||||
# 处理类别格式
|
||||
# 1. 转换为小写
|
||||
# 2. 确保多个类别之间使用+连接
|
||||
category = category.lower().replace(' ', '+')
|
||||
|
||||
# 构建RSS feed URL
|
||||
feed_url = f"https://rss.arxiv.org/rss/{category}"
|
||||
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
|
||||
|
||||
feed = feedparser.parse(feed_url)
|
||||
|
||||
# 检查feed是否有效
|
||||
if hasattr(feed, 'status') and feed.status != 200:
|
||||
raise ValueError(f"获取RSS feed失败,状态码: {feed.status}")
|
||||
|
||||
if not feed.entries:
|
||||
print(f"警告:未在feed中找到任何条目") # 添加调试信息
|
||||
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
|
||||
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
|
||||
|
||||
if debug:
|
||||
# 调试模式:只获取5篇最新论文
|
||||
search = arxiv.Search(
|
||||
query=f'cat:{category}',
|
||||
sort_by=arxiv.SortCriterion.SubmittedDate,
|
||||
sort_order=arxiv.SortOrder.Descending,
|
||||
max_results=5
|
||||
)
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
# 正常模式:获取所有新论文
|
||||
# 从RSS条目中提取arXiv ID
|
||||
paper_ids = []
|
||||
for entry in feed.entries:
|
||||
try:
|
||||
# RSS链接格式可能是以下几种:
|
||||
# - http://arxiv.org/abs/2403.xxxxx
|
||||
# - http://arxiv.org/pdf/2403.xxxxx
|
||||
# - https://arxiv.org/abs/2403.xxxxx
|
||||
link = entry.link or entry.id
|
||||
arxiv_id = link.split('/')[-1].replace('.pdf', '')
|
||||
if arxiv_id:
|
||||
paper_ids.append(arxiv_id)
|
||||
except Exception as e:
|
||||
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
|
||||
continue
|
||||
|
||||
if not paper_ids:
|
||||
print("未能从feed中提取到任何论文ID") # 添加调试信息
|
||||
return []
|
||||
|
||||
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
|
||||
|
||||
# 批量获取论文详情
|
||||
papers = []
|
||||
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
|
||||
for i in range(0, len(paper_ids), batch_size):
|
||||
batch_ids = paper_ids[i:i + batch_size]
|
||||
search = arxiv.Search(
|
||||
id_list=batch_ids,
|
||||
max_results=len(batch_ids)
|
||||
)
|
||||
batch_results = list(self.client.results(search))
|
||||
papers.extend([self._parse_paper_data(result) for result in batch_results])
|
||||
pbar.update(len(batch_results))
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取最新论文时发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc()) # 添加完整的错误追踪
|
||||
return []
|
||||
|
||||
async def example_usage():
|
||||
"""ArxivSource使用示例"""
|
||||
arxiv_source = ArxivSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索,使用不同的排序方式
|
||||
# print("\n=== 示例1:搜索最新的机器学习论文(按提交时间排序)===")
|
||||
# papers = await arxiv_source.search(
|
||||
# "ti:\"machine learning\"",
|
||||
# limit=3,
|
||||
# sort_by='submitted',
|
||||
# sort_order='descending'
|
||||
# )
|
||||
# print(f"找到 {len(papers)} 篇论文")
|
||||
|
||||
# for i, paper in enumerate(papers, 1):
|
||||
# print(f"\n--- 论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
# print(f"arXiv ID: {paper.doi}")
|
||||
# print(f"PDF URL: {paper.url}")
|
||||
# if paper.abstract:
|
||||
# print(f"\n摘要:")
|
||||
# print(paper.abstract)
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例2:按ID搜索
|
||||
# print("\n=== 示例2:按ID搜索论文 ===")
|
||||
# paper_id = "2005.14165" # GPT-3论文
|
||||
# papers = await arxiv_source.search_by_id(paper_id)
|
||||
# if papers:
|
||||
# paper = papers[0]
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
|
||||
# # 示例3:按类别搜索
|
||||
# print("\n=== 示例3:搜索人工智能领域最新论文 ===")
|
||||
# ai_papers = await arxiv_source.search_by_category(
|
||||
# "cs.AI",
|
||||
# limit=2,
|
||||
# sort_by='updated',
|
||||
# sort_order='descending'
|
||||
# )
|
||||
# for i, paper in enumerate(ai_papers, 1):
|
||||
# print(f"\n--- AI论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例4:按作者搜索
|
||||
# print("\n=== 示例4:搜索特定作者的论文 ===")
|
||||
# author_papers = await arxiv_source.search_by_authors(
|
||||
# ["Bengio"],
|
||||
# limit=2,
|
||||
# sort_by='relevance'
|
||||
# )
|
||||
# for i, paper in enumerate(author_papers, 1):
|
||||
# print(f"\n--- Bengio的论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表venue: {paper.venue}")
|
||||
|
||||
# # 示例5:按日期范围搜索
|
||||
# print("\n=== 示例5:搜索特定日期范围的论文 ===")
|
||||
# from datetime import datetime, timedelta
|
||||
# end_date = datetime.now()
|
||||
# start_date = end_date - timedelta(days=7) # 最近一周
|
||||
# recent_papers = await arxiv_source.search_by_date_range(
|
||||
# start_date,
|
||||
# end_date,
|
||||
# limit=2
|
||||
# )
|
||||
# for i, paper in enumerate(recent_papers, 1):
|
||||
# print(f"\n--- 最近论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
# print(f"作者: {', '.join(paper.authors)}")
|
||||
# print(f"发表年份: {paper.year}")
|
||||
|
||||
# # 示例6:下载PDF
|
||||
# print("\n=== 示例6:下载论文PDF ===")
|
||||
# if papers: # 使用之前搜索到的GPT-3论文
|
||||
# pdf_path = await arxiv_source.download_pdf(paper_id)
|
||||
# print(f"PDF已下载到: {pdf_path}")
|
||||
|
||||
# # 示例7:下载源文件
|
||||
# print("\n=== 示例7:下载论文源文件 ===")
|
||||
# if papers:
|
||||
# source_path = await arxiv_source.download_source(paper_id)
|
||||
# print(f"源文件已下载到: {source_path}")
|
||||
|
||||
# 示例6:获取最新论文
|
||||
print("\n=== 示例8:获取最新论文 ===")
|
||||
|
||||
# 获取CS.AI领域的最新论文
|
||||
print("\n--- 获取AI领域最新论文 ---")
|
||||
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
|
||||
for i, paper in enumerate(ai_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 获取整个计算机科学领域的最新论文
|
||||
print("\n--- 获取整个CS领域最新论文 ---")
|
||||
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
|
||||
for i, paper in enumerate(cs_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 获取多个类别的最新论文
|
||||
print("\n--- 获取AI和机器学习领域最新论文 ---")
|
||||
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
|
||||
for i, paper in enumerate(multi_latest, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,102 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
class PaperMetadata:
|
||||
"""论文元数据"""
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
authors: List[str],
|
||||
abstract: str,
|
||||
year: int,
|
||||
doi: str = None,
|
||||
url: str = None,
|
||||
citations: int = None,
|
||||
venue: str = None,
|
||||
institutions: List[str] = None,
|
||||
venue_type: str = None, # 来源类型(journal/conference/preprint等)
|
||||
venue_name: str = None, # 具体的期刊/会议名称
|
||||
venue_info: Dict = None, # 更多来源详细信息(如影响因子、分区等)
|
||||
source: str = None # 新增: 论文来源标记
|
||||
):
|
||||
self.title = title
|
||||
self.authors = authors
|
||||
self.abstract = abstract
|
||||
self.year = year
|
||||
self.doi = doi
|
||||
self.url = url
|
||||
self.citations = citations
|
||||
self.venue = venue
|
||||
self.institutions = institutions or []
|
||||
self.venue_type = venue_type # 新增
|
||||
self.venue_name = venue_name # 新增
|
||||
self.venue_info = venue_info or {} # 新增
|
||||
self.source = source # 新增: 存储论文来源
|
||||
|
||||
# 新增:影响因子和分区信息,初始化为None
|
||||
self._if_factor = None
|
||||
self._cas_division = None
|
||||
self._jcr_division = None
|
||||
|
||||
@property
|
||||
def if_factor(self) -> Optional[float]:
|
||||
"""获取影响因子"""
|
||||
return self._if_factor
|
||||
|
||||
@if_factor.setter
|
||||
def if_factor(self, value: float):
|
||||
"""设置影响因子"""
|
||||
self._if_factor = value
|
||||
|
||||
@property
|
||||
def cas_division(self) -> Optional[str]:
|
||||
"""获取中科院分区"""
|
||||
return self._cas_division
|
||||
|
||||
@cas_division.setter
|
||||
def cas_division(self, value: str):
|
||||
"""设置中科院分区"""
|
||||
self._cas_division = value
|
||||
|
||||
@property
|
||||
def jcr_division(self) -> Optional[str]:
|
||||
"""获取JCR分区"""
|
||||
return self._jcr_division
|
||||
|
||||
@jcr_division.setter
|
||||
def jcr_division(self, value: str):
|
||||
"""设置JCR分区"""
|
||||
self._jcr_division = value
|
||||
|
||||
class DataSource(ABC):
|
||||
"""数据源基类"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key
|
||||
self._initialize()
|
||||
|
||||
@abstractmethod
|
||||
def _initialize(self) -> None:
|
||||
"""初始化数据源"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_paper_details(self, paper_id: str) -> PaperMetadata:
|
||||
"""获取论文详细信息"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
pass
|
||||
文件差异因一行或多行过长而隐藏
@@ -0,0 +1,400 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import random
|
||||
|
||||
class CrossrefSource(DataSource):
|
||||
"""Crossref API实现"""
|
||||
|
||||
CONTACT_EMAILS = [
|
||||
"gpt_abc_academic@163.com",
|
||||
"gpt_abc_newapi@163.com",
|
||||
"gpt_abc_academic_pwd@163.com"
|
||||
]
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.base_url = "https://api.crossref.org"
|
||||
# 随机选择一个邮箱
|
||||
contact_email = random.choice(self.CONTACT_EMAILS)
|
||||
self.headers = {
|
||||
"Accept": "application/json",
|
||||
"User-Agent": f"Mozilla/5.0 (compatible; PythonScript/1.0; mailto:{contact_email})",
|
||||
}
|
||||
if self.api_key:
|
||||
self.headers["Crossref-Plus-API-Token"] = f"Bearer {self.api_key}"
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序字段
|
||||
sort_order: 排序顺序
|
||||
start_year: 起始年份
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
# 请求更多的结果以补偿可能被过滤掉的文章
|
||||
adjusted_limit = min(limit * 3, 1000) # 设置上限以避免请求过多
|
||||
params = {
|
||||
"query": query,
|
||||
"rows": adjusted_limit,
|
||||
"select": (
|
||||
"DOI,title,author,published-print,abstract,reference,"
|
||||
"container-title,is-referenced-by-count,type,"
|
||||
"publisher,ISSN,ISBN,issue,volume,page"
|
||||
)
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["filter"] = f"from-pub-date:{start_year}"
|
||||
|
||||
# 添加排序
|
||||
if sort_by:
|
||||
params["sort"] = sort_by
|
||||
if sort_order:
|
||||
params["order"] = sort_order
|
||||
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return []
|
||||
|
||||
data = await response.json()
|
||||
items = data.get("message", {}).get("items", [])
|
||||
if not items:
|
||||
print(f"未找到相关论文")
|
||||
return []
|
||||
|
||||
# 过滤掉没有摘要的文章
|
||||
papers = []
|
||||
filtered_count = 0
|
||||
for work in items:
|
||||
paper = self._parse_work(work)
|
||||
if paper.abstract and paper.abstract.strip():
|
||||
papers.append(paper)
|
||||
if len(papers) >= limit: # 达到原始请求的限制后停止
|
||||
break
|
||||
else:
|
||||
filtered_count += 1
|
||||
|
||||
print(f"找到 {len(items)} 篇相关论文,其中 {filtered_count} 篇因缺少摘要被过滤")
|
||||
print(f"返回 {len(papers)} 篇包含摘要的论文")
|
||||
return papers
|
||||
|
||||
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||
"""获取指定DOI的论文详情"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/{doi}",
|
||||
params={
|
||||
"select": (
|
||||
"DOI,title,author,published-print,abstract,reference,"
|
||||
"container-title,is-referenced-by-count,type,"
|
||||
"publisher,ISSN,ISBN,issue,volume,page"
|
||||
)
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取论文详情失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
return self._parse_work(data.get("message", {}))
|
||||
except Exception as e:
|
||||
print(f"解析论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/{doi}",
|
||||
params={"select": "reference"}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取参考文献失败: HTTP {response.status}")
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
# 确保我们正确处理返回的数据结构
|
||||
if not isinstance(data, dict):
|
||||
print(f"API返回了意外的数据格式: {type(data)}")
|
||||
return []
|
||||
|
||||
references = data.get("message", {}).get("reference", [])
|
||||
if not references:
|
||||
print(f"未找到参考文献")
|
||||
return []
|
||||
|
||||
return [
|
||||
PaperMetadata(
|
||||
title=ref.get("article-title", ""),
|
||||
authors=[ref.get("author", "")],
|
||||
year=ref.get("year"),
|
||||
doi=ref.get("DOI"),
|
||||
url=f"https://doi.org/{ref.get('DOI')}" if ref.get("DOI") else None,
|
||||
abstract="",
|
||||
citations=None,
|
||||
venue=ref.get("journal-title", ""),
|
||||
institutions=[]
|
||||
)
|
||||
for ref in references
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"解析参考文献数据时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params={
|
||||
"filter": f"reference.DOI:{doi}",
|
||||
"select": "DOI,title,author,published-print,abstract"
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
print(f"获取引用信息失败: HTTP {response.status}")
|
||||
print(f"响应内容: {await response.text()}")
|
||||
return []
|
||||
|
||||
try:
|
||||
data = await response.json()
|
||||
# 检查返回的数据结构
|
||||
if isinstance(data, dict):
|
||||
items = data.get("message", {}).get("items", [])
|
||||
return [self._parse_work(work) for work in items]
|
||||
else:
|
||||
print(f"API返回了意外的数据格式: {type(data)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"解析引用数据时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||
"""解析Crossref返回的数据"""
|
||||
# 获取摘要 - 处理可能的不同格式
|
||||
abstract = ""
|
||||
if isinstance(work.get("abstract"), str):
|
||||
abstract = work.get("abstract", "")
|
||||
elif isinstance(work.get("abstract"), dict):
|
||||
abstract = work.get("abstract", {}).get("value", "")
|
||||
|
||||
if not abstract:
|
||||
print(f"警告: 论文 '{work.get('title', [''])[0]}' 没有可用的摘要")
|
||||
|
||||
# 获取机构信息
|
||||
institutions = []
|
||||
for author in work.get("author", []):
|
||||
if "affiliation" in author:
|
||||
for affiliation in author["affiliation"]:
|
||||
if "name" in affiliation and affiliation["name"] not in institutions:
|
||||
institutions.append(affiliation["name"])
|
||||
|
||||
# 获取venue信息
|
||||
venue_name = work.get("container-title", [None])[0]
|
||||
venue_type = work.get("type", "unknown") # 文献类型
|
||||
venue_info = {
|
||||
"publisher": work.get("publisher"),
|
||||
"issn": work.get("ISSN", []),
|
||||
"isbn": work.get("ISBN", []),
|
||||
"issue": work.get("issue"),
|
||||
"volume": work.get("volume"),
|
||||
"page": work.get("page")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=work.get("title", [None])[0] or "",
|
||||
authors=[
|
||||
author.get("given", "") + " " + author.get("family", "")
|
||||
for author in work.get("author", [])
|
||||
],
|
||||
institutions=institutions, # 添加机构信息
|
||||
abstract=abstract,
|
||||
year=work.get("published-print", {}).get("date-parts", [[None]])[0][0],
|
||||
doi=work.get("DOI"),
|
||||
url=f"https://doi.org/{work.get('DOI')}" if work.get("DOI") else None,
|
||||
citations=work.get("is-referenced-by-count"),
|
||||
venue=venue_name,
|
||||
venue_type=venue_type, # 添加venue类型
|
||||
venue_name=venue_name, # 添加venue名称
|
||||
venue_info=venue_info, # 添加venue详细信息
|
||||
source='crossref' # 添加来源标记
|
||||
)
|
||||
|
||||
async def search_by_authors(
|
||||
self,
|
||||
authors: List[str],
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = " ".join([f"author:\"{author}\"" for author in authors])
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
start_year=start_year
|
||||
)
|
||||
|
||||
async def search_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
sort_order: str = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按日期范围搜索论文"""
|
||||
query = f"from-pub-date:{start_date.strftime('%Y-%m-%d')} until-pub-date:{end_date.strftime('%Y-%m-%d')}"
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def example_usage():
|
||||
"""CrossrefSource使用示例"""
|
||||
crossref = CrossrefSource(api_key=None)
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索,使用不同的排序方式
|
||||
print("\n=== 示例1:搜索最新的机器学习论文 ===")
|
||||
papers = await crossref.search(
|
||||
query="machine learning",
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
sort_order="desc",
|
||||
start_year=2023
|
||||
)
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
if paper.institutions:
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"发表venue: {paper.venue}")
|
||||
print(f"venue类型: {paper.venue_type}")
|
||||
if paper.venue_info:
|
||||
print("Venue详细信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
# 示例2:按DOI获取论文详情
|
||||
print("\n=== 示例2:获取特定论文详情 ===")
|
||||
# 使用BERT论文的DOI
|
||||
doi = "10.18653/v1/N19-1423"
|
||||
paper = await crossref.get_paper_details(doi)
|
||||
if paper:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例3:按作者搜索
|
||||
print("\n=== 示例3:搜索特定作者的论文 ===")
|
||||
author_papers = await crossref.search_by_authors(
|
||||
authors=["Yoshua Bengio"],
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
start_year=2020
|
||||
)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- {i}. {paper.title} ---")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例4:按日期范围搜索
|
||||
print("\n=== 示例4:搜索特定日期范围的论文 ===")
|
||||
from datetime import datetime, timedelta
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=30) # 最近一个月
|
||||
recent_papers = await crossref.search_by_date_range(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=3,
|
||||
sort_by="published",
|
||||
sort_order="desc"
|
||||
)
|
||||
for i, paper in enumerate(recent_papers, 1):
|
||||
print(f"\n--- 最近发表的论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
|
||||
# 示例5:获取论文引用信息
|
||||
print("\n=== 示例5:获取论文引用信息 ===")
|
||||
if paper: # 使用之前获取的BERT论文
|
||||
print("\n获取引用该论文的文献:")
|
||||
citations = await crossref.get_citations(paper.doi)
|
||||
for i, citing_paper in enumerate(citations[:3], 1):
|
||||
print(f"\n--- 引用论文 {i} ---")
|
||||
print(f"标题: {citing_paper.title}")
|
||||
print(f"作者: {', '.join(citing_paper.authors)}")
|
||||
print(f"发表年份: {citing_paper.year}")
|
||||
|
||||
print("\n获取该论文引用的参考文献:")
|
||||
references = await crossref.get_references(paper.doi)
|
||||
for i, ref_paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {ref_paper.title}")
|
||||
print(f"作者: {', '.join(ref_paper.authors)}")
|
||||
print(f"发表年份: {ref_paper.year if ref_paper.year else '未知'}")
|
||||
|
||||
# 示例6:展示venue信息的使用
|
||||
print("\n=== 示例6:展示期刊/会议详细信息 ===")
|
||||
if papers:
|
||||
paper = papers[0]
|
||||
print(f"文献类型: {paper.venue_type}")
|
||||
print(f"发表venue: {paper.venue_name}")
|
||||
if paper.venue_info:
|
||||
print("Venue详细信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,449 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class ElsevierSource(DataSource):
|
||||
"""Elsevier (Scopus) API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Elsevier API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.elsevier.com/content"
|
||||
self.headers = {
|
||||
"X-ELS-APIKey": self.api_key,
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
# 添加更多必要的头部信息
|
||||
"X-ELS-Insttoken": "", # 如果有机构令牌
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
JSON响应
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
# 添加更详细的错误信息
|
||||
error_text = await response.text()
|
||||
print(f"请求失败: {response.status}")
|
||||
print(f"错误详情: {error_text}")
|
||||
if response.status == 401:
|
||||
print(f"使用的API密钥: {self.api_key}")
|
||||
# 尝试切换到另一个API密钥
|
||||
new_key = random.choice([k for k in self.API_KEYS if k != self.api_key])
|
||||
print(f"尝试切换到新的API密钥: {new_key}")
|
||||
self.api_key = new_key
|
||||
self.headers["X-ELS-APIKey"] = new_key
|
||||
# 重试请求
|
||||
return await self._make_request(url, params)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
params = {
|
||||
"query": query,
|
||||
"count": min(limit, 100),
|
||||
"view": "STANDARD",
|
||||
# 移除dc:description字段,因为它在STANDARD视图中不可用
|
||||
"field": "dc:title,dc:creator,prism:doi,prism:coverDate,citedby-count,prism:publicationName"
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["date"] = f"{start_year}-present"
|
||||
|
||||
# 添加排序
|
||||
if sort_by == "date":
|
||||
params["sort"] = "-coverDate"
|
||||
elif sort_by == "cited":
|
||||
params["sort"] = "-citedby-count"
|
||||
|
||||
# 发送搜索请求
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
# 解析搜索结果
|
||||
entries = response["search-results"].get("entry", [])
|
||||
papers = [paper for paper in (self._parse_entry(entry) for entry in entries) if paper is not None]
|
||||
|
||||
# 尝试为每篇论文获取摘要
|
||||
for paper in papers:
|
||||
if paper.doi:
|
||||
paper.abstract = await self.fetch_abstract(paper.doi) or ""
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_entry(self, entry: Dict) -> Optional[PaperMetadata]:
|
||||
"""解析Scopus API返回的条目"""
|
||||
try:
|
||||
# 获取作者列表
|
||||
authors = []
|
||||
creator = entry.get("dc:creator")
|
||||
if creator:
|
||||
authors = [creator]
|
||||
|
||||
# 获取发表年份
|
||||
year = None
|
||||
if "prism:coverDate" in entry:
|
||||
try:
|
||||
year = int(entry["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
# 简化venue信息
|
||||
venue_info = {
|
||||
'source_id': entry.get("source-id"),
|
||||
'issn': entry.get("prism:issn")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=entry.get("dc:title", ""),
|
||||
authors=authors,
|
||||
abstract=entry.get("dc:description", ""), # 从响应中获取摘要
|
||||
year=year,
|
||||
doi=entry.get("prism:doi"),
|
||||
url=entry.get("prism:url"),
|
||||
citations=int(entry.get("citedby-count", 0)),
|
||||
venue=entry.get("prism:publicationName"),
|
||||
institutions=[], # 移除机构信息
|
||||
venue_type="",
|
||||
venue_name=entry.get("prism:publicationName"),
|
||||
venue_info=venue_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析条目时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献"""
|
||||
try:
|
||||
params = {
|
||||
"query": f"REF({doi})",
|
||||
"count": min(limit, 100),
|
||||
"view": "STANDARD"
|
||||
}
|
||||
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=params
|
||||
)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
entries = response["search-results"].get("entry", [])
|
||||
return [self._parse_entry(entry) for entry in entries]
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取引用文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献"""
|
||||
try:
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{doi}/references",
|
||||
params={"view": "STANDARD"}
|
||||
)
|
||||
|
||||
if not response or "references" not in response:
|
||||
return []
|
||||
|
||||
references = response["references"].get("reference", [])
|
||||
papers = [paper for paper in (self._parse_reference(ref) for ref in references) if paper is not None]
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_reference(self, ref: Dict) -> Optional[PaperMetadata]:
|
||||
"""解析参考文献数据"""
|
||||
try:
|
||||
authors = []
|
||||
if "author-list" in ref:
|
||||
author_list = ref["author-list"].get("author", [])
|
||||
if isinstance(author_list, list):
|
||||
authors = [f"{author.get('ce:given-name', '')} {author.get('ce:surname', '')}"
|
||||
for author in author_list]
|
||||
else:
|
||||
authors = [f"{author_list.get('ce:given-name', '')} {author_list.get('ce:surname', '')}"]
|
||||
|
||||
year = None
|
||||
if "prism:coverDate" in ref:
|
||||
try:
|
||||
year = int(ref["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
return PaperMetadata(
|
||||
title=ref.get("ce:title", ""),
|
||||
authors=authors,
|
||||
abstract="", # 参考文献通常不包含摘要
|
||||
year=year,
|
||||
doi=ref.get("prism:doi"),
|
||||
url=None,
|
||||
citations=None,
|
||||
venue=ref.get("prism:publicationName"),
|
||||
institutions=[],
|
||||
venue_type="unknown",
|
||||
venue_name=ref.get("prism:publicationName"),
|
||||
venue_info={}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析参考文献时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"AUTHOR-NAME({author})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_affiliation(
|
||||
self,
|
||||
affiliation: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按机构搜索论文"""
|
||||
query = f"AF-ID({affiliation})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def search_by_venue(
|
||||
self,
|
||||
venue: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊/会议搜索论文"""
|
||||
query = f"SRCTITLE({venue})"
|
||||
return await self.search(query, limit=limit, start_year=start_year)
|
||||
|
||||
async def test_api_access(self):
|
||||
"""测试API访问权限"""
|
||||
print(f"\n测试API密钥: {self.api_key}")
|
||||
|
||||
# 测试1: 基础搜索
|
||||
basic_params = {
|
||||
"query": "test",
|
||||
"count": 1,
|
||||
"view": "STANDARD"
|
||||
}
|
||||
print("\n1. 测试基础搜索...")
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/search/scopus",
|
||||
params=basic_params
|
||||
)
|
||||
if response:
|
||||
print("基础搜索成功")
|
||||
print("可用字段:", list(response.get("search-results", {}).get("entry", [{}])[0].keys()))
|
||||
|
||||
# 测试2: 测试单篇文章访问
|
||||
print("\n2. 测试文章详情访问...")
|
||||
test_doi = "10.1016/j.artint.2021.103535" # 一个示例DOI
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{test_doi}",
|
||||
params={"view": "STANDARD"} # 改为STANDARD视图
|
||||
)
|
||||
if response:
|
||||
print("文章详情访问成功")
|
||||
else:
|
||||
print("文章详情访问失败")
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详细信息
|
||||
|
||||
注意:当前API权限不支持获取详细信息,返回None
|
||||
|
||||
Args:
|
||||
paper_id: 论文ID
|
||||
|
||||
Returns:
|
||||
None,因为当前API权限不支持此功能
|
||||
"""
|
||||
return None
|
||||
|
||||
async def fetch_abstract(self, doi: str) -> Optional[str]:
|
||||
"""获取论文摘要
|
||||
|
||||
使用Scopus Abstract API获取论文摘要
|
||||
|
||||
Args:
|
||||
doi: 论文的DOI
|
||||
|
||||
Returns:
|
||||
摘要文本,如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 使用Abstract API而不是Search API
|
||||
response = await self._make_request(
|
||||
f"{self.base_url}/abstract/doi/{doi}",
|
||||
params={
|
||||
"view": "FULL" # 使用FULL视图
|
||||
}
|
||||
)
|
||||
|
||||
if response and "abstracts-retrieval-response" in response:
|
||||
# 从coredata中获取摘要
|
||||
coredata = response["abstracts-retrieval-response"].get("coredata", {})
|
||||
return coredata.get("dc:description", "")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取摘要时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def example_usage():
|
||||
"""ElsevierSource使用示例"""
|
||||
elsevier = ElsevierSource()
|
||||
|
||||
try:
|
||||
# 首先测试API访问权限
|
||||
print("\n=== 测试API访问权限 ===")
|
||||
await elsevier.test_api_access()
|
||||
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||
papers = await elsevier.search("machine learning", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
print("期刊信息:")
|
||||
for key, value in paper.venue_info.items():
|
||||
if value: # 只打印非空值
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
# 示例2:获取引用信息
|
||||
if papers and papers[0].doi:
|
||||
print("\n=== 示例2:获取引用该论文的文献 ===")
|
||||
citations = await elsevier.get_citations(papers[0].doi, limit=3)
|
||||
for i, paper in enumerate(citations, 1):
|
||||
print(f"\n--- 引用论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例3:获取参考文献
|
||||
if papers and papers[0].doi:
|
||||
print("\n=== 示例3:获取论文的参考文献 ===")
|
||||
references = await elsevier.get_references(papers[0].doi)
|
||||
for i, paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例4:按作者搜索
|
||||
print("\n=== 示例4:按作者搜索 ===")
|
||||
author_papers = await elsevier.search_by_author("Hinton G", limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例5:按机构搜索
|
||||
print("\n=== 示例5:按机构搜索 ===")
|
||||
affiliation_papers = await elsevier.search_by_affiliation("60027950", limit=3) # MIT的机构ID
|
||||
for i, paper in enumerate(affiliation_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"期刊/会议: {paper.venue}")
|
||||
|
||||
# 示例6:获取论文摘要
|
||||
print("\n=== 示例6:获取论文摘要 ===")
|
||||
test_doi = "10.1016/j.artint.2021.103535"
|
||||
abstract = await elsevier.fetch_abstract(test_doi)
|
||||
if abstract:
|
||||
print(f"摘要: {abstract[:200]}...") # 只显示前200个字符
|
||||
else:
|
||||
print("无法获取摘要")
|
||||
|
||||
# 在搜索结果中显示摘要
|
||||
print("\n=== 示例7:搜索结果中的摘要 ===")
|
||||
papers = await elsevier.search("machine learning", limit=1)
|
||||
for paper in papers:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"摘要: {paper.abstract[:200]}..." if paper.abstract else "摘要: 无")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,698 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
|
||||
class GitHubSource:
|
||||
"""GitHub API实现"""
|
||||
|
||||
# 默认API密钥列表 - 可以放置多个GitHub令牌
|
||||
API_KEYS = [
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"github_pat_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
]
|
||||
|
||||
def __init__(self, api_key: Optional[Union[str, List[str]]] = None):
|
||||
"""初始化GitHub API客户端
|
||||
|
||||
Args:
|
||||
api_key: GitHub个人访问令牌或令牌列表
|
||||
"""
|
||||
if api_key is None:
|
||||
self.api_keys = self.API_KEYS
|
||||
elif isinstance(api_key, str):
|
||||
self.api_keys = [api_key]
|
||||
else:
|
||||
self.api_keys = api_key
|
||||
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.base_url = "https://api.github.com"
|
||||
self.headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
"User-Agent": "GitHub-API-Python-Client"
|
||||
}
|
||||
|
||||
# 如果有可用的API密钥,随机选择一个
|
||||
if self.api_keys:
|
||||
selected_key = random.choice(self.api_keys)
|
||||
self.headers["Authorization"] = f"Bearer {selected_key}"
|
||||
print(f"已随机选择API密钥进行认证")
|
||||
else:
|
||||
print("警告: 未提供API密钥,将受到GitHub API请求限制")
|
||||
|
||||
async def _request(self, method: str, endpoint: str, params: Dict = None, data: Dict = None) -> Any:
|
||||
"""发送API请求
|
||||
|
||||
Args:
|
||||
method: HTTP方法 (GET, POST, PUT, DELETE等)
|
||||
endpoint: API端点
|
||||
params: URL参数
|
||||
data: 请求体数据
|
||||
|
||||
Returns:
|
||||
解析后的响应JSON
|
||||
"""
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
|
||||
# 为调试目的打印请求信息
|
||||
print(f"请求: {method} {url}")
|
||||
if params:
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 发送请求
|
||||
request_kwargs = {}
|
||||
if params:
|
||||
request_kwargs["params"] = params
|
||||
if data:
|
||||
request_kwargs["json"] = data
|
||||
|
||||
async with session.request(method, url, **request_kwargs) as response:
|
||||
response_text = await response.text()
|
||||
|
||||
# 检查HTTP状态码
|
||||
if response.status >= 400:
|
||||
print(f"API请求失败: HTTP {response.status}")
|
||||
print(f"响应内容: {response_text}")
|
||||
return None
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except json.JSONDecodeError:
|
||||
print(f"JSON解析错误: {response_text}")
|
||||
return None
|
||||
|
||||
# ===== 用户相关方法 =====
|
||||
|
||||
async def get_user(self, username: Optional[str] = None) -> Dict:
|
||||
"""获取用户信息
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
|
||||
Returns:
|
||||
用户信息字典
|
||||
"""
|
||||
endpoint = "/user" if username is None else f"/users/{username}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_user_repos(self, username: Optional[str] = None, sort: str = "updated",
|
||||
direction: str = "desc", per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户的仓库列表
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = "/user/repos" if username is None else f"/users/{username}/repos"
|
||||
params = {
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_user_starred(self, username: Optional[str] = None,
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取用户星标的仓库
|
||||
|
||||
Args:
|
||||
username: 指定用户名,不指定则获取当前授权用户
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
星标仓库列表
|
||||
"""
|
||||
endpoint = "/user/starred" if username is None else f"/users/{username}/starred"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 仓库相关方法 =====
|
||||
|
||||
async def get_repo(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库信息
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
仓库信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repo_branches(self, owner: str, repo: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的分支列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
分支列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/branches"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_repo_commits(self, owner: str, repo: str, sha: Optional[str] = None,
|
||||
path: Optional[str] = None, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的提交历史
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
sha: 特定提交SHA或分支名
|
||||
path: 文件路径筛选
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
提交列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
if sha:
|
||||
params["sha"] = sha
|
||||
if path:
|
||||
params["path"] = path
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_commit_details(self, owner: str, repo: str, commit_sha: str) -> Dict:
|
||||
"""获取特定提交的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
commit_sha: 提交SHA
|
||||
|
||||
Returns:
|
||||
提交详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/commits/{commit_sha}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 内容相关方法 =====
|
||||
|
||||
async def get_file_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> Dict:
|
||||
"""获取文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 文件路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
文件内容信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
response = await self._request("GET", endpoint, params=params)
|
||||
if response and isinstance(response, dict) and "content" in response:
|
||||
try:
|
||||
# 解码Base64编码的文件内容
|
||||
content = base64.b64decode(response["content"].encode()).decode()
|
||||
response["decoded_content"] = content
|
||||
except Exception as e:
|
||||
print(f"解码文件内容时出错: {str(e)}")
|
||||
|
||||
return response
|
||||
|
||||
async def get_directory_content(self, owner: str, repo: str, path: str, ref: Optional[str] = None) -> List[Dict]:
|
||||
"""获取目录内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
path: 目录路径
|
||||
ref: 分支名、标签名或提交SHA
|
||||
|
||||
Returns:
|
||||
目录内容列表
|
||||
"""
|
||||
# 注意:此方法与get_file_content使用相同的端点,但对于目录会返回列表
|
||||
endpoint = f"/repos/{owner}/{repo}/contents/{path}"
|
||||
params = {}
|
||||
if ref:
|
||||
params["ref"] = ref
|
||||
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== Issues相关方法 =====
|
||||
|
||||
async def get_issues(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Issues列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: Issue状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Issues列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_issue(self, owner: str, repo: str, issue_number: int) -> Dict:
|
||||
"""获取特定Issue的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
Issue详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_issue_comments(self, owner: str, repo: str, issue_number: int) -> List[Dict]:
|
||||
"""获取Issue的评论
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
issue_number: Issue编号
|
||||
|
||||
Returns:
|
||||
评论列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/issues/{issue_number}/comments"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== Pull Requests相关方法 =====
|
||||
|
||||
async def get_pull_requests(self, owner: str, repo: str, state: str = "open",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取仓库的Pull Request列表
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
state: PR状态 (open, closed, all)
|
||||
sort: 排序方式 (created, updated, popularity, long-running)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
Pull Request列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls"
|
||||
params = {
|
||||
"state": state,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_pull_request(self, owner: str, repo: str, pr_number: int) -> Dict:
|
||||
"""获取特定Pull Request的详情
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
Pull Request详情
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_pull_request_files(self, owner: str, repo: str, pr_number: int) -> List[Dict]:
|
||||
"""获取Pull Request中修改的文件
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
pr_number: Pull Request编号
|
||||
|
||||
Returns:
|
||||
修改文件列表
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/pulls/{pr_number}/files"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
# ===== 搜索相关方法 =====
|
||||
|
||||
async def search_repositories(self, query: str, sort: str = "stars",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索仓库
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (stars, forks, updated)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/repositories"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_code(self, query: str, sort: str = "indexed",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索代码
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (indexed)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/code"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_issues(self, query: str, sort: str = "created",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索Issues和Pull Requests
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (created, updated, comments)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/issues"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def search_users(self, query: str, sort: str = "followers",
|
||||
order: str = "desc", per_page: int = 30, page: int = 1) -> Dict:
|
||||
"""搜索用户
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
sort: 排序方式 (followers, repositories, joined)
|
||||
order: 排序顺序 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
搜索结果
|
||||
"""
|
||||
endpoint = "/search/users"
|
||||
params = {
|
||||
"q": query,
|
||||
"sort": sort,
|
||||
"order": order,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 组织相关方法 =====
|
||||
|
||||
async def get_organization(self, org: str) -> Dict:
|
||||
"""获取组织信息
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
|
||||
Returns:
|
||||
组织信息
|
||||
"""
|
||||
endpoint = f"/orgs/{org}"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_organization_repos(self, org: str, type: str = "all",
|
||||
sort: str = "created", direction: str = "desc",
|
||||
per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织的仓库列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
type: 仓库类型 (all, public, private, forks, sources, member, internal)
|
||||
sort: 排序方式 (created, updated, pushed, full_name)
|
||||
direction: 排序方向 (asc, desc)
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
仓库列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/repos"
|
||||
params = {
|
||||
"type": type,
|
||||
"sort": sort,
|
||||
"direction": direction,
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
async def get_organization_members(self, org: str, per_page: int = 30, page: int = 1) -> List[Dict]:
|
||||
"""获取组织成员列表
|
||||
|
||||
Args:
|
||||
org: 组织名称
|
||||
per_page: 每页结果数量
|
||||
page: 页码
|
||||
|
||||
Returns:
|
||||
成员列表
|
||||
"""
|
||||
endpoint = f"/orgs/{org}/members"
|
||||
params = {
|
||||
"per_page": per_page,
|
||||
"page": page
|
||||
}
|
||||
return await self._request("GET", endpoint, params=params)
|
||||
|
||||
# ===== 更复杂的操作 =====
|
||||
|
||||
async def get_repository_languages(self, owner: str, repo: str) -> Dict:
|
||||
"""获取仓库使用的编程语言及其比例
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
语言使用情况
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/languages"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_contributors(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的贡献者统计
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
贡献者统计信息
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/contributors"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def get_repository_stats_commit_activity(self, owner: str, repo: str) -> List[Dict]:
|
||||
"""获取仓库的提交活动
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名
|
||||
|
||||
Returns:
|
||||
提交活动统计
|
||||
"""
|
||||
endpoint = f"/repos/{owner}/{repo}/stats/commit_activity"
|
||||
return await self._request("GET", endpoint)
|
||||
|
||||
async def example_usage():
|
||||
"""GitHubSource使用示例"""
|
||||
# 创建客户端实例(可选传入API令牌)
|
||||
# github = GitHubSource(api_key="your_github_token")
|
||||
github = GitHubSource()
|
||||
|
||||
try:
|
||||
# 示例1:搜索热门Python仓库
|
||||
print("\n=== 示例1:搜索热门Python仓库 ===")
|
||||
repos = await github.search_repositories(
|
||||
query="language:python stars:>1000",
|
||||
sort="stars",
|
||||
order="desc",
|
||||
per_page=5
|
||||
)
|
||||
|
||||
if repos and "items" in repos:
|
||||
for i, repo in enumerate(repos["items"], 1):
|
||||
print(f"\n--- 仓库 {i} ---")
|
||||
print(f"名称: {repo['full_name']}")
|
||||
print(f"描述: {repo['description']}")
|
||||
print(f"星标数: {repo['stargazers_count']}")
|
||||
print(f"Fork数: {repo['forks_count']}")
|
||||
print(f"最近更新: {repo['updated_at']}")
|
||||
print(f"URL: {repo['html_url']}")
|
||||
|
||||
# 示例2:获取特定仓库的详情
|
||||
print("\n=== 示例2:获取特定仓库的详情 ===")
|
||||
repo_details = await github.get_repo("microsoft", "vscode")
|
||||
if repo_details:
|
||||
print(f"名称: {repo_details['full_name']}")
|
||||
print(f"描述: {repo_details['description']}")
|
||||
print(f"星标数: {repo_details['stargazers_count']}")
|
||||
print(f"Fork数: {repo_details['forks_count']}")
|
||||
print(f"默认分支: {repo_details['default_branch']}")
|
||||
print(f"开源许可: {repo_details.get('license', {}).get('name', '无')}")
|
||||
print(f"语言: {repo_details['language']}")
|
||||
print(f"Open Issues数: {repo_details['open_issues_count']}")
|
||||
|
||||
# 示例3:获取仓库的提交历史
|
||||
print("\n=== 示例3:获取仓库的最近提交 ===")
|
||||
commits = await github.get_repo_commits("tensorflow", "tensorflow", per_page=5)
|
||||
if commits:
|
||||
for i, commit in enumerate(commits, 1):
|
||||
print(f"\n--- 提交 {i} ---")
|
||||
print(f"SHA: {commit['sha'][:7]}")
|
||||
print(f"作者: {commit['commit']['author']['name']}")
|
||||
print(f"日期: {commit['commit']['author']['date']}")
|
||||
print(f"消息: {commit['commit']['message'].splitlines()[0]}")
|
||||
|
||||
# 示例4:搜索代码
|
||||
print("\n=== 示例4:搜索代码 ===")
|
||||
code_results = await github.search_code(
|
||||
query="filename:README.md language:markdown pytorch in:file",
|
||||
per_page=3
|
||||
)
|
||||
if code_results and "items" in code_results:
|
||||
print(f"共找到: {code_results['total_count']} 个结果")
|
||||
for i, item in enumerate(code_results["items"], 1):
|
||||
print(f"\n--- 代码 {i} ---")
|
||||
print(f"仓库: {item['repository']['full_name']}")
|
||||
print(f"文件: {item['path']}")
|
||||
print(f"URL: {item['html_url']}")
|
||||
|
||||
# 示例5:获取文件内容
|
||||
print("\n=== 示例5:获取文件内容 ===")
|
||||
file_content = await github.get_file_content("python", "cpython", "README.rst")
|
||||
if file_content and "decoded_content" in file_content:
|
||||
content = file_content["decoded_content"]
|
||||
print(f"文件名: {file_content['name']}")
|
||||
print(f"大小: {file_content['size']} 字节")
|
||||
print(f"内容预览: {content[:200]}...")
|
||||
|
||||
# 示例6:获取仓库使用的编程语言
|
||||
print("\n=== 示例6:获取仓库使用的编程语言 ===")
|
||||
languages = await github.get_repository_languages("facebook", "react")
|
||||
if languages:
|
||||
print(f"React仓库使用的编程语言:")
|
||||
for lang, bytes_of_code in languages.items():
|
||||
print(f"- {lang}: {bytes_of_code} 字节")
|
||||
|
||||
# 示例7:获取组织信息
|
||||
print("\n=== 示例7:获取组织信息 ===")
|
||||
org_info = await github.get_organization("google")
|
||||
if org_info:
|
||||
print(f"名称: {org_info['name']}")
|
||||
print(f"描述: {org_info.get('description', '无')}")
|
||||
print(f"位置: {org_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {org_info['public_repos']}")
|
||||
print(f"成员数: {org_info.get('public_members', 0)}")
|
||||
print(f"URL: {org_info['html_url']}")
|
||||
|
||||
# 示例8:获取用户信息
|
||||
print("\n=== 示例8:获取用户信息 ===")
|
||||
user_info = await github.get_user("torvalds")
|
||||
if user_info:
|
||||
print(f"名称: {user_info['name']}")
|
||||
print(f"公司: {user_info.get('company', '无')}")
|
||||
print(f"博客: {user_info.get('blog', '无')}")
|
||||
print(f"位置: {user_info.get('location', '未指定')}")
|
||||
print(f"公共仓库数: {user_info['public_repos']}")
|
||||
print(f"关注者数: {user_info['followers']}")
|
||||
print(f"URL: {user_info['html_url']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,142 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
class JournalMetrics:
|
||||
"""期刊指标管理类"""
|
||||
|
||||
def __init__(self):
|
||||
self.journal_data: Dict = {} # 期刊名称到指标的映射
|
||||
self.issn_map: Dict = {} # ISSN到指标的映射
|
||||
self.name_map: Dict = {} # 标准化名称到指标的映射
|
||||
self._load_journal_data()
|
||||
|
||||
def _normalize_journal_name(self, name: str) -> str:
|
||||
"""标准化期刊名称
|
||||
|
||||
Args:
|
||||
name: 原始期刊名称
|
||||
|
||||
Returns:
|
||||
标准化后的期刊名称
|
||||
"""
|
||||
if not name:
|
||||
return ""
|
||||
|
||||
# 转换为小写
|
||||
name = name.lower()
|
||||
|
||||
# 移除常见的前缀和后缀
|
||||
prefixes = ['the ', 'proceedings of ', 'journal of ']
|
||||
suffixes = [' journal', ' proceedings', ' magazine', ' review', ' letters']
|
||||
|
||||
for prefix in prefixes:
|
||||
if name.startswith(prefix):
|
||||
name = name[len(prefix):]
|
||||
|
||||
for suffix in suffixes:
|
||||
if name.endswith(suffix):
|
||||
name = name[:-len(suffix)]
|
||||
|
||||
# 移除特殊字符,保留字母、数字和空格
|
||||
name = ''.join(c for c in name if c.isalnum() or c.isspace())
|
||||
|
||||
# 移除多余的空格
|
||||
name = ' '.join(name.split())
|
||||
|
||||
return name
|
||||
|
||||
def _convert_if_value(self, if_str: str) -> Optional[float]:
|
||||
"""转换IF值为float,处理特殊情况"""
|
||||
try:
|
||||
if if_str.startswith('<'):
|
||||
# 对于<0.1这样的值,返回0.1
|
||||
return float(if_str.strip('<'))
|
||||
return float(if_str)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
def _load_journal_data(self):
|
||||
"""加载期刊数据"""
|
||||
try:
|
||||
file_path = os.path.join(os.path.dirname(__file__), 'cas_if.json')
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 建立期刊名称到指标的映射
|
||||
for journal in data:
|
||||
# 准备指标数据
|
||||
metrics = {
|
||||
'if_factor': self._convert_if_value(journal.get('IF')),
|
||||
'jcr_division': journal.get('Q'),
|
||||
'cas_division': journal.get('B')
|
||||
}
|
||||
|
||||
# 存储期刊名称映射(使用标准化名称)
|
||||
if journal.get('journal'):
|
||||
normalized_name = self._normalize_journal_name(journal['journal'])
|
||||
self.journal_data[normalized_name] = metrics
|
||||
self.name_map[normalized_name] = metrics
|
||||
|
||||
# 存储期刊缩写映射
|
||||
if journal.get('jabb'):
|
||||
normalized_abbr = self._normalize_journal_name(journal['jabb'])
|
||||
self.journal_data[normalized_abbr] = metrics
|
||||
self.name_map[normalized_abbr] = metrics
|
||||
|
||||
# 存储ISSN映射
|
||||
if journal.get('issn'):
|
||||
self.issn_map[journal['issn']] = metrics
|
||||
if journal.get('eissn'):
|
||||
self.issn_map[journal['eissn']] = metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"加载期刊数据时出错: {str(e)}")
|
||||
self.journal_data = {}
|
||||
self.issn_map = {}
|
||||
self.name_map = {}
|
||||
|
||||
def get_journal_metrics(self, venue_name: str, venue_info: dict) -> dict:
|
||||
"""获取期刊指标
|
||||
|
||||
Args:
|
||||
venue_name: 期刊名称
|
||||
venue_info: 期刊详细信息
|
||||
|
||||
Returns:
|
||||
包含期刊指标的字典
|
||||
"""
|
||||
try:
|
||||
metrics = {}
|
||||
|
||||
# 1. 首先尝试通过ISSN匹配
|
||||
if venue_info and 'issn' in venue_info:
|
||||
issn_value = venue_info['issn']
|
||||
# 处理ISSN可能是列表的情况
|
||||
if isinstance(issn_value, list):
|
||||
# 尝试每个ISSN
|
||||
for issn in issn_value:
|
||||
metrics = self.issn_map.get(issn, {})
|
||||
if metrics: # 如果找到匹配的指标,就停止搜索
|
||||
break
|
||||
else: # ISSN是字符串的情况
|
||||
metrics = self.issn_map.get(issn_value, {})
|
||||
|
||||
# 2. 如果ISSN匹配失败,尝试通过期刊名称匹配
|
||||
if not metrics and venue_name:
|
||||
# 标准化期刊名称
|
||||
normalized_name = self._normalize_journal_name(venue_name)
|
||||
metrics = self.name_map.get(normalized_name, {})
|
||||
|
||||
# 如果完全匹配失败,尝试部分匹配
|
||||
# if not metrics:
|
||||
# for db_name, db_metrics in self.name_map.items():
|
||||
# if normalized_name in db_name:
|
||||
# metrics = db_metrics
|
||||
# break
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取期刊指标时出错: {str(e)}")
|
||||
return {}
|
||||
@@ -0,0 +1,163 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
import os
|
||||
from urllib.parse import quote
|
||||
|
||||
class OpenAlexSource(DataSource):
|
||||
"""OpenAlex API实现"""
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self.base_url = "https://api.openalex.org"
|
||||
self.mailto = "xxxxxxxxxxxxxxxxxxxxxxxx@163.com" # 直接写入邮件地址
|
||||
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
params.update({
|
||||
"filter": f"title.search:{query}",
|
||||
"per-page": limit
|
||||
})
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
results = data.get("results", [])
|
||||
return [self._parse_work(work) for work in results]
|
||||
except Exception as e:
|
||||
print(f"搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_work(self, work: Dict) -> PaperMetadata:
|
||||
"""解析OpenAlex返回的数据"""
|
||||
# 获取作者信息
|
||||
raw_author_names = [
|
||||
authorship.get("raw_author_name", "")
|
||||
for authorship in work.get("authorships", [])
|
||||
if authorship
|
||||
]
|
||||
# 处理作者名字格式
|
||||
authors = [
|
||||
self._reformat_name(author)
|
||||
for author in raw_author_names
|
||||
]
|
||||
|
||||
# 获取机构信息
|
||||
institutions = [
|
||||
inst.get("display_name", "")
|
||||
for authorship in work.get("authorships", [])
|
||||
for inst in authorship.get("institutions", [])
|
||||
if inst
|
||||
]
|
||||
|
||||
# 获取主要发表位置信息
|
||||
primary_location = work.get("primary_location") or {}
|
||||
source = primary_location.get("source") or {}
|
||||
venue = source.get("display_name")
|
||||
|
||||
# 获取发表日期
|
||||
year = work.get("publication_year")
|
||||
|
||||
return PaperMetadata(
|
||||
title=work.get("title", ""),
|
||||
authors=authors,
|
||||
institutions=institutions,
|
||||
abstract=work.get("abstract", ""),
|
||||
year=year,
|
||||
doi=work.get("doi"),
|
||||
url=work.get("doi"), # OpenAlex 使用 DOI 作为 URL
|
||||
citations=work.get("cited_by_count"),
|
||||
venue=venue
|
||||
)
|
||||
|
||||
def _reformat_name(self, name: str) -> str:
|
||||
"""重新格式化作者名字"""
|
||||
if "," not in name:
|
||||
return name
|
||||
family, given_names = (x.strip() for x in name.split(",", maxsplit=1))
|
||||
return f"{given_names} {family}"
|
||||
|
||||
async def get_paper_details(self, doi: str) -> PaperMetadata:
|
||||
"""获取指定DOI的论文详情"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return self._parse_work(data)
|
||||
|
||||
async def get_references(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works/https://doi.org/{quote(doi, safe='')}/references",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_work(work) for work in data.get("results", [])]
|
||||
|
||||
async def get_citations(self, doi: str) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
params = {"mailto": self.mailto} if self.mailto else {}
|
||||
params.update({
|
||||
"filter": f"cites:doi:{doi}",
|
||||
"per-page": 100
|
||||
})
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/works",
|
||||
params=params
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_work(work) for work in data.get("results", [])]
|
||||
|
||||
async def example_usage():
|
||||
"""OpenAlexSource使用示例"""
|
||||
# 初始化OpenAlexSource
|
||||
openalex = OpenAlexSource()
|
||||
|
||||
try:
|
||||
print("正在搜索论文...")
|
||||
# 搜索与"artificial intelligence"相关的论文,限制返回5篇
|
||||
papers = await openalex.search(query="artificial intelligence", limit=5)
|
||||
|
||||
if not papers:
|
||||
print("未获取到任何论文信息")
|
||||
return
|
||||
|
||||
print(f"找到 {len(papers)} 篇论文")
|
||||
|
||||
# 打印搜索结果
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors) if paper.authors else '未知'}")
|
||||
if paper.institutions:
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
print(f"发表年份: {paper.year if paper.year else '未知'}")
|
||||
print(f"DOI: {paper.doi if paper.doi else '未知'}")
|
||||
print(f"URL: {paper.url if paper.url else '未知'}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
print(f"引用次数: {paper.citations if paper.citations is not None else '未知'}")
|
||||
print(f"发表venue: {paper.venue if paper.venue else '未知'}")
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
# 如果直接运行此文件,执行示例代码
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 运行示例
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,458 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import xml.etree.ElementTree as ET
|
||||
from urllib.parse import quote
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
||||
class PubMedSource(DataSource):
|
||||
"""PubMed API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: PubMed API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS) # 随机选择一个API密钥
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
||||
self.headers = {
|
||||
"User-Agent": "Mozilla/5.0 PubMedDataSource/1.0",
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str) -> Optional[str]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
|
||||
Returns:
|
||||
响应内容
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
return await response.text()
|
||||
else:
|
||||
print(f"请求失败: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = "relevance",
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
query = f"{query} AND {start_year}:3000[dp]"
|
||||
|
||||
# 构建搜索URL
|
||||
search_url = (
|
||||
f"{self.base_url}/esearch.fcgi?"
|
||||
f"db=pubmed&term={quote(query)}&retmax={limit}"
|
||||
f"&usehistory=y&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
if sort_by == "date":
|
||||
search_url += "&sort=date"
|
||||
|
||||
# 获取搜索结果
|
||||
response = await self._make_request(search_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
id_list = root.findall(".//Id")
|
||||
pmids = [id_elem.text for id_elem in id_list]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 批量获取论文详情
|
||||
papers = []
|
||||
batch_size = 50
|
||||
for i in range(0, len(pmids), batch_size):
|
||||
batch = pmids[i:i + batch_size]
|
||||
batch_papers = await self._fetch_papers_batch(batch)
|
||||
papers.extend(batch_papers)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _fetch_papers_batch(self, pmids: List[str]) -> List[PaperMetadata]:
|
||||
"""批量获取论文详情
|
||||
|
||||
Args:
|
||||
pmids: PubMed ID列表
|
||||
|
||||
Returns:
|
||||
论文详情列表
|
||||
"""
|
||||
try:
|
||||
# 构建批量获取URL
|
||||
fetch_url = (
|
||||
f"{self.base_url}/efetch.fcgi?"
|
||||
f"db=pubmed&id={','.join(pmids)}"
|
||||
f"&retmode=xml&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(fetch_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
articles = root.findall(".//PubmedArticle")
|
||||
|
||||
return [self._parse_article(article) for article in articles]
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文批次时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_article(self, article: ET.Element) -> PaperMetadata:
|
||||
"""解析PubMed文章XML
|
||||
|
||||
Args:
|
||||
article: XML元素
|
||||
|
||||
Returns:
|
||||
解析后的论文数据
|
||||
"""
|
||||
try:
|
||||
# 提取基本信息
|
||||
pmid = article.find(".//PMID").text
|
||||
article_meta = article.find(".//Article")
|
||||
|
||||
# 获取标题
|
||||
title = article_meta.find(".//ArticleTitle")
|
||||
title = title.text if title is not None else ""
|
||||
|
||||
# 获取作者列表
|
||||
authors = []
|
||||
author_list = article_meta.findall(".//Author")
|
||||
for author in author_list:
|
||||
last_name = author.find("LastName")
|
||||
fore_name = author.find("ForeName")
|
||||
if last_name is not None and fore_name is not None:
|
||||
authors.append(f"{fore_name.text} {last_name.text}")
|
||||
elif last_name is not None:
|
||||
authors.append(last_name.text)
|
||||
|
||||
# 获取摘要
|
||||
abstract = article_meta.find(".//Abstract/AbstractText")
|
||||
abstract = abstract.text if abstract is not None else ""
|
||||
|
||||
# 获取发表年份
|
||||
pub_date = article_meta.find(".//PubDate/Year")
|
||||
year = int(pub_date.text) if pub_date is not None else None
|
||||
|
||||
# 获取DOI
|
||||
doi = article.find(".//ELocationID[@EIdType='doi']")
|
||||
doi = doi.text if doi is not None else None
|
||||
|
||||
# 获取期刊信息
|
||||
journal = article_meta.find(".//Journal")
|
||||
if journal is not None:
|
||||
journal_title = journal.find(".//Title")
|
||||
venue = journal_title.text if journal_title is not None else None
|
||||
|
||||
# 获取期刊详细信息
|
||||
venue_info = {
|
||||
'issn': journal.findtext(".//ISSN"),
|
||||
'volume': journal.findtext(".//Volume"),
|
||||
'issue': journal.findtext(".//Issue"),
|
||||
'pub_date': journal.findtext(".//PubDate/MedlineDate") or
|
||||
f"{journal.findtext('.//PubDate/Year', '')}-{journal.findtext('.//PubDate/Month', '')}"
|
||||
}
|
||||
else:
|
||||
venue = None
|
||||
venue_info = {}
|
||||
|
||||
# 获取机构信息
|
||||
institutions = []
|
||||
affiliations = article_meta.findall(".//Affiliation")
|
||||
for affiliation in affiliations:
|
||||
if affiliation is not None and affiliation.text:
|
||||
institutions.append(affiliation.text)
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/" if pmid else None,
|
||||
citations=None, # PubMed API不直接提供引用数据
|
||||
venue=venue,
|
||||
institutions=institutions,
|
||||
venue_type="journal",
|
||||
venue_name=venue,
|
||||
venue_info=venue_info,
|
||||
source='pubmed' # 添加来源标记
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析文章时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_paper_details(self, pmid: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定PMID的论文详情"""
|
||||
papers = await self._fetch_papers_batch([pmid])
|
||||
return papers[0] if papers else None
|
||||
|
||||
async def get_related_papers(self, pmid: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取相关论文
|
||||
|
||||
使用PubMed的相关文章功能
|
||||
|
||||
Args:
|
||||
pmid: PubMed ID
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
相关论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建相关文章URL
|
||||
link_url = (
|
||||
f"{self.base_url}/elink.fcgi?"
|
||||
f"db=pubmed&id={pmid}&cmd=neighbor&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(link_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
related_ids = root.findall(".//Link/Id")
|
||||
pmids = [id_elem.text for id_elem in related_ids][:limit]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 获取相关论文详情
|
||||
return await self._fetch_papers_batch(pmids)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取相关论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"{author}[Author]"
|
||||
if start_year:
|
||||
query += f" AND {start_year}:3000[dp]"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"{journal}[Journal]"
|
||||
if start_year:
|
||||
query += f" AND {start_year}:3000[dp]"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文
|
||||
|
||||
Args:
|
||||
days: 最近几天的论文
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
最新论文列表
|
||||
"""
|
||||
query = f"last {days} days[dp]"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献
|
||||
|
||||
注意:PubMed API本身不提供引用数据,此方法将返回空列表
|
||||
未来可以考虑集成其他数据源(如CrossRef)来获取引用信息
|
||||
|
||||
Args:
|
||||
paper_id: PubMed ID
|
||||
|
||||
Returns:
|
||||
空列表,因为PubMed不提供引用数据
|
||||
"""
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献
|
||||
|
||||
从PubMed文章的参考文献列表获取引用的文献
|
||||
|
||||
Args:
|
||||
paper_id: PubMed ID
|
||||
|
||||
Returns:
|
||||
引用的文献列表
|
||||
"""
|
||||
try:
|
||||
# 构建获取参考文献的URL
|
||||
refs_url = (
|
||||
f"{self.base_url}/elink.fcgi?"
|
||||
f"dbfrom=pubmed&db=pubmed&id={paper_id}"
|
||||
f"&cmd=neighbor_history&linkname=pubmed_pubmed_refs"
|
||||
f"&api_key={self.api_key}"
|
||||
)
|
||||
|
||||
response = await self._make_request(refs_url)
|
||||
if not response:
|
||||
return []
|
||||
|
||||
# 解析XML响应
|
||||
root = ET.fromstring(response)
|
||||
ref_ids = root.findall(".//Link/Id")
|
||||
pmids = [id_elem.text for id_elem in ref_ids]
|
||||
|
||||
if not pmids:
|
||||
return []
|
||||
|
||||
# 获取参考文献详情
|
||||
return await self._fetch_papers_batch(pmids)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def example_usage():
|
||||
"""PubMedSource使用示例"""
|
||||
pubmed = PubMedSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索COVID-19相关论文 ===")
|
||||
papers = await pubmed.search("COVID-19", limit=3)
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要: {paper.abstract[:200]}...")
|
||||
|
||||
# 示例2:获取论文详情
|
||||
if papers:
|
||||
print("\n=== 示例2:获取论文详情 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
paper = await pubmed.get_paper_details(paper_id)
|
||||
if paper:
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"期刊: {paper.venue}")
|
||||
print(f"机构: {', '.join(paper.institutions)}")
|
||||
|
||||
# 示例3:获取相关论文
|
||||
if papers:
|
||||
print("\n=== 示例3:获取相关论文 ===")
|
||||
related = await pubmed.get_related_papers(paper_id, limit=3)
|
||||
for i, paper in enumerate(related, 1):
|
||||
print(f"\n--- 相关论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
|
||||
# 示例4:按作者搜索
|
||||
print("\n=== 示例4:按作者搜索 ===")
|
||||
author_papers = await pubmed.search_by_author("Fauci AS", limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例5:按期刊搜索
|
||||
print("\n=== 示例5:按期刊搜索 ===")
|
||||
journal_papers = await pubmed.search_by_journal("Nature", limit=3)
|
||||
for i, paper in enumerate(journal_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例6:获取最新论文
|
||||
print("\n=== 示例6:获取最新论文 ===")
|
||||
latest = await pubmed.get_latest_papers(days=7, limit=3)
|
||||
for i, paper in enumerate(latest, 1):
|
||||
print(f"\n--- 最新论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表日期: {paper.venue_info.get('pub_date')}")
|
||||
|
||||
# 示例7:获取论文的参考文献
|
||||
if papers:
|
||||
print("\n=== 示例7:获取论文的参考文献 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
references = await pubmed.get_references(paper_id)
|
||||
for i, paper in enumerate(references[:3], 1):
|
||||
print(f"\n--- 参考文献 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例8:尝试获取引用信息(将返回空列表)
|
||||
if papers:
|
||||
print("\n=== 示例8:获取论文的引用信息 ===")
|
||||
paper_id = papers[0].url.split("/")[-2]
|
||||
citations = await pubmed.get_citations(paper_id)
|
||||
print(f"引用数据:{len(citations)} (PubMed API不提供引用信息)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,326 @@
|
||||
from pathlib import Path
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
import time
|
||||
from loguru import logger
|
||||
import PyPDF2
|
||||
import io
|
||||
|
||||
|
||||
class SciHub:
|
||||
# 更新的镜像列表,包含更多可用的镜像
|
||||
MIRRORS = [
|
||||
'https://sci-hub.se/',
|
||||
'https://sci-hub.st/',
|
||||
'https://sci-hub.ru/',
|
||||
'https://sci-hub.wf/',
|
||||
'https://sci-hub.ee/',
|
||||
'https://sci-hub.ren/',
|
||||
'https://sci-hub.tf/',
|
||||
'https://sci-hub.si/',
|
||||
'https://sci-hub.do/',
|
||||
'https://sci-hub.hkvisa.net/',
|
||||
'https://sci-hub.mksa.top/',
|
||||
'https://sci-hub.shop/',
|
||||
'https://sci-hub.yncjkj.com/',
|
||||
'https://sci-hub.41610.org/',
|
||||
'https://sci-hub.automic.us/',
|
||||
'https://sci-hub.et-fine.com/',
|
||||
'https://sci-hub.pooh.mu/',
|
||||
'https://sci-hub.bban.top/',
|
||||
'https://sci-hub.usualwant.com/',
|
||||
'https://sci-hub.unblockit.kim/'
|
||||
]
|
||||
|
||||
def __init__(self, doi: str, path: Path, url=None, timeout=60, use_proxy=True):
|
||||
self.timeout = timeout
|
||||
self.path = path
|
||||
self.doi = str(doi)
|
||||
self.use_proxy = use_proxy
|
||||
self.headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
||||
}
|
||||
self.payload = {
|
||||
'sci-hub-plugin-check': '',
|
||||
'request': self.doi
|
||||
}
|
||||
self.url = url if url else self.MIRRORS[0]
|
||||
self.proxies = {
|
||||
"http": "socks5h://localhost:10880",
|
||||
"https": "socks5h://localhost:10880",
|
||||
} if use_proxy else None
|
||||
|
||||
def _test_proxy_connection(self):
|
||||
"""测试代理连接是否可用"""
|
||||
if not self.use_proxy:
|
||||
return True
|
||||
|
||||
try:
|
||||
# 测试代理连接
|
||||
test_response = requests.get(
|
||||
'https://httpbin.org/ip',
|
||||
proxies=self.proxies,
|
||||
timeout=10
|
||||
)
|
||||
if test_response.status_code == 200:
|
||||
logger.info("代理连接测试成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"代理连接测试失败: {str(e)}")
|
||||
return False
|
||||
return False
|
||||
|
||||
def _check_pdf_validity(self, content):
|
||||
"""检查PDF文件是否有效"""
|
||||
try:
|
||||
# 使用PyPDF2检查PDF是否可以正常打开和读取
|
||||
pdf = PyPDF2.PdfReader(io.BytesIO(content))
|
||||
if len(pdf.pages) > 0:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"PDF文件无效: {str(e)}")
|
||||
return False
|
||||
|
||||
def _send_request(self):
|
||||
"""发送请求到Sci-Hub镜像站点"""
|
||||
# 首先测试代理连接
|
||||
if self.use_proxy and not self._test_proxy_connection():
|
||||
logger.warning("代理连接不可用,切换到直连模式")
|
||||
self.use_proxy = False
|
||||
self.proxies = None
|
||||
|
||||
last_exception = None
|
||||
working_mirrors = []
|
||||
|
||||
# 先测试哪些镜像可用
|
||||
logger.info("正在测试镜像站点可用性...")
|
||||
for mirror in self.MIRRORS:
|
||||
try:
|
||||
test_response = requests.get(
|
||||
mirror,
|
||||
headers=self.headers,
|
||||
proxies=self.proxies,
|
||||
timeout=10
|
||||
)
|
||||
if test_response.status_code == 200:
|
||||
working_mirrors.append(mirror)
|
||||
logger.info(f"镜像 {mirror} 可用")
|
||||
if len(working_mirrors) >= 5: # 找到5个可用镜像就够了
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"镜像 {mirror} 不可用: {str(e)}")
|
||||
continue
|
||||
|
||||
if not working_mirrors:
|
||||
raise Exception("没有找到可用的镜像站点")
|
||||
|
||||
logger.info(f"找到 {len(working_mirrors)} 个可用镜像,开始尝试下载...")
|
||||
|
||||
# 使用可用的镜像进行下载
|
||||
for mirror in working_mirrors:
|
||||
try:
|
||||
res = requests.post(
|
||||
mirror,
|
||||
headers=self.headers,
|
||||
data=self.payload,
|
||||
proxies=self.proxies,
|
||||
timeout=self.timeout
|
||||
)
|
||||
if res.ok:
|
||||
logger.info(f"成功使用镜像站点: {mirror}")
|
||||
self.url = mirror # 更新当前使用的镜像
|
||||
time.sleep(1) # 降低等待时间以提高效率
|
||||
return res
|
||||
except Exception as e:
|
||||
logger.error(f"尝试镜像 {mirror} 失败: {str(e)}")
|
||||
last_exception = e
|
||||
continue
|
||||
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise Exception("所有可用镜像站点均无法完成下载")
|
||||
|
||||
def _extract_url(self, response):
|
||||
"""从响应中提取PDF下载链接"""
|
||||
soup = BeautifulSoup(response.content, 'html.parser')
|
||||
try:
|
||||
# 尝试多种方式提取PDF链接
|
||||
pdf_element = soup.find(id='pdf')
|
||||
if pdf_element:
|
||||
content_url = pdf_element.get('src')
|
||||
else:
|
||||
# 尝试其他可能的选择器
|
||||
pdf_element = soup.find('iframe')
|
||||
if pdf_element:
|
||||
content_url = pdf_element.get('src')
|
||||
else:
|
||||
# 查找直接的PDF链接
|
||||
pdf_links = soup.find_all('a', href=lambda x: x and '.pdf' in x)
|
||||
if pdf_links:
|
||||
content_url = pdf_links[0].get('href')
|
||||
else:
|
||||
raise AttributeError("未找到PDF链接")
|
||||
|
||||
if content_url:
|
||||
content_url = content_url.replace('#navpanes=0&view=FitH', '').replace('//', '/')
|
||||
if not content_url.endswith('.pdf') and 'pdf' not in content_url.lower():
|
||||
raise AttributeError("找到的链接不是PDF文件")
|
||||
except AttributeError:
|
||||
logger.error(f"未找到论文 {self.doi}")
|
||||
return None
|
||||
|
||||
current_mirror = self.url.rstrip('/')
|
||||
if content_url.startswith('/'):
|
||||
return current_mirror + content_url
|
||||
elif content_url.startswith('http'):
|
||||
return content_url
|
||||
else:
|
||||
return 'https:/' + content_url
|
||||
|
||||
def _download_pdf(self, pdf_url):
|
||||
"""下载PDF文件并验证其完整性"""
|
||||
try:
|
||||
# 尝试不同的下载方式
|
||||
download_methods = [
|
||||
# 方法1:直接下载
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout),
|
||||
# 方法2:添加 Referer 头
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': self.url}),
|
||||
# 方法3:使用原始域名作为 Referer
|
||||
lambda: requests.get(pdf_url, proxies=self.proxies, timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': pdf_url.split('/downloads')[0] if '/downloads' in pdf_url else self.url})
|
||||
]
|
||||
|
||||
for i, download_method in enumerate(download_methods):
|
||||
try:
|
||||
logger.info(f"尝试下载方式 {i+1}/3...")
|
||||
response = download_method()
|
||||
if response.status_code == 200:
|
||||
content = response.content
|
||||
if len(content) > 1000 and self._check_pdf_validity(content): # 确保文件不是太小
|
||||
logger.info(f"PDF下载成功,文件大小: {len(content)} bytes")
|
||||
return content
|
||||
else:
|
||||
logger.warning("下载的文件可能不是有效的PDF")
|
||||
elif response.status_code == 403:
|
||||
logger.warning(f"访问被拒绝 (403 Forbidden),尝试其他下载方式")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"下载失败,状态码: {response.status_code}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"下载方式 {i+1} 失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 如果所有方法都失败,尝试构造替代URL
|
||||
try:
|
||||
logger.info("尝试使用替代镜像下载...")
|
||||
# 从原始URL提取关键信息
|
||||
if '/downloads/' in pdf_url:
|
||||
file_part = pdf_url.split('/downloads/')[-1]
|
||||
alternative_mirrors = [
|
||||
f"https://sci-hub.se/downloads/{file_part}",
|
||||
f"https://sci-hub.st/downloads/{file_part}",
|
||||
f"https://sci-hub.ru/downloads/{file_part}",
|
||||
f"https://sci-hub.wf/downloads/{file_part}",
|
||||
f"https://sci-hub.ee/downloads/{file_part}",
|
||||
f"https://sci-hub.ren/downloads/{file_part}",
|
||||
f"https://sci-hub.tf/downloads/{file_part}"
|
||||
]
|
||||
|
||||
for alt_url in alternative_mirrors:
|
||||
try:
|
||||
response = requests.get(
|
||||
alt_url,
|
||||
proxies=self.proxies,
|
||||
timeout=self.timeout,
|
||||
headers={**self.headers, 'Referer': alt_url.split('/downloads')[0]}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
content = response.content
|
||||
if len(content) > 1000 and self._check_pdf_validity(content):
|
||||
logger.info(f"使用替代镜像成功下载: {alt_url}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"替代镜像 {alt_url} 下载失败: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"所有下载方式都失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载PDF文件失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def fetch(self):
|
||||
"""获取论文PDF,包含重试和验证机制"""
|
||||
for attempt in range(2): # 最多重试3次
|
||||
try:
|
||||
logger.info(f"开始第 {attempt + 1} 次尝试下载论文: {self.doi}")
|
||||
|
||||
# 获取PDF下载链接
|
||||
response = self._send_request()
|
||||
pdf_url = self._extract_url(response)
|
||||
if pdf_url is None:
|
||||
logger.warning(f"第 {attempt + 1} 次尝试:未找到PDF下载链接")
|
||||
continue
|
||||
|
||||
logger.info(f"找到PDF下载链接: {pdf_url}")
|
||||
|
||||
# 下载并验证PDF
|
||||
pdf_content = self._download_pdf(pdf_url)
|
||||
if pdf_content is None:
|
||||
logger.warning(f"第 {attempt + 1} 次尝试:PDF下载失败")
|
||||
continue
|
||||
|
||||
# 保存PDF文件
|
||||
pdf_name = f"{self.doi.replace('/', '_').replace(':', '_')}.pdf"
|
||||
pdf_path = self.path.joinpath(pdf_name)
|
||||
pdf_path.write_bytes(pdf_content)
|
||||
|
||||
logger.info(f"成功下载论文: {pdf_name},文件大小: {len(pdf_content)} bytes")
|
||||
return str(pdf_path)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"第 {attempt + 1} 次尝试失败: {str(e)}")
|
||||
if attempt < 2: # 不是最后一次尝试
|
||||
wait_time = (attempt + 1) * 3 # 递增等待时间
|
||||
logger.info(f"等待 {wait_time} 秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
raise Exception(f"无法下载论文 {self.doi},所有重试都失败了")
|
||||
|
||||
# Usage Example
|
||||
if __name__ == '__main__':
|
||||
# 创建一个用于保存PDF的目录
|
||||
save_path = Path('./downloaded_papers')
|
||||
save_path.mkdir(exist_ok=True)
|
||||
|
||||
# DOI示例
|
||||
sample_doi = '10.3897/rio.7.e67379' # 这是一篇Nature的论文DOI
|
||||
|
||||
try:
|
||||
# 初始化SciHub下载器,先尝试使用代理
|
||||
logger.info("尝试使用代理模式...")
|
||||
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=True)
|
||||
|
||||
# 开始下载
|
||||
result = downloader.fetch()
|
||||
print(f"论文已保存到: {result}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"使用代理模式失败: {str(e)}")
|
||||
try:
|
||||
# 如果代理模式失败,尝试直连模式
|
||||
logger.info("尝试直连模式...")
|
||||
downloader = SciHub(doi=sample_doi, path=save_path, use_proxy=False)
|
||||
result = downloader.fetch()
|
||||
print(f"论文已保存到: {result}")
|
||||
except Exception as e2:
|
||||
print(f"直连模式也失败: {str(e2)}")
|
||||
print("建议检查网络连接或尝试其他DOI")
|
||||
@@ -0,0 +1,400 @@
|
||||
from typing import List, Optional, Dict, Union
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import random
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
from tqdm import tqdm
|
||||
|
||||
class ScopusSource(DataSource):
|
||||
"""Scopus API实现"""
|
||||
|
||||
# 定义API密钥列表
|
||||
API_KEYS = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Scopus API密钥,如果不提供则从预定义列表中随机选择
|
||||
"""
|
||||
self.api_key = api_key or random.choice(self.API_KEYS)
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化基础URL和请求头"""
|
||||
self.base_url = "https://api.elsevier.com/content"
|
||||
self.headers = {
|
||||
"X-ELS-APIKey": self.api_key,
|
||||
"Accept": "application/json"
|
||||
}
|
||||
|
||||
async def _make_request(self, url: str, params: Dict = None) -> Optional[Dict]:
|
||||
"""发送HTTP请求
|
||||
|
||||
Args:
|
||||
url: 请求URL
|
||||
params: 查询参数
|
||||
|
||||
Returns:
|
||||
响应JSON数据
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(headers=self.headers) as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
print(f"请求失败: {response.status}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"请求发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
def _parse_paper_data(self, data: Dict) -> PaperMetadata:
|
||||
"""解析Scopus API返回的数据
|
||||
|
||||
Args:
|
||||
data: Scopus API返回的论文数据
|
||||
|
||||
Returns:
|
||||
解析后的论文元数据
|
||||
"""
|
||||
try:
|
||||
# 提取基本信息
|
||||
title = data.get("dc:title", "")
|
||||
|
||||
# 提取作者信息
|
||||
authors = []
|
||||
if "author" in data:
|
||||
if isinstance(data["author"], list):
|
||||
for author in data["author"]:
|
||||
if "given-name" in author and "surname" in author:
|
||||
authors.append(f"{author['given-name']} {author['surname']}")
|
||||
elif "indexed-name" in author:
|
||||
authors.append(author["indexed-name"])
|
||||
elif isinstance(data["author"], dict):
|
||||
if "given-name" in data["author"] and "surname" in data["author"]:
|
||||
authors.append(f"{data['author']['given-name']} {data['author']['surname']}")
|
||||
elif "indexed-name" in data["author"]:
|
||||
authors.append(data["author"]["indexed-name"])
|
||||
|
||||
# 提取摘要
|
||||
abstract = data.get("dc:description", "")
|
||||
|
||||
# 提取年份
|
||||
year = None
|
||||
if "prism:coverDate" in data:
|
||||
try:
|
||||
year = int(data["prism:coverDate"][:4])
|
||||
except:
|
||||
pass
|
||||
|
||||
# 提取DOI
|
||||
doi = data.get("prism:doi")
|
||||
|
||||
# 提取引用次数
|
||||
citations = data.get("citedby-count")
|
||||
if citations:
|
||||
try:
|
||||
citations = int(citations)
|
||||
except:
|
||||
citations = None
|
||||
|
||||
# 提取期刊信息
|
||||
venue = data.get("prism:publicationName")
|
||||
|
||||
# 提取机构信息
|
||||
institutions = []
|
||||
if "affiliation" in data:
|
||||
if isinstance(data["affiliation"], list):
|
||||
for aff in data["affiliation"]:
|
||||
if "affilname" in aff:
|
||||
institutions.append(aff["affilname"])
|
||||
elif isinstance(data["affiliation"], dict):
|
||||
if "affilname" in data["affiliation"]:
|
||||
institutions.append(data["affiliation"]["affilname"])
|
||||
|
||||
# 构建venue信息
|
||||
venue_info = {
|
||||
"issn": data.get("prism:issn"),
|
||||
"eissn": data.get("prism:eIssn"),
|
||||
"volume": data.get("prism:volume"),
|
||||
"issue": data.get("prism:issueIdentifier"),
|
||||
"page_range": data.get("prism:pageRange"),
|
||||
"article_number": data.get("article-number"),
|
||||
"publication_date": data.get("prism:coverDate")
|
||||
}
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=authors,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=data.get("link", [{}])[0].get("@href"),
|
||||
citations=citations,
|
||||
venue=venue,
|
||||
institutions=institutions,
|
||||
venue_type="journal",
|
||||
venue_name=venue,
|
||||
venue_info=venue_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"解析论文数据时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = None,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
limit: 返回结果数量限制
|
||||
sort_by: 排序方式 ('relevance', 'date', 'citations')
|
||||
start_year: 起始年份
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询参数
|
||||
params = {
|
||||
"query": query,
|
||||
"count": min(limit, 100), # Scopus API单次请求限制
|
||||
"start": 0
|
||||
}
|
||||
|
||||
# 添加年份过滤
|
||||
if start_year:
|
||||
params["date"] = f"{start_year}-present"
|
||||
|
||||
# 添加排序
|
||||
if sort_by:
|
||||
sort_map = {
|
||||
"relevance": "-score",
|
||||
"date": "-coverDate",
|
||||
"citations": "-citedby-count"
|
||||
}
|
||||
if sort_by in sort_map:
|
||||
params["sort"] = sort_map[sort_by]
|
||||
|
||||
# 发送请求
|
||||
url = f"{self.base_url}/search/scopus"
|
||||
response = await self._make_request(url, params)
|
||||
|
||||
if not response or "search-results" not in response:
|
||||
return []
|
||||
|
||||
# 解析结果
|
||||
results = response["search-results"].get("entry", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详情
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID或DOI
|
||||
|
||||
Returns:
|
||||
论文详情
|
||||
"""
|
||||
try:
|
||||
# 判断是否为DOI
|
||||
if "/" in paper_id:
|
||||
url = f"{self.base_url}/article/doi/{paper_id}"
|
||||
else:
|
||||
url = f"{self.base_url}/abstract/scopus_id/{paper_id}"
|
||||
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "abstracts-retrieval-response" not in response:
|
||||
return None
|
||||
|
||||
data = response["abstracts-retrieval-response"]
|
||||
return self._parse_paper_data(data)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取引用该论文的文献
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID
|
||||
|
||||
Returns:
|
||||
引用论文列表
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/abstract/citations/{paper_id}"
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "citing-papers" not in response:
|
||||
return []
|
||||
|
||||
results = response["citing-papers"].get("papers", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取引用信息时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(self, paper_id: str) -> List[PaperMetadata]:
|
||||
"""获取该论文引用的文献
|
||||
|
||||
Args:
|
||||
paper_id: Scopus ID
|
||||
|
||||
Returns:
|
||||
参考文献列表
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/abstract/references/{paper_id}"
|
||||
response = await self._make_request(url)
|
||||
|
||||
if not response or "references" not in response:
|
||||
return []
|
||||
|
||||
results = response["references"].get("reference", [])
|
||||
papers = []
|
||||
|
||||
for result in results:
|
||||
paper = self._parse_paper_data(result)
|
||||
if paper:
|
||||
papers.append(paper)
|
||||
|
||||
return papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取参考文献时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_by_author(
|
||||
self,
|
||||
author: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = f"AUTHOR-NAME({author})"
|
||||
if start_year:
|
||||
query += f" AND PUBYEAR > {start_year}"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def search_by_journal(
|
||||
self,
|
||||
journal: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按期刊搜索论文"""
|
||||
query = f"SRCTITLE({journal})"
|
||||
if start_year:
|
||||
query += f" AND PUBYEAR > {start_year}"
|
||||
return await self.search(query, limit=limit)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
days: int = 7,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取最新论文"""
|
||||
query = f"LOAD-DATE > NOW() - {days}d"
|
||||
return await self.search(query, limit=limit, sort_by="date")
|
||||
|
||||
async def example_usage():
|
||||
"""ScopusSource使用示例"""
|
||||
scopus = ScopusSource()
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索
|
||||
print("\n=== 示例1:搜索机器学习相关论文 ===")
|
||||
papers = await scopus.search("machine learning", limit=3)
|
||||
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
# 示例2:按作者搜索
|
||||
print("\n=== 示例2:搜索特定作者的论文 ===")
|
||||
author_papers = await scopus.search_by_author("Hinton G.", limit=3)
|
||||
print(f"\n找到 {len(author_papers)} 篇 Hinton 的论文:")
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
# 示例3:根据关键词搜索相关论文
|
||||
print("\n=== 示例3:搜索人工智能相关论文 ===")
|
||||
keywords = "artificial intelligence AND deep learning"
|
||||
papers = await scopus.search(
|
||||
query=keywords,
|
||||
limit=5,
|
||||
sort_by="citations", # 按引用次数排序
|
||||
start_year=2020 # 只搜索2020年之后的论文
|
||||
)
|
||||
|
||||
print(f"\n找到 {len(papers)} 篇相关论文:")
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n论文 {i}:")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"发表期刊: {paper.venue}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
if paper.abstract:
|
||||
print(f"摘要:\n{paper.abstract}")
|
||||
print("-" * 80)
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,480 @@
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from crazy_functions.review_fns.data_sources.base_source import DataSource, PaperMetadata
|
||||
import random
|
||||
|
||||
class SemanticScholarSource(DataSource):
|
||||
"""Semantic Scholar API实现,使用官方Python包"""
|
||||
|
||||
def __init__(self, api_key: str = None):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
api_key: Semantic Scholar API密钥(可选)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self._initialize() # 调用初始化方法
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化API客户端"""
|
||||
if not self.api_key:
|
||||
# 默认API密钥列表
|
||||
default_api_keys = [
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
]
|
||||
self.api_key = random.choice(default_api_keys)
|
||||
|
||||
self.client = None # 延迟初始化
|
||||
self.fields = [
|
||||
"title",
|
||||
"authors",
|
||||
"abstract",
|
||||
"year",
|
||||
"externalIds",
|
||||
"citationCount",
|
||||
"venue",
|
||||
"openAccessPdf",
|
||||
"publicationVenue"
|
||||
]
|
||||
|
||||
async def _ensure_client(self):
|
||||
"""确保客户端已初始化"""
|
||||
if self.client is None:
|
||||
from semanticscholar import AsyncSemanticScholar
|
||||
self.client = AsyncSemanticScholar(api_key=self.api_key)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""搜索论文"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} year>={start_year}"
|
||||
|
||||
# 直接使用 search_paper 的结果
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/search",
|
||||
f"query={query}&limit={min(limit, 100)}&fields={','.join(self.fields)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
papers = response.get('data', [])
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"搜索论文时发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return []
|
||||
|
||||
async def get_paper_details(self, doi: str) -> Optional[PaperMetadata]:
|
||||
"""获取指定DOI的论文详情"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
paper = await self.client.get_paper(f"DOI:{doi}", fields=self.fields)
|
||||
return self._parse_paper_data(paper)
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_citations(
|
||||
self,
|
||||
doi: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取引用指定DOI论文的文献列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 构建查询参数
|
||||
fields_param = f"fields={','.join(self.fields)}"
|
||||
limit_param = f"limit={limit}"
|
||||
year_param = f"year>={start_year}" if start_year else ""
|
||||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/citations",
|
||||
params,
|
||||
self.client.auth_header
|
||||
)
|
||||
citations = response.get('data', [])
|
||||
return [self._parse_paper_data(citation.get('citingPaper', {})) for citation in citations]
|
||||
except Exception as e:
|
||||
print(f"获取引用列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_references(
|
||||
self,
|
||||
doi: str,
|
||||
limit: int = 100,
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取指定DOI论文的参考文献列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 构建查询参数
|
||||
fields_param = f"fields={','.join(self.fields)}"
|
||||
limit_param = f"limit={limit}"
|
||||
year_param = f"year>={start_year}" if start_year else ""
|
||||
params = "&".join(filter(None, [fields_param, limit_param, year_param]))
|
||||
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/DOI:{doi}/references",
|
||||
params,
|
||||
self.client.auth_header
|
||||
)
|
||||
references = response.get('data', [])
|
||||
return [self._parse_paper_data(reference.get('citedPaper', {})) for reference in references]
|
||||
except Exception as e:
|
||||
print(f"获取参考文献列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recommended_papers(self, doi: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取论文推荐
|
||||
|
||||
根据一篇论文获取相关的推荐论文
|
||||
|
||||
Args:
|
||||
doi: 论文的DOI
|
||||
limit: 返回结果数量限制,最大500
|
||||
|
||||
Returns:
|
||||
推荐论文列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
papers = await self.client.get_recommended_papers(
|
||||
f"DOI:{doi}",
|
||||
fields=self.fields,
|
||||
limit=min(limit, 500)
|
||||
)
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取论文推荐时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_recommended_papers_from_lists(
|
||||
self,
|
||||
positive_dois: List[str],
|
||||
negative_dois: List[str] = None,
|
||||
limit: int = 100
|
||||
) -> List[PaperMetadata]:
|
||||
"""基于正负例论文列表获取推荐
|
||||
|
||||
Args:
|
||||
positive_dois: 正例论文DOI列表(想要获取类似的论文)
|
||||
negative_dois: 负例论文DOI列表(不想要类似的论文)
|
||||
limit: 返回结果数量限制,最大500
|
||||
|
||||
Returns:
|
||||
推荐论文列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
positive_ids = [f"DOI:{doi}" for doi in positive_dois]
|
||||
negative_ids = [f"DOI:{doi}" for doi in negative_dois] if negative_dois else None
|
||||
|
||||
papers = await self.client.get_recommended_papers_from_lists(
|
||||
positive_paper_ids=positive_ids,
|
||||
negative_paper_ids=negative_ids,
|
||||
fields=self.fields,
|
||||
limit=min(limit, 500)
|
||||
)
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取论文推荐列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def search_author(self, query: str, limit: int = 100) -> List[dict]:
|
||||
"""搜索作者"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求而不是 search_author 方法
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/search",
|
||||
f"query={query}&fields=name,paperCount,citationCount&limit={min(limit, 1000)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
authors = response.get('data', [])
|
||||
return [
|
||||
{
|
||||
'author_id': author.get('authorId'),
|
||||
'name': author.get('name'),
|
||||
'paper_count': author.get('paperCount'),
|
||||
'citation_count': author.get('citationCount'),
|
||||
}
|
||||
for author in authors
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"搜索作者时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_author_details(self, author_id: str) -> Optional[dict]:
|
||||
"""获取作者详细信息"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}",
|
||||
"fields=name,paperCount,citationCount,hIndex",
|
||||
self.client.auth_header
|
||||
)
|
||||
return {
|
||||
'author_id': response.get('authorId'),
|
||||
'name': response.get('name'),
|
||||
'paper_count': response.get('paperCount'),
|
||||
'citation_count': response.get('citationCount'),
|
||||
'h_index': response.get('hIndex'),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"获取作者详情时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_author_papers(self, author_id: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
"""获取作者的论文列表"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/author/{author_id}/papers",
|
||||
f"fields={','.join(self.fields)}&limit={min(limit, 1000)}",
|
||||
self.client.auth_header
|
||||
)
|
||||
papers = response.get('data', [])
|
||||
return [self._parse_paper_data(paper) for paper in papers]
|
||||
except Exception as e:
|
||||
print(f"获取作者论文列表时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
async def get_paper_autocomplete(self, query: str) -> List[dict]:
|
||||
"""论文标题自动补全"""
|
||||
try:
|
||||
await self._ensure_client()
|
||||
# 直接使用 API 请求
|
||||
response = await self.client._requester.get_data_async(
|
||||
f"{self.client.api_url}{self.client.BASE_PATH_GRAPH}/paper/autocomplete",
|
||||
f"query={query}",
|
||||
self.client.auth_header
|
||||
)
|
||||
suggestions = response.get('matches', [])
|
||||
return [
|
||||
{
|
||||
'title': suggestion.get('title'),
|
||||
'paper_id': suggestion.get('paperId'),
|
||||
'year': suggestion.get('year'),
|
||||
'venue': suggestion.get('venue'),
|
||||
}
|
||||
for suggestion in suggestions
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"获取标题自动补全时发生错误: {str(e)}")
|
||||
return []
|
||||
|
||||
def _parse_paper_data(self, paper) -> PaperMetadata:
|
||||
"""解析论文数据"""
|
||||
# 获取DOI
|
||||
doi = None
|
||||
external_ids = paper.get('externalIds', {}) if isinstance(paper, dict) else paper.externalIds
|
||||
if external_ids:
|
||||
if isinstance(external_ids, dict):
|
||||
doi = external_ids.get('DOI')
|
||||
if not doi and 'ArXiv' in external_ids:
|
||||
doi = f"10.48550/arXiv.{external_ids['ArXiv']}"
|
||||
else:
|
||||
doi = external_ids.DOI if hasattr(external_ids, 'DOI') else None
|
||||
if not doi and hasattr(external_ids, 'ArXiv'):
|
||||
doi = f"10.48550/arXiv.{external_ids.ArXiv}"
|
||||
|
||||
# 获取PDF URL
|
||||
pdf_url = None
|
||||
pdf_info = paper.get('openAccessPdf', {}) if isinstance(paper, dict) else paper.openAccessPdf
|
||||
if pdf_info:
|
||||
pdf_url = pdf_info.get('url') if isinstance(pdf_info, dict) else pdf_info.url
|
||||
|
||||
# 获取发表场所详细信息
|
||||
venue_type = None
|
||||
venue_name = None
|
||||
venue_info = {}
|
||||
|
||||
venue = paper.get('publicationVenue', {}) if isinstance(paper, dict) else paper.publicationVenue
|
||||
if venue:
|
||||
if isinstance(venue, dict):
|
||||
venue_name = venue.get('name')
|
||||
venue_type = venue.get('type')
|
||||
# 提取更多venue信息
|
||||
venue_info = {
|
||||
'issn': venue.get('issn'),
|
||||
'publisher': venue.get('publisher'),
|
||||
'url': venue.get('url'),
|
||||
'alternate_names': venue.get('alternate_names', [])
|
||||
}
|
||||
else:
|
||||
venue_name = venue.name if hasattr(venue, 'name') else None
|
||||
venue_type = venue.type if hasattr(venue, 'type') else None
|
||||
venue_info = {
|
||||
'issn': getattr(venue, 'issn', None),
|
||||
'publisher': getattr(venue, 'publisher', None),
|
||||
'url': getattr(venue, 'url', None),
|
||||
'alternate_names': getattr(venue, 'alternate_names', [])
|
||||
}
|
||||
|
||||
# 获取标题
|
||||
title = paper.get('title', '') if isinstance(paper, dict) else getattr(paper, 'title', '')
|
||||
|
||||
# 获取作者
|
||||
authors = paper.get('authors', []) if isinstance(paper, dict) else getattr(paper, 'authors', [])
|
||||
author_names = []
|
||||
for author in authors:
|
||||
if isinstance(author, dict):
|
||||
author_names.append(author.get('name', ''))
|
||||
else:
|
||||
author_names.append(author.name if hasattr(author, 'name') else str(author))
|
||||
|
||||
# 获取摘要
|
||||
abstract = paper.get('abstract', '') if isinstance(paper, dict) else getattr(paper, 'abstract', '')
|
||||
|
||||
# 获取年份
|
||||
year = paper.get('year') if isinstance(paper, dict) else getattr(paper, 'year', None)
|
||||
|
||||
# 获取引用次数
|
||||
citations = paper.get('citationCount') if isinstance(paper, dict) else getattr(paper, 'citationCount', None)
|
||||
|
||||
return PaperMetadata(
|
||||
title=title,
|
||||
authors=author_names,
|
||||
abstract=abstract,
|
||||
year=year,
|
||||
doi=doi,
|
||||
url=pdf_url or (f"https://doi.org/{doi}" if doi else None),
|
||||
citations=citations,
|
||||
venue=venue_name,
|
||||
institutions=[],
|
||||
venue_type=venue_type,
|
||||
venue_name=venue_name,
|
||||
venue_info=venue_info,
|
||||
source='semantic' # 添加来源标记
|
||||
)
|
||||
|
||||
async def example_usage():
|
||||
"""SemanticScholarSource使用示例"""
|
||||
semantic = SemanticScholarSource()
|
||||
|
||||
try:
|
||||
# 示例1:使用DOI直接获取论文
|
||||
print("\n=== 示例1:通过DOI获取论文 ===")
|
||||
doi = "10.18653/v1/N19-1423" # BERT论文
|
||||
print(f"获取DOI为 {doi} 的论文信息...")
|
||||
|
||||
paper = await semantic.get_paper_details(doi)
|
||||
if paper:
|
||||
print("\n--- 论文信息 ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"DOI: {paper.doi}")
|
||||
print(f"URL: {paper.url}")
|
||||
if paper.abstract:
|
||||
print(f"\n摘要:")
|
||||
print(paper.abstract)
|
||||
print(f"\n引用次数: {paper.citations}")
|
||||
print(f"发表venue: {paper.venue}")
|
||||
|
||||
# 示例2:搜索论文
|
||||
print("\n=== 示例2:搜索论文 ===")
|
||||
query = "BERT pre-training"
|
||||
print(f"搜索关键词 '{query}' 相关的论文...")
|
||||
papers = await semantic.search(query=query, limit=3)
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
print(f"\n--- 搜索结果 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
if paper.abstract:
|
||||
print(f"\n摘要:")
|
||||
print(paper.abstract)
|
||||
print(f"\nDOI: {paper.doi}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例3:获取论文推荐
|
||||
print("\n=== 示例3:获取论文推荐 ===")
|
||||
print(f"获取与论文 {doi} 相关的推荐论文...")
|
||||
recommendations = await semantic.get_recommended_papers(doi, limit=3)
|
||||
for i, paper in enumerate(recommendations, 1):
|
||||
print(f"\n--- 推荐论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
# 示例4:基于多篇论文的推荐
|
||||
print("\n=== 示例4:基于多篇论文的推荐 ===")
|
||||
positive_dois = ["10.18653/v1/N19-1423", "10.18653/v1/P19-1285"]
|
||||
print(f"基于 {len(positive_dois)} 篇论文获取推荐...")
|
||||
multi_recommendations = await semantic.get_recommended_papers_from_lists(
|
||||
positive_dois=positive_dois,
|
||||
limit=3
|
||||
)
|
||||
for i, paper in enumerate(multi_recommendations, 1):
|
||||
print(f"\n--- 推荐论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
|
||||
# 示例5:搜索作者
|
||||
print("\n=== 示例5:搜索作者 ===")
|
||||
author_query = "Yann LeCun"
|
||||
print(f"搜索作者: '{author_query}'")
|
||||
authors = await semantic.search_author(author_query, limit=3)
|
||||
for i, author in enumerate(authors, 1):
|
||||
print(f"\n--- 作者 {i} ---")
|
||||
print(f"姓名: {author['name']}")
|
||||
print(f"论文数量: {author['paper_count']}")
|
||||
print(f"总引用次数: {author['citation_count']}")
|
||||
|
||||
# 示例6:获取作者详情
|
||||
print("\n=== 示例6:获取作者详情 ===")
|
||||
if authors: # 使用第一个搜索结果的作者ID
|
||||
author_id = authors[0]['author_id']
|
||||
print(f"获取作者ID {author_id} 的详细信息...")
|
||||
author_details = await semantic.get_author_details(author_id)
|
||||
if author_details:
|
||||
print(f"姓名: {author_details['name']}")
|
||||
print(f"H指数: {author_details['h_index']}")
|
||||
print(f"总引用次数: {author_details['citation_count']}")
|
||||
print(f"发表论文数: {author_details['paper_count']}")
|
||||
|
||||
# 示例7:获取作者论文
|
||||
print("\n=== 示例7:获取作者论文 ===")
|
||||
if authors: # 使用第一个搜索结果的作者ID
|
||||
author_id = authors[0]['author_id']
|
||||
print(f"获取作者 {authors[0]['name']} 的论文列表...")
|
||||
author_papers = await semantic.get_author_papers(author_id, limit=3)
|
||||
for i, paper in enumerate(author_papers, 1):
|
||||
print(f"\n--- 论文 {i} ---")
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
print(f"引用次数: {paper.citations}")
|
||||
|
||||
# 示例8:论文标题自动补全
|
||||
print("\n=== 示例8:论文标题自动补全 ===")
|
||||
title_query = "Attention is all"
|
||||
print(f"搜索标题: '{title_query}'")
|
||||
suggestions = await semantic.get_paper_autocomplete(title_query)
|
||||
for i, suggestion in enumerate(suggestions[:3], 1):
|
||||
print(f"\n--- 建议 {i} ---")
|
||||
print(f"标题: {suggestion['title']}")
|
||||
print(f"发表年份: {suggestion['year']}")
|
||||
print(f"发表venue: {suggestion['venue']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,46 @@
|
||||
import aiohttp
|
||||
from typing import List, Dict, Optional
|
||||
from datetime import datetime
|
||||
from .base_source import DataSource, PaperMetadata
|
||||
|
||||
class UnpaywallSource(DataSource):
|
||||
"""Unpaywall API实现"""
|
||||
|
||||
def _initialize(self) -> None:
|
||||
self.base_url = "https://api.unpaywall.org/v2"
|
||||
self.email = self.api_key # Unpaywall使用email作为API key
|
||||
|
||||
async def search(self, query: str, limit: int = 100) -> List[PaperMetadata]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.base_url}/search",
|
||||
params={
|
||||
"query": query,
|
||||
"email": self.email,
|
||||
"limit": limit
|
||||
}
|
||||
) as response:
|
||||
data = await response.json()
|
||||
return [self._parse_response(item.response)
|
||||
for item in data.get("results", [])]
|
||||
|
||||
def _parse_response(self, data: Dict) -> PaperMetadata:
|
||||
"""解析Unpaywall返回的数据"""
|
||||
return PaperMetadata(
|
||||
title=data.get("title", ""),
|
||||
authors=[
|
||||
f"{author.get('given', '')} {author.get('family', '')}"
|
||||
for author in data.get("z_authors", [])
|
||||
],
|
||||
institutions=[
|
||||
aff.get("name", "")
|
||||
for author in data.get("z_authors", [])
|
||||
for aff in author.get("affiliation", [])
|
||||
],
|
||||
abstract="", # Unpaywall不提供摘要
|
||||
year=data.get("year"),
|
||||
doi=data.get("doi"),
|
||||
url=data.get("doi_url"),
|
||||
citations=None, # Unpaywall不提供引用计数
|
||||
venue=data.get("journal_name")
|
||||
)
|
||||
@@ -0,0 +1,412 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
from crazy_functions.review_fns.data_sources.arxiv_source import ArxivSource
|
||||
from crazy_functions.review_fns.data_sources.semantic_source import SemanticScholarSource
|
||||
from crazy_functions.review_fns.data_sources.pubmed_source import PubMedSource
|
||||
from crazy_functions.review_fns.paper_processor.paper_llm_ranker import PaperLLMRanker
|
||||
from crazy_functions.pdf_fns.breakdown_pdf_txt import cut_from_end_to_satisfy_token_limit
|
||||
from request_llms.bridge_all import model_info
|
||||
from crazy_functions.review_fns.data_sources.crossref_source import CrossrefSource
|
||||
from crazy_functions.review_fns.data_sources.adsabs_source import AdsabsSource
|
||||
from toolbox import get_conf
|
||||
|
||||
|
||||
class BaseHandler(ABC):
|
||||
"""处理器基类"""
|
||||
|
||||
def __init__(self, arxiv: ArxivSource, semantic: SemanticScholarSource, llm_kwargs: Dict = None):
|
||||
self.arxiv = arxiv
|
||||
self.semantic = semantic
|
||||
self.pubmed = PubMedSource()
|
||||
self.crossref = CrossrefSource() # 添加 Crossref 实例
|
||||
self.adsabs = AdsabsSource() # 添加 ADS 实例
|
||||
self.paper_ranker = PaperLLMRanker(llm_kwargs=llm_kwargs)
|
||||
self.ranked_papers = [] # 存储排序后的论文列表
|
||||
self.llm_kwargs = llm_kwargs or {} # 保存llm_kwargs
|
||||
|
||||
def _get_search_params(self, plugin_kwargs: Dict) -> Dict:
|
||||
"""获取搜索参数"""
|
||||
return {
|
||||
'max_papers': plugin_kwargs.get('max_papers', 100), # 最大论文数量
|
||||
'min_year': plugin_kwargs.get('min_year', 2015), # 最早年份
|
||||
'search_multiplier': plugin_kwargs.get('search_multiplier', 3), # 检索倍数
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> List[List[str]]:
|
||||
"""处理查询"""
|
||||
pass
|
||||
|
||||
async def _search_arxiv(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用arXiv专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 首先尝试基础搜索
|
||||
query = params.get("query", "")
|
||||
if query:
|
||||
papers = await self.arxiv.search(
|
||||
query,
|
||||
limit=params["limit"],
|
||||
sort_by=params.get("sort_by", "relevance"),
|
||||
sort_order=params.get("sort_order", "descending"),
|
||||
start_year=min_year
|
||||
)
|
||||
|
||||
# 如果基础搜索没有结果,尝试分类搜索
|
||||
if not papers:
|
||||
categories = params.get("categories", [])
|
||||
for category in categories:
|
||||
category_papers = await self.arxiv.search_by_category(
|
||||
category,
|
||||
limit=params["limit"],
|
||||
sort_by=params.get("sort_by", "relevance"),
|
||||
sort_order=params.get("sort_order", "descending"),
|
||||
)
|
||||
if category_papers:
|
||||
papers.extend(category_papers)
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"arXiv搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_semantic(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用Semantic Scholar专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
|
||||
# 只使用基本的搜索参数
|
||||
papers = await self.semantic.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"]
|
||||
)
|
||||
|
||||
# 在内存中进行过滤
|
||||
if papers and min_year:
|
||||
papers = [p for p in papers if getattr(p, 'year', 0) and p.year >= min_year]
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"Semantic Scholar搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_pubmed(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用PubMed专用参数搜索"""
|
||||
try:
|
||||
# 如果不需要PubMed搜索,直接返回空列表
|
||||
if params.get("search_type") == "none":
|
||||
return []
|
||||
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 根据搜索类型选择搜索方法
|
||||
if params.get("search_type") == "basic":
|
||||
papers = await self.pubmed.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "author":
|
||||
papers = await self.pubmed.search_by_author(
|
||||
author=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "journal":
|
||||
papers = await self.pubmed.search_by_journal(
|
||||
journal=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"PubMed搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_crossref(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用Crossref专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 根据搜索类型选择搜索方法
|
||||
if params.get("search_type") == "basic":
|
||||
papers = await self.crossref.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "author":
|
||||
papers = await self.crossref.search_by_authors(
|
||||
authors=[params.get("query", "")],
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
elif params.get("search_type") == "journal":
|
||||
# 实现期刊搜索逻辑
|
||||
pass
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"Crossref搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_adsabs(self, params: Dict, limit_multiplier: int = 1, min_year: int = 2015) -> List:
|
||||
"""使用ADS专用参数搜索"""
|
||||
try:
|
||||
original_limit = params.get("limit", 20)
|
||||
params["limit"] = original_limit * limit_multiplier
|
||||
papers = []
|
||||
|
||||
# 执行搜索
|
||||
if params.get("search_type") == "basic":
|
||||
papers = await self.adsabs.search(
|
||||
query=params.get("query", ""),
|
||||
limit=params["limit"],
|
||||
start_year=min_year
|
||||
)
|
||||
|
||||
return papers or []
|
||||
|
||||
except Exception as e:
|
||||
print(f"ADS搜索出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _search_all_sources(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||
"""从所有数据源搜索论文"""
|
||||
search_tasks = []
|
||||
|
||||
# # 检查是否需要执行PubMed搜索
|
||||
# is_using_pubmed = criteria.pubmed_params.get("search_type") != "none" and criteria.pubmed_params.get("query") != "none"
|
||||
is_using_pubmed = False # 开源版本不再搜索pubmed
|
||||
|
||||
# 如果使用PubMed,则只执行PubMed和Semantic Scholar搜索
|
||||
if is_using_pubmed:
|
||||
search_tasks.append(
|
||||
self._search_pubmed(
|
||||
criteria.pubmed_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
|
||||
# Semantic Scholar总是执行搜索
|
||||
search_tasks.append(
|
||||
self._search_semantic(
|
||||
criteria.semantic_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
# 如果不使用ADS,则执行Crossref搜索
|
||||
if criteria.crossref_params.get("search_type") != "none" and criteria.crossref_params.get("query") != "none":
|
||||
search_tasks.append(
|
||||
self._search_crossref(
|
||||
criteria.crossref_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
|
||||
search_tasks.append(
|
||||
self._search_arxiv(
|
||||
criteria.arxiv_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
if get_conf("SEMANTIC_SCHOLAR_KEY"):
|
||||
search_tasks.append(
|
||||
self._search_semantic(
|
||||
criteria.semantic_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=criteria.start_year
|
||||
)
|
||||
)
|
||||
|
||||
# 执行所有需要的搜索任务
|
||||
papers = await asyncio.gather(*search_tasks)
|
||||
|
||||
# 合并所有来源的论文并统计各来源的数量
|
||||
all_papers = []
|
||||
source_counts = {
|
||||
'arxiv': 0,
|
||||
'semantic': 0,
|
||||
'pubmed': 0,
|
||||
'crossref': 0,
|
||||
'adsabs': 0
|
||||
}
|
||||
|
||||
for source_papers in papers:
|
||||
if source_papers:
|
||||
for paper in source_papers:
|
||||
source = getattr(paper, 'source', 'unknown')
|
||||
if source in source_counts:
|
||||
source_counts[source] += 1
|
||||
all_papers.extend(source_papers)
|
||||
|
||||
# 打印各来源的论文数量
|
||||
print("\n=== 各数据源找到的论文数量 ===")
|
||||
for source, count in source_counts.items():
|
||||
if count > 0: # 只打印有论文的来源
|
||||
print(f"{source.capitalize()}: {count} 篇")
|
||||
print(f"总计: {len(all_papers)} 篇")
|
||||
print("===========================\n")
|
||||
|
||||
return all_papers
|
||||
|
||||
def _format_paper_time(self, paper) -> str:
|
||||
"""格式化论文时间信息"""
|
||||
year = getattr(paper, 'year', None)
|
||||
if not year:
|
||||
return ""
|
||||
|
||||
# 如果有具体的发表日期,使用具体日期
|
||||
if hasattr(paper, 'published') and paper.published:
|
||||
return f"(发表于 {paper.published.strftime('%Y-%m')})"
|
||||
# 如果只有年份,只显示年份
|
||||
return f"({year})"
|
||||
|
||||
def _format_papers(self, papers: List) -> str:
|
||||
"""格式化论文列表,使用token限制控制长度"""
|
||||
formatted = []
|
||||
|
||||
for i, paper in enumerate(papers, 1):
|
||||
# 只保留前三个作者
|
||||
authors = paper.authors[:3]
|
||||
if len(paper.authors) > 3:
|
||||
authors.append("et al.")
|
||||
|
||||
# 构建所有可能的下载链接
|
||||
download_links = []
|
||||
|
||||
# 添加arXiv链接
|
||||
if hasattr(paper, 'doi') and paper.doi:
|
||||
if paper.doi.startswith("10.48550/arXiv."):
|
||||
# 从DOI中提取完整的arXiv ID
|
||||
arxiv_id = paper.doi.split("arXiv.")[-1]
|
||||
# 移除多余的点号并确保格式正确
|
||||
arxiv_id = arxiv_id.replace("..", ".") # 移除重复的点号
|
||||
if arxiv_id.startswith("."): # 移除开头的点号
|
||||
arxiv_id = arxiv_id[1:]
|
||||
if arxiv_id.endswith("."): # 移除结尾的点号
|
||||
arxiv_id = arxiv_id[:-1]
|
||||
|
||||
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||||
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||||
elif "arxiv.org/abs/" in paper.doi:
|
||||
# 直接从URL中提取arXiv ID
|
||||
arxiv_id = paper.doi.split("arxiv.org/abs/")[-1]
|
||||
if "v" in arxiv_id: # 移除版本号
|
||||
arxiv_id = arxiv_id.split("v")[0]
|
||||
|
||||
download_links.append(f"[arXiv PDF](https://arxiv.org/pdf/{arxiv_id}.pdf)")
|
||||
download_links.append(f"[arXiv Page](https://arxiv.org/abs/{arxiv_id})")
|
||||
else:
|
||||
download_links.append(f"[DOI](https://doi.org/{paper.doi})")
|
||||
|
||||
# 添加直接URL链接(如果存在且不同于前面的链接)
|
||||
if hasattr(paper, 'url') and paper.url:
|
||||
if not any(paper.url in link for link in download_links):
|
||||
download_links.append(f"[Source]({paper.url})")
|
||||
|
||||
# 构建下载链接字符串
|
||||
download_section = " | ".join(download_links) if download_links else "No direct download link available"
|
||||
|
||||
# 构建来源信息
|
||||
source_info = []
|
||||
if hasattr(paper, 'venue_type') and paper.venue_type and paper.venue_type != 'preprint':
|
||||
source_info.append(f"Type: {paper.venue_type}")
|
||||
if hasattr(paper, 'venue_name') and paper.venue_name:
|
||||
source_info.append(f"Venue: {paper.venue_name}")
|
||||
|
||||
# 添加IF指数和分区信息
|
||||
if hasattr(paper, 'if_factor') and paper.if_factor:
|
||||
source_info.append(f"IF: {paper.if_factor}")
|
||||
if hasattr(paper, 'cas_division') and paper.cas_division:
|
||||
source_info.append(f"中科院分区: {paper.cas_division}")
|
||||
if hasattr(paper, 'jcr_division') and paper.jcr_division:
|
||||
source_info.append(f"JCR分区: {paper.jcr_division}")
|
||||
|
||||
if hasattr(paper, 'venue_info') and paper.venue_info:
|
||||
if paper.venue_info.get('journal_ref'):
|
||||
source_info.append(f"Journal Reference: {paper.venue_info['journal_ref']}")
|
||||
if paper.venue_info.get('publisher'):
|
||||
source_info.append(f"Publisher: {paper.venue_info['publisher']}")
|
||||
|
||||
# 构建当前论文的格式化文本
|
||||
paper_text = (
|
||||
f"{i}. **{paper.title}**\n" +
|
||||
f" Authors: {', '.join(authors)}\n" +
|
||||
f" Year: {paper.year}\n" +
|
||||
f" Citations: {paper.citations if paper.citations else 'N/A'}\n" +
|
||||
(f" Source: {'; '.join(source_info)}\n" if source_info else "") +
|
||||
# 添加PubMed特有信息
|
||||
(f" MeSH Terms: {'; '.join(paper.mesh_terms)}\n" if hasattr(paper,
|
||||
'mesh_terms') and paper.mesh_terms else "") +
|
||||
f" 📥 PDF Downloads: {download_section}\n" +
|
||||
f" Abstract: {paper.abstract}\n"
|
||||
)
|
||||
|
||||
formatted.append(paper_text)
|
||||
|
||||
full_text = "\n".join(formatted)
|
||||
|
||||
# 根据不同模型设置不同的token限制
|
||||
model_name = getattr(self, 'llm_kwargs', {}).get('llm_model', 'gpt-3.5-turbo')
|
||||
|
||||
token_limit = model_info[model_name]['max_token'] * 3 // 4
|
||||
# 使用token限制控制长度
|
||||
return cut_from_end_to_satisfy_token_limit(full_text, limit=token_limit, reserve_token=0, llm_model=model_name)
|
||||
|
||||
def _get_current_time(self) -> str:
|
||||
"""获取当前时间信息"""
|
||||
now = datetime.now()
|
||||
return now.strftime("%Y年%m月%d日")
|
||||
|
||||
def _generate_apology_prompt(self, criteria: SearchCriteria) -> str:
|
||||
"""生成道歉提示"""
|
||||
return f"""很抱歉,我们未能找到与"{criteria.main_topic}"相关的有效文献。
|
||||
|
||||
可能的原因:
|
||||
1. 搜索词过于具体或专业
|
||||
2. 时间范围限制过严
|
||||
|
||||
建议解决方案:
|
||||
1. 尝试使用更通用的关键词
|
||||
2. 扩大搜索时间范围
|
||||
3. 使用同义词或相关术语
|
||||
请根据以上建议调整后重试。"""
|
||||
|
||||
def get_ranked_papers(self) -> str:
|
||||
"""获取排序后的论文列表的格式化字符串"""
|
||||
return self._format_papers(self.ranked_papers) if self.ranked_papers else ""
|
||||
|
||||
def _is_pubmed_paper(self, paper) -> bool:
|
||||
"""判断是否为PubMed论文"""
|
||||
return (paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url)
|
||||
@@ -0,0 +1,106 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class Arxiv最新论文推荐功能(BaseHandler):
|
||||
"""最新论文推荐处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理最新论文推荐请求"""
|
||||
|
||||
# 获取搜索参数
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 获取最新论文
|
||||
papers = []
|
||||
for category in criteria.arxiv_params["categories"]:
|
||||
latest_papers = await self.arxiv.get_latest_papers(
|
||||
category=category,
|
||||
debug=False,
|
||||
batch_size=50
|
||||
)
|
||||
papers.extend(latest_papers)
|
||||
|
||||
if not papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 使用embedding模型对论文进行排序
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.original_query,
|
||||
papers=papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""Current time: {current_time}
|
||||
|
||||
Based on your interest in {criteria.main_topic}, here are the latest papers from arXiv in relevant categories:
|
||||
{', '.join(criteria.arxiv_params["categories"])}
|
||||
|
||||
Latest papers available:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
Please provide:
|
||||
1. A clear list of latext papers, organized by themes or approaches
|
||||
|
||||
|
||||
2. Group papers by sub-topics or themes if applicable
|
||||
|
||||
3. For each paper:
|
||||
- Publication time
|
||||
- The key contributions and main findings
|
||||
- Why it's relevant to the user's interests
|
||||
- How it relates to other latest papers
|
||||
- The paper's citation count and citation impact
|
||||
- The paper's download link
|
||||
|
||||
4. A suggested reading order based on:
|
||||
- Paper relationships and dependencies
|
||||
- Difficulty level
|
||||
- Significance
|
||||
|
||||
5. Future Directions
|
||||
- Emerging venues and research streams
|
||||
- Novel methodological approaches
|
||||
- Cross-disciplinary opportunities
|
||||
- Research gaps by publication type
|
||||
|
||||
IMPORTANT:
|
||||
- Focus on explaining why each paper is interesting
|
||||
- Highlight the novelty and potential impact
|
||||
- Consider the credibility and stage of each publication
|
||||
- Use the provided paper titles with their links when referring to specific papers
|
||||
- Base recommendations ONLY on the explicitly provided paper information
|
||||
- Do not make ANY assumptions about papers beyond the given data
|
||||
- When information is missing or unclear, acknowledge the limitation
|
||||
- Never speculate about:
|
||||
* Paper quality or rigor not evidenced in the data
|
||||
* Research impact beyond citation counts
|
||||
* Implementation details not mentioned
|
||||
* Author expertise or background
|
||||
* Future research directions not stated
|
||||
- For each paper, cite only verifiable information
|
||||
- Clearly distinguish between facts and potential implications
|
||||
- Each paper includes download links in its 📥 PDF Downloads section
|
||||
|
||||
Format your response in markdown with clear sections.
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
@@ -0,0 +1,344 @@
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
|
||||
class 单篇论文分析功能(BaseHandler):
|
||||
"""论文分析处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理论文分析请求,返回最终的prompt"""
|
||||
|
||||
# 1. 获取论文详情
|
||||
paper = await self._get_paper_details(criteria)
|
||||
if not paper:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 保存为ranked_papers以便统一接口
|
||||
self.ranked_papers = [paper]
|
||||
|
||||
# 2. 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
|
||||
# 获取论文信息
|
||||
title = getattr(paper, "title", "Unknown Title")
|
||||
authors = getattr(paper, "authors", [])
|
||||
year = getattr(paper, "year", "Unknown Year")
|
||||
abstract = getattr(paper, "abstract", "No abstract available")
|
||||
citations = getattr(paper, "citations", "N/A")
|
||||
|
||||
# 添加论文ID信息
|
||||
paper_id = ""
|
||||
if criteria.paper_source == "arxiv":
|
||||
paper_id = f"arXiv ID: {criteria.paper_id}\n"
|
||||
elif criteria.paper_source == "doi":
|
||||
paper_id = f"DOI: {criteria.paper_id}\n"
|
||||
|
||||
# 格式化作者列表
|
||||
authors_str = ', '.join(authors) if isinstance(authors, list) else authors
|
||||
|
||||
final_prompt = f"""Current time: {current_time}
|
||||
|
||||
Please provide a comprehensive analysis of the following paper:
|
||||
|
||||
{paper_id}Title: {title}
|
||||
Authors: {authors_str}
|
||||
Year: {year}
|
||||
Citations: {citations}
|
||||
Publication Venue: {paper.venue_name} ({paper.venue_type})
|
||||
{f"Publisher: {paper.venue_info.get('publisher')}" if paper.venue_info.get('publisher') else ""}
|
||||
{f"Journal Reference: {paper.venue_info.get('journal_ref')}" if paper.venue_info.get('journal_ref') else ""}
|
||||
Abstract: {abstract}
|
||||
|
||||
Please provide:
|
||||
1. Publication Context
|
||||
- Publication venue analysis and impact factor (if available)
|
||||
- Paper type (journal article, conference paper, preprint)
|
||||
- Publication timeline and peer review status
|
||||
- Publisher reputation and venue prestige
|
||||
|
||||
2. Research Context
|
||||
- Field positioning and significance
|
||||
- Historical context and prior work
|
||||
- Related research streams
|
||||
- Cross-venue impact analysis
|
||||
|
||||
3. Technical Analysis
|
||||
- Detailed methodology review
|
||||
- Implementation details
|
||||
- Experimental setup and results
|
||||
- Technical innovations
|
||||
|
||||
4. Impact Analysis
|
||||
- Citation patterns and influence
|
||||
- Cross-venue recognition
|
||||
- Industry vs. academic impact
|
||||
- Practical applications
|
||||
|
||||
5. Critical Review
|
||||
- Methodological rigor assessment
|
||||
- Result reliability and reproducibility
|
||||
- Venue-appropriate evaluation standards
|
||||
- Limitations and potential improvements
|
||||
|
||||
IMPORTANT:
|
||||
- Strictly use ONLY the information provided above about the paper
|
||||
- Do not make ANY assumptions or inferences beyond the given data
|
||||
- If certain information is not provided, explicitly state that it is unknown
|
||||
- For any unclear or missing details, acknowledge the limitation rather than speculating
|
||||
- When discussing methodology or results, only describe what is explicitly stated in the abstract
|
||||
- Never fabricate or assume any details about:
|
||||
* Publication venues or status
|
||||
* Implementation details not mentioned
|
||||
* Results or findings not stated
|
||||
* Impact or influence not supported by the citation count
|
||||
* Authors' affiliations or backgrounds
|
||||
* Future work or implications not mentioned
|
||||
- You can find the paper's download options in the 📥 PDF Downloads section
|
||||
- Available download formats include arXiv PDF, DOI links, and source URLs
|
||||
|
||||
Format your response in markdown with clear sections.
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
async def _get_paper_details(self, criteria: SearchCriteria):
|
||||
"""获取论文详情"""
|
||||
try:
|
||||
if criteria.paper_source == "arxiv":
|
||||
# 使用 arxiv ID 搜索
|
||||
papers = await self.arxiv.search_by_id(criteria.paper_id)
|
||||
return papers[0] if papers else None
|
||||
|
||||
elif criteria.paper_source == "doi":
|
||||
# 尝试从所有来源获取
|
||||
paper = await self.semantic.get_paper_by_doi(criteria.paper_id)
|
||||
if not paper:
|
||||
# 如果Semantic Scholar没有找到,尝试PubMed
|
||||
papers = await self.pubmed.search(
|
||||
f"{criteria.paper_id}[doi]",
|
||||
limit=1
|
||||
)
|
||||
if papers:
|
||||
return papers[0]
|
||||
return paper
|
||||
|
||||
elif criteria.paper_source == "title":
|
||||
# 使用_search_all_sources搜索
|
||||
search_params = {
|
||||
'max_papers': 1,
|
||||
'min_year': 1900, # 不限制年份
|
||||
'search_multiplier': 1
|
||||
}
|
||||
|
||||
# 设置搜索参数
|
||||
criteria.arxiv_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'ti:"{criteria.paper_title}"',
|
||||
"limit": 1
|
||||
}
|
||||
criteria.semantic_params = {
|
||||
"query": criteria.paper_title,
|
||||
"limit": 1
|
||||
}
|
||||
criteria.pubmed_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'"{criteria.paper_title}"[Title]',
|
||||
"limit": 1
|
||||
}
|
||||
|
||||
papers = await self._search_all_sources(criteria, search_params)
|
||||
return papers[0] if papers else None
|
||||
|
||||
# 如果都没有找到,尝试使用 main_topic 作为标题搜索
|
||||
if not criteria.paper_title and not criteria.paper_id:
|
||||
search_params = {
|
||||
'max_papers': 1,
|
||||
'min_year': 1900,
|
||||
'search_multiplier': 1
|
||||
}
|
||||
|
||||
# 设置搜索参数
|
||||
criteria.arxiv_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'ti:"{criteria.main_topic}"',
|
||||
"limit": 1
|
||||
}
|
||||
criteria.semantic_params = {
|
||||
"query": criteria.main_topic,
|
||||
"limit": 1
|
||||
}
|
||||
criteria.pubmed_params = {
|
||||
"search_type": "basic",
|
||||
"query": f'"{criteria.main_topic}"[Title]',
|
||||
"limit": 1
|
||||
}
|
||||
|
||||
papers = await self._search_all_sources(criteria, search_params)
|
||||
return papers[0] if papers else None
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文详情时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _get_citation_context(self, paper: Dict, plugin_kwargs: Dict) -> Tuple[List, List]:
|
||||
"""获取引用上下文"""
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 使用论文标题构建搜索参数
|
||||
title_query = f'ti:"{getattr(paper, "title", "")}"'
|
||||
arxiv_params = {
|
||||
"query": title_query,
|
||||
"limit": search_params['max_papers'],
|
||||
"search_type": "basic",
|
||||
"sort_by": "relevance",
|
||||
"sort_order": "descending"
|
||||
}
|
||||
semantic_params = {
|
||||
"query": getattr(paper, "title", ""),
|
||||
"limit": search_params['max_papers']
|
||||
}
|
||||
|
||||
citations, references = await asyncio.gather(
|
||||
self._search_semantic(
|
||||
semantic_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=search_params['min_year']
|
||||
),
|
||||
self._search_arxiv(
|
||||
arxiv_params,
|
||||
limit_multiplier=search_params['search_multiplier'],
|
||||
min_year=search_params['min_year']
|
||||
)
|
||||
)
|
||||
|
||||
return citations, references
|
||||
|
||||
async def _generate_analysis(
|
||||
self,
|
||||
paper: Dict,
|
||||
citations: List,
|
||||
references: List,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any]
|
||||
) -> List[List[str]]:
|
||||
"""生成论文分析"""
|
||||
|
||||
# 构建提示
|
||||
analysis_prompt = f"""Please provide a comprehensive analysis of the following paper:
|
||||
|
||||
Paper details:
|
||||
{self._format_paper(paper)}
|
||||
|
||||
Key references (papers cited by this paper):
|
||||
{self._format_papers(references)}
|
||||
|
||||
Important citations (papers that cite this paper):
|
||||
{self._format_papers(citations)}
|
||||
|
||||
Please provide:
|
||||
1. Paper Overview
|
||||
- Main research question/objective
|
||||
- Key methodology/approach
|
||||
- Main findings/contributions
|
||||
|
||||
2. Technical Analysis
|
||||
- Detailed methodology review
|
||||
- Technical innovations
|
||||
- Implementation details
|
||||
- Experimental setup and results
|
||||
|
||||
3. Impact Analysis
|
||||
- Significance in the field
|
||||
- Influence on subsequent research (based on citing papers)
|
||||
- Relationship to prior work (based on cited papers)
|
||||
- Practical applications
|
||||
|
||||
4. Critical Review
|
||||
- Strengths and limitations
|
||||
- Potential improvements
|
||||
- Open questions and future directions
|
||||
- Alternative approaches
|
||||
|
||||
5. Related Research Context
|
||||
- How it builds on previous work
|
||||
- How it has influenced subsequent research
|
||||
- Comparison with alternative approaches
|
||||
|
||||
Format your response in markdown with clear sections."""
|
||||
|
||||
# 并行生成概述和技术分析
|
||||
for response_chunk in request_gpt(
|
||||
inputs_array=[
|
||||
analysis_prompt,
|
||||
self._get_technical_prompt(paper)
|
||||
],
|
||||
inputs_show_user_array=[
|
||||
"Generating paper analysis...",
|
||||
"Analyzing technical details..."
|
||||
],
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[history, []],
|
||||
sys_prompt_array=[
|
||||
system_prompt,
|
||||
"You are an expert at analyzing technical details in research papers."
|
||||
]
|
||||
):
|
||||
pass # 等待生成完成
|
||||
|
||||
# 获取最后的两个回答
|
||||
if chatbot and len(chatbot[-2:]) == 2:
|
||||
analysis = chatbot[-2][1]
|
||||
technical = chatbot[-1][1]
|
||||
full_analysis = f"""# Paper Analysis: {paper.title}
|
||||
|
||||
## General Analysis
|
||||
{analysis}
|
||||
|
||||
## Technical Deep Dive
|
||||
{technical}
|
||||
"""
|
||||
chatbot.append(["Here is the paper analysis:", full_analysis])
|
||||
else:
|
||||
chatbot.append(["Here is the paper analysis:", "Failed to generate analysis."])
|
||||
|
||||
return chatbot
|
||||
|
||||
def _get_technical_prompt(self, paper: Dict) -> str:
|
||||
"""生成技术分析提示"""
|
||||
return f"""Please provide a detailed technical analysis of the following paper:
|
||||
|
||||
{self._format_paper(paper)}
|
||||
|
||||
Focus on:
|
||||
1. Mathematical formulations and their implications
|
||||
2. Algorithm design and complexity analysis
|
||||
3. Architecture details and design choices
|
||||
4. Implementation challenges and solutions
|
||||
5. Performance analysis and bottlenecks
|
||||
6. Technical limitations and potential improvements
|
||||
|
||||
Format your response in markdown, focusing purely on technical aspects."""
|
||||
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
from textwrap import dedent
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
|
||||
class 学术问答功能(BaseHandler):
|
||||
"""学术问答处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理学术问答请求,返回最终的prompt"""
|
||||
|
||||
# 1. 获取搜索参数
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 2. 搜索相关论文
|
||||
papers = await self._search_relevant_papers(criteria, search_params)
|
||||
if not papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = dedent(f"""Current time: {current_time}
|
||||
|
||||
Based on the following paper abstracts, please answer this academic question: {criteria.original_query}
|
||||
|
||||
Available papers for reference:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
Please structure your response in the following format:
|
||||
|
||||
1. Core Answer (2-3 paragraphs)
|
||||
- Provide a clear, direct answer synthesizing key findings
|
||||
- Support main points with citations [1,2,etc.]
|
||||
- Focus on consensus and differences across papers
|
||||
|
||||
2. Key Evidence (2-3 paragraphs)
|
||||
- Present supporting evidence from abstracts
|
||||
- Compare methodologies and results
|
||||
- Highlight significant findings with citations
|
||||
|
||||
3. Research Context (1-2 paragraphs)
|
||||
- Discuss current trends and developments
|
||||
- Identify research gaps or limitations
|
||||
- Suggest potential future directions
|
||||
|
||||
Guidelines:
|
||||
- Base your answer ONLY on the provided abstracts
|
||||
- Use numbered citations [1], [2,3], etc. for every claim
|
||||
- Maintain academic tone and objectivity
|
||||
- Synthesize findings across multiple papers
|
||||
- Focus on the most relevant information to the question
|
||||
|
||||
Constraints:
|
||||
- Do not include information beyond the provided abstracts
|
||||
- Avoid speculation or personal opinions
|
||||
- Do not elaborate on technical details unless directly relevant
|
||||
- Keep citations concise and focused
|
||||
- Use [N] citations for every major claim or finding
|
||||
- Cite multiple papers [1,2,3] when showing consensus
|
||||
- Place citations immediately after the relevant statements
|
||||
|
||||
Note: Provide citations for every major claim to ensure traceability to source papers.
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language. Use Chinese to answer if no language is specified.
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
)
|
||||
|
||||
return final_prompt
|
||||
|
||||
async def _search_relevant_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||
"""搜索相关论文"""
|
||||
# 使用_search_all_sources替代原来的并行搜索
|
||||
all_papers = await self._search_all_sources(criteria, search_params)
|
||||
|
||||
if not all_papers:
|
||||
return []
|
||||
|
||||
# 使用BGE重排序
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.main_topic,
|
||||
papers=all_papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
return self.ranked_papers or []
|
||||
|
||||
async def _generate_answer(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
papers: List,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any]
|
||||
) -> List[List[str]]:
|
||||
"""生成答案"""
|
||||
|
||||
# 构建提示
|
||||
qa_prompt = dedent(f"""Please answer the following academic question based on recent research papers.
|
||||
|
||||
Question: {criteria.main_topic}
|
||||
|
||||
Relevant papers:
|
||||
{self._format_papers(papers)}
|
||||
|
||||
Please provide:
|
||||
1. A direct answer to the question
|
||||
2. Supporting evidence from the papers
|
||||
3. Different perspectives or approaches if applicable
|
||||
4. Current limitations and open questions
|
||||
5. References to specific papers
|
||||
|
||||
Format your response in markdown with clear sections."""
|
||||
)
|
||||
# 调用LLM生成答案
|
||||
for response_chunk in request_gpt(
|
||||
inputs_array=[qa_prompt],
|
||||
inputs_show_user_array=["Generating answer..."],
|
||||
llm_kwargs=llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[history],
|
||||
sys_prompt_array=[system_prompt]
|
||||
):
|
||||
pass # 等待生成完成
|
||||
|
||||
# 获取最后的回答
|
||||
if chatbot and len(chatbot[-1]) >= 2:
|
||||
answer = chatbot[-1][1]
|
||||
chatbot.append(["Here is the answer:", answer])
|
||||
else:
|
||||
chatbot.append(["Here is the answer:", "Failed to generate answer."])
|
||||
|
||||
return chatbot
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
from typing import List, Dict, Any
|
||||
from .base_handler import BaseHandler
|
||||
from textwrap import dedent
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
|
||||
class 论文推荐功能(BaseHandler):
|
||||
"""论文推荐处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理论文推荐请求,返回最终的prompt"""
|
||||
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 1. 先搜索种子论文
|
||||
seed_papers = await self._search_seed_papers(criteria, search_params)
|
||||
if not seed_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 使用BGE重排序
|
||||
all_papers = seed_papers
|
||||
|
||||
if not all_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.original_query,
|
||||
papers=all_papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
if not self.ranked_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 构建最终的prompt
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = dedent(f"""Current time: {current_time}
|
||||
|
||||
Based on the user's interest in {criteria.main_topic}, here are relevant papers.
|
||||
|
||||
Available papers for recommendation:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
Please provide:
|
||||
1. Group papers by sub-topics or themes if applicable
|
||||
|
||||
2. For each paper:
|
||||
- Publication time and venue (when available)
|
||||
- Journal metrics (when available):
|
||||
* Impact Factor (IF)
|
||||
* JCR Quartile
|
||||
* Chinese Academy of Sciences (CAS) Division
|
||||
- The key contributions and main findings
|
||||
- Why it's relevant to the user's interests
|
||||
- How it relates to other recommended papers
|
||||
- The paper's citation count and citation impact
|
||||
- The paper's download link
|
||||
|
||||
3. A suggested reading order based on:
|
||||
- Journal impact and quality metrics
|
||||
- Chronological development of ideas
|
||||
- Paper relationships and dependencies
|
||||
- Difficulty level
|
||||
- Impact and significance
|
||||
|
||||
4. Future Directions
|
||||
- Emerging venues and research streams
|
||||
- Novel methodological approaches
|
||||
- Cross-disciplinary opportunities
|
||||
- Research gaps by publication type
|
||||
|
||||
|
||||
IMPORTANT:
|
||||
- Focus on explaining why each paper is valuable
|
||||
- Highlight connections between papers
|
||||
- Consider both citation counts AND journal metrics when discussing impact
|
||||
- When available, use IF, JCR quartile, and CAS division to assess paper quality
|
||||
- Mention publication timing when discussing paper relationships
|
||||
- When referring to papers, use HTML links in this format:
|
||||
* For DOIs: <a href='https://doi.org/DOI_HERE' target='_blank'>DOI: DOI_HERE</a>
|
||||
* For titles: <a href='PAPER_URL' target='_blank'>PAPER_TITLE</a>
|
||||
- Present papers in a way that shows the evolution of ideas over time
|
||||
- Base recommendations ONLY on the explicitly provided paper information
|
||||
- Do not make ANY assumptions about papers beyond the given data
|
||||
- When information is missing or unclear, acknowledge the limitation
|
||||
- Never speculate about:
|
||||
* Paper quality or rigor not evidenced in the data
|
||||
* Research impact beyond citation counts and journal metrics
|
||||
* Implementation details not mentioned
|
||||
* Author expertise or background
|
||||
* Future research directions not stated
|
||||
- For each recommendation, cite only verifiable information
|
||||
- Clearly distinguish between facts and potential implications
|
||||
|
||||
Format your response in markdown with clear sections.
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
)
|
||||
return final_prompt
|
||||
|
||||
async def _search_seed_papers(self, criteria: SearchCriteria, search_params: Dict) -> List:
|
||||
"""搜索种子论文"""
|
||||
try:
|
||||
# 使用_search_all_sources替代原来的并行搜索
|
||||
all_papers = await self._search_all_sources(criteria, search_params)
|
||||
|
||||
if not all_papers:
|
||||
return []
|
||||
|
||||
return all_papers
|
||||
|
||||
except Exception as e:
|
||||
print(f"搜索种子论文时出错: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _get_recommendations(self, seed_papers: List, multiplier: int = 1) -> List:
|
||||
"""获取推荐论文"""
|
||||
recommendations = []
|
||||
base_limit = 3 * multiplier
|
||||
|
||||
# 将种子论文添加到推荐列表中
|
||||
recommendations.extend(seed_papers)
|
||||
|
||||
# 只使用前5篇论文作为种子
|
||||
seed_papers = seed_papers[:5]
|
||||
|
||||
for paper in seed_papers:
|
||||
try:
|
||||
if paper.doi and paper.doi.startswith("10.48550/arXiv."):
|
||||
# arXiv论文
|
||||
arxiv_id = paper.doi.split(".")[-1]
|
||||
paper_details = await self.arxiv.get_paper_details(arxiv_id)
|
||||
if paper_details and hasattr(paper_details, 'venue'):
|
||||
category = paper_details.venue.split(":")[-1]
|
||||
similar_papers = await self.arxiv.search_by_category(
|
||||
category,
|
||||
limit=base_limit,
|
||||
sort_by='relevance'
|
||||
)
|
||||
recommendations.extend(similar_papers)
|
||||
elif paper.doi: # 只对有DOI的论文获取推荐
|
||||
# Semantic Scholar论文
|
||||
similar_papers = await self.semantic.get_recommended_papers(
|
||||
paper.doi,
|
||||
limit=base_limit
|
||||
)
|
||||
if similar_papers: # 只添加成功获取的推荐
|
||||
recommendations.extend(similar_papers)
|
||||
else:
|
||||
# 对于没有DOI的论文,使用标题进行相关搜索
|
||||
if paper.title:
|
||||
similar_papers = await self.semantic.search(
|
||||
query=paper.title,
|
||||
limit=base_limit
|
||||
)
|
||||
recommendations.extend(similar_papers)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取论文 '{paper.title}' 的推荐时发生错误: {str(e)}")
|
||||
continue
|
||||
|
||||
# 去重处理
|
||||
seen_dois = set()
|
||||
unique_recommendations = []
|
||||
for paper in recommendations:
|
||||
if paper.doi and paper.doi not in seen_dois:
|
||||
seen_dois.add(paper.doi)
|
||||
unique_recommendations.append(paper)
|
||||
elif not paper.doi and paper not in unique_recommendations:
|
||||
unique_recommendations.append(paper)
|
||||
|
||||
return unique_recommendations
|
||||
@@ -0,0 +1,193 @@
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from .base_handler import BaseHandler
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import asyncio
|
||||
|
||||
class 文献综述功能(BaseHandler):
|
||||
"""文献综述处理器"""
|
||||
|
||||
def __init__(self, arxiv, semantic, llm_kwargs=None):
|
||||
super().__init__(arxiv, semantic, llm_kwargs)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
criteria: SearchCriteria,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
"""处理文献综述请求,返回最终的prompt"""
|
||||
|
||||
# 获取搜索参数
|
||||
search_params = self._get_search_params(plugin_kwargs)
|
||||
|
||||
# 使用_search_all_sources替代原来的并行搜索
|
||||
all_papers = await self._search_all_sources(criteria, search_params)
|
||||
|
||||
if not all_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
self.ranked_papers = self.paper_ranker.rank_papers(
|
||||
query=criteria.original_query,
|
||||
papers=all_papers,
|
||||
search_criteria=criteria
|
||||
)
|
||||
|
||||
# 检查排序后的论文数量
|
||||
if not self.ranked_papers:
|
||||
return self._generate_apology_prompt(criteria)
|
||||
|
||||
# 检查是否包含PubMed论文
|
||||
has_pubmed_papers = any(paper.url and 'pubmed.ncbi.nlm.nih.gov' in paper.url
|
||||
for paper in self.ranked_papers)
|
||||
|
||||
if has_pubmed_papers:
|
||||
return self._generate_medical_review_prompt(criteria)
|
||||
else:
|
||||
return self._generate_general_review_prompt(criteria)
|
||||
|
||||
def _generate_medical_review_prompt(self, criteria: SearchCriteria) -> str:
|
||||
"""生成医学文献综述prompt"""
|
||||
return f"""Current time: {self._get_current_time()}
|
||||
|
||||
Conduct a systematic medical literature review on {criteria.main_topic} based STRICTLY on the provided articles.
|
||||
|
||||
Available literature for review:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
|
||||
|
||||
Please structure your medical review following these guidelines:
|
||||
|
||||
1. Research Overview
|
||||
- Main research questions and objectives from the studies
|
||||
- Types of studies included (clinical trials, observational studies, etc.)
|
||||
- Study populations and settings
|
||||
- Time period of the research
|
||||
|
||||
2. Key Findings
|
||||
- Main outcomes and results reported in abstracts
|
||||
- Primary endpoints and their measurements
|
||||
- Statistical significance when reported
|
||||
- Observed trends across studies
|
||||
|
||||
3. Methods Summary
|
||||
- Study designs used
|
||||
- Major interventions or treatments studied
|
||||
- Key outcome measures
|
||||
- Patient populations studied
|
||||
|
||||
4. Clinical Relevance
|
||||
- Reported clinical implications
|
||||
- Main conclusions from authors
|
||||
- Reported benefits and risks
|
||||
- Treatment responses when available
|
||||
|
||||
5. Research Status
|
||||
- Current research focus areas
|
||||
- Reported limitations
|
||||
- Gaps identified in abstracts
|
||||
- Authors' suggested future directions
|
||||
|
||||
CRITICAL REQUIREMENTS:
|
||||
|
||||
Citation Rules (MANDATORY):
|
||||
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
|
||||
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
|
||||
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
|
||||
- Use ONLY the papers provided in the available literature list above
|
||||
- Citations must appear immediately after each statement
|
||||
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
|
||||
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
|
||||
|
||||
Content Guidelines:
|
||||
- Present only information available in the provided papers
|
||||
- If certain information is not available, simply omit that aspect rather than explicitly stating its absence
|
||||
- Focus on synthesizing and presenting available findings
|
||||
- Maintain professional medical writing style
|
||||
- Present limitations and gaps as research opportunities rather than missing information
|
||||
|
||||
Writing Style:
|
||||
- Use precise medical terminology
|
||||
- Maintain objective reporting
|
||||
- Use consistent terminology throughout
|
||||
- Present a cohesive narrative without referencing data limitations
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
def _generate_general_review_prompt(self, criteria: SearchCriteria) -> str:
|
||||
"""生成通用文献综述prompt"""
|
||||
current_time = self._get_current_time()
|
||||
final_prompt = f"""Current time: {current_time}
|
||||
|
||||
Conduct a comprehensive literature review on {criteria.main_topic} focusing on the following aspects:
|
||||
{', '.join(criteria.sub_topics)}
|
||||
|
||||
Available literature for review:
|
||||
{self._format_papers(self.ranked_papers)}
|
||||
|
||||
IMPORTANT: If the user query contains specific requirements for the review structure or format, those requirements take precedence over the following guidelines.
|
||||
|
||||
Please structure your review following these guidelines:
|
||||
|
||||
1. Introduction and Research Background
|
||||
- Current state and significance of the research field
|
||||
- Key research problems and challenges
|
||||
- Research development timeline and evolution
|
||||
|
||||
2. Research Directions and Classifications
|
||||
- Major research directions and their relationships
|
||||
- Different technical approaches and their characteristics
|
||||
- Comparative analysis of various solutions
|
||||
|
||||
3. Core Technologies and Methods
|
||||
- Key technological breakthroughs
|
||||
- Advantages and limitations of different methods
|
||||
- Technical challenges and solutions
|
||||
|
||||
4. Applications and Impact
|
||||
- Real-world applications and use cases
|
||||
- Industry influence and practical value
|
||||
- Implementation challenges and solutions
|
||||
|
||||
5. Future Trends and Prospects
|
||||
- Emerging research directions
|
||||
- Unsolved problems and challenges
|
||||
- Potential breakthrough points
|
||||
|
||||
CRITICAL REQUIREMENTS:
|
||||
|
||||
Citation Rules (MANDATORY):
|
||||
- EVERY finding or statement MUST be supported by citations [N], where N is the number matching the paper in the provided literature list
|
||||
- When reporting outcomes, ALWAYS cite the source studies using the exact paper numbers from the literature list
|
||||
- For findings supported by multiple studies, use consecutive numbers as shown in the literature list [1,2,3]
|
||||
- Use ONLY the papers provided in the available literature list above
|
||||
- Citations must appear immediately after each statement
|
||||
- Citation numbers MUST match the numbers assigned to papers in the literature list above (e.g., if a finding comes from the first paper in the list, cite it as [1])
|
||||
- DO NOT change or reorder the citation numbers - they must exactly match the paper numbers in the literature list
|
||||
|
||||
Writing Style:
|
||||
- Maintain academic and professional tone
|
||||
- Focus on objective analysis with proper citations
|
||||
- Ensure logical flow and clear structure
|
||||
|
||||
Content Requirements:
|
||||
- Base ALL analysis STRICTLY on the provided papers with explicit citations
|
||||
- When introducing any concept, method, or finding, immediately follow with [N]
|
||||
- For each research direction or approach, cite the specific papers [N] that proposed or developed it
|
||||
- When discussing limitations or challenges, cite the papers [N] that identified them
|
||||
- DO NOT include information from sources outside the provided paper list
|
||||
- DO NOT make unsupported claims or statements
|
||||
|
||||
Language requirement:
|
||||
- If the query explicitly specifies a language, use that language
|
||||
- Otherwise, match the language of the original user query
|
||||
"""
|
||||
|
||||
return final_prompt
|
||||
|
||||
@@ -0,0 +1,452 @@
|
||||
from typing import List, Dict
|
||||
from crazy_functions.review_fns.data_sources.base_source import PaperMetadata
|
||||
from request_llms.embed_models.bge_llm import BGELLMRanker
|
||||
from crazy_functions.review_fns.query_analyzer import SearchCriteria
|
||||
import random
|
||||
from crazy_functions.review_fns.data_sources.journal_metrics import JournalMetrics
|
||||
|
||||
class PaperLLMRanker:
|
||||
"""使用LLM进行论文重排序"""
|
||||
|
||||
def __init__(self, llm_kwargs: Dict = None):
|
||||
self.ranker = BGELLMRanker(llm_kwargs=llm_kwargs)
|
||||
self.journal_metrics = JournalMetrics()
|
||||
|
||||
def _update_paper_metrics(self, papers: List[PaperMetadata]) -> None:
|
||||
"""更新论文的期刊指标"""
|
||||
for paper in papers:
|
||||
# 跳过arXiv来源的论文
|
||||
if getattr(paper, 'source', '') == 'arxiv':
|
||||
continue
|
||||
|
||||
if hasattr(paper, 'venue_name') or hasattr(paper, 'venue_info'):
|
||||
# 获取venue_name和venue_info
|
||||
venue_name = getattr(paper, 'venue_name', '')
|
||||
venue_info = getattr(paper, 'venue_info', {})
|
||||
|
||||
# 使用改进的匹配逻辑获取指标
|
||||
metrics = self.journal_metrics.get_journal_metrics(venue_name, venue_info)
|
||||
|
||||
# 更新论文的指标
|
||||
paper.if_factor = metrics.get('if_factor')
|
||||
paper.jcr_division = metrics.get('jcr_division')
|
||||
paper.cas_division = metrics.get('cas_division')
|
||||
|
||||
def _get_year_as_int(self, paper) -> int:
|
||||
"""统一获取论文年份为整数格式
|
||||
|
||||
Args:
|
||||
paper: 论文对象或直接是年份值
|
||||
|
||||
Returns:
|
||||
整数格式的年份,如果无法转换则返回0
|
||||
"""
|
||||
try:
|
||||
# 如果输入直接是年份而不是论文对象
|
||||
if isinstance(paper, int):
|
||||
return paper
|
||||
elif isinstance(paper, str):
|
||||
try:
|
||||
return int(paper)
|
||||
except ValueError:
|
||||
import re
|
||||
year_match = re.search(r'\d{4}', paper)
|
||||
if year_match:
|
||||
return int(year_match.group())
|
||||
return 0
|
||||
elif isinstance(paper, float):
|
||||
return int(paper)
|
||||
|
||||
# 处理论文对象
|
||||
year = getattr(paper, 'year', None)
|
||||
if year is None:
|
||||
return 0
|
||||
|
||||
# 如果是字符串,尝试转换为整数
|
||||
if isinstance(year, str):
|
||||
# 首先尝试直接转换整个字符串
|
||||
try:
|
||||
return int(year)
|
||||
except ValueError:
|
||||
# 如果直接转换失败,尝试提取第一个数字序列
|
||||
import re
|
||||
year_match = re.search(r'\d{4}', year)
|
||||
if year_match:
|
||||
return int(year_match.group())
|
||||
return 0
|
||||
# 如果是浮点数,转换为整数
|
||||
elif isinstance(year, float):
|
||||
return int(year)
|
||||
# 如果已经是整数,直接返回
|
||||
elif isinstance(year, int):
|
||||
return year
|
||||
return 0
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
def rank_papers(
|
||||
self,
|
||||
query: str,
|
||||
papers: List[PaperMetadata],
|
||||
search_criteria: SearchCriteria = None,
|
||||
top_k: int = 40,
|
||||
use_rerank: bool = False,
|
||||
pre_filter_ratio: float = 0.5,
|
||||
max_papers: int = 150
|
||||
) -> List[PaperMetadata]:
|
||||
"""对论文进行重排序"""
|
||||
initial_count = len(papers) if papers else 0
|
||||
stats = {'initial': initial_count}
|
||||
|
||||
if not papers or not query:
|
||||
return []
|
||||
|
||||
# 更新论文的期刊指标
|
||||
self._update_paper_metrics(papers)
|
||||
|
||||
# 构建增强查询
|
||||
# enhanced_query = self._build_enhanced_query(query, search_criteria) if search_criteria else query
|
||||
enhanced_query = query
|
||||
# 首先过滤不满足年份要求的论文
|
||||
if search_criteria and search_criteria.start_year and search_criteria.end_year:
|
||||
before_year_filter = len(papers)
|
||||
filtered_papers = []
|
||||
start_year = int(search_criteria.start_year)
|
||||
end_year = int(search_criteria.end_year)
|
||||
|
||||
for paper in papers:
|
||||
paper_year = self._get_year_as_int(paper)
|
||||
if paper_year == 0 or start_year <= paper_year <= end_year:
|
||||
filtered_papers.append(paper)
|
||||
|
||||
papers = filtered_papers
|
||||
stats['after_year_filter'] = len(papers)
|
||||
|
||||
if not papers: # 如果过滤后没有论文,直接返回空列表
|
||||
return []
|
||||
|
||||
# 新增:对少量论文的快速处理
|
||||
SMALL_PAPER_THRESHOLD = 10 # 定义"少量"论文的阈值
|
||||
if len(papers) <= SMALL_PAPER_THRESHOLD:
|
||||
# 对于少量论文,直接根据查询类型进行简单排序
|
||||
if search_criteria:
|
||||
if search_criteria.query_type == "latest":
|
||||
papers.sort(key=lambda x: getattr(x, 'year', 0) or 0, reverse=True)
|
||||
elif search_criteria.query_type == "recommend":
|
||||
papers.sort(key=lambda x: getattr(x, 'citations', 0) or 0, reverse=True)
|
||||
elif search_criteria.query_type == "review":
|
||||
papers.sort(key=lambda x:
|
||||
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
|
||||
keyword in (getattr(x, 'abstract', '') or '').lower()
|
||||
for keyword in ['review', 'survey', 'overview'])
|
||||
else 0,
|
||||
reverse=True
|
||||
)
|
||||
return papers[:top_k]
|
||||
|
||||
# 1. 优先处理最新的论文
|
||||
if search_criteria and search_criteria.query_type == "latest":
|
||||
papers = sorted(papers, key=lambda x: self._get_year_as_int(x), reverse=True)
|
||||
|
||||
# 2. 如果是综述类查询,优先处理可能的综述论文
|
||||
if search_criteria and search_criteria.query_type == "review":
|
||||
papers = sorted(papers, key=lambda x:
|
||||
1 if any(keyword in (getattr(x, 'title', '') or '').lower() or
|
||||
keyword in (getattr(x, 'abstract', '') or '').lower()
|
||||
for keyword in ['review', 'survey', 'overview'])
|
||||
else 0,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 3. 如果论文数量超过限制,采用分层采样而不是完全随机
|
||||
if len(papers) > max_papers:
|
||||
before_max_limit = len(papers)
|
||||
papers = self._select_papers_strategically(papers, search_criteria, max_papers)
|
||||
stats['after_max_limit'] = len(papers)
|
||||
|
||||
try:
|
||||
paper_texts = []
|
||||
valid_papers = [] # 4. 跟踪有效论文
|
||||
|
||||
for paper in papers:
|
||||
if paper is None:
|
||||
continue
|
||||
# 5. 预先过滤明显不相关的论文
|
||||
if search_criteria and search_criteria.start_year:
|
||||
if getattr(paper, 'year', 0) and self._get_year_as_int(paper.year) < search_criteria.start_year:
|
||||
continue
|
||||
|
||||
doc = self._build_enhanced_document(paper, search_criteria)
|
||||
paper_texts.append(doc)
|
||||
valid_papers.append(paper) # 记录对应的论文
|
||||
|
||||
stats['after_valid_check'] = len(valid_papers)
|
||||
|
||||
if not paper_texts:
|
||||
return []
|
||||
|
||||
# 使用LLM判断相关性
|
||||
relevance_results = self.ranker.batch_check_relevance(
|
||||
query=enhanced_query, # 使用增强的查询
|
||||
paper_texts=paper_texts,
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
# 6. 优化相关论文的选择策略
|
||||
relevant_papers = []
|
||||
for paper, is_relevant in zip(valid_papers, relevance_results):
|
||||
if is_relevant:
|
||||
relevant_papers.append(paper)
|
||||
|
||||
stats['after_llm_filter'] = len(relevant_papers)
|
||||
|
||||
# 打印统计信息
|
||||
print(f"论文筛选统计: 初始数量={stats['initial']}, " +
|
||||
f"年份过滤后={stats.get('after_year_filter', stats['initial'])}, " +
|
||||
f"数量限制后={stats.get('after_max_limit', stats.get('after_year_filter', stats['initial']))}, " +
|
||||
f"有效性检查后={stats['after_valid_check']}, " +
|
||||
f"LLM筛选后={stats['after_llm_filter']}")
|
||||
|
||||
# 7. 改进回退策略
|
||||
if len(relevant_papers) < min(5, len(papers)):
|
||||
# 如果相关论文太少,返回按引用量排序的论文
|
||||
return sorted(
|
||||
papers[:top_k],
|
||||
key=lambda x: getattr(x, 'citations', 0) or 0,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 8. 对最终结果进行排序
|
||||
if search_criteria:
|
||||
if search_criteria.query_type == "latest":
|
||||
# 最新论文优先,但同年份按IF排序
|
||||
relevant_papers.sort(key=lambda x: (
|
||||
self._get_year_as_int(x),
|
||||
getattr(x, 'if_factor', 0) or 0
|
||||
), reverse=True)
|
||||
elif search_criteria.query_type == "recommend":
|
||||
# IF指数优先,其次是引用量
|
||||
relevant_papers.sort(key=lambda x: (
|
||||
getattr(x, 'if_factor', 0) or 0,
|
||||
getattr(x, 'citations', 0) or 0
|
||||
), reverse=True)
|
||||
else:
|
||||
# 默认按IF指数排序
|
||||
relevant_papers.sort(key=lambda x: getattr(x, 'if_factor', 0) or 0, reverse=True)
|
||||
|
||||
return relevant_papers[:top_k]
|
||||
|
||||
except Exception as e:
|
||||
print(f"论文排序时出错: {str(e)}")
|
||||
# 9. 改进错误处理的回退策略
|
||||
try:
|
||||
return sorted(
|
||||
papers[:top_k],
|
||||
key=lambda x: getattr(x, 'citations', 0) or 0,
|
||||
reverse=True
|
||||
)
|
||||
except:
|
||||
return papers[:top_k] if papers else []
|
||||
|
||||
def _build_enhanced_query(self, query: str, criteria: SearchCriteria) -> str:
|
||||
"""构建增强的查询文本"""
|
||||
components = []
|
||||
|
||||
# 强调这是用户的原始查询,是最重要的匹配依据
|
||||
components.append(f"Original user query that must be primarily matched: {query}")
|
||||
|
||||
if criteria:
|
||||
# 添加主题(如果与原始查询不同)
|
||||
if criteria.main_topic and criteria.main_topic != query:
|
||||
components.append(f"Additional context - The main topic is about: {criteria.main_topic}")
|
||||
|
||||
# 添加子主题
|
||||
if criteria.sub_topics:
|
||||
components.append(f"Secondary aspects to consider: {', '.join(criteria.sub_topics)}")
|
||||
|
||||
# 添加查询类型相关信息
|
||||
if criteria.query_type == "review":
|
||||
components.append("Paper type preference: Looking for comprehensive review papers, survey papers, or overview papers")
|
||||
elif criteria.query_type == "latest":
|
||||
components.append("Temporal preference: Focus on the most recent developments and latest papers")
|
||||
elif criteria.query_type == "recommend":
|
||||
components.append("Impact preference: Consider influential and fundamental papers")
|
||||
|
||||
# 直接连接所有组件,保持语序
|
||||
enhanced_query = ' '.join(components)
|
||||
|
||||
# 限制长度但不打乱顺序
|
||||
if len(enhanced_query) > 1000:
|
||||
enhanced_query = enhanced_query[:997] + "..."
|
||||
|
||||
return enhanced_query
|
||||
|
||||
def _build_enhanced_document(self, paper: PaperMetadata, criteria: SearchCriteria) -> str:
|
||||
"""构建增强的文档表示"""
|
||||
components = []
|
||||
|
||||
# 基本信息
|
||||
title = getattr(paper, 'title', '')
|
||||
authors = ', '.join(getattr(paper, 'authors', []))
|
||||
abstract = getattr(paper, 'abstract', '')
|
||||
year = getattr(paper, 'year', '')
|
||||
venue = getattr(paper, 'venue', '')
|
||||
|
||||
components.extend([
|
||||
f"Title: {title}",
|
||||
f"Authors: {authors}",
|
||||
f"Year: {year}",
|
||||
f"Venue: {venue}",
|
||||
f"Abstract: {abstract}"
|
||||
])
|
||||
|
||||
# 根据查询类型添加额外信息
|
||||
if criteria:
|
||||
if criteria.query_type == "review":
|
||||
# 对于综述类查询,强调论文的综述性质
|
||||
title_lower = (title or '').lower()
|
||||
abstract_lower = (abstract or '').lower()
|
||||
if any(keyword in title_lower or keyword in abstract_lower
|
||||
for keyword in ['review', 'survey', 'overview']):
|
||||
components.append("This is a review/survey paper")
|
||||
|
||||
elif criteria.query_type == "latest":
|
||||
# 对于最新论文查询,强调时间信息
|
||||
if year and int(year) >= criteria.start_year:
|
||||
components.append(f"This is a recent paper from {year}")
|
||||
|
||||
elif criteria.query_type == "recommend":
|
||||
# 对于推荐类查询,添加主题相关性信息
|
||||
if criteria.main_topic:
|
||||
title_lower = (title or '').lower()
|
||||
abstract_lower = (abstract or '').lower()
|
||||
topic_relevance = any(topic.lower() in title_lower or topic.lower() in abstract_lower
|
||||
for topic in [criteria.main_topic] + (criteria.sub_topics or []))
|
||||
if topic_relevance:
|
||||
components.append(f"This paper is directly related to {criteria.main_topic}")
|
||||
|
||||
return '\n'.join(components)
|
||||
|
||||
def _select_papers_strategically(
|
||||
self,
|
||||
papers: List[PaperMetadata],
|
||||
search_criteria: SearchCriteria,
|
||||
max_papers: int = 150
|
||||
) -> List[PaperMetadata]:
|
||||
"""战略性地选择论文子集,优先选择非Crossref来源的论文,
|
||||
当ADS论文充足时排除arXiv论文"""
|
||||
if len(papers) <= max_papers:
|
||||
return papers
|
||||
|
||||
# 1. 首先按来源分组
|
||||
papers_by_source = {
|
||||
'crossref': [],
|
||||
'adsabs': [],
|
||||
'arxiv': [],
|
||||
'others': [] # semantic, pubmed等其他来源
|
||||
}
|
||||
|
||||
for paper in papers:
|
||||
source = getattr(paper, 'source', '')
|
||||
if source == 'crossref':
|
||||
papers_by_source['crossref'].append(paper)
|
||||
elif source == 'adsabs':
|
||||
papers_by_source['adsabs'].append(paper)
|
||||
elif source == 'arxiv':
|
||||
papers_by_source['arxiv'].append(paper)
|
||||
else:
|
||||
papers_by_source['others'].append(paper)
|
||||
|
||||
# 2. 计算分数的通用函数
|
||||
def calculate_paper_score(paper):
|
||||
score = 0
|
||||
title = (getattr(paper, 'title', '') or '').lower()
|
||||
abstract = (getattr(paper, 'abstract', '') or '').lower()
|
||||
year = self._get_year_as_int(paper)
|
||||
citations = getattr(paper, 'citations', 0) or 0
|
||||
|
||||
# 安全地获取搜索条件
|
||||
main_topic = (getattr(search_criteria, 'main_topic', '') or '').lower()
|
||||
sub_topics = getattr(search_criteria, 'sub_topics', []) or []
|
||||
query_type = getattr(search_criteria, 'query_type', '')
|
||||
start_year = getattr(search_criteria, 'start_year', 0) or 0
|
||||
|
||||
# 主题相关性得分
|
||||
if main_topic and main_topic in title:
|
||||
score += 10
|
||||
if main_topic and main_topic in abstract:
|
||||
score += 5
|
||||
|
||||
# 子主题相关性得分
|
||||
for sub_topic in sub_topics:
|
||||
if sub_topic and sub_topic.lower() in title:
|
||||
score += 5
|
||||
if sub_topic and sub_topic.lower() in abstract:
|
||||
score += 2.5
|
||||
|
||||
# 根据查询类型调整分数
|
||||
if query_type == "review":
|
||||
review_keywords = ['review', 'survey', 'overview']
|
||||
if any(keyword in title for keyword in review_keywords):
|
||||
score *= 1.5
|
||||
if any(keyword in abstract for keyword in review_keywords):
|
||||
score *= 1.2
|
||||
elif query_type == "latest":
|
||||
if year and start_year:
|
||||
year_int = year if isinstance(year, int) else self._get_year_as_int(paper)
|
||||
start_year_int = start_year if isinstance(start_year, int) else int(start_year)
|
||||
if year_int >= start_year_int:
|
||||
recency_bonus = min(5, (year_int - start_year_int))
|
||||
score += recency_bonus * 2
|
||||
elif query_type == "recommend":
|
||||
citation_score = min(10, citations / 100)
|
||||
score += citation_score
|
||||
|
||||
return score
|
||||
|
||||
result = []
|
||||
|
||||
# 3. 处理ADS和arXiv论文
|
||||
non_crossref_papers = papers_by_source['others'] # 首先添加其他来源的论文
|
||||
|
||||
# 添加ADS论文
|
||||
if papers_by_source['adsabs']:
|
||||
non_crossref_papers.extend(papers_by_source['adsabs'])
|
||||
|
||||
# 只有当ADS论文不足20篇时,才添加arXiv论文
|
||||
if len(papers_by_source['adsabs']) <= 20:
|
||||
non_crossref_papers.extend(papers_by_source['arxiv'])
|
||||
elif not papers_by_source['adsabs'] and papers_by_source['arxiv']:
|
||||
# 如果没有ADS论文但有arXiv论文,也使用arXiv论文
|
||||
non_crossref_papers.extend(papers_by_source['arxiv'])
|
||||
|
||||
# 4. 对非Crossref论文评分和排序
|
||||
scored_non_crossref = [(p, calculate_paper_score(p)) for p in non_crossref_papers]
|
||||
scored_non_crossref.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 5. 先添加高分的非Crossref论文
|
||||
non_crossref_limit = max_papers * 0.9 # 90%的配额给非Crossref论文
|
||||
if len(scored_non_crossref) >= non_crossref_limit:
|
||||
result.extend([p[0] for p in scored_non_crossref[:int(non_crossref_limit)]])
|
||||
else:
|
||||
result.extend([p[0] for p in scored_non_crossref])
|
||||
|
||||
# 6. 如果还有剩余空间,考虑添加Crossref论文
|
||||
remaining_slots = max_papers - len(result)
|
||||
if remaining_slots > 0 and papers_by_source['crossref']:
|
||||
# 计算Crossref论文的最大数量(不超过总数的10%)
|
||||
max_crossref = min(remaining_slots, max_papers * 0.1)
|
||||
|
||||
# 对Crossref论文评分和排序
|
||||
scored_crossref = [(p, calculate_paper_score(p)) for p in papers_by_source['crossref']]
|
||||
scored_crossref.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 添加最高分的Crossref论文
|
||||
result.extend([p[0] for p in scored_crossref[:int(max_crossref)]])
|
||||
|
||||
# 7. 如果使用了Crossref论文后还有空位,继续使用非Crossref论文填充
|
||||
if len(result) < max_papers and len(scored_non_crossref) > len(result):
|
||||
remaining_non_crossref = [p[0] for p in scored_non_crossref[len(result):]]
|
||||
result.extend(remaining_non_crossref[:max_papers - len(result)])
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,76 @@
|
||||
# ADS query optimization prompt
|
||||
ADSABS_QUERY_PROMPT = """Analyze and optimize the following query for NASA ADS search.
|
||||
If the query is not related to astronomy, astrophysics, or physics, return <query>none</query>.
|
||||
If the query contains non-English terms, translate them to English first.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized ADS search query.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Relevant research areas for ADS:
|
||||
- Astronomy and astrophysics
|
||||
- Physics (theoretical and experimental)
|
||||
- Space science and exploration
|
||||
- Planetary science
|
||||
- Cosmology
|
||||
- Astrobiology
|
||||
- Related instrumentation and methods
|
||||
|
||||
Available search fields and filters:
|
||||
1. Basic fields:
|
||||
- title: Search in title (title:"term")
|
||||
- abstract: Search in abstract (abstract:"term")
|
||||
- author: Search for author names (author:"lastname, firstname")
|
||||
- year: Filter by year (year:2020-2023)
|
||||
- bibstem: Search by journal abbreviation (bibstem:ApJ)
|
||||
|
||||
2. Boolean operators:
|
||||
- AND
|
||||
- OR
|
||||
- NOT
|
||||
- (): Group terms
|
||||
- "": Exact phrase match
|
||||
|
||||
3. Special filters:
|
||||
- citations(identifier:paper): Papers citing a specific paper
|
||||
- references(identifier:paper): References of a specific paper
|
||||
- citation_count: Filter by citation count
|
||||
- database: Filter by database (database:astronomy)
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Black holes in galaxy centers after 2020"
|
||||
<query>title:"black hole" AND abstract:"galaxy center" AND year:2020-</query>
|
||||
|
||||
2. Query: "Papers by Neil deGrasse Tyson about exoplanets"
|
||||
<query>author:"Tyson, Neil deGrasse" AND title:exoplanet</query>
|
||||
|
||||
3. Query: "Most cited papers about dark matter in ApJ"
|
||||
<query>title:"dark matter" AND bibstem:ApJ AND citation_count:[100 TO *]</query>
|
||||
|
||||
4. Query: "Latest research on diabetes treatment"
|
||||
<query>none</query>
|
||||
|
||||
5. Query: "Machine learning for galaxy classification"
|
||||
<query>title:("machine learning" OR "deep learning") AND (title:galaxy OR abstract:galaxy) AND abstract:classification</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized ADS search query using appropriate fields and operators, or "none" if not relevant</query>"""
|
||||
|
||||
# System prompt
|
||||
ADSABS_QUERY_SYSTEM_PROMPT = """You are an expert at crafting NASA ADS search queries.
|
||||
Your task is to:
|
||||
1. First determine if the query is relevant to astronomy, astrophysics, or physics research
|
||||
2. If relevant, optimize the natural language query for the ADS API
|
||||
3. If not relevant, return "none" to indicate the query should be handled by other databases
|
||||
|
||||
Focus on creating precise queries that will return relevant astronomical and physics literature.
|
||||
Always generate English search terms regardless of the input language.
|
||||
Consider using field-specific search terms and appropriate filters to improve search accuracy.
|
||||
|
||||
Remember: ADS is specifically for astronomy, astrophysics, and physics research.
|
||||
Medical, biological, or general research queries should return "none"."""
|
||||
@@ -0,0 +1,341 @@
|
||||
# Basic type analysis prompt
|
||||
ARXIV_TYPE_PROMPT = """Analyze the research query and determine if arXiv search is needed and its type.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task 1: Determine if this query requires arXiv search
|
||||
- arXiv is suitable for:
|
||||
* Computer science and AI/ML
|
||||
* Physics and mathematics
|
||||
* Quantitative biology and finance
|
||||
* Electrical engineering
|
||||
* Recent preprints in these fields
|
||||
- arXiv is NOT needed for:
|
||||
* Medical research (unless ML/AI applications)
|
||||
* Social sciences
|
||||
* Business studies
|
||||
* Humanities
|
||||
* Industry reports
|
||||
|
||||
Task 2: If arXiv search is needed, determine the most appropriate search type
|
||||
Available types:
|
||||
1. basic: Keyword-based search across all fields
|
||||
- For specific technical queries
|
||||
- When looking for particular methods or applications
|
||||
2. category: Category-based search within specific fields
|
||||
- For broad topic exploration
|
||||
- When surveying a research area
|
||||
3. none: arXiv search not needed for this query
|
||||
- When topic is outside arXiv's scope
|
||||
- For non-technical or clinical research
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "BERT transformer architecture"
|
||||
<search_type>basic</search_type>
|
||||
|
||||
2. Query: "latest developments in machine learning"
|
||||
<search_type>category</search_type>
|
||||
|
||||
3. Query: "COVID-19 clinical trials"
|
||||
<search_type>none</search_type>
|
||||
|
||||
4. Query: "psychological effects of social media"
|
||||
<search_type>none</search_type>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<search_type>Choose either 'basic', 'category', or 'none'</search_type>"""
|
||||
|
||||
# Query optimization prompt
|
||||
ARXIV_QUERY_PROMPT = """Optimize the following query for arXiv search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized arXiv search query using boolean operators and field tags.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Available field tags:
|
||||
- ti: Search in title
|
||||
- abs: Search in abstract
|
||||
- au: Search for author
|
||||
- all: Search in all fields (default)
|
||||
|
||||
Boolean operators:
|
||||
- AND: Both terms must appear
|
||||
- OR: Either term can appear
|
||||
- NOT: Exclude terms
|
||||
- (): Group terms
|
||||
- "": Exact phrase match
|
||||
|
||||
Examples:
|
||||
|
||||
1. Natural query: "Recent papers about transformer models by Vaswani"
|
||||
<query>ti:"transformer model" AND au:Vaswani AND year:[2017 TO 2024]</query>
|
||||
|
||||
2. Natural query: "Deep learning for computer vision, excluding surveys"
|
||||
<query>ti:(deep learning AND "computer vision") NOT (ti:survey OR ti:review)</query>
|
||||
|
||||
3. Natural query: "Attention mechanism in language models"
|
||||
<query>ti:(attention OR "attention mechanism") AND abs:"language model"</query>
|
||||
|
||||
4. Natural query: "GANs or generative adversarial networks for image generation"
|
||||
<query>(ti:GAN OR ti:"generative adversarial network") AND abs:"image generation"</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized search query using appropriate operators and tags</query>
|
||||
|
||||
Note:
|
||||
- Use quotes for exact phrases
|
||||
- Combine multiple conditions with boolean operators
|
||||
- Consider both title and abstract for important concepts
|
||||
- Include author names when relevant
|
||||
- Use parentheses for complex logical groupings"""
|
||||
|
||||
# Sort parameters prompt
|
||||
ARXIV_SORT_PROMPT = """Determine optimal sorting parameters for the research query.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Select the most appropriate sorting parameters to help users find the most relevant papers.
|
||||
|
||||
Available sorting options:
|
||||
|
||||
1. Sort by:
|
||||
- relevance: Best match to query terms (default)
|
||||
- lastUpdatedDate: Most recently updated papers
|
||||
- submittedDate: Most recently submitted papers
|
||||
|
||||
2. Sort order:
|
||||
- descending: Newest/Most relevant first (default)
|
||||
- ascending: Oldest/Least relevant first
|
||||
|
||||
3. Result limit:
|
||||
- Minimum: 10 papers
|
||||
- Maximum: 50 papers
|
||||
- Recommended: 20-30 papers for most queries
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest developments in transformer models"
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>30</limit>
|
||||
|
||||
2. Query: "Foundational papers about neural networks"
|
||||
<sort_by>relevance</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>20</limit>
|
||||
|
||||
3. Query: "Evolution of deep learning since 2012"
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>ascending</sort_order>
|
||||
<limit>50</limit>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<sort_by>Choose: relevance, lastUpdatedDate, or submittedDate</sort_by>
|
||||
<sort_order>Choose: ascending or descending</sort_order>
|
||||
<limit>Suggest number between 10-50</limit>
|
||||
|
||||
Note:
|
||||
- Choose relevance for specific technical queries
|
||||
- Use lastUpdatedDate for tracking paper revisions
|
||||
- Use submittedDate for following recent developments
|
||||
- Consider query context when setting the limit"""
|
||||
|
||||
# System prompts for each task
|
||||
ARXIV_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||
Your task is to determine whether the query is better suited for keyword search or category-based search.
|
||||
Consider the query's specificity, scope, and intended search area when making your decision.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
ARXIV_QUERY_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
|
||||
Your task is to optimize natural language queries using boolean operators and field tags.
|
||||
Focus on creating precise, targeted queries that will return the most relevant results.
|
||||
Always generate English search terms regardless of the input language."""
|
||||
|
||||
ARXIV_CATEGORIES_SYSTEM_PROMPT = """You are an expert at arXiv category classification.
|
||||
Your task is to select the most relevant categories for the given research query.
|
||||
Consider both primary and related interdisciplinary categories, while maintaining focus on the main research area.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
ARXIV_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
|
||||
Your task is to determine the best sorting parameters based on the query context.
|
||||
Consider the user's likely intent and temporal aspects of the research topic.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
# 添加新的搜索提示词
|
||||
ARXIV_SEARCH_PROMPT = """Analyze and optimize the research query for arXiv search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized arXiv search query.
|
||||
|
||||
Available search options:
|
||||
1. Basic search with field tags:
|
||||
- ti: Search in title
|
||||
- abs: Search in abstract
|
||||
- au: Search for author
|
||||
Example: "ti:transformer AND abs:attention"
|
||||
|
||||
2. Category-based search:
|
||||
- Use specific arXiv categories
|
||||
Example: "cat:cs.AI AND neural networks"
|
||||
|
||||
3. Date range:
|
||||
- Specify date range using submittedDate
|
||||
Example: "deep learning AND submittedDate:[20200101 TO 20231231]"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Recent papers about transformer models by Vaswani"
|
||||
<search_criteria>
|
||||
<query>ti:"transformer model" AND au:Vaswani AND submittedDate:[20170101 TO 99991231]</query>
|
||||
<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>30</limit>
|
||||
</search_criteria>
|
||||
|
||||
2. Query: "Latest developments in computer vision"
|
||||
<search_criteria>
|
||||
<query>cat:cs.CV AND submittedDate:[20220101 TO 99991231]</query>
|
||||
<categories>cs.CV, cs.AI, cs.LG</categories>
|
||||
<sort_by>submittedDate</sort_by>
|
||||
<sort_order>descending</sort_order>
|
||||
<limit>25</limit>
|
||||
</search_criteria>
|
||||
|
||||
Please analyze the query and respond with XML tags containing search criteria."""
|
||||
|
||||
ARXIV_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting arXiv search queries.
|
||||
Your task is to analyze research queries and transform them into optimized arXiv search criteria.
|
||||
Consider query intent, relevant categories, and temporal aspects when creating the search parameters.
|
||||
Always generate English search terms and respond in English regardless of the input language."""
|
||||
|
||||
# Categories selection prompt
|
||||
ARXIV_CATEGORIES_PROMPT = """Select the most relevant arXiv categories for the research query.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Choose 2-4 most relevant categories that best match the research topic.
|
||||
|
||||
Available Categories:
|
||||
|
||||
Computer Science (cs):
|
||||
- cs.AI: Artificial Intelligence (neural networks, machine learning, NLP)
|
||||
- cs.CL: Computation and Language (NLP, machine translation)
|
||||
- cs.CV: Computer Vision and Pattern Recognition
|
||||
- cs.LG: Machine Learning (deep learning, reinforcement learning)
|
||||
- cs.NE: Neural and Evolutionary Computing
|
||||
- cs.RO: Robotics
|
||||
- cs.IR: Information Retrieval
|
||||
- cs.SE: Software Engineering
|
||||
- cs.DB: Databases
|
||||
- cs.DC: Distributed Computing
|
||||
- cs.CY: Computers and Society
|
||||
- cs.HC: Human-Computer Interaction
|
||||
|
||||
Mathematics (math):
|
||||
- math.OC: Optimization and Control
|
||||
- math.PR: Probability
|
||||
- math.ST: Statistics
|
||||
- math.NA: Numerical Analysis
|
||||
- math.DS: Dynamical Systems
|
||||
|
||||
Statistics (stat):
|
||||
- stat.ML: Machine Learning
|
||||
- stat.ME: Methodology
|
||||
- stat.TH: Theory
|
||||
- stat.AP: Applications
|
||||
|
||||
Physics (physics):
|
||||
- physics.comp-ph: Computational Physics
|
||||
- physics.data-an: Data Analysis
|
||||
- physics.soc-ph: Physics and Society
|
||||
|
||||
Electrical Engineering (eess):
|
||||
- eess.SP: Signal Processing
|
||||
- eess.AS: Audio and Speech Processing
|
||||
- eess.IV: Image and Video Processing
|
||||
- eess.SY: Systems and Control
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Deep learning for computer vision"
|
||||
<categories>cs.CV, cs.LG, stat.ML</categories>
|
||||
|
||||
2. Query: "Natural language processing with transformers"
|
||||
<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||
|
||||
3. Query: "Reinforcement learning for robotics"
|
||||
<categories>cs.RO, cs.AI, cs.LG</categories>
|
||||
|
||||
4. Query: "Statistical methods in machine learning"
|
||||
<categories>stat.ML, cs.LG, math.ST</categories>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<categories>List 2-4 most relevant categories, comma-separated</categories>
|
||||
|
||||
Note:
|
||||
- Choose primary categories first, then add related ones
|
||||
- Limit to 2-4 most relevant categories
|
||||
- Order by relevance (most relevant first)
|
||||
- Use comma and space between categories (e.g., "cs.AI, cs.LG")"""
|
||||
|
||||
# 在文件末尾添加新的 prompt
|
||||
ARXIV_LATEST_PROMPT = """Determine if the query is requesting latest papers from arXiv.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Analyze if the query is specifically asking for recent/latest papers from arXiv.
|
||||
|
||||
IMPORTANT RULE:
|
||||
- The query MUST explicitly mention "arXiv" or "arxiv" to be considered a latest arXiv papers request
|
||||
- Queries only asking for recent/latest papers WITHOUT mentioning arXiv should return false
|
||||
|
||||
Indicators for latest papers request:
|
||||
1. MUST HAVE keywords about arXiv:
|
||||
- "arxiv"
|
||||
- "arXiv"
|
||||
AND
|
||||
|
||||
2. Keywords about recency:
|
||||
- "latest"
|
||||
- "recent"
|
||||
- "new"
|
||||
- "newest"
|
||||
- "just published"
|
||||
- "this week/month"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Latest papers request (Valid):
|
||||
Query: "Show me the latest AI papers on arXiv"
|
||||
<is_latest_request>true</is_latest_request>
|
||||
|
||||
2. Latest papers request (Valid):
|
||||
Query: "What are the recent papers about transformers on arxiv"
|
||||
<is_latest_request>true</is_latest_request>
|
||||
|
||||
3. Not a latest papers request (Invalid - no mention of arXiv):
|
||||
Query: "Show me the latest papers about BERT"
|
||||
<is_latest_request>false</is_latest_request>
|
||||
|
||||
4. Not a latest papers request (Invalid - no recency):
|
||||
Query: "Find papers on arxiv about transformers"
|
||||
<is_latest_request>false</is_latest_request>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<is_latest_request>true/false</is_latest_request>
|
||||
|
||||
Note: The response should be true ONLY if both conditions are met:
|
||||
1. Query explicitly mentions arXiv/arxiv
|
||||
2. Query asks for recent/latest papers"""
|
||||
|
||||
ARXIV_LATEST_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||
Your task is to determine if the query is specifically requesting latest/recent papers from arXiv.
|
||||
Remember: The query MUST explicitly mention arXiv to be considered valid, even if it asks for recent papers.
|
||||
Always respond in English regardless of the input language."""
|
||||
@@ -0,0 +1,55 @@
|
||||
# Crossref query optimization prompt
|
||||
CROSSREF_QUERY_PROMPT = """Analyze and optimize the query for Crossref search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized Crossref search query.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Available search fields and filters:
|
||||
1. Basic fields:
|
||||
- title: Search in title
|
||||
- abstract: Search in abstract
|
||||
- author: Search for author names
|
||||
- container-title: Search in journal/conference name
|
||||
- publisher: Search by publisher name
|
||||
- type: Filter by work type (journal-article, book-chapter, etc.)
|
||||
- year: Filter by publication year
|
||||
|
||||
2. Boolean operators:
|
||||
- AND: Both terms must appear
|
||||
- OR: Either term can appear
|
||||
- NOT: Exclude terms
|
||||
- "": Exact phrase match
|
||||
|
||||
3. Special filters:
|
||||
- is-referenced-by-count: Filter by citation count
|
||||
- from-pub-date: Filter by publication date
|
||||
- has-abstract: Filter papers with abstracts
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Machine learning in healthcare after 2020"
|
||||
<query>title:"machine learning" AND title:healthcare AND from-pub-date:2020</query>
|
||||
|
||||
2. Query: "Papers by Geoffrey Hinton about deep learning"
|
||||
<query>author:"Hinton, Geoffrey" AND (title:"deep learning" OR abstract:"deep learning")</query>
|
||||
|
||||
3. Query: "Most cited papers about transformers in Nature"
|
||||
<query>title:transformer AND container-title:Nature AND is-referenced-by-count:[100 TO *]</query>
|
||||
|
||||
4. Query: "Recent BERT applications in medical domain"
|
||||
<query>title:BERT AND abstract:medical AND from-pub-date:2020 AND type:journal-article</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized Crossref search query using appropriate fields and operators</query>"""
|
||||
|
||||
# System prompt
|
||||
CROSSREF_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Crossref search queries.
|
||||
Your task is to optimize natural language queries for Crossref's API.
|
||||
Focus on creating precise queries that will return relevant results.
|
||||
Always generate English search terms regardless of the input language.
|
||||
Consider using field-specific search terms and appropriate filters to improve search accuracy."""
|
||||
@@ -0,0 +1,47 @@
|
||||
# 新建文件,添加论文识别提示
|
||||
PAPER_IDENTIFY_PROMPT = """Analyze the query to identify paper details.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Extract paper identification information from the query.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on identifying paper details.
|
||||
|
||||
Possible paper identifiers:
|
||||
1. arXiv ID (e.g., 2103.14030, arXiv:2103.14030)
|
||||
2. DOI (e.g., 10.1234/xxx.xxx)
|
||||
3. Paper title (e.g., "Attention is All You Need")
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query with arXiv ID:
|
||||
Query: "Analyze paper 2103.14030"
|
||||
<paper_info>
|
||||
<paper_source>arxiv</paper_source>
|
||||
<paper_id>2103.14030</paper_id>
|
||||
<paper_title></paper_title>
|
||||
</paper_info>
|
||||
|
||||
2. Query with DOI:
|
||||
Query: "Review the paper with DOI 10.1234/xxx.xxx"
|
||||
<paper_info>
|
||||
<paper_source>doi</paper_source>
|
||||
<paper_id>10.1234/xxx.xxx</paper_id>
|
||||
<paper_title></paper_title>
|
||||
</paper_info>
|
||||
|
||||
3. Query with paper title:
|
||||
Query: "Analyze 'Attention is All You Need' paper"
|
||||
<paper_info>
|
||||
<paper_source>title</paper_source>
|
||||
<paper_id></paper_id>
|
||||
<paper_title>Attention is All You Need</paper_title>
|
||||
</paper_info>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags containing paper information."""
|
||||
|
||||
PAPER_IDENTIFY_SYSTEM_PROMPT = """You are an expert at identifying academic paper references.
|
||||
Your task is to extract paper identification information from queries.
|
||||
Look for arXiv IDs, DOIs, and paper titles."""
|
||||
@@ -0,0 +1,108 @@
|
||||
# PubMed search type prompt
|
||||
PUBMED_TYPE_PROMPT = """Analyze the research query and determine the appropriate PubMed search type.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Available search types:
|
||||
1. basic: General keyword search for medical/biomedical topics
|
||||
2. author: Search by author name
|
||||
3. journal: Search within specific journals
|
||||
4. none: Query not related to medical/biomedical research
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "COVID-19 treatment outcomes"
|
||||
<search_type>basic</search_type>
|
||||
|
||||
2. Query: "Papers by Anthony Fauci"
|
||||
<search_type>author</search_type>
|
||||
|
||||
3. Query: "Recent papers in Nature about CRISPR"
|
||||
<search_type>journal</search_type>
|
||||
|
||||
4. Query: "Deep learning for computer vision"
|
||||
<search_type>none</search_type>
|
||||
|
||||
5. Query: "Transformer architecture for NLP"
|
||||
<search_type>none</search_type>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<search_type>Choose: basic, author, journal, or none</search_type>"""
|
||||
|
||||
# PubMed query optimization prompt
|
||||
PUBMED_QUERY_PROMPT = """Optimize the following query for PubMed search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized PubMed search query.
|
||||
Requirements:
|
||||
- Always generate English search terms regardless of input language
|
||||
- Translate any non-English terms to English before creating the query
|
||||
- Never include non-English characters in the final query
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core medical/biomedical topic for the search query.
|
||||
|
||||
Available field tags:
|
||||
- [Title] - Search in title
|
||||
- [Author] - Search for author
|
||||
- [Journal] - Search in journal name
|
||||
- [MeSH Terms] - Search using MeSH terms
|
||||
|
||||
Boolean operators:
|
||||
- AND
|
||||
- OR
|
||||
- NOT
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "COVID-19 treatment in elderly patients"
|
||||
<query>COVID-19[Title] AND treatment[Title/Abstract] AND elderly[Title/Abstract]</query>
|
||||
|
||||
2. Query: "Cancer immunotherapy review articles"
|
||||
<query>cancer immunotherapy[Title/Abstract] AND review[Publication Type]</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized PubMed search query</query>"""
|
||||
|
||||
# PubMed sort parameters prompt
|
||||
PUBMED_SORT_PROMPT = """Determine optimal sorting parameters for PubMed results.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Select the most appropriate sorting method and result limit.
|
||||
|
||||
Available sort options:
|
||||
- relevance: Best match to query
|
||||
- date: Most recent first
|
||||
- journal: Sort by journal name
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest developments in gene therapy"
|
||||
<sort_by>date</sort_by>
|
||||
<limit>30</limit>
|
||||
|
||||
2. Query: "Classic papers about DNA structure"
|
||||
<sort_by>relevance</sort_by>
|
||||
<limit>20</limit>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<sort_by>Choose: relevance, date, or journal</sort_by>
|
||||
<limit>Suggest number between 10-50</limit>"""
|
||||
|
||||
# System prompts
|
||||
PUBMED_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing medical and scientific queries.
|
||||
Your task is to determine the most appropriate PubMed search type.
|
||||
Consider the query's focus and intended search scope.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
PUBMED_QUERY_SYSTEM_PROMPT = """You are an expert at crafting PubMed search queries.
|
||||
Your task is to optimize natural language queries using PubMed's search syntax.
|
||||
Focus on creating precise, targeted queries that will return relevant medical literature.
|
||||
Always generate English search terms regardless of the input language."""
|
||||
|
||||
PUBMED_SORT_SYSTEM_PROMPT = """You are an expert at optimizing PubMed search results.
|
||||
Your task is to determine the best sorting parameters based on the query context.
|
||||
Consider the balance between relevance and recency.
|
||||
Always respond in English regardless of the input language."""
|
||||
@@ -0,0 +1,276 @@
|
||||
# Search type prompt
|
||||
SEMANTIC_TYPE_PROMPT = """Determine the most appropriate search type for Semantic Scholar.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Analyze the research query and select the most appropriate search type for Semantic Scholar API.
|
||||
|
||||
Available search types:
|
||||
|
||||
1. paper: General paper search
|
||||
- Use for broad topic searches
|
||||
- Looking for specific papers
|
||||
- Keyword-based searches
|
||||
Example: "transformer models in NLP"
|
||||
|
||||
2. author: Author-based search
|
||||
- Finding works by specific researchers
|
||||
- Author profile analysis
|
||||
Example: "papers by Yoshua Bengio"
|
||||
|
||||
3. paper_details: Specific paper lookup
|
||||
- Getting details about a known paper
|
||||
- Finding specific versions or citations
|
||||
Example: "Attention is All You Need paper details"
|
||||
|
||||
4. citations: Citation analysis
|
||||
- Finding papers that cite a specific work
|
||||
- Impact analysis
|
||||
Example: "papers citing BERT"
|
||||
|
||||
5. references: Reference analysis
|
||||
- Finding papers cited by a specific work
|
||||
- Background research
|
||||
Example: "references in GPT-3 paper"
|
||||
|
||||
6. recommendations: Paper recommendations
|
||||
- Finding similar papers
|
||||
- Research direction exploration
|
||||
Example: "papers similar to Transformer"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest papers about deep learning"
|
||||
<search_type>paper</search_type>
|
||||
|
||||
2. Query: "Works by Geoffrey Hinton since 2020"
|
||||
<search_type>author</search_type>
|
||||
|
||||
3. Query: "Papers citing the original Transformer paper"
|
||||
<search_type>citations</search_type>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<search_type>Choose the most appropriate search type from the list above</search_type>"""
|
||||
|
||||
# Query optimization prompt
|
||||
SEMANTIC_QUERY_PROMPT = """Optimize the following query for Semantic Scholar search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into an optimized search query for maximum relevance.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements. Focus only on the core research topic for the search query.
|
||||
|
||||
Query optimization guidelines:
|
||||
|
||||
1. Use quotes for exact phrases
|
||||
- Ensures exact matching
|
||||
- Reduces irrelevant results
|
||||
Example: "\"attention mechanism\"" vs attention mechanism
|
||||
|
||||
2. Include key technical terms
|
||||
- Use specific technical terminology
|
||||
- Include common variations
|
||||
Example: "transformer architecture" neural networks
|
||||
|
||||
3. Author names (if relevant)
|
||||
- Include full names when known
|
||||
- Consider common name variations
|
||||
Example: "Geoffrey Hinton" OR "G. E. Hinton"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Natural query: "Recent advances in transformer models"
|
||||
<query>"transformer model" "neural architecture" deep learning</query>
|
||||
|
||||
2. Natural query: "BERT applications in text classification"
|
||||
<query>"BERT" "text classification" "language model" application</query>
|
||||
|
||||
3. Natural query: "Deep learning for computer vision by Kaiming He"
|
||||
<query>"deep learning" "computer vision" author:"Kaiming He"</query>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<query>Provide the optimized search query</query>
|
||||
|
||||
Note:
|
||||
- Balance between specificity and coverage
|
||||
- Include important technical terms
|
||||
- Use quotes for key phrases
|
||||
- Consider synonyms and related terms"""
|
||||
|
||||
# Fields selection prompt
|
||||
SEMANTIC_FIELDS_PROMPT = """Select relevant fields to retrieve from Semantic Scholar.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Determine which paper fields should be retrieved based on the research needs.
|
||||
|
||||
Available fields:
|
||||
|
||||
Core fields:
|
||||
- title: Paper title (always included)
|
||||
- abstract: Full paper abstract
|
||||
- authors: Author information
|
||||
- year: Publication year
|
||||
- venue: Publication venue
|
||||
|
||||
Citation fields:
|
||||
- citations: Papers citing this work
|
||||
- references: Papers cited by this work
|
||||
|
||||
Additional fields:
|
||||
- embedding: Paper vector embedding
|
||||
- tldr: AI-generated summary
|
||||
- venue: Publication venue/journal
|
||||
- url: Paper URL
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Latest developments in NLP"
|
||||
<fields>title, abstract, authors, year, venue, citations</fields>
|
||||
|
||||
2. Query: "Most influential papers in deep learning"
|
||||
<fields>title, abstract, authors, year, citations, references</fields>
|
||||
|
||||
3. Query: "Survey of transformer architectures"
|
||||
<fields>title, abstract, authors, year, tldr, references</fields>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<fields>List relevant fields, comma-separated</fields>
|
||||
|
||||
Note:
|
||||
- Choose fields based on the query's purpose
|
||||
- Include citation data for impact analysis
|
||||
- Consider tldr for quick paper screening
|
||||
- Balance completeness with API efficiency"""
|
||||
|
||||
# Sort parameters prompt
|
||||
SEMANTIC_SORT_PROMPT = """Determine optimal sorting parameters for the query.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Select the most appropriate sorting method and result limit for the search.
|
||||
Always generate English search terms regardless of the input language.
|
||||
|
||||
Sorting options:
|
||||
|
||||
1. relevance (default)
|
||||
- Best match to query terms
|
||||
- Recommended for specific technical searches
|
||||
Example: "specific algorithm implementations"
|
||||
|
||||
2. citations
|
||||
- Sort by citation count
|
||||
- Best for finding influential papers
|
||||
Example: "most important papers in deep learning"
|
||||
|
||||
3. year
|
||||
- Sort by publication date
|
||||
- Best for following recent developments
|
||||
Example: "latest advances in NLP"
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Recent breakthroughs in AI"
|
||||
<sort_by>year</sort_by>
|
||||
<limit>30</limit>
|
||||
|
||||
2. Query: "Most influential papers about GANs"
|
||||
<sort_by>citations</sort_by>
|
||||
<limit>20</limit>
|
||||
|
||||
3. Query: "Specific papers about BERT fine-tuning"
|
||||
<sort_by>relevance</sort_by>
|
||||
<limit>25</limit>
|
||||
|
||||
Please analyze the query and respond ONLY with XML tags:
|
||||
<sort_by>Choose: relevance, citations, or year</sort_by>
|
||||
<limit>Suggest number between 10-50</limit>
|
||||
|
||||
Note:
|
||||
- Consider the query's temporal aspects
|
||||
- Balance between comprehensive coverage and information overload
|
||||
- Use citation sorting for impact analysis
|
||||
- Use year sorting for tracking developments"""
|
||||
|
||||
# System prompts for each task
|
||||
SEMANTIC_TYPE_SYSTEM_PROMPT = """You are an expert at analyzing academic queries.
|
||||
Your task is to determine the most appropriate type of search on Semantic Scholar.
|
||||
Consider the query's intent, scope, and specific research needs.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
SEMANTIC_QUERY_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
|
||||
Your task is to optimize natural language queries for maximum relevance.
|
||||
Focus on creating precise queries that leverage the platform's search capabilities.
|
||||
Always generate English search terms regardless of the input language."""
|
||||
|
||||
SEMANTIC_FIELDS_SYSTEM_PROMPT = """You are an expert at Semantic Scholar data fields.
|
||||
Your task is to select the most relevant fields based on the research context.
|
||||
Consider both essential and supplementary information needs.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
SEMANTIC_SORT_SYSTEM_PROMPT = """You are an expert at optimizing search results.
|
||||
Your task is to determine the best sorting parameters based on the query context.
|
||||
Consider the balance between relevance, impact, and recency.
|
||||
Always respond in English regardless of the input language."""
|
||||
|
||||
# 添加新的综合搜索提示词
|
||||
SEMANTIC_SEARCH_PROMPT = """Analyze and optimize the research query for Semantic Scholar search.
|
||||
|
||||
Query: {query}
|
||||
|
||||
Task: Transform the natural language query into optimized search criteria for Semantic Scholar.
|
||||
|
||||
IMPORTANT: Ignore any requirements about journal ranking (CAS, JCR, IF index),
|
||||
or output format requirements when generating the search terms. These requirements
|
||||
should be considered only for post-search filtering, not as part of the core query.
|
||||
|
||||
Available search options:
|
||||
1. Paper search:
|
||||
- Title and abstract search
|
||||
- Author search
|
||||
- Field-specific search
|
||||
Example: "transformer architecture neural networks"
|
||||
|
||||
2. Field tags:
|
||||
- title: Search in title
|
||||
- abstract: Search in abstract
|
||||
- authors: Search by author names
|
||||
- venue: Search by publication venue
|
||||
Example: "title:transformer authors:\"Vaswani\""
|
||||
|
||||
3. Advanced options:
|
||||
- Year range filtering
|
||||
- Citation count filtering
|
||||
- Venue filtering
|
||||
Example: "deep learning year>2020 venue:\"NeurIPS\""
|
||||
|
||||
Examples:
|
||||
|
||||
1. Query: "Recent transformer papers by Vaswani with high impact"
|
||||
<search_criteria>
|
||||
<query>title:transformer authors:"Vaswani" year>2017</query>
|
||||
<search_type>paper</search_type>
|
||||
<fields>title,abstract,authors,year,citations,venue</fields>
|
||||
<sort_by>citations</sort_by>
|
||||
<limit>30</limit>
|
||||
</search_criteria>
|
||||
|
||||
2. Query: "Most cited papers about BERT in top conferences"
|
||||
<search_criteria>
|
||||
<query>title:BERT venue:"ACL|EMNLP|NAACL"</query>
|
||||
<search_type>paper</search_type>
|
||||
<fields>title,abstract,authors,year,citations,venue,references</fields>
|
||||
<sort_by>citations</sort_by>
|
||||
<limit>25</limit>
|
||||
</search_criteria>
|
||||
|
||||
Please analyze the query and respond with XML tags containing complete search criteria."""
|
||||
|
||||
SEMANTIC_SEARCH_SYSTEM_PROMPT = """You are an expert at crafting Semantic Scholar search queries.
|
||||
Your task is to analyze research queries and transform them into optimized search criteria.
|
||||
Consider query intent, field relevance, and citation impact when creating the search parameters.
|
||||
Focus on producing precise and comprehensive search criteria that will yield the most relevant results.
|
||||
Always generate English search terms and respond in English regardless of the input language."""
|
||||
@@ -0,0 +1,493 @@
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
from textwrap import dedent
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchCriteria:
|
||||
"""搜索条件"""
|
||||
query_type: str # 查询类型: review/recommend/qa/paper
|
||||
main_topic: str # 主题
|
||||
sub_topics: List[str] # 子主题列表
|
||||
start_year: int # 起始年份
|
||||
end_year: int # 结束年份
|
||||
arxiv_params: Dict # arXiv搜索参数
|
||||
semantic_params: Dict # Semantic Scholar搜索参数
|
||||
pubmed_params: Dict # 新增: PubMed搜索参数
|
||||
crossref_params: Dict # 添加 Crossref 参数
|
||||
adsabs_params: Dict # 添加 ADS 参数
|
||||
paper_id: str = "" # 论文ID (arxiv ID 或 DOI)
|
||||
paper_title: str = "" # 论文标题
|
||||
paper_source: str = "" # 论文来源 (arxiv/doi/title)
|
||||
original_query: str = "" # 新增: 原始查询字符串
|
||||
|
||||
|
||||
class QueryAnalyzer:
|
||||
"""查询分析器"""
|
||||
|
||||
# 响应索引常量
|
||||
BASIC_QUERY_INDEX = 0
|
||||
PAPER_IDENTIFY_INDEX = 1
|
||||
ARXIV_QUERY_INDEX = 2
|
||||
ARXIV_CATEGORIES_INDEX = 3
|
||||
ARXIV_LATEST_INDEX = 4
|
||||
ARXIV_SORT_INDEX = 5
|
||||
SEMANTIC_QUERY_INDEX = 6
|
||||
SEMANTIC_FIELDS_INDEX = 7
|
||||
PUBMED_TYPE_INDEX = 8
|
||||
PUBMED_QUERY_INDEX = 9
|
||||
CROSSREF_QUERY_INDEX = 10
|
||||
ADSABS_QUERY_INDEX = 11
|
||||
|
||||
def __init__(self):
|
||||
self.current_year = datetime.now().year
|
||||
self.valid_types = {
|
||||
"review": ["review", "literature review", "survey"],
|
||||
"recommend": ["recommend", "recommendation", "suggest", "similar"],
|
||||
"qa": ["qa", "question", "answer", "explain", "what", "how", "why"],
|
||||
"paper": ["paper", "analyze", "analysis"]
|
||||
}
|
||||
|
||||
def analyze_query(self, query: str, chatbot: List, llm_kwargs: Dict):
|
||||
"""分析查询意图"""
|
||||
from crazy_functions.crazy_utils import \
|
||||
request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency as request_gpt
|
||||
from crazy_functions.review_fns.prompts.arxiv_prompts import (
|
||||
ARXIV_QUERY_PROMPT, ARXIV_CATEGORIES_PROMPT, ARXIV_LATEST_PROMPT,
|
||||
ARXIV_SORT_PROMPT, ARXIV_QUERY_SYSTEM_PROMPT, ARXIV_CATEGORIES_SYSTEM_PROMPT, ARXIV_SORT_SYSTEM_PROMPT,
|
||||
ARXIV_LATEST_SYSTEM_PROMPT
|
||||
)
|
||||
from crazy_functions.review_fns.prompts.semantic_prompts import (
|
||||
SEMANTIC_QUERY_PROMPT, SEMANTIC_FIELDS_PROMPT,
|
||||
SEMANTIC_QUERY_SYSTEM_PROMPT, SEMANTIC_FIELDS_SYSTEM_PROMPT
|
||||
)
|
||||
from .prompts.paper_prompts import PAPER_IDENTIFY_PROMPT, PAPER_IDENTIFY_SYSTEM_PROMPT
|
||||
from .prompts.pubmed_prompts import (
|
||||
PUBMED_TYPE_PROMPT, PUBMED_QUERY_PROMPT, PUBMED_SORT_PROMPT,
|
||||
PUBMED_TYPE_SYSTEM_PROMPT, PUBMED_QUERY_SYSTEM_PROMPT, PUBMED_SORT_SYSTEM_PROMPT
|
||||
)
|
||||
from .prompts.crossref_prompts import (
|
||||
CROSSREF_QUERY_PROMPT,
|
||||
CROSSREF_QUERY_SYSTEM_PROMPT
|
||||
)
|
||||
from .prompts.adsabs_prompts import ADSABS_QUERY_PROMPT, ADSABS_QUERY_SYSTEM_PROMPT
|
||||
|
||||
# 1. 基本查询分析
|
||||
type_prompt = dedent(f"""Please analyze this academic query and respond STRICTLY in the following XML format:
|
||||
|
||||
Query: {query}
|
||||
|
||||
Instructions:
|
||||
1. Your response must use XML tags exactly as shown below
|
||||
2. Do not add any text outside the tags
|
||||
3. Choose query type from: review/recommend/qa/paper
|
||||
- review: for literature review or survey requests
|
||||
- recommend: for paper recommendation requests
|
||||
- qa: for general questions about research topics
|
||||
- paper: ONLY for queries about a SPECIFIC paper (with paper ID, DOI, or exact title)
|
||||
4. Identify main topic and subtopics
|
||||
5. Specify year range if mentioned
|
||||
|
||||
Required format:
|
||||
<query_type>ANSWER HERE</query_type>
|
||||
<main_topic>ANSWER HERE</main_topic>
|
||||
<sub_topics>SUBTOPIC1, SUBTOPIC2, ...</sub_topics>
|
||||
<year_range>START_YEAR-END_YEAR</year_range>
|
||||
|
||||
Example responses:
|
||||
|
||||
1. Literature Review Request:
|
||||
Query: "Review recent developments in transformer models for NLP from 2020 to 2023"
|
||||
<query_type>review</query_type>
|
||||
<main_topic>transformer models in natural language processing</main_topic>
|
||||
<sub_topics>architecture improvements, pre-training methods, fine-tuning techniques</sub_topics>
|
||||
<year_range>2020-2023</year_range>
|
||||
|
||||
2. Paper Recommendation Request:
|
||||
Query: "Suggest papers about reinforcement learning in robotics since 2018"
|
||||
<query_type>recommend</query_type>
|
||||
<main_topic>reinforcement learning in robotics</main_topic>
|
||||
<sub_topics>robot control, policy learning, sim-to-real transfer</sub_topics>
|
||||
<year_range>2018-2023</year_range>"""
|
||||
)
|
||||
|
||||
try:
|
||||
# 构建提示数组
|
||||
prompts = [
|
||||
type_prompt,
|
||||
PAPER_IDENTIFY_PROMPT.format(query=query),
|
||||
ARXIV_QUERY_PROMPT.format(query=query),
|
||||
ARXIV_CATEGORIES_PROMPT.format(query=query),
|
||||
ARXIV_LATEST_PROMPT.format(query=query),
|
||||
ARXIV_SORT_PROMPT.format(query=query),
|
||||
SEMANTIC_QUERY_PROMPT.format(query=query),
|
||||
SEMANTIC_FIELDS_PROMPT.format(query=query),
|
||||
PUBMED_TYPE_PROMPT.format(query=query),
|
||||
PUBMED_QUERY_PROMPT.format(query=query),
|
||||
CROSSREF_QUERY_PROMPT.format(query=query),
|
||||
ADSABS_QUERY_PROMPT.format(query=query)
|
||||
]
|
||||
|
||||
show_messages = [
|
||||
"Analyzing query type...",
|
||||
"Identifying paper details...",
|
||||
"Determining arXiv search type...",
|
||||
"Selecting arXiv categories...",
|
||||
"Checking if latest papers requested...",
|
||||
"Determining arXiv sort parameters...",
|
||||
"Optimizing Semantic Scholar query...",
|
||||
"Selecting Semantic Scholar fields...",
|
||||
"Determining PubMed search type...",
|
||||
"Optimizing PubMed query...",
|
||||
"Optimizing Crossref query...",
|
||||
"Optimizing ADS query..."
|
||||
]
|
||||
|
||||
sys_prompts = [
|
||||
"You are an expert at analyzing academic queries.",
|
||||
PAPER_IDENTIFY_SYSTEM_PROMPT,
|
||||
ARXIV_QUERY_SYSTEM_PROMPT,
|
||||
ARXIV_CATEGORIES_SYSTEM_PROMPT,
|
||||
ARXIV_LATEST_SYSTEM_PROMPT,
|
||||
ARXIV_SORT_SYSTEM_PROMPT,
|
||||
SEMANTIC_QUERY_SYSTEM_PROMPT,
|
||||
SEMANTIC_FIELDS_SYSTEM_PROMPT,
|
||||
PUBMED_TYPE_SYSTEM_PROMPT,
|
||||
PUBMED_QUERY_SYSTEM_PROMPT,
|
||||
CROSSREF_QUERY_SYSTEM_PROMPT,
|
||||
ADSABS_QUERY_SYSTEM_PROMPT
|
||||
]
|
||||
new_llm_kwargs = llm_kwargs.copy()
|
||||
# new_llm_kwargs['llm_model'] = 'deepseek-chat' # deepseek-ai/DeepSeek-V2.5
|
||||
|
||||
# 使用同步方式调用LLM
|
||||
responses = yield from request_gpt(
|
||||
inputs_array=prompts,
|
||||
inputs_show_user_array=show_messages,
|
||||
llm_kwargs=new_llm_kwargs,
|
||||
chatbot=chatbot,
|
||||
history_array=[[] for _ in prompts],
|
||||
sys_prompt_array=sys_prompts,
|
||||
max_workers=5
|
||||
)
|
||||
|
||||
# 从收集的响应中提取我们需要的内容
|
||||
extracted_responses = []
|
||||
for i in range(len(prompts)):
|
||||
if (i * 2 + 1) < len(responses):
|
||||
response = responses[i * 2 + 1]
|
||||
if response is None:
|
||||
raise Exception(f"Response {i} is None")
|
||||
if not isinstance(response, str):
|
||||
try:
|
||||
response = str(response)
|
||||
except:
|
||||
raise Exception(f"Cannot convert response {i} to string")
|
||||
extracted_responses.append(response)
|
||||
else:
|
||||
raise Exception(f"未收到第 {i + 1} 个响应")
|
||||
|
||||
# 解析基本信息
|
||||
query_type = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "query_type")
|
||||
if not query_type:
|
||||
print(
|
||||
f"Debug - Failed to extract query_type. Response was: {extracted_responses[self.BASIC_QUERY_INDEX]}")
|
||||
raise Exception("无法提取query_type标签内容")
|
||||
query_type = query_type.lower()
|
||||
|
||||
main_topic = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "main_topic")
|
||||
if not main_topic:
|
||||
print(f"Debug - Failed to extract main_topic. Using query as fallback.")
|
||||
main_topic = query
|
||||
|
||||
query_type = self._normalize_query_type(query_type, query)
|
||||
|
||||
# 解析arXiv参数
|
||||
try:
|
||||
arxiv_params = {
|
||||
"query": self._extract_tag(extracted_responses[self.ARXIV_QUERY_INDEX], "query"),
|
||||
"categories": [cat.strip() for cat in
|
||||
self._extract_tag(extracted_responses[self.ARXIV_CATEGORIES_INDEX],
|
||||
"categories").split(",")],
|
||||
"sort_by": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_by"),
|
||||
"sort_order": self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "sort_order"),
|
||||
"limit": 20
|
||||
}
|
||||
|
||||
# 安全地解析limit值
|
||||
limit_str = self._extract_tag(extracted_responses[self.ARXIV_SORT_INDEX], "limit")
|
||||
if limit_str and limit_str.isdigit():
|
||||
arxiv_params["limit"] = int(limit_str)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing arXiv parameters: {str(e)}")
|
||||
arxiv_params = {
|
||||
"query": "",
|
||||
"categories": [],
|
||||
"sort_by": "relevance",
|
||||
"sort_order": "descending",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
# 解析Semantic Scholar参数
|
||||
try:
|
||||
semantic_params = {
|
||||
"query": self._extract_tag(extracted_responses[self.SEMANTIC_QUERY_INDEX], "query"),
|
||||
"fields": [field.strip() for field in
|
||||
self._extract_tag(extracted_responses[self.SEMANTIC_FIELDS_INDEX], "fields").split(",")],
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing Semantic Scholar parameters: {str(e)}")
|
||||
semantic_params = {
|
||||
"query": query,
|
||||
"fields": ["title", "abstract", "authors", "year"],
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
|
||||
# 解析PubMed参数
|
||||
try:
|
||||
# 首先检查是否需要PubMed搜索
|
||||
pubmed_search_type = self._extract_tag(extracted_responses[self.PUBMED_TYPE_INDEX], "search_type")
|
||||
|
||||
if pubmed_search_type == "none":
|
||||
# 不需要PubMed搜索,使用空参数
|
||||
pubmed_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
else:
|
||||
# 需要PubMed搜索,解析完整参数
|
||||
pubmed_params = {
|
||||
"search_type": pubmed_search_type,
|
||||
"query": self._extract_tag(extracted_responses[self.PUBMED_QUERY_INDEX], "query"),
|
||||
"sort_by": "relevance",
|
||||
"limit": 200
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing PubMed parameters: {str(e)}")
|
||||
pubmed_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
# 解析Crossref参数
|
||||
try:
|
||||
crossref_query = self._extract_tag(extracted_responses[self.CROSSREF_QUERY_INDEX], "query")
|
||||
|
||||
if not crossref_query:
|
||||
crossref_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
else:
|
||||
crossref_params = {
|
||||
"search_type": "basic",
|
||||
"query": crossref_query,
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing Crossref parameters: {str(e)}")
|
||||
crossref_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
# 解析ADS参数
|
||||
try:
|
||||
adsabs_query = self._extract_tag(extracted_responses[self.ADSABS_QUERY_INDEX], "query")
|
||||
|
||||
if not adsabs_query:
|
||||
adsabs_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
else:
|
||||
adsabs_params = {
|
||||
"search_type": "basic",
|
||||
"query": adsabs_query,
|
||||
"sort_by": "relevance",
|
||||
"limit": 20
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Warning: Error parsing ADS parameters: {str(e)}")
|
||||
adsabs_params = {
|
||||
"search_type": "none",
|
||||
"query": "",
|
||||
"sort_by": "relevance",
|
||||
"limit": 0
|
||||
}
|
||||
|
||||
print(f"Debug - Extracted information:")
|
||||
print(f"Query type: {query_type}")
|
||||
print(f"Main topic: {main_topic}")
|
||||
print(f"arXiv params: {arxiv_params}")
|
||||
print(f"Semantic params: {semantic_params}")
|
||||
print(f"PubMed params: {pubmed_params}")
|
||||
print(f"Crossref params: {crossref_params}")
|
||||
print(f"ADS params: {adsabs_params}")
|
||||
|
||||
# 提取子主题
|
||||
sub_topics = []
|
||||
if "sub_topics" in query.lower():
|
||||
sub_topics_text = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "sub_topics")
|
||||
if sub_topics_text:
|
||||
sub_topics = [topic.strip() for topic in sub_topics_text.split(",")]
|
||||
|
||||
# 提取年份范围
|
||||
start_year = self.current_year - 5 # 默认最近5年
|
||||
end_year = self.current_year
|
||||
year_range = self._extract_tag(extracted_responses[self.BASIC_QUERY_INDEX], "year_range")
|
||||
if year_range:
|
||||
try:
|
||||
years = year_range.split("-")
|
||||
if len(years) == 2:
|
||||
start_year = int(years[0].strip())
|
||||
end_year = int(years[1].strip())
|
||||
except:
|
||||
pass
|
||||
|
||||
# 提取 latest request 判断
|
||||
is_latest_request = self._extract_tag(extracted_responses[self.ARXIV_LATEST_INDEX],
|
||||
"is_latest_request").lower() == "true"
|
||||
|
||||
# 如果是最新论文请求,将查询类型改为 "latest"
|
||||
if is_latest_request:
|
||||
query_type = "latest"
|
||||
|
||||
# 提取论文标识信息
|
||||
paper_source = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_source")
|
||||
paper_id = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_id")
|
||||
paper_title = self._extract_tag(extracted_responses[self.PAPER_IDENTIFY_INDEX], "paper_title")
|
||||
if start_year > end_year:
|
||||
start_year, end_year = end_year, start_year
|
||||
# 更新返回的 SearchCriteria
|
||||
return SearchCriteria(
|
||||
query_type=query_type,
|
||||
main_topic=main_topic,
|
||||
sub_topics=sub_topics,
|
||||
start_year=start_year,
|
||||
end_year=end_year,
|
||||
arxiv_params=arxiv_params,
|
||||
semantic_params=semantic_params,
|
||||
pubmed_params=pubmed_params,
|
||||
crossref_params=crossref_params,
|
||||
paper_id=paper_id,
|
||||
paper_title=paper_title,
|
||||
paper_source=paper_source,
|
||||
original_query=query,
|
||||
adsabs_params=adsabs_params
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to analyze query: {str(e)}")
|
||||
|
||||
def _normalize_query_type(self, query_type: str, query: str) -> str:
|
||||
"""规范化查询类型"""
|
||||
if query_type in ["review", "recommend", "qa", "paper"]:
|
||||
return query_type
|
||||
|
||||
query_lower = query.lower()
|
||||
for type_name, keywords in self.valid_types.items():
|
||||
for keyword in keywords:
|
||||
if keyword in query_lower:
|
||||
return type_name
|
||||
|
||||
query_type_lower = query_type.lower()
|
||||
for type_name, keywords in self.valid_types.items():
|
||||
for keyword in keywords:
|
||||
if keyword in query_type_lower:
|
||||
return type_name
|
||||
|
||||
return "qa" # 默认返回qa类型
|
||||
|
||||
def _extract_tag(self, text: str, tag: str) -> str:
|
||||
"""提取标记内容"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 1. 标准XML格式(处理多行和特殊字符)
|
||||
pattern = f"<{tag}>(.*?)</{tag}>"
|
||||
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
if content:
|
||||
return content
|
||||
|
||||
# 2. 处理特定标签的复杂内容
|
||||
if tag == "categories":
|
||||
# 处理arXiv类别
|
||||
patterns = [
|
||||
# 标准格式:<categories>cs.CL, cs.AI, cs.LG</categories>
|
||||
r"<categories>\s*((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)\s*</categories>",
|
||||
# 简单列表格式:cs.CL, cs.AI, cs.LG
|
||||
r"(?:^|\s)((?:(?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+(?:\s*,\s*)?)+)(?:\s|$)",
|
||||
# 单个类别格式:cs.AI
|
||||
r"(?:^|\s)((?:cs|stat|math|physics|q-bio|q-fin|nlin|astro-ph|cond-mat|gr-qc|hep-[a-z]+|math-ph|nucl-[a-z]+|quant-ph)\.[A-Z]+)(?:\s|$)"
|
||||
]
|
||||
|
||||
elif tag == "query":
|
||||
# 处理搜索查询
|
||||
patterns = [
|
||||
# 完整的查询格式:<query>complex query</query>
|
||||
r"<query>\s*((?:(?:ti|abs|au|cat):[^\n]*?|(?:AND|OR|NOT|\(|\)|\d{4}|year:\d{4}|[\"'][^\"']*[\"']|\s+))+)\s*</query>",
|
||||
# 简单的关键词列表:keyword1, keyword2
|
||||
r"(?:^|\s)((?:\"[^\"]*\"|'[^']*'|[^\s,]+)(?:\s*,\s*(?:\"[^\"]*\"|'[^']*'|[^\s,]+))*)",
|
||||
# 字段搜索格式:field:value
|
||||
r"((?:ti|abs|au|cat):\s*(?:\"[^\"]*\"|'[^']*'|[^\s]+))"
|
||||
]
|
||||
|
||||
elif tag == "fields":
|
||||
# 处理字段列表
|
||||
patterns = [
|
||||
# 标准格式:<fields>field1, field2</fields>
|
||||
r"<fields>\s*([\w\s,]+)\s*</fields>",
|
||||
# 简单列表格式:field1, field2
|
||||
r"(?:^|\s)([\w]+(?:\s*,\s*[\w]+)*)",
|
||||
]
|
||||
|
||||
elif tag == "sort_by":
|
||||
# 处理排序字段
|
||||
patterns = [
|
||||
# 标准格式:<sort_by>value</sort_by>
|
||||
r"<sort_by>\s*(relevance|date|citations|submittedDate|year)\s*</sort_by>",
|
||||
# 简单值格式:relevance
|
||||
r"(?:^|\s)(relevance|date|citations|submittedDate|year)(?:\s|$)"
|
||||
]
|
||||
|
||||
else:
|
||||
# 通用模式
|
||||
patterns = [
|
||||
f"<{tag}>\s*([\s\S]*?)\s*</{tag}>", # 标准XML格式
|
||||
f"<{tag}>([\s\S]*?)(?:</{tag}>|$)", # 未闭合的标签
|
||||
f"[{tag}]([\s\S]*?)[/{tag}]", # 方括号格式
|
||||
f"{tag}:\s*(.*?)(?=\n\w|$)", # 冒号格式
|
||||
f"<{tag}>\s*(.*?)(?=<|$)" # 部分闭合
|
||||
]
|
||||
|
||||
# 3. 尝试所有模式
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
|
||||
if match:
|
||||
content = match.group(1).strip()
|
||||
if content: # 确保提取的内容不为空
|
||||
return content
|
||||
|
||||
# 4. 如果所有模式都失败,返回空字符串
|
||||
return ""
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
from typing import List, Dict, Any
|
||||
from .query_analyzer import QueryAnalyzer, SearchCriteria
|
||||
from .data_sources.arxiv_source import ArxivSource
|
||||
from .data_sources.semantic_source import SemanticScholarSource
|
||||
from .handlers.review_handler import 文献综述功能
|
||||
from .handlers.recommend_handler import 论文推荐功能
|
||||
from .handlers.qa_handler import 学术问答功能
|
||||
from .handlers.paper_handler import 单篇论文分析功能
|
||||
|
||||
class QueryProcessor:
|
||||
"""查询处理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.analyzer = QueryAnalyzer()
|
||||
self.arxiv = ArxivSource()
|
||||
self.semantic = SemanticScholarSource()
|
||||
|
||||
# 初始化各种处理器
|
||||
self.handlers = {
|
||||
"review": 文献综述功能(self.arxiv, self.semantic),
|
||||
"recommend": 论文推荐功能(self.arxiv, self.semantic),
|
||||
"qa": 学术问答功能(self.arxiv, self.semantic),
|
||||
"paper": 单篇论文分析功能(self.arxiv, self.semantic)
|
||||
}
|
||||
|
||||
async def process_query(
|
||||
self,
|
||||
query: str,
|
||||
chatbot: List[List[str]],
|
||||
history: List[List[str]],
|
||||
system_prompt: str,
|
||||
llm_kwargs: Dict[str, Any],
|
||||
plugin_kwargs: Dict[str, Any],
|
||||
) -> List[List[str]]:
|
||||
"""处理用户查询"""
|
||||
|
||||
# 设置默认的插件参数
|
||||
default_plugin_kwargs = {
|
||||
'max_papers': 20, # 最大论文数量
|
||||
'min_year': 2015, # 最早年份
|
||||
'search_multiplier': 3, # 检索倍数
|
||||
}
|
||||
# 更新插件参数
|
||||
plugin_kwargs.update({k: v for k, v in default_plugin_kwargs.items() if k not in plugin_kwargs})
|
||||
|
||||
# 1. 分析查询意图
|
||||
criteria = self.analyzer.analyze_query(query, chatbot, llm_kwargs)
|
||||
|
||||
# 2. 根据查询类型选择处理器
|
||||
handler = self.handlers.get(criteria.query_type)
|
||||
if not handler:
|
||||
handler = self.handlers["qa"] # 默认使用QA处理器
|
||||
|
||||
# 3. 处理查询
|
||||
response = await handler.handle(
|
||||
criteria,
|
||||
chatbot,
|
||||
history,
|
||||
system_prompt,
|
||||
llm_kwargs,
|
||||
plugin_kwargs
|
||||
)
|
||||
|
||||
return response
|
||||
在新工单中引用
屏蔽一个用户