镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
116 行
3.5 KiB
Python
116 行
3.5 KiB
Python
import logging
|
||
import tarfile
|
||
from pathlib import Path
|
||
from typing import Optional, Dict
|
||
|
||
import requests
|
||
|
||
|
||
class ArxivDownloader:
|
||
"""用于下载arXiv论文源码的下载器"""
|
||
|
||
def __init__(self, root_dir: str = "./papers", proxies: Optional[Dict[str, str]] = None):
|
||
"""
|
||
初始化下载器
|
||
|
||
Args:
|
||
root_dir: 保存下载文件的根目录
|
||
proxies: 代理服务器设置,例如 {"http": "http://proxy:port", "https": "https://proxy:port"}
|
||
"""
|
||
self.root_dir = Path(root_dir)
|
||
self.root_dir.mkdir(exist_ok=True)
|
||
self.proxies = proxies
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
)
|
||
|
||
def _download_and_extract(self, arxiv_id: str) -> str:
|
||
"""
|
||
下载并解压arxiv论文源码
|
||
|
||
Args:
|
||
arxiv_id: arXiv论文ID,例如"2103.00020"
|
||
|
||
Returns:
|
||
str: 解压后的文件目录路径
|
||
|
||
Raises:
|
||
RuntimeError: 当下载失败时抛出
|
||
"""
|
||
paper_dir = self.root_dir / arxiv_id
|
||
tar_path = paper_dir / f"{arxiv_id}.tar.gz"
|
||
|
||
# 检查缓存
|
||
if paper_dir.exists() and any(paper_dir.iterdir()):
|
||
logging.info(f"Using cached version for {arxiv_id}")
|
||
return str(paper_dir)
|
||
|
||
paper_dir.mkdir(exist_ok=True)
|
||
|
||
urls = [
|
||
f"https://arxiv.org/src/{arxiv_id}",
|
||
f"https://arxiv.org/e-print/{arxiv_id}"
|
||
]
|
||
|
||
for url in urls:
|
||
try:
|
||
logging.info(f"Downloading from {url}")
|
||
response = requests.get(url, proxies=self.proxies)
|
||
if response.status_code == 200:
|
||
tar_path.write_bytes(response.content)
|
||
with tarfile.open(tar_path, 'r:gz') as tar:
|
||
tar.extractall(path=paper_dir)
|
||
return str(paper_dir)
|
||
except Exception as e:
|
||
logging.warning(f"Download failed for {url}: {e}")
|
||
continue
|
||
|
||
raise RuntimeError(f"Failed to download paper {arxiv_id}")
|
||
|
||
def download_paper(self, arxiv_id: str) -> str:
|
||
"""
|
||
下载指定的arXiv论文
|
||
|
||
Args:
|
||
arxiv_id: arXiv论文ID
|
||
|
||
Returns:
|
||
str: 论文文件所在的目录路径
|
||
"""
|
||
return self._download_and_extract(arxiv_id)
|
||
|
||
|
||
def main():
|
||
"""测试下载功能"""
|
||
# 配置代理(如果需要)
|
||
proxies = {
|
||
"http": "http://your-proxy:port",
|
||
"https": "https://your-proxy:port"
|
||
}
|
||
|
||
# 创建下载器实例(如果不需要代理,可以不传入proxies参数)
|
||
downloader = ArxivDownloader(root_dir="./downloaded_papers", proxies=None)
|
||
|
||
# 测试下载一篇论文(这里使用一个示例ID)
|
||
try:
|
||
paper_id = "2103.00020" # 这是一个示例ID
|
||
paper_dir = downloader.download_paper(paper_id)
|
||
print(f"Successfully downloaded paper to: {paper_dir}")
|
||
|
||
# 检查下载的文件
|
||
paper_path = Path(paper_dir)
|
||
if paper_path.exists():
|
||
print("Downloaded files:")
|
||
for file in paper_path.rglob("*"):
|
||
if file.is_file():
|
||
print(f"- {file.relative_to(paper_path)}")
|
||
except Exception as e:
|
||
print(f"Error downloading paper: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|