diff --git a/crazy_functions/Arxiv_论文对话.py b/crazy_functions/Arxiv_论文对话.py index 99783b52..7a32222b 100644 --- a/crazy_functions/Arxiv_论文对话.py +++ b/crazy_functions/Arxiv_论文对话.py @@ -12,7 +12,7 @@ from typing import List, Dict, Optional from crazy_functions.crazy_utils import input_clipping from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive -from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file +from crazy_functions.rag_fns.arxiv_fns.arxiv_splitter import ArxivSplitter, save_fragments_to_file, process_arxiv_sync from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment as Fragment from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker from toolbox import CatchException, update_ui, get_log_folder, update_ui_lastest_msg @@ -51,6 +51,7 @@ class ArxivRagWorker: self.user_name = user_name self.llm_kwargs = llm_kwargs self.arxiv_id = self._normalize_arxiv_id(arxiv_id) if arxiv_id else None + self.fragments = None # 初始化基础目录 @@ -63,7 +64,6 @@ class ArxivRagWorker: self._processing_lock = ThreadLock() self._processed_fragments = set() self._processed_count = 0 - # 优化的线程池配置 cpu_count = os.cpu_count() or 1 self.thread_pool = ThreadPoolExecutor( @@ -268,27 +268,18 @@ class ArxivRagWorker: f"in {elapsed_time:.2f}s (rate: {processing_rate:.2f} fragments/s)" ) - async def process_paper(self, arxiv_id: str) -> bool: + async def process_paper(self, fragments: List[Fragment]) -> bool: """处理论文主函数""" try: - arxiv_id = self._normalize_arxiv_id(arxiv_id) - logger.info(f"Starting to process paper: {arxiv_id}") if self.paper_path.exists(): - logger.info(f"Paper {arxiv_id} already processed") + logger.info(f"Paper {self.arxiv_id} already processed") return True - task = self._create_processing_task(arxiv_id) - + task = self._create_processing_task(self.arxiv_id) try: async with self.semaphore: - fragments = await self.arxiv_splitter.process(arxiv_id) - if not fragments: - raise ValueError(f"No fragments extracted from paper {arxiv_id}") - - logger.info(f"Extracted {len(fragments)} fragments from paper {arxiv_id}") await self._process_fragments(fragments) - self._complete_task(task, fragments, self.paper_path) return True @@ -297,7 +288,7 @@ class ArxivRagWorker: raise except Exception as e: - logger.error(f"Error processing paper {arxiv_id}: {str(e)}") + logger.error(f"Error processing paper {self.arxiv_id}: {str(e)}") return False def _create_processing_task(self, arxiv_id: str) -> ProcessingTask: @@ -429,29 +420,28 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: return user_name = chatbot.get_user() - worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt) + arxiv_worker = ArxivRagWorker(user_name, llm_kwargs, arxiv_id=txt) # 处理新论文的情况 - if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not worker.loading: + if txt.lower().strip().startswith(('https://arxiv.org', 'arxiv.org', '0', '1', '2')) and not arxiv_worker.loading: chatbot.append((txt, "正在处理论文,请稍等...")) yield from update_ui(chatbot=chatbot, history=history) - + arxiv_id = arxiv_worker.arxiv_id + fragments, formatted_content, output_dir = process_arxiv_sync(arxiv_worker.arxiv_splitter, arxiv_worker.arxiv_id) + chatbot.append(["论文下载成功,接下来将编码论文,预计等待两分钟,请耐心等待,论文内容如下:", formatted_content]) try: # 创建新的事件循环 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - # 使用超时控制 - success = False try: # 设置超时时间为5分钟 success = loop.run_until_complete( - asyncio.wait_for(worker.process_paper(txt), timeout=300) + asyncio.wait_for(arxiv_worker.process_paper(fragments), timeout=300) ) if success: - arxiv_id = worker._normalize_arxiv_id(txt) success = loop.run_until_complete( - asyncio.wait_for(worker.wait_for_paper(arxiv_id), timeout=60) + asyncio.wait_for(arxiv_worker.wait_for_paper(arxiv_id), timeout=60) ) if success: chatbot[-1] = (txt, "论文处理完成,您现在可以开始提问。") @@ -515,7 +505,7 @@ def Arxiv论文对话(txt: str, llm_kwargs: Dict, plugin_kwargs: Dict, chatbot: yield from update_ui(chatbot=chatbot, history=history) # 生成提示词 - prompt = worker.retrieve_and_generate(query_clip) + prompt = arxiv_worker.retrieve_and_generate(query_clip) if not prompt: chatbot[-1] = (user_query, "抱歉,处理您的问题时出现错误,请重试。") yield from update_ui(chatbot=chatbot, history=history) diff --git a/crazy_functions/doc_fns/content_folder.py b/crazy_functions/doc_fns/content_folder.py new file mode 100644 index 00000000..1d32dfc4 --- /dev/null +++ b/crazy_functions/doc_fns/content_folder.py @@ -0,0 +1,387 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type, TypeVar, Generic, Union +from dataclasses import dataclass +from enum import Enum, auto +import logging +from datetime import datetime +from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment + +# 设置日志 +logger = logging.getLogger(__name__) + + +# 自定义异常类定义 +class FoldingError(Exception): + """折叠相关的自定义异常基类""" + pass + + +class FormattingError(FoldingError): + """格式化过程中的错误""" + pass + + +class MetadataError(FoldingError): + """元数据相关的错误""" + pass + + +class ValidationError(FoldingError): + """验证错误""" + pass + + +class FoldingStyle(Enum): + """折叠样式枚举""" + SIMPLE = auto() # 简单折叠 + DETAILED = auto() # 详细折叠(带有额外信息) + NESTED = auto() # 嵌套折叠 + + +@dataclass +class FoldingOptions: + """折叠选项配置""" + style: FoldingStyle = FoldingStyle.DETAILED + code_language: Optional[str] = None # 代码块的语言 + show_timestamp: bool = False # 是否显示时间戳 + indent_level: int = 0 # 缩进级别 + custom_css: Optional[str] = None # 自定义CSS类 + + +T = TypeVar('T') # 用于泛型类型 + + +class BaseMetadata(ABC): + """元数据基类""" + + @abstractmethod + def validate(self) -> bool: + """验证元数据的有效性""" + pass + + def _validate_non_empty_str(self, value: Optional[str]) -> bool: + """验证字符串非空""" + return bool(value and value.strip()) + + +@dataclass +class FileMetadata(BaseMetadata): + """文件元数据""" + rel_path: str + size: float + last_modified: Optional[datetime] = None + mime_type: Optional[str] = None + encoding: str = 'utf-8' + + def validate(self) -> bool: + """验证文件元数据的有效性""" + try: + if not self._validate_non_empty_str(self.rel_path): + return False + if self.size < 0: + return False + return True + except Exception as e: + logger.error(f"File metadata validation error: {str(e)}") + return False + + + + +class ContentFormatter(ABC, Generic[T]): + """内容格式化抽象基类 + + 支持泛型类型参数,可以指定具体的元数据类型。 + """ + + @abstractmethod + def format(self, + content: str, + metadata: T, + options: Optional[FoldingOptions] = None) -> str: + """格式化内容 + + Args: + content: 需要格式化的内容 + metadata: 类型化的元数据 + options: 折叠选项 + + Returns: + str: 格式化后的内容 + + Raises: + FormattingError: 格式化过程中的错误 + """ + pass + + def _create_summary(self, metadata: T) -> str: + """创建折叠摘要,可被子类重写""" + return str(metadata) + + def _format_content_block(self, + content: str, + options: Optional[FoldingOptions]) -> str: + """格式化内容块,处理代码块等特殊格式""" + if not options: + return content + + if options.code_language: + return f"```{options.code_language}\n{content}\n```" + return content + + def _add_indent(self, text: str, level: int) -> str: + """添加缩进""" + if level <= 0: + return text + indent = " " * level + return "\n".join(indent + line for line in text.splitlines()) + + +class FileContentFormatter(ContentFormatter[FileMetadata]): + """文件内容格式化器""" + + def format(self, + content: str, + metadata: FileMetadata, + options: Optional[FoldingOptions] = None) -> str: + """格式化文件内容""" + if not metadata.validate(): + raise MetadataError("Invalid file metadata") + + try: + options = options or FoldingOptions() + + # 构建摘要信息 + summary_parts = [ + f"{metadata.rel_path} ({metadata.size:.2f}MB)", + f"Type: {metadata.mime_type}" if metadata.mime_type else None, + (f"Modified: {metadata.last_modified.strftime('%Y-%m-%d %H:%M:%S')}" + if metadata.last_modified and options.show_timestamp else None) + ] + summary = " | ".join(filter(None, summary_parts)) + + # 构建HTML类 + css_class = f' class="{options.custom_css}"' if options.custom_css else '' + + # 格式化内容 + formatted_content = self._format_content_block(content, options) + + # 组装最终结果 + result = ( + f'{summary}\n\n' + f'{formatted_content}\n\n' + f'\n\n' + ) + + return self._add_indent(result, options.indent_level) + + except Exception as e: + logger.error(f"Error formatting file content: {str(e)}") + raise FormattingError(f"Failed to format file content: {str(e)}") + + +class ContentFoldingManager: + """内容折叠管理器""" + + def __init__(self): + """初始化折叠管理器""" + self._formatters: Dict[str, ContentFormatter] = {} + self._register_default_formatters() + + def _register_default_formatters(self) -> None: + """注册默认的格式化器""" + self.register_formatter('file', FileContentFormatter()) + + def register_formatter(self, name: str, formatter: ContentFormatter) -> None: + """注册新的格式化器""" + if not isinstance(formatter, ContentFormatter): + raise TypeError("Formatter must implement ContentFormatter interface") + self._formatters[name] = formatter + + def _guess_language(self, extension: str) -> Optional[str]: + """根据文件扩展名猜测编程语言""" + extension = extension.lower().lstrip('.') + language_map = { + 'py': 'python', + 'js': 'javascript', + 'java': 'java', + 'cpp': 'cpp', + 'cs': 'csharp', + 'html': 'html', + 'css': 'css', + 'md': 'markdown', + 'json': 'json', + 'xml': 'xml', + 'sql': 'sql', + 'sh': 'bash', + 'yaml': 'yaml', + 'yml': 'yaml', + 'txt': None # 纯文本不需要语言标识 + } + return language_map.get(extension) + + def format_content(self, + content: str, + formatter_type: str, + metadata: Union[FileMetadata], + options: Optional[FoldingOptions] = None) -> str: + """格式化内容""" + formatter = self._formatters.get(formatter_type) + if not formatter: + raise KeyError(f"No formatter registered for type: {formatter_type}") + + if not isinstance(metadata, FileMetadata): + raise TypeError("Invalid metadata type") + + return formatter.format(content, metadata, options) + + +@dataclass +class PaperMetadata(BaseMetadata): + """论文元数据""" + title: str + authors: str + abstract: str + catalogs: str + arxiv_id: str = "" + + def validate(self) -> bool: + """验证论文元数据的有效性""" + try: + if not self._validate_non_empty_str(self.title): + return False + if not self._validate_non_empty_str(self.authors): + return False + if not self._validate_non_empty_str(self.abstract): + return False + if not self._validate_non_empty_str(self.catalogs): + return False + return True + except Exception as e: + logger.error(f"Paper metadata validation error: {str(e)}") + return False + + +class PaperContentFormatter(ContentFormatter[PaperMetadata]): + """论文内容格式化器""" + + def format(self, + fragments: list[SectionFragment], + metadata: PaperMetadata, + options: Optional[FoldingOptions] = None) -> str: + """格式化论文内容 + + Args: + fragments: 论文片段列表 + metadata: 论文元数据 + options: 折叠选项 + + Returns: + str: 格式化后的论文内容 + """ + if not metadata.validate(): + raise MetadataError("Invalid paper metadata") + + try: + options = options or FoldingOptions() + + # 1. 生成标题部分(不折叠) + result = [f"# {metadata.title}\n"] + + # 2. 生成作者信息(折叠) + result.append(self._create_folded_section( + "Authors", + metadata.authors, + options + )) + + # 3. 生成摘要(折叠) + result.append(self._create_folded_section( + "Abstract", + metadata.abstract, + options + )) + + # 4. 生成目录树(折叠) + result.append(self._create_folded_section( + "Table of Contents", + f"```\n{metadata.catalogs}\n```", + options + )) + + # 5. 按章节组织并生成内容 + sections = self._organize_sections(fragments) + for section, section_fragments in sections.items(): + # 拼接该章节的所有内容 + section_content = "\n\n".join( + fragment.content for fragment in section_fragments + ) + + result.append(self._create_folded_section( + section, + section_content, + options + )) + + # 6. 生成参考文献(折叠) + # 收集所有非空的参考文献 + all_refs = "\n".join(filter(None, + (fragment.bibliography for fragment in fragments) + )) + if all_refs: + result.append(self._create_folded_section( + "Bibliography", + f"```bibtex\n{all_refs}\n```", + options + )) + + return "\n\n".join(result) + + except Exception as e: + logger.error(f"Error formatting paper content: {str(e)}") + raise FormattingError(f"Failed to format paper content: {str(e)}") + + def _create_folded_section(self, + title: str, + content: str, + options: FoldingOptions) -> str: + """创建折叠区块 + + Args: + title: 区块标题 + content: 区块内容 + options: 折叠选项 + + Returns: + str: 格式化后的折叠区块 + """ + css_class = f' class="{options.custom_css}"' if options.custom_css else '' + + result = ( + f'{title}\n\n' + f'{content}\n\n' + f'' + ) + + return self._add_indent(result, options.indent_level) + + def _organize_sections(self, + fragments: list[SectionFragment] + ) -> Dict[str, list[SectionFragment]]: + """将片段按章节分组 + + Args: + fragments: 论文片段列表 + + Returns: + Dict[str, list[SectionFragment]]: 按章节分组的片段字典 + """ + sections: Dict[str, list[SectionFragment]] = {} + + for fragment in fragments: + section = fragment.current_section or "Uncategorized" + if section not in sections: + sections[section] = [] + sections[section].append(fragment) + + return sections \ No newline at end of file diff --git a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py index c78f5ccc..83efd966 100644 --- a/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py +++ b/crazy_functions/rag_fns/arxiv_fns/arxiv_splitter.py @@ -14,9 +14,10 @@ from crazy_functions.rag_fns.arxiv_fns.essay_structure import EssayStructurePars from crazy_functions.rag_fns.arxiv_fns.section_extractor import Section from crazy_functions.rag_fns.arxiv_fns.section_fragment import SectionFragment from crazy_functions.rag_fns.arxiv_fns.tex_utils import TexUtils +from crazy_functions.doc_fns.content_folder import PaperContentFormatter, PaperMetadata -def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = "fragment_outputs") -> Path: +def save_fragments_to_file(fragments: List[SectionFragment], output_dir: Path ) -> Path: """ Save all fragments to a single structured markdown file. @@ -37,7 +38,7 @@ def save_fragments_to_file(fragments: List[SectionFragment], output_dir: str = " # Generate filename filename = f"paper_latex_content_{timestamp}.md" - file_path = output_path / filename + file_path = output_path/ filename # Group fragments by section sections = {} @@ -733,7 +734,53 @@ class ArxivSplitter: return content.strip() -async def test_arxiv_splitter(): +def process_arxiv_sync(splitter: ArxivSplitter, arxiv_id: str) -> tuple[List[SectionFragment], str, Path]: + """ + 同步处理 ArXiv 文档并返回分割后的片段 + + Args: + splitter: ArxivSplitter 实例 + arxiv_id: ArXiv 文档ID + + Returns: + list: 分割后的文档片段列表 + """ + try: + # 创建一个异步函数来执行异步操作 + async def _process(): + return await splitter.process(arxiv_id) + + # 使用 asyncio.run() 运行异步函数 + fragments = asyncio.run(_process()) + + # 保存片段到文件 + output_dir = save_fragments_to_file( + fragments, + output_dir=splitter.root_dir / "arxiv_fragments" + ) + print(f"Output saved to: {output_dir}") + # 创建论文格式化器 + formatter = PaperContentFormatter() + + # 准备元数据 + # 创建格式化选项 + + metadata = PaperMetadata( + title=fragments[0].title, + authors=fragments[0].authors, + abstract=fragments[0].abstract, + catalogs=fragments[0].catalogs, + arxiv_id=fragments[0].arxiv_id + ) + + # 格式化内容 + formatted_content = formatter.format(fragments, metadata) + return fragments, formatted_content, output_dir + + except Exception as e: + print(f"✗ Processing failed for {arxiv_id}: {str(e)}") + raise +def test_arxiv_splitter(): """测试ArXiv分割器的功能""" # 测试配置 @@ -752,25 +799,21 @@ async def test_arxiv_splitter(): # 创建分割器实例 splitter = ArxivSplitter( - root_dir="test_cache" + root_dir="private_upload/default_user" ) for case in test_cases: print(f"\nTesting paper: {case['arxiv_id']}") try: - fragments = await splitter.process(case['arxiv_id']) - + # fragments = await splitter.process(case['arxiv_id']) + fragments, formatted_content, output_dir = process_arxiv_sync(splitter, case['arxiv_id']) # 保存fragments - output_dir = save_fragments_to_file(fragments, output_dir="crazy_functions/rag_fns/arxiv_fns/gpt_log") - print(f"Output saved to: {output_dir}") - # # 内容检查 - # for fragment in fragments: - # # 长度检查 - # - # print((fragment.content)) - # print(len(fragment.content)) + for fragment in fragments: + # 长度检查 + print((fragment.content)) + print(len(fragment.content)) # 类型检查 - + print(output_dir) except Exception as e: print(f"✗ Test failed for {case['arxiv_id']}: {str(e)}")