-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: query and convert img to embedding, cosine similarity calculations
- Loading branch information
Showing
3 changed files
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from phi.tools.image_to_image_search import ImageSearcher | ||
|
||
def search_image(): | ||
# Initialize the image searcher | ||
image_searcher = ImageSearcher() | ||
|
||
# Add test images | ||
image_urls = [ | ||
"http://images.cocodataset.org/val2017/000000039769.jpg", # cat | ||
"https://cdn.pixabay.com/photo/2023/08/18/15/02/dog-8198719_1280.jpg", # dog | ||
] | ||
|
||
# Add images to database | ||
image_searcher.add_images(image_urls, clear_existing=True) | ||
|
||
# Search with a query image | ||
query_image = "https://static.vecteezy.com/system/resources/thumbnails/024/646/930/small_2x/ai-generated-stray-cat-in-danger-background-animal-background-photo.jpg" | ||
similar_images = image_searcher.search_similar_images(query_image, limit=2) | ||
|
||
# Print results | ||
print("\nSimilar images found:") | ||
for i, img in enumerate(similar_images, 1): | ||
print(f"\n{i}. URL: {img['url']}") | ||
print(f" Type: {img['type']}") | ||
print(f" Distance: {img['distance']:.4f}") | ||
|
||
if __name__ == "__main__": | ||
search_image() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from phi.tools.openai_embeddings import OpenAIEmbeddings | ||
from phi.vectordb.chroma import ChromaDb | ||
from phi.document import Document | ||
from typing import List | ||
import numpy as np | ||
|
||
class ImageSearcher: | ||
def __init__(self): | ||
# Initialize embeddings | ||
self.embeddings = OpenAIEmbeddings() | ||
|
||
# Initialize ChromaDB | ||
self.vector_db = ChromaDb( | ||
collection="image_collection", | ||
embedder=self.embeddings, | ||
path="tmp/image_db", | ||
persistent_client=True | ||
) | ||
|
||
# Create the collection | ||
self.vector_db.create() | ||
|
||
def add_images(self, image_urls: List[str], clear_existing: bool = False): | ||
"""Add images to the vector database""" | ||
if clear_existing: | ||
self.vector_db.drop() | ||
self.vector_db.create() | ||
|
||
# Create Document objects for each image | ||
image_docs = [] | ||
for i, url in enumerate(image_urls): | ||
doc = Document( | ||
content=url, | ||
name=f"image_{i}", | ||
meta_data={"type": url.split('/')[-1].split('_')[0]} # Extract image type from URL | ||
) | ||
image_docs.append(doc) | ||
|
||
# Insert documents into the vector database | ||
self.vector_db.insert(image_docs) | ||
print(f"Added {len(image_docs)} images to the database") | ||
|
||
def search_similar_images(self, query_image_url: str, category: str = None, limit: int = 5) -> List[dict]: | ||
"""Search for similar images given a query image URL and optional category""" | ||
# Get embedding for query image | ||
self.embeddings.load_model() | ||
query_embedding = self.embeddings.get_image_embeddings(query_image_url) | ||
|
||
# Get raw collection to access embeddings | ||
chroma_collection = self.vector_db._collection | ||
collection_data = chroma_collection.get( | ||
include=['embeddings', 'documents', 'metadatas'] | ||
) | ||
|
||
# Calculate similarities manually | ||
similar_images = [] | ||
for doc_content, doc_embedding in zip(collection_data['documents'], collection_data['embeddings']): | ||
if doc_content != query_image_url: # Skip the query image | ||
# Calculate cosine similarity | ||
similarity = np.dot(query_embedding, doc_embedding) / ( | ||
np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding) | ||
) | ||
print("similarity", similarity) | ||
distance = 1 - similarity | ||
|
||
similar_images.append({ | ||
'url': doc_content, | ||
'distance': distance, | ||
'type': doc_content.split('/')[-1].split('_')[0] | ||
}) | ||
|
||
# Sort by distance | ||
similar_images.sort(key=lambda x: x['distance']) | ||
return similar_images[:limit] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from phi.tools import Toolkit | ||
from PIL import Image | ||
import requests | ||
from io import BytesIO | ||
import numpy as np | ||
|
||
try: | ||
import open_clip | ||
except ImportError: | ||
raise ImportError("`open-clip-torch` is not installed. Please install using `pip install open-clip-torch`") | ||
|
||
class OpenAIEmbeddings(Toolkit): | ||
def __init__(self, model: str = "hf-hub:laion/CLIP-ViT-g-14-laion2B-s12B-b42K"): | ||
super().__init__(name="openai_embeddings") | ||
self.model_name = model # For storing the original model name for tokenizer to use. | ||
self.model = None | ||
self.preprocess = None | ||
self.tokenizer = None | ||
self.register(self.load_model) | ||
self.register(self.get_text_embeddings) | ||
self.register(self.get_image_embeddings) | ||
|
||
def load_model(self): | ||
if self.model is None: | ||
self.model, self.preprocess = open_clip.create_model_from_pretrained(self.model_name) | ||
self.tokenizer = open_clip.get_tokenizer(self.model_name) | ||
|
||
def get_embedding_and_usage(self, text: str): | ||
"""Required method for phi.document.Document compatibility""" | ||
if isinstance(text, str): | ||
embedding = self.get_text_embeddings(text) | ||
else: | ||
embedding = self.get_image_embeddings(text) | ||
return embedding.tolist(), {"total_tokens": 0} # Mock usage stats | ||
|
||
def get_embedding(self, text: str): | ||
"""Required method for phi.vectordb compatibility""" | ||
embedding, _ = self.get_embedding_and_usage(text) | ||
return embedding | ||
|
||
def get_text_embeddings(self, text: str) -> np.ndarray: | ||
self.load_model() | ||
text_tokens = self.tokenizer([text]) | ||
text_embeddings = self.model.encode_text(text_tokens) | ||
return text_embeddings.detach().cpu().numpy()[0] | ||
|
||
def get_image_embeddings(self, image_url: str) -> np.ndarray: | ||
self.load_model() | ||
# Load image from URL | ||
response = requests.get(image_url) | ||
image = Image.open(BytesIO(response.content)) | ||
# Preprocess and get embeddings | ||
image_input = self.preprocess(image).unsqueeze(0) | ||
image_embeddings = self.model.encode_image(image_input) | ||
return image_embeddings.detach().cpu().numpy()[0] | ||
|
||
|