Understanding Source Attribution in RAG Systems

Leveraging ProtoDash for Context Attribution

Gautam Chutani
5 min readJun 25, 2024

Co-Author: Jayesh Ranjan

In the realm of modern artificial intelligence, achieving transparency and interpretability in machine learning models is paramount. This need is especially pronounced in systems employing Retrieval-Augmented Generation (RAG), where understanding the influence of contextual source documents on the generated output for a large language model is crucial for maintaining reliability and accuracy.

Image Source

What is ProtoDash?

Developed collaboratively by Amazon and IBM Research, ProtoDash stands out as an advanced algorithm designed to identify prototypical examples within datasets. Its primary goal is to capture and quantify the underlying distribution of data, providing weighted representations that closely reflect the dataset’s characteristics.

Applying ProtoDash to RAG Systems

In the context of RAG systems, ProtoDash is particularly useful for source attribution. The goal is to determine how retrieved source documents contribute to the content generated by the language model. ProtoDash selects prototypes from the retrieved documents that effectively explain the generation outcomes, helping us understand which documents and passages have the most significant influence on the final output.

Methodology and Benefits

ProtoDash operates by minimizing the Maximum Mean Discrepancy (MMD) between the generated content and a subset of examples from the retrieved source documents (prototypes). This selection process ensures that the chosen prototypes accurately represent the key influences on the generated output while highlighting diverse aspects of the input data’s features.

ProtoDash’s approach is particularly advantageous for RAG systems because it enhances:

  • Interpretability: By identifying prototypes that best match the generated content, ProtoDash enables stakeholders to comprehend how specific retrieved documents and passages lead to particular outputs in RAG-generated content.
  • Accuracy: The algorithm’s emphasis on minimizing MMD helps in selecting prototypes that are not only representative but also diverse, thereby improving the fidelity of generated content to the source material.
  • Validation: ProtoDash aids in validating the reliability of RAG-generated outputs by ensuring that the selected prototypes encapsulate the necessary contextual nuances and thematic elements from the retrieved source documents.

Implementing ProtoDash for Source Attribution in RAG

To see the ProtoDash algorithm in action for understanding the relative attribution of context supplied to generated content in RAG applications, we will utilize the IBM Watson OpenScale platform for configuring the model explainability. We will use the IBM Watson OpenScale Python client SDK to customize our explainability metrics using the ibm_metrics_plugin which is a python API for computing fairness metrics and explaining transactions.

Download and Load sample data

For demonstration, we will download the sample cricket data from the specified URL and save the CSV file which consists of 4 questions.

import os
import wget

filename = "sample_cricket_data.csv"
url = "https://raw.githubusercontent.com/gautamgc17/RAG-Assets/main/sample_data/sample_cricket_data.csv"

if not os.path.isfile(filename):
wget.download(url, out=filename)
print(f"Downloaded File - {filename}")
else:
print("File with same name already exists! Skipping download....")
import pandas as pd

df = pd.read_csv("sample_cricket_data.csv")
df.head()

The dataframe consists of 4 columns, namely — ‘Question’, ‘Chunk1’, ‘Chunk2’, ‘Chunk3’, ‘Answer’.

Load the custom Embedding model

from langchain.embeddings import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2')

Source Attribution detection

Source attribution for LLM responses in RAG based applications is computed using Protodash Explainer . The information needed for this computation :

  1. Response data for which source attribution has to be identified. This is considered as input data
  2. Context information retained using RAG . This is considered as reference data

Using the above information , prototypes of the input are identified. Using this technique the source in the context which has attributed to the response is identified. Now, we will construct a dataframe with results and contexts to supply for source attribution.

  • generated_text : Response from the foundation model
  • context: Relevant context retrieved from knowledge base for each question
from copy import deepcopy

data = deepcopy(df)
data = data.rename(columns={'Answer': 'generated_text'})
data['context'] = data[['Chunk1', 'Chunk2', 'Chunk3']].apply(list, axis=1)
data = data.drop(columns=['Question', 'Chunk1', 'Chunk2', 'Chunk3'])

Set up the Openscale Client

from ibm_cloud_sdk_core.authenticators import IAMAuthenticator
from ibm_watson_openscale import APIClient as WOSClient
from ibm_watson_openscale.supporting_classes.enums import *
from ibm_watson_openscale.supporting_classes import *

project_id = "*****"
credentials = {
"url": "****",
"apikey": "*****"
}

authenticator = IAMAuthenticator(apikey=credentials.get("apikey"))
client = WOSClient(authenticator=authenticator)

Update the configuration needed for Source Attribution

from ibm_metrics_plugin.common.utils.constants import ExplainabilityMetricType
from ibm_metrics_plugin.metrics.explainability.entity.explain_config import ExplainConfig
from ibm_metrics_plugin.common.utils.constants import InputDataType,ProblemType

config_json = {
"configuration": {

"input_data_type": InputDataType.TEXT.value,
"problem_type": ProblemType.QA.value,
"feature_columns":["context"],
"prediction": "generated_text",
"context_column": "context",
"explainability": {

"metrics_configuration": {
ExplainabilityMetricType.PROTODASH.value:{
"embedding_fn": embeddings.embed_documents
}
}
}
}
}

Run ProtoDash explainer to identify source attribution for the RAG based responses

results_response = client.ai_metrics.compute_metrics(configuration=config_json,data_frame=data)
metrics = results_response.get("metrics_result")
results = metrics.get("explainability").get("protodash")
Output for single question — Image by Author

Source attribution can be understood using the weights (the attribution or contribution factor) and the prototypes (the relevant context or source) which has attributed to the response by the foundation model behind the scenes . For example a weight: 1.0 indicate that that a single paragraph of the context has attributed for response by foundation model. Likewise weights : 0.6,0.3,0.1 indicate that 3 paragraphs have attributed for response by foundation model behind the scenes. The prototype values are the paragraphs supplied as part of the relevant context.

Visualizing ProtoDash Results with Plotly

On hovering each bar in the below plot, one can see the entire text for each source chunk and their relative scores.

import plotly.graph_objects as go

protodash_data = metrics['explainability']['protodash']

for idx, (question, entry) in enumerate(zip(df["Question"].tolist(), protodash_data), start=1):
prototypes = entry['prototypes']['values']
weights = [prototype[0] for prototype in prototypes]
contexts = [prototype[1] for prototype in prototypes]

fig = go.Figure()

fig.add_trace(go.Bar(
y=[f'Context {i+1}' for i in range(len(contexts))],
x=weights,
orientation='h',
marker=dict(color='skyblue'),
hovertext=contexts, # Full contexts for hover text
hoverinfo='text',
hovertemplate='%{hovertext}<extra></extra>', # Custom hover template

))

fig.update_layout(
title=f'ProtoDash Results (Question {idx})',
xaxis_title='Source Attribution (Weight)',
yaxis_title='Source Document',
height=600,
margin=dict(l=100, r=50, t=75, b=50),
showlegend=False,
)

fig.update_traces(hoverlabel=dict(bgcolor='rgba(255,255,255,0.7)', font_size=13))
fig.show()
Visualizing ProtoDash Results with Plotly — Image by Author

Conclusion

In summary, it’s one of the key algorithms for understanding datasets, machine learning models, and Retrieval-Augmented Generation (RAG) systems. By identifying and weighting prototypical examples, ProtoDash provides insights into data distribution and model behavior. Its application enhances interpretability, accuracy, and trust in AI-generated content. Embracing ProtoDash fosters transparency and reliability in various AI applications, paving the way for more informed decision-making and innovative solutions.

References

--

--

Gautam Chutani
Gautam Chutani

Written by Gautam Chutani

Exploring the depths of AI, data science, and backend development while embracing a passion for cricket!

No responses yet