修复Azure OpenAI接口的各种bug

这个提交包含在:
qingxu fu
2023-07-07 10:42:38 +08:00
父节点 bb1d5a61c0
当前提交 9c0bc48420
共有 4 个文件被更改,包括 40 次插入74 次删除

查看文件

@@ -14,7 +14,8 @@ import traceback
import importlib
import openai
import time
import requests
import json
# 读取config.py文件中关于AZURE OPENAI API的信息
from toolbox import get_conf, update_ui, clip_history, trimmed_format_exc
@@ -43,7 +44,6 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
additional_fn代表点击的哪个按钮,按钮见functional.py
"""
print(llm_kwargs["llm_model"])
if additional_fn is not None:
import core_functional
@@ -56,7 +56,6 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
logging.info(f'[raw_input] {raw_input}')
chatbot.append((inputs, ""))
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
payload = generate_azure_payload(inputs, llm_kwargs, history, system_prompt, stream)
@@ -64,20 +63,22 @@ def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_promp
retry = 0
while True:
try:
try:
openai.api_type = "azure"
openai.api_version = AZURE_API_VERSION
openai.api_base = AZURE_ENDPOINT
openai.api_key = AZURE_API_KEY
response = openai.ChatCompletion.create(timeout=TIMEOUT_SECONDS, **payload);break
except openai.error.AuthenticationError:
tb_str = '```\n' + trimmed_format_exc() + '```'
chatbot[-1] = [chatbot[-1][0], tb_str]
yield from update_ui(chatbot=chatbot, history=history, msg="openai返回错误") # 刷新界面
return
except:
retry += 1
chatbot[-1] = ((chatbot[-1][0], "获取response失败,重试中。。。"))
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
traceback.print_exc()
if retry > MAX_RETRY: raise TimeoutError
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
gpt_replying_buffer = ""
is_head_of_the_stream = True
@@ -141,20 +142,17 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
payload = generate_azure_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
retry = 0
while True:
try:
openai.api_type = "azure"
openai.api_version = AZURE_API_VERSION
openai.api_base = AZURE_ENDPOINT
openai.api_key = AZURE_API_KEY
response = openai.ChatCompletion.create(timeout=TIMEOUT_SECONDS, **payload);break
except:
except:
retry += 1
traceback.print_exc()
if retry > MAX_RETRY: raise TimeoutError
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
stream_response = response
result = ''
@@ -164,19 +162,14 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
break
except:
chunk = next(stream_response) # 失败了,重试一次?再失败就没办法了。
if len(chunk)==0: continue
if not chunk.startswith('data:'):
error_msg = get_full_error(chunk, stream_response)
if "reduce the length" in error_msg:
raise ConnectionAbortedError("AZURE OPENAI API拒绝了请求:" + error_msg)
else:
raise RuntimeError("AZURE OPENAI API拒绝了请求" + error_msg)
if ('data: [DONE]' in chunk): break
delta = chunk["delta"]
if len(delta) == 0: break
if "role" in delta: continue
json_data = json.loads(str(chunk))['choices'][0]
delta = json_data["delta"]
if len(delta) == 0:
break
if "role" in delta:
continue
if "content" in delta:
result += delta["content"]
if not console_slience: print(delta["content"], end='')
@@ -184,11 +177,14 @@ def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="",
# 观测窗,把已经获取的数据显示出去
if len(observe_window) >= 1: observe_window[0] += delta["content"]
# 看门狗,如果超过期限没有喂狗,则终止
if len(observe_window) >= 2:
if len(observe_window) >= 2000:
if (time.time()-observe_window[1]) > watch_dog_patience:
raise RuntimeError("用户取消了程序。")
else: raise RuntimeError("意外Json结构"+delta)
if chunk['finish_reason'] == 'length':
else:
raise RuntimeError("意外Json结构"+delta)
if json_data['finish_reason'] == 'content_filter':
raise RuntimeError("由于提问含不合规内容被Azure过滤。")
if json_data['finish_reason'] == 'length':
raise ConnectionAbortedError("正常结束,但显示Token不足,导致输出不完整,请削减单次输入的文本量。")
return result