适配星火大模型图片理解 增加上传图片view

这个提交包含在:
spike
2023-11-15 10:09:42 +08:00
父节点 371158cb56
当前提交 aa341fd268
共有 3 个文件被更改,包括 114 次插入29 次删除

查看文件

@@ -1,4 +1,4 @@
from toolbox import get_conf
from toolbox import get_conf, get_pictures_list, encode_image
import base64
import datetime
import hashlib
@@ -65,6 +65,7 @@ class SparkRequestInstance():
self.gpt_url = "ws://spark-api.xf-yun.com/v1.1/chat"
self.gpt_url_v2 = "ws://spark-api.xf-yun.com/v2.1/chat"
self.gpt_url_v3 = "ws://spark-api.xf-yun.com/v3.1/chat"
self.gpt_url_img = "wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image"
self.time_to_yield_event = threading.Event()
self.time_to_exit_event = threading.Event()
@@ -92,7 +93,11 @@ class SparkRequestInstance():
gpt_url = self.gpt_url_v3
else:
gpt_url = self.gpt_url
file_manifest = []
if llm_kwargs.get('most_recent_uploaded'):
if llm_kwargs['most_recent_uploaded'].get('path'):
file_manifest = get_pictures_list(llm_kwargs['most_recent_uploaded']['path'])
gpt_url = self.gpt_url_img
wsParam = Ws_Param(self.appid, self.api_key, self.api_secret, gpt_url)
websocket.enableTrace(False)
wsUrl = wsParam.create_url()
@@ -101,9 +106,8 @@ class SparkRequestInstance():
def on_open(ws):
import _thread as thread
thread.start_new_thread(run, (ws,))
def run(ws, *args):
data = json.dumps(gen_params(ws.appid, *ws.all_args))
data = json.dumps(gen_params(ws.appid, *ws.all_args, file_manifest))
ws.send(data)
# 收到websocket消息的处理
@@ -142,9 +146,18 @@ class SparkRequestInstance():
ws.all_args = (inputs, llm_kwargs, history, system_prompt)
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
def generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest):
conversation_cnt = len(history) // 2
messages = [{"role": "system", "content": system_prompt}]
messages = []
if file_manifest:
base64_images = []
for image_path in file_manifest:
base64_images.append(encode_image(image_path))
for img_s in base64_images:
if img_s not in str(messages):
messages.append({"role": "user", "content": img_s, "content_type": "image"})
else:
messages = [{"role": "system", "content": system_prompt}]
if conversation_cnt:
for index in range(0, 2*conversation_cnt, 2):
what_i_have_asked = {}
@@ -167,7 +180,7 @@ def generate_message_payload(inputs, llm_kwargs, history, system_prompt):
return messages
def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
def gen_params(appid, inputs, llm_kwargs, history, system_prompt, file_manifest):
"""
通过appid和用户的提问来生成请参数
"""
@@ -176,6 +189,8 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
"sparkv2": "generalv2",
"sparkv3": "generalv3",
}
domains_select = domains[llm_kwargs['llm_model']]
if file_manifest: domains_select = 'image'
data = {
"header": {
"app_id": appid,
@@ -183,7 +198,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
},
"parameter": {
"chat": {
"domain": domains[llm_kwargs['llm_model']],
"domain": domains_select,
"temperature": llm_kwargs["temperature"],
"random_threshold": 0.5,
"max_tokens": 4096,
@@ -192,7 +207,7 @@ def gen_params(appid, inputs, llm_kwargs, history, system_prompt):
},
"payload": {
"message": {
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt)
"text": generate_message_payload(inputs, llm_kwargs, history, system_prompt, file_manifest)
}
}
}