begin rag project with llama index

这个提交包含在:
binary-husky
2024-08-21 14:24:37 +00:00
父节点 16f4fd636e
当前提交 294716c832
共有 7 个文件被更改,包括 277 次插入1 次删除

查看文件

@@ -0,0 +1,60 @@
from llama_index.embeddings.openai import OpenAIEmbedding
from openai import OpenAI
from toolbox import get_conf
from toolbox import CatchException, update_ui, get_conf, select_api_key, get_log_folder, ProxyNetworkActivate
from shared_utils.key_pattern_manager import select_api_key_for_embed_models
from typing import List, Any
class OpenAiEmbeddingModel():
def __init__(self, llm_kwargs:dict=None):
self.llm_kwargs = llm_kwargs
def compute_embedding(self, text="这是要计算嵌入的文本", llm_kwargs:dict=None, batch_mode=False):
from .bridge_all_embed import embed_model_info
# load kwargs
if llm_kwargs is None:
llm_kwargs = self.llm_kwargs
if llm_kwargs is None:
raise RuntimeError("llm_kwargs is not provided!")
# setup api and req url
api_key = select_api_key_for_embed_models(llm_kwargs['api_key'], llm_kwargs['llm_model'])
embed_model = llm_kwargs['llm_model']
base_url = embed_model_info[llm_kwargs['llm_model']]['embed_endpoint'].replace('embeddings', '')
# send and compute
with ProxyNetworkActivate("Connect_OpenAI_Embedding"):
self.oai_client = OpenAI(api_key=api_key, base_url=base_url)
if batch_mode:
input = text
assert isinstance(text, list)
else:
input = [text]
assert isinstance(text, str)
res = self.oai_client.embeddings.create(input=input, model=embed_model)
# parse result
if batch_mode:
embedding = [d.embedding for d in res.data]
else:
embedding = res.data[0].embedding
return embedding
def embedding_dimension(self, llm_kwargs):
from .bridge_all_embed import embed_model_info
return embed_model_info[llm_kwargs['llm_model']]['embed_dimension']
def get_text_embedding_batch(
self,
texts: List[str],
show_progress: bool = False,
):
return self.compute_embedding(texts, batch_mode=True)
if __name__ == "__main__":
pass