这个提交包含在:
Your Name
2023-03-29 21:44:59 +08:00
父节点 0f28564fea
当前提交 92d4400d19
共有 11 个文件被更改,包括 372 次插入80 次删除

查看文件

@@ -1,5 +1,16 @@
# 借鉴了 https://github.com/GaiZhenbiao/ChuanhuChatGPT 项目
"""
该文件中主要包含三个函数
不具备多线程能力的函数:
1. predict: 正常对话时使用,具备完备的交互功能,不可多线程
具备多线程调用能力的函数
2. predict_no_ui高级实验性功能模块调用,不会实时显示在界面上,参数简单,可以多线程并行,方便实现复杂的功能逻辑
3. predict_no_ui_long_connection在实验过程中发现调用predict_no_ui处理长文档时,和openai的连接容易断掉,这个函数用stream的方式解决这个问题,同样支持多线程
"""
import json
import gradio as gr
import logging
@@ -25,7 +36,7 @@ def get_full_error(chunk, stream_response):
break
return chunk
def predict_no_ui(inputs, top_p, temperature, history=[]):
def predict_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""):
"""
发送至chatGPT,等待回复,一次性完成,不显示中间过程。
predict函数的简化版。
@@ -36,7 +47,7 @@ def predict_no_ui(inputs, top_p, temperature, history=[]):
history 是之前的对话列表
注意无论是inputs还是history,内容太长了都会触发token数量溢出的错误,然后raise ConnectionAbortedError
"""
headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt="", stream=False)
headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt=sys_prompt, stream=False)
retry = 0
while True:
@@ -47,8 +58,8 @@ def predict_no_ui(inputs, top_p, temperature, history=[]):
except requests.exceptions.ReadTimeout as e:
retry += 1
traceback.print_exc()
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
if retry > MAX_RETRY: raise TimeoutError
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
try:
result = json.loads(response.text)["choices"][0]["message"]["content"]
@@ -58,6 +69,41 @@ def predict_no_ui(inputs, top_p, temperature, history=[]):
raise ConnectionAbortedError("Json解析不合常规,可能是文本过长" + response.text)
def predict_no_ui_long_connection(inputs, top_p, temperature, history=[], sys_prompt=""):
"""
发送至chatGPT,等待回复,一次性完成,不显示中间过程。但内部用stream的方法避免有人中途掐网线。
"""
headers, payload = generate_payload(inputs, top_p, temperature, history, system_prompt=sys_prompt, stream=True)
retry = 0
while True:
try:
# make a POST request to the API endpoint, stream=False
response = requests.post(API_URL, headers=headers, proxies=proxies,
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
except requests.exceptions.ReadTimeout as e:
retry += 1
traceback.print_exc()
if retry > MAX_RETRY: raise TimeoutError
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
stream_response = response.iter_lines()
result = ''
while True:
try: chunk = next(stream_response).decode()
except StopIteration: break
if len(chunk)==0: continue
if not chunk.startswith('data:'):
chunk = get_full_error(chunk.encode('utf8'), stream_response)
raise ConnectionAbortedError("OpenAI拒绝了请求:" + chunk.decode())
delta = json.loads(chunk.lstrip('data:'))['choices'][0]["delta"]
if len(delta) == 0: break
if "role" in delta: continue
if "content" in delta: result += delta["content"]; print(delta["content"], end='')
else: raise RuntimeError("意外Json结构"+delta)
return result
def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt='',
stream = True, additional_fn=None):
"""
@@ -130,7 +176,7 @@ def predict(inputs, top_p, temperature, chatbot=[], history=[], system_prompt=''
chunk = get_full_error(chunk, stream_response)
error_msg = chunk.decode()
if "reduce the length" in error_msg:
chatbot[-1] = (chatbot[-1][0], "[Local Message] Input (or history) is too long, please reduce input or clear history by refleshing this page.")
chatbot[-1] = (chatbot[-1][0], "[Local Message] Input (or history) is too long, please reduce input or clear history by refreshing this page.")
history = []
elif "Incorrect API key" in error_msg:
chatbot[-1] = (chatbot[-1][0], "[Local Message] Incorrect API key provided.")