From 9f8b8260052ebd0d2e7795df2d177c4ea0291513 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 23 Sep 2025 01:53:09 +0200 Subject: [PATCH 1/3] feat: add a real moth classifier to processing service --- processing_services/moths/Dockerfile | 7 + .../moths/README_GLOBAL_MOTH_CLASSIFIER.md | 203 ++++++++ processing_services/moths/api/Dockerfile | 7 + .../api/README_GLOBAL_MOTH_CLASSIFIER.md | 203 ++++++++ processing_services/moths/api/api/__init__.py | 0 .../moths/api/api/algorithms.py | 431 +++++++++++++++++ processing_services/moths/api/api/api.py | 261 ++++++++++ processing_services/moths/api/api/base.py | 241 +++++++++ .../moths/api/api/global_moth_classifier.py | 254 ++++++++++ .../moths/api/api/pipelines.py | 457 ++++++++++++++++++ processing_services/moths/api/api/schemas.py | 341 +++++++++++++ processing_services/moths/api/api/test.py | 64 +++ processing_services/moths/api/api/utils.py | 172 +++++++ .../moths/api/docker-compose.yml | 25 + processing_services/moths/api/main.py | 4 + .../moths/api/requirements.txt | 10 + .../moths/api/test_api_integration.py | 173 +++++++ .../moths/api/test_compilation_logging.py | 38 ++ .../moths/api/test_global_moth_pipeline.py | 109 +++++ processing_services/moths/requirements.txt | 10 + .../moths/test_api_integration.py | 173 +++++++ .../moths/test_global_moth_pipeline.py | 109 +++++ 22 files changed, 3292 insertions(+) create mode 100644 processing_services/moths/Dockerfile create mode 100644 processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md create mode 100644 processing_services/moths/api/Dockerfile create mode 100644 processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md create mode 100644 processing_services/moths/api/api/__init__.py create mode 100644 processing_services/moths/api/api/algorithms.py create mode 100644 processing_services/moths/api/api/api.py create mode 100644 processing_services/moths/api/api/base.py create mode 100644 processing_services/moths/api/api/global_moth_classifier.py create mode 100644 processing_services/moths/api/api/pipelines.py create mode 100644 processing_services/moths/api/api/schemas.py create mode 100644 processing_services/moths/api/api/test.py create mode 100644 processing_services/moths/api/api/utils.py create mode 100644 processing_services/moths/api/docker-compose.yml create mode 100644 processing_services/moths/api/main.py create mode 100644 processing_services/moths/api/requirements.txt create mode 100644 processing_services/moths/api/test_api_integration.py create mode 100644 processing_services/moths/api/test_compilation_logging.py create mode 100644 processing_services/moths/api/test_global_moth_pipeline.py create mode 100644 processing_services/moths/requirements.txt create mode 100644 processing_services/moths/test_api_integration.py create mode 100644 processing_services/moths/test_global_moth_pipeline.py diff --git a/processing_services/moths/Dockerfile b/processing_services/moths/Dockerfile new file mode 100644 index 000000000..3e0781f92 --- /dev/null +++ b/processing_services/moths/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.11-slim + +# Set up ml backend FastAPI +WORKDIR /app +COPY . /app +RUN pip install -r ./requirements.txt +CMD ["python", "/app/main.py"] diff --git a/processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md b/processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md new file mode 100644 index 000000000..223ada70d --- /dev/null +++ b/processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md @@ -0,0 +1,203 @@ +# Global Moth Classifier Processing Service Implementation + +## Overview + +Successfully implemented a simplified, self-contained module for running the global moth classifier within the processing services framework. This eliminates the need for database and queue dependencies while maintaining full classifier functionality. + +## Architecture Summary + +### Original vs. Simplified Architecture + +**Original trapdata architecture:** +- Complex inheritance: `MothClassifierGlobal` → `APIMothClassifier` + `GlobalMothSpeciesClassifier` → multiple base classes +- Database dependencies: `db_path`, `QueueManager`, `InferenceBaseClass` with DB connections +- Queue system: `UnclassifiedObjectQueue`, `DetectedObjectQueue` +- File management: `user_data_path`, complex caching system + +**New simplified architecture:** +- Flattened inheritance: `GlobalMothClassifier` → `Algorithm` + `TimmResNet50Base` → `SimplifiedInferenceBase` +- No database dependencies: Direct PIL image processing +- No queue system: Processes Detection objects directly +- Simple file management: Downloads models to temp directories as needed + +### Key Components Created + +1. **Base Classes** (`base.py`): + - `SimplifiedInferenceBase`: Core inference functionality without DB dependencies + - `ResNet50Base`: ResNet50-specific model loading and inference + - `TimmResNet50Base`: Timm-based ResNet50 implementation + +2. **Global Moth Classifier** (`global_moth_classifier.py`): + - `GlobalMothClassifier`: Simplified version of the original classifier + - 29,176+ species support + - Batch processing capabilities + - Algorithm interface compatibility + +3. **New Pipeline** (`pipelines.py`): + - `ZeroShotObjectDetectorWithGlobalMothClassifierPipeline`: Combines HF zero-shot detector with global moth classifier + - Two-stage processing: detection → classification + - Configurable batch sizes for optimal performance + +4. **Updated Utils** (`utils.py`): + - Added `get_best_device()` for GPU/CPU selection + - Enhanced `get_or_download_file()` for model weight downloading + +## Data Flow + +``` +PipelineRequest + ↓ +SourceImages (PIL images) + ↓ +ZeroShotObjectDetector (stage 1) + ↓ +Detections with bounding boxes + ↓ +GlobalMothClassifier (stage 2) + ↓ +Detections with species classifications + ↓ +PipelineResponse +``` + +## API Integration + +The new pipeline is now available in the processing service API: + +- **Pipeline Name**: "Zero Shot Object Detector With Global Moth Classifier Pipeline" +- **Slug**: `zero-shot-object-detector-with-global-moth-classifier-pipeline` +- **Algorithms**: 2 (detector + classifier) +- **Batch Sizes**: [1, 4] (detector=1, classifier=4 for efficiency) + +## Key Differences from trapdata + +1. **No Database Dependencies**: + - Removed: `db_path`, `QueueManager`, `save_classified_objects()` + - Uses: Direct Detection object processing + +2. **Simplified File Management**: + - Removed: Complex `user_data_path` caching + - Uses: Temporary directories for model downloads + +3. **Flattened Inheritance**: + - Removed: Complex multi-level inheritance chains + - Uses: Simple Algorithm + base class pattern + +4. **Direct Image Processing**: + - Removed: Database-backed image references + - Uses: PIL images attached to Detection objects + +5. **API-First Design**: + - Removed: CLI and database queue processing + - Focused: REST API pipeline processing only + +## Benefits + +1. **Simplicity**: Much easier to understand and maintain +2. **Performance**: No database overhead, direct processing +3. **Portability**: Self-contained, minimal dependencies +4. **Scalability**: Stateless processing suitable for containerization +5. **Maintainability**: Clear separation of concerns, focused functionality + +## Usage Example + +```python +from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline +from api.schemas import SourceImage + +# Create pipeline +pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( + source_images=[source_image], + request_config={"candidate_labels": ["moth", "insect"]}, + existing_detections=[] +) + +# Compile and run +pipeline.compile() +results = pipeline.run() +``` + +## Files Created/Modified + +- ✅ `processing_services/example/api/base.py` - New simplified base classes +- ✅ `processing_services/example/api/global_moth_classifier.py` - New global moth classifier +- ✅ `processing_services/example/api/pipelines.py` - Added new pipeline class +- ✅ `processing_services/example/api/utils.py` - Enhanced utility functions +- ✅ `processing_services/example/api/api.py` - Added new pipeline to API +- ✅ `processing_services/example/test_global_moth_pipeline.py` - Basic test file + +## Original trapdata Source Analysis + +To create this simplified implementation, the following files and line ranges from the original AMI Data Companion (trapdata) module were analyzed: + +### Core Classification Classes +- **`trapdata/api/models/classification.py`**: + - Lines 1-25: Import statements and base dependencies + - Lines 37-163: `APIMothClassifier` base class implementation + - Lines 165-209: All classifier implementations including `MothClassifierGlobal` (line 207) + - Lines 112-137: `save_results()` method for processing predictions + - Lines 138-163: `update_classification()` and pipeline integration + +### Base Inference Framework +- **`trapdata/ml/models/base.py`**: + - Lines 58-120: `InferenceBaseClass` core structure and initialization + - Lines 121-200: Model loading, transforms, and file management methods + - Lines 25-50: Normalization constants and utility functions + +### Global Moth Classifier Implementation +- **`trapdata/ml/models/classification.py`**: + - Lines 507-527: `GlobalMothSpeciesClassifier` class definition and configuration + - Lines 338-375: `SpeciesClassifier` base class and database integration + - Lines 527-567: Various regional classifiers showing inheritance patterns + - Lines 1-50: Import structure and database dataset classes + - Lines 200-300: ResNet50 and Timm-based classifier implementations + +### API Integration Patterns +- **`trapdata/api/api.py`**: + - Lines 1-50: FastAPI setup and classifier imports + - Lines 37-60: `CLASSIFIER_CHOICES` dictionary including global moths + - Lines 120-150: `make_pipeline_config_response()` function + - Lines 175-310: Main processing pipeline in `process()` endpoint + - Lines 60-80: Pipeline choice enumeration and filtering logic + +### Model Architecture References +- **`trapdata/ml/models/localization.py`**: + - Lines 142-200: `ObjectDetector` base class structure + - Lines 245-290: `MothObjectDetector_FasterRCNN_2023` implementation + +### API Schema Definitions +- **`trapdata/api/schemas.py`**: + - Lines 293-330: Pipeline configuration schemas + - Lines 1-100: Detection and classification response schemas + +### Processing Pipeline Examples +- **`trapdata/api/models/localization.py`**: + - Lines 13-60: `APIMothDetector` implementation showing API adaptation pattern + +### Key Configuration Values +From the analysis, these critical configuration values were extracted: +- **Model weights URL**: Lines 507-515 in `classification.py` +- **Labels path**: Lines 516-520 in `classification.py` +- **Input size**: Line 508 (`input_size = 128`) +- **Normalization**: Line 509 (`normalization = imagenet_normalization`) +- **Species count**: 29,176 species from model description +- **Default taxon rank**: "SPECIES" from base class + +### Database Dependencies Removed +These database-dependent components were identified and removed: +- **`trapdata/db/models/queue.py`**: Lines 1-500 (entire queue system) +- **`trapdata/db/models/detections.py`**: `save_classified_objects()` function +- **Database path parameters**: Throughout `base.py` and classification classes +- **Queue managers**: `UnclassifiedObjectQueue`, `DetectedObjectQueue` references + +This analysis allowed for the creation of a streamlined implementation that preserves all the essential functionality while eliminating the complex database and queue infrastructure. + +## Next Steps + +1. **Testing**: Run end-to-end tests with real images +2. **Performance Optimization**: Tune batch sizes and memory usage +3. **Error Handling**: Add robust error handling for edge cases +4. **Documentation**: Add detailed API documentation +5. **Docker Integration**: Update Docker configurations if needed + +The implementation successfully provides a clean, maintainable global moth classifier that can process 29,176+ species without the complexity of the original trapdata system. diff --git a/processing_services/moths/api/Dockerfile b/processing_services/moths/api/Dockerfile new file mode 100644 index 000000000..3e0781f92 --- /dev/null +++ b/processing_services/moths/api/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.11-slim + +# Set up ml backend FastAPI +WORKDIR /app +COPY . /app +RUN pip install -r ./requirements.txt +CMD ["python", "/app/main.py"] diff --git a/processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md b/processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md new file mode 100644 index 000000000..223ada70d --- /dev/null +++ b/processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md @@ -0,0 +1,203 @@ +# Global Moth Classifier Processing Service Implementation + +## Overview + +Successfully implemented a simplified, self-contained module for running the global moth classifier within the processing services framework. This eliminates the need for database and queue dependencies while maintaining full classifier functionality. + +## Architecture Summary + +### Original vs. Simplified Architecture + +**Original trapdata architecture:** +- Complex inheritance: `MothClassifierGlobal` → `APIMothClassifier` + `GlobalMothSpeciesClassifier` → multiple base classes +- Database dependencies: `db_path`, `QueueManager`, `InferenceBaseClass` with DB connections +- Queue system: `UnclassifiedObjectQueue`, `DetectedObjectQueue` +- File management: `user_data_path`, complex caching system + +**New simplified architecture:** +- Flattened inheritance: `GlobalMothClassifier` → `Algorithm` + `TimmResNet50Base` → `SimplifiedInferenceBase` +- No database dependencies: Direct PIL image processing +- No queue system: Processes Detection objects directly +- Simple file management: Downloads models to temp directories as needed + +### Key Components Created + +1. **Base Classes** (`base.py`): + - `SimplifiedInferenceBase`: Core inference functionality without DB dependencies + - `ResNet50Base`: ResNet50-specific model loading and inference + - `TimmResNet50Base`: Timm-based ResNet50 implementation + +2. **Global Moth Classifier** (`global_moth_classifier.py`): + - `GlobalMothClassifier`: Simplified version of the original classifier + - 29,176+ species support + - Batch processing capabilities + - Algorithm interface compatibility + +3. **New Pipeline** (`pipelines.py`): + - `ZeroShotObjectDetectorWithGlobalMothClassifierPipeline`: Combines HF zero-shot detector with global moth classifier + - Two-stage processing: detection → classification + - Configurable batch sizes for optimal performance + +4. **Updated Utils** (`utils.py`): + - Added `get_best_device()` for GPU/CPU selection + - Enhanced `get_or_download_file()` for model weight downloading + +## Data Flow + +``` +PipelineRequest + ↓ +SourceImages (PIL images) + ↓ +ZeroShotObjectDetector (stage 1) + ↓ +Detections with bounding boxes + ↓ +GlobalMothClassifier (stage 2) + ↓ +Detections with species classifications + ↓ +PipelineResponse +``` + +## API Integration + +The new pipeline is now available in the processing service API: + +- **Pipeline Name**: "Zero Shot Object Detector With Global Moth Classifier Pipeline" +- **Slug**: `zero-shot-object-detector-with-global-moth-classifier-pipeline` +- **Algorithms**: 2 (detector + classifier) +- **Batch Sizes**: [1, 4] (detector=1, classifier=4 for efficiency) + +## Key Differences from trapdata + +1. **No Database Dependencies**: + - Removed: `db_path`, `QueueManager`, `save_classified_objects()` + - Uses: Direct Detection object processing + +2. **Simplified File Management**: + - Removed: Complex `user_data_path` caching + - Uses: Temporary directories for model downloads + +3. **Flattened Inheritance**: + - Removed: Complex multi-level inheritance chains + - Uses: Simple Algorithm + base class pattern + +4. **Direct Image Processing**: + - Removed: Database-backed image references + - Uses: PIL images attached to Detection objects + +5. **API-First Design**: + - Removed: CLI and database queue processing + - Focused: REST API pipeline processing only + +## Benefits + +1. **Simplicity**: Much easier to understand and maintain +2. **Performance**: No database overhead, direct processing +3. **Portability**: Self-contained, minimal dependencies +4. **Scalability**: Stateless processing suitable for containerization +5. **Maintainability**: Clear separation of concerns, focused functionality + +## Usage Example + +```python +from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline +from api.schemas import SourceImage + +# Create pipeline +pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( + source_images=[source_image], + request_config={"candidate_labels": ["moth", "insect"]}, + existing_detections=[] +) + +# Compile and run +pipeline.compile() +results = pipeline.run() +``` + +## Files Created/Modified + +- ✅ `processing_services/example/api/base.py` - New simplified base classes +- ✅ `processing_services/example/api/global_moth_classifier.py` - New global moth classifier +- ✅ `processing_services/example/api/pipelines.py` - Added new pipeline class +- ✅ `processing_services/example/api/utils.py` - Enhanced utility functions +- ✅ `processing_services/example/api/api.py` - Added new pipeline to API +- ✅ `processing_services/example/test_global_moth_pipeline.py` - Basic test file + +## Original trapdata Source Analysis + +To create this simplified implementation, the following files and line ranges from the original AMI Data Companion (trapdata) module were analyzed: + +### Core Classification Classes +- **`trapdata/api/models/classification.py`**: + - Lines 1-25: Import statements and base dependencies + - Lines 37-163: `APIMothClassifier` base class implementation + - Lines 165-209: All classifier implementations including `MothClassifierGlobal` (line 207) + - Lines 112-137: `save_results()` method for processing predictions + - Lines 138-163: `update_classification()` and pipeline integration + +### Base Inference Framework +- **`trapdata/ml/models/base.py`**: + - Lines 58-120: `InferenceBaseClass` core structure and initialization + - Lines 121-200: Model loading, transforms, and file management methods + - Lines 25-50: Normalization constants and utility functions + +### Global Moth Classifier Implementation +- **`trapdata/ml/models/classification.py`**: + - Lines 507-527: `GlobalMothSpeciesClassifier` class definition and configuration + - Lines 338-375: `SpeciesClassifier` base class and database integration + - Lines 527-567: Various regional classifiers showing inheritance patterns + - Lines 1-50: Import structure and database dataset classes + - Lines 200-300: ResNet50 and Timm-based classifier implementations + +### API Integration Patterns +- **`trapdata/api/api.py`**: + - Lines 1-50: FastAPI setup and classifier imports + - Lines 37-60: `CLASSIFIER_CHOICES` dictionary including global moths + - Lines 120-150: `make_pipeline_config_response()` function + - Lines 175-310: Main processing pipeline in `process()` endpoint + - Lines 60-80: Pipeline choice enumeration and filtering logic + +### Model Architecture References +- **`trapdata/ml/models/localization.py`**: + - Lines 142-200: `ObjectDetector` base class structure + - Lines 245-290: `MothObjectDetector_FasterRCNN_2023` implementation + +### API Schema Definitions +- **`trapdata/api/schemas.py`**: + - Lines 293-330: Pipeline configuration schemas + - Lines 1-100: Detection and classification response schemas + +### Processing Pipeline Examples +- **`trapdata/api/models/localization.py`**: + - Lines 13-60: `APIMothDetector` implementation showing API adaptation pattern + +### Key Configuration Values +From the analysis, these critical configuration values were extracted: +- **Model weights URL**: Lines 507-515 in `classification.py` +- **Labels path**: Lines 516-520 in `classification.py` +- **Input size**: Line 508 (`input_size = 128`) +- **Normalization**: Line 509 (`normalization = imagenet_normalization`) +- **Species count**: 29,176 species from model description +- **Default taxon rank**: "SPECIES" from base class + +### Database Dependencies Removed +These database-dependent components were identified and removed: +- **`trapdata/db/models/queue.py`**: Lines 1-500 (entire queue system) +- **`trapdata/db/models/detections.py`**: `save_classified_objects()` function +- **Database path parameters**: Throughout `base.py` and classification classes +- **Queue managers**: `UnclassifiedObjectQueue`, `DetectedObjectQueue` references + +This analysis allowed for the creation of a streamlined implementation that preserves all the essential functionality while eliminating the complex database and queue infrastructure. + +## Next Steps + +1. **Testing**: Run end-to-end tests with real images +2. **Performance Optimization**: Tune batch sizes and memory usage +3. **Error Handling**: Add robust error handling for edge cases +4. **Documentation**: Add detailed API documentation +5. **Docker Integration**: Update Docker configurations if needed + +The implementation successfully provides a clean, maintainable global moth classifier that can process 29,176+ species without the complexity of the original trapdata system. diff --git a/processing_services/moths/api/api/__init__.py b/processing_services/moths/api/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/processing_services/moths/api/api/algorithms.py b/processing_services/moths/api/api/algorithms.py new file mode 100644 index 000000000..8a80038dd --- /dev/null +++ b/processing_services/moths/api/api/algorithms.py @@ -0,0 +1,431 @@ +import datetime +import logging +import math +import random + +import torch + +from .schemas import ( + AlgorithmCategoryMapResponse, + AlgorithmConfigResponse, + AlgorithmReference, + BoundingBox, + ClassificationResponse, + Detection, + SourceImage, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +SAVED_MODELS = {} + + +def get_best_device() -> str: + """ + Returns the best available device for running the model. + + MPS is not supported by the current algoritms. + """ + if torch.cuda.is_available(): + return f"cuda:{torch.cuda.current_device()}" + else: + return "cpu" + + +class Algorithm: + algorithm_config_response: AlgorithmConfigResponse + + def compile(self): + raise NotImplementedError("Subclasses must implement the compile method") + + def run(self, inputs: list[SourceImage] | list[Detection]) -> list[Detection]: + raise NotImplementedError("Subclasses must implement the run method") + + def get_category_map(self) -> AlgorithmCategoryMapResponse: + return AlgorithmCategoryMapResponse( + data=[], + labels=[], + version="v1", + description="A model without labels.", + uri=None, + ) + + def get_algorithm_config_response(self) -> AlgorithmConfigResponse: + return AlgorithmConfigResponse( + name="Base Algorithm", + key="base", + task_type="base", + description="A base class for all algorithms.", + version=1, + version_name="v1", + category_map=self.get_category_map(), + ) + + def __init__(self): + self.algorithm_config_response = self.get_algorithm_config_response() + + +class ZeroShotObjectDetector(Algorithm): + """ + Huggingface Zero-Shot Object Detection model. + Produces both a bounding box and a classification for each detection. + The classification is based on the candidate labels. + """ + + candidate_labels: list[str] = ["insect"] + + def compile(self, device: str | None = None): + saved_models_key = "zero_shot_object_detector" # generate a key for each uniquely compiled algorithm + + if saved_models_key not in SAVED_MODELS: + from transformers import pipeline + + device_choice = device or get_best_device() + device_index = int(device_choice.split(":")[-1]) if ":" in device_choice else -1 + logger.info(f"Compiling {self.algorithm_config_response.name} on device {device_choice}...") + checkpoint = "google/owlv2-base-patch16-ensemble" + self.model = pipeline( + model=checkpoint, + task="zero-shot-object-detection", + use_fast=True, + device=device_index, + ) + SAVED_MODELS[saved_models_key] = self.model + else: + logger.info(f"Using saved model for {self.algorithm_config_response.name}...") + self.model = SAVED_MODELS[saved_models_key] + + def run(self, source_images: list[SourceImage], intermediate=False) -> list[Detection]: + detector_responses: list[Detection] = [] + for source_image in source_images: + if source_image.width and source_image.height and source_image._pil: + start_time = datetime.datetime.now() + logger.info("Predicting...") + if not self.candidate_labels: + raise ValueError("No candidate labels are provided during inference.") + logger.info(f"Predicting with candidate labels: {self.candidate_labels}") + predictions = self.model(source_image._pil, candidate_labels=self.candidate_labels) + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + for prediction in predictions: + logger.info("Prediction: %s", prediction) + bbox = BoundingBox( + x1=prediction["box"]["xmin"], + x2=prediction["box"]["xmax"], + y1=prediction["box"]["ymin"], + y2=prediction["box"]["ymax"], + ) + cropped_image_pil = source_image._pil.crop((bbox.x1, bbox.y1, bbox.x2, bbox.y2)) + detection = Detection( + id=f"{source_image.id}-crop-{bbox.x1}-{bbox.y1}-{bbox.x2}-{bbox.y2}", + url=source_image.url, # @TODO: ideally, should save cropped image at separate url + width=cropped_image_pil.width, + height=cropped_image_pil.height, + timestamp=datetime.datetime.now(), + source_image=source_image, + bbox=bbox, + inference_time=elapsed_time, + algorithm=AlgorithmReference( + name=self.algorithm_config_response.name, + key=self.algorithm_config_response.key, + ), + classifications=[ + ClassificationResponse( + classification=prediction["label"], + labels=[prediction["label"]], + scores=[prediction["score"]], + logits=[prediction["score"]], + inference_time=elapsed_time, + timestamp=datetime.datetime.now(), + algorithm=AlgorithmReference( + name=self.algorithm_config_response.name, + key=self.algorithm_config_response.key, + ), + terminal=not intermediate, + ) + ], + ) + detection._pil = cropped_image_pil + detector_responses.append(detection) + else: + raise ValueError(f"Source image {source_image.id} does not have width and height attributes.") + + return detector_responses + + def get_category_map(self) -> AlgorithmCategoryMapResponse: + return AlgorithmCategoryMapResponse( + data=[{"index": i, "label": label} for i, label in enumerate(self.candidate_labels)], + labels=self.candidate_labels, + version="v1", + description="Candidate labels used for zero-shot object detection.", + uri=None, + ) + + def get_algorithm_config_response(self) -> AlgorithmConfigResponse: + return AlgorithmConfigResponse( + name="Zero Shot Object Detector", + key="zero-shot-object-detector", + task_type="detection", + description=( + "Huggingface Zero Shot Object Detection model." + "Produces both a bounding box and a candidate label classification for each detection." + ), + version=1, + version_name="v1", + category_map=self.get_category_map(), + ) + + +class HFImageClassifier(Algorithm): + """ + A local classifier that uses the Hugging Face pipeline to classify images. + """ + + model_name: str = "google/vit-base-patch16-224" # Vision Transformer model trained on ImageNet-1k + + def compile(self): + saved_models_key = "hf_image_classifier" # generate a key for each uniquely compiled algorithm + + if saved_models_key not in SAVED_MODELS: + from transformers import pipeline + + logger.info(f"Compiling {self.algorithm_config_response.name} from scratch...") + self.model = pipeline("image-classification", model=self.model_name, device=get_best_device()) + SAVED_MODELS[saved_models_key] = self.model + else: + logger.info(f"Using saved model for {self.algorithm_config_response.name}...") + self.model = SAVED_MODELS[saved_models_key] + + def run(self, detections: list[Detection]) -> list[Detection]: + detections_to_return: list[Detection] = [] + start_time = datetime.datetime.now() + + opened_cropped_images = [detection._pil for detection in detections] # type: ignore + + # Process the entire batch of cropped images at once + results = self.model(images=opened_cropped_images) + + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + for detection, preds in zip(detections, results): + labels = [pred["label"] for pred in preds] + scores = [pred["score"] for pred in preds] + max_score_index = scores.index(max(scores)) + classification = labels[max_score_index] + logger.info(f"Classification: {classification}") + logger.info(f"labels: {labels}") + logger.info(f"scores: {scores}") + + existing_classifications = detection.classifications + + detection_with_classification = detection.copy(deep=True) + detection_with_classification.classifications = existing_classifications + [ + ClassificationResponse( + classification=classification, + labels=labels, + scores=scores, + logits=scores, + inference_time=elapsed_time, + timestamp=datetime.datetime.now(), + algorithm=AlgorithmReference( + name=self.algorithm_config_response.name, key=self.algorithm_config_response.key + ), + terminal=True, + ) + ] + + detections_to_return.append(detection_with_classification) + + return detections_to_return + + def get_category_map(self) -> AlgorithmCategoryMapResponse: + """ + Extract the category map from the model. + Returns an AlgorithmCategoryMapResponse with labels, data, and model information. + """ + from transformers.models.auto.configuration_auto import AutoConfig + + logger.info(f"Loading configuration for {self.model_name}") + config = AutoConfig.from_pretrained(self.model_name) + + # Extract label information + if not hasattr(config, "id2label") or not config.id2label: + raise ValueError( + f"Cannot create category map for model {self.model_name}, no id2label mapping found in config" + ) + else: + # Sort labels by index + # Ensure keys are strings for consistent access + id2label: dict[str, str] = {str(k): v for k, v in config.id2label.items()} + indices = sorted([int(k) for k in id2label.keys()]) + + # Create labels and data + labels = [id2label[str(i)] for i in indices] + data = [{"label": label, "index": idx} for idx, label in zip(indices, labels)] + + # Build description + description_text = ( + f"Vision Transformer model trained on ImageNet-1k. " + f"Contains {len(labels)} object classes. Model: {self.model_name}" + ) + + return AlgorithmCategoryMapResponse( + data=data, + labels=labels, + version="ImageNet-1k", + description=description_text, + uri=f"https://huggingface.co/{self.model_name}", + ) + + def get_algorithm_config_response(self) -> AlgorithmConfigResponse: + return AlgorithmConfigResponse( + name="HF Image Classifier", + key="hf-image-classifier", + task_type="classification", + description="HF ViT for image classification.", + version=1, + version_name="v1", + category_map=self.get_category_map(), + ) + + +class RandomSpeciesClassifier(Algorithm): + """ + A local classifier that produces random butterfly species classifications. + """ + + def compile(self): + pass + + def _make_random_prediction( + self, + terminal: bool = True, + max_labels: int = 2, + ) -> ClassificationResponse: + assert self.algorithm_config_response.category_map is not None + category_labels = self.algorithm_config_response.category_map.labels + logits = [random.random() for _ in category_labels] + softmax = [math.exp(logit) / sum([math.exp(logit) for logit in logits]) for logit in logits] + top_class = category_labels[softmax.index(max(softmax))] + return ClassificationResponse( + classification=top_class, + labels=category_labels if len(category_labels) <= max_labels else None, + scores=softmax, + logits=logits, + timestamp=datetime.datetime.now(), + algorithm=AlgorithmReference( + name=self.algorithm_config_response.name, + key=self.algorithm_config_response.key, + ), + terminal=terminal, + ) + + def run(self, detections: list[Detection]) -> list[Detection]: + detections_to_return: list[Detection] = [] + for detection in detections: + detection_with_classification = detection.copy(deep=True) + detection_with_classification.classifications = [self._make_random_prediction(terminal=True)] + detections_to_return.append(detection_with_classification) + return detections_to_return + + algorithm_config_response = AlgorithmConfigResponse( + name="Random species classifier", + key="random-species-classifier", + task_type="classification", + description="A random species classifier", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/random-species-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + { + "index": 0, + "gbif_key": "1234", + "label": "Vanessa atalanta", + "source": "manual", + "taxon_rank": "SPECIES", + }, + { + "index": 1, + "gbif_key": "4543", + "label": "Vanessa cardui", + "source": "manual", + "taxon_rank": "SPECIES", + }, + { + "index": 2, + "gbif_key": "7890", + "label": "Vanessa itea", + "source": "manual", + "taxon_rank": "SPECIES", + }, + ], + labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"], + version="v1", + description="A simple species classifier", + uri="https://huggingface.co/RolnickLab/random-species-classifier", + ), + ) + + +class ConstantClassifier(Algorithm): + """ + A local classifier that always returns a constant species classification. + """ + + def compile(self): + pass + + def _make_constant_prediction( + self, + terminal: bool = True, + ) -> ClassificationResponse: + assert self.algorithm_config_response.category_map is not None + labels = self.algorithm_config_response.category_map.labels + return ClassificationResponse( + classification=labels[0], + labels=labels, + scores=[0.9], # Constant score for each detection + timestamp=datetime.datetime.now(), + algorithm=AlgorithmReference( + name=self.algorithm_config_response.name, + key=self.algorithm_config_response.key, + ), + terminal=terminal, + ) + + def run(self, detections: list[Detection]) -> list[Detection]: + detections_to_return: list[Detection] = [] + for detection in detections: + detection_with_classification = detection.copy(deep=True) + detection_with_classification.classifications = [self._make_constant_prediction(terminal=True)] + detections_to_return.append(detection_with_classification) + return detections_to_return + + algorithm_config_response = AlgorithmConfigResponse( + name="Constant classifier", + key="constant-classifier", + task_type="classification", + description="Always return a classification of 'Moth'", + version=1, + version_name="v1", + uri="https://huggingface.co/RolnickLab/constant-classifier", + category_map=AlgorithmCategoryMapResponse( + data=[ + { + "index": 0, + "gbif_key": "1234", + "label": "Moth", + "source": "manual", + "taxon_rank": "SUPERFAMILY", + } + ], + labels=["Moth"], + version="v1", + description="A classifier that always returns 'Moth'", + uri="https://huggingface.co/RolnickLab/constant-classifier", + ), + ) diff --git a/processing_services/moths/api/api/api.py b/processing_services/moths/api/api/api.py new file mode 100644 index 000000000..0396af5e2 --- /dev/null +++ b/processing_services/moths/api/api/api.py @@ -0,0 +1,261 @@ +""" +Fast API interface for processing images through the localization and classification pipelines. +""" + +import logging + +import fastapi + +from .pipelines import ( + Pipeline, + ZeroShotHFClassifierPipeline, + ZeroShotObjectDetectorPipeline, + ZeroShotObjectDetectorWithConstantClassifierPipeline, + ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, + ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, +) +from .schemas import ( + AlgorithmConfigResponse, + Detection, + DetectionRequest, + PipelineRequest, + PipelineRequestConfigParameters, + PipelineResultsResponse, + ProcessingServiceInfoResponse, + SourceImage, +) +from .utils import is_base64, is_url + +# Configure root logger +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", +) + +# Get the root logger +logger = logging.getLogger(__name__) + +app = fastapi.FastAPI() + + +pipelines: list[type[Pipeline]] = [ + ZeroShotHFClassifierPipeline, + ZeroShotObjectDetectorPipeline, + ZeroShotObjectDetectorWithConstantClassifierPipeline, + ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, + ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, +] +pipeline_choices: dict[str, type[Pipeline]] = { + pipeline.config.slug: pipeline for pipeline in pipelines +} +algorithm_choices: dict[str, AlgorithmConfigResponse] = { + algorithm.key: algorithm + for pipeline in pipelines + for algorithm in pipeline.config.algorithms +} + +# ----------- +# API endpoints +# ----------- + + +@app.get("/") +async def root(): + return fastapi.responses.RedirectResponse("/docs") + + +@app.get("/info", tags=["services"]) +async def info() -> ProcessingServiceInfoResponse: + info = ProcessingServiceInfoResponse( + name="Custom ML Backend", + description=("A template for running custom models locally."), + pipelines=[pipeline.config for pipeline in pipelines], + # algorithms=list(algorithm_choices.values()), + ) + return info + + +# Check if the server is online +@app.get("/livez", tags=["health checks"]) +async def livez(): + return fastapi.responses.JSONResponse(status_code=200, content={"status": True}) + + +# Check if the pipelines are ready to process data +@app.get("/readyz", tags=["health checks"]) +async def readyz(): + """ + Check if the server is ready to process data. + + Returns a list of pipeline slugs that are online and ready to process data. + @TODO may need to simplify this to just return True/False. Pipeline algorithms will likely be loaded into memory + on-demand when the pipeline is selected. + """ + if pipeline_choices: + return fastapi.responses.JSONResponse( + status_code=200, content={"status": list(pipeline_choices.keys())} + ) + else: + return fastapi.responses.JSONResponse(status_code=503, content={"status": []}) + + +@app.post("/process", tags=["services"]) +async def process(data: PipelineRequest) -> PipelineResultsResponse: + pipeline_slug = data.pipeline + request_config = data.config + + source_images = [SourceImage(**img.model_dump()) for img in data.source_images] + # Open source images once before processing + for img in source_images: + img.open(raise_exception=True) + + detections = create_detections( + source_images=source_images, + detection_requests=data.detections, + ) + + try: + Pipeline = pipeline_choices[pipeline_slug] + except KeyError: + raise fastapi.HTTPException( + status_code=422, detail=f"Invalid pipeline choice: {pipeline_slug}" + ) + + pipeline_request_config = ( + PipelineRequestConfigParameters(**dict(request_config)) + if request_config + else {} + ) + try: + pipeline = Pipeline( + source_images=source_images, + request_config=pipeline_request_config, + existing_detections=detections, + ) + pipeline.compile() + except Exception as e: + logger.error(f"Error compiling pipeline: {e}") + raise fastapi.HTTPException(status_code=422, detail=f"{e}") + + try: + response = pipeline.run() + except Exception as e: + logger.error(f"Error running pipeline: {e}") + raise fastapi.HTTPException(status_code=422, detail=f"{e}") + + return response + + +# ----------- +# Helper functions +# ----------- + + +def create_detections( + source_images: list[SourceImage], + detection_requests: list[DetectionRequest] | None, +): + if not detection_requests: + return [] + + # Group detection requests by source image id + source_image_map = {img.id: img for img in source_images} + grouped_detection_requests = {} + for request in detection_requests: + if request.source_image.id not in grouped_detection_requests: + grouped_detection_requests[request.source_image.id] = [] + grouped_detection_requests[request.source_image.id].append(request) + + # Process each source image and its detection requests + detections = [] + for source_image_id, requests in grouped_detection_requests.items(): + if source_image_id not in source_image_map: + raise ValueError( + f"A detection request for source image {source_image_id} was received, " + "but no source image with that ID was provided." + ) + + logger.info( + f"Processing existing detections for source image {source_image_id}." + ) + + for request in requests: + source_image = source_image_map[source_image_id] + cropped_image_id = f"{source_image.id}-crop-{request.bbox.x1}-{request.bbox.y1}-{request.bbox.x2}-{request.bbox.y2}" + if not request.crop_image_url: + logger.info( + "Detection request does not have a crop_image_url, crop the original source image." + ) + assert ( + source_image._pil is not None + ), "Source image must be opened before cropping." + cropped_image_pil = source_image._pil.crop( + (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) + ) + else: + try: + logger.info( + f"Opening existing cropped image from {request.crop_image_url}." + ) + if is_url(request.crop_image_url): + cropped_image = SourceImage( + id=cropped_image_id, + url=request.crop_image_url, + ) + elif is_base64(request.crop_image_url): + logger.info("Decoding base64 cropped image.") + cropped_image = SourceImage( + id=cropped_image_id, + b64=request.crop_image_url, + ) + else: + # Must be a filepath + cropped_image = SourceImage( + id=cropped_image_id, + filepath=request.crop_image_url, + ) + cropped_image.open(raise_exception=True) + cropped_image_pil = cropped_image._pil + except Exception as e: + logger.warning(f"Error opening cropped image: {e}") + logger.info( + f"Falling back to cropping the original source image {source_image_id}." + ) + assert ( + source_image._pil is not None + ), "Source image must be opened before cropping." + cropped_image_pil = source_image._pil.crop( + ( + request.bbox.x1, + request.bbox.y1, + request.bbox.x2, + request.bbox.y2, + ) + ) + + # Create a Detection object + det = Detection( + source_image=SourceImage( + id=source_image.id, + url=source_image.url, + ), + bbox=request.bbox, + id=cropped_image_id, + url=request.crop_image_url or source_image.url, + algorithm=request.algorithm, + ) + # Set the _pil attribute to the cropped image + det._pil = cropped_image_pil + detections.append(det) + logger.info( + f"Created detection {det.id} for source image {source_image_id}." + ) + + return detections + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=2000) diff --git a/processing_services/moths/api/api/base.py b/processing_services/moths/api/api/base.py new file mode 100644 index 000000000..bb001ec46 --- /dev/null +++ b/processing_services/moths/api/api/base.py @@ -0,0 +1,241 @@ +""" +Simplified base classes for inference models without database dependencies. +Adapted from trapdata.ml.models.base but streamlined for processing service use. +""" + +import json +import logging +from typing import Any, Dict, Optional + +import torch +import torchvision.transforms + +from .utils import get_best_device, get_or_download_file + +logger = logging.getLogger(__name__) + + +# Standard normalization transforms +imagenet_normalization = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], # RGB + std=[0.229, 0.224, 0.225], # RGB +) + +tensorflow_normalization = torchvision.transforms.Normalize( + mean=[0.5, 0.5, 0.5], # RGB + std=[0.5, 0.5, 0.5], # RGB +) + + +class SimplifiedInferenceBase: + """ + Simplified base class for inference models without database or queue dependencies. + """ + + name: str = "Unknown Inference Model" + description: str = "" + weights_path: Optional[str] = None + labels_path: Optional[str] = None + category_map: Dict[int, str] = {} + num_classes: Optional[int] = None + default_taxon_rank: str = "SPECIES" + normalization = tensorflow_normalization + batch_size: int = 4 + device: Optional[str] = None + + def __init__(self, **kwargs): + # Override any class attributes with provided kwargs + for k, v in kwargs.items(): + setattr(self, k, v) + + logger.info(f"Initializing simplified inference class {self.name}") + + self.device = self.device or get_best_device() + self.category_map = self.get_labels(self.labels_path) + self.num_classes = self.num_classes or len(self.category_map) + self.weights = self.get_weights(self.weights_path) + self.transforms = self.get_transforms() + + logger.info( + f"Loading model for {self.name} with {len(self.category_map or [])} categories" + ) + self.model = self.get_model() + + @classmethod + def get_key(cls) -> str: + """Generate a unique key for this algorithm.""" + if hasattr(cls, "key") and cls.key: + return cls.key + else: + return cls.name.lower().replace(" ", "-").replace("/", "-") + + def get_weights(self, weights_path: Optional[str]) -> Optional[str]: + """Download and cache model weights.""" + if weights_path: + logger.info(f"⬇️ Downloading model weights from: {weights_path}") + weights_file = str(get_or_download_file(weights_path, tempdir_prefix="models")) + logger.info(f"✅ Model weights downloaded to: {weights_file}") + return weights_file + else: + logger.warning(f"No weights specified for model {self.name}") + return None + + def get_labels(self, labels_path: Optional[str]) -> Dict[int, str]: + """Download and load category labels.""" + if not labels_path: + return {} + + logger.info(f"⬇️ Downloading category labels from: {labels_path}") + local_path = get_or_download_file(labels_path, tempdir_prefix="models") + logger.info(f"📝 Loading category labels from: {local_path}") + + with open(local_path) as f: + labels = json.load(f) + + # Convert label->index mapping to index->label mapping + index_to_label = {index: label for label, index in labels.items()} + logger.info(f"✅ Loaded {len(index_to_label)} category labels") + return index_to_label + + def get_transforms(self) -> torchvision.transforms.Compose: + """Get image preprocessing transforms.""" + return torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((224, 224)), + torchvision.transforms.ToTensor(), + self.normalization, + ] + ) + + def get_model(self) -> torch.nn.Module: + """ + Load and return the PyTorch model. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement get_model()") + + def predict_batch(self, batch: torch.Tensor) -> torch.Tensor: + """ + Run inference on a batch of images. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement predict_batch()") + + def post_process_batch(self, logits: torch.Tensor) -> Any: + """ + Post-process model outputs. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement post_process_batch()") + + +class ResNet50Base(SimplifiedInferenceBase): + """ + Base class for ResNet50-based models. + """ + + input_size: int = 224 + normalization = imagenet_normalization + + def get_transforms(self) -> torchvision.transforms.Compose: + """Get ResNet50-specific transforms.""" + return torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((self.input_size, self.input_size)), + torchvision.transforms.ToTensor(), + self.normalization, + ] + ) + + def get_model(self) -> torch.nn.Module: + """Load ResNet50 model with custom classifier.""" + import torchvision.models as models + + logger.info("🏗️ Creating ResNet50 model architecture...") + # Create ResNet50 backbone + model = models.resnet50(weights=None) + + # Replace final classifier layer + if self.num_classes is None: + raise ValueError("num_classes must be set before loading model") + logger.info(f"🔧 Setting up classifier layer for {self.num_classes} classes...") + model.fc = torch.nn.Linear(model.fc.in_features, self.num_classes) + + # Load pretrained weights + if self.weights: + logger.info(f"📂 Loading pretrained weights from: {self.weights}") + checkpoint = torch.load(self.weights, map_location=self.device) + + # Handle different checkpoint formats + if "model_state_dict" in checkpoint: + logger.info("📥 Loading state dict from 'model_state_dict' key...") + model.load_state_dict(checkpoint["model_state_dict"]) + elif "state_dict" in checkpoint: + logger.info("📥 Loading state dict from 'state_dict' key...") + model.load_state_dict(checkpoint["state_dict"]) + else: + logger.info("📥 Loading state dict directly...") + model.load_state_dict(checkpoint) + logger.info("✅ Model weights loaded successfully!") + else: + logger.warning("⚠️ No pretrained weights provided - using random initialization") + + logger.info(f"📱 Moving model to device: {self.device}") + model = model.to(self.device) + model.eval() + logger.info("✅ Model ready for inference!") + return model + + def predict_batch(self, batch: torch.Tensor) -> torch.Tensor: + """Run inference on batch.""" + with torch.no_grad(): + batch = batch.to(self.device) + outputs = self.model(batch) + return outputs + + def post_process_batch(self, logits: torch.Tensor) -> list: + """Convert logits to predictions.""" + probabilities = torch.softmax(logits, dim=1) + predictions = [] + + for prob_tensor in probabilities: + prob_list = prob_tensor.cpu().numpy().tolist() + predictions.append( + { + "scores": prob_list, + "logits": logits[len(predictions)].cpu().numpy().tolist(), + } + ) + + return predictions + + +class TimmResNet50Base(ResNet50Base): + """ + Base class for timm ResNet50-based models. + """ + + def get_model(self) -> torch.nn.Module: + """Load timm ResNet50 model.""" + import timm + + # Create timm ResNet50 model + model = timm.create_model( + "resnet50", pretrained=False, num_classes=self.num_classes + ) + + # Load pretrained weights + if self.weights: + checkpoint = torch.load(self.weights, map_location=self.device) + + # Handle different checkpoint formats + if "model_state_dict" in checkpoint: + model.load_state_dict(checkpoint["model_state_dict"]) + elif "state_dict" in checkpoint: + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + + model = model.to(self.device) + model.eval() + return model diff --git a/processing_services/moths/api/api/global_moth_classifier.py b/processing_services/moths/api/api/global_moth_classifier.py new file mode 100644 index 000000000..9d6abe8a0 --- /dev/null +++ b/processing_services/moths/api/api/global_moth_classifier.py @@ -0,0 +1,254 @@ +""" +Global Moth Classifier algorithm implementation. +Simplified version of trapdata.api.models.classification.MothClassifierGlobal +adapted for the processing service framework. +""" + +import datetime +import logging +from typing import List + +import torch +import torchvision.transforms + +from .algorithms import Algorithm +from .base import TimmResNet50Base, imagenet_normalization +from .schemas import ( + AlgorithmCategoryMapResponse, + AlgorithmConfigResponse, + AlgorithmReference, + ClassificationResponse, + Detection, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class GlobalMothClassifier(Algorithm, TimmResNet50Base): + """ + Global Moth Species Classifier. + + Simplified version of the trapdata GlobalMothSpeciesClassifier + that works without database dependencies. + """ + + name = "Global Species Classifier - Aug 2024" + description = ( + "Trained on August 28th, 2024 for 29,176 species. " + "https://wandb.ai/moth-ai/global-moth-classifier/runs/h0cuqrbc/overview" + ) + weights_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" + "global_resnet50_20240828_b06d3b3a.pth" + ) + labels_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" + "global_category_map_with_names_20240828.json" + ) + + # Model configuration + input_size = 128 + normalization = imagenet_normalization + default_taxon_rank = "SPECIES" + batch_size = 4 + + def __init__(self, **kwargs): + """Initialize the global moth classifier.""" + # Initialize Algorithm parent class + Algorithm.__init__(self) + + # Store kwargs for later use in compile() + self._init_kwargs = kwargs + + # Initialize basic attributes without loading model + self.model = None + self.transforms = None + self.category_map = {} # Empty dict for now + self.num_classes = 29176 # Known number of classes + + logger.info(f"Initialized {self.name} (model loading deferred to compile())") + + @property + def algorithm_config_response(self) -> AlgorithmConfigResponse: + """Get algorithm configuration for API response.""" + if not hasattr(self, '_algorithm_config_response'): + # Create a basic config response before compilation + self._algorithm_config_response = AlgorithmConfigResponse( + name=self.name, + key=self.get_key(), + task_type="classification", + description=self.description, + version=1, + version_name="v1", + category_map=AlgorithmCategoryMapResponse( + data=[], + labels=[], + version="v1", + description="Global moth species classifier (not yet compiled)", + uri=self.labels_path, + ), + uri=self.weights_path, + ) + return self._algorithm_config_response + + @algorithm_config_response.setter + def algorithm_config_response(self, value: AlgorithmConfigResponse): + """Set algorithm configuration response.""" + self._algorithm_config_response = value + + def compile(self): + """Load model weights and initialize transforms (called by pipeline).""" + if self.model is not None: + logger.info("Model already compiled, skipping...") + return + + logger.info(f"🔧 Compiling {self.name}...") + logger.info(f" 📊 Expected classes: {self.num_classes}") + logger.info(f" 🏷️ Labels URL: {self.labels_path}") + logger.info(f" ⚖️ Weights URL: {self.weights_path}") + + # Initialize the TimmResNet50Base now (this will download weights/labels) + logger.info(" 📥 Downloading model weights and labels...") + TimmResNet50Base.__init__(self, **self._init_kwargs) + + # Set algorithm config response + logger.info(" 📋 Setting up algorithm configuration...") + self.algorithm_config_response = self.get_algorithm_config_response() + + logger.info(f"✅ {self.name} compiled successfully!") + logger.info(f" 📊 Loaded {len(self.category_map)} species categories") + logger.info(f" 🔧 Model device: {getattr(self, 'device', 'unknown')}") + logger.info(f" 🖼️ Input size: {self.input_size}x{self.input_size}") + + def get_transforms(self) -> torchvision.transforms.Compose: + """Get transforms specific to this model.""" + return torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((self.input_size, self.input_size)), + torchvision.transforms.ToTensor(), + self.normalization, + ] + ) + + def run(self, detections: List[Detection]) -> List[Detection]: + """ + Run classification on a list of detections. + + Args: + detections: List of Detection objects with cropped images + + Returns: + List of Detection objects with added classifications + """ + if not detections: + return [] + + # Ensure model is compiled + if self.model is None: + raise RuntimeError("Model not compiled. Call compile() first.") + + logger.info(f"Running {self.name} on {len(detections)} detections") + + # Process detections in batches + classified_detections = [] + + for i in range(0, len(detections), self.batch_size): + batch_detections = detections[i : i + self.batch_size] + batch_images = [] + + # Prepare batch of images + for detection in batch_detections: + if detection._pil: + # Convert to RGB if needed + if detection._pil.mode != "RGB": + img = detection._pil.convert("RGB") + else: + img = detection._pil + batch_images.append(img) + else: + logger.warning(f"Detection {detection.id} has no PIL image") + continue + + if not batch_images: + continue + + # Transform images + if self.transforms is None: + raise RuntimeError("Transforms not initialized. Call compile() first.") + batch_tensor = torch.stack([self.transforms(img) for img in batch_images]) + + # Run inference + start_time = datetime.datetime.now() + predictions = self.predict_batch(batch_tensor) + processed_predictions = self.post_process_batch(predictions) + end_time = datetime.datetime.now() + + inference_time = (end_time - start_time).total_seconds() / len(batch_images) + + # Add classifications to detections + for detection, prediction in zip(batch_detections, processed_predictions): + # Get best prediction + best_score = max(prediction["scores"]) + best_idx = prediction["scores"].index(best_score) + best_label = self.category_map.get(best_idx, f"class_{best_idx}") + + classification = ClassificationResponse( + classification=best_label, + labels=[best_label], + scores=[best_score], + logits=prediction["logits"], + inference_time=inference_time, + timestamp=datetime.datetime.now(), + algorithm=AlgorithmReference( + name=self.name, + key=self.get_key(), + ), + terminal=True, + ) + + # Add classification to detection + detection_with_classification = detection.copy(deep=True) + detection_with_classification.classifications = [classification] + classified_detections.append(detection_with_classification) + + logger.info(f"Classified {len(classified_detections)} detections") + return classified_detections + + def get_category_map(self) -> AlgorithmCategoryMapResponse: + """Get category map for API response.""" + categories_sorted_by_index = sorted( + self.category_map.items(), key=lambda x: x[0] + ) + categories_data = [ + { + "index": index, + "label": label, + "taxon_rank": self.default_taxon_rank, + } + for index, label in categories_sorted_by_index + ] + label_strings = [cat["label"] for cat in categories_data] + + return AlgorithmCategoryMapResponse( + data=categories_data, + labels=label_strings, + version="v1", + description=f"Global moth species classifier with {len(categories_data)} species", + uri=self.labels_path, + ) + + def get_algorithm_config_response(self) -> AlgorithmConfigResponse: + """Get algorithm configuration for API response.""" + return AlgorithmConfigResponse( + name=self.name, + key=self.get_key(), + task_type="classification", + description=self.description, + version=1, + version_name="v1", + category_map=self.get_category_map(), + uri=self.weights_path, + ) + + diff --git a/processing_services/moths/api/api/pipelines.py b/processing_services/moths/api/api/pipelines.py new file mode 100644 index 000000000..ceb395cf9 --- /dev/null +++ b/processing_services/moths/api/api/pipelines.py @@ -0,0 +1,457 @@ +import datetime +import logging +from typing import final + +from .algorithms import ( + Algorithm, + ConstantClassifier, + HFImageClassifier, + RandomSpeciesClassifier, + ZeroShotObjectDetector, +) +from .global_moth_classifier import GlobalMothClassifier +from .schemas import ( + Detection, + DetectionResponse, + PipelineConfigResponse, + PipelineRequestConfigParameters, + PipelineResultsResponse, + SourceImage, + SourceImageResponse, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class Pipeline: + """ + A base class for defining and running a pipeline consisting of multiple stages. + Each stage is represented by an algorithm that processes inputs and produces + outputs. The pipeline is designed to handle batch processing using custom batch + sizes for each stage. + + Attributes: + stages (list[Algorithm]): A list of algorithms representing the stages of + the pipeline in order of execution. Typically [Detector(), Classifier()]. + batch_sizes (list[int]): A list of integers specifying the batch size for + each stage. For example, [1, 1] means that the detector can process 1 + source image a time and the classifier can process 1 detection at a time. + config (PipelineConfigResponse): Pipeline metadata. + """ + + stages: list[Algorithm] + batch_sizes: list[int] + request_config: dict + config: PipelineConfigResponse + + stages = [] + batch_sizes = [] + config = PipelineConfigResponse( + name="Base Pipeline", + slug="base", + description="A base class for all pipelines.", + version=1, + algorithms=[], + ) + + def __init__( + self, + source_images: list[SourceImage], + request_config: PipelineRequestConfigParameters | dict = {}, + existing_detections: list[Detection] = [], + custom_batch_sizes: list[int] = [], + ): + self.source_images = source_images + self.request_config = ( + request_config + if isinstance(request_config, dict) + else request_config.model_dump() + ) + self.existing_detections = existing_detections + + logger.info("Initializing algorithms....") + self.stages = self.stages or self.get_stages() + self.batch_sizes = ( + custom_batch_sizes or self.batch_sizes or [1] * len(self.stages) + ) + assert len(self.batch_sizes) == len( + self.stages + ), "Number of batch sizes must match the number of stages." + + def get_stages(self) -> list[Algorithm]: + """ + An optional function to initialize and return a list of algorithms/stages. + Any pipeline config values relevant to a particular algorithm should be passed or set here. + """ + return [] + + @final + def compile(self): + logger.info("Compiling algorithms....") + for stage_idx, stage in enumerate(self.stages): + logger.info( + f"[{stage_idx+1}/{len(self.stages)}] Compiling {stage.algorithm_config_response.name}..." + ) + stage.compile() + + def run(self) -> PipelineResultsResponse: + """ + This function must always return a PipelineResultsResponse object. + """ + raise NotImplementedError("Subclasses must implement") + + @final + def _batchify_inputs(self, inputs: list, batch_size: int) -> list[list]: + """ + Helper function to split the inputs into batches of the specified size. + """ + batched_inputs = [] + for i in range(0, len(inputs), batch_size): + start_id = i + end_id = i + batch_size + batched_inputs.append(inputs[start_id:end_id]) + return batched_inputs + + @final + def _get_detections( + self, + algorithm: Algorithm, + inputs: list[SourceImage] | list[Detection], + batch_size: int, + **kwargs, + ) -> list[Detection]: + """A single stage, step, or algorithm in a pipeline. Batchifies inputs and produces Detections as outputs.""" + outputs: list[Detection] = [] + batched_inputs = self._batchify_inputs(inputs, batch_size) + for batch in batched_inputs: + outputs.extend(algorithm.run(batch, **kwargs)) + return outputs + + @final + def _get_pipeline_response( + self, detections: list[Detection], elapsed_time: float + ) -> PipelineResultsResponse: + """ + Final stage of the pipeline to format the detections. + """ + detection_responses = [ + DetectionResponse( + source_image_id=detection.source_image.id, + bbox=detection.bbox, + inference_time=detection.inference_time, + algorithm=detection.algorithm, + timestamp=datetime.datetime.now(), + classifications=detection.classifications, + ) + for detection in detections + ] + source_image_responses = [ + SourceImageResponse(**image.model_dump()) for image in self.source_images + ] + + return PipelineResultsResponse( + pipeline=self.config.slug, # type: ignore + # algorithms={algorithm.key: algorithm for algorithm in self.config.algorithms}, + total_time=elapsed_time, + source_images=source_image_responses, + detections=detection_responses, + ) + + +class ZeroShotHFClassifierPipeline(Pipeline): + """ + A pipeline that uses the Zero Shot Object Detector to produce bounding boxes + and then applies the HuggingFace image classifier. + """ + + batch_sizes = [1, 1] + config = PipelineConfigResponse( + name="Zero Shot HF Classifier Pipeline", + slug="zero-shot-hf-classifier-pipeline", + description=("Zero Shot Object Detector with HF image classifier."), + version=1, + algorithms=[ + ZeroShotObjectDetector().algorithm_config_response, + HFImageClassifier().algorithm_config_response, + ], + ) + + def get_stages(self) -> list[Algorithm]: + zero_shot_object_detector = ZeroShotObjectDetector() + if "candidate_labels" in self.request_config: + logger.info( + "Setting candidate labels for zero shot object detector to %s", + self.request_config["candidate_labels"], + ) + zero_shot_object_detector.candidate_labels = self.request_config[ + "candidate_labels" + ] + self.config.algorithms = [ + zero_shot_object_detector.algorithm_config_response, + HFImageClassifier().algorithm_config_response, + ] + + return [zero_shot_object_detector, HFImageClassifier()] + + def run(self) -> PipelineResultsResponse: + start_time = datetime.datetime.now() + detections_with_candidate_labels: list[Detection] = [] + if self.existing_detections: + logger.info("[1/2] Skipping the localizer, use existing detections...") + detections_with_candidate_labels = self.existing_detections + else: + logger.info("[1/2] No existing detections, generating detections...") + detections_with_candidate_labels: list[Detection] = self._get_detections( + self.stages[0], + self.source_images, + self.batch_sizes[0], + intermediate=True, + ) + + logger.info("[2/2] Running the classifier...") + detections_with_classifications: list[Detection] = self._get_detections( + self.stages[1], detections_with_candidate_labels, self.batch_sizes[1] + ) + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + pipeline_response: PipelineResultsResponse = self._get_pipeline_response( + detections_with_classifications, elapsed_time + ) + logger.info( + f"Successfully processed {len(detections_with_classifications)} detections." + ) + + return pipeline_response + + +class ZeroShotObjectDetectorPipeline(Pipeline): + """ + A pipeline that uses the HuggingFace zero shot object detector. + Produces both a bounding box and a classification for each detection. + The classification is based on the candidate labels provided in the request. + """ + + batch_sizes = [1] + config = PipelineConfigResponse( + name="Zero Shot Object Detector Pipeline", + slug="zero-shot-object-detector-pipeline", + description=("Zero shot object detector (bbox and classification)."), + version=1, + algorithms=[ZeroShotObjectDetector().algorithm_config_response], + ) + + def get_stages(self) -> list[Algorithm]: + zero_shot_object_detector = ZeroShotObjectDetector() + if "candidate_labels" in self.request_config: + logger.info( + "Setting candidate labels for zero shot object detector to %s", + self.request_config["candidate_labels"], + ) + zero_shot_object_detector.candidate_labels = self.request_config[ + "candidate_labels" + ] + self.config.algorithms = [zero_shot_object_detector.algorithm_config_response] + + return [zero_shot_object_detector] + + def run(self) -> PipelineResultsResponse: + start_time = datetime.datetime.now() + logger.info("[1/1] Running the zero shot object detector...") + detections_with_classifications: list[Detection] = self._get_detections( + self.stages[0], self.source_images, self.batch_sizes[0] + ) + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + pipeline_response: PipelineResultsResponse = self._get_pipeline_response( + detections_with_classifications, elapsed_time + ) + logger.info( + f"Successfully processed {len(detections_with_classifications)} detections." + ) + + return pipeline_response + + +class ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline(Pipeline): + """ + A pipeline that uses the HuggingFace zero shot object detector and a random species classifier. + """ + + batch_sizes = [1, 1] + config = PipelineConfigResponse( + name="Zero Shot Object Detector With Random Species Classifier Pipeline", + slug="zero-shot-object-detector-with-random-species-classifier-pipeline", + description=("HF zero shot object detector with random species classifier."), + version=1, + algorithms=[ + ZeroShotObjectDetector().algorithm_config_response, + RandomSpeciesClassifier().algorithm_config_response, + ], + ) + + def get_stages(self) -> list[Algorithm]: + zero_shot_object_detector = ZeroShotObjectDetector() + if "candidate_labels" in self.request_config: + zero_shot_object_detector.candidate_labels = self.request_config[ + "candidate_labels" + ] + + self.config.algorithms = [ + zero_shot_object_detector.algorithm_config_response, + RandomSpeciesClassifier().algorithm_config_response, + ] + + return [zero_shot_object_detector, RandomSpeciesClassifier()] + + def run(self) -> PipelineResultsResponse: + start_time = datetime.datetime.now() + detections: list[Detection] = [] + if self.existing_detections: + logger.info("[1/2] Skipping the localizer, use existing detections...") + detections = self.existing_detections + else: + logger.info("[1/2] No existing detections, generating detections...") + detections = self._get_detections( + self.stages[0], self.source_images, self.batch_sizes[0] + ) + + logger.info("[2/2] Running the classifier...") + detections_with_classifications: list[Detection] = self._get_detections( + self.stages[1], detections, self.batch_sizes[1] + ) + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + pipeline_response: PipelineResultsResponse = self._get_pipeline_response( + detections_with_classifications, elapsed_time + ) + logger.info( + f"Successfully processed {len(detections_with_classifications)} detections." + ) + + return pipeline_response + + +class ZeroShotObjectDetectorWithConstantClassifierPipeline(Pipeline): + """ + A pipeline that uses the HuggingFace zero shot object detector and a constant classifier. + """ + + batch_sizes = [1, 1] + config = PipelineConfigResponse( + name="Zero Shot Object Detector With Constant Classifier Pipeline", + slug="zero-shot-object-detector-with-constant-classifier-pipeline", + description=("HF zero shot object detector with constant classifier."), + version=1, + algorithms=[ + ZeroShotObjectDetector().algorithm_config_response, + ConstantClassifier().algorithm_config_response, + ], + ) + + def get_stages(self) -> list[Algorithm]: + zero_shot_object_detector = ZeroShotObjectDetector() + if "candidate_labels" in self.request_config: + zero_shot_object_detector.candidate_labels = self.request_config[ + "candidate_labels" + ] + + self.config.algorithms = [ + zero_shot_object_detector.algorithm_config_response, + ConstantClassifier().algorithm_config_response, + ] + + return [zero_shot_object_detector, ConstantClassifier()] + + def run(self) -> PipelineResultsResponse: + start_time = datetime.datetime.now() + detections: list[Detection] = [] + if self.existing_detections: + logger.info("[1/2] Skipping the localizer, use existing detections...") + detections = self.existing_detections + else: + logger.info("[1/2] No existing detections, generating detections...") + detections = self._get_detections( + self.stages[0], self.source_images, self.batch_sizes[0] + ) + + logger.info("[2/2] Running the classifier...") + detections_with_classifications: list[Detection] = self._get_detections( + self.stages[1], detections, self.batch_sizes[1] + ) + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + pipeline_response: PipelineResultsResponse = self._get_pipeline_response( + detections_with_classifications, elapsed_time + ) + logger.info( + f"Successfully processed {len(detections_with_classifications)} detections." + ) + + return pipeline_response + + +class ZeroShotObjectDetectorWithGlobalMothClassifierPipeline(Pipeline): + """ + A pipeline that uses the HuggingFace zero shot object detector and the global moth classifier. + This provides high-quality moth species identification with 29,176+ species support. + """ + + batch_sizes = [1, 4] # Detector batch=1, Classifier batch=4 + config = PipelineConfigResponse( + name="Zero Shot Object Detector With Global Moth Classifier Pipeline", + slug="zero-shot-object-detector-with-global-moth-classifier-pipeline", + description=( + "HF zero shot object detector with global moth species classifier. " + "Supports 29,176+ moth species trained on global data." + ), + version=1, + algorithms=[], # Will be populated in get_stages() + ) + + def get_stages(self) -> list[Algorithm]: + zero_shot_object_detector = ZeroShotObjectDetector() + if "candidate_labels" in self.request_config: + zero_shot_object_detector.candidate_labels = self.request_config[ + "candidate_labels" + ] + + global_moth_classifier = GlobalMothClassifier() + + self.config.algorithms = [ + zero_shot_object_detector.algorithm_config_response, + global_moth_classifier.algorithm_config_response, + ] + + return [zero_shot_object_detector, global_moth_classifier] + + def run(self) -> PipelineResultsResponse: + start_time = datetime.datetime.now() + detections: list[Detection] = [] + + if self.existing_detections: + logger.info("[1/2] Skipping the localizer, use existing detections...") + detections = self.existing_detections + else: + logger.info("[1/2] No existing detections, generating detections...") + detections = self._get_detections( + self.stages[0], self.source_images, self.batch_sizes[0] + ) + + logger.info("[2/2] Running the global moth classifier...") + detections_with_classifications: list[Detection] = self._get_detections( + self.stages[1], detections, self.batch_sizes[1] + ) + + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + pipeline_response: PipelineResultsResponse = self._get_pipeline_response( + detections_with_classifications, elapsed_time + ) + logger.info( + f"Successfully processed {len(detections_with_classifications)} detections with global moth classifier." + ) + + return pipeline_response diff --git a/processing_services/moths/api/api/schemas.py b/processing_services/moths/api/api/schemas.py new file mode 100644 index 000000000..93e99f6aa --- /dev/null +++ b/processing_services/moths/api/api/schemas.py @@ -0,0 +1,341 @@ +# Can these be imported from the OpenAPI spec yaml? +import datetime +import logging +import pathlib +import typing + +import PIL.Image +import pydantic + +from .utils import get_image + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class BoundingBox(pydantic.BaseModel): + x1: float + y1: float + x2: float + y2: float + + @classmethod + def from_coords(cls, coords: list[float]): + return cls(x1=coords[0], y1=coords[1], x2=coords[2], y2=coords[3]) + + def to_string(self): + return f"{self.x1},{self.y1},{self.x2},{self.y2}" + + def to_path(self): + return "-".join([str(int(x)) for x in [self.x1, self.y1, self.x2, self.y2]]) + + def to_tuple(self): + return (self.x1, self.y1, self.x2, self.y2) + + +class BaseImage(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="ignore", arbitrary_types_allowed=True) + + id: str + url: str | None = None + b64: str | None = None + filepath: str | pathlib.Path | None = None + _pil: PIL.Image.Image | None = None + width: int | None = None + height: int | None = None + timestamp: datetime.datetime | None = None + + # Validate that there is at least one of the following fields + @pydantic.model_validator(mode="after") + def validate_source(self): + if not any([self.url, self.b64, self.filepath, self._pil]): + raise ValueError( + "At least one of the following fields must be provided: url, b64, filepath, pil" + ) + return self + + def open(self, raise_exception=False) -> PIL.Image.Image | None: + if not self._pil: + logger.warn(f"Opening image {self.id} for the first time") + self._pil = get_image( + url=self.url, + b64=self.b64, + filepath=self.filepath, + raise_exception=raise_exception, + ) + else: + logger.info(f"Using already loaded image {self.id}") + if self._pil: + self.width, self.height = self._pil.size + return self._pil + + +class SourceImage(BaseImage): + pass + + +class AlgorithmReference(pydantic.BaseModel): + name: str + key: str + + +class ClassificationResponse(pydantic.BaseModel): + classification: str + labels: list[str] | None = pydantic.Field( + default=None, + description=( + "A list of all possible labels for the model, in the correct order. " + "Omitted if the model has too many labels to include for each classification in the response. " + "Use the category map from the algorithm to get the full list of labels and metadata." + ), + ) + scores: list[float] = pydantic.Field( + default_factory=list, + description="The calibrated probabilities for each class label, most commonly the softmax output.", + ) + logits: list[float] = pydantic.Field( + default_factory=list, + description="The raw logits output by the model, before any calibration or normalization.", + ) + inference_time: float | None = None + algorithm: AlgorithmReference + terminal: bool = True + timestamp: datetime.datetime + + +class SourceImageRequest(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="ignore") + + id: str + url: str + # b64: str | None = None + # @TODO bring over new SourceImage & b64 validation from the lepsAI repo + + +class SourceImageResponse(pydantic.BaseModel): + model_config = pydantic.ConfigDict(extra="ignore") + + id: str + url: str + + +class DetectionRequest(pydantic.BaseModel): + source_image: SourceImageRequest # the 'original' image + bbox: BoundingBox + crop_image_url: str | None = None + algorithm: AlgorithmReference + + +class DetectionResponse(pydantic.BaseModel): + # these fields are populated with values from a Detection, excluding source_image details + source_image_id: str + bbox: BoundingBox + inference_time: float | None = None + algorithm: AlgorithmReference + timestamp: datetime.datetime + crop_image_url: str | None = None + classifications: list[ClassificationResponse] = [] + + +class Detection(BaseImage): + """ + An internal representation of a detection with reference to a source image instance. + """ + + source_image: SourceImage # the 'original' uncropped image + bbox: BoundingBox + inference_time: float | None = None + algorithm: AlgorithmReference + classifications: list[ClassificationResponse] = [] + + +class AlgorithmCategoryMapResponse(pydantic.BaseModel): + data: list[dict] = pydantic.Field( + default_factory=dict, + description="Complete data for each label, such as id, gbif_key, explicit index, source, etc.", + examples=[ + [ + {"label": "Moth", "index": 0, "gbif_key": 1234}, + {"label": "Not a moth", "index": 1, "gbif_key": 5678}, + ] + ], + ) + labels: list[str] = pydantic.Field( + default_factory=list, + description="A simple list of string labels, in the correct index order used by the model.", + examples=[["Moth", "Not a moth"]], + ) + version: str | None = pydantic.Field( + default=None, + description="The version of the category map. Can be a descriptive string or a version number.", + examples=["LepNet2021-with-2023-mods"], + ) + description: str | None = pydantic.Field( + default=None, + description="A description of the category map used to train. e.g. source, purpose and modifications.", + examples=[ + "LepNet2021 with Schmidt 2023 corrections. Limited to species with > 1000 observations." + ], + ) + uri: str | None = pydantic.Field( + default=None, + description="A URI to the category map file, could be a public web URL or object store path.", + ) + + +class AlgorithmConfigResponse(pydantic.BaseModel): + name: str + key: str = pydantic.Field( + description=( + "A unique key for an algorithm to lookup the category map (class list) and other metadata." + ), + ) + description: str | None = None + task_type: str | None = pydantic.Field( + default=None, + description="The type of task the model is trained for. e.g. 'detection', 'classification', 'embedding', etc.", + examples=["detection", "classification", "segmentation", "embedding"], + ) + version: int = pydantic.Field( + default=1, + description="A sortable version number for the model. Increment this number when the model is updated.", + ) + version_name: str | None = pydantic.Field( + default=None, + description="A complete version name e.g. '2021-01-01', 'LepNet2021'.", + ) + uri: str | None = pydantic.Field( + default=None, + description="A URI to the weights or model details, could be a public web URL or object store path.", + ) + category_map: AlgorithmCategoryMapResponse | None = None + + class Config: + extra = "ignore" + + +PipelineChoice = typing.Literal[ + # @TODO can this be dynamically generated from available pipelines? + "zero-shot-hf-classifier-pipeline", + "zero-shot-object-detector-pipeline", + "zero-shot-object-detector-with-constant-classifier-pipeline", + "zero-shot-object-detector-with-random-species-classifier-pipeline", + "zero-shot-object-detector-with-global-moth-classifier-pipeline", +] + + +class PipelineRequestConfigParameters(pydantic.BaseModel): + """Parameters used to configure a pipeline request. + + Accepts any serializable key-value pair. + Example: {"force_reprocess": True, "auth_token": "abc123"} + + Supported parameters are defined by the pipeline in the processing service + and should be published in the Pipeline's info response. + """ + + force_reprocess: bool = pydantic.Field( + default=False, + description="Force reprocessing of the image, even if it has already been processed.", + ) + auth_token: str | None = pydantic.Field( + default=None, + description="An optional authentication token to use for the pipeline.", + ) + candidate_labels: list[str] | None = pydantic.Field( + default=None, + description="A list of candidate labels to use for the zero-shot object detector.", + ) + + +class PipelineRequest(pydantic.BaseModel): + pipeline: PipelineChoice + source_images: list[SourceImageRequest] + detections: list[DetectionRequest] | None = None + config: PipelineRequestConfigParameters | dict | None = None + + # Example for API docs: + class Config: + json_schema_extra = { + "example": { + "pipeline": "random", + "source_images": [ + { + "id": "123", + "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg", + } + ], + "config": {"force_reprocess": True, "auth_token": "abc123"}, + } + } + + +class PipelineResultsResponse(pydantic.BaseModel): + pipeline: PipelineChoice + total_time: float + algorithms: dict[str, AlgorithmConfigResponse] = pydantic.Field( + default_factory=dict, + description=( + "A dictionary of all algorithms used in the pipeline, including their class list and other " + "metadata, keyed by the algorithm key. " + "DEPRECATED: Algorithms should only be provided in the ProcessingServiceInfoResponse." + ), + depreciated=True, + ) + source_images: list[SourceImageResponse] + detections: list[DetectionResponse] + errors: list | str | None = None + + +class PipelineStageParam(pydantic.BaseModel): + """A configurable parameter of a stage of a pipeline.""" + + name: str + key: str + category: str = "default" + + +class PipelineStage(pydantic.BaseModel): + """A configurable stage of a pipeline.""" + + key: str + name: str + params: list[PipelineStageParam] = [] + description: str | None = None + + +class PipelineConfigResponse(pydantic.BaseModel): + """Details about a pipeline, its algorithms and category maps.""" + + name: str + slug: str + version: int + description: str | None = None + algorithms: list[AlgorithmConfigResponse] = [] + stages: list[PipelineStage] = [] + + +class ProcessingServiceInfoResponse(pydantic.BaseModel): + """Information about the processing service.""" + + name: str = pydantic.Field(example="Mila Research Lab - Moth AI Services") + description: str | None = pydantic.Field( + default=None, + examples=[ + "Algorithms developed by the Mila Research Lab for analysis of moth images." + ], + ) + pipelines: list[PipelineConfigResponse] = pydantic.Field( + default=list, + examples=[ + [ + PipelineConfigResponse( + name="Random Pipeline", slug="random", version=1, algorithms=[] + ), + ] + ], + ) + # algorithms: list[AlgorithmConfigResponse] = pydantic.Field( + # default=list, + # examples=[RANDOM_BINARY_CLASSIFIER], + # ) diff --git a/processing_services/moths/api/api/test.py b/processing_services/moths/api/api/test.py new file mode 100644 index 000000000..b5b1b5f7c --- /dev/null +++ b/processing_services/moths/api/api/test.py @@ -0,0 +1,64 @@ +import unittest + +from fastapi.testclient import TestClient + +from .api import app +from .pipelines import CustomPipeline +from .schemas import PipelineRequest, SourceImage, SourceImageRequest + + +class TestPipeline(unittest.TestCase): + def test_custom_pipeline(self): + # @TODO: Load actual antenna images? + pipeline = CustomPipeline( + source_images=[ + SourceImage( + id="1001", + url=( + "https://huggingface.co/datasets/huggingface/" + "documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + ), + ), + SourceImage(id="1002", url="https://cdn.britannica.com/79/191679-050-C7114D2B/Adult-capybara.jpg"), + ], + detector_batch_size=2, + classifier_batch_size=2, + ) + detections = pipeline.run() + + self.assertEqual(len(detections), 20) + expected_labels = ["lynx, catamount", "beaver"] + for detection_id, detection in enumerate(detections): + self.assertEqual(detection.source_image_id, pipeline.source_images[detection_id].id) + self.assertIsNotNone(detection.bbox) + self.assertEqual(len(detection.classifications), 1) + classification = detection.classifications[0] + self.assertEqual(classification.classification, expected_labels[detection_id]) + self.assertGreaterEqual(classification.scores[0], 0.0) + self.assertLessEqual(classification.scores[0], 1.0) + + +class TestAPI(unittest.TestCase): + def setUp(self): + self.client = TestClient(app) + + def test_root(self): + response = self.client.get("/") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.url, "http://testserver/docs") + + def test_process(self): + source_images = [ + SourceImage(id="1", url="https://example.com/image1.jpg"), + SourceImage(id="2", url="https://example.com/image2.jpg"), + ] + source_image_requests = [SourceImageRequest(**image.dict()) for image in source_images] + request = PipelineRequest(pipeline="local-pipeline", source_images=source_image_requests, config={}) + response = self.client.post("/process", json=request.dict()) + + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["pipeline"], "local-pipeline") + self.assertEqual(len(data["source_images"]), 2) + self.assertEqual(len(data["detections"]), 2) + self.assertGreater(data["total_time"], 0.0) diff --git a/processing_services/moths/api/api/utils.py b/processing_services/moths/api/api/utils.py new file mode 100644 index 000000000..9fd50d3a4 --- /dev/null +++ b/processing_services/moths/api/api/utils.py @@ -0,0 +1,172 @@ +import base64 +import binascii +import io +import logging +import pathlib +import re +import tempfile +from urllib.parse import urlparse + +import PIL.Image +import PIL.ImageFile +import requests +import torch + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True + +# This is polite and required by some hosts +# see: https://foundation.wikimedia.org/wiki/Policy:User-Agent_policy +USER_AGENT = "AntennaInsectDataPlatform/1.0 (https://insectai.org)" + +# ----------- +# File handling functions +# ----------- + + +def is_url(path: str) -> bool: + return path.startswith("http://") or path.startswith("https://") + + +def is_base64(s: str) -> bool: + try: + # Check if string can be decoded from base64 + return base64.b64encode(base64.b64decode(s)).decode() == s + except Exception: + return False + + +def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path: + """ + Fetch a file from a URL or local path. If the path is a URL, download the file. + If the URL has already been downloaded, return the existing local path. + If the path is a local path, return the path. + + >>> filepath = get_or_download_file("https://example.uk/images/31-20230919033000-snapshot.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=451d406b7eb1113e1bb05c083ce51481%2F20240429%2F") # noqa: E501 + >>> filepath.name + '31-20230919033000-snapshot.jpg' + >>> filepath = get_or_download_file("/home/user/images/31-20230919033000-snapshot.jpg") + >>> filepath.name + '31-20230919033000-snapshot.jpg' + """ + if not path_or_url: + raise Exception("Specify a URL or path to fetch file from.") + + # If path is a local path instead of a URL then urlretrieve will just return that path + + # Use persistent cache directory instead of temp + cache_root = pathlib.Path("/app/cache") if pathlib.Path("/app").exists() else pathlib.Path.home() / ".antenna_cache" + destination_dir = cache_root / tempdir_prefix + fname = pathlib.Path(urlparse(path_or_url).path).name + if not destination_dir.exists(): + destination_dir.mkdir(parents=True, exist_ok=True) + local_filepath = destination_dir / fname + + if local_filepath and local_filepath.exists(): + logger.info(f"📁 Using cached file: {local_filepath}") + return local_filepath + + else: + logger.info(f"⬇️ Downloading {path_or_url} to {local_filepath}") + headers = {"User-Agent": USER_AGENT} + response = requests.get(path_or_url, stream=True, headers=headers) + response.raise_for_status() # Raise an exception for HTTP errors + + with open(local_filepath, "wb") as f: + total_size = int(response.headers.get('content-length', 0)) + downloaded = 0 + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + downloaded += len(chunk) + if total_size > 0: + percent = (downloaded / total_size) * 100 + logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)") + + resulting_filepath = pathlib.Path(local_filepath).resolve() + logger.info(f"✅ Download completed: {resulting_filepath}") + logger.info(f"Downloaded to {resulting_filepath}") + return resulting_filepath + + +def get_best_device() -> str: + """ + Returns the best available device for running the model. + MPS is not supported by the current algorithms. + """ + if torch.cuda.is_available(): + return f"cuda:{torch.cuda.current_device()}" + else: + return "cpu" + + +def open_image( + fp: str | bytes | pathlib.Path | io.BytesIO, raise_exception: bool = True +) -> PIL.Image.Image | None: + """ + Wrapper from PIL.Image.open that handles errors and converts to RGB. + """ + img = None + try: + img = PIL.Image.open(fp) + except PIL.UnidentifiedImageError: + logger.warn(f"Unidentified image: {str(fp)[:100]}...") + if raise_exception: + raise + except OSError: + logger.warn(f"Could not open image: {str(fp)[:100]}...") + if raise_exception: + raise + else: + # Convert to RGB if necessary + if img.mode != "RGB": + img = img.convert("RGB") + + return img + + +def decode_base64_string(string) -> io.BytesIO: + image_data = re.sub("^data:image/.+;base64,", "", string) + decoded = base64.b64decode(image_data) + buffer = io.BytesIO(decoded) + buffer.seek(0) + return buffer + + +def get_image( + url: str | None = None, + filepath: str | pathlib.Path | None = None, + b64: str | None = None, + raise_exception: bool = True, +) -> PIL.Image.Image | None: + """ + Given a URL, local file path or base64 image, return a PIL image. + """ + + if url: + logger.info(f"Fetching image from URL: {url}") + tempdir = tempfile.TemporaryDirectory(prefix="ami_images") + img_path = get_or_download_file(url, tempdir_prefix=tempdir.name) + return open_image(img_path, raise_exception=raise_exception) + + elif filepath: + logger.info(f"Loading image from local filesystem: {filepath}") + return open_image(filepath, raise_exception=raise_exception) + + elif b64: + logger.info(f"Loading image from base64 string: {b64[:30]}...") + try: + buffer = decode_base64_string(b64) + except binascii.Error as e: + logger.warn(f"Could not decode base64 image: {e}") + if raise_exception: + raise + else: + return None + else: + return open_image(buffer, raise_exception=raise_exception) + + else: + raise Exception("Specify a URL, path or base64 image.") diff --git a/processing_services/moths/api/docker-compose.yml b/processing_services/moths/api/docker-compose.yml new file mode 100644 index 000000000..83db6ccfa --- /dev/null +++ b/processing_services/moths/api/docker-compose.yml @@ -0,0 +1,25 @@ +services: + ml_backend_example: + build: + context: . + volumes: + - ./:/app:z + - ./huggingface_cache:/root/.cache/huggingface + - ./pytorch_cache:/root/.cache/torch + ports: + - "2003:2000" + extra_hosts: + - minio:host-gateway + networks: + - antenna_network + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [ gpu ] + +networks: + antenna_network: + name: antenna_network diff --git a/processing_services/moths/api/main.py b/processing_services/moths/api/main.py new file mode 100644 index 000000000..2ed50004d --- /dev/null +++ b/processing_services/moths/api/main.py @@ -0,0 +1,4 @@ +if __name__ == "__main__": + import uvicorn + + uvicorn.run("api.api:app", host="0.0.0.0", port=2000, reload=True) diff --git a/processing_services/moths/api/requirements.txt b/processing_services/moths/api/requirements.txt new file mode 100644 index 000000000..4cf7a91d9 --- /dev/null +++ b/processing_services/moths/api/requirements.txt @@ -0,0 +1,10 @@ +fastapi==0.116.0 +uvicorn==0.35.0 +pydantic==2.11.7 +Pillow==11.3.0 +requests==2.32.4 +transformers==4.50.3 +torch==2.6.0 +torchvision==0.21.0 +scipy==1.16.0 +timm diff --git a/processing_services/moths/api/test_api_integration.py b/processing_services/moths/api/test_api_integration.py new file mode 100644 index 000000000..492a549b2 --- /dev/null +++ b/processing_services/moths/api/test_api_integration.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +""" +API Integration Test for Global Moth Classifier Pipeline. +This test calls the actual HTTP API endpoints to validate the service. +""" + +import json +import pathlib +import sys +import time + +import requests + +# Add the processing_services/example to the path +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + + +def test_api_integration(): + """Test the Global Moth Classifier Pipeline via HTTP API.""" + print("🌐 Testing API Integration for Global Moth Classifier...") + + base_url = "http://ml_backend_example:2000" + + # Test 1: Get service info + print("\n📋 Test 1: Getting service info...") + try: + response = requests.get(f"{base_url}/info", timeout=30) + response.raise_for_status() + info = response.json() + + print("✅ Service info retrieved successfully!") + print(f" Service name: {info.get('name', 'Unknown')}") + print(f" Version: {info.get('version', 'Unknown')}") + print(f" Available pipelines: {len(info.get('pipelines', []))}") + + # Check if our pipeline is available + pipeline_slugs = [p.get('slug') for p in info.get('pipelines', [])] + expected_slug = "zero-shot-object-detector-with-global-moth-classifier-pipeline" + + if expected_slug in pipeline_slugs: + print("✅ Global Moth Classifier pipeline found in service!") + else: + print("❌ Global Moth Classifier pipeline NOT found in service") + print(f" Available pipelines: {pipeline_slugs}") + return False + + except Exception as e: + print(f"❌ Service info request failed: {str(e)}") + return False + + # Test 2: Process image with Global Moth Classifier + print("\n🦋 Test 2: Processing moth image...") + + request_payload = { + "config": { + "auth_token": "test123", + "force_reprocess": True, + "candidate_labels": ["moth", "butterfly", "insect"] + }, + "pipeline": "zero-shot-object-detector-with-global-moth-classifier-pipeline", + "source_images": [ + { + "id": "api_test_123", + "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg" + } + ] + } + + try: + print("📤 Sending processing request...") + print(f" Pipeline: {request_payload['pipeline']}") + print(f" Image URL: {request_payload['source_images'][0]['url']}") + + start_time = time.time() + response = requests.post( + f"{base_url}/process", + json=request_payload, + timeout=300 # 5 minutes timeout for processing + ) + end_time = time.time() + + response.raise_for_status() + result = response.json() + + processing_time = end_time - start_time + print("✅ Image processed successfully!") + print(f" API response time: {processing_time:.2f}s") + print(f" Pipeline processing time: {result.get('total_time', 'unknown')}s") + print(f" Number of detections: {len(result.get('detections', []))}") + + # Analyze results + detections = result.get('detections', []) + if detections: + print("\n🔍 Detection Results:") + for i, detection in enumerate(detections[:5]): # Show first 5 + bbox = detection.get('bbox', {}) + classifications = detection.get('classifications', []) + + print(f" Detection {i+1}:") + print(f" - Bbox: {bbox}") + print(f" - Algorithm: {detection.get('algorithm', {}).get('name', 'unknown')}") + + if classifications: + # Find top classification + top_classification = classifications[0] + if 'scores' in top_classification and top_classification['scores']: + max_score = max(top_classification['scores']) + max_idx = top_classification['scores'].index(max_score) + if 'labels' in top_classification and max_idx < len(top_classification['labels']): + species_name = top_classification['labels'][max_idx] + print(f" - Top species: {species_name} ({max_score:.3f})") + else: + print(f" - Classification: {top_classification.get('classification', 'unknown')}") + + if len(detections) > 5: + print(f" ... and {len(detections) - 5} more detections") + else: + print("⚠️ No detections found in the image") + + return True + + except Exception as e: + print(f"❌ Image processing request failed: {str(e)}") + if hasattr(e, 'response') and e.response is not None: + try: + error_details = e.response.json() + print(f" Error details: {json.dumps(error_details, indent=2)}") + except: + print(f" Error response: {e.response.text}") + return False + + +def test_service_health(): + """Test basic service health endpoints.""" + print("\n🏥 Testing service health endpoints...") + + base_url = "http://ml_backend_example:2000" + + # Test health endpoints + health_endpoints = ["/", "/livez", "/readyz"] + + for endpoint in health_endpoints: + try: + response = requests.get(f"{base_url}{endpoint}", timeout=10) + response.raise_for_status() + print(f"✅ {endpoint}: {response.status_code}") + except Exception as e: + print(f"❌ {endpoint}: {str(e)}") + return False + + return True + + +if __name__ == "__main__": + print("🧪 Starting API Integration Tests for Global Moth Classifier") + print("=" * 60) + + # Test service health first + health_ok = test_service_health() + if not health_ok: + print("\n❌ Service health checks failed!") + sys.exit(1) + + # Test main API integration + api_ok = test_api_integration() + + print("\n" + "=" * 60) + if api_ok: + print("🎉 All API integration tests PASSED!") + sys.exit(0) + else: + print("❌ API integration tests FAILED!") + sys.exit(1) \ No newline at end of file diff --git a/processing_services/moths/api/test_compilation_logging.py b/processing_services/moths/api/test_compilation_logging.py new file mode 100644 index 000000000..b756fa286 --- /dev/null +++ b/processing_services/moths/api/test_compilation_logging.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +""" +Simple test to trigger model compilation and see enhanced logging. +""" + +import pathlib +import sys + +# Add the processing_services/example to the path +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + +from api.global_moth_classifier import GlobalMothClassifier + + +def test_compilation_logging(): + """Test the enhanced logging during model compilation.""" + print("🔧 Testing enhanced compilation logging...") + + # Create classifier instance + classifier = GlobalMothClassifier() + + print(f"📋 Classifier instantiated: {classifier.name}") + print(f" Expected classes: {classifier.num_classes}") + + # Trigger compilation (this should show our enhanced logging) + print("\n⚡ Triggering compilation...") + classifier.compile() + + print("\n✅ Compilation complete!") + print(f" Model loaded: {classifier.model is not None}") + print(f" Transforms ready: {classifier.transforms is not None}") + print(f" Categories loaded: {len(classifier.category_map)} species") + + return True + + +if __name__ == "__main__": + test_compilation_logging() \ No newline at end of file diff --git a/processing_services/moths/api/test_global_moth_pipeline.py b/processing_services/moths/api/test_global_moth_pipeline.py new file mode 100644 index 000000000..1886e5823 --- /dev/null +++ b/processing_services/moths/api/test_global_moth_pipeline.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +Test script for the Global Moth Classifier Pipeline. +This test processes a real moth image and validates the full pipeline functionality. +""" + +import pathlib +import sys + +# Add the processing_services/example to the path +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + +from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline +from api.schemas import SourceImage +from api.utils import get_image + + +def test_global_moth_pipeline(): + """Test the Global Moth Classifier Pipeline with a real request.""" + print("🧪 Testing Global Moth Classifier Pipeline with real request...") + + # Create source image from the provided URL + source_image = SourceImage( + id="123", + url="https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg", + width=800, # Typical image dimensions + height=600 + ) + + # Load the PIL image and attach it to the source image + print("📥 Loading image from URL...") + pil_image = get_image(url=source_image.url) + if pil_image: + source_image._pil = pil_image + # Update dimensions with actual image size + source_image.width = pil_image.width + source_image.height = pil_image.height + print(f"✅ Image loaded: {pil_image.width}x{pil_image.height}") + else: + print("❌ Failed to load image") + return False + + # Create pipeline with the test configuration + pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( + source_images=[source_image], + request_config={ + "auth_token": "abc123", + "force_reprocess": True, + "candidate_labels": ["moth", "butterfly", "insect"] # Add candidate labels for detection + }, + existing_detections=[], + ) + + print("✅ Pipeline instantiated successfully!") + print(f" Pipeline name: {pipeline.config.name}") + print(f" Pipeline slug: {pipeline.config.slug}") + print(f" Number of algorithms: {len(pipeline.config.algorithms)}") + print(f" Algorithm 1: {pipeline.config.algorithms[0].name}") + print(f" Algorithm 2: {pipeline.config.algorithms[1].name}") + + # Test that stages can be created + stages = pipeline.get_stages() + assert len(stages) == 2 + print(f" Stages created: {len(stages)}") + + # Compile the pipeline (load models) + print("🔧 Compiling pipeline (loading models)...") + pipeline.compile() + print("✅ Pipeline compiled successfully!") + + # Run the pipeline + print("🚀 Running pipeline on test image...") + try: + result = pipeline.run() + print("✅ Pipeline execution completed!") + print(f" Total processing time: {result.total_time:.2f}s") + print(f" Number of detections: {len(result.detections)}") + + # Print detection details + for i, detection in enumerate(result.detections): + print(f" Detection {i+1}:") + print(f" - Bbox: {detection.bbox}") + print(f" - Inference time: {detection.inference_time:.3f}s") + print(f" - Algorithm: {detection.algorithm}") + if detection.classifications: + # Get the classification with the highest score + top_classification = detection.classifications[0] # Usually sorted by confidence + if top_classification.scores: + max_score = max(top_classification.scores) + max_idx = top_classification.scores.index(max_score) + if top_classification.labels and max_idx < len(top_classification.labels): + species_name = top_classification.labels[max_idx] + print(f" - Top classification: {species_name} ({max_score:.3f})") + else: + print(f" - Classification: {top_classification.classification}") + else: + print(f" - Classification: {top_classification.classification}") + + return True + + except Exception as e: + print(f"❌ Pipeline execution failed: {str(e)}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + test_global_moth_pipeline() diff --git a/processing_services/moths/requirements.txt b/processing_services/moths/requirements.txt new file mode 100644 index 000000000..4cf7a91d9 --- /dev/null +++ b/processing_services/moths/requirements.txt @@ -0,0 +1,10 @@ +fastapi==0.116.0 +uvicorn==0.35.0 +pydantic==2.11.7 +Pillow==11.3.0 +requests==2.32.4 +transformers==4.50.3 +torch==2.6.0 +torchvision==0.21.0 +scipy==1.16.0 +timm diff --git a/processing_services/moths/test_api_integration.py b/processing_services/moths/test_api_integration.py new file mode 100644 index 000000000..492a549b2 --- /dev/null +++ b/processing_services/moths/test_api_integration.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +""" +API Integration Test for Global Moth Classifier Pipeline. +This test calls the actual HTTP API endpoints to validate the service. +""" + +import json +import pathlib +import sys +import time + +import requests + +# Add the processing_services/example to the path +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + + +def test_api_integration(): + """Test the Global Moth Classifier Pipeline via HTTP API.""" + print("🌐 Testing API Integration for Global Moth Classifier...") + + base_url = "http://ml_backend_example:2000" + + # Test 1: Get service info + print("\n📋 Test 1: Getting service info...") + try: + response = requests.get(f"{base_url}/info", timeout=30) + response.raise_for_status() + info = response.json() + + print("✅ Service info retrieved successfully!") + print(f" Service name: {info.get('name', 'Unknown')}") + print(f" Version: {info.get('version', 'Unknown')}") + print(f" Available pipelines: {len(info.get('pipelines', []))}") + + # Check if our pipeline is available + pipeline_slugs = [p.get('slug') for p in info.get('pipelines', [])] + expected_slug = "zero-shot-object-detector-with-global-moth-classifier-pipeline" + + if expected_slug in pipeline_slugs: + print("✅ Global Moth Classifier pipeline found in service!") + else: + print("❌ Global Moth Classifier pipeline NOT found in service") + print(f" Available pipelines: {pipeline_slugs}") + return False + + except Exception as e: + print(f"❌ Service info request failed: {str(e)}") + return False + + # Test 2: Process image with Global Moth Classifier + print("\n🦋 Test 2: Processing moth image...") + + request_payload = { + "config": { + "auth_token": "test123", + "force_reprocess": True, + "candidate_labels": ["moth", "butterfly", "insect"] + }, + "pipeline": "zero-shot-object-detector-with-global-moth-classifier-pipeline", + "source_images": [ + { + "id": "api_test_123", + "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg" + } + ] + } + + try: + print("📤 Sending processing request...") + print(f" Pipeline: {request_payload['pipeline']}") + print(f" Image URL: {request_payload['source_images'][0]['url']}") + + start_time = time.time() + response = requests.post( + f"{base_url}/process", + json=request_payload, + timeout=300 # 5 minutes timeout for processing + ) + end_time = time.time() + + response.raise_for_status() + result = response.json() + + processing_time = end_time - start_time + print("✅ Image processed successfully!") + print(f" API response time: {processing_time:.2f}s") + print(f" Pipeline processing time: {result.get('total_time', 'unknown')}s") + print(f" Number of detections: {len(result.get('detections', []))}") + + # Analyze results + detections = result.get('detections', []) + if detections: + print("\n🔍 Detection Results:") + for i, detection in enumerate(detections[:5]): # Show first 5 + bbox = detection.get('bbox', {}) + classifications = detection.get('classifications', []) + + print(f" Detection {i+1}:") + print(f" - Bbox: {bbox}") + print(f" - Algorithm: {detection.get('algorithm', {}).get('name', 'unknown')}") + + if classifications: + # Find top classification + top_classification = classifications[0] + if 'scores' in top_classification and top_classification['scores']: + max_score = max(top_classification['scores']) + max_idx = top_classification['scores'].index(max_score) + if 'labels' in top_classification and max_idx < len(top_classification['labels']): + species_name = top_classification['labels'][max_idx] + print(f" - Top species: {species_name} ({max_score:.3f})") + else: + print(f" - Classification: {top_classification.get('classification', 'unknown')}") + + if len(detections) > 5: + print(f" ... and {len(detections) - 5} more detections") + else: + print("⚠️ No detections found in the image") + + return True + + except Exception as e: + print(f"❌ Image processing request failed: {str(e)}") + if hasattr(e, 'response') and e.response is not None: + try: + error_details = e.response.json() + print(f" Error details: {json.dumps(error_details, indent=2)}") + except: + print(f" Error response: {e.response.text}") + return False + + +def test_service_health(): + """Test basic service health endpoints.""" + print("\n🏥 Testing service health endpoints...") + + base_url = "http://ml_backend_example:2000" + + # Test health endpoints + health_endpoints = ["/", "/livez", "/readyz"] + + for endpoint in health_endpoints: + try: + response = requests.get(f"{base_url}{endpoint}", timeout=10) + response.raise_for_status() + print(f"✅ {endpoint}: {response.status_code}") + except Exception as e: + print(f"❌ {endpoint}: {str(e)}") + return False + + return True + + +if __name__ == "__main__": + print("🧪 Starting API Integration Tests for Global Moth Classifier") + print("=" * 60) + + # Test service health first + health_ok = test_service_health() + if not health_ok: + print("\n❌ Service health checks failed!") + sys.exit(1) + + # Test main API integration + api_ok = test_api_integration() + + print("\n" + "=" * 60) + if api_ok: + print("🎉 All API integration tests PASSED!") + sys.exit(0) + else: + print("❌ API integration tests FAILED!") + sys.exit(1) \ No newline at end of file diff --git a/processing_services/moths/test_global_moth_pipeline.py b/processing_services/moths/test_global_moth_pipeline.py new file mode 100644 index 000000000..1886e5823 --- /dev/null +++ b/processing_services/moths/test_global_moth_pipeline.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +""" +Test script for the Global Moth Classifier Pipeline. +This test processes a real moth image and validates the full pipeline functionality. +""" + +import pathlib +import sys + +# Add the processing_services/example to the path +sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) + +from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline +from api.schemas import SourceImage +from api.utils import get_image + + +def test_global_moth_pipeline(): + """Test the Global Moth Classifier Pipeline with a real request.""" + print("🧪 Testing Global Moth Classifier Pipeline with real request...") + + # Create source image from the provided URL + source_image = SourceImage( + id="123", + url="https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg", + width=800, # Typical image dimensions + height=600 + ) + + # Load the PIL image and attach it to the source image + print("📥 Loading image from URL...") + pil_image = get_image(url=source_image.url) + if pil_image: + source_image._pil = pil_image + # Update dimensions with actual image size + source_image.width = pil_image.width + source_image.height = pil_image.height + print(f"✅ Image loaded: {pil_image.width}x{pil_image.height}") + else: + print("❌ Failed to load image") + return False + + # Create pipeline with the test configuration + pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( + source_images=[source_image], + request_config={ + "auth_token": "abc123", + "force_reprocess": True, + "candidate_labels": ["moth", "butterfly", "insect"] # Add candidate labels for detection + }, + existing_detections=[], + ) + + print("✅ Pipeline instantiated successfully!") + print(f" Pipeline name: {pipeline.config.name}") + print(f" Pipeline slug: {pipeline.config.slug}") + print(f" Number of algorithms: {len(pipeline.config.algorithms)}") + print(f" Algorithm 1: {pipeline.config.algorithms[0].name}") + print(f" Algorithm 2: {pipeline.config.algorithms[1].name}") + + # Test that stages can be created + stages = pipeline.get_stages() + assert len(stages) == 2 + print(f" Stages created: {len(stages)}") + + # Compile the pipeline (load models) + print("🔧 Compiling pipeline (loading models)...") + pipeline.compile() + print("✅ Pipeline compiled successfully!") + + # Run the pipeline + print("🚀 Running pipeline on test image...") + try: + result = pipeline.run() + print("✅ Pipeline execution completed!") + print(f" Total processing time: {result.total_time:.2f}s") + print(f" Number of detections: {len(result.detections)}") + + # Print detection details + for i, detection in enumerate(result.detections): + print(f" Detection {i+1}:") + print(f" - Bbox: {detection.bbox}") + print(f" - Inference time: {detection.inference_time:.3f}s") + print(f" - Algorithm: {detection.algorithm}") + if detection.classifications: + # Get the classification with the highest score + top_classification = detection.classifications[0] # Usually sorted by confidence + if top_classification.scores: + max_score = max(top_classification.scores) + max_idx = top_classification.scores.index(max_score) + if top_classification.labels and max_idx < len(top_classification.labels): + species_name = top_classification.labels[max_idx] + print(f" - Top classification: {species_name} ({max_score:.3f})") + else: + print(f" - Classification: {top_classification.classification}") + else: + print(f" - Classification: {top_classification.classification}") + + return True + + except Exception as e: + print(f"❌ Pipeline execution failed: {str(e)}") + import traceback + traceback.print_exc() + return False + + +if __name__ == "__main__": + test_global_moth_pipeline() From cf00c3e5e980357f8d87c0dd0d905fa0c96becfe Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 10 Oct 2025 21:07:27 -0700 Subject: [PATCH 2/3] feat: add global moth classifier to example pipeline --- processing_services/example/.gitignore | 4 + processing_services/example/api/api.py | 2 + .../{moths/api => example}/api/base.py | 24 +- .../api/global_moth_classifier.py | 27 +- processing_services/example/api/pipelines.py | 62 +++ processing_services/example/api/schemas.py | 1 + processing_services/example/api/utils.py | 31 +- .../example/docker-compose.yml | 1 + processing_services/example/requirements.txt | 1 + processing_services/moths/Dockerfile | 7 - .../moths/README_GLOBAL_MOTH_CLASSIFIER.md | 203 -------- processing_services/moths/api/Dockerfile | 7 - .../api/README_GLOBAL_MOTH_CLASSIFIER.md | 203 -------- processing_services/moths/api/api/__init__.py | 0 .../moths/api/api/algorithms.py | 431 ----------------- processing_services/moths/api/api/api.py | 261 ---------- .../moths/api/api/pipelines.py | 457 ------------------ processing_services/moths/api/api/schemas.py | 341 ------------- processing_services/moths/api/api/test.py | 64 --- processing_services/moths/api/api/utils.py | 172 ------- .../moths/api/docker-compose.yml | 25 - processing_services/moths/api/main.py | 4 - .../moths/api/requirements.txt | 10 - .../moths/api/test_api_integration.py | 173 ------- .../moths/api/test_compilation_logging.py | 38 -- .../moths/api/test_global_moth_pipeline.py | 109 ----- processing_services/moths/requirements.txt | 10 - .../moths/test_api_integration.py | 173 ------- .../moths/test_global_moth_pipeline.py | 109 ----- 29 files changed, 119 insertions(+), 2831 deletions(-) create mode 100644 processing_services/example/.gitignore rename processing_services/{moths/api => example}/api/base.py (92%) rename processing_services/{moths/api => example}/api/global_moth_classifier.py (95%) delete mode 100644 processing_services/moths/Dockerfile delete mode 100644 processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md delete mode 100644 processing_services/moths/api/Dockerfile delete mode 100644 processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md delete mode 100644 processing_services/moths/api/api/__init__.py delete mode 100644 processing_services/moths/api/api/algorithms.py delete mode 100644 processing_services/moths/api/api/api.py delete mode 100644 processing_services/moths/api/api/pipelines.py delete mode 100644 processing_services/moths/api/api/schemas.py delete mode 100644 processing_services/moths/api/api/test.py delete mode 100644 processing_services/moths/api/api/utils.py delete mode 100644 processing_services/moths/api/docker-compose.yml delete mode 100644 processing_services/moths/api/main.py delete mode 100644 processing_services/moths/api/requirements.txt delete mode 100644 processing_services/moths/api/test_api_integration.py delete mode 100644 processing_services/moths/api/test_compilation_logging.py delete mode 100644 processing_services/moths/api/test_global_moth_pipeline.py delete mode 100644 processing_services/moths/requirements.txt delete mode 100644 processing_services/moths/test_api_integration.py delete mode 100644 processing_services/moths/test_global_moth_pipeline.py diff --git a/processing_services/example/.gitignore b/processing_services/example/.gitignore new file mode 100644 index 000000000..95312cdaf --- /dev/null +++ b/processing_services/example/.gitignore @@ -0,0 +1,4 @@ +# Cache directories for models and dependencies +cache/ +huggingface_cache/ +pytorch_cache/ diff --git a/processing_services/example/api/api.py b/processing_services/example/api/api.py index 79ce5d83c..43a6205db 100644 --- a/processing_services/example/api/api.py +++ b/processing_services/example/api/api.py @@ -11,6 +11,7 @@ ZeroShotHFClassifierPipeline, ZeroShotObjectDetectorPipeline, ZeroShotObjectDetectorWithConstantClassifierPipeline, + ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, ) from .schemas import ( @@ -41,6 +42,7 @@ ZeroShotObjectDetectorPipeline, ZeroShotObjectDetectorWithConstantClassifierPipeline, ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, + ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, ] pipeline_choices: dict[str, type[Pipeline]] = {pipeline.config.slug: pipeline for pipeline in pipelines} algorithm_choices: dict[str, AlgorithmConfigResponse] = { diff --git a/processing_services/moths/api/api/base.py b/processing_services/example/api/base.py similarity index 92% rename from processing_services/moths/api/api/base.py rename to processing_services/example/api/base.py index bb001ec46..f8e781dc5 100644 --- a/processing_services/moths/api/api/base.py +++ b/processing_services/example/api/base.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any, Dict, Optional +from typing import Any import torch import torchvision.transforms @@ -34,14 +34,14 @@ class SimplifiedInferenceBase: name: str = "Unknown Inference Model" description: str = "" - weights_path: Optional[str] = None - labels_path: Optional[str] = None - category_map: Dict[int, str] = {} - num_classes: Optional[int] = None + weights_path: str | None = None + labels_path: str | None = None + category_map: dict[int, str] = {} + num_classes: int | None = None default_taxon_rank: str = "SPECIES" normalization = tensorflow_normalization batch_size: int = 4 - device: Optional[str] = None + device: str | None = None def __init__(self, **kwargs): # Override any class attributes with provided kwargs @@ -56,9 +56,7 @@ def __init__(self, **kwargs): self.weights = self.get_weights(self.weights_path) self.transforms = self.get_transforms() - logger.info( - f"Loading model for {self.name} with {len(self.category_map or [])} categories" - ) + logger.info(f"Loading model for {self.name} with {len(self.category_map or [])} categories") self.model = self.get_model() @classmethod @@ -69,7 +67,7 @@ def get_key(cls) -> str: else: return cls.name.lower().replace(" ", "-").replace("/", "-") - def get_weights(self, weights_path: Optional[str]) -> Optional[str]: + def get_weights(self, weights_path: str | None) -> str | None: """Download and cache model weights.""" if weights_path: logger.info(f"⬇️ Downloading model weights from: {weights_path}") @@ -80,7 +78,7 @@ def get_weights(self, weights_path: Optional[str]) -> Optional[str]: logger.warning(f"No weights specified for model {self.name}") return None - def get_labels(self, labels_path: Optional[str]) -> Dict[int, str]: + def get_labels(self, labels_path: str | None) -> dict[int, str]: """Download and load category labels.""" if not labels_path: return {} @@ -220,9 +218,7 @@ def get_model(self) -> torch.nn.Module: import timm # Create timm ResNet50 model - model = timm.create_model( - "resnet50", pretrained=False, num_classes=self.num_classes - ) + model = timm.create_model("resnet50", pretrained=False, num_classes=self.num_classes) # Load pretrained weights if self.weights: diff --git a/processing_services/moths/api/api/global_moth_classifier.py b/processing_services/example/api/global_moth_classifier.py similarity index 95% rename from processing_services/moths/api/api/global_moth_classifier.py rename to processing_services/example/api/global_moth_classifier.py index 9d6abe8a0..586cc9cfd 100644 --- a/processing_services/moths/api/api/global_moth_classifier.py +++ b/processing_services/example/api/global_moth_classifier.py @@ -6,7 +6,6 @@ import datetime import logging -from typing import List import torch import torchvision.transforms @@ -57,22 +56,22 @@ def __init__(self, **kwargs): """Initialize the global moth classifier.""" # Initialize Algorithm parent class Algorithm.__init__(self) - + # Store kwargs for later use in compile() self._init_kwargs = kwargs - + # Initialize basic attributes without loading model self.model = None self.transforms = None self.category_map = {} # Empty dict for now self.num_classes = 29176 # Known number of classes - + logger.info(f"Initialized {self.name} (model loading deferred to compile())") @property def algorithm_config_response(self) -> AlgorithmConfigResponse: """Get algorithm configuration for API response.""" - if not hasattr(self, '_algorithm_config_response'): + if not hasattr(self, "_algorithm_config_response"): # Create a basic config response before compilation self._algorithm_config_response = AlgorithmConfigResponse( name=self.name, @@ -102,20 +101,20 @@ def compile(self): if self.model is not None: logger.info("Model already compiled, skipping...") return - + logger.info(f"🔧 Compiling {self.name}...") logger.info(f" 📊 Expected classes: {self.num_classes}") logger.info(f" 🏷️ Labels URL: {self.labels_path}") logger.info(f" ⚖️ Weights URL: {self.weights_path}") - + # Initialize the TimmResNet50Base now (this will download weights/labels) logger.info(" 📥 Downloading model weights and labels...") TimmResNet50Base.__init__(self, **self._init_kwargs) - + # Set algorithm config response logger.info(" 📋 Setting up algorithm configuration...") self.algorithm_config_response = self.get_algorithm_config_response() - + logger.info(f"✅ {self.name} compiled successfully!") logger.info(f" 📊 Loaded {len(self.category_map)} species categories") logger.info(f" 🔧 Model device: {getattr(self, 'device', 'unknown')}") @@ -131,7 +130,7 @@ def get_transforms(self) -> torchvision.transforms.Compose: ] ) - def run(self, detections: List[Detection]) -> List[Detection]: + def run(self, detections: list[Detection]) -> list[Detection]: """ Run classification on a list of detections. @@ -154,7 +153,7 @@ def run(self, detections: List[Detection]) -> List[Detection]: classified_detections = [] for i in range(0, len(detections), self.batch_size): - batch_detections = detections[i : i + self.batch_size] + batch_detections = detections[i: i + self.batch_size] batch_images = [] # Prepare batch of images @@ -217,9 +216,7 @@ def run(self, detections: List[Detection]) -> List[Detection]: def get_category_map(self) -> AlgorithmCategoryMapResponse: """Get category map for API response.""" - categories_sorted_by_index = sorted( - self.category_map.items(), key=lambda x: x[0] - ) + categories_sorted_by_index = sorted(self.category_map.items(), key=lambda x: x[0]) categories_data = [ { "index": index, @@ -250,5 +247,3 @@ def get_algorithm_config_response(self) -> AlgorithmConfigResponse: category_map=self.get_category_map(), uri=self.weights_path, ) - - diff --git a/processing_services/example/api/pipelines.py b/processing_services/example/api/pipelines.py index 02b31d0d9..a983009e2 100644 --- a/processing_services/example/api/pipelines.py +++ b/processing_services/example/api/pipelines.py @@ -9,6 +9,7 @@ RandomSpeciesClassifier, ZeroShotObjectDetector, ) +from .global_moth_classifier import GlobalMothClassifier from .schemas import ( Detection, DetectionResponse, @@ -346,3 +347,64 @@ def run(self) -> PipelineResultsResponse: logger.info(f"Successfully processed {len(detections_with_classifications)} detections.") return pipeline_response + + +class ZeroShotObjectDetectorWithGlobalMothClassifierPipeline(Pipeline): + """ + A pipeline that uses the HuggingFace zero shot object detector and the global moth classifier. + This provides high-quality moth species identification with 29,176+ species support. + """ + + batch_sizes = [1, 4] # Detector batch=1, Classifier batch=4 + config = PipelineConfigResponse( + name="Zero Shot Object Detector With Global Moth Classifier Pipeline", + slug="zero-shot-object-detector-with-global-moth-classifier-pipeline", + description=( + "HF zero shot object detector with global moth species classifier. " + "Supports 29,176+ moth species trained on global data." + ), + version=1, + algorithms=[], # Will be populated in get_stages() + ) + + def get_stages(self) -> list[Algorithm]: + zero_shot_object_detector = ZeroShotObjectDetector() + if "candidate_labels" in self.request_config: + zero_shot_object_detector.candidate_labels = self.request_config["candidate_labels"] + + global_moth_classifier = GlobalMothClassifier() + + self.config.algorithms = [ + zero_shot_object_detector.algorithm_config_response, + global_moth_classifier.algorithm_config_response, + ] + + return [zero_shot_object_detector, global_moth_classifier] + + def run(self) -> PipelineResultsResponse: + start_time = datetime.datetime.now() + detections: list[Detection] = [] + + if self.existing_detections: + logger.info("[1/2] Skipping the localizer, use existing detections...") + detections = self.existing_detections + else: + logger.info("[1/2] No existing detections, generating detections...") + detections = self._get_detections(self.stages[0], self.source_images, self.batch_sizes[0]) + + logger.info("[2/2] Running the global moth classifier...") + detections_with_classifications: list[Detection] = self._get_detections( + self.stages[1], detections, self.batch_sizes[1] + ) + + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + pipeline_response: PipelineResultsResponse = self._get_pipeline_response( + detections_with_classifications, elapsed_time + ) + logger.info( + f"Successfully processed {len(detections_with_classifications)} detections with global moth classifier." + ) + + return pipeline_response diff --git a/processing_services/example/api/schemas.py b/processing_services/example/api/schemas.py index 05682c6f1..e56e16c85 100644 --- a/processing_services/example/api/schemas.py +++ b/processing_services/example/api/schemas.py @@ -213,6 +213,7 @@ class Config: "zero-shot-object-detector-pipeline", "zero-shot-object-detector-with-constant-classifier-pipeline", "zero-shot-object-detector-with-random-species-classifier-pipeline", + "zero-shot-object-detector-with-global-moth-classifier-pipeline", ] diff --git a/processing_services/example/api/utils.py b/processing_services/example/api/utils.py index a7fcb6a75..47350e781 100644 --- a/processing_services/example/api/utils.py +++ b/processing_services/example/api/utils.py @@ -10,6 +10,7 @@ import PIL.Image import PIL.ImageFile import requests +import torch logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -56,31 +57,53 @@ def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path: # If path is a local path instead of a URL then urlretrieve will just return that path - destination_dir = pathlib.Path(tempfile.mkdtemp(prefix=tempdir_prefix)) + # Use persistent cache directory instead of temp + cache_root = ( + pathlib.Path("/app/cache") if pathlib.Path("/app").exists() else pathlib.Path.home() / ".antenna_cache" + ) + destination_dir = cache_root / tempdir_prefix fname = pathlib.Path(urlparse(path_or_url).path).name if not destination_dir.exists(): destination_dir.mkdir(parents=True, exist_ok=True) - local_filepath = pathlib.Path(destination_dir) / fname + local_filepath = destination_dir / fname if local_filepath and local_filepath.exists(): - logger.info(f"Using existing {local_filepath}") + logger.info(f"📁 Using cached file: {local_filepath}") return local_filepath else: - logger.info(f"Downloading {path_or_url} to {local_filepath}") + logger.info(f"⬇️ Downloading {path_or_url} to {local_filepath}") headers = {"User-Agent": USER_AGENT} response = requests.get(path_or_url, stream=True, headers=headers) response.raise_for_status() # Raise an exception for HTTP errors with open(local_filepath, "wb") as f: + total_size = int(response.headers.get("content-length", 0)) + downloaded = 0 for chunk in response.iter_content(chunk_size=8192): f.write(chunk) + downloaded += len(chunk) + if total_size > 0: + percent = (downloaded / total_size) * 100 + logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)") resulting_filepath = pathlib.Path(local_filepath).resolve() + logger.info(f"✅ Download completed: {resulting_filepath}") logger.info(f"Downloaded to {resulting_filepath}") return resulting_filepath +def get_best_device() -> str: + """ + Returns the best available device for running the model. + MPS is not supported by the current algorithms. + """ + if torch.cuda.is_available(): + return f"cuda:{torch.cuda.current_device()}" + else: + return "cpu" + + def open_image(fp: str | bytes | pathlib.Path | io.BytesIO, raise_exception: bool = True) -> PIL.Image.Image | None: """ Wrapper from PIL.Image.open that handles errors and converts to RGB. diff --git a/processing_services/example/docker-compose.yml b/processing_services/example/docker-compose.yml index 83db6ccfa..2df8b49e9 100644 --- a/processing_services/example/docker-compose.yml +++ b/processing_services/example/docker-compose.yml @@ -4,6 +4,7 @@ services: context: . volumes: - ./:/app:z + - ./cache:/app/cache # Model weights and labels cache - ./huggingface_cache:/root/.cache/huggingface - ./pytorch_cache:/root/.cache/torch ports: diff --git a/processing_services/example/requirements.txt b/processing_services/example/requirements.txt index eccbee47a..251693690 100644 --- a/processing_services/example/requirements.txt +++ b/processing_services/example/requirements.txt @@ -6,4 +6,5 @@ requests==2.32.4 transformers==4.50.3 torch==2.6.0 torchvision==0.21.0 +timm==1.0.11 scipy==1.16.0 diff --git a/processing_services/moths/Dockerfile b/processing_services/moths/Dockerfile deleted file mode 100644 index 3e0781f92..000000000 --- a/processing_services/moths/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM python:3.11-slim - -# Set up ml backend FastAPI -WORKDIR /app -COPY . /app -RUN pip install -r ./requirements.txt -CMD ["python", "/app/main.py"] diff --git a/processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md b/processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md deleted file mode 100644 index 223ada70d..000000000 --- a/processing_services/moths/README_GLOBAL_MOTH_CLASSIFIER.md +++ /dev/null @@ -1,203 +0,0 @@ -# Global Moth Classifier Processing Service Implementation - -## Overview - -Successfully implemented a simplified, self-contained module for running the global moth classifier within the processing services framework. This eliminates the need for database and queue dependencies while maintaining full classifier functionality. - -## Architecture Summary - -### Original vs. Simplified Architecture - -**Original trapdata architecture:** -- Complex inheritance: `MothClassifierGlobal` → `APIMothClassifier` + `GlobalMothSpeciesClassifier` → multiple base classes -- Database dependencies: `db_path`, `QueueManager`, `InferenceBaseClass` with DB connections -- Queue system: `UnclassifiedObjectQueue`, `DetectedObjectQueue` -- File management: `user_data_path`, complex caching system - -**New simplified architecture:** -- Flattened inheritance: `GlobalMothClassifier` → `Algorithm` + `TimmResNet50Base` → `SimplifiedInferenceBase` -- No database dependencies: Direct PIL image processing -- No queue system: Processes Detection objects directly -- Simple file management: Downloads models to temp directories as needed - -### Key Components Created - -1. **Base Classes** (`base.py`): - - `SimplifiedInferenceBase`: Core inference functionality without DB dependencies - - `ResNet50Base`: ResNet50-specific model loading and inference - - `TimmResNet50Base`: Timm-based ResNet50 implementation - -2. **Global Moth Classifier** (`global_moth_classifier.py`): - - `GlobalMothClassifier`: Simplified version of the original classifier - - 29,176+ species support - - Batch processing capabilities - - Algorithm interface compatibility - -3. **New Pipeline** (`pipelines.py`): - - `ZeroShotObjectDetectorWithGlobalMothClassifierPipeline`: Combines HF zero-shot detector with global moth classifier - - Two-stage processing: detection → classification - - Configurable batch sizes for optimal performance - -4. **Updated Utils** (`utils.py`): - - Added `get_best_device()` for GPU/CPU selection - - Enhanced `get_or_download_file()` for model weight downloading - -## Data Flow - -``` -PipelineRequest - ↓ -SourceImages (PIL images) - ↓ -ZeroShotObjectDetector (stage 1) - ↓ -Detections with bounding boxes - ↓ -GlobalMothClassifier (stage 2) - ↓ -Detections with species classifications - ↓ -PipelineResponse -``` - -## API Integration - -The new pipeline is now available in the processing service API: - -- **Pipeline Name**: "Zero Shot Object Detector With Global Moth Classifier Pipeline" -- **Slug**: `zero-shot-object-detector-with-global-moth-classifier-pipeline` -- **Algorithms**: 2 (detector + classifier) -- **Batch Sizes**: [1, 4] (detector=1, classifier=4 for efficiency) - -## Key Differences from trapdata - -1. **No Database Dependencies**: - - Removed: `db_path`, `QueueManager`, `save_classified_objects()` - - Uses: Direct Detection object processing - -2. **Simplified File Management**: - - Removed: Complex `user_data_path` caching - - Uses: Temporary directories for model downloads - -3. **Flattened Inheritance**: - - Removed: Complex multi-level inheritance chains - - Uses: Simple Algorithm + base class pattern - -4. **Direct Image Processing**: - - Removed: Database-backed image references - - Uses: PIL images attached to Detection objects - -5. **API-First Design**: - - Removed: CLI and database queue processing - - Focused: REST API pipeline processing only - -## Benefits - -1. **Simplicity**: Much easier to understand and maintain -2. **Performance**: No database overhead, direct processing -3. **Portability**: Self-contained, minimal dependencies -4. **Scalability**: Stateless processing suitable for containerization -5. **Maintainability**: Clear separation of concerns, focused functionality - -## Usage Example - -```python -from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline -from api.schemas import SourceImage - -# Create pipeline -pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( - source_images=[source_image], - request_config={"candidate_labels": ["moth", "insect"]}, - existing_detections=[] -) - -# Compile and run -pipeline.compile() -results = pipeline.run() -``` - -## Files Created/Modified - -- ✅ `processing_services/example/api/base.py` - New simplified base classes -- ✅ `processing_services/example/api/global_moth_classifier.py` - New global moth classifier -- ✅ `processing_services/example/api/pipelines.py` - Added new pipeline class -- ✅ `processing_services/example/api/utils.py` - Enhanced utility functions -- ✅ `processing_services/example/api/api.py` - Added new pipeline to API -- ✅ `processing_services/example/test_global_moth_pipeline.py` - Basic test file - -## Original trapdata Source Analysis - -To create this simplified implementation, the following files and line ranges from the original AMI Data Companion (trapdata) module were analyzed: - -### Core Classification Classes -- **`trapdata/api/models/classification.py`**: - - Lines 1-25: Import statements and base dependencies - - Lines 37-163: `APIMothClassifier` base class implementation - - Lines 165-209: All classifier implementations including `MothClassifierGlobal` (line 207) - - Lines 112-137: `save_results()` method for processing predictions - - Lines 138-163: `update_classification()` and pipeline integration - -### Base Inference Framework -- **`trapdata/ml/models/base.py`**: - - Lines 58-120: `InferenceBaseClass` core structure and initialization - - Lines 121-200: Model loading, transforms, and file management methods - - Lines 25-50: Normalization constants and utility functions - -### Global Moth Classifier Implementation -- **`trapdata/ml/models/classification.py`**: - - Lines 507-527: `GlobalMothSpeciesClassifier` class definition and configuration - - Lines 338-375: `SpeciesClassifier` base class and database integration - - Lines 527-567: Various regional classifiers showing inheritance patterns - - Lines 1-50: Import structure and database dataset classes - - Lines 200-300: ResNet50 and Timm-based classifier implementations - -### API Integration Patterns -- **`trapdata/api/api.py`**: - - Lines 1-50: FastAPI setup and classifier imports - - Lines 37-60: `CLASSIFIER_CHOICES` dictionary including global moths - - Lines 120-150: `make_pipeline_config_response()` function - - Lines 175-310: Main processing pipeline in `process()` endpoint - - Lines 60-80: Pipeline choice enumeration and filtering logic - -### Model Architecture References -- **`trapdata/ml/models/localization.py`**: - - Lines 142-200: `ObjectDetector` base class structure - - Lines 245-290: `MothObjectDetector_FasterRCNN_2023` implementation - -### API Schema Definitions -- **`trapdata/api/schemas.py`**: - - Lines 293-330: Pipeline configuration schemas - - Lines 1-100: Detection and classification response schemas - -### Processing Pipeline Examples -- **`trapdata/api/models/localization.py`**: - - Lines 13-60: `APIMothDetector` implementation showing API adaptation pattern - -### Key Configuration Values -From the analysis, these critical configuration values were extracted: -- **Model weights URL**: Lines 507-515 in `classification.py` -- **Labels path**: Lines 516-520 in `classification.py` -- **Input size**: Line 508 (`input_size = 128`) -- **Normalization**: Line 509 (`normalization = imagenet_normalization`) -- **Species count**: 29,176 species from model description -- **Default taxon rank**: "SPECIES" from base class - -### Database Dependencies Removed -These database-dependent components were identified and removed: -- **`trapdata/db/models/queue.py`**: Lines 1-500 (entire queue system) -- **`trapdata/db/models/detections.py`**: `save_classified_objects()` function -- **Database path parameters**: Throughout `base.py` and classification classes -- **Queue managers**: `UnclassifiedObjectQueue`, `DetectedObjectQueue` references - -This analysis allowed for the creation of a streamlined implementation that preserves all the essential functionality while eliminating the complex database and queue infrastructure. - -## Next Steps - -1. **Testing**: Run end-to-end tests with real images -2. **Performance Optimization**: Tune batch sizes and memory usage -3. **Error Handling**: Add robust error handling for edge cases -4. **Documentation**: Add detailed API documentation -5. **Docker Integration**: Update Docker configurations if needed - -The implementation successfully provides a clean, maintainable global moth classifier that can process 29,176+ species without the complexity of the original trapdata system. diff --git a/processing_services/moths/api/Dockerfile b/processing_services/moths/api/Dockerfile deleted file mode 100644 index 3e0781f92..000000000 --- a/processing_services/moths/api/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -FROM python:3.11-slim - -# Set up ml backend FastAPI -WORKDIR /app -COPY . /app -RUN pip install -r ./requirements.txt -CMD ["python", "/app/main.py"] diff --git a/processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md b/processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md deleted file mode 100644 index 223ada70d..000000000 --- a/processing_services/moths/api/README_GLOBAL_MOTH_CLASSIFIER.md +++ /dev/null @@ -1,203 +0,0 @@ -# Global Moth Classifier Processing Service Implementation - -## Overview - -Successfully implemented a simplified, self-contained module for running the global moth classifier within the processing services framework. This eliminates the need for database and queue dependencies while maintaining full classifier functionality. - -## Architecture Summary - -### Original vs. Simplified Architecture - -**Original trapdata architecture:** -- Complex inheritance: `MothClassifierGlobal` → `APIMothClassifier` + `GlobalMothSpeciesClassifier` → multiple base classes -- Database dependencies: `db_path`, `QueueManager`, `InferenceBaseClass` with DB connections -- Queue system: `UnclassifiedObjectQueue`, `DetectedObjectQueue` -- File management: `user_data_path`, complex caching system - -**New simplified architecture:** -- Flattened inheritance: `GlobalMothClassifier` → `Algorithm` + `TimmResNet50Base` → `SimplifiedInferenceBase` -- No database dependencies: Direct PIL image processing -- No queue system: Processes Detection objects directly -- Simple file management: Downloads models to temp directories as needed - -### Key Components Created - -1. **Base Classes** (`base.py`): - - `SimplifiedInferenceBase`: Core inference functionality without DB dependencies - - `ResNet50Base`: ResNet50-specific model loading and inference - - `TimmResNet50Base`: Timm-based ResNet50 implementation - -2. **Global Moth Classifier** (`global_moth_classifier.py`): - - `GlobalMothClassifier`: Simplified version of the original classifier - - 29,176+ species support - - Batch processing capabilities - - Algorithm interface compatibility - -3. **New Pipeline** (`pipelines.py`): - - `ZeroShotObjectDetectorWithGlobalMothClassifierPipeline`: Combines HF zero-shot detector with global moth classifier - - Two-stage processing: detection → classification - - Configurable batch sizes for optimal performance - -4. **Updated Utils** (`utils.py`): - - Added `get_best_device()` for GPU/CPU selection - - Enhanced `get_or_download_file()` for model weight downloading - -## Data Flow - -``` -PipelineRequest - ↓ -SourceImages (PIL images) - ↓ -ZeroShotObjectDetector (stage 1) - ↓ -Detections with bounding boxes - ↓ -GlobalMothClassifier (stage 2) - ↓ -Detections with species classifications - ↓ -PipelineResponse -``` - -## API Integration - -The new pipeline is now available in the processing service API: - -- **Pipeline Name**: "Zero Shot Object Detector With Global Moth Classifier Pipeline" -- **Slug**: `zero-shot-object-detector-with-global-moth-classifier-pipeline` -- **Algorithms**: 2 (detector + classifier) -- **Batch Sizes**: [1, 4] (detector=1, classifier=4 for efficiency) - -## Key Differences from trapdata - -1. **No Database Dependencies**: - - Removed: `db_path`, `QueueManager`, `save_classified_objects()` - - Uses: Direct Detection object processing - -2. **Simplified File Management**: - - Removed: Complex `user_data_path` caching - - Uses: Temporary directories for model downloads - -3. **Flattened Inheritance**: - - Removed: Complex multi-level inheritance chains - - Uses: Simple Algorithm + base class pattern - -4. **Direct Image Processing**: - - Removed: Database-backed image references - - Uses: PIL images attached to Detection objects - -5. **API-First Design**: - - Removed: CLI and database queue processing - - Focused: REST API pipeline processing only - -## Benefits - -1. **Simplicity**: Much easier to understand and maintain -2. **Performance**: No database overhead, direct processing -3. **Portability**: Self-contained, minimal dependencies -4. **Scalability**: Stateless processing suitable for containerization -5. **Maintainability**: Clear separation of concerns, focused functionality - -## Usage Example - -```python -from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline -from api.schemas import SourceImage - -# Create pipeline -pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( - source_images=[source_image], - request_config={"candidate_labels": ["moth", "insect"]}, - existing_detections=[] -) - -# Compile and run -pipeline.compile() -results = pipeline.run() -``` - -## Files Created/Modified - -- ✅ `processing_services/example/api/base.py` - New simplified base classes -- ✅ `processing_services/example/api/global_moth_classifier.py` - New global moth classifier -- ✅ `processing_services/example/api/pipelines.py` - Added new pipeline class -- ✅ `processing_services/example/api/utils.py` - Enhanced utility functions -- ✅ `processing_services/example/api/api.py` - Added new pipeline to API -- ✅ `processing_services/example/test_global_moth_pipeline.py` - Basic test file - -## Original trapdata Source Analysis - -To create this simplified implementation, the following files and line ranges from the original AMI Data Companion (trapdata) module were analyzed: - -### Core Classification Classes -- **`trapdata/api/models/classification.py`**: - - Lines 1-25: Import statements and base dependencies - - Lines 37-163: `APIMothClassifier` base class implementation - - Lines 165-209: All classifier implementations including `MothClassifierGlobal` (line 207) - - Lines 112-137: `save_results()` method for processing predictions - - Lines 138-163: `update_classification()` and pipeline integration - -### Base Inference Framework -- **`trapdata/ml/models/base.py`**: - - Lines 58-120: `InferenceBaseClass` core structure and initialization - - Lines 121-200: Model loading, transforms, and file management methods - - Lines 25-50: Normalization constants and utility functions - -### Global Moth Classifier Implementation -- **`trapdata/ml/models/classification.py`**: - - Lines 507-527: `GlobalMothSpeciesClassifier` class definition and configuration - - Lines 338-375: `SpeciesClassifier` base class and database integration - - Lines 527-567: Various regional classifiers showing inheritance patterns - - Lines 1-50: Import structure and database dataset classes - - Lines 200-300: ResNet50 and Timm-based classifier implementations - -### API Integration Patterns -- **`trapdata/api/api.py`**: - - Lines 1-50: FastAPI setup and classifier imports - - Lines 37-60: `CLASSIFIER_CHOICES` dictionary including global moths - - Lines 120-150: `make_pipeline_config_response()` function - - Lines 175-310: Main processing pipeline in `process()` endpoint - - Lines 60-80: Pipeline choice enumeration and filtering logic - -### Model Architecture References -- **`trapdata/ml/models/localization.py`**: - - Lines 142-200: `ObjectDetector` base class structure - - Lines 245-290: `MothObjectDetector_FasterRCNN_2023` implementation - -### API Schema Definitions -- **`trapdata/api/schemas.py`**: - - Lines 293-330: Pipeline configuration schemas - - Lines 1-100: Detection and classification response schemas - -### Processing Pipeline Examples -- **`trapdata/api/models/localization.py`**: - - Lines 13-60: `APIMothDetector` implementation showing API adaptation pattern - -### Key Configuration Values -From the analysis, these critical configuration values were extracted: -- **Model weights URL**: Lines 507-515 in `classification.py` -- **Labels path**: Lines 516-520 in `classification.py` -- **Input size**: Line 508 (`input_size = 128`) -- **Normalization**: Line 509 (`normalization = imagenet_normalization`) -- **Species count**: 29,176 species from model description -- **Default taxon rank**: "SPECIES" from base class - -### Database Dependencies Removed -These database-dependent components were identified and removed: -- **`trapdata/db/models/queue.py`**: Lines 1-500 (entire queue system) -- **`trapdata/db/models/detections.py`**: `save_classified_objects()` function -- **Database path parameters**: Throughout `base.py` and classification classes -- **Queue managers**: `UnclassifiedObjectQueue`, `DetectedObjectQueue` references - -This analysis allowed for the creation of a streamlined implementation that preserves all the essential functionality while eliminating the complex database and queue infrastructure. - -## Next Steps - -1. **Testing**: Run end-to-end tests with real images -2. **Performance Optimization**: Tune batch sizes and memory usage -3. **Error Handling**: Add robust error handling for edge cases -4. **Documentation**: Add detailed API documentation -5. **Docker Integration**: Update Docker configurations if needed - -The implementation successfully provides a clean, maintainable global moth classifier that can process 29,176+ species without the complexity of the original trapdata system. diff --git a/processing_services/moths/api/api/__init__.py b/processing_services/moths/api/api/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/processing_services/moths/api/api/algorithms.py b/processing_services/moths/api/api/algorithms.py deleted file mode 100644 index 8a80038dd..000000000 --- a/processing_services/moths/api/api/algorithms.py +++ /dev/null @@ -1,431 +0,0 @@ -import datetime -import logging -import math -import random - -import torch - -from .schemas import ( - AlgorithmCategoryMapResponse, - AlgorithmConfigResponse, - AlgorithmReference, - BoundingBox, - ClassificationResponse, - Detection, - SourceImage, -) - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -SAVED_MODELS = {} - - -def get_best_device() -> str: - """ - Returns the best available device for running the model. - - MPS is not supported by the current algoritms. - """ - if torch.cuda.is_available(): - return f"cuda:{torch.cuda.current_device()}" - else: - return "cpu" - - -class Algorithm: - algorithm_config_response: AlgorithmConfigResponse - - def compile(self): - raise NotImplementedError("Subclasses must implement the compile method") - - def run(self, inputs: list[SourceImage] | list[Detection]) -> list[Detection]: - raise NotImplementedError("Subclasses must implement the run method") - - def get_category_map(self) -> AlgorithmCategoryMapResponse: - return AlgorithmCategoryMapResponse( - data=[], - labels=[], - version="v1", - description="A model without labels.", - uri=None, - ) - - def get_algorithm_config_response(self) -> AlgorithmConfigResponse: - return AlgorithmConfigResponse( - name="Base Algorithm", - key="base", - task_type="base", - description="A base class for all algorithms.", - version=1, - version_name="v1", - category_map=self.get_category_map(), - ) - - def __init__(self): - self.algorithm_config_response = self.get_algorithm_config_response() - - -class ZeroShotObjectDetector(Algorithm): - """ - Huggingface Zero-Shot Object Detection model. - Produces both a bounding box and a classification for each detection. - The classification is based on the candidate labels. - """ - - candidate_labels: list[str] = ["insect"] - - def compile(self, device: str | None = None): - saved_models_key = "zero_shot_object_detector" # generate a key for each uniquely compiled algorithm - - if saved_models_key not in SAVED_MODELS: - from transformers import pipeline - - device_choice = device or get_best_device() - device_index = int(device_choice.split(":")[-1]) if ":" in device_choice else -1 - logger.info(f"Compiling {self.algorithm_config_response.name} on device {device_choice}...") - checkpoint = "google/owlv2-base-patch16-ensemble" - self.model = pipeline( - model=checkpoint, - task="zero-shot-object-detection", - use_fast=True, - device=device_index, - ) - SAVED_MODELS[saved_models_key] = self.model - else: - logger.info(f"Using saved model for {self.algorithm_config_response.name}...") - self.model = SAVED_MODELS[saved_models_key] - - def run(self, source_images: list[SourceImage], intermediate=False) -> list[Detection]: - detector_responses: list[Detection] = [] - for source_image in source_images: - if source_image.width and source_image.height and source_image._pil: - start_time = datetime.datetime.now() - logger.info("Predicting...") - if not self.candidate_labels: - raise ValueError("No candidate labels are provided during inference.") - logger.info(f"Predicting with candidate labels: {self.candidate_labels}") - predictions = self.model(source_image._pil, candidate_labels=self.candidate_labels) - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - - for prediction in predictions: - logger.info("Prediction: %s", prediction) - bbox = BoundingBox( - x1=prediction["box"]["xmin"], - x2=prediction["box"]["xmax"], - y1=prediction["box"]["ymin"], - y2=prediction["box"]["ymax"], - ) - cropped_image_pil = source_image._pil.crop((bbox.x1, bbox.y1, bbox.x2, bbox.y2)) - detection = Detection( - id=f"{source_image.id}-crop-{bbox.x1}-{bbox.y1}-{bbox.x2}-{bbox.y2}", - url=source_image.url, # @TODO: ideally, should save cropped image at separate url - width=cropped_image_pil.width, - height=cropped_image_pil.height, - timestamp=datetime.datetime.now(), - source_image=source_image, - bbox=bbox, - inference_time=elapsed_time, - algorithm=AlgorithmReference( - name=self.algorithm_config_response.name, - key=self.algorithm_config_response.key, - ), - classifications=[ - ClassificationResponse( - classification=prediction["label"], - labels=[prediction["label"]], - scores=[prediction["score"]], - logits=[prediction["score"]], - inference_time=elapsed_time, - timestamp=datetime.datetime.now(), - algorithm=AlgorithmReference( - name=self.algorithm_config_response.name, - key=self.algorithm_config_response.key, - ), - terminal=not intermediate, - ) - ], - ) - detection._pil = cropped_image_pil - detector_responses.append(detection) - else: - raise ValueError(f"Source image {source_image.id} does not have width and height attributes.") - - return detector_responses - - def get_category_map(self) -> AlgorithmCategoryMapResponse: - return AlgorithmCategoryMapResponse( - data=[{"index": i, "label": label} for i, label in enumerate(self.candidate_labels)], - labels=self.candidate_labels, - version="v1", - description="Candidate labels used for zero-shot object detection.", - uri=None, - ) - - def get_algorithm_config_response(self) -> AlgorithmConfigResponse: - return AlgorithmConfigResponse( - name="Zero Shot Object Detector", - key="zero-shot-object-detector", - task_type="detection", - description=( - "Huggingface Zero Shot Object Detection model." - "Produces both a bounding box and a candidate label classification for each detection." - ), - version=1, - version_name="v1", - category_map=self.get_category_map(), - ) - - -class HFImageClassifier(Algorithm): - """ - A local classifier that uses the Hugging Face pipeline to classify images. - """ - - model_name: str = "google/vit-base-patch16-224" # Vision Transformer model trained on ImageNet-1k - - def compile(self): - saved_models_key = "hf_image_classifier" # generate a key for each uniquely compiled algorithm - - if saved_models_key not in SAVED_MODELS: - from transformers import pipeline - - logger.info(f"Compiling {self.algorithm_config_response.name} from scratch...") - self.model = pipeline("image-classification", model=self.model_name, device=get_best_device()) - SAVED_MODELS[saved_models_key] = self.model - else: - logger.info(f"Using saved model for {self.algorithm_config_response.name}...") - self.model = SAVED_MODELS[saved_models_key] - - def run(self, detections: list[Detection]) -> list[Detection]: - detections_to_return: list[Detection] = [] - start_time = datetime.datetime.now() - - opened_cropped_images = [detection._pil for detection in detections] # type: ignore - - # Process the entire batch of cropped images at once - results = self.model(images=opened_cropped_images) - - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - - for detection, preds in zip(detections, results): - labels = [pred["label"] for pred in preds] - scores = [pred["score"] for pred in preds] - max_score_index = scores.index(max(scores)) - classification = labels[max_score_index] - logger.info(f"Classification: {classification}") - logger.info(f"labels: {labels}") - logger.info(f"scores: {scores}") - - existing_classifications = detection.classifications - - detection_with_classification = detection.copy(deep=True) - detection_with_classification.classifications = existing_classifications + [ - ClassificationResponse( - classification=classification, - labels=labels, - scores=scores, - logits=scores, - inference_time=elapsed_time, - timestamp=datetime.datetime.now(), - algorithm=AlgorithmReference( - name=self.algorithm_config_response.name, key=self.algorithm_config_response.key - ), - terminal=True, - ) - ] - - detections_to_return.append(detection_with_classification) - - return detections_to_return - - def get_category_map(self) -> AlgorithmCategoryMapResponse: - """ - Extract the category map from the model. - Returns an AlgorithmCategoryMapResponse with labels, data, and model information. - """ - from transformers.models.auto.configuration_auto import AutoConfig - - logger.info(f"Loading configuration for {self.model_name}") - config = AutoConfig.from_pretrained(self.model_name) - - # Extract label information - if not hasattr(config, "id2label") or not config.id2label: - raise ValueError( - f"Cannot create category map for model {self.model_name}, no id2label mapping found in config" - ) - else: - # Sort labels by index - # Ensure keys are strings for consistent access - id2label: dict[str, str] = {str(k): v for k, v in config.id2label.items()} - indices = sorted([int(k) for k in id2label.keys()]) - - # Create labels and data - labels = [id2label[str(i)] for i in indices] - data = [{"label": label, "index": idx} for idx, label in zip(indices, labels)] - - # Build description - description_text = ( - f"Vision Transformer model trained on ImageNet-1k. " - f"Contains {len(labels)} object classes. Model: {self.model_name}" - ) - - return AlgorithmCategoryMapResponse( - data=data, - labels=labels, - version="ImageNet-1k", - description=description_text, - uri=f"https://huggingface.co/{self.model_name}", - ) - - def get_algorithm_config_response(self) -> AlgorithmConfigResponse: - return AlgorithmConfigResponse( - name="HF Image Classifier", - key="hf-image-classifier", - task_type="classification", - description="HF ViT for image classification.", - version=1, - version_name="v1", - category_map=self.get_category_map(), - ) - - -class RandomSpeciesClassifier(Algorithm): - """ - A local classifier that produces random butterfly species classifications. - """ - - def compile(self): - pass - - def _make_random_prediction( - self, - terminal: bool = True, - max_labels: int = 2, - ) -> ClassificationResponse: - assert self.algorithm_config_response.category_map is not None - category_labels = self.algorithm_config_response.category_map.labels - logits = [random.random() for _ in category_labels] - softmax = [math.exp(logit) / sum([math.exp(logit) for logit in logits]) for logit in logits] - top_class = category_labels[softmax.index(max(softmax))] - return ClassificationResponse( - classification=top_class, - labels=category_labels if len(category_labels) <= max_labels else None, - scores=softmax, - logits=logits, - timestamp=datetime.datetime.now(), - algorithm=AlgorithmReference( - name=self.algorithm_config_response.name, - key=self.algorithm_config_response.key, - ), - terminal=terminal, - ) - - def run(self, detections: list[Detection]) -> list[Detection]: - detections_to_return: list[Detection] = [] - for detection in detections: - detection_with_classification = detection.copy(deep=True) - detection_with_classification.classifications = [self._make_random_prediction(terminal=True)] - detections_to_return.append(detection_with_classification) - return detections_to_return - - algorithm_config_response = AlgorithmConfigResponse( - name="Random species classifier", - key="random-species-classifier", - task_type="classification", - description="A random species classifier", - version=1, - version_name="v1", - uri="https://huggingface.co/RolnickLab/random-species-classifier", - category_map=AlgorithmCategoryMapResponse( - data=[ - { - "index": 0, - "gbif_key": "1234", - "label": "Vanessa atalanta", - "source": "manual", - "taxon_rank": "SPECIES", - }, - { - "index": 1, - "gbif_key": "4543", - "label": "Vanessa cardui", - "source": "manual", - "taxon_rank": "SPECIES", - }, - { - "index": 2, - "gbif_key": "7890", - "label": "Vanessa itea", - "source": "manual", - "taxon_rank": "SPECIES", - }, - ], - labels=["Vanessa atalanta", "Vanessa cardui", "Vanessa itea"], - version="v1", - description="A simple species classifier", - uri="https://huggingface.co/RolnickLab/random-species-classifier", - ), - ) - - -class ConstantClassifier(Algorithm): - """ - A local classifier that always returns a constant species classification. - """ - - def compile(self): - pass - - def _make_constant_prediction( - self, - terminal: bool = True, - ) -> ClassificationResponse: - assert self.algorithm_config_response.category_map is not None - labels = self.algorithm_config_response.category_map.labels - return ClassificationResponse( - classification=labels[0], - labels=labels, - scores=[0.9], # Constant score for each detection - timestamp=datetime.datetime.now(), - algorithm=AlgorithmReference( - name=self.algorithm_config_response.name, - key=self.algorithm_config_response.key, - ), - terminal=terminal, - ) - - def run(self, detections: list[Detection]) -> list[Detection]: - detections_to_return: list[Detection] = [] - for detection in detections: - detection_with_classification = detection.copy(deep=True) - detection_with_classification.classifications = [self._make_constant_prediction(terminal=True)] - detections_to_return.append(detection_with_classification) - return detections_to_return - - algorithm_config_response = AlgorithmConfigResponse( - name="Constant classifier", - key="constant-classifier", - task_type="classification", - description="Always return a classification of 'Moth'", - version=1, - version_name="v1", - uri="https://huggingface.co/RolnickLab/constant-classifier", - category_map=AlgorithmCategoryMapResponse( - data=[ - { - "index": 0, - "gbif_key": "1234", - "label": "Moth", - "source": "manual", - "taxon_rank": "SUPERFAMILY", - } - ], - labels=["Moth"], - version="v1", - description="A classifier that always returns 'Moth'", - uri="https://huggingface.co/RolnickLab/constant-classifier", - ), - ) diff --git a/processing_services/moths/api/api/api.py b/processing_services/moths/api/api/api.py deleted file mode 100644 index 0396af5e2..000000000 --- a/processing_services/moths/api/api/api.py +++ /dev/null @@ -1,261 +0,0 @@ -""" -Fast API interface for processing images through the localization and classification pipelines. -""" - -import logging - -import fastapi - -from .pipelines import ( - Pipeline, - ZeroShotHFClassifierPipeline, - ZeroShotObjectDetectorPipeline, - ZeroShotObjectDetectorWithConstantClassifierPipeline, - ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, - ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, -) -from .schemas import ( - AlgorithmConfigResponse, - Detection, - DetectionRequest, - PipelineRequest, - PipelineRequestConfigParameters, - PipelineResultsResponse, - ProcessingServiceInfoResponse, - SourceImage, -) -from .utils import is_base64, is_url - -# Configure root logger -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", -) - -# Get the root logger -logger = logging.getLogger(__name__) - -app = fastapi.FastAPI() - - -pipelines: list[type[Pipeline]] = [ - ZeroShotHFClassifierPipeline, - ZeroShotObjectDetectorPipeline, - ZeroShotObjectDetectorWithConstantClassifierPipeline, - ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, - ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, -] -pipeline_choices: dict[str, type[Pipeline]] = { - pipeline.config.slug: pipeline for pipeline in pipelines -} -algorithm_choices: dict[str, AlgorithmConfigResponse] = { - algorithm.key: algorithm - for pipeline in pipelines - for algorithm in pipeline.config.algorithms -} - -# ----------- -# API endpoints -# ----------- - - -@app.get("/") -async def root(): - return fastapi.responses.RedirectResponse("/docs") - - -@app.get("/info", tags=["services"]) -async def info() -> ProcessingServiceInfoResponse: - info = ProcessingServiceInfoResponse( - name="Custom ML Backend", - description=("A template for running custom models locally."), - pipelines=[pipeline.config for pipeline in pipelines], - # algorithms=list(algorithm_choices.values()), - ) - return info - - -# Check if the server is online -@app.get("/livez", tags=["health checks"]) -async def livez(): - return fastapi.responses.JSONResponse(status_code=200, content={"status": True}) - - -# Check if the pipelines are ready to process data -@app.get("/readyz", tags=["health checks"]) -async def readyz(): - """ - Check if the server is ready to process data. - - Returns a list of pipeline slugs that are online and ready to process data. - @TODO may need to simplify this to just return True/False. Pipeline algorithms will likely be loaded into memory - on-demand when the pipeline is selected. - """ - if pipeline_choices: - return fastapi.responses.JSONResponse( - status_code=200, content={"status": list(pipeline_choices.keys())} - ) - else: - return fastapi.responses.JSONResponse(status_code=503, content={"status": []}) - - -@app.post("/process", tags=["services"]) -async def process(data: PipelineRequest) -> PipelineResultsResponse: - pipeline_slug = data.pipeline - request_config = data.config - - source_images = [SourceImage(**img.model_dump()) for img in data.source_images] - # Open source images once before processing - for img in source_images: - img.open(raise_exception=True) - - detections = create_detections( - source_images=source_images, - detection_requests=data.detections, - ) - - try: - Pipeline = pipeline_choices[pipeline_slug] - except KeyError: - raise fastapi.HTTPException( - status_code=422, detail=f"Invalid pipeline choice: {pipeline_slug}" - ) - - pipeline_request_config = ( - PipelineRequestConfigParameters(**dict(request_config)) - if request_config - else {} - ) - try: - pipeline = Pipeline( - source_images=source_images, - request_config=pipeline_request_config, - existing_detections=detections, - ) - pipeline.compile() - except Exception as e: - logger.error(f"Error compiling pipeline: {e}") - raise fastapi.HTTPException(status_code=422, detail=f"{e}") - - try: - response = pipeline.run() - except Exception as e: - logger.error(f"Error running pipeline: {e}") - raise fastapi.HTTPException(status_code=422, detail=f"{e}") - - return response - - -# ----------- -# Helper functions -# ----------- - - -def create_detections( - source_images: list[SourceImage], - detection_requests: list[DetectionRequest] | None, -): - if not detection_requests: - return [] - - # Group detection requests by source image id - source_image_map = {img.id: img for img in source_images} - grouped_detection_requests = {} - for request in detection_requests: - if request.source_image.id not in grouped_detection_requests: - grouped_detection_requests[request.source_image.id] = [] - grouped_detection_requests[request.source_image.id].append(request) - - # Process each source image and its detection requests - detections = [] - for source_image_id, requests in grouped_detection_requests.items(): - if source_image_id not in source_image_map: - raise ValueError( - f"A detection request for source image {source_image_id} was received, " - "but no source image with that ID was provided." - ) - - logger.info( - f"Processing existing detections for source image {source_image_id}." - ) - - for request in requests: - source_image = source_image_map[source_image_id] - cropped_image_id = f"{source_image.id}-crop-{request.bbox.x1}-{request.bbox.y1}-{request.bbox.x2}-{request.bbox.y2}" - if not request.crop_image_url: - logger.info( - "Detection request does not have a crop_image_url, crop the original source image." - ) - assert ( - source_image._pil is not None - ), "Source image must be opened before cropping." - cropped_image_pil = source_image._pil.crop( - (request.bbox.x1, request.bbox.y1, request.bbox.x2, request.bbox.y2) - ) - else: - try: - logger.info( - f"Opening existing cropped image from {request.crop_image_url}." - ) - if is_url(request.crop_image_url): - cropped_image = SourceImage( - id=cropped_image_id, - url=request.crop_image_url, - ) - elif is_base64(request.crop_image_url): - logger.info("Decoding base64 cropped image.") - cropped_image = SourceImage( - id=cropped_image_id, - b64=request.crop_image_url, - ) - else: - # Must be a filepath - cropped_image = SourceImage( - id=cropped_image_id, - filepath=request.crop_image_url, - ) - cropped_image.open(raise_exception=True) - cropped_image_pil = cropped_image._pil - except Exception as e: - logger.warning(f"Error opening cropped image: {e}") - logger.info( - f"Falling back to cropping the original source image {source_image_id}." - ) - assert ( - source_image._pil is not None - ), "Source image must be opened before cropping." - cropped_image_pil = source_image._pil.crop( - ( - request.bbox.x1, - request.bbox.y1, - request.bbox.x2, - request.bbox.y2, - ) - ) - - # Create a Detection object - det = Detection( - source_image=SourceImage( - id=source_image.id, - url=source_image.url, - ), - bbox=request.bbox, - id=cropped_image_id, - url=request.crop_image_url or source_image.url, - algorithm=request.algorithm, - ) - # Set the _pil attribute to the cropped image - det._pil = cropped_image_pil - detections.append(det) - logger.info( - f"Created detection {det.id} for source image {source_image_id}." - ) - - return detections - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=2000) diff --git a/processing_services/moths/api/api/pipelines.py b/processing_services/moths/api/api/pipelines.py deleted file mode 100644 index ceb395cf9..000000000 --- a/processing_services/moths/api/api/pipelines.py +++ /dev/null @@ -1,457 +0,0 @@ -import datetime -import logging -from typing import final - -from .algorithms import ( - Algorithm, - ConstantClassifier, - HFImageClassifier, - RandomSpeciesClassifier, - ZeroShotObjectDetector, -) -from .global_moth_classifier import GlobalMothClassifier -from .schemas import ( - Detection, - DetectionResponse, - PipelineConfigResponse, - PipelineRequestConfigParameters, - PipelineResultsResponse, - SourceImage, - SourceImageResponse, -) - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -class Pipeline: - """ - A base class for defining and running a pipeline consisting of multiple stages. - Each stage is represented by an algorithm that processes inputs and produces - outputs. The pipeline is designed to handle batch processing using custom batch - sizes for each stage. - - Attributes: - stages (list[Algorithm]): A list of algorithms representing the stages of - the pipeline in order of execution. Typically [Detector(), Classifier()]. - batch_sizes (list[int]): A list of integers specifying the batch size for - each stage. For example, [1, 1] means that the detector can process 1 - source image a time and the classifier can process 1 detection at a time. - config (PipelineConfigResponse): Pipeline metadata. - """ - - stages: list[Algorithm] - batch_sizes: list[int] - request_config: dict - config: PipelineConfigResponse - - stages = [] - batch_sizes = [] - config = PipelineConfigResponse( - name="Base Pipeline", - slug="base", - description="A base class for all pipelines.", - version=1, - algorithms=[], - ) - - def __init__( - self, - source_images: list[SourceImage], - request_config: PipelineRequestConfigParameters | dict = {}, - existing_detections: list[Detection] = [], - custom_batch_sizes: list[int] = [], - ): - self.source_images = source_images - self.request_config = ( - request_config - if isinstance(request_config, dict) - else request_config.model_dump() - ) - self.existing_detections = existing_detections - - logger.info("Initializing algorithms....") - self.stages = self.stages or self.get_stages() - self.batch_sizes = ( - custom_batch_sizes or self.batch_sizes or [1] * len(self.stages) - ) - assert len(self.batch_sizes) == len( - self.stages - ), "Number of batch sizes must match the number of stages." - - def get_stages(self) -> list[Algorithm]: - """ - An optional function to initialize and return a list of algorithms/stages. - Any pipeline config values relevant to a particular algorithm should be passed or set here. - """ - return [] - - @final - def compile(self): - logger.info("Compiling algorithms....") - for stage_idx, stage in enumerate(self.stages): - logger.info( - f"[{stage_idx+1}/{len(self.stages)}] Compiling {stage.algorithm_config_response.name}..." - ) - stage.compile() - - def run(self) -> PipelineResultsResponse: - """ - This function must always return a PipelineResultsResponse object. - """ - raise NotImplementedError("Subclasses must implement") - - @final - def _batchify_inputs(self, inputs: list, batch_size: int) -> list[list]: - """ - Helper function to split the inputs into batches of the specified size. - """ - batched_inputs = [] - for i in range(0, len(inputs), batch_size): - start_id = i - end_id = i + batch_size - batched_inputs.append(inputs[start_id:end_id]) - return batched_inputs - - @final - def _get_detections( - self, - algorithm: Algorithm, - inputs: list[SourceImage] | list[Detection], - batch_size: int, - **kwargs, - ) -> list[Detection]: - """A single stage, step, or algorithm in a pipeline. Batchifies inputs and produces Detections as outputs.""" - outputs: list[Detection] = [] - batched_inputs = self._batchify_inputs(inputs, batch_size) - for batch in batched_inputs: - outputs.extend(algorithm.run(batch, **kwargs)) - return outputs - - @final - def _get_pipeline_response( - self, detections: list[Detection], elapsed_time: float - ) -> PipelineResultsResponse: - """ - Final stage of the pipeline to format the detections. - """ - detection_responses = [ - DetectionResponse( - source_image_id=detection.source_image.id, - bbox=detection.bbox, - inference_time=detection.inference_time, - algorithm=detection.algorithm, - timestamp=datetime.datetime.now(), - classifications=detection.classifications, - ) - for detection in detections - ] - source_image_responses = [ - SourceImageResponse(**image.model_dump()) for image in self.source_images - ] - - return PipelineResultsResponse( - pipeline=self.config.slug, # type: ignore - # algorithms={algorithm.key: algorithm for algorithm in self.config.algorithms}, - total_time=elapsed_time, - source_images=source_image_responses, - detections=detection_responses, - ) - - -class ZeroShotHFClassifierPipeline(Pipeline): - """ - A pipeline that uses the Zero Shot Object Detector to produce bounding boxes - and then applies the HuggingFace image classifier. - """ - - batch_sizes = [1, 1] - config = PipelineConfigResponse( - name="Zero Shot HF Classifier Pipeline", - slug="zero-shot-hf-classifier-pipeline", - description=("Zero Shot Object Detector with HF image classifier."), - version=1, - algorithms=[ - ZeroShotObjectDetector().algorithm_config_response, - HFImageClassifier().algorithm_config_response, - ], - ) - - def get_stages(self) -> list[Algorithm]: - zero_shot_object_detector = ZeroShotObjectDetector() - if "candidate_labels" in self.request_config: - logger.info( - "Setting candidate labels for zero shot object detector to %s", - self.request_config["candidate_labels"], - ) - zero_shot_object_detector.candidate_labels = self.request_config[ - "candidate_labels" - ] - self.config.algorithms = [ - zero_shot_object_detector.algorithm_config_response, - HFImageClassifier().algorithm_config_response, - ] - - return [zero_shot_object_detector, HFImageClassifier()] - - def run(self) -> PipelineResultsResponse: - start_time = datetime.datetime.now() - detections_with_candidate_labels: list[Detection] = [] - if self.existing_detections: - logger.info("[1/2] Skipping the localizer, use existing detections...") - detections_with_candidate_labels = self.existing_detections - else: - logger.info("[1/2] No existing detections, generating detections...") - detections_with_candidate_labels: list[Detection] = self._get_detections( - self.stages[0], - self.source_images, - self.batch_sizes[0], - intermediate=True, - ) - - logger.info("[2/2] Running the classifier...") - detections_with_classifications: list[Detection] = self._get_detections( - self.stages[1], detections_with_candidate_labels, self.batch_sizes[1] - ) - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - - pipeline_response: PipelineResultsResponse = self._get_pipeline_response( - detections_with_classifications, elapsed_time - ) - logger.info( - f"Successfully processed {len(detections_with_classifications)} detections." - ) - - return pipeline_response - - -class ZeroShotObjectDetectorPipeline(Pipeline): - """ - A pipeline that uses the HuggingFace zero shot object detector. - Produces both a bounding box and a classification for each detection. - The classification is based on the candidate labels provided in the request. - """ - - batch_sizes = [1] - config = PipelineConfigResponse( - name="Zero Shot Object Detector Pipeline", - slug="zero-shot-object-detector-pipeline", - description=("Zero shot object detector (bbox and classification)."), - version=1, - algorithms=[ZeroShotObjectDetector().algorithm_config_response], - ) - - def get_stages(self) -> list[Algorithm]: - zero_shot_object_detector = ZeroShotObjectDetector() - if "candidate_labels" in self.request_config: - logger.info( - "Setting candidate labels for zero shot object detector to %s", - self.request_config["candidate_labels"], - ) - zero_shot_object_detector.candidate_labels = self.request_config[ - "candidate_labels" - ] - self.config.algorithms = [zero_shot_object_detector.algorithm_config_response] - - return [zero_shot_object_detector] - - def run(self) -> PipelineResultsResponse: - start_time = datetime.datetime.now() - logger.info("[1/1] Running the zero shot object detector...") - detections_with_classifications: list[Detection] = self._get_detections( - self.stages[0], self.source_images, self.batch_sizes[0] - ) - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - pipeline_response: PipelineResultsResponse = self._get_pipeline_response( - detections_with_classifications, elapsed_time - ) - logger.info( - f"Successfully processed {len(detections_with_classifications)} detections." - ) - - return pipeline_response - - -class ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline(Pipeline): - """ - A pipeline that uses the HuggingFace zero shot object detector and a random species classifier. - """ - - batch_sizes = [1, 1] - config = PipelineConfigResponse( - name="Zero Shot Object Detector With Random Species Classifier Pipeline", - slug="zero-shot-object-detector-with-random-species-classifier-pipeline", - description=("HF zero shot object detector with random species classifier."), - version=1, - algorithms=[ - ZeroShotObjectDetector().algorithm_config_response, - RandomSpeciesClassifier().algorithm_config_response, - ], - ) - - def get_stages(self) -> list[Algorithm]: - zero_shot_object_detector = ZeroShotObjectDetector() - if "candidate_labels" in self.request_config: - zero_shot_object_detector.candidate_labels = self.request_config[ - "candidate_labels" - ] - - self.config.algorithms = [ - zero_shot_object_detector.algorithm_config_response, - RandomSpeciesClassifier().algorithm_config_response, - ] - - return [zero_shot_object_detector, RandomSpeciesClassifier()] - - def run(self) -> PipelineResultsResponse: - start_time = datetime.datetime.now() - detections: list[Detection] = [] - if self.existing_detections: - logger.info("[1/2] Skipping the localizer, use existing detections...") - detections = self.existing_detections - else: - logger.info("[1/2] No existing detections, generating detections...") - detections = self._get_detections( - self.stages[0], self.source_images, self.batch_sizes[0] - ) - - logger.info("[2/2] Running the classifier...") - detections_with_classifications: list[Detection] = self._get_detections( - self.stages[1], detections, self.batch_sizes[1] - ) - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - pipeline_response: PipelineResultsResponse = self._get_pipeline_response( - detections_with_classifications, elapsed_time - ) - logger.info( - f"Successfully processed {len(detections_with_classifications)} detections." - ) - - return pipeline_response - - -class ZeroShotObjectDetectorWithConstantClassifierPipeline(Pipeline): - """ - A pipeline that uses the HuggingFace zero shot object detector and a constant classifier. - """ - - batch_sizes = [1, 1] - config = PipelineConfigResponse( - name="Zero Shot Object Detector With Constant Classifier Pipeline", - slug="zero-shot-object-detector-with-constant-classifier-pipeline", - description=("HF zero shot object detector with constant classifier."), - version=1, - algorithms=[ - ZeroShotObjectDetector().algorithm_config_response, - ConstantClassifier().algorithm_config_response, - ], - ) - - def get_stages(self) -> list[Algorithm]: - zero_shot_object_detector = ZeroShotObjectDetector() - if "candidate_labels" in self.request_config: - zero_shot_object_detector.candidate_labels = self.request_config[ - "candidate_labels" - ] - - self.config.algorithms = [ - zero_shot_object_detector.algorithm_config_response, - ConstantClassifier().algorithm_config_response, - ] - - return [zero_shot_object_detector, ConstantClassifier()] - - def run(self) -> PipelineResultsResponse: - start_time = datetime.datetime.now() - detections: list[Detection] = [] - if self.existing_detections: - logger.info("[1/2] Skipping the localizer, use existing detections...") - detections = self.existing_detections - else: - logger.info("[1/2] No existing detections, generating detections...") - detections = self._get_detections( - self.stages[0], self.source_images, self.batch_sizes[0] - ) - - logger.info("[2/2] Running the classifier...") - detections_with_classifications: list[Detection] = self._get_detections( - self.stages[1], detections, self.batch_sizes[1] - ) - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - pipeline_response: PipelineResultsResponse = self._get_pipeline_response( - detections_with_classifications, elapsed_time - ) - logger.info( - f"Successfully processed {len(detections_with_classifications)} detections." - ) - - return pipeline_response - - -class ZeroShotObjectDetectorWithGlobalMothClassifierPipeline(Pipeline): - """ - A pipeline that uses the HuggingFace zero shot object detector and the global moth classifier. - This provides high-quality moth species identification with 29,176+ species support. - """ - - batch_sizes = [1, 4] # Detector batch=1, Classifier batch=4 - config = PipelineConfigResponse( - name="Zero Shot Object Detector With Global Moth Classifier Pipeline", - slug="zero-shot-object-detector-with-global-moth-classifier-pipeline", - description=( - "HF zero shot object detector with global moth species classifier. " - "Supports 29,176+ moth species trained on global data." - ), - version=1, - algorithms=[], # Will be populated in get_stages() - ) - - def get_stages(self) -> list[Algorithm]: - zero_shot_object_detector = ZeroShotObjectDetector() - if "candidate_labels" in self.request_config: - zero_shot_object_detector.candidate_labels = self.request_config[ - "candidate_labels" - ] - - global_moth_classifier = GlobalMothClassifier() - - self.config.algorithms = [ - zero_shot_object_detector.algorithm_config_response, - global_moth_classifier.algorithm_config_response, - ] - - return [zero_shot_object_detector, global_moth_classifier] - - def run(self) -> PipelineResultsResponse: - start_time = datetime.datetime.now() - detections: list[Detection] = [] - - if self.existing_detections: - logger.info("[1/2] Skipping the localizer, use existing detections...") - detections = self.existing_detections - else: - logger.info("[1/2] No existing detections, generating detections...") - detections = self._get_detections( - self.stages[0], self.source_images, self.batch_sizes[0] - ) - - logger.info("[2/2] Running the global moth classifier...") - detections_with_classifications: list[Detection] = self._get_detections( - self.stages[1], detections, self.batch_sizes[1] - ) - - end_time = datetime.datetime.now() - elapsed_time = (end_time - start_time).total_seconds() - - pipeline_response: PipelineResultsResponse = self._get_pipeline_response( - detections_with_classifications, elapsed_time - ) - logger.info( - f"Successfully processed {len(detections_with_classifications)} detections with global moth classifier." - ) - - return pipeline_response diff --git a/processing_services/moths/api/api/schemas.py b/processing_services/moths/api/api/schemas.py deleted file mode 100644 index 93e99f6aa..000000000 --- a/processing_services/moths/api/api/schemas.py +++ /dev/null @@ -1,341 +0,0 @@ -# Can these be imported from the OpenAPI spec yaml? -import datetime -import logging -import pathlib -import typing - -import PIL.Image -import pydantic - -from .utils import get_image - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -class BoundingBox(pydantic.BaseModel): - x1: float - y1: float - x2: float - y2: float - - @classmethod - def from_coords(cls, coords: list[float]): - return cls(x1=coords[0], y1=coords[1], x2=coords[2], y2=coords[3]) - - def to_string(self): - return f"{self.x1},{self.y1},{self.x2},{self.y2}" - - def to_path(self): - return "-".join([str(int(x)) for x in [self.x1, self.y1, self.x2, self.y2]]) - - def to_tuple(self): - return (self.x1, self.y1, self.x2, self.y2) - - -class BaseImage(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra="ignore", arbitrary_types_allowed=True) - - id: str - url: str | None = None - b64: str | None = None - filepath: str | pathlib.Path | None = None - _pil: PIL.Image.Image | None = None - width: int | None = None - height: int | None = None - timestamp: datetime.datetime | None = None - - # Validate that there is at least one of the following fields - @pydantic.model_validator(mode="after") - def validate_source(self): - if not any([self.url, self.b64, self.filepath, self._pil]): - raise ValueError( - "At least one of the following fields must be provided: url, b64, filepath, pil" - ) - return self - - def open(self, raise_exception=False) -> PIL.Image.Image | None: - if not self._pil: - logger.warn(f"Opening image {self.id} for the first time") - self._pil = get_image( - url=self.url, - b64=self.b64, - filepath=self.filepath, - raise_exception=raise_exception, - ) - else: - logger.info(f"Using already loaded image {self.id}") - if self._pil: - self.width, self.height = self._pil.size - return self._pil - - -class SourceImage(BaseImage): - pass - - -class AlgorithmReference(pydantic.BaseModel): - name: str - key: str - - -class ClassificationResponse(pydantic.BaseModel): - classification: str - labels: list[str] | None = pydantic.Field( - default=None, - description=( - "A list of all possible labels for the model, in the correct order. " - "Omitted if the model has too many labels to include for each classification in the response. " - "Use the category map from the algorithm to get the full list of labels and metadata." - ), - ) - scores: list[float] = pydantic.Field( - default_factory=list, - description="The calibrated probabilities for each class label, most commonly the softmax output.", - ) - logits: list[float] = pydantic.Field( - default_factory=list, - description="The raw logits output by the model, before any calibration or normalization.", - ) - inference_time: float | None = None - algorithm: AlgorithmReference - terminal: bool = True - timestamp: datetime.datetime - - -class SourceImageRequest(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra="ignore") - - id: str - url: str - # b64: str | None = None - # @TODO bring over new SourceImage & b64 validation from the lepsAI repo - - -class SourceImageResponse(pydantic.BaseModel): - model_config = pydantic.ConfigDict(extra="ignore") - - id: str - url: str - - -class DetectionRequest(pydantic.BaseModel): - source_image: SourceImageRequest # the 'original' image - bbox: BoundingBox - crop_image_url: str | None = None - algorithm: AlgorithmReference - - -class DetectionResponse(pydantic.BaseModel): - # these fields are populated with values from a Detection, excluding source_image details - source_image_id: str - bbox: BoundingBox - inference_time: float | None = None - algorithm: AlgorithmReference - timestamp: datetime.datetime - crop_image_url: str | None = None - classifications: list[ClassificationResponse] = [] - - -class Detection(BaseImage): - """ - An internal representation of a detection with reference to a source image instance. - """ - - source_image: SourceImage # the 'original' uncropped image - bbox: BoundingBox - inference_time: float | None = None - algorithm: AlgorithmReference - classifications: list[ClassificationResponse] = [] - - -class AlgorithmCategoryMapResponse(pydantic.BaseModel): - data: list[dict] = pydantic.Field( - default_factory=dict, - description="Complete data for each label, such as id, gbif_key, explicit index, source, etc.", - examples=[ - [ - {"label": "Moth", "index": 0, "gbif_key": 1234}, - {"label": "Not a moth", "index": 1, "gbif_key": 5678}, - ] - ], - ) - labels: list[str] = pydantic.Field( - default_factory=list, - description="A simple list of string labels, in the correct index order used by the model.", - examples=[["Moth", "Not a moth"]], - ) - version: str | None = pydantic.Field( - default=None, - description="The version of the category map. Can be a descriptive string or a version number.", - examples=["LepNet2021-with-2023-mods"], - ) - description: str | None = pydantic.Field( - default=None, - description="A description of the category map used to train. e.g. source, purpose and modifications.", - examples=[ - "LepNet2021 with Schmidt 2023 corrections. Limited to species with > 1000 observations." - ], - ) - uri: str | None = pydantic.Field( - default=None, - description="A URI to the category map file, could be a public web URL or object store path.", - ) - - -class AlgorithmConfigResponse(pydantic.BaseModel): - name: str - key: str = pydantic.Field( - description=( - "A unique key for an algorithm to lookup the category map (class list) and other metadata." - ), - ) - description: str | None = None - task_type: str | None = pydantic.Field( - default=None, - description="The type of task the model is trained for. e.g. 'detection', 'classification', 'embedding', etc.", - examples=["detection", "classification", "segmentation", "embedding"], - ) - version: int = pydantic.Field( - default=1, - description="A sortable version number for the model. Increment this number when the model is updated.", - ) - version_name: str | None = pydantic.Field( - default=None, - description="A complete version name e.g. '2021-01-01', 'LepNet2021'.", - ) - uri: str | None = pydantic.Field( - default=None, - description="A URI to the weights or model details, could be a public web URL or object store path.", - ) - category_map: AlgorithmCategoryMapResponse | None = None - - class Config: - extra = "ignore" - - -PipelineChoice = typing.Literal[ - # @TODO can this be dynamically generated from available pipelines? - "zero-shot-hf-classifier-pipeline", - "zero-shot-object-detector-pipeline", - "zero-shot-object-detector-with-constant-classifier-pipeline", - "zero-shot-object-detector-with-random-species-classifier-pipeline", - "zero-shot-object-detector-with-global-moth-classifier-pipeline", -] - - -class PipelineRequestConfigParameters(pydantic.BaseModel): - """Parameters used to configure a pipeline request. - - Accepts any serializable key-value pair. - Example: {"force_reprocess": True, "auth_token": "abc123"} - - Supported parameters are defined by the pipeline in the processing service - and should be published in the Pipeline's info response. - """ - - force_reprocess: bool = pydantic.Field( - default=False, - description="Force reprocessing of the image, even if it has already been processed.", - ) - auth_token: str | None = pydantic.Field( - default=None, - description="An optional authentication token to use for the pipeline.", - ) - candidate_labels: list[str] | None = pydantic.Field( - default=None, - description="A list of candidate labels to use for the zero-shot object detector.", - ) - - -class PipelineRequest(pydantic.BaseModel): - pipeline: PipelineChoice - source_images: list[SourceImageRequest] - detections: list[DetectionRequest] | None = None - config: PipelineRequestConfigParameters | dict | None = None - - # Example for API docs: - class Config: - json_schema_extra = { - "example": { - "pipeline": "random", - "source_images": [ - { - "id": "123", - "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg", - } - ], - "config": {"force_reprocess": True, "auth_token": "abc123"}, - } - } - - -class PipelineResultsResponse(pydantic.BaseModel): - pipeline: PipelineChoice - total_time: float - algorithms: dict[str, AlgorithmConfigResponse] = pydantic.Field( - default_factory=dict, - description=( - "A dictionary of all algorithms used in the pipeline, including their class list and other " - "metadata, keyed by the algorithm key. " - "DEPRECATED: Algorithms should only be provided in the ProcessingServiceInfoResponse." - ), - depreciated=True, - ) - source_images: list[SourceImageResponse] - detections: list[DetectionResponse] - errors: list | str | None = None - - -class PipelineStageParam(pydantic.BaseModel): - """A configurable parameter of a stage of a pipeline.""" - - name: str - key: str - category: str = "default" - - -class PipelineStage(pydantic.BaseModel): - """A configurable stage of a pipeline.""" - - key: str - name: str - params: list[PipelineStageParam] = [] - description: str | None = None - - -class PipelineConfigResponse(pydantic.BaseModel): - """Details about a pipeline, its algorithms and category maps.""" - - name: str - slug: str - version: int - description: str | None = None - algorithms: list[AlgorithmConfigResponse] = [] - stages: list[PipelineStage] = [] - - -class ProcessingServiceInfoResponse(pydantic.BaseModel): - """Information about the processing service.""" - - name: str = pydantic.Field(example="Mila Research Lab - Moth AI Services") - description: str | None = pydantic.Field( - default=None, - examples=[ - "Algorithms developed by the Mila Research Lab for analysis of moth images." - ], - ) - pipelines: list[PipelineConfigResponse] = pydantic.Field( - default=list, - examples=[ - [ - PipelineConfigResponse( - name="Random Pipeline", slug="random", version=1, algorithms=[] - ), - ] - ], - ) - # algorithms: list[AlgorithmConfigResponse] = pydantic.Field( - # default=list, - # examples=[RANDOM_BINARY_CLASSIFIER], - # ) diff --git a/processing_services/moths/api/api/test.py b/processing_services/moths/api/api/test.py deleted file mode 100644 index b5b1b5f7c..000000000 --- a/processing_services/moths/api/api/test.py +++ /dev/null @@ -1,64 +0,0 @@ -import unittest - -from fastapi.testclient import TestClient - -from .api import app -from .pipelines import CustomPipeline -from .schemas import PipelineRequest, SourceImage, SourceImageRequest - - -class TestPipeline(unittest.TestCase): - def test_custom_pipeline(self): - # @TODO: Load actual antenna images? - pipeline = CustomPipeline( - source_images=[ - SourceImage( - id="1001", - url=( - "https://huggingface.co/datasets/huggingface/" - "documentation-images/resolve/main/pipeline-cat-chonk.jpeg" - ), - ), - SourceImage(id="1002", url="https://cdn.britannica.com/79/191679-050-C7114D2B/Adult-capybara.jpg"), - ], - detector_batch_size=2, - classifier_batch_size=2, - ) - detections = pipeline.run() - - self.assertEqual(len(detections), 20) - expected_labels = ["lynx, catamount", "beaver"] - for detection_id, detection in enumerate(detections): - self.assertEqual(detection.source_image_id, pipeline.source_images[detection_id].id) - self.assertIsNotNone(detection.bbox) - self.assertEqual(len(detection.classifications), 1) - classification = detection.classifications[0] - self.assertEqual(classification.classification, expected_labels[detection_id]) - self.assertGreaterEqual(classification.scores[0], 0.0) - self.assertLessEqual(classification.scores[0], 1.0) - - -class TestAPI(unittest.TestCase): - def setUp(self): - self.client = TestClient(app) - - def test_root(self): - response = self.client.get("/") - self.assertEqual(response.status_code, 200) - self.assertEqual(response.url, "http://testserver/docs") - - def test_process(self): - source_images = [ - SourceImage(id="1", url="https://example.com/image1.jpg"), - SourceImage(id="2", url="https://example.com/image2.jpg"), - ] - source_image_requests = [SourceImageRequest(**image.dict()) for image in source_images] - request = PipelineRequest(pipeline="local-pipeline", source_images=source_image_requests, config={}) - response = self.client.post("/process", json=request.dict()) - - self.assertEqual(response.status_code, 200) - data = response.json() - self.assertEqual(data["pipeline"], "local-pipeline") - self.assertEqual(len(data["source_images"]), 2) - self.assertEqual(len(data["detections"]), 2) - self.assertGreater(data["total_time"], 0.0) diff --git a/processing_services/moths/api/api/utils.py b/processing_services/moths/api/api/utils.py deleted file mode 100644 index 9fd50d3a4..000000000 --- a/processing_services/moths/api/api/utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import base64 -import binascii -import io -import logging -import pathlib -import re -import tempfile -from urllib.parse import urlparse - -import PIL.Image -import PIL.ImageFile -import requests -import torch - -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True - -# This is polite and required by some hosts -# see: https://foundation.wikimedia.org/wiki/Policy:User-Agent_policy -USER_AGENT = "AntennaInsectDataPlatform/1.0 (https://insectai.org)" - -# ----------- -# File handling functions -# ----------- - - -def is_url(path: str) -> bool: - return path.startswith("http://") or path.startswith("https://") - - -def is_base64(s: str) -> bool: - try: - # Check if string can be decoded from base64 - return base64.b64encode(base64.b64decode(s)).decode() == s - except Exception: - return False - - -def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path: - """ - Fetch a file from a URL or local path. If the path is a URL, download the file. - If the URL has already been downloaded, return the existing local path. - If the path is a local path, return the path. - - >>> filepath = get_or_download_file("https://example.uk/images/31-20230919033000-snapshot.jpg?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=451d406b7eb1113e1bb05c083ce51481%2F20240429%2F") # noqa: E501 - >>> filepath.name - '31-20230919033000-snapshot.jpg' - >>> filepath = get_or_download_file("/home/user/images/31-20230919033000-snapshot.jpg") - >>> filepath.name - '31-20230919033000-snapshot.jpg' - """ - if not path_or_url: - raise Exception("Specify a URL or path to fetch file from.") - - # If path is a local path instead of a URL then urlretrieve will just return that path - - # Use persistent cache directory instead of temp - cache_root = pathlib.Path("/app/cache") if pathlib.Path("/app").exists() else pathlib.Path.home() / ".antenna_cache" - destination_dir = cache_root / tempdir_prefix - fname = pathlib.Path(urlparse(path_or_url).path).name - if not destination_dir.exists(): - destination_dir.mkdir(parents=True, exist_ok=True) - local_filepath = destination_dir / fname - - if local_filepath and local_filepath.exists(): - logger.info(f"📁 Using cached file: {local_filepath}") - return local_filepath - - else: - logger.info(f"⬇️ Downloading {path_or_url} to {local_filepath}") - headers = {"User-Agent": USER_AGENT} - response = requests.get(path_or_url, stream=True, headers=headers) - response.raise_for_status() # Raise an exception for HTTP errors - - with open(local_filepath, "wb") as f: - total_size = int(response.headers.get('content-length', 0)) - downloaded = 0 - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - downloaded += len(chunk) - if total_size > 0: - percent = (downloaded / total_size) * 100 - logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)") - - resulting_filepath = pathlib.Path(local_filepath).resolve() - logger.info(f"✅ Download completed: {resulting_filepath}") - logger.info(f"Downloaded to {resulting_filepath}") - return resulting_filepath - - -def get_best_device() -> str: - """ - Returns the best available device for running the model. - MPS is not supported by the current algorithms. - """ - if torch.cuda.is_available(): - return f"cuda:{torch.cuda.current_device()}" - else: - return "cpu" - - -def open_image( - fp: str | bytes | pathlib.Path | io.BytesIO, raise_exception: bool = True -) -> PIL.Image.Image | None: - """ - Wrapper from PIL.Image.open that handles errors and converts to RGB. - """ - img = None - try: - img = PIL.Image.open(fp) - except PIL.UnidentifiedImageError: - logger.warn(f"Unidentified image: {str(fp)[:100]}...") - if raise_exception: - raise - except OSError: - logger.warn(f"Could not open image: {str(fp)[:100]}...") - if raise_exception: - raise - else: - # Convert to RGB if necessary - if img.mode != "RGB": - img = img.convert("RGB") - - return img - - -def decode_base64_string(string) -> io.BytesIO: - image_data = re.sub("^data:image/.+;base64,", "", string) - decoded = base64.b64decode(image_data) - buffer = io.BytesIO(decoded) - buffer.seek(0) - return buffer - - -def get_image( - url: str | None = None, - filepath: str | pathlib.Path | None = None, - b64: str | None = None, - raise_exception: bool = True, -) -> PIL.Image.Image | None: - """ - Given a URL, local file path or base64 image, return a PIL image. - """ - - if url: - logger.info(f"Fetching image from URL: {url}") - tempdir = tempfile.TemporaryDirectory(prefix="ami_images") - img_path = get_or_download_file(url, tempdir_prefix=tempdir.name) - return open_image(img_path, raise_exception=raise_exception) - - elif filepath: - logger.info(f"Loading image from local filesystem: {filepath}") - return open_image(filepath, raise_exception=raise_exception) - - elif b64: - logger.info(f"Loading image from base64 string: {b64[:30]}...") - try: - buffer = decode_base64_string(b64) - except binascii.Error as e: - logger.warn(f"Could not decode base64 image: {e}") - if raise_exception: - raise - else: - return None - else: - return open_image(buffer, raise_exception=raise_exception) - - else: - raise Exception("Specify a URL, path or base64 image.") diff --git a/processing_services/moths/api/docker-compose.yml b/processing_services/moths/api/docker-compose.yml deleted file mode 100644 index 83db6ccfa..000000000 --- a/processing_services/moths/api/docker-compose.yml +++ /dev/null @@ -1,25 +0,0 @@ -services: - ml_backend_example: - build: - context: . - volumes: - - ./:/app:z - - ./huggingface_cache:/root/.cache/huggingface - - ./pytorch_cache:/root/.cache/torch - ports: - - "2003:2000" - extra_hosts: - - minio:host-gateway - networks: - - antenna_network - # deploy: - # resources: - # reservations: - # devices: - # - driver: nvidia - # count: 1 - # capabilities: [ gpu ] - -networks: - antenna_network: - name: antenna_network diff --git a/processing_services/moths/api/main.py b/processing_services/moths/api/main.py deleted file mode 100644 index 2ed50004d..000000000 --- a/processing_services/moths/api/main.py +++ /dev/null @@ -1,4 +0,0 @@ -if __name__ == "__main__": - import uvicorn - - uvicorn.run("api.api:app", host="0.0.0.0", port=2000, reload=True) diff --git a/processing_services/moths/api/requirements.txt b/processing_services/moths/api/requirements.txt deleted file mode 100644 index 4cf7a91d9..000000000 --- a/processing_services/moths/api/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -fastapi==0.116.0 -uvicorn==0.35.0 -pydantic==2.11.7 -Pillow==11.3.0 -requests==2.32.4 -transformers==4.50.3 -torch==2.6.0 -torchvision==0.21.0 -scipy==1.16.0 -timm diff --git a/processing_services/moths/api/test_api_integration.py b/processing_services/moths/api/test_api_integration.py deleted file mode 100644 index 492a549b2..000000000 --- a/processing_services/moths/api/test_api_integration.py +++ /dev/null @@ -1,173 +0,0 @@ -#!/usr/bin/env python3 -""" -API Integration Test for Global Moth Classifier Pipeline. -This test calls the actual HTTP API endpoints to validate the service. -""" - -import json -import pathlib -import sys -import time - -import requests - -# Add the processing_services/example to the path -sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) - - -def test_api_integration(): - """Test the Global Moth Classifier Pipeline via HTTP API.""" - print("🌐 Testing API Integration for Global Moth Classifier...") - - base_url = "http://ml_backend_example:2000" - - # Test 1: Get service info - print("\n📋 Test 1: Getting service info...") - try: - response = requests.get(f"{base_url}/info", timeout=30) - response.raise_for_status() - info = response.json() - - print("✅ Service info retrieved successfully!") - print(f" Service name: {info.get('name', 'Unknown')}") - print(f" Version: {info.get('version', 'Unknown')}") - print(f" Available pipelines: {len(info.get('pipelines', []))}") - - # Check if our pipeline is available - pipeline_slugs = [p.get('slug') for p in info.get('pipelines', [])] - expected_slug = "zero-shot-object-detector-with-global-moth-classifier-pipeline" - - if expected_slug in pipeline_slugs: - print("✅ Global Moth Classifier pipeline found in service!") - else: - print("❌ Global Moth Classifier pipeline NOT found in service") - print(f" Available pipelines: {pipeline_slugs}") - return False - - except Exception as e: - print(f"❌ Service info request failed: {str(e)}") - return False - - # Test 2: Process image with Global Moth Classifier - print("\n🦋 Test 2: Processing moth image...") - - request_payload = { - "config": { - "auth_token": "test123", - "force_reprocess": True, - "candidate_labels": ["moth", "butterfly", "insect"] - }, - "pipeline": "zero-shot-object-detector-with-global-moth-classifier-pipeline", - "source_images": [ - { - "id": "api_test_123", - "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg" - } - ] - } - - try: - print("📤 Sending processing request...") - print(f" Pipeline: {request_payload['pipeline']}") - print(f" Image URL: {request_payload['source_images'][0]['url']}") - - start_time = time.time() - response = requests.post( - f"{base_url}/process", - json=request_payload, - timeout=300 # 5 minutes timeout for processing - ) - end_time = time.time() - - response.raise_for_status() - result = response.json() - - processing_time = end_time - start_time - print("✅ Image processed successfully!") - print(f" API response time: {processing_time:.2f}s") - print(f" Pipeline processing time: {result.get('total_time', 'unknown')}s") - print(f" Number of detections: {len(result.get('detections', []))}") - - # Analyze results - detections = result.get('detections', []) - if detections: - print("\n🔍 Detection Results:") - for i, detection in enumerate(detections[:5]): # Show first 5 - bbox = detection.get('bbox', {}) - classifications = detection.get('classifications', []) - - print(f" Detection {i+1}:") - print(f" - Bbox: {bbox}") - print(f" - Algorithm: {detection.get('algorithm', {}).get('name', 'unknown')}") - - if classifications: - # Find top classification - top_classification = classifications[0] - if 'scores' in top_classification and top_classification['scores']: - max_score = max(top_classification['scores']) - max_idx = top_classification['scores'].index(max_score) - if 'labels' in top_classification and max_idx < len(top_classification['labels']): - species_name = top_classification['labels'][max_idx] - print(f" - Top species: {species_name} ({max_score:.3f})") - else: - print(f" - Classification: {top_classification.get('classification', 'unknown')}") - - if len(detections) > 5: - print(f" ... and {len(detections) - 5} more detections") - else: - print("⚠️ No detections found in the image") - - return True - - except Exception as e: - print(f"❌ Image processing request failed: {str(e)}") - if hasattr(e, 'response') and e.response is not None: - try: - error_details = e.response.json() - print(f" Error details: {json.dumps(error_details, indent=2)}") - except: - print(f" Error response: {e.response.text}") - return False - - -def test_service_health(): - """Test basic service health endpoints.""" - print("\n🏥 Testing service health endpoints...") - - base_url = "http://ml_backend_example:2000" - - # Test health endpoints - health_endpoints = ["/", "/livez", "/readyz"] - - for endpoint in health_endpoints: - try: - response = requests.get(f"{base_url}{endpoint}", timeout=10) - response.raise_for_status() - print(f"✅ {endpoint}: {response.status_code}") - except Exception as e: - print(f"❌ {endpoint}: {str(e)}") - return False - - return True - - -if __name__ == "__main__": - print("🧪 Starting API Integration Tests for Global Moth Classifier") - print("=" * 60) - - # Test service health first - health_ok = test_service_health() - if not health_ok: - print("\n❌ Service health checks failed!") - sys.exit(1) - - # Test main API integration - api_ok = test_api_integration() - - print("\n" + "=" * 60) - if api_ok: - print("🎉 All API integration tests PASSED!") - sys.exit(0) - else: - print("❌ API integration tests FAILED!") - sys.exit(1) \ No newline at end of file diff --git a/processing_services/moths/api/test_compilation_logging.py b/processing_services/moths/api/test_compilation_logging.py deleted file mode 100644 index b756fa286..000000000 --- a/processing_services/moths/api/test_compilation_logging.py +++ /dev/null @@ -1,38 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple test to trigger model compilation and see enhanced logging. -""" - -import pathlib -import sys - -# Add the processing_services/example to the path -sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) - -from api.global_moth_classifier import GlobalMothClassifier - - -def test_compilation_logging(): - """Test the enhanced logging during model compilation.""" - print("🔧 Testing enhanced compilation logging...") - - # Create classifier instance - classifier = GlobalMothClassifier() - - print(f"📋 Classifier instantiated: {classifier.name}") - print(f" Expected classes: {classifier.num_classes}") - - # Trigger compilation (this should show our enhanced logging) - print("\n⚡ Triggering compilation...") - classifier.compile() - - print("\n✅ Compilation complete!") - print(f" Model loaded: {classifier.model is not None}") - print(f" Transforms ready: {classifier.transforms is not None}") - print(f" Categories loaded: {len(classifier.category_map)} species") - - return True - - -if __name__ == "__main__": - test_compilation_logging() \ No newline at end of file diff --git a/processing_services/moths/api/test_global_moth_pipeline.py b/processing_services/moths/api/test_global_moth_pipeline.py deleted file mode 100644 index 1886e5823..000000000 --- a/processing_services/moths/api/test_global_moth_pipeline.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for the Global Moth Classifier Pipeline. -This test processes a real moth image and validates the full pipeline functionality. -""" - -import pathlib -import sys - -# Add the processing_services/example to the path -sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) - -from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline -from api.schemas import SourceImage -from api.utils import get_image - - -def test_global_moth_pipeline(): - """Test the Global Moth Classifier Pipeline with a real request.""" - print("🧪 Testing Global Moth Classifier Pipeline with real request...") - - # Create source image from the provided URL - source_image = SourceImage( - id="123", - url="https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg", - width=800, # Typical image dimensions - height=600 - ) - - # Load the PIL image and attach it to the source image - print("📥 Loading image from URL...") - pil_image = get_image(url=source_image.url) - if pil_image: - source_image._pil = pil_image - # Update dimensions with actual image size - source_image.width = pil_image.width - source_image.height = pil_image.height - print(f"✅ Image loaded: {pil_image.width}x{pil_image.height}") - else: - print("❌ Failed to load image") - return False - - # Create pipeline with the test configuration - pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( - source_images=[source_image], - request_config={ - "auth_token": "abc123", - "force_reprocess": True, - "candidate_labels": ["moth", "butterfly", "insect"] # Add candidate labels for detection - }, - existing_detections=[], - ) - - print("✅ Pipeline instantiated successfully!") - print(f" Pipeline name: {pipeline.config.name}") - print(f" Pipeline slug: {pipeline.config.slug}") - print(f" Number of algorithms: {len(pipeline.config.algorithms)}") - print(f" Algorithm 1: {pipeline.config.algorithms[0].name}") - print(f" Algorithm 2: {pipeline.config.algorithms[1].name}") - - # Test that stages can be created - stages = pipeline.get_stages() - assert len(stages) == 2 - print(f" Stages created: {len(stages)}") - - # Compile the pipeline (load models) - print("🔧 Compiling pipeline (loading models)...") - pipeline.compile() - print("✅ Pipeline compiled successfully!") - - # Run the pipeline - print("🚀 Running pipeline on test image...") - try: - result = pipeline.run() - print("✅ Pipeline execution completed!") - print(f" Total processing time: {result.total_time:.2f}s") - print(f" Number of detections: {len(result.detections)}") - - # Print detection details - for i, detection in enumerate(result.detections): - print(f" Detection {i+1}:") - print(f" - Bbox: {detection.bbox}") - print(f" - Inference time: {detection.inference_time:.3f}s") - print(f" - Algorithm: {detection.algorithm}") - if detection.classifications: - # Get the classification with the highest score - top_classification = detection.classifications[0] # Usually sorted by confidence - if top_classification.scores: - max_score = max(top_classification.scores) - max_idx = top_classification.scores.index(max_score) - if top_classification.labels and max_idx < len(top_classification.labels): - species_name = top_classification.labels[max_idx] - print(f" - Top classification: {species_name} ({max_score:.3f})") - else: - print(f" - Classification: {top_classification.classification}") - else: - print(f" - Classification: {top_classification.classification}") - - return True - - except Exception as e: - print(f"❌ Pipeline execution failed: {str(e)}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - test_global_moth_pipeline() diff --git a/processing_services/moths/requirements.txt b/processing_services/moths/requirements.txt deleted file mode 100644 index 4cf7a91d9..000000000 --- a/processing_services/moths/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -fastapi==0.116.0 -uvicorn==0.35.0 -pydantic==2.11.7 -Pillow==11.3.0 -requests==2.32.4 -transformers==4.50.3 -torch==2.6.0 -torchvision==0.21.0 -scipy==1.16.0 -timm diff --git a/processing_services/moths/test_api_integration.py b/processing_services/moths/test_api_integration.py deleted file mode 100644 index 492a549b2..000000000 --- a/processing_services/moths/test_api_integration.py +++ /dev/null @@ -1,173 +0,0 @@ -#!/usr/bin/env python3 -""" -API Integration Test for Global Moth Classifier Pipeline. -This test calls the actual HTTP API endpoints to validate the service. -""" - -import json -import pathlib -import sys -import time - -import requests - -# Add the processing_services/example to the path -sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) - - -def test_api_integration(): - """Test the Global Moth Classifier Pipeline via HTTP API.""" - print("🌐 Testing API Integration for Global Moth Classifier...") - - base_url = "http://ml_backend_example:2000" - - # Test 1: Get service info - print("\n📋 Test 1: Getting service info...") - try: - response = requests.get(f"{base_url}/info", timeout=30) - response.raise_for_status() - info = response.json() - - print("✅ Service info retrieved successfully!") - print(f" Service name: {info.get('name', 'Unknown')}") - print(f" Version: {info.get('version', 'Unknown')}") - print(f" Available pipelines: {len(info.get('pipelines', []))}") - - # Check if our pipeline is available - pipeline_slugs = [p.get('slug') for p in info.get('pipelines', [])] - expected_slug = "zero-shot-object-detector-with-global-moth-classifier-pipeline" - - if expected_slug in pipeline_slugs: - print("✅ Global Moth Classifier pipeline found in service!") - else: - print("❌ Global Moth Classifier pipeline NOT found in service") - print(f" Available pipelines: {pipeline_slugs}") - return False - - except Exception as e: - print(f"❌ Service info request failed: {str(e)}") - return False - - # Test 2: Process image with Global Moth Classifier - print("\n🦋 Test 2: Processing moth image...") - - request_payload = { - "config": { - "auth_token": "test123", - "force_reprocess": True, - "candidate_labels": ["moth", "butterfly", "insect"] - }, - "pipeline": "zero-shot-object-detector-with-global-moth-classifier-pipeline", - "source_images": [ - { - "id": "api_test_123", - "url": "https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg" - } - ] - } - - try: - print("📤 Sending processing request...") - print(f" Pipeline: {request_payload['pipeline']}") - print(f" Image URL: {request_payload['source_images'][0]['url']}") - - start_time = time.time() - response = requests.post( - f"{base_url}/process", - json=request_payload, - timeout=300 # 5 minutes timeout for processing - ) - end_time = time.time() - - response.raise_for_status() - result = response.json() - - processing_time = end_time - start_time - print("✅ Image processed successfully!") - print(f" API response time: {processing_time:.2f}s") - print(f" Pipeline processing time: {result.get('total_time', 'unknown')}s") - print(f" Number of detections: {len(result.get('detections', []))}") - - # Analyze results - detections = result.get('detections', []) - if detections: - print("\n🔍 Detection Results:") - for i, detection in enumerate(detections[:5]): # Show first 5 - bbox = detection.get('bbox', {}) - classifications = detection.get('classifications', []) - - print(f" Detection {i+1}:") - print(f" - Bbox: {bbox}") - print(f" - Algorithm: {detection.get('algorithm', {}).get('name', 'unknown')}") - - if classifications: - # Find top classification - top_classification = classifications[0] - if 'scores' in top_classification and top_classification['scores']: - max_score = max(top_classification['scores']) - max_idx = top_classification['scores'].index(max_score) - if 'labels' in top_classification and max_idx < len(top_classification['labels']): - species_name = top_classification['labels'][max_idx] - print(f" - Top species: {species_name} ({max_score:.3f})") - else: - print(f" - Classification: {top_classification.get('classification', 'unknown')}") - - if len(detections) > 5: - print(f" ... and {len(detections) - 5} more detections") - else: - print("⚠️ No detections found in the image") - - return True - - except Exception as e: - print(f"❌ Image processing request failed: {str(e)}") - if hasattr(e, 'response') and e.response is not None: - try: - error_details = e.response.json() - print(f" Error details: {json.dumps(error_details, indent=2)}") - except: - print(f" Error response: {e.response.text}") - return False - - -def test_service_health(): - """Test basic service health endpoints.""" - print("\n🏥 Testing service health endpoints...") - - base_url = "http://ml_backend_example:2000" - - # Test health endpoints - health_endpoints = ["/", "/livez", "/readyz"] - - for endpoint in health_endpoints: - try: - response = requests.get(f"{base_url}{endpoint}", timeout=10) - response.raise_for_status() - print(f"✅ {endpoint}: {response.status_code}") - except Exception as e: - print(f"❌ {endpoint}: {str(e)}") - return False - - return True - - -if __name__ == "__main__": - print("🧪 Starting API Integration Tests for Global Moth Classifier") - print("=" * 60) - - # Test service health first - health_ok = test_service_health() - if not health_ok: - print("\n❌ Service health checks failed!") - sys.exit(1) - - # Test main API integration - api_ok = test_api_integration() - - print("\n" + "=" * 60) - if api_ok: - print("🎉 All API integration tests PASSED!") - sys.exit(0) - else: - print("❌ API integration tests FAILED!") - sys.exit(1) \ No newline at end of file diff --git a/processing_services/moths/test_global_moth_pipeline.py b/processing_services/moths/test_global_moth_pipeline.py deleted file mode 100644 index 1886e5823..000000000 --- a/processing_services/moths/test_global_moth_pipeline.py +++ /dev/null @@ -1,109 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for the Global Moth Classifier Pipeline. -This test processes a real moth image and validates the full pipeline functionality. -""" - -import pathlib -import sys - -# Add the processing_services/example to the path -sys.path.insert(0, str(pathlib.Path(__file__).parent.parent)) - -from api.pipelines import ZeroShotObjectDetectorWithGlobalMothClassifierPipeline -from api.schemas import SourceImage -from api.utils import get_image - - -def test_global_moth_pipeline(): - """Test the Global Moth Classifier Pipeline with a real request.""" - print("🧪 Testing Global Moth Classifier Pipeline with real request...") - - # Create source image from the provided URL - source_image = SourceImage( - id="123", - url="https://archive.org/download/mma_various_moths_and_butterflies_54143/54143.jpg", - width=800, # Typical image dimensions - height=600 - ) - - # Load the PIL image and attach it to the source image - print("📥 Loading image from URL...") - pil_image = get_image(url=source_image.url) - if pil_image: - source_image._pil = pil_image - # Update dimensions with actual image size - source_image.width = pil_image.width - source_image.height = pil_image.height - print(f"✅ Image loaded: {pil_image.width}x{pil_image.height}") - else: - print("❌ Failed to load image") - return False - - # Create pipeline with the test configuration - pipeline = ZeroShotObjectDetectorWithGlobalMothClassifierPipeline( - source_images=[source_image], - request_config={ - "auth_token": "abc123", - "force_reprocess": True, - "candidate_labels": ["moth", "butterfly", "insect"] # Add candidate labels for detection - }, - existing_detections=[], - ) - - print("✅ Pipeline instantiated successfully!") - print(f" Pipeline name: {pipeline.config.name}") - print(f" Pipeline slug: {pipeline.config.slug}") - print(f" Number of algorithms: {len(pipeline.config.algorithms)}") - print(f" Algorithm 1: {pipeline.config.algorithms[0].name}") - print(f" Algorithm 2: {pipeline.config.algorithms[1].name}") - - # Test that stages can be created - stages = pipeline.get_stages() - assert len(stages) == 2 - print(f" Stages created: {len(stages)}") - - # Compile the pipeline (load models) - print("🔧 Compiling pipeline (loading models)...") - pipeline.compile() - print("✅ Pipeline compiled successfully!") - - # Run the pipeline - print("🚀 Running pipeline on test image...") - try: - result = pipeline.run() - print("✅ Pipeline execution completed!") - print(f" Total processing time: {result.total_time:.2f}s") - print(f" Number of detections: {len(result.detections)}") - - # Print detection details - for i, detection in enumerate(result.detections): - print(f" Detection {i+1}:") - print(f" - Bbox: {detection.bbox}") - print(f" - Inference time: {detection.inference_time:.3f}s") - print(f" - Algorithm: {detection.algorithm}") - if detection.classifications: - # Get the classification with the highest score - top_classification = detection.classifications[0] # Usually sorted by confidence - if top_classification.scores: - max_score = max(top_classification.scores) - max_idx = top_classification.scores.index(max_score) - if top_classification.labels and max_idx < len(top_classification.labels): - species_name = top_classification.labels[max_idx] - print(f" - Top classification: {species_name} ({max_score:.3f})") - else: - print(f" - Classification: {top_classification.classification}") - else: - print(f" - Classification: {top_classification.classification}") - - return True - - except Exception as e: - print(f"❌ Pipeline execution failed: {str(e)}") - import traceback - traceback.print_exc() - return False - - -if __name__ == "__main__": - test_global_moth_pipeline() From 9bd828b49c45bffd510ec10ec93337b91fc3e047 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 13 Oct 2025 15:39:54 -0700 Subject: [PATCH 3/3] fix: prefer black formatter's whitespace style --- processing_services/example/api/global_moth_classifier.py | 2 +- setup.cfg | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/processing_services/example/api/global_moth_classifier.py b/processing_services/example/api/global_moth_classifier.py index 586cc9cfd..f878cf764 100644 --- a/processing_services/example/api/global_moth_classifier.py +++ b/processing_services/example/api/global_moth_classifier.py @@ -153,7 +153,7 @@ def run(self, detections: list[Detection]) -> list[Detection]: classified_detections = [] for i in range(0, len(detections), self.batch_size): - batch_detections = detections[i: i + self.batch_size] + batch_detections = detections[i : i + self.batch_size] batch_images = [] # Prepare batch of images diff --git a/setup.cfg b/setup.cfg index 829064213..d50c5c183 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,6 +4,8 @@ [flake8] max-line-length = 119 exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv,.venv +# E203: whitespace before ':' (conflicts with Black's slice formatting) +extend-ignore = E203, W503 [pycodestyle] max-line-length = 119