Code: Select all
def load_faiss_index_and_docidmap_pkl(self):
faiss_data = {}
# Load the FAISS index and doc_id_map from Google Cloud Storage
faiss_index_blob = self._bucket.blob(self._blob_name)
download_path = "/tmp/faiss_index.pkl"
logger.info(f"Attempting to download FAISS index from blob: {self._blob_name}")
faiss_index_blob.download_to_filename(download_path)
# Check if the downloaded file is valid
file_size = os.path.getsize(download_path)
if file_size == 0:
logger.error("Downloaded FAISS index file is empty.")
raise ValueError("Downloaded FAISS index file is empty.")
logger.info(f"Downloaded FAISS index file size: {file_size} bytes")
# Import faiss only when needed to reduce memory usage
import faiss
logger.info(f"FAISS version: {faiss.__version__}")
# Deserialize the FAISS index and doc_id_map
try:
with open(download_path, "rb") as f:
faiss_data = pickle.load(f)
# Immediately delete the file to free disk space
os.remove(download_path)
logger.info("Successfully deserialized FAISS data and removed temporary file")
except Exception as e:
logger.error(f"Error deserializing FAISS data: {str(e)}")
raise
faiss_index = faiss_data.get("faiss_index")
doc_id_map = faiss_data.get("doc_id_map")
# Log the type of the loaded FAISS index
logger.info(f"Loaded FAISS index type: {type(faiss_index)}")
# Check if FAISS index is valid
if not hasattr(faiss_index, "ntotal"):
logger.error("The loaded object is not a valid FAISS index.")
raise ValueError("The loaded object is not a valid FAISS index.")
logger.info(f"FAISS index loaded successfully with {faiss_index.ntotal} vectors")
# Check if doc_id_map is valid
if not doc_id_map:
logger.error("doc_id_map is missing in the downloaded data.")
raise ValueError("doc_id_map is missing in the downloaded data.")
# Validate that the doc_id_map length matches the FAISS index
if len(doc_id_map) != faiss_index.ntotal:
logger.error(f"The length of doc_id_map ({len(doc_id_map)}) and FAISS index ({faiss_index.ntotal}) do not match.")
raise ValueError("The length of doc_id_map and FAISS index do not match.")
# Log the total number of vectors in the FAISS index
logger.info(f"Total number of vectors in FAISS index: {faiss_index.ntotal}")