镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
normalize source code names
这个提交包含在:
@@ -9,7 +9,7 @@ from tqdm import tqdm
|
||||
|
||||
class ArxivSource(DataSource):
|
||||
"""arXiv API实现"""
|
||||
|
||||
|
||||
CATEGORIES = {
|
||||
# 物理学
|
||||
"Physics": {
|
||||
@@ -27,7 +27,7 @@ class ArxivSource(DataSource):
|
||||
"physics": "物理学",
|
||||
"quant-ph": "量子物理",
|
||||
},
|
||||
|
||||
|
||||
# 数学
|
||||
"Mathematics": {
|
||||
"math.AG": "代数几何",
|
||||
@@ -63,7 +63,7 @@ class ArxivSource(DataSource):
|
||||
"math.ST": "统计理论",
|
||||
"math.SG": "辛几何",
|
||||
},
|
||||
|
||||
|
||||
# 计算机科学
|
||||
"Computer Science": {
|
||||
"cs.AI": "人工智能",
|
||||
@@ -107,7 +107,7 @@ class ArxivSource(DataSource):
|
||||
"cs.SC": "符号计算",
|
||||
"cs.SY": "系统与控制",
|
||||
},
|
||||
|
||||
|
||||
# 定量生物学
|
||||
"Quantitative Biology": {
|
||||
"q-bio.BM": "生物分子",
|
||||
@@ -121,7 +121,7 @@ class ArxivSource(DataSource):
|
||||
"q-bio.SC": "亚细胞过程",
|
||||
"q-bio.TO": "组织与器官",
|
||||
},
|
||||
|
||||
|
||||
# 定量金融
|
||||
"Quantitative Finance": {
|
||||
"q-fin.CP": "计算金融",
|
||||
@@ -134,7 +134,7 @@ class ArxivSource(DataSource):
|
||||
"q-fin.ST": "统计金融",
|
||||
"q-fin.TR": "交易与市场微观结构",
|
||||
},
|
||||
|
||||
|
||||
# 统计学
|
||||
"Statistics": {
|
||||
"stat.AP": "应用统计",
|
||||
@@ -144,7 +144,7 @@ class ArxivSource(DataSource):
|
||||
"stat.OT": "其他统计",
|
||||
"stat.TH": "统计理论",
|
||||
},
|
||||
|
||||
|
||||
# 电气工程与系统科学
|
||||
"Electrical Engineering and Systems Science": {
|
||||
"eess.AS": "音频与语音处理",
|
||||
@@ -152,7 +152,7 @@ class ArxivSource(DataSource):
|
||||
"eess.SP": "信号处理",
|
||||
"eess.SY": "系统与控制",
|
||||
},
|
||||
|
||||
|
||||
# 经济学
|
||||
"Economics": {
|
||||
"econ.EM": "计量经济学",
|
||||
@@ -170,15 +170,15 @@ class ArxivSource(DataSource):
|
||||
'lastUpdatedDate': arxiv.SortCriterion.LastUpdatedDate, # 最后更新日期
|
||||
'submittedDate': arxiv.SortCriterion.SubmittedDate, # 提交日期
|
||||
}
|
||||
|
||||
|
||||
self.sort_order_options = {
|
||||
'ascending': arxiv.SortOrder.Ascending,
|
||||
'descending': arxiv.SortOrder.Descending
|
||||
}
|
||||
|
||||
|
||||
self.default_sort = 'lastUpdatedDate'
|
||||
self.default_order = 'descending'
|
||||
|
||||
|
||||
def _initialize(self) -> None:
|
||||
"""初始化客户端,设置默认参数"""
|
||||
self.client = arxiv.Client()
|
||||
@@ -196,22 +196,22 @@ class ArxivSource(DataSource):
|
||||
# 使用默认排序如果提供的排序选项无效
|
||||
if not sort_by or sort_by not in self.sort_options:
|
||||
sort_by = self.default_sort
|
||||
|
||||
# 使用默认排序顺序如果提供的顺序无效
|
||||
|
||||
# 使用默认排序顺序如果提供的顺序无效
|
||||
if not sort_order or sort_order not in self.sort_order_options:
|
||||
sort_order = self.default_order
|
||||
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
|
||||
search = arxiv.Search(
|
||||
query=query,
|
||||
max_results=limit,
|
||||
sort_by=self.sort_options[sort_by],
|
||||
sort_order=self.sort_order_options[sort_order]
|
||||
)
|
||||
|
||||
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
except Exception as e:
|
||||
@@ -220,13 +220,13 @@ class ArxivSource(DataSource):
|
||||
|
||||
async def search_by_id(self, paper_id: Union[str, List[str]]) -> List[PaperMetadata]:
|
||||
"""按ID搜索论文
|
||||
|
||||
|
||||
Args:
|
||||
paper_id: 单个arXiv ID或ID列表,例如:'2005.14165' 或 ['2005.14165', '2103.14030']
|
||||
"""
|
||||
if isinstance(paper_id, str):
|
||||
paper_id = [paper_id]
|
||||
|
||||
|
||||
search = arxiv.Search(
|
||||
id_list=paper_id,
|
||||
max_results=len(paper_id)
|
||||
@@ -235,8 +235,8 @@ class ArxivSource(DataSource):
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
async def search_by_category(
|
||||
self,
|
||||
category: str,
|
||||
self,
|
||||
category: str,
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
sort_order: str = 'descending',
|
||||
@@ -244,11 +244,11 @@ class ArxivSource(DataSource):
|
||||
) -> List[PaperMetadata]:
|
||||
"""按类别搜索论文"""
|
||||
query = f"cat:{category}"
|
||||
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
@@ -257,19 +257,19 @@ class ArxivSource(DataSource):
|
||||
)
|
||||
|
||||
async def search_by_authors(
|
||||
self,
|
||||
authors: List[str],
|
||||
self,
|
||||
authors: List[str],
|
||||
limit: int = 100,
|
||||
sort_by: str = 'relevance',
|
||||
start_year: int = None
|
||||
) -> List[PaperMetadata]:
|
||||
"""按作者搜索论文"""
|
||||
query = " AND ".join([f"au:\"{author}\"" for author in authors])
|
||||
|
||||
|
||||
# 如果指定了起始年份,添加到查询中
|
||||
if start_year:
|
||||
query = f"{query} AND submittedDate:[{start_year}0101 TO 99991231]"
|
||||
|
||||
|
||||
return await self.search(
|
||||
query=query,
|
||||
limit=limit,
|
||||
@@ -277,9 +277,9 @@ class ArxivSource(DataSource):
|
||||
)
|
||||
|
||||
async def search_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
limit: int = 100,
|
||||
sort_by: Literal['relevance', 'updated', 'submitted'] = 'submitted',
|
||||
sort_order: Literal['ascending', 'descending'] = 'descending'
|
||||
@@ -287,20 +287,20 @@ class ArxivSource(DataSource):
|
||||
"""按日期范围搜索论文"""
|
||||
query = f"submittedDate:[{start_date.strftime('%Y%m%d')} TO {end_date.strftime('%Y%m%d')}]"
|
||||
return await self.search(
|
||||
query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
query,
|
||||
limit=limit,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order
|
||||
)
|
||||
|
||||
async def download_pdf(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文PDF
|
||||
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.pdf
|
||||
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
@@ -308,24 +308,24 @@ class ArxivSource(DataSource):
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
|
||||
if not filename:
|
||||
# 清理标题中的非法字符
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.pdf"
|
||||
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
urlretrieve(paper.url, filepath)
|
||||
return filepath
|
||||
|
||||
async def download_source(self, paper_id: str, dirpath: str = "./", filename: str = "") -> str:
|
||||
"""下载论文源文件(通常是LaTeX源码)
|
||||
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID
|
||||
dirpath: 保存目录
|
||||
filename: 文件名,如果为空则使用默认格式:{paper_id}_{标题}.tar.gz
|
||||
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
@@ -333,11 +333,11 @@ class ArxivSource(DataSource):
|
||||
if not papers:
|
||||
raise ValueError(f"未找到ID为 {paper_id} 的论文")
|
||||
paper = papers[0]
|
||||
|
||||
|
||||
if not filename:
|
||||
safe_title = "".join(c if c.isalnum() else "_" for c in paper.title)
|
||||
filename = f"{paper_id}_{safe_title}.tar.gz"
|
||||
|
||||
|
||||
filepath = os.path.join(dirpath, filename)
|
||||
source_url = paper.url.replace("/pdf/", "/src/")
|
||||
urlretrieve(source_url, filepath)
|
||||
@@ -353,10 +353,10 @@ class ArxivSource(DataSource):
|
||||
|
||||
async def get_paper_details(self, paper_id: str) -> Optional[PaperMetadata]:
|
||||
"""获取论文详情
|
||||
|
||||
|
||||
Args:
|
||||
paper_id: arXiv ID 或 DOI
|
||||
|
||||
|
||||
Returns:
|
||||
论文详细信息,如果未找到返回 None
|
||||
"""
|
||||
@@ -367,7 +367,7 @@ class ArxivSource(DataSource):
|
||||
# 如果是 DOI 格式且是 arXiv 论文,提取 ID
|
||||
elif paper_id.startswith("10.48550/arXiv."):
|
||||
paper_id = paper_id.split(".")[-1]
|
||||
|
||||
|
||||
papers = await self.search_by_id(paper_id)
|
||||
return papers[0] if papers else None
|
||||
except Exception as e:
|
||||
@@ -379,7 +379,7 @@ class ArxivSource(DataSource):
|
||||
# 解析主要类别和次要类别
|
||||
primary_category = result.primary_category
|
||||
categories = result.categories
|
||||
|
||||
|
||||
# 构建venue信息
|
||||
venue_info = {
|
||||
'primary_category': primary_category,
|
||||
@@ -387,7 +387,7 @@ class ArxivSource(DataSource):
|
||||
'comments': getattr(result, 'comment', None),
|
||||
'journal_ref': getattr(result, 'journal_ref', None)
|
||||
}
|
||||
|
||||
|
||||
return PaperMetadata(
|
||||
title=result.title,
|
||||
authors=[author.name for author in result.authors],
|
||||
@@ -405,15 +405,15 @@ class ArxivSource(DataSource):
|
||||
)
|
||||
|
||||
async def get_latest_papers(
|
||||
self,
|
||||
category: str,
|
||||
self,
|
||||
category: str,
|
||||
debug: bool = False,
|
||||
batch_size: int = 50
|
||||
) -> List[PaperMetadata]:
|
||||
"""获取指定类别的最新论文
|
||||
|
||||
|
||||
通过 RSS feed 获取最新发布的论文,然后批量获取详细信息
|
||||
|
||||
|
||||
Args:
|
||||
category: arXiv类别,例如:
|
||||
- 整个领域: 'cs'
|
||||
@@ -421,10 +421,10 @@ class ArxivSource(DataSource):
|
||||
- 多个类别: 'cs.AI+q-bio.NC'
|
||||
debug: 是否为调试模式,如果为True则只返回5篇最新论文
|
||||
batch_size: 批量获取论文的数量,默认50
|
||||
|
||||
|
||||
Returns:
|
||||
论文列表
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: 如果类别无效
|
||||
"""
|
||||
@@ -433,22 +433,22 @@ class ArxivSource(DataSource):
|
||||
# 1. 转换为小写
|
||||
# 2. 确保多个类别之间使用+连接
|
||||
category = category.lower().replace(' ', '+')
|
||||
|
||||
|
||||
# 构建RSS feed URL
|
||||
feed_url = f"https://rss.arxiv.org/rss/{category}"
|
||||
print(f"正在获取RSS feed: {feed_url}") # 添加调试信息
|
||||
|
||||
|
||||
feed = feedparser.parse(feed_url)
|
||||
|
||||
|
||||
# 检查feed是否有效
|
||||
if hasattr(feed, 'status') and feed.status != 200:
|
||||
raise ValueError(f"获取RSS feed失败,状态码: {feed.status}")
|
||||
|
||||
|
||||
if not feed.entries:
|
||||
print(f"警告:未在feed中找到任何条目") # 添加调试信息
|
||||
print(f"Feed标题: {feed.feed.title if hasattr(feed, 'feed') else '无标题'}")
|
||||
raise ValueError(f"无效的arXiv类别或未找到论文: {category}")
|
||||
|
||||
|
||||
if debug:
|
||||
# 调试模式:只获取5篇最新论文
|
||||
search = arxiv.Search(
|
||||
@@ -459,7 +459,7 @@ class ArxivSource(DataSource):
|
||||
)
|
||||
results = list(self.client.results(search))
|
||||
return [self._parse_paper_data(result) for result in results]
|
||||
|
||||
|
||||
# 正常模式:获取所有新论文
|
||||
# 从RSS条目中提取arXiv ID
|
||||
paper_ids = []
|
||||
@@ -476,13 +476,13 @@ class ArxivSource(DataSource):
|
||||
except Exception as e:
|
||||
print(f"警告:处理条目时出错: {str(e)}") # 添加调试信息
|
||||
continue
|
||||
|
||||
|
||||
if not paper_ids:
|
||||
print("未能从feed中提取到任何论文ID") # 添加调试信息
|
||||
return []
|
||||
|
||||
|
||||
print(f"成功提取到 {len(paper_ids)} 个论文ID") # 添加调试信息
|
||||
|
||||
|
||||
# 批量获取论文详情
|
||||
papers = []
|
||||
with tqdm(total=len(paper_ids), desc="获取arXiv论文") as pbar:
|
||||
@@ -495,9 +495,9 @@ class ArxivSource(DataSource):
|
||||
batch_results = list(self.client.results(search))
|
||||
papers.extend([self._parse_paper_data(result) for result in batch_results])
|
||||
pbar.update(len(batch_results))
|
||||
|
||||
|
||||
return papers
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取最新论文时发生错误: {str(e)}")
|
||||
import traceback
|
||||
@@ -507,18 +507,18 @@ class ArxivSource(DataSource):
|
||||
async def example_usage():
|
||||
"""ArxivSource使用示例"""
|
||||
arxiv_source = ArxivSource()
|
||||
|
||||
|
||||
try:
|
||||
# 示例1:基本搜索,使用不同的排序方式
|
||||
# print("\n=== 示例1:搜索最新的机器学习论文(按提交时间排序)===")
|
||||
# papers = await arxiv_source.search(
|
||||
# "ti:\"machine learning\"",
|
||||
# "ti:\"machine learning\"",
|
||||
# limit=3,
|
||||
# sort_by='submitted',
|
||||
# sort_order='descending'
|
||||
# )
|
||||
# print(f"找到 {len(papers)} 篇论文")
|
||||
|
||||
|
||||
# for i, paper in enumerate(papers, 1):
|
||||
# print(f"\n--- 论文 {i} ---")
|
||||
# print(f"标题: {paper.title}")
|
||||
@@ -544,7 +544,7 @@ async def example_usage():
|
||||
# # 示例3:按类别搜索
|
||||
# print("\n=== 示例3:搜索人工智能领域最新论文 ===")
|
||||
# ai_papers = await arxiv_source.search_by_category(
|
||||
# "cs.AI",
|
||||
# "cs.AI",
|
||||
# limit=2,
|
||||
# sort_by='updated',
|
||||
# sort_order='descending'
|
||||
@@ -558,7 +558,7 @@ async def example_usage():
|
||||
# # 示例4:按作者搜索
|
||||
# print("\n=== 示例4:搜索特定作者的论文 ===")
|
||||
# author_papers = await arxiv_source.search_by_authors(
|
||||
# ["Bengio"],
|
||||
# ["Bengio"],
|
||||
# limit=2,
|
||||
# sort_by='relevance'
|
||||
# )
|
||||
@@ -598,7 +598,7 @@ async def example_usage():
|
||||
|
||||
# 示例6:获取最新论文
|
||||
print("\n=== 示例8:获取最新论文 ===")
|
||||
|
||||
|
||||
# 获取CS.AI领域的最新论文
|
||||
print("\n--- 获取AI领域最新论文 ---")
|
||||
ai_latest = await arxiv_source.get_latest_papers("cs.AI", debug=True)
|
||||
@@ -607,7 +607,7 @@ async def example_usage():
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
|
||||
# 获取整个计算机科学领域的最新论文
|
||||
print("\n--- 获取整个CS领域最新论文 ---")
|
||||
cs_latest = await arxiv_source.get_latest_papers("cs", debug=True)
|
||||
@@ -616,7 +616,7 @@ async def example_usage():
|
||||
print(f"标题: {paper.title}")
|
||||
print(f"作者: {', '.join(paper.authors)}")
|
||||
print(f"发表年份: {paper.year}")
|
||||
|
||||
|
||||
# 获取多个类别的最新论文
|
||||
print("\n--- 获取AI和机器学习领域最新论文 ---")
|
||||
multi_latest = await arxiv_source.get_latest_papers("cs.AI+cs.LG", debug=True)
|
||||
@@ -633,4 +633,4 @@ async def example_usage():
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(example_usage())
|
||||
asyncio.run(example_usage())
|
||||
在新工单中引用
屏蔽一个用户