镜像自地址
https://github.com/binary-husky/gpt_academic.git
已同步 2025-12-07 15:06:48 +00:00
@@ -28,6 +28,9 @@ from .bridge_chatglm3 import predict as chatglm3_ui
|
||||
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
|
||||
from .bridge_qianfan import predict as qianfan_ui
|
||||
|
||||
from .bridge_google_gemini import predict as genai_ui
|
||||
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui
|
||||
|
||||
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
|
||||
|
||||
class LazyloadTiktoken(object):
|
||||
@@ -246,6 +249,22 @@ model_info = {
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-pro": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": None,
|
||||
"max_token": 1024 * 32,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
"gemini-pro-vision": {
|
||||
"fn_with_ui": genai_ui,
|
||||
"fn_without_ui": genai_noui,
|
||||
"endpoint": None,
|
||||
"max_token": 1024 * 32,
|
||||
"tokenizer": tokenizer_gpt35,
|
||||
"token_cnt": get_token_num_gpt35,
|
||||
},
|
||||
}
|
||||
|
||||
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-
|
||||
|
||||
101
request_llms/bridge_google_gemini.py
普通文件
101
request_llms/bridge_google_gemini.py
普通文件
@@ -0,0 +1,101 @@
|
||||
# encoding: utf-8
|
||||
# @Time : 2023/12/21
|
||||
# @Author : Spike
|
||||
# @Descr :
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from request_llms.com_google import GoogleChatInit
|
||||
from toolbox import get_conf, update_ui, update_ui_lastest_msg
|
||||
|
||||
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
|
||||
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
|
||||
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
|
||||
|
||||
|
||||
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
|
||||
console_slience=False):
|
||||
# 检查API_KEY
|
||||
if get_conf("GEMINI_API_KEY") == "":
|
||||
raise ValueError(f"请配置 GEMINI_API_KEY。")
|
||||
|
||||
genai = GoogleChatInit()
|
||||
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
|
||||
gpt_replying_buffer = ''
|
||||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
|
||||
for response in stream_response:
|
||||
results = response.decode()
|
||||
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
|
||||
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
|
||||
except:
|
||||
raise ValueError(f"解析GEMINI消息出错。")
|
||||
buffer = paraphrase['text']
|
||||
gpt_replying_buffer += buffer
|
||||
if len(observe_window) >= 1:
|
||||
observe_window[0] = gpt_replying_buffer
|
||||
if len(observe_window) >= 2:
|
||||
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
|
||||
if error_match:
|
||||
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
|
||||
return gpt_replying_buffer
|
||||
|
||||
|
||||
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
|
||||
# 检查API_KEY
|
||||
if get_conf("GEMINI_API_KEY") == "":
|
||||
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
|
||||
return
|
||||
|
||||
chatbot.append((inputs, ""))
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
genai = GoogleChatInit()
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
|
||||
break
|
||||
except Exception as e:
|
||||
retry += 1
|
||||
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
|
||||
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
|
||||
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
|
||||
if retry > MAX_RETRY: raise TimeoutError
|
||||
gpt_replying_buffer = ""
|
||||
gpt_security_policy = ""
|
||||
history.extend([inputs, ''])
|
||||
for response in stream_response:
|
||||
results = response.decode("utf-8") # 被这个解码给耍了。。
|
||||
gpt_security_policy += results
|
||||
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
|
||||
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
|
||||
if match:
|
||||
try:
|
||||
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
|
||||
except:
|
||||
raise ValueError(f"解析GEMINI消息出错。")
|
||||
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
|
||||
chatbot[-1] = (inputs, gpt_replying_buffer)
|
||||
history[-1] = gpt_replying_buffer
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
if error_match:
|
||||
history = history[-2] # 错误的不纳入对话
|
||||
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
raise RuntimeError('对话错误')
|
||||
if not gpt_replying_buffer:
|
||||
history = history[-2] # 错误的不纳入对话
|
||||
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
|
||||
yield from update_ui(chatbot=chatbot, history=history)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
llm_kwargs = {'llm_model': 'gemini-pro'}
|
||||
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
|
||||
for i in result:
|
||||
print(i)
|
||||
198
request_llms/com_google.py
普通文件
198
request_llms/com_google.py
普通文件
@@ -0,0 +1,198 @@
|
||||
# encoding: utf-8
|
||||
# @Time : 2023/12/25
|
||||
# @Author : Spike
|
||||
# @Descr :
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
from typing import List, Dict, Tuple
|
||||
from toolbox import get_conf, encode_image
|
||||
|
||||
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
|
||||
|
||||
"""
|
||||
========================================================================
|
||||
第五部分 一些文件处理方法
|
||||
files_filter_handler 根据type过滤文件
|
||||
input_encode_handler 提取input中的文件,并解析
|
||||
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
|
||||
link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
|
||||
html_view_blank 超链接
|
||||
html_local_file 本地文件取相对路径
|
||||
to_markdown_tabs 文件list 转换为 md tab
|
||||
"""
|
||||
|
||||
|
||||
def files_filter_handler(file_list):
|
||||
new_list = []
|
||||
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
|
||||
for file in file_list:
|
||||
file = str(file).replace('file=', '')
|
||||
if os.path.exists(file):
|
||||
if str(os.path.basename(file)).split('.')[-1] in filter_:
|
||||
new_list.append(file)
|
||||
return new_list
|
||||
|
||||
|
||||
def input_encode_handler(inputs):
|
||||
md_encode = []
|
||||
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
|
||||
matches_path = re.findall(pattern_md_file, inputs)
|
||||
for md_path in matches_path:
|
||||
pattern_file = r"\((file=.*)\)"
|
||||
matches_path = re.findall(pattern_file, md_path)
|
||||
encode_file = files_filter_handler(file_list=matches_path)
|
||||
if encode_file:
|
||||
md_encode.extend([{
|
||||
"data": encode_image(i),
|
||||
"type": os.path.splitext(i)[1].replace('.', '')
|
||||
} for i in encode_file])
|
||||
inputs = inputs.replace(md_path, '')
|
||||
return inputs, md_encode
|
||||
|
||||
|
||||
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
|
||||
new_list = []
|
||||
if not filter_:
|
||||
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
|
||||
for file in file_list:
|
||||
if str(os.path.basename(file)).split('.')[-1] in filter_:
|
||||
new_list.append(html_local_img(file, md=md_type))
|
||||
elif os.path.exists(file):
|
||||
new_list.append(link_mtime_to_md(file))
|
||||
else:
|
||||
new_list.append(file)
|
||||
return new_list
|
||||
|
||||
|
||||
def link_mtime_to_md(file):
|
||||
link_local = html_local_file(file)
|
||||
link_name = os.path.basename(file)
|
||||
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
|
||||
return a
|
||||
|
||||
|
||||
def html_local_file(file):
|
||||
base_path = os.path.dirname(__file__) # 项目目录
|
||||
if os.path.exists(str(file)):
|
||||
file = f'file={file.replace(base_path, ".")}'
|
||||
return file
|
||||
|
||||
|
||||
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
|
||||
style = ''
|
||||
if max_width is not None:
|
||||
style += f"max-width: {max_width};"
|
||||
if max_height is not None:
|
||||
style += f"max-height: {max_height};"
|
||||
__file = html_local_file(__file)
|
||||
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
|
||||
if md:
|
||||
a = f''
|
||||
return a
|
||||
|
||||
|
||||
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
|
||||
"""
|
||||
Args:
|
||||
head: 表头:[]
|
||||
tabs: 表值:[[列1], [列2], [列3], [列4]]
|
||||
alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
|
||||
column: True to keep data in columns, False to keep data in rows (default).
|
||||
Returns:
|
||||
A string representation of the markdown table.
|
||||
"""
|
||||
if column:
|
||||
transposed_tabs = list(map(list, zip(*tabs)))
|
||||
else:
|
||||
transposed_tabs = tabs
|
||||
# Find the maximum length among the columns
|
||||
max_len = max(len(column) for column in transposed_tabs)
|
||||
|
||||
tab_format = "| %s "
|
||||
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
|
||||
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
|
||||
|
||||
for i in range(max_len):
|
||||
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
|
||||
row_data = file_manifest_filter_html(row_data, filter_=None)
|
||||
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
|
||||
|
||||
return tabs_list
|
||||
|
||||
|
||||
class GoogleChatInit:
|
||||
|
||||
def __init__(self):
|
||||
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
|
||||
|
||||
def __conversation_user(self, user_input):
|
||||
what_i_have_asked = {"role": "user", "parts": []}
|
||||
if 'vision' not in self.url_gemini:
|
||||
input_ = user_input
|
||||
encode_img = []
|
||||
else:
|
||||
input_, encode_img = input_encode_handler(user_input)
|
||||
what_i_have_asked['parts'].append({'text': input_})
|
||||
if encode_img:
|
||||
for data in encode_img:
|
||||
what_i_have_asked['parts'].append(
|
||||
{'inline_data': {
|
||||
"mime_type": f"image/{data['type']}",
|
||||
"data": data['data']
|
||||
}})
|
||||
return what_i_have_asked
|
||||
|
||||
def __conversation_history(self, history):
|
||||
messages = []
|
||||
conversation_cnt = len(history) // 2
|
||||
if conversation_cnt:
|
||||
for index in range(0, 2 * conversation_cnt, 2):
|
||||
what_i_have_asked = self.__conversation_user(history[index])
|
||||
what_gpt_answer = {
|
||||
"role": "model",
|
||||
"parts": [{"text": history[index + 1]}]
|
||||
}
|
||||
messages.append(what_i_have_asked)
|
||||
messages.append(what_gpt_answer)
|
||||
return messages
|
||||
|
||||
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
|
||||
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
|
||||
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
|
||||
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
|
||||
return response.iter_lines()
|
||||
|
||||
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
|
||||
messages = [
|
||||
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
|
||||
# {"role": "user", "parts": [{"text": ""}]},
|
||||
# {"role": "model", "parts": [{"text": ""}]}
|
||||
]
|
||||
self.url_gemini = self.url_gemini.replace(
|
||||
'%m', llm_kwargs['llm_model']).replace(
|
||||
'%k', get_conf('GEMINI_API_KEY')
|
||||
)
|
||||
header = {'Content-Type': 'application/json'}
|
||||
if 'vision' not in self.url_gemini: # 不是vision 才处理history
|
||||
messages.extend(self.__conversation_history(history)) # 处理 history
|
||||
messages.append(self.__conversation_user(inputs)) # 处理用户对话
|
||||
payload = {
|
||||
"contents": messages,
|
||||
"generationConfig": {
|
||||
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
|
||||
"temperature": llm_kwargs.get('temperature', 1),
|
||||
# "maxOutputTokens": 800,
|
||||
"topP": llm_kwargs.get('top_p', 0.8),
|
||||
"topK": 10
|
||||
}
|
||||
}
|
||||
return header, payload
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
google = GoogleChatInit()
|
||||
# print(gootle.generate_message_payload('你好呀', {},
|
||||
# ['123123', '3123123'], ''))
|
||||
# gootle.input_encode_handle('123123[123123](./123123), ')
|
||||
在新工单中引用
屏蔽一个用户