镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-06 14:36:48 +00:00
fix: support o1 models
这个提交包含在:
@@ -255,6 +255,8 @@ model_info = {
|
|||||||
"max_token": 128000,
|
"max_token": 128000,
|
||||||
"tokenizer": tokenizer_gpt4,
|
"tokenizer": tokenizer_gpt4,
|
||||||
"token_cnt": get_token_num_gpt4,
|
"token_cnt": get_token_num_gpt4,
|
||||||
|
"openai_disable_system_prompt": True,
|
||||||
|
"openai_disable_stream": True,
|
||||||
},
|
},
|
||||||
"o1-mini": {
|
"o1-mini": {
|
||||||
"fn_with_ui": chatgpt_ui,
|
"fn_with_ui": chatgpt_ui,
|
||||||
@@ -263,6 +265,8 @@ model_info = {
|
|||||||
"max_token": 128000,
|
"max_token": 128000,
|
||||||
"tokenizer": tokenizer_gpt4,
|
"tokenizer": tokenizer_gpt4,
|
||||||
"token_cnt": get_token_num_gpt4,
|
"token_cnt": get_token_num_gpt4,
|
||||||
|
"openai_disable_system_prompt": True,
|
||||||
|
"openai_disable_stream": True,
|
||||||
},
|
},
|
||||||
|
|
||||||
"gpt-4-turbo": {
|
"gpt-4-turbo": {
|
||||||
|
|||||||
@@ -133,22 +133,33 @@ def predict_no_ui_long_connection(inputs:str, llm_kwargs:dict, history:list=[],
|
|||||||
observe_window = None:
|
observe_window = None:
|
||||||
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
用于负责跨越线程传递已经输出的部分,大部分时候仅仅为了fancy的视觉效果,留空即可。observe_window[0]:观测窗。observe_window[1]:看门狗
|
||||||
"""
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
|
||||||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||||
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=True)
|
|
||||||
|
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||||
|
else: stream = True
|
||||||
|
|
||||||
|
headers, payload = generate_payload(inputs, llm_kwargs, history, system_prompt=sys_prompt, stream=stream)
|
||||||
retry = 0
|
retry = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# make a POST request to the API endpoint, stream=False
|
# make a POST request to the API endpoint, stream=False
|
||||||
from .bridge_all import model_info
|
|
||||||
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
endpoint = verify_endpoint(model_info[llm_kwargs['llm_model']]['endpoint'])
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS); break
|
json=payload, stream=stream, timeout=TIMEOUT_SECONDS); break
|
||||||
except requests.exceptions.ReadTimeout as e:
|
except requests.exceptions.ReadTimeout as e:
|
||||||
retry += 1
|
retry += 1
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
if MAX_RETRY!=0: print(f'请求超时,正在重试 ({retry}/{MAX_RETRY}) ……')
|
||||||
|
|
||||||
|
if not stream:
|
||||||
|
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||||
|
chunkjson = json.loads(response.content.decode())
|
||||||
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
|
return gpt_replying_buffer
|
||||||
|
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
result = ''
|
result = ''
|
||||||
json_data = None
|
json_data = None
|
||||||
@@ -208,7 +219,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
chatbot 为WebUI中显示的对话列表,修改它,然后yeild出去,可以直接修改对话界面内容
|
||||||
additional_fn代表点击的哪个按钮,按钮见functional.py
|
additional_fn代表点击的哪个按钮,按钮见functional.py
|
||||||
"""
|
"""
|
||||||
from .bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
if is_any_api_key(inputs):
|
if is_any_api_key(inputs):
|
||||||
chatbot._cookies['api_key'] = inputs
|
chatbot._cookies['api_key'] = inputs
|
||||||
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
|
chatbot.append(("输入已识别为openai的api_key", what_keys(inputs)))
|
||||||
@@ -237,6 +248,10 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chatbot.append((_inputs, ""))
|
chatbot.append((_inputs, ""))
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="等待响应") # 刷新界面
|
||||||
|
|
||||||
|
# 禁用stream的特殊模型处理
|
||||||
|
if model_info[llm_kwargs['llm_model']].get('openai_disable_stream', False): stream = False
|
||||||
|
else: stream = True
|
||||||
|
|
||||||
# check mis-behavior
|
# check mis-behavior
|
||||||
if is_the_upload_folder(user_input):
|
if is_the_upload_folder(user_input):
|
||||||
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
|
chatbot[-1] = (inputs, f"[Local Message] 检测到操作错误!当您上传文档之后,需点击“**函数插件区**”按钮进行处理,请勿点击“提交”按钮或者“基础功能区”按钮。")
|
||||||
@@ -270,7 +285,7 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
try:
|
try:
|
||||||
# make a POST request to the API endpoint, stream=True
|
# make a POST request to the API endpoint, stream=True
|
||||||
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
response = requests.post(endpoint, headers=headers, proxies=proxies,
|
||||||
json=payload, stream=True, timeout=TIMEOUT_SECONDS);break
|
json=payload, stream=stream, timeout=TIMEOUT_SECONDS);break
|
||||||
except:
|
except:
|
||||||
retry += 1
|
retry += 1
|
||||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||||
@@ -278,10 +293,15 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时"+retry_msg) # 刷新界面
|
||||||
if retry > MAX_RETRY: raise TimeoutError
|
if retry > MAX_RETRY: raise TimeoutError
|
||||||
|
|
||||||
gpt_replying_buffer = ""
|
|
||||||
|
|
||||||
is_head_of_the_stream = True
|
if not stream:
|
||||||
|
# 该分支仅适用于不支持stream的o1模型,其他情形一律不适用
|
||||||
|
yield from handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history)
|
||||||
|
return
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
|
gpt_replying_buffer = ""
|
||||||
|
is_head_of_the_stream = True
|
||||||
stream_response = response.iter_lines()
|
stream_response = response.iter_lines()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -343,12 +363,24 @@ def predict(inputs:str, llm_kwargs:dict, plugin_kwargs:dict, chatbot:ChatBotWith
|
|||||||
chunk_decoded = chunk.decode()
|
chunk_decoded = chunk.decode()
|
||||||
error_msg = chunk_decoded
|
error_msg = chunk_decoded
|
||||||
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
chatbot, history = handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg)
|
||||||
yield from update_ui(chatbot=chatbot, history=history, msg="Json异常" + error_msg) # 刷新界面
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + error_msg) # 刷新界面
|
||||||
print(error_msg)
|
print(error_msg)
|
||||||
return
|
return
|
||||||
|
return # return from stream-branch
|
||||||
|
|
||||||
|
def handle_o1_model_special(response, inputs, llm_kwargs, chatbot, history):
|
||||||
|
try:
|
||||||
|
chunkjson = json.loads(response.content.decode())
|
||||||
|
gpt_replying_buffer = chunkjson['choices'][0]["message"]["content"]
|
||||||
|
log_chat(llm_model=llm_kwargs["llm_model"], input_str=inputs, output_str=gpt_replying_buffer)
|
||||||
|
history[-1] = gpt_replying_buffer
|
||||||
|
chatbot[-1] = (history[-2], history[-1])
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history) # 刷新界面
|
||||||
|
except Exception as e:
|
||||||
|
yield from update_ui(chatbot=chatbot, history=history, msg="Json解析异常" + response.text) # 刷新界面
|
||||||
|
|
||||||
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
def handle_error(inputs, llm_kwargs, chatbot, history, chunk_decoded, error_msg):
|
||||||
from .bridge_all import model_info
|
from request_llms.bridge_all import model_info
|
||||||
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
openai_website = ' 请登录OpenAI查看详情 https://platform.openai.com/signup'
|
||||||
if "reduce the length" in error_msg:
|
if "reduce the length" in error_msg:
|
||||||
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
if len(history) >= 2: history[-1] = ""; history[-2] = "" # 清除当前溢出的输入:history[-2] 是本次输入, history[-1] 是本次输出
|
||||||
@@ -381,6 +413,8 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
"""
|
"""
|
||||||
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
整合所有信息,选择LLM模型,生成http请求,为发送请求做准备
|
||||||
"""
|
"""
|
||||||
|
from request_llms.bridge_all import model_info
|
||||||
|
|
||||||
if not is_any_api_key(llm_kwargs['api_key']):
|
if not is_any_api_key(llm_kwargs['api_key']):
|
||||||
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
raise AssertionError("你提供了错误的API_KEY。\n\n1. 临时解决方案:直接在输入区键入api_key,然后回车提交。\n\n2. 长效解决方案:在config.py中配置。")
|
||||||
|
|
||||||
@@ -409,10 +443,16 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
else:
|
else:
|
||||||
enable_multimodal_capacity = False
|
enable_multimodal_capacity = False
|
||||||
|
|
||||||
|
conversation_cnt = len(history) // 2
|
||||||
|
openai_disable_system_prompt = model_info[llm_kwargs['llm_model']].get('openai_disable_system_prompt', False)
|
||||||
|
|
||||||
|
if openai_disable_system_prompt:
|
||||||
|
messages = []
|
||||||
|
else:
|
||||||
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
if not enable_multimodal_capacity:
|
if not enable_multimodal_capacity:
|
||||||
# 不使用多模态能力
|
# 不使用多模态能力
|
||||||
conversation_cnt = len(history) // 2
|
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
|
||||||
if conversation_cnt:
|
if conversation_cnt:
|
||||||
for index in range(0, 2*conversation_cnt, 2):
|
for index in range(0, 2*conversation_cnt, 2):
|
||||||
what_i_have_asked = {}
|
what_i_have_asked = {}
|
||||||
@@ -434,8 +474,6 @@ def generate_payload(inputs:str, llm_kwargs:dict, history:list, system_prompt:st
|
|||||||
messages.append(what_i_ask_now)
|
messages.append(what_i_ask_now)
|
||||||
else:
|
else:
|
||||||
# 多模态能力
|
# 多模态能力
|
||||||
conversation_cnt = len(history) // 2
|
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
|
||||||
if conversation_cnt:
|
if conversation_cnt:
|
||||||
for index in range(0, 2*conversation_cnt, 2):
|
for index in range(0, 2*conversation_cnt, 2):
|
||||||
what_i_have_asked = {}
|
what_i_have_asked = {}
|
||||||
|
|||||||
在新工单中引用
屏蔽一个用户