from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.prompts import PromptTemplate
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import Any, List, Optional
from modelscope import snapshot_download
model_dir = snapshot_download('AI-ModelScope/bge-small-en-v1.5', cache_dir='./')
from modelscope import snapshot_download
model_dir = snapshot_download('IEITYuan/Yuan2-2B-Mars-hf', cache_dir='./')
model_path = './IEITYuan/Yuan2-2B-Mars-hf'
embedding_model_path = './AI-ModelScope/bge-small-en-v1___5'
torch_dtype = torch.bfloat16 # A10
# torch_dtype = torch.float16 # P100
tokenizer: AutoTokenizer = None
model: AutoModelForCausalLM = None
def __init__(self, mode_path :str):
print("Creat tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(mode_path, add_eos_token=False, add_bos_token=False, eos_token='<eod>')
self.tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
self.model = AutoModelForCausalLM.from_pretrained(mode_path, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"].cuda()
outputs = self.model.generate(inputs,do_sample=False,max_length=4096)
output = self.tokenizer.decode(outputs[0])
response = output.split("<sep>")[-1].split("<eod>")[0]
def _llm_type(self) -> str:
# 定义一个函数,用于获取llm和embeddings
llm = Yuan2_LLM(model_path)
model_kwargs = {'device': 'cuda'}
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_path,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
summarizer_template = """
假设你是一个AI科研助手,请用一段话概括下面文章的主要内容,200字左右。
def __init__(self, llm):
self.prompt = PromptTemplate(
input_variables=["text"],
template=summarizer_template
self.chain = LLMChain(llm=self.llm, prompt=self.prompt)
def summarize(self, docs):
content = docs[0].page_content.split('ABSTRACT')[1].split('KEY WORDS')[0]
summary = self.chain.run(content)
假设你是一个AI科研助手,请基于背景,简要回答问题。
def __init__(self, llm, embeddings):
self.prompt = PromptTemplate(
input_variables=["text"],
template=chatbot_template
self.chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=self.prompt)
self.embeddings = embeddings
self.text_splitter = RecursiveCharacterTextSplitter(
def run(self, docs, query):
text = ''.join([doc.page_content for doc in docs])
all_chunks = self.text_splitter.split_text(text=text)
VectorStore = FAISS.from_texts(all_chunks, embedding=self.embeddings)
chunks = VectorStore.similarity_search(query=query, k=1)
response = self.chain.run(input_documents=chunks, question=query)
st.title('💬 Yuan2.0 AI科研助手')
llm, embeddings = get_models()
summarizer = Summarizer(llm)
chatbot = ChatBot(llm, embeddings)
uploaded_file = st.file_uploader("Upload your PDF", type='pdf')
file_content = uploaded_file.read()
temp_file_path = "temp.pdf"
with open(temp_file_path, "wb") as temp_file:
temp_file.write(file_content)
loader = PyPDFLoader(temp_file_path)
st.chat_message("assistant").write(f"正在生成论文概括,请稍候...")
summary = summarizer.summarize(docs)
st.chat_message("assistant").write(summary)
if query := st.text_input("Ask questions about your PDF file"):
chunks, response = chatbot.run(docs, query)
st.chat_message("assistant").write(f"正在检索相关信息,请稍候...")
st.chat_message("assistant").write(chunks)
st.chat_message("assistant").write(f"正在生成回复,请稍候...")
st.chat_message("assistant").write(response)
if __name__ == '__main__':