Conversational AI: Chat with your documents ๐ using Llama2 ๐ฆ, AWS SageMaker ๐ง , LangChain ๐ฆ๏ธ๐ and Streamlit ๐ฅ โ Part-1
An interactive web application powered by large language model to chat with documents
Opinions are my own and not the views of my current or former employer(s)
Large language models (LLMs) have gained a lot of attention in the last few months (โthanksโ to chatGPT) for their ability to converse like a human. LLMs have been around for a few years, but they have only recently become powerful enough to generate realistic and engaging conversation.
In this blog, I will be talking about an LLM application that will enable us to converse with our documents using the RAG architecture
High level process flow
Before we delve into the application, it is important to understand the high-level process flow of building an RAG application that enables interaction with an LLM.
It consists of two major steps.
- Create vector embeddings of the source document(s) and store it in a vector store
โ Vector embeddings are numerical representations of words or sentences that are used in Natural Language Processing (NLP) to analyze and manipulate text data. They are lists of numbers that can represent many types of data, such as audio, video, text, and images. The process of converting words into numbers is called vectorization.
โ A vector store is a specialized database that stores and manages unstructured data in the form of collections. Vector stores are also known as vector databases or vector indexes.
2. Retrieve documents that are similar to a userโs query and pass it to a LLM to respond to the query by using the documents as context plus the knowledge that it already has.
Now that we have an idea about the process flow, letโs build the application.
๐ Steps to build the application
- Setup a conda environment
- Data Acquisition (Documents)
- Create & Store embeddings
- Setup an API Endpoint for interacting with LLM
- Create an application for handling user interaction and LLM response
1. Setup a conda ๐ environment
Setting up a conda environment provides an isolated environment for developing and running an application while ensuring all the package dependencies are being satisfied
conda env create -f environment.yml python=3.10 # conda 22.9.0
conda activate docai
pip install -r requirements.txt
An environment.yml file is a YAML file that defines a Conda environment. It is used to create a new environment or to export an existing environment. The file contains a list of packages and their versions that are required to create the environment.
To change the name of the conda environment that is getting created please change the name in the first line of the environment.yml
file
Letโs also create a ipykernel
that will allow us to interact with the conda environment from a jupyter
notebook
python -m ipykernel install --user --name=conda_docai
I have used AWS SageMaker Notebook for this project
2. Data Acquisition ๐งบ
Data acquisition is the process of gathering the data that is needed for or training a ML model, populating a data warehouse, building a data application etc.,
For this project we will be using the AWS Developer Guide for SageMaker as the documents with which a user will be conversing using a web application and an LLM.
Letโs go ahead and gather the documents using the sitemap.xml
by running the following wget
commands
wget --quiet https://docs.aws.amazon.com/sagemaker/latest/dg/sitemap.xml --output-document - | egrep -o "https://[^<]+" | wget --directory-prefix=./aws_docs/sagemaker/ -i -
3. Create and Store Embeddings1๏ธโฃ0๏ธโฃ
Now that our documents are ready letโs create the vector embeddings and store it in a vector store
import os
import glob
from langchain.document_loaders import BSHTMLLoader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings, SentenceTransformerEmbeddings
# Step1: Define a sentence transformer model that will be used
# to convert the documents into vector embeddings
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L12-v2")
# Step2: Create a list of html contents from the documents
html_docs = []
path_to_dir = f"./aws_docs/sagemaker/"
html_files = glob.glob(os.path.join(path_to_dir, "*.html"))
for _file in html_files:
with open(_file) as f:
loader = BSHTMLLoader(_file)
data = loader.load()
html_docs.extend(data)
# Step3: Create & save a vector database with the vector embeddings
# of the documents
db = FAISS.from_documents(html_docs, embeddings)
db.save_local("faiss_index")
Let me try to explain what is happening in the above code. The Python script follows a sequence of steps to convert HTML documents into vector embeddings and subsequently store them in a vector database, employing various libraries and models. Letโs review each step:
- ๐ Importing ๏ธ๐ฆ๐LangChain Libraries: We import essential LangChain libraries.
๐ฆ๐LangChain is a purpose-built framework for developing applications powered by language models. The framework comprises core elements known as Components and Chains. Components are fundamental building blocks of a LangChain chain, abstracting interactions with diverse models and frameworks. These components include Models, Prompts, Indexes, Memory. Chains are a structured assembly of components for accomplishing a certain task. - ๐ Extracting the HTML documents: Subsequently, we generate a list called
html_docs
that contains the HTML contents of the documents. The HTML files are read during this process, and the `BSHTMLLoader` document loader from LangChain extracts the content from these files. - ๐ Generating Vector Embeddings and Using FAISS Index: Finally, we proceed to create vector embeddings using a SentenceTransformer model.
* SentenceTransformers is a python framework for creating embeddings of text and images using state-of-the-art-models. There are many models to choose from based on individual performance / speed. Details of various SentenceTransformers models can be found here (toggle theAll Models
button to see the deatils of various models). For this application, I have used theall-MiniLM-L12-v2
model based on itโs size, speed and performance
* FAISS is a similarity search library developed by Facebook AI that enables efficient & fast searching of multimedia documents that are similar to each other
Please note as I did not have access to a machine with GPU both the sentence transformer and index creation procedures were executed on a CPU-only machine. As such, we installed faiss-cpu and torch-cpu to facilitate the execution. To aid in replicating the environment, all the necessary dependencies are documented in the environment.yml and requirements.txt files, which are available in the GitHub repository mentioned below.
4. Setup an API Endpoint ๐ for interacting with LLM ๐ฆ
With our embeddings prepared, itโs time to proceed to the next step and bring our LLM to life. For this application, we will utilize the recently launched Llama2๐ฆ, an open-source LLM from Facebook.
I utilized AWS SageMaker Jumpstart to host the Llama2 foundational model on a SageMaker endpoint since I lack a local GPU. With only a few lines of code, I effortlessly established an endpoint running LLama2โ7B optimized for chat applications within a mere 10 minutes.
from sagemaker.jumpstart.model import JumpStartModel
my_model = JumpStartModel(model_id = "meta-textgeneration-llama-2-7b-f")
predictor = my_model.deploy()
print(predictor.endpoint_name) # SageMaker Endpoint Name
Checkout the detailed blogยนby AWS on hosting foundational models using SageMaker Jumpstart
Note: You need an AWS Account to setup a SageMaker Endpoint
After a successful deployment, the endpoint information should be available in the SageMaker web console
5. Create an application for handling user interaction and LLM response ๐ฅ
We have reached the final step of the process. The remaining task involves developing a web application that will enable users to interact seamlessly with the LLM.
To accomplish this, we will create two Python modules. The first module will be responsible for managing the communication between the userโs query (prompt) and the LLM, handling both input and output responses. The second module will focus on providing the Web UI using Streamlit, offering a user-friendly interface for the interaction.
- Module 1:
retrieve_from_llama2.py
#!/usr/bin/env python3
# coding: utf-8
from __future__ import annotations
import json
import os
import sys
from dotenv import load_dotenv
from langchain import PromptTemplate, SagemakerEndpoint
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import SentenceTransformerEmbeddings
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.vectorstores import FAISS
# Get Env Variables
load_dotenv() # load the values for environment variables from the .env file
AWS_REGION=os.environ.get('AWS_REGION')
EMBEDDING_MODEL=os.environ.get('EMBEDDING_MODEL')
LLAMA2_ENDPOINT=os.environ.get('LLAMA2_ENDPOINT')
MAX_HISTORY_LENGTH=os.environ.get('MAX_HISTORY_LENGTH')
def build_chain():
# Sentence transformer
embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)
# Laod Faiss index
db = FAISS.load_local("faiss_index", embeddings)
# Default system prompt for the LLamav2 on SageMaker Jumpstart Endpoint
system_prompt = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
# Custom ContentHandler to handle input and output to the SageMaker Endpoint
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
payload = {
"inputs": [
[
{
"role": "system",
"content": system_prompt,
},
{"role": "user", "content": prompt},
],
],
"parameters": {"max_new_tokens": 1000, "top_p": 0.9, "temperature": 0.6},
}
input_str = json.dumps(
payload,
)
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
content = response_json[0]["generation"]["content"]
return content
# Langchain chain for invoking SageMaker Endpoint
llm = SagemakerEndpoint(
endpoint_name=LLAMA2_ENDPOINT,
region_name=AWS_REGION,
content_handler=ContentHandler(),
callbacks=[StreamingStdOutCallbackHandler()],
endpoint_kwargs={"CustomAttributes": "accept_eula=true"},
)
def get_chat_history(inputs) -> str:
res = []
for _i in inputs:
if _i.get("role") == "user":
user_content = _i.get("content")
if _i.get("role") == "assistant":
assistant_content = _i.get("content")
res.append(f"user:{user_content}\nassistant:{assistant_content}")
return "\n".join(res)
condense_qa_template = """
Given the following conversation and a follow up question, rephrase the follow up question
to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
standalone_question_prompt = PromptTemplate.from_template(
condense_qa_template,
)
# Langchain chain for Conversation
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=db.as_retriever(search_kwargs={"k": 2}),
condense_question_prompt=standalone_question_prompt,
return_source_documents=True,
get_chat_history=get_chat_history,
# verbose=True,
)
return qa
def run_chain(chain, prompt: str, history=[]):
return chain({"question": prompt, "chat_history": history})
In the above code, there are two functions build_chain
and run_chain
- ๐The
build_chain()
function sets up the conversational retrieval chain. It performs the following steps:
1. Loads the FAISS index containing vector embeddings of the documents.
* When loading the embeddings, we will be using the same SentenceTransformer model that was used for creating the embeddings
2. Create a system promptsystem_prompt
* Providing a system prompt to an LLM helps the LLM to not hallucinate as well as provide a safe answer to the user.
3. Creates a customContentHandler
class to handle input and output to the SageMaker endpoint. Thetransform_input()
method prepares the prompt and system prompt for the LLM, and thetransform_output()
method handles the LLM's response.
4. SageMakerEndpoint chain โ A LangChain chain to handle request & response with an AWS SageMaker inference endpoint
5.get_chat_history
โ This function is defined within thebuild_chain()
function. It processes the chat history, extracting user and assistant responses, and returns them as a formatted string.
6.ConversationalRetrievalChain
โ A LangChain chain that handles execution of the above Components to accomplish the task. Below picture explains how this chain works
7.condense_qa_template
andstandalone_question_prompt
: These variables define the template and prompt for condensing a follow-up question into a standalone question in the conversation.
run_chain()
Function: Therun_chain()
function takes the initialized chain, a user prompt, and an optional history of the conversation. It invokes the conversational retrieval chain to get a response based on the user prompt and history.
An AI prompt is any form of text, question, information, or coding that communicates to AI what response youโre looking for. The point of a prompt is to take advantage of natural language processing (NLP), which lets you ask an AI a question using normal words and syntax as you would a real person.
2. Module 2: streamlit_app.py
This module has the code for a streamlit application that ultimately enables the interaction between a user and a LLM
Streamlit is an open-source Python library that makes it easy to create and share beautiful, custom web apps for machine learning and data science.
import os
import streamlit as st
from dotenv import load_dotenv
import retrieve_from_llama2 as llama2
# Get Env Variables
load_dotenv() # load the values for environment variables from the .env file
MAX_HISTORY_LENGTH=os.environ.get('MAX_HISTORY_LENGTH') # Determine how many conversation to be stored in the chat history
###Set Streamlit Session State Variables:###
st.session_state["llm_app"] = llama2
st.session_state["llm_chain"] = llama2.build_chain()
###Initial UI configuration:###
st.set_page_config(page_title="AIDocChatBot", page_icon="๐")
def render_app():
# Reduce font sizes for input text boxes. Reduce button sizes too.
custom_css = """
<style>
.stTextArea textarea {font-size: 13px;}
div[data-baseweb="select"] > div {font-size: 13px !important;}
</style>
<style>
button {
height: 30px !important;
width: 150px !important;
padding-top: 10px !important;
padding-bottom: 10px !important;
}
</style>
"""
st.markdown(custom_css, unsafe_allow_html=True)
# Set config for a cleaner menu, footer & background:
hide_streamlit_style = """
<style>
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
</style>
"""
st.markdown(hide_streamlit_style, unsafe_allow_html=True)
st.subheader("Hello ๐ I'm your AI ChatBot๐")
# Accept user input
# container for the chat history
st.container()
# container for the user input
st.container()
# Set up/Initialize Session State variables:
if "chat_dialogue" not in st.session_state:
st.session_state["chat_dialogue"] = []
if "llm" not in st.session_state:
st.session_state["llm"] = llama2
st.session_state["llm_chain"] = llama2.build_chain()
# Add the "Clear Chat History" button to the sidebar
def clear_history():
st.session_state["chat_dialogue"] = []
# Display chat messages from history on app rerun
for message in st.session_state.chat_dialogue:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if len(st.session_state.chat_dialogue) == int(MAX_HISTORY_LENGTH):
st.session_state.chat_dialogue = st.session_state.chat_dialogue[:-1]
clear_history()
if prompt := st.chat_input("Type your question here..."):
# Add user message to chat history
st.session_state.chat_dialogue.append({"role": "user", "content": prompt})
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Display message from LLM / assistant
with st.chat_message("assistant"):
answer_placeholder = st.empty()
answer = ""
for dict_message in st.session_state.chat_dialogue:
if dict_message["role"] == "user":
string_dialogue = "User: " + dict_message["content"] + "\n\n"
else:
string_dialogue = "Assistant: " + dict_message["content"] + "\n\n"
llm_chain = st.session_state["llm_chain"]
chain = st.session_state["llm_app"]
try:
output = chain.run_chain(llm_chain, prompt)
except Exception:
output = {}
output["answer"] = "I'm sorry I'm not unable to respond to your question ๐"
answer = output.get("answer")
if 'source_documents' in output:
with st.expander("Sources"):
for _sd in output.get('source_documents'):
_sd_metadata = _sd.metadata
source = _sd_metadata.get('source').replace('./aws_docs/sagemaker/', 'https://docs.aws.amazon.com/sagemaker/latest/dg/')
title = _sd_metadata.get('title')
st.write(f"{title} --> {source}")
answer_placeholder.markdown(answer + "โ")
# Add assistant response to chat history
st.session_state.chat_dialogue.append({"role": "assistant", "content": answer})
col1, col2 = st.columns([10, 4])
with col1:
pass
with col2:
st.button("Clear History", use_container_width=True, on_click=clear_history)
render_app()
Running the app on your local machine
- Clone this GitHub repo
2. Create a conda environment as mentioned above
3. Download the data files using wget
as shown above
4. Run the dataprep.py
to create the vector embeddings
5. Setup a SageMaker Endpoint with Llama2 as mentioned belowยน
Update the endpoint name in the .env
file
6. Run the below command to start the streamlit app
streamlit run streamlit_app.py --server.address 0.0.0.0 --server.port 8080 --server.fileWatcherType none --browser.gatherUsageStats False
# You can now view your Streamlit app in your browser.
# URL: http://0.0.0.0:8080
Improvement Opportunities
Below are some of areas that Iโm planning to improve in my Part-2
- Chunk the source documents. Currently, Iโm just creating a vector embedding per document. The LLM would respond faster and better if the provided context is lot more precise and clean. This would require splitting the documents into chunks in a more meaningful way
- Try leveraging GPU for creating the embeddings and index search.
- Add support for multiple LLMs so that a user can compare the results.
If you liked the above blog, please leave a clap ๐ or comment โ๏ธ below. Also, follow me for more such blogs in the future ๐ค
References
- https://aws.amazon.com/blogs/machine-learning/llama-2-foundation-models-from-meta-are-now-available-in-amazon-sagemaker-jumpstart/
- https://github.com/a16z-infra/llama2-chatbot
#GenerativeAI #LLM #LLaMa #Streamlit #LangChain #MachineLearning #AWS #SageMaker