diff --git a/processing_services/example/.gitignore b/processing_services/example/.gitignore new file mode 100644 index 000000000..95312cdaf --- /dev/null +++ b/processing_services/example/.gitignore @@ -0,0 +1,4 @@ +# Cache directories for models and dependencies +cache/ +huggingface_cache/ +pytorch_cache/ diff --git a/processing_services/example/api/api.py b/processing_services/example/api/api.py index 79ce5d83c..43a6205db 100644 --- a/processing_services/example/api/api.py +++ b/processing_services/example/api/api.py @@ -11,6 +11,7 @@ ZeroShotHFClassifierPipeline, ZeroShotObjectDetectorPipeline, ZeroShotObjectDetectorWithConstantClassifierPipeline, + ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, ) from .schemas import ( @@ -41,6 +42,7 @@ ZeroShotObjectDetectorPipeline, ZeroShotObjectDetectorWithConstantClassifierPipeline, ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline, + ZeroShotObjectDetectorWithGlobalMothClassifierPipeline, ] pipeline_choices: dict[str, type[Pipeline]] = {pipeline.config.slug: pipeline for pipeline in pipelines} algorithm_choices: dict[str, AlgorithmConfigResponse] = { diff --git a/processing_services/example/api/base.py b/processing_services/example/api/base.py new file mode 100644 index 000000000..f8e781dc5 --- /dev/null +++ b/processing_services/example/api/base.py @@ -0,0 +1,237 @@ +""" +Simplified base classes for inference models without database dependencies. +Adapted from trapdata.ml.models.base but streamlined for processing service use. +""" + +import json +import logging +from typing import Any + +import torch +import torchvision.transforms + +from .utils import get_best_device, get_or_download_file + +logger = logging.getLogger(__name__) + + +# Standard normalization transforms +imagenet_normalization = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], # RGB + std=[0.229, 0.224, 0.225], # RGB +) + +tensorflow_normalization = torchvision.transforms.Normalize( + mean=[0.5, 0.5, 0.5], # RGB + std=[0.5, 0.5, 0.5], # RGB +) + + +class SimplifiedInferenceBase: + """ + Simplified base class for inference models without database or queue dependencies. + """ + + name: str = "Unknown Inference Model" + description: str = "" + weights_path: str | None = None + labels_path: str | None = None + category_map: dict[int, str] = {} + num_classes: int | None = None + default_taxon_rank: str = "SPECIES" + normalization = tensorflow_normalization + batch_size: int = 4 + device: str | None = None + + def __init__(self, **kwargs): + # Override any class attributes with provided kwargs + for k, v in kwargs.items(): + setattr(self, k, v) + + logger.info(f"Initializing simplified inference class {self.name}") + + self.device = self.device or get_best_device() + self.category_map = self.get_labels(self.labels_path) + self.num_classes = self.num_classes or len(self.category_map) + self.weights = self.get_weights(self.weights_path) + self.transforms = self.get_transforms() + + logger.info(f"Loading model for {self.name} with {len(self.category_map or [])} categories") + self.model = self.get_model() + + @classmethod + def get_key(cls) -> str: + """Generate a unique key for this algorithm.""" + if hasattr(cls, "key") and cls.key: + return cls.key + else: + return cls.name.lower().replace(" ", "-").replace("/", "-") + + def get_weights(self, weights_path: str | None) -> str | None: + """Download and cache model weights.""" + if weights_path: + logger.info(f"⬇️ Downloading model weights from: {weights_path}") + weights_file = str(get_or_download_file(weights_path, tempdir_prefix="models")) + logger.info(f"✅ Model weights downloaded to: {weights_file}") + return weights_file + else: + logger.warning(f"No weights specified for model {self.name}") + return None + + def get_labels(self, labels_path: str | None) -> dict[int, str]: + """Download and load category labels.""" + if not labels_path: + return {} + + logger.info(f"⬇️ Downloading category labels from: {labels_path}") + local_path = get_or_download_file(labels_path, tempdir_prefix="models") + logger.info(f"📝 Loading category labels from: {local_path}") + + with open(local_path) as f: + labels = json.load(f) + + # Convert label->index mapping to index->label mapping + index_to_label = {index: label for label, index in labels.items()} + logger.info(f"✅ Loaded {len(index_to_label)} category labels") + return index_to_label + + def get_transforms(self) -> torchvision.transforms.Compose: + """Get image preprocessing transforms.""" + return torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((224, 224)), + torchvision.transforms.ToTensor(), + self.normalization, + ] + ) + + def get_model(self) -> torch.nn.Module: + """ + Load and return the PyTorch model. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement get_model()") + + def predict_batch(self, batch: torch.Tensor) -> torch.Tensor: + """ + Run inference on a batch of images. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement predict_batch()") + + def post_process_batch(self, logits: torch.Tensor) -> Any: + """ + Post-process model outputs. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement post_process_batch()") + + +class ResNet50Base(SimplifiedInferenceBase): + """ + Base class for ResNet50-based models. + """ + + input_size: int = 224 + normalization = imagenet_normalization + + def get_transforms(self) -> torchvision.transforms.Compose: + """Get ResNet50-specific transforms.""" + return torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((self.input_size, self.input_size)), + torchvision.transforms.ToTensor(), + self.normalization, + ] + ) + + def get_model(self) -> torch.nn.Module: + """Load ResNet50 model with custom classifier.""" + import torchvision.models as models + + logger.info("🏗️ Creating ResNet50 model architecture...") + # Create ResNet50 backbone + model = models.resnet50(weights=None) + + # Replace final classifier layer + if self.num_classes is None: + raise ValueError("num_classes must be set before loading model") + logger.info(f"🔧 Setting up classifier layer for {self.num_classes} classes...") + model.fc = torch.nn.Linear(model.fc.in_features, self.num_classes) + + # Load pretrained weights + if self.weights: + logger.info(f"📂 Loading pretrained weights from: {self.weights}") + checkpoint = torch.load(self.weights, map_location=self.device) + + # Handle different checkpoint formats + if "model_state_dict" in checkpoint: + logger.info("📥 Loading state dict from 'model_state_dict' key...") + model.load_state_dict(checkpoint["model_state_dict"]) + elif "state_dict" in checkpoint: + logger.info("📥 Loading state dict from 'state_dict' key...") + model.load_state_dict(checkpoint["state_dict"]) + else: + logger.info("📥 Loading state dict directly...") + model.load_state_dict(checkpoint) + logger.info("✅ Model weights loaded successfully!") + else: + logger.warning("⚠️ No pretrained weights provided - using random initialization") + + logger.info(f"📱 Moving model to device: {self.device}") + model = model.to(self.device) + model.eval() + logger.info("✅ Model ready for inference!") + return model + + def predict_batch(self, batch: torch.Tensor) -> torch.Tensor: + """Run inference on batch.""" + with torch.no_grad(): + batch = batch.to(self.device) + outputs = self.model(batch) + return outputs + + def post_process_batch(self, logits: torch.Tensor) -> list: + """Convert logits to predictions.""" + probabilities = torch.softmax(logits, dim=1) + predictions = [] + + for prob_tensor in probabilities: + prob_list = prob_tensor.cpu().numpy().tolist() + predictions.append( + { + "scores": prob_list, + "logits": logits[len(predictions)].cpu().numpy().tolist(), + } + ) + + return predictions + + +class TimmResNet50Base(ResNet50Base): + """ + Base class for timm ResNet50-based models. + """ + + def get_model(self) -> torch.nn.Module: + """Load timm ResNet50 model.""" + import timm + + # Create timm ResNet50 model + model = timm.create_model("resnet50", pretrained=False, num_classes=self.num_classes) + + # Load pretrained weights + if self.weights: + checkpoint = torch.load(self.weights, map_location=self.device) + + # Handle different checkpoint formats + if "model_state_dict" in checkpoint: + model.load_state_dict(checkpoint["model_state_dict"]) + elif "state_dict" in checkpoint: + model.load_state_dict(checkpoint["state_dict"]) + else: + model.load_state_dict(checkpoint) + + model = model.to(self.device) + model.eval() + return model diff --git a/processing_services/example/api/global_moth_classifier.py b/processing_services/example/api/global_moth_classifier.py new file mode 100644 index 000000000..f878cf764 --- /dev/null +++ b/processing_services/example/api/global_moth_classifier.py @@ -0,0 +1,249 @@ +""" +Global Moth Classifier algorithm implementation. +Simplified version of trapdata.api.models.classification.MothClassifierGlobal +adapted for the processing service framework. +""" + +import datetime +import logging + +import torch +import torchvision.transforms + +from .algorithms import Algorithm +from .base import TimmResNet50Base, imagenet_normalization +from .schemas import ( + AlgorithmCategoryMapResponse, + AlgorithmConfigResponse, + AlgorithmReference, + ClassificationResponse, + Detection, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class GlobalMothClassifier(Algorithm, TimmResNet50Base): + """ + Global Moth Species Classifier. + + Simplified version of the trapdata GlobalMothSpeciesClassifier + that works without database dependencies. + """ + + name = "Global Species Classifier - Aug 2024" + description = ( + "Trained on August 28th, 2024 for 29,176 species. " + "https://wandb.ai/moth-ai/global-moth-classifier/runs/h0cuqrbc/overview" + ) + weights_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" + "global_resnet50_20240828_b06d3b3a.pth" + ) + labels_path = ( + "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" + "global_category_map_with_names_20240828.json" + ) + + # Model configuration + input_size = 128 + normalization = imagenet_normalization + default_taxon_rank = "SPECIES" + batch_size = 4 + + def __init__(self, **kwargs): + """Initialize the global moth classifier.""" + # Initialize Algorithm parent class + Algorithm.__init__(self) + + # Store kwargs for later use in compile() + self._init_kwargs = kwargs + + # Initialize basic attributes without loading model + self.model = None + self.transforms = None + self.category_map = {} # Empty dict for now + self.num_classes = 29176 # Known number of classes + + logger.info(f"Initialized {self.name} (model loading deferred to compile())") + + @property + def algorithm_config_response(self) -> AlgorithmConfigResponse: + """Get algorithm configuration for API response.""" + if not hasattr(self, "_algorithm_config_response"): + # Create a basic config response before compilation + self._algorithm_config_response = AlgorithmConfigResponse( + name=self.name, + key=self.get_key(), + task_type="classification", + description=self.description, + version=1, + version_name="v1", + category_map=AlgorithmCategoryMapResponse( + data=[], + labels=[], + version="v1", + description="Global moth species classifier (not yet compiled)", + uri=self.labels_path, + ), + uri=self.weights_path, + ) + return self._algorithm_config_response + + @algorithm_config_response.setter + def algorithm_config_response(self, value: AlgorithmConfigResponse): + """Set algorithm configuration response.""" + self._algorithm_config_response = value + + def compile(self): + """Load model weights and initialize transforms (called by pipeline).""" + if self.model is not None: + logger.info("Model already compiled, skipping...") + return + + logger.info(f"🔧 Compiling {self.name}...") + logger.info(f" 📊 Expected classes: {self.num_classes}") + logger.info(f" 🏷️ Labels URL: {self.labels_path}") + logger.info(f" ⚖️ Weights URL: {self.weights_path}") + + # Initialize the TimmResNet50Base now (this will download weights/labels) + logger.info(" 📥 Downloading model weights and labels...") + TimmResNet50Base.__init__(self, **self._init_kwargs) + + # Set algorithm config response + logger.info(" 📋 Setting up algorithm configuration...") + self.algorithm_config_response = self.get_algorithm_config_response() + + logger.info(f"✅ {self.name} compiled successfully!") + logger.info(f" 📊 Loaded {len(self.category_map)} species categories") + logger.info(f" 🔧 Model device: {getattr(self, 'device', 'unknown')}") + logger.info(f" 🖼️ Input size: {self.input_size}x{self.input_size}") + + def get_transforms(self) -> torchvision.transforms.Compose: + """Get transforms specific to this model.""" + return torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((self.input_size, self.input_size)), + torchvision.transforms.ToTensor(), + self.normalization, + ] + ) + + def run(self, detections: list[Detection]) -> list[Detection]: + """ + Run classification on a list of detections. + + Args: + detections: List of Detection objects with cropped images + + Returns: + List of Detection objects with added classifications + """ + if not detections: + return [] + + # Ensure model is compiled + if self.model is None: + raise RuntimeError("Model not compiled. Call compile() first.") + + logger.info(f"Running {self.name} on {len(detections)} detections") + + # Process detections in batches + classified_detections = [] + + for i in range(0, len(detections), self.batch_size): + batch_detections = detections[i : i + self.batch_size] + batch_images = [] + + # Prepare batch of images + for detection in batch_detections: + if detection._pil: + # Convert to RGB if needed + if detection._pil.mode != "RGB": + img = detection._pil.convert("RGB") + else: + img = detection._pil + batch_images.append(img) + else: + logger.warning(f"Detection {detection.id} has no PIL image") + continue + + if not batch_images: + continue + + # Transform images + if self.transforms is None: + raise RuntimeError("Transforms not initialized. Call compile() first.") + batch_tensor = torch.stack([self.transforms(img) for img in batch_images]) + + # Run inference + start_time = datetime.datetime.now() + predictions = self.predict_batch(batch_tensor) + processed_predictions = self.post_process_batch(predictions) + end_time = datetime.datetime.now() + + inference_time = (end_time - start_time).total_seconds() / len(batch_images) + + # Add classifications to detections + for detection, prediction in zip(batch_detections, processed_predictions): + # Get best prediction + best_score = max(prediction["scores"]) + best_idx = prediction["scores"].index(best_score) + best_label = self.category_map.get(best_idx, f"class_{best_idx}") + + classification = ClassificationResponse( + classification=best_label, + labels=[best_label], + scores=[best_score], + logits=prediction["logits"], + inference_time=inference_time, + timestamp=datetime.datetime.now(), + algorithm=AlgorithmReference( + name=self.name, + key=self.get_key(), + ), + terminal=True, + ) + + # Add classification to detection + detection_with_classification = detection.copy(deep=True) + detection_with_classification.classifications = [classification] + classified_detections.append(detection_with_classification) + + logger.info(f"Classified {len(classified_detections)} detections") + return classified_detections + + def get_category_map(self) -> AlgorithmCategoryMapResponse: + """Get category map for API response.""" + categories_sorted_by_index = sorted(self.category_map.items(), key=lambda x: x[0]) + categories_data = [ + { + "index": index, + "label": label, + "taxon_rank": self.default_taxon_rank, + } + for index, label in categories_sorted_by_index + ] + label_strings = [cat["label"] for cat in categories_data] + + return AlgorithmCategoryMapResponse( + data=categories_data, + labels=label_strings, + version="v1", + description=f"Global moth species classifier with {len(categories_data)} species", + uri=self.labels_path, + ) + + def get_algorithm_config_response(self) -> AlgorithmConfigResponse: + """Get algorithm configuration for API response.""" + return AlgorithmConfigResponse( + name=self.name, + key=self.get_key(), + task_type="classification", + description=self.description, + version=1, + version_name="v1", + category_map=self.get_category_map(), + uri=self.weights_path, + ) diff --git a/processing_services/example/api/pipelines.py b/processing_services/example/api/pipelines.py index 02b31d0d9..a983009e2 100644 --- a/processing_services/example/api/pipelines.py +++ b/processing_services/example/api/pipelines.py @@ -9,6 +9,7 @@ RandomSpeciesClassifier, ZeroShotObjectDetector, ) +from .global_moth_classifier import GlobalMothClassifier from .schemas import ( Detection, DetectionResponse, @@ -346,3 +347,64 @@ def run(self) -> PipelineResultsResponse: logger.info(f"Successfully processed {len(detections_with_classifications)} detections.") return pipeline_response + + +class ZeroShotObjectDetectorWithGlobalMothClassifierPipeline(Pipeline): + """ + A pipeline that uses the HuggingFace zero shot object detector and the global moth classifier. + This provides high-quality moth species identification with 29,176+ species support. + """ + + batch_sizes = [1, 4] # Detector batch=1, Classifier batch=4 + config = PipelineConfigResponse( + name="Zero Shot Object Detector With Global Moth Classifier Pipeline", + slug="zero-shot-object-detector-with-global-moth-classifier-pipeline", + description=( + "HF zero shot object detector with global moth species classifier. " + "Supports 29,176+ moth species trained on global data." + ), + version=1, + algorithms=[], # Will be populated in get_stages() + ) + + def get_stages(self) -> list[Algorithm]: + zero_shot_object_detector = ZeroShotObjectDetector() + if "candidate_labels" in self.request_config: + zero_shot_object_detector.candidate_labels = self.request_config["candidate_labels"] + + global_moth_classifier = GlobalMothClassifier() + + self.config.algorithms = [ + zero_shot_object_detector.algorithm_config_response, + global_moth_classifier.algorithm_config_response, + ] + + return [zero_shot_object_detector, global_moth_classifier] + + def run(self) -> PipelineResultsResponse: + start_time = datetime.datetime.now() + detections: list[Detection] = [] + + if self.existing_detections: + logger.info("[1/2] Skipping the localizer, use existing detections...") + detections = self.existing_detections + else: + logger.info("[1/2] No existing detections, generating detections...") + detections = self._get_detections(self.stages[0], self.source_images, self.batch_sizes[0]) + + logger.info("[2/2] Running the global moth classifier...") + detections_with_classifications: list[Detection] = self._get_detections( + self.stages[1], detections, self.batch_sizes[1] + ) + + end_time = datetime.datetime.now() + elapsed_time = (end_time - start_time).total_seconds() + + pipeline_response: PipelineResultsResponse = self._get_pipeline_response( + detections_with_classifications, elapsed_time + ) + logger.info( + f"Successfully processed {len(detections_with_classifications)} detections with global moth classifier." + ) + + return pipeline_response diff --git a/processing_services/example/api/schemas.py b/processing_services/example/api/schemas.py index 05682c6f1..e56e16c85 100644 --- a/processing_services/example/api/schemas.py +++ b/processing_services/example/api/schemas.py @@ -213,6 +213,7 @@ class Config: "zero-shot-object-detector-pipeline", "zero-shot-object-detector-with-constant-classifier-pipeline", "zero-shot-object-detector-with-random-species-classifier-pipeline", + "zero-shot-object-detector-with-global-moth-classifier-pipeline", ] diff --git a/processing_services/example/api/utils.py b/processing_services/example/api/utils.py index a7fcb6a75..47350e781 100644 --- a/processing_services/example/api/utils.py +++ b/processing_services/example/api/utils.py @@ -10,6 +10,7 @@ import PIL.Image import PIL.ImageFile import requests +import torch logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -56,31 +57,53 @@ def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path: # If path is a local path instead of a URL then urlretrieve will just return that path - destination_dir = pathlib.Path(tempfile.mkdtemp(prefix=tempdir_prefix)) + # Use persistent cache directory instead of temp + cache_root = ( + pathlib.Path("/app/cache") if pathlib.Path("/app").exists() else pathlib.Path.home() / ".antenna_cache" + ) + destination_dir = cache_root / tempdir_prefix fname = pathlib.Path(urlparse(path_or_url).path).name if not destination_dir.exists(): destination_dir.mkdir(parents=True, exist_ok=True) - local_filepath = pathlib.Path(destination_dir) / fname + local_filepath = destination_dir / fname if local_filepath and local_filepath.exists(): - logger.info(f"Using existing {local_filepath}") + logger.info(f"📁 Using cached file: {local_filepath}") return local_filepath else: - logger.info(f"Downloading {path_or_url} to {local_filepath}") + logger.info(f"⬇️ Downloading {path_or_url} to {local_filepath}") headers = {"User-Agent": USER_AGENT} response = requests.get(path_or_url, stream=True, headers=headers) response.raise_for_status() # Raise an exception for HTTP errors with open(local_filepath, "wb") as f: + total_size = int(response.headers.get("content-length", 0)) + downloaded = 0 for chunk in response.iter_content(chunk_size=8192): f.write(chunk) + downloaded += len(chunk) + if total_size > 0: + percent = (downloaded / total_size) * 100 + logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)") resulting_filepath = pathlib.Path(local_filepath).resolve() + logger.info(f"✅ Download completed: {resulting_filepath}") logger.info(f"Downloaded to {resulting_filepath}") return resulting_filepath +def get_best_device() -> str: + """ + Returns the best available device for running the model. + MPS is not supported by the current algorithms. + """ + if torch.cuda.is_available(): + return f"cuda:{torch.cuda.current_device()}" + else: + return "cpu" + + def open_image(fp: str | bytes | pathlib.Path | io.BytesIO, raise_exception: bool = True) -> PIL.Image.Image | None: """ Wrapper from PIL.Image.open that handles errors and converts to RGB. diff --git a/processing_services/example/docker-compose.yml b/processing_services/example/docker-compose.yml index 83db6ccfa..2df8b49e9 100644 --- a/processing_services/example/docker-compose.yml +++ b/processing_services/example/docker-compose.yml @@ -4,6 +4,7 @@ services: context: . volumes: - ./:/app:z + - ./cache:/app/cache # Model weights and labels cache - ./huggingface_cache:/root/.cache/huggingface - ./pytorch_cache:/root/.cache/torch ports: diff --git a/processing_services/example/requirements.txt b/processing_services/example/requirements.txt index eccbee47a..251693690 100644 --- a/processing_services/example/requirements.txt +++ b/processing_services/example/requirements.txt @@ -6,4 +6,5 @@ requests==2.32.4 transformers==4.50.3 torch==2.6.0 torchvision==0.21.0 +timm==1.0.11 scipy==1.16.0 diff --git a/setup.cfg b/setup.cfg index 829064213..d50c5c183 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,6 +4,8 @@ [flake8] max-line-length = 119 exclude = .tox,.git,*/migrations/*,*/static/CACHE/*,docs,node_modules,venv,.venv +# E203: whitespace before ':' (conflicts with Black's slice formatting) +extend-ignore = E203, W503 [pycodestyle] max-line-length = 119