比较提交

..

47 次代码提交

作者 SHA1 备注 提交日期
binary-husky
68ff3660ae Update HUGGINGFACE_ACCESS_TOKEN 2024-02-25 22:18:14 +08:00
binary-husky
d0703ef32d update 2024-02-25 22:16:46 +08:00
binary-husky
47289f863d update gr 2024-01-23 15:54:09 +08:00
binary-husky
eaf27df32a ac 2023-12-27 00:05:52 +08:00
binary-husky
d245958dfa new 2023-12-26 23:59:53 +08:00
binary-husky
8dd4d48474 new 2023-12-26 23:59:36 +08:00
binary-husky
15f14f51ff bug fixes 2023-11-29 00:36:26 +08:00
binary-husky
1de63835fc remove old folder 2023-11-20 01:39:45 +08:00
binary-husky
17d0a32f36 version 3.6 2023-11-20 01:17:59 +08:00
binary-husky
971ac206f3 new version 2023-10-06 12:00:27 +08:00
binary-husky
c89c62b914 up 2023-09-15 17:31:18 +08:00
binary-husky
f8946c13f2 follow master 2023-09-15 17:25:21 +08:00
binary-husky
098d8654b3 up 2023-09-09 19:01:31 +08:00
binary-husky
ac8830e30e change default theme 2023-09-09 18:58:08 +08:00
binary-husky
5c0a0882c8 up 2023-09-09 18:56:10 +08:00
binary-husky
f5357f67ca up 2023-08-28 01:44:42 +08:00
binary-husky
a1fe67d7f2 up 2023-08-28 01:40:35 +08:00
binary-husky
cbee909bc8 sync 2023-08-16 13:28:43 +08:00
binary-husky
8a5e8bc5c1 "version": 3.48 2023-08-16 13:26:37 +08:00
binary-husky
96c1852abc Merge branch 'master' into huggingface 2023-06-30 12:09:25 +08:00
binary-husky
cd145c0794 1 2023-06-29 15:04:03 +08:00
binary-husky
7a4d4ad956 Merge branch 'huggingface' of github.com:binary-husky/chatgpt_academic into huggingface 2023-06-29 12:54:24 +08:00
binary-husky
9f9848c6e9 again 2023-06-29 12:54:19 +08:00
binary-husky
94425c49fd again 2023-05-28 21:34:50 +08:00
binary-husky
e874a16050 try again 2023-05-28 21:33:28 +08:00
binary-husky
c28388c5fe load version 2023-05-28 21:32:10 +08:00
binary-husky
b4a56d391b Merge branch 'huggingface' of github.com:binary-husky/chatgpt_academic into huggingface 2023-05-28 21:30:34 +08:00
binary-husky
7075092f86 fix app 2023-05-28 21:30:29 +08:00
binary-husky
1086ff8092 Merge branch 'huggingface' of github.com:binary-husky/chatgpt_academic into huggingface 2023-05-28 21:27:31 +08:00
binary-husky
3a22446b47 try4 2023-05-28 21:27:25 +08:00
binary-husky
7842cf03cc Merge branch 'master' into huggingface 2023-05-28 21:27:20 +08:00
binary-husky
54f55c32f2 213 2023-05-28 21:25:45 +08:00
binary-husky
94318ff0a2 try3 2023-05-28 21:24:46 +08:00
binary-husky
5be6b83762 try2 2023-05-28 21:24:02 +08:00
binary-husky
6f18d1716e Merge branch 'master' into huggingface 2023-05-28 21:21:12 +08:00
binary-husky
90944bd744 up 2023-05-25 15:04:53 +08:00
binary-husky
752937cb70 Merge branch 'master' into huggingface 2023-05-25 15:01:30 +08:00
binary-husky
c584cbac5b fix ver 2023-05-19 14:08:47 +08:00
binary-husky
309d12b404 Merge branch 'master' into huggingface 2023-05-19 14:05:23 +08:00
binary-husky
52ea0acd61 Merge branch 'master' into huggingface 2023-05-06 23:06:53 +08:00
binary-husky
9f5e3e0fd5 Merge branch 'master' into huggingface 2023-05-05 18:24:36 +08:00
binary-husky
315e78e5d9 Merge branch 'master' into huggingface 2023-04-29 03:53:32 +08:00
binary-husky
b6b4ba684a Merge branch 'master' into huggingface 2023-04-24 18:32:56 +08:00
binary-husky
2281a5ca7f 修改提示 2023-04-24 12:55:53 +08:00
binary-husky
49558686f2 Merge branch 'master' into huggingface 2023-04-24 12:30:59 +08:00
Your Name
b050ccedb5 Merge branch 'master' into huggingface 2023-04-21 18:48:00 +08:00
Your Name
ae56cab6f4 huggingface 2023-04-19 18:07:32 +08:00
共有 287 个文件被更改,包括 27942 次插入16449 次删除

查看文件

@@ -11,8 +11,6 @@ body:
- Please choose | 请选择
- Pip Install (I ignored requirements.txt)
- Pip Install (I used latest requirements.txt)
- OneKeyInstall (一键安装脚本-windows)
- OneKeyInstall (一键安装脚本-mac)
- Anaconda (I ignored requirements.txt)
- Anaconda (I used latest requirements.txt)
- DockerWindows/Mac
@@ -34,7 +32,7 @@ body:
- Others | 非最新版
validations:
required: true
- type: dropdown
id: os
attributes:
@@ -47,7 +45,7 @@ body:
- Docker
validations:
required: true
- type: textarea
id: describe
attributes:
@@ -55,7 +53,7 @@ body:
description: Describe the bug | 简述
validations:
required: true
- type: textarea
id: screenshot
attributes:
@@ -63,9 +61,15 @@ body:
description: Screen Shot | 有帮助的截图
validations:
required: true
- type: textarea
id: traceback
attributes:
label: Terminal Traceback & Material to Help Reproduce Bugs | 终端traceback如有 + 帮助我们复现的测试材料样本(如有)
description: Terminal Traceback & Material to Help Reproduce Bugs | 终端traceback如有 + 帮助我们复现的测试材料样本(如有)

查看文件

@@ -21,3 +21,8 @@ body:
attributes:
label: Feature Request | 功能请求
description: Feature Request | 功能请求

查看文件

@@ -1,44 +0,0 @@
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
name: build-with-audio-assistant
on:
push:
branches:
- 'master'
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}_audio_assistant
jobs:
build-and-push-image:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v3
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
- name: Build and push Docker image
uses: docker/build-push-action@v4
with:
context: .
push: true
file: docs/GithubAction+NoLocal+AudioAssistant
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

查看文件

@@ -1,5 +1,5 @@
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
name: build-with-chatglm
name: Create and publish a Docker image for ChatGLM support
on:
push:

查看文件

@@ -1,5 +1,5 @@
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
name: build-with-all-capacity
name: Create and publish a Docker image for ChatGLM support
on:
push:
@@ -8,7 +8,7 @@ on:
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}_with_all_capacity
IMAGE_NAME: ${{ github.repository }}_jittorllms
jobs:
build-and-push-image:
@@ -39,6 +39,6 @@ jobs:
with:
context: .
push: true
file: docs/GithubAction+AllCapacity
file: docs/GithubAction+JittorLLMs
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

查看文件

@@ -1,51 +0,0 @@
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
name: build-with-latex-arm
on:
push:
branches:
- "master"
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}_with_latex_arm
jobs:
build-and-push-image:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
push: true
platforms: linux/arm64
file: docs/GithubAction+NoLocal+Latex
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

查看文件

@@ -1,5 +1,5 @@
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
name: build-with-latex
name: Create and publish a Docker image for Latex support
on:
push:

查看文件

@@ -1,5 +1,5 @@
# https://docs.github.com/en/actions/publishing-packages/publishing-docker-images#publishing-images-to-github-packages
name: build-without-local-llms
name: Create and publish a Docker image
on:
push:

查看文件

@@ -1,25 +0,0 @@
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
#
# You can adjust the behavior by modifying this file.
# For more information, see:
# https://github.com/actions/stale
name: 'Close stale issues and PRs'
on:
schedule:
- cron: '*/5 * * * *'
jobs:
stale:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: read
steps:
- uses: actions/stale@v8
with:
stale-issue-message: 'This issue is stale because it has been open 100 days with no activity. Remove stale label or comment or this will be closed in 1 days.'
days-before-stale: 100
days-before-close: 1
debug-only: true

15
.gitignore vendored
查看文件

@@ -131,9 +131,6 @@ dmypy.json
# Pyre type checker
.pyre/
# macOS files
.DS_Store
.vscode
.idea
@@ -149,15 +146,7 @@ debug*
private*
crazy_functions/test_project/pdf_and_word
crazy_functions/test_samples
request_llms/jittorllms
request_llm/jittorllms
multi-language
request_llms/moss
request_llm/moss
media
flagged
request_llms/ChatGLM-6b-onnx-u8s8
.pre-commit-config.yaml
test.*
temp.*
objdump*
*.min.*.js
TODO

32
.pre-commit-config.yaml 普通文件
查看文件

@@ -0,0 +1,32 @@
default_language_version:
python: python3
exclude: 'dotnet'
ci:
autofix_prs: true
autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
autoupdate_schedule: 'quarterly'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-ast
# - id: check-yaml
- id: check-toml
- id: check-json
- id: check-byte-order-marker
exclude: .gitignore
- id: check-merge-conflict
- id: detect-private-key
- id: trailing-whitespace
- id: end-of-file-fixer
- id: no-commit-to-branch
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
# - repo: https://github.com/charliermarsh/ruff-pre-commit
# rev: v0.0.261
# hooks:
# - id: ruff
# args: ["--fix"]

查看文件

@@ -12,16 +12,11 @@ RUN echo '[global]' > /etc/pip.conf && \
echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
# 语音输出功能以下两行,第一行更换阿里源,第二行安装ffmpeg,都可以删除
RUN UBUNTU_VERSION=$(awk -F= '/^VERSION_CODENAME=/{print $2}' /etc/os-release); echo "deb https://mirrors.aliyun.com/debian/ $UBUNTU_VERSION main non-free contrib" > /etc/apt/sources.list; apt-get update
RUN apt-get install ffmpeg -y
# 进入工作路径(必要)
WORKDIR /gpt
# 安装大部分依赖,利用Docker缓存加速以后的构建 (以下行,可以删除)
# 安装大部分依赖,利用Docker缓存加速以后的构建 (以下行,可以删除)
COPY requirements.txt ./
RUN pip3 install -r requirements.txt

查看文件

@@ -1,9 +1,20 @@
> [!IMPORTANT]
> 2024.10.10: 突发停电,紧急恢复了提供[whl包](https://drive.google.com/file/d/19U_hsLoMrjOlQSzYS3pzWX9fTzyusArP/view?usp=sharing)的文件服务器
> 2024.10.8: 版本3.90加入对llama-index的初步支持,版本3.80加入插件二级菜单功能详见wiki
> 2024.5.1: 加入Doc2x翻译PDF论文的功能,[查看详情](https://github.com/binary-husky/gpt_academic/wiki/Doc2x)
> 2024.3.11: 全力支持Qwen、GLM、DeepseekCoder等中文大语言模型 SoVits语音克隆模块,[查看详情](https://www.bilibili.com/video/BV1Rp421S7tF/)
> 2024.1.17: 安装依赖时,请选择`requirements.txt`中**指定的版本**。 安装命令:`pip install -r requirements.txt`。本项目完全开源免费,您可通过订阅[在线服务](https://github.com/binary-husky/gpt_academic/wiki/online)的方式鼓励本项目的发展。
---
title: GPT-Academic
emoji: 😻
colorFrom: blue
colorTo: blue
sdk: gradio
sdk_version: 3.32.0
app_file: app.py
pinned: false
---
# ChatGPT 学术优化
> **Note**
>
> 2023.11.12: 某些依赖包尚不兼容python 3.12,推荐python 3.11。
>
> 2023.12.26: 安装依赖时,请选择`requirements.txt`中**指定的版本**。 安装命令:`pip install -r requirements.txt`。本项目完全开源免费,您可通过订阅[在线服务](https://github.com/binary-husky/gpt_academic/wiki/online)的方式鼓励本项目的发展。
<br>
@@ -68,7 +79,7 @@ Read this in [English](docs/README.English.md) | [日本語](docs/README.Japanes
读论文、[翻译](https://www.bilibili.com/video/BV1KT411x7Wn)论文 | [插件] 一键解读latex/pdf论文全文并生成摘要
Latex全文[翻译](https://www.bilibili.com/video/BV1nk4y1Y7Js/)、[润色](https://www.bilibili.com/video/BV1FT411H7c5/) | [插件] 一键翻译或润色latex论文
批量注释生成 | [插件] 一键批量生成函数注释
Markdown[中英互译](https://www.bilibili.com/video/BV1yo4y157jV/) | [插件] 看到上面5种语言的[README](https://github.com/binary-husky/gpt_academic/blob/master/docs/README.English.md)了吗?就是出自他的手笔
Markdown[中英互译](https://www.bilibili.com/video/BV1yo4y157jV/) | [插件] 看到上面5种语言的[README](https://github.com/binary-husky/gpt_academic/blob/master/docs/README_EN.md)了吗?就是出自他的手笔
[PDF论文全文翻译功能](https://www.bilibili.com/video/BV1KT411x7Wn) | [插件] PDF论文提取题目&摘要+翻译全文(多线程)
[Arxiv小助手](https://www.bilibili.com/video/BV1LM4y1279X) | [插件] 输入arxiv文章url即可一键翻译摘要+下载PDF
Latex论文一键校对 | [插件] 仿Grammarly对Latex文章进行语法、拼写纠错+输出对照PDF
@@ -88,10 +99,6 @@ Latex论文一键校对 | [插件] 仿Grammarly对Latex文章进行语法、拼
<img src="https://user-images.githubusercontent.com/96192199/279702205-d81137c3-affd-4cd1-bb5e-b15610389762.gif" width="700" >
</div>
<div align="center">
<img src="https://github.com/binary-husky/gpt_academic/assets/96192199/70ff1ec5-e589-4561-a29e-b831079b37fb.gif" width="700" >
</div>
- 所有按钮都通过读取functional.py动态生成,可随意加自定义功能,解放剪贴板
<div align="center">
@@ -258,7 +265,8 @@ P.S. 如果需要依赖Latex的插件功能,请见Wiki。另外,您也可以
# Advanced Usage
### I自定义新的便捷按钮学术快捷键
现在已可以通过UI中的`界面外观`菜单中的`自定义菜单`添加新的便捷按钮。如果需要在代码中定义,请使用任意文本编辑器打开`core_functional.py`,添加如下条目即可:
任意文本编辑器打开`core_functional.py`,添加如下条目,然后重启程序。(如果按钮已存在,那么可以直接修改(前缀、后缀都已支持热修改),无需重启程序即可生效。)
例如
```python
"超级英译中": {

412
app.py 普通文件
查看文件

@@ -0,0 +1,412 @@
import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
help_menu_description = \
"""Github源代码开源和更新[地址🚀](https://github.com/binary-husky/gpt_academic),
感谢热情的[开发者们❤️](https://github.com/binary-husky/gpt_academic/graphs/contributors).
</br></br>常见问题请查阅[项目Wiki](https://github.com/binary-husky/gpt_academic/wiki),
如遇到Bug请前往[Bug反馈](https://github.com/binary-husky/gpt_academic/issues).
</br></br>普通对话使用说明: 1. 输入问题; 2. 点击提交
</br></br>基础功能区使用说明: 1. 输入文本; 2. 点击任意基础功能区按钮
</br></br>函数插件区使用说明: 1. 输入路径/问题, 或者上传文件; 2. 点击任意函数插件区按钮
</br></br>虚空终端使用说明: 点击虚空终端, 然后根据提示输入指令, 再次点击虚空终端
</br></br>如何保存对话: 点击保存当前的对话按钮
</br></br>如何语音对话: 请阅读Wiki
</br></br>如何临时更换API_KEY: 在输入区输入临时API_KEY后提交网页刷新后失效"""
def main():
import subprocess, sys
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'https://public.agent-matrix.com/publish/gradio-3.32.8-py3-none-any.whl'])
import gradio as gr
if gr.__version__ not in ['3.32.8']:
raise ModuleNotFoundError("使用项目内置Gradio获取最优体验! 请运行 `pip install -r requirements.txt` 指令安装内置Gradio及其他依赖, 详情信息见requirements.txt.")
from request_llms.bridge_all import predict
from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated, get_conf, ArgsGeneralWrapper, load_chat_cookies, DummyWith
# 建议您复制一个config_private.py放自己的秘密, 如API和代理网址
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION = get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION')
CHATBOT_HEIGHT, LAYOUT, AVAIL_LLM_MODELS, AUTO_CLEAR_TXT = get_conf('CHATBOT_HEIGHT', 'LAYOUT', 'AVAIL_LLM_MODELS', 'AUTO_CLEAR_TXT')
ENABLE_AUDIO, AUTO_CLEAR_TXT, PATH_LOGGING, AVAIL_THEMES, THEME, ADD_WAIFU = get_conf('ENABLE_AUDIO', 'AUTO_CLEAR_TXT', 'PATH_LOGGING', 'AVAIL_THEMES', 'THEME', 'ADD_WAIFU')
DARK_MODE, NUM_CUSTOM_BASIC_BTN, SSL_KEYFILE, SSL_CERTFILE = get_conf('DARK_MODE', 'NUM_CUSTOM_BASIC_BTN', 'SSL_KEYFILE', 'SSL_CERTFILE')
INIT_SYS_PROMPT = get_conf('INIT_SYS_PROMPT')
# 如果WEB_PORT是-1, 则随机选取WEB端口
PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
from check_proxy import get_current_version
from themes.theme import adjust_theme, advanced_css, theme_declaration, js_code_clear, js_code_reset, js_code_show_or_hide, js_code_show_or_hide_group2
from themes.theme import js_code_for_css_changing, js_code_for_toggle_darkmode, js_code_for_persistent_cookie_init
from themes.theme import load_dynamic_theme, to_cookie_str, from_cookie_str, init_cookie
title_html = f"<h1 align=\"center\">GPT 学术优化 {get_current_version()}</h1>{theme_declaration}"
# 问询记录, python 版本建议3.9+(越新越好)
import logging, uuid
os.makedirs(PATH_LOGGING, exist_ok=True)
try:logging.basicConfig(filename=f"{PATH_LOGGING}/chat_secrets.log", level=logging.INFO, encoding="utf-8", format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
except:logging.basicConfig(filename=f"{PATH_LOGGING}/chat_secrets.log", level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
# Disable logging output from the 'httpx' logger
logging.getLogger("httpx").setLevel(logging.WARNING)
print(f"所有问询记录将自动保存在本地目录./{PATH_LOGGING}/chat_secrets.log, 请注意自我隐私保护哦!")
# 一些普通功能模块
from core_functional import get_core_functions
functional = get_core_functions()
# 高级函数插件
from crazy_functional import get_crazy_functions
DEFAULT_FN_GROUPS = get_conf('DEFAULT_FN_GROUPS')
plugins = get_crazy_functions()
all_plugin_groups = list(set([g for _, plugin in plugins.items() for g in plugin['Group'].split('|')]))
match_group = lambda tags, groups: any([g in groups for g in tags.split('|')])
# 处理markdown文本格式的转变
gr.Chatbot.postprocess = format_io
# 做一些外观色彩上的调整
set_theme = adjust_theme()
# 代理与自动更新
from check_proxy import check_proxy, auto_update, warm_up_modules
proxy_info = check_proxy(proxies)
gr_L1 = lambda: gr.Row().style()
gr_L2 = lambda scale, elem_id: gr.Column(scale=scale, elem_id=elem_id, min_width=400)
if LAYOUT == "TOP-DOWN":
gr_L1 = lambda: DummyWith()
gr_L2 = lambda scale, elem_id: gr.Row()
CHATBOT_HEIGHT /= 2
cancel_handles = []
customize_btns = {}
predefined_btns = {}
with gr.Blocks(title="GPT 学术优化", theme=set_theme, analytics_enabled=False, css=advanced_css) as demo:
gr.HTML(title_html)
gr.HTML('''<center><a href="https://huggingface.co/spaces/qingxu98/gpt-academic?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>请您打开此页面后务必点击上方的“复制空间”Duplicate Space按钮<font color="#FF00FF">使用时,先在输入框填入API-KEY然后回车。</font><br/>切忌在“复制空间”Duplicate Space之前填入API_KEY或进行提问,否则您的API_KEY将极可能被空间所有者攫取<br/>支持任意数量的OpenAI的密钥和API2D的密钥共存,例如输入"OpenAI密钥1,API2D密钥2",然后提交,即可同时使用两种模型接口。</center>''')
secret_css, dark_mode, py_pickle_cookie = gr.Textbox(visible=False), gr.Textbox(DARK_MODE, visible=False), gr.Textbox(visible=False)
cookies = gr.State(load_chat_cookies())
with gr_L1():
with gr_L2(scale=2, elem_id="gpt-chat"):
chatbot = gr.Chatbot(label=f"当前模型:{LLM_MODEL}", elem_id="gpt-chatbot")
if LAYOUT == "TOP-DOWN": chatbot.style(height=CHATBOT_HEIGHT)
history = gr.State([])
with gr_L2(scale=1, elem_id="gpt-panel"):
with gr.Accordion("输入区", open=True, elem_id="input-panel") as area_input_primary:
with gr.Row():
txt = gr.Textbox(show_label=False, lines=2, placeholder="输入问题或API密钥,输入多个密钥时,用英文逗号间隔。支持多个OpenAI密钥共存。").style(container=False)
with gr.Row():
submitBtn = gr.Button("提交", elem_id="elem_submit", variant="primary")
with gr.Row():
resetBtn = gr.Button("重置", elem_id="elem_reset", variant="secondary"); resetBtn.style(size="sm")
stopBtn = gr.Button("停止", elem_id="elem_stop", variant="secondary"); stopBtn.style(size="sm")
clearBtn = gr.Button("清除", elem_id="elem_clear", variant="secondary", visible=False); clearBtn.style(size="sm")
if ENABLE_AUDIO:
with gr.Row():
audio_mic = gr.Audio(source="microphone", type="numpy", elem_id="elem_audio", streaming=True, show_label=False).style(container=False)
with gr.Row():
status = gr.Markdown(f"Tip: 按Enter提交, 按Shift+Enter换行。当前模型: {LLM_MODEL} \n {proxy_info}", elem_id="state-panel")
with gr.Accordion("基础功能区", open=True, elem_id="basic-panel") as area_basic_fn:
with gr.Row():
for k in range(NUM_CUSTOM_BASIC_BTN):
customize_btn = gr.Button("自定义按钮" + str(k+1), visible=False, variant="secondary", info_str=f'基础功能区: 自定义按钮')
customize_btn.style(size="sm")
customize_btns.update({"自定义按钮" + str(k+1): customize_btn})
for k in functional:
if ("Visible" in functional[k]) and (not functional[k]["Visible"]): continue
variant = functional[k]["Color"] if "Color" in functional[k] else "secondary"
functional[k]["Button"] = gr.Button(k, variant=variant, info_str=f'基础功能区: {k}')
functional[k]["Button"].style(size="sm")
predefined_btns.update({k: functional[k]["Button"]})
with gr.Accordion("函数插件区", open=True, elem_id="plugin-panel") as area_crazy_fn:
with gr.Row():
gr.Markdown("插件可读取“输入区”文本/路径作为参数(上传文件自动修正路径)")
with gr.Row(elem_id="input-plugin-group"):
plugin_group_sel = gr.Dropdown(choices=all_plugin_groups, label='', show_label=False, value=DEFAULT_FN_GROUPS,
multiselect=True, interactive=True, elem_classes='normal_mut_select').style(container=False)
with gr.Row():
for k, plugin in plugins.items():
if not plugin.get("AsButton", True): continue
visible = True if match_group(plugin['Group'], DEFAULT_FN_GROUPS) else False
variant = plugins[k]["Color"] if "Color" in plugin else "secondary"
info = plugins[k].get("Info", k)
plugin['Button'] = plugins[k]['Button'] = gr.Button(k, variant=variant,
visible=visible, info_str=f'函数插件区: {info}').style(size="sm")
with gr.Row():
with gr.Accordion("更多函数插件", open=True):
dropdown_fn_list = []
for k, plugin in plugins.items():
if not match_group(plugin['Group'], DEFAULT_FN_GROUPS): continue
if not plugin.get("AsButton", True): dropdown_fn_list.append(k) # 排除已经是按钮的插件
elif plugin.get('AdvancedArgs', False): dropdown_fn_list.append(k) # 对于需要高级参数的插件,亦在下拉菜单中显示
with gr.Row():
dropdown = gr.Dropdown(dropdown_fn_list, value=r"打开插件列表", label="", show_label=False).style(container=False)
with gr.Row():
plugin_advanced_arg = gr.Textbox(show_label=True, label="高级参数输入区", visible=False,
placeholder="这里是特殊函数插件的高级参数输入区").style(container=False)
with gr.Row():
switchy_bt = gr.Button(r"请先从插件列表中选择", variant="secondary").style(size="sm")
with gr.Row():
with gr.Accordion("点击展开“文件下载区”。", open=False) as area_file_up:
file_upload = gr.Files(label="任何文件, 推荐上传压缩文件(zip, tar)", file_count="multiple", elem_id="elem_upload")
with gr.Floating(init_x="0%", init_y="0%", visible=True, width=None, drag="forbidden", elem_id="tooltip"):
with gr.Row():
with gr.Tab("上传文件", elem_id="interact-panel"):
gr.Markdown("请上传本地文件/压缩包供“函数插件区”功能调用。请注意: 上传文件后会自动把输入区修改为相应路径。")
file_upload_2 = gr.Files(label="任何文件, 推荐上传压缩文件(zip, tar)", file_count="multiple", elem_id="elem_upload_float")
with gr.Tab("更换模型", elem_id="interact-panel"):
md_dropdown = gr.Dropdown(AVAIL_LLM_MODELS, value=LLM_MODEL, label="更换LLM模型/请求源").style(container=False)
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.01,interactive=True, label="Top-p (nucleus sampling)",)
temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
max_length_sl = gr.Slider(minimum=256, maximum=1024*32, value=4096, step=128, interactive=True, label="Local LLM MaxLength",)
system_prompt = gr.Textbox(show_label=True, lines=2, placeholder=f"System Prompt", label="System prompt", value=INIT_SYS_PROMPT)
with gr.Tab("界面外观", elem_id="interact-panel"):
theme_dropdown = gr.Dropdown(AVAIL_THEMES, value=THEME, label="更换UI主题").style(container=False)
checkboxes = gr.CheckboxGroup(["基础功能区", "函数插件区", "浮动输入区", "输入清除键", "插件参数区"], value=["基础功能区", "函数插件区"], label="显示/隐藏功能区", elem_id='cbs').style(container=False)
opt = ["自定义菜单"]
value=[]
if ADD_WAIFU: opt += ["添加Live2D形象"]; value += ["添加Live2D形象"]
checkboxes_2 = gr.CheckboxGroup(opt, value=value, label="显示/隐藏自定义菜单", elem_id='cbsc').style(container=False)
dark_mode_btn = gr.Button("切换界面明暗 ☀", variant="secondary").style(size="sm")
dark_mode_btn.click(None, None, None, _js=js_code_for_toggle_darkmode)
with gr.Tab("帮助", elem_id="interact-panel"):
gr.Markdown(help_menu_description)
with gr.Floating(init_x="20%", init_y="50%", visible=False, width="40%", drag="top") as area_input_secondary:
with gr.Accordion("浮动输入区", open=True, elem_id="input-panel2"):
with gr.Row() as row:
row.style(equal_height=True)
with gr.Column(scale=10):
txt2 = gr.Textbox(show_label=False, placeholder="Input question here.",
elem_id='user_input_float', lines=8, label="输入区2").style(container=False)
with gr.Column(scale=1, min_width=40):
submitBtn2 = gr.Button("提交", variant="primary"); submitBtn2.style(size="sm")
resetBtn2 = gr.Button("重置", variant="secondary"); resetBtn2.style(size="sm")
stopBtn2 = gr.Button("停止", variant="secondary"); stopBtn2.style(size="sm")
clearBtn2 = gr.Button("清除", elem_id="elem_clear2", variant="secondary", visible=False); clearBtn2.style(size="sm")
with gr.Floating(init_x="20%", init_y="50%", visible=False, width="40%", drag="top") as area_customize:
with gr.Accordion("自定义菜单", open=True, elem_id="edit-panel"):
with gr.Row() as row:
with gr.Column(scale=10):
AVAIL_BTN = [btn for btn in customize_btns.keys()] + [k for k in functional]
basic_btn_dropdown = gr.Dropdown(AVAIL_BTN, value="自定义按钮1", label="选择一个需要自定义基础功能区按钮").style(container=False)
basic_fn_title = gr.Textbox(show_label=False, placeholder="输入新按钮名称", lines=1).style(container=False)
basic_fn_prefix = gr.Textbox(show_label=False, placeholder="输入新提示前缀", lines=4).style(container=False)
basic_fn_suffix = gr.Textbox(show_label=False, placeholder="输入新提示后缀", lines=4).style(container=False)
with gr.Column(scale=1, min_width=70):
basic_fn_confirm = gr.Button("确认并保存", variant="primary"); basic_fn_confirm.style(size="sm")
basic_fn_clean = gr.Button("恢复默认", variant="primary"); basic_fn_clean.style(size="sm")
def assign_btn(persistent_cookie_, cookies_, basic_btn_dropdown_, basic_fn_title, basic_fn_prefix, basic_fn_suffix, clean_up=False):
ret = {}
# 读取之前的自定义按钮
customize_fn_overwrite_ = cookies_['customize_fn_overwrite']
# 更新新的自定义按钮
customize_fn_overwrite_.update({
basic_btn_dropdown_:
{
"Title":basic_fn_title,
"Prefix":basic_fn_prefix,
"Suffix":basic_fn_suffix,
}
}
)
if clean_up:
customize_fn_overwrite_ = {}
cookies_.update(customize_fn_overwrite_) # 更新cookie
visible = (not clean_up) and (basic_fn_title != "")
if basic_btn_dropdown_ in customize_btns:
# 是自定义按钮,不是预定义按钮
ret.update({customize_btns[basic_btn_dropdown_]: gr.update(visible=visible, value=basic_fn_title)})
else:
# 是预定义按钮
ret.update({predefined_btns[basic_btn_dropdown_]: gr.update(visible=visible, value=basic_fn_title)})
ret.update({cookies: cookies_})
try: persistent_cookie_ = from_cookie_str(persistent_cookie_) # persistent cookie to dict
except: persistent_cookie_ = {}
persistent_cookie_["custom_bnt"] = customize_fn_overwrite_ # dict update new value
persistent_cookie_ = to_cookie_str(persistent_cookie_) # persistent cookie to dict
ret.update({py_pickle_cookie: persistent_cookie_}) # write persistent cookie
return ret
# update btn
h = basic_fn_confirm.click(assign_btn, [py_pickle_cookie, cookies, basic_btn_dropdown, basic_fn_title, basic_fn_prefix, basic_fn_suffix],
[py_pickle_cookie, cookies, *customize_btns.values(), *predefined_btns.values()])
h.then(None, [py_pickle_cookie], None, _js="""(py_pickle_cookie)=>{setCookie("py_pickle_cookie", py_pickle_cookie, 365);}""")
# clean up btn
h2 = basic_fn_clean.click(assign_btn, [py_pickle_cookie, cookies, basic_btn_dropdown, basic_fn_title, basic_fn_prefix, basic_fn_suffix, gr.State(True)],
[py_pickle_cookie, cookies, *customize_btns.values(), *predefined_btns.values()])
h2.then(None, [py_pickle_cookie], None, _js="""(py_pickle_cookie)=>{setCookie("py_pickle_cookie", py_pickle_cookie, 365);}""")
def persistent_cookie_reload(persistent_cookie_, cookies_):
ret = {}
for k in customize_btns:
ret.update({customize_btns[k]: gr.update(visible=False, value="")})
try: persistent_cookie_ = from_cookie_str(persistent_cookie_) # persistent cookie to dict
except: return ret
customize_fn_overwrite_ = persistent_cookie_.get("custom_bnt", {})
cookies_['customize_fn_overwrite'] = customize_fn_overwrite_
ret.update({cookies: cookies_})
for k,v in persistent_cookie_["custom_bnt"].items():
if v['Title'] == "": continue
if k in customize_btns: ret.update({customize_btns[k]: gr.update(visible=True, value=v['Title'])})
else: ret.update({predefined_btns[k]: gr.update(visible=True, value=v['Title'])})
return ret
# 功能区显示开关与功能区的互动
def fn_area_visibility(a):
ret = {}
ret.update({area_input_primary: gr.update(visible=("浮动输入区" not in a))})
ret.update({area_input_secondary: gr.update(visible=("浮动输入区" in a))})
ret.update({plugin_advanced_arg: gr.update(visible=("插件参数区" in a))})
if "浮动输入区" in a: ret.update({txt: gr.update(value="")})
return ret
checkboxes.select(fn_area_visibility, [checkboxes], [area_basic_fn, area_crazy_fn, area_input_primary, area_input_secondary, txt, txt2, plugin_advanced_arg] )
checkboxes.select(None, [checkboxes], None, _js=js_code_show_or_hide)
# 功能区显示开关与功能区的互动
def fn_area_visibility_2(a):
ret = {}
ret.update({area_customize: gr.update(visible=("自定义菜单" in a))})
return ret
checkboxes_2.select(fn_area_visibility_2, [checkboxes_2], [area_customize] )
checkboxes_2.select(None, [checkboxes_2], None, _js=js_code_show_or_hide_group2)
# 整理反复出现的控件句柄组合
input_combo = [cookies, max_length_sl, md_dropdown, txt, txt2, top_p, temperature, chatbot, history, system_prompt, plugin_advanced_arg]
output_combo = [cookies, chatbot, history, status]
predict_args = dict(fn=ArgsGeneralWrapper(predict), inputs=[*input_combo, gr.State(True)], outputs=output_combo)
# 提交按钮、重置按钮
cancel_handles.append(txt.submit(**predict_args))
cancel_handles.append(txt2.submit(**predict_args))
cancel_handles.append(submitBtn.click(**predict_args))
cancel_handles.append(submitBtn2.click(**predict_args))
resetBtn.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
resetBtn2.click(None, None, [chatbot, history, status], _js=js_code_reset) # 先在前端快速清除chatbot&status
resetBtn.click(lambda: ([], [], "已重置"), None, [chatbot, history, status]) # 再在后端清除history
resetBtn2.click(lambda: ([], [], "已重置"), None, [chatbot, history, status]) # 再在后端清除history
clearBtn.click(None, None, [txt, txt2], _js=js_code_clear)
clearBtn2.click(None, None, [txt, txt2], _js=js_code_clear)
if AUTO_CLEAR_TXT:
submitBtn.click(None, None, [txt, txt2], _js=js_code_clear)
submitBtn2.click(None, None, [txt, txt2], _js=js_code_clear)
txt.submit(None, None, [txt, txt2], _js=js_code_clear)
txt2.submit(None, None, [txt, txt2], _js=js_code_clear)
# 基础功能区的回调函数注册
for k in functional:
if ("Visible" in functional[k]) and (not functional[k]["Visible"]): continue
click_handle = functional[k]["Button"].click(fn=ArgsGeneralWrapper(predict), inputs=[*input_combo, gr.State(True), gr.State(k)], outputs=output_combo)
cancel_handles.append(click_handle)
for btn in customize_btns.values():
click_handle = btn.click(fn=ArgsGeneralWrapper(predict), inputs=[*input_combo, gr.State(True), gr.State(btn.value)], outputs=output_combo)
cancel_handles.append(click_handle)
# 文件上传区,接收文件后与chatbot的互动
file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt, txt2, checkboxes, cookies], [chatbot, txt, txt2, cookies]).then(None, None, None, _js=r"()=>{toast_push('上传完毕 ...'); cancel_loading_status();}")
file_upload_2.upload(on_file_uploaded, [file_upload_2, chatbot, txt, txt2, checkboxes, cookies], [chatbot, txt, txt2, cookies]).then(None, None, None, _js=r"()=>{toast_push('上传完毕 ...'); cancel_loading_status();}")
# 函数插件-固定按钮区
for k in plugins:
if not plugins[k].get("AsButton", True): continue
click_handle = plugins[k]["Button"].click(ArgsGeneralWrapper(plugins[k]["Function"]), [*input_combo], output_combo)
click_handle.then(on_report_generated, [cookies, file_upload, chatbot], [cookies, file_upload, chatbot])
cancel_handles.append(click_handle)
# 函数插件-下拉菜单与随变按钮的互动
def on_dropdown_changed(k):
variant = plugins[k]["Color"] if "Color" in plugins[k] else "secondary"
info = plugins[k].get("Info", k)
ret = {switchy_bt: gr.update(value=k, variant=variant, info_str=f'函数插件区: {info}')}
if plugins[k].get("AdvancedArgs", False): # 是否唤起高级插件参数区
ret.update({plugin_advanced_arg: gr.update(visible=True, label=f"插件[{k}]的高级参数说明:" + plugins[k].get("ArgsReminder", [f"没有提供高级参数功能说明"]))})
else:
ret.update({plugin_advanced_arg: gr.update(visible=False, label=f"插件[{k}]不需要高级参数。")})
return ret
dropdown.select(on_dropdown_changed, [dropdown], [switchy_bt, plugin_advanced_arg] )
def on_md_dropdown_changed(k):
return {chatbot: gr.update(label="当前模型:"+k)}
md_dropdown.select(on_md_dropdown_changed, [md_dropdown], [chatbot] )
def on_theme_dropdown_changed(theme, secret_css):
adjust_theme, css_part1, _, adjust_dynamic_theme = load_dynamic_theme(theme)
if adjust_dynamic_theme:
css_part2 = adjust_dynamic_theme._get_theme_css()
else:
css_part2 = adjust_theme()._get_theme_css()
return css_part2 + css_part1
theme_handle = theme_dropdown.select(on_theme_dropdown_changed, [theme_dropdown, secret_css], [secret_css])
theme_handle.then(
None,
[secret_css],
None,
_js=js_code_for_css_changing
)
# 随变按钮的回调函数注册
def route(request: gr.Request, k, *args, **kwargs):
if k in [r"打开插件列表", r"请先从插件列表中选择"]: return
yield from ArgsGeneralWrapper(plugins[k]["Function"])(request, *args, **kwargs)
click_handle = switchy_bt.click(route,[switchy_bt, *input_combo], output_combo)
click_handle.then(on_report_generated, [cookies, file_upload, chatbot], [cookies, file_upload, chatbot])
cancel_handles.append(click_handle)
# 终止按钮的回调函数注册
stopBtn.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
stopBtn2.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
plugins_as_btn = {name:plugin for name, plugin in plugins.items() if plugin.get('Button', None)}
def on_group_change(group_list):
btn_list = []
fns_list = []
if not group_list: # 处理特殊情况:没有选择任何插件组
return [*[plugin['Button'].update(visible=False) for _, plugin in plugins_as_btn.items()], gr.Dropdown.update(choices=[])]
for k, plugin in plugins.items():
if plugin.get("AsButton", True):
btn_list.append(plugin['Button'].update(visible=match_group(plugin['Group'], group_list))) # 刷新按钮
if plugin.get('AdvancedArgs', False): dropdown_fn_list.append(k) # 对于需要高级参数的插件,亦在下拉菜单中显示
elif match_group(plugin['Group'], group_list): fns_list.append(k) # 刷新下拉列表
return [*btn_list, gr.Dropdown.update(choices=fns_list)]
plugin_group_sel.select(fn=on_group_change, inputs=[plugin_group_sel], outputs=[*[plugin['Button'] for name, plugin in plugins_as_btn.items()], dropdown])
if ENABLE_AUDIO:
from crazy_functions.live_audio.audio_io import RealtimeAudioDistribution
rad = RealtimeAudioDistribution()
def deal_audio(audio, cookies):
rad.feed(cookies['uuid'].hex, audio)
audio_mic.stream(deal_audio, inputs=[audio_mic, cookies])
demo.load(init_cookie, inputs=[cookies], outputs=[cookies])
demo.load(persistent_cookie_reload, inputs = [py_pickle_cookie, cookies],
outputs = [py_pickle_cookie, cookies, *customize_btns.values(), *predefined_btns.values()], _js=js_code_for_persistent_cookie_init)
demo.load(None, inputs=[dark_mode], outputs=None, _js="""(dark_mode)=>{apply_cookie_for_checkbox(dark_mode);}""") # 配置暗色主题或亮色主题
demo.load(None, inputs=[gr.Textbox(LAYOUT, visible=False)], outputs=None, _js='(LAYOUT)=>{GptAcademicJavaScriptInit(LAYOUT);}')
# gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
def run_delayed_tasks():
import threading, webbrowser, time
print(f"如果浏览器没有自动打开,请复制并转到以下URL")
if DARK_MODE: print(f"\t「暗色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
else: print(f"\t「亮色主题已启用(支持动态切换主题)」: http://localhost:{PORT}")
def auto_updates(): time.sleep(0); auto_update()
def open_browser(): time.sleep(2); webbrowser.open_new_tab(f"http://localhost:{PORT}")
def warm_up_mods(): time.sleep(6); warm_up_modules()
threading.Thread(target=auto_updates, name="self-upgrade", daemon=True).start() # 查看自动更新
threading.Thread(target=open_browser, name="open-browser", daemon=True).start() # 打开浏览器页面
threading.Thread(target=warm_up_mods, name="warm-up", daemon=True).start() # 预热tiktoken模块
run_delayed_tasks()
demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", share=False, favicon_path="docs/logo.png", blocked_paths=["config.py","config_private.py","docker-compose.yml","Dockerfile"])
# 如果需要在二级路径下运行
# CUSTOM_PATH = get_conf('CUSTOM_PATH')
# if CUSTOM_PATH != "/":
# from toolbox import run_gradio_in_subpath
# run_gradio_in_subpath(demo, auth=AUTHENTICATION, port=PORT, custom_path=CUSTOM_PATH)
# else:
# demo.launch(server_name="0.0.0.0", server_port=PORT, auth=AUTHENTICATION, favicon_path="docs/logo.png",
# blocked_paths=["config.py","config_private.py","docker-compose.yml","Dockerfile",f"{PATH_LOGGING}/admin"])
if __name__ == "__main__":
main()

查看文件

@@ -1,77 +1,37 @@
from loguru import logger
def check_proxy(proxies, return_ip=False):
"""
检查代理配置并返回结果。
Args:
proxies (dict): 包含http和https代理配置的字典。
return_ip (bool, optional): 是否返回代理的IP地址。默认为False。
Returns:
str or None: 检查的结果信息或代理的IP地址如果`return_ip`为True
"""
def check_proxy(proxies):
import requests
proxies_https = proxies['https'] if proxies is not None else ''
ip = None
try:
response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4) # ⭐ 执行GET请求以获取代理信息
response = requests.get("https://ipapi.co/json/", proxies=proxies, timeout=4)
data = response.json()
if 'country_name' in data:
country = data['country_name']
result = f"代理配置 {proxies_https}, 代理所在地:{country}"
if 'ip' in data:
ip = data['ip']
elif 'error' in data:
alternative, ip = _check_with_backup_source(proxies) # ⭐ 调用备用方法检查代理配置
alternative = _check_with_backup_source(proxies)
if alternative is None:
result = f"代理配置 {proxies_https}, 代理所在地未知,IP查询频率受限"
else:
result = f"代理配置 {proxies_https}, 代理所在地:{alternative}"
else:
result = f"代理配置 {proxies_https}, 代理数据解析失败:{data}"
if not return_ip:
logger.warning(result)
return result
else:
return ip
print(result)
return result
except:
result = f"代理配置 {proxies_https}, 代理所在地查询超时,代理可能无效"
if not return_ip:
logger.warning(result)
return result
else:
return ip
print(result)
return result
def _check_with_backup_source(proxies):
"""
通过备份源检查代理,并获取相应信息。
Args:
proxies (dict): 包含代理信息的字典。
Returns:
tuple: 代理信息(geo)和IP地址(ip)的元组。
"""
import random, string, requests
random_string = ''.join(random.choices(string.ascii_letters + string.digits, k=32))
try:
res_json = requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json() # ⭐ 执行代理检查和备份源请求
return res_json['dns']['geo'], res_json['dns']['ip']
except:
return None, None
try: return requests.get(f"http://{random_string}.edns.ip-api.com/json", proxies=proxies, timeout=4).json()['dns']['geo']
except: return None
def backup_and_download(current_version, remote_version):
"""
一键更新协议:备份当前版本,下载远程版本并解压缩。
Args:
current_version (str): 当前版本号。
remote_version (str): 远程版本号。
Returns:
str: 新版本目录的路径。
一键更新协议:备份和下载
"""
from toolbox import get_conf
import shutil
@@ -87,8 +47,8 @@ def backup_and_download(current_version, remote_version):
shutil.copytree('./', backup_dir, ignore=lambda x, y: ['history'])
proxies = get_conf('proxies')
try: r = requests.get('https://github.com/binary-husky/chatgpt_academic/archive/refs/heads/master.zip', proxies=proxies, stream=True)
except: r = requests.get('https://public.agent-matrix.com/publish/master.zip', proxies=proxies, stream=True)
zip_file_path = backup_dir+'/master.zip' # ⭐ 保存备份文件的路径
except: r = requests.get('https://public.gpt-academic.top/publish/master.zip', proxies=proxies, stream=True)
zip_file_path = backup_dir+'/master.zip'
with open(zip_file_path, 'wb+') as f:
f.write(r.content)
dst_path = new_version_dir
@@ -104,17 +64,6 @@ def backup_and_download(current_version, remote_version):
def patch_and_restart(path):
"""
一键更新协议:覆盖和重启
Args:
path (str): 新版本代码所在的路径
注意事项:
如果您的程序没有使用config_private.py私密配置文件,则会将config.py重命名为config_private.py以避免配置丢失。
更新流程:
- 复制最新版本代码到当前目录
- 更新pip包依赖
- 如果更新失败,则提示手动安装依赖库并重启
"""
from distutils import dir_util
import shutil
@@ -122,44 +71,33 @@ def patch_and_restart(path):
import sys
import time
import glob
from shared_utils.colorful import log亮黄, log亮绿, log亮红
from colorful import print亮黄, print亮绿, print亮红
# if not using config_private, move origin config.py as config_private.py
if not os.path.exists('config_private.py'):
log亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
print亮黄('由于您没有设置config_private.py私密配置,现将您的现有配置移动至config_private.py以防止配置丢失,',
'另外您可以随时在history子文件夹下找回旧版的程序。')
shutil.copyfile('config.py', 'config_private.py')
path_new_version = glob.glob(path + '/*-master')[0]
dir_util.copy_tree(path_new_version, './') # ⭐ 将最新版本代码复制到当前目录
log亮绿('代码已经更新,即将更新pip包依赖……')
for i in reversed(range(5)): time.sleep(1); log亮绿(i)
try:
dir_util.copy_tree(path_new_version, './')
print亮绿('代码已经更新,即将更新pip包依赖……')
for i in reversed(range(5)): time.sleep(1); print(i)
try:
import subprocess
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'])
except:
log亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
log亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
log亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
log亮绿(' ------------------------------ -----------------------------------')
for i in reversed(range(8)): time.sleep(1); log亮绿(i)
os.execl(sys.executable, sys.executable, *sys.argv) # 重启程序
print亮红('pip包依赖安装出现问题,需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
print亮绿('更新完成,您可以随时在history子文件夹下找回旧版的程序,5s之后重启')
print亮红('假如重启失败,您可能需要手动安装新增的依赖库 `python -m pip install -r requirements.txt`,然后在用常规的`python main.py`的方式启动。')
print(' ------------------------------ -----------------------------------')
for i in reversed(range(8)): time.sleep(1); print(i)
os.execl(sys.executable, sys.executable, *sys.argv)
def get_current_version():
"""
获取当前的版本号。
Returns:
str: 当前的版本号。如果无法获取版本号,则返回空字符串。
"""
import json
try:
with open('./version', 'r', encoding='utf8') as f:
current_version = json.loads(f.read())['version'] # ⭐ 从读取的json数据中提取版本号
current_version = json.loads(f.read())['version']
except:
current_version = ""
return current_version
@@ -168,12 +106,6 @@ def get_current_version():
def auto_update(raise_error=False):
"""
一键更新协议:查询版本和用户意见
Args:
raise_error (bool, optional): 是否在出错时抛出错误。默认为 False。
Returns:
None
"""
try:
from toolbox import get_conf
@@ -181,7 +113,7 @@ def auto_update(raise_error=False):
import json
proxies = get_conf('proxies')
try: response = requests.get("https://raw.githubusercontent.com/binary-husky/chatgpt_academic/master/version", proxies=proxies, timeout=5)
except: response = requests.get("https://public.agent-matrix.com/publish/version", proxies=proxies, timeout=5)
except: response = requests.get("https://public.gpt-academic.top/publish/version", proxies=proxies, timeout=5)
remote_json_data = json.loads(response.text)
remote_version = remote_json_data['version']
if remote_json_data["show_feature"]:
@@ -192,22 +124,22 @@ def auto_update(raise_error=False):
current_version = f.read()
current_version = json.loads(current_version)['version']
if (remote_version - current_version) >= 0.01-1e-5:
from shared_utils.colorful import log亮黄
log亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}{new_feature}') # ⭐ 在控制台打印新版本信息
logger.info('1Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
from colorful import print亮黄
print亮黄(f'\n新版本可用。新版本:{remote_version},当前版本:{current_version}{new_feature}')
print('1Github更新地址:\nhttps://github.com/binary-husky/chatgpt_academic\n')
user_instruction = input('2是否一键更新代码Y+回车=确认,输入其他/无输入+回车=不更新)?')
if user_instruction in ['Y', 'y']:
path = backup_and_download(current_version, remote_version) # ⭐ 备份并下载文件
path = backup_and_download(current_version, remote_version)
try:
patch_and_restart(path) # ⭐ 执行覆盖并重启操作
patch_and_restart(path)
except:
msg = '更新失败。'
if raise_error:
from toolbox import trimmed_format_exc
msg += trimmed_format_exc()
logger.warning(msg)
print(msg)
else:
logger.info('自动更新程序:已禁用')
print('自动更新程序:已禁用')
return
else:
return
@@ -216,13 +148,10 @@ def auto_update(raise_error=False):
if raise_error:
from toolbox import trimmed_format_exc
msg += trimmed_format_exc()
logger.info(msg)
print(msg)
def warm_up_modules():
"""
预热模块,加载特定模块并执行预热操作。
"""
logger.info('正在执行一些模块的预热 ...')
print('正在执行一些模块的预热 ...')
from toolbox import ProxyNetworkActivate
from request_llms.bridge_all import model_info
with ProxyNetworkActivate("Warmup_Modules"):
@@ -230,28 +159,18 @@ def warm_up_modules():
enc.encode("模块预热", disallowed_special=())
enc = model_info["gpt-4"]['tokenizer']
enc.encode("模块预热", disallowed_special=())
def warm_up_vectordb():
"""
执行一些模块的预热操作。
本函数主要用于执行一些模块的预热操作,确保在后续的流程中能够顺利运行。
⭐ 关键作用:预热模块
Returns:
None
"""
logger.info('正在执行一些模块的预热 ...')
print('正在执行一些模块的预热 ...')
from toolbox import ProxyNetworkActivate
with ProxyNetworkActivate("Warmup_Modules"):
import nltk
with ProxyNetworkActivate("Warmup_Modules"): nltk.download("punkt")
if __name__ == '__main__':
import os
os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
from toolbox import get_conf
proxies = get_conf('proxies')
check_proxy(proxies)
check_proxy(proxies)

查看文件

@@ -1,10 +1,9 @@
import platform
from sys import stdout
from loguru import logger
if platform.system()=="Linux":
pass
else:
else:
from colorama import init
init()
@@ -60,29 +59,3 @@ def sprint亮紫(*kw):
return "\033[1;35m"+' '.join(kw)+"\033[0m"
def sprint亮靛(*kw):
return "\033[1;36m"+' '.join(kw)+"\033[0m"
def log红(*kw,**kargs):
logger.opt(depth=1).info(sprint红(*kw))
def log绿(*kw,**kargs):
logger.opt(depth=1).info(sprint绿(*kw))
def log黄(*kw,**kargs):
logger.opt(depth=1).info(sprint黄(*kw))
def log蓝(*kw,**kargs):
logger.opt(depth=1).info(sprint蓝(*kw))
def log紫(*kw,**kargs):
logger.opt(depth=1).info(sprint紫(*kw))
def log靛(*kw,**kargs):
logger.opt(depth=1).info(sprint靛(*kw))
def log亮红(*kw,**kargs):
logger.opt(depth=1).info(sprint亮红(*kw))
def log亮绿(*kw,**kargs):
logger.opt(depth=1).info(sprint亮绿(*kw))
def log亮黄(*kw,**kargs):
logger.opt(depth=1).info(sprint亮黄(*kw))
def log亮蓝(*kw,**kargs):
logger.opt(depth=1).info(sprint亮蓝(*kw))
def log亮紫(*kw,**kargs):
logger.opt(depth=1).info(sprint亮紫(*kw))
def log亮靛(*kw,**kargs):
logger.opt(depth=1).info(sprint亮靛(*kw))

157
config.py
查看文件

@@ -11,6 +11,10 @@
API_KEY = "此处填API密钥" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
# [step 1]>> API_KEY = "sk-123456789xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx123456789"。极少数情况下,还需要填写组织格式如org-123456789abcdefghijklmno的,请向下翻,找 API_ORG 设置项
API_KEY = "此处填API密钥" # 可同时填写多个API-KEY,用英文逗号分割,例如API_KEY = "sk-openaikey1,sk-openaikey2,fkxxxx-api2dkey3,azure-apikey4"
# [step 2]>> 改为True应用代理,如果直接在海外服务器部署,此处不修改;如果使用本地或无地域限制的大模型时,此处也不需要修改
USE_PROXY = False
if USE_PROXY:
@@ -30,44 +34,11 @@ if USE_PROXY:
else:
proxies = None
# [step 3]>> 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
LLM_MODEL = "gpt-3.5-turbo-16k" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
"gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-4v", "glm-3-turbo",
"gemini-1.5-pro", "chatglm3"
]
EMBEDDING_MODEL = "text-embedding-3-small"
# --- --- --- ---
# P.S. 其他可用的模型还包括
# AVAIL_LLM_MODELS = [
# "glm-4-0520", "glm-4-air", "glm-4-airx", "glm-4-flash",
# "qianfan", "deepseekcoder",
# "spark", "sparkv2", "sparkv3", "sparkv3.5", "sparkv4",
# "qwen-turbo", "qwen-plus", "qwen-max", "qwen-local",
# "moonshot-v1-128k", "moonshot-v1-32k", "moonshot-v1-8k",
# "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0125", "gpt-4o-2024-05-13"
# "claude-3-haiku-20240307","claude-3-sonnet-20240229","claude-3-opus-20240229", "claude-2.1", "claude-instant-1.2",
# "moss", "llama2", "chatglm_onnx", "internlm", "jittorllms_pangualpha", "jittorllms_llama",
# "deepseek-chat" ,"deepseek-coder",
# "gemini-1.5-flash",
# "yi-34b-chat-0205","yi-34b-chat-200k","yi-large","yi-medium","yi-spark","yi-large-turbo","yi-large-preview",
# ]
# --- --- --- ---
# 此外,您还可以在接入one-api/vllm/ollama/Openroute时,
# 使用"one-api-*","vllm-*","ollama-*","openrouter-*"前缀直接使用非标准方式接入的模型,例如
# AVAIL_LLM_MODELS = ["one-api-claude-3-sonnet-20240229(max_token=100000)", "ollama-phi3(max_token=4096)","openrouter-openai/gpt-4o-mini","openrouter-openai/chatgpt-4o-latest"]
# --- --- --- ---
# --------------- 以下配置可以优化体验 ---------------
# ------------------------------------ 以下配置可以优化体验, 但大部分场合下并不需要修改 ------------------------------------
# 重新URL重新定向,实现更换API_URL的作用高危设置! 常规情况下不要修改! 通过修改此设置,您将把您的API-KEY和对话隐私完全暴露给您设定的中间人
# 格式: API_URL_REDIRECT = {"https://api.openai.com/v1/chat/completions": "在这里填写重定向的api.openai.com的URL"}
# 举例: API_URL_REDIRECT = {"https://api.openai.com/v1/chat/completions": "https://reverse-proxy-url/v1/chat/completions", "http://localhost:11434/api/chat": "在这里填写您ollama的URL"}
# 举例: API_URL_REDIRECT = {"https://api.openai.com/v1/chat/completions": "https://reverse-proxy-url/v1/chat/completions"}
API_URL_REDIRECT = {}
@@ -78,7 +49,7 @@ DEFAULT_WORKER_NUM = 3
# 色彩主题, 可选 ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast"]
# 更多主题, 请查阅Gradio主题商店: https://huggingface.co/spaces/gradio/theme-gallery 可选 ["Gstaff/Xkcd", "NoCrypt/Miku", ...]
THEME = "Default"
THEME = "Chuanhu-Small-and-Beautiful"
AVAIL_THEMES = ["Default", "Chuanhu-Small-and-Beautiful", "High-Contrast", "Gstaff/Xkcd", "NoCrypt/Miku"]
@@ -99,7 +70,7 @@ LAYOUT = "LEFT-RIGHT" # "LEFT-RIGHT"(左右布局) # "TOP-DOWN"(上下
# 暗色模式 / 亮色模式
DARK_MODE = True
DARK_MODE = False
# 发送请求到OpenAI后,等待多久判定为超时
@@ -110,18 +81,31 @@ TIMEOUT_SECONDS = 30
WEB_PORT = -1
# 是否自动打开浏览器页面
AUTO_OPEN_BROWSER = True
# 如果OpenAI不响应网络卡顿、代理失败、KEY失效,重试的次数限制
MAX_RETRY = 2
# OpenAI模型选择是gpt4现在只对申请成功的人开放
LLM_MODEL = "gpt-3.5-turbo" # 可选 "chatglm"
AVAIL_LLM_MODELS = ["gpt-3.5-turbo", "gpt-4", "api2d-gpt-4", "api2d-gpt-3.5-turbo", "spark", "azure-gpt-3.5"]
# 插件分类默认选项
DEFAULT_FN_GROUPS = ['对话', '编程', '学术', '智能体']
# 模型选择是 (注意: LLM_MODEL是默认选中的模型, 它*必须*被包含在AVAIL_LLM_MODELS列表中 )
LLM_MODEL = "gpt-3.5-turbo-16k" # 可选 ↓↓↓
AVAIL_LLM_MODELS = ["gpt-4-1106-preview", "gpt-4-turbo-preview", "gpt-4-vision-preview",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
"gpt-4", "gpt-4-32k", "azure-gpt-4", "glm-4", "glm-3-turbo",
"gemini-pro", "chatglm3", "claude-2"]
# P.S. 其他可用的模型还包括 [
# "moss", "qwen-turbo", "qwen-plus", "qwen-max"
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
# ]
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
MULTI_QUERY_LLM_MODELS = "gpt-3.5-turbo&chatglm3"
@@ -139,7 +123,7 @@ DASHSCOPE_API_KEY = "" # 阿里灵积云API_KEY
# 百度千帆LLM_MODEL="qianfan"
BAIDU_CLOUD_API_KEY = ''
BAIDU_CLOUD_SECRET_KEY = ''
BAIDU_CLOUD_QIANFAN_MODEL = 'ERNIE-Bot' # 可选 "ERNIE-Bot-4"(文心大模型4.0), "ERNIE-Bot"(文心一言), "ERNIE-Bot-turbo", "BLOOMZ-7B", "Llama-2-70B-Chat", "Llama-2-13B-Chat", "Llama-2-7B-Chat", "ERNIE-Speed-128K", "ERNIE-Speed-8K", "ERNIE-Lite-8K"
BAIDU_CLOUD_QIANFAN_MODEL = 'ERNIE-Bot' # 可选 "ERNIE-Bot-4"(文心大模型4.0), "ERNIE-Bot"(文心一言), "ERNIE-Bot-turbo", "BLOOMZ-7B", "Llama-2-70B-Chat", "Llama-2-13B-Chat", "Llama-2-7B-Chat"
# 如果使用ChatGLM2微调模型,请把 LLM_MODEL="chatglmft",并在此处指定模型路径
@@ -150,7 +134,6 @@ CHATGLM_PTUNING_CHECKPOINT = "" # 例如"/home/hmp/ChatGLM2-6B/ptuning/output/6b
LOCAL_MODEL_DEVICE = "cpu" # 可选 "cuda"
LOCAL_MODEL_QUANT = "FP16" # 默认 "FP16" "INT4" 启用量化INT4版本 "INT8" 启用量化INT8版本
# 设置gradio的并行线程数不需要修改
CONCURRENT_COUNT = 100
@@ -160,7 +143,7 @@ AUTO_CLEAR_TXT = False
# 加一个live2d装饰
ADD_WAIFU = False
ADD_WAIFU = True
# 设置用户名和密码不需要修改相关功能不稳定,与gradio版本和网络都相关,如果本地使用不建议加这个
@@ -168,8 +151,7 @@ ADD_WAIFU = False
AUTHENTICATION = []
# 如果需要在二级路径下运行(常规情况下,不要修改!!
# (举例 CUSTOM_PATH = "/gpt_academic",可以让软件运行在 http://ip:port/gpt_academic/ 下。)
# 如果需要在二级路径下运行(常规情况下,不要修改!!需要配合修改main.py才能生效!
CUSTOM_PATH = "/"
@@ -197,8 +179,14 @@ AZURE_ENGINE = "填入你亲手写的部署名" # 读 docs\use_azure.
AZURE_CFG_ARRAY = {}
# 阿里云实时语音识别 配置难度较高
# 参考 https://github.com/binary-husky/gpt_academic/blob/master/docs/use_audio.md
# 使用Newbing (不推荐使用,未来将删除)
NEWBING_STYLE = "creative" # ["creative", "balanced", "precise"]
NEWBING_COOKIES = """
put your new bing cookies here
"""
# 阿里云实时语音识别 配置难度较高 仅建议高手用户使用 参考 https://github.com/binary-husky/gpt_academic/blob/master/docs/use_audio.md
ENABLE_AUDIO = False
ALIYUN_TOKEN="" # 例如 f37f30e0f9934c34a992f6f64f7eba4f
ALIYUN_APPKEY="" # 例如 RoPlZrM88DnAFkZK
@@ -206,12 +194,6 @@ ALIYUN_ACCESSKEY="" # (无需填写)
ALIYUN_SECRET="" # (无需填写)
# GPT-SOVITS 文本转语音服务的运行地址(将语言模型的生成文本朗读出来)
TTS_TYPE = "EDGE_TTS" # EDGE_TTS / LOCAL_SOVITS_API / DISABLE
GPT_SOVITS_URL = ""
EDGE_TTS_VOICE = "zh-CN-XiaoxiaoNeural"
# 接入讯飞星火大模型 https://console.xfyun.cn/services/iat
XFYUN_APPID = "00000000"
XFYUN_API_SECRET = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
@@ -223,35 +205,21 @@ ZHIPUAI_API_KEY = ""
ZHIPUAI_MODEL = "" # 此选项已废弃,不再需要填写
# # 火山引擎YUNQUE大模型
# YUNQUE_SECRET_KEY = ""
# YUNQUE_ACCESS_KEY = ""
# YUNQUE_MODEL = ""
# Claude API KEY
ANTHROPIC_API_KEY = ""
# 月之暗面 API KEY
MOONSHOT_API_KEY = ""
# 零一万物(Yi Model) API KEY
YIMODEL_API_KEY = ""
# 深度求索(DeepSeek) API KEY,默认请求地址为"https://api.deepseek.com/v1/chat/completions"
DEEPSEEK_API_KEY = ""
# 紫东太初大模型 https://ai-maas.wair.ac.cn
TAICHU_API_KEY = ""
# Mathpix 拥有执行PDF的OCR功能,但是需要注册账号
MATHPIX_APPID = ""
MATHPIX_APPKEY = ""
# DOC2X的PDF解析服务,注册账号并获取API KEY: https://doc2x.noedgeai.com/login
DOC2X_API_KEY = ""
# 自定义API KEY格式
CUSTOM_API_KEY_PATTERN = ""
@@ -261,7 +229,7 @@ GEMINI_API_KEY = ''
# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
HUGGINGFACE_ACCESS_TOKEN = "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"
HUGGINGFACE_ACCESS_TOKEN = ""
# GROBID服务器地址填写多个可以均衡负载,用于高质量地读取PDF文档
@@ -273,10 +241,6 @@ GROBID_URLS = [
]
# Searxng互联网检索服务
SEARXNG_URL = "https://cloud-1.agent-matrix.com/"
# 是否允许通过自然语言描述修改本页的配置,该功能具有一定的危险性,默认关闭
ALLOW_RESET_CONFIG = False
@@ -285,21 +249,21 @@ ALLOW_RESET_CONFIG = False
AUTOGEN_USE_DOCKER = False
# 临时的上传文件夹位置,请尽量不要修改
# 临时的上传文件夹位置,请修改
PATH_PRIVATE_UPLOAD = "private_upload"
# 日志文件夹的位置,请尽量不要修改
# 日志文件夹的位置,请修改
PATH_LOGGING = "gpt_log"
# 存储翻译好的arxiv论文的路径,请尽量不要修改
ARXIV_CACHE_DIR = "gpt_log/arxiv_cache"
# 除了连接OpenAI之外,还有哪些场合允许使用代理,请尽量不要修改
# 除了连接OpenAI之外,还有哪些场合允许使用代理,请勿修改
WHEN_TO_USE_PROXY = ["Download_LLM", "Download_Gradio_Theme", "Connect_Grobid",
"Warmup_Modules", "Nougat_Download", "AutoGen", "Connect_OpenAI_Embedding"]
"Warmup_Modules", "Nougat_Download", "AutoGen"]
# *实验性功能*: 自动检测并屏蔽失效的KEY,请勿使用
BLOCK_INVALID_APIKEY = False
# 启用插件热加载
@@ -309,11 +273,7 @@ PLUGIN_HOT_RELOAD = False
# 自定义按钮的最大数量限制
NUM_CUSTOM_BASIC_BTN = 4
"""
--------------- 配置关联关系说明 ---------------
在线大模型配置关联关系示意图
├── "gpt-3.5-turbo" 等openai模型
@@ -337,7 +297,7 @@ NUM_CUSTOM_BASIC_BTN = 4
│ ├── XFYUN_API_SECRET
│ └── XFYUN_API_KEY
├── "claude-3-opus-20240229" 等claude模型
├── "claude-1-100k" 等claude模型
│ └── ANTHROPIC_API_KEY
├── "stack-claude"
@@ -352,19 +312,15 @@ NUM_CUSTOM_BASIC_BTN = 4
├── "glm-4", "glm-3-turbo", "zhipuai" 智谱AI大模型
│ └── ZHIPUAI_API_KEY
├── "yi-34b-chat-0205", "yi-34b-chat-200k" 等零一万物(Yi Model)大模型
│ └── YIMODEL_API_KEY
├── "qwen-turbo" 等通义千问大模型
│ └── DASHSCOPE_API_KEY
├── "Gemini"
│ └── GEMINI_API_KEY
└── "one-api-...(max_token=...)" 用一种更方便的方式接入one-api多模型管理界面
├── AVAIL_LLM_MODELS
── API_KEY
└── API_URL_REDIRECT
└── "newbing" Newbing接口不再稳定,不推荐使用
├── NEWBING_STYLE
── NEWBING_COOKIES
本地大模型示意图
@@ -398,9 +354,6 @@ NUM_CUSTOM_BASIC_BTN = 4
插件在线服务配置依赖关系示意图
├── 互联网检索
│ └── SEARXNG_URL
├── 语音功能
│ ├── ENABLE_AUDIO
│ ├── ALIYUN_TOKEN

查看文件

@@ -17,7 +17,7 @@ def get_core_functions():
text_show_english=
r"Below is a paragraph from an academic paper. Polish the writing to meet the academic style, "
r"improve the spelling, grammar, clarity, concision and overall readability. When necessary, rewrite the whole sentence. "
r"Firstly, you should provide the polished paragraph (in English). "
r"Firstly, you should provide the polished paragraph. "
r"Secondly, you should list all your modification and explain the reasons to do so in markdown table.",
text_show_chinese=
r"作为一名中文学术论文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,"
@@ -33,19 +33,17 @@ def get_core_functions():
"AutoClearHistory": False,
# [6] 文本预处理 (可选参数,默认 None,举例写个函数移除所有的换行符
"PreProcess": None,
# [7] 模型选择 (可选参数。如不设置,则使用当前全局模型;如设置,则用指定模型覆盖全局模型。)
# "ModelOverride": "gpt-3.5-turbo", # 主要用途:强制点击此基础功能按钮时,使用指定的模型。
},
"总结绘制脑图": {
# 前缀,会被加在你的输入之前。例如,用来描述你的要求,例如翻译、解释代码、润色等等
"Prefix": '''"""\n\n''',
"Prefix": r"",
# 后缀,会被加在你的输入之后。例如,配合前缀可以把你的输入内容用引号圈起来
"Suffix":
# dedent() 函数用于去除多行字符串的缩进
dedent("\n\n"+r'''
"""
dedent("\n"+r'''
==============================
使用mermaid flowchart对以上文本进行总结,概括上述段落的内容以及内在逻辑关系,例如
@@ -59,15 +57,15 @@ def get_core_functions():
C --> |"箭头名2"| F["节点名6"]
```
注意
警告
1使用中文
2节点名字使用引号包裹,如["Laptop"]
3`|` 和 `"`之间不要存在空格
4根据情况选择flowchart LR从左到右或者flowchart TD从上到下
'''),
},
"查找语法错误": {
"Prefix": r"Help me ensure that the grammar and the spelling is correct. "
r"Do not try to polish the text, if no mistake is found, tell me that this paragraph is good. "
@@ -87,14 +85,14 @@ def get_core_functions():
"Suffix": r"",
"PreProcess": clear_line_break, # 预处理:清除换行符
},
"中译英": {
"Prefix": r"Please translate following sentence to English:" + "\n\n",
"Suffix": r"",
},
"学术英中互译": {
"Prefix": build_gpt_academic_masked_string_langbased(
text_show_chinese=
@@ -114,29 +112,29 @@ def get_core_functions():
) + "\n\n",
"Suffix": r"",
},
"英译中": {
"Prefix": r"翻译成地道的中文:" + "\n\n",
"Suffix": r"",
"Visible": False,
},
"找图片": {
"Prefix": r"我需要你找一张网络图片。使用Unsplash API(https://source.unsplash.com/960x640/?<英语关键词>)获取图片URL,"
r"然后请使用Markdown格式封装,并且不要有反斜线,不要用代码块。现在,请按以下描述给我发送图片" + "\n\n",
"Suffix": r"",
"Visible": False,
},
"解释代码": {
"Prefix": r"请解释以下代码:" + "\n```\n",
"Suffix": "\n```\n",
},
"参考文献转Bib": {
"Prefix": r"Here are some bibliography items, please transform them into bibtex style."
r"Note that, reference styles maybe more than one kind, you should transform each item correctly."

查看文件

@@ -1,62 +1,46 @@
from toolbox import HotReload # HotReload 的意思是热更新,修改函数插件后,不需要重启程序,代码直接生效
from toolbox import trimmed_format_exc
from loguru import logger
def get_crazy_functions():
from crazy_functions.读文章写摘要 import 读文章写摘要
from crazy_functions.生成函数注释 import 批量生成函数注释
from crazy_functions.SourceCode_Analyse import 解析项目本身
from crazy_functions.SourceCode_Analyse import 解析一个Python项目
from crazy_functions.SourceCode_Analyse import 解析一个Matlab项目
from crazy_functions.SourceCode_Analyse import 解析一个C项目的头文件
from crazy_functions.SourceCode_Analyse import 解析一个C项目
from crazy_functions.SourceCode_Analyse import 解析一个Golang项目
from crazy_functions.SourceCode_Analyse import 解析一个Rust项目
from crazy_functions.SourceCode_Analyse import 解析一个Java项目
from crazy_functions.SourceCode_Analyse import 解析一个前端项目
from crazy_functions.解析项目源代码 import 解析项目本身
from crazy_functions.解析项目源代码 import 解析一个Python项目
from crazy_functions.解析项目源代码 import 解析一个Matlab项目
from crazy_functions.解析项目源代码 import 解析一个C项目的头文件
from crazy_functions.解析项目源代码 import 解析一个C项目
from crazy_functions.解析项目源代码 import 解析一个Golang项目
from crazy_functions.解析项目源代码 import 解析一个Rust项目
from crazy_functions.解析项目源代码 import 解析一个Java项目
from crazy_functions.解析项目源代码 import 解析一个前端项目
from crazy_functions.高级功能函数模板 import 高阶功能模板函数
from crazy_functions.高级功能函数模板 import Demo_Wrap
from crazy_functions.Latex全文润色 import Latex英文润色
from crazy_functions.询问多个大语言模型 import 同时问询
from crazy_functions.SourceCode_Analyse import 解析一个Lua项目
from crazy_functions.SourceCode_Analyse import 解析一个CSharp项目
from crazy_functions.解析项目源代码 import 解析一个Lua项目
from crazy_functions.解析项目源代码 import 解析一个CSharp项目
from crazy_functions.总结word文档 import 总结word文档
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
from crazy_functions.Conversation_To_File import 载入对话历史存档
from crazy_functions.Conversation_To_File import 对话历史存档
from crazy_functions.Conversation_To_File import Conversation_To_File_Wrap
from crazy_functions.Conversation_To_File import 删除所有本地对话历史记录
from crazy_functions.对话历史存档 import 对话历史存档
from crazy_functions.对话历史存档 import 载入对话历史存档
from crazy_functions.对话历史存档 import 删除所有本地对话历史记录
from crazy_functions.辅助功能 import 清除缓存
from crazy_functions.批量文件询问 import 批量文件询问
from crazy_functions.Markdown_Translate import Markdown英译中
from crazy_functions.批量Markdown翻译 import Markdown英译中
from crazy_functions.批量总结PDF文档 import 批量总结PDF文档
from crazy_functions.PDF_Translate import 批量翻译PDF文档
from crazy_functions.批量翻译PDF文档_多线程 import 批量翻译PDF文档
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
from crazy_functions.理解PDF文档内容 import 理解PDF文档内容标准文件输入
from crazy_functions.Latex全文润色 import Latex中文润色
from crazy_functions.Latex全文润色 import Latex英文纠错
from crazy_functions.Markdown_Translate import Markdown中译英
from crazy_functions.批量Markdown翻译 import Markdown中译英
from crazy_functions.虚空终端 import 虚空终端
from crazy_functions.生成多种Mermaid图表 import Mermaid_Gen
from crazy_functions.PDF_Translate_Wrap import PDF_Tran
from crazy_functions.Latex_Function import Latex英文纠错加PDF对比
from crazy_functions.Latex_Function import Latex翻译中文并重新编译PDF
from crazy_functions.Latex_Function import PDF翻译中文并重新编译PDF
from crazy_functions.Latex_Function_Wrap import Arxiv_Localize
from crazy_functions.Latex_Function_Wrap import PDF_Localize
from crazy_functions.Internet_GPT import 连接网络回答问题
from crazy_functions.Internet_GPT_Wrap import NetworkGPT_Wrap
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
from crazy_functions.Image_Generate_Wrap import ImageGen_Wrap
from crazy_functions.SourceCode_Comment import 注释Python项目
from crazy_functions.SourceCode_Comment_Wrap import SourceCodeComment_Wrap
from crazy_functions.生成多种Mermaid图表 import 生成多种Mermaid图表
function_plugins = {
"虚空终端": {
"Group": "对话|编程|学术|智能体",
"Color": "stop",
"AsButton": True,
"Info": "使用自然语言实现您的想法",
"Function": HotReload(虚空终端),
},
"解析整个Python项目": {
@@ -66,14 +50,6 @@ def get_crazy_functions():
"Info": "解析一个Python项目的所有源文件(.py) | 输入参数为路径",
"Function": HotReload(解析一个Python项目),
},
"注释Python项目": {
"Group": "编程",
"Color": "stop",
"AsButton": False,
"Info": "上传一系列python源文件(或者压缩包), 为这些代码添加docstring | 输入参数为路径",
"Function": HotReload(注释Python项目),
"Class": SourceCodeComment_Wrap,
},
"载入对话历史存档(先上传存档或输入路径)": {
"Group": "对话",
"Color": "stop",
@@ -99,24 +75,16 @@ def get_crazy_functions():
"Color": "stop",
"AsButton": False,
"Info" : "基于当前对话或文件生成多种Mermaid图表,图表类型由模型判断",
"Function": None,
"Class": Mermaid_Gen
"Function": HotReload(生成多种Mermaid图表),
"AdvancedArgs": True,
"ArgsReminder": "请输入图类型对应的数字,不输入则为模型自行判断:1-流程图,2-序列图,3-类图,4-饼图,5-甘特图,6-状态图,7-实体关系图,8-象限提示图,9-思维导图",
},
"Arxiv论文翻译": {
"批量总结Word文档": {
"Group": "学术",
"Color": "stop",
"AsButton": True,
"Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": Arxiv_Localize, # 新一代插件需要注册Class
},
"批量文件询问": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"Info": "通过在高级参数区写入prompt,可自定义询问逻辑,默认情况下为总结逻辑 | 输入参数为路径",
"Function": HotReload(批量文件询问),
"Info": "批量总结word文档 | 输入参数为路径",
"Function": HotReload(总结word文档),
},
"解析整个Matlab项目": {
"Group": "编程",
@@ -220,42 +188,28 @@ def get_crazy_functions():
},
"保存当前的对话": {
"Group": "对话",
"Color": "stop",
"AsButton": True,
"Info": "保存当前的对话 | 不需要输入参数",
"Function": HotReload(对话历史存档), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": Conversation_To_File_Wrap # 新一代插件需要注册Class
"Function": HotReload(对话历史存档),
},
"[多线程Demo]解析此项目本身(源码自译解)": {
"Group": "对话|编程",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "多线程解析并翻译此项目的源码 | 不需要输入参数",
"Function": HotReload(解析项目本身),
},
"查互联网后回答": {
"Group": "对话",
"Color": "stop",
"AsButton": True, # 加入下拉菜单中
# "Info": "连接网络回答问题(需要访问谷歌)| 输入参数是一个问题",
"Function": HotReload(连接网络回答问题),
"Class": NetworkGPT_Wrap # 新一代插件需要注册Class
},
"历史上的今天": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"AsButton": True,
"Info": "查看历史上的今天事件 (这是一个面向开发者的插件Demo) | 不需要输入参数",
"Function": None,
"Class": Demo_Wrap, # 新一代插件需要注册Class
"Function": HotReload(高阶功能模板函数),
},
"精准翻译PDF论文": {
"Group": "学术",
"Color": "stop",
"AsButton": True,
"Info": "精准翻译PDF论文为中文 | 输入参数为路径",
"Function": HotReload(批量翻译PDF文档), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": PDF_Tran, # 新一代插件需要注册Class
"Function": HotReload(批量翻译PDF文档),
},
"询问多个GPT模型": {
"Group": "对话",
@@ -330,85 +284,8 @@ def get_crazy_functions():
"Info": "批量将Markdown文件中文翻译为英文 | 输入参数为路径或上传压缩包",
"Function": HotReload(Markdown中译英),
},
"Latex英文纠错+高亮修正位置 [需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": "如果有必要, 请在此处追加更细致的矫错指令(使用英文)。",
"Function": HotReload(Latex英文纠错加PDF对比),
},
"📚Arxiv论文精细翻译输入arxivID[需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
"Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Function": HotReload(Latex翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": Arxiv_Localize, # 新一代插件需要注册Class
},
"📚本地Latex论文精细翻译上传Latex项目[需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
"Info": "本地Latex论文精细翻译 | 输入参数是路径",
"Function": HotReload(Latex翻译中文并重新编译PDF),
},
"PDF翻译中文并重新编译PDF上传PDF[需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
"Info": "PDF翻译中文,并重新编译PDF | 输入参数为路径",
"Function": HotReload(PDF翻译中文并重新编译PDF), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": PDF_Localize # 新一代插件需要注册Class
}
}
function_plugins.update(
{
"🎨图片生成DALLE2/DALLE3, 使用前切换到GPT系列模型": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"Info": "使用 DALLE2/DALLE3 生成图片 | 输入参数字符串,提供图像的内容",
"Function": HotReload(图片生成_DALLE2), # 当注册Class后,Function旧接口仅会在“虚空终端”中起作用
"Class": ImageGen_Wrap # 新一代插件需要注册Class
},
}
)
function_plugins.update(
{
"🎨图片修改_DALLE2 使用前请切换模型到GPT系列": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": False, # 调用时,唤起高级参数输入区默认False
# "Info": "使用DALLE2修改图片 | 输入参数字符串,提供图像的内容",
"Function": HotReload(图片修改_DALLE2),
},
}
)
# -=--=- 尚未充分测试的实验性插件 & 需要额外依赖的插件 -=--=-
try:
from crazy_functions.下载arxiv论文翻译摘要 import 下载arxiv论文并翻译摘要
@@ -425,42 +302,42 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
# try:
# from crazy_functions.联网的ChatGPT import 连接网络回答问题
# function_plugins.update(
# {
# "连接网络回答问题(输入问题后点击该插件,需要访问谷歌)": {
# "Group": "对话",
# "Color": "stop",
# "AsButton": False, # 加入下拉菜单中
# # "Info": "连接网络回答问题(需要访问谷歌)| 输入参数是一个问题",
# "Function": HotReload(连接网络回答问题),
# }
# }
# )
# from crazy_functions.联网的ChatGPT_bing版 import 连接bing搜索回答问题
# function_plugins.update(
# {
# "连接网络回答问题中文Bing版,输入问题后点击该插件": {
# "Group": "对话",
# "Color": "stop",
# "AsButton": False, # 加入下拉菜单中
# "Info": "连接网络回答问题需要访问中文Bing| 输入参数是一个问题",
# "Function": HotReload(连接bing搜索回答问题),
# }
# }
# )
# except:
# logger.error(trimmed_format_exc())
# logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.SourceCode_Analyse import 解析任意code项目
from crazy_functions.联网的ChatGPT import 连接网络回答问题
function_plugins.update(
{
"连接网络回答问题(输入问题后点击该插件,需要访问谷歌)": {
"Group": "对话",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
# "Info": "连接网络回答问题(需要访问谷歌)| 输入参数是一个问题",
"Function": HotReload(连接网络回答问题),
}
}
)
from crazy_functions.联网的ChatGPT_bing版 import 连接bing搜索回答问题
function_plugins.update(
{
"连接网络回答问题中文Bing版,输入问题后点击该插件": {
"Group": "对话",
"Color": "stop",
"AsButton": False, # 加入下拉菜单中
"Info": "连接网络回答问题需要访问中文Bing| 输入参数是一个问题",
"Function": HotReload(连接bing搜索回答问题),
}
}
)
except:
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.解析项目源代码 import 解析任意code项目
function_plugins.update(
{
@@ -475,8 +352,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.询问多个大语言模型 import 同时问询_指定模型
@@ -494,10 +371,53 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.图片生成 import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
function_plugins.update(
{
"图片生成_DALLE2 先切换模型到gpt-*": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True, # 调用时,唤起高级参数输入区默认False
"ArgsReminder": "在这里输入分辨率, 如1024x1024默认,支持 256x256, 512x512, 1024x1024", # 高级参数输入区的显示提示
"Info": "使用DALLE2生成图片 | 输入参数字符串,提供图像的内容",
"Function": HotReload(图片生成_DALLE2),
},
}
)
function_plugins.update(
{
"图片生成_DALLE3 先切换模型到gpt-*": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True, # 调用时,唤起高级参数输入区默认False
"ArgsReminder": "在这里输入自定义参数「分辨率-质量(可选)-风格(可选)」, 参数示例「1024x1024-hd-vivid」 || 分辨率支持 「1024x1024」(默认) /「1792x1024」/「1024x1792」 || 质量支持 「-standard」(默认) /「-hd」 || 风格支持 「-vivid」(默认) /「-natural」", # 高级参数输入区的显示提示
"Info": "使用DALLE3生成图片 | 输入参数字符串,提供图像的内容",
"Function": HotReload(图片生成_DALLE3),
},
}
)
function_plugins.update(
{
"图片修改_DALLE2 先切换模型到gpt-*": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": False, # 调用时,唤起高级参数输入区默认False
# "Info": "使用DALLE2修改图片 | 输入参数字符串,提供图像的内容",
"Function": HotReload(图片修改_DALLE2),
},
}
)
except:
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.总结音视频 import 总结音视频
@@ -516,8 +436,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.数学动画生成manim import 动画生成
@@ -534,11 +454,11 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.Markdown_Translate import Markdown翻译指定语言
from crazy_functions.批量Markdown翻译 import Markdown翻译指定语言
function_plugins.update(
{
@@ -553,8 +473,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.知识库问答 import 知识库文件注入
@@ -572,8 +492,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.知识库问答 import 读取知识库作答
@@ -591,8 +511,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.交互功能函数模板 import 交互功能模板函数
@@ -608,9 +528,62 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.Latex输出PDF import Latex英文纠错加PDF对比
from crazy_functions.Latex输出PDF import Latex翻译中文并重新编译PDF
from crazy_functions.Latex输出PDF import PDF翻译中文并重新编译PDF
function_plugins.update(
{
"Latex英文纠错+高亮修正位置 [需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": "如果有必要, 请在此处追加更细致的矫错指令(使用英文)。",
"Function": HotReload(Latex英文纠错加PDF对比),
},
"Arxiv论文精细翻译输入arxivID[需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
"Info": "Arixv论文精细翻译 | 输入参数arxiv论文的ID,比如1812.10695",
"Function": HotReload(Latex翻译中文并重新编译PDF),
},
"本地Latex论文精细翻译上传Latex项目[需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
"Info": "本地Latex论文精细翻译 | 输入参数是路径",
"Function": HotReload(Latex翻译中文并重新编译PDF),
},
"PDF翻译中文并重新编译PDF上传PDF[需Latex]": {
"Group": "学术",
"Color": "stop",
"AsButton": False,
"AdvancedArgs": True,
"ArgsReminder": r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
"Info": "PDF翻译中文,并重新编译PDF | 输入参数为路径",
"Function": HotReload(PDF翻译中文并重新编译PDF)
}
}
)
except:
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from toolbox import get_conf
@@ -631,8 +604,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.批量翻译PDF文档_NOUGAT import 批量翻译PDF文档
@@ -648,8 +621,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.函数动态生成 import 函数动态生成
@@ -665,8 +638,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.多智能体 import 多智能体终端
@@ -682,8 +655,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
try:
from crazy_functions.互动小游戏 import 随机小游戏
@@ -699,33 +672,8 @@ def get_crazy_functions():
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
try:
from crazy_functions.Rag_Interface import Rag问答
function_plugins.update(
{
"Rag智能召回": {
"Group": "对话",
"Color": "stop",
"AsButton": False,
"Info": "将问答数据记录到向量库中,作为长期参考。",
"Function": HotReload(Rag问答),
},
}
)
except:
logger.error(trimmed_format_exc())
logger.error("Load function plugin failed")
print(trimmed_format_exc())
print("Load function plugin failed")
# try:
# from crazy_functions.高级功能函数模板 import 测试图表渲染
@@ -738,7 +686,7 @@ def get_crazy_functions():
# }
# })
# except:
# logger.error(trimmed_format_exc())
# print(trimmed_format_exc())
# print('Load function plugin failed')
# try:

查看文件

@@ -0,0 +1,232 @@
from collections.abc import Callable, Iterable, Mapping
from typing import Any
from toolbox import CatchException, update_ui, gen_time_str, trimmed_format_exc
from toolbox import promote_file_to_downloadzone, get_log_folder
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from .crazy_utils import input_clipping, try_install_deps
from multiprocessing import Process, Pipe
import os
import time
templete = """
```python
import ... # Put dependencies here, e.g. import numpy as np
class TerminalFunction(object): # Do not change the name of the class, The name of the class must be `TerminalFunction`
def run(self, path): # The name of the function must be `run`, it takes only a positional argument.
# rewrite the function you have just written here
...
return generated_file_path
```
"""
def inspect_dependency(chatbot, history):
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return True
def get_code_block(reply):
import re
pattern = r"```([\s\S]*?)```" # regex pattern to match code blocks
matches = re.findall(pattern, reply) # find all code blocks in text
if len(matches) == 1:
return matches[0].strip('python') # code block
for match in matches:
if 'class TerminalFunction' in match:
return match.strip('python') # code block
raise RuntimeError("GPT is not generating proper code.")
def gpt_interact_multi_step(txt, file_type, llm_kwargs, chatbot, history):
# 输入
prompt_compose = [
f'Your job:\n'
f'1. write a single Python function, which takes a path of a `{file_type}` file as the only argument and returns a `string` containing the result of analysis or the path of generated files. \n',
f"2. You should write this function to perform following task: " + txt + "\n",
f"3. Wrap the output python function with markdown codeblock."
]
i_say = "".join(prompt_compose)
demo = []
# 第一步
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=i_say, inputs_show_user=i_say,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=demo,
sys_prompt= r"You are a programmer."
)
history.extend([i_say, gpt_say])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
# 第二步
prompt_compose = [
"If previous stage is successful, rewrite the function you have just written to satisfy following templete: \n",
templete
]
i_say = "".join(prompt_compose); inputs_show_user = "If previous stage is successful, rewrite the function you have just written to satisfy executable templete. "
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=i_say, inputs_show_user=inputs_show_user,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
sys_prompt= r"You are a programmer."
)
code_to_return = gpt_say
history.extend([i_say, gpt_say])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
# # 第三步
# i_say = "Please list to packages to install to run the code above. Then show me how to use `try_install_deps` function to install them."
# i_say += 'For instance. `try_install_deps(["opencv-python", "scipy", "numpy"])`'
# installation_advance = yield from request_gpt_model_in_new_thread_with_ui_alive(
# inputs=i_say, inputs_show_user=inputs_show_user,
# llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
# sys_prompt= r"You are a programmer."
# )
# # # 第三步
# i_say = "Show me how to use `pip` to install packages to run the code above. "
# i_say += 'For instance. `pip install -r opencv-python scipy numpy`'
# installation_advance = yield from request_gpt_model_in_new_thread_with_ui_alive(
# inputs=i_say, inputs_show_user=i_say,
# llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
# sys_prompt= r"You are a programmer."
# )
installation_advance = ""
return code_to_return, installation_advance, txt, file_type, llm_kwargs, chatbot, history
def make_module(code):
module_file = 'gpt_fn_' + gen_time_str().replace('-','_')
with open(f'{get_log_folder()}/{module_file}.py', 'w', encoding='utf8') as f:
f.write(code)
def get_class_name(class_string):
import re
# Use regex to extract the class name
class_name = re.search(r'class (\w+)\(', class_string).group(1)
return class_name
class_name = get_class_name(code)
return f"{get_log_folder().replace('/', '.')}.{module_file}->{class_name}"
def init_module_instance(module):
import importlib
module_, class_ = module.split('->')
init_f = getattr(importlib.import_module(module_), class_)
return init_f()
def for_immediate_show_off_when_possible(file_type, fp, chatbot):
if file_type in ['png', 'jpg']:
image_path = os.path.abspath(fp)
chatbot.append(['这是一张图片, 展示如下:',
f'本地文件地址: <br/>`{image_path}`<br/>'+
f'本地文件预览: <br/><div align="center"><img src="file={image_path}"></div>'
])
return chatbot
def subprocess_worker(instance, file_path, return_dict):
return_dict['result'] = instance.run(file_path)
def have_any_recent_upload_files(chatbot):
_5min = 5 * 60
if not chatbot: return False # chatbot is None
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
if not most_recent_uploaded: return False # most_recent_uploaded is None
if time.time() - most_recent_uploaded["time"] < _5min: return True # most_recent_uploaded is new
else: return False # most_recent_uploaded is too old
def get_recent_file_prompt_support(chatbot):
most_recent_uploaded = chatbot._cookies.get("most_recent_uploaded", None)
path = most_recent_uploaded['path']
return path
@CatchException
def 虚空终端CodeInterpreter(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
"""
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
llm_kwargs gpt模型参数,如温度和top_p等,一般原样传递下去就行
plugin_kwargs 插件模型的参数,暂时没有用武之地
chatbot 聊天显示框的句柄,用于显示给用户
history 聊天历史,前情提要
system_prompt 给gpt的静默提醒
web_port 当前软件运行的端口号
"""
raise NotImplementedError
# 清空历史,以免输入溢出
history = []; clear_file_downloadzone(chatbot)
# 基本信息:功能、贡献者
chatbot.append([
"函数插件功能?",
"CodeInterpreter开源版, 此插件处于开发阶段, 建议暂时不要使用, 插件初始化中 ..."
])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
if have_any_recent_upload_files(chatbot):
file_path = get_recent_file_prompt_support(chatbot)
else:
chatbot.append(["文件检索", "没有发现任何近期上传的文件。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 读取文件
if ("recently_uploaded_files" in plugin_kwargs) and (plugin_kwargs["recently_uploaded_files"] == ""): plugin_kwargs.pop("recently_uploaded_files")
recently_uploaded_files = plugin_kwargs.get("recently_uploaded_files", None)
file_path = recently_uploaded_files[-1]
file_type = file_path.split('.')[-1]
# 粗心检查
if is_the_upload_folder(txt):
chatbot.append([
"...",
f"请在输入框内填写需求,然后再次点击该插件(文件路径 {file_path} 已经被记忆)"
])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# 开始干正事
for j in range(5): # 最多重试5次
try:
code, installation_advance, txt, file_type, llm_kwargs, chatbot, history = \
yield from gpt_interact_multi_step(txt, file_type, llm_kwargs, chatbot, history)
code = get_code_block(code)
res = make_module(code)
instance = init_module_instance(res)
break
except Exception as e:
chatbot.append([f"{j}次代码生成尝试,失败了", f"错误追踪\n```\n{trimmed_format_exc()}\n```\n"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 代码生成结束, 开始执行
try:
import multiprocessing
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(target=subprocess_worker, args=(instance, file_path, return_dict))
# only has 10 seconds to run
p.start(); p.join(timeout=10)
if p.is_alive(): p.terminate(); p.join()
p.close()
res = return_dict['result']
# res = instance.run(file_path)
except Exception as e:
chatbot.append(["执行失败了", f"错误追踪\n```\n{trimmed_format_exc()}\n```\n"])
# chatbot.append(["如果是缺乏依赖,请参考以下建议", installation_advance])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# 顺利完成,收尾
res = str(res)
if os.path.exists(res):
chatbot.append(["执行成功了,结果是一个有效文件", "结果:" + res])
new_file_path = promote_file_to_downloadzone(res, chatbot=chatbot)
chatbot = for_immediate_show_off_when_possible(file_type, new_file_path, chatbot)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
else:
chatbot.append(["执行成功了,结果是一个字符串", "结果:" + res])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
"""
测试:
裁剪图像,保留下半部分
交换图像的蓝色通道和红色通道
将图像转为灰度图像
将csv文件转excel表格
"""

查看文件

@@ -1,56 +0,0 @@
from toolbox import get_conf, update_ui
from crazy_functions.Image_Generate import 图片生成_DALLE2, 图片生成_DALLE3, 图片修改_DALLE2
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
class ImageGen_Wrap(GptAcademicPluginTemplate):
def __init__(self):
"""
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
"""
pass
def define_arg_selection_menu(self):
"""
定义插件的二级选项菜单
第一个参数,名称`main_input`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
第二个参数,名称`advanced_arg`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
"""
gui_definition = {
"main_input":
ArgProperty(title="输入图片描述", description="需要生成图像的文本描述,尽量使用英文", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
"model_name":
ArgProperty(title="模型", options=["DALLE2", "DALLE3"], default_value="DALLE3", description="", type="dropdown").model_dump_json(),
"resolution":
ArgProperty(title="分辨率", options=["256x256(限DALLE2)", "512x512(限DALLE2)", "1024x1024", "1792x1024(限DALLE3)", "1024x1792(限DALLE3)"], default_value="1024x1024", description="", type="dropdown").model_dump_json(),
"quality (仅DALLE3生效)":
ArgProperty(title="质量", options=["standard", "hd"], default_value="standard", description="", type="dropdown").model_dump_json(),
"style (仅DALLE3生效)":
ArgProperty(title="风格", options=["vivid", "natural"], default_value="vivid", description="", type="dropdown").model_dump_json(),
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
# 分辨率
resolution = plugin_kwargs["resolution"].replace("(限DALLE2)", "").replace("(限DALLE3)", "")
if plugin_kwargs["model_name"] == "DALLE2":
plugin_kwargs["advanced_arg"] = resolution
yield from 图片生成_DALLE2(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
elif plugin_kwargs["model_name"] == "DALLE3":
quality = plugin_kwargs["quality (仅DALLE3生效)"]
style = plugin_kwargs["style (仅DALLE3生效)"]
plugin_kwargs["advanced_arg"] = f"{resolution}-{quality}-{style}"
yield from 图片生成_DALLE3(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
else:
chatbot.append([None, "抱歉,找不到该模型"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面

查看文件

@@ -1,278 +0,0 @@
import requests
import random
import time
import re
import json
from bs4 import BeautifulSoup
from functools import lru_cache
from itertools import zip_longest
from check_proxy import check_proxy
from toolbox import CatchException, update_ui, get_conf
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, input_clipping
from request_llms.bridge_all import model_info
from request_llms.bridge_all import predict_no_ui_long_connection
from crazy_functions.prompts.internet import SearchOptimizerPrompt, SearchAcademicOptimizerPrompt
def search_optimizer(
query,
proxies,
history,
llm_kwargs,
optimizer=1,
categories="general",
searxng_url=None,
engines=None,
):
# ------------- < 第1步尝试进行搜索优化 > -------------
# * 增强优化,会尝试结合历史记录进行搜索优化
if optimizer == 2:
his = " "
if len(history) == 0:
pass
else:
for i, h in enumerate(history):
if i % 2 == 0:
his += f"Q: {h}\n"
else:
his += f"A: {h}\n"
if categories == "general":
sys_prompt = SearchOptimizerPrompt.format(query=query, history=his, num=4)
elif categories == "science":
sys_prompt = SearchAcademicOptimizerPrompt.format(query=query, history=his, num=4)
else:
his = " "
if categories == "general":
sys_prompt = SearchOptimizerPrompt.format(query=query, history=his, num=3)
elif categories == "science":
sys_prompt = SearchAcademicOptimizerPrompt.format(query=query, history=his, num=3)
mutable = ["", time.time(), ""]
llm_kwargs["temperature"] = 0.8
try:
querys_json = predict_no_ui_long_connection(
inputs=query,
llm_kwargs=llm_kwargs,
history=[],
sys_prompt=sys_prompt,
observe_window=mutable,
)
except Exception:
querys_json = "1234"
#* 尝试解码优化后的搜索结果
querys_json = re.sub(r"```json|```", "", querys_json)
try:
querys = json.loads(querys_json)
except Exception:
#* 如果解码失败,降低温度再试一次
try:
llm_kwargs["temperature"] = 0.4
querys_json = predict_no_ui_long_connection(
inputs=query,
llm_kwargs=llm_kwargs,
history=[],
sys_prompt=sys_prompt,
observe_window=mutable,
)
querys_json = re.sub(r"```json|```", "", querys_json)
querys = json.loads(querys_json)
except Exception:
#* 如果再次失败,直接返回原始问题
querys = [query]
links = []
success = 0
Exceptions = ""
for q in querys:
try:
link = searxng_request(q, proxies, categories, searxng_url, engines=engines)
if len(link) > 0:
links.append(link[:-5])
success += 1
except Exception:
Exceptions = Exception
pass
if success == 0:
raise ValueError(f"在线搜索失败!\n{Exceptions}")
# * 清洗搜索结果,依次放入每组第一,第二个搜索结果,并清洗重复的搜索结果
seen_links = set()
result = []
for tuple in zip_longest(*links, fillvalue=None):
for item in tuple:
if item is not None:
link = item["link"]
if link not in seen_links:
seen_links.add(link)
result.append(item)
return result
@lru_cache
def get_auth_ip():
ip = check_proxy(None, return_ip=True)
if ip is None:
return '114.114.114.' + str(random.randint(1, 10))
return ip
def searxng_request(query, proxies, categories='general', searxng_url=None, engines=None):
if searxng_url is None:
url = get_conf("SEARXNG_URL")
else:
url = searxng_url
if engines == "Mixed":
engines = None
if categories == 'general':
params = {
'q': query, # 搜索查询
'format': 'json', # 输出格式为JSON
'language': 'zh', # 搜索语言
'engines': engines,
}
elif categories == 'science':
params = {
'q': query, # 搜索查询
'format': 'json', # 输出格式为JSON
'language': 'zh', # 搜索语言
'categories': 'science'
}
else:
raise ValueError('不支持的检索类型')
headers = {
'Accept-Language': 'zh-CN,zh;q=0.9',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36',
'X-Forwarded-For': get_auth_ip(),
'X-Real-IP': get_auth_ip()
}
results = []
response = requests.post(url, params=params, headers=headers, proxies=proxies, timeout=30)
if response.status_code == 200:
json_result = response.json()
for result in json_result['results']:
item = {
"title": result.get("title", ""),
"source": result.get("engines", "unknown"),
"content": result.get("content", ""),
"link": result["url"],
}
results.append(item)
return results
else:
if response.status_code == 429:
raise ValueError("Searxng在线搜索服务当前使用人数太多,请稍后。")
else:
raise ValueError("在线搜索失败,状态码: " + str(response.status_code) + '\t' + response.content.decode('utf-8'))
def scrape_text(url, proxies) -> str:
"""Scrape text from a webpage
Args:
url (str): The URL to scrape text from
Returns:
str: The scraped text
"""
headers = {
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/94.0.4606.61 Safari/537.36',
'Content-Type': 'text/plain',
}
try:
response = requests.get(url, headers=headers, proxies=proxies, timeout=8)
if response.encoding == "ISO-8859-1": response.encoding = response.apparent_encoding
except:
return "无法连接到该网页"
soup = BeautifulSoup(response.text, "html.parser")
for script in soup(["script", "style"]):
script.extract()
text = soup.get_text()
lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = "\n".join(chunk for chunk in chunks if chunk)
return text
@CatchException
def 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
optimizer_history = history[:-8]
history = [] # 清空历史,以免输入溢出
chatbot.append((f"请结合互联网信息回答以下问题:{txt}", "检索中..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# ------------- < 第1步爬取搜索引擎的结果 > -------------
from toolbox import get_conf
proxies = get_conf('proxies')
categories = plugin_kwargs.get('categories', 'general')
searxng_url = plugin_kwargs.get('searxng_url', None)
engines = plugin_kwargs.get('engine', None)
optimizer = plugin_kwargs.get('optimizer', "关闭")
if optimizer == "关闭":
urls = searxng_request(txt, proxies, categories, searxng_url, engines=engines)
else:
urls = search_optimizer(txt, proxies, optimizer_history, llm_kwargs, optimizer, categories, searxng_url, engines)
history = []
if len(urls) == 0:
chatbot.append((f"结论:{txt}",
"[Local Message] 受到限制,无法从searxng获取信息请尝试更换搜索引擎。"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# ------------- < 第2步依次访问网页 > -------------
max_search_result = 5 # 最多收纳多少个网页的结果
if optimizer == "开启(增强)":
max_search_result = 8
chatbot.append(["联网检索中 ...", None])
for index, url in enumerate(urls[:max_search_result]):
res = scrape_text(url['link'], proxies)
prefix = f"{index}份搜索结果 [源自{url['source'][0]}搜索] {url['title'][:25]}"
history.extend([prefix, res])
res_squeeze = res.replace('\n', '...')
chatbot[-1] = [prefix + "\n\n" + res_squeeze[:500] + "......", None]
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# ------------- < 第3步ChatGPT综合 > -------------
if (optimizer != "开启(增强)"):
i_say = f"从以上搜索结果中抽取信息,然后回答问题:{txt}"
i_say, history = input_clipping( # 裁剪输入,从最长的条目开始裁剪,防止爆token
inputs=i_say,
history=history,
max_token_limit=min(model_info[llm_kwargs['llm_model']]['max_token']*3//4, 8192)
)
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=i_say, inputs_show_user=i_say,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
sys_prompt="请从给定的若干条搜索结果中抽取信息,对最相关的两个搜索结果进行总结,然后回答问题。"
)
chatbot[-1] = (i_say, gpt_say)
history.append(i_say);history.append(gpt_say)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
#* 或者使用搜索优化器,这样可以保证后续问答能读取到有效的历史记录
else:
i_say = f"从以上搜索结果中抽取与问题:{txt} 相关的信息:"
i_say, history = input_clipping( # 裁剪输入,从最长的条目开始裁剪,防止爆token
inputs=i_say,
history=history,
max_token_limit=min(model_info[llm_kwargs['llm_model']]['max_token']*3//4, 8192)
)
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=i_say, inputs_show_user=i_say,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
sys_prompt="请从给定的若干条搜索结果中抽取信息,对最相关的三个搜索结果进行总结"
)
chatbot[-1] = (i_say, gpt_say)
history = []
history.append(i_say);history.append(gpt_say)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 界面更新
# ------------- < 第4步根据综合回答问题 > -------------
i_say = f"请根据以上搜索结果回答问题:{txt}"
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=i_say, inputs_show_user=i_say,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
sys_prompt="请根据给定的若干条搜索结果回答问题"
)
chatbot[-1] = (i_say, gpt_say)
history.append(i_say);history.append(gpt_say)
yield from update_ui(chatbot=chatbot, history=history)

查看文件

@@ -1,45 +0,0 @@
from toolbox import get_conf
from crazy_functions.Internet_GPT import 连接网络回答问题
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
class NetworkGPT_Wrap(GptAcademicPluginTemplate):
def __init__(self):
"""
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
"""
pass
def define_arg_selection_menu(self):
"""
定义插件的二级选项菜单
第一个参数,名称`main_input`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
第二个参数,名称`advanced_arg`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
第三个参数,名称`allow_cache`,参数`type`声明这是一个下拉菜单,下拉菜单上方显示`title`+`description`,下拉菜单的选项为`options`,`default_value`为下拉菜单默认值;
"""
gui_definition = {
"main_input":
ArgProperty(title="输入问题", description="待通过互联网检索的问题,会自动读取输入框内容", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
"categories":
ArgProperty(title="搜索分类", options=["网页", "学术论文"], default_value="网页", description="", type="dropdown").model_dump_json(),
"engine":
ArgProperty(title="选择搜索引擎", options=["Mixed", "bing", "google", "duckduckgo"], default_value="google", description="", type="dropdown").model_dump_json(),
"optimizer":
ArgProperty(title="搜索优化", options=["关闭", "开启", "开启(增强)"], default_value="关闭", description="是否使用搜索增强。注意这可能会消耗较多token", type="dropdown").model_dump_json(),
"searxng_url":
ArgProperty(title="Searxng服务地址", description="输入Searxng的地址", default_value=get_conf("SEARXNG_URL"), type="string").model_dump_json(), # 主输入,自动从输入框同步
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
if plugin_kwargs["categories"] == "网页": plugin_kwargs["categories"] = "general"
if plugin_kwargs["categories"] == "学术论文": plugin_kwargs["categories"] = "science"
yield from 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)

查看文件

@@ -0,0 +1,106 @@
from toolbox import CatchException, update_ui, ProxyNetworkActivate, update_ui_lastest_msg
from .crazy_utils import request_gpt_model_in_new_thread_with_ui_alive, get_files_from_everything
@CatchException
def 知识库问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
"""
txt 输入栏用户输入的文本,例如需要翻译的一段话,再例如一个包含了待处理文件的路径
llm_kwargs gpt模型参数, 如温度和top_p等, 一般原样传递下去就行
plugin_kwargs 插件模型的参数,暂时没有用武之地
chatbot 聊天显示框的句柄,用于显示给用户
history 聊天历史,前情提要
system_prompt 给gpt的静默提醒
web_port 当前软件运行的端口号
"""
history = [] # 清空历史,以免输入溢出
# < --------------------读取参数--------------- >
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
kai_id = plugin_kwargs.get("advanced_arg", 'default')
chatbot.append((f"向`{kai_id}`知识库中添加文件。", "[Local Message] 从一批文件(txt, md, tex)中读取数据构建知识库, 然后进行问答。"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# resolve deps
try:
from zh_langchain import construct_vector_store
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from .crazy_utils import knowledge_archive_interface
except Exception as e:
chatbot.append(["依赖不足", "导入依赖失败。正在尝试自动安装,请查看终端的输出或耐心等待..."])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
from .crazy_utils import try_install_deps
try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
return
# < --------------------读取文件--------------- >
file_manifest = []
spl = ["txt", "doc", "docx", "email", "epub", "html", "json", "md", "msg", "pdf", "ppt", "pptx", "rtf"]
for sp in spl:
_, file_manifest_tmp, _ = get_files_from_everything(txt, type=f'.{sp}')
file_manifest += file_manifest_tmp
if len(file_manifest) == 0:
chatbot.append(["没有找到任何可读取文件", "当前支持的格式包括: txt, md, docx, pptx, pdf, json等"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# < -------------------预热文本向量化模组--------------- >
chatbot.append(['<br/>'.join(file_manifest), "正在预热文本向量化模组, 如果是第一次运行, 将消耗较长时间下载中文向量化模型..."])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
print('Checking Text2vec ...')
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
# < -------------------构建知识库--------------- >
chatbot.append(['<br/>'.join(file_manifest), "正在构建知识库..."])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
print('Establishing knowledge archive ...')
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
kai = knowledge_archive_interface()
kai.feed_archive(file_manifest=file_manifest, id=kai_id)
kai_files = kai.get_loaded_file()
kai_files = '<br/>'.join(kai_files)
# chatbot.append(['知识库构建成功', "正在将知识库存储至cookie中"])
# yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# chatbot._cookies['langchain_plugin_embedding'] = kai.get_current_archive_id()
# chatbot._cookies['lock_plugin'] = 'crazy_functions.Langchain知识库->读取知识库作答'
# chatbot.append(['完成', "“根据知识库作答”函数插件已经接管问答系统, 提问吧! 但注意, 您接下来不能再使用其他插件了,刷新页面即可以退出知识库问答模式。"])
chatbot.append(['构建完成', f"当前知识库内的有效文件:\n\n---\n\n{kai_files}\n\n---\n\n请切换至“知识库问答”插件进行知识库访问, 或者使用此插件继续上传更多文件。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
@CatchException
def 读取知识库作答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port=-1):
# resolve deps
try:
from zh_langchain import construct_vector_store
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from .crazy_utils import knowledge_archive_interface
except Exception as e:
chatbot.append(["依赖不足", "导入依赖失败。正在尝试自动安装,请查看终端的输出或耐心等待..."])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
from .crazy_utils import try_install_deps
try_install_deps(['zh_langchain==0.2.1', 'pypinyin'], reload_m=['pypinyin', 'zh_langchain'])
yield from update_ui_lastest_msg("安装完成,您可以再次重试。", chatbot, history)
return
# < ------------------- --------------- >
kai = knowledge_archive_interface()
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
kai_id = plugin_kwargs.get("advanced_arg", 'default')
resp, prompt = kai.answer_with_archive_by_id(txt, kai_id)
chatbot.append((txt, f'[知识库 {kai_id}] ' + prompt))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新
gpt_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt, inputs_show_user=txt,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=[],
sys_prompt=system_prompt
)
history.extend((prompt, gpt_say))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面 # 由于请求gpt需要一段时间,我们先及时地做一次界面更新

查看文件

@@ -1,85 +0,0 @@
from crazy_functions.Latex_Function import Latex翻译中文并重新编译PDF, PDF翻译中文并重新编译PDF
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
class Arxiv_Localize(GptAcademicPluginTemplate):
def __init__(self):
"""
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
"""
pass
def define_arg_selection_menu(self):
"""
定义插件的二级选项菜单
第一个参数,名称`main_input`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
第二个参数,名称`advanced_arg`,参数`type`声明这是一个文本框,文本框上方显示`title`,文本框内部显示`description`,`default_value`为默认值;
第三个参数,名称`allow_cache`,参数`type`声明这是一个下拉菜单,下拉菜单上方显示`title`+`description`,下拉菜单的选项为`options`,`default_value`为下拉菜单默认值;
"""
gui_definition = {
"main_input":
ArgProperty(title="ArxivID", description="输入Arxiv的ID或者网址", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
"advanced_arg":
ArgProperty(title="额外的翻译提示词",
description=r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
"allow_cache":
ArgProperty(title="是否允许从缓存中调取结果", options=["允许缓存", "从头执行"], default_value="允许缓存", description="", type="dropdown").model_dump_json(),
"allow_cloudio":
ArgProperty(title="是否允许从GPTAC学术云下载(或者上传)翻译结果(仅针对Arxiv论文)", options=["允许", "禁止"], default_value="禁止", description="共享文献,互助互利", type="dropdown").model_dump_json(),
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
allow_cache = plugin_kwargs["allow_cache"]
allow_cloudio = plugin_kwargs["allow_cloudio"]
advanced_arg = plugin_kwargs["advanced_arg"]
if allow_cache == "从头执行": plugin_kwargs["advanced_arg"] = "--no-cache " + plugin_kwargs["advanced_arg"]
# 从云端下载翻译结果,以及上传翻译结果到云端;人人为我,我为人人。
if allow_cloudio == "允许": plugin_kwargs["advanced_arg"] = "--allow-cloudio " + plugin_kwargs["advanced_arg"]
yield from Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
class PDF_Localize(GptAcademicPluginTemplate):
def __init__(self):
"""
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
"""
pass
def define_arg_selection_menu(self):
"""
定义插件的二级选项菜单
"""
gui_definition = {
"main_input":
ArgProperty(title="PDF文件路径", description="未指定路径,请上传文件后,再点击该插件", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
"advanced_arg":
ArgProperty(title="额外的翻译提示词",
description=r"如果有必要, 请在此处给出自定义翻译命令, 解决部分词汇翻译不准确的问题。 "
r"例如当单词'agent'翻译不准确时, 请尝试把以下指令复制到高级参数区: "
r'If the term "agent" is used in this section, it should be translated to "智能体". ',
default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
"method":
ArgProperty(title="采用哪种方法执行转换", options=["MATHPIX", "DOC2X"], default_value="DOC2X", description="", type="dropdown").model_dump_json(),
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
yield from PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)

查看文件

@@ -1,6 +1,6 @@
from toolbox import update_ui, trimmed_format_exc, promote_file_to_downloadzone, get_log_folder
from toolbox import CatchException, report_exception, write_history_to_file, zip_folder
from loguru import logger
class PaperFileGroup():
def __init__(self):
@@ -33,7 +33,7 @@ class PaperFileGroup():
self.sp_file_index.append(index)
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
logger.info('Segmentation: done')
print('Segmentation: done')
def merge_result(self):
self.file_result = ["" for _ in range(len(self.file_paths))]
for r, k in zip(self.sp_file_result, self.sp_file_index):
@@ -46,7 +46,7 @@ class PaperFileGroup():
manifest.append(path + '.polish.tex')
f.write(res)
return manifest
def zip_result(self):
import os, time
folder = os.path.dirname(self.file_paths[0])
@@ -56,10 +56,10 @@ class PaperFileGroup():
def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en', mode='polish'):
import time, os, re
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
pfg = PaperFileGroup()
for index, fp in enumerate(file_manifest):
@@ -73,31 +73,31 @@ def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
pfg.file_paths.append(fp)
pfg.file_contents.append(clean_tex_content)
# <-------- 拆分过长的latex文件 ---------->
# <-------- 拆分过长的latex文件 ---------->
pfg.run_file_split(max_token_limit=1024)
n_split = len(pfg.sp_file_contents)
# <-------- 多线程润色开始 ---------->
# <-------- 多线程润色开始 ---------->
if language == 'en':
if mode == 'polish':
inputs_array = [r"Below is a section from an academic paper, polish this section to meet the academic standard, " +
r"improve the grammar, clarity and overall readability, do not modify any latex command such as \section, \cite and equations:" +
inputs_array = ["Below is a section from an academic paper, polish this section to meet the academic standard, " +
"improve the grammar, clarity and overall readability, do not modify any latex command such as \section, \cite and equations:" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
else:
inputs_array = [r"Below is a section from an academic paper, proofread this section." +
r"Do not modify any latex command such as \section, \cite, \begin, \item and equations. " +
r"Answer me only with the revised text:" +
inputs_array = [r"Below is a section from an academic paper, proofread this section." +
r"Do not modify any latex command such as \section, \cite, \begin, \item and equations. " +
r"Answer me only with the revised text:" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
inputs_show_user_array = [f"Polish {f}" for f in pfg.sp_file_tag]
sys_prompt_array = ["You are a professional academic paper writer." for _ in range(n_split)]
elif language == 'zh':
if mode == 'polish':
inputs_array = [r"以下是一篇学术论文中的一段内容,请将此部分润色以满足学术标准,提高语法、清晰度和整体可读性,不要修改任何LaTeX命令,例如\section,\cite和方程式" +
inputs_array = [f"以下是一篇学术论文中的一段内容,请将此部分润色以满足学术标准,提高语法、清晰度和整体可读性,不要修改任何LaTeX命令,例如\section,\cite和方程式" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
else:
inputs_array = [r"以下是一篇学术论文中的一段内容,请对这部分内容进行语法矫正。不要修改任何LaTeX命令,例如\section,\cite和方程式" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
inputs_array = [f"以下是一篇学术论文中的一段内容,请对这部分内容进行语法矫正。不要修改任何LaTeX命令,例如\section,\cite和方程式" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
inputs_show_user_array = [f"润色 {f}" for f in pfg.sp_file_tag]
sys_prompt_array=["你是一位专业的中文学术论文作家。" for _ in range(n_split)]
@@ -113,7 +113,7 @@ def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
scroller_max_len = 80
)
# <-------- 文本碎片重组为完整的tex文件,整理结果为压缩包 ---------->
# <-------- 文本碎片重组为完整的tex文件,整理结果为压缩包 ---------->
try:
pfg.sp_file_result = []
for i_say, gpt_say in zip(gpt_response_collection[0::2], gpt_response_collection[1::2]):
@@ -122,9 +122,9 @@ def 多文件润色(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
pfg.write_result()
pfg.zip_result()
except:
logger.error(trimmed_format_exc())
print(trimmed_format_exc())
# <-------- 整理结果,退出 ---------->
# <-------- 整理结果,退出 ---------->
create_report_file_name = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + f"-chatgpt.polish.md"
res = write_history_to_file(gpt_response_collection, file_basename=create_report_file_name)
promote_file_to_downloadzone(res, chatbot=chatbot)

查看文件

@@ -1,6 +1,6 @@
from toolbox import update_ui, promote_file_to_downloadzone
from toolbox import CatchException, report_exception, write_history_to_file
from loguru import logger
fast_debug = False
class PaperFileGroup():
def __init__(self):
@@ -33,13 +33,13 @@ class PaperFileGroup():
self.sp_file_index.append(index)
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
logger.info('Segmentation: done')
print('Segmentation: done')
def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, language='en'):
import time, os, re
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
# <-------- 读取Latex文件,删除其中的所有注释 ---------->
pfg = PaperFileGroup()
for index, fp in enumerate(file_manifest):
@@ -53,11 +53,11 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
pfg.file_paths.append(fp)
pfg.file_contents.append(clean_tex_content)
# <-------- 拆分过长的latex文件 ---------->
# <-------- 拆分过长的latex文件 ---------->
pfg.run_file_split(max_token_limit=1024)
n_split = len(pfg.sp_file_contents)
# <-------- 抽取摘要 ---------->
# <-------- 抽取摘要 ---------->
# if language == 'en':
# abs_extract_inputs = f"Please write an abstract for this paper"
@@ -70,14 +70,14 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
# sys_prompt="Your job is to collect information from materials。",
# )
# <-------- 多线程润色开始 ---------->
# <-------- 多线程润色开始 ---------->
if language == 'en->zh':
inputs_array = ["Below is a section from an English academic paper, translate it into Chinese, do not modify any latex command such as \section, \cite and equations:" +
inputs_array = ["Below is a section from an English academic paper, translate it into Chinese, do not modify any latex command such as \section, \cite and equations:" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
inputs_show_user_array = [f"翻译 {f}" for f in pfg.sp_file_tag]
sys_prompt_array = ["You are a professional academic paper translator." for _ in range(n_split)]
elif language == 'zh->en':
inputs_array = [f"Below is a section from a Chinese academic paper, translate it into English, do not modify any latex command such as \section, \cite and equations:" +
inputs_array = [f"Below is a section from a Chinese academic paper, translate it into English, do not modify any latex command such as \section, \cite and equations:" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
inputs_show_user_array = [f"翻译 {f}" for f in pfg.sp_file_tag]
sys_prompt_array = ["You are a professional academic paper translator." for _ in range(n_split)]
@@ -93,7 +93,7 @@ def 多文件翻译(file_manifest, project_folder, llm_kwargs, plugin_kwargs, ch
scroller_max_len = 80
)
# <-------- 整理结果,退出 ---------->
# <-------- 整理结果,退出 ---------->
create_report_file_name = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + f"-chatgpt.polish.md"
res = write_history_to_file(gpt_response_collection, create_report_file_name)
promote_file_to_downloadzone(res, chatbot=chatbot)

查看文件

@@ -1,12 +1,10 @@
from toolbox import update_ui, trimmed_format_exc, get_conf, get_log_folder, promote_file_to_downloadzone, check_repeat_upload, map_file_to_sha256
from toolbox import update_ui, trimmed_format_exc, get_conf, get_log_folder, promote_file_to_downloadzone
from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip_result, gen_time_str
from functools import partial
from loguru import logger
import glob, os, requests, time, json, tarfile, threading
import glob, os, requests, time, json, tarfile
pj = os.path.join
ARXIV_CACHE_DIR = get_conf("ARXIV_CACHE_DIR")
ARXIV_CACHE_DIR = os.path.expanduser(f"~/arxiv_cache/")
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- 工具函数 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
@@ -42,7 +40,7 @@ def switch_prompt(pfg, mode, more_requirement):
def desend_to_extracted_folder_if_exist(project_folder):
"""
"""
Descend into the extracted folder if it exists, otherwise return the original folder.
Args:
@@ -58,7 +56,7 @@ def desend_to_extracted_folder_if_exist(project_folder):
def move_project(project_folder, arxiv_id=None):
"""
"""
Create a new work folder and copy the project folder to it.
Args:
@@ -109,25 +107,20 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
except ValueError:
return False
if txt.startswith('https://arxiv.org/pdf/'):
arxiv_id = txt.split('/')[-1] # 2402.14207v2.pdf
txt = arxiv_id.split('v')[0] # 2402.14207
if ('.' in txt) and ('/' not in txt) and is_float(txt): # is arxiv ID
txt = 'https://arxiv.org/abs/' + txt.strip()
if ('.' in txt) and ('/' not in txt) and is_float(txt[:10]): # is arxiv ID
txt = 'https://arxiv.org/abs/' + txt[:10]
if not txt.startswith('https://arxiv.org'):
if not txt.startswith('https://arxiv.org'):
return txt, None # 是本地文件,跳过下载
# <-------------- inspect format ------------->
chatbot.append([f"检测到arxiv文档连接", '尝试下载 ...'])
yield from update_ui(chatbot=chatbot, history=history)
time.sleep(1) # 刷新界面
url_ = txt # https://arxiv.org/abs/1707.06690
if not txt.startswith('https://arxiv.org/abs/'):
msg = f"解析arxiv网址失败, 期望格式例如: https://arxiv.org/abs/1707.06690。实际得到格式: {url_}"
yield from update_ui_lastest_msg(msg, chatbot=chatbot, history=history) # 刷新界面
@@ -138,122 +131,97 @@ def arxiv_download(chatbot, history, txt, allow_cache=True):
cached_translation_pdf = check_cached_translation_pdf(arxiv_id)
if cached_translation_pdf and allow_cache: return cached_translation_pdf, arxiv_id
extract_dst = pj(ARXIV_CACHE_DIR, arxiv_id, 'extract')
url_tar = url_.replace('/abs/', '/e-print/')
translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
dst = pj(translation_dir, arxiv_id + '.tar')
extract_dst = pj(ARXIV_CACHE_DIR, arxiv_id, 'extract')
os.makedirs(translation_dir, exist_ok=True)
# <-------------- download arxiv source file ------------->
def fix_url_and_download():
# for url_tar in [url_.replace('/abs/', '/e-print/'), url_.replace('/abs/', '/src/')]:
for url_tar in [url_.replace('/abs/', '/src/'), url_.replace('/abs/', '/e-print/')]:
proxies = get_conf('proxies')
r = requests.get(url_tar, proxies=proxies)
if r.status_code == 200:
with open(dst, 'wb+') as f:
f.write(r.content)
return True
return False
if os.path.exists(dst) and allow_cache:
yield from update_ui_lastest_msg(f"调用缓存 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
success = True
dst = pj(translation_dir, arxiv_id + '.tar')
if os.path.exists(dst):
yield from update_ui_lastest_msg("调用缓存", chatbot=chatbot, history=history) # 刷新界面
else:
yield from update_ui_lastest_msg(f"开始下载 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
success = fix_url_and_download()
yield from update_ui_lastest_msg(f"下载完成 {arxiv_id}", chatbot=chatbot, history=history) # 刷新界面
if not success:
yield from update_ui_lastest_msg(f"下载失败 {arxiv_id}", chatbot=chatbot, history=history)
raise tarfile.ReadError(f"论文下载失败 {arxiv_id}")
yield from update_ui_lastest_msg("开始下载", chatbot=chatbot, history=history) # 刷新界面
proxies = get_conf('proxies')
r = requests.get(url_tar, proxies=proxies)
with open(dst, 'wb+') as f:
f.write(r.content)
# <-------------- extract file ------------->
yield from update_ui_lastest_msg("下载完成", chatbot=chatbot, history=history) # 刷新界面
from toolbox import extract_archive
try:
extract_archive(file_path=dst, dest_dir=extract_dst)
except tarfile.ReadError:
os.remove(dst)
raise tarfile.ReadError(f"论文下载失败")
extract_archive(file_path=dst, dest_dir=extract_dst)
return extract_dst, arxiv_id
def pdf2tex_project(pdf_file_path, plugin_kwargs):
if plugin_kwargs["method"] == "MATHPIX":
# Mathpix API credentials
app_id, app_key = get_conf('MATHPIX_APPID', 'MATHPIX_APPKEY')
headers = {"app_id": app_id, "app_key": app_key}
def pdf2tex_project(pdf_file_path):
# Mathpix API credentials
app_id, app_key = get_conf('MATHPIX_APPID', 'MATHPIX_APPKEY')
headers = {"app_id": app_id, "app_key": app_key}
# Step 1: Send PDF file for processing
options = {
"conversion_formats": {"tex.zip": True},
"math_inline_delimiters": ["$", "$"],
"rm_spaces": True
}
# Step 1: Send PDF file for processing
options = {
"conversion_formats": {"tex.zip": True},
"math_inline_delimiters": ["$", "$"],
"rm_spaces": True
}
response = requests.post(url="https://api.mathpix.com/v3/pdf",
headers=headers,
data={"options_json": json.dumps(options)},
files={"file": open(pdf_file_path, "rb")})
response = requests.post(url="https://api.mathpix.com/v3/pdf",
headers=headers,
data={"options_json": json.dumps(options)},
files={"file": open(pdf_file_path, "rb")})
if response.ok:
pdf_id = response.json()["pdf_id"]
logger.info(f"PDF processing initiated. PDF ID: {pdf_id}")
if response.ok:
pdf_id = response.json()["pdf_id"]
print(f"PDF processing initiated. PDF ID: {pdf_id}")
# Step 2: Check processing status
while True:
conversion_response = requests.get(f"https://api.mathpix.com/v3/pdf/{pdf_id}", headers=headers)
conversion_data = conversion_response.json()
# Step 2: Check processing status
while True:
conversion_response = requests.get(f"https://api.mathpix.com/v3/pdf/{pdf_id}", headers=headers)
conversion_data = conversion_response.json()
if conversion_data["status"] == "completed":
logger.info("PDF processing completed.")
break
elif conversion_data["status"] == "error":
logger.info("Error occurred during processing.")
else:
logger.info(f"Processing status: {conversion_data['status']}")
time.sleep(5) # wait for a few seconds before checking again
if conversion_data["status"] == "completed":
print("PDF processing completed.")
break
elif conversion_data["status"] == "error":
print("Error occurred during processing.")
else:
print(f"Processing status: {conversion_data['status']}")
time.sleep(5) # wait for a few seconds before checking again
# Step 3: Save results to local files
output_dir = os.path.join(os.path.dirname(pdf_file_path), 'mathpix_output')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Step 3: Save results to local files
output_dir = os.path.join(os.path.dirname(pdf_file_path), 'mathpix_output')
if not os.path.exists(output_dir):
os.makedirs(output_dir)
url = f"https://api.mathpix.com/v3/pdf/{pdf_id}.tex"
response = requests.get(url, headers=headers)
file_name_wo_dot = '_'.join(os.path.basename(pdf_file_path).split('.')[:-1])
output_name = f"{file_name_wo_dot}.tex.zip"
output_path = os.path.join(output_dir, output_name)
with open(output_path, "wb") as output_file:
output_file.write(response.content)
logger.info(f"tex.zip file saved at: {output_path}")
url = f"https://api.mathpix.com/v3/pdf/{pdf_id}.tex"
response = requests.get(url, headers=headers)
file_name_wo_dot = '_'.join(os.path.basename(pdf_file_path).split('.')[:-1])
output_name = f"{file_name_wo_dot}.tex.zip"
output_path = os.path.join(output_dir, output_name)
with open(output_path, "wb") as output_file:
output_file.write(response.content)
print(f"tex.zip file saved at: {output_path}")
import zipfile
unzip_dir = os.path.join(output_dir, file_name_wo_dot)
with zipfile.ZipFile(output_path, 'r') as zip_ref:
zip_ref.extractall(unzip_dir)
import zipfile
unzip_dir = os.path.join(output_dir, file_name_wo_dot)
with zipfile.ZipFile(output_path, 'r') as zip_ref:
zip_ref.extractall(unzip_dir)
return unzip_dir
else:
logger.error(f"Error sending PDF for processing. Status code: {response.status_code}")
return None
else:
from crazy_functions.pdf_fns.parse_pdf_via_doc2x import 解析PDF_DOC2X_转Latex
unzip_dir = 解析PDF_DOC2X_转Latex(pdf_file_path)
return unzip_dir
else:
print(f"Error sending PDF for processing. Status code: {response.status_code}")
return None
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序1 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序1 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
@CatchException
def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
# <-------------- information about this plugin ------------->
chatbot.append(["函数插件功能?",
"对整个Latex项目进行纠错, 用latex编译为PDF对修正处做高亮。函数插件贡献者: Binary-Husky。注意事项: 目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。仅在Windows系统进行了测试,其他操作系统表现未知。"])
"对整个Latex项目进行纠错, 用latex编译为PDF对修正处做高亮。函数插件贡献者: Binary-Husky。注意事项: 目前仅支持GPT3.5/GPT4,其他模型转化效果未知。目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。仅在Windows系统进行了测试,其他操作系统表现未知。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# <-------------- more requirements ------------->
@@ -291,8 +259,6 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
project_folder = desend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
project_folder = move_project(project_folder, arxiv_id=None)
# <-------------- if merge_translate_zh is already generated, skip gpt req ------------->
@@ -316,7 +282,7 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
else:
chatbot.append((f"失败了",
'虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 也是可读的, 您可以到Github Issue区, 用该压缩包+Conversation_To_File进行反馈 ...'))
'虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 也是可读的, 您可以到Github Issue区, 用该压缩包+对话历史存档进行反馈 ...'))
yield from update_ui(chatbot=chatbot, history=history);
time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
@@ -325,30 +291,24 @@ def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, histo
return success
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序2 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序2 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
@CatchException
def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
# <-------------- information about this plugin ------------->
chatbot.append([
"函数插件功能?",
"对整个Latex项目进行翻译, 生成中文PDF。函数插件贡献者: Binary-Husky。注意事项: 此插件Windows支持最佳,Linux下必须使用Docker安装,详见项目主README.md。目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。"])
"对整个Latex项目进行翻译, 生成中文PDF。函数插件贡献者: Binary-Husky。注意事项: 此插件Windows支持最佳,Linux下必须使用Docker安装,详见项目主README.md。目前仅支持GPT3.5/GPT4,其他模型转化效果未知。目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# <-------------- more requirements ------------->
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
more_req = plugin_kwargs.get("advanced_arg", "")
no_cache = ("--no-cache" in more_req)
if no_cache: more_req = more_req.replace("--no-cache", "").strip()
allow_gptac_cloud_io = ("--allow-cloudio" in more_req) # 从云端下载翻译结果,以及上传翻译结果到云端
if allow_gptac_cloud_io: more_req = more_req.replace("--allow-cloudio", "").strip()
no_cache = more_req.startswith("--no-cache")
if no_cache: more_req.lstrip("--no-cache")
allow_cache = not no_cache
_switch_prompt_ = partial(switch_prompt, more_requirement=more_req)
# <-------------- check deps ------------->
try:
import glob, os, time, subprocess
@@ -366,7 +326,7 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
txt, arxiv_id = yield from arxiv_download(chatbot, history, txt, allow_cache)
except tarfile.ReadError as e:
yield from update_ui_lastest_msg(
"无法自动下载该论文的Latex源码,请前往arxiv打开此论文下载页面,点other Formats,然后download source手动下载latex源码包。接下来调用本地Latex翻译插件即可。",
"无法自动下载该论文的Latex源码,请前往arxiv打开此论文下载页面,点other Formats,然后download source手动下载latex源码包。接下来调用本地Latex翻译插件即可。",
chatbot=chatbot, history=history)
return
@@ -375,20 +335,6 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# #################################################################
if allow_gptac_cloud_io and arxiv_id:
# 访问 GPTAC学术云,查询云端是否存在该论文的翻译版本
from crazy_functions.latex_fns.latex_actions import check_gptac_cloud
success, downloaded = check_gptac_cloud(arxiv_id, chatbot)
if success:
chatbot.append([
f"检测到GPTAC云端存在翻译版本, 如果不满意翻译结果, 请禁用云端分享, 然后重新执行。",
None
])
yield from update_ui(chatbot=chatbot, history=history)
return
#################################################################
if os.path.exists(txt):
project_folder = txt
else:
@@ -407,8 +353,6 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
project_folder = desend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
project_folder = move_project(project_folder, arxiv_id)
# <-------------- if merge_translate_zh is already generated, skip gpt req ------------->
@@ -426,21 +370,14 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
# <-------------- zip PDF ------------->
zip_res = zip_result(project_folder)
if success:
if allow_gptac_cloud_io and arxiv_id:
# 如果用户允许,我们将翻译好的arxiv论文PDF上传到GPTAC学术云
from crazy_functions.latex_fns.latex_actions import upload_to_gptac_cloud_if_user_allow
threading.Thread(target=upload_to_gptac_cloud_if_user_allow,
args=(chatbot, arxiv_id), daemon=True).start()
chatbot.append((f"成功啦", '请查收结果(压缩包)...'))
yield from update_ui(chatbot=chatbot, history=history)
yield from update_ui(chatbot=chatbot, history=history);
time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
else:
chatbot.append((f"失败了",
'虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 您可以到Github Issue区, 用该压缩包进行反馈。如系统是Linux,请检查系统字体见Github wiki ...'))
yield from update_ui(chatbot=chatbot, history=history)
yield from update_ui(chatbot=chatbot, history=history);
time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
@@ -448,14 +385,14 @@ def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot,
return success
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- 插件主程序3 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- 插件主程序3 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
@CatchException
def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
# <-------------- information about this plugin ------------->
chatbot.append([
"函数插件功能?",
"将PDF转换为Latex项目,翻译为中文后重新编译为PDF。函数插件贡献者: Marroh。注意事项: 此插件Windows支持最佳,Linux下必须使用Docker安装,详见项目主README.md。目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。"])
"将PDF转换为Latex项目,翻译为中文后重新编译为PDF。函数插件贡献者: Marroh。注意事项: 此插件Windows支持最佳,Linux下必须使用Docker安装,详见项目主README.md。目前仅支持GPT3.5/GPT4,其他模型转化效果未知。目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# <-------------- more requirements ------------->
@@ -495,55 +432,16 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"不支持同时处理多个pdf文件: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
if plugin_kwargs.get("method", "") == 'MATHPIX':
app_id, app_key = get_conf('MATHPIX_APPID', 'MATHPIX_APPKEY')
if len(app_id) == 0 or len(app_key) == 0:
report_exception(chatbot, history, a="缺失 MATHPIX_APPID 和 MATHPIX_APPKEY。", b=f"请配置 MATHPIX_APPID 和 MATHPIX_APPKEY")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
if plugin_kwargs.get("method", "") == 'DOC2X':
app_id, app_key = "", ""
DOC2X_API_KEY = get_conf('DOC2X_API_KEY')
if len(DOC2X_API_KEY) == 0:
report_exception(chatbot, history, a="缺失 DOC2X_API_KEY。", b=f"请配置 DOC2X_API_KEY")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
hash_tag = map_file_to_sha256(file_manifest[0])
# # <-------------- check repeated pdf ------------->
# chatbot.append([f"检查PDF是否被重复上传", "正在检查..."])
# yield from update_ui(chatbot=chatbot, history=history)
# repeat, project_folder = check_repeat_upload(file_manifest[0], hash_tag)
# if repeat:
# yield from update_ui_lastest_msg(f"发现重复上传,请查收结果(压缩包)...", chatbot=chatbot, history=history)
# try:
# translate_pdf = [f for f in glob.glob(f'{project_folder}/**/merge_translate_zh.pdf', recursive=True)][0]
# promote_file_to_downloadzone(translate_pdf, rename_file=None, chatbot=chatbot)
# comparison_pdf = [f for f in glob.glob(f'{project_folder}/**/comparison.pdf', recursive=True)][0]
# promote_file_to_downloadzone(comparison_pdf, rename_file=None, chatbot=chatbot)
# zip_res = zip_result(project_folder)
# promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
# return
# except:
# report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"发现重复上传,但是无法找到相关文件")
# yield from update_ui(chatbot=chatbot, history=history)
# else:
# yield from update_ui_lastest_msg(f"未发现重复上传", chatbot=chatbot, history=history)
app_id, app_key = get_conf('MATHPIX_APPID', 'MATHPIX_APPKEY')
if len(app_id) == 0 or len(app_key) == 0:
report_exception(chatbot, history, a="缺失 MATHPIX_APPID 和 MATHPIX_APPKEY。", b=f"请配置 MATHPIX_APPID 和 MATHPIX_APPKEY")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# <-------------- convert pdf into tex ------------->
chatbot.append([f"解析项目: {txt}", "正在将PDF转换为tex项目,请耐心等待..."])
yield from update_ui(chatbot=chatbot, history=history)
project_folder = pdf2tex_project(file_manifest[0], plugin_kwargs)
if project_folder is None:
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"PDF转换为tex项目失败")
yield from update_ui(chatbot=chatbot, history=history)
return False
project_folder = pdf2tex_project(file_manifest[0])
# <-------------- translate latex file into Chinese ------------->
yield from update_ui_lastest_msg("正在tex项目将翻译为中文...", chatbot=chatbot, history=history)
# Translate English Latex to Chinese Latex, and compile it
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.tex', recursive=True)]
if len(file_manifest) == 0:
report_exception(chatbot, history, a=f"解析项目: {txt}", b=f"找不到任何.tex文件: {txt}")
@@ -554,28 +452,19 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
project_folder = desend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
from shared_utils.fastapi_server import validate_path_safety
validate_path_safety(project_folder, chatbot.get_user())
project_folder = move_project(project_folder)
# <-------------- set a hash tag for repeat-checking ------------->
with open(pj(project_folder, hash_tag + '.tag'), 'w') as f:
f.write(hash_tag)
f.close()
# <-------------- if merge_translate_zh is already generated, skip gpt req ------------->
if not os.path.exists(project_folder + '/merge_translate_zh.tex'):
yield from Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin_kwargs,
chatbot, history, system_prompt, mode='translate_zh',
switch_prompt=_switch_prompt_)
chatbot, history, system_prompt, mode='translate_zh',
switch_prompt=_switch_prompt_)
# <-------------- compile PDF ------------->
yield from update_ui_lastest_msg("正在将翻译好的项目tex项目编译为PDF...", chatbot=chatbot, history=history)
success = yield from 编译Latex(chatbot, history, main_file_original='merge',
main_file_modified='merge_translate_zh', mode='translate_zh',
work_folder_original=project_folder, work_folder_modified=project_folder,
work_folder=project_folder)
main_file_modified='merge_translate_zh', mode='translate_zh',
work_folder_original=project_folder, work_folder_modified=project_folder,
work_folder=project_folder)
# <-------------- zip PDF ------------->
zip_res = zip_result(project_folder)
@@ -592,4 +481,4 @@ def PDF翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, h
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
# <-------------- we are done ------------->
return success
return success

查看文件

@@ -0,0 +1,306 @@
from toolbox import update_ui, trimmed_format_exc, get_conf, get_log_folder, promote_file_to_downloadzone
from toolbox import CatchException, report_exception, update_ui_lastest_msg, zip_result, gen_time_str
from functools import partial
import glob, os, requests, time
pj = os.path.join
ARXIV_CACHE_DIR = os.path.expanduser(f"~/arxiv_cache/")
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- 工具函数 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
# 专业词汇声明 = 'If the term "agent" is used in this section, it should be translated to "智能体". '
def switch_prompt(pfg, mode, more_requirement):
"""
Generate prompts and system prompts based on the mode for proofreading or translating.
Args:
- pfg: Proofreader or Translator instance.
- mode: A string specifying the mode, either 'proofread' or 'translate_zh'.
Returns:
- inputs_array: A list of strings containing prompts for users to respond to.
- sys_prompt_array: A list of strings containing prompts for system prompts.
"""
n_split = len(pfg.sp_file_contents)
if mode == 'proofread_en':
inputs_array = [r"Below is a section from an academic paper, proofread this section." +
r"Do not modify any latex command such as \section, \cite, \begin, \item and equations. " + more_requirement +
r"Answer me only with the revised text:" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
sys_prompt_array = ["You are a professional academic paper writer." for _ in range(n_split)]
elif mode == 'translate_zh':
inputs_array = [r"Below is a section from an English academic paper, translate it into Chinese. " + more_requirement +
r"Do not modify any latex command such as \section, \cite, \begin, \item and equations. " +
r"Answer me only with the translated text:" +
f"\n\n{frag}" for frag in pfg.sp_file_contents]
sys_prompt_array = ["You are a professional translator." for _ in range(n_split)]
else:
assert False, "未知指令"
return inputs_array, sys_prompt_array
def desend_to_extracted_folder_if_exist(project_folder):
"""
Descend into the extracted folder if it exists, otherwise return the original folder.
Args:
- project_folder: A string specifying the folder path.
Returns:
- A string specifying the path to the extracted folder, or the original folder if there is no extracted folder.
"""
maybe_dir = [f for f in glob.glob(f'{project_folder}/*') if os.path.isdir(f)]
if len(maybe_dir) == 0: return project_folder
if maybe_dir[0].endswith('.extract'): return maybe_dir[0]
return project_folder
def move_project(project_folder, arxiv_id=None):
"""
Create a new work folder and copy the project folder to it.
Args:
- project_folder: A string specifying the folder path of the project.
Returns:
- A string specifying the path to the new work folder.
"""
import shutil, time
time.sleep(2) # avoid time string conflict
if arxiv_id is not None:
new_workfolder = pj(ARXIV_CACHE_DIR, arxiv_id, 'workfolder')
else:
new_workfolder = f'{get_log_folder()}/{gen_time_str()}'
try:
shutil.rmtree(new_workfolder)
except:
pass
# align subfolder if there is a folder wrapper
items = glob.glob(pj(project_folder,'*'))
items = [item for item in items if os.path.basename(item)!='__MACOSX']
if len(glob.glob(pj(project_folder,'*.tex'))) == 0 and len(items) == 1:
if os.path.isdir(items[0]): project_folder = items[0]
shutil.copytree(src=project_folder, dst=new_workfolder)
return new_workfolder
def arxiv_download(chatbot, history, txt, allow_cache=True):
def check_cached_translation_pdf(arxiv_id):
translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'translation')
if not os.path.exists(translation_dir):
os.makedirs(translation_dir)
target_file = pj(translation_dir, 'translate_zh.pdf')
if os.path.exists(target_file):
promote_file_to_downloadzone(target_file, rename_file=None, chatbot=chatbot)
target_file_compare = pj(translation_dir, 'comparison.pdf')
if os.path.exists(target_file_compare):
promote_file_to_downloadzone(target_file_compare, rename_file=None, chatbot=chatbot)
return target_file
return False
def is_float(s):
try:
float(s)
return True
except ValueError:
return False
if ('.' in txt) and ('/' not in txt) and is_float(txt): # is arxiv ID
txt = 'https://arxiv.org/abs/' + txt.strip()
if ('.' in txt) and ('/' not in txt) and is_float(txt[:10]): # is arxiv ID
txt = 'https://arxiv.org/abs/' + txt[:10]
if not txt.startswith('https://arxiv.org'):
return txt, None
# <-------------- inspect format ------------->
chatbot.append([f"检测到arxiv文档连接", '尝试下载 ...'])
yield from update_ui(chatbot=chatbot, history=history)
time.sleep(1) # 刷新界面
url_ = txt # https://arxiv.org/abs/1707.06690
if not txt.startswith('https://arxiv.org/abs/'):
msg = f"解析arxiv网址失败, 期望格式例如: https://arxiv.org/abs/1707.06690。实际得到格式: {url_}"
yield from update_ui_lastest_msg(msg, chatbot=chatbot, history=history) # 刷新界面
return msg, None
# <-------------- set format ------------->
arxiv_id = url_.split('/abs/')[-1]
if 'v' in arxiv_id: arxiv_id = arxiv_id[:10]
cached_translation_pdf = check_cached_translation_pdf(arxiv_id)
if cached_translation_pdf and allow_cache: return cached_translation_pdf, arxiv_id
url_tar = url_.replace('/abs/', '/e-print/')
translation_dir = pj(ARXIV_CACHE_DIR, arxiv_id, 'e-print')
extract_dst = pj(ARXIV_CACHE_DIR, arxiv_id, 'extract')
os.makedirs(translation_dir, exist_ok=True)
# <-------------- download arxiv source file ------------->
dst = pj(translation_dir, arxiv_id+'.tar')
if os.path.exists(dst):
yield from update_ui_lastest_msg("调用缓存", chatbot=chatbot, history=history) # 刷新界面
else:
yield from update_ui_lastest_msg("开始下载", chatbot=chatbot, history=history) # 刷新界面
proxies = get_conf('proxies')
r = requests.get(url_tar, proxies=proxies)
with open(dst, 'wb+') as f:
f.write(r.content)
# <-------------- extract file ------------->
yield from update_ui_lastest_msg("下载完成", chatbot=chatbot, history=history) # 刷新界面
from toolbox import extract_archive
extract_archive(file_path=dst, dest_dir=extract_dst)
return extract_dst, arxiv_id
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序1 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
@CatchException
def Latex英文纠错加PDF对比(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
# <-------------- information about this plugin ------------->
chatbot.append([ "函数插件功能?",
"对整个Latex项目进行纠错, 用latex编译为PDF对修正处做高亮。函数插件贡献者: Binary-Husky。注意事项: 目前仅支持GPT3.5/GPT4,其他模型转化效果未知。目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。仅在Windows系统进行了测试,其他操作系统表现未知。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# <-------------- more requirements ------------->
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
more_req = plugin_kwargs.get("advanced_arg", "")
_switch_prompt_ = partial(switch_prompt, more_requirement=more_req)
# <-------------- check deps ------------->
try:
import glob, os, time, subprocess
subprocess.Popen(['pdflatex', '-version'])
from .latex_fns.latex_actions import Latex精细分解与转化, 编译Latex
except Exception as e:
chatbot.append([ f"解析项目: {txt}",
f"尝试执行Latex指令失败。Latex没有安装, 或者不在环境变量PATH中。安装方法https://tug.org/texlive/。报错信息\n\n```\n\n{trimmed_format_exc()}\n\n```\n\n"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# <-------------- clear history and read input ------------->
history = []
if os.path.exists(txt):
project_folder = txt
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.tex', recursive=True)]
if len(file_manifest) == 0:
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.tex文件: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# <-------------- if is a zip/tar file ------------->
project_folder = desend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
project_folder = move_project(project_folder, arxiv_id=None)
# <-------------- if merge_translate_zh is already generated, skip gpt req ------------->
if not os.path.exists(project_folder + '/merge_proofread_en.tex'):
yield from Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin_kwargs,
chatbot, history, system_prompt, mode='proofread_en', switch_prompt=_switch_prompt_)
# <-------------- compile PDF ------------->
success = yield from 编译Latex(chatbot, history, main_file_original='merge', main_file_modified='merge_proofread_en',
work_folder_original=project_folder, work_folder_modified=project_folder, work_folder=project_folder)
# <-------------- zip PDF ------------->
zip_res = zip_result(project_folder)
if success:
chatbot.append((f"成功啦", '请查收结果(压缩包)...'))
yield from update_ui(chatbot=chatbot, history=history); time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
else:
chatbot.append((f"失败了", '虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 也是可读的, 您可以到Github Issue区, 用该压缩包+对话历史存档进行反馈 ...'))
yield from update_ui(chatbot=chatbot, history=history); time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
# <-------------- we are done ------------->
return success
# =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= 插件主程序2 =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
@CatchException
def Latex翻译中文并重新编译PDF(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
# <-------------- information about this plugin ------------->
chatbot.append([
"函数插件功能?",
"对整个Latex项目进行翻译, 生成中文PDF。函数插件贡献者: Binary-Husky。注意事项: 此插件Windows支持最佳,Linux下必须使用Docker安装,详见项目主README.md。目前仅支持GPT3.5/GPT4,其他模型转化效果未知。目前对机器学习类文献转化效果最好,其他类型文献转化效果未知。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# <-------------- more requirements ------------->
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
more_req = plugin_kwargs.get("advanced_arg", "")
no_cache = more_req.startswith("--no-cache")
if no_cache: more_req.lstrip("--no-cache")
allow_cache = not no_cache
_switch_prompt_ = partial(switch_prompt, more_requirement=more_req)
# <-------------- check deps ------------->
try:
import glob, os, time, subprocess
subprocess.Popen(['pdflatex', '-version'])
from .latex_fns.latex_actions import Latex精细分解与转化, 编译Latex
except Exception as e:
chatbot.append([ f"解析项目: {txt}",
f"尝试执行Latex指令失败。Latex没有安装, 或者不在环境变量PATH中。安装方法https://tug.org/texlive/。报错信息\n\n```\n\n{trimmed_format_exc()}\n\n```\n\n"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# <-------------- clear history and read input ------------->
history = []
txt, arxiv_id = yield from arxiv_download(chatbot, history, txt, allow_cache)
if txt.endswith('.pdf'):
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"发现已经存在翻译好的PDF文档")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
if os.path.exists(txt):
project_folder = txt
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无法处理: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.tex', recursive=True)]
if len(file_manifest) == 0:
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何.tex文件: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# <-------------- if is a zip/tar file ------------->
project_folder = desend_to_extracted_folder_if_exist(project_folder)
# <-------------- move latex project away from temp folder ------------->
project_folder = move_project(project_folder, arxiv_id)
# <-------------- if merge_translate_zh is already generated, skip gpt req ------------->
if not os.path.exists(project_folder + '/merge_translate_zh.tex'):
yield from Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin_kwargs,
chatbot, history, system_prompt, mode='translate_zh', switch_prompt=_switch_prompt_)
# <-------------- compile PDF ------------->
success = yield from 编译Latex(chatbot, history, main_file_original='merge', main_file_modified='merge_translate_zh', mode='translate_zh',
work_folder_original=project_folder, work_folder_modified=project_folder, work_folder=project_folder)
# <-------------- zip PDF ------------->
zip_res = zip_result(project_folder)
if success:
chatbot.append((f"成功啦", '请查收结果(压缩包)...'))
yield from update_ui(chatbot=chatbot, history=history); time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
else:
chatbot.append((f"失败了", '虽然PDF生成失败了, 但请查收结果(压缩包), 内含已经翻译的Tex文档, 您可以到Github Issue区, 用该压缩包进行反馈。如系统是Linux,请检查系统字体见Github wiki ...'))
yield from update_ui(chatbot=chatbot, history=history); time.sleep(1) # 刷新界面
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
# <-------------- we are done ------------->
return success

查看文件

@@ -1,83 +0,0 @@
from toolbox import CatchException, check_packages, get_conf
from toolbox import update_ui, update_ui_lastest_msg, disable_auto_promotion
from toolbox import trimmed_format_exc_markdown
from crazy_functions.crazy_utils import get_files_from_everything
from crazy_functions.pdf_fns.parse_pdf import get_avail_grobid_url
from crazy_functions.pdf_fns.parse_pdf_via_doc2x import 解析PDF_基于DOC2X
from crazy_functions.pdf_fns.parse_pdf_legacy import 解析PDF_简单拆解
from crazy_functions.pdf_fns.parse_pdf_grobid import 解析PDF_基于GROBID
from shared_utils.colorful import *
@CatchException
def 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
disable_auto_promotion(chatbot)
# 基本信息:功能、贡献者
chatbot.append([None, "插件功能批量翻译PDF文档。函数插件贡献者: Binary-Husky"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 尝试导入依赖,如果缺少依赖,则给出安装建议
try:
check_packages(["fitz", "tiktoken", "scipdf"])
except:
chatbot.append([None, f"导入软件依赖失败。使用该模块需要额外依赖,安装方法```pip install --upgrade pymupdf tiktoken scipdf_parser```。"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# 清空历史,以免输入溢出
history = []
success, file_manifest, project_folder = get_files_from_everything(txt, type='.pdf')
# 检测输入参数,如没有给定输入参数,直接退出
if (not success) and txt == "": txt = '空空如也的输入栏。提示请先上传文件把PDF文件拖入对话'
# 如果没找到任何文件
if len(file_manifest) == 0:
chatbot.append([None, f"找不到任何.pdf拓展名的文件: {txt}"])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
# 开始正式执行任务
method = plugin_kwargs.get("pdf_parse_method", None)
if method == "DOC2X":
# ------- 第一种方法,效果最好,但是需要DOC2X服务 -------
DOC2X_API_KEY = get_conf("DOC2X_API_KEY")
if len(DOC2X_API_KEY) != 0:
try:
yield from 解析PDF_基于DOC2X(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request)
return
except:
chatbot.append([None, f"DOC2X服务不可用,现在将执行效果稍差的旧版代码。{trimmed_format_exc_markdown()}"])
yield from update_ui(chatbot=chatbot, history=history)
if method == "GROBID":
# ------- 第二种方法,效果次优 -------
grobid_url = get_avail_grobid_url()
if grobid_url is not None:
yield from 解析PDF_基于GROBID(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, grobid_url)
return
if method == "ClASSIC":
# ------- 第三种方法,早期代码,效果不理想 -------
yield from update_ui_lastest_msg("GROBID服务不可用,请检查config中的GROBID_URL。作为替代,现在将执行效果稍差的旧版代码。", chatbot, history, delay=3)
yield from 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
return
if method is None:
# ------- 以上三种方法都试一遍 -------
DOC2X_API_KEY = get_conf("DOC2X_API_KEY")
if len(DOC2X_API_KEY) != 0:
try:
yield from 解析PDF_基于DOC2X(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request)
return
except:
chatbot.append([None, f"DOC2X服务不可用,正在尝试GROBID。{trimmed_format_exc_markdown()}"])
yield from update_ui(chatbot=chatbot, history=history)
grobid_url = get_avail_grobid_url()
if grobid_url is not None:
yield from 解析PDF_基于GROBID(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, grobid_url)
return
yield from update_ui_lastest_msg("GROBID服务不可用,请检查config中的GROBID_URL。作为替代,现在将执行效果稍差的旧版代码。", chatbot, history, delay=3)
yield from 解析PDF_简单拆解(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)
return

查看文件

@@ -1,33 +0,0 @@
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
from .PDF_Translate import 批量翻译PDF文档
class PDF_Tran(GptAcademicPluginTemplate):
def __init__(self):
"""
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
"""
pass
def define_arg_selection_menu(self):
"""
定义插件的二级选项菜单
"""
gui_definition = {
"main_input":
ArgProperty(title="PDF文件路径", description="未指定路径,请上传文件后,再点击该插件", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
"additional_prompt":
ArgProperty(title="额外提示词", description="例如:对专有名词、翻译语气等方面的要求", default_value="", type="string").model_dump_json(), # 高级参数输入区,自动同步
"pdf_parse_method":
ArgProperty(title="PDF解析方法", options=["DOC2X", "GROBID", "ClASSIC"], description="", default_value="GROBID", type="dropdown").model_dump_json(),
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
main_input = plugin_kwargs["main_input"]
additional_prompt = plugin_kwargs["additional_prompt"]
pdf_parse_method = plugin_kwargs["pdf_parse_method"]
yield from 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)

查看文件

@@ -1,92 +0,0 @@
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
RAG_WORKER_REGISTER = {}
MAX_HISTORY_ROUND = 5
MAX_CONTEXT_TOKEN_LIMIT = 4096
REMEMBER_PREVIEW = 1000
@CatchException
def Rag问答(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
# import vector store lib
VECTOR_STORE_TYPE = "Milvus"
if VECTOR_STORE_TYPE == "Milvus":
try:
from crazy_functions.rag_fns.milvus_worker import MilvusRagWorker as LlamaIndexRagWorker
except:
VECTOR_STORE_TYPE = "Simple"
if VECTOR_STORE_TYPE == "Simple":
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
# 1. we retrieve rag worker from global context
user_name = chatbot.get_user()
checkpoint_dir = get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in RAG_WORKER_REGISTER:
rag_worker = RAG_WORKER_REGISTER[user_name]
else:
rag_worker = RAG_WORKER_REGISTER[user_name] = LlamaIndexRagWorker(
user_name,
llm_kwargs,
checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True)
current_context = f"{VECTOR_STORE_TYPE} @ {checkpoint_dir}"
tip = "提示输入“清空向量数据库”可以清空RAG向量数据库"
if txt == "清空向量数据库":
chatbot.append([txt, f'正在清空 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
rag_worker.purge()
yield from update_ui_lastest_msg('已清空', chatbot, history, delay=0) # 刷新界面
return
chatbot.append([txt, f'正在召回知识 ({current_context}) ...'])
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 2. clip history to reduce token consumption
# 2-1. reduce chat round
txt_origin = txt
if len(history) > MAX_HISTORY_ROUND * 2:
history = history[-(MAX_HISTORY_ROUND * 2):]
txt_clip, history, flags = input_clipping(txt, history, max_token_limit=MAX_CONTEXT_TOKEN_LIMIT, return_clip_flags=True)
input_is_clipped_flag = (flags["original_input_len"] != flags["clipped_input_len"])
# 2-2. if input is clipped, add input to vector store before retrieve
if input_is_clipped_flag:
yield from update_ui_lastest_msg('检测到长输入, 正在向量化 ...', chatbot, history, delay=0) # 刷新界面
# save input to vector store
rag_worker.add_text_to_vector_store(txt_origin)
yield from update_ui_lastest_msg('向量化完成 ...', chatbot, history, delay=0) # 刷新界面
if len(txt_origin) > REMEMBER_PREVIEW:
HALF = REMEMBER_PREVIEW//2
i_say_to_remember = txt[:HALF] + f" ...\n...(省略{len(txt_origin)-REMEMBER_PREVIEW}字)...\n... " + txt[-HALF:]
if (flags["original_input_len"] - flags["clipped_input_len"]) > HALF:
txt_clip = txt_clip + f" ...\n...(省略{len(txt_origin)-len(txt_clip)-HALF}字)...\n... " + txt[-HALF:]
else:
pass
i_say = txt_clip
else:
i_say_to_remember = i_say = txt_clip
else:
i_say_to_remember = i_say = txt_clip
# 3. we search vector store and build prompts
nodes = rag_worker.retrieve_from_store_with_query(i_say)
prompt = rag_worker.build_prompt(query=i_say, nodes=nodes)
# 4. it is time to query llms
if len(chatbot) != 0: chatbot.pop(-1) # pop temp chat, because we are going to add them again inside `request_gpt_model_in_new_thread_with_ui_alive`
model_say = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs=prompt, inputs_show_user=i_say,
llm_kwargs=llm_kwargs, chatbot=chatbot, history=history,
sys_prompt=system_prompt,
retry_times_at_unknown_error=0
)
# 5. remember what has been asked / answered
yield from update_ui_lastest_msg(model_say + '</br></br>' + f'对话记忆中, 请稍等 ({current_context}) ...', chatbot, history, delay=0.5) # 刷新界面
rag_worker.remember_qa(i_say_to_remember, model_say)
history.extend([i_say, model_say])
yield from update_ui_lastest_msg(model_say, chatbot, history, delay=0, msg=tip) # 刷新界面

查看文件

@@ -1,167 +0,0 @@
import pickle, os, random
from toolbox import CatchException, update_ui, get_conf, get_log_folder, update_ui_lastest_msg
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from request_llms.bridge_all import predict_no_ui_long_connection
from crazy_functions.json_fns.select_tool import structure_output, select_tool
from pydantic import BaseModel, Field
from loguru import logger
from typing import List
SOCIAL_NETWOK_WORKER_REGISTER = {}
class SocialNetwork():
def __init__(self):
self.people = []
class SaveAndLoad():
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.social_network = self.load_from_checkpoint(checkpoint_dir)
else:
self.social_network = SocialNetwork()
def does_checkpoint_exist(self, checkpoint_dir=None):
import os, glob
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if not os.path.exists(checkpoint_dir): return False
if len(glob.glob(os.path.join(checkpoint_dir, "social_network.pkl"))) == 0: return False
return True
def save_to_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "wb+") as f:
pickle.dump(self.social_network, f)
return
def load_from_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
with open(os.path.join(checkpoint_dir, 'social_network.pkl'), "rb") as f:
social_network = pickle.load(f)
return social_network
else:
return SocialNetwork()
class Friend(BaseModel):
friend_name: str = Field(description="name of a friend")
friend_description: str = Field(description="description of a friend (everything about this friend)")
friend_relationship: str = Field(description="The relationship with a friend (e.g. friend, family, colleague)")
class FriendList(BaseModel):
friends_list: List[Friend] = Field(description="The list of friends")
class SocialNetworkWorker(SaveAndLoad):
def ai_socail_advice(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
pass
def ai_remove_friend(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
pass
def ai_list_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
pass
def ai_add_multi_friends(self, prompt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type):
friend, err_msg = structure_output(
txt=prompt,
prompt="根据提示, 解析多个联系人的身份信息\n\n",
err_msg=f"不能理解该联系人",
run_gpt_fn=run_gpt_fn,
pydantic_cls=FriendList
)
if friend.friends_list:
for f in friend.friends_list:
self.add_friend(f)
msg = f"成功添加{len(friend.friends_list)}个联系人: {str(friend.friends_list)}"
yield from update_ui_lastest_msg(lastmsg=msg, chatbot=chatbot, history=history, delay=0)
def run(self, txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
prompt = txt
run_gpt_fn = lambda inputs, sys_prompt: predict_no_ui_long_connection(inputs=inputs, llm_kwargs=llm_kwargs, history=[], sys_prompt=sys_prompt, observe_window=[])
self.tools_to_select = {
"SocialAdvice":{
"explain_to_llm": "如果用户希望获取社交指导,调用SocialAdvice生成一些社交建议",
"callback": self.ai_socail_advice,
},
"AddFriends":{
"explain_to_llm": "如果用户给出了联系人,调用AddMultiFriends把联系人添加到数据库",
"callback": self.ai_add_multi_friends,
},
"RemoveFriend":{
"explain_to_llm": "如果用户希望移除某个联系人,调用RemoveFriend",
"callback": self.ai_remove_friend,
},
"ListFriends":{
"explain_to_llm": "如果用户列举联系人,调用ListFriends",
"callback": self.ai_list_friends,
}
}
try:
Explaination = '\n'.join([f'{k}: {v["explain_to_llm"]}' for k, v in self.tools_to_select.items()])
class UserSociaIntention(BaseModel):
intention_type: str = Field(
description=
f"The type of user intention. You must choose from {self.tools_to_select.keys()}.\n\n"
f"Explaination:\n{Explaination}",
default="SocialAdvice"
)
pydantic_cls_instance, err_msg = select_tool(
prompt=txt,
run_gpt_fn=run_gpt_fn,
pydantic_cls=UserSociaIntention
)
except Exception as e:
yield from update_ui_lastest_msg(
lastmsg=f"无法理解用户意图 {err_msg}",
chatbot=chatbot,
history=history,
delay=0
)
return
intention_type = pydantic_cls_instance.intention_type
intention_callback = self.tools_to_select[pydantic_cls_instance.intention_type]['callback']
yield from intention_callback(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, run_gpt_fn, intention_type)
def add_friend(self, friend):
# check whether the friend is already in the social network
for f in self.social_network.people:
if f.friend_name == friend.friend_name:
f.friend_description = friend.friend_description
f.friend_relationship = friend.friend_relationship
logger.info(f"Repeated friend, update info: {friend}")
return
logger.info(f"Add a new friend: {friend}")
self.social_network.people.append(friend)
return
@CatchException
def I人助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
# 1. we retrieve worker from global context
user_name = chatbot.get_user()
checkpoint_dir=get_log_folder(user_name, plugin_name='experimental_rag')
if user_name in SOCIAL_NETWOK_WORKER_REGISTER:
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name]
else:
social_network_worker = SOCIAL_NETWOK_WORKER_REGISTER[user_name] = SocialNetworkWorker(
user_name,
llm_kwargs,
checkpoint_dir=checkpoint_dir,
auto_load_checkpoint=True
)
# 2. save
yield from social_network_worker.run(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
social_network_worker.save_to_checkpoint(checkpoint_dir)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面

查看文件

@@ -1,162 +0,0 @@
import os, copy, time
from toolbox import CatchException, report_exception, update_ui, zip_result, promote_file_to_downloadzone, update_ui_lastest_msg, get_conf, generate_file_link
from shared_utils.fastapi_server import validate_path_safety
from crazy_functions.crazy_utils import input_clipping
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
from crazy_functions.agent_fns.python_comment_agent import PythonCodeComment
from crazy_functions.diagram_fns.file_tree import FileNode
from crazy_functions.agent_fns.watchdog import WatchDog
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
from loguru import logger
def 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt):
summary_batch_isolation = True
inputs_array = []
inputs_show_user_array = []
history_array = []
sys_prompt_array = []
assert len(file_manifest) <= 512, "源文件太多超过512个, 请缩减输入文件的数量。或者,您也可以选择删除此行警告,并修改代码拆分file_manifest列表,从而实现分批次处理。"
# 建立文件树
file_tree_struct = FileNode("root", build_manifest=True)
for file_path in file_manifest:
file_tree_struct.add_file(file_path, file_path)
# <第一步,逐个文件分析,多线程>
lang = "" if not plugin_kwargs["use_chinese"] else " (you must use Chinese)"
for index, fp in enumerate(file_manifest):
# 读取文件
with open(fp, 'r', encoding='utf-8', errors='replace') as f:
file_content = f.read()
prefix = ""
i_say = prefix + f'Please conclude the following source code at {os.path.relpath(fp, project_folder)} with only one sentence{lang}, the code is:\n```{file_content}```'
i_say_show_user = prefix + f'[{index+1}/{len(file_manifest)}] 请用一句话对下面的程序文件做一个整体概述: {fp}'
# 装载请求内容
MAX_TOKEN_SINGLE_FILE = 2560
i_say, _ = input_clipping(inputs=i_say, history=[], max_token_limit=MAX_TOKEN_SINGLE_FILE)
inputs_array.append(i_say)
inputs_show_user_array.append(i_say_show_user)
history_array.append([])
sys_prompt_array.append(f"You are a software architecture analyst analyzing a source code project. Do not dig into details, tell me what the code is doing in general. Your answer must be short, simple and clear{lang}.")
# 文件读取完成,对每一个源代码文件,生成一个请求线程,发送到大模型进行分析
gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array = inputs_array,
inputs_show_user_array = inputs_show_user_array,
history_array = history_array,
sys_prompt_array = sys_prompt_array,
llm_kwargs = llm_kwargs,
chatbot = chatbot,
show_user_at_complete = True
)
# <第二步,逐个文件分析,生成带注释文件>
tasks = ["" for _ in range(len(file_manifest))]
def bark_fn(tasks):
for i in range(len(tasks)): tasks[i] = "watchdog is dead"
wd = WatchDog(timeout=10, bark_fn=lambda: bark_fn(tasks), interval=3, msg="ThreadWatcher timeout")
wd.begin_watch()
from concurrent.futures import ThreadPoolExecutor
executor = ThreadPoolExecutor(max_workers=get_conf('DEFAULT_WORKER_NUM'))
def _task_multi_threading(i_say, gpt_say, fp, file_tree_struct, index):
language = 'Chinese' if plugin_kwargs["use_chinese"] else 'English'
def observe_window_update(x):
if tasks[index] == "watchdog is dead":
raise TimeoutError("ThreadWatcher: watchdog is dead")
tasks[index] = x
pcc = PythonCodeComment(llm_kwargs, plugin_kwargs, language=language, observe_window_update=observe_window_update)
pcc.read_file(path=fp, brief=gpt_say)
revised_path, revised_content = pcc.begin_comment_source_code(None, None)
file_tree_struct.manifest[fp].revised_path = revised_path
file_tree_struct.manifest[fp].revised_content = revised_content
# <将结果写回源文件>
with open(fp, 'w', encoding='utf-8') as f:
f.write(file_tree_struct.manifest[fp].revised_content)
# <生成对比html>
with open("crazy_functions/agent_fns/python_comment_compare.html", 'r', encoding='utf-8') as f:
html_template = f.read()
warp = lambda x: "```python\n\n" + x + "\n\n```"
from themes.theme import load_dynamic_theme
_, advanced_css, _, _ = load_dynamic_theme("Default")
html_template = html_template.replace("ADVANCED_CSS", advanced_css)
html_template = html_template.replace("REPLACE_CODE_FILE_LEFT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(pcc.original_content))))
html_template = html_template.replace("REPLACE_CODE_FILE_RIGHT", pcc.get_markdown_block_in_html(markdown_convertion_for_file(warp(revised_content))))
compare_html_path = fp + '.compare.html'
file_tree_struct.manifest[fp].compare_html = compare_html_path
with open(compare_html_path, 'w', encoding='utf-8') as f:
f.write(html_template)
tasks[index] = ""
chatbot.append([None, f"正在处理:"])
futures = []
index = 0
for i_say, gpt_say, fp in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], file_manifest):
future = executor.submit(_task_multi_threading, i_say, gpt_say, fp, file_tree_struct, index)
index += 1
futures.append(future)
# <第三步,等待任务完成>
cnt = 0
while True:
cnt += 1
wd.feed()
time.sleep(3)
worker_done = [h.done() for h in futures]
remain = len(worker_done) - sum(worker_done)
# <展示已经完成的部分>
preview_html_list = []
for done, fp in zip(worker_done, file_manifest):
if not done: continue
if hasattr(file_tree_struct.manifest[fp], 'compare_html'):
preview_html_list.append(file_tree_struct.manifest[fp].compare_html)
else:
logger.error(f"文件: {fp} 的注释结果未能成功")
file_links = generate_file_link(preview_html_list)
yield from update_ui_lastest_msg(
f"当前任务: <br/>{'<br/>'.join(tasks)}.<br/>" +
f"剩余源文件数量: {remain}.<br/>" +
f"已完成的文件: {sum(worker_done)}.<br/>" +
file_links +
"<br/>" +
''.join(['.']*(cnt % 10 + 1)
), chatbot=chatbot, history=history, delay=0)
yield from update_ui(chatbot=chatbot, history=[]) # 刷新界面
if all(worker_done):
executor.shutdown()
break
# <第四步,压缩结果>
zip_res = zip_result(project_folder)
promote_file_to_downloadzone(file=zip_res, chatbot=chatbot)
# <END>
chatbot.append((None, "所有源文件均已处理完毕。"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
@CatchException
def 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
history = [] # 清空历史,以免输入溢出
plugin_kwargs["use_chinese"] = plugin_kwargs.get("use_chinese", False)
import glob, os
if os.path.exists(txt):
project_folder = txt
validate_path_safety(project_folder, chatbot.get_user())
else:
if txt == "": txt = '空空如也的输入栏'
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到本地项目或无权访问: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
file_manifest = [f for f in glob.glob(f'{project_folder}/**/*.py', recursive=True)]
if len(file_manifest) == 0:
report_exception(chatbot, history, a = f"解析项目: {txt}", b = f"找不到任何python文件: {txt}")
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return
yield from 注释源代码(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt)

查看文件

@@ -1,36 +0,0 @@
from toolbox import get_conf, update_ui
from crazy_functions.plugin_template.plugin_class_template import GptAcademicPluginTemplate, ArgProperty
from crazy_functions.SourceCode_Comment import 注释Python项目
class SourceCodeComment_Wrap(GptAcademicPluginTemplate):
def __init__(self):
"""
请注意`execute`会执行在不同的线程中,因此您在定义和使用类变量时,应当慎之又慎!
"""
pass
def define_arg_selection_menu(self):
"""
定义插件的二级选项菜单
"""
gui_definition = {
"main_input":
ArgProperty(title="路径", description="程序路径(上传文件后自动填写)", default_value="", type="string").model_dump_json(), # 主输入,自动从输入框同步
"use_chinese":
ArgProperty(title="注释语言", options=["英文", "中文"], default_value="英文", description="", type="dropdown").model_dump_json(),
# "use_emoji":
# ArgProperty(title="在注释中使用emoji", options=["禁止", "允许"], default_value="禁止", description="无", type="dropdown").model_dump_json(),
}
return gui_definition
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
"""
执行插件
"""
if plugin_kwargs["use_chinese"] == "中文":
plugin_kwargs["use_chinese"] = True
else:
plugin_kwargs["use_chinese"] = False
yield from 注释Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)

查看文件

@@ -1,5 +1,4 @@
from crazy_functions.agent_fns.pipe import PluginMultiprocessManager, PipeCom
from loguru import logger
class EchoDemo(PluginMultiprocessManager):
def subprocess_worker(self, child_conn):
@@ -17,4 +16,4 @@ class EchoDemo(PluginMultiprocessManager):
elif msg.cmd == "terminate":
self.child_conn.send(PipeCom("done", ""))
break
logger.info('[debug] subprocess_worker terminated')
print('[debug] subprocess_worker terminated')

查看文件

@@ -1,6 +1,5 @@
from toolbox import get_log_folder, update_ui, gen_time_str, get_conf, promote_file_to_downloadzone
from crazy_functions.agent_fns.watchdog import WatchDog
from loguru import logger
import time, os
class PipeCom:
@@ -48,7 +47,7 @@ class PluginMultiprocessManager:
def terminate(self):
self.p.terminate()
self.alive = False
logger.info("[debug] instance terminated")
print("[debug] instance terminated")
def subprocess_worker(self, child_conn):
# ⭐⭐ run in subprocess
@@ -73,7 +72,7 @@ class PluginMultiprocessManager:
if file_type.lower() in ['png', 'jpg']:
image_path = os.path.abspath(fp)
self.chatbot.append([
'检测到新生图像:',
'检测到新生图像:',
f'本地文件预览: <br/><div align="center"><img src="file={image_path}"></div>'
])
yield from update_ui(chatbot=self.chatbot, history=self.history)
@@ -115,21 +114,21 @@ class PluginMultiprocessManager:
self.cnt = 1
self.parent_conn = self.launch_subprocess_with_pipe() # ⭐⭐⭐
repeated, cmd_to_autogen = self.send_command(txt)
if txt == 'exit':
if txt == 'exit':
self.chatbot.append([f"结束", "结束信号已明确,终止AutoGen程序。"])
yield from update_ui(chatbot=self.chatbot, history=self.history)
self.terminate()
return "terminate"
# patience = 10
while True:
time.sleep(0.5)
if not self.alive:
# the heartbeat watchdog might have it killed
self.terminate()
return "terminate"
if self.parent_conn.poll():
if self.parent_conn.poll():
self.feed_heartbeat_watchdog()
if "[GPT-Academic] 等待中" in self.chatbot[-1][-1]:
self.chatbot.pop(-1) # remove the last line
@@ -153,8 +152,8 @@ class PluginMultiprocessManager:
yield from update_ui(chatbot=self.chatbot, history=self.history)
if msg.cmd == "interact":
yield from self.overwatch_workdir_file_change()
self.chatbot.append([f"程序抵达用户反馈节点.", msg.content +
"\n\n等待您的进一步指令." +
self.chatbot.append([f"程序抵达用户反馈节点.", msg.content +
"\n\n等待您的进一步指令." +
"\n\n(1) 一般情况下您不需要说什么, 清空输入区, 然后直接点击“提交”以继续. " +
"\n\n(2) 如果您需要补充些什么, 输入要反馈的内容, 直接点击“提交”以继续. " +
"\n\n(3) 如果您想终止程序, 输入exit, 直接点击“提交”以终止AutoGen并解锁. "

查看文件

@@ -1,457 +0,0 @@
import datetime
import re
import os
from loguru import logger
from textwrap import dedent
from toolbox import CatchException, update_ui
from request_llms.bridge_all import predict_no_ui_long_connection
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
# TODO: 解决缩进问题
find_function_end_prompt = '''
Below is a page of code that you need to read. This page may not yet complete, you job is to split this page to sperate functions, class functions etc.
- Provide the line number where the first visible function ends.
- Provide the line number where the next visible function begins.
- If there are no other functions in this page, you should simply return the line number of the last line.
- Only focus on functions declared by `def` keyword. Ignore inline functions. Ignore function calls.
------------------ Example ------------------
INPUT:
```
L0000 |import sys
L0001 |import re
L0002 |
L0003 |def trimmed_format_exc():
L0004 | import os
L0005 | import traceback
L0006 | str = traceback.format_exc()
L0007 | current_path = os.getcwd()
L0008 | replace_path = "."
L0009 | return str.replace(current_path, replace_path)
L0010 |
L0011 |
L0012 |def trimmed_format_exc_markdown():
L0013 | ...
L0014 | ...
```
OUTPUT:
```
<first_function_end_at>L0009</first_function_end_at>
<next_function_begin_from>L0012</next_function_begin_from>
```
------------------ End of Example ------------------
------------------ the real INPUT you need to process NOW ------------------
```
{THE_TAGGED_CODE}
```
'''
revise_funtion_prompt = '''
You need to read the following code, and revise the source code ({FILE_BASENAME}) according to following instructions:
1. You should analyze the purpose of the functions (if there are any).
2. You need to add docstring for the provided functions (if there are any).
Be aware:
1. You must NOT modify the indent of code.
2. You are NOT authorized to change or translate non-comment code, and you are NOT authorized to add empty lines either, toggle qu.
3. Use {LANG} to add comments and docstrings. Do NOT translate Chinese that is already in the code.
4. Besides adding a docstring, use the ⭐ symbol to annotate the most core and important line of code within the function, explaining its role.
------------------ Example ------------------
INPUT:
```
L0000 |
L0001 |def zip_result(folder):
L0002 | t = gen_time_str()
L0003 | zip_folder(folder, get_log_folder(), f"result.zip")
L0004 | return os.path.join(get_log_folder(), f"result.zip")
L0005 |
L0006 |
```
OUTPUT:
<instruction_1_purpose>
This function compresses a given folder, and return the path of the resulting `zip` file.
</instruction_1_purpose>
<instruction_2_revised_code>
```
def zip_result(folder):
"""
Compresses the specified folder into a zip file and stores it in the log folder.
Args:
folder (str): The path to the folder that needs to be compressed.
Returns:
str: The path to the created zip file in the log folder.
"""
t = gen_time_str()
zip_folder(folder, get_log_folder(), f"result.zip") # ⭐ Execute the zipping of folder
return os.path.join(get_log_folder(), f"result.zip")
```
</instruction_2_revised_code>
------------------ End of Example ------------------
------------------ the real INPUT you need to process NOW ({FILE_BASENAME}) ------------------
```
{THE_CODE}
```
{INDENT_REMINDER}
{BRIEF_REMINDER}
{HINT_REMINDER}
'''
revise_funtion_prompt_chinese = '''
您需要阅读以下代码,并根据以下说明修订源代码({FILE_BASENAME}):
1. 如果源代码中包含函数的话, 你应该分析给定函数实现了什么功能
2. 如果源代码中包含函数的话, 你需要为函数添加docstring, docstring必须使用中文
请注意:
1. 你不得修改代码的缩进
2. 你无权更改或翻译代码中的非注释部分,也不允许添加空行
3. 使用 {LANG} 添加注释和文档字符串。不要翻译代码中已有的中文
4. 除了添加docstring之外, 使用⭐符号给该函数中最核心、最重要的一行代码添加注释,并说明其作用
------------------ 示例 ------------------
INPUT:
```
L0000 |
L0001 |def zip_result(folder):
L0002 | t = gen_time_str()
L0003 | zip_folder(folder, get_log_folder(), f"result.zip")
L0004 | return os.path.join(get_log_folder(), f"result.zip")
L0005 |
L0006 |
```
OUTPUT:
<instruction_1_purpose>
该函数用于压缩指定文件夹,并返回生成的`zip`文件的路径。
</instruction_1_purpose>
<instruction_2_revised_code>
```
def zip_result(folder):
"""
该函数将指定的文件夹压缩成ZIP文件, 并将其存储在日志文件夹中。
输入参数:
folder (str): 需要压缩的文件夹的路径。
返回值:
str: 日志文件夹中创建的ZIP文件的路径。
"""
t = gen_time_str()
zip_folder(folder, get_log_folder(), f"result.zip") # ⭐ 执行文件夹的压缩
return os.path.join(get_log_folder(), f"result.zip")
```
</instruction_2_revised_code>
------------------ End of Example ------------------
------------------ the real INPUT you need to process NOW ({FILE_BASENAME}) ------------------
```
{THE_CODE}
```
{INDENT_REMINDER}
{BRIEF_REMINDER}
{HINT_REMINDER}
'''
class PythonCodeComment():
def __init__(self, llm_kwargs, plugin_kwargs, language, observe_window_update) -> None:
self.original_content = ""
self.full_context = []
self.full_context_with_line_no = []
self.current_page_start = 0
self.page_limit = 100 # 100 lines of code each page
self.ignore_limit = 20
self.llm_kwargs = llm_kwargs
self.plugin_kwargs = plugin_kwargs
self.language = language
self.observe_window_update = observe_window_update
if self.language == "chinese":
self.core_prompt = revise_funtion_prompt_chinese
else:
self.core_prompt = revise_funtion_prompt
self.path = None
self.file_basename = None
self.file_brief = ""
def generate_tagged_code_from_full_context(self):
for i, code in enumerate(self.full_context):
number = i
padded_number = f"{number:04}"
result = f"L{padded_number}"
self.full_context_with_line_no.append(f"{result} | {code}")
return self.full_context_with_line_no
def read_file(self, path, brief):
with open(path, 'r', encoding='utf8') as f:
self.full_context = f.readlines()
self.original_content = ''.join(self.full_context)
self.file_basename = os.path.basename(path)
self.file_brief = brief
self.full_context_with_line_no = self.generate_tagged_code_from_full_context()
self.path = path
def find_next_function_begin(self, tagged_code:list, begin_and_end):
begin, end = begin_and_end
THE_TAGGED_CODE = ''.join(tagged_code)
self.llm_kwargs['temperature'] = 0
result = predict_no_ui_long_connection(
inputs=find_function_end_prompt.format(THE_TAGGED_CODE=THE_TAGGED_CODE),
llm_kwargs=self.llm_kwargs,
history=[],
sys_prompt="",
observe_window=[],
console_slience=True
)
def extract_number(text):
# 使用正则表达式匹配模式
match = re.search(r'<next_function_begin_from>L(\d+)</next_function_begin_from>', text)
if match:
# 提取匹配的数字部分并转换为整数
return int(match.group(1))
return None
line_no = extract_number(result)
if line_no is not None:
return line_no
else:
return end
def _get_next_window(self):
#
current_page_start = self.current_page_start
if self.current_page_start == len(self.full_context) + 1:
raise StopIteration
# 如果剩余的行数非常少,一鼓作气处理掉
if len(self.full_context) - self.current_page_start < self.ignore_limit:
future_page_start = len(self.full_context) + 1
self.current_page_start = future_page_start
return current_page_start, future_page_start
tagged_code = self.full_context_with_line_no[ self.current_page_start: self.current_page_start + self.page_limit]
line_no = self.find_next_function_begin(tagged_code, [self.current_page_start, self.current_page_start + self.page_limit])
if line_no > len(self.full_context) - 5:
line_no = len(self.full_context) + 1
future_page_start = line_no
self.current_page_start = future_page_start
# ! consider eof
return current_page_start, future_page_start
def dedent(self, text):
"""Remove any common leading whitespace from every line in `text`.
"""
# Look for the longest leading string of spaces and tabs common to
# all lines.
margin = None
_whitespace_only_re = re.compile('^[ \t]+$', re.MULTILINE)
_leading_whitespace_re = re.compile('(^[ \t]*)(?:[^ \t\n])', re.MULTILINE)
text = _whitespace_only_re.sub('', text)
indents = _leading_whitespace_re.findall(text)
for indent in indents:
if margin is None:
margin = indent
# Current line more deeply indented than previous winner:
# no change (previous winner is still on top).
elif indent.startswith(margin):
pass
# Current line consistent with and no deeper than previous winner:
# it's the new winner.
elif margin.startswith(indent):
margin = indent
# Find the largest common whitespace between current line and previous
# winner.
else:
for i, (x, y) in enumerate(zip(margin, indent)):
if x != y:
margin = margin[:i]
break
# sanity check (testing/debugging only)
if 0 and margin:
for line in text.split("\n"):
assert not line or line.startswith(margin), \
"line = %r, margin = %r" % (line, margin)
if margin:
text = re.sub(r'(?m)^' + margin, '', text)
return text, len(margin)
else:
return text, 0
def get_next_batch(self):
current_page_start, future_page_start = self._get_next_window()
return ''.join(self.full_context[current_page_start: future_page_start]), current_page_start, future_page_start
def tag_code(self, fn, hint):
code = fn
_, n_indent = self.dedent(code)
indent_reminder = "" if n_indent == 0 else "(Reminder: as you can see, this piece of code has indent made up with {n_indent} whitespace, please preseve them in the OUTPUT.)"
brief_reminder = "" if self.file_brief == "" else f"({self.file_basename} abstract: {self.file_brief})"
hint_reminder = "" if hint is None else f"(Reminder: do not ignore or modify code such as `{hint}`, provide complete code in the OUTPUT.)"
self.llm_kwargs['temperature'] = 0
result = predict_no_ui_long_connection(
inputs=self.core_prompt.format(
LANG=self.language,
FILE_BASENAME=self.file_basename,
THE_CODE=code,
INDENT_REMINDER=indent_reminder,
BRIEF_REMINDER=brief_reminder,
HINT_REMINDER=hint_reminder
),
llm_kwargs=self.llm_kwargs,
history=[],
sys_prompt="",
observe_window=[],
console_slience=True
)
def get_code_block(reply):
import re
pattern = r"```([\s\S]*?)```" # regex pattern to match code blocks
matches = re.findall(pattern, reply) # find all code blocks in text
if len(matches) == 1:
return matches[0].strip('python') # code block
return None
code_block = get_code_block(result)
if code_block is not None:
code_block = self.sync_and_patch(original=code, revised=code_block)
return code_block
else:
return code
def get_markdown_block_in_html(self, html):
from bs4 import BeautifulSoup
soup = BeautifulSoup(html, 'lxml')
found_list = soup.find_all("div", class_="markdown-body")
if found_list:
res = found_list[0]
return res.prettify()
else:
return None
def sync_and_patch(self, original, revised):
"""Ensure the number of pre-string empty lines in revised matches those in original."""
def count_leading_empty_lines(s, reverse=False):
"""Count the number of leading empty lines in a string."""
lines = s.split('\n')
if reverse: lines = list(reversed(lines))
count = 0
for line in lines:
if line.strip() == '':
count += 1
else:
break
return count
original_empty_lines = count_leading_empty_lines(original)
revised_empty_lines = count_leading_empty_lines(revised)
if original_empty_lines > revised_empty_lines:
additional_lines = '\n' * (original_empty_lines - revised_empty_lines)
revised = additional_lines + revised
elif original_empty_lines < revised_empty_lines:
lines = revised.split('\n')
revised = '\n'.join(lines[revised_empty_lines - original_empty_lines:])
original_empty_lines = count_leading_empty_lines(original, reverse=True)
revised_empty_lines = count_leading_empty_lines(revised, reverse=True)
if original_empty_lines > revised_empty_lines:
additional_lines = '\n' * (original_empty_lines - revised_empty_lines)
revised = revised + additional_lines
elif original_empty_lines < revised_empty_lines:
lines = revised.split('\n')
revised = '\n'.join(lines[:-(revised_empty_lines - original_empty_lines)])
return revised
def begin_comment_source_code(self, chatbot=None, history=None):
# from toolbox import update_ui_lastest_msg
assert self.path is not None
assert '.py' in self.path # must be python source code
# write_target = self.path + '.revised.py'
write_content = ""
# with open(self.path + '.revised.py', 'w+', encoding='utf8') as f:
while True:
try:
# yield from update_ui_lastest_msg(f"({self.file_basename}) 正在读取下一段代码片段:\n", chatbot=chatbot, history=history, delay=0)
next_batch, line_no_start, line_no_end = self.get_next_batch()
self.observe_window_update(f"正在处理{self.file_basename} - {line_no_start}/{len(self.full_context)}\n")
# yield from update_ui_lastest_msg(f"({self.file_basename}) 处理代码片段:\n\n{next_batch}", chatbot=chatbot, history=history, delay=0)
hint = None
MAX_ATTEMPT = 2
for attempt in range(MAX_ATTEMPT):
result = self.tag_code(next_batch, hint)
try:
successful, hint = self.verify_successful(next_batch, result)
except Exception as e:
logger.error('ignored exception:\n' + str(e))
break
if successful:
break
if attempt == MAX_ATTEMPT - 1:
# cannot deal with this, give up
result = next_batch
break
# f.write(result)
write_content += result
except StopIteration:
next_batch, line_no_start, line_no_end = [], -1, -1
return None, write_content
def verify_successful(self, original, revised):
""" Determine whether the revised code contains every line that already exists
"""
from crazy_functions.ast_fns.comment_remove import remove_python_comments
original = remove_python_comments(original)
original_lines = original.split('\n')
revised_lines = revised.split('\n')
for l in original_lines:
l = l.strip()
if '\'' in l or '\"' in l: continue # ast sometimes toggle " to '
found = False
for lt in revised_lines:
if l in lt:
found = True
break
if not found:
return False, l
return True, None

查看文件

@@ -1,45 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<style>ADVANCED_CSS</style>
<meta charset="UTF-8">
<title>源文件对比</title>
<style>
body {
font-family: Arial, sans-serif;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
}
.container {
display: flex;
width: 95%;
height: -webkit-fill-available;
}
.code-container {
flex: 1;
margin: 0px;
padding: 0px;
border: 1px solid #ccc;
background-color: #f9f9f9;
overflow: auto;
}
pre {
white-space: pre-wrap;
word-wrap: break-word;
}
</style>
</head>
<body>
<div class="container">
<div class="code-container">
REPLACE_CODE_FILE_LEFT
</div>
<div class="code-container">
REPLACE_CODE_FILE_RIGHT
</div>
</div>
</body>
</html>

查看文件

@@ -1,5 +1,4 @@
import threading, time
from loguru import logger
class WatchDog():
def __init__(self, timeout, bark_fn, interval=3, msg="") -> None:
@@ -9,12 +8,12 @@ class WatchDog():
self.interval = interval
self.msg = msg
self.kill_dog = False
def watch(self):
while True:
if self.kill_dog: break
if time.time() - self.last_feed > self.timeout:
if len(self.msg) > 0: logger.info(self.msg)
if len(self.msg) > 0: print(self.msg)
self.bark_fn()
break
time.sleep(self.interval)

查看文件

@@ -1,54 +0,0 @@
import token
import tokenize
import copy
import io
def remove_python_comments(input_source: str) -> str:
source_flag = copy.copy(input_source)
source = io.StringIO(input_source)
ls = input_source.split('\n')
prev_toktype = token.INDENT
readline = source.readline
def get_char_index(lineno, col):
# find the index of the char in the source code
if lineno == 1:
return len('\n'.join(ls[:(lineno-1)])) + col
else:
return len('\n'.join(ls[:(lineno-1)])) + col + 1
def replace_char_between(start_lineno, start_col, end_lineno, end_col, source, replace_char, ls):
# replace char between start_lineno, start_col and end_lineno, end_col with replace_char, but keep '\n' and ' '
b = get_char_index(start_lineno, start_col)
e = get_char_index(end_lineno, end_col)
for i in range(b, e):
if source[i] == '\n':
source = source[:i] + '\n' + source[i+1:]
elif source[i] == ' ':
source = source[:i] + ' ' + source[i+1:]
else:
source = source[:i] + replace_char + source[i+1:]
return source
tokgen = tokenize.generate_tokens(readline)
for toktype, ttext, (slineno, scol), (elineno, ecol), ltext in tokgen:
if toktype == token.STRING and (prev_toktype == token.INDENT):
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
elif toktype == token.STRING and (prev_toktype == token.NEWLINE):
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
elif toktype == tokenize.COMMENT:
source_flag = replace_char_between(slineno, scol, elineno, ecol, source_flag, ' ', ls)
prev_toktype = toktype
return source_flag
# 示例使用
if __name__ == "__main__":
with open("source.py", "r", encoding="utf-8") as f:
source_code = f.read()
cleaned_code = remove_python_comments(source_code)
with open("cleaned_source.py", "w", encoding="utf-8") as f:
f.write(cleaned_code)

查看文件

@@ -1,5 +1,5 @@
from toolbox import CatchException, update_ui, promote_file_to_downloadzone
from crazy_functions.crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
import datetime, json
def fetch_items(list_of_items, batch_size):
@@ -46,7 +46,7 @@ def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
args = plugin_kwargs.get("advanced_arg", None)
if args is None:
if args is None:
chatbot.append(("没给定指令", "退出"))
yield from update_ui(chatbot=chatbot, history=history); return
else:
@@ -69,7 +69,7 @@ def 微调数据集生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, syst
sys_prompt_array=[arguments.system_prompt for _ in (batch)],
max_workers=10 # OpenAI所允许的最大并行过载
)
with open(txt+'.generated.json', 'a+', encoding='utf8') as f:
for b, r in zip(batch, res[1::2]):
f.write(json.dumps({"content":b, "summary":r}, ensure_ascii=False)+'\n')
@@ -95,12 +95,12 @@ def 启动微调(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt
chatbot.append(("这是什么功能?", "[Local Message] 微调数据集生成"))
if ("advanced_arg" in plugin_kwargs) and (plugin_kwargs["advanced_arg"] == ""): plugin_kwargs.pop("advanced_arg")
args = plugin_kwargs.get("advanced_arg", None)
if args is None:
if args is None:
chatbot.append(("没给定指令", "退出"))
yield from update_ui(chatbot=chatbot, history=history); return
else:
arguments = string_to_options(arguments=args)
pre_seq_len = arguments.pre_seq_len # 128

查看文件

@@ -0,0 +1,231 @@
"""
这是什么?
这个文件用于函数插件的单元测试
运行方法 python crazy_functions/crazy_functions_test.py
"""
# ==============================================================================================================================
def validate_path():
import os, sys
dir_name = os.path.dirname(__file__)
root_dir_assume = os.path.abspath(os.path.dirname(__file__) + '/..')
os.chdir(root_dir_assume)
sys.path.append(root_dir_assume)
validate_path() # validate path so you can run from base directory
# ==============================================================================================================================
from colorful import *
from toolbox import get_conf, ChatBotWithCookies
import contextlib
import os
import sys
from functools import wraps
proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION, CHATBOT_HEIGHT, LAYOUT, API_KEY = \
get_conf('proxies', 'WEB_PORT', 'LLM_MODEL', 'CONCURRENT_COUNT', 'AUTHENTICATION', 'CHATBOT_HEIGHT', 'LAYOUT', 'API_KEY')
llm_kwargs = {
'api_key': API_KEY,
'llm_model': LLM_MODEL,
'top_p':1.0,
'max_length': None,
'temperature':1.0,
}
plugin_kwargs = { }
chatbot = ChatBotWithCookies(llm_kwargs)
history = []
system_prompt = "Serve me as a writing and programming assistant."
web_port = 1024
# ==============================================================================================================================
def silence_stdout(func):
@wraps(func)
def wrapper(*args, **kwargs):
_original_stdout = sys.stdout
sys.stdout = open(os.devnull, 'w')
for q in func(*args, **kwargs):
sys.stdout = _original_stdout
yield q
sys.stdout = open(os.devnull, 'w')
sys.stdout.close()
sys.stdout = _original_stdout
return wrapper
class CLI_Printer():
def __init__(self) -> None:
self.pre_buf = ""
def print(self, buf):
bufp = ""
for index, chat in enumerate(buf):
a, b = chat
bufp += sprint亮靛('[Me]:' + a) + '\n'
bufp += '[GPT]:' + b
if index < len(buf)-1:
bufp += '\n'
if self.pre_buf!="" and bufp.startswith(self.pre_buf):
print(bufp[len(self.pre_buf):], end='')
else:
print('\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n'+bufp, end='')
self.pre_buf = bufp
return
cli_printer = CLI_Printer()
# ==============================================================================================================================
def test_解析一个Python项目():
from crazy_functions.解析项目源代码 import 解析一个Python项目
txt = "crazy_functions/test_project/python/dqn"
for cookies, cb, hist, msg in 解析一个Python项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_解析一个Cpp项目():
from crazy_functions.解析项目源代码 import 解析一个C项目
txt = "crazy_functions/test_project/cpp/cppipc"
for cookies, cb, hist, msg in 解析一个C项目(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_Latex英文润色():
from crazy_functions.Latex全文润色 import Latex英文润色
txt = "crazy_functions/test_project/latex/attention"
for cookies, cb, hist, msg in Latex英文润色(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_Markdown中译英():
from crazy_functions.批量Markdown翻译 import Markdown中译英
txt = "README.md"
for cookies, cb, hist, msg in Markdown中译英(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_批量翻译PDF文档():
from crazy_functions.批量翻译PDF文档_多线程 import 批量翻译PDF文档
txt = "crazy_functions/test_project/pdf_and_word"
for cookies, cb, hist, msg in 批量翻译PDF文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_谷歌检索小助手():
from crazy_functions.谷歌检索小助手 import 谷歌检索小助手
txt = "https://scholar.google.com/scholar?hl=en&as_sdt=0%2C5&q=auto+reinforcement+learning&btnG="
for cookies, cb, hist, msg in 谷歌检索小助手(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_总结word文档():
from crazy_functions.总结word文档 import 总结word文档
txt = "crazy_functions/test_project/pdf_and_word"
for cookies, cb, hist, msg in 总结word文档(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_下载arxiv论文并翻译摘要():
from crazy_functions.下载arxiv论文翻译摘要 import 下载arxiv论文并翻译摘要
txt = "1812.10695"
for cookies, cb, hist, msg in 下载arxiv论文并翻译摘要(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_联网回答问题():
from crazy_functions.联网的ChatGPT import 连接网络回答问题
# txt = "谁是应急食品?"
# >> '根据以上搜索结果可以得知,应急食品是“原神”游戏中的角色派蒙的外号。'
# txt = "道路千万条,安全第一条。后面两句是?"
# >> '行车不规范,亲人两行泪。'
# txt = "You should have gone for the head. What does that mean?"
# >> The phrase "You should have gone for the head" is a quote from the Marvel movies, Avengers: Infinity War and Avengers: Endgame. It was spoken by the character Thanos in Infinity War and by Thor in Endgame.
txt = "AutoGPT是什么?"
for cookies, cb, hist, msg in 连接网络回答问题(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print("当前问答:", cb[-1][-1].replace("\n"," "))
for i, it in enumerate(cb): print亮蓝(it[0]); print亮黄(it[1])
def test_解析ipynb文件():
from crazy_functions.解析JupyterNotebook import 解析ipynb文件
txt = "crazy_functions/test_samples"
for cookies, cb, hist, msg in 解析ipynb文件(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_数学动画生成manim():
from crazy_functions.数学动画生成manim import 动画生成
txt = "A ball split into 2, and then split into 4, and finally split into 8."
for cookies, cb, hist, msg in 动画生成(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_Markdown多语言():
from crazy_functions.批量Markdown翻译 import Markdown翻译指定语言
txt = "README.md"
history = []
for lang in ["English", "French", "Japanese", "Korean", "Russian", "Italian", "German", "Portuguese", "Arabic"]:
plugin_kwargs = {"advanced_arg": lang}
for cookies, cb, hist, msg in Markdown翻译指定语言(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
print(cb)
def test_Langchain知识库():
from crazy_functions.Langchain知识库 import 知识库问答
txt = "./"
chatbot = ChatBotWithCookies(llm_kwargs)
for cookies, cb, hist, msg in silence_stdout(知识库问答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
cli_printer.print(cb) # print(cb)
chatbot = ChatBotWithCookies(cookies)
from crazy_functions.Langchain知识库 import 读取知识库作答
txt = "What is the installation method?"
for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
cli_printer.print(cb) # print(cb)
def test_Langchain知识库读取():
from crazy_functions.Langchain知识库 import 读取知识库作答
txt = "远程云服务器部署?"
for cookies, cb, hist, msg in silence_stdout(读取知识库作答)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
cli_printer.print(cb) # print(cb)
def test_Latex():
from crazy_functions.Latex输出PDF结果 import Latex英文纠错加PDF对比, Latex翻译中文并重新编译PDF
# txt = r"https://arxiv.org/abs/1706.03762"
# txt = r"https://arxiv.org/abs/1902.03185"
# txt = r"https://arxiv.org/abs/2305.18290"
# txt = r"https://arxiv.org/abs/2305.17608"
# txt = r"https://arxiv.org/abs/2211.16068" # ACE
# txt = r"C:\Users\x\arxiv_cache\2211.16068\workfolder" # ACE
# txt = r"https://arxiv.org/abs/2002.09253"
# txt = r"https://arxiv.org/abs/2306.07831"
# txt = r"https://arxiv.org/abs/2212.10156"
# txt = r"https://arxiv.org/abs/2211.11559"
# txt = r"https://arxiv.org/abs/2303.08774"
txt = r"https://arxiv.org/abs/2303.12712"
# txt = r"C:\Users\fuqingxu\arxiv_cache\2303.12712\workfolder"
for cookies, cb, hist, msg in (Latex翻译中文并重新编译PDF)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
cli_printer.print(cb) # print(cb)
# txt = "2302.02948.tar"
# print(txt)
# main_tex, work_folder = Latex预处理(txt)
# print('main tex:', main_tex)
# res = 编译Latex(main_tex, work_folder)
# # for cookies, cb, hist, msg in silence_stdout(编译Latex)(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, web_port):
# cli_printer.print(cb) # print(cb)
# test_解析一个Python项目()
# test_Latex英文润色()
# test_Markdown中译英()
# test_批量翻译PDF文档()
# test_谷歌检索小助手()
# test_总结word文档()
# test_下载arxiv论文并翻译摘要()
# test_解析一个Cpp项目()
# test_联网回答问题()
# test_解析ipynb文件()
# test_数学动画生成manim()
# test_Langchain知识库()
# test_Langchain知识库读取()
if __name__ == "__main__":
test_Latex()
input("程序完成,回车退出。")
print("退出。")

查看文件

@@ -1,39 +1,25 @@
import os
import threading
from loguru import logger
from shared_utils.char_visual_effect import scolling_visual_effect
from toolbox import update_ui, get_conf, trimmed_format_exc, get_max_token, Singleton
import threading
import os
import logging
def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
"""
当输入文本 + 历史文本超出最大限制时,采取措施丢弃一部分文本。
输入:
- inputs 本次请求
- history 历史上下文
- max_token_limit 最大token限制
输出:
- inputs 本次请求经过clip
- history 历史上下文经过clip
"""
def input_clipping(inputs, history, max_token_limit):
import numpy as np
from request_llms.bridge_all import model_info
enc = model_info["gpt-3.5-turbo"]['tokenizer']
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))
mode = 'input-and-history'
# 当 输入部分的token占比 小于 全文的一半时,只裁剪历史
input_token_num = get_token_num(inputs)
original_input_len = len(inputs)
if input_token_num < max_token_limit//2:
mode = 'only-history'
max_token_limit = max_token_limit - input_token_num
everything = [inputs] if mode == 'input-and-history' else ['']
everything.extend(history)
full_token_num = n_token = get_token_num('\n'.join(everything))
n_token = get_token_num('\n'.join(everything))
everything_token = [get_token_num(e) for e in everything]
everything_token_num = sum(everything_token)
delta = max(everything_token) // 16 # 截断时的颗粒度
while n_token > max_token_limit:
@@ -46,24 +32,10 @@ def input_clipping(inputs, history, max_token_limit, return_clip_flags=False):
if mode == 'input-and-history':
inputs = everything[0]
full_token_num = everything_token_num
else:
full_token_num = everything_token_num + input_token_num
pass
history = everything[1:]
flags = {
"mode": mode,
"original_input_token_num": input_token_num,
"original_full_token_num": full_token_num,
"original_input_len": original_input_len,
"clipped_input_len": len(inputs),
}
if not return_clip_flags:
return inputs, history
else:
return inputs, history, flags
return inputs, history
def request_gpt_model_in_new_thread_with_ui_alive(
inputs, inputs_show_user, llm_kwargs,
@@ -133,7 +105,7 @@ def request_gpt_model_in_new_thread_with_ui_alive(
except:
# 【第三种情况】:其他错误:重试几次
tb_str = '```\n' + trimmed_format_exc() + '```'
logger.error(tb_str)
print(tb_str)
mutable[0] += f"[Local Message] 警告,在执行过程中遭遇问题, Traceback\n\n{tb_str}\n\n"
if retry_op > 0:
retry_op -= 1
@@ -163,30 +135,18 @@ def request_gpt_model_in_new_thread_with_ui_alive(
yield from update_ui(chatbot=chatbot, history=[]) # 如果最后成功了,则删除报错信息
return final_result
def can_multi_process(llm) -> bool:
from request_llms.bridge_all import model_info
def default_condition(llm) -> bool:
# legacy condition
if llm.startswith('gpt-'): return True
if llm.startswith('api2d-'): return True
if llm.startswith('azure-'): return True
if llm.startswith('spark'): return True
if llm.startswith('zhipuai') or llm.startswith('glm-'): return True
return False
if llm in model_info:
if 'can_multi_thread' in model_info[llm]:
return model_info[llm]['can_multi_thread']
else:
return default_condition(llm)
else:
return default_condition(llm)
def can_multi_process(llm):
if llm.startswith('gpt-'): return True
if llm.startswith('api2d-'): return True
if llm.startswith('azure-'): return True
if llm.startswith('spark'): return True
if llm.startswith('zhipuai') or llm.startswith('glm-'): return True
return False
def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array, inputs_show_user_array, llm_kwargs,
chatbot, history_array, sys_prompt_array,
refresh_interval=0.2, max_workers=-1, scroller_max_len=75,
refresh_interval=0.2, max_workers=-1, scroller_max_len=30,
handle_token_exceed=True, show_user_at_complete=False,
retry_times_at_unknown_error=2,
):
@@ -283,7 +243,7 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
# 【第三种情况】:其他错误
if detect_timeout(): raise RuntimeError("检测到程序终止。")
tb_str = '```\n' + trimmed_format_exc() + '```'
logger.error(tb_str)
print(tb_str)
gpt_say += f"[Local Message] 警告,线程{index}在执行过程中遭遇问题, Traceback\n\n{tb_str}\n\n"
if len(mutable[index][0]) > 0: gpt_say += "此线程失败前收到的回答:\n\n" + mutable[index][0]
if retry_op > 0:
@@ -311,8 +271,6 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
futures = [executor.submit(_req_gpt, index, inputs, history, sys_prompt) for index, inputs, history, sys_prompt in zip(
range(len(inputs_array)), inputs_array, history_array, sys_prompt_array)]
cnt = 0
while True:
# yield一次以刷新前端页面
time.sleep(refresh_interval)
@@ -325,7 +283,8 @@ def request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
mutable[thread_index][1] = time.time()
# 在前端打印些好玩的东西
for thread_index, _ in enumerate(worker_done):
print_something_really_funny = f"[ ...`{scolling_visual_effect(mutable[thread_index][0], scroller_max_len)}`... ]"
print_something_really_funny = "[ ...`"+mutable[thread_index][0][-scroller_max_len:].\
replace('\n', '').replace('`', '.').replace(' ', '.').replace('<br/>', '.....').replace('$', '.')+"`... ]"
observe_win.append(print_something_really_funny)
# 在前端打印些好玩的东西
stat_str = ''.join([f'`{mutable[thread_index][2]}`: {obs}\n\n'
@@ -378,7 +337,7 @@ def read_and_clean_pdf_text(fp):
import fitz, copy
import re
import numpy as np
# from shared_utils.colorful import print亮黄, print亮绿
from colorful import print亮黄, print亮绿
fc = 0 # Index 0 文本
fs = 1 # Index 1 字体
fb = 2 # Index 2 框框
@@ -595,15 +554,15 @@ class nougat_interface():
def nougat_with_timeout(self, command, cwd, timeout=3600):
import subprocess
from toolbox import ProxyNetworkActivate
logger.info(f'正在执行命令 {command}')
logging.info(f'正在执行命令 {command}')
with ProxyNetworkActivate("Nougat_Download"):
process = subprocess.Popen(command, shell=False, cwd=cwd, env=os.environ)
process = subprocess.Popen(command, shell=True, cwd=cwd, env=os.environ)
try:
stdout, stderr = process.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
logger.error("Process timed out!")
print("Process timed out!")
return False
return True
@@ -621,8 +580,7 @@ class nougat_interface():
yield from update_ui_lastest_msg("正在解析论文, 请稍候。进度正在加载NOUGAT... 提示首次运行需要花费较长时间下载NOUGAT参数",
chatbot=chatbot, history=history, delay=0)
command = ['nougat', '--out', os.path.abspath(dst), os.path.abspath(fp)]
self.nougat_with_timeout(command, cwd=os.getcwd(), timeout=3600)
self.nougat_with_timeout(f'nougat --out "{os.path.abspath(dst)}" "{os.path.abspath(fp)}"', os.getcwd(), timeout=3600)
res = glob.glob(os.path.join(dst,'*.mmd'))
if len(res) == 0:
self.threadLock.release()

查看文件

@@ -1,9 +1,8 @@
import os
from textwrap import indent
from loguru import logger
class FileNode:
def __init__(self, name, build_manifest=False):
def __init__(self, name):
self.name = name
self.children = []
self.is_leaf = False
@@ -11,9 +10,7 @@ class FileNode:
self.parenting_ship = []
self.comment = ""
self.comment_maxlen_show = 50
self.build_manifest = build_manifest
self.manifest = {}
@staticmethod
def add_linebreaks_at_spaces(string, interval=10):
return '\n'.join(string[i:i+interval] for i in range(0, len(string), interval))
@@ -32,7 +29,6 @@ class FileNode:
level = 1
if directory_names == "":
new_node = FileNode(file_name)
self.manifest[file_path] = new_node
current_node.children.append(new_node)
new_node.is_leaf = True
new_node.comment = self.sanitize_comment(file_comment)
@@ -54,14 +50,13 @@ class FileNode:
new_node.level = level - 1
current_node = new_node
term = FileNode(file_name)
self.manifest[file_path] = term
term.level = level
term.comment = self.sanitize_comment(file_comment)
term.is_leaf = True
current_node.children.append(term)
def print_files_recursively(self, level=0, code="R0"):
logger.info(' '*level + self.name + ' ' + str(self.is_leaf) + ' ' + str(self.level))
print(' '*level + self.name + ' ' + str(self.is_leaf) + ' ' + str(self.level))
for j, child in enumerate(self.children):
child.print_files_recursively(level=level+1, code=code+str(j))
self.parenting_ship.extend(child.parenting_ship)
@@ -124,4 +119,4 @@ if __name__ == "__main__":
"用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器用于加载和分割文件中的文本的通用文件加载器",
"包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类包含了用于构建和管理向量数据库的函数和类",
]
logger.info(build_file_tree_mermaid_diagram(file_manifest, file_comments, "项目文件树"))
print(build_file_tree_mermaid_diagram(file_manifest, file_comments, "项目文件树"))

查看文件

@@ -1,450 +0,0 @@
import os
import time
from abc import ABC, abstractmethod
from datetime import datetime
from docx import Document
from docx.enum.style import WD_STYLE_TYPE
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT, WD_LINE_SPACING
from docx.oxml.ns import qn
from docx.shared import Inches, Cm
from docx.shared import Pt, RGBColor, Inches
from typing import Dict, List, Tuple
class DocumentFormatter(ABC):
"""文档格式化基类,定义文档格式化的基本接口"""
def __init__(self, final_summary: str, file_summaries_map: Dict, failed_files: List[Tuple]):
self.final_summary = final_summary
self.file_summaries_map = file_summaries_map
self.failed_files = failed_files
@abstractmethod
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
pass
@abstractmethod
def format_file_summaries(self) -> str:
"""格式化文件总结内容"""
pass
@abstractmethod
def create_document(self) -> str:
"""创建完整文档"""
pass
class WordFormatter(DocumentFormatter):
"""Word格式文档生成器 - 符合中国政府公文格式规范(GB/T 9704-2012),并进行了优化"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.doc = Document()
self._setup_document()
self._create_styles()
# 初始化三级标题编号系统
self.numbers = {
1: 0, # 一级标题编号
2: 0, # 二级标题编号
3: 0 # 三级标题编号
}
def _setup_document(self):
"""设置文档基本格式,包括页面设置和页眉"""
sections = self.doc.sections
for section in sections:
# 设置页面大小为A4
section.page_width = Cm(21)
section.page_height = Cm(29.7)
# 设置页边距
section.top_margin = Cm(3.7) # 上边距37mm
section.bottom_margin = Cm(3.5) # 下边距35mm
section.left_margin = Cm(2.8) # 左边距28mm
section.right_margin = Cm(2.6) # 右边距26mm
# 设置页眉页脚距离
section.header_distance = Cm(2.0)
section.footer_distance = Cm(2.0)
# 添加页眉
header = section.header
header_para = header.paragraphs[0]
header_para.alignment = WD_PARAGRAPH_ALIGNMENT.RIGHT
header_run = header_para.add_run("该文档由GPT-academic生成")
header_run.font.name = '仿宋'
header_run._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
header_run.font.size = Pt(9)
def _create_styles(self):
"""创建文档样式"""
# 创建正文样式
style = self.doc.styles.add_style('Normal_Custom', WD_STYLE_TYPE.PARAGRAPH)
style.font.name = '仿宋'
style._element.rPr.rFonts.set(qn('w:eastAsia'), '仿宋')
style.font.size = Pt(14)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
style.paragraph_format.space_after = Pt(0)
style.paragraph_format.first_line_indent = Pt(28)
# 创建各级标题样式
self._create_heading_style('Title_Custom', '方正小标宋简体', 32, WD_PARAGRAPH_ALIGNMENT.CENTER)
self._create_heading_style('Heading1_Custom', '黑体', 22, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading2_Custom', '黑体', 18, WD_PARAGRAPH_ALIGNMENT.LEFT)
self._create_heading_style('Heading3_Custom', '黑体', 16, WD_PARAGRAPH_ALIGNMENT.LEFT)
def _create_heading_style(self, style_name: str, font_name: str, font_size: int, alignment):
"""创建标题样式"""
style = self.doc.styles.add_style(style_name, WD_STYLE_TYPE.PARAGRAPH)
style.font.name = font_name
style._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
style.font.size = Pt(font_size)
style.font.bold = True
style.paragraph_format.alignment = alignment
style.paragraph_format.space_before = Pt(12)
style.paragraph_format.space_after = Pt(12)
style.paragraph_format.line_spacing_rule = WD_LINE_SPACING.ONE_POINT_FIVE
return style
def _get_heading_number(self, level: int) -> str:
"""
生成标题编号
Args:
level: 标题级别 (0-3)
Returns:
str: 格式化的标题编号
"""
if level == 0: # 主标题不需要编号
return ""
self.numbers[level] += 1 # 增加当前级别的编号
# 重置下级标题编号
for i in range(level + 1, 4):
self.numbers[i] = 0
# 根据级别返回不同格式的编号
if level == 1:
return f"{self.numbers[1]}. "
elif level == 2:
return f"{self.numbers[1]}.{self.numbers[2]} "
elif level == 3:
return f"{self.numbers[1]}.{self.numbers[2]}.{self.numbers[3]} "
return ""
def _add_heading(self, text: str, level: int):
"""
添加带编号的标题
Args:
text: 标题文本
level: 标题级别 (0-3)
"""
style_map = {
0: 'Title_Custom',
1: 'Heading1_Custom',
2: 'Heading2_Custom',
3: 'Heading3_Custom'
}
number = self._get_heading_number(level)
paragraph = self.doc.add_paragraph(style=style_map[level])
if number:
number_run = paragraph.add_run(number)
font_size = 22 if level == 1 else (18 if level == 2 else 16)
self._get_run_style(number_run, '黑体', font_size, True)
text_run = paragraph.add_run(text)
font_size = 32 if level == 0 else (22 if level == 1 else (18 if level == 2 else 16))
self._get_run_style(text_run, '黑体', font_size, True)
# 主标题添加日期
if level == 0:
date_paragraph = self.doc.add_paragraph()
date_paragraph.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER
date_run = date_paragraph.add_run(datetime.now().strftime('%Y年%m月%d'))
self._get_run_style(date_run, '仿宋', 16, False)
return paragraph
def _get_run_style(self, run, font_name: str, font_size: int, bold: bool = False):
"""设置文本运行对象的样式"""
run.font.name = font_name
run._element.rPr.rFonts.set(qn('w:eastAsia'), font_name)
run.font.size = Pt(font_size)
run.font.bold = bold
def format_failed_files(self) -> str:
"""格式化失败文件列表"""
result = []
if not self.failed_files:
return "\n".join(result)
result.append("处理失败文件:")
for fp, reason in self.failed_files:
result.append(f"{os.path.basename(fp)}: {reason}")
self._add_heading("处理失败文件", 1)
for fp, reason in self.failed_files:
self._add_content(f"{os.path.basename(fp)}: {reason}", indent=False)
self.doc.add_paragraph()
return "\n".join(result)
def _add_content(self, text: str, indent: bool = True):
"""添加正文内容"""
paragraph = self.doc.add_paragraph(text, style='Normal_Custom')
if not indent:
paragraph.paragraph_format.first_line_indent = Pt(0)
return paragraph
def format_file_summaries(self) -> str:
"""
格式化文件总结内容,确保正确的标题层级
返回:
str: 格式化后的文件总结字符串
标题层级规则:
1. 一级标题为"各文件详细总结"
2. 如果文件有目录路径:
- 目录路径作为二级标题 (2.1, 2.2 等)
- 该目录下所有文件作为三级标题 (2.1.1, 2.1.2 等)
3. 如果文件没有目录路径:
- 文件直接作为二级标题 (2.1, 2.2 等)
"""
result = []
# 首先对文件路径进行分组整理
file_groups = {}
for path in sorted(self.file_summaries_map.keys()):
dir_path = os.path.dirname(path)
if dir_path not in file_groups:
file_groups[dir_path] = []
file_groups[dir_path].append(path)
# 处理没有目录的文件
root_files = file_groups.get("", [])
if root_files:
for path in sorted(root_files):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 无目录的文件作为二级标题
self._add_heading(f"📄 {file_name}", 2)
self._add_content(self.file_summaries_map[path])
self.doc.add_paragraph()
# 处理有目录的文件
for dir_path in sorted(file_groups.keys()):
if dir_path == "": # 跳过已处理的根目录文件
continue
# 添加目录作为二级标题
result.append(f"\n📁 {dir_path}")
self._add_heading(f"📁 {dir_path}", 2)
# 该目录下的所有文件作为三级标题
for path in sorted(file_groups[dir_path]):
file_name = os.path.basename(path)
result.append(f"\n📄 {file_name}")
result.append(self.file_summaries_map[path])
# 添加文件名作为三级标题
self._add_heading(f"📄 {file_name}", 3)
self._add_content(self.file_summaries_map[path])
self.doc.add_paragraph()
return "\n".join(result)
def create_document(self):
"""创建完整Word文档并返回文档对象"""
# 重置所有编号
for level in self.numbers:
self.numbers[level] = 0
# 添加主标题
self._add_heading("文档总结报告", 0)
self.doc.add_paragraph()
# 添加总体摘要
self._add_heading("总体摘要", 1)
self._add_content(self.final_summary)
self.doc.add_paragraph()
# 添加失败文件列表(如果有)
if self.failed_files:
self.format_failed_files()
# 添加文件详细总结
self._add_heading("各文件详细总结", 1)
self.format_file_summaries()
return self.doc
class MarkdownFormatter(DocumentFormatter):
"""Markdown格式文档生成器"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
formatted_text = ["\n## ⚠️ 处理失败的文件"]
for fp, reason in self.failed_files:
formatted_text.append(f"- {os.path.basename(fp)}: {reason}")
formatted_text.append("\n---")
return "\n".join(formatted_text)
def format_file_summaries(self) -> str:
formatted_text = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_text.append(f"\n## 📁 {dir_path}")
current_dir = dir_path
file_name = os.path.basename(path)
formatted_text.append(f"\n### 📄 {file_name}")
formatted_text.append(self.file_summaries_map[path])
formatted_text.append("\n---")
return "\n".join(formatted_text)
def create_document(self) -> str:
document = [
"# 📑 文档总结报告",
"\n## 总体摘要",
self.final_summary
]
if self.failed_files:
document.append(self.format_failed_files())
document.extend([
"\n# 📚 各文件详细总结",
self.format_file_summaries()
])
return "\n".join(document)
class HtmlFormatter(DocumentFormatter):
"""HTML格式文档生成器"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.css_styles = """
body {
font-family: "Microsoft YaHei", Arial, sans-serif;
line-height: 1.6;
max-width: 1000px;
margin: 0 auto;
padding: 20px;
color: #333;
}
h1 {
color: #2c3e50;
border-bottom: 2px solid #eee;
padding-bottom: 10px;
font-size: 24px;
text-align: center;
}
h2 {
color: #34495e;
margin-top: 30px;
font-size: 20px;
border-left: 4px solid #3498db;
padding-left: 10px;
}
h3 {
color: #2c3e50;
font-size: 18px;
margin-top: 20px;
}
.summary {
background-color: #f8f9fa;
padding: 20px;
border-radius: 5px;
margin: 20px 0;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.details {
margin-top: 40px;
}
.failed-files {
background-color: #fff3f3;
padding: 15px;
border-left: 4px solid #e74c3c;
margin: 20px 0;
}
.file-summary {
background-color: #fff;
padding: 15px;
margin: 15px 0;
border-radius: 4px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
"""
def format_failed_files(self) -> str:
if not self.failed_files:
return ""
failed_files_html = ['<div class="failed-files">']
failed_files_html.append("<h2>⚠️ 处理失败的文件</h2>")
failed_files_html.append("<ul>")
for fp, reason in self.failed_files:
failed_files_html.append(f"<li><strong>{os.path.basename(fp)}:</strong> {reason}</li>")
failed_files_html.append("</ul></div>")
return "\n".join(failed_files_html)
def format_file_summaries(self) -> str:
formatted_html = []
sorted_paths = sorted(self.file_summaries_map.keys())
current_dir = ""
for path in sorted_paths:
dir_path = os.path.dirname(path)
if dir_path != current_dir:
if dir_path:
formatted_html.append(f'<h2>📁 {dir_path}</h2>')
current_dir = dir_path
file_name = os.path.basename(path)
formatted_html.append('<div class="file-summary">')
formatted_html.append(f'<h3>📄 {file_name}</h3>')
formatted_html.append(f'<p>{self.file_summaries_map[path]}</p>')
formatted_html.append('</div>')
return "\n".join(formatted_html)
def create_document(self) -> str:
return f"""
<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'>
<title>文档总结报告</title>
<style>{self.css_styles}</style>
</head>
<body>
<h1>📑 文档总结报告</h1>
<h2>总体摘要</h2>
<div class="summary">{self.final_summary}</div>
{self.format_failed_files()}
<div class="details">
<h2>📚 各文件详细总结</h2>
{self.format_file_summaries()}
</div>
</body>
</html>
"""

查看文件

@@ -8,7 +8,7 @@ import random
class MiniGame_ASCII_Art(GptAcademicGameBaseState):
def step(self, prompt, chatbot, history):
if self.step_cnt == 0:
if self.step_cnt == 0:
chatbot.append(["我画你猜(动物)", "请稍等..."])
else:
if prompt.strip() == 'exit':

查看文件

@@ -88,23 +88,23 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
self.story = []
chatbot.append(["互动写故事", f"这次的故事开头是:{self.headstart}"])
self.sys_prompt_ = '你是一个想象力丰富的杰出作家。正在与你的朋友互动,一起写故事,因此你每次写的故事段落应少于300字结局除外'
def generate_story_image(self, story_paragraph):
try:
from crazy_functions.Image_Generate import gen_image
from crazy_functions.图片生成 import gen_image
prompt_ = predict_no_ui_long_connection(inputs=story_paragraph, llm_kwargs=self.llm_kwargs, history=[], sys_prompt='你需要根据用户给出的小说段落,进行简短的环境描写。要求80字以内。')
image_url, image_path = gen_image(self.llm_kwargs, prompt_, '512x512', model="dall-e-2", quality='standard', style='natural')
return f'<br/><div align="center"><img src="file={image_path}"></div>'
except:
return ''
def step(self, prompt, chatbot, history):
"""
首先,处理游戏初始化等特殊情况
"""
if self.step_cnt == 0:
if self.step_cnt == 0:
self.begin_game_step_0(prompt, chatbot, history)
self.lock_plugin(chatbot)
self.cur_task = 'head_start'
@@ -132,7 +132,7 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
inputs_ = prompts_hs.format(headstart=self.headstart)
history_ = []
story_paragraph = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs_, '故事开头', self.llm_kwargs,
inputs_, '故事开头', self.llm_kwargs,
chatbot, history_, self.sys_prompt_
)
self.story.append(story_paragraph)
@@ -147,7 +147,7 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
inputs_ = prompts_interact.format(previously_on_story=previously_on_story)
history_ = []
self.next_choices = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs_, '请在以下几种故事走向中,选择一种(当然,您也可以选择给出其他故事走向):', self.llm_kwargs,
inputs_, '请在以下几种故事走向中,选择一种(当然,您也可以选择给出其他故事走向):', self.llm_kwargs,
chatbot,
history_,
self.sys_prompt_
@@ -166,7 +166,7 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
inputs_ = prompts_resume.format(previously_on_story=previously_on_story, choice=self.next_choices, user_choice=prompt)
history_ = []
story_paragraph = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs_, f'下一段故事(您的选择是:{prompt})。', self.llm_kwargs,
inputs_, f'下一段故事(您的选择是:{prompt})。', self.llm_kwargs,
chatbot, history_, self.sys_prompt_
)
self.story.append(story_paragraph)
@@ -181,10 +181,10 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
inputs_ = prompts_interact.format(previously_on_story=previously_on_story)
history_ = []
self.next_choices = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs_,
'请在以下几种故事走向中,选择一种。当然,您也可以给出您心中的其他故事走向。另外,如果您希望剧情立即收尾,请输入剧情走向,并以“剧情收尾”四个字提示程序。', self.llm_kwargs,
chatbot,
history_,
inputs_,
'请在以下几种故事走向中,选择一种。当然,您也可以给出您心中的其他故事走向。另外,如果您希望剧情立即收尾,请输入剧情走向,并以“剧情收尾”四个字提示程序。', self.llm_kwargs,
chatbot,
history_,
self.sys_prompt_
)
self.cur_task = 'user_choice'
@@ -200,7 +200,7 @@ class MiniGame_ResumeStory(GptAcademicGameBaseState):
inputs_ = prompts_terminate.format(previously_on_story=previously_on_story, user_choice=prompt)
history_ = []
story_paragraph = yield from request_gpt_model_in_new_thread_with_ui_alive(
inputs_, f'故事收尾(您的选择是:{prompt})。', self.llm_kwargs,
inputs_, f'故事收尾(您的选择是:{prompt})。', self.llm_kwargs,
chatbot, history_, self.sys_prompt_
)
# # 配图

查看文件

@@ -5,7 +5,7 @@ def get_code_block(reply):
import re
pattern = r"```([\s\S]*?)```" # regex pattern to match code blocks
matches = re.findall(pattern, reply) # find all code blocks in text
if len(matches) == 1:
if len(matches) == 1:
return "```" + matches[0] + "```" # code block
raise RuntimeError("GPT is not generating proper code.")
@@ -13,10 +13,10 @@ def is_same_thing(a, b, llm_kwargs):
from pydantic import BaseModel, Field
class IsSameThing(BaseModel):
is_same_thing: bool = Field(description="determine whether two objects are same thing.", default=False)
def run_gpt_fn(inputs, sys_prompt, history=[]):
def run_gpt_fn(inputs, sys_prompt, history=[]):
return predict_no_ui_long_connection(
inputs=inputs, llm_kwargs=llm_kwargs,
inputs=inputs, llm_kwargs=llm_kwargs,
history=history, sys_prompt=sys_prompt, observe_window=[]
)
@@ -24,7 +24,7 @@ def is_same_thing(a, b, llm_kwargs):
inputs_01 = "Identity whether the user input and the target is the same thing: \n target object: {a} \n user input object: {b} \n\n\n".format(a=a, b=b)
inputs_01 += "\n\n\n Note that the user may describe the target object with a different language, e.g. cat and 猫 are the same thing."
analyze_res_cot_01 = run_gpt_fn(inputs_01, "", [])
inputs_02 = inputs_01 + gpt_json_io.format_instructions
analyze_res = run_gpt_fn(inputs_02, "", [inputs_01, analyze_res_cot_01])

查看文件

@@ -41,11 +41,11 @@ def is_function_successfully_generated(fn_path, class_name, return_dict):
# Now you can create an instance of the class
instance = some_class()
return_dict['success'] = True
return
return
except:
return_dict['traceback'] = trimmed_format_exc()
return
def subprocess_worker(code, file_path, return_dict):
return_dict['result'] = None
return_dict['success'] = False

查看文件

@@ -1,4 +1,4 @@
import platform
import platform
import pickle
import multiprocessing

查看文件

@@ -24,8 +24,8 @@ class Actor(BaseModel):
film_names: List[str] = Field(description="list of names of films they starred in")
"""
import json, re
from loguru import logger as logging
import json, re, logging
PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
@@ -62,8 +62,8 @@ class GptJsonIO():
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)
if self.example_instruction:
schema_str = json.dumps(reduced_schema)
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
else:
return PYDANTIC_FORMAT_INSTRUCTIONS_SIMPLE.format(schema=schema_str)
@@ -89,7 +89,7 @@ class GptJsonIO():
error + "\n\n" + \
"Now, fix this json string. \n\n"
return prompt
def generate_output_auto_repair(self, response, gpt_gen_fn):
"""
response: string containing canidate json

查看文件

@@ -1,26 +0,0 @@
from crazy_functions.json_fns.pydantic_io import GptJsonIO, JsonStringError
def structure_output(txt, prompt, err_msg, run_gpt_fn, pydantic_cls):
gpt_json_io = GptJsonIO(pydantic_cls)
analyze_res = run_gpt_fn(
txt,
sys_prompt=prompt + gpt_json_io.format_instructions
)
try:
friend = gpt_json_io.generate_output_auto_repair(analyze_res, run_gpt_fn)
except JsonStringError as e:
return None, err_msg
err_msg = ""
return friend, err_msg
def select_tool(prompt, run_gpt_fn, pydantic_cls):
pydantic_cls_instance, err_msg = structure_output(
txt=prompt,
prompt="根据提示, 分析应该调用哪个工具函数\n\n",
err_msg=f"不能理解该联系人",
run_gpt_fn=run_gpt_fn,
pydantic_cls=pydantic_cls
)
return pydantic_cls_instance, err_msg

查看文件

@@ -1,17 +1,14 @@
import os
import re
import shutil
import numpy as np
from loguru import logger
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder, gen_time_str
from toolbox import get_conf, promote_file_to_downloadzone
from crazy_functions.latex_fns.latex_toolbox import PRESERVE, TRANSFORM
from crazy_functions.latex_fns.latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
from crazy_functions.latex_fns.latex_toolbox import reverse_forbidden_text_careful_brace, reverse_forbidden_text, convert_to_linklist, post_process
from crazy_functions.latex_fns.latex_toolbox import fix_content, find_main_tex_file, merge_tex_files, compile_latex_with_timeout
from crazy_functions.latex_fns.latex_toolbox import find_title_and_abs
from crazy_functions.latex_fns.latex_pickle_io import objdump, objload
from toolbox import update_ui, update_ui_lastest_msg, get_log_folder
from toolbox import get_conf, objdump, objload, promote_file_to_downloadzone
from .latex_toolbox import PRESERVE, TRANSFORM
from .latex_toolbox import set_forbidden_text, set_forbidden_text_begin_end, set_forbidden_text_careful_brace
from .latex_toolbox import reverse_forbidden_text_careful_brace, reverse_forbidden_text, convert_to_linklist, post_process
from .latex_toolbox import fix_content, find_main_tex_file, merge_tex_files, compile_latex_with_timeout
from .latex_toolbox import find_title_and_abs
import os, shutil
import re
import numpy as np
pj = os.path.join
@@ -93,16 +90,16 @@ class LatexPaperSplit():
"版权归原文作者所有。翻译内容可靠性无保障,请仔细鉴别并以原文为准。" + \
"项目Github地址 \\url{https://github.com/binary-husky/gpt_academic/}。"
# 请您不要删除或修改这行警告,除非您是论文的原作者如果您是论文原作者,欢迎加REAME中的QQ联系开发者
self.msg_declare = "为了防止大语言模型的意外谬误产生扩散影响,禁止移除或修改此警告。}}\\\\"
self.msg_declare = "为了防止大语言模型的意外谬误产生扩散影响,禁止移除或修改此警告。}}\\\\"
self.title = "unknown"
self.abstract = "unknown"
def read_title_and_abstract(self, txt):
try:
title, abstract = find_title_and_abs(txt)
if title is not None:
if title is not None:
self.title = title.replace('\n', ' ').replace('\\\\', ' ').replace(' ', '').replace(' ', '')
if abstract is not None:
if abstract is not None:
self.abstract = abstract.replace('\n', ' ').replace('\\\\', ' ').replace(' ', '').replace(' ', '')
except:
pass
@@ -114,7 +111,7 @@ class LatexPaperSplit():
result_string = ""
node_cnt = 0
line_cnt = 0
for node in self.nodes:
if node.preserve:
line_cnt += node.string.count('\n')
@@ -147,7 +144,7 @@ class LatexPaperSplit():
return result_string
def split(self, txt, project_folder, opts):
def split(self, txt, project_folder, opts):
"""
break down latex file to a linked list,
each node use a preserve flag to indicate whether it should
@@ -158,7 +155,7 @@ class LatexPaperSplit():
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=split_subprocess,
target=split_subprocess,
args=(txt, project_folder, return_dict, opts))
p.start()
p.join()
@@ -220,13 +217,13 @@ def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin
from ..crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from .latex_actions import LatexPaperFileGroup, LatexPaperSplit
# <-------- 寻找主tex文件 ---------->
# <-------- 寻找主tex文件 ---------->
maintex = find_main_tex_file(file_manifest, mode)
chatbot.append((f"定位主Latex文件", f'[Local Message] 分析结果该项目的Latex主文件是{maintex}, 如果分析错误, 请立即终止程序, 删除或修改歧义文件, 然后重试。主程序即将开始, 请稍候。'))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
time.sleep(3)
# <-------- 读取Latex文件, 将多文件tex工程融合为一个巨型tex ---------->
# <-------- 读取Latex文件, 将多文件tex工程融合为一个巨型tex ---------->
main_tex_basename = os.path.basename(maintex)
assert main_tex_basename.endswith('.tex')
main_tex_basename_bare = main_tex_basename[:-4]
@@ -243,13 +240,13 @@ def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin
with open(project_folder + '/merge.tex', 'w', encoding='utf-8', errors='replace') as f:
f.write(merged_content)
# <-------- 精细切分latex文件 ---------->
# <-------- 精细切分latex文件 ---------->
chatbot.append((f"Latex文件融合完成", f'[Local Message] 正在精细切分latex文件,这需要一段时间计算,文档越长耗时越长,请耐心等待。'))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
lps = LatexPaperSplit()
lps.read_title_and_abstract(merged_content)
res = lps.split(merged_content, project_folder, opts) # 消耗时间的函数
# <-------- 拆分过长的latex片段 ---------->
# <-------- 拆分过长的latex片段 ---------->
pfg = LatexPaperFileGroup()
for index, r in enumerate(res):
pfg.file_paths.append('segment-' + str(index))
@@ -258,17 +255,17 @@ def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin
pfg.run_file_split(max_token_limit=1024)
n_split = len(pfg.sp_file_contents)
# <-------- 根据需要切换prompt ---------->
# <-------- 根据需要切换prompt ---------->
inputs_array, sys_prompt_array = switch_prompt(pfg, mode)
inputs_show_user_array = [f"{mode} {f}" for f in pfg.sp_file_tag]
if os.path.exists(pj(project_folder,'temp.pkl')):
# <-------- 【仅调试】如果存在调试缓存文件,则跳过GPT请求环节 ---------->
# <-------- 【仅调试】如果存在调试缓存文件,则跳过GPT请求环节 ---------->
pfg = objload(file=pj(project_folder,'temp.pkl'))
else:
# <-------- gpt 多线程请求 ---------->
# <-------- gpt 多线程请求 ---------->
history_array = [[""] for _ in range(n_split)]
# LATEX_EXPERIMENTAL, = get_conf('LATEX_EXPERIMENTAL')
# if LATEX_EXPERIMENTAL:
@@ -287,32 +284,32 @@ def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin
scroller_max_len = 40
)
# <-------- 文本碎片重组为完整的tex片段 ---------->
# <-------- 文本碎片重组为完整的tex片段 ---------->
pfg.sp_file_result = []
for i_say, gpt_say, orig_content in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], pfg.sp_file_contents):
pfg.sp_file_result.append(gpt_say)
pfg.merge_result()
# <-------- 临时存储用于调试 ---------->
# <-------- 临时存储用于调试 ---------->
pfg.get_token_num = None
objdump(pfg, file=pj(project_folder,'temp.pkl'))
write_html(pfg.sp_file_contents, pfg.sp_file_result, chatbot=chatbot, project_folder=project_folder)
# <-------- 写出文件 ---------->
# <-------- 写出文件 ---------->
msg = f"当前大语言模型: {llm_kwargs['llm_model']},当前语言模型温度设定: {llm_kwargs['temperature']}"
final_tex = lps.merge_result(pfg.file_result, mode, msg)
objdump((lps, pfg.file_result, mode, msg), file=pj(project_folder,'merge_result.pkl'))
with open(project_folder + f'/merge_{mode}.tex', 'w', encoding='utf-8', errors='replace') as f:
if mode != 'translate_zh' or "binary" in final_tex: f.write(final_tex)
# <-------- 整理结果, 退出 ---------->
# <-------- 整理结果, 退出 ---------->
chatbot.append((f"完成了吗?", 'GPT结果已输出, 即将编译PDF'))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# <-------- 返回 ---------->
# <-------- 返回 ---------->
return project_folder + f'/merge_{mode}.tex'
@@ -325,7 +322,7 @@ def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work
buggy_lines = [int(l) for l in buggy_lines]
buggy_lines = sorted(buggy_lines)
buggy_line = buggy_lines[0]-1
logger.warning("reversing tex line that has errors", buggy_line)
print("reversing tex line that has errors", buggy_line)
# 重组,逆转出错的段落
if buggy_line not in fixed_line:
@@ -339,7 +336,7 @@ def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work
return True, f"{tex_name_pure}_fix_{n_fix}", buggy_lines
except:
logger.error("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
print("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
return False, -1, [-1]
@@ -365,7 +362,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译转化后的PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_modified}.tex', work_folder_modified)
if ok and os.path.exists(pj(work_folder_modified, f'{main_file_modified}.pdf')):
# 只有第二步成功,才能继续下面的步骤
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译BibTex ...', chatbot, history) # 刷新Gradio前端界面
@@ -382,7 +379,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
if mode!='translate_zh':
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
logger.info( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
print( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
ok = compile_latex_with_timeout(f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex', os.getcwd())
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
@@ -396,9 +393,9 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
original_pdf_success = os.path.exists(pj(work_folder_original, f'{main_file_original}.pdf'))
modified_pdf_success = os.path.exists(pj(work_folder_modified, f'{main_file_modified}.pdf'))
diff_pdf_success = os.path.exists(pj(work_folder, f'merge_diff.pdf'))
results_ += f"原始PDF编译是否成功: {original_pdf_success};"
results_ += f"转化PDF编译是否成功: {modified_pdf_success};"
results_ += f"对比PDF编译是否成功: {diff_pdf_success};"
results_ += f"原始PDF编译是否成功: {original_pdf_success};"
results_ += f"转化PDF编译是否成功: {modified_pdf_success};"
results_ += f"对比PDF编译是否成功: {diff_pdf_success};"
yield from update_ui_lastest_msg(f'{n_fix}编译结束:<br/>{results_}...', chatbot, history) # 刷新Gradio前端界面
if diff_pdf_success:
@@ -412,7 +409,7 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
shutil.copyfile(result_pdf, pj(work_folder, '..', 'translation', 'translate_zh.pdf'))
promote_file_to_downloadzone(result_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
# 将两个PDF拼接
if original_pdf_success:
if original_pdf_success:
try:
from .latex_toolbox import merge_pdfs
concat_pdf = pj(work_folder_modified, f'comparison.pdf')
@@ -421,14 +418,14 @@ def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_f
shutil.copyfile(concat_pdf, pj(work_folder, '..', 'translation', 'comparison.pdf'))
promote_file_to_downloadzone(concat_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
except Exception as e:
logger.error(e)
print(e)
pass
return True # 成功啦
else:
if n_fix>=max_try: break
n_fix += 1
can_retry, main_file_modified, buggy_lines = remove_buggy_lines(
file_path=pj(work_folder_modified, f'{main_file_modified}.tex'),
file_path=pj(work_folder_modified, f'{main_file_modified}.tex'),
log_path=pj(work_folder_modified, f'{main_file_modified}.log'),
tex_name=f'{main_file_modified}.tex',
tex_name_pure=f'{main_file_modified}',
@@ -448,14 +445,14 @@ def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
import shutil
from crazy_functions.pdf_fns.report_gen_html import construct_html
from toolbox import gen_time_str
ch = construct_html()
ch = construct_html()
orig = ""
trans = ""
final = []
for c,r in zip(sp_file_contents, sp_file_result):
for c,r in zip(sp_file_contents, sp_file_result):
final.append(c)
final.append(r)
for i, k in enumerate(final):
for i, k in enumerate(final):
if i%2==0:
orig = k
if i%2==1:
@@ -467,71 +464,4 @@ def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
promote_file_to_downloadzone(file=res, chatbot=chatbot)
except:
from toolbox import trimmed_format_exc
logger.error('writing html result failed:', trimmed_format_exc())
def upload_to_gptac_cloud_if_user_allow(chatbot, arxiv_id):
try:
# 如果用户允许,我们将arxiv论文PDF上传到GPTAC学术云
from toolbox import map_file_to_sha256
# 检查是否顺利,如果没有生成预期的文件,则跳过
is_result_good = False
for file_path in chatbot._cookies.get("files_to_promote", []):
if file_path.endswith('translate_zh.pdf'):
is_result_good = True
if not is_result_good:
return
# 上传文件
for file_path in chatbot._cookies.get("files_to_promote", []):
align_name = None
# normalized name
for name in ['translate_zh.pdf', 'comparison.pdf']:
if file_path.endswith(name): align_name = name
# if match any align name
if align_name:
logger.info(f'Uploading to GPTAC cloud as the user has set `allow_cloud_io`: {file_path}')
with open(file_path, 'rb') as f:
import requests
url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_upload'
files = {'file': (align_name, f, 'application/octet-stream')}
data = {
'arxiv_id': arxiv_id,
'file_hash': map_file_to_sha256(file_path),
'language': 'zh',
'trans_prompt': 'to_be_implemented',
'llm_model': 'to_be_implemented',
'llm_model_param': 'to_be_implemented',
}
resp = requests.post(url=url, files=files, data=data, timeout=30)
logger.info(f'Uploading terminate ({resp.status_code})`: {file_path}')
except:
# 如果上传失败,不会中断程序,因为这是次要功能
pass
def check_gptac_cloud(arxiv_id, chatbot):
import requests
success = False
downloaded = []
try:
for pdf_target in ['translate_zh.pdf', 'comparison.pdf']:
url = 'https://cloud-2.agent-matrix.com/arxiv_tf_paper_normal_exist'
data = {
'arxiv_id': arxiv_id,
'name': pdf_target,
}
resp = requests.post(url=url, data=data)
cache_hit_result = resp.text.strip('"')
if cache_hit_result.startswith("http"):
url = cache_hit_result
logger.info(f'Downloading from GPTAC cloud: {url}')
resp = requests.get(url=url, timeout=30)
target = os.path.join(get_log_folder(plugin_name='gptac_cloud'), gen_time_str(), pdf_target)
os.makedirs(os.path.dirname(target), exist_ok=True)
with open(target, 'wb') as f:
f.write(resp.content)
new_path = promote_file_to_downloadzone(target, chatbot=chatbot)
success = True
downloaded.append(new_path)
except:
pass
return success, downloaded
print('writing html result failed:', trimmed_format_exc())

查看文件

@@ -1,46 +0,0 @@
import pickle
class SafeUnpickler(pickle.Unpickler):
def get_safe_classes(self):
from crazy_functions.latex_fns.latex_actions import LatexPaperFileGroup, LatexPaperSplit
from crazy_functions.latex_fns.latex_toolbox import LinkedListNode
# 定义允许的安全类
safe_classes = {
# 在这里添加其他安全的类
'LatexPaperFileGroup': LatexPaperFileGroup,
'LatexPaperSplit': LatexPaperSplit,
'LinkedListNode': LinkedListNode,
}
return safe_classes
def find_class(self, module, name):
# 只允许特定的类进行反序列化
self.safe_classes = self.get_safe_classes()
match_class_name = None
for class_name in self.safe_classes.keys():
if (class_name in f'{module}.{name}'):
match_class_name = class_name
if module == 'numpy' or module.startswith('numpy.'):
return super().find_class(module, name)
if match_class_name is not None:
return self.safe_classes[match_class_name]
# 如果尝试加载未授权的类,则抛出异常
raise pickle.UnpicklingError(f"Attempted to deserialize unauthorized class '{name}' from module '{module}'")
def objdump(obj, file="objdump.tmp"):
with open(file, "wb+") as f:
pickle.dump(obj, f)
return
def objload(file="objdump.tmp"):
import os
if not os.path.exists(file):
return
with open(file, "rb") as f:
unpickler = SafeUnpickler(f)
return unpickler.load()

查看文件

@@ -1,8 +1,6 @@
import os
import os, shutil
import re
import shutil
import numpy as np
from loguru import logger
PRESERVE = 0
TRANSFORM = 1
@@ -57,7 +55,7 @@ def post_process(root):
str_stack.append("{")
elif c == "}":
if len(str_stack) == 1:
logger.warning("fixing brace error")
print("stack fix")
return i
str_stack.pop(-1)
else:
@@ -603,7 +601,7 @@ def compile_latex_with_timeout(command, cwd, timeout=60):
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
logger.error("Process timed out (compile_latex_with_timeout)!")
print("Process timed out!")
return False
return True
@@ -644,216 +642,6 @@ def run_in_subprocess(func):
def _merge_pdfs(pdf1_path, pdf2_path, output_path):
try:
logger.info("Merging PDFs using _merge_pdfs_ng")
_merge_pdfs_ng(pdf1_path, pdf2_path, output_path)
except:
logger.info("Merging PDFs using _merge_pdfs_legacy")
_merge_pdfs_legacy(pdf1_path, pdf2_path, output_path)
def _merge_pdfs_ng(pdf1_path, pdf2_path, output_path):
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
from PyPDF2.generic import NameObject, TextStringObject, ArrayObject, FloatObject, NumberObject
Percent = 1
# raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
# Open the first PDF file
with open(pdf1_path, "rb") as pdf1_file:
pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
# Open the second PDF file
with open(pdf2_path, "rb") as pdf2_file:
pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
# Create a new PDF file to store the merged pages
output_writer = PyPDF2.PdfFileWriter()
# Determine the number of pages in each PDF file
num_pages = max(pdf1_reader.numPages, pdf2_reader.numPages)
# Merge the pages from the two PDF files
for page_num in range(num_pages):
# Add the page from the first PDF file
if page_num < pdf1_reader.numPages:
page1 = pdf1_reader.getPage(page_num)
else:
page1 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
# Add the page from the second PDF file
if page_num < pdf2_reader.numPages:
page2 = pdf2_reader.getPage(page_num)
else:
page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
# Create a new empty page with double width
new_page = PyPDF2.PageObject.createBlankPage(
width=int(
int(page1.mediaBox.getWidth())
+ int(page2.mediaBox.getWidth()) * Percent
),
height=max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()),
)
new_page.mergeTranslatedPage(page1, 0, 0)
new_page.mergeTranslatedPage(
page2,
int(
int(page1.mediaBox.getWidth())
- int(page2.mediaBox.getWidth()) * (1 - Percent)
),
0,
)
if "/Annots" in new_page:
annotations = new_page["/Annots"]
for i, annot in enumerate(annotations):
annot_obj = annot.get_object()
# 检查注释类型是否是链接(/Link
if annot_obj.get("/Subtype") == "/Link":
# 检查是否为内部链接跳转(/GoTo或外部URI链接/URI
action = annot_obj.get("/A")
if action:
if "/S" in action and action["/S"] == "/GoTo":
# 内部链接:跳转到文档中的某个页面
dest = action.get("/D") # 目标页或目标位置
# if dest and annot.idnum in page2_annot_id:
# if dest in pdf2_reader.named_destinations:
if dest and page2.annotations:
if annot in page2.annotations:
# 获取原始文件中跳转信息,包括跳转页面
destination = pdf2_reader.named_destinations[
dest
]
page_number = (
pdf2_reader.get_destination_page_number(
destination
)
)
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
# “/D”:[10,'/XYZ',100,100,0]
if destination.dest_array[1] == "/XYZ":
annot_obj["/A"].update(
{
NameObject("/D"): ArrayObject(
[
NumberObject(page_number),
destination.dest_array[1],
FloatObject(
destination.dest_array[
2
]
+ int(
page1.mediaBox.getWidth()
)
),
destination.dest_array[3],
destination.dest_array[4],
]
) # 确保键和值是 PdfObject
}
)
else:
annot_obj["/A"].update(
{
NameObject("/D"): ArrayObject(
[
NumberObject(page_number),
destination.dest_array[1],
]
) # 确保键和值是 PdfObject
}
)
rect = annot_obj.get("/Rect")
# 更新点击坐标
rect = ArrayObject(
[
FloatObject(
rect[0]
+ int(page1.mediaBox.getWidth())
),
rect[1],
FloatObject(
rect[2]
+ int(page1.mediaBox.getWidth())
),
rect[3],
]
)
annot_obj.update(
{
NameObject(
"/Rect"
): rect # 确保键和值是 PdfObject
}
)
# if dest and annot.idnum in page1_annot_id:
# if dest in pdf1_reader.named_destinations:
if dest and page1.annotations:
if annot in page1.annotations:
# 获取原始文件中跳转信息,包括跳转页面
destination = pdf1_reader.named_destinations[
dest
]
page_number = (
pdf1_reader.get_destination_page_number(
destination
)
)
# 更新跳转信息,跳转到对应的页面和,指定坐标 (100, 150),缩放比例为 100%
# “/D”:[10,'/XYZ',100,100,0]
if destination.dest_array[1] == "/XYZ":
annot_obj["/A"].update(
{
NameObject("/D"): ArrayObject(
[
NumberObject(page_number),
destination.dest_array[1],
FloatObject(
destination.dest_array[
2
]
),
destination.dest_array[3],
destination.dest_array[4],
]
) # 确保键和值是 PdfObject
}
)
else:
annot_obj["/A"].update(
{
NameObject("/D"): ArrayObject(
[
NumberObject(page_number),
destination.dest_array[1],
]
) # 确保键和值是 PdfObject
}
)
rect = annot_obj.get("/Rect")
rect = ArrayObject(
[
FloatObject(rect[0]),
rect[1],
FloatObject(rect[2]),
rect[3],
]
)
annot_obj.update(
{
NameObject(
"/Rect"
): rect # 确保键和值是 PdfObject
}
)
elif "/S" in action and action["/S"] == "/URI":
# 外部链接跳转到某个URI
uri = action.get("/URI")
output_writer.addPage(new_page)
# Save the merged PDF file
with open(output_path, "wb") as output_file:
output_writer.write(output_file)
def _merge_pdfs_legacy(pdf1_path, pdf2_path, output_path):
import PyPDF2 # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放
Percent = 0.95

查看文件

@@ -0,0 +1,788 @@
from toolbox import update_ui, update_ui_lastest_msg # 刷新Gradio前端界面
from toolbox import zip_folder, objdump, objload, promote_file_to_downloadzone
import os, shutil
import re
import numpy as np
pj = os.path.join
"""
========================================================================
Part One
Latex segmentation with a binary mask (PRESERVE=0, TRANSFORM=1)
========================================================================
"""
PRESERVE = 0
TRANSFORM = 1
def set_forbidden_text(text, mask, pattern, flags=0):
"""
Add a preserve text area in this paper
e.g. with pattern = r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}"
you can mask out (mask = PRESERVE so that text become untouchable for GPT)
everything between "\begin{equation}" and "\end{equation}"
"""
if isinstance(pattern, list): pattern = '|'.join(pattern)
pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text):
mask[res.span()[0]:res.span()[1]] = PRESERVE
return text, mask
def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
"""
Move area out of preserve area (make text editable for GPT)
count the number of the braces so as to catch compelete text area.
e.g.
\begin{abstract} blablablablablabla. \end{abstract}
"""
if isinstance(pattern, list): pattern = '|'.join(pattern)
pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text):
if not forbid_wrapper:
mask[res.span()[0]:res.span()[1]] = TRANSFORM
else:
mask[res.regs[0][0]: res.regs[1][0]] = PRESERVE # '\\begin{abstract}'
mask[res.regs[1][0]: res.regs[1][1]] = TRANSFORM # abstract
mask[res.regs[1][1]: res.regs[0][1]] = PRESERVE # abstract
return text, mask
def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
"""
Add a preserve text area in this paper (text become untouchable for GPT).
count the number of the braces so as to catch compelete text area.
e.g.
\caption{blablablablabla\texbf{blablabla}blablabla.}
"""
pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text):
brace_level = -1
p = begin = end = res.regs[0][0]
for _ in range(1024*16):
if text[p] == '}' and brace_level == 0: break
elif text[p] == '}': brace_level -= 1
elif text[p] == '{': brace_level += 1
p += 1
end = p+1
mask[begin:end] = PRESERVE
return text, mask
def reverse_forbidden_text_careful_brace(text, mask, pattern, flags=0, forbid_wrapper=True):
"""
Move area out of preserve area (make text editable for GPT)
count the number of the braces so as to catch compelete text area.
e.g.
\caption{blablablablabla\texbf{blablabla}blablabla.}
"""
pattern_compile = re.compile(pattern, flags)
for res in pattern_compile.finditer(text):
brace_level = 0
p = begin = end = res.regs[1][0]
for _ in range(1024*16):
if text[p] == '}' and brace_level == 0: break
elif text[p] == '}': brace_level -= 1
elif text[p] == '{': brace_level += 1
p += 1
end = p
mask[begin:end] = TRANSFORM
if forbid_wrapper:
mask[res.regs[0][0]:begin] = PRESERVE
mask[end:res.regs[0][1]] = PRESERVE
return text, mask
def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
"""
Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
Add it to preserve area
"""
pattern_compile = re.compile(pattern, flags)
def search_with_line_limit(text, mask):
for res in pattern_compile.finditer(text):
cmd = res.group(1) # begin{what}
this = res.group(2) # content between begin and end
this_mask = mask[res.regs[2][0]:res.regs[2][1]]
white_list = ['document', 'abstract', 'lemma', 'definition', 'sproof',
'em', 'emph', 'textit', 'textbf', 'itemize', 'enumerate']
if (cmd in white_list) or this.count('\n') >= limit_n_lines: # use a magical number 42
this, this_mask = search_with_line_limit(this, this_mask)
mask[res.regs[2][0]:res.regs[2][1]] = this_mask
else:
mask[res.regs[0][0]:res.regs[0][1]] = PRESERVE
return text, mask
return search_with_line_limit(text, mask)
class LinkedListNode():
"""
Linked List Node
"""
def __init__(self, string, preserve=True) -> None:
self.string = string
self.preserve = preserve
self.next = None
# self.begin_line = 0
# self.begin_char = 0
def convert_to_linklist(text, mask):
root = LinkedListNode("", preserve=True)
current_node = root
for c, m, i in zip(text, mask, range(len(text))):
if (m==PRESERVE and current_node.preserve) \
or (m==TRANSFORM and not current_node.preserve):
# add
current_node.string += c
else:
current_node.next = LinkedListNode(c, preserve=(m==PRESERVE))
current_node = current_node.next
return root
"""
========================================================================
Latex Merge File
========================================================================
"""
def 寻找Latex主文件(file_manifest, mode):
"""
在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
P.S. 但愿没人把latex模板放在里面传进来 (6.25 加入判定latex模板的代码)
"""
canidates = []
for texf in file_manifest:
if os.path.basename(texf).startswith('merge'):
continue
with open(texf, 'r', encoding='utf8') as f:
file_content = f.read()
if r'\documentclass' in file_content:
canidates.append(texf)
else:
continue
if len(canidates) == 0:
raise RuntimeError('无法找到一个主Tex文件包含documentclass关键字')
elif len(canidates) == 1:
return canidates[0]
else: # if len(canidates) >= 2 通过一些Latex模板中常见但通常不会出现在正文的单词,对不同latex源文件扣分,取评分最高者返回
canidates_score = []
# 给出一些判定模板文档的词作为扣分项
unexpected_words = ['\LaTeX', 'manuscript', 'Guidelines', 'font', 'citations', 'rejected', 'blind review', 'reviewers']
expected_words = ['\input', '\ref', '\cite']
for texf in canidates:
canidates_score.append(0)
with open(texf, 'r', encoding='utf8') as f:
file_content = f.read()
for uw in unexpected_words:
if uw in file_content:
canidates_score[-1] -= 1
for uw in expected_words:
if uw in file_content:
canidates_score[-1] += 1
select = np.argmax(canidates_score) # 取评分最高者返回
return canidates[select]
def rm_comments(main_file):
new_file_remove_comment_lines = []
for l in main_file.splitlines():
# 删除整行的空注释
if l.lstrip().startswith("%"):
pass
else:
new_file_remove_comment_lines.append(l)
main_file = '\n'.join(new_file_remove_comment_lines)
# main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file) # 将 \include 命令转换为 \input 命令
main_file = re.sub(r'(?<!\\)%.*', '', main_file) # 使用正则表达式查找半行注释, 并替换为空字符串
return main_file
def merge_tex_files_(project_foler, main_file, mode):
"""
Merge Tex project recrusively
"""
main_file = rm_comments(main_file)
for s in reversed([q for q in re.finditer(r"\\input\{(.*?)\}", main_file, re.M)]):
f = s.group(1)
fp = os.path.join(project_foler, f)
if os.path.exists(fp):
# e.g., \input{srcs/07_appendix.tex}
with open(fp, 'r', encoding='utf-8', errors='replace') as fx:
c = fx.read()
else:
# e.g., \input{srcs/07_appendix}
with open(fp+'.tex', 'r', encoding='utf-8', errors='replace') as fx:
c = fx.read()
c = merge_tex_files_(project_foler, c, mode)
main_file = main_file[:s.span()[0]] + c + main_file[s.span()[1]:]
return main_file
def merge_tex_files(project_foler, main_file, mode):
"""
Merge Tex project recrusively
P.S. 顺便把CTEX塞进去以支持中文
P.S. 顺便把Latex的注释去除
"""
main_file = merge_tex_files_(project_foler, main_file, mode)
main_file = rm_comments(main_file)
if mode == 'translate_zh':
# find paper documentclass
pattern = re.compile(r'\\documentclass.*\n')
match = pattern.search(main_file)
assert match is not None, "Cannot find documentclass statement!"
position = match.end()
add_ctex = '\\usepackage{ctex}\n'
add_url = '\\usepackage{url}\n' if '{url}' not in main_file else ''
main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
# fontset=windows
import platform
main_file = re.sub(r"\\documentclass\[(.*?)\]{(.*?)}", r"\\documentclass[\1,fontset=windows,UTF8]{\2}",main_file)
main_file = re.sub(r"\\documentclass{(.*?)}", r"\\documentclass[fontset=windows,UTF8]{\1}",main_file)
# find paper abstract
pattern_opt1 = re.compile(r'\\begin\{abstract\}.*\n')
pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
match_opt1 = pattern_opt1.search(main_file)
match_opt2 = pattern_opt2.search(main_file)
assert (match_opt1 is not None) or (match_opt2 is not None), "Cannot find paper abstract section!"
return main_file
"""
========================================================================
Post process
========================================================================
"""
def mod_inbraket(match):
"""
为啥chatgpt会把cite里面的逗号换成中文逗号呀
"""
# get the matched string
cmd = match.group(1)
str_to_modify = match.group(2)
# modify the matched string
str_to_modify = str_to_modify.replace('', ':') # 前面是中文冒号,后面是英文冒号
str_to_modify = str_to_modify.replace('', ',') # 前面是中文逗号,后面是英文逗号
# str_to_modify = 'BOOM'
return "\\" + cmd + "{" + str_to_modify + "}"
def fix_content(final_tex, node_string):
"""
Fix common GPT errors to increase success rate
"""
final_tex = re.sub(r"(?<!\\)%", "\\%", final_tex)
final_tex = re.sub(r"\\([a-z]{2,10})\ \{", r"\\\1{", string=final_tex)
final_tex = re.sub(r"\\\ ([a-z]{2,10})\{", r"\\\1{", string=final_tex)
final_tex = re.sub(r"\\([a-z]{2,10})\{([^\}]*?)\}", mod_inbraket, string=final_tex)
if "Traceback" in final_tex and "[Local Message]" in final_tex:
final_tex = node_string # 出问题了,还原原文
if node_string.count('\\begin') != final_tex.count('\\begin'):
final_tex = node_string # 出问题了,还原原文
if node_string.count('\_') > 0 and node_string.count('\_') > final_tex.count('\_'):
# walk and replace any _ without \
final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)
def compute_brace_level(string):
# this function count the number of { and }
brace_level = 0
for c in string:
if c == "{": brace_level += 1
elif c == "}": brace_level -= 1
return brace_level
def join_most(tex_t, tex_o):
# this function join translated string and original string when something goes wrong
p_t = 0
p_o = 0
def find_next(string, chars, begin):
p = begin
while p < len(string):
if string[p] in chars: return p, string[p]
p += 1
return None, None
while True:
res1, char = find_next(tex_o, ['{','}'], p_o)
if res1 is None: break
res2, char = find_next(tex_t, [char], p_t)
if res2 is None: break
p_o = res1 + 1
p_t = res2 + 1
return tex_t[:p_t] + tex_o[p_o:]
if compute_brace_level(final_tex) != compute_brace_level(node_string):
# 出问题了,还原部分原文,保证括号正确
final_tex = join_most(final_tex, node_string)
return final_tex
def split_subprocess(txt, project_folder, return_dict, opts):
"""
break down latex file to a linked list,
each node use a preserve flag to indicate whether it should
be proccessed by GPT.
"""
text = txt
mask = np.zeros(len(txt), dtype=np.uint8) + TRANSFORM
# 吸收title与作者以上的部分
text, mask = set_forbidden_text(text, mask, r"(.*?)\\maketitle", re.DOTALL)
# 吸收iffalse注释
text, mask = set_forbidden_text(text, mask, r"\\iffalse(.*?)\\fi", re.DOTALL)
# 吸收在42行以内的begin-end组合
text, mask = set_forbidden_text_begin_end(text, mask, r"\\begin\{([a-z\*]*)\}(.*?)\\end\{\1\}", re.DOTALL, limit_n_lines=42)
# 吸收匿名公式
text, mask = set_forbidden_text(text, mask, [ r"\$\$(.*?)\$\$", r"\\\[.*?\\\]" ], re.DOTALL)
# 吸收其他杂项
text, mask = set_forbidden_text(text, mask, [ r"\\section\{(.*?)\}", r"\\section\*\{(.*?)\}", r"\\subsection\{(.*?)\}", r"\\subsubsection\{(.*?)\}" ])
text, mask = set_forbidden_text(text, mask, [ r"\\bibliography\{(.*?)\}", r"\\bibliographystyle\{(.*?)\}" ])
text, mask = set_forbidden_text(text, mask, r"\\begin\{thebibliography\}.*?\\end\{thebibliography\}", re.DOTALL)
text, mask = set_forbidden_text(text, mask, r"\\begin\{lstlisting\}(.*?)\\end\{lstlisting\}", re.DOTALL)
text, mask = set_forbidden_text(text, mask, r"\\begin\{wraptable\}(.*?)\\end\{wraptable\}", re.DOTALL)
text, mask = set_forbidden_text(text, mask, r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}", re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\begin\{wrapfigure\}(.*?)\\end\{wrapfigure\}", r"\\begin\{wrapfigure\*\}(.*?)\\end\{wrapfigure\*\}"], re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\begin\{figure\}(.*?)\\end\{figure\}", r"\\begin\{figure\*\}(.*?)\\end\{figure\*\}"], re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\begin\{multline\}(.*?)\\end\{multline\}", r"\\begin\{multline\*\}(.*?)\\end\{multline\*\}"], re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\begin\{table\}(.*?)\\end\{table\}", r"\\begin\{table\*\}(.*?)\\end\{table\*\}"], re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\begin\{minipage\}(.*?)\\end\{minipage\}", r"\\begin\{minipage\*\}(.*?)\\end\{minipage\*\}"], re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\begin\{align\*\}(.*?)\\end\{align\*\}", r"\\begin\{align\}(.*?)\\end\{align\}"], re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\begin\{equation\}(.*?)\\end\{equation\}", r"\\begin\{equation\*\}(.*?)\\end\{equation\*\}"], re.DOTALL)
text, mask = set_forbidden_text(text, mask, [r"\\includepdf\[(.*?)\]\{(.*?)\}", r"\\clearpage", r"\\newpage", r"\\appendix", r"\\tableofcontents", r"\\include\{(.*?)\}"])
text, mask = set_forbidden_text(text, mask, [r"\\vspace\{(.*?)\}", r"\\hspace\{(.*?)\}", r"\\label\{(.*?)\}", r"\\begin\{(.*?)\}", r"\\end\{(.*?)\}", r"\\item "])
text, mask = set_forbidden_text_careful_brace(text, mask, r"\\hl\{(.*?)\}", re.DOTALL)
# reverse 操作必须放在最后
text, mask = reverse_forbidden_text_careful_brace(text, mask, r"\\caption\{(.*?)\}", re.DOTALL, forbid_wrapper=True)
text, mask = reverse_forbidden_text_careful_brace(text, mask, r"\\abstract\{(.*?)\}", re.DOTALL, forbid_wrapper=True)
text, mask = reverse_forbidden_text(text, mask, r"\\begin\{abstract\}(.*?)\\end\{abstract\}", re.DOTALL, forbid_wrapper=True)
root = convert_to_linklist(text, mask)
# 修复括号
node = root
while True:
string = node.string
if node.preserve:
node = node.next
if node is None: break
continue
def break_check(string):
str_stack = [""] # (lv, index)
for i, c in enumerate(string):
if c == '{':
str_stack.append('{')
elif c == '}':
if len(str_stack) == 1:
print('stack fix')
return i
str_stack.pop(-1)
else:
str_stack[-1] += c
return -1
bp = break_check(string)
if bp == -1:
pass
elif bp == 0:
node.string = string[:1]
q = LinkedListNode(string[1:], False)
q.next = node.next
node.next = q
else:
node.string = string[:bp]
q = LinkedListNode(string[bp:], False)
q.next = node.next
node.next = q
node = node.next
if node is None: break
# 屏蔽空行和太短的句子
node = root
while True:
if len(node.string.strip('\n').strip(''))==0: node.preserve = True
if len(node.string.strip('\n').strip(''))<42: node.preserve = True
node = node.next
if node is None: break
node = root
while True:
if node.next and node.preserve and node.next.preserve:
node.string += node.next.string
node.next = node.next.next
node = node.next
if node is None: break
# 将前后断行符脱离
node = root
prev_node = None
while True:
if not node.preserve:
lstriped_ = node.string.lstrip().lstrip('\n')
if (prev_node is not None) and (prev_node.preserve) and (len(lstriped_)!=len(node.string)):
prev_node.string += node.string[:-len(lstriped_)]
node.string = lstriped_
rstriped_ = node.string.rstrip().rstrip('\n')
if (node.next is not None) and (node.next.preserve) and (len(rstriped_)!=len(node.string)):
node.next.string = node.string[len(rstriped_):] + node.next.string
node.string = rstriped_
# =====
prev_node = node
node = node.next
if node is None: break
# 输出html调试文件,用红色标注处保留区PRESERVE,用黑色标注转换区TRANSFORM
with open(pj(project_folder, 'debug_log.html'), 'w', encoding='utf8') as f:
segment_parts_for_gpt = []
nodes = []
node = root
while True:
nodes.append(node)
show_html = node.string.replace('\n','<br/>')
if not node.preserve:
segment_parts_for_gpt.append(node.string)
f.write(f'<p style="color:black;">#{show_html}#</p>')
else:
f.write(f'<p style="color:red;">{show_html}</p>')
node = node.next
if node is None: break
for n in nodes: n.next = None # break
return_dict['nodes'] = nodes
return_dict['segment_parts_for_gpt'] = segment_parts_for_gpt
return return_dict
class LatexPaperSplit():
"""
break down latex file to a linked list,
each node use a preserve flag to indicate whether it should
be proccessed by GPT.
"""
def __init__(self) -> None:
self.nodes = None
self.msg = "*{\\scriptsize\\textbf{警告该PDF由GPT-Academic开源项目调用大语言模型+Latex翻译插件一键生成," + \
"版权归原文作者所有。翻译内容可靠性无保障,请仔细鉴别并以原文为准。" + \
"项目Github地址 \\url{https://github.com/binary-husky/gpt_academic/}。"
# 请您不要删除或修改这行警告,除非您是论文的原作者如果您是论文原作者,欢迎加REAME中的QQ联系开发者
self.msg_declare = "为了防止大语言模型的意外谬误产生扩散影响,禁止移除或修改此警告。}}\\\\"
def merge_result(self, arr, mode, msg):
"""
Merge the result after the GPT process completed
"""
result_string = ""
p = 0
for node in self.nodes:
if node.preserve:
result_string += node.string
else:
result_string += fix_content(arr[p], node.string)
p += 1
if mode == 'translate_zh':
pattern = re.compile(r'\\begin\{abstract\}.*\n')
match = pattern.search(result_string)
if not match:
# match \abstract{xxxx}
pattern_compile = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
match = pattern_compile.search(result_string)
position = match.regs[1][0]
else:
# match \begin{abstract}xxxx\end{abstract}
position = match.end()
result_string = result_string[:position] + self.msg + msg + self.msg_declare + result_string[position:]
return result_string
def split(self, txt, project_folder, opts):
"""
break down latex file to a linked list,
each node use a preserve flag to indicate whether it should
be proccessed by GPT.
P.S. use multiprocessing to avoid timeout error
"""
import multiprocessing
manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(
target=split_subprocess,
args=(txt, project_folder, return_dict, opts))
p.start()
p.join()
p.close()
self.nodes = return_dict['nodes']
self.sp = return_dict['segment_parts_for_gpt']
return self.sp
class LatexPaperFileGroup():
"""
use tokenizer to break down text according to max_token_limit
"""
def __init__(self):
self.file_paths = []
self.file_contents = []
self.sp_file_contents = []
self.sp_file_index = []
self.sp_file_tag = []
# count_token
from request_llm.bridge_all import model_info
enc = model_info["gpt-3.5-turbo"]['tokenizer']
def get_token_num(txt): return len(enc.encode(txt, disallowed_special=()))
self.get_token_num = get_token_num
def run_file_split(self, max_token_limit=1900):
"""
use tokenizer to break down text according to max_token_limit
"""
for index, file_content in enumerate(self.file_contents):
if self.get_token_num(file_content) < max_token_limit:
self.sp_file_contents.append(file_content)
self.sp_file_index.append(index)
self.sp_file_tag.append(self.file_paths[index])
else:
from .crazy_utils import breakdown_txt_to_satisfy_token_limit_for_pdf
segments = breakdown_txt_to_satisfy_token_limit_for_pdf(file_content, self.get_token_num, max_token_limit)
for j, segment in enumerate(segments):
self.sp_file_contents.append(segment)
self.sp_file_index.append(index)
self.sp_file_tag.append(self.file_paths[index] + f".part-{j}.tex")
print('Segmentation: done')
def merge_result(self):
self.file_result = ["" for _ in range(len(self.file_paths))]
for r, k in zip(self.sp_file_result, self.sp_file_index):
self.file_result[k] += r
def write_result(self):
manifest = []
for path, res in zip(self.file_paths, self.file_result):
with open(path + '.polish.tex', 'w', encoding='utf8') as f:
manifest.append(path + '.polish.tex')
f.write(res)
return manifest
def write_html(sp_file_contents, sp_file_result, chatbot, project_folder):
# write html
try:
import shutil
from .crazy_utils import construct_html
from toolbox import gen_time_str
ch = construct_html()
orig = ""
trans = ""
final = []
for c,r in zip(sp_file_contents, sp_file_result):
final.append(c)
final.append(r)
for i, k in enumerate(final):
if i%2==0:
orig = k
if i%2==1:
trans = k
ch.add_row(a=orig, b=trans)
create_report_file_name = f"{gen_time_str()}.trans.html"
ch.save_file(create_report_file_name)
shutil.copyfile(pj('./gpt_log/', create_report_file_name), pj(project_folder, create_report_file_name))
promote_file_to_downloadzone(file=f'./gpt_log/{create_report_file_name}', chatbot=chatbot)
except:
from toolbox import trimmed_format_exc
print('writing html result failed:', trimmed_format_exc())
def Latex精细分解与转化(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, mode='proofread', switch_prompt=None, opts=[]):
import time, os, re
from .crazy_utils import request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency
from .latex_utils import LatexPaperFileGroup, merge_tex_files, LatexPaperSplit, 寻找Latex主文件
# <-------- 寻找主tex文件 ---------->
maintex = 寻找Latex主文件(file_manifest, mode)
chatbot.append((f"定位主Latex文件", f'[Local Message] 分析结果该项目的Latex主文件是{maintex}, 如果分析错误, 请立即终止程序, 删除或修改歧义文件, 然后重试。主程序即将开始, 请稍候。'))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
time.sleep(3)
# <-------- 读取Latex文件, 将多文件tex工程融合为一个巨型tex ---------->
main_tex_basename = os.path.basename(maintex)
assert main_tex_basename.endswith('.tex')
main_tex_basename_bare = main_tex_basename[:-4]
may_exist_bbl = pj(project_folder, f'{main_tex_basename_bare}.bbl')
if os.path.exists(may_exist_bbl):
shutil.copyfile(may_exist_bbl, pj(project_folder, f'merge.bbl'))
shutil.copyfile(may_exist_bbl, pj(project_folder, f'merge_{mode}.bbl'))
shutil.copyfile(may_exist_bbl, pj(project_folder, f'merge_diff.bbl'))
with open(maintex, 'r', encoding='utf-8', errors='replace') as f:
content = f.read()
merged_content = merge_tex_files(project_folder, content, mode)
with open(project_folder + '/merge.tex', 'w', encoding='utf-8', errors='replace') as f:
f.write(merged_content)
# <-------- 精细切分latex文件 ---------->
chatbot.append((f"Latex文件融合完成", f'[Local Message] 正在精细切分latex文件,这需要一段时间计算,文档越长耗时越长,请耐心等待。'))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
lps = LatexPaperSplit()
res = lps.split(merged_content, project_folder, opts) # 消耗时间的函数
# <-------- 拆分过长的latex片段 ---------->
pfg = LatexPaperFileGroup()
for index, r in enumerate(res):
pfg.file_paths.append('segment-' + str(index))
pfg.file_contents.append(r)
pfg.run_file_split(max_token_limit=1024)
n_split = len(pfg.sp_file_contents)
# <-------- 根据需要切换prompt ---------->
inputs_array, sys_prompt_array = switch_prompt(pfg, mode)
inputs_show_user_array = [f"{mode} {f}" for f in pfg.sp_file_tag]
if os.path.exists(pj(project_folder,'temp.pkl')):
# <-------- 【仅调试】如果存在调试缓存文件,则跳过GPT请求环节 ---------->
pfg = objload(file=pj(project_folder,'temp.pkl'))
else:
# <-------- gpt 多线程请求 ---------->
gpt_response_collection = yield from request_gpt_model_multi_threads_with_very_awesome_ui_and_high_efficiency(
inputs_array=inputs_array,
inputs_show_user_array=inputs_show_user_array,
llm_kwargs=llm_kwargs,
chatbot=chatbot,
history_array=[[""] for _ in range(n_split)],
sys_prompt_array=sys_prompt_array,
# max_workers=5, # 并行任务数量限制, 最多同时执行5个, 其他的排队等待
scroller_max_len = 40
)
# <-------- 文本碎片重组为完整的tex片段 ---------->
pfg.sp_file_result = []
for i_say, gpt_say, orig_content in zip(gpt_response_collection[0::2], gpt_response_collection[1::2], pfg.sp_file_contents):
pfg.sp_file_result.append(gpt_say)
pfg.merge_result()
# <-------- 临时存储用于调试 ---------->
pfg.get_token_num = None
objdump(pfg, file=pj(project_folder,'temp.pkl'))
write_html(pfg.sp_file_contents, pfg.sp_file_result, chatbot=chatbot, project_folder=project_folder)
# <-------- 写出文件 ---------->
msg = f"当前大语言模型: {llm_kwargs['llm_model']},当前语言模型温度设定: {llm_kwargs['temperature']}"
final_tex = lps.merge_result(pfg.file_result, mode, msg)
with open(project_folder + f'/merge_{mode}.tex', 'w', encoding='utf-8', errors='replace') as f:
if mode != 'translate_zh' or "binary" in final_tex: f.write(final_tex)
# <-------- 整理结果, 退出 ---------->
chatbot.append((f"完成了吗?", 'GPT结果已输出, 正在编译PDF'))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# <-------- 返回 ---------->
return project_folder + f'/merge_{mode}.tex'
def remove_buggy_lines(file_path, log_path, tex_name, tex_name_pure, n_fix, work_folder_modified):
try:
with open(log_path, 'r', encoding='utf-8', errors='replace') as f:
log = f.read()
with open(file_path, 'r', encoding='utf-8', errors='replace') as f:
file_lines = f.readlines()
import re
buggy_lines = re.findall(tex_name+':([0-9]{1,5}):', log)
buggy_lines = [int(l) for l in buggy_lines]
buggy_lines = sorted(buggy_lines)
print("removing lines that has errors", buggy_lines)
file_lines.pop(buggy_lines[0]-1)
with open(pj(work_folder_modified, f"{tex_name_pure}_fix_{n_fix}.tex"), 'w', encoding='utf-8', errors='replace') as f:
f.writelines(file_lines)
return True, f"{tex_name_pure}_fix_{n_fix}", buggy_lines
except:
print("Fatal error occurred, but we cannot identify error, please download zip, read latex log, and compile manually.")
return False, -1, [-1]
def compile_latex_with_timeout(command, cwd, timeout=60):
import subprocess
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd)
try:
stdout, stderr = process.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
stdout, stderr = process.communicate()
print("Process timed out!")
return False
return True
def 编译Latex(chatbot, history, main_file_original, main_file_modified, work_folder_original, work_folder_modified, work_folder, mode='default'):
import os, time
current_dir = os.getcwd()
n_fix = 1
max_try = 32
chatbot.append([f"正在编译PDF文档", f'编译已经开始。当前工作路径为{work_folder},如果程序停顿5分钟以上,请直接去该路径下取回翻译结果,或者重启之后再度尝试 ...']); yield from update_ui(chatbot=chatbot, history=history)
chatbot.append([f"正在编译PDF文档", '...']); yield from update_ui(chatbot=chatbot, history=history); time.sleep(1); chatbot[-1] = list(chatbot[-1]) # 刷新界面
yield from update_ui_lastest_msg('编译已经开始...', chatbot, history) # 刷新Gradio前端界面
while True:
import os
# https://stackoverflow.com/questions/738755/dont-make-me-manually-abort-a-latex-compile-when-theres-an-error
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译原始PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_original}.tex', work_folder_original)
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译转化后的PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_modified}.tex', work_folder_modified)
if ok and os.path.exists(pj(work_folder_modified, f'{main_file_modified}.pdf')):
# 只有第二步成功,才能继续下面的步骤
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译BibTex ...', chatbot, history) # 刷新Gradio前端界面
if not os.path.exists(pj(work_folder_original, f'{main_file_original}.bbl')):
ok = compile_latex_with_timeout(f'bibtex {main_file_original}.aux', work_folder_original)
if not os.path.exists(pj(work_folder_modified, f'{main_file_modified}.bbl')):
ok = compile_latex_with_timeout(f'bibtex {main_file_modified}.aux', work_folder_modified)
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 编译文献交叉引用 ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_original}.tex', work_folder_original)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_modified}.tex', work_folder_modified)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_original}.tex', work_folder_original)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error {main_file_modified}.tex', work_folder_modified)
if mode!='translate_zh':
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 使用latexdiff生成论文转化前后对比 ...', chatbot, history) # 刷新Gradio前端界面
print( f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
ok = compile_latex_with_timeout(f'latexdiff --encoding=utf8 --append-safecmd=subfile {work_folder_original}/{main_file_original}.tex {work_folder_modified}/{main_file_modified}.tex --flatten > {work_folder}/merge_diff.tex')
yield from update_ui_lastest_msg(f'尝试第 {n_fix}/{max_try} 次编译, 正在编译对比PDF ...', chatbot, history) # 刷新Gradio前端界面
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
ok = compile_latex_with_timeout(f'bibtex merge_diff.aux', work_folder)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
ok = compile_latex_with_timeout(f'pdflatex -interaction=batchmode -file-line-error merge_diff.tex', work_folder)
# <---------- 检查结果 ----------->
results_ = ""
original_pdf_success = os.path.exists(pj(work_folder_original, f'{main_file_original}.pdf'))
modified_pdf_success = os.path.exists(pj(work_folder_modified, f'{main_file_modified}.pdf'))
diff_pdf_success = os.path.exists(pj(work_folder, f'merge_diff.pdf'))
results_ += f"原始PDF编译是否成功: {original_pdf_success};"
results_ += f"转化PDF编译是否成功: {modified_pdf_success};"
results_ += f"对比PDF编译是否成功: {diff_pdf_success};"
yield from update_ui_lastest_msg(f'{n_fix}编译结束:<br/>{results_}...', chatbot, history) # 刷新Gradio前端界面
if diff_pdf_success:
result_pdf = pj(work_folder_modified, f'merge_diff.pdf') # get pdf path
promote_file_to_downloadzone(result_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
if modified_pdf_success:
yield from update_ui_lastest_msg(f'转化PDF编译已经成功, 即将退出 ...', chatbot, history) # 刷新Gradio前端界面
result_pdf = pj(work_folder_modified, f'{main_file_modified}.pdf') # get pdf path
if os.path.exists(pj(work_folder, '..', 'translation')):
shutil.copyfile(result_pdf, pj(work_folder, '..', 'translation', 'translate_zh.pdf'))
promote_file_to_downloadzone(result_pdf, rename_file=None, chatbot=chatbot) # promote file to web UI
return True # 成功啦
else:
if n_fix>=max_try: break
n_fix += 1
can_retry, main_file_modified, buggy_lines = remove_buggy_lines(
file_path=pj(work_folder_modified, f'{main_file_modified}.tex'),
log_path=pj(work_folder_modified, f'{main_file_modified}.log'),
tex_name=f'{main_file_modified}.tex',
tex_name_pure=f'{main_file_modified}',
n_fix=n_fix,
work_folder_modified=work_folder_modified,
)
yield from update_ui_lastest_msg(f'由于最为关键的转化PDF编译失败, 将根据报错信息修正tex源文件并重试, 当前报错的latex代码处于第{buggy_lines}行 ...', chatbot, history) # 刷新Gradio前端界面
if not can_retry: break
return False # 失败啦

查看文件

@@ -1,6 +1,5 @@
import time, json, sys, struct
import time, logging, json, sys, struct
import numpy as np
from loguru import logger as logging
from scipy.io.wavfile import WAVE_FORMAT
def write_numpy_to_wave(filename, rate, data, add_header=False):
@@ -86,8 +85,8 @@ def write_numpy_to_wave(filename, rate, data, add_header=False):
def is_speaker_speaking(vad, data, sample_rate):
# Function to detect if the speaker is speaking
# The WebRTC VAD only accepts 16-bit mono PCM audio,
# sampled at 8000, 16000, 32000 or 48000 Hz.
# The WebRTC VAD only accepts 16-bit mono PCM audio,
# sampled at 8000, 16000, 32000 or 48000 Hz.
# A frame must be either 10, 20, or 30 ms in duration:
frame_duration = 30
n_bit_each = int(sample_rate * frame_duration / 1000)*2 # x2 because audio is 16 bit (2 bytes)
@@ -95,7 +94,7 @@ def is_speaker_speaking(vad, data, sample_rate):
for t in range(len(data)):
if t!=0 and t % n_bit_each == 0:
res_list.append(vad.is_speech(data[t-n_bit_each:t], sample_rate))
info = ''.join(['^' if r else '.' for r in res_list])
info = info[:10]
if any(res_list):
@@ -107,14 +106,18 @@ def is_speaker_speaking(vad, data, sample_rate):
class AliyunASR():
def test_on_sentence_begin(self, message, *args):
# print("test_on_sentence_begin:{}".format(message))
pass
def test_on_sentence_end(self, message, *args):
# print("test_on_sentence_end:{}".format(message))
message = json.loads(message)
self.parsed_sentence = message['payload']['result']
self.event_on_entence_end.set()
# print(self.parsed_sentence)
def test_on_start(self, message, *args):
# print("test_on_start:{}".format(message))
pass
def test_on_error(self, message, *args):
@@ -126,11 +129,13 @@ class AliyunASR():
pass
def test_on_result_chg(self, message, *args):
# print("test_on_chg:{}".format(message))
message = json.loads(message)
self.parsed_text = message['payload']['result']
self.event_on_result_chg.set()
def test_on_completed(self, message, *args):
# print("on_completed:args=>{} message=>{}".format(args, message))
pass
def audio_convertion_thread(self, uuid):
@@ -181,10 +186,10 @@ class AliyunASR():
keep_alive_last_send_time = time.time()
while not self.stop:
# time.sleep(self.capture_interval)
audio = rad.read(uuid.hex)
audio = rad.read(uuid.hex)
if audio is not None:
# convert to pcm file
temp_file = f'{temp_folder}/{uuid.hex}.pcm' #
temp_file = f'{temp_folder}/{uuid.hex}.pcm' #
dsdata = change_sample_rate(audio, rad.rate, NEW_SAMPLERATE) # 48000 --> 16000
write_numpy_to_wave(temp_file, NEW_SAMPLERATE, dsdata)
# read pcm binary
@@ -243,14 +248,14 @@ class AliyunASR():
try:
response = client.do_action_with_exception(request)
logging.info(response)
print(response)
jss = json.loads(response)
if 'Token' in jss and 'Id' in jss['Token']:
token = jss['Token']['Id']
expireTime = jss['Token']['ExpireTime']
logging.info("token = " + token)
logging.info("expireTime = " + str(expireTime))
print("token = " + token)
print("expireTime = " + str(expireTime))
except Exception as e:
logging.error(e)
print(e)
return token

查看文件

@@ -3,12 +3,12 @@ from scipy import interpolate
def Singleton(cls):
_instance = {}
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
return _singleton
@@ -39,7 +39,7 @@ class RealtimeAudioDistribution():
else:
res = None
return res
def change_sample_rate(audio, old_sr, new_sr):
duration = audio.shape[0] / old_sr

查看文件

@@ -40,7 +40,7 @@ class GptAcademicState():
class GptAcademicGameBaseState():
"""
1. first init: __init__ ->
1. first init: __init__ ->
"""
def init_game(self, chatbot, lock_plugin):
self.plugin_name = None
@@ -53,7 +53,7 @@ class GptAcademicGameBaseState():
raise ValueError("callback_fn is None")
chatbot._cookies['lock_plugin'] = self.callback_fn
self.dump_state(chatbot)
def get_plugin_name(self):
if self.plugin_name is None:
raise ValueError("plugin_name is None")
@@ -71,7 +71,7 @@ class GptAcademicGameBaseState():
state = chatbot._cookies.get(f'plugin_state/{plugin_name}', None)
if state is not None:
state = pickle.loads(state)
else:
else:
state = cls()
state.init_game(chatbot, lock_plugin)
state.plugin_name = plugin_name
@@ -79,7 +79,7 @@ class GptAcademicGameBaseState():
state.chatbot = chatbot
state.callback_fn = callback_fn
return state
def continue_game(self, prompt, chatbot, history):
# 游戏主体
yield from self.step(prompt, chatbot, history)

查看文件

@@ -1,5 +1,4 @@
from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout
from loguru import logger
def force_breakdown(txt, limit, get_token_fn):
""" 当无法用标点、空行分割时,我们用最暴力的方法切割
@@ -36,7 +35,7 @@ def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=F
remain_txt_to_cut_storage = ""
# 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
while True:
if get_token_fn(remain_txt_to_cut) <= limit:
# 如果剩余文本的token数小于限制,那么就不用切了
@@ -77,7 +76,7 @@ def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=F
remain_txt_to_cut = post
remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage)
process = fin_len/total_len
logger.info(f'正在文本切分 {int(process*100)}%')
print(f'正在文本切分 {int(process*100)}%')
if len(remain_txt_to_cut.strip()) == 0:
break
return res
@@ -120,7 +119,7 @@ if __name__ == '__main__':
for i in range(5):
file_content += file_content
logger.info(len(file_content))
print(len(file_content))
TOKEN_LIMIT_PER_FRAGMENT = 2500
res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT)

查看文件

@@ -4,7 +4,7 @@ from toolbox import promote_file_to_downloadzone
from toolbox import write_history_to_file, promote_file_to_downloadzone
from toolbox import get_conf
from toolbox import ProxyNetworkActivate
from shared_utils.colorful import *
from colorful import *
import requests
import random
import copy
@@ -64,15 +64,15 @@ def produce_report_markdown(gpt_response_collection, meta, paper_meta_info, chat
# 再做一个小修改重新修改当前part的标题,默认用英文的
cur_value += value
translated_res_array.append(cur_value)
res_path = write_history_to_file(meta + ["# Meta Translation" , paper_meta_info] + translated_res_array,
file_basename = f"{gen_time_str()}-translated_only.md",
res_path = write_history_to_file(meta + ["# Meta Translation" , paper_meta_info] + translated_res_array,
file_basename = f"{gen_time_str()}-translated_only.md",
file_fullname = None,
auto_caption = False)
promote_file_to_downloadzone(res_path, rename_file=os.path.basename(res_path)+'.md', chatbot=chatbot)
generated_conclusion_files.append(res_path)
return res_path
def translate_pdf(article_dict, llm_kwargs, chatbot, fp, generated_conclusion_files, TOKEN_LIMIT_PER_FRAGMENT, DST_LANG, plugin_kwargs={}):
def translate_pdf(article_dict, llm_kwargs, chatbot, fp, generated_conclusion_files, TOKEN_LIMIT_PER_FRAGMENT, DST_LANG):
from crazy_functions.pdf_fns.report_gen_html import construct_html
from crazy_functions.pdf_fns.breakdown_txt import breakdown_text_to_satisfy_token_limit
from crazy_functions.crazy_utils import request_gpt_model_in_new_thread_with_ui_alive
@@ -138,17 +138,17 @@ def translate_pdf(article_dict, llm_kwargs, chatbot, fp, generated_conclusion_fi
chatbot=chatbot,
history_array=[meta for _ in inputs_array],
sys_prompt_array=[
"请你作为一个学术翻译,负责把学术论文准确翻译成中文。注意文章中的每一句话都要翻译。" + plugin_kwargs.get("additional_prompt", "") for _ in inputs_array],
"请你作为一个学术翻译,负责把学术论文准确翻译成中文。注意文章中的每一句话都要翻译。" for _ in inputs_array],
)
# -=-=-=-=-=-=-=-= 写出Markdown文件 -=-=-=-=-=-=-=-=
produce_report_markdown(gpt_response_collection, meta, paper_meta_info, chatbot, fp, generated_conclusion_files)
# -=-=-=-=-=-=-=-= 写出HTML文件 -=-=-=-=-=-=-=-=
ch = construct_html()
ch = construct_html()
orig = ""
trans = ""
gpt_response_collection_html = copy.deepcopy(gpt_response_collection)
for i,k in enumerate(gpt_response_collection_html):
for i,k in enumerate(gpt_response_collection_html):
if i%2==0:
gpt_response_collection_html[i] = inputs_show_user_array[i//2]
else:
@@ -159,7 +159,7 @@ def translate_pdf(article_dict, llm_kwargs, chatbot, fp, generated_conclusion_fi
final = ["", "", "一、论文概况", "", "Abstract", paper_meta_info, "二、论文翻译", ""]
final.extend(gpt_response_collection_html)
for i, k in enumerate(final):
for i, k in enumerate(final):
if i%2==0:
orig = k
if i%2==1:

查看文件

@@ -1,26 +0,0 @@
import os
from toolbox import CatchException, report_exception, get_log_folder, gen_time_str, check_packages
from toolbox import update_ui, promote_file_to_downloadzone, update_ui_lastest_msg, disable_auto_promotion
from toolbox import write_history_to_file, promote_file_to_downloadzone, get_conf, extract_archive
from crazy_functions.pdf_fns.parse_pdf import parse_pdf, translate_pdf
def 解析PDF_基于GROBID(file_manifest, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, grobid_url):
import copy, json
TOKEN_LIMIT_PER_FRAGMENT = 1024
generated_conclusion_files = []
generated_html_files = []
DST_LANG = "中文"
from crazy_functions.pdf_fns.report_gen_html import construct_html
for index, fp in enumerate(file_manifest):
chatbot.append(["当前进度:", f"正在连接GROBID服务,请稍候: {grobid_url}\n如果等待时间过长,请修改config中的GROBID_URL,可修改成本地GROBID服务。"]); yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
article_dict = parse_pdf(fp, grobid_url)
grobid_json_res = os.path.join(get_log_folder(), gen_time_str() + "grobid.json")
with open(grobid_json_res, 'w+', encoding='utf8') as f:
f.write(json.dumps(article_dict, indent=4, ensure_ascii=False))
promote_file_to_downloadzone(grobid_json_res, chatbot=chatbot)
if article_dict is None: raise RuntimeError("解析PDF失败,请检查PDF是否损坏。")
yield from translate_pdf(article_dict, llm_kwargs, chatbot, fp, generated_conclusion_files, TOKEN_LIMIT_PER_FRAGMENT, DST_LANG, plugin_kwargs=plugin_kwargs)
chatbot.append(("给出输出文件清单", str(generated_conclusion_files + generated_html_files)))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面

查看文件

@@ -1,250 +0,0 @@
from toolbox import get_log_folder, gen_time_str, get_conf
from toolbox import update_ui, promote_file_to_downloadzone
from toolbox import promote_file_to_downloadzone, extract_archive
from toolbox import generate_file_link, zip_folder
from crazy_functions.crazy_utils import get_files_from_everything
from shared_utils.colorful import *
from loguru import logger
import os
import time
def refresh_key(doc2x_api_key):
import requests, json
url = "https://api.doc2x.noedgeai.com/api/token/refresh"
res = requests.post(
url,
headers={"Authorization": "Bearer " + doc2x_api_key}
)
res_json = []
if res.status_code == 200:
decoded = res.content.decode("utf-8")
res_json = json.loads(decoded)
doc2x_api_key = res_json['data']['token']
else:
raise RuntimeError(format("[ERROR] status code: %d, body: %s" % (res.status_code, res.text)))
return doc2x_api_key
def 解析PDF_DOC2X_转Latex(pdf_file_path):
zip_file_path, unzipped_folder = 解析PDF_DOC2X(pdf_file_path, format='tex')
return unzipped_folder
def 解析PDF_DOC2X(pdf_file_path, format='tex'):
"""
format: 'tex', 'md', 'docx'
"""
import requests, json, os
DOC2X_API_KEY = get_conf('DOC2X_API_KEY')
latex_dir = get_log_folder(plugin_name="pdf_ocr_latex")
markdown_dir = get_log_folder(plugin_name="pdf_ocr")
doc2x_api_key = DOC2X_API_KEY
# < ------ 第1步上传 ------ >
logger.info("Doc2x 第1步上传")
with open(pdf_file_path, 'rb') as file:
res = requests.post(
"https://v2.doc2x.noedgeai.com/api/v2/parse/pdf",
headers={"Authorization": "Bearer " + doc2x_api_key},
data=file
)
# res_json = []
if res.status_code == 200:
res_json = res.json()
else:
raise RuntimeError(f"Doc2x return an error: {res.json()}")
uuid = res_json['data']['uid']
# < ------ 第2步轮询等待 ------ >
logger.info("Doc2x 第2步轮询等待")
params = {'uid': uuid}
while True:
res = requests.get(
'https://v2.doc2x.noedgeai.com/api/v2/parse/status',
headers={"Authorization": "Bearer " + doc2x_api_key},
params=params
)
res_json = res.json()
if res_json['data']['status'] == "success":
break
elif res_json['data']['status'] == "processing":
time.sleep(3)
logger.info(f"Doc2x is processing at {res_json['data']['progress']}%")
elif res_json['data']['status'] == "failed":
raise RuntimeError(f"Doc2x return an error: {res_json}")
# < ------ 第3步提交转化 ------ >
logger.info("Doc2x 第3步提交转化")
data = {
"uid": uuid,
"to": format,
"formula_mode": "dollar",
"filename": "output"
}
res = requests.post(
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse',
headers={"Authorization": "Bearer " + doc2x_api_key},
json=data
)
if res.status_code == 200:
res_json = res.json()
else:
raise RuntimeError(f"Doc2x return an error: {res.json()}")
# < ------ 第4步等待结果 ------ >
logger.info("Doc2x 第4步等待结果")
params = {'uid': uuid}
while True:
res = requests.get(
'https://v2.doc2x.noedgeai.com/api/v2/convert/parse/result',
headers={"Authorization": "Bearer " + doc2x_api_key},
params=params
)
res_json = res.json()
if res_json['data']['status'] == "success":
break
elif res_json['data']['status'] == "processing":
time.sleep(3)
logger.info(f"Doc2x still processing")
elif res_json['data']['status'] == "failed":
raise RuntimeError(f"Doc2x return an error: {res_json}")
# < ------ 第5步最后的处理 ------ >
logger.info("Doc2x 第5步最后的处理")
if format=='tex':
target_path = latex_dir
if format=='md':
target_path = markdown_dir
os.makedirs(target_path, exist_ok=True)
max_attempt = 3
# < ------ 下载 ------ >
for attempt in range(max_attempt):
try:
result_url = res_json['data']['url']
res = requests.get(result_url)
zip_path = os.path.join(target_path, gen_time_str() + '.zip')
unzip_path = os.path.join(target_path, gen_time_str())
if res.status_code == 200:
with open(zip_path, "wb") as f: f.write(res.content)
else:
raise RuntimeError(f"Doc2x return an error: {res.json()}")
except Exception as e:
if attempt < max_attempt - 1:
logger.error(f"Failed to download latex file, retrying... {e}")
time.sleep(3)
continue
else:
raise e
# < ------ 解压 ------ >
import zipfile
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(unzip_path)
return zip_path, unzip_path
def 解析PDF_DOC2X_单文件(fp, project_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, DOC2X_API_KEY, user_request):
def pdf2markdown(filepath):
chatbot.append((None, f"Doc2x 解析中"))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
md_zip_path, unzipped_folder = 解析PDF_DOC2X(filepath, format='md')
promote_file_to_downloadzone(md_zip_path, chatbot=chatbot)
chatbot.append((None, f"完成解析 {md_zip_path} ..."))
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
return md_zip_path
def deliver_to_markdown_plugin(md_zip_path, user_request):
from crazy_functions.Markdown_Translate import Markdown英译中
import shutil, re
time_tag = gen_time_str()
target_path_base = get_log_folder(chatbot.get_user())
file_origin_name = os.path.basename(md_zip_path)
this_file_path = os.path.join(target_path_base, file_origin_name)
os.makedirs(target_path_base, exist_ok=True)
shutil.copyfile(md_zip_path, this_file_path)
ex_folder = this_file_path + ".extract"
extract_archive(
file_path=this_file_path, dest_dir=ex_folder
)
# edit markdown files
success, file_manifest, project_folder = get_files_from_everything(ex_folder, type='.md')
for generated_fp in file_manifest:
# 修正一些公式问题
with open(generated_fp, 'r', encoding='utf8') as f:
content = f.read()
# 将公式中的\[ \]替换成$$
content = content.replace(r'\[', r'$$').replace(r'\]', r'$$')
# 将公式中的\( \)替换成$
content = content.replace(r'\(', r'$').replace(r'\)', r'$')
content = content.replace('```markdown', '\n').replace('```', '\n')
with open(generated_fp, 'w', encoding='utf8') as f:
f.write(content)
promote_file_to_downloadzone(generated_fp, chatbot=chatbot)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
# 生成在线预览html
file_name = '在线预览翻译(原文)' + gen_time_str() + '.html'
preview_fp = os.path.join(ex_folder, file_name)
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
with open(generated_fp, "r", encoding="utf-8") as f:
md = f.read()
# # Markdown中使用不标准的表格,需要在表格前加上一个emoji,以便公式渲染
# md = re.sub(r'^<table>', r'.<table>', md, flags=re.MULTILINE)
html = markdown_convertion_for_file(md)
with open(preview_fp, "w", encoding="utf-8") as f: f.write(html)
chatbot.append([None, f"生成在线预览:{generate_file_link([preview_fp])}"])
promote_file_to_downloadzone(preview_fp, chatbot=chatbot)
chatbot.append((None, f"调用Markdown插件 {ex_folder} ..."))
plugin_kwargs['markdown_expected_output_dir'] = ex_folder
translated_f_name = 'translated_markdown.md'
generated_fp = plugin_kwargs['markdown_expected_output_path'] = os.path.join(ex_folder, translated_f_name)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
yield from Markdown英译中(ex_folder, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request)
if os.path.exists(generated_fp):
# 修正一些公式问题
with open(generated_fp, 'r', encoding='utf8') as f: content = f.read()
content = content.replace('```markdown', '\n').replace('```', '\n')
# Markdown中使用不标准的表格,需要在表格前加上一个emoji,以便公式渲染
# content = re.sub(r'^<table>', r'.<table>', content, flags=re.MULTILINE)
with open(generated_fp, 'w', encoding='utf8') as f: f.write(content)
# 生成在线预览html
file_name = '在线预览翻译' + gen_time_str() + '.html'
preview_fp = os.path.join(ex_folder, file_name)
from shared_utils.advanced_markdown_format import markdown_convertion_for_file
with open(generated_fp, "r", encoding="utf-8") as f:
md = f.read()
html = markdown_convertion_for_file(md)
with open(preview_fp, "w", encoding="utf-8") as f: f.write(html)
promote_file_to_downloadzone(preview_fp, chatbot=chatbot)
# 生成包含图片的压缩包
dest_folder = get_log_folder(chatbot.get_user())
zip_name = '翻译后的带图文档.zip'
zip_folder(source_folder=ex_folder, dest_folder=dest_folder, zip_name=zip_name)
zip_fp = os.path.join(dest_folder, zip_name)
promote_file_to_downloadzone(zip_fp, chatbot=chatbot)
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
md_zip_path = yield from pdf2markdown(fp)
yield from deliver_to_markdown_plugin(md_zip_path, user_request)
def 解析PDF_基于DOC2X(file_manifest, *args):
for index, fp in enumerate(file_manifest):
yield from 解析PDF_DOC2X_单文件(fp, *args)
return

查看文件

@@ -22,10 +22,10 @@ def extract_text_from_files(txt, chatbot, history):
file_manifest = []
excption = ""
if txt == "":
if txt == "":
final_result.append(txt)
return False, final_result, page_one, file_manifest, excption #如输入区内容不是文件则直接返回输入区内容
#查找输入区内容中的文件
file_pdf,pdf_manifest,folder_pdf = get_files_from_everything(txt, '.pdf')
file_md,md_manifest,folder_md = get_files_from_everything(txt, '.md')
@@ -35,12 +35,12 @@ def extract_text_from_files(txt, chatbot, history):
if file_doc:
excption = "word"
return False, final_result, page_one, file_manifest, excption
file_num = len(pdf_manifest) + len(md_manifest) + len(word_manifest)
if file_num == 0:
final_result.append(txt)
return False, final_result, page_one, file_manifest, excption #如输入区内容不是文件则直接返回输入区内容
if file_pdf:
try: # 尝试导入依赖,如果缺少依赖,则给出安装建议
import fitz
@@ -61,7 +61,7 @@ def extract_text_from_files(txt, chatbot, history):
file_content = f.read()
file_content = file_content.encode('utf-8', 'ignore').decode()
headers = re.findall(r'^#\s(.*)$', file_content, re.MULTILINE) #接下来提取md中的一级/二级标题作为摘要
if len(headers) > 0:
if len(headers) > 0:
page_one.append("\n".join(headers)) #合并所有的标题,以换行符分割
else:
page_one.append("")
@@ -81,5 +81,5 @@ def extract_text_from_files(txt, chatbot, history):
page_one.append(file_content[:200])
final_result.append(file_content)
file_manifest.append(os.path.relpath(fp, folder_word))
return True, final_result, page_one, file_manifest, excption

查看文件

@@ -1,73 +0,0 @@
<!DOCTYPE html>
<html xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />
<title>GPT-Academic 翻译报告书</title>
<style>
.centered-a {
color: red;
text-align: center;
margin-bottom: 2%;
font-size: 1.5em;
}
.centered-b {
color: red;
text-align: center;
margin-top: 10%;
margin-bottom: 20%;
font-size: 1.5em;
}
.centered-c {
color: rgba(255, 0, 0, 0);
text-align: center;
margin-top: 2%;
margin-bottom: 20%;
font-size: 7em;
}
</style>
<script>
// Configure MathJax settings
MathJax = {
tex: {
inlineMath: [
['$', '$'],
['\(', '\)']
]
}
}
addEventListener('zero-md-rendered', () => {MathJax.typeset(); console.log('MathJax typeset!');})
</script>
<!-- Load MathJax library -->
<script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml.js"></script>
<script
type="module"
src="https://cdn.jsdelivr.net/gh/zerodevx/zero-md@2/dist/zero-md.min.js"
></script>
</head>
<body>
<div class="test_temp1" style="width:10%; height: 500px; float:left;">
</div>
<div class="test_temp2" style="width:80%; height: 500px; float:left;">
<!-- Simply set the `src` attribute to your MD file and win -->
<div class="centered-a">
请按Ctrl+S保存此页面,否则该页面可能在几分钟后失效。
</div>
<zero-md src="translated_markdown.md" no-shadow>
</zero-md>
<div class="centered-b">
本报告由GPT-Academic开源项目生成,地址https://github.com/binary-husky/gpt_academic。
</div>
<div class="centered-c">
本报告由GPT-Academic开源项目生成,地址https://github.com/binary-husky/gpt_academic。
</div>
</div>
<div class="test_temp3" style="width:10%; height: 500px; float:left;">
</div>
</body>
</html>

查看文件

@@ -1,52 +0,0 @@
import os, json, base64
from pydantic import BaseModel, Field
from textwrap import dedent
from typing import List
class ArgProperty(BaseModel): # PLUGIN_ARG_MENU
title: str = Field(description="The title", default="")
description: str = Field(description="The description", default="")
default_value: str = Field(description="The default value", default="")
type: str = Field(description="The type", default="") # currently we support ['string', 'dropdown']
options: List[str] = Field(default=[], description="List of options available for the argument") # only used when type is 'dropdown'
class GptAcademicPluginTemplate():
def __init__(self):
# please note that `execute` method may run in different threads,
# thus you should not store any state in the plugin instance,
# which may be accessed by multiple threads
pass
def define_arg_selection_menu(self):
"""
An example as below:
```
def define_arg_selection_menu(self):
gui_definition = {
"main_input":
ArgProperty(title="main input", description="description", default_value="default_value", type="string").model_dump_json(),
"advanced_arg":
ArgProperty(title="advanced arguments", description="description", default_value="default_value", type="string").model_dump_json(),
"additional_arg_01":
ArgProperty(title="additional", description="description", default_value="default_value", type="string").model_dump_json(),
}
return gui_definition
```
"""
raise NotImplementedError("You need to implement this method in your plugin class")
def get_js_code_for_generating_menu(self, btnName):
define_arg_selection = self.define_arg_selection_menu()
if len(define_arg_selection.keys()) > 8:
raise ValueError("You can only have up to 8 arguments in the define_arg_selection")
# if "main_input" not in define_arg_selection:
# raise ValueError("You must have a 'main_input' in the define_arg_selection")
DEFINE_ARG_INPUT_INTERFACE = json.dumps(define_arg_selection)
return base64.b64encode(DEFINE_ARG_INPUT_INTERFACE.encode('utf-8')).decode('utf-8')
def execute(txt, llm_kwargs, plugin_kwargs, chatbot, history, system_prompt, user_request):
raise NotImplementedError("You need to implement this method in your plugin class")

查看文件

@@ -1,87 +0,0 @@
SearchOptimizerPrompt="""作为一个网页搜索助手,你的任务是结合历史记录,从不同角度,为“原问题”生成个不同版本的“检索词”,从而提高网页检索的精度。生成的问题要求指向对象清晰明确,并与“原问题语言相同”。例如:
历史记录:
"
Q: 对话背景。
A: 当前对话是关于 Nginx 的介绍和在Ubuntu上的使用等。
"
原问题: 怎么下载
检索词: ["Nginx 下载","Ubuntu Nginx","Ubuntu安装Nginx"]
----------------
历史记录:
"
Q: 对话背景。
A: 当前对话是关于 Nginx 的介绍和使用等。
Q: 报错 "no connection"
A: 报错"no connection"可能是因为……
"
原问题: 怎么解决
检索词: ["Nginx报错"no connection" 解决","Nginx'no connection'报错 原因","Nginx提示'no connection'"]
----------------
历史记录:
"
"
原问题: 你知道 Python 么?
检索词: ["Python","Python 使用教程。","Python 特点和优势"]
----------------
历史记录:
"
Q: 列出Java的三种特点?
A: 1. Java 是一种编译型语言。
2. Java 是一种面向对象的编程语言。
3. Java 是一种跨平台的编程语言。
"
原问题: 介绍下第2点。
检索词: ["Java 面向对象特点","Java 面向对象编程优势。","Java 面向对象编程"]
----------------
现在有历史记录:
"
{history}
"
有其原问题: {query}
直接给出最多{num}个检索词,必须以json形式给出,不得有多余字符:
"""
SearchAcademicOptimizerPrompt="""作为一个学术论文搜索助手,你的任务是结合历史记录,从不同角度,为“原问题”生成个不同版本的“检索词”,从而提高学术论文检索的精度。生成的问题要求指向对象清晰明确,并与“原问题语言相同”。例如:
历史记录:
"
Q: 对话背景。
A: 当前对话是关于深度学习的介绍和在图像识别中的应用等。
"
原问题: 怎么下载相关论文
检索词: ["深度学习 图像识别 论文下载","图像识别 深度学习 研究论文","深度学习 图像识别 论文资源","Deep Learning Image Recognition Paper Download","Image Recognition Deep Learning Research Paper"]
----------------
历史记录:
"
Q: 对话背景。
A: 当前对话是关于深度学习的介绍和应用等。
Q: 报错 "模型不收敛"
A: 报错"模型不收敛"可能是因为……
"
原问题: 怎么解决
检索词: ["深度学习 模型不收敛 解决方案 论文","深度学习 模型不收敛 原因 研究","深度学习 模型不收敛 论文","Deep Learning Model Convergence Issue Solution Paper","Deep Learning Model Convergence Problem Research"]
----------------
历史记录:
"
"
原问题: 你知道 GAN 么?
检索词: ["生成对抗网络 论文","GAN 使用教程 论文","GAN 特点和优势 研究","Generative Adversarial Network Paper","GAN Usage Tutorial Paper"]
----------------
历史记录:
"
Q: 列出机器学习的三种应用?
A: 1. 机器学习在图像识别中的应用。
2. 机器学习在自然语言处理中的应用。
3. 机器学习在推荐系统中的应用。
"
原问题: 介绍下第2点。
检索词: ["机器学习 自然语言处理 应用 论文","机器学习 自然语言处理 研究","机器学习 NLP 应用 论文","Machine Learning Natural Language Processing Application Paper","Machine Learning NLP Research"]
----------------
现在有历史记录:
"
{history}
"
有其原问题: {query}
直接给出最多{num}个检索词,必须以json形式给出,不得有多余字符:
"""

查看文件

@@ -1,138 +0,0 @@
import atexit
from loguru import logger
from typing import List
from llama_index.core import Document
from llama_index.core.ingestion import run_transformations
from llama_index.core.schema import TextNode
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
---------------------
{context_str}
---------------------
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
---------------------
{query_str}
"""
QUESTION_ANSWER_RECORD = """\
{{
"type": "This is a previous conversation with the user",
"question": "{question}",
"answer": "{answer}",
}}
"""
class SaveLoad():
def does_checkpoint_exist(self, checkpoint_dir=None):
import os, glob
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if not os.path.exists(checkpoint_dir): return False
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
return True
def save_to_checkpoint(self, checkpoint_dir=None):
logger.info(f'saving vector store to: {checkpoint_dir}')
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
def load_from_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
logger.info('loading checkpoint from disk')
from llama_index.core import StorageContext, load_index_from_storage
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
return self.vs_index
else:
return self.create_new_vs()
def create_new_vs(self):
return GptacVectorStoreIndex.default_vector_store(embed_model=self.embed_model)
def purge(self):
import shutil
shutil.rmtree(self.checkpoint_dir, ignore_errors=True)
self.vs_index = self.create_new_vs(self.checkpoint_dir)
class LlamaIndexRagWorker(SaveLoad):
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
self.debug_mode = True
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else:
self.vs_index = self.create_new_vs()
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
def assign_embedding_model(self):
pass
def inspect_vector_store(self):
# This function is for debugging
self.vs_index.storage_context.index_store.to_dict()
docstore = self.vs_index.storage_context.docstore.docs
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
logger.info('\n++ --------inspect_vector_store begin--------')
logger.info(vector_store_preview)
logger.info('oo --------inspect_vector_store end--------')
return vector_store_preview
def add_documents_to_vector_store(self, document_list: List[Document]):
"""
Adds a list of Document objects to the vector store after processing.
"""
documents = document_list
documents_nodes = run_transformations(
documents, # type: ignore
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode:
self.inspect_vector_store()
def add_text_to_vector_store(self, text: str):
node = TextNode(text=text)
documents_nodes = run_transformations(
[node],
self.vs_index._transformations,
show_progress=True
)
self.vs_index.insert_nodes(documents_nodes)
if self.debug_mode:
self.inspect_vector_store()
def remember_qa(self, question, answer):
formatted_str = QUESTION_ANSWER_RECORD.format(question=question, answer=answer)
self.add_text_to_vector_store(formatted_str)
def retrieve_from_store_with_query(self, query):
if self.debug_mode:
self.inspect_vector_store()
retriever = self.vs_index.as_retriever()
return retriever.retrieve(query)
def build_prompt(self, query, nodes):
context_str = self.generate_node_array_preview(nodes)
return DEFAULT_QUERY_GENERATION_PROMPT.format(context_str=context_str, query_str=query)
def generate_node_array_preview(self, nodes):
buf = "\n".join(([f"(No.{i+1} | score {n.score:.3f}): {n.text}" for i, n in enumerate(nodes)]))
if self.debug_mode: logger.info(buf)
return buf
def purge_vector_store(self):
"""
Purges the current vector store and creates a new one.
"""
self.purge()

查看文件

@@ -1,108 +0,0 @@
import llama_index
import os
import atexit
from typing import List
from loguru import logger
from llama_index.core import Document
from llama_index.core.schema import TextNode
from request_llms.embed_models.openai_embed import OpenAiEmbeddingModel
from shared_utils.connect_void_terminal import get_chat_default_kwargs
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from crazy_functions.rag_fns.vector_store_index import GptacVectorStoreIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core import PromptTemplate
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.core import StorageContext
from llama_index.vector_stores.milvus import MilvusVectorStore
from crazy_functions.rag_fns.llama_index_worker import LlamaIndexRagWorker
DEFAULT_QUERY_GENERATION_PROMPT = """\
Now, you have context information as below:
---------------------
{context_str}
---------------------
Answer the user request below (use the context information if necessary, otherwise you can ignore them):
---------------------
{query_str}
"""
QUESTION_ANSWER_RECORD = """\
{{
"type": "This is a previous conversation with the user",
"question": "{question}",
"answer": "{answer}",
}}
"""
class MilvusSaveLoad():
def does_checkpoint_exist(self, checkpoint_dir=None):
import os, glob
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if not os.path.exists(checkpoint_dir): return False
if len(glob.glob(os.path.join(checkpoint_dir, "*.json"))) == 0: return False
return True
def save_to_checkpoint(self, checkpoint_dir=None):
logger.info(f'saving vector store to: {checkpoint_dir}')
# if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
# self.vs_index.storage_context.persist(persist_dir=checkpoint_dir)
def load_from_checkpoint(self, checkpoint_dir=None):
if checkpoint_dir is None: checkpoint_dir = self.checkpoint_dir
if self.does_checkpoint_exist(checkpoint_dir=checkpoint_dir):
logger.info('loading checkpoint from disk')
from llama_index.core import StorageContext, load_index_from_storage
storage_context = StorageContext.from_defaults(persist_dir=checkpoint_dir)
try:
self.vs_index = load_index_from_storage(storage_context, embed_model=self.embed_model)
return self.vs_index
except:
return self.create_new_vs(checkpoint_dir)
else:
return self.create_new_vs(checkpoint_dir)
def create_new_vs(self, checkpoint_dir, overwrite=False):
vector_store = MilvusVectorStore(
uri=os.path.join(checkpoint_dir, "milvus_demo.db"),
dim=self.embed_model.embedding_dimension(),
overwrite=overwrite
)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = GptacVectorStoreIndex.default_vector_store(storage_context=storage_context, embed_model=self.embed_model)
return index
def purge(self):
self.vs_index = self.create_new_vs(self.checkpoint_dir, overwrite=True)
class MilvusRagWorker(MilvusSaveLoad, LlamaIndexRagWorker):
def __init__(self, user_name, llm_kwargs, auto_load_checkpoint=True, checkpoint_dir=None) -> None:
self.debug_mode = True
self.embed_model = OpenAiEmbeddingModel(llm_kwargs)
self.user_name = user_name
self.checkpoint_dir = checkpoint_dir
if auto_load_checkpoint:
self.vs_index = self.load_from_checkpoint(checkpoint_dir)
else:
self.vs_index = self.create_new_vs(checkpoint_dir)
atexit.register(lambda: self.save_to_checkpoint(checkpoint_dir))
def inspect_vector_store(self):
# This function is for debugging
try:
self.vs_index.storage_context.index_store.to_dict()
docstore = self.vs_index.storage_context.docstore.docs
if not docstore.items():
raise ValueError("cannot inspect")
vector_store_preview = "\n".join([ f"{_id} | {tn.text}" for _id, tn in docstore.items() ])
except:
dummy_retrieve_res: List["NodeWithScore"] = self.vs_index.as_retriever().retrieve(' ')
vector_store_preview = "\n".join(
[f"{node.id_} | {node.text}" for node in dummy_retrieve_res]
)
logger.info('\n++ --------inspect_vector_store begin--------')
logger.info(vector_store_preview)
logger.info('oo --------inspect_vector_store end--------')
return vector_store_preview

查看文件

@@ -1,45 +0,0 @@
import os
from llama_index.core import SimpleDirectoryReader
supports_format = ['.csv', '.docx','.doc', '.epub', '.ipynb', '.mbox', '.md', '.pdf', '.txt', '.ppt',
'.pptm', '.pptx','.py', '.xls', '.xlsx', '.html', '.json', '.xml', '.yaml', '.yml' ,'.m']
def read_docx_doc(file_path):
if file_path.split(".")[-1] == "docx":
from docx import Document
doc = Document(file_path)
file_content = "\n".join([para.text for para in doc.paragraphs])
else:
try:
import win32com.client
word = win32com.client.Dispatch("Word.Application")
word.visible = False
# 打开文件
doc = word.Documents.Open(os.getcwd() + '/' + file_path)
# file_content = doc.Content.Text
doc = word.ActiveDocument
file_content = doc.Range().Text
doc.Close()
word.Quit()
except:
raise RuntimeError('请先将.doc文档转换为.docx文档。')
return file_content
# 修改后的 extract_text 函数,结合 SimpleDirectoryReader 和自定义解析逻辑
import os
def extract_text(file_path):
_, ext = os.path.splitext(file_path.lower())
# 使用 SimpleDirectoryReader 处理它支持的文件格式
if ext in ['.docx', '.doc']:
return read_docx_doc(file_path)
try:
reader = SimpleDirectoryReader(input_files=[file_path])
documents = reader.load_data()
if len(documents) > 0:
return documents[0].text
except Exception as e:
pass
return None

查看文件

@@ -1,58 +0,0 @@
from llama_index.core import VectorStoreIndex
from typing import Any, List, Optional
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.schema import TransformComponent
from llama_index.core.service_context import ServiceContext
from llama_index.core.settings import (
Settings,
callback_manager_from_settings_or_context,
transformations_from_settings_or_context,
)
from llama_index.core.storage.storage_context import StorageContext
class GptacVectorStoreIndex(VectorStoreIndex):
@classmethod
def default_vector_store(
cls,
storage_context: Optional[StorageContext] = None,
show_progress: bool = False,
callback_manager: Optional[CallbackManager] = None,
transformations: Optional[List[TransformComponent]] = None,
# deprecated
service_context: Optional[ServiceContext] = None,
embed_model = None,
**kwargs: Any,
):
"""Create index from documents.
Args:
documents (Optional[Sequence[BaseDocument]]): List of documents to
build the index from.
"""
storage_context = storage_context or StorageContext.from_defaults()
docstore = storage_context.docstore
callback_manager = (
callback_manager
or callback_manager_from_settings_or_context(Settings, service_context)
)
transformations = transformations or transformations_from_settings_or_context(
Settings, service_context
)
with callback_manager.as_trace("index_construction"):
return cls(
nodes=[],
storage_context=storage_context,
callback_manager=callback_manager,
show_progress=show_progress,
transformations=transformations,
service_context=service_context,
embed_model=embed_model,
**kwargs,
)

查看文件

@@ -0,0 +1,87 @@
#include "libipc/buffer.h"
#include "libipc/utility/pimpl.h"
#include <cstring>
namespace ipc {
bool operator==(buffer const & b1, buffer const & b2) {
return (b1.size() == b2.size()) && (std::memcmp(b1.data(), b2.data(), b1.size()) == 0);
}
bool operator!=(buffer const & b1, buffer const & b2) {
return !(b1 == b2);
}
class buffer::buffer_ : public pimpl<buffer_> {
public:
void* p_;
std::size_t s_;
void* a_;
buffer::destructor_t d_;
buffer_(void* p, std::size_t s, buffer::destructor_t d, void* a)
: p_(p), s_(s), a_(a), d_(d) {
}
~buffer_() {
if (d_ == nullptr) return;
d_((a_ == nullptr) ? p_ : a_, s_);
}
};
buffer::buffer()
: buffer(nullptr, 0, nullptr, nullptr) {
}
buffer::buffer(void* p, std::size_t s, destructor_t d)
: p_(p_->make(p, s, d, nullptr)) {
}
buffer::buffer(void* p, std::size_t s, destructor_t d, void* additional)
: p_(p_->make(p, s, d, additional)) {
}
buffer::buffer(void* p, std::size_t s)
: buffer(p, s, nullptr) {
}
buffer::buffer(char const & c)
: buffer(const_cast<char*>(&c), 1) {
}
buffer::buffer(buffer&& rhs)
: buffer() {
swap(rhs);
}
buffer::~buffer() {
p_->clear();
}
void buffer::swap(buffer& rhs) {
std::swap(p_, rhs.p_);
}
buffer& buffer::operator=(buffer rhs) {
swap(rhs);
return *this;
}
bool buffer::empty() const noexcept {
return (impl(p_)->p_ == nullptr) || (impl(p_)->s_ == 0);
}
void* buffer::data() noexcept {
return impl(p_)->p_;
}
void const * buffer::data() const noexcept {
return impl(p_)->p_;
}
std::size_t buffer::size() const noexcept {
return impl(p_)->s_;
}
} // namespace ipc

查看文件

@@ -0,0 +1,701 @@
#include <type_traits>
#include <cstring>
#include <algorithm>
#include <utility> // std::pair, std::move, std::forward
#include <atomic>
#include <type_traits> // aligned_storage_t
#include <string>
#include <vector>
#include <array>
#include <cassert>
#include "libipc/ipc.h"
#include "libipc/def.h"
#include "libipc/shm.h"
#include "libipc/pool_alloc.h"
#include "libipc/queue.h"
#include "libipc/policy.h"
#include "libipc/rw_lock.h"
#include "libipc/waiter.h"
#include "libipc/utility/log.h"
#include "libipc/utility/id_pool.h"
#include "libipc/utility/scope_guard.h"
#include "libipc/utility/utility.h"
#include "libipc/memory/resource.h"
#include "libipc/platform/detail.h"
#include "libipc/circ/elem_array.h"
namespace {
using msg_id_t = std::uint32_t;
using acc_t = std::atomic<msg_id_t>;
template <std::size_t DataSize, std::size_t AlignSize>
struct msg_t;
template <std::size_t AlignSize>
struct msg_t<0, AlignSize> {
msg_id_t cc_id_;
msg_id_t id_;
std::int32_t remain_;
bool storage_;
};
template <std::size_t DataSize, std::size_t AlignSize>
struct msg_t : msg_t<0, AlignSize> {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
msg_t() = default;
msg_t(msg_id_t cc_id, msg_id_t id, std::int32_t remain, void const * data, std::size_t size)
: msg_t<0, AlignSize> {cc_id, id, remain, (data == nullptr) || (size == 0)} {
if (this->storage_) {
if (data != nullptr) {
// copy storage-id
*reinterpret_cast<ipc::storage_id_t*>(&data_) =
*static_cast<ipc::storage_id_t const *>(data);
}
}
else std::memcpy(&data_, data, size);
}
};
template <typename T>
ipc::buff_t make_cache(T& data, std::size_t size) {
auto ptr = ipc::mem::alloc(size);
std::memcpy(ptr, &data, (ipc::detail::min)(sizeof(data), size));
return { ptr, size, ipc::mem::free };
}
struct cache_t {
std::size_t fill_;
ipc::buff_t buff_;
cache_t(std::size_t f, ipc::buff_t && b)
: fill_(f), buff_(std::move(b))
{}
void append(void const * data, std::size_t size) {
if (fill_ >= buff_.size() || data == nullptr || size == 0) return;
auto new_fill = (ipc::detail::min)(fill_ + size, buff_.size());
std::memcpy(static_cast<ipc::byte_t*>(buff_.data()) + fill_, data, new_fill - fill_);
fill_ = new_fill;
}
};
auto cc_acc() {
static ipc::shm::handle acc_h("__CA_CONN__", sizeof(acc_t));
return static_cast<acc_t*>(acc_h.get());
}
IPC_CONSTEXPR_ std::size_t align_chunk_size(std::size_t size) noexcept {
return (((size - 1) / ipc::large_msg_align) + 1) * ipc::large_msg_align;
}
IPC_CONSTEXPR_ std::size_t calc_chunk_size(std::size_t size) noexcept {
return ipc::make_align(alignof(std::max_align_t), align_chunk_size(
ipc::make_align(alignof(std::max_align_t), sizeof(std::atomic<ipc::circ::cc_t>)) + size));
}
struct chunk_t {
std::atomic<ipc::circ::cc_t> &conns() noexcept {
return *reinterpret_cast<std::atomic<ipc::circ::cc_t> *>(this);
}
void *data() noexcept {
return reinterpret_cast<ipc::byte_t *>(this)
+ ipc::make_align(alignof(std::max_align_t), sizeof(std::atomic<ipc::circ::cc_t>));
}
};
struct chunk_info_t {
ipc::id_pool<> pool_;
ipc::spin_lock lock_;
IPC_CONSTEXPR_ static std::size_t chunks_mem_size(std::size_t chunk_size) noexcept {
return ipc::id_pool<>::max_count * chunk_size;
}
ipc::byte_t *chunks_mem() noexcept {
return reinterpret_cast<ipc::byte_t *>(this + 1);
}
chunk_t *at(std::size_t chunk_size, ipc::storage_id_t id) noexcept {
if (id < 0) return nullptr;
return reinterpret_cast<chunk_t *>(chunks_mem() + (chunk_size * id));
}
};
auto& chunk_storages() {
class chunk_handle_t {
ipc::shm::handle handle_;
public:
chunk_info_t *get_info(std::size_t chunk_size) {
if (!handle_.valid() &&
!handle_.acquire( ("__CHUNK_INFO__" + ipc::to_string(chunk_size)).c_str(),
sizeof(chunk_info_t) + chunk_info_t::chunks_mem_size(chunk_size) )) {
ipc::error("[chunk_storages] chunk_shm.id_info_.acquire failed: chunk_size = %zd\n", chunk_size);
return nullptr;
}
auto info = static_cast<chunk_info_t*>(handle_.get());
if (info == nullptr) {
ipc::error("[chunk_storages] chunk_shm.id_info_.get failed: chunk_size = %zd\n", chunk_size);
return nullptr;
}
return info;
}
};
static ipc::map<std::size_t, chunk_handle_t> chunk_hs;
return chunk_hs;
}
chunk_info_t *chunk_storage_info(std::size_t chunk_size) {
auto &storages = chunk_storages();
std::decay_t<decltype(storages)>::iterator it;
{
static ipc::rw_lock lock;
IPC_UNUSED_ std::shared_lock<ipc::rw_lock> guard {lock};
if ((it = storages.find(chunk_size)) == storages.end()) {
using chunk_handle_t = std::decay_t<decltype(storages)>::value_type::second_type;
guard.unlock();
IPC_UNUSED_ std::lock_guard<ipc::rw_lock> guard {lock};
it = storages.emplace(chunk_size, chunk_handle_t{}).first;
}
}
return it->second.get_info(chunk_size);
}
std::pair<ipc::storage_id_t, void*> acquire_storage(std::size_t size, ipc::circ::cc_t conns) {
std::size_t chunk_size = calc_chunk_size(size);
auto info = chunk_storage_info(chunk_size);
if (info == nullptr) return {};
info->lock_.lock();
info->pool_.prepare();
// got an unique id
auto id = info->pool_.acquire();
info->lock_.unlock();
auto chunk = info->at(chunk_size, id);
if (chunk == nullptr) return {};
chunk->conns().store(conns, std::memory_order_relaxed);
return { id, chunk->data() };
}
void *find_storage(ipc::storage_id_t id, std::size_t size) {
if (id < 0) {
ipc::error("[find_storage] id is invalid: id = %ld, size = %zd\n", (long)id, size);
return nullptr;
}
std::size_t chunk_size = calc_chunk_size(size);
auto info = chunk_storage_info(chunk_size);
if (info == nullptr) return nullptr;
return info->at(chunk_size, id)->data();
}
void release_storage(ipc::storage_id_t id, std::size_t size) {
if (id < 0) {
ipc::error("[release_storage] id is invalid: id = %ld, size = %zd\n", (long)id, size);
return;
}
std::size_t chunk_size = calc_chunk_size(size);
auto info = chunk_storage_info(chunk_size);
if (info == nullptr) return;
info->lock_.lock();
info->pool_.release(id);
info->lock_.unlock();
}
template <ipc::relat Rp, ipc::relat Rc>
bool sub_rc(ipc::wr<Rp, Rc, ipc::trans::unicast>,
std::atomic<ipc::circ::cc_t> &/*conns*/, ipc::circ::cc_t /*curr_conns*/, ipc::circ::cc_t /*conn_id*/) noexcept {
return true;
}
template <ipc::relat Rp, ipc::relat Rc>
bool sub_rc(ipc::wr<Rp, Rc, ipc::trans::broadcast>,
std::atomic<ipc::circ::cc_t> &conns, ipc::circ::cc_t curr_conns, ipc::circ::cc_t conn_id) noexcept {
auto last_conns = curr_conns & ~conn_id;
for (unsigned k = 0;;) {
auto chunk_conns = conns.load(std::memory_order_acquire);
if (conns.compare_exchange_weak(chunk_conns, chunk_conns & last_conns, std::memory_order_release)) {
return (chunk_conns & last_conns) == 0;
}
ipc::yield(k);
}
}
template <typename Flag>
void recycle_storage(ipc::storage_id_t id, std::size_t size, ipc::circ::cc_t curr_conns, ipc::circ::cc_t conn_id) {
if (id < 0) {
ipc::error("[recycle_storage] id is invalid: id = %ld, size = %zd\n", (long)id, size);
return;
}
std::size_t chunk_size = calc_chunk_size(size);
auto info = chunk_storage_info(chunk_size);
if (info == nullptr) return;
auto chunk = info->at(chunk_size, id);
if (chunk == nullptr) return;
if (!sub_rc(Flag{}, chunk->conns(), curr_conns, conn_id)) {
return;
}
info->lock_.lock();
info->pool_.release(id);
info->lock_.unlock();
}
template <typename MsgT>
bool clear_message(void* p) {
auto msg = static_cast<MsgT*>(p);
if (msg->storage_) {
std::int32_t r_size = static_cast<std::int32_t>(ipc::data_length) + msg->remain_;
if (r_size <= 0) {
ipc::error("[clear_message] invalid msg size: %d\n", (int)r_size);
return true;
}
release_storage(
*reinterpret_cast<ipc::storage_id_t*>(&msg->data_),
static_cast<std::size_t>(r_size));
}
return true;
}
struct conn_info_head {
ipc::string name_;
msg_id_t cc_id_; // connection-info id
ipc::detail::waiter cc_waiter_, wt_waiter_, rd_waiter_;
ipc::shm::handle acc_h_;
conn_info_head(char const * name)
: name_ {name}
, cc_id_ {(cc_acc() == nullptr) ? 0 : cc_acc()->fetch_add(1, std::memory_order_relaxed)}
, cc_waiter_{("__CC_CONN__" + name_).c_str()}
, wt_waiter_{("__WT_CONN__" + name_).c_str()}
, rd_waiter_{("__RD_CONN__" + name_).c_str()}
, acc_h_ {("__AC_CONN__" + name_).c_str(), sizeof(acc_t)} {
}
void quit_waiting() {
cc_waiter_.quit_waiting();
wt_waiter_.quit_waiting();
rd_waiter_.quit_waiting();
}
auto acc() {
return static_cast<acc_t*>(acc_h_.get());
}
auto& recv_cache() {
thread_local ipc::unordered_map<msg_id_t, cache_t> tls;
return tls;
}
};
template <typename W, typename F>
bool wait_for(W& waiter, F&& pred, std::uint64_t tm) {
if (tm == 0) return !pred();
for (unsigned k = 0; pred();) {
bool ret = true;
ipc::sleep(k, [&k, &ret, &waiter, &pred, tm] {
ret = waiter.wait_if(std::forward<F>(pred), tm);
k = 0;
});
if (!ret) return false; // timeout or fail
if (k == 0) break; // k has been reset
}
return true;
}
template <typename Policy,
std::size_t DataSize = ipc::data_length,
std::size_t AlignSize = (ipc::detail::min)(DataSize, alignof(std::max_align_t))>
struct queue_generator {
using queue_t = ipc::queue<msg_t<DataSize, AlignSize>, Policy>;
struct conn_info_t : conn_info_head {
queue_t que_;
conn_info_t(char const * name)
: conn_info_head{name}
, que_{("__QU_CONN__" +
ipc::to_string(DataSize) + "__" +
ipc::to_string(AlignSize) + "__" + name).c_str()} {
}
void disconnect_receiver() {
bool dis = que_.disconnect();
this->quit_waiting();
if (dis) {
this->recv_cache().clear();
}
}
};
};
template <typename Policy>
struct detail_impl {
using policy_t = Policy;
using flag_t = typename policy_t::flag_t;
using queue_t = typename queue_generator<policy_t>::queue_t;
using conn_info_t = typename queue_generator<policy_t>::conn_info_t;
constexpr static conn_info_t* info_of(ipc::handle_t h) noexcept {
return static_cast<conn_info_t*>(h);
}
constexpr static queue_t* queue_of(ipc::handle_t h) noexcept {
return (info_of(h) == nullptr) ? nullptr : &(info_of(h)->que_);
}
/* API implementations */
static void disconnect(ipc::handle_t h) {
auto que = queue_of(h);
if (que == nullptr) {
return;
}
que->shut_sending();
assert(info_of(h) != nullptr);
info_of(h)->disconnect_receiver();
}
static bool reconnect(ipc::handle_t * ph, bool start_to_recv) {
assert(ph != nullptr);
assert(*ph != nullptr);
auto que = queue_of(*ph);
if (que == nullptr) {
return false;
}
if (start_to_recv) {
que->shut_sending();
if (que->connect()) { // wouldn't connect twice
info_of(*ph)->cc_waiter_.broadcast();
return true;
}
return false;
}
// start_to_recv == false
if (que->connected()) {
info_of(*ph)->disconnect_receiver();
}
return que->ready_sending();
}
static bool connect(ipc::handle_t * ph, char const * name, bool start_to_recv) {
assert(ph != nullptr);
if (*ph == nullptr) {
*ph = ipc::mem::alloc<conn_info_t>(name);
}
return reconnect(ph, start_to_recv);
}
static void destroy(ipc::handle_t h) {
disconnect(h);
ipc::mem::free(info_of(h));
}
static std::size_t recv_count(ipc::handle_t h) noexcept {
auto que = queue_of(h);
if (que == nullptr) {
return ipc::invalid_value;
}
return que->conn_count();
}
static bool wait_for_recv(ipc::handle_t h, std::size_t r_count, std::uint64_t tm) {
auto que = queue_of(h);
if (que == nullptr) {
return false;
}
return wait_for(info_of(h)->cc_waiter_, [que, r_count] {
return que->conn_count() < r_count;
}, tm);
}
template <typename F>
static bool send(F&& gen_push, ipc::handle_t h, void const * data, std::size_t size) {
if (data == nullptr || size == 0) {
ipc::error("fail: send(%p, %zd)\n", data, size);
return false;
}
auto que = queue_of(h);
if (que == nullptr) {
ipc::error("fail: send, queue_of(h) == nullptr\n");
return false;
}
if (que->elems() == nullptr) {
ipc::error("fail: send, queue_of(h)->elems() == nullptr\n");
return false;
}
if (!que->ready_sending()) {
ipc::error("fail: send, que->ready_sending() == false\n");
return false;
}
ipc::circ::cc_t conns = que->elems()->connections(std::memory_order_relaxed);
if (conns == 0) {
ipc::error("fail: send, there is no receiver on this connection.\n");
return false;
}
// calc a new message id
auto acc = info_of(h)->acc();
if (acc == nullptr) {
ipc::error("fail: send, info_of(h)->acc() == nullptr\n");
return false;
}
auto msg_id = acc->fetch_add(1, std::memory_order_relaxed);
auto try_push = std::forward<F>(gen_push)(info_of(h), que, msg_id);
if (size > ipc::large_msg_limit) {
auto dat = acquire_storage(size, conns);
void * buf = dat.second;
if (buf != nullptr) {
std::memcpy(buf, data, size);
return try_push(static_cast<std::int32_t>(size) -
static_cast<std::int32_t>(ipc::data_length), &(dat.first), 0);
}
// try using message fragment
//ipc::log("fail: shm::handle for big message. msg_id: %zd, size: %zd\n", msg_id, size);
}
// push message fragment
std::int32_t offset = 0;
for (std::int32_t i = 0; i < static_cast<std::int32_t>(size / ipc::data_length); ++i, offset += ipc::data_length) {
if (!try_push(static_cast<std::int32_t>(size) - offset - static_cast<std::int32_t>(ipc::data_length),
static_cast<ipc::byte_t const *>(data) + offset, ipc::data_length)) {
return false;
}
}
// if remain > 0, this is the last message fragment
std::int32_t remain = static_cast<std::int32_t>(size) - offset;
if (remain > 0) {
if (!try_push(remain - static_cast<std::int32_t>(ipc::data_length),
static_cast<ipc::byte_t const *>(data) + offset,
static_cast<std::size_t>(remain))) {
return false;
}
}
return true;
}
static bool send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
return send([tm](auto info, auto que, auto msg_id) {
return [tm, info, que, msg_id](std::int32_t remain, void const * data, std::size_t size) {
if (!wait_for(info->wt_waiter_, [&] {
return !que->push(
[](void*) { return true; },
info->cc_id_, msg_id, remain, data, size);
}, tm)) {
ipc::log("force_push: msg_id = %zd, remain = %d, size = %zd\n", msg_id, remain, size);
if (!que->force_push(
clear_message<typename queue_t::value_t>,
info->cc_id_, msg_id, remain, data, size)) {
return false;
}
}
info->rd_waiter_.broadcast();
return true;
};
}, h, data, size);
}
static bool try_send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
return send([tm](auto info, auto que, auto msg_id) {
return [tm, info, que, msg_id](std::int32_t remain, void const * data, std::size_t size) {
if (!wait_for(info->wt_waiter_, [&] {
return !que->push(
[](void*) { return true; },
info->cc_id_, msg_id, remain, data, size);
}, tm)) {
return false;
}
info->rd_waiter_.broadcast();
return true;
};
}, h, data, size);
}
static ipc::buff_t recv(ipc::handle_t h, std::uint64_t tm) {
auto que = queue_of(h);
if (que == nullptr) {
ipc::error("fail: recv, queue_of(h) == nullptr\n");
return {};
}
if (!que->connected()) {
// hasn't connected yet, just return.
return {};
}
auto& rc = info_of(h)->recv_cache();
for (;;) {
// pop a new message
typename queue_t::value_t msg;
if (!wait_for(info_of(h)->rd_waiter_, [que, &msg] {
return !que->pop(msg);
}, tm)) {
// pop failed, just return.
return {};
}
info_of(h)->wt_waiter_.broadcast();
if ((info_of(h)->acc() != nullptr) && (msg.cc_id_ == info_of(h)->cc_id_)) {
continue; // ignore message to self
}
// msg.remain_ may minus & abs(msg.remain_) < data_length
std::int32_t r_size = static_cast<std::int32_t>(ipc::data_length) + msg.remain_;
if (r_size <= 0) {
ipc::error("fail: recv, r_size = %d\n", (int)r_size);
return {};
}
std::size_t msg_size = static_cast<std::size_t>(r_size);
// large message
if (msg.storage_) {
ipc::storage_id_t buf_id = *reinterpret_cast<ipc::storage_id_t*>(&msg.data_);
void* buf = find_storage(buf_id, msg_size);
if (buf != nullptr) {
struct recycle_t {
ipc::storage_id_t storage_id;
ipc::circ::cc_t curr_conns;
ipc::circ::cc_t conn_id;
} *r_info = ipc::mem::alloc<recycle_t>(recycle_t{
buf_id, que->elems()->connections(std::memory_order_relaxed), que->connected_id()
});
if (r_info == nullptr) {
ipc::log("fail: ipc::mem::alloc<recycle_t>.\n");
return ipc::buff_t{buf, msg_size}; // no recycle
} else {
return ipc::buff_t{buf, msg_size, [](void* p_info, std::size_t size) {
auto r_info = static_cast<recycle_t *>(p_info);
IPC_UNUSED_ auto finally = ipc::guard([r_info] {
ipc::mem::free(r_info);
});
recycle_storage<flag_t>(r_info->storage_id, size, r_info->curr_conns, r_info->conn_id);
}, r_info};
}
} else {
ipc::log("fail: shm::handle for large message. msg_id: %zd, buf_id: %zd, size: %zd\n", msg.id_, buf_id, msg_size);
continue;
}
}
// find cache with msg.id_
auto cac_it = rc.find(msg.id_);
if (cac_it == rc.end()) {
if (msg_size <= ipc::data_length) {
return make_cache(msg.data_, msg_size);
}
// gc
if (rc.size() > 1024) {
std::vector<msg_id_t> need_del;
for (auto const & pair : rc) {
auto cmp = std::minmax(msg.id_, pair.first);
if (cmp.second - cmp.first > 8192) {
need_del.push_back(pair.first);
}
}
for (auto id : need_del) rc.erase(id);
}
// cache the first message fragment
rc.emplace(msg.id_, cache_t { ipc::data_length, make_cache(msg.data_, msg_size) });
}
// has cached before this message
else {
auto& cac = cac_it->second;
// this is the last message fragment
if (msg.remain_ <= 0) {
cac.append(&(msg.data_), msg_size);
// finish this message, erase it from cache
auto buff = std::move(cac.buff_);
rc.erase(cac_it);
return buff;
}
// there are remain datas after this message
cac.append(&(msg.data_), ipc::data_length);
}
}
}
static ipc::buff_t try_recv(ipc::handle_t h) {
return recv(h, 0);
}
}; // detail_impl<Policy>
template <typename Flag>
using policy_t = ipc::policy::choose<ipc::circ::elem_array, Flag>;
} // internal-linkage
namespace ipc {
template <typename Flag>
ipc::handle_t chan_impl<Flag>::inited() {
ipc::detail::waiter::init();
return nullptr;
}
template <typename Flag>
bool chan_impl<Flag>::connect(ipc::handle_t * ph, char const * name, unsigned mode) {
return detail_impl<policy_t<Flag>>::connect(ph, name, mode & receiver);
}
template <typename Flag>
bool chan_impl<Flag>::reconnect(ipc::handle_t * ph, unsigned mode) {
return detail_impl<policy_t<Flag>>::reconnect(ph, mode & receiver);
}
template <typename Flag>
void chan_impl<Flag>::disconnect(ipc::handle_t h) {
detail_impl<policy_t<Flag>>::disconnect(h);
}
template <typename Flag>
void chan_impl<Flag>::destroy(ipc::handle_t h) {
detail_impl<policy_t<Flag>>::destroy(h);
}
template <typename Flag>
char const * chan_impl<Flag>::name(ipc::handle_t h) {
auto info = detail_impl<policy_t<Flag>>::info_of(h);
return (info == nullptr) ? nullptr : info->name_.c_str();
}
template <typename Flag>
std::size_t chan_impl<Flag>::recv_count(ipc::handle_t h) {
return detail_impl<policy_t<Flag>>::recv_count(h);
}
template <typename Flag>
bool chan_impl<Flag>::wait_for_recv(ipc::handle_t h, std::size_t r_count, std::uint64_t tm) {
return detail_impl<policy_t<Flag>>::wait_for_recv(h, r_count, tm);
}
template <typename Flag>
bool chan_impl<Flag>::send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
return detail_impl<policy_t<Flag>>::send(h, data, size, tm);
}
template <typename Flag>
buff_t chan_impl<Flag>::recv(ipc::handle_t h, std::uint64_t tm) {
return detail_impl<policy_t<Flag>>::recv(h, tm);
}
template <typename Flag>
bool chan_impl<Flag>::try_send(ipc::handle_t h, void const * data, std::size_t size, std::uint64_t tm) {
return detail_impl<policy_t<Flag>>::try_send(h, data, size, tm);
}
template <typename Flag>
buff_t chan_impl<Flag>::try_recv(ipc::handle_t h) {
return detail_impl<policy_t<Flag>>::try_recv(h);
}
template struct chan_impl<ipc::wr<relat::single, relat::single, trans::unicast >>;
// template struct chan_impl<ipc::wr<relat::single, relat::multi , trans::unicast >>; // TBD
// template struct chan_impl<ipc::wr<relat::multi , relat::multi , trans::unicast >>; // TBD
template struct chan_impl<ipc::wr<relat::single, relat::multi , trans::broadcast>>;
template struct chan_impl<ipc::wr<relat::multi , relat::multi , trans::broadcast>>;
} // namespace ipc

查看文件

@@ -0,0 +1,25 @@
#pragma once
#include <type_traits>
#include "libipc/def.h"
#include "libipc/prod_cons.h"
#include "libipc/circ/elem_array.h"
namespace ipc {
namespace policy {
template <template <typename, std::size_t...> class Elems, typename Flag>
struct choose;
template <typename Flag>
struct choose<circ::elem_array, Flag> {
using flag_t = Flag;
template <std::size_t DataSize, std::size_t AlignSize>
using elems_t = circ::elem_array<ipc::prod_cons_impl<flag_t>, DataSize, AlignSize>;
};
} // namespace policy
} // namespace ipc

查看文件

@@ -0,0 +1,17 @@
#include "libipc/pool_alloc.h"
#include "libipc/memory/resource.h"
namespace ipc {
namespace mem {
void* pool_alloc::alloc(std::size_t size) {
return async_pool_alloc::alloc(size);
}
void pool_alloc::free(void* p, std::size_t size) {
async_pool_alloc::free(p, size);
}
} // namespace mem
} // namespace ipc

查看文件

@@ -0,0 +1,433 @@
#pragma once
#include <atomic>
#include <utility>
#include <cstring>
#include <type_traits>
#include <cstdint>
#include "libipc/def.h"
#include "libipc/platform/detail.h"
#include "libipc/circ/elem_def.h"
#include "libipc/utility/log.h"
#include "libipc/utility/utility.h"
namespace ipc {
////////////////////////////////////////////////////////////////
/// producer-consumer implementation
////////////////////////////////////////////////////////////////
template <typename Flag>
struct prod_cons_impl;
template <>
struct prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
};
alignas(cache_line_size) std::atomic<circ::u2_t> rd_; // read index
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
constexpr circ::u2_t cursor() const noexcept {
return 0;
}
template <typename W, typename F, typename E>
bool push(W* /*wrapper*/, F&& f, E* elems) {
auto cur_wt = circ::index_of(wt_.load(std::memory_order_relaxed));
if (cur_wt == circ::index_of(rd_.load(std::memory_order_acquire) - 1)) {
return false; // full
}
std::forward<F>(f)(&(elems[cur_wt].data_));
wt_.fetch_add(1, std::memory_order_release);
return true;
}
/**
* In single-single-unicast, 'force_push' means 'no reader' or 'the only one reader is dead'.
* So we could just disconnect all connections of receiver, and return false.
*/
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&&, E*) {
wrapper->elems()->disconnect_receiver(~static_cast<circ::cc_t>(0u));
return false;
}
template <typename W, typename F, typename R, typename E>
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E* elems) {
auto cur_rd = circ::index_of(rd_.load(std::memory_order_relaxed));
if (cur_rd == circ::index_of(wt_.load(std::memory_order_acquire))) {
return false; // empty
}
std::forward<F>(f)(&(elems[cur_rd].data_));
std::forward<R>(out)(true);
rd_.fetch_add(1, std::memory_order_release);
return true;
}
};
template <>
struct prod_cons_impl<wr<relat::single, relat::multi , trans::unicast>>
: prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&&, E*) {
wrapper->elems()->disconnect_receiver(1);
return false;
}
template <typename W, typename F, typename R,
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
byte_t buff[DS];
for (unsigned k = 0;;) {
auto cur_rd = rd_.load(std::memory_order_relaxed);
if (circ::index_of(cur_rd) ==
circ::index_of(wt_.load(std::memory_order_acquire))) {
return false; // empty
}
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
std::forward<F>(f)(buff);
std::forward<R>(out)(true);
return true;
}
ipc::yield(k);
}
}
};
template <>
struct prod_cons_impl<wr<relat::multi , relat::multi, trans::unicast>>
: prod_cons_impl<wr<relat::single, relat::multi, trans::unicast>> {
using flag_t = std::uint64_t;
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
};
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
template <typename W, typename F, typename E>
bool push(W* /*wrapper*/, F&& f, E* elems) {
circ::u2_t cur_ct, nxt_ct;
for (unsigned k = 0;;) {
cur_ct = ct_.load(std::memory_order_relaxed);
if (circ::index_of(nxt_ct = cur_ct + 1) ==
circ::index_of(rd_.load(std::memory_order_acquire))) {
return false; // full
}
if (ct_.compare_exchange_weak(cur_ct, nxt_ct, std::memory_order_acq_rel)) {
break;
}
ipc::yield(k);
}
auto* el = elems + circ::index_of(cur_ct);
std::forward<F>(f)(&(el->data_));
// set flag & try update wt
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
while (1) {
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
if (cur_ct != wt_.load(std::memory_order_relaxed)) {
return true;
}
if ((~cac_ct) != cur_ct) {
return true;
}
if (!el->f_ct_.compare_exchange_strong(cac_ct, 0, std::memory_order_relaxed)) {
return true;
}
wt_.store(nxt_ct, std::memory_order_release);
cur_ct = nxt_ct;
nxt_ct = cur_ct + 1;
el = elems + circ::index_of(cur_ct);
}
return true;
}
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&&, E*) {
wrapper->elems()->disconnect_receiver(1);
return false;
}
template <typename W, typename F, typename R,
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
byte_t buff[DS];
for (unsigned k = 0;;) {
auto cur_rd = rd_.load(std::memory_order_relaxed);
auto cur_wt = wt_.load(std::memory_order_acquire);
auto id_rd = circ::index_of(cur_rd);
auto id_wt = circ::index_of(cur_wt);
if (id_rd == id_wt) {
auto* el = elems + id_wt;
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
if ((~cac_ct) != cur_wt) {
return false; // empty
}
if (el->f_ct_.compare_exchange_weak(cac_ct, 0, std::memory_order_relaxed)) {
wt_.store(cur_wt + 1, std::memory_order_release);
}
k = 0;
}
else {
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
std::forward<F>(f)(buff);
std::forward<R>(out)(true);
return true;
}
ipc::yield(k);
}
}
}
};
template <>
struct prod_cons_impl<wr<relat::single, relat::multi, trans::broadcast>> {
using rc_t = std::uint64_t;
enum : rc_t {
ep_mask = 0x00000000ffffffffull,
ep_incr = 0x0000000100000000ull
};
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
std::atomic<rc_t> rc_ { 0 }; // read-counter
};
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
alignas(cache_line_size) rc_t epoch_ { 0 }; // only one writer
circ::u2_t cursor() const noexcept {
return wt_.load(std::memory_order_acquire);
}
template <typename W, typename F, typename E>
bool push(W* wrapper, F&& f, E* elems) {
E* el;
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_acquire);
circ::cc_t rem_cc = cur_rc & ep_mask;
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch_)) {
return false; // has not finished yet
}
// consider rem_cc to be 0 here
if (el->rc_.compare_exchange_weak(
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
break;
}
ipc::yield(k);
}
std::forward<F>(f)(&(el->data_));
wt_.fetch_add(1, std::memory_order_release);
return true;
}
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&& f, E* elems) {
E* el;
epoch_ += ep_incr;
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_acquire);
circ::cc_t rem_cc = cur_rc & ep_mask;
if (cc & rem_cc) {
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
if (cc == 0) return false; // no reader
}
// just compare & exchange
if (el->rc_.compare_exchange_weak(
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
break;
}
ipc::yield(k);
}
std::forward<F>(f)(&(el->data_));
wt_.fetch_add(1, std::memory_order_release);
return true;
}
template <typename W, typename F, typename R, typename E>
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E* elems) {
if (cur == cursor()) return false; // acquire
auto* el = elems + circ::index_of(cur++);
std::forward<F>(f)(&(el->data_));
for (unsigned k = 0;;) {
auto cur_rc = el->rc_.load(std::memory_order_acquire);
if ((cur_rc & ep_mask) == 0) {
std::forward<R>(out)(true);
return true;
}
auto nxt_rc = cur_rc & ~static_cast<rc_t>(wrapper->connected_id());
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
std::forward<R>(out)((nxt_rc & ep_mask) == 0);
return true;
}
ipc::yield(k);
}
}
};
template <>
struct prod_cons_impl<wr<relat::multi, relat::multi, trans::broadcast>> {
using rc_t = std::uint64_t;
using flag_t = std::uint64_t;
enum : rc_t {
rc_mask = 0x00000000ffffffffull,
ep_mask = 0x00ffffffffffffffull,
ep_incr = 0x0100000000000000ull,
ic_mask = 0xff000000ffffffffull,
ic_incr = 0x0000000100000000ull
};
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
std::atomic<rc_t > rc_ { 0 }; // read-counter
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
};
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
alignas(cache_line_size) std::atomic<rc_t> epoch_ { 0 };
circ::u2_t cursor() const noexcept {
return ct_.load(std::memory_order_acquire);
}
constexpr static rc_t inc_rc(rc_t rc) noexcept {
return (rc & ic_mask) | ((rc + ic_incr) & ~ic_mask);
}
constexpr static rc_t inc_mask(rc_t rc) noexcept {
return inc_rc(rc) & ~rc_mask;
}
template <typename W, typename F, typename E>
bool push(W* wrapper, F&& f, E* elems) {
E* el;
circ::u2_t cur_ct;
rc_t epoch = epoch_.load(std::memory_order_acquire);
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_relaxed);
circ::cc_t rem_cc = cur_rc & rc_mask;
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch)) {
return false; // has not finished yet
}
else if (!rem_cc) {
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
if ((cur_fl != cur_ct) && cur_fl) {
return false; // full
}
}
// consider rem_cc to be 0 here
if (el->rc_.compare_exchange_weak(
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed) &&
epoch_.compare_exchange_weak(epoch, epoch, std::memory_order_acq_rel)) {
break;
}
ipc::yield(k);
}
// only one thread/process would touch here at one time
ct_.store(cur_ct + 1, std::memory_order_release);
std::forward<F>(f)(&(el->data_));
// set flag & try update wt
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
return true;
}
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&& f, E* elems) {
E* el;
circ::u2_t cur_ct;
rc_t epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_acquire);
circ::cc_t rem_cc = cur_rc & rc_mask;
if (cc & rem_cc) {
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
if (cc == 0) return false; // no reader
}
// just compare & exchange
if (el->rc_.compare_exchange_weak(
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed)) {
if (epoch == epoch_.load(std::memory_order_acquire)) {
break;
}
else if (push(wrapper, std::forward<F>(f), elems)) {
return true;
}
epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
}
ipc::yield(k);
}
// only one thread/process would touch here at one time
ct_.store(cur_ct + 1, std::memory_order_release);
std::forward<F>(f)(&(el->data_));
// set flag & try update wt
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
return true;
}
template <typename W, typename F, typename R, typename E, std::size_t N>
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E(& elems)[N]) {
auto* el = elems + circ::index_of(cur);
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
if (cur_fl != ~static_cast<flag_t>(cur)) {
return false; // empty
}
++cur;
std::forward<F>(f)(&(el->data_));
for (unsigned k = 0;;) {
auto cur_rc = el->rc_.load(std::memory_order_acquire);
if ((cur_rc & rc_mask) == 0) {
std::forward<R>(out)(true);
el->f_ct_.store(cur + N - 1, std::memory_order_release);
return true;
}
auto nxt_rc = inc_rc(cur_rc) & ~static_cast<rc_t>(wrapper->connected_id());
bool last_one = false;
if ((last_one = (nxt_rc & rc_mask) == 0)) {
el->f_ct_.store(cur + N - 1, std::memory_order_release);
}
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
std::forward<R>(out)(last_one);
return true;
}
ipc::yield(k);
}
}
};
} // namespace ipc

查看文件

@@ -0,0 +1,216 @@
#pragma once
#include <type_traits>
#include <new>
#include <utility> // [[since C++14]]: std::exchange
#include <algorithm>
#include <atomic>
#include <tuple>
#include <thread>
#include <chrono>
#include <string>
#include <cassert> // assert
#include "libipc/def.h"
#include "libipc/shm.h"
#include "libipc/rw_lock.h"
#include "libipc/utility/log.h"
#include "libipc/platform/detail.h"
#include "libipc/circ/elem_def.h"
namespace ipc {
namespace detail {
class queue_conn {
protected:
circ::cc_t connected_ = 0;
shm::handle elems_h_;
template <typename Elems>
Elems* open(char const * name) {
if (name == nullptr || name[0] == '\0') {
ipc::error("fail open waiter: name is empty!\n");
return nullptr;
}
if (!elems_h_.acquire(name, sizeof(Elems))) {
return nullptr;
}
auto elems = static_cast<Elems*>(elems_h_.get());
if (elems == nullptr) {
ipc::error("fail acquire elems: %s\n", name);
return nullptr;
}
elems->init();
return elems;
}
void close() {
elems_h_.release();
}
public:
queue_conn() = default;
queue_conn(const queue_conn&) = delete;
queue_conn& operator=(const queue_conn&) = delete;
bool connected() const noexcept {
return connected_ != 0;
}
circ::cc_t connected_id() const noexcept {
return connected_;
}
template <typename Elems>
auto connect(Elems* elems) noexcept
/*needs 'optional' here*/
-> std::tuple<bool, bool, decltype(std::declval<Elems>().cursor())> {
if (elems == nullptr) return {};
// if it's already connected, just return
if (connected()) return {connected(), false, 0};
connected_ = elems->connect_receiver();
return {connected(), true, elems->cursor()};
}
template <typename Elems>
bool disconnect(Elems* elems) noexcept {
if (elems == nullptr) return false;
// if it's already disconnected, just return false
if (!connected()) return false;
elems->disconnect_receiver(std::exchange(connected_, 0));
return true;
}
};
template <typename Elems>
class queue_base : public queue_conn {
using base_t = queue_conn;
public:
using elems_t = Elems;
using policy_t = typename elems_t::policy_t;
protected:
elems_t * elems_ = nullptr;
decltype(std::declval<elems_t>().cursor()) cursor_ = 0;
bool sender_flag_ = false;
public:
using base_t::base_t;
queue_base() = default;
explicit queue_base(char const * name)
: queue_base{} {
elems_ = open<elems_t>(name);
}
explicit queue_base(elems_t * elems) noexcept
: queue_base{} {
assert(elems != nullptr);
elems_ = elems;
}
/* not virtual */ ~queue_base() {
base_t::close();
}
elems_t * elems() noexcept { return elems_; }
elems_t const * elems() const noexcept { return elems_; }
bool ready_sending() noexcept {
if (elems_ == nullptr) return false;
return sender_flag_ || (sender_flag_ = elems_->connect_sender());
}
void shut_sending() noexcept {
if (elems_ == nullptr) return;
if (!sender_flag_) return;
elems_->disconnect_sender();
}
bool connect() noexcept {
auto tp = base_t::connect(elems_);
if (std::get<0>(tp) && std::get<1>(tp)) {
cursor_ = std::get<2>(tp);
return true;
}
return std::get<0>(tp);
}
bool disconnect() noexcept {
return base_t::disconnect(elems_);
}
std::size_t conn_count() const noexcept {
return (elems_ == nullptr) ? static_cast<std::size_t>(invalid_value) : elems_->conn_count();
}
bool valid() const noexcept {
return elems_ != nullptr;
}
bool empty() const noexcept {
return !valid() || (cursor_ == elems_->cursor());
}
template <typename T, typename F, typename... P>
bool push(F&& prep, P&&... params) {
if (elems_ == nullptr) return false;
return elems_->push(this, [&](void* p) {
if (prep(p)) ::new (p) T(std::forward<P>(params)...);
});
}
template <typename T, typename F, typename... P>
bool force_push(F&& prep, P&&... params) {
if (elems_ == nullptr) return false;
return elems_->force_push(this, [&](void* p) {
if (prep(p)) ::new (p) T(std::forward<P>(params)...);
});
}
template <typename T, typename F>
bool pop(T& item, F&& out) {
if (elems_ == nullptr) {
return false;
}
return elems_->pop(this, &(this->cursor_), [&item](void* p) {
::new (&item) T(std::move(*static_cast<T*>(p)));
}, std::forward<F>(out));
}
};
} // namespace detail
template <typename T, typename Policy>
class queue final : public detail::queue_base<typename Policy::template elems_t<sizeof(T), alignof(T)>> {
using base_t = detail::queue_base<typename Policy::template elems_t<sizeof(T), alignof(T)>>;
public:
using value_t = T;
using base_t::base_t;
template <typename... P>
bool push(P&&... params) {
return base_t::template push<T>(std::forward<P>(params)...);
}
template <typename... P>
bool force_push(P&&... params) {
return base_t::template force_push<T>(std::forward<P>(params)...);
}
bool pop(T& item) {
return base_t::pop(item, [](bool) {});
}
template <typename F>
bool pop(T& item, F&& out) {
return base_t::pop(item, std::forward<F>(out));
}
};
} // namespace ipc

查看文件

@@ -0,0 +1,103 @@
#include <string>
#include <utility>
#include "libipc/shm.h"
#include "libipc/utility/pimpl.h"
#include "libipc/memory/resource.h"
namespace ipc {
namespace shm {
class handle::handle_ : public pimpl<handle_> {
public:
shm::id_t id_ = nullptr;
void* m_ = nullptr;
ipc::string n_;
std::size_t s_ = 0;
};
handle::handle()
: p_(p_->make()) {
}
handle::handle(char const * name, std::size_t size, unsigned mode)
: handle() {
acquire(name, size, mode);
}
handle::handle(handle&& rhs)
: handle() {
swap(rhs);
}
handle::~handle() {
release();
p_->clear();
}
void handle::swap(handle& rhs) {
std::swap(p_, rhs.p_);
}
handle& handle::operator=(handle rhs) {
swap(rhs);
return *this;
}
bool handle::valid() const noexcept {
return impl(p_)->m_ != nullptr;
}
std::size_t handle::size() const noexcept {
return impl(p_)->s_;
}
char const * handle::name() const noexcept {
return impl(p_)->n_.c_str();
}
std::int32_t handle::ref() const noexcept {
return shm::get_ref(impl(p_)->id_);
}
void handle::sub_ref() noexcept {
shm::sub_ref(impl(p_)->id_);
}
bool handle::acquire(char const * name, std::size_t size, unsigned mode) {
release();
impl(p_)->id_ = shm::acquire((impl(p_)->n_ = name).c_str(), size, mode);
impl(p_)->m_ = shm::get_mem(impl(p_)->id_, &(impl(p_)->s_));
return valid();
}
std::int32_t handle::release() {
if (impl(p_)->id_ == nullptr) return -1;
return shm::release(detach());
}
void* handle::get() const {
return impl(p_)->m_;
}
void handle::attach(id_t id) {
if (id == nullptr) return;
release();
impl(p_)->id_ = id;
impl(p_)->m_ = shm::get_mem(impl(p_)->id_, &(impl(p_)->s_));
}
id_t handle::detach() {
auto old = impl(p_)->id_;
impl(p_)->id_ = nullptr;
impl(p_)->m_ = nullptr;
impl(p_)->s_ = 0;
impl(p_)->n_.clear();
return old;
}
} // namespace shm
} // namespace ipc

查看文件

@@ -0,0 +1,83 @@
#pragma once
#include <utility>
#include <string>
#include <mutex>
#include <atomic>
#include "libipc/def.h"
#include "libipc/mutex.h"
#include "libipc/condition.h"
#include "libipc/platform/detail.h"
namespace ipc {
namespace detail {
class waiter {
ipc::sync::condition cond_;
ipc::sync::mutex lock_;
std::atomic<bool> quit_ {false};
public:
static void init();
waiter() = default;
waiter(char const *name) {
open(name);
}
~waiter() {
close();
}
bool valid() const noexcept {
return cond_.valid() && lock_.valid();
}
bool open(char const *name) noexcept {
quit_.store(false, std::memory_order_relaxed);
if (!cond_.open((std::string{"_waiter_cond_"} + name).c_str())) {
return false;
}
if (!lock_.open((std::string{"_waiter_lock_"} + name).c_str())) {
cond_.close();
return false;
}
return valid();
}
void close() noexcept {
cond_.close();
lock_.close();
}
template <typename F>
bool wait_if(F &&pred, std::uint64_t tm = ipc::invalid_value) noexcept {
IPC_UNUSED_ std::lock_guard<ipc::sync::mutex> guard {lock_};
while ([this, &pred] {
return !quit_.load(std::memory_order_relaxed)
&& std::forward<F>(pred)();
}()) {
if (!cond_.wait(lock_, tm)) return false;
}
return true;
}
bool notify() noexcept {
std::lock_guard<ipc::sync::mutex>{lock_}; // barrier
return cond_.notify(lock_);
}
bool broadcast() noexcept {
std::lock_guard<ipc::sync::mutex>{lock_}; // barrier
return cond_.broadcast(lock_);
}
bool quit_waiting() {
quit_.store(true, std::memory_order_release);
return broadcast();
}
};
} // namespace detail
} // namespace ipc

查看文件

@@ -0,0 +1,3 @@
https://github.com/mutouyun/cpp-ipc
A high-performance inter-process communication library using shared memory on Linux/Windows.

文件差异内容过多而无法显示 加载差异

查看文件

@@ -0,0 +1,316 @@
// jpgd.h - C++ class for JPEG decompression.
// Public domain, Rich Geldreich <richgel99@gmail.com>
#ifndef JPEG_DECODER_H
#define JPEG_DECODER_H
#include <stdlib.h>
#include <stdio.h>
#include <setjmp.h>
namespace jpgd
{
typedef unsigned char uint8;
typedef signed short int16;
typedef unsigned short uint16;
typedef unsigned int uint;
typedef signed int int32;
// Loads a JPEG image from a memory buffer or a file.
// req_comps can be 1 (grayscale), 3 (RGB), or 4 (RGBA).
// On return, width/height will be set to the image's dimensions, and actual_comps will be set to the either 1 (grayscale) or 3 (RGB).
// Notes: For more control over where and how the source data is read, see the decompress_jpeg_image_from_stream() function below, or call the jpeg_decoder class directly.
// Requesting a 8 or 32bpp image is currently a little faster than 24bpp because the jpeg_decoder class itself currently always unpacks to either 8 or 32bpp.
// BEGIN EPIC MOD
//unsigned char *decompress_jpeg_image_from_memory(const unsigned char *pSrc_data, int src_data_size, int *width, int *height, int *actual_comps, int req_comps);
unsigned char *decompress_jpeg_image_from_memory(const unsigned char *pSrc_data, int src_data_size, int *width, int *height, int *actual_comps, int req_comps, int format);
// END EPIC MOD
unsigned char *decompress_jpeg_image_from_file(const char *pSrc_filename, int *width, int *height, int *actual_comps, int req_comps);
// Success/failure error codes.
enum jpgd_status
{
JPGD_SUCCESS = 0, JPGD_FAILED = -1, JPGD_DONE = 1,
JPGD_BAD_DHT_COUNTS = -256, JPGD_BAD_DHT_INDEX, JPGD_BAD_DHT_MARKER, JPGD_BAD_DQT_MARKER, JPGD_BAD_DQT_TABLE,
JPGD_BAD_PRECISION, JPGD_BAD_HEIGHT, JPGD_BAD_WIDTH, JPGD_TOO_MANY_COMPONENTS,
JPGD_BAD_SOF_LENGTH, JPGD_BAD_VARIABLE_MARKER, JPGD_BAD_DRI_LENGTH, JPGD_BAD_SOS_LENGTH,
JPGD_BAD_SOS_COMP_ID, JPGD_W_EXTRA_BYTES_BEFORE_MARKER, JPGD_NO_ARITHMITIC_SUPPORT, JPGD_UNEXPECTED_MARKER,
JPGD_NOT_JPEG, JPGD_UNSUPPORTED_MARKER, JPGD_BAD_DQT_LENGTH, JPGD_TOO_MANY_BLOCKS,
JPGD_UNDEFINED_QUANT_TABLE, JPGD_UNDEFINED_HUFF_TABLE, JPGD_NOT_SINGLE_SCAN, JPGD_UNSUPPORTED_COLORSPACE,
JPGD_UNSUPPORTED_SAMP_FACTORS, JPGD_DECODE_ERROR, JPGD_BAD_RESTART_MARKER, JPGD_ASSERTION_ERROR,
JPGD_BAD_SOS_SPECTRAL, JPGD_BAD_SOS_SUCCESSIVE, JPGD_STREAM_READ, JPGD_NOTENOUGHMEM
};
// Input stream interface.
// Derive from this class to read input data from sources other than files or memory. Set m_eof_flag to true when no more data is available.
// The decoder is rather greedy: it will keep on calling this method until its internal input buffer is full, or until the EOF flag is set.
// It the input stream contains data after the JPEG stream's EOI (end of image) marker it will probably be pulled into the internal buffer.
// Call the get_total_bytes_read() method to determine the actual size of the JPEG stream after successful decoding.
class jpeg_decoder_stream
{
public:
jpeg_decoder_stream() { }
virtual ~jpeg_decoder_stream() { }
// The read() method is called when the internal input buffer is empty.
// Parameters:
// pBuf - input buffer
// max_bytes_to_read - maximum bytes that can be written to pBuf
// pEOF_flag - set this to true if at end of stream (no more bytes remaining)
// Returns -1 on error, otherwise return the number of bytes actually written to the buffer (which may be 0).
// Notes: This method will be called in a loop until you set *pEOF_flag to true or the internal buffer is full.
virtual int read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag) = 0;
};
// stdio FILE stream class.
class jpeg_decoder_file_stream : public jpeg_decoder_stream
{
jpeg_decoder_file_stream(const jpeg_decoder_file_stream &);
jpeg_decoder_file_stream &operator =(const jpeg_decoder_file_stream &);
FILE *m_pFile;
bool m_eof_flag, m_error_flag;
public:
jpeg_decoder_file_stream();
virtual ~jpeg_decoder_file_stream();
bool open(const char *Pfilename);
void close();
virtual int read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag);
};
// Memory stream class.
class jpeg_decoder_mem_stream : public jpeg_decoder_stream
{
const uint8 *m_pSrc_data;
uint m_ofs, m_size;
public:
jpeg_decoder_mem_stream() : m_pSrc_data(NULL), m_ofs(0), m_size(0) { }
jpeg_decoder_mem_stream(const uint8 *pSrc_data, uint size) : m_pSrc_data(pSrc_data), m_ofs(0), m_size(size) { }
virtual ~jpeg_decoder_mem_stream() { }
bool open(const uint8 *pSrc_data, uint size);
void close() { m_pSrc_data = NULL; m_ofs = 0; m_size = 0; }
virtual int read(uint8 *pBuf, int max_bytes_to_read, bool *pEOF_flag);
};
// Loads JPEG file from a jpeg_decoder_stream.
unsigned char *decompress_jpeg_image_from_stream(jpeg_decoder_stream *pStream, int *width, int *height, int *actual_comps, int req_comps);
enum
{
JPGD_IN_BUF_SIZE = 8192, JPGD_MAX_BLOCKS_PER_MCU = 10, JPGD_MAX_HUFF_TABLES = 8, JPGD_MAX_QUANT_TABLES = 4,
JPGD_MAX_COMPONENTS = 4, JPGD_MAX_COMPS_IN_SCAN = 4, JPGD_MAX_BLOCKS_PER_ROW = 8192, JPGD_MAX_HEIGHT = 16384, JPGD_MAX_WIDTH = 16384
};
typedef int16 jpgd_quant_t;
typedef int16 jpgd_block_t;
class jpeg_decoder
{
public:
// Call get_error_code() after constructing to determine if the stream is valid or not. You may call the get_width(), get_height(), etc.
// methods after the constructor is called. You may then either destruct the object, or begin decoding the image by calling begin_decoding(), then decode() on each scanline.
jpeg_decoder(jpeg_decoder_stream *pStream);
~jpeg_decoder();
// Call this method after constructing the object to begin decompression.
// If JPGD_SUCCESS is returned you may then call decode() on each scanline.
int begin_decoding();
// Returns the next scan line.
// For grayscale images, pScan_line will point to a buffer containing 8-bit pixels (get_bytes_per_pixel() will return 1).
// Otherwise, it will always point to a buffer containing 32-bit RGBA pixels (A will always be 255, and get_bytes_per_pixel() will return 4).
// Returns JPGD_SUCCESS if a scan line has been returned.
// Returns JPGD_DONE if all scan lines have been returned.
// Returns JPGD_FAILED if an error occurred. Call get_error_code() for a more info.
int decode(const void** pScan_line, uint* pScan_line_len);
inline jpgd_status get_error_code() const { return m_error_code; }
inline int get_width() const { return m_image_x_size; }
inline int get_height() const { return m_image_y_size; }
inline int get_num_components() const { return m_comps_in_frame; }
inline int get_bytes_per_pixel() const { return m_dest_bytes_per_pixel; }
inline int get_bytes_per_scan_line() const { return m_image_x_size * get_bytes_per_pixel(); }
// Returns the total number of bytes actually consumed by the decoder (which should equal the actual size of the JPEG file).
inline int get_total_bytes_read() const { return m_total_bytes_read; }
private:
jpeg_decoder(const jpeg_decoder &);
jpeg_decoder &operator =(const jpeg_decoder &);
typedef void (*pDecode_block_func)(jpeg_decoder *, int, int, int);
struct huff_tables
{
bool ac_table;
uint look_up[256];
uint look_up2[256];
uint8 code_size[256];
uint tree[512];
};
struct coeff_buf
{
uint8 *pData;
int block_num_x, block_num_y;
int block_len_x, block_len_y;
int block_size;
};
struct mem_block
{
mem_block *m_pNext;
size_t m_used_count;
size_t m_size;
char m_data[1];
};
jmp_buf m_jmp_state;
mem_block *m_pMem_blocks;
int m_image_x_size;
int m_image_y_size;
jpeg_decoder_stream *m_pStream;
int m_progressive_flag;
uint8 m_huff_ac[JPGD_MAX_HUFF_TABLES];
uint8* m_huff_num[JPGD_MAX_HUFF_TABLES]; // pointer to number of Huffman codes per bit size
uint8* m_huff_val[JPGD_MAX_HUFF_TABLES]; // pointer to Huffman codes per bit size
jpgd_quant_t* m_quant[JPGD_MAX_QUANT_TABLES]; // pointer to quantization tables
int m_scan_type; // Gray, Yh1v1, Yh1v2, Yh2v1, Yh2v2 (CMYK111, CMYK4114 no longer supported)
int m_comps_in_frame; // # of components in frame
int m_comp_h_samp[JPGD_MAX_COMPONENTS]; // component's horizontal sampling factor
int m_comp_v_samp[JPGD_MAX_COMPONENTS]; // component's vertical sampling factor
int m_comp_quant[JPGD_MAX_COMPONENTS]; // component's quantization table selector
int m_comp_ident[JPGD_MAX_COMPONENTS]; // component's ID
int m_comp_h_blocks[JPGD_MAX_COMPONENTS];
int m_comp_v_blocks[JPGD_MAX_COMPONENTS];
int m_comps_in_scan; // # of components in scan
int m_comp_list[JPGD_MAX_COMPS_IN_SCAN]; // components in this scan
int m_comp_dc_tab[JPGD_MAX_COMPONENTS]; // component's DC Huffman coding table selector
int m_comp_ac_tab[JPGD_MAX_COMPONENTS]; // component's AC Huffman coding table selector
int m_spectral_start; // spectral selection start
int m_spectral_end; // spectral selection end
int m_successive_low; // successive approximation low
int m_successive_high; // successive approximation high
int m_max_mcu_x_size; // MCU's max. X size in pixels
int m_max_mcu_y_size; // MCU's max. Y size in pixels
int m_blocks_per_mcu;
int m_max_blocks_per_row;
int m_mcus_per_row, m_mcus_per_col;
int m_mcu_org[JPGD_MAX_BLOCKS_PER_MCU];
int m_total_lines_left; // total # lines left in image
int m_mcu_lines_left; // total # lines left in this MCU
int m_real_dest_bytes_per_scan_line;
int m_dest_bytes_per_scan_line; // rounded up
int m_dest_bytes_per_pixel; // 4 (RGB) or 1 (Y)
huff_tables* m_pHuff_tabs[JPGD_MAX_HUFF_TABLES];
coeff_buf* m_dc_coeffs[JPGD_MAX_COMPONENTS];
coeff_buf* m_ac_coeffs[JPGD_MAX_COMPONENTS];
int m_eob_run;
int m_block_y_mcu[JPGD_MAX_COMPONENTS];
uint8* m_pIn_buf_ofs;
int m_in_buf_left;
int m_tem_flag;
bool m_eof_flag;
uint8 m_in_buf_pad_start[128];
uint8 m_in_buf[JPGD_IN_BUF_SIZE + 128];
uint8 m_in_buf_pad_end[128];
int m_bits_left;
uint m_bit_buf;
int m_restart_interval;
int m_restarts_left;
int m_next_restart_num;
int m_max_mcus_per_row;
int m_max_blocks_per_mcu;
int m_expanded_blocks_per_mcu;
int m_expanded_blocks_per_row;
int m_expanded_blocks_per_component;
bool m_freq_domain_chroma_upsample;
int m_max_mcus_per_col;
uint m_last_dc_val[JPGD_MAX_COMPONENTS];
jpgd_block_t* m_pMCU_coefficients;
int m_mcu_block_max_zag[JPGD_MAX_BLOCKS_PER_MCU];
uint8* m_pSample_buf;
int m_crr[256];
int m_cbb[256];
int m_crg[256];
int m_cbg[256];
uint8* m_pScan_line_0;
uint8* m_pScan_line_1;
jpgd_status m_error_code;
bool m_ready_flag;
int m_total_bytes_read;
void free_all_blocks();
// BEGIN EPIC MOD
UE_NORETURN void stop_decoding(jpgd_status status);
// END EPIC MOD
void *alloc(size_t n, bool zero = false);
void word_clear(void *p, uint16 c, uint n);
void prep_in_buffer();
void read_dht_marker();
void read_dqt_marker();
void read_sof_marker();
void skip_variable_marker();
void read_dri_marker();
void read_sos_marker();
int next_marker();
int process_markers();
void locate_soi_marker();
void locate_sof_marker();
int locate_sos_marker();
void init(jpeg_decoder_stream * pStream);
void create_look_ups();
void fix_in_buffer();
void transform_mcu(int mcu_row);
void transform_mcu_expand(int mcu_row);
coeff_buf* coeff_buf_open(int block_num_x, int block_num_y, int block_len_x, int block_len_y);
inline jpgd_block_t *coeff_buf_getp(coeff_buf *cb, int block_x, int block_y);
void load_next_row();
void decode_next_row();
void make_huff_table(int index, huff_tables *pH);
void check_quant_tables();
void check_huff_tables();
void calc_mcu_block_order();
int init_scan();
void init_frame();
void process_restart();
void decode_scan(pDecode_block_func decode_block_func);
void init_progressive();
void init_sequential();
void decode_start();
void decode_init(jpeg_decoder_stream * pStream);
void H2V2Convert();
void H2V1Convert();
void H1V2Convert();
void H1V1Convert();
void gray_convert();
void expanded_convert();
void find_eoi();
inline uint get_char();
inline uint get_char(bool *pPadding_flag);
inline void stuff_char(uint8 q);
inline uint8 get_octet();
inline uint get_bits(int num_bits);
inline uint get_bits_no_markers(int numbits);
inline int huff_decode(huff_tables *pH);
inline int huff_decode(huff_tables *pH, int& extrabits);
static inline uint8 clamp(int i);
static void decode_block_dc_first(jpeg_decoder *pD, int component_id, int block_x, int block_y);
static void decode_block_dc_refine(jpeg_decoder *pD, int component_id, int block_x, int block_y);
static void decode_block_ac_first(jpeg_decoder *pD, int component_id, int block_x, int block_y);
static void decode_block_ac_refine(jpeg_decoder *pD, int component_id, int block_x, int block_y);
};
} // namespace jpgd
#endif // JPEG_DECODER_H

文件差异内容过多而无法显示 加载差异

查看文件

@@ -0,0 +1,172 @@
// jpge.h - C++ class for JPEG compression.
// Public domain, Rich Geldreich <richgel99@gmail.com>
// Alex Evans: Added RGBA support, linear memory allocator.
#ifndef JPEG_ENCODER_H
#define JPEG_ENCODER_H
#include <stdint.h>
namespace jpge
{
typedef unsigned char uint8;
typedef signed short int16;
typedef signed int int32;
typedef unsigned short uint16;
typedef unsigned int uint32;
typedef unsigned int uint;
// JPEG chroma subsampling factors. Y_ONLY (grayscale images) and H2V2 (color images) are the most common.
enum subsampling_t { Y_ONLY = 0, H1V1 = 1, H2V1 = 2, H2V2 = 3 };
// JPEG compression parameters structure.
struct params
{
inline params() : m_quality(85), m_subsampling(H2V2), m_no_chroma_discrim_flag(false), m_two_pass_flag(false) { }
inline bool check_valid() const
{
if ((m_quality < 1) || (m_quality > 100)) return false;
if ((uint)m_subsampling > (uint)H2V2) return false;
return true;
}
// Quality: 1-100, higher is better. Typical values are around 50-95.
int m_quality;
// m_subsampling:
// 0 = Y (grayscale) only
// 1 = YCbCr, no subsampling (H1V1, YCbCr 1x1x1, 3 blocks per MCU)
// 2 = YCbCr, H2V1 subsampling (YCbCr 2x1x1, 4 blocks per MCU)
// 3 = YCbCr, H2V2 subsampling (YCbCr 4x1x1, 6 blocks per MCU-- very common)
subsampling_t m_subsampling;
// Disables CbCr discrimination - only intended for testing.
// If true, the Y quantization table is also used for the CbCr channels.
bool m_no_chroma_discrim_flag;
bool m_two_pass_flag;
};
// Writes JPEG image to a file.
// num_channels must be 1 (Y) or 3 (RGB), image pitch must be width*num_channels.
bool compress_image_to_jpeg_file(const char *pFilename, int64_t width, int64_t height, int64_t num_channels, const uint8 *pImage_data, const params &comp_params = params());
// Writes JPEG image to memory buffer.
// On entry, buf_size is the size of the output buffer pointed at by pBuf, which should be at least ~1024 bytes.
// If return value is true, buf_size will be set to the size of the compressed data.
bool compress_image_to_jpeg_file_in_memory(void *pBuf, int64_t &buf_size, int64_t width, int64_t height, int64_t num_channels, const uint8 *pImage_data, const params &comp_params = params());
// Output stream abstract class - used by the jpeg_encoder class to write to the output stream.
// put_buf() is generally called with len==JPGE_OUT_BUF_SIZE bytes, but for headers it'll be called with smaller amounts.
class output_stream
{
public:
virtual ~output_stream() { };
virtual bool put_buf(const void* Pbuf, int64_t len) = 0;
template<class T> inline bool put_obj(const T& obj) { return put_buf(&obj, sizeof(T)); }
};
// Lower level jpeg_encoder class - useful if more control is needed than the above helper functions.
class jpeg_encoder
{
public:
jpeg_encoder();
~jpeg_encoder();
// Initializes the compressor.
// pStream: The stream object to use for writing compressed data.
// params - Compression parameters structure, defined above.
// width, height - Image dimensions.
// channels - May be 1, or 3. 1 indicates grayscale, 3 indicates RGB source data.
// Returns false on out of memory or if a stream write fails.
bool init(output_stream *pStream, int64_t width, int64_t height, int64_t src_channels, const params &comp_params = params());
const params &get_params() const { return m_params; }
// Deinitializes the compressor, freeing any allocated memory. May be called at any time.
void deinit();
uint get_total_passes() const { return m_params.m_two_pass_flag ? 2 : 1; }
inline uint get_cur_pass() { return m_pass_num; }
// Call this method with each source scanline.
// width * src_channels bytes per scanline is expected (RGB or Y format).
// You must call with NULL after all scanlines are processed to finish compression.
// Returns false on out of memory or if a stream write fails.
bool process_scanline(const void* pScanline);
private:
jpeg_encoder(const jpeg_encoder &);
jpeg_encoder &operator =(const jpeg_encoder &);
typedef int32 sample_array_t;
output_stream *m_pStream;
params m_params;
uint8 m_num_components;
uint8 m_comp_h_samp[3], m_comp_v_samp[3];
int m_image_x, m_image_y, m_image_bpp, m_image_bpl;
int m_image_x_mcu, m_image_y_mcu;
int m_image_bpl_xlt, m_image_bpl_mcu;
int m_mcus_per_row;
int m_mcu_x, m_mcu_y;
uint8 *m_mcu_lines[16];
uint8 m_mcu_y_ofs;
sample_array_t m_sample_array[64];
int16 m_coefficient_array[64];
int32 m_quantization_tables[2][64];
uint m_huff_codes[4][256];
uint8 m_huff_code_sizes[4][256];
uint8 m_huff_bits[4][17];
uint8 m_huff_val[4][256];
uint32 m_huff_count[4][256];
int m_last_dc_val[3];
enum { JPGE_OUT_BUF_SIZE = 2048 };
uint8 m_out_buf[JPGE_OUT_BUF_SIZE];
uint8 *m_pOut_buf;
uint m_out_buf_left;
uint32 m_bit_buffer;
uint m_bits_in;
uint8 m_pass_num;
bool m_all_stream_writes_succeeded;
void optimize_huffman_table(int table_num, int table_len);
void emit_byte(uint8 i);
void emit_word(uint i);
void emit_marker(int marker);
void emit_jfif_app0();
void emit_dqt();
void emit_sof();
void emit_dht(uint8 *bits, uint8 *val, int index, bool ac_flag);
void emit_dhts();
void emit_sos();
void emit_markers();
void compute_huffman_table(uint *codes, uint8 *code_sizes, uint8 *bits, uint8 *val);
void compute_quant_table(int32 *dst, int16 *src);
void adjust_quant_table(int32 *dst, int32 *src);
void first_pass_init();
bool second_pass_init();
bool jpg_open(int p_x_res, int p_y_res, int src_channels);
void load_block_8_8_grey(int x);
void load_block_8_8(int x, int y, int c);
void load_block_16_8(int x, int c);
void load_block_16_8_8(int x, int c);
void load_quantized_coefficients(int component_num);
void flush_output_buffer();
void put_bits(uint bits, uint len);
void code_coefficients_pass_one(int component_num);
void code_coefficients_pass_two(int component_num);
void code_block(int component_num);
void process_mcu_row();
bool terminate_pass_one();
bool terminate_pass_two();
bool process_end_of_image();
void load_mcu(const void* src);
void clear();
void init();
};
} // namespace jpge
#endif // JPEG_ENCODER

查看文件

@@ -0,0 +1,3 @@
jpge.h - C++ class for JPEG compression.
Public domain, Rich Geldreich <richgel99@gmail.com>
Alex Evans: Added RGBA support, linear memory allocator.

文件差异内容过多而无法显示 加载差异

文件差异内容过多而无法显示 加载差异

查看文件

@@ -0,0 +1,433 @@
#pragma once
#include <atomic>
#include <utility>
#include <cstring>
#include <type_traits>
#include <cstdint>
#include "libipc/def.h"
#include "libipc/platform/detail.h"
#include "libipc/circ/elem_def.h"
#include "libipc/utility/log.h"
#include "libipc/utility/utility.h"
namespace ipc {
////////////////////////////////////////////////////////////////
/// producer-consumer implementation
////////////////////////////////////////////////////////////////
template <typename Flag>
struct prod_cons_impl;
template <>
struct prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
};
alignas(cache_line_size) std::atomic<circ::u2_t> rd_; // read index
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
constexpr circ::u2_t cursor() const noexcept {
return 0;
}
template <typename W, typename F, typename E>
bool push(W* /*wrapper*/, F&& f, E* elems) {
auto cur_wt = circ::index_of(wt_.load(std::memory_order_relaxed));
if (cur_wt == circ::index_of(rd_.load(std::memory_order_acquire) - 1)) {
return false; // full
}
std::forward<F>(f)(&(elems[cur_wt].data_));
wt_.fetch_add(1, std::memory_order_release);
return true;
}
/**
* In single-single-unicast, 'force_push' means 'no reader' or 'the only one reader is dead'.
* So we could just disconnect all connections of receiver, and return false.
*/
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&&, E*) {
wrapper->elems()->disconnect_receiver(~static_cast<circ::cc_t>(0u));
return false;
}
template <typename W, typename F, typename R, typename E>
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E* elems) {
auto cur_rd = circ::index_of(rd_.load(std::memory_order_relaxed));
if (cur_rd == circ::index_of(wt_.load(std::memory_order_acquire))) {
return false; // empty
}
std::forward<F>(f)(&(elems[cur_rd].data_));
std::forward<R>(out)(true);
rd_.fetch_add(1, std::memory_order_release);
return true;
}
};
template <>
struct prod_cons_impl<wr<relat::single, relat::multi , trans::unicast>>
: prod_cons_impl<wr<relat::single, relat::single, trans::unicast>> {
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&&, E*) {
wrapper->elems()->disconnect_receiver(1);
return false;
}
template <typename W, typename F, typename R,
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
byte_t buff[DS];
for (unsigned k = 0;;) {
auto cur_rd = rd_.load(std::memory_order_relaxed);
if (circ::index_of(cur_rd) ==
circ::index_of(wt_.load(std::memory_order_acquire))) {
return false; // empty
}
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
std::forward<F>(f)(buff);
std::forward<R>(out)(true);
return true;
}
ipc::yield(k);
}
}
};
template <>
struct prod_cons_impl<wr<relat::multi , relat::multi, trans::unicast>>
: prod_cons_impl<wr<relat::single, relat::multi, trans::unicast>> {
using flag_t = std::uint64_t;
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
};
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
template <typename W, typename F, typename E>
bool push(W* /*wrapper*/, F&& f, E* elems) {
circ::u2_t cur_ct, nxt_ct;
for (unsigned k = 0;;) {
cur_ct = ct_.load(std::memory_order_relaxed);
if (circ::index_of(nxt_ct = cur_ct + 1) ==
circ::index_of(rd_.load(std::memory_order_acquire))) {
return false; // full
}
if (ct_.compare_exchange_weak(cur_ct, nxt_ct, std::memory_order_acq_rel)) {
break;
}
ipc::yield(k);
}
auto* el = elems + circ::index_of(cur_ct);
std::forward<F>(f)(&(el->data_));
// set flag & try update wt
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
while (1) {
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
if (cur_ct != wt_.load(std::memory_order_relaxed)) {
return true;
}
if ((~cac_ct) != cur_ct) {
return true;
}
if (!el->f_ct_.compare_exchange_strong(cac_ct, 0, std::memory_order_relaxed)) {
return true;
}
wt_.store(nxt_ct, std::memory_order_release);
cur_ct = nxt_ct;
nxt_ct = cur_ct + 1;
el = elems + circ::index_of(cur_ct);
}
return true;
}
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&&, E*) {
wrapper->elems()->disconnect_receiver(1);
return false;
}
template <typename W, typename F, typename R,
template <std::size_t, std::size_t> class E, std::size_t DS, std::size_t AS>
bool pop(W* /*wrapper*/, circ::u2_t& /*cur*/, F&& f, R&& out, E<DS, AS>* elems) {
byte_t buff[DS];
for (unsigned k = 0;;) {
auto cur_rd = rd_.load(std::memory_order_relaxed);
auto cur_wt = wt_.load(std::memory_order_acquire);
auto id_rd = circ::index_of(cur_rd);
auto id_wt = circ::index_of(cur_wt);
if (id_rd == id_wt) {
auto* el = elems + id_wt;
auto cac_ct = el->f_ct_.load(std::memory_order_acquire);
if ((~cac_ct) != cur_wt) {
return false; // empty
}
if (el->f_ct_.compare_exchange_weak(cac_ct, 0, std::memory_order_relaxed)) {
wt_.store(cur_wt + 1, std::memory_order_release);
}
k = 0;
}
else {
std::memcpy(buff, &(elems[circ::index_of(cur_rd)].data_), sizeof(buff));
if (rd_.compare_exchange_weak(cur_rd, cur_rd + 1, std::memory_order_release)) {
std::forward<F>(f)(buff);
std::forward<R>(out)(true);
return true;
}
ipc::yield(k);
}
}
}
};
template <>
struct prod_cons_impl<wr<relat::single, relat::multi, trans::broadcast>> {
using rc_t = std::uint64_t;
enum : rc_t {
ep_mask = 0x00000000ffffffffull,
ep_incr = 0x0000000100000000ull
};
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
std::atomic<rc_t> rc_ { 0 }; // read-counter
};
alignas(cache_line_size) std::atomic<circ::u2_t> wt_; // write index
alignas(cache_line_size) rc_t epoch_ { 0 }; // only one writer
circ::u2_t cursor() const noexcept {
return wt_.load(std::memory_order_acquire);
}
template <typename W, typename F, typename E>
bool push(W* wrapper, F&& f, E* elems) {
E* el;
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_acquire);
circ::cc_t rem_cc = cur_rc & ep_mask;
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch_)) {
return false; // has not finished yet
}
// consider rem_cc to be 0 here
if (el->rc_.compare_exchange_weak(
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
break;
}
ipc::yield(k);
}
std::forward<F>(f)(&(el->data_));
wt_.fetch_add(1, std::memory_order_release);
return true;
}
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&& f, E* elems) {
E* el;
epoch_ += ep_incr;
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(wt_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_acquire);
circ::cc_t rem_cc = cur_rc & ep_mask;
if (cc & rem_cc) {
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
if (cc == 0) return false; // no reader
}
// just compare & exchange
if (el->rc_.compare_exchange_weak(
cur_rc, epoch_ | static_cast<rc_t>(cc), std::memory_order_release)) {
break;
}
ipc::yield(k);
}
std::forward<F>(f)(&(el->data_));
wt_.fetch_add(1, std::memory_order_release);
return true;
}
template <typename W, typename F, typename R, typename E>
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E* elems) {
if (cur == cursor()) return false; // acquire
auto* el = elems + circ::index_of(cur++);
std::forward<F>(f)(&(el->data_));
for (unsigned k = 0;;) {
auto cur_rc = el->rc_.load(std::memory_order_acquire);
if ((cur_rc & ep_mask) == 0) {
std::forward<R>(out)(true);
return true;
}
auto nxt_rc = cur_rc & ~static_cast<rc_t>(wrapper->connected_id());
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
std::forward<R>(out)((nxt_rc & ep_mask) == 0);
return true;
}
ipc::yield(k);
}
}
};
template <>
struct prod_cons_impl<wr<relat::multi, relat::multi, trans::broadcast>> {
using rc_t = std::uint64_t;
using flag_t = std::uint64_t;
enum : rc_t {
rc_mask = 0x00000000ffffffffull,
ep_mask = 0x00ffffffffffffffull,
ep_incr = 0x0100000000000000ull,
ic_mask = 0xff000000ffffffffull,
ic_incr = 0x0000000100000000ull
};
template <std::size_t DataSize, std::size_t AlignSize>
struct elem_t {
std::aligned_storage_t<DataSize, AlignSize> data_ {};
std::atomic<rc_t > rc_ { 0 }; // read-counter
std::atomic<flag_t> f_ct_ { 0 }; // commit flag
};
alignas(cache_line_size) std::atomic<circ::u2_t> ct_; // commit index
alignas(cache_line_size) std::atomic<rc_t> epoch_ { 0 };
circ::u2_t cursor() const noexcept {
return ct_.load(std::memory_order_acquire);
}
constexpr static rc_t inc_rc(rc_t rc) noexcept {
return (rc & ic_mask) | ((rc + ic_incr) & ~ic_mask);
}
constexpr static rc_t inc_mask(rc_t rc) noexcept {
return inc_rc(rc) & ~rc_mask;
}
template <typename W, typename F, typename E>
bool push(W* wrapper, F&& f, E* elems) {
E* el;
circ::u2_t cur_ct;
rc_t epoch = epoch_.load(std::memory_order_acquire);
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_relaxed);
circ::cc_t rem_cc = cur_rc & rc_mask;
if ((cc & rem_cc) && ((cur_rc & ~ep_mask) == epoch)) {
return false; // has not finished yet
}
else if (!rem_cc) {
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
if ((cur_fl != cur_ct) && cur_fl) {
return false; // full
}
}
// consider rem_cc to be 0 here
if (el->rc_.compare_exchange_weak(
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed) &&
epoch_.compare_exchange_weak(epoch, epoch, std::memory_order_acq_rel)) {
break;
}
ipc::yield(k);
}
// only one thread/process would touch here at one time
ct_.store(cur_ct + 1, std::memory_order_release);
std::forward<F>(f)(&(el->data_));
// set flag & try update wt
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
return true;
}
template <typename W, typename F, typename E>
bool force_push(W* wrapper, F&& f, E* elems) {
E* el;
circ::u2_t cur_ct;
rc_t epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
for (unsigned k = 0;;) {
circ::cc_t cc = wrapper->elems()->connections(std::memory_order_relaxed);
if (cc == 0) return false; // no reader
el = elems + circ::index_of(cur_ct = ct_.load(std::memory_order_relaxed));
// check all consumers have finished reading this element
auto cur_rc = el->rc_.load(std::memory_order_acquire);
circ::cc_t rem_cc = cur_rc & rc_mask;
if (cc & rem_cc) {
ipc::log("force_push: k = %u, cc = %u, rem_cc = %u\n", k, cc, rem_cc);
cc = wrapper->elems()->disconnect_receiver(rem_cc); // disconnect all invalid readers
if (cc == 0) return false; // no reader
}
// just compare & exchange
if (el->rc_.compare_exchange_weak(
cur_rc, inc_mask(epoch | (cur_rc & ep_mask)) | static_cast<rc_t>(cc), std::memory_order_relaxed)) {
if (epoch == epoch_.load(std::memory_order_acquire)) {
break;
}
else if (push(wrapper, std::forward<F>(f), elems)) {
return true;
}
epoch = epoch_.fetch_add(ep_incr, std::memory_order_release) + ep_incr;
}
ipc::yield(k);
}
// only one thread/process would touch here at one time
ct_.store(cur_ct + 1, std::memory_order_release);
std::forward<F>(f)(&(el->data_));
// set flag & try update wt
el->f_ct_.store(~static_cast<flag_t>(cur_ct), std::memory_order_release);
return true;
}
template <typename W, typename F, typename R, typename E, std::size_t N>
bool pop(W* wrapper, circ::u2_t& cur, F&& f, R&& out, E(& elems)[N]) {
auto* el = elems + circ::index_of(cur);
auto cur_fl = el->f_ct_.load(std::memory_order_acquire);
if (cur_fl != ~static_cast<flag_t>(cur)) {
return false; // empty
}
++cur;
std::forward<F>(f)(&(el->data_));
for (unsigned k = 0;;) {
auto cur_rc = el->rc_.load(std::memory_order_acquire);
if ((cur_rc & rc_mask) == 0) {
std::forward<R>(out)(true);
el->f_ct_.store(cur + N - 1, std::memory_order_release);
return true;
}
auto nxt_rc = inc_rc(cur_rc) & ~static_cast<rc_t>(wrapper->connected_id());
bool last_one = false;
if ((last_one = (nxt_rc & rc_mask) == 0)) {
el->f_ct_.store(cur + N - 1, std::memory_order_release);
}
if (el->rc_.compare_exchange_weak(cur_rc, nxt_rc, std::memory_order_release)) {
std::forward<R>(out)(last_one);
return true;
}
ipc::yield(k);
}
}
};
} // namespace ipc

查看文件

@@ -0,0 +1,58 @@
The goal of reducing sequential computation also forms the foundation of the Extended Neural GPU \citep{extendedngpu}, ByteNet \citep{NalBytenet2017} and ConvS2S \citep{JonasFaceNet2017}, all of which use convolutional neural networks as basic building block, computing hidden representations in parallel for all input and output positions. In these models, the number of operations required to relate signals from two arbitrary input or output positions grows in the distance between positions, linearly for ConvS2S and logarithmically for ByteNet. This makes it more difficult to learn dependencies between distant positions \citep{hochreiter2001gradient}. In the Transformer this is reduced to a constant number of operations, albeit at the cost of reduced effective resolution due to averaging attention-weighted positions, an effect we counteract with Multi-Head Attention as described in section~\ref{sec:attention}.
Self-attention, sometimes called intra-attention is an attention mechanism relating different positions of a single sequence in order to compute a representation of the sequence. Self-attention has been used successfully in a variety of tasks including reading comprehension, abstractive summarization, textual entailment and learning task-independent sentence representations \citep{cheng2016long, decomposableAttnModel, paulus2017deep, lin2017structured}.
End-to-end memory networks are based on a recurrent attention mechanism instead of sequence-aligned recurrence and have been shown to perform well on simple-language question answering and language modeling tasks \citep{sukhbaatar2015}.
To the best of our knowledge, however, the Transformer is the first transduction model relying entirely on self-attention to compute representations of its input and output without using sequence-aligned RNNs or convolution.
In the following sections, we will describe the Transformer, motivate self-attention and discuss its advantages over models such as \citep{neural_gpu, NalBytenet2017} and \citep{JonasFaceNet2017}.
%\citep{JonasFaceNet2017} report new SOTA on machine translation for English-to-German (EnDe), Enlish-to-French (EnFr) and English-to-Romanian language pairs.
%For example,! in MT, we must draw information from both input and previous output words to translate an output word accurately. An attention layer \citep{bahdanau2014neural} can connect a very large number of positions at low computation cost, making it an essential ingredient in competitive recurrent models for machine translation.
%A natural question to ask then is, "Could we replace recurrence with attention?". \marginpar{Don't know if it's the most natural question to ask given the previous statements. Also, need to say that the complexity table summarizes these statements} Such a model would be blessed with the computational efficiency of attention and the power of cross-positional communication. In this work, show that pure attention models work remarkably well for MT, achieving new SOTA results on EnDe and EnFr, and can be trained in under $2$ days on xyz architecture.
%After the seminal models introduced in \citep{sutskever14, bahdanau2014neural, cho2014learning}, recurrent models have become the dominant solution for both sequence modeling and sequence-to-sequence transduction. Many efforts such as \citep{wu2016google,luong2015effective,jozefowicz2016exploring} have pushed the boundaries of machine translation (MT) and language modeling with recurrent endoder-decoder and recurrent language models. Recent effort \citep{shazeer2017outrageously} has successfully combined the power of conditional computation with sequence models to train very large models for MT, pushing SOTA at lower computational cost.
%Recurrent models compute a vector of hidden states $h_t$, for each time step $t$ of computation. $h_t$ is a function of both the input at time $t$ and the previous hidden state $h_t$. This dependence on the previous hidden state precludes processing all timesteps at once, instead requiring long sequences of sequential operations. In practice, this results in greatly reduced computational efficiency, as on modern computing hardware, a single operation on a large batch is much faster than a large number of operations on small batches. The problem gets worse at longer sequence lengths. Although sequential computation is not a severe bottleneck at inference time, as autoregressively generating each output requires all previous outputs, the inability to compute scores at all output positions at once hinders us from rapidly training our models over large datasets. Although impressive work such as \citep{Kuchaiev2017Factorization} is able to significantly accelerate the training of LSTMs with factorization tricks, we are still bound by the linear dependence on sequence length.
%If the model could compute hidden states at each time step using only the inputs and outputs, it would be liberated from the dependence on results from previous time steps during training. This line of thought is the foundation of recent efforts such as the Markovian neural GPU \citep{neural_gpu}, ByteNet \citep{NalBytenet2017} and ConvS2S \citep{JonasFaceNet2017}, all of which use convolutional neural networks as a building block to compute hidden representations simultaneously for all timesteps, resulting in $O(1)$ sequential time complexity. \citep{JonasFaceNet2017} report new SOTA on machine translation for English-to-German (EnDe), Enlish-to-French (EnFr) and English-to-Romanian language pairs.
%A crucial component for accurate sequence prediction is modeling cross-positional communication. For example, in MT, we must draw information from both input and previous output words to translate an output word accurately. An attention layer \citep{bahdanau2014neural} can connect a very large number of positions at a low computation cost, also $O(1)$ sequential time complexity, making it an essential ingredient in recurrent encoder-decoder architectures for MT. A natural question to ask then is, "Could we replace recurrence with attention?". \marginpar{Don't know if it's the most natural question to ask given the previous statements. Also, need to say that the complexity table summarizes these statements} Such a model would be blessed with the computational efficiency of attention and the power of cross-positional communication. In this work, show that pure attention models work remarkably well for MT, achieving new SOTA results on EnDe and EnFr, and can be trained in under $2$ days on xyz architecture.
%Note: Facebook model is no better than RNNs in this regard, since it requires a number of layers proportional to the distance you want to communicate. Bytenet is more promising, since it requires a logarithmnic number of layers (does bytenet have SOTA results)?
%Note: An attention layer can connect a very large number of positions at a low computation cost in O(1) sequential operations. This is why encoder-decoder attention has been so successful in seq-to-seq models so far. It is only natural, then, to also use attention to connect the timesteps of the same sequence.
%Note: I wouldn't say that long sequences are not a problem during inference. It would be great if we could infer with no long sequences. We could just say later on that, while our training graph is constant-depth, our model still requires sequential operations in the decoder part during inference due to the autoregressive nature of the model.
%\begin{table}[h!]
%\caption{Attention models are quite efficient for cross-positional communications when sequence length is smaller than channel depth. $n$ represents the sequence length and $d$ represents the channel depth.}
%\label{tab:op_complexities}
%\begin{center}
%\vspace{-5pt}
%\scalebox{0.75}{
%\begin{tabular}{l|c|c|c}
%\hline \hline
%Layer Type & Receptive & Complexity & Sequential \\
% & Field & & Operations \\
%\hline
%Pointwise Feed-Forward & $1$ & $O(n \cdot d^2)$ & $O(1)$ \\
%\hline
%Recurrent & $n$ & $O(n \cdot d^2)$ & $O(n)$ \\
%\hline
%Convolutional & $r$ & $O(r \cdot n \cdot d^2)$ & $O(1)$ \\
%\hline
%Convolutional (separable) & $r$ & $O(r \cdot n \cdot d + n %\cdot d^2)$ & $O(1)$ \\
%\hline
%Attention & $r$ & $O(r \cdot n \cdot d)$ & $O(1)$ \\
%\hline \hline
%\end{tabular}
%}
%\end{center}
%\end{table}

查看文件

@@ -0,0 +1,18 @@
Recurrent neural networks, long short-term memory \citep{hochreiter1997} and gated recurrent \citep{gruEval14} neural networks in particular, have been firmly established as state of the art approaches in sequence modeling and transduction problems such as language modeling and machine translation \citep{sutskever14, bahdanau2014neural, cho2014learning}. Numerous efforts have since continued to push the boundaries of recurrent language models and encoder-decoder architectures \citep{wu2016google,luong2015effective,jozefowicz2016exploring}.
Recurrent models typically factor computation along the symbol positions of the input and output sequences. Aligning the positions to steps in computation time, they generate a sequence of hidden states $h_t$, as a function of the previous hidden state $h_{t-1}$ and the input for position $t$. This inherently sequential nature precludes parallelization within training examples, which becomes critical at longer sequence lengths, as memory constraints limit batching across examples.
%\marginpar{not sure if the memory constraints are understandable here}
Recent work has achieved significant improvements in computational efficiency through factorization tricks \citep{Kuchaiev2017Factorization} and conditional computation \citep{shazeer2017outrageously}, while also improving model performance in case of the latter. The fundamental constraint of sequential computation, however, remains.
%\marginpar{@all: there is work on analyzing what attention really does in seq2seq models, couldn't find it right away}
Attention mechanisms have become an integral part of compelling sequence modeling and transduction models in various tasks, allowing modeling of dependencies without regard to their distance in the input or output sequences \citep{bahdanau2014neural, structuredAttentionNetworks}. In all but a few cases \citep{decomposableAttnModel}, however, such attention mechanisms are used in conjunction with a recurrent network.
%\marginpar{not sure if "cross-positional communication" is understandable without explanation}
%\marginpar{insert exact training times and stats for the model that reaches sota earliest, maybe even a single GPU model?}
In this work we propose the Transformer, a model architecture eschewing recurrence and instead relying entirely on an attention mechanism to draw global dependencies between input and output. The Transformer allows for significantly more parallelization and can reach a new state of the art in translation quality after being trained for as little as twelve hours on eight P100 GPUs.
%\marginpar{you removed the constant number of repetitions part. I wrote it because I wanted to make it clear that the model does not only perform attention once, while it's also not recurrent. I thought that might be important to get across early.}
% Just a standard paragraph with citations, rewrite.
%After the seminal papers of \citep{sutskever14}, \citep{bahdanau2014neural}, and \citep{cho2014learning}, recurrent models have become the dominant solution for both sequence modeling and sequence-to-sequence transduction. Many efforts such as \citep{wu2016google,luong2015effective,jozefowicz2016exploring} have pushed the boundaries of machine translation and language modeling with recurrent sequence models. Recent effort \citep{shazeer2017outrageously} has combined the power of conditional computation with sequence models to train very large models for machine translation, pushing SOTA at lower computational cost. Recurrent models compute a vector of hidden states $h_t$, for each time step $t$ of computation. $h_t$ is a function of both the input at time $t$ and the previous hidden state $h_t$. This dependence on the previous hidden state encumbers recurrnet models to process multiple inputs at once, and their time complexity is a linear function of the length of the input and output, both during training and inference. [What I want to say here is that although this is fine during decoding, at training time, we are given both input and output and this linear nature does not allow the RNN to process all inputs and outputs simultaneously and haven't been used on datasets that are the of the scale of the web. What's the largest dataset we have ? . Talk about Nividia and possibly other's effors to speed up things, and possibly other efforts that alleviate this, but are still limited by it's comptuational nature]. Rest of the intro: What if you could construct the state based on the actual inputs and outputs, then you could construct them all at once. This has been the foundation of many promising recent efforts, bytenet,facenet (Also talk about quasi rnn here). Now we talk about attention!! Along with cell architectures such as long short-term meory (LSTM) \citep{hochreiter1997}, and gated recurrent units (GRUs) \citep{cho2014learning}, attention has emerged as an essential ingredient in successful sequence models, in particular for machine translation. In recent years, many, if not all, state-of-the-art (SOTA) results in machine translation have been achieved with attention-based sequence models \citep{wu2016google,luong2015effective,jozefowicz2016exploring}. Talk about the neon work on how it played with attention to do self attention! Then talk about what we do.

查看文件

@@ -0,0 +1,155 @@
\begin{figure}
\centering
\includegraphics[scale=0.6]{Figures/ModalNet-21}
\caption{The Transformer - model architecture.}
\label{fig:model-arch}
\end{figure}
% Although the primary workhorse of our model is attention,
%Our model maintains the encoder-decoder structure that is common to many so-called sequence-to-sequence models \citep{bahdanau2014neural,sutskever14}. As in all such architectures, the encoder computes a representation of the input sequence, and the decoder consumes these representations along with the output tokens to autoregressively produce the output sequence. Where, traditionally, the encoder and decoder contain stacks of recurrent or convolutional layers, our encoder and decoder stacks are composed of attention layers and position-wise feed-forward layers (Figure~\ref{fig:model-arch}). The following sections describe the gross architecture and these particular components in detail.
Most competitive neural sequence transduction models have an encoder-decoder structure \citep{cho2014learning,bahdanau2014neural,sutskever14}. Here, the encoder maps an input sequence of symbol representations $(x_1, ..., x_n)$ to a sequence of continuous representations $\mathbf{z} = (z_1, ..., z_n)$. Given $\mathbf{z}$, the decoder then generates an output sequence $(y_1,...,y_m)$ of symbols one element at a time. At each step the model is auto-regressive \citep{graves2013generating}, consuming the previously generated symbols as additional input when generating the next.
The Transformer follows this overall architecture using stacked self-attention and point-wise, fully connected layers for both the encoder and decoder, shown in the left and right halves of Figure~\ref{fig:model-arch}, respectively.
\subsection{Encoder and Decoder Stacks}
\paragraph{Encoder:}The encoder is composed of a stack of $N=6$ identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network. We employ a residual connection \citep{he2016deep} around each of the two sub-layers, followed by layer normalization \cite{layernorm2016}. That is, the output of each sub-layer is $\mathrm{LayerNorm}(x + \mathrm{Sublayer}(x))$, where $\mathrm{Sublayer}(x)$ is the function implemented by the sub-layer itself. To facilitate these residual connections, all sub-layers in the model, as well as the embedding layers, produce outputs of dimension $\dmodel=512$.
\paragraph{Decoder:}The decoder is also composed of a stack of $N=6$ identical layers. In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization. We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with fact that the output embeddings are offset by one position, ensures that the predictions for position $i$ can depend only on the known outputs at positions less than $i$.
% In our model (Figure~\ref{fig:model-arch}), the encoder and decoder are composed of stacks of alternating self-attention layers (for cross-positional communication) and position-wise feed-forward layers (for in-place computation). In addition, the decoder stack contains encoder-decoder attention layers. Since attention is agnostic to the distances between words, our model requires a "positional encoding" to be added to the encoder and decoder input. The following sections describe all of these components in detail.
\subsection{Attention} \label{sec:attention}
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.
\subsubsection{Scaled Dot-Product Attention} \label{sec:scaled-dot-prod}
% \begin{figure}
% \centering
% \includegraphics[scale=0.6]{Figures/ModalNet-19}
% \caption{Scaled Dot-Product Attention.}
% \label{fig:multi-head-att}
% \end{figure}
We call our particular attention "Scaled Dot-Product Attention" (Figure~\ref{fig:multi-head-att}). The input consists of queries and keys of dimension $d_k$, and values of dimension $d_v$. We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values.
In practice, we compute the attention function on a set of queries simultaneously, packed together into a matrix $Q$. The keys and values are also packed together into matrices $K$ and $V$. We compute the matrix of outputs as:
\begin{equation}
\mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V
\end{equation}
The two most commonly used attention functions are additive attention \citep{bahdanau2014neural}, and dot-product (multiplicative) attention. Dot-product attention is identical to our algorithm, except for the scaling factor of $\frac{1}{\sqrt{d_k}}$. Additive attention computes the compatibility function using a feed-forward network with a single hidden layer. While the two are similar in theoretical complexity, dot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized matrix multiplication code.
%We scale the dot products by $1/\sqrt{d_k}$ to limit the magnitude of the dot products, which works well in practice. Otherwise, we found applying the softmax to often result in weights very close to 0 or 1, and hence minuscule gradients.
% Already described in the subsequent section
%When used as part of decoder self-attention, an optional mask function is applied just before the softmax to prevent positions from attending to subsequent positions. This mask simply sets the logits corresponding to all illegal connections (those outside of the lower triangle) to $-\infty$.
%\paragraph{Comparison to Additive Attention: } We choose dot product attention over additive attention \citep{bahdanau2014neural} since it can be computed using highly optimized matrix multiplication code. This optimization is particularly important to us, as we employ many attention layers in our model.
While for small values of $d_k$ the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of $d_k$ \citep{DBLP:journals/corr/BritzGLL17}. We suspect that for large values of $d_k$, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients \footnote{To illustrate why the dot products get large, assume that the components of $q$ and $k$ are independent random variables with mean $0$ and variance $1$. Then their dot product, $q \cdot k = \sum_{i=1}^{d_k} q_ik_i$, has mean $0$ and variance $d_k$.}. To counteract this effect, we scale the dot products by $\frac{1}{\sqrt{d_k}}$.
%We suspect this to be caused by the dot products growing too large in magnitude to result in useful gradients after applying the softmax function. To counteract this, we scale the dot product by $1/\sqrt{d_k}$.
\subsubsection{Multi-Head Attention} \label{sec:multihead}
\begin{figure}
\begin{minipage}[t]{0.5\textwidth}
\centering
Scaled Dot-Product Attention \\
\vspace{0.5cm}
\includegraphics[scale=0.6]{Figures/ModalNet-19}
\end{minipage}
\begin{minipage}[t]{0.5\textwidth}
\centering
Multi-Head Attention \\
\vspace{0.1cm}
\includegraphics[scale=0.6]{Figures/ModalNet-20}
\end{minipage}
% \centering
\caption{(left) Scaled Dot-Product Attention. (right) Multi-Head Attention consists of several attention layers running in parallel.}
\label{fig:multi-head-att}
\end{figure}
Instead of performing a single attention function with $\dmodel$-dimensional keys, values and queries, we found it beneficial to linearly project the queries, keys and values $h$ times with different, learned linear projections to $d_k$, $d_k$ and $d_v$ dimensions, respectively.
On each of these projected versions of queries, keys and values we then perform the attention function in parallel, yielding $d_v$-dimensional output values. These are concatenated and once again projected, resulting in the final values, as depicted in Figure~\ref{fig:multi-head-att}.
Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.
\begin{align*}
\mathrm{MultiHead}(Q, K, V) &= \mathrm{Concat}(\mathrm{head_1}, ..., \mathrm{head_h})W^O\\
% \mathrm{where} \mathrm{head_i} &= \mathrm{Attention}(QW_Q_i^{\dmodel \times d_q}, KW_K_i^{\dmodel \times d_k}, VW^V_i^{\dmodel \times d_v})\\
\text{where}~\mathrm{head_i} &= \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)\\
\end{align*}
Where the projections are parameter matrices $W^Q_i \in \mathbb{R}^{\dmodel \times d_k}$, $W^K_i \in \mathbb{R}^{\dmodel \times d_k}$, $W^V_i \in \mathbb{R}^{\dmodel \times d_v}$ and $W^O \in \mathbb{R}^{hd_v \times \dmodel}$.
%find it better (and no more expensive) to have multiple parallel attention layers (each over the full set of positions) with proportionally lower-dimensional keys, values and queries. We call this "Multi-Head Attention" (Figure~\ref{fig:multi-head-att}). The keys, values, and queries for each of these parallel attention layers are computed by learned linear transformations of the inputs to the multi-head attention. We use different linear transformations across different parallel attention layers. The output of the parallel attention layers are concatenated, and then passed through a final learned linear transformation.
In this work we employ $h=8$ parallel attention layers, or heads. For each of these we use $d_k=d_v=\dmodel/h=64$.
Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.
\subsubsection{Applications of Attention in our Model}
The Transformer uses multi-head attention in three different ways:
\begin{itemize}
\item In "encoder-decoder attention" layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence. This mimics the typical encoder-decoder attention mechanisms in sequence-to-sequence models such as \citep{wu2016google, bahdanau2014neural,JonasFaceNet2017}.
\item The encoder contains self-attention layers. In a self-attention layer all of the keys, values and queries come from the same place, in this case, the output of the previous layer in the encoder. Each position in the encoder can attend to all positions in the previous layer of the encoder.
\item Similarly, self-attention layers in the decoder allow each position in the decoder to attend to all positions in the decoder up to and including that position. We need to prevent leftward information flow in the decoder to preserve the auto-regressive property. We implement this inside of scaled dot-product attention by masking out (setting to $-\infty$) all values in the input of the softmax which correspond to illegal connections. See Figure~\ref{fig:multi-head-att}.
\end{itemize}
\subsection{Position-wise Feed-Forward Networks}\label{sec:ffn}
In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically. This consists of two linear transformations with a ReLU activation in between.
\begin{equation}
\mathrm{FFN}(x)=\max(0, xW_1 + b_1) W_2 + b_2
\end{equation}
While the linear transformations are the same across different positions, they use different parameters from layer to layer. Another way of describing this is as two convolutions with kernel size 1. The dimensionality of input and output is $\dmodel=512$, and the inner-layer has dimensionality $d_{ff}=2048$.
%In the appendix, we describe how the position-wise feed-forward network can also be seen as a form of attention.
%from Jakob: The number of operations required for the model to relate signals from two arbitrary input or output positions grows in the distance between positions in input or output, linearly for ConvS2S and logarithmically for ByteNet, making it harder to learn dependencies between these positions \citep{hochreiter2001gradient}. In the transformer this is reduced to a constant number of operations, albeit at the cost of effective resolution caused by averaging attention-weighted positions, an effect we aim to counteract with multi-headed attention.
%Figure~\ref{fig:simple-att} presents a simple attention function, $A$, with a single head, that forms the basis of our multi-head attention. $A$ takes a query key vector $\kq$, matrices of memory keys $\km$ and memory values $\vm$ ,and produces a query value vector $\vq$ as
%\begin{equation*} \label{eq:attention}
% A(\kq, \km, \vm) = {\vm}^T (Softmax(\km \kq).
%\end{equation*}
%We linearly transform $\kq,\,\km$, and $\vm$ with learned matrices ${\Wkq \text{,} \, \Wkm}$, and ${\Wvm}$ before calling the attention function, and transform the output query with $\Wvq$ before handing it to the feed forward layer. Each attention layer has it's own set of transformation matrices, which are shared across all query positions. $A$ is applied in parallel for each query position, and is implemented very efficiently as a batch of matrix multiplies. The self-attention and encoder-decoder attention layers use $A$, but with different arguments. For example, in encdoder self-attention, queries in encoder layer $i$ attention to memories in encoder layer $i-1$. To ensure that decoder self-attention layers do not look at future words, we add $- \inf$ to the softmax logits in positions $j+1$ to query length for query position $l$.
%In simple attention, the query value is a weighted combination of the memory values where the attention weights sum to one. Although this function performs well in practice, the constraint on attention weights can restrict the amount of information that flows from memories to queries because the query cannot focus on multiple memory positions at once, which might be desirable when translating long sequences. \marginpar{@usz, could you think of an example of this ?} We remedy this by maintaining multiple attention heads at each query position that attend to all memory positions in parallel, with a different set of parameters per attention head $h$.
%\marginpar{}
\subsection{Embeddings and Softmax}
Similarly to other sequence transduction models, we use learned embeddings to convert the input tokens and output tokens to vectors of dimension $\dmodel$. We also use the usual learned linear transformation and softmax function to convert the decoder output to predicted next-token probabilities. In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation, similar to \citep{press2016using}. In the embedding layers, we multiply those weights by $\sqrt{\dmodel}$.
\subsection{Positional Encoding}
Since our model contains no recurrence and no convolution, in order for the model to make use of the order of the sequence, we must inject some information about the relative or absolute position of the tokens in the sequence. To this end, we add "positional encodings" to the input embeddings at the bottoms of the encoder and decoder stacks. The positional encodings have the same dimension $\dmodel$ as the embeddings, so that the two can be summed. There are many choices of positional encodings, learned and fixed \citep{JonasFaceNet2017}.
In this work, we use sine and cosine functions of different frequencies:
\begin{align*}
PE_{(pos,2i)} = sin(pos / 10000^{2i/\dmodel}) \\
PE_{(pos,2i+1)} = cos(pos / 10000^{2i/\dmodel})
\end{align*}
where $pos$ is the position and $i$ is the dimension. That is, each dimension of the positional encoding corresponds to a sinusoid. The wavelengths form a geometric progression from $2\pi$ to $10000 \cdot 2\pi$. We chose this function because we hypothesized it would allow the model to easily learn to attend by relative positions, since for any fixed offset $k$, $PE_{pos+k}$ can be represented as a linear function of $PE_{pos}$.
We also experimented with using learned positional embeddings \citep{JonasFaceNet2017} instead, and found that the two versions produced nearly identical results (see Table~\ref{tab:variations} row (E)). We chose the sinusoidal version because it may allow the model to extrapolate to sequence lengths longer than the ones encountered during training.

查看文件

@@ -0,0 +1,45 @@
\pagebreak
\section*{Two Feed-Forward Layers = Attention over Parameters}\label{sec:parameter_attention}
In addition to attention layers, our model contains position-wise feed-forward networks (Section \ref{sec:ffn}), which consist of two linear transformations with a ReLU activation in between. In fact, these networks too can be seen as a form of attention. Compare the formula for such a network with the formula for a simple dot-product attention layer (biases and scaling factors omitted):
\begin{align*}
FFN(x, W_1, W_2) = ReLU(xW_1)W_2 \\
A(q, K, V) = Softmax(qK^T)V
\end{align*}
Based on the similarity of these formulae, the two-layer feed-forward network can be seen as a kind of attention, where the keys and values are the rows of the trainable parameter matrices $W_1$ and $W_2$, and where we use ReLU instead of Softmax in the compatibility function.
%the compatablity function is $compat(q, k_i) = ReLU(q \cdot k_i)$ instead of $Softmax(qK_T)_i$.
Given this similarity, we experimented with replacing the position-wise feed-forward networks with attention layers similar to the ones we use everywhere else our model. The multi-head-attention-over-parameters sublayer is identical to the multi-head attention described in \ref{sec:multihead}, except that the "keys" and "values" inputs to each attention head are trainable model parameters, as opposed to being linear projections of a previous layer. These parameters are scaled up by a factor of $\sqrt{d_{model}}$ in order to be more similar to activations.
In our first experiment, we replaced each position-wise feed-forward network with a multi-head-attention-over-parameters sublayer with $h_p=8$ heads, key-dimensionality $d_{pk}=64$, and value-dimensionality $d_{pv}=64$, using $n_p=1536$ key-value pairs for each attention head. The sublayer has a total of $2097152$ parameters, including the parameters in the query projection and the output projection. This matches the number of parameters in the position-wise feed-forward network that we replaced. While the theoretical amount of computation is also the same, in practice, the attention version caused the step times to be about 30\% longer.
In our second experiment, we used $h_p=8$ heads, and $n_p=512$ key-value pairs for each attention head, again matching the total number of parameters in the base model.
Results for the first experiment were slightly worse than for the base model, and results for the second experiment were slightly better, see Table~\ref{tab:parameter_attention}.
\begin{table}[h]
\caption{Replacing the position-wise feed-forward networks with multihead-attention-over-parameters produces similar results to the base model. All metrics are on the English-to-German translation development set, newstest2013.}
\label{tab:parameter_attention}
\begin{center}
\vspace{-2mm}
%\scalebox{1.0}{
\begin{tabular}{c|cccccc|cccc}
\hline\rule{0pt}{2.0ex}
& \multirow{2}{*}{$\dmodel$} & \multirow{2}{*}{$\dff$} &
\multirow{2}{*}{$h_p$} & \multirow{2}{*}{$d_{pk}$} & \multirow{2}{*}{$d_{pv}$} &
\multirow{2}{*}{$n_p$} &
PPL & BLEU & params & training\\
& & & & & & & (dev) & (dev) & $\times10^6$ & time \\
\hline\rule{0pt}{2.0ex}
base & 512 & 2048 & & & & & 4.92 & 25.8 & 65 & 12 hours\\
\hline\rule{0pt}{2.0ex}
AOP$_1$ & 512 & & 8 & 64 & 64 & 1536 & 4.92& 25.5 & 65 & 16 hours\\
AOP$_2$ & 512 & & 16 & 64 & 64 & 512 & \textbf{4.86} & \textbf{25.9} & 65 & 16 hours \\
\hline
\end{tabular}
%}
\end{center}
\end{table}

查看文件

@@ -0,0 +1,8 @@
chatgpt的老祖宗《Attention is all you need》
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin
真实的摘要如下
The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
https://arxiv.org/abs/1706.03762

查看文件

@@ -0,0 +1,2 @@
from stable_baselines3.dqn.dqn import DQN
from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy

查看文件

@@ -0,0 +1,245 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union
import gym
import numpy as np
import torch as th
from torch.nn import functional as F
from stable_baselines3.common import logger
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import DQNPolicy
class DQN(OffPolicyAlgorithm):
"""
Deep Q-Network (DQN)
Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236
Default hyperparameters are taken from the nature paper,
except for the optimizer and learning rate that were taken from Stable Baselines defaults.
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param target_update_interval: update the target network every ``target_update_interval``
environment steps.
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
:param exploration_initial_eps: initial value of random action probability
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param create_eval_env: Whether to create a second environment that will be
used for evaluating the agent periodically. (Only available when passing string for the environment)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""
def __init__(
self,
policy: Union[str, Type[DQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 1e-4,
buffer_size: int = 1000000,
learning_starts: int = 50000,
batch_size: Optional[int] = 32,
tau: float = 1.0,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
tensorboard_log: Optional[str] = None,
create_eval_env: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super(DQN, self).__init__(
policy,
env,
DQNPolicy,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
action_noise=None, # No action noise
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
create_eval_env=create_eval_env,
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(gym.spaces.Discrete,),
)
self.exploration_initial_eps = exploration_initial_eps
self.exploration_final_eps = exploration_final_eps
self.exploration_fraction = exploration_fraction
self.target_update_interval = target_update_interval
self.max_grad_norm = max_grad_norm
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0
# Linear schedule will be defined in `_setup_model()`
self.exploration_schedule = None
self.q_net, self.q_net_target = None, None
if _init_setup_model:
self._setup_model()
def _setup_model(self) -> None:
super(DQN, self)._setup_model()
self._create_aliases()
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)
def _create_aliases(self) -> None:
self.q_net = self.policy.q_net
self.q_net_target = self.policy.q_net_target
def _on_step(self) -> None:
"""
Update the exploration rate and target network if needed.
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
if self.num_timesteps % self.target_update_interval == 0:
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
logger.record("rollout/exploration rate", self.exploration_rate)
def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)
losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
with th.no_grad():
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
# Get current Q-values estimates
current_q_values = self.q_net(replay_data.observations)
# Retrieve the q-values for the actions from the replay buffer
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())
# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
losses.append(loss.item())
# Optimize the policy
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()
# Increase update counter
self._n_updates += gradient_steps
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/loss", np.mean(losses))
def predict(
self,
observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[np.ndarray]]:
"""
Overrides the base_class predict function to include epsilon-greedy exploration.
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, mask, deterministic)
return action, state
def learn(
self,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "DQN",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> OffPolicyAlgorithm:
return super(DQN, self).learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
eval_env=eval_env,
eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name,
eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps,
)
def _excluded_save_params(self) -> List[str]:
return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"]
def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]
return state_dicts, []

查看文件

@@ -0,0 +1,237 @@
from typing import Any, Dict, List, Optional, Type
import gym
import torch as th
from torch import nn
from stable_baselines3.common.policies import BasePolicy, register_policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp
from stable_baselines3.common.type_aliases import Schedule
class QNetwork(BasePolicy):
"""
Action-Value (Q-Value) network for DQN
:param observation_space: Observation space
:param action_space: Action space
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
features_extractor: nn.Module,
features_dim: int,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
normalize_images: bool = True,
):
super(QNetwork, self).__init__(
observation_space,
action_space,
features_extractor=features_extractor,
normalize_images=normalize_images,
)
if net_arch is None:
net_arch = [64, 64]
self.net_arch = net_arch
self.activation_fn = activation_fn
self.features_extractor = features_extractor
self.features_dim = features_dim
self.normalize_images = normalize_images
action_dim = self.action_space.n # number of actions
q_net = create_mlp(self.features_dim, action_dim, self.net_arch, self.activation_fn)
self.q_net = nn.Sequential(*q_net)
def forward(self, obs: th.Tensor) -> th.Tensor:
"""
Predict the q-values.
:param obs: Observation
:return: The estimated Q-Value for each action.
"""
return self.q_net(self.extract_features(obs))
def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
q_values = self.forward(observation)
# Greedy action
action = q_values.argmax(dim=1).reshape(-1)
return action
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_arch,
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor,
)
)
return data
class DQNPolicy(BasePolicy):
"""
Policy class with Q-Value Net and target net for DQN
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param features_extractor_kwargs: Keyword arguments
to pass to the features extractor.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(DQNPolicy, self).__init__(
observation_space,
action_space,
features_extractor_class,
features_extractor_kwargs,
optimizer_class=optimizer_class,
optimizer_kwargs=optimizer_kwargs,
)
if net_arch is None:
if features_extractor_class == FlattenExtractor:
net_arch = [64, 64]
else:
net_arch = []
self.net_arch = net_arch
self.activation_fn = activation_fn
self.normalize_images = normalize_images
self.net_args = {
"observation_space": self.observation_space,
"action_space": self.action_space,
"net_arch": self.net_arch,
"activation_fn": self.activation_fn,
"normalize_images": normalize_images,
}
self.q_net, self.q_net_target = None, None
self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None:
"""
Create the network and the optimizer.
:param lr_schedule: Learning rate schedule
lr_schedule(1) is the initial learning rate
"""
self.q_net = self.make_q_net()
self.q_net_target = self.make_q_net()
self.q_net_target.load_state_dict(self.q_net.state_dict())
# Setup optimizer with initial learning rate
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
def make_q_net(self) -> QNetwork:
# Make sure we always have separate networks for features extractors etc
net_args = self._update_features_extractor(self.net_args, features_extractor=None)
return QNetwork(**net_args).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self._predict(obs, deterministic=deterministic)
def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
return self.q_net._predict(obs, deterministic=deterministic)
def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters()
data.update(
dict(
net_arch=self.net_args["net_arch"],
activation_fn=self.net_args["activation_fn"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone
optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class,
features_extractor_kwargs=self.features_extractor_kwargs,
)
)
return data
MlpPolicy = DQNPolicy
class CnnPolicy(DQNPolicy):
"""
Policy class for DQN when using images as input.
:param observation_space: Observation space
:param action_space: Action space
:param lr_schedule: Learning rate schedule (could be constant)
:param net_arch: The specification of the policy and value networks.
:param activation_fn: Activation function
:param features_extractor_class: Features extractor to use.
:param normalize_images: Whether to normalize images or not,
dividing by 255.0 (True by default)
:param optimizer_class: The optimizer to use,
``th.optim.Adam`` by default
:param optimizer_kwargs: Additional keyword arguments,
excluding the learning rate, to pass to the optimizer
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[int]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super(CnnPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
register_policy("MlpPolicy", MlpPolicy)
register_policy("CnnPolicy", CnnPolicy)

查看文件

@@ -0,0 +1,2 @@
github stablebaseline3
https://github.com/DLR-RM/stable-baselines3

查看文件

@@ -0,0 +1,27 @@
"In practice, we found that a high-entropy initial state is more likely to increase the speed of training.
The entropy is calculated by:
$$H=-\sum_{k= 1}^{n_k} p(k) \cdot \log p(k), p(k)=\frac{|A_k|}{|\mathcal{A}|}$$
where $H$ is the entropy, $|A_k|$ is the number of agent nodes in $k$-th cluster, $|\mathcal{A}|$ is the total number of agents.
To ensure the Cooperation Graph initialization has higher entropy,
we will randomly generate multiple initial states,
rank by their entropy and then pick the one with maximum $H$."
```
FROM ubuntu:latest
RUN apt-get update && \
apt-get install -y python3 python3-pip && \
rm -rf /var/lib/apt/lists/*
RUN echo '[global]' > /etc/pip.conf && \
echo 'index-url = https://mirrors.aliyun.com/pypi/simple/' >> /etc/pip.conf && \
echo 'trusted-host = mirrors.aliyun.com' >> /etc/pip.conf
RUN pip3 install gradio requests[socks] mdtex2html
COPY . /gpt
WORKDIR /gpt
CMD ["python3", "main.py"]
```

查看文件

@@ -1,17 +1,16 @@
# From project chatglm-langchain
import threading
from toolbox import Singleton
import os
import shutil
import os
import uuid
import tqdm
import shutil
import threading
import numpy as np
from toolbox import Singleton
from loguru import logger
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from typing import List, Tuple
import numpy as np
from crazy_functions.vector_fns.general_file_loader import load_file
embedding_model_dict = {
@@ -29,7 +28,7 @@ EMBEDDING_DEVICE = "cpu"
# 基于上下文的prompt模版,请务必保留"{question}"和"{context}"
PROMPT_TEMPLATE = """已知信息:
{context}
{context}
根据上述已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 “根据已知信息无法回答该问题” 或 “没有提供足够的相关信息”,不允许在答案中添加编造成分,答案请使用中文。 问题是:{question}"""
@@ -59,7 +58,7 @@ OPEN_CROSS_DOMAIN = False
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
) -> List[Tuple[Document, float]]:
def seperate_list(ls: List[int]) -> List[List[int]]:
lists = []
ls1 = [ls[0]]
@@ -151,17 +150,17 @@ class LocalDocQA:
failed_files = []
if isinstance(filepath, str):
if not os.path.exists(filepath):
logger.error("路径不存在")
print("路径不存在")
return None
elif os.path.isfile(filepath):
file = os.path.split(filepath)[-1]
try:
docs = load_file(filepath, SENTENCE_SIZE)
logger.info(f"{file} 已成功加载")
print(f"{file} 已成功加载")
loaded_files.append(filepath)
except Exception as e:
logger.error(e)
logger.error(f"{file} 未能成功加载")
print(e)
print(f"{file} 未能成功加载")
return None
elif os.path.isdir(filepath):
docs = []
@@ -171,23 +170,23 @@ class LocalDocQA:
docs += load_file(fullfilepath, SENTENCE_SIZE)
loaded_files.append(fullfilepath)
except Exception as e:
logger.error(e)
print(e)
failed_files.append(file)
if len(failed_files) > 0:
logger.error("以下文件未能成功加载:")
print("以下文件未能成功加载:")
for file in failed_files:
logger.error(f"{file}\n")
print(f"{file}\n")
else:
docs = []
for file in filepath:
docs += load_file(file, SENTENCE_SIZE)
logger.info(f"{file} 已成功加载")
print(f"{file} 已成功加载")
loaded_files.append(file)
if len(docs) > 0:
logger.info("文件加载完毕,正在生成向量库")
print("文件加载完毕,正在生成向量库")
if vs_path and os.path.isdir(vs_path):
try:
self.vector_store = FAISS.load_local(vs_path, text2vec)
@@ -201,7 +200,7 @@ class LocalDocQA:
return vs_path, loaded_files
else:
raise RuntimeError("文件加载失败,请检查文件格式是否正确")
def get_loaded_file(self, vs_path):
ds = self.vector_store.docstore
return set([ds._dict[k].metadata['source'].split(vs_path)[-1] for k in ds._dict])
@@ -234,7 +233,7 @@ class LocalDocQA:
prompt += "\n\n".join([f"({k}): " + doc.page_content for k, doc in enumerate(related_docs_with_score)])
prompt += "\n\n---\n\n"
prompt = prompt.encode('utf-8', 'ignore').decode() # avoid reading non-utf8 chars
# logger.info(prompt)
# print(prompt)
response = {"query": query, "source_documents": related_docs_with_score}
return response, prompt
@@ -263,7 +262,7 @@ def construct_vector_store(vs_id, vs_path, files, sentence_size, history, one_co
else:
pass
# file_status = "文件未成功加载,请重新上传文件"
# logger.info(file_status)
# print(file_status)
return local_doc_qa, vs_path
@Singleton
@@ -279,7 +278,7 @@ class knowledge_archive_interface():
if self.text2vec_large_chinese is None:
# < -------------------预热文本向量化模组--------------- >
from toolbox import ProxyNetworkActivate
logger.info('Checking Text2vec ...')
print('Checking Text2vec ...')
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
with ProxyNetworkActivate('Download_LLM'): # 临时地激活代理网络
self.text2vec_large_chinese = HuggingFaceEmbeddings(model_name="GanymedeNil/text2vec-large-chinese")
@@ -291,10 +290,10 @@ class knowledge_archive_interface():
self.threadLock.acquire()
# import uuid
self.current_id = id
self.qa_handle, self.kai_path = construct_vector_store(
vs_id=self.current_id,
self.qa_handle, self.kai_path = construct_vector_store(
vs_id=self.current_id,
vs_path=vs_path,
files=file_manifest,
files=file_manifest,
sentence_size=100,
history=[],
one_conent="",
@@ -305,7 +304,7 @@ class knowledge_archive_interface():
def get_current_archive_id(self):
return self.current_id
def get_loaded_file(self, vs_path):
return self.qa_handle.get_loaded_file(vs_path)
@@ -313,10 +312,10 @@ class knowledge_archive_interface():
self.threadLock.acquire()
if not self.current_id == id:
self.current_id = id
self.qa_handle, self.kai_path = construct_vector_store(
vs_id=self.current_id,
self.qa_handle, self.kai_path = construct_vector_store(
vs_id=self.current_id,
vs_path=vs_path,
files=[],
files=[],
sentence_size=100,
history=[],
one_conent="",
@@ -330,7 +329,7 @@ class knowledge_archive_interface():
query = txt,
vs_path = self.kai_path,
score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
vector_search_top_k=VECTOR_SEARCH_TOP_K,
vector_search_top_k=VECTOR_SEARCH_TOP_K,
chunk_conent=True,
chunk_size=CHUNK_SIZE,
text2vec = self.get_chinese_text2vec(),

某些文件未显示,因为此 diff 中更改的文件太多 显示更多