diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py
index 99783b52..7a32222b 100644
--- a/crazy_functions/Arxiv_论文对话.py
+++ b/crazy_functions/Arxiv_论文对话.py
@@ -12,7 +12,7 @@ from typing import List, Dict, Optional
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
-from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file
+from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg
@@ -51,6 +51,7 @@ class ArxivRagWorker:
self.user_name = user_name
self.llm_kwargs = llm_kwargs
self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None
+ self.fragments = None
# 初始化基础目录
@@ -63,7 +64,6 @@ class ArxivRagWorker:
self._processing_lock = ThreadLock()
self._processed_fragments = set()
self._processed_count = 0
-
# 优化的线程池配置
cpu_count = os.cpu_count() or 1
self.thread_pool = ThreadPoolExecutor(
@@ -268,27 +268,18 @@ class ArxivRagWorker:
f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)"
)
- async def process_paper(self, arxiv_id: str) -> bool:
+ async def process_paper(self, fragments: List[Fragment]) -> bool:
"""处理论文主函数"""
try:
- arxiv_id = self._normalize_arxiv_id(arxiv_id)
- logger.info(f"Starting to process paper: {arxiv_id}")
if self.paper_path.exists():
- logger.info(f"Paper {arxiv_id} already processed")
+ logger.info(f"Paper {self.arxiv_id} already processed")
return True
- task = self._create_processing_task(arxiv_id)
-
+ task = self._create_processing_task(self.arxiv_id)
try:
async with self.semaphore:
- fragments = await self.arxiv_splitter.process(arxiv_id)
- if not fragments:
- raise ValueError(f"No fragments extracted from paper {arxiv_id}")
-
- logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}")
await self._process_fragments(fragments)
-
self._complete_task(task, fragments, self.paper_path)
return True
@@ -297,7 +288,7 @@ class ArxivRagWorker:
raise
except Exception as e:
- logger.error(f"Error processing paper {arxiv_id}: {str(e)}")
+ logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}")
return False
def _create_processing_task(self, arxiv_id: str) -> ProcessingTask:
@@ -429,29 +420,28 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
return
user_name = chatbot.get_user()
- worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
+ arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt)
# 处理新论文的情况
- if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not worker.loading:
+ if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading:
chatbot.append((txt, "正在处理论文,请稍等..."))
yield from update_ui(chatbot=chatbot, history=history)
-
+ arxiv_id = arxiv_worker.arxiv_id
+ fragments, formatted_content, output_dir = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_worker.arxiv_id)
+ chatbot.append(["论文下载成功,接下来将编码论文,预计等待两分钟,请耐心等待,论文内容如下:", formatted_content])
try:
# 创建新的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
- # 使用超时控制
- success = False
try:
# 设置超时时间为5分钟
success = loop.run_until_complete(
- asyncio.wait_for(worker.process_paper(txt), timeout=300)
+ asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300)
)
if success:
- arxiv_id = worker._normalize_arxiv_id(txt)
success = loop.run_until_complete(
- asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60)
+ asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60)
)
if success:
chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。")
@@ -515,7 +505,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot:
yield from update_ui(chatbot=chatbot, history=history)
# 生成提示词
- prompt = worker.retrieve_and_generate(query_clip)
+ prompt = arxiv_worker.retrieve_and_generate(query_clip)
if not prompt:
chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。")
yield from update_ui(chatbot=chatbot, history=history)
diff --git a/crazy_functions/doc_fns/content_folder.py b/crazy_functions/doc_fns/content_folder.py
new file mode 100644
index 00000000..1d32dfc4
--- /dev/null
+++ b/crazy_functions/doc_fns/content_folder.py
@@ -0,0 +1,387 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional, Type, TypeVar, Generic, Union
+from dataclasses import dataclass
+from enum import Enum, auto
+import logging
+from datetime import datetime
+from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
+
+# 设置日志
+logger = logging.getLogger(__name__)
+
+
+# 自定义异常类定义
+class FoldingError(Exception):
+ """折叠相关的自定义异常基类"""
+ pass
+
+
+class FormattingError(FoldingError):
+ """格式化过程中的错误"""
+ pass
+
+
+class MetadataError(FoldingError):
+ """元数据相关的错误"""
+ pass
+
+
+class ValidationError(FoldingError):
+ """验证错误"""
+ pass
+
+
+class FoldingStyle(Enum):
+ """折叠样式枚举"""
+ SIMPLE = auto() # 简单折叠
+ DETAILED = auto() # 详细折叠(带有额外信息)
+ NESTED = auto() # 嵌套折叠
+
+
+@dataclass
+class FoldingOptions:
+ """折叠选项配置"""
+ style: FoldingStyle = FoldingStyle.DETAILED
+ code_language: Optional[str] = None # 代码块的语言
+ show_timestamp: bool = False # 是否显示时间戳
+ indent_level: int = 0 # 缩进级别
+ custom_css: Optional[str] = None # 自定义CSS类
+
+
+T = TypeVar('T') # 用于泛型类型
+
+
+class BaseMetadata(ABC):
+ """元数据基类"""
+
+ @abstractmethod
+ def validate(self) -> bool:
+ """验证元数据的有效性"""
+ pass
+
+ def _validate_non_empty_str(self, value: Optional[str]) -> bool:
+ """验证字符串非空"""
+ return bool(value and value.strip())
+
+
+@dataclass
+class FileMetadata(BaseMetadata):
+ """文件元数据"""
+ rel_path: str
+ size: float
+ last_modified: Optional[datetime] = None
+ mime_type: Optional[str] = None
+ encoding: str = 'utf-8'
+
+ def validate(self) -> bool:
+ """验证文件元数据的有效性"""
+ try:
+ if not self._validate_non_empty_str(self.rel_path):
+ return False
+ if self.size < 0:
+ return False
+ return True
+ except Exception as e:
+ logger.error(f"File metadata validation error: {str(e)}")
+ return False
+
+
+
+
+class ContentFormatter(ABC, Generic[T]):
+ """内容格式化抽象基类
+
+ 支持泛型类型参数,可以指定具体的元数据类型。
+ """
+
+ @abstractmethod
+ def format(self,
+ content: str,
+ metadata: T,
+ options: Optional[FoldingOptions] = None) -> str:
+ """格式化内容
+
+ Args:
+ content: 需要格式化的内容
+ metadata: 类型化的元数据
+ options: 折叠选项
+
+ Returns:
+ str: 格式化后的内容
+
+ Raises:
+ FormattingError: 格式化过程中的错误
+ """
+ pass
+
+ def _create_summary(self, metadata: T) -> str:
+ """创建折叠摘要,可被子类重写"""
+ return str(metadata)
+
+ def _format_content_block(self,
+ content: str,
+ options: Optional[FoldingOptions]) -> str:
+ """格式化内容块,处理代码块等特殊格式"""
+ if not options:
+ return content
+
+ if options.code_language:
+ return f"```{options.code_language}\n{content}\n```"
+ return content
+
+ def _add_indent(self, text: str, level: int) -> str:
+ """添加缩进"""
+ if level <= 0:
+ return text
+ indent = " " * level
+ return "\n".join(indent + line for line in text.splitlines())
+
+
+class FileContentFormatter(ContentFormatter[FileMetadata]):
+ """文件内容格式化器"""
+
+ def format(self,
+ content: str,
+ metadata: FileMetadata,
+ options: Optional[FoldingOptions] = None) -> str:
+ """格式化文件内容"""
+ if not metadata.validate():
+ raise MetadataError("Invalid file metadata")
+
+ try:
+ options = options or FoldingOptions()
+
+ # 构建摘要信息
+ summary_parts = [
+ f"{metadata.rel_path} ({metadata.size:.2f}MB)",
+ f"Type: {metadata.mime_type}" if metadata.mime_type else None,
+ (f"Modified: {metadata.last_modified.strftime('%Y-%m-%d %H:%M:%S')}"
+ if metadata.last_modified and options.show_timestamp else None)
+ ]
+ summary = " | ".join(filter(None, summary_parts))
+
+ # 构建HTML类
+ css_class = f' class="{options.custom_css}"' if options.custom_css else ''
+
+ # 格式化内容
+ formatted_content = self._format_content_block(content, options)
+
+ # 组装最终结果
+ result = (
+ f'{summary}
\n\n'
+ f'{formatted_content}\n\n'
+ f' \n\n'
+ )
+
+ return self._add_indent(result, options.indent_level)
+
+ except Exception as e:
+ logger.error(f"Error formatting file content: {str(e)}")
+ raise FormattingError(f"Failed to format file content: {str(e)}")
+
+
+class ContentFoldingManager:
+ """内容折叠管理器"""
+
+ def __init__(self):
+ """初始化折叠管理器"""
+ self._formatters: Dict[str, ContentFormatter] = {}
+ self._register_default_formatters()
+
+ def _register_default_formatters(self) -> None:
+ """注册默认的格式化器"""
+ self.register_formatter('file', FileContentFormatter())
+
+ def register_formatter(self, name: str, formatter: ContentFormatter) -> None:
+ """注册新的格式化器"""
+ if not isinstance(formatter, ContentFormatter):
+ raise TypeError("Formatter must implement ContentFormatter interface")
+ self._formatters[name] = formatter
+
+ def _guess_language(self, extension: str) -> Optional[str]:
+ """根据文件扩展名猜测编程语言"""
+ extension = extension.lower().lstrip('.')
+ language_map = {
+ 'py': 'python',
+ 'js': 'javascript',
+ 'java': 'java',
+ 'cpp': 'cpp',
+ 'cs': 'csharp',
+ 'html': 'html',
+ 'css': 'css',
+ 'md': 'markdown',
+ 'json': 'json',
+ 'xml': 'xml',
+ 'sql': 'sql',
+ 'sh': 'bash',
+ 'yaml': 'yaml',
+ 'yml': 'yaml',
+ 'txt': None # 纯文本不需要语言标识
+ }
+ return language_map.get(extension)
+
+ def format_content(self,
+ content: str,
+ formatter_type: str,
+ metadata: Union[FileMetadata],
+ options: Optional[FoldingOptions] = None) -> str:
+ """格式化内容"""
+ formatter = self._formatters.get(formatter_type)
+ if not formatter:
+ raise KeyError(f"No formatter registered for type: {formatter_type}")
+
+ if not isinstance(metadata, FileMetadata):
+ raise TypeError("Invalid metadata type")
+
+ return formatter.format(content, metadata, options)
+
+
+@dataclass
+class PaperMetadata(BaseMetadata):
+ """论文元数据"""
+ title: str
+ authors: str
+ abstract: str
+ catalogs: str
+ arxiv_id: str = ""
+
+ def validate(self) -> bool:
+ """验证论文元数据的有效性"""
+ try:
+ if not self._validate_non_empty_str(self.title):
+ return False
+ if not self._validate_non_empty_str(self.authors):
+ return False
+ if not self._validate_non_empty_str(self.abstract):
+ return False
+ if not self._validate_non_empty_str(self.catalogs):
+ return False
+ return True
+ except Exception as e:
+ logger.error(f"Paper metadata validation error: {str(e)}")
+ return False
+
+
+class PaperContentFormatter(ContentFormatter[PaperMetadata]):
+ """论文内容格式化器"""
+
+ def format(self,
+ fragments: list[SectionFragment],
+ metadata: PaperMetadata,
+ options: Optional[FoldingOptions] = None) -> str:
+ """格式化论文内容
+
+ Args:
+ fragments: 论文片段列表
+ metadata: 论文元数据
+ options: 折叠选项
+
+ Returns:
+ str: 格式化后的论文内容
+ """
+ if not metadata.validate():
+ raise MetadataError("Invalid paper metadata")
+
+ try:
+ options = options or FoldingOptions()
+
+ # 1. 生成标题部分(不折叠)
+ result = [f"# {metadata.title}\n"]
+
+ # 2. 生成作者信息(折叠)
+ result.append(self._create_folded_section(
+ "Authors",
+ metadata.authors,
+ options
+ ))
+
+ # 3. 生成摘要(折叠)
+ result.append(self._create_folded_section(
+ "Abstract",
+ metadata.abstract,
+ options
+ ))
+
+ # 4. 生成目录树(折叠)
+ result.append(self._create_folded_section(
+ "Table of Contents",
+ f"```\n{metadata.catalogs}\n```",
+ options
+ ))
+
+ # 5. 按章节组织并生成内容
+ sections = self._organize_sections(fragments)
+ for section, section_fragments in sections.items():
+ # 拼接该章节的所有内容
+ section_content = "\n\n".join(
+ fragment.content for fragment in section_fragments
+ )
+
+ result.append(self._create_folded_section(
+ section,
+ section_content,
+ options
+ ))
+
+ # 6. 生成参考文献(折叠)
+ # 收集所有非空的参考文献
+ all_refs = "\n".join(filter(None,
+ (fragment.bibliography for fragment in fragments)
+ ))
+ if all_refs:
+ result.append(self._create_folded_section(
+ "Bibliography",
+ f"```bibtex\n{all_refs}\n```",
+ options
+ ))
+
+ return "\n\n".join(result)
+
+ except Exception as e:
+ logger.error(f"Error formatting paper content: {str(e)}")
+ raise FormattingError(f"Failed to format paper content: {str(e)}")
+
+ def _create_folded_section(self,
+ title: str,
+ content: str,
+ options: FoldingOptions) -> str:
+ """创建折叠区块
+
+ Args:
+ title: 区块标题
+ content: 区块内容
+ options: 折叠选项
+
+ Returns:
+ str: 格式化后的折叠区块
+ """
+ css_class = f' class="{options.custom_css}"' if options.custom_css else ''
+
+ result = (
+ f'{title}
\n\n'
+ f'{content}\n\n'
+ f' '
+ )
+
+ return self._add_indent(result, options.indent_level)
+
+ def _organize_sections(self,
+ fragments: list[SectionFragment]
+ ) -> Dict[str, list[SectionFragment]]:
+ """将片段按章节分组
+
+ Args:
+ fragments: 论文片段列表
+
+ Returns:
+ Dict[str, list[SectionFragment]]: 按章节分组的片段字典
+ """
+ sections: Dict[str, list[SectionFragment]] = {}
+
+ for fragment in fragments:
+ section = fragment.current_section or "Uncategorized"
+ if section not in sections:
+ sections[section] = []
+ sections[section].append(fragment)
+
+ return sections
\ No newline at end of file
diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py
index c78f5ccc..83efd966 100644
--- a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py
+++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py
@@ -14,9 +14,10 @@ from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructurePars
from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section
from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment
from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils
+from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata
-def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path:
+def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path:
"""
Save all fragments to a single structured markdown file.
@@ -37,7 +38,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "
# Generate filename
filename = f"paper_latex_content_{timestamp}.md"
- file_path = output_path / filename
+ file_path = output_path/ filename
# Group fragments by section
sections = {}
@@ -733,7 +734,53 @@ class ArxivSplitter:
return content.strip()
-async def test_arxiv_splitter():
+def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, Path]:
+ """
+ 同步处理 ArXiv 文档并返回分割后的片段
+
+ Args:
+ splitter: ArxivSplitter 实例
+ arxiv_id: ArXiv 文档ID
+
+ Returns:
+ list: 分割后的文档片段列表
+ """
+ try:
+ # 创建一个异步函数来执行异步操作
+ async def _process():
+ return await splitter.process(arxiv_id)
+
+ # 使用 asyncio.run() 运行异步函数
+ fragments = asyncio.run(_process())
+
+ # 保存片段到文件
+ output_dir = save_fragments_to_file(
+ fragments,
+ output_dir=splitter.root_dir / "arxiv_fragments"
+ )
+ print(f"Output saved to: {output_dir}")
+ # 创建论文格式化器
+ formatter = PaperContentFormatter()
+
+ # 准备元数据
+ # 创建格式化选项
+
+ metadata = PaperMetadata(
+ title=fragments[0].title,
+ authors=fragments[0].authors,
+ abstract=fragments[0].abstract,
+ catalogs=fragments[0].catalogs,
+ arxiv_id=fragments[0].arxiv_id
+ )
+
+ # 格式化内容
+ formatted_content = formatter.format(fragments, metadata)
+ return fragments, formatted_content, output_dir
+
+ except Exception as e:
+ print(f"✗ Processing failed for {arxiv_id}: {str(e)}")
+ raise
+def test_arxiv_splitter():
"""测试ArXiv分割器的功能"""
# 测试配置
@@ -752,25 +799,21 @@ async def test_arxiv_splitter():
# 创建分割器实例
splitter = ArxivSplitter(
- root_dir="test_cache"
+ root_dir="private_upload/default_user"
)
for case in test_cases:
print(f"\nTesting paper: {case['arxiv_id']}")
try:
- fragments = await splitter.process(case['arxiv_id'])
-
+ # fragments = await splitter.process(case['arxiv_id'])
+ fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id'])
# 保存fragments
- output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log")
- print(f"Output saved to: {output_dir}")
- # # 内容检查
- # for fragment in fragments:
- # # 长度检查
- #
- # print((fragment.content))
- # print(len(fragment.content))
+ for fragment in fragments:
+ # 长度检查
+ print((fragment.content))
+ print(len(fragment.content))
# 类型检查
-
+ print(output_dir)
except Exception as e:
print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")