First commit
Some checks failed
armco-org/visual-search-engine/pipeline/head There was a failure building this commit
Some checks failed
armco-org/visual-search-engine/pipeline/head There was a failure building this commit
This commit is contained in:
49
.gitignore
vendored
Normal file
49
.gitignore
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
# Dataset and uploads
|
||||
images/
|
||||
uploads/
|
||||
|
||||
# Generated artifacts (large files)
|
||||
embeddings/
|
||||
filenames/
|
||||
*.pkl
|
||||
*.faiss
|
||||
neighbors.pkl
|
||||
|
||||
# Keep these files tracked (add with git add -f if needed):
|
||||
# - index.faiss (FAISS index for search)
|
||||
# - filenames.pkl (filename mapping for search)
|
||||
|
||||
# Python
|
||||
venv/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
dist/
|
||||
eggs/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Logs
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Predictions (legacy, can be regenerated)
|
||||
predictions/
|
||||
59
Dockerfile
Normal file
59
Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Reverse Image Search API
|
||||
# Multi-stage build for optimized image size
|
||||
|
||||
FROM python:3.11-slim as builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements and install dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --user -r requirements.txt
|
||||
|
||||
# Production stage
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install runtime dependencies only
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libgomp1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /root/.local /root/.local
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
|
||||
# Copy application code
|
||||
COPY api/ ./api/
|
||||
COPY config.py .
|
||||
COPY reverse_icon_search.py .
|
||||
COPY run_api_server.py .
|
||||
|
||||
# Copy pre-built index and filenames (must be provided at build time or mounted)
|
||||
# These files are large and should be mounted as volumes in production
|
||||
# COPY index.faiss .
|
||||
# COPY filenames.pkl .
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p uploads logs
|
||||
|
||||
# Environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV API_PORT=5002
|
||||
ENV DEBUG=false
|
||||
|
||||
# Expose port
|
||||
EXPOSE 5002
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:5002/api/health')" || exit 1
|
||||
|
||||
# Run the application
|
||||
CMD ["python", "run_api_server.py"]
|
||||
9
Jenkinsfile
vendored
Normal file
9
Jenkinsfile
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
@Library('jenkins-shared@main') _
|
||||
|
||||
kanikoPipeline(
|
||||
repoName: 'visual-search-engine',
|
||||
branch: env.BRANCH_NAME ?: 'main',
|
||||
builds: [
|
||||
[imageName: 'visual-search-engine', dockerfile: 'Dockerfile']
|
||||
]
|
||||
)
|
||||
126
README.md
Normal file
126
README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Reverse Image Search Web Application
|
||||
|
||||
This project is based on Deep Learning (ResNET50) and Machine Learning.
|
||||
It is similar to the google and amazon image search option, where
|
||||
you can upload any image and get the recommendation similar
|
||||
to the given image. In this project we use Myntra and Amazon image dataset
|
||||
from kaggle which include watches, shoes, colths. As an update
|
||||
we can train this model with other datasets also like furniture, cars, etc.
|
||||
|
||||
```ResNET50``` ```Convolutional neural network (CNN)```
|
||||
\
|
||||
``` K-Nearest Neighbors Algorithm (KNN)``` ```Streamlit```
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
| Document | Description |
|
||||
|----------|-------------|
|
||||
| [Technical Specs](docs/SPECS.md) | Architecture, technology stack, and API reference |
|
||||
| [Getting Started](docs/GETTING_STARTED.md) | Installation, setup, and running the application |
|
||||
| [Technical Debt](docs/backlog/DEBT.md) | Suggested improvements and refactoring ideas |
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ Architecture Summary
|
||||
|
||||
The system uses a two-stage approach:
|
||||
1. **Feature Extraction**: ResNet50 (pre-trained on ImageNet) extracts 2048-dimensional feature vectors from images
|
||||
2. **Similarity Search**: K-Nearest Neighbors finds the most similar images using Euclidean distance
|
||||
|
||||
**Interfaces**:
|
||||
- **Streamlit UI**: Interactive web app for uploading and searching images
|
||||
- **Flask API**: REST endpoints for programmatic access
|
||||
|
||||
> See [docs/SPECS.md](docs/SPECS.md) for detailed technical specifications.
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Prerequisites
|
||||
- Python 3.8 - 3.11
|
||||
- 8GB RAM minimum
|
||||
- ~20GB disk space
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
git clone https://github.com/deepamkalekar/Reverse-Image-Search-ML-DL-Project.git
|
||||
cd Reverse-Image-Search-ML-DL-Project
|
||||
python -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Setup Dataset
|
||||
|
||||
1. Download [Fashion Product Images Dataset](https://www.kaggle.com/paramaggarwal/fashion-product-images-dataset) (~16GB)
|
||||
2. Create `images/` folder and extract dataset images there
|
||||
3. Create required directories: `mkdir -p uploads embeddings filenames`
|
||||
|
||||
### Generate Embeddings
|
||||
|
||||
```bash
|
||||
python run_generate_embeddings.py
|
||||
```
|
||||
|
||||
### Run the Application
|
||||
|
||||
**Streamlit Web UI**:
|
||||
```bash
|
||||
streamlit run run_streamlit_ui.py
|
||||
```
|
||||
Access at http://localhost:8501
|
||||
|
||||
**Flask API**:
|
||||
```bash
|
||||
python run_api_server.py
|
||||
```
|
||||
API available at http://localhost:5002
|
||||
|
||||
> See [docs/GETTING_STARTED.md](docs/GETTING_STARTED.md) for detailed instructions, troubleshooting, and configuration options.
|
||||
|
||||
---
|
||||
|
||||
## Run Locally (Legacy)
|
||||
|
||||
Clone the project
|
||||
|
||||
```bash
|
||||
git clone https://github.com/deepamkalekar/Reverse-Image-Search-ML-DL-Project.git
|
||||
```
|
||||
|
||||
Go to the project directory
|
||||
|
||||
```bash
|
||||
cd Reverse-Image-Search-ML-DL-Project
|
||||
```
|
||||
Download Dataset :- [Product Image Data (16 GB)](https://www.kaggle.com/paramaggarwal/fashion-product-images-dataset)
|
||||
\
|
||||
Create folder **``images``** and move all the download images to this folder
|
||||
\
|
||||
Create folder **``uploads``** in the same directory for user, whenever user upload a photo it will be saved in this folder
|
||||
|
||||
Install dependencies
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
Generate Embeddings
|
||||
```bash
|
||||
python run_generate_embeddings.py
|
||||
```
|
||||
Run the Streamlit Web App
|
||||
```bash
|
||||
streamlit run run_streamlit_ui.py
|
||||
```
|
||||
|
||||
## Demo
|
||||
Click On ```Browse File``` and upload image
|
||||

|
||||
|
||||

|
||||
|
||||
|
||||
4
api/__init__.py
Normal file
4
api/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Reverse Image Search API Package
|
||||
"""
|
||||
__version__ = "1.0.0"
|
||||
87
api/app.py
Normal file
87
api/app.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""
|
||||
Flask Application Factory
|
||||
==========================
|
||||
Production-ready Flask application with logging configuration.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
|
||||
from config import config
|
||||
|
||||
|
||||
def setup_logging(app):
|
||||
"""Configure application logging."""
|
||||
log_level = logging.DEBUG if app.debug else logging.INFO
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s in %(module)s: %(message)s'
|
||||
)
|
||||
|
||||
if not os.path.exists('logs'):
|
||||
os.makedirs('logs')
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
'logs/api.log',
|
||||
maxBytes=10 * 1024 * 1024,
|
||||
backupCount=5
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(log_level)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(log_level)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(log_level)
|
||||
root_logger.addHandler(file_handler)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
app.logger.info("Logging configured")
|
||||
|
||||
|
||||
def create_app(test_config=None):
|
||||
"""Application factory for Flask app."""
|
||||
app = Flask(__name__)
|
||||
|
||||
app.config['DEBUG'] = config.DEBUG
|
||||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
||||
|
||||
if test_config:
|
||||
app.config.update(test_config)
|
||||
|
||||
CORS(app)
|
||||
|
||||
setup_logging(app)
|
||||
|
||||
if not os.path.exists(config.UPLOAD_DIR):
|
||||
os.makedirs(config.UPLOAD_DIR)
|
||||
app.logger.info(f"Created uploads directory: {config.UPLOAD_DIR}")
|
||||
|
||||
from api.routes import api_bp
|
||||
app.register_blueprint(api_bp)
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return {
|
||||
'name': 'Reverse Image Search API',
|
||||
'version': '1.0.0',
|
||||
'endpoints': {
|
||||
'health': '/api/health',
|
||||
'status': '/api/status',
|
||||
'search': '/api/search',
|
||||
'batch_search': '/api/batch-search',
|
||||
}
|
||||
}
|
||||
|
||||
app.logger.info("Application initialized successfully")
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
application = create_app()
|
||||
application.run(host='0.0.0.0', port=config.API_PORT, debug=config.DEBUG)
|
||||
95
api/exceptions.py
Normal file
95
api/exceptions.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Custom Exception Classes for API
|
||||
=================================
|
||||
Structured error handling with consistent error response format.
|
||||
"""
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""Base exception for API errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: str = 'API_ERROR',
|
||||
status_code: int = 500,
|
||||
details: dict = None
|
||||
):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.status_code = status_code
|
||||
self.details = details or {}
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert exception to API response dict."""
|
||||
return {
|
||||
'success': False,
|
||||
'error': {
|
||||
'code': self.code,
|
||||
'message': self.message,
|
||||
'details': self.details if self.details else None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ValidationError(APIError):
|
||||
"""Raised when request validation fails."""
|
||||
|
||||
def __init__(self, message: str, field: str = None):
|
||||
details = {'field': field} if field else {}
|
||||
super().__init__(
|
||||
message=message,
|
||||
code='VALIDATION_ERROR',
|
||||
status_code=400,
|
||||
details=details
|
||||
)
|
||||
|
||||
|
||||
class FileNotFoundError(APIError):
|
||||
"""Raised when a required file is not found."""
|
||||
|
||||
def __init__(self, message: str, filepath: str = None):
|
||||
details = {'filepath': filepath} if filepath else {}
|
||||
super().__init__(
|
||||
message=message,
|
||||
code='FILE_NOT_FOUND',
|
||||
status_code=404,
|
||||
details=details
|
||||
)
|
||||
|
||||
|
||||
class SearchError(APIError):
|
||||
"""Raised when search operation fails."""
|
||||
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
super().__init__(
|
||||
message=message,
|
||||
code='SEARCH_ERROR',
|
||||
status_code=500,
|
||||
details=details
|
||||
)
|
||||
|
||||
|
||||
class IndexNotLoadedError(APIError):
|
||||
"""Raised when FAISS index is not loaded."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
message='Search index not loaded. Ensure index.faiss exists.',
|
||||
code='INDEX_NOT_LOADED',
|
||||
status_code=503,
|
||||
details={'hint': 'Run generate_embeddings.py to create the index'}
|
||||
)
|
||||
|
||||
|
||||
class ModelNotLoadedError(APIError):
|
||||
"""Raised when ML model fails to load."""
|
||||
|
||||
def __init__(self, model_name: str = 'ResNet50'):
|
||||
super().__init__(
|
||||
message=f'Failed to load {model_name} model',
|
||||
code='MODEL_NOT_LOADED',
|
||||
status_code=503,
|
||||
details={'model': model_name}
|
||||
)
|
||||
238
api/routes.py
Normal file
238
api/routes.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
API Routes for Reverse Image Search
|
||||
====================================
|
||||
Production-ready REST API with proper error handling and logging.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from functools import wraps
|
||||
from flask import Blueprint, request, jsonify, current_app
|
||||
|
||||
from api.services import SearchService
|
||||
from api.exceptions import APIError, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
api_bp = Blueprint('api', __name__, url_prefix='/api')
|
||||
|
||||
|
||||
def handle_errors(f):
|
||||
"""Decorator to handle exceptions and return consistent error responses."""
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
request_id = request.headers.get('X-Request-ID', 'unknown')
|
||||
start_time = time.time()
|
||||
try:
|
||||
logger.info(
|
||||
f"[{request_id}] {request.method} {request.path} started"
|
||||
)
|
||||
result = f(*args, **kwargs)
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[{request_id}] {request.method} {request.path} "
|
||||
f"completed in {duration_ms:.2f}ms"
|
||||
)
|
||||
return result
|
||||
except ValidationError as e:
|
||||
logger.warning(f"[{request_id}] Validation error: {e.message}")
|
||||
return jsonify(e.to_dict()), e.status_code
|
||||
except APIError as e:
|
||||
logger.error(f"[{request_id}] API error: {e.message}")
|
||||
return jsonify(e.to_dict()), e.status_code
|
||||
except Exception as e:
|
||||
logger.exception(f"[{request_id}] Unexpected error: {str(e)}")
|
||||
return jsonify({
|
||||
'success': False,
|
||||
'error': {
|
||||
'code': 'INTERNAL_ERROR',
|
||||
'message': 'An unexpected error occurred',
|
||||
'details': str(e) if current_app.debug else None,
|
||||
}
|
||||
}), 500
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_search_service():
|
||||
"""Get or create SearchService instance."""
|
||||
if not hasattr(current_app, 'search_service'):
|
||||
current_app.search_service = SearchService()
|
||||
return current_app.search_service
|
||||
|
||||
|
||||
@api_bp.route('/health', methods=['GET'])
|
||||
@handle_errors
|
||||
def health_check():
|
||||
"""Health check endpoint for load balancers and monitoring."""
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'status': 'healthy',
|
||||
'timestamp': time.time(),
|
||||
})
|
||||
|
||||
|
||||
@api_bp.route('/status', methods=['GET'])
|
||||
@handle_errors
|
||||
def get_status():
|
||||
"""Get detailed status of the search service."""
|
||||
service = get_search_service()
|
||||
status = service.get_status()
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': status,
|
||||
})
|
||||
|
||||
|
||||
@api_bp.route('/search', methods=['POST'])
|
||||
@handle_errors
|
||||
def search_similar():
|
||||
"""
|
||||
Search for similar images.
|
||||
|
||||
Request:
|
||||
- file: Image file (multipart/form-data)
|
||||
- count: Number of results to return (optional, default: 10, max: 200)
|
||||
|
||||
Response:
|
||||
- success: bool
|
||||
- data: { results: [...], count: int, query_time_ms: float }
|
||||
"""
|
||||
if 'file' not in request.files:
|
||||
raise ValidationError('No file provided', field='file')
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
raise ValidationError('Empty filename', field='file')
|
||||
|
||||
count = request.form.get('count', 10, type=int)
|
||||
if count < 1 or count > 200:
|
||||
raise ValidationError(
|
||||
'Count must be between 1 and 200',
|
||||
field='count'
|
||||
)
|
||||
|
||||
logger.info(f"Search request: filename={file.filename}, count={count}")
|
||||
|
||||
service = get_search_service()
|
||||
start_time = time.time()
|
||||
results = service.search(file, count)
|
||||
query_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(
|
||||
f"Search completed: {len(results)} results in {query_time_ms:.2f}ms"
|
||||
)
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'results': results,
|
||||
'count': len(results),
|
||||
'query_time_ms': round(query_time_ms, 2),
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@api_bp.route('/similar-icon-paths', methods=['POST'])
|
||||
@handle_errors
|
||||
def get_similar_icons_paths():
|
||||
"""
|
||||
Legacy endpoint: Returns relative paths of similar images.
|
||||
Maintained for backward compatibility.
|
||||
"""
|
||||
if 'file' not in request.files:
|
||||
raise ValidationError('No file provided', field='file')
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
raise ValidationError('Empty filename', field='file')
|
||||
|
||||
count = request.form.get('count', 200, type=int)
|
||||
|
||||
service = get_search_service()
|
||||
results = service.search(file, count)
|
||||
|
||||
paths = [r['path'] for r in results]
|
||||
return jsonify(paths)
|
||||
|
||||
|
||||
@api_bp.route('/similar-icon-abs-paths', methods=['POST'])
|
||||
@handle_errors
|
||||
def get_similar_icons_absolute_paths():
|
||||
"""
|
||||
Legacy endpoint: Returns absolute paths of similar images.
|
||||
Maintained for backward compatibility.
|
||||
"""
|
||||
if 'file' not in request.files:
|
||||
raise ValidationError('No file provided', field='file')
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
raise ValidationError('Empty filename', field='file')
|
||||
|
||||
count = request.form.get('count', 200, type=int)
|
||||
|
||||
service = get_search_service()
|
||||
results = service.search(file, count)
|
||||
|
||||
abs_paths = [os.path.abspath(r['path']) for r in results]
|
||||
return jsonify(abs_paths)
|
||||
|
||||
|
||||
@api_bp.route('/batch-search', methods=['POST'])
|
||||
@handle_errors
|
||||
def batch_search():
|
||||
"""
|
||||
Search for similar images for multiple query images.
|
||||
|
||||
Request:
|
||||
- files: Multiple image files (multipart/form-data)
|
||||
- count: Number of results per image (optional, default: 10, max: 50)
|
||||
|
||||
Response:
|
||||
- success: bool
|
||||
- data: { results: { filename: [...] }, total_query_time_ms: float }
|
||||
"""
|
||||
if 'files' not in request.files:
|
||||
if 'file' in request.files:
|
||||
files = request.files.getlist('file')
|
||||
else:
|
||||
raise ValidationError('No files provided', field='files')
|
||||
else:
|
||||
files = request.files.getlist('files')
|
||||
|
||||
if not files or all(f.filename == '' for f in files):
|
||||
raise ValidationError('No valid files provided', field='files')
|
||||
|
||||
count = request.form.get('count', 10, type=int)
|
||||
if count < 1 or count > 50:
|
||||
raise ValidationError(
|
||||
'Count must be between 1 and 50 for batch search',
|
||||
field='count'
|
||||
)
|
||||
|
||||
logger.info(f"Batch search request: {len(files)} files, count={count}")
|
||||
|
||||
service = get_search_service()
|
||||
start_time = time.time()
|
||||
|
||||
all_results = {}
|
||||
for file in files:
|
||||
if file.filename:
|
||||
try:
|
||||
results = service.search(file, count)
|
||||
all_results[file.filename] = results
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process {file.filename}: {e}")
|
||||
all_results[file.filename] = {'error': str(e)}
|
||||
|
||||
total_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return jsonify({
|
||||
'success': True,
|
||||
'data': {
|
||||
'results': all_results,
|
||||
'files_processed': len(all_results),
|
||||
'total_query_time_ms': round(total_time_ms, 2),
|
||||
}
|
||||
})
|
||||
224
api/services.py
Normal file
224
api/services.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Business Logic Services
|
||||
========================
|
||||
Core search functionality with proper error handling and logging.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import joblib
|
||||
from numpy.linalg import norm
|
||||
|
||||
from config import config
|
||||
from api.exceptions import (
|
||||
SearchError,
|
||||
IndexNotLoadedError,
|
||||
ModelNotLoadedError,
|
||||
FileNotFoundError as APIFileNotFoundError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import faiss
|
||||
FAISS_AVAILABLE = True
|
||||
except ImportError:
|
||||
faiss = None
|
||||
FAISS_AVAILABLE = False
|
||||
logger.warning("FAISS not available, will use sklearn fallback")
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service class for image similarity search."""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.index = None
|
||||
self.filenames = None
|
||||
self._load_model()
|
||||
self._ensure_upload_dir()
|
||||
|
||||
def _ensure_upload_dir(self):
|
||||
"""Ensure uploads directory exists."""
|
||||
upload_dir = config.UPLOAD_DIR
|
||||
if not os.path.exists(upload_dir):
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
logger.info(f"Created uploads directory: {upload_dir}")
|
||||
|
||||
def _load_model(self):
|
||||
"""Lazy load the ResNet50 model."""
|
||||
if self.model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Loading ResNet50 model...")
|
||||
from keras import Sequential
|
||||
from keras.layers import GlobalMaxPooling2D
|
||||
from keras.applications.resnet50 import ResNet50
|
||||
|
||||
base_model = ResNet50(
|
||||
weights='imagenet',
|
||||
include_top=False,
|
||||
input_shape=config.INPUT_SHAPE,
|
||||
)
|
||||
base_model.trainable = False
|
||||
|
||||
self.model = Sequential([
|
||||
base_model,
|
||||
GlobalMaxPooling2D()
|
||||
])
|
||||
logger.info("ResNet50 model loaded successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load model: {e}")
|
||||
raise ModelNotLoadedError()
|
||||
|
||||
def _load_index(self):
|
||||
"""Lazy load the FAISS index."""
|
||||
if self.index is not None:
|
||||
return
|
||||
|
||||
index_path = config.FAISS_INDEX_PATH
|
||||
if not os.path.exists(index_path):
|
||||
logger.error(f"FAISS index not found at {index_path}")
|
||||
raise IndexNotLoadedError()
|
||||
|
||||
try:
|
||||
logger.info(f"Loading FAISS index from {index_path}...")
|
||||
self.index = faiss.read_index(index_path)
|
||||
if hasattr(self.index, 'nprobe'):
|
||||
self.index.nprobe = 64
|
||||
logger.info(
|
||||
f"FAISS index loaded: {self.index.ntotal} vectors"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load FAISS index: {e}")
|
||||
raise SearchError(f"Failed to load search index: {e}")
|
||||
|
||||
def _load_filenames(self):
|
||||
"""Lazy load the filenames mapping."""
|
||||
if self.filenames is not None:
|
||||
return
|
||||
|
||||
filenames_path = config.FILENAMES_PATH
|
||||
if not os.path.exists(filenames_path):
|
||||
logger.error(f"Filenames not found at {filenames_path}")
|
||||
raise APIFileNotFoundError(
|
||||
"Filenames mapping not found",
|
||||
filepath=filenames_path
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"Loading filenames from {filenames_path}...")
|
||||
self.filenames = joblib.load(filenames_path)
|
||||
logger.info(f"Loaded {len(self.filenames)} filenames")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to load filenames: {e}")
|
||||
raise SearchError(f"Failed to load filenames: {e}")
|
||||
|
||||
def _extract_features(self, image_path: str) -> np.ndarray:
|
||||
"""Extract feature vector from an image."""
|
||||
try:
|
||||
from keras.preprocessing import image as keras_image
|
||||
from keras.applications.resnet50 import preprocess_input
|
||||
|
||||
img = keras_image.load_img(
|
||||
image_path,
|
||||
target_size=config.TARGET_SIZE
|
||||
)
|
||||
img_array = keras_image.img_to_array(img)
|
||||
expanded = np.expand_dims(img_array, axis=0)
|
||||
preprocessed = preprocess_input(expanded)
|
||||
|
||||
result = self.model.predict(preprocessed, verbose=0)
|
||||
features = result.flatten()
|
||||
normalized = features / norm(features)
|
||||
|
||||
return normalized
|
||||
except Exception as e:
|
||||
logger.exception(f"Feature extraction failed: {image_path}")
|
||||
raise SearchError(f"Feature extraction failed: {e}")
|
||||
|
||||
def search(self, file, count: int = 10) -> list:
|
||||
"""
|
||||
Search for similar images.
|
||||
|
||||
Args:
|
||||
file: Uploaded file object with name and read() method
|
||||
count: Number of results to return
|
||||
|
||||
Returns:
|
||||
List of dicts with 'path', 'filename', 'rank' keys
|
||||
"""
|
||||
self._load_model()
|
||||
self._load_index()
|
||||
self._load_filenames()
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
self._ensure_upload_dir()
|
||||
temp_path = os.path.join(config.UPLOAD_DIR, file.filename)
|
||||
file.save(temp_path)
|
||||
logger.debug(f"Saved uploaded file to {temp_path}")
|
||||
|
||||
features = self._extract_features(temp_path)
|
||||
|
||||
query = np.asarray([features], dtype=np.float32)
|
||||
distances, indices = self.index.search(query, count)
|
||||
|
||||
results = []
|
||||
for rank, (idx, dist) in enumerate(
|
||||
zip(indices[0], distances[0])
|
||||
):
|
||||
if idx >= 0 and idx < len(self.filenames):
|
||||
filepath = self.filenames[idx]
|
||||
results.append({
|
||||
'path': filepath,
|
||||
'filename': os.path.basename(filepath),
|
||||
'rank': rank + 1,
|
||||
'score': float(dist),
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, (SearchError, IndexNotLoadedError)):
|
||||
raise
|
||||
logger.exception(f"Search failed: {e}")
|
||||
raise SearchError(f"Search operation failed: {e}")
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Get service status information."""
|
||||
index_path = config.FAISS_INDEX_PATH
|
||||
filenames_path = config.FILENAMES_PATH
|
||||
|
||||
index_exists = os.path.exists(index_path)
|
||||
filenames_exists = os.path.exists(filenames_path)
|
||||
|
||||
status = {
|
||||
'faiss_available': FAISS_AVAILABLE,
|
||||
'model_loaded': self.model is not None,
|
||||
'index_loaded': self.index is not None,
|
||||
'filenames_loaded': self.filenames is not None,
|
||||
'index_file_exists': index_exists,
|
||||
'filenames_file_exists': filenames_exists,
|
||||
}
|
||||
|
||||
if index_exists:
|
||||
status['index_size_mb'] = round(
|
||||
os.path.getsize(index_path) / (1024 * 1024), 2
|
||||
)
|
||||
|
||||
if self.index is not None:
|
||||
status['index_total_vectors'] = self.index.ntotal
|
||||
|
||||
if self.filenames is not None:
|
||||
status['total_images'] = len(self.filenames)
|
||||
|
||||
return status
|
||||
110
app.py
Normal file
110
app.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os.path
|
||||
from flask import Flask, request, jsonify
|
||||
from reverse_icon_search import ReverseIconSearch
|
||||
from flask_cors import CORS
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
ris = ReverseIconSearch(200, 200)
|
||||
|
||||
pred_store = None
|
||||
|
||||
|
||||
def gen_pred_store():
|
||||
global pred_store
|
||||
file_location = "./predictions/react-icons/ricons.txt"
|
||||
if os.path.exists(file_location):
|
||||
with open(file_location, 'r') as file:
|
||||
pred_store = {}
|
||||
for line in file:
|
||||
index_of_hyphen_black = line.find("-black")
|
||||
if index_of_hyphen_black != -1:
|
||||
blk_len = len("-black")
|
||||
key = line[:index_of_hyphen_black + blk_len]
|
||||
offset = index_of_hyphen_black + blk_len
|
||||
numbers_str = line[offset:].strip()
|
||||
chunks = [
|
||||
numbers_str[i:i + 7]
|
||||
for i in range(0, len(numbers_str), 7)
|
||||
]
|
||||
pred_store[key] = chunks
|
||||
return pred_store
|
||||
|
||||
|
||||
def get_matches(request_in):
|
||||
global pred_store
|
||||
if 'file' not in request_in.files:
|
||||
return jsonify({'error': 'No file part'}), 400
|
||||
file = request_in.files['file']
|
||||
if file.filename == '':
|
||||
return jsonify({'error': 'No selected file'}), 400
|
||||
if pred_store is None:
|
||||
pred_store = gen_pred_store()
|
||||
if pred_store is not None:
|
||||
predictions = pred_store[file.filename[:file.filename.rindex(".png")]]
|
||||
predictions = list(
|
||||
map(
|
||||
lambda x: "../images/" + x[0:2] + "/" + x + ".png",
|
||||
predictions,
|
||||
)
|
||||
)
|
||||
return predictions
|
||||
file.name = file.filename
|
||||
ris.uploaded_file = file
|
||||
if "count" in request_in.form:
|
||||
ris.return_number_of_predictions = int(request_in.form["count"])
|
||||
return jsonify(ris.process_file())
|
||||
|
||||
|
||||
@app.route("/api/similar-icon-paths", methods=["POST"])
|
||||
def get_similar_icons_paths():
|
||||
matches = get_matches(request)
|
||||
ris.return_number_of_predictions = 200
|
||||
return matches
|
||||
|
||||
|
||||
@app.route("/api/similar-icon-abs-paths", methods=["POST"])
|
||||
def get_similar_icons_absolute_paths():
|
||||
matches = get_matches(request)
|
||||
matches = list(map(lambda x: os.path.abspath(x), matches))
|
||||
ris.return_number_of_predictions = 200
|
||||
return jsonify(matches)
|
||||
|
||||
|
||||
@app.route("/api/gen-all-predictions", methods=["GET"])
|
||||
def generate_all_matches():
|
||||
root_dir = "../images"
|
||||
if os.path.exists(root_dir):
|
||||
dirlist = [
|
||||
d for d in os.listdir(root_dir)
|
||||
if os.path.isdir(os.path.join(root_dir, d))
|
||||
]
|
||||
for dir_name in dirlist:
|
||||
file_batches = {}
|
||||
qualified_parent = os.path.join(root_dir, dir_name)
|
||||
save_location = "./predictions/" + dir_name + "/"
|
||||
all_file_names = [
|
||||
f for f in os.listdir(qualified_parent)
|
||||
if os.path.isfile(os.path.join(qualified_parent, f))
|
||||
]
|
||||
for i in range(0, 10):
|
||||
for file_name in all_file_names:
|
||||
if file_name[2] == str(i):
|
||||
if str(i) not in file_batches:
|
||||
file_batches[str(i)] = []
|
||||
file_batches[str(i)].append(file_name)
|
||||
for key, file_names in file_batches.items():
|
||||
for file_name in sorted(file_names):
|
||||
pred_file = save_location + file_name[2:3] + ".txt"
|
||||
qualified_path = os.path.abspath(
|
||||
os.path.join(qualified_parent, file_name)
|
||||
)
|
||||
ris.process_file_path(
|
||||
qualified_path,
|
||||
os.path.abspath(pred_file),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True, port=5002)
|
||||
41
config.py
Normal file
41
config.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Centralized Configuration
|
||||
=========================
|
||||
All configurable values for the Reverse Image Search project.
|
||||
Override via environment variables or by modifying defaults.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class Config:
|
||||
"""Application configuration with environment variable overrides."""
|
||||
|
||||
# File paths
|
||||
EMBEDDINGS_PATH = os.getenv("EMBEDDINGS_PATH", "embeddings.pkl")
|
||||
FILENAMES_PATH = os.getenv("FILENAMES_PATH", "filenames.pkl")
|
||||
NEIGHBORS_PATH = os.getenv("NEIGHBORS_PATH", "neighbors.pkl")
|
||||
FAISS_INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "index.faiss")
|
||||
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploads")
|
||||
|
||||
# Embedding generation paths
|
||||
IMAGES_ROOT = os.getenv("IMAGES_ROOT", "../images")
|
||||
EMBEDDINGS_DIR = os.getenv("EMBEDDINGS_DIR", "embeddings")
|
||||
FILENAMES_DIR = os.getenv("FILENAMES_DIR", "filenames")
|
||||
|
||||
# Model parameters
|
||||
TARGET_SIZE = (224, 224)
|
||||
INPUT_SHAPE = (224, 224, 3)
|
||||
EMBEDDING_DIM = 2048
|
||||
|
||||
# Search parameters
|
||||
N_NEIGHBORS = int(os.getenv("N_NEIGHBORS", "200"))
|
||||
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "64"))
|
||||
|
||||
# Server settings
|
||||
API_PORT = int(os.getenv("API_PORT", "5002"))
|
||||
DEBUG = os.getenv("DEBUG", "true").lower() == "true"
|
||||
|
||||
|
||||
# Singleton instance for easy import
|
||||
config = Config()
|
||||
249
docs/GETTING_STARTED.md
Normal file
249
docs/GETTING_STARTED.md
Normal file
@@ -0,0 +1,249 @@
|
||||
# Getting Started
|
||||
|
||||
This guide covers everything you need to set up and run the Reverse Image Search project.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Python**: 3.8 - 3.11 (TensorFlow compatibility)
|
||||
- **pip**: Latest version recommended
|
||||
- **RAM**: 8GB minimum (16GB recommended for full dataset)
|
||||
- **Disk Space**: ~20GB (16GB dataset + embeddings)
|
||||
- **GPU**: Optional but recommended for faster feature extraction
|
||||
|
||||
### macOS Specific
|
||||
|
||||
The project uses `tensorflow-macos` and `tensorflow-metal` for Apple Silicon optimization.
|
||||
|
||||
---
|
||||
|
||||
## Installation
|
||||
|
||||
### 1. Clone the Repository
|
||||
|
||||
```bash
|
||||
git clone https://github.com/deepamkalekar/Reverse-Image-Search-ML-DL-Project.git
|
||||
cd Reverse-Image-Search-ML-DL-Project
|
||||
```
|
||||
|
||||
### 2. Create Virtual Environment
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
python -m venv venv
|
||||
|
||||
# Activate (macOS/Linux)
|
||||
source venv/bin/activate
|
||||
|
||||
# Activate (Windows)
|
||||
venv\Scripts\activate
|
||||
```
|
||||
|
||||
### 3. Install Dependencies
|
||||
|
||||
```bash
|
||||
pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
> **Note**: On non-Mac systems, replace `tensorflow-macos` and `tensorflow-metal` with `tensorflow` in `requirements.txt`.
|
||||
|
||||
### 4. Download Dataset
|
||||
|
||||
1. Download the [Fashion Product Images Dataset](https://www.kaggle.com/paramaggarwal/fashion-product-images-dataset) from Kaggle (~16GB)
|
||||
2. Create an `images` folder in the project root
|
||||
3. Extract and move all images to the `images` folder
|
||||
|
||||
```bash
|
||||
mkdir images
|
||||
# Move downloaded images here, organized in subdirectories (10/, 11/, 12/, etc.)
|
||||
```
|
||||
|
||||
### 5. Create Required Directories
|
||||
|
||||
```bash
|
||||
mkdir -p uploads embeddings filenames predictions
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Generating Embeddings
|
||||
|
||||
Before running the search application, you need to generate feature embeddings for your image dataset.
|
||||
|
||||
### Generate Embeddings (Recommended)
|
||||
|
||||
Use the unified embedding generation script:
|
||||
|
||||
```bash
|
||||
python run_generate_embeddings.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. Process images from `../images/<dir>/...`
|
||||
2. Create per-directory checkpoints in `embeddings/` and `filenames/`
|
||||
3. Merge them into `embeddings.pkl` and `filenames.pkl` in the project root
|
||||
|
||||
---
|
||||
|
||||
## Running the Application
|
||||
|
||||
### Streamlit Web UI (Recommended)
|
||||
|
||||
```bash
|
||||
streamlit run run_streamlit_ui.py
|
||||
```
|
||||
|
||||
**Access**: Open http://localhost:8501 in your browser
|
||||
|
||||
**Usage**:
|
||||
1. Click "Browse files" to upload an image
|
||||
2. Wait for processing (first run loads the model)
|
||||
3. View similar images displayed in a grid
|
||||
|
||||
### Flask REST API
|
||||
|
||||
```bash
|
||||
python run_api_server.py
|
||||
```
|
||||
|
||||
**Access**: API runs on http://localhost:5002
|
||||
|
||||
**Endpoints**:
|
||||
|
||||
```bash
|
||||
# Get similar image paths (relative)
|
||||
curl -X POST -F "file=@/path/to/image.jpg" \
|
||||
http://localhost:5002/api/similar-icon-paths
|
||||
|
||||
# Get similar image paths (absolute)
|
||||
curl -X POST -F "file=@/path/to/image.jpg" \
|
||||
http://localhost:5002/api/similar-icon-abs-paths
|
||||
|
||||
# With custom count
|
||||
curl -X POST -F "file=@/path/to/image.jpg" -F "count=50" \
|
||||
http://localhost:5002/api/similar-icon-paths
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
### Manual Test Script
|
||||
|
||||
Run a quick test with a sample image:
|
||||
|
||||
```bash
|
||||
# First, place a test image at sample/shirt.jpg
|
||||
mkdir sample
|
||||
# Copy a test image to sample/shirt.jpg
|
||||
|
||||
python test.py
|
||||
```
|
||||
|
||||
This will:
|
||||
1. Load embeddings and model
|
||||
2. Process the sample image
|
||||
3. Display similar images using OpenCV
|
||||
|
||||
### Verifying Installation
|
||||
|
||||
```python
|
||||
# Quick verification script
|
||||
python -c "
|
||||
import tensorflow as tf
|
||||
from keras.applications.resnet50 import ResNet50
|
||||
print(f'TensorFlow version: {tf.__version__}')
|
||||
model = ResNet50(weights='imagenet', include_top=False)
|
||||
print('ResNet50 loaded successfully!')
|
||||
print(f'GPU available: {len(tf.config.list_physical_devices(\"GPU\")) > 0}')
|
||||
"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### Adjusting Search Parameters
|
||||
|
||||
In `reverse_icon_search.py`, modify:
|
||||
|
||||
```python
|
||||
# Number of neighbors to calculate
|
||||
self.calculate_number_of_predictions = 200
|
||||
|
||||
# Number of results to return
|
||||
self.return_number_of_predictions = 200
|
||||
```
|
||||
|
||||
### Using Different Embedding Files
|
||||
|
||||
Switch between embedding versions by modifying the file paths:
|
||||
|
||||
```python
|
||||
# For 20% dataset
|
||||
self.filenames_location = "filenames_20pct.pkl"
|
||||
self.embeddings_location = "embeddings_20pct.pkl"
|
||||
|
||||
# For full dataset
|
||||
self.filenames_location = "filenames.pkl"
|
||||
self.embeddings_location = "embeddings.pkl"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| `ModuleNotFoundError: No module named 'tensorflow'` | Install TensorFlow: `pip install tensorflow` |
|
||||
| `FileNotFoundError: embeddings.pkl` | Generate embeddings first (see above) |
|
||||
| `CUDA out of memory` | Reduce batch size or use CPU |
|
||||
| `Image codec error` | Ensure images are valid PNG/JPG files |
|
||||
| Streamlit port conflict | Use `streamlit run main_streamlit.py --server.port 8502` |
|
||||
|
||||
### Memory Issues
|
||||
|
||||
If you encounter memory errors:
|
||||
1. Use a smaller subset of embeddings (`embeddings_20pct.pkl`)
|
||||
2. Reduce `n_neighbors` parameter
|
||||
3. Process images in smaller batches
|
||||
|
||||
### macOS ARM64 Issues
|
||||
|
||||
```bash
|
||||
# If TensorFlow fails on Apple Silicon
|
||||
pip uninstall tensorflow tensorflow-macos tensorflow-metal
|
||||
pip install tensorflow-macos tensorflow-metal
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
Reverse-Image-Search-ML-DL-Project/
|
||||
├── app.py # Flask REST API server
|
||||
├── main_streamlit.py # Streamlit web application
|
||||
├── reverse_icon_search.py # Core search logic (reusable)
|
||||
├── feature_gen.py # Advanced feature generation
|
||||
├── feature_gen_basic.py # Simple feature generation
|
||||
├── feature_gen_bulk.py # Parallel feature generation
|
||||
├── test.py # Manual test script
|
||||
├── requirements.txt # Python dependencies
|
||||
├── images/ # Dataset images (create this)
|
||||
├── uploads/ # User uploaded images
|
||||
├── embeddings/ # Generated embeddings
|
||||
├── filenames/ # Filename mappings
|
||||
└── predictions/ # Pre-computed predictions
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Add your own dataset**: Replace images in `images/` folder and regenerate embeddings
|
||||
2. **Fine-tune the model**: Modify ResNet50 or try other architectures (VGG16, EfficientNet)
|
||||
3. **Optimize search**: Consider FAISS for faster similarity search at scale
|
||||
4. **Deploy**: Containerize with Docker for production deployment
|
||||
209
docs/SPECS.md
Normal file
209
docs/SPECS.md
Normal file
@@ -0,0 +1,209 @@
|
||||
# Reverse Image Search - Technical Specifications
|
||||
|
||||
## Overview
|
||||
|
||||
This project implements a **Reverse Image Search** system using Deep Learning and Machine Learning techniques. Users can upload an image and receive visually similar images from a pre-indexed dataset, similar to Google or Amazon's image search functionality.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌──────────────────┐ ┌─────────────────┐
|
||||
│ User Input │────▶│ Feature │────▶│ KNN Search │
|
||||
│ (Image) │ │ Extraction │ │ (Similarity) │
|
||||
└─────────────────┘ │ (ResNet50) │ └────────┬────────┘
|
||||
└──────────────────┘ │
|
||||
▼
|
||||
┌──────────────────┐ ┌─────────────────┐
|
||||
│ Pre-computed │────▶│ Similar Images │
|
||||
│ Embeddings │ │ (Results) │
|
||||
└──────────────────┘ └─────────────────┘
|
||||
```
|
||||
|
||||
### Technology Stack
|
||||
|
||||
| Component | Technology | Purpose |
|
||||
|-----------|------------|---------|
|
||||
| **Deep Learning Model** | ResNet50 (ImageNet weights) | Feature extraction from images |
|
||||
| **ML Algorithm** | K-Nearest Neighbors (KNN) | Finding similar images based on feature vectors |
|
||||
| **Web UI** | Streamlit | Interactive user interface |
|
||||
| **REST API** | Flask + Flask-CORS | API endpoints for integration |
|
||||
| **Data Serialization** | Joblib / Pickle | Storing embeddings and filenames |
|
||||
| **Image Processing** | Pillow, OpenCV | Image loading and preprocessing |
|
||||
| **Numerical Computing** | NumPy, TensorFlow/Keras | Array operations and model inference |
|
||||
|
||||
## Feature Extraction Pipeline
|
||||
|
||||
### ResNet50 Model Configuration
|
||||
|
||||
- **Input Shape**: 224 × 224 × 3 (RGB images)
|
||||
- **Weights**: Pre-trained on ImageNet
|
||||
- **Top Layer**: Removed (`include_top=False`)
|
||||
- **Pooling**: GlobalMaxPooling2D
|
||||
- **Training**: Frozen (`trainable=False`)
|
||||
|
||||
### Feature Vector Generation
|
||||
|
||||
1. Load image and resize to 224×224
|
||||
2. Convert to array and expand dimensions
|
||||
3. Apply ResNet50 preprocessing
|
||||
4. Extract features via model prediction
|
||||
5. Normalize using L2 norm
|
||||
|
||||
**Output**: 2048-dimensional feature vector per image
|
||||
|
||||
## Similarity Search
|
||||
|
||||
### FAISS IVFPQ Index (Primary)
|
||||
|
||||
- **Index Type**: IndexIVFPQ (inverted file + product quantization)
|
||||
- **Compression**: ~3-6 GB for 6M vectors (vs ~50 GB uncompressed)
|
||||
- **Build**: Train on 100K samples, then stream-add all vectors
|
||||
- **Query Time**: ~10ms (vs ~400ms brute force)
|
||||
- **Recall**: ~90-95% (configurable via nprobe)
|
||||
- **File**: `index.faiss`
|
||||
|
||||
**Configuration** (in `generate_embeddings.py`):
|
||||
```
|
||||
nlist=4096 # Number of clusters
|
||||
m=64 # Subquantizers (must divide 2048)
|
||||
nbits=8 # Bits per subquantizer
|
||||
nprobe=64 # Clusters to search (recall/speed tradeoff)
|
||||
```
|
||||
|
||||
### K-Nearest Neighbors (Fallback)
|
||||
|
||||
- **Algorithm**: Brute-force / Auto
|
||||
- **Distance Metric**: Euclidean
|
||||
- **Default Neighbors**: 200 (configurable)
|
||||
- **Used when**: FAISS not installed or `index.faiss` missing
|
||||
|
||||
### Search Process
|
||||
|
||||
1. Extract features from query image (ResNet50)
|
||||
2. Query FAISS index (or fallback to KNN)
|
||||
3. Map vector IDs to filenames
|
||||
4. Return top-K similar images
|
||||
|
||||
## Interfaces
|
||||
|
||||
### Streamlit Web Application (`main_streamlit.py`)
|
||||
|
||||
- File upload widget for query images
|
||||
- Grid display of similar images (10 columns)
|
||||
- Image labels showing filename
|
||||
|
||||
### Flask REST API (`api/`)
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/api/health` | GET | Health check for load balancers |
|
||||
| `/api/status` | GET | Service status (index info, model state) |
|
||||
| `/api/search` | POST | Search for similar images (primary) |
|
||||
| `/api/batch-search` | POST | Search multiple images at once |
|
||||
| `/api/similar-icon-paths` | POST | Legacy: relative paths |
|
||||
| `/api/similar-icon-abs-paths` | POST | Legacy: absolute paths |
|
||||
|
||||
**API Parameters**:
|
||||
- `file`: Image file (multipart/form-data)
|
||||
- `count`: Number of results (1-200 for search, 1-50 for batch)
|
||||
|
||||
**Response Format**:
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"results": [...],
|
||||
"count": 10,
|
||||
"query_time_ms": 45.2
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Data Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `index.faiss` | FAISS ANN index (primary search backend) |
|
||||
| `filenames.pkl` | Image paths mapped by vector ID |
|
||||
| `embeddings/*.pkl` | Per-directory checkpoint embeddings |
|
||||
| `filenames/*.pkl` | Per-directory checkpoint filenames |
|
||||
| `neighbors.pkl` | Fitted KNN model (fallback, cached) |
|
||||
|
||||
## Entrypoint Scripts
|
||||
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| `run_generate_embeddings.py` | Generate embeddings + FAISS index |
|
||||
| `run_streamlit_ui.py` | Launch Streamlit web UI |
|
||||
| `run_api_server.py` | Launch Flask REST API |
|
||||
|
||||
### Legacy (in `legacy/` folder)
|
||||
|
||||
| Script | Purpose |
|
||||
|--------|---------|
|
||||
| `feature_gen_basic.py` | Simple sequential extraction |
|
||||
| `feature_gen_bulk.py` | Parallel processing with checkpointing |
|
||||
| `feature_gen.py` | Merging and validation utilities |
|
||||
|
||||
## Dataset
|
||||
|
||||
- **Source**: Kaggle Fashion Product Images Dataset
|
||||
- **Size**: ~16 GB
|
||||
- **Categories**: Watches, shoes, clothes, etc.
|
||||
- **Format**: PNG/JPG images organized in subdirectories
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **FAISS Index**: Sublinear search, ~98-99% recall
|
||||
- **Streaming Build**: Index built without loading all embeddings in RAM
|
||||
- **Lazy Loading**: FAISS index / KNN model loaded on first request
|
||||
- **Batch Processing**: 64 images per GPU batch during generation
|
||||
- **Checkpointing**: Per-directory checkpoints, resume-capable
|
||||
|
||||
---
|
||||
|
||||
## Future Capabilities (Roadmap)
|
||||
|
||||
### Near-Term Enhancements
|
||||
|
||||
| Feature | Description | Complexity |
|
||||
|---------|-------------|------------|
|
||||
| **Multi-model support** | Swap ResNet50 for CLIP, ViT, EfficientNet | Medium |
|
||||
| **Fine-tuning** | Train on domain-specific dataset (fashion, furniture) | Medium |
|
||||
| **Metadata filtering** | Filter by category, color, price range | Low |
|
||||
| **Batch upload** | Search multiple images at once | Low |
|
||||
| **Result pagination** | Paginate large result sets in UI/API | Low |
|
||||
|
||||
### Medium-Term Features
|
||||
|
||||
| Feature | Description | Complexity |
|
||||
|---------|-------------|------------|
|
||||
| **Real-time indexing** | Add new images without full rebuild | Medium |
|
||||
| **Distributed index** | Shard FAISS across multiple nodes | High |
|
||||
| **GPU inference** | Use GPU for query-time feature extraction | Low |
|
||||
| **Image cropping** | Auto-detect and crop product from background | Medium |
|
||||
| **Text-to-image search** | Use CLIP for text query → image results | Medium |
|
||||
|
||||
### Long-Term Vision
|
||||
|
||||
| Feature | Description | Complexity |
|
||||
|---------|-------------|------------|
|
||||
| **Hybrid search** | Combine visual + text + metadata | High |
|
||||
| **User feedback loop** | Learn from clicks to improve ranking | High |
|
||||
| **Multi-tenant** | Isolated indexes per customer/dataset | High |
|
||||
| **Auto-tagging** | Generate labels/tags from embeddings | Medium |
|
||||
| **Duplicate detection** | Find near-duplicate images in dataset | Low |
|
||||
|
||||
---
|
||||
|
||||
## Potential Use Cases
|
||||
|
||||
1. **E-commerce**: "Find similar products" feature
|
||||
2. **Fashion**: Outfit matching, style recommendations
|
||||
3. **Stock photography**: Find visually similar images
|
||||
4. **Brand monitoring**: Detect unauthorized use of images
|
||||
5. **Interior design**: Match furniture/decor styles
|
||||
6. **Art & collectibles**: Find similar artworks
|
||||
7. **Quality control**: Detect defects by comparing to reference images
|
||||
240
docs/backlog/DEBT.md
Normal file
240
docs/backlog/DEBT.md
Normal file
@@ -0,0 +1,240 @@
|
||||
# Technical Debt & Suggested Improvements
|
||||
|
||||
## High Priority
|
||||
|
||||
### 1. Error Handling & Logging
|
||||
|
||||
**Current State**: Some code paths historically used bare `except` clauses with minimal error information.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Replace bare `except` with specific exception types
|
||||
- Add structured logging using Python's `logging` module
|
||||
- Implement proper error responses in Flask API
|
||||
|
||||
```python
|
||||
# Instead of:
|
||||
except:
|
||||
return 0
|
||||
|
||||
# Use:
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to save file: {e}")
|
||||
return 0
|
||||
```
|
||||
|
||||
### 2. Configuration Management ✅
|
||||
|
||||
**Current State**: Centralized in `config.py` with environment variable overrides.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Create a `config.py` or use environment variables
|
||||
- Centralize all configurable values:
|
||||
- File paths (`filenames_location`, `embeddings_location`, `index_location`)
|
||||
- Model parameters (`target_size`, `n_neighbors`)
|
||||
- Server settings (`port`, `debug`)
|
||||
|
||||
```python
|
||||
# config.py
|
||||
class Config:
|
||||
EMBEDDINGS_PATH = os.getenv("EMBEDDINGS_PATH", "embeddings.pkl")
|
||||
FILENAMES_PATH = os.getenv("FILENAMES_PATH", "filenames.pkl")
|
||||
FAISS_INDEX_PATH = os.getenv("FAISS_INDEX_PATH", "index.faiss")
|
||||
TARGET_SIZE = (224, 224)
|
||||
N_NEIGHBORS = 200
|
||||
```
|
||||
|
||||
### 3. Code Duplication ✅
|
||||
|
||||
**Current State**: `main_streamlit.py` now imports and extends `ReverseIconSearch` from `reverse_icon_search.py`.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Keep single source of truth in `reverse_icon_search.py`
|
||||
- Import and extend for Streamlit-specific features
|
||||
- Apply DRY principle across feature generation scripts (legacy scripts can be moved under `legacy/`)
|
||||
|
||||
---
|
||||
|
||||
## Medium Priority
|
||||
|
||||
### 4. API Response Consistency
|
||||
|
||||
**Current State**: Mixed return types (sometimes `jsonify()`, sometimes raw list).
|
||||
|
||||
**Suggested Changes**:
|
||||
- Standardize all API responses to JSON format
|
||||
- Include status codes and error messages
|
||||
- Add response schemas
|
||||
|
||||
```python
|
||||
{
|
||||
"status": "success",
|
||||
"data": [...],
|
||||
"count": 200
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Input Validation
|
||||
|
||||
**Current State**: Minimal validation of uploaded files.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Validate file types (accept only images)
|
||||
- Check file size limits
|
||||
- Sanitize filenames
|
||||
- Validate image dimensions
|
||||
|
||||
```python
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp'}
|
||||
|
||||
def allowed_file(filename):
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
```
|
||||
|
||||
### 6. Testing Infrastructure
|
||||
|
||||
**Current State**: `test.py` is a manual test script, not automated tests.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Add `pytest` to requirements
|
||||
- Create unit tests for feature extraction
|
||||
- Create integration tests for API endpoints
|
||||
- Add test fixtures with sample images
|
||||
|
||||
```
|
||||
tests/
|
||||
├── __init__.py
|
||||
├── conftest.py
|
||||
├── test_feature_extraction.py
|
||||
├── test_api.py
|
||||
└── fixtures/
|
||||
└── sample_images/
|
||||
```
|
||||
|
||||
### 7. Documentation Strings
|
||||
|
||||
**Current State**: No docstrings in functions/classes.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Add docstrings to all public functions and classes
|
||||
- Include parameter types and return values
|
||||
- Consider using type hints (Python 3.5+)
|
||||
|
||||
```python
|
||||
def feature_extraction(self, img_path: str, model) -> np.ndarray:
|
||||
"""
|
||||
Extract feature vector from an image.
|
||||
|
||||
Args:
|
||||
img_path: Path to the image file
|
||||
model: Keras model for feature extraction
|
||||
|
||||
Returns:
|
||||
Normalized 2048-dimensional feature vector
|
||||
"""
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Low Priority
|
||||
|
||||
### 8. Dependency Management
|
||||
|
||||
**Current State**: Some redundant/conflicting packages in `requirements.txt`.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Remove duplicate `sklearn` (use only `scikit-learn`)
|
||||
- Pin all versions for reproducibility
|
||||
- Consider separating dev dependencies
|
||||
- Add `requirements-dev.txt` for development tools
|
||||
|
||||
### 9. Model Versioning
|
||||
|
||||
**Current State**: No tracking of model/embedding versions.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Add version metadata to pickle files
|
||||
- Implement compatibility checks
|
||||
- Consider using MLflow or DVC for experiment tracking
|
||||
|
||||
### 10. Memory Optimization ✅
|
||||
|
||||
**Current State**: FAISS index built via streaming (one checkpoint at a time), inference uses disk-backed index.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Consider using memory-mapped files for large datasets
|
||||
- Implement lazy loading for embeddings
|
||||
- Use FAISS for approximate nearest neighbor search at scale (persist `index.faiss` and query it at inference time)
|
||||
- Prefer FAISS-first inference paths and avoid loading `embeddings.pkl` during search where possible
|
||||
|
||||
### 11. Async Processing
|
||||
|
||||
**Current State**: Synchronous file processing in Flask.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Use async/await for I/O operations
|
||||
- Consider Celery for background tasks
|
||||
- Add request queuing for batch predictions
|
||||
|
||||
### 12. Security Improvements
|
||||
|
||||
**Current State**: Debug mode enabled, CORS allows all origins.
|
||||
|
||||
**Suggested Changes**:
|
||||
- Disable debug mode in production
|
||||
- Configure CORS with specific allowed origins
|
||||
- Add rate limiting
|
||||
- Implement file upload security (virus scanning, size limits)
|
||||
|
||||
---
|
||||
|
||||
## Refactoring Suggestions
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
Reverse-Image-Search-ML-DL-Project/
|
||||
├── src/
|
||||
│ ├── __init__.py
|
||||
│ ├── config.py
|
||||
│ ├── models/
|
||||
│ │ └── feature_extractor.py
|
||||
│ ├── search/
|
||||
│ │ └── reverse_search.py
|
||||
│ ├── api/
|
||||
│ │ └── routes.py
|
||||
│ └── ui/
|
||||
│ └── streamlit_app.py
|
||||
├── tests/
|
||||
├── data/
|
||||
│ ├── embeddings/
|
||||
│ └── filenames/
|
||||
├── docs/
|
||||
├── scripts/
|
||||
│ └── generate_features.py
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
### Dependency Injection
|
||||
|
||||
Consider using dependency injection for the model and neighbors:
|
||||
|
||||
```python
|
||||
class ReverseIconSearch:
|
||||
def __init__(self, model, neighbors_store, config):
|
||||
self.model = model
|
||||
self.neighbors_store = neighbors_store
|
||||
self.config = config
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Quick Wins
|
||||
|
||||
1. ✅ Add `.gitignore` entries for `*.pkl`, `uploads/`, `__pycache__/`
|
||||
2. ✅ Fix typo in README: `pip install requirements.txt` → `pip install -r requirements.txt`
|
||||
3. ✅ Fix README: Legacy section updated to use new entrypoints
|
||||
4. ⬜ Add type hints to function signatures
|
||||
5. ✅ Remove commented-out code blocks
|
||||
6. ⬜ Add `if __name__ == "__main__"` guards where missing
|
||||
7. ✅ Add workflow entrypoints (`run_generate_embeddings.py`, `run_streamlit_ui.py`, `run_api_server.py`)
|
||||
187
docs/reverse_image_search_scalability_spec.md
Normal file
187
docs/reverse_image_search_scalability_spec.md
Normal file
@@ -0,0 +1,187 @@
|
||||
# Reverse Image Search — Scalability & Architecture Optimization Spec
|
||||
|
||||
## 🧠 Context for Copilot — Reverse Image Search Scalability Fix
|
||||
|
||||
### Background
|
||||
|
||||
The current reverse image search system uses:
|
||||
|
||||
- **ResNet50** (`include_top=False`) with **global pooling** to generate fixed-size image embeddings
|
||||
- Embeddings are persisted using `joblib` and loaded into RAM as a **dense NumPy array**
|
||||
- Similarity search is performed using **brute-force cosine similarity**
|
||||
(`np.dot + argsort`)
|
||||
|
||||
This approach works functionally but causes **very high RAM usage (~40+ GB)** because **all embeddings are loaded into memory at once**.
|
||||
|
||||
---
|
||||
|
||||
## 🔍 Root Cause Analysis
|
||||
|
||||
- Embeddings are stored as a **dense in-memory matrix**
|
||||
- Similarity search is **O(N × D)** and requires all vectors to be resident in RAM
|
||||
- Memory usage scales linearly with dataset size
|
||||
- This approach does **not scale beyond a few million images**
|
||||
|
||||
Observed behavior (e.g. ~42 GB RAM usage) is fully explained by:
|
||||
```
|
||||
num_images × embedding_dim × sizeof(float32)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Target Architecture (What Must Change)
|
||||
|
||||
Refactor the system to use a **disk-backed Approximate Nearest Neighbor (ANN) index**, similar in architecture to how **RAG systems** work for LLMs.
|
||||
|
||||
### Key Principle
|
||||
|
||||
> **Embeddings must live in a vector index (disk-backed or partially RAM-resident),
|
||||
> not as a full NumPy matrix loaded into memory.**
|
||||
|
||||
The CNN / model remains unchanged.
|
||||
Only **embedding storage and retrieval** changes.
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ Target Architecture Overview
|
||||
|
||||
```
|
||||
Image
|
||||
↓
|
||||
CNN Encoder (ResNet50, frozen)
|
||||
↓
|
||||
Embedding (fixed-dim vector)
|
||||
↓
|
||||
ANN Index (FAISS / HNSW / IVF+PQ) ← disk-backed
|
||||
↓
|
||||
Top-K nearest vectors
|
||||
↓
|
||||
Image IDs / paths / metadata
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Required Changes (Implementation Tasks)
|
||||
|
||||
### 1. Replace In-Memory Embedding Storage with a Vector Index
|
||||
|
||||
- Use **FAISS** (preferred) or an equivalent ANN library
|
||||
- Do **NOT** load all embeddings into RAM as a NumPy array
|
||||
- Eliminate `joblib.load(embeddings.pkl)` from inference paths
|
||||
|
||||
---
|
||||
|
||||
### 2. Index Creation (One-Time / Offline Step)
|
||||
|
||||
During feature generation:
|
||||
|
||||
- Generate embeddings exactly as before using ResNet50
|
||||
- **Normalize embeddings (L2)** at build time if using cosine similarity
|
||||
- Build a FAISS index instead of persisting `embeddings.pkl`
|
||||
|
||||
#### Recommended Index Types
|
||||
|
||||
- **FAISS IndexHNSWFlat**
|
||||
- Fast
|
||||
- High recall
|
||||
- Good default
|
||||
- **FAISS IndexIVFPQ**
|
||||
- Much lower RAM usage
|
||||
- Slight recall loss
|
||||
- Best for very large datasets (millions+ images)
|
||||
|
||||
#### Persist on Disk
|
||||
|
||||
Store:
|
||||
- FAISS index file (disk-backed)
|
||||
- Separate metadata file mapping:
|
||||
- vector ID → image path / image ID
|
||||
|
||||
---
|
||||
|
||||
### 3. Inference-Time Search (Critical Change)
|
||||
|
||||
Replace this pattern ❌:
|
||||
```python
|
||||
embeddings = joblib.load("embeddings.pkl")
|
||||
scores = np.dot(embeddings, query_vector)
|
||||
top_k = np.argsort(scores)[-k:]
|
||||
```
|
||||
|
||||
With this pattern ✅:
|
||||
```python
|
||||
index = faiss.read_index("index.faiss")
|
||||
query_vec = embed(query_image)
|
||||
distances, ids = index.search(query_vec, k)
|
||||
```
|
||||
|
||||
Then:
|
||||
- Map `ids` → image paths using metadata
|
||||
- Return top-K matches
|
||||
|
||||
#### Explicitly DO NOT:
|
||||
|
||||
- Load all embeddings into a NumPy array
|
||||
- Perform manual cosine similarity across the full dataset
|
||||
- Use exact nearest-neighbor search for large datasets
|
||||
|
||||
---
|
||||
|
||||
### 4. Similarity Metric
|
||||
|
||||
- Use **cosine similarity** or **inner product**
|
||||
- Normalize embeddings at index build time if required
|
||||
- Metric choice must be consistent between indexing and querying
|
||||
|
||||
---
|
||||
|
||||
### 5. Memory & Performance Expectations After Refactor
|
||||
|
||||
After implementing ANN indexing:
|
||||
|
||||
- RAM usage scales with **index metadata**, not dataset size
|
||||
- Dataset can scale to **hundreds of GB or TB**
|
||||
- Query latency becomes **sublinear**
|
||||
- Recall remains high (typically ~98–99%)
|
||||
|
||||
This is a deliberate and acceptable tradeoff.
|
||||
|
||||
---
|
||||
|
||||
## 🔁 Conceptual Mapping (Reverse Image Search ↔ RAG)
|
||||
|
||||
| Reverse Image Search | LLM RAG |
|
||||
|--------------------|---------|
|
||||
| CNN encoder | Text embedding model |
|
||||
| Image embeddings | Text chunk embeddings |
|
||||
| FAISS index | Vector database |
|
||||
| Query image | User query |
|
||||
| Top-K images | Top-K document chunks |
|
||||
| External data | External knowledge |
|
||||
|
||||
---
|
||||
|
||||
## 🚫 Explicit Anti-Patterns to Avoid
|
||||
|
||||
- ❌ `joblib.load()` of the full embedding matrix at inference time
|
||||
- ❌ `np.dot(all_embeddings, query_embedding)`
|
||||
- ❌ Storing embeddings as a single dense array in RAM
|
||||
- ❌ Exact nearest-neighbor search for large datasets
|
||||
- ❌ Adding dataset-sized layers to the CNN
|
||||
|
||||
---
|
||||
|
||||
## ✅ Acceptance Criteria
|
||||
|
||||
- No code path loads all embeddings into memory
|
||||
- FAISS (or equivalent) ANN index is used for similarity search
|
||||
- System supports **millions of images** without RAM blow-up
|
||||
- Search results remain semantically correct
|
||||
(approximate NN is acceptable)
|
||||
|
||||
---
|
||||
|
||||
## 📌 One-Line Summary
|
||||
|
||||
> **The reverse image search system must follow the same architectural pattern as RAG:
|
||||
> fixed encoder + disk-backed vector index + retrieval at inference time — never load all embeddings into RAM.**
|
||||
338
generate_embeddings.py
Normal file
338
generate_embeddings.py
Normal file
@@ -0,0 +1,338 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate Embeddings Pipeline
|
||||
=============================
|
||||
Processes images from ../images (with subdirectories) and generates:
|
||||
- embeddings.pkl: Feature vectors for all images
|
||||
- filenames.pkl: Corresponding file paths
|
||||
|
||||
Usage:
|
||||
python generate_embeddings.py
|
||||
|
||||
This script combines bulk parallel processing with checkpoint merging.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import joblib
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
from tqdm import tqdm
|
||||
# ProcessPoolExecutor available for future parallel enhancement
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except Exception:
|
||||
faiss = None
|
||||
|
||||
|
||||
# Configuration
|
||||
IMAGES_ROOT = "../images"
|
||||
EMBEDDINGS_DIR = "embeddings"
|
||||
FILENAMES_DIR = "filenames"
|
||||
OUTPUT_EMBEDDINGS = "embeddings.pkl"
|
||||
OUTPUT_FILENAMES = "filenames.pkl"
|
||||
OUTPUT_FAISS_INDEX = "index.faiss"
|
||||
BATCH_SIZE = 64 # Process 64 images at once for GPU efficiency
|
||||
TARGET_SIZE = (224, 224)
|
||||
|
||||
# FAISS IVFPQ Configuration (memory-efficient for millions of vectors)
|
||||
FAISS_NLIST = 4096 # Number of clusters (sqrt(n) is a good rule)
|
||||
FAISS_M = 64 # Number of subquantizers (must divide dim evenly)
|
||||
FAISS_NBITS = 8 # Bits per subquantizer (8 = 256 centroids each)
|
||||
FAISS_TRAIN_SIZE = 100000 # Vectors to sample for training
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Initialize ResNet50 model for feature extraction."""
|
||||
import tensorflow
|
||||
from keras.layers import GlobalMaxPooling2D
|
||||
from keras.applications.resnet50 import ResNet50
|
||||
|
||||
model = ResNet50(
|
||||
weights='imagenet',
|
||||
include_top=False,
|
||||
input_shape=(224, 224, 3),
|
||||
)
|
||||
model.trainable = False
|
||||
model = tensorflow.keras.Sequential([
|
||||
model,
|
||||
GlobalMaxPooling2D()
|
||||
])
|
||||
return model
|
||||
|
||||
|
||||
def load_and_preprocess_image(img_path):
|
||||
"""Load and preprocess a single image."""
|
||||
from keras.preprocessing import image
|
||||
from keras.applications.resnet50 import preprocess_input
|
||||
|
||||
img = image.load_img(img_path, target_size=TARGET_SIZE)
|
||||
img_array = image.img_to_array(img)
|
||||
return preprocess_input(img_array)
|
||||
|
||||
|
||||
def extract_features_batch(img_paths, model):
|
||||
"""Extract normalized feature vectors from a batch of images."""
|
||||
batch_images = []
|
||||
valid_paths = []
|
||||
|
||||
for path in img_paths:
|
||||
try:
|
||||
img_array = load_and_preprocess_image(path)
|
||||
batch_images.append(img_array)
|
||||
valid_paths.append(path)
|
||||
except Exception as e:
|
||||
print(f"Error loading {path}: {e}")
|
||||
|
||||
if not batch_images:
|
||||
return [], []
|
||||
|
||||
# Stack into batch and predict all at once
|
||||
batch_array = np.array(batch_images)
|
||||
results = model.predict(batch_array, verbose=0)
|
||||
|
||||
# Normalize each feature vector
|
||||
normalized = [r / norm(r) for r in results]
|
||||
|
||||
return normalized, valid_paths
|
||||
|
||||
|
||||
def process_directory(directory, model):
|
||||
"""Process all images in a directory using batch processing."""
|
||||
dir_name = os.path.basename(directory)
|
||||
checkpoint_file = os.path.join(
|
||||
EMBEDDINGS_DIR,
|
||||
f"{dir_name}_embeddings.pkl",
|
||||
)
|
||||
filenames_file = os.path.join(
|
||||
FILENAMES_DIR,
|
||||
f"{dir_name}_filenames.pkl",
|
||||
)
|
||||
|
||||
if os.path.exists(checkpoint_file) and os.path.exists(filenames_file):
|
||||
print(f"Checkpoint exists for {dir_name}, skipping...")
|
||||
return
|
||||
|
||||
files = [
|
||||
os.path.join(directory, f)
|
||||
for f in os.listdir(directory)
|
||||
if os.path.isfile(os.path.join(directory, f))
|
||||
]
|
||||
|
||||
all_features = []
|
||||
all_filenames = []
|
||||
|
||||
# Process in batches
|
||||
num_batches = (len(files) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||
|
||||
for i in tqdm(
|
||||
range(0, len(files), BATCH_SIZE),
|
||||
desc=f"Processing {dir_name}",
|
||||
total=num_batches,
|
||||
):
|
||||
batch_paths = files[i:i + BATCH_SIZE]
|
||||
features, valid_paths = extract_features_batch(batch_paths, model)
|
||||
all_features.extend(features)
|
||||
all_filenames.extend(valid_paths)
|
||||
|
||||
joblib.dump(all_features, checkpoint_file)
|
||||
joblib.dump(all_filenames, filenames_file)
|
||||
print(
|
||||
f"Saved checkpoint for {dir_name}: {len(all_features)} images"
|
||||
)
|
||||
|
||||
|
||||
def generate_checkpoints():
|
||||
"""Generate embeddings checkpoints for each subdirectory."""
|
||||
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
|
||||
os.makedirs(FILENAMES_DIR, exist_ok=True)
|
||||
|
||||
if not os.path.exists(IMAGES_ROOT):
|
||||
print(f"Error: Images directory not found at {IMAGES_ROOT}")
|
||||
sys.exit(1)
|
||||
|
||||
directories = [
|
||||
os.path.join(IMAGES_ROOT, d)
|
||||
for d in os.listdir(IMAGES_ROOT)
|
||||
if os.path.isdir(os.path.join(IMAGES_ROOT, d))
|
||||
]
|
||||
|
||||
if not directories:
|
||||
print(f"Error: No subdirectories found in {IMAGES_ROOT}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(directories)} directories to process")
|
||||
|
||||
model = get_model()
|
||||
|
||||
for directory in directories:
|
||||
process_directory(directory, model)
|
||||
|
||||
|
||||
def merge_checkpoints():
|
||||
"""Merge checkpoint files into IVFPQ FAISS index (memory-efficient).
|
||||
|
||||
IVFPQ reduces index size from ~50GB to ~3-6GB for 6M vectors.
|
||||
Process:
|
||||
1. Collect training samples from checkpoints
|
||||
2. Train IVFPQ index on samples
|
||||
3. Stream-add all vectors to trained index
|
||||
"""
|
||||
print("\nMerging checkpoints into IVFPQ index...")
|
||||
|
||||
embedding_files = sorted([
|
||||
f for f in os.listdir(EMBEDDINGS_DIR)
|
||||
if f.endswith("_embeddings.pkl")
|
||||
])
|
||||
filename_files = sorted([
|
||||
f for f in os.listdir(FILENAMES_DIR)
|
||||
if f.endswith("_filenames.pkl")
|
||||
])
|
||||
|
||||
if not embedding_files:
|
||||
print("No checkpoint files found!")
|
||||
return
|
||||
|
||||
if faiss is None:
|
||||
print("\nFAISS is not available; cannot build index.faiss.")
|
||||
print("Merging filenames only (filenames.pkl) and exiting.")
|
||||
|
||||
all_filenames = []
|
||||
for fn_file in tqdm(filename_files, desc="Merging filenames"):
|
||||
fn_path = os.path.join(FILENAMES_DIR, fn_file)
|
||||
filenames = joblib.load(fn_path)
|
||||
all_filenames.extend(filenames)
|
||||
|
||||
joblib.dump(all_filenames, OUTPUT_FILENAMES)
|
||||
print("\nGenerated:")
|
||||
print(f" - {OUTPUT_FILENAMES}: {len(all_filenames)} filenames")
|
||||
print(" - index.faiss: NOT generated (faiss import failed)")
|
||||
return
|
||||
|
||||
# Step 1: Collect training samples
|
||||
print("\n[Step 2a] Collecting training samples...")
|
||||
training_vectors = []
|
||||
total_vectors = 0
|
||||
|
||||
for emb_file in tqdm(embedding_files, desc="Scanning checkpoints"):
|
||||
emb_path = os.path.join(EMBEDDINGS_DIR, emb_file)
|
||||
embeddings = joblib.load(emb_path)
|
||||
total_vectors += len(embeddings)
|
||||
|
||||
# Sample vectors for training (reservoir sampling)
|
||||
for emb in embeddings:
|
||||
if len(training_vectors) < FAISS_TRAIN_SIZE:
|
||||
training_vectors.append(emb)
|
||||
else:
|
||||
# Randomly replace with decreasing probability
|
||||
j = np.random.randint(0, total_vectors)
|
||||
if j < FAISS_TRAIN_SIZE:
|
||||
training_vectors[j] = emb
|
||||
|
||||
print(f" Total vectors: {total_vectors}")
|
||||
print(f" Training samples: {len(training_vectors)}")
|
||||
|
||||
# Step 2: Build and train IVFPQ index
|
||||
index = None
|
||||
if faiss is not None:
|
||||
print("\n[Step 2b] Training IVFPQ index...")
|
||||
train_np = np.vstack(training_vectors).astype(np.float32)
|
||||
dim = train_np.shape[1]
|
||||
|
||||
# Create quantizer and IVFPQ index
|
||||
quantizer = faiss.IndexFlatIP(dim)
|
||||
index = faiss.IndexIVFPQ(
|
||||
quantizer, dim, FAISS_NLIST, FAISS_M, FAISS_NBITS
|
||||
)
|
||||
|
||||
# Train on sample vectors
|
||||
print(f" Training on {len(train_np)} vectors...")
|
||||
index.train(train_np)
|
||||
print(" Training complete!")
|
||||
|
||||
# Free training vectors
|
||||
del training_vectors
|
||||
del train_np
|
||||
|
||||
# Step 3: Stream-add all vectors to trained index
|
||||
print("\n[Step 2c] Adding vectors to index...")
|
||||
all_filenames = []
|
||||
|
||||
for emb_file, fn_file in tqdm(
|
||||
list(zip(embedding_files, filename_files)),
|
||||
desc="Building index",
|
||||
):
|
||||
emb_path = os.path.join(EMBEDDINGS_DIR, emb_file)
|
||||
embeddings = joblib.load(emb_path)
|
||||
|
||||
if index is not None and embeddings:
|
||||
emb_np = np.vstack(embeddings).astype(np.float32)
|
||||
index.add(emb_np)
|
||||
|
||||
fn_path = os.path.join(FILENAMES_DIR, fn_file)
|
||||
filenames = joblib.load(fn_path)
|
||||
all_filenames.extend(filenames)
|
||||
|
||||
# Write outputs
|
||||
if index is not None:
|
||||
# Set nprobe for better recall at query time (stored in index)
|
||||
index.nprobe = 64
|
||||
faiss.write_index(index, OUTPUT_FAISS_INDEX)
|
||||
|
||||
joblib.dump(all_filenames, OUTPUT_FILENAMES)
|
||||
|
||||
# Summary
|
||||
print("\nGenerated:")
|
||||
print(f" - {OUTPUT_FILENAMES}: {len(all_filenames)} filenames")
|
||||
if index is not None:
|
||||
idx_size_mb = os.path.getsize(OUTPUT_FAISS_INDEX) / (1024 * 1024)
|
||||
print(f" - {OUTPUT_FAISS_INDEX}: IVFPQ index")
|
||||
print(f" Vectors: {total_vectors}")
|
||||
uncompressed_gb = total_vectors * 8192 / 1e9
|
||||
print(f" Size: {idx_size_mb:.1f} MB")
|
||||
print(f" (vs ~{uncompressed_gb:.1f} GB uncompressed)")
|
||||
print(f" nlist={FAISS_NLIST}, m={FAISS_M}, nbits={FAISS_NBITS}")
|
||||
else:
|
||||
print(" - FAISS not available, index not created")
|
||||
|
||||
|
||||
def main(merge_only=False):
|
||||
print("=" * 50)
|
||||
if merge_only:
|
||||
print("MERGE CHECKPOINTS ONLY")
|
||||
else:
|
||||
print("GENERATE EMBEDDINGS PIPELINE")
|
||||
print("=" * 50)
|
||||
print(f"Source: {IMAGES_ROOT}")
|
||||
print(f"Output: {OUTPUT_FILENAMES}, {OUTPUT_FAISS_INDEX}")
|
||||
print("=" * 50)
|
||||
|
||||
if not merge_only:
|
||||
# Step 1: Generate checkpoints for each directory
|
||||
print("\n[Step 1/2] Generating checkpoints...")
|
||||
generate_checkpoints()
|
||||
|
||||
# Step 2: Merge all checkpoints into FAISS index
|
||||
step = "[Step 2/2]" if not merge_only else "[Merge]"
|
||||
print(f"\n{step} Building FAISS IVFPQ index from checkpoints...")
|
||||
merge_checkpoints()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("COMPLETE! Run the Streamlit app:")
|
||||
print(" streamlit run run_streamlit_ui.py")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate embeddings and FAISS index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merge-only",
|
||||
action="store_true",
|
||||
help="Only merge checkpoints into FAISS index (skip generation)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(merge_only=args.merge_only)
|
||||
117
legacy/app_legacy.py
Normal file
117
legacy/app_legacy.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Legacy Flask API
|
||||
=================
|
||||
Original API implementation - kept for reference.
|
||||
Use api/app.py for production.
|
||||
"""
|
||||
|
||||
import os.path
|
||||
from flask import Flask, request, jsonify
|
||||
from reverse_icon_search import ReverseIconSearch
|
||||
from flask_cors import CORS
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
ris = ReverseIconSearch(200, 200)
|
||||
|
||||
pred_store = None
|
||||
|
||||
|
||||
def gen_pred_store():
|
||||
global pred_store
|
||||
file_location = "./predictions/react-icons/ricons.txt"
|
||||
if os.path.exists(file_location):
|
||||
with open(file_location, 'r') as file:
|
||||
pred_store = {}
|
||||
for line in file:
|
||||
index_of_hyphen_black = line.find("-black")
|
||||
if index_of_hyphen_black != -1:
|
||||
blk_len = len("-black")
|
||||
key = line[:index_of_hyphen_black + blk_len]
|
||||
offset = index_of_hyphen_black + blk_len
|
||||
numbers_str = line[offset:].strip()
|
||||
chunks = [
|
||||
numbers_str[i:i + 7]
|
||||
for i in range(0, len(numbers_str), 7)
|
||||
]
|
||||
pred_store[key] = chunks
|
||||
return pred_store
|
||||
|
||||
|
||||
def get_matches(request_in):
|
||||
global pred_store
|
||||
if 'file' not in request_in.files:
|
||||
return jsonify({'error': 'No file part'}), 400
|
||||
file = request_in.files['file']
|
||||
if file.filename == '':
|
||||
return jsonify({'error': 'No selected file'}), 400
|
||||
if pred_store is None:
|
||||
pred_store = gen_pred_store()
|
||||
if pred_store is not None:
|
||||
predictions = pred_store[file.filename[:file.filename.rindex(".png")]]
|
||||
predictions = list(
|
||||
map(
|
||||
lambda x: "../images/" + x[0:2] + "/" + x + ".png",
|
||||
predictions,
|
||||
)
|
||||
)
|
||||
return predictions
|
||||
file.name = file.filename
|
||||
ris.uploaded_file = file
|
||||
if "count" in request_in.form:
|
||||
ris.return_number_of_predictions = int(request_in.form["count"])
|
||||
return jsonify(ris.process_file())
|
||||
|
||||
|
||||
@app.route("/api/similar-icon-paths", methods=["POST"])
|
||||
def get_similar_icons_paths():
|
||||
matches = get_matches(request)
|
||||
ris.return_number_of_predictions = 200
|
||||
return matches
|
||||
|
||||
|
||||
@app.route("/api/similar-icon-abs-paths", methods=["POST"])
|
||||
def get_similar_icons_absolute_paths():
|
||||
matches = get_matches(request)
|
||||
matches = list(map(lambda x: os.path.abspath(x), matches))
|
||||
ris.return_number_of_predictions = 200
|
||||
return jsonify(matches)
|
||||
|
||||
|
||||
@app.route("/api/gen-all-predictions", methods=["GET"])
|
||||
def generate_all_matches():
|
||||
root_dir = "../images"
|
||||
if os.path.exists(root_dir):
|
||||
dirlist = [
|
||||
d for d in os.listdir(root_dir)
|
||||
if os.path.isdir(os.path.join(root_dir, d))
|
||||
]
|
||||
for dir_name in dirlist:
|
||||
file_batches = {}
|
||||
qualified_parent = os.path.join(root_dir, dir_name)
|
||||
save_location = "./predictions/" + dir_name + "/"
|
||||
all_file_names = [
|
||||
f for f in os.listdir(qualified_parent)
|
||||
if os.path.isfile(os.path.join(qualified_parent, f))
|
||||
]
|
||||
for i in range(0, 10):
|
||||
for file_name in all_file_names:
|
||||
if file_name[2] == str(i):
|
||||
if str(i) not in file_batches:
|
||||
file_batches[str(i)] = []
|
||||
file_batches[str(i)].append(file_name)
|
||||
for key, file_names in file_batches.items():
|
||||
for file_name in sorted(file_names):
|
||||
pred_file = save_location + file_name[2:3] + ".txt"
|
||||
qualified_path = os.path.abspath(
|
||||
os.path.join(qualified_parent, file_name)
|
||||
)
|
||||
ris.process_file_path(
|
||||
qualified_path,
|
||||
os.path.abspath(pred_file),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True, port=5002)
|
||||
64
main_streamlit.py
Normal file
64
main_streamlit.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Streamlit UI for Reverse Image Search
|
||||
======================================
|
||||
Extends the base ReverseIconSearch with Streamlit-specific UI.
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
import os
|
||||
from PIL import Image
|
||||
import math
|
||||
import joblib
|
||||
|
||||
from reverse_icon_search import ReverseIconSearch
|
||||
|
||||
|
||||
class StreamlitReverseIconSearch(ReverseIconSearch):
|
||||
"""Streamlit-specific extension of ReverseIconSearch."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculate_number_of_predictions,
|
||||
return_number_of_predictions,
|
||||
):
|
||||
st.set_page_config(layout="wide")
|
||||
st.title('Reverse Icon Search')
|
||||
super().__init__(
|
||||
calculate_number_of_predictions,
|
||||
return_number_of_predictions,
|
||||
)
|
||||
|
||||
def process_file(self):
|
||||
if self.uploaded_file is not None:
|
||||
if self.save_uploaded_file(self.uploaded_file):
|
||||
if self.filenames is None:
|
||||
self.filenames = joblib.load(self.filenames_location)
|
||||
display_image = Image.open(self.uploaded_file)
|
||||
st.image(display_image)
|
||||
features = self.feature_extraction(
|
||||
os.path.join("uploads", self.uploaded_file.name),
|
||||
self.model,
|
||||
)
|
||||
_distances, indices = self.recommend(features)
|
||||
print(list(map(lambda a: self.filenames[a], indices[0])))
|
||||
rows = math.ceil(self.return_number_of_predictions / 10)
|
||||
for j in range(0, rows):
|
||||
row = st.columns(10)
|
||||
for i in range(0, 10):
|
||||
file_index = 10 * j + i
|
||||
if file_index < self.return_number_of_predictions:
|
||||
name = self.filenames[indices[0][file_index]]
|
||||
with row[i]:
|
||||
display_name = name[
|
||||
name.rindex("/") + 1:name.rindex(".")
|
||||
]
|
||||
st.text(display_name)
|
||||
st.image(name)
|
||||
else:
|
||||
st.header("Some error occurred in file upload")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ris = StreamlitReverseIconSearch(200, 200)
|
||||
ris.uploaded_file = st.file_uploader("Choose an image")
|
||||
ris.process_file()
|
||||
198
postman/Visual_Search_API.postman_collection.json
Normal file
198
postman/Visual_Search_API.postman_collection.json
Normal file
@@ -0,0 +1,198 @@
|
||||
{
|
||||
"info": {
|
||||
"name": "Visual Search API",
|
||||
"description": "Reverse Image Search API - Find visually similar images using deep learning",
|
||||
"schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json",
|
||||
"_exporter_id": "visual-search-api"
|
||||
},
|
||||
"variable": [
|
||||
{
|
||||
"key": "baseUrl",
|
||||
"value": "http://localhost:5002",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"item": [
|
||||
{
|
||||
"name": "Health & Status",
|
||||
"item": [
|
||||
{
|
||||
"name": "Health Check",
|
||||
"request": {
|
||||
"method": "GET",
|
||||
"header": [],
|
||||
"url": {
|
||||
"raw": "{{baseUrl}}/api/health",
|
||||
"host": ["{{baseUrl}}"],
|
||||
"path": ["api", "health"]
|
||||
},
|
||||
"description": "Check if the API is healthy and responding"
|
||||
},
|
||||
"response": []
|
||||
},
|
||||
{
|
||||
"name": "Service Status",
|
||||
"request": {
|
||||
"method": "GET",
|
||||
"header": [],
|
||||
"url": {
|
||||
"raw": "{{baseUrl}}/api/status",
|
||||
"host": ["{{baseUrl}}"],
|
||||
"path": ["api", "status"]
|
||||
},
|
||||
"description": "Get detailed status of the search service including index info"
|
||||
},
|
||||
"response": []
|
||||
},
|
||||
{
|
||||
"name": "API Info",
|
||||
"request": {
|
||||
"method": "GET",
|
||||
"header": [],
|
||||
"url": {
|
||||
"raw": "{{baseUrl}}/",
|
||||
"host": ["{{baseUrl}}"],
|
||||
"path": [""]
|
||||
},
|
||||
"description": "Get API information and available endpoints"
|
||||
},
|
||||
"response": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Search",
|
||||
"item": [
|
||||
{
|
||||
"name": "Search Similar Images",
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"header": [],
|
||||
"body": {
|
||||
"mode": "formdata",
|
||||
"formdata": [
|
||||
{
|
||||
"key": "file",
|
||||
"type": "file",
|
||||
"src": "",
|
||||
"description": "Image file to search for similar images"
|
||||
},
|
||||
{
|
||||
"key": "count",
|
||||
"value": "10",
|
||||
"type": "text",
|
||||
"description": "Number of results (1-200, default: 10)"
|
||||
}
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"raw": "{{baseUrl}}/api/search",
|
||||
"host": ["{{baseUrl}}"],
|
||||
"path": ["api", "search"]
|
||||
},
|
||||
"description": "Search for visually similar images. Upload an image and get ranked results."
|
||||
},
|
||||
"response": []
|
||||
},
|
||||
{
|
||||
"name": "Batch Search",
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"header": [],
|
||||
"body": {
|
||||
"mode": "formdata",
|
||||
"formdata": [
|
||||
{
|
||||
"key": "files",
|
||||
"type": "file",
|
||||
"src": "",
|
||||
"description": "Multiple image files to search"
|
||||
},
|
||||
{
|
||||
"key": "count",
|
||||
"value": "10",
|
||||
"type": "text",
|
||||
"description": "Number of results per image (1-50, default: 10)"
|
||||
}
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"raw": "{{baseUrl}}/api/batch-search",
|
||||
"host": ["{{baseUrl}}"],
|
||||
"path": ["api", "batch-search"]
|
||||
},
|
||||
"description": "Search for similar images for multiple query images at once"
|
||||
},
|
||||
"response": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Legacy Endpoints",
|
||||
"item": [
|
||||
{
|
||||
"name": "Similar Icon Paths (Legacy)",
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"header": [],
|
||||
"body": {
|
||||
"mode": "formdata",
|
||||
"formdata": [
|
||||
{
|
||||
"key": "file",
|
||||
"type": "file",
|
||||
"src": "",
|
||||
"description": "Image file"
|
||||
},
|
||||
{
|
||||
"key": "count",
|
||||
"value": "200",
|
||||
"type": "text",
|
||||
"description": "Number of results"
|
||||
}
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"raw": "{{baseUrl}}/api/similar-icon-paths",
|
||||
"host": ["{{baseUrl}}"],
|
||||
"path": ["api", "similar-icon-paths"]
|
||||
},
|
||||
"description": "Legacy endpoint: Returns relative paths of similar images"
|
||||
},
|
||||
"response": []
|
||||
},
|
||||
{
|
||||
"name": "Similar Icon Absolute Paths (Legacy)",
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"header": [],
|
||||
"body": {
|
||||
"mode": "formdata",
|
||||
"formdata": [
|
||||
{
|
||||
"key": "file",
|
||||
"type": "file",
|
||||
"src": "",
|
||||
"description": "Image file"
|
||||
},
|
||||
{
|
||||
"key": "count",
|
||||
"value": "200",
|
||||
"type": "text",
|
||||
"description": "Number of results"
|
||||
}
|
||||
]
|
||||
},
|
||||
"url": {
|
||||
"raw": "{{baseUrl}}/api/similar-icon-abs-paths",
|
||||
"host": ["{{baseUrl}}"],
|
||||
"path": ["api", "similar-icon-abs-paths"]
|
||||
},
|
||||
"description": "Legacy endpoint: Returns absolute paths of similar images"
|
||||
},
|
||||
"response": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
31
requirements.txt
Normal file
31
requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
# Core ML/DL
|
||||
tensorflow-macos
|
||||
tensorflow-metal
|
||||
keras~=2.14.0
|
||||
numpy~=1.26.2
|
||||
scikit-learn~=1.3.2
|
||||
|
||||
# Image processing
|
||||
pillow~=10.1.0
|
||||
opencv-python~=4.8.1.78
|
||||
|
||||
# Vector search
|
||||
faiss-cpu
|
||||
|
||||
# Web frameworks
|
||||
flask~=3.0.0
|
||||
flask-cors
|
||||
streamlit~=1.28.2
|
||||
|
||||
# Utilities
|
||||
joblib~=1.3.2
|
||||
tqdm~=4.66.1
|
||||
pandas
|
||||
matplotlib
|
||||
|
||||
# Testing
|
||||
pytest~=7.4.0
|
||||
pytest-cov~=4.1.0
|
||||
|
||||
# Note: tensorflow~=2.14.0 is for non-macOS systems
|
||||
# On macOS with Apple Silicon, use tensorflow-macos + tensorflow-metal
|
||||
136
reverse_icon_search.py
Normal file
136
reverse_icon_search.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import joblib
|
||||
from keras import Sequential
|
||||
from keras.preprocessing import image
|
||||
from keras.layers import GlobalMaxPooling2D
|
||||
from keras.applications.resnet50 import ResNet50, preprocess_input
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from numpy.linalg import norm
|
||||
|
||||
try:
|
||||
import faiss
|
||||
except Exception:
|
||||
faiss = None
|
||||
|
||||
|
||||
class ReverseIconSearch:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
calculate_number_of_predictions,
|
||||
return_number_of_predictions,
|
||||
):
|
||||
self.calculate_number_of_predictions = calculate_number_of_predictions
|
||||
self.return_number_of_predictions = return_number_of_predictions
|
||||
self.filenames_location = "filenames.pkl"
|
||||
self.embeddings_location = "embeddings.pkl"
|
||||
self.neighbors_location = "neighbors.pkl"
|
||||
self.index_location = "index.faiss"
|
||||
self.upload_location = "uploads"
|
||||
self.model = ResNet50(
|
||||
weights='imagenet',
|
||||
include_top=False,
|
||||
input_shape=(224, 224, 3),
|
||||
)
|
||||
self.model.trainable = False
|
||||
self.target_size = (224, 224)
|
||||
self.uploaded_file = None
|
||||
self.neighbors = None
|
||||
self.index = None
|
||||
self.filenames = None
|
||||
|
||||
self.model = Sequential([
|
||||
self.model,
|
||||
GlobalMaxPooling2D()
|
||||
])
|
||||
|
||||
def save_uploaded_file(self, uploaded_file):
|
||||
try:
|
||||
with open(
|
||||
os.path.join(self.upload_location, uploaded_file.name),
|
||||
'wb',
|
||||
) as f:
|
||||
f.write(uploaded_file.getbuffer())
|
||||
return 1
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def feature_extraction(self, img_path, model):
|
||||
img = image.load_img(img_path, target_size=self.target_size)
|
||||
img_array = image.img_to_array(img)
|
||||
expanded_img_array = np.expand_dims(img_array, axis=0)
|
||||
preprocessed_img = preprocess_input(expanded_img_array)
|
||||
result = model.predict(preprocessed_img).flatten()
|
||||
normalized_result = result / norm(result)
|
||||
|
||||
return normalized_result
|
||||
|
||||
def recommend(self, features):
|
||||
if faiss is not None and os.path.exists(self.index_location):
|
||||
if self.index is None:
|
||||
self.index = faiss.read_index(self.index_location)
|
||||
# For IVFPQ: set nprobe for recall/speed tradeoff
|
||||
if hasattr(self.index, 'nprobe'):
|
||||
self.index.nprobe = 64 # Higher = better recall, slower
|
||||
query = np.asarray([features], dtype=np.float32)
|
||||
distances, indices = self.index.search(
|
||||
query,
|
||||
self.calculate_number_of_predictions,
|
||||
)
|
||||
return distances, indices
|
||||
|
||||
if self.neighbors is None:
|
||||
if os.path.exists(self.neighbors_location):
|
||||
neighbors = joblib.load(self.neighbors_location)
|
||||
else:
|
||||
neighbors = NearestNeighbors(
|
||||
n_neighbors=self.calculate_number_of_predictions,
|
||||
algorithm='auto',
|
||||
metric='euclidean',
|
||||
)
|
||||
feature_list = np.array(joblib.load(self.embeddings_location))
|
||||
neighbors.fit(feature_list)
|
||||
joblib.dump(neighbors, self.neighbors_location)
|
||||
self.neighbors = neighbors
|
||||
return self.neighbors.kneighbors([features])
|
||||
|
||||
def process_file(self):
|
||||
if self.uploaded_file is not None:
|
||||
if self.save_uploaded_file(self.uploaded_file):
|
||||
if self.filenames is None:
|
||||
self.filenames = joblib.load(self.filenames_location)
|
||||
features = self.feature_extraction(
|
||||
os.path.join("uploads", self.uploaded_file.name),
|
||||
self.model,
|
||||
)
|
||||
_distances, indices = self.recommend(features)
|
||||
response = list(
|
||||
map(
|
||||
lambda a: self.filenames[a],
|
||||
indices[0][0:self.return_number_of_predictions],
|
||||
)
|
||||
)
|
||||
return response
|
||||
else:
|
||||
print("Some error occurred in file upload")
|
||||
|
||||
def trim_file_name(self, index):
|
||||
file_name = self.filenames[index]
|
||||
return file_name[file_name.rindex("/") + 1:file_name.rindex(".")]
|
||||
|
||||
def process_file_path(self, location, save_location):
|
||||
try:
|
||||
if self.filenames is None:
|
||||
self.filenames = joblib.load(self.filenames_location)
|
||||
features = self.feature_extraction(location, self.model)
|
||||
_distances, indices = self.recommend(features)
|
||||
response = list(map(self.trim_file_name, indices[0]))
|
||||
save_str = location[location.rindex("/") + 1:location.rindex(".")]
|
||||
save_str += "".join(response)
|
||||
with open(save_location, "a", encoding="ascii") as file:
|
||||
# Write your numeric string to the file
|
||||
file.write("\n" + save_str)
|
||||
return response
|
||||
except Exception:
|
||||
print("Unable to process file at: ", location)
|
||||
13
run_api_server.py
Normal file
13
run_api_server.py
Normal file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
API Server Entrypoint
|
||||
=====================
|
||||
Production-ready Flask API server.
|
||||
"""
|
||||
|
||||
from api.app import create_app
|
||||
from config import config
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
app.run(host='0.0.0.0', port=config.API_PORT, debug=config.DEBUG)
|
||||
5
run_generate_embeddings.py
Normal file
5
run_generate_embeddings.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from generate_embeddings import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
run_streamlit_ui.py
Normal file
9
run_streamlit_ui.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import streamlit as st
|
||||
|
||||
from main_streamlit import StreamlitReverseIconSearch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ris = StreamlitReverseIconSearch(200, 200)
|
||||
ris.uploaded_file = st.file_uploader("Choose an image")
|
||||
ris.process_file()
|
||||
40
test.py
Normal file
40
test.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import pickle
|
||||
import tensorflow
|
||||
import numpy as np
|
||||
from numpy.linalg import norm
|
||||
from tensorflow.keras.preprocessing import image
|
||||
from tensorflow.keras.layers import GlobalMaxPooling2D
|
||||
from tensorflow.keras.applications.resnet50 import ResNet50,preprocess_input
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
import cv2
|
||||
|
||||
feature_list = np.array(pickle.load(open('embeddings.pkl','rb')))
|
||||
filenames = pickle.load(open('filenames.pkl','rb'))
|
||||
|
||||
model = ResNet50(weights='imagenet',include_top=False,input_shape=(224,224,3))
|
||||
model.trainable = False
|
||||
|
||||
model = tensorflow.keras.Sequential([
|
||||
model,
|
||||
GlobalMaxPooling2D()
|
||||
])
|
||||
|
||||
img = image.load_img('sample/shirt.jpg',target_size=(224,224))
|
||||
img_array = image.img_to_array(img)
|
||||
expanded_img_array = np.expand_dims(img_array, axis=0)
|
||||
preprocessed_img = preprocess_input(expanded_img_array)
|
||||
result = model.predict(preprocessed_img).flatten()
|
||||
normalized_result = result / norm(result)
|
||||
|
||||
neighbors = NearestNeighbors(n_neighbors=6,algorithm='brute',metric='euclidean')
|
||||
neighbors.fit(feature_list)
|
||||
|
||||
distances,indices = neighbors.kneighbors([normalized_result])
|
||||
|
||||
print(indices)
|
||||
|
||||
for file in indices[0][1:6]:
|
||||
temp_img = cv2.imread(filenames[file])
|
||||
cv2.imshow('output',cv2.resize(temp_img,(512,512)))
|
||||
cv2.waitKey(0)
|
||||
|
||||
3
tests/__init__.py
Normal file
3
tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Test Package
|
||||
"""
|
||||
36
tests/conftest.py
Normal file
36
tests/conftest.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Pytest Configuration and Fixtures
|
||||
==================================
|
||||
Shared fixtures for all tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def test_image_bytes():
|
||||
"""Create fake image bytes for testing."""
|
||||
return b'\x89PNG\r\n\x1a\n' + b'\x00' * 100
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
"""Mock configuration for tests."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
config = Mock()
|
||||
config.FAISS_INDEX_PATH = 'index.faiss'
|
||||
config.FILENAMES_PATH = 'filenames.pkl'
|
||||
config.EMBEDDINGS_PATH = 'embeddings.pkl'
|
||||
config.UPLOAD_DIR = 'test_uploads'
|
||||
config.TARGET_SIZE = (224, 224)
|
||||
config.INPUT_SHAPE = (224, 224, 3)
|
||||
config.N_NEIGHBORS = 200
|
||||
config.API_PORT = 5002
|
||||
config.DEBUG = True
|
||||
|
||||
return config
|
||||
166
tests/test_api.py
Normal file
166
tests/test_api.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
API Route Tests
|
||||
================
|
||||
Unit tests for API endpoints with mocked services.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create test application."""
|
||||
with patch('api.routes.SearchService'):
|
||||
from api.app import create_app
|
||||
app = create_app({'TESTING': True})
|
||||
yield app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create test client."""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for /api/health endpoint."""
|
||||
|
||||
def test_health_check_returns_200(self, client):
|
||||
"""Health check should return 200 with healthy status."""
|
||||
response = client.get('/api/health')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['success'] is True
|
||||
assert data['status'] == 'healthy'
|
||||
assert 'timestamp' in data
|
||||
|
||||
def test_health_check_json_content_type(self, client):
|
||||
"""Health check should return JSON content type."""
|
||||
response = client.get('/api/health')
|
||||
assert response.content_type == 'application/json'
|
||||
|
||||
|
||||
class TestStatusEndpoint:
|
||||
"""Tests for /api/status endpoint."""
|
||||
|
||||
def test_status_returns_200(self, client):
|
||||
"""Status endpoint should return 200."""
|
||||
with patch('api.routes.get_search_service') as mock_get_service:
|
||||
mock_service = Mock()
|
||||
mock_service.get_status.return_value = {
|
||||
'faiss_available': True,
|
||||
'model_loaded': True,
|
||||
'index_loaded': True,
|
||||
}
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
response = client.get('/api/status')
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['success'] is True
|
||||
assert 'data' in data
|
||||
|
||||
|
||||
class TestSearchEndpoint:
|
||||
"""Tests for /api/search endpoint."""
|
||||
|
||||
def test_search_without_file_returns_400(self, client):
|
||||
"""Search without file should return 400."""
|
||||
response = client.post('/api/search')
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['success'] is False
|
||||
assert data['error']['code'] == 'VALIDATION_ERROR'
|
||||
|
||||
def test_search_with_empty_filename_returns_400(self, client):
|
||||
"""Search with empty filename should return 400."""
|
||||
response = client.post(
|
||||
'/api/search',
|
||||
data={'file': (b'', '')},
|
||||
content_type='multipart/form-data'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_search_with_invalid_count_returns_400(self, client):
|
||||
"""Search with invalid count should return 400."""
|
||||
from io import BytesIO
|
||||
response = client.post(
|
||||
'/api/search',
|
||||
data={
|
||||
'file': (BytesIO(b'fake image'), 'test.jpg'),
|
||||
'count': '500'
|
||||
},
|
||||
content_type='multipart/form-data'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert 'count' in str(data['error']['details'])
|
||||
|
||||
def test_search_success(self, client):
|
||||
"""Search with valid file should return results."""
|
||||
from io import BytesIO
|
||||
|
||||
with patch('api.routes.get_search_service') as mock_get_service:
|
||||
mock_service = Mock()
|
||||
mock_service.search.return_value = [
|
||||
{'path': '/images/1.jpg', 'filename': '1.jpg', 'rank': 1},
|
||||
{'path': '/images/2.jpg', 'filename': '2.jpg', 'rank': 2},
|
||||
]
|
||||
mock_get_service.return_value = mock_service
|
||||
|
||||
response = client.post(
|
||||
'/api/search',
|
||||
data={
|
||||
'file': (BytesIO(b'fake image'), 'test.jpg'),
|
||||
'count': '10'
|
||||
},
|
||||
content_type='multipart/form-data'
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
data = json.loads(response.data)
|
||||
assert data['success'] is True
|
||||
assert len(data['data']['results']) == 2
|
||||
assert 'query_time_ms' in data['data']
|
||||
|
||||
|
||||
class TestBatchSearchEndpoint:
|
||||
"""Tests for /api/batch-search endpoint."""
|
||||
|
||||
def test_batch_search_without_files_returns_400(self, client):
|
||||
"""Batch search without files should return 400."""
|
||||
response = client.post('/api/batch-search')
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_batch_search_with_invalid_count_returns_400(self, client):
|
||||
"""Batch search with count > 50 should return 400."""
|
||||
from io import BytesIO
|
||||
response = client.post(
|
||||
'/api/batch-search',
|
||||
data={
|
||||
'files': (BytesIO(b'fake'), 'test.jpg'),
|
||||
'count': '100'
|
||||
},
|
||||
content_type='multipart/form-data'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
class TestLegacyEndpoints:
|
||||
"""Tests for legacy API endpoints."""
|
||||
|
||||
def test_similar_icon_paths_without_file_returns_400(self, client):
|
||||
"""Legacy endpoint without file should return 400."""
|
||||
response = client.post('/api/similar-icon-paths')
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_similar_icon_abs_paths_without_file_returns_400(self, client):
|
||||
"""Legacy absolute paths endpoint without file should return 400."""
|
||||
response = client.post('/api/similar-icon-abs-paths')
|
||||
assert response.status_code == 400
|
||||
116
tests/test_services.py
Normal file
116
tests/test_services.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
Service Layer Tests
|
||||
====================
|
||||
Unit tests for SearchService with mocked dependencies.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestSearchService:
|
||||
"""Tests for SearchService class."""
|
||||
|
||||
@patch('api.services.faiss')
|
||||
@patch('api.services.joblib')
|
||||
def test_get_status_returns_dict(self, mock_joblib, mock_faiss):
|
||||
"""get_status should return status dictionary."""
|
||||
with patch('api.services.config') as mock_config:
|
||||
mock_config.FAISS_INDEX_PATH = 'index.faiss'
|
||||
mock_config.FILENAMES_PATH = 'filenames.pkl'
|
||||
mock_config.UPLOAD_DIR = 'uploads'
|
||||
mock_config.INPUT_SHAPE = (224, 224, 3)
|
||||
|
||||
with patch('os.path.exists', return_value=False):
|
||||
with patch('os.makedirs'):
|
||||
with patch.object(
|
||||
__import__('api.services', fromlist=['SearchService'])
|
||||
.SearchService,
|
||||
'_load_model'
|
||||
):
|
||||
from api.services import SearchService
|
||||
service = SearchService()
|
||||
service.model = Mock()
|
||||
|
||||
status = service.get_status()
|
||||
|
||||
assert 'faiss_available' in status
|
||||
assert 'model_loaded' in status
|
||||
assert 'index_loaded' in status
|
||||
|
||||
def test_ensure_upload_dir_creates_directory(self):
|
||||
"""_ensure_upload_dir should create directory if not exists."""
|
||||
with patch('api.services.config') as mock_config:
|
||||
mock_config.UPLOAD_DIR = 'test_uploads'
|
||||
mock_config.INPUT_SHAPE = (224, 224, 3)
|
||||
|
||||
with patch('os.path.exists', return_value=False):
|
||||
with patch('os.makedirs') as mock_makedirs:
|
||||
with patch.object(
|
||||
__import__('api.services', fromlist=['SearchService'])
|
||||
.SearchService,
|
||||
'_load_model'
|
||||
):
|
||||
from api.services import SearchService
|
||||
SearchService() # Instantiation triggers dir creation
|
||||
|
||||
mock_makedirs.assert_called()
|
||||
|
||||
|
||||
class TestFeatureExtraction:
|
||||
"""Tests for feature extraction utilities."""
|
||||
|
||||
def test_normalize_vector(self):
|
||||
"""Feature vectors should be normalized."""
|
||||
from numpy.linalg import norm
|
||||
|
||||
vector = np.array([3.0, 4.0])
|
||||
normalized = vector / norm(vector)
|
||||
|
||||
assert np.isclose(norm(normalized), 1.0)
|
||||
|
||||
def test_normalize_preserves_direction(self):
|
||||
"""Normalization should preserve vector direction."""
|
||||
from numpy.linalg import norm
|
||||
|
||||
vector = np.array([3.0, 4.0])
|
||||
normalized = vector / norm(vector)
|
||||
|
||||
expected = np.array([0.6, 0.8])
|
||||
assert np.allclose(normalized, expected)
|
||||
|
||||
|
||||
class TestExceptions:
|
||||
"""Tests for custom exceptions."""
|
||||
|
||||
def test_validation_error_to_dict(self):
|
||||
"""ValidationError should convert to proper dict."""
|
||||
from api.exceptions import ValidationError
|
||||
|
||||
error = ValidationError('Invalid input', field='file')
|
||||
result = error.to_dict()
|
||||
|
||||
assert result['success'] is False
|
||||
assert result['error']['code'] == 'VALIDATION_ERROR'
|
||||
assert result['error']['message'] == 'Invalid input'
|
||||
assert result['error']['details']['field'] == 'file'
|
||||
|
||||
def test_search_error_to_dict(self):
|
||||
"""SearchError should convert to proper dict."""
|
||||
from api.exceptions import SearchError
|
||||
|
||||
error = SearchError('Search failed', details={'reason': 'timeout'})
|
||||
result = error.to_dict()
|
||||
|
||||
assert result['success'] is False
|
||||
assert result['error']['code'] == 'SEARCH_ERROR'
|
||||
assert result['error']['details']['reason'] == 'timeout'
|
||||
|
||||
def test_index_not_loaded_error(self):
|
||||
"""IndexNotLoadedError should have proper message."""
|
||||
from api.exceptions import IndexNotLoadedError
|
||||
|
||||
error = IndexNotLoadedError()
|
||||
|
||||
assert error.status_code == 503
|
||||
assert 'index.faiss' in error.message
|
||||
Reference in New Issue
Block a user