接入TGUI

这个提交包含在:
Your Name
2023-04-02 00:40:05 +08:00
父节点 3af0bbdbe4
当前提交 2420d62a33
共有 3 个文件被更改,包括 17 次插入17 次删除

查看文件

@@ -15,7 +15,10 @@ import importlib
from toolbox import get_conf
LLM_MODEL, = get_conf('LLM_MODEL')
model_name, addr, port = LLM_MODEL.split('@')
# "TGUI:galactica-1.3b@localhost:7860"
model_name, addr_port = LLM_MODEL.split('@')
assert ':' in addr_port, "LLM_MODEL 格式不正确!" + LLM_MODEL
addr, port = addr_port.split(':')
def random_hash():
letters = string.ascii_lowercase + string.digits
@@ -117,11 +120,11 @@ def predict_tgui(inputs, top_p, temperature, chatbot=[], history=[], system_prom
def run_coorotine(mutable):
async def get_result(mutable):
async for response in run(prompt):
# Print intermediate steps
print(response[len(mutable[0]):])
mutable[0] = response
asyncio.run(get_result(mutable))
thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))
thread_listen = threading.Thread(target=run_coorotine, args=(mutable,), daemon=True)
thread_listen.start()
while thread_listen.is_alive():
@@ -145,7 +148,7 @@ def predict_tgui_no_ui(inputs, top_p, temperature, history=[], sys_prompt=""):
def run_coorotine(mutable):
async def get_result(mutable):
async for response in run(prompt):
# Print intermediate steps
print(response[len(mutable[0]):])
mutable[0] = response
asyncio.run(get_result(mutable))
thread_listen = threading.Thread(target=run_coorotine, args=(mutable,))