镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-07 15:06:48 +00:00
up
这个提交包含在:
@@ -12,7 +12,7 @@ from typing import List, Dict, Optional
|
|||||||
|
|
||||||
from crazy_functions.crazy_utils import input_clipping
|
from crazy_functions.crazy_utils import input_clipping
|
||||||
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
|
||||||
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
|
from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
|
||||||
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
|
||||||
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
|
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
|
||||||
@@ -51,6 +51,7 @@ class ArxivRagWorker:
|
|||||||
self.user_name = user_name
|
self.user_name = user_name
|
||||||
self.llm_kwargs = llm_kwargs
|
self.llm_kwargs = llm_kwargs
|
||||||
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
|
||||||
|
self.fragments = None
|
||||||
|
|
||||||
|
|
||||||
# 初始化基础目录
|
# 初始化基础目录
|
||||||
@@ -63,7 +64,6 @@ class ArxivRagWorker:
|
|||||||
self._processing_lock = ThreadLock()
|
self._processing_lock = ThreadLock()
|
||||||
self._processed_fragments = set()
|
self._processed_fragments = set()
|
||||||
self._processed_count = 0
|
self._processed_count = 0
|
||||||
|
|
||||||
# 优化的线程池配置
|
# 优化的线程池配置
|
||||||
cpu_count = os.cpu_count() or 1
|
cpu_count = os.cpu_count() or 1
|
||||||
self.thread_pool = ThreadPoolExecutor(
|
self.thread_pool = ThreadPoolExecutor(
|
||||||
@@ -268,27 +268,18 @@ class ArxivRagWorker:
|
|||||||
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
|
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def process_paper(self, arxiv_id: str) -> bool:
|
async def process_paper(self, fragments: List[Fragment]) -> bool:
|
||||||
"""处理论文主函数"""
|
"""处理论文主函数"""
|
||||||
try:
|
try:
|
||||||
arxiv_id = self._normalize_arxiv_id(arxiv_id)
|
|
||||||
logger.info(f"Starting to process paper: {arxiv_id}")
|
|
||||||
|
|
||||||
if self.paper_path.exists():
|
if self.paper_path.exists():
|
||||||
logger.info(f"Paper {arxiv_id} already processed")
|
logger.info(f"Paper {self.arxiv_id} already processed")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
task = self._create_processing_task(arxiv_id)
|
task = self._create_processing_task(self.arxiv_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.semaphore:
|
async with self.semaphore:
|
||||||
fragments = await self.arxiv_splitter.process(arxiv_id)
|
|
||||||
if not fragments:
|
|
||||||
raise ValueError(f"No fragments extracted from paper {arxiv_id}")
|
|
||||||
|
|
||||||
logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}")
|
|
||||||
await self._process_fragments(fragments)
|
await self._process_fragments(fragments)
|
||||||
|
|
||||||
self._complete_task(task, fragments, self.paper_path)
|
self._complete_task(task, fragments, self.paper_path)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -297,7 +288,7 @@ class ArxivRagWorker:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
|
logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
|
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
|
||||||
@@ -429,29 +420,28 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
|||||||
return
|
return
|
||||||
|
|
||||||
user_name = chatbot.get_user()
|
user_name = chatbot.get_user()
|
||||||
worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
|
arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
|
||||||
|
|
||||||
# 处理新论文的情况
|
# 处理新论文的情况
|
||||||
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not worker.loading:
|
if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading:
|
||||||
chatbot.append((txt, "正在处理论文,请稍等..."))
|
chatbot.append((txt, "正在处理论文,请稍等..."))
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
arxiv_id = arxiv_worker.arxiv_id
|
||||||
|
fragments, formatted_content, output_dir = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_worker.arxiv_id)
|
||||||
|
chatbot.append(["论文下载成功,接下来将编码论文,预计等待两分钟,请耐心等待,论文内容如下:", formatted_content])
|
||||||
try:
|
try:
|
||||||
# 创建新的事件循环
|
# 创建新的事件循环
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
# 使用超时控制
|
|
||||||
success = False
|
|
||||||
try:
|
try:
|
||||||
# 设置超时时间为5分钟
|
# 设置超时时间为5分钟
|
||||||
success = loop.run_until_complete(
|
success = loop.run_until_complete(
|
||||||
asyncio.wait_for(worker.process_paper(txt), timeout=300)
|
asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300)
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
arxiv_id = worker._normalize_arxiv_id(txt)
|
|
||||||
success = loop.run_until_complete(
|
success = loop.run_until_complete(
|
||||||
asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60)
|
asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60)
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
|
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
|
||||||
@@ -515,7 +505,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
|
|||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|
||||||
# 生成提示词
|
# 生成提示词
|
||||||
prompt = worker.retrieve_and_generate(query_clip)
|
prompt = arxiv_worker.retrieve_and_generate(query_clip)
|
||||||
if not prompt:
|
if not prompt:
|
||||||
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
|
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
|
||||||
yield from update_ui(chatbot=chatbot, history=history)
|
yield from update_ui(chatbot=chatbot, history=history)
|
||||||
|
|||||||
@@ -0,0 +1,387 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict, Optional, Type, TypeVar, Generic, Union
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义异常类定义
|
||||||
|
class FoldingError(Exception):
|
||||||
|
"""折叠相关的自定义异常基类"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FormattingError(FoldingError):
|
||||||
|
"""格式化过程中的错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataError(FoldingError):
|
||||||
|
"""元数据相关的错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(FoldingError):
|
||||||
|
"""验证错误"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FoldingStyle(Enum):
|
||||||
|
"""折叠样式枚举"""
|
||||||
|
SIMPLE = auto() # 简单折叠
|
||||||
|
DETAILED = auto() # 详细折叠(带有额外信息)
|
||||||
|
NESTED = auto() # 嵌套折叠
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FoldingOptions:
|
||||||
|
"""折叠选项配置"""
|
||||||
|
style: FoldingStyle = FoldingStyle.DETAILED
|
||||||
|
code_language: Optional[str] = None # 代码块的语言
|
||||||
|
show_timestamp: bool = False # 是否显示时间戳
|
||||||
|
indent_level: int = 0 # 缩进级别
|
||||||
|
custom_css: Optional[str] = None # 自定义CSS类
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T') # 用于泛型类型
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMetadata(ABC):
|
||||||
|
"""元数据基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证元数据的有效性"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _validate_non_empty_str(self, value: Optional[str]) -> bool:
|
||||||
|
"""验证字符串非空"""
|
||||||
|
return bool(value and value.strip())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileMetadata(BaseMetadata):
|
||||||
|
"""文件元数据"""
|
||||||
|
rel_path: str
|
||||||
|
size: float
|
||||||
|
last_modified: Optional[datetime] = None
|
||||||
|
mime_type: Optional[str] = None
|
||||||
|
encoding: str = 'utf-8'
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证文件元数据的有效性"""
|
||||||
|
try:
|
||||||
|
if not self._validate_non_empty_str(self.rel_path):
|
||||||
|
return False
|
||||||
|
if self.size < 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"File metadata validation error: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ContentFormatter(ABC, Generic[T]):
|
||||||
|
"""内容格式化抽象基类
|
||||||
|
|
||||||
|
支持泛型类型参数,可以指定具体的元数据类型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format(self,
|
||||||
|
content: str,
|
||||||
|
metadata: T,
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 需要格式化的内容
|
||||||
|
metadata: 类型化的元数据
|
||||||
|
options: 折叠选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的内容
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FormattingError: 格式化过程中的错误
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_summary(self, metadata: T) -> str:
|
||||||
|
"""创建折叠摘要,可被子类重写"""
|
||||||
|
return str(metadata)
|
||||||
|
|
||||||
|
def _format_content_block(self,
|
||||||
|
content: str,
|
||||||
|
options: Optional[FoldingOptions]) -> str:
|
||||||
|
"""格式化内容块,处理代码块等特殊格式"""
|
||||||
|
if not options:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if options.code_language:
|
||||||
|
return f"```{options.code_language}\n{content}\n```"
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _add_indent(self, text: str, level: int) -> str:
|
||||||
|
"""添加缩进"""
|
||||||
|
if level <= 0:
|
||||||
|
return text
|
||||||
|
indent = " " * level
|
||||||
|
return "\n".join(indent + line for line in text.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
class FileContentFormatter(ContentFormatter[FileMetadata]):
|
||||||
|
"""文件内容格式化器"""
|
||||||
|
|
||||||
|
def format(self,
|
||||||
|
content: str,
|
||||||
|
metadata: FileMetadata,
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化文件内容"""
|
||||||
|
if not metadata.validate():
|
||||||
|
raise MetadataError("Invalid file metadata")
|
||||||
|
|
||||||
|
try:
|
||||||
|
options = options or FoldingOptions()
|
||||||
|
|
||||||
|
# 构建摘要信息
|
||||||
|
summary_parts = [
|
||||||
|
f"{metadata.rel_path} ({metadata.size:.2f}MB)",
|
||||||
|
f"Type: {metadata.mime_type}" if metadata.mime_type else None,
|
||||||
|
(f"Modified: {metadata.last_modified.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
if metadata.last_modified and options.show_timestamp else None)
|
||||||
|
]
|
||||||
|
summary = " | ".join(filter(None, summary_parts))
|
||||||
|
|
||||||
|
# 构建HTML类
|
||||||
|
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
|
||||||
|
|
||||||
|
# 格式化内容
|
||||||
|
formatted_content = self._format_content_block(content, options)
|
||||||
|
|
||||||
|
# 组装最终结果
|
||||||
|
result = (
|
||||||
|
f'<details{css_class}><summary>{summary}</summary>\n\n'
|
||||||
|
f'{formatted_content}\n\n'
|
||||||
|
f'</details>\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._add_indent(result, options.indent_level)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error formatting file content: {str(e)}")
|
||||||
|
raise FormattingError(f"Failed to format file content: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
class ContentFoldingManager:
|
||||||
|
"""内容折叠管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化折叠管理器"""
|
||||||
|
self._formatters: Dict[str, ContentFormatter] = {}
|
||||||
|
self._register_default_formatters()
|
||||||
|
|
||||||
|
def _register_default_formatters(self) -> None:
|
||||||
|
"""注册默认的格式化器"""
|
||||||
|
self.register_formatter('file', FileContentFormatter())
|
||||||
|
|
||||||
|
def register_formatter(self, name: str, formatter: ContentFormatter) -> None:
|
||||||
|
"""注册新的格式化器"""
|
||||||
|
if not isinstance(formatter, ContentFormatter):
|
||||||
|
raise TypeError("Formatter must implement ContentFormatter interface")
|
||||||
|
self._formatters[name] = formatter
|
||||||
|
|
||||||
|
def _guess_language(self, extension: str) -> Optional[str]:
|
||||||
|
"""根据文件扩展名猜测编程语言"""
|
||||||
|
extension = extension.lower().lstrip('.')
|
||||||
|
language_map = {
|
||||||
|
'py': 'python',
|
||||||
|
'js': 'javascript',
|
||||||
|
'java': 'java',
|
||||||
|
'cpp': 'cpp',
|
||||||
|
'cs': 'csharp',
|
||||||
|
'html': 'html',
|
||||||
|
'css': 'css',
|
||||||
|
'md': 'markdown',
|
||||||
|
'json': 'json',
|
||||||
|
'xml': 'xml',
|
||||||
|
'sql': 'sql',
|
||||||
|
'sh': 'bash',
|
||||||
|
'yaml': 'yaml',
|
||||||
|
'yml': 'yaml',
|
||||||
|
'txt': None # 纯文本不需要语言标识
|
||||||
|
}
|
||||||
|
return language_map.get(extension)
|
||||||
|
|
||||||
|
def format_content(self,
|
||||||
|
content: str,
|
||||||
|
formatter_type: str,
|
||||||
|
metadata: Union[FileMetadata],
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化内容"""
|
||||||
|
formatter = self._formatters.get(formatter_type)
|
||||||
|
if not formatter:
|
||||||
|
raise KeyError(f"No formatter registered for type: {formatter_type}")
|
||||||
|
|
||||||
|
if not isinstance(metadata, FileMetadata):
|
||||||
|
raise TypeError("Invalid metadata type")
|
||||||
|
|
||||||
|
return formatter.format(content, metadata, options)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PaperMetadata(BaseMetadata):
|
||||||
|
"""论文元数据"""
|
||||||
|
title: str
|
||||||
|
authors: str
|
||||||
|
abstract: str
|
||||||
|
catalogs: str
|
||||||
|
arxiv_id: str = ""
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证论文元数据的有效性"""
|
||||||
|
try:
|
||||||
|
if not self._validate_non_empty_str(self.title):
|
||||||
|
return False
|
||||||
|
if not self._validate_non_empty_str(self.authors):
|
||||||
|
return False
|
||||||
|
if not self._validate_non_empty_str(self.abstract):
|
||||||
|
return False
|
||||||
|
if not self._validate_non_empty_str(self.catalogs):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Paper metadata validation error: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class PaperContentFormatter(ContentFormatter[PaperMetadata]):
|
||||||
|
"""论文内容格式化器"""
|
||||||
|
|
||||||
|
def format(self,
|
||||||
|
fragments: list[SectionFragment],
|
||||||
|
metadata: PaperMetadata,
|
||||||
|
options: Optional[FoldingOptions] = None) -> str:
|
||||||
|
"""格式化论文内容
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragments: 论文片段列表
|
||||||
|
metadata: 论文元数据
|
||||||
|
options: 折叠选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的论文内容
|
||||||
|
"""
|
||||||
|
if not metadata.validate():
|
||||||
|
raise MetadataError("Invalid paper metadata")
|
||||||
|
|
||||||
|
try:
|
||||||
|
options = options or FoldingOptions()
|
||||||
|
|
||||||
|
# 1. 生成标题部分(不折叠)
|
||||||
|
result = [f"# {metadata.title}\n"]
|
||||||
|
|
||||||
|
# 2. 生成作者信息(折叠)
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Authors",
|
||||||
|
metadata.authors,
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 3. 生成摘要(折叠)
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Abstract",
|
||||||
|
metadata.abstract,
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 4. 生成目录树(折叠)
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Table of Contents",
|
||||||
|
f"```\n{metadata.catalogs}\n```",
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 5. 按章节组织并生成内容
|
||||||
|
sections = self._organize_sections(fragments)
|
||||||
|
for section, section_fragments in sections.items():
|
||||||
|
# 拼接该章节的所有内容
|
||||||
|
section_content = "\n\n".join(
|
||||||
|
fragment.content for fragment in section_fragments
|
||||||
|
)
|
||||||
|
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
section,
|
||||||
|
section_content,
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
# 6. 生成参考文献(折叠)
|
||||||
|
# 收集所有非空的参考文献
|
||||||
|
all_refs = "\n".join(filter(None,
|
||||||
|
(fragment.bibliography for fragment in fragments)
|
||||||
|
))
|
||||||
|
if all_refs:
|
||||||
|
result.append(self._create_folded_section(
|
||||||
|
"Bibliography",
|
||||||
|
f"```bibtex\n{all_refs}\n```",
|
||||||
|
options
|
||||||
|
))
|
||||||
|
|
||||||
|
return "\n\n".join(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error formatting paper content: {str(e)}")
|
||||||
|
raise FormattingError(f"Failed to format paper content: {str(e)}")
|
||||||
|
|
||||||
|
def _create_folded_section(self,
|
||||||
|
title: str,
|
||||||
|
content: str,
|
||||||
|
options: FoldingOptions) -> str:
|
||||||
|
"""创建折叠区块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: 区块标题
|
||||||
|
content: 区块内容
|
||||||
|
options: 折叠选项
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的折叠区块
|
||||||
|
"""
|
||||||
|
css_class = f' class="{options.custom_css}"' if options.custom_css else ''
|
||||||
|
|
||||||
|
result = (
|
||||||
|
f'<details{css_class}><summary>{title}</summary>\n\n'
|
||||||
|
f'{content}\n\n'
|
||||||
|
f'</details>'
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._add_indent(result, options.indent_level)
|
||||||
|
|
||||||
|
def _organize_sections(self,
|
||||||
|
fragments: list[SectionFragment]
|
||||||
|
) -> Dict[str, list[SectionFragment]]:
|
||||||
|
"""将片段按章节分组
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fragments: 论文片段列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, list[SectionFragment]]: 按章节分组的片段字典
|
||||||
|
"""
|
||||||
|
sections: Dict[str, list[SectionFragment]] = {}
|
||||||
|
|
||||||
|
for fragment in fragments:
|
||||||
|
section = fragment.current_section or "Uncategorized"
|
||||||
|
if section not in sections:
|
||||||
|
sections[section] = []
|
||||||
|
sections[section].append(fragment)
|
||||||
|
|
||||||
|
return sections
|
||||||
@@ -14,9 +14,10 @@ from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructurePars
|
|||||||
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
|
||||||
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
|
||||||
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
|
||||||
|
from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata
|
||||||
|
|
||||||
|
|
||||||
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
|
def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path:
|
||||||
"""
|
"""
|
||||||
Save all fragments to a single structured markdown file.
|
Save all fragments to a single structured markdown file.
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
|
|||||||
|
|
||||||
# Generate filename
|
# Generate filename
|
||||||
filename = f"paper_latex_content_{timestamp}.md"
|
filename = f"paper_latex_content_{timestamp}.md"
|
||||||
file_path = output_path / filename
|
file_path = output_path/ filename
|
||||||
|
|
||||||
# Group fragments by section
|
# Group fragments by section
|
||||||
sections = {}
|
sections = {}
|
||||||
@@ -733,7 +734,53 @@ class ArxivSplitter:
|
|||||||
return content.strip()
|
return content.strip()
|
||||||
|
|
||||||
|
|
||||||
async def test_arxiv_splitter():
|
def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, Path]:
|
||||||
|
"""
|
||||||
|
同步处理 ArXiv 文档并返回分割后的片段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
splitter: ArxivSplitter 实例
|
||||||
|
arxiv_id: ArXiv 文档ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 分割后的文档片段列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 创建一个异步函数来执行异步操作
|
||||||
|
async def _process():
|
||||||
|
return await splitter.process(arxiv_id)
|
||||||
|
|
||||||
|
# 使用 asyncio.run() 运行异步函数
|
||||||
|
fragments = asyncio.run(_process())
|
||||||
|
|
||||||
|
# 保存片段到文件
|
||||||
|
output_dir = save_fragments_to_file(
|
||||||
|
fragments,
|
||||||
|
output_dir=splitter.root_dir / "arxiv_fragments"
|
||||||
|
)
|
||||||
|
print(f"Output saved to: {output_dir}")
|
||||||
|
# 创建论文格式化器
|
||||||
|
formatter = PaperContentFormatter()
|
||||||
|
|
||||||
|
# 准备元数据
|
||||||
|
# 创建格式化选项
|
||||||
|
|
||||||
|
metadata = PaperMetadata(
|
||||||
|
title=fragments[0].title,
|
||||||
|
authors=fragments[0].authors,
|
||||||
|
abstract=fragments[0].abstract,
|
||||||
|
catalogs=fragments[0].catalogs,
|
||||||
|
arxiv_id=fragments[0].arxiv_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化内容
|
||||||
|
formatted_content = formatter.format(fragments, metadata)
|
||||||
|
return fragments, formatted_content, output_dir
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Processing failed for {arxiv_id}: {str(e)}")
|
||||||
|
raise
|
||||||
|
def test_arxiv_splitter():
|
||||||
"""测试ArXiv分割器的功能"""
|
"""测试ArXiv分割器的功能"""
|
||||||
|
|
||||||
# 测试配置
|
# 测试配置
|
||||||
@@ -752,25 +799,21 @@ async def test_arxiv_splitter():
|
|||||||
|
|
||||||
# 创建分割器实例
|
# 创建分割器实例
|
||||||
splitter = ArxivSplitter(
|
splitter = ArxivSplitter(
|
||||||
root_dir="test_cache"
|
root_dir="private_upload/default_user"
|
||||||
)
|
)
|
||||||
|
|
||||||
for case in test_cases:
|
for case in test_cases:
|
||||||
print(f"\nTesting paper: {case['arxiv_id']}")
|
print(f"\nTesting paper: {case['arxiv_id']}")
|
||||||
try:
|
try:
|
||||||
fragments = await splitter.process(case['arxiv_id'])
|
# fragments = await splitter.process(case['arxiv_id'])
|
||||||
|
fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id'])
|
||||||
# 保存fragments
|
# 保存fragments
|
||||||
output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
|
for fragment in fragments:
|
||||||
print(f"Output saved to: {output_dir}")
|
# 长度检查
|
||||||
# # 内容检查
|
print((fragment.content))
|
||||||
# for fragment in fragments:
|
print(len(fragment.content))
|
||||||
# # 长度检查
|
|
||||||
#
|
|
||||||
# print((fragment.content))
|
|
||||||
# print(len(fragment.content))
|
|
||||||
# 类型检查
|
# 类型检查
|
||||||
|
print(output_dir)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
|
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")
|
||||||
|
|||||||
在新工单中引用
屏蔽一个用户