Support Conversational History in RAG Pipelines with Llama 3

Support Conversational History in RAG Pipelines with Llama 3
Photo by Vincent Yuan @USA / Unsplash

In Retrieval-Augmented Generation (RAG) pipelines, it's crucial to help chatbots recall previous conversations, as users may ask follow-up questions that rely on earlier context. However, users' prompts might lack sufficient context, assuming previous discussions are still relevant. To tackle this challenge, incorporating chat history into LLMs' question-answering context enables them to retrieve relevant information for new queries.

This post presents a solution leveraging LangChain, Llama 3-8B, and Ollama, which can efficiently run on an M2 Pro MacBook Pro with 16 GB memory.

1 Dependencies

1.1 Ollama and Llama 3 Model

Firstly, Ollama should be installed on a MacBook. Ollama can utilize the GPUs of the machine, ensuring efficient inference, provided there is sufficient memory. Llama 3-8B performs well on machines with 16 GB of memory.

💡
Ollama can be downloaded here: https://ollama.com/

Once it is downloaded, can use below command in the terminal to pull the Llama 3-8B model:

ollama pull llama3

1.2 Python Dependencies

Now, let's import the required packages to construct a RAG system with chat history, utilizing the LangChain toolkit.

# Models
from langchain.llms import LlamaCpp
from langchain.chat_models import ChatOpenAI

# Setup
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

# Vector store
from langchain.document_loaders import  TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter

# LangChain supports many other chat models. Here, we're using Ollama
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate

# RAG with Memory 
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory

# Display results
import markdown
from IPython.display import display, Markdown, Latex

2 Create Vector Store

The source data consists of a summary of important events and statistics from the week of May 13th, 2024, as published by Yahoo Finance. This data is not included in the training set of Llama 3. For demonstration purposes, the news is extracted to a text file and utilized in the code to create the Chroma vector store and retriever.

source_data_path = '../data/yahoo.txt'

# for token-wise streaming so you'll see the answer gets generated token by token when Llama is answering your question
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

loader = TextLoader(source_data_path)

documents = loader.load()

#splitting the text into
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)

embedding = HuggingFaceEmbeddings()

vectordb = Chroma.from_documents(documents=texts,
                                 embedding=embedding
                                 # persist_directory=persist_directory
                                )
                                
retriever = vectordb.as_retriever(search_kwargs={"k": 5})

3 Create the LLM Object

Make sure the Ollama is on and the LLama 3 model has been downloaded, then below code can be used to define a LLM object in the pipeline:

llm = ChatOllama(model="llama3",
                temperature=0.1)

4 RAG with Memory

In essence, there should be place to store chat history, also the the chat history is added to the prompt  in RAG, so that the LLM can access past conversation, also the chat history is update after each round of conversation. Below is a way to use the BaseChatMessageHistory to address this need:

### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", contextualize_q_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
history_aware_retriever = create_history_aware_retriever(
    llm, retriever, contextualize_q_prompt
)


### Answer question ###
qa_system_prompt = """You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Use three sentences maximum and keep the answer concise.\

{context}"""
qa_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", qa_system_prompt),
        MessagesPlaceholder("chat_history"),
        ("human", "{input}"),
    ]
)
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)

rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)

Then define the RAG chain:

### Statefully manage chat history ###
store = {}

def get_session_history(session_id: str) -> BaseChatMessageHistory:
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]


conversational_rag_chain = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)

Then let's try if the model understands the Yahoo Finance analysis, the question is What is the wall street expectation of the April Consumer Price Index (CPI)?.

llm_response = conversational_rag_chain.invoke(
    {"input": "What is the wall street expectation of the April Consumer Price Index (CPI)?"},
    config={
        "configurable": {"session_id": "abc123"}
    },  # constructs a key "abc123" in `store`.
)["answer"]

print('='*50)
display(Markdown(llm_response))

The response is:

According to the text, Wall Street expects an annual gain of 3.4% for headline CPI, which includes the price of food and energy, a decrease from the 3.5% headline number in March. Additionally, prices are expected to rise 0.4% on a month-over-month basis, in line with March's rise.

This is aligned with the source:

Yahoo Finance Analysis

Then, a question is asked based on the output of last question to calculate the double of the expected CPI:

llm_response = conversational_rag_chain.invoke(
    {"input": "What is the double of the expected CPI in the prior answer?"},
    config={
        "configurable": {"session_id": "abc123"}
    },  # constructs a key "abc123" in `store`.
)["answer"]

print('='*50)
display(Markdown(llm_response))

And this is the output:

The expected annual gain for headline CPI is 3.4%. The double of this value would be:

2 x 3.4% = 6.8%

So, the double of the expected CPI is 6.8%.

So the model successfully picks up the information that it returns in the past and answer correctly to the new question.

5 Summary

This enhanced solution extends the capabilities of a regular RAG by supporting chat history, making it highly beneficial for multiple rounds of conversations. With Ollama, experiments like this can be run on an affordable laptop with embedded GPUs. A special acknowledgment to Meta for their great work in improving Llama 3.