co3/co3/accessors/vss.py
2024-03-28 23:11:30 -07:00

101 lines
2.7 KiB
Python

import pickle
import logging
from pathlib import Path
import time
import sqlalchemy as sa
#from sentence_transformers import SentenceTransformer, util
from co3.accessor import Accessor
logger = logging.getLogger(__name__)
class VSSAccessor(Accessor):
def __init__(self, cache_path):
super().__init__()
self._model = None
self._embeddings = None
self._embedding_size = 384
self.embedding_path = Path(cache_path, 'embeddings.pkl')
def write_embeddings(self, embedding_dict):
self.embedding_path.write_bytes(pickle.dumps(embedding_dict))
def read_embeddings(self):
if not self.embedding_path.exists():
logger.warning(
f'Attempting to access non-existent embeddings at {self.embedding_path}'
)
return None
return pickle.loads(self.embedding_path.read_bytes())
@property
def model(self):
if self._model is None:
# model trained with 128 token seqs
self._model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
return self._model
@property
def embeddings(self):
if self._embeddings is None:
self._embeddings = self.read_embeddings()
return self._embeddings
def embed_chunks(self, chunks, batch_size=64, show_prog=True):
return self.model.encode(
chunks,
batch_size = batch_size,
show_progress_bar = show_prog,
convert_to_numpy = True,
normalize_embeddings = True
)
def search(
self,
query : str,
index_name : str,
limit : int = 10,
score_threshold = 0.5,
):
'''
Parameters:
index_name: one of ['chunks','blocks','notes']
'''
if not query:
return None
if index_name not in self.embeddings:
logger.warning(
f'Index "{index_name}" does not exist'
)
return None
start = time.time()
query_embedding = self.embed_chunks(query, show_prog=False)
index_ids, index_embeddings, index_items = self.embeddings[index_name]
hits = util.semantic_search(
query_embedding,
index_embeddings,
top_k=limit,
score_function=util.dot_score
)[0]
hits = [hit for hit in hits if hit['score'] >= score_threshold]
for hit in hits:
idx = hit['corpus_id']
hit['group_name'] = index_ids[idx]
hit['item'] = index_items[idx]
logger.info(f'{len(hits)} hits in {time.time()-start:.2f}s')
return hits