-
Notifications
You must be signed in to change notification settings - Fork 11
Implement class masking using the post-processing framework #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…en creating a terminal classification with the rolled up taxon
…gress, and algorithm binding
…f.logger and progress updates
…g and progress tracking
…ss-masking branch)
✅ Deploy Preview for antenna-preview canceled.
|
0b77504 to
88ffba8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Implements class masking as a post-processing task that recalculates classifications by masking out classes not present in a provided taxa list and updates occurrences accordingly.
- Adds ClassMaskingTask to the post-processing framework and registers it.
- Filters and recalculates logits/scores per taxa list, creates new terminal classifications, and updates occurrences.
- Minor logging update in job runner.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 10 comments.
| File | Description |
|---|---|
| ami/ml/post_processing/class_masking.py | New class masking task and supporting functions to filter classifications by taxa list and recompute softmax. |
| ami/ml/post_processing/init.py | Registers the new class_masking task module. |
| ami/jobs/models.py | Improves log line to print only the task config for post-processing. |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| top_index = scores.index(max(scores)) | ||
| top_taxon = category_map_with_taxa[top_index][ | ||
| "taxon" | ||
| ] # @TODO: This doesn't work if the taxon has never been classified | ||
| print("Top taxon: ", category_map_with_taxa[top_index]) # @TODO: REMOVE | ||
| print("Top index: ", top_index) # @TODO: REMOVE |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argmax is computed across all categories, so an excluded class can still be selected as the top taxon. If all categories are excluded, the current approach will select an arbitrary class. Restrict the selection to indices whose taxa are in taxa_in_list and handle the 'all-excluded' case gracefully (skip creating a new classification or mark appropriately). For example:
- Build allowed_indices = [i for i, c in enumerate(category_map_with_taxa) if c['taxon'] in taxa_in_list]
- Mask logits for non-allowed indices with -np.inf, recompute softmax over the allowed set, and if allowed_indices is empty, skip this classification.
| logger.info(f"Found {len(classifications)} terminal classifications with scores to update.") | ||
|
|
||
| if not classifications: |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
len(classifications) executes an extra COUNT query and if not classifications triggers a potentially expensive truthiness evaluation on a QuerySet. Use count() once for logging and a zero check (or exists() if you don't need the exact number) to avoid double evaluation, e.g., count = classifications.count(); if count == 0: ...
| logger.info(f"Found {len(classifications)} terminal classifications with scores to update.") | |
| if not classifications: | |
| count = classifications.count() | |
| logger.info(f"Found {count} terminal classifications with scores to update.") | |
| if count == 0: |
| scores, logits = classification.scores, classification.logits | ||
| # Set scores and logits to zero if they are not in the filtered category indices | ||
|
|
||
| import numpy as np |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing inside the processing loop adds overhead each iteration. Move these imports to the module top and prefer using np.exp and np.sum directly for consistency, e.g., import numpy as np at the top and use np.exp / np.sum.
| from numpy import exp | ||
| from numpy import sum as np_sum |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing inside the processing loop adds overhead each iteration. Move these imports to the module top and prefer using np.exp and np.sum directly for consistency, e.g., import numpy as np at the top and use np.exp / np.sum.
| "taxon" | ||
| ] # @TODO: This doesn't work if the taxon has never been classified | ||
| print("Top taxon: ", category_map_with_taxa[top_index]) # @TODO: REMOVE | ||
| print("Top index: ", top_index) # @TODO: REMOVE |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid print statements in production code; use logger.debug(...) to keep logs consistent and configurable.
| print("Top index: ", top_index) # @TODO: REMOVE | |
| logger.debug(f"Top index: {top_index}") |
| assert new_classification.detection.occurrence is not None | ||
| occurrences_to_update.add(new_classification.detection.occurrence) | ||
|
|
||
| logging.info( |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This uses the root logging module instead of the module logger or the provided task_logger, making log output inconsistent. Replace with logger.info(...) or task_logger.info(...).
| logging.info( | |
| task_logger.info( |
| # Get the classifications for the occurrence in the collection | ||
| classifications = Classification.objects.filter( | ||
| detection__occurrence=occurrence, | ||
| terminal=True, | ||
| algorithm=algorithm, | ||
| scores__isnull=False, | ||
| ).distinct() |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You validate that logits is a list later and raise if not, but the query doesn't exclude classifications with null logits. Add logits__isnull=False to avoid unnecessary processing failures.
| terminal=True, | ||
| # algorithm__task_type="classification", | ||
| algorithm=algorithm, | ||
| scores__isnull=False, |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mirror the logits presence guard here as well to avoid raising later when logits is missing: add logits__isnull=False to the filter.
| scores__isnull=False, | |
| scores__isnull=False, | |
| logits__isnull=False, |
| updated_at=timestamp, | ||
| ) | ||
| if new_classification.taxon is None: | ||
| raise (ValueError("Classification isn't registered yet. Aborting")) # @TODO remove or fail gracefully |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message is unclear for the actual failure mode. Clarify to something actionable, e.g., raise ValueError('Unable to determine top taxon after class masking (no allowed classes). Aborting.').
| raise (ValueError("Classification isn't registered yet. Aborting")) # @TODO remove or fail gracefully | |
| raise ValueError("Unable to determine top taxon after class masking (no allowed classes). Aborting.") |
| if classifications_to_update: | ||
| logger.info(f"Bulk updating {len(classifications_to_update)} existing classifications") | ||
| Classification.objects.bulk_update(classifications_to_update, ["terminal", "updated_at"]) | ||
| logger.info(f"Updated {len(classifications_to_update)} existing classifications") | ||
|
|
||
| if classifications_to_add: | ||
| # Bulk create the new classifications | ||
| logger.info(f"Bulk creating {len(classifications_to_add)} new classifications") | ||
| Classification.objects.bulk_create(classifications_to_add) | ||
| logger.info(f"Added {len(classifications_to_add)} new classifications") |
Copilot
AI
Oct 15, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider wrapping the bulk_update and bulk_create in a single transaction to keep updates atomic and avoid partial state if an error occurs later (e.g., during occurrence updates). For example: with transaction.atomic(): ... bulk_update ... bulk_create ....
| # Update the occurrence determinations | ||
| logger.info(f"Updating the determinations for {len(occurrences_to_update)} occurrences") | ||
| for occurrence in occurrences_to_update: | ||
| occurrence.save(update_determination=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mohamedelabbas1996 here is how I updated all of the determinations previously
Summary
This PR implements Class Masking as part of the post-processing framework.
List of Changes
TBD
Related Issues
TBD
Detailed Description
TBD
How to Test the Changes
TBD
Screenshots
TBD
Deployment Notes
TBD
Checklist