Implementation
Basic RAG
Introduction¶
Retrieval-Augmented Generation (RAG) is a powerful technique that combines the strengths of large language models with the ability to retrieve relevant information from a knowledge base. This approach enhances the quality and accuracy of generated responses by grounding them in specific, retrieved information.
This notebook aims to provide a clear and concise introduction to RAG, suitable for beginners who want to understand and implement this technology.
Motivation¶
Traditional language models generate text based on learned patterns from training data. However, when they are presented with queries that require specific, updated, or niche information, they may struggle to provide accurate responses. RAG addresses this limitation by incorporating a retrieval step that provides the language model with relevant context to generate more informed answers.
Method Details¶
Document Preprocessing and Vector Store Creation¶
Document Chunking: The knowledge base documents (e.g., PDFs, articles) are preprocessed and split into manageable chunks. This is done to create a searchable corpus that can be efficiently used in the retrieval process.
Embedding Generation: Each chunk is converted into a vector representation using pre-trained embeddings (e.g., OpenAI's embeddings). This allows the documents to be stored in a vector database, such as Qdrant, enabling efficient similarity searches.
Retrieval-Augmented Generation Workflow¶
Query Input: A user provides a query that needs to be answered.
Retrieval Step: The query is embedded into a vector using the same embedding model that was used for the documents. A similarity search is then performed in the vector database to find the most relevant document chunks.
Generation Step: The retrieved document chunks are passed to a large language model (e.g., GPT-4) as additional context. The model uses this context to generate a more accurate and relevant response.
Key Features of RAG¶
Contextual Relevance: By grounding responses in actual retrieved information, RAG models can produce more contextually relevant and accurate answers.
Scalability: The retrieval step can scale to handle large knowledge bases, allowing the model to draw from vast amounts of information.
Flexibility in Use Cases: RAG can be adapted for a variety of applications, including question answering, summarization, recommendation systems, and more.
Improved Accuracy: Combining generation with retrieval often yields more precise results, especially for queries requiring specific or lesser-known information.
Benefits of this Approach¶
Combines Strengths of Both Retrieval and Generation: RAG effectively merges retrieval-based methods with generative models, allowing for both precise fact-finding and natural language generation.
Enhanced Handling of Long-Tail Queries: It is particularly effective for queries where specific and less frequently occurring information is needed.
Domain Adaptability: The retrieval mechanism can be tuned to specific domains, ensuring that the generated responses are grounded in the most relevant and accurate domain-specific information.
Conclusion¶
Retrieval-Augmented Generation (RAG) represents an innovative fusion of retrieval and generation techniques, significantly enhancing the capabilities of language models by grounding their outputs in relevant external information. This approach can be particularly valuable in scenarios requiring precise, context-aware responses, such as customer support, academic research, and more. As AI continues to evolve, RAG stands out as a powerful method for building more reliable and context-sensitive AI systems.
Prerequisites¶
- Preferably Python 3.11
- Jupyter Notebook or JupyterLab
- LLM API Key
- You can use any llm of your choice in this notebook we have use OpenAI and Gpt-4o-mini
With these steps, you can implement a basic RAG system to enhance the capabilities of language models by incorporating real-world, up-to-date information, improving their effectiveness in various applications.
Setting up the Environment¶
!pip install llama-index
!pip install llama-index-vector-stores-qdrant
!pip install llama-index-readers-file
!pip install llama-index-embeddings-fastembed
!pip install llama-index-llms-openai
!pip install llama-index-llms-groq
!pip install -U qdrant_client fastembed
!pip install python-dotenv
!pip install gradio
# Standard library imports
import logging
import sys
import os
# Third-party imports
from dotenv import load_dotenv
from IPython.display import Markdown, display
# Qdrant client import
import qdrant_client
# LlamaIndex core imports
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
from llama_index.core import Settings
# LlamaIndex vector store import
from llama_index.vector_stores.qdrant import QdrantVectorStore
# Embedding model imports
from llama_index.embeddings.fastembed import FastEmbedEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
# LLM import
from llama_index.llms.openai import OpenAI
from llama_index.llms.groq import Groq
# Load environment variables
load_dotenv()
# Get OpenAI API key from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GROK_API_KEY = os.getenv("GROK_API_KEYs")
# Setting up Base LLM
# Settings.llm = OpenAI(
# model="gpt-4o-mini", temperature=0.1, max_tokens=1024, streaming=True
# )
Settings.llm = Groq(model="llama-3.1-70b-versatile" , api_key=GROK_API_KEY)
# Set the embedding model
# Option 1: Use FastEmbed with BAAI/bge-base-en-v1.5 model (default)
Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-base-en-v1.5")
# Option 2: Use OpenAI's embedding model (commented out)
# If you want to use OpenAI's embedding model, uncomment the following line:
# Settings.embed_model = OpenAIEmbedding(embed_batch_size=10, api_key=OPENAI_API_KEY)
# Qdrant configuration (commented out)
# If you're using Qdrant, uncomment and set these variables:
# QDRANT_CLOUD_ENDPOINT = os.getenv("QDRANT_CLOUD_ENDPOINT")
# QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
# Note: Remember to add QDRANT_CLOUD_ENDPOINT and QDRANT_API_KEY to your .env file if using Qdrant Hosted version
Load the Data¶
# lets loading the documents using SimpleDirectoryReader
print("🔃 Loading Data")
from llama_index.core import Document
reader = SimpleDirectoryReader("/content/data" , recursive=True)
documents = reader.load_data(show_progress=True)
Setting up Vector Database¶
We will be using qDrant as the Vector database There are 4 ways to initialize qdrant
- Inmemory
client = qdrant_client.QdrantClient(location=":memory:")
- Disk
client = qdrant_client.QdrantClient(path="./data")
- Self hosted or Docker
client = qdrant_client.QdrantClient(
# url="http://<host>:<port>"
host="localhost",port=6333
)
- Qdrant cloud
client = qdrant_client.QdrantClient(
url=QDRANT_CLOUD_ENDPOINT,
api_key=QDRANT_API_KEY,
)
for this notebook we will be using qdrant cloud
# creating a qdrant client instance
client = qdrant_client.QdrantClient(
# you can use :memory: mode for fast and light-weight experiments,
# it does not require to have Qdrant deployed anywhere
# but requires qdrant-client >= 1.1.1
# location=":memory:"
# otherwise set Qdrant instance address with:
# url=QDRANT_CLOUD_ENDPOINT,
# otherwise set Qdrant instance with host and port:
# host="localhost",
# port=6333
# set API KEY for Qdrant Cloud
# api_key=QDRANT_API_KEY,
path="./db/"
)
vector_store = QdrantVectorStore(client=client, collection_name="01_Basic_RAG")
Ingest Data into vector DB¶
## ingesting data into vector database
## lets set up an ingestion pipeline
from llama_index.core.node_parser import TokenTextSplitter
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.node_parser import MarkdownNodeParser
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.ingestion import IngestionPipeline
pipeline = IngestionPipeline(
transformations=[
# MarkdownNodeParser(include_metadata=True),
# TokenTextSplitter(chunk_size=500, chunk_overlap=20),
SentenceSplitter(chunk_size=1024, chunk_overlap=20),
# SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=95 , embed_model=Settings.embed_model),
Settings.embed_model,
],
vector_store=vector_store,
)
# Ingest directly into a vector db
nodes = pipeline.run(documents=documents , show_progress=True)
print("Number of chunks added to vector DB :",len(nodes))
Setting Up Index¶
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
Modifying Prompts and Prompt Tuning¶
from llama_index.core import ChatPromptTemplate
qa_prompt_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
)
refine_prompt_str = (
"We have the opportunity to refine the original answer "
"(only if needed) with some more context below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question: {query_str}. "
"If the context isn't useful, output the original answer again.\n"
"Original Answer: {existing_answer}"
)
# Text QA Prompt
chat_text_qa_msgs = [
("system","You are a AI assistant who is well versed with answering questions from the provided context"),
("user", qa_prompt_str),
]
text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)
# Refine Prompt
chat_refine_msgs = [
("system","Always answer the question, even if the context isn't helpful.",),
("user", refine_prompt_str),
]
refine_template = ChatPromptTemplate.from_messages(chat_refine_msgs)
Example of Retrivers¶
- Query Engine
- Chat Engine
chat_query = "What is in this document"
# Setting up Chat Engine
BASE_RAG_CHAT_ENGINE = index.as_chat_engine()
response = BASE_RAG_CHAT_ENGINE.chat(chat_query)
display(Markdown(str(response)))
Simple Chat Application with RAG¶
from typing import List
from llama_index.core.base.llms.types import ChatMessage, MessageRole
class ChatEngineInterface:
def __init__(self, index):
self.chat_engine = index.as_chat_engine()
self.chat_history: List[ChatMessage] = []
def display_message(self, role: str, content: str):
if role == "USER":
display(Markdown(f"**Human:** {content}"))
else:
display(Markdown(f"**AI:** {content}"))
def chat(self, message: str) -> str:
# Create a ChatMessage for the user input
user_message = ChatMessage(role=MessageRole.USER, content=message)
self.chat_history.append(user_message)
# Get response from the chat engine
response = self.chat_engine.chat(message, chat_history=self.chat_history)
# Create a ChatMessage for the AI response
ai_message = ChatMessage(role=MessageRole.ASSISTANT, content=str(response))
self.chat_history.append(ai_message)
# Display the conversation
self.display_message("USER", message)
self.display_message("ASSISTANT", str(response))
print("\n" + "-"*50 + "\n") # Separator for readability
return str(response)
def get_chat_history(self) -> List[ChatMessage]:
return self.chat_history
chat_interface = ChatEngineInterface(index)
while True:
user_input = input("You: ").strip()
if user_input.lower() == 'exit':
print("Thank you for chatting! Goodbye.")
break
chat_interface.chat(user_input)
# To view chat history:
history = chat_interface.get_chat_history()
for message in history:
print(f"{message.role}: {message.content}")
Gradio Applicaiton¶
import gradio as gr
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Document, Settings
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
import qdrant_client
import os
import tempfile
import shutil
from typing import List
from llama_index.core.base.llms.types import ChatMessage, MessageRole
class RAGChatbot:
def __init__(self):
self.client = qdrant_client.QdrantClient(path="./Demo_RAG")
self.vector_store = None
self.index = None
self.chat_engine = None
self.chat_history = []
# Initialize vector store and index
self.vector_store = QdrantVectorStore(
client=self.client,
collection_name="Demo_RAG"
)
# Create the index and ingest documents
self.index = VectorStoreIndex.from_vector_store(
vector_store=self.vector_store
)
# Initialize chat engine
self.chat_engine = self.index.as_chat_engine(
streaming=True,
verbose=True
)
def process_uploaded_files(self, files) -> str:
try:
# Create a temporary directory for processing
with tempfile.TemporaryDirectory() as temp_dir:
# Save uploaded files to temporary directory
for file in files:
shutil.copy(file.name, temp_dir)
# Load documents
reader = SimpleDirectoryReader(temp_dir)
documents = reader.load_data()
pipeline = IngestionPipeline(
transformations=[
# MarkdownNodeParser(include_metadata=True),
# TokenTextSplitter(chunk_size=500, chunk_overlap=20),
SentenceSplitter(chunk_size=1024, chunk_overlap=20),
# SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=95 , embed_model=Settings.embed_model),
Settings.embed_model,
],
vector_store=self.vector_store,
)
# Ingest directly into a vector db
nodes = pipeline.run(documents=documents , show_progress=True)
return f"Successfully processed {len(documents)} documents. Ready to chat! and inserted {len(nodes)} into the database"
except Exception as e:
return f"Error processing files: {str(e)}"
def chat(self, message: str, history: List[List[str]]) -> List[List[str]]:
if self.chat_engine is None:
return history + [[message, "Please upload documents first before starting the chat."]]
try:
# Convert history to ChatMessage format
chat_history = []
for h in history:
chat_history.extend([
ChatMessage(role=MessageRole.USER, content=h[0]),
ChatMessage(role=MessageRole.ASSISTANT, content=h[1])
])
# Add current message to history
chat_history.append(ChatMessage(role=MessageRole.USER, content=message))
# Get response from chat engine
response = self.chat_engine.chat(message, chat_history=chat_history)
# Return the updated history with the new message pair
return history + [[message, str(response)]]
except Exception as e:
return history + [[message, f"Error generating response: {str(e)}"]]
def create_demo():
# Initialize the chatbot
chatbot = RAGChatbot()
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# RAG Chatbot")
gr.Markdown("Upload your documents and start chatting!")
with gr.Row():
with gr.Column(scale=1):
file_output = gr.File(
file_count="multiple",
label="Upload Documents",
file_types=[".txt", ".pdf", ".docx", ".md"]
)
upload_button = gr.Button("Process Documents")
status_box = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
chatbot_interface = gr.Chatbot(
label="Chat History",
height=400,
bubble_full_width=False,
)
with gr.Row():
msg = gr.Textbox(
label="Type your message",
placeholder="Ask me anything about the uploaded documents...",
lines=2,
scale=4
)
submit_button = gr.Button("Submit", scale=1)
clear = gr.Button("Clear")
# Event handlers
upload_button.click(
fn=chatbot.process_uploaded_files,
inputs=[file_output],
outputs=[status_box],
)
submit_button.click(
fn=chatbot.chat,
inputs=[msg, chatbot_interface],
outputs=[chatbot_interface],
)
clear.click(
lambda: None,
None,
chatbot_interface,
queue=False
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True)