From 42aaf02b218c8cafb5a3279284cdc5ffe0f060ae Mon Sep 17 00:00:00 2001 From: benITo47 Date: Fri, 20 Mar 2026 10:08:10 +0100 Subject: [PATCH 1/5] feat: Add multimethod handling to ObjectDetection --- .../app/object_detection/index.tsx | 2 + .../app/vision_camera/index.tsx | 7 +- .../tasks/ObjectDetectionTask.tsx | 17 +- .../object_detection/ObjectDetection.cpp | 158 ++++++++++----- .../models/object_detection/ObjectDetection.h | 65 +++++- .../src/constants/modelUrls.ts | 47 +++++ .../computer_vision/useObjectDetection.ts | 12 +- .../computer_vision/ObjectDetectionModule.ts | 189 +++++++++++++++++- .../src/types/objectDetection.ts | 85 ++++++-- 9 files changed, 504 insertions(+), 78 deletions(-) diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index c6ec9f1dc3..399939cfcb 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -7,6 +7,7 @@ import { useObjectDetection, RF_DETR_NANO, SSDLITE_320_MOBILENET_V3_LARGE, + YOLO26N, ObjectDetectionModelSources, } from 'react-native-executorch'; import { View, StyleSheet, Image } from 'react-native'; @@ -18,6 +19,7 @@ import ScreenWrapper from '../../ScreenWrapper'; const MODELS: ModelOption[] = [ { label: 'RF-DeTR Nano', value: RF_DETR_NANO }, { label: 'SSDLite MobileNet', value: SSDLITE_320_MOBILENET_V3_LARGE }, + { label: 'YOLO26N', value: YOLO26N }, ]; export default function ObjectDetectionScreen() { diff --git a/apps/computer-vision/app/vision_camera/index.tsx b/apps/computer-vision/app/vision_camera/index.tsx index dbd969ad09..c13e55925c 100644 --- a/apps/computer-vision/app/vision_camera/index.tsx +++ b/apps/computer-vision/app/vision_camera/index.tsx @@ -46,6 +46,7 @@ type ModelId = | 'classification' | 'objectDetectionSsdlite' | 'objectDetectionRfdetr' + | 'objectDetectionYolo26n' | 'segmentationDeeplabResnet50' | 'segmentationDeeplabResnet101' | 'segmentationDeeplabMobilenet' @@ -95,6 +96,7 @@ const TASKS: Task[] = [ variants: [ { id: 'objectDetectionSsdlite', label: 'SSDLite MobileNet' }, { id: 'objectDetectionRfdetr', label: 'RF-DETR Nano' }, + { id: 'objectDetectionYolo26n', label: 'YOLO26N' }, ], }, { @@ -241,7 +243,10 @@ export default function VisionCameraScreen() { )} diff --git a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx index 0155be7e46..243a3ee09d 100644 --- a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx +++ b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx @@ -6,12 +6,16 @@ import { Detection, RF_DETR_NANO, SSDLITE_320_MOBILENET_V3_LARGE, + YOLO26N, useObjectDetection, } from 'react-native-executorch'; import { labelColor, labelColorBg } from '../utils/colors'; import { TaskProps } from './types'; -type ObjModelId = 'objectDetectionSsdlite' | 'objectDetectionRfdetr'; +type ObjModelId = + | 'objectDetectionSsdlite' + | 'objectDetectionRfdetr' + | 'objectDetectionYolo26n'; type Props = TaskProps & { activeModel: ObjModelId }; @@ -34,8 +38,17 @@ export default function ObjectDetectionTask({ model: RF_DETR_NANO, preventLoad: activeModel !== 'objectDetectionRfdetr', }); + const yolo26n = useObjectDetection({ + model: YOLO26N, + preventLoad: activeModel !== 'objectDetectionYolo26n', + }); - const active = activeModel === 'objectDetectionSsdlite' ? ssdlite : rfdetr; + const active = + activeModel === 'objectDetectionSsdlite' + ? ssdlite + : activeModel === 'objectDetectionRfdetr' + ? rfdetr + : yolo26n; const [detections, setDetections] = useState([]); const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index f81d648bb5..eb0943c2f6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -1,6 +1,8 @@ #include "ObjectDetection.h" #include "Constants.h" +#include + #include #include #include @@ -18,21 +20,6 @@ ObjectDetection::ObjectDetection( std::shared_ptr callInvoker) : VisionModel(modelSource, callInvoker), labelNames_(std::move(labelNames)) { - auto inputTensors = getAllInputShapes(); - if (inputTensors.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - "Model seems to not take any input tensors."); - } - modelInputShape_ = inputTensors[0]; - if (modelInputShape_.size() < 2) { - char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unexpected model input size, expected at least 2 dimensions " - "but got: %zu.", - modelInputShape_.size()); - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - errorMessage); - } if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { @@ -47,14 +34,65 @@ ObjectDetection::ObjectDetection( } } +cv::Size ObjectDetection::modelInputSize() const { + if (currentlyLoadedMethod_.empty()) { + return VisionModel::modelInputSize(); + } + auto inputShapes = getAllInputShapes(currentlyLoadedMethod_); + if (inputShapes.empty() || inputShapes[0].size() < 2) { + return VisionModel::modelInputSize(); + } + const auto &shape = inputShapes[0]; + return {static_cast(shape[shape.size() - 2]), + static_cast(shape[shape.size() - 1])}; +} + +void ObjectDetection::ensureMethodLoaded(const std::string &methodName) { + if (methodName.empty()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "methodName cannot be empty"); + } + if (currentlyLoadedMethod_ == methodName) { + return; + } + if (!module_) { + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, + "Model module is not loaded"); + } + if (!currentlyLoadedMethod_.empty()) { + module_->unload_method(currentlyLoadedMethod_); + } + auto loadResult = module_->load_method(methodName); + if (loadResult != executorch::runtime::Error::Ok) { + throw RnExecutorchError( + loadResult, "Failed to load method '" + methodName + + "'. Ensure the method exists in the exported model."); + } + currentlyLoadedMethod_ = methodName; +} + +std::set ObjectDetection::prepareAllowedClasses( + const std::vector &classIndices) const { + std::set allowedClasses; + if (!classIndices.empty()) { + allowedClasses.insert(classIndices.begin(), classIndices.end()); + } + return allowedClasses; +} + std::vector ObjectDetection::postprocess(const std::vector &tensors, - cv::Size originalSize, double detectionThreshold) { + cv::Size originalSize, double detectionThreshold, + double iouThreshold, + const std::vector &classIndices) { const cv::Size inputSize = modelInputSize(); float widthRatio = static_cast(originalSize.width) / inputSize.width; float heightRatio = static_cast(originalSize.height) / inputSize.height; + // Prepare allowed classes set for filtering + auto allowedClasses = prepareAllowedClasses(classIndices); + std::vector detections; auto bboxTensor = tensors.at(0).toTensor(); std::span bboxes( @@ -75,12 +113,21 @@ ObjectDetection::postprocess(const std::vector &tensors, if (scores[i] < detectionThreshold) { continue; } + + auto labelIdx = static_cast(labels[i]); + + // Filter by class if classesOfInterest is specified + if (!allowedClasses.empty() && + allowedClasses.find(labelIdx) == allowedClasses.end()) { + continue; + } + float x1 = bboxes[i * 4] * widthRatio; float y1 = bboxes[i * 4 + 1] * heightRatio; float x2 = bboxes[i * 4 + 2] * widthRatio; float y2 = bboxes[i * 4 + 3] * heightRatio; - auto labelIdx = static_cast(labels[i]); - if (labelIdx >= labelNames_.size()) { + + if (static_cast(labelIdx) >= labelNames_.size()) { throw RnExecutorchError( RnExecutorchErrorCode::InvalidConfig, "Model output class index " + std::to_string(labelIdx) + @@ -88,23 +135,40 @@ ObjectDetection::postprocess(const std::vector &tensors, ". Ensure the labelMap covers all model output classes."); } detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2}, - labelNames_[labelIdx], - static_cast(labelIdx), scores[i]); + labelNames_[labelIdx], labelIdx, scores[i]); } - return utils::computer_vision::nonMaxSuppression(detections, - constants::IOU_THRESHOLD); + return utils::computer_vision::nonMaxSuppression(detections, iouThreshold); } -std::vector -ObjectDetection::runInference(cv::Mat image, double detectionThreshold) { +std::vector ObjectDetection::runInference( + cv::Mat image, double detectionThreshold, double iouThreshold, + const std::vector &classIndices, const std::string &methodName) { if (detectionThreshold < 0.0 || detectionThreshold > 1.0) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "detectionThreshold must be in range [0, 1]"); } + if (iouThreshold < 0.0 || iouThreshold > 1.0) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "iouThreshold must be in range [0, 1]"); + } + std::scoped_lock lock(inference_mutex_); + // Ensure the correct method is loaded + ensureMethodLoaded(methodName); + cv::Size originalSize = image.size(); + + // Query input shapes for the currently loaded method + auto inputShapes = getAllInputShapes(methodName); + if (inputShapes.empty() || inputShapes[0].size() < 2) { + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, + "Could not determine input shape for method: " + + methodName); + } + modelInputShape_ = inputShapes[0]; + cv::Mat preprocessed = preprocess(image); auto inputTensor = @@ -114,46 +178,50 @@ ObjectDetection::runInference(cv::Mat image, double detectionThreshold) { : image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); - auto forwardResult = BaseModel::forward(inputTensor); - if (!forwardResult.ok()) { - throw RnExecutorchError(forwardResult.error(), - "The model's forward function did not succeed. " - "Ensure the model input is correct."); + auto executeResult = execute(methodName, {inputTensor}); + if (!executeResult.ok()) { + throw RnExecutorchError(executeResult.error(), + "The model's " + methodName + + " method did not succeed. " + "Ensure the model input is correct."); } - return postprocess(forwardResult.get(), originalSize, detectionThreshold); + return postprocess(executeResult.get(), originalSize, detectionThreshold, + iouThreshold, classIndices); } -std::vector -ObjectDetection::generateFromString(std::string imageSource, - double detectionThreshold) { +std::vector ObjectDetection::generateFromString( + std::string imageSource, double detectionThreshold, double iouThreshold, + std::vector classIndices, std::string methodName) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Mat imageRGB; cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); - return runInference(imageRGB, detectionThreshold); + return runInference(imageRGB, detectionThreshold, iouThreshold, classIndices, + methodName); } -std::vector -ObjectDetection::generateFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData, - double detectionThreshold) { - auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); +std::vector ObjectDetection::generateFromFrame( + jsi::Runtime &runtime, const jsi::Value &frameData, + double detectionThreshold, double iouThreshold, + std::vector classIndices, std::string methodName) { cv::Mat frame = extractFromFrame(runtime, frameData); - cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient); - auto detections = runInference(rotated, detectionThreshold); + auto detections = runInference(frame, detectionThreshold, iouThreshold, + classIndices, methodName); + for (auto &det : detections) { ::rnexecutorch::utils::inverseRotateBbox(det.bbox, orient, rotated.size()); } return detections; } -std::vector -ObjectDetection::generateFromPixels(JSTensorViewIn pixelData, - double detectionThreshold) { +std::vector ObjectDetection::generateFromPixels( + JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold, + std::vector classIndices, std::string methodName) { cv::Mat image = extractFromPixels(pixelData); - return runInference(image, detectionThreshold); + return runInference(image, detectionThreshold, iouThreshold, classIndices, + methodName); } } // namespace rnexecutorch::models::object_detection diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index 1a7e72a6db..6e3c01356e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -57,6 +57,13 @@ class ObjectDetection : public VisionModel { * @param imageSource URI or file path of the input image. * @param detectionThreshold Minimum confidence score in (0, 1] for a * detection to be included in the output. + * @param iouThreshold IoU threshold for non-maximum suppression. + * @param classIndices Optional list of class indices to filter results. + * Only detections matching these classes will be + * returned. Pass empty vector to include all + * classes. + * @param methodName Name of the method to execute (e.g., "forward", + * "forward_384", "forward_512", "forward_640"). * * @return A vector of @ref types::Detection objects with bounding boxes, * label strings (resolved via the label names passed to the @@ -66,16 +73,33 @@ class ObjectDetection : public VisionModel { * fails. */ [[nodiscard("Registered non-void function")]] std::vector - generateFromString(std::string imageSource, double detectionThreshold); + generateFromString(std::string imageSource, double detectionThreshold, + double iouThreshold, std::vector classIndices, + std::string methodName); [[nodiscard("Registered non-void function")]] std::vector generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, - double detectionThreshold); + double detectionThreshold, double iouThreshold, + std::vector classIndices, std::string methodName); [[nodiscard("Registered non-void function")]] std::vector - generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold); + generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold, + double iouThreshold, std::vector classIndices, + std::string methodName); protected: - std::vector runInference(cv::Mat image, - double detectionThreshold); + /** + * @brief Returns the model input size based on the currently loaded method. + * + * Overrides VisionModel::modelInputSize() to support multi-method models + * where each method may have different input dimensions. + * + * @return The expected input size for the currently loaded method. + */ + cv::Size modelInputSize() const override; + + std::vector + runInference(cv::Mat image, double detectionThreshold, double iouThreshold, + const std::vector &classIndices, + const std::string &methodName); private: /** @@ -88,15 +112,37 @@ class ObjectDetection : public VisionModel { * bounding boxes back to input coordinates. * @param detectionThreshold Confidence threshold below which detections * are discarded. + * @param iouThreshold IoU threshold for non-maximum suppression. + * @param classIndices Optional list of class indices to filter results. * * @return Non-max-suppressed detections above the threshold. * * @throws RnExecutorchError if the model outputs a class index that exceeds * the size of @ref labelNames_. */ - std::vector postprocess(const std::vector &tensors, - cv::Size originalSize, - double detectionThreshold); + std::vector + postprocess(const std::vector &tensors, cv::Size originalSize, + double detectionThreshold, double iouThreshold, + const std::vector &classIndices); + + /** + * @brief Ensures the specified method is loaded, unloading any previous + * method if necessary. + * + * @param methodName Name of the method to load (e.g., "forward", + * "forward_384"). + * @throws RnExecutorchError if the method cannot be loaded. + */ + void ensureMethodLoaded(const std::string &methodName); + + /** + * @brief Prepares a set of allowed class indices for filtering detections. + * + * @param classIndices Vector of class indices to allow. + * @return A set containing the allowed class indices. + */ + std::set + prepareAllowedClasses(const std::vector &classIndices) const; /// Optional per-channel mean for input normalisation (set in constructor). std::optional normMean_; @@ -106,6 +152,9 @@ class ObjectDetection : public VisionModel { /// Ordered label strings mapping class indices to human-readable names. std::vector labelNames_; + + /// Name of the currently loaded method (for multi-method models). + std::string currentlyLoadedMethod_; }; } // namespace models::object_detection diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 0e4bcdf080..c173b839cb 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -470,6 +470,53 @@ export const RF_DETR_NANO = { modelSource: RF_DETR_NANO_MODEL, } as const; +// YOLO26 Object Detection +const YOLO26N_DETECTION_MODEL = `${URL_PREFIX}-yolo26/${NEXT_VERSION_TAG}/yolo26n/xnnpack/yolo26n.pte`; +const YOLO26S_DETECTION_MODEL = `${URL_PREFIX}-yolo26/${NEXT_VERSION_TAG}/yolo26s/xnnpack/yolo26s.pte`; +const YOLO26M_DETECTION_MODEL = `${URL_PREFIX}-yolo26/${NEXT_VERSION_TAG}/yolo26m/xnnpack/yolo26m.pte`; +const YOLO26L_DETECTION_MODEL = `${URL_PREFIX}-yolo26/${NEXT_VERSION_TAG}/yolo26l/xnnpack/yolo26l.pte`; +const YOLO26X_DETECTION_MODEL = `${URL_PREFIX}-yolo26/${NEXT_VERSION_TAG}/yolo26x/xnnpack/yolo26x.pte`; + +/** + * @category Models - Object Detection + */ +export const YOLO26N = { + modelName: 'yolo26n', + modelSource: YOLO26N_DETECTION_MODEL, +} as const; + +/** + * @category Models - Object Detection + */ +export const YOLO26S = { + modelName: 'yolo26s', + modelSource: YOLO26S_DETECTION_MODEL, +} as const; + +/** + * @category Models - Object Detection + */ +export const YOLO26M = { + modelName: 'yolo26m', + modelSource: YOLO26M_DETECTION_MODEL, +} as const; + +/** + * @category Models - Object Detection + */ +export const YOLO26L = { + modelName: 'yolo26l', + modelSource: YOLO26L_DETECTION_MODEL, +} as const; + +/** + * @category Models - Object Detection + */ +export const YOLO26X = { + modelName: 'yolo26x', + modelSource: YOLO26X_DETECTION_MODEL, +} as const; + // Style transfer const STYLE_TRANSFER_CANDY_MODEL = Platform.OS === `ios` diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts index c19b819318..cc5d69e42a 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts @@ -6,6 +6,7 @@ import { ObjectDetectionModelSources, ObjectDetectionProps, ObjectDetectionType, + ObjectDetectionOptions, } from '../../types/objectDetection'; import { PixelData } from '../../types/common'; import { useModuleFactory } from '../useModuleFactory'; @@ -30,6 +31,7 @@ export const useObjectDetection = ({ downloadProgress, runForward, runOnFrame, + instance, } = useModuleFactory({ factory: (config, onProgress) => ObjectDetectionModule.fromModelName(config, onProgress), @@ -38,8 +40,13 @@ export const useObjectDetection = ({ preventLoad, }); - const forward = (input: string | PixelData, detectionThreshold?: number) => - runForward((inst) => inst.forward(input, detectionThreshold)); + const forward = ( + input: string | PixelData, + options?: ObjectDetectionOptions> + ) => runForward((inst) => inst.forward(input, options)); + + const getAvailableInputSizes = () => + instance?.getAvailableInputSizes() ?? undefined; return { error, @@ -48,5 +55,6 @@ export const useObjectDetection = ({ downloadProgress, forward, runOnFrame, + getAvailableInputSizes, }; }; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index cbb3847ffa..93d08c0c9c 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -4,9 +4,13 @@ import { ObjectDetectionConfig, ObjectDetectionModelName, ObjectDetectionModelSources, + ObjectDetectionOptions, } from '../../types/objectDetection'; +import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; +import { RnExecutorchError } from '../../errors/errorUtils'; import { CocoLabel, + CocoLabelYolo, IMAGENET1K_MEAN, IMAGENET1K_STD, } from '../../constants/commonVision'; @@ -16,15 +20,37 @@ import { VisionLabeledModule, } from './VisionLabeledModule'; +const YOLO_DETECTION_CONFIG = { + labelMap: CocoLabelYolo, + preprocessorConfig: undefined, + availableInputSizes: [384, 512, 640] as const, + defaultInputSize: 384, + defaultDetectionThreshold: 0.5, + defaultIouThreshold: 0.5, +} satisfies ObjectDetectionConfig; + const ModelConfigs = { 'ssdlite-320-mobilenet-v3-large': { labelMap: CocoLabel, preprocessorConfig: undefined, + availableInputSizes: undefined, + defaultInputSize: undefined, + defaultDetectionThreshold: 0.7, + defaultIouThreshold: 0.55, }, 'rf-detr-nano': { labelMap: CocoLabel, preprocessorConfig: { normMean: IMAGENET1K_MEAN, normStd: IMAGENET1K_STD }, + availableInputSizes: undefined, + defaultInputSize: undefined, + defaultDetectionThreshold: 0.7, + defaultIouThreshold: 0.55, }, + 'yolo26n': YOLO_DETECTION_CONFIG, + 'yolo26s': YOLO_DETECTION_CONFIG, + 'yolo26m': YOLO_DETECTION_CONFIG, + 'yolo26l': YOLO_DETECTION_CONFIG, + 'yolo26x': YOLO_DETECTION_CONFIG, } as const satisfies Record< ObjectDetectionModelName, ObjectDetectionConfig @@ -55,8 +81,15 @@ type ResolveLabels = export class ObjectDetectionModule< T extends ObjectDetectionModelName | LabelEnum, > extends VisionLabeledModule>[], ResolveLabels> { - private constructor(labelMap: ResolveLabels, nativeModule: unknown) { + private modelConfig: ObjectDetectionConfig; + + private constructor( + labelMap: ResolveLabels, + modelConfig: ObjectDetectionConfig, + nativeModule: unknown + ) { super(labelMap, nativeModule); + this.modelConfig = modelConfig; } /** @@ -70,9 +103,10 @@ export class ObjectDetectionModule< onDownloadProgress: (progress: number) => void = () => {} ): Promise>> { const { modelSource } = namedSources; - const { labelMap, preprocessorConfig } = ModelConfigs[ + const modelConfig = ModelConfigs[ namedSources.modelName ] as ObjectDetectionConfig; + const { labelMap, preprocessorConfig } = modelConfig; const normMean = preprocessorConfig?.normMean ?? []; const normStd = preprocessorConfig?.normStd ?? []; const allLabelNames: string[] = []; @@ -91,21 +125,165 @@ export class ObjectDetectionModule< ); return new ObjectDetectionModule>( labelMap as ResolveLabels>, + modelConfig, nativeModule ); } + /** + * Returns the available input sizes for this model, or undefined if the model accepts any size. + * + * @returns An array of available input sizes, or undefined if not constrained. + * + * @example + * ```typescript + * const sizes = model.getAvailableInputSizes(); // [384, 512, 640] for YOLO models, or undefined for RF-DETR + * ``` + */ + getAvailableInputSizes(): readonly number[] | undefined { + return this.modelConfig.availableInputSizes; + } + + /** + * Override runOnFrame to provide an options-based API for VisionCamera integration. + */ + override get runOnFrame(): + | (( + frame: any, + options?: ObjectDetectionOptions> + ) => Detection>[]) + | null { + const baseRunOnFrame = super.runOnFrame; + if (!baseRunOnFrame) return null; + + // Create reverse map (label → enum value) for classesOfInterest lookup + const labelMap: Record = {}; + for (const [name, value] of Object.entries(this.labelMap)) { + if (typeof value === 'number') { + labelMap[name] = value; + } + } + + const defaultDetectionThreshold = + this.modelConfig.defaultDetectionThreshold ?? 0.7; + const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.55; + const defaultInputSize = this.modelConfig.defaultInputSize; + + return ( + frame: any, + options?: ObjectDetectionOptions> + ): Detection>[] => { + 'worklet'; + + const detectionThreshold = + options?.detectionThreshold ?? defaultDetectionThreshold; + const iouThreshold = options?.iouThreshold ?? defaultIouThreshold; + const inputSize = options?.inputSize ?? defaultInputSize; + const methodName = + inputSize !== undefined ? `forward_${inputSize}` : 'forward'; + + const classIndices = options?.classesOfInterest + ? options.classesOfInterest.map((label) => { + const labelStr = String(label); + const enumValue = labelMap[labelStr]; + return typeof enumValue === 'number' ? enumValue : -1; + }) + : []; + + return baseRunOnFrame( + frame, + detectionThreshold, + iouThreshold, + classIndices, + methodName + ); + }; + } + /** * Executes the model's forward pass to detect objects within the provided image. + * + * Supports two input types: + * 1. **String path/URI**: File path, URL, or Base64-encoded string + * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) + * * @param input - A string image source (file path, URI, or Base64) or a {@link PixelData} object. - * @param detectionThreshold - Minimum confidence score for a detection to be included. Default is 0.7. + * @param options - Optional configuration for detection inference. Includes `detectionThreshold`, `inputSize`, and `classesOfInterest`. * @returns A Promise resolving to an array of {@link Detection} objects. + * @throws {RnExecutorchError} If the model is not loaded or if an invalid `inputSize` is provided. + * + * @example + * ```typescript + * const detections = await model.forward('path/to/image.jpg', { + * detectionThreshold: 0.7, + * inputSize: 640, // For YOLO models + * classesOfInterest: ['PERSON', 'CAR'], + * }); + * ``` */ override async forward( input: string | PixelData, - detectionThreshold = 0.7 + options?: ObjectDetectionOptions> ): Promise>[]> { - return super.forward(input, detectionThreshold); + if (this.nativeModule == null) { + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling forward().' + ); + } + + // Extract parameters with defaults from config + const detectionThreshold = + options?.detectionThreshold ?? + this.modelConfig.defaultDetectionThreshold ?? + 0.7; + const iouThreshold = + options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.55; + const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize; + + // Validate inputSize against availableInputSizes + if ( + this.modelConfig.availableInputSizes && + inputSize !== undefined && + !this.modelConfig.availableInputSizes.includes( + inputSize as (typeof this.modelConfig.availableInputSizes)[number] + ) + ) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidArgument, + `Invalid inputSize: ${inputSize}. Available sizes: ${this.modelConfig.availableInputSizes.join(', ')}` + ); + } + + // Construct method name: forward_384, forward_512, forward_640, or forward + const methodName = + inputSize !== undefined ? `forward_${inputSize}` : 'forward'; + + // Convert classesOfInterest to indices + const classIndices = options?.classesOfInterest + ? options.classesOfInterest.map((label) => { + const labelStr = String(label); + const enumValue = this.labelMap[labelStr as keyof ResolveLabels]; + return typeof enumValue === 'number' ? enumValue : -1; + }) + : []; + + // Call native with all parameters + return typeof input === 'string' + ? await this.nativeModule.generateFromString( + input, + detectionThreshold, + iouThreshold, + classIndices, + methodName + ) + : await this.nativeModule.generateFromPixels( + input, + detectionThreshold, + iouThreshold, + classIndices, + methodName + ); } /** @@ -159,6 +337,7 @@ export class ObjectDetectionModule< ); return new ObjectDetectionModule( config.labelMap as ResolveLabels, + config, nativeModule ); } diff --git a/packages/react-native-executorch/src/types/objectDetection.ts b/packages/react-native-executorch/src/types/objectDetection.ts index 8e8bf02896..bdb380c0ac 100644 --- a/packages/react-native-executorch/src/types/objectDetection.ts +++ b/packages/react-native-executorch/src/types/objectDetection.ts @@ -32,6 +32,23 @@ export interface Detection { score: number; } +/** + * Options for configuring object detection inference. + * + * @category Types + * @typeParam L - The label enum type for filtering classes of interest. + * @property {number} [detectionThreshold] - Minimum confidence score for detections (0-1). Defaults to model-specific value. + * @property {number} [iouThreshold] - IoU threshold for non-maximum suppression (0-1). Defaults to model-specific value. + * @property {number} [inputSize] - Input size for multi-method models (e.g., 384, 512, 640 for YOLO). Required for YOLO models if not using default. + * @property {(keyof L)[]} [classesOfInterest] - Optional array of class labels to filter detections. Only detections matching these classes will be returned. + */ +export interface ObjectDetectionOptions { + detectionThreshold?: number; + iouThreshold?: number; + inputSize?: number; + classesOfInterest?: (keyof L)[]; +} + /** * Per-model config for {@link ObjectDetectionModule.fromModelName}. * Each model name maps to its required fields. @@ -39,7 +56,12 @@ export interface Detection { */ export type ObjectDetectionModelSources = | { modelName: 'ssdlite-320-mobilenet-v3-large'; modelSource: ResourceSource } - | { modelName: 'rf-detr-nano'; modelSource: ResourceSource }; + | { modelName: 'rf-detr-nano'; modelSource: ResourceSource } + | { modelName: 'yolo26n'; modelSource: ResourceSource } + | { modelName: 'yolo26s'; modelSource: ResourceSource } + | { modelName: 'yolo26m'; modelSource: ResourceSource } + | { modelName: 'yolo26l'; modelSource: ResourceSource } + | { modelName: 'yolo26x'; modelSource: ResourceSource }; /** * Union of all built-in object detection model names. @@ -50,11 +72,29 @@ export type ObjectDetectionModelName = ObjectDetectionModelSources['modelName']; /** * Configuration for a custom object detection model. * @category Types + * @typeParam T - The label enum type for the model. + * @property {T} labelMap - The label mapping for the model. + * @property {object} [preprocessorConfig] - Optional preprocessing configuration with normalization parameters. + * @property {number} [defaultDetectionThreshold] - Default detection confidence threshold (0-1). + * @property {number} [defaultIouThreshold] - Default IoU threshold for non-maximum suppression (0-1). + * @property {readonly number[]} [availableInputSizes] - For multi-method models, the available input sizes (e.g., [384, 512, 640]). + * @property {number} [defaultInputSize] - For multi-method models, the default input size to use. */ export type ObjectDetectionConfig = { labelMap: T; preprocessorConfig?: { normMean?: Triple; normStd?: Triple }; -}; + defaultDetectionThreshold?: number; + defaultIouThreshold?: number; +} & ( + | { + availableInputSizes: readonly number[]; + defaultInputSize: number; + } + | { + availableInputSizes?: undefined; + defaultInputSize?: undefined; + } +); /** * Props for the `useObjectDetection` hook. @@ -98,27 +138,44 @@ export interface ObjectDetectionType { /** * Executes the model's forward pass with automatic input type detection. * @param input - Image source (string path/URI or PixelData object) - * @param detectionThreshold - An optional number between 0 and 1 representing the minimum confidence score. Default is 0.7. + * @param options - Optional configuration for detection inference * @returns A Promise that resolves to an array of `Detection` objects. * @throws {RnExecutorchError} If the model is not loaded or is currently processing another image. * @example * ```typescript - * // String path - * const detections1 = await model.forward('file:///path/to/image.jpg'); + * // String path with options + * const detections1 = await model.forward('file:///path/to/image.jpg', { + * detectionThreshold: 0.7, + * inputSize: 640, // For YOLO models + * classesOfInterest: ['PERSON', 'CAR'] + * }); * * // Pixel data * const detections2 = await model.forward({ * dataPtr: new Uint8Array(rgbPixels), * sizes: [480, 640, 3], * scalarType: ScalarType.BYTE - * }); + * }, { detectionThreshold: 0.5 }); * ``` */ forward: ( input: string | PixelData, - detectionThreshold?: number + options?: ObjectDetectionOptions ) => Promise[]>; + /** + * Returns the available input sizes for multi-method models (e.g., YOLO). + * Returns undefined for single-method models (e.g., RF-DETR, SSDLite). + * + * @returns Array of available input sizes or undefined + * + * @example + * ```typescript + * const sizes = model.getAvailableInputSizes(); // [384, 512, 640] for YOLO models + * ``` + */ + getAvailableInputSizes: () => readonly number[] | undefined; + /** * Synchronous worklet function for real-time VisionCamera frame processing. * Automatically handles native buffer extraction and cleanup. @@ -129,14 +186,12 @@ export interface ObjectDetectionType { * Available after model is loaded (`isReady: true`). * @param frame - VisionCamera Frame object * @param isFrontCamera - Whether the front camera is active, used for mirroring corrections. - * @param detectionThreshold - The threshold for detection sensitivity. + * @param options - Optional configuration for detection inference * @returns Array of Detection objects representing detected items in the frame. */ - runOnFrame: - | (( - frame: Frame, - isFrontCamera: boolean, - detectionThreshold: number - ) => Detection[]) - | null; + runOnFrame: ( + frame: Frame, + isFrontCamera: boolean, + options?: ObjectDetectionOptions + ) => Detection[]; } From 9b77da76a7f93d7851c880c8d000b2448713bea6 Mon Sep 17 00:00:00 2001 From: benITo47 Date: Fri, 20 Mar 2026 10:57:53 +0100 Subject: [PATCH 2/5] Update docs --- docs/docs/02-benchmarks/inference-time.md | 10 +++++ docs/docs/02-benchmarks/memory-usage.md | 10 +++++ docs/docs/02-benchmarks/model-size.md | 6 +++ .../02-computer-vision/useObjectDetection.md | 38 ++++++++++++------- .../ObjectDetectionModule.md | 16 +++++++- 5 files changed, 66 insertions(+), 14 deletions(-) diff --git a/docs/docs/02-benchmarks/inference-time.md b/docs/docs/02-benchmarks/inference-time.md index cec25098b8..fe1c143409 100644 --- a/docs/docs/02-benchmarks/inference-time.md +++ b/docs/docs/02-benchmarks/inference-time.md @@ -43,11 +43,21 @@ processing. Resizing is typically fast for small images but may be noticeably slower for very large images, which can increase total time. ::: +:::warning +Times presented in the tables are measured for forward method with input size equal to 512. Other input sizes may yield slower or faster inference times. +::: + | Model / Device | iPhone 17 Pro [ms] | Google Pixel 10 [ms] | | :-------------------------------------------- | :----------------: | :------------------: | | SSDLITE_320_MOBILENET_V3_LARGE (XNNPACK FP32) | 20 | 18 | | SSDLITE_320_MOBILENET_V3_LARGE (Core ML FP32) | 18 | - | | SSDLITE_320_MOBILENET_V3_LARGE (Core ML FP16) | 8 | - | +| RF_DETR_NANO (XNNPACK FP32) | TBD | TBD | +| YOLO26N (XNNPACK FP32) | TBD | TBD | +| YOLO26S (XNNPACK FP32) | TBD | TBD | +| YOLO26M (XNNPACK FP32) | TBD | TBD | +| YOLO26L (XNNPACK FP32) | TBD | TBD | +| YOLO26X (XNNPACK FP32) | TBD | TBD | ## Style Transfer diff --git a/docs/docs/02-benchmarks/memory-usage.md b/docs/docs/02-benchmarks/memory-usage.md index 0ad6e7a11d..0e1b3ccee8 100644 --- a/docs/docs/02-benchmarks/memory-usage.md +++ b/docs/docs/02-benchmarks/memory-usage.md @@ -25,11 +25,21 @@ loaded and actively running inference, relative to the baseline app memory before model initialization. ::: +:::warning +Data presented for YOLO models is based on inference with forward_640 method. +::: + | Model / Device | iPhone 17 Pro [MB] | Google Pixel 10 [MB] | | --------------------------------------------- | :----------------: | :------------------: | | SSDLITE_320_MOBILENET_V3_LARGE (XNNPACK FP32) | 94 | 104 | | SSDLITE_320_MOBILENET_V3_LARGE (Core ML FP32) | 83 | - | | SSDLITE_320_MOBILENET_V3_LARGE (Core ML FP16) | 62 | - | +| RF_DETR_NANO (XNNPACK FP32) | TBD | TBD | +| YOLO26N (XNNPACK FP32) | TBD | TBD | +| YOLO26S (XNNPACK FP32) | TBD | TBD | +| YOLO26M (XNNPACK FP32) | TBD | TBD | +| YOLO26L (XNNPACK FP32) | TBD | TBD | +| YOLO26X (XNNPACK FP32) | TBD | TBD | ## Style Transfer diff --git a/docs/docs/02-benchmarks/model-size.md b/docs/docs/02-benchmarks/model-size.md index 38ea9e7a6e..f9f5e4701f 100644 --- a/docs/docs/02-benchmarks/model-size.md +++ b/docs/docs/02-benchmarks/model-size.md @@ -13,6 +13,12 @@ title: Model Size | Model | XNNPACK FP32 [MB] | Core ML FP32 [MB] | Core ML FP16 [MB] | | ------------------------------ | :---------------: | :---------------: | :---------------: | | SSDLITE_320_MOBILENET_V3_LARGE | 13.9 | 15.6 | 8.46 | +| RF_DETR_NANO | 112 | - | - | +| YOLO26N | 10.3 | - | - | +| YOLO26S | 38.6 | - | - | +| YOLO26M | 82.3 | - | - | +| YOLO26L | 100 | - | - | +| YOLO26X | 224 | - | - | ## Instance Segmentation diff --git a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md index 5fb2b2bb3a..3ede23f48d 100644 --- a/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md +++ b/docs/docs/03-hooks/02-computer-vision/useObjectDetection.md @@ -61,13 +61,18 @@ You need more details? Check the following resources: - `error` - An error object if the model failed to load or encountered a runtime error. - `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary. - `forward` - A function to run inference on an image. +- `getAvailableInputSizes` - A function that returns available input sizes for multi-method models (YOLO). Returns `undefined` for single-method models. ## Running the model To run the model, use the [`forward`](../../06-api-reference/interfaces/ObjectDetectionType.md#forward) method. It accepts two arguments: - `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). -- `detectionThreshold` (optional) - A number between 0 and 1 representing the minimum confidence score for a detection to be included in the results. Defaults to `0.7`. +- `options` (optional) - An [`ObjectDetectionOptions`](../../06-api-reference/interfaces/ObjectDetectionOptions.md) object with the following properties: + - `detectionThreshold` (optional) - A number between 0 and 1 representing the minimum confidence score. Defaults to model-specific value (typically `0.7`). + - `iouThreshold` (optional) - IoU threshold for non-maximum suppression (0-1). Defaults to model-specific value (typically `0.55`). + - `inputSize` (optional) - For multi-method models like YOLO, specify the input resolution (`384`, `512`, or `640`). Defaults to `384` for YOLO models. + - `classesOfInterest` (optional) - Array of class labels to filter detections. Only detections matching these classes will be returned. `forward` returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing: @@ -78,11 +83,11 @@ To run the model, use the [`forward`](../../06-api-reference/interfaces/ObjectDe ## Example ```typescript -import { useObjectDetection, RF_DETR_NANO } from 'react-native-executorch'; +import { useObjectDetection, YOLO26N } from 'react-native-executorch'; function App() { const model = useObjectDetection({ - model: RF_DETR_NANO, + model: YOLO26N, }); const handleDetect = async () => { @@ -91,13 +96,12 @@ function App() { const imageUri = 'file:///Users/.../photo.jpg'; try { - const detections = await model.forward(imageUri, 0.5); + const detections = await model.forward(imageUri, { + detectionThreshold: 0.5, + inputSize: 640, + }); - for (const detection of detections) { - console.log('Label:', detection.label); - console.log('Score:', detection.score); - console.log('Bounding box:', detection.bbox); - } + console.log('Detected:', detections.length, 'objects'); } catch (error) { console.error(error); } @@ -113,7 +117,15 @@ See the full guide: [VisionCamera Integration](./visioncamera-integration.md). ## Supported models -| Model | Number of classes | Class list | -| ----------------------------------------------------------------------------------------------------------------------------- | ----------------- | -------------------------------------------------------- | -| [SSDLite320 MobileNetV3 Large](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large) | 91 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | -| [RF-DETR Nano](https://huggingface.co/software-mansion/react-native-executorch-rf-detr-nano) | 80 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | +| Model | Number of classes | Class list | Multi-size Support |\n| ----------------------------------------------------------------------------------------------------------------------------- | ----------------- | ------------------------------------------------------------ | ------------------ | +| [SSDLite320 MobileNetV3 Large](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large) | 91 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | No | +| [RF-DETR Nano](https://huggingface.co/software-mansion/react-native-executorch-rf-detr-nano) | 80 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | No | +| [YOLO26N](https://huggingface.co/software-mansion/react-native-executorch-yolo26) | 80 | [COCO YOLO](../../06-api-reference/enumerations/CocoLabel.md) | Yes (384/512/640) | +| [YOLO26S](https://huggingface.co/software-mansion/react-native-executorch-yolo26) | 80 | [COCO YOLO](../../06-api-reference/enumerations/CocoLabel.md) | Yes (384/512/640) | +| [YOLO26M](https://huggingface.co/software-mansion/react-native-executorch-yolo26) | 80 | [COCO YOLO](../../06-api-reference/enumerations/CocoLabel.md) | Yes (384/512/640) | +| [YOLO26L](https://huggingface.co/software-mansion/react-native-executorch-yolo26) | 80 | [COCO YOLO](../../06-api-reference/enumerations/CocoLabel.md) | Yes (384/512/640) | +| [YOLO26X](https://huggingface.co/software-mansion/react-native-executorch-yolo26) | 80 | [COCO YOLO](../../06-api-reference/enumerations/CocoLabel.md) | Yes (384/512/640) | + +:::tip +YOLO models support multiple input sizes (384px, 512px, 640px). Smaller sizes are faster but less accurate, while larger sizes are more accurate but slower. Choose based on your speed/accuracy requirements. +::: diff --git a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md index b56cb47713..0d004e6752 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md @@ -43,12 +43,26 @@ For more information on loading resources, take a look at [loading models](../.. To run the model, use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method. It accepts two arguments: - `input` (required) - The image to process. Can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). -- `detectionThreshold` (optional) - A number between 0 and 1. Defaults to `0.7`. +- `options` (optional) - An [`ObjectDetectionOptions`](../../06-api-reference/interfaces/ObjectDetectionOptions.md) object with: + - `detectionThreshold` (optional) - Minimum confidence score (0-1). Defaults to model-specific value. + - `iouThreshold` (optional) - IoU threshold for NMS (0-1). Defaults to model-specific value. + - `inputSize` (optional) - For YOLO models: `384`, `512`, or `640`. Defaults to `384`. + - `classesOfInterest` (optional) - Array of class labels to filter detections. The method returns a promise resolving to an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects, each containing the bounding box, label, and confidence score. For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead. +### Example with Options + +```typescript +const detections = await model.forward(imageUri, { + detectionThreshold: 0.5, + inputSize: 640, // YOLO models only + classesOfInterest: ['PERSON', 'CAR'], +}); +``` + ## Using a custom model Use [`fromCustomModel`](../../06-api-reference/classes/ObjectDetectionModule.md#fromcustommodel) to load your own exported model binary instead of a built-in preset. From 9ac45d989ae73af65fc1fab6e5b0ba40aa7f2607 Mon Sep 17 00:00:00 2001 From: benITo47 Date: Fri, 20 Mar 2026 14:56:51 +0100 Subject: [PATCH 3/5] Fix type signatures after rebase --- .../tasks/ObjectDetectionTask.tsx | 25 ++++++++++++++++--- .../computer_vision/ObjectDetectionModule.ts | 8 +++--- .../src/types/objectDetection.ts | 15 ++++++----- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx index 243a3ee09d..e77c959075 100644 --- a/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx +++ b/apps/computer-vision/components/vision_camera/tasks/ObjectDetectionTask.tsx @@ -8,6 +8,8 @@ import { SSDLITE_320_MOBILENET_V3_LARGE, YOLO26N, useObjectDetection, + CocoLabel, + CocoLabelYolo, } from 'react-native-executorch'; import { labelColor, labelColorBg } from '../utils/colors'; import { TaskProps } from './types'; @@ -50,7 +52,9 @@ export default function ObjectDetectionTask({ ? rfdetr : yolo26n; - const [detections, setDetections] = useState([]); + type CommonDetection = Omit & { label: string }; + + const [detections, setDetections] = useState([]); const [imageSize, setImageSize] = useState({ width: 1, height: 1 }); const lastFrameTimeRef = useRef(Date.now()); @@ -69,8 +73,19 @@ export default function ObjectDetectionTask({ const detRof = active.runOnFrame; const updateDetections = useCallback( - (p: { results: Detection[]; imageWidth: number; imageHeight: number }) => { - setDetections(p.results); + (p: { + results: + | Detection[] + | Detection[]; + imageWidth: number; + imageHeight: number; + }) => { + setDetections( + p.results.map((det) => ({ + ...det, + label: String(det.label), + })) + ); setImageSize({ width: p.imageWidth, height: p.imageHeight }); const now = Date.now(); const diff = now - lastFrameTimeRef.current; @@ -95,7 +110,9 @@ export default function ObjectDetectionTask({ try { if (!detRof) return; const isFrontCamera = cameraPositionSync.getDirty() === 'front'; - const result = detRof(frame, isFrontCamera, 0.5); + const result = detRof(frame, isFrontCamera, { + detectionThreshold: 0.5, + }); // Sensor frames are landscape-native, so width/height are swapped // relative to portrait screen orientation. const screenW = frame.height; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index 93d08c0c9c..99d0fa7c5d 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -132,9 +132,7 @@ export class ObjectDetectionModule< /** * Returns the available input sizes for this model, or undefined if the model accepts any size. - * * @returns An array of available input sizes, or undefined if not constrained. - * * @example * ```typescript * const sizes = model.getAvailableInputSizes(); // [384, 512, 640] for YOLO models, or undefined for RF-DETR @@ -146,10 +144,12 @@ export class ObjectDetectionModule< /** * Override runOnFrame to provide an options-based API for VisionCamera integration. + * @returns A worklet function for frame processing, or null if the model is not loaded. */ override get runOnFrame(): | (( frame: any, + isFrontCamera: boolean, options?: ObjectDetectionOptions> ) => Detection>[]) | null { @@ -171,6 +171,7 @@ export class ObjectDetectionModule< return ( frame: any, + isFrontCamera: boolean, options?: ObjectDetectionOptions> ): Detection>[] => { 'worklet'; @@ -192,6 +193,7 @@ export class ObjectDetectionModule< return baseRunOnFrame( frame, + isFrontCamera, detectionThreshold, iouThreshold, classIndices, @@ -206,12 +208,10 @@ export class ObjectDetectionModule< * Supports two input types: * 1. **String path/URI**: File path, URL, or Base64-encoded string * 2. **PixelData**: Raw pixel data from image libraries (e.g., NitroImage) - * * @param input - A string image source (file path, URI, or Base64) or a {@link PixelData} object. * @param options - Optional configuration for detection inference. Includes `detectionThreshold`, `inputSize`, and `classesOfInterest`. * @returns A Promise resolving to an array of {@link Detection} objects. * @throws {RnExecutorchError} If the model is not loaded or if an invalid `inputSize` is provided. - * * @example * ```typescript * const detections = await model.forward('path/to/image.jpg', { diff --git a/packages/react-native-executorch/src/types/objectDetection.ts b/packages/react-native-executorch/src/types/objectDetection.ts index bdb380c0ac..ca33697d93 100644 --- a/packages/react-native-executorch/src/types/objectDetection.ts +++ b/packages/react-native-executorch/src/types/objectDetection.ts @@ -34,7 +34,6 @@ export interface Detection { /** * Options for configuring object detection inference. - * * @category Types * @typeParam L - The label enum type for filtering classes of interest. * @property {number} [detectionThreshold] - Minimum confidence score for detections (0-1). Defaults to model-specific value. @@ -166,9 +165,7 @@ export interface ObjectDetectionType { /** * Returns the available input sizes for multi-method models (e.g., YOLO). * Returns undefined for single-method models (e.g., RF-DETR, SSDLite). - * * @returns Array of available input sizes or undefined - * * @example * ```typescript * const sizes = model.getAvailableInputSizes(); // [384, 512, 640] for YOLO models @@ -189,9 +186,11 @@ export interface ObjectDetectionType { * @param options - Optional configuration for detection inference * @returns Array of Detection objects representing detected items in the frame. */ - runOnFrame: ( - frame: Frame, - isFrontCamera: boolean, - options?: ObjectDetectionOptions - ) => Detection[]; + runOnFrame: + | (( + frame: Frame, + isFrontCamera: boolean, + options?: ObjectDetectionOptions + ) => Detection[]) + | null; } From 4e64116d9a610d88df352e55fb4e15174cdbb0ed Mon Sep 17 00:00:00 2001 From: benITo47 Date: Fri, 20 Mar 2026 15:27:22 +0100 Subject: [PATCH 4/5] Fix inproper rebase merge --- .../rnexecutorch/models/object_detection/ObjectDetection.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index eb0943c2f6..be1eb539a2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -206,8 +206,10 @@ std::vector ObjectDetection::generateFromFrame( jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold, double iouThreshold, std::vector classIndices, std::string methodName) { + auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = extractFromFrame(runtime, frameData); - auto detections = runInference(frame, detectionThreshold, iouThreshold, + cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient); + auto detections = runInference(rotated, detectionThreshold, iouThreshold, classIndices, methodName); for (auto &det : detections) { From 195337380698befb0930239ec2199d32693c3b13 Mon Sep 17 00:00:00 2001 From: benITo47 Date: Fri, 20 Mar 2026 13:18:54 +0100 Subject: [PATCH 5/5] First shot at deduplicating CV code --- .../data_processing/CVProcessing.cpp | 57 ++++ .../data_processing/CVProcessing.h | 107 ++++++++ .../rnexecutorch/data_processing/CVTypes.h | 64 +++++ .../host_objects/JsiConversions.h | 4 +- .../rnexecutorch/models/VisionModel.cpp | 51 ++++ .../common/rnexecutorch/models/VisionModel.h | 37 +++ .../BaseInstanceSegmentation.cpp | 112 +------- .../BaseInstanceSegmentation.h | 39 +-- .../models/instance_segmentation/Types.h | 10 +- .../object_detection/ObjectDetection.cpp | 79 +----- .../models/object_detection/ObjectDetection.h | 38 --- .../models/object_detection/Types.h | 8 +- .../common/rnexecutorch/tests/CMakeLists.txt | 17 +- .../tests/integration/ObjectDetectionTest.cpp | 50 ++-- .../tests/unit/CVProcessingTest.cpp | 244 ++++++++++++++++++ .../utils/computer_vision/Processing.cpp | 21 -- .../utils/computer_vision/Processing.h | 51 ---- .../utils/computer_vision/Types.h | 33 --- 18 files changed, 650 insertions(+), 372 deletions(-) create mode 100644 packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.cpp create mode 100644 packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/data_processing/CVTypes.h create mode 100644 packages/react-native-executorch/common/rnexecutorch/tests/unit/CVProcessingTest.cpp delete mode 100644 packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.cpp delete mode 100644 packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h delete mode 100644 packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.cpp new file mode 100644 index 0000000000..cbb9817b8c --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.cpp @@ -0,0 +1,57 @@ +#include "CVProcessing.h" +#include +#include +#include +#include +#include + +namespace rnexecutorch::cv_processing { + +float computeIoU(const BBox &a, const BBox &b) { + float x1 = std::max(a.x1, b.x1); + float y1 = std::max(a.y1, b.y1); + float x2 = std::min(a.x2, b.x2); + float y2 = std::min(a.y2, b.y2); + + float intersectionArea = std::max(0.0f, x2 - x1) * std::max(0.0f, y2 - y1); + float areaA = a.area(); + float areaB = b.area(); + float unionArea = areaA + areaB - intersectionArea; + + return (unionArea > 0.0f) ? (intersectionArea / unionArea) : 0.0f; +} + +std::optional validateNormParam(const std::vector &values, + const char *paramName) { + if (values.size() == 3) { + return cv::Scalar(values[0], values[1], values[2]); + } else if (!values.empty()) { + log(LOG_LEVEL::Warn, + std::string(paramName) + + " must have 3 elements — ignoring provided value."); + } + return std::nullopt; +} + +std::set +prepareAllowedClasses(const std::vector &classIndices) { + std::set allowedClasses; + if (!classIndices.empty()) { + allowedClasses.insert(classIndices.begin(), classIndices.end()); + } + return allowedClasses; +} + +void validateThresholds(double confidenceThreshold, double iouThreshold) { + if (confidenceThreshold < 0.0 || confidenceThreshold > 1.0) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Confidence threshold must be in range [0, 1]."); + } + + if (iouThreshold < 0.0 || iouThreshold > 1.0) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "IoU threshold must be in range [0, 1]."); + } +} + +} // namespace rnexecutorch::cv_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.h new file mode 100644 index 0000000000..091631a779 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/CVProcessing.h @@ -0,0 +1,107 @@ +#pragma once + +#include "CVTypes.h" +#include +#include +#include +#include +#include + +namespace rnexecutorch::cv_processing { + +/** + * @brief Compute Intersection over Union (IoU) between two bounding boxes + * @param a First bounding box + * @param b Second bounding box + * @return IoU value between 0.0 and 1.0 + * + * Moved from utils/computer_vision/Processing.h for consolidation. + */ +float computeIoU(const BBox &a, const BBox &b); + +/** + * @brief Non-Maximum Suppression for detection/segmentation results + * @tparam T Type that has bbox and score fields (satisfies HasBBoxAndScore) + * @param items Vector of items to filter + * @param iouThreshold IoU threshold for suppression (typically 0.5) + * @return Filtered vector with overlapping detections removed + * + * Moved from utils/computer_vision/Processing.h for consolidation. + * Handles both class-aware and class-agnostic NMS automatically. + */ +template +std::vector nonMaxSuppression(std::vector items, double iouThreshold) { + if (items.empty()) { + return {}; + } + + // Sort by score in descending order + std::ranges::sort(items, + [](const T &a, const T &b) { return a.score > b.score; }); + + std::vector result; + std::vector suppressed(items.size(), false); + + for (size_t i = 0; i < items.size(); ++i) { + if (suppressed[i]) { + continue; + } + + result.push_back(items[i]); + + // Suppress overlapping boxes + for (size_t j = i + 1; j < items.size(); ++j) { + if (suppressed[j]) { + continue; + } + + // If type has classIndex, only suppress boxes of same class + if constexpr (requires(T t) { t.classIndex; }) { + if (items[i].classIndex != items[j].classIndex) { + continue; + } + } + + float iou = computeIoU(items[i].bbox, items[j].bbox); + if (iou > iouThreshold) { + suppressed[j] = true; + } + } + } + + return result; +} + +/** + * @brief Validate and convert normalization parameter vector to cv::Scalar + * @param values Vector of normalization values (should have 3 elements for RGB) + * @param paramName Parameter name for logging (e.g., "normMean", "normStd") + * @return Optional cv::Scalar if valid (3 elements), nullopt otherwise + * + * Replaces duplicate validation logic across ObjectDetection, + * BaseInstanceSegmentation, and BaseSemanticSegmentation. + */ +std::optional validateNormParam(const std::vector &values, + const char *paramName); + +/** + * @brief Convert class indices vector to a set for efficient filtering + * @param classIndices Vector of class indices to allow + * @return Set of allowed class indices (empty set = allow all classes) + * + * Used by detection and segmentation models to filter results by class. + */ +std::set +prepareAllowedClasses(const std::vector &classIndices); + +/** + * @brief Validate confidence and IoU thresholds are in valid range [0, 1] + * @param confidenceThreshold Detection confidence threshold + * @param iouThreshold Non-maximum suppression IoU threshold + * @throws RnExecutorchError if either threshold is out of range + * + * Used by detection and segmentation models to validate user input. + */ +void validateThresholds(double confidenceThreshold, double iouThreshold); + +} // namespace rnexecutorch::cv_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/CVTypes.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/CVTypes.h new file mode 100644 index 0000000000..4a146d2180 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/CVTypes.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::cv_processing { + +/** + * @brief Bounding box representation with x1, y1, x2, y2 coordinates + * + * Moved from utils/computer_vision/Types.h for consolidation. + */ +struct BBox { + float x1, y1, x2, y2; + + float width() const { return x2 - x1; } + + float height() const { return y2 - y1; } + + float area() const { return width() * height(); } + + bool isValid() const { + return x2 > x1 && y2 > y1 && x1 >= 0.0f && y1 >= 0.0f; + } + + BBox scale(float widthRatio, float heightRatio) const { + return {x1 * widthRatio, y1 * heightRatio, x2 * widthRatio, + y2 * heightRatio}; + } +}; + +/** + * @brief Concept for types that have a bounding box and confidence score + * + * Used for NMS and other detection/segmentation operations. + */ +template +concept HasBBoxAndScore = requires(T t) { + { t.bbox } -> std::convertible_to; + { t.score } -> std::convertible_to; +}; + +/** + * @brief Scale ratios for mapping between original and model input dimensions + * + * Replaces duplicate scale ratio calculation code across multiple models. + */ +struct ScaleRatios { + float widthRatio; + float heightRatio; + + /** + * @brief Compute scale ratios from original size to model input size + * @param original Original image dimensions + * @param model Model input dimensions + * @return ScaleRatios struct containing width and height ratios + */ + static ScaleRatios compute(cv::Size original, cv::Size model) { + return {static_cast(original.width) / model.width, + static_cast(original.height) / model.height}; + } +}; + +} // namespace rnexecutorch::cv_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 7b389d45b6..f81077a572 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -25,7 +26,6 @@ #include #include #include -#include using namespace rnexecutorch::models::speech_to_text; @@ -433,7 +433,7 @@ getJsiValue(const std::unordered_map &map, return mapObj; } -inline jsi::Value getJsiValue(const utils::computer_vision::BBox &bbox, +inline jsi::Value getJsiValue(const cv_processing::BBox &bbox, jsi::Runtime &runtime) { jsi::Object obj(runtime); obj.setProperty(runtime, "x1", bbox.x1); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index cc9c862b32..b084de9ab7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -1,6 +1,7 @@ #include "VisionModel.h" #include #include +#include #include #include @@ -18,6 +19,18 @@ void VisionModel::unload() noexcept { } cv::Size VisionModel::modelInputSize() const { + // For multi-method models, query the currently loaded method's input shape + if (!currentlyLoadedMethod_.empty()) { + auto inputShapes = getAllInputShapes(currentlyLoadedMethod_); + if (!inputShapes.empty() && !inputShapes[0].empty() && + inputShapes[0].size() >= 2) { + const auto &shape = inputShapes[0]; + return {static_cast(shape[shape.size() - 2]), + static_cast(shape[shape.size() - 1])}; + } + } + + // Default: use cached modelInputShape_ from single-method models if (modelInputShape_.size() < 2) { return {0, 0}; } @@ -51,4 +64,42 @@ cv::Mat VisionModel::extractFromPixels(const JSTensorViewIn &tensorView) const { return ::rnexecutorch::utils::pixelsToMat(tensorView); } +void VisionModel::ensureMethodLoaded(const std::string &methodName) { + if (methodName.empty()) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidConfig, + "Method name cannot be empty. Use 'forward' for single-method models " + "or 'forward_{inputSize}' for multi-method models."); + } + + if (currentlyLoadedMethod_ == methodName) { + return; + } + + if (!module_) { + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, + "Model not loaded. Cannot load method '" + + methodName + "'."); + } + + if (!currentlyLoadedMethod_.empty()) { + module_->unload_method(currentlyLoadedMethod_); + } + + auto loadResult = module_->load_method(methodName); + if (loadResult != executorch::runtime::Error::Ok) { + throw RnExecutorchError( + loadResult, "Failed to load method '" + methodName + + "'. Ensure the method exists in the exported model."); + } + + currentlyLoadedMethod_ = methodName; +} + +void VisionModel::initializeNormalization(const std::vector &normMean, + const std::vector &normStd) { + normMean_ = cv_processing::validateNormParam(normMean, "normMean"); + normStd_ = cv_processing::validateNormParam(normStd, "normStd"); +} + } // namespace rnexecutorch::models diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index cf003948af..cdfe2c1ab6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -78,6 +79,42 @@ class VisionModel : public BaseModel { */ mutable std::mutex inference_mutex_; + /// Name of the currently loaded method (for multi-method models). + /// Empty for single-method models using default "forward". + std::string currentlyLoadedMethod_; + + /// Optional per-channel mean for input normalisation. + std::optional normMean_; + + /// Optional per-channel standard deviation for input normalisation. + std::optional normStd_; + + /** + * @brief Ensures the specified method is loaded, unloading any previous + * method if necessary. + * + * For single-method models, pass "forward" (the default). + * For multi-method models, pass the specific method name (e.g., + * "forward_384"). + * + * @param methodName Name of the method to load. Defaults to "forward". + * @throws RnExecutorchError if the method cannot be loaded. + */ + void ensureMethodLoaded(const std::string &methodName = "forward"); + + /** + * @brief Initializes normalization parameters from vectors. + * + * Uses cv_processing::validateNormParam() for validation. + * + * @param normMean Per-channel mean values (must be exactly 3 elements, or + * empty to skip). + * @param normStd Per-channel std dev values (must be exactly 3 elements, or + * empty to skip). + */ + void initializeNormalization(const std::vector &normMean, + const std::vector &normStd); + /** * @brief Resize an RGB image to the model's expected input size * diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp index 3d2f9d1715..85ffc65152 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp @@ -4,11 +4,10 @@ #include #include #include -#include +#include #include #include #include -#include namespace rnexecutorch::models::instance_segmentation { @@ -17,31 +16,7 @@ BaseInstanceSegmentation::BaseInstanceSegmentation( std::vector normStd, bool applyNMS, std::shared_ptr callInvoker) : VisionModel(modelSource, callInvoker), applyNMS_(applyNMS) { - - if (normMean.size() == 3) { - normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); - } else if (!normMean.empty()) { - log(LOG_LEVEL::Warn, - "normMean must have 3 elements — ignoring provided value."); - } - if (normStd.size() == 3) { - normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); - } else if (!normStd.empty()) { - log(LOG_LEVEL::Warn, - "normStd must have 3 elements — ignoring provided value."); - } -} - -cv::Size BaseInstanceSegmentation::modelInputSize() const { - if (currentlyLoadedMethod_.empty()) { - return VisionModel::modelInputSize(); - } - auto inputShapes = getAllInputShapes(currentlyLoadedMethod_); - if (inputShapes.empty() || inputShapes[0].size() < 2) { - return VisionModel::modelInputSize(); - } - const auto &shape = inputShapes[0]; - return {shape[shape.size() - 2], shape[shape.size() - 1]}; + initializeNormalization(normMean, normStd); } TensorPtr BaseInstanceSegmentation::buildInputTensor(const cv::Mat &image) { @@ -75,7 +50,7 @@ std::vector BaseInstanceSegmentation::runInference( cv::Size modelInputSize(shape[shape.size() - 2], shape[shape.size() - 1]); cv::Size originalSize(image.cols, image.rows); - validateThresholds(confidenceThreshold, iouThreshold); + cv_processing::validateThresholds(confidenceThreshold, iouThreshold); auto forwardResult = BaseModel::execute(methodName, {buildInputTensor(image)}); @@ -144,13 +119,12 @@ std::vector BaseInstanceSegmentation::generateFromPixels( classIndices, returnMaskAtOriginalResolution, methodName); } -std::tuple +std::tuple BaseInstanceSegmentation::extractDetectionData(const float *bboxData, const float *scoresData, int32_t index) { - utils::computer_vision::BBox bbox{ - bboxData[index * 4], bboxData[index * 4 + 1], bboxData[index * 4 + 2], - bboxData[index * 4 + 3]}; + cv_processing::BBox bbox{bboxData[index * 4], bboxData[index * 4 + 1], + bboxData[index * 4 + 2], bboxData[index * 4 + 3]}; float score = scoresData[index * 2]; int32_t label = static_cast(scoresData[index * 2 + 1]); @@ -158,7 +132,7 @@ BaseInstanceSegmentation::extractDetectionData(const float *bboxData, } cv::Rect BaseInstanceSegmentation::computeMaskCropRect( - const utils::computer_vision::BBox &bboxModel, cv::Size modelInputSize, + const cv_processing::BBox &bboxModel, cv::Size modelInputSize, cv::Size maskSize) { float mx1F = bboxModel.x1 * maskSize.width / modelInputSize.width; @@ -187,7 +161,7 @@ cv::Rect BaseInstanceSegmentation::addPaddingToRect(const cv::Rect &rect, cv::Mat BaseInstanceSegmentation::warpToOriginalResolution( const cv::Mat &probMat, const cv::Rect &maskRect, cv::Size originalSize, - cv::Size maskSize, const utils::computer_vision::BBox &bboxOriginal) { + cv::Size maskSize, const cv_processing::BBox &bboxOriginal) { float scaleX = static_cast(originalSize.width) / maskSize.width; float scaleY = static_cast(originalSize.height) / maskSize.height; @@ -211,8 +185,8 @@ cv::Mat BaseInstanceSegmentation::thresholdToBinary(const cv::Mat &probMat) { } cv::Mat BaseInstanceSegmentation::processMaskFromLogits( - const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel, - const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize, + const cv::Mat &logitsMat, const cv_processing::BBox &bboxModel, + const cv_processing::BBox &bboxOriginal, cv::Size modelInputSize, cv::Size originalSize, bool warpToOriginal) { cv::Size maskSize = logitsMat.size(); @@ -232,22 +206,6 @@ cv::Mat BaseInstanceSegmentation::processMaskFromLogits( return thresholdToBinary(probMat); } -void BaseInstanceSegmentation::validateThresholds(double confidenceThreshold, - double iouThreshold) const { - if (confidenceThreshold < 0 || confidenceThreshold > 1) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidConfig, - "Confidence threshold must be greater or equal to 0 " - "and less than or equal to 1."); - } - - if (iouThreshold < 0 || iouThreshold > 1) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, - "IoU threshold must be greater or equal to 0 " - "and less than or equal to 1."); - } -} - void BaseInstanceSegmentation::validateOutputTensors( const std::vector &tensors) const { if (tensors.size() != 3) { @@ -258,55 +216,12 @@ void BaseInstanceSegmentation::validateOutputTensors( } } -std::set BaseInstanceSegmentation::prepareAllowedClasses( - const std::vector &classIndices) const { - std::set allowedClasses; - if (!classIndices.empty()) { - allowedClasses.insert(classIndices.begin(), classIndices.end()); - } - return allowedClasses; -} - -void BaseInstanceSegmentation::ensureMethodLoaded( - const std::string &methodName) { - if (methodName.empty()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidConfig, - "Method name cannot be empty. Use 'forward' for single-method models " - "or 'forward_{inputSize}' for multi-method models."); - } - - if (currentlyLoadedMethod_ == methodName) { - return; - } - - if (!module_) { - throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - "Model not loaded. Cannot load method '" + - methodName + "'."); - } - - if (!currentlyLoadedMethod_.empty()) { - module_->unload_method(currentlyLoadedMethod_); - } - - auto loadResult = module_->load_method(methodName); - if (loadResult != executorch::runtime::Error::Ok) { - throw RnExecutorchError( - loadResult, "Failed to load method '" + methodName + - "'. Ensure the method exists in the exported model."); - } - - currentlyLoadedMethod_ = methodName; -} - std::vector BaseInstanceSegmentation::finalizeInstances( std::vector instances, double iouThreshold, int32_t maxInstances) const { if (applyNMS_) { - instances = - utils::computer_vision::nonMaxSuppression(instances, iouThreshold); + instances = cv_processing::nonMaxSuppression(instances, iouThreshold); } if (std::cmp_greater(instances.size(), maxInstances)) { @@ -326,7 +241,7 @@ std::vector BaseInstanceSegmentation::collectInstances( static_cast(originalSize.width) / modelInputSize.width; float heightRatio = static_cast(originalSize.height) / modelInputSize.height; - auto allowedClasses = prepareAllowedClasses(classIndices); + auto allowedClasses = cv_processing::prepareAllowedClasses(classIndices); // CONTRACT auto bboxTensor = tensors[0].toTensor(); // [1, N, 4] @@ -357,8 +272,7 @@ std::vector BaseInstanceSegmentation::collectInstances( continue; } - utils::computer_vision::BBox bboxOriginal = - bboxModel.scale(widthRatio, heightRatio); + cv_processing::BBox bboxOriginal = bboxModel.scale(widthRatio, heightRatio); if (!bboxOriginal.isValid()) { continue; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h index 341d0f2235..d59400e5fa 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h @@ -3,13 +3,11 @@ #include #include #include -#include -#include #include "Types.h" #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" +#include #include -#include namespace rnexecutorch { namespace models::instance_segmentation { @@ -44,9 +42,6 @@ class BaseInstanceSegmentation : public VisionModel { bool returnMaskAtOriginalResolution, std::string methodName); -protected: - cv::Size modelInputSize() const override; - private: std::vector runInference( const cv::Mat &image, double confidenceThreshold, double iouThreshold, @@ -61,29 +56,21 @@ class BaseInstanceSegmentation : public VisionModel { const std::vector &classIndices, bool returnMaskAtOriginalResolution); - void validateThresholds(double confidenceThreshold, - double iouThreshold) const; void validateOutputTensors(const std::vector &tensors) const; - std::set - prepareAllowedClasses(const std::vector &classIndices) const; - - // Model loading and input helpers - void ensureMethodLoaded(const std::string &methodName); - - std::tuple + std::tuple extractDetectionData(const float *bboxData, const float *scoresData, int32_t index); - cv::Rect computeMaskCropRect(const utils::computer_vision::BBox &bboxModel, + cv::Rect computeMaskCropRect(const cv_processing::BBox &bboxModel, cv::Size modelInputSize, cv::Size maskSize); cv::Rect addPaddingToRect(const cv::Rect &rect, cv::Size maskSize); - cv::Mat - warpToOriginalResolution(const cv::Mat &probMat, const cv::Rect &maskRect, - cv::Size originalSize, cv::Size maskSize, - const utils::computer_vision::BBox &bboxOriginal); + cv::Mat warpToOriginalResolution(const cv::Mat &probMat, + const cv::Rect &maskRect, + cv::Size originalSize, cv::Size maskSize, + const cv_processing::BBox &bboxOriginal); cv::Mat thresholdToBinary(const cv::Mat &probMat); @@ -91,15 +78,13 @@ class BaseInstanceSegmentation : public VisionModel { finalizeInstances(std::vector instances, double iouThreshold, int32_t maxInstances) const; - cv::Mat processMaskFromLogits( - const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel, - const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize, - cv::Size originalSize, bool warpToOriginal); + cv::Mat processMaskFromLogits(const cv::Mat &logitsMat, + const cv_processing::BBox &bboxModel, + const cv_processing::BBox &bboxOriginal, + cv::Size modelInputSize, cv::Size originalSize, + bool warpToOriginal); - std::optional normMean_; - std::optional normStd_; bool applyNMS_; - std::string currentlyLoadedMethod_; }; } // namespace models::instance_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h index 9006688ce1..7fabaeea69 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h @@ -2,8 +2,8 @@ #include #include +#include #include -#include namespace rnexecutorch::models::instance_segmentation::types { @@ -16,13 +16,13 @@ namespace rnexecutorch::models::instance_segmentation::types { struct Instance { Instance() = default; - Instance(utils::computer_vision::BBox bbox, - std::shared_ptr mask, int32_t maskWidth, - int32_t maskHeight, int32_t classIndex, float score) + Instance(cv_processing::BBox bbox, std::shared_ptr mask, + int32_t maskWidth, int32_t maskHeight, int32_t classIndex, + float score) : bbox(bbox), mask(std::move(mask)), maskWidth(maskWidth), maskHeight(maskHeight), classIndex(classIndex), score(score) {} - utils::computer_vision::BBox bbox; + cv_processing::BBox bbox; std::shared_ptr mask; int32_t maskWidth; int32_t maskHeight; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index be1eb539a2..b76d78fa6e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -1,16 +1,13 @@ #include "ObjectDetection.h" #include "Constants.h" -#include - #include #include -#include +#include #include #include #include #include -#include namespace rnexecutorch::models::object_detection { @@ -20,64 +17,7 @@ ObjectDetection::ObjectDetection( std::shared_ptr callInvoker) : VisionModel(modelSource, callInvoker), labelNames_(std::move(labelNames)) { - if (normMean.size() == 3) { - normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); - } else if (!normMean.empty()) { - log(LOG_LEVEL::Warn, - "normMean must have 3 elements — ignoring provided value."); - } - if (normStd.size() == 3) { - normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); - } else if (!normStd.empty()) { - log(LOG_LEVEL::Warn, - "normStd must have 3 elements — ignoring provided value."); - } -} - -cv::Size ObjectDetection::modelInputSize() const { - if (currentlyLoadedMethod_.empty()) { - return VisionModel::modelInputSize(); - } - auto inputShapes = getAllInputShapes(currentlyLoadedMethod_); - if (inputShapes.empty() || inputShapes[0].size() < 2) { - return VisionModel::modelInputSize(); - } - const auto &shape = inputShapes[0]; - return {static_cast(shape[shape.size() - 2]), - static_cast(shape[shape.size() - 1])}; -} - -void ObjectDetection::ensureMethodLoaded(const std::string &methodName) { - if (methodName.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "methodName cannot be empty"); - } - if (currentlyLoadedMethod_ == methodName) { - return; - } - if (!module_) { - throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - "Model module is not loaded"); - } - if (!currentlyLoadedMethod_.empty()) { - module_->unload_method(currentlyLoadedMethod_); - } - auto loadResult = module_->load_method(methodName); - if (loadResult != executorch::runtime::Error::Ok) { - throw RnExecutorchError( - loadResult, "Failed to load method '" + methodName + - "'. Ensure the method exists in the exported model."); - } - currentlyLoadedMethod_ = methodName; -} - -std::set ObjectDetection::prepareAllowedClasses( - const std::vector &classIndices) const { - std::set allowedClasses; - if (!classIndices.empty()) { - allowedClasses.insert(classIndices.begin(), classIndices.end()); - } - return allowedClasses; + initializeNormalization(normMean, normStd); } std::vector @@ -91,7 +31,7 @@ ObjectDetection::postprocess(const std::vector &tensors, static_cast(originalSize.height) / inputSize.height; // Prepare allowed classes set for filtering - auto allowedClasses = prepareAllowedClasses(classIndices); + auto allowedClasses = cv_processing::prepareAllowedClasses(classIndices); std::vector detections; auto bboxTensor = tensors.at(0).toTensor(); @@ -134,24 +74,17 @@ ObjectDetection::postprocess(const std::vector &tensors, " exceeds labelNames size " + std::to_string(labelNames_.size()) + ". Ensure the labelMap covers all model output classes."); } - detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2}, + detections.emplace_back(cv_processing::BBox{x1, y1, x2, y2}, labelNames_[labelIdx], labelIdx, scores[i]); } - return utils::computer_vision::nonMaxSuppression(detections, iouThreshold); + return cv_processing::nonMaxSuppression(detections, iouThreshold); } std::vector ObjectDetection::runInference( cv::Mat image, double detectionThreshold, double iouThreshold, const std::vector &classIndices, const std::string &methodName) { - if (detectionThreshold < 0.0 || detectionThreshold > 1.0) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "detectionThreshold must be in range [0, 1]"); - } - if (iouThreshold < 0.0 || iouThreshold > 1.0) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "iouThreshold must be in range [0, 1]"); - } + cv_processing::validateThresholds(detectionThreshold, iouThreshold); std::scoped_lock lock(inference_mutex_); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index 6e3c01356e..f52f29e223 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -86,16 +86,6 @@ class ObjectDetection : public VisionModel { std::string methodName); protected: - /** - * @brief Returns the model input size based on the currently loaded method. - * - * Overrides VisionModel::modelInputSize() to support multi-method models - * where each method may have different input dimensions. - * - * @return The expected input size for the currently loaded method. - */ - cv::Size modelInputSize() const override; - std::vector runInference(cv::Mat image, double detectionThreshold, double iouThreshold, const std::vector &classIndices, @@ -125,36 +115,8 @@ class ObjectDetection : public VisionModel { double detectionThreshold, double iouThreshold, const std::vector &classIndices); - /** - * @brief Ensures the specified method is loaded, unloading any previous - * method if necessary. - * - * @param methodName Name of the method to load (e.g., "forward", - * "forward_384"). - * @throws RnExecutorchError if the method cannot be loaded. - */ - void ensureMethodLoaded(const std::string &methodName); - - /** - * @brief Prepares a set of allowed class indices for filtering detections. - * - * @param classIndices Vector of class indices to allow. - * @return A set containing the allowed class indices. - */ - std::set - prepareAllowedClasses(const std::vector &classIndices) const; - - /// Optional per-channel mean for input normalisation (set in constructor). - std::optional normMean_; - - /// Optional per-channel standard deviation for input normalisation. - std::optional normStd_; - /// Ordered label strings mapping class indices to human-readable names. std::vector labelNames_; - - /// Name of the currently loaded method (for multi-method models). - std::string currentlyLoadedMethod_; }; } // namespace models::object_detection diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h index 1652516e89..2f63aa29a8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h @@ -1,19 +1,19 @@ #pragma once #include -#include +#include #include namespace rnexecutorch::models::object_detection::types { struct Detection { Detection() = default; - Detection(utils::computer_vision::BBox bbox, std::string label, - int32_t classIndex, float score) + Detection(cv_processing::BBox bbox, std::string label, int32_t classIndex, + float score) : bbox(bbox), label(std::move(label)), classIndex(classIndex), score(score) {} - utils::computer_vision::BBox bbox; + cv_processing::BBox bbox; std::string label; int32_t classIndex; float score; diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index f6fe386a7e..75f579a713 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -103,6 +103,10 @@ set(IMAGE_UTILS_SOURCES ${COMMON_DIR}/ada/ada.cpp ) +set(CV_PROCESSING_SOURCES + ${RNEXECUTORCH_DIR}/data_processing/CVProcessing.cpp +) + set(TOKENIZER_SOURCES ${RNEXECUTORCH_DIR}/TokenizerModule.cpp) set(DSP_SOURCES ${RNEXECUTORCH_DIR}/data_processing/dsp.cpp) @@ -157,6 +161,12 @@ add_rn_test(ImageProcessingTest unit/ImageProcessingTest.cpp LIBS opencv_deps ) +add_rn_test(CVProcessingTest unit/CVProcessingTest.cpp + SOURCES + ${CV_PROCESSING_SOURCES} + LIBS opencv_deps +) + add_rn_test(FrameProcessorTests unit/FrameProcessorTest.cpp SOURCES ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp @@ -179,6 +189,7 @@ add_rn_test(VisionModelTests integration/VisionModelTest.cpp ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp + ${CV_PROCESSING_SOURCES} ${IMAGE_UTILS_SOURCES} LIBS opencv_deps android ) @@ -190,6 +201,7 @@ add_rn_test(ClassificationTests integration/ClassificationTest.cpp ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp + ${CV_PROCESSING_SOURCES} ${IMAGE_UTILS_SOURCES} LIBS opencv_deps android ) @@ -202,6 +214,7 @@ add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp ${RNEXECUTORCH_DIR}/utils/computer_vision/Processing.cpp + ${CV_PROCESSING_SOURCES} ${IMAGE_UTILS_SOURCES} LIBS opencv_deps android ) @@ -214,6 +227,7 @@ add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp + ${CV_PROCESSING_SOURCES} ${IMAGE_UTILS_SOURCES} LIBS opencv_deps android ) @@ -233,6 +247,7 @@ add_rn_test(StyleTransferTests integration/StyleTransferTest.cpp ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp ${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp + ${CV_PROCESSING_SOURCES} ${IMAGE_UTILS_SOURCES} LIBS opencv_deps android ) @@ -306,7 +321,7 @@ add_rn_test(InstanceSegmentationTests integration/InstanceSegmentationTest.cpp ${RNEXECUTORCH_DIR}/models/VisionModel.cpp ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp - ${RNEXECUTORCH_DIR}/utils/computer_vision/Processing.cpp + ${CV_PROCESSING_SOURCES} ${IMAGE_UTILS_SOURCES} LIBS opencv_deps android ) diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp index de36b3c545..553f4e61e6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp @@ -50,7 +50,8 @@ template <> struct ModelTraits { } static void callGenerate(ModelType &model) { - (void)model.generateFromString(kValidTestImagePath, 0.5); + (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, {}, + "forward"); } }; } // namespace model_tests @@ -67,57 +68,65 @@ INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest, TEST(ObjectDetectionGenerateTests, InvalidImagePathThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5, + {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, EmptyImagePathThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - EXPECT_THROW((void)model.generateFromString("", 0.5), RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("", 0.5, 0.5, {}, "forward"), + RnExecutorchError); } TEST(ObjectDetectionGenerateTests, MalformedURIThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5, 0.5, + {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, NegativeThresholdThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1), + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.5, + {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, ThresholdAboveOneThrows) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1), + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.5, {}, + "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, ValidImageReturnsResults) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, {}, "forward"); EXPECT_GE(results.size(), 0u); } TEST(ObjectDetectionGenerateTests, HighThresholdReturnsFewerResults) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto lowThresholdResults = model.generateFromString(kValidTestImagePath, 0.1); + auto lowThresholdResults = + model.generateFromString(kValidTestImagePath, 0.1, 0.5, {}, "forward"); auto highThresholdResults = - model.generateFromString(kValidTestImagePath, 0.9); + model.generateFromString(kValidTestImagePath, 0.9, 0.5, {}, "forward"); EXPECT_GE(lowThresholdResults.size(), highThresholdResults.size()); } TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, {}, "forward"); for (const auto &detection : results) { EXPECT_LE(detection.bbox.x1, detection.bbox.x2); @@ -130,7 +139,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) { TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, {}, "forward"); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -141,7 +151,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) { TEST(ObjectDetectionGenerateTests, DetectionsHaveValidLabels) { ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, {}, "forward"); for (const auto &detection : results) { const auto &label = detection.label; @@ -162,7 +173,7 @@ TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) { JSTensorViewIn tensorView{pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; - auto results = model.generateFromPixels(tensorView, 0.3); + auto results = model.generateFromPixels(tensorView, 0.3, 0.5, {}, "forward"); EXPECT_GE(results.size(), 0u); } @@ -174,8 +185,9 @@ TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) { JSTensorViewIn tensorView{pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1), - RnExecutorchError); + EXPECT_THROW( + (void)model.generateFromPixels(tensorView, -0.1, 0.5, {}, "forward"), + RnExecutorchError); } TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) { @@ -186,8 +198,9 @@ TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) { JSTensorViewIn tensorView{pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1), - RnExecutorchError); + EXPECT_THROW( + (void)model.generateFromPixels(tensorView, 1.1, 0.5, {}, "forward"), + RnExecutorchError); } TEST(ObjectDetectionInheritedTests, GetInputShapeWorks) { @@ -239,5 +252,6 @@ TEST(ObjectDetectionNormTests, ValidNormParamsGenerateSucceeds) { const std::vector std = {0.229f, 0.224f, 0.225f}; ObjectDetection model(kValidObjectDetectionModelPath, mean, std, kCocoLabels, nullptr); - EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5)); + EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, + {}, "forward")); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/CVProcessingTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/CVProcessingTest.cpp new file mode 100644 index 0000000000..246cdafc2f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/CVProcessingTest.cpp @@ -0,0 +1,244 @@ +#include +#include +#include +#include + +using namespace rnexecutorch::cv_processing; + +class CVProcessingTest : public ::testing::Test {}; + +// ============================================================================ +// prepareAllowedClasses Tests +// ============================================================================ + +TEST_F(CVProcessingTest, PrepareAllowedClasses_EmptyVector_ReturnsEmptySet) { + std::vector input = {}; + auto result = prepareAllowedClasses(input); + EXPECT_TRUE(result.empty()); +} + +TEST_F(CVProcessingTest, PrepareAllowedClasses_SingleClass_ReturnsSetWithOne) { + std::vector input = {5}; + auto result = prepareAllowedClasses(input); + EXPECT_EQ(result.size(), 1); + EXPECT_TRUE(result.count(5) > 0); +} + +TEST_F(CVProcessingTest, + PrepareAllowedClasses_MultipleClasses_ReturnsCorrectSet) { + std::vector input = {1, 3, 5, 7}; + auto result = prepareAllowedClasses(input); + EXPECT_EQ(result.size(), 4); + EXPECT_TRUE(result.count(1) > 0); + EXPECT_TRUE(result.count(3) > 0); + EXPECT_TRUE(result.count(5) > 0); + EXPECT_TRUE(result.count(7) > 0); +} + +TEST_F(CVProcessingTest, + PrepareAllowedClasses_DuplicateClasses_RemovesDuplicates) { + std::vector input = {1, 3, 3, 5, 1}; + auto result = prepareAllowedClasses(input); + EXPECT_EQ(result.size(), 3); // Should have 1, 3, 5 + EXPECT_TRUE(result.count(1) > 0); + EXPECT_TRUE(result.count(3) > 0); + EXPECT_TRUE(result.count(5) > 0); +} + +// ============================================================================ +// validateThresholds Tests +// ============================================================================ + +TEST_F(CVProcessingTest, ValidateThresholds_ValidValues_DoesNotThrow) { + EXPECT_NO_THROW(validateThresholds(0.5, 0.5)); + EXPECT_NO_THROW(validateThresholds(0.0, 0.0)); + EXPECT_NO_THROW(validateThresholds(1.0, 1.0)); +} + +TEST_F(CVProcessingTest, ValidateThresholds_NegativeConfidence_Throws) { + EXPECT_THROW(validateThresholds(-0.1, 0.5), rnexecutorch::RnExecutorchError); +} + +TEST_F(CVProcessingTest, ValidateThresholds_ConfidenceAboveOne_Throws) { + EXPECT_THROW(validateThresholds(1.1, 0.5), rnexecutorch::RnExecutorchError); +} + +TEST_F(CVProcessingTest, ValidateThresholds_NegativeIoU_Throws) { + EXPECT_THROW(validateThresholds(0.5, -0.1), rnexecutorch::RnExecutorchError); +} + +TEST_F(CVProcessingTest, ValidateThresholds_IoUAboveOne_Throws) { + EXPECT_THROW(validateThresholds(0.5, 1.1), rnexecutorch::RnExecutorchError); +} + +// ============================================================================ +// computeIoU Tests +// ============================================================================ + +TEST_F(CVProcessingTest, ComputeIoU_IdenticalBoxes_ReturnsOne) { + BBox box{0.0f, 0.0f, 10.0f, 10.0f}; + float iou = computeIoU(box, box); + EXPECT_FLOAT_EQ(iou, 1.0f); +} + +TEST_F(CVProcessingTest, ComputeIoU_NoOverlap_ReturnsZero) { + BBox box1{0.0f, 0.0f, 10.0f, 10.0f}; + BBox box2{20.0f, 20.0f, 30.0f, 30.0f}; + float iou = computeIoU(box1, box2); + EXPECT_FLOAT_EQ(iou, 0.0f); +} + +TEST_F(CVProcessingTest, ComputeIoU_PartialOverlap_ReturnsCorrectValue) { + BBox box1{0.0f, 0.0f, 10.0f, 10.0f}; // Area = 100 + BBox box2{5.0f, 5.0f, 15.0f, 15.0f}; // Area = 100 + // Intersection: (5,5) to (10,10) = 25 + // Union: 100 + 100 - 25 = 175 + // IoU = 25/175 ≈ 0.142857 + float iou = computeIoU(box1, box2); + EXPECT_NEAR(iou, 0.142857f, 0.0001f); +} + +TEST_F(CVProcessingTest, ComputeIoU_OneBoxInsideAnother_ReturnsCorrectValue) { + BBox box1{0.0f, 0.0f, 10.0f, 10.0f}; // Area = 100 + BBox box2{2.0f, 2.0f, 8.0f, 8.0f}; // Area = 36 + // Intersection: 36 (box2 is fully inside) + // Union: 100 + 36 - 36 = 100 + // IoU = 36/100 = 0.36 + float iou = computeIoU(box1, box2); + EXPECT_FLOAT_EQ(iou, 0.36f); +} + +// ============================================================================ +// BBox Tests +// ============================================================================ + +TEST_F(CVProcessingTest, BBox_Width_ReturnsCorrectValue) { + BBox box{0.0f, 0.0f, 10.0f, 5.0f}; + EXPECT_FLOAT_EQ(box.width(), 10.0f); +} + +TEST_F(CVProcessingTest, BBox_Height_ReturnsCorrectValue) { + BBox box{0.0f, 0.0f, 10.0f, 5.0f}; + EXPECT_FLOAT_EQ(box.height(), 5.0f); +} + +TEST_F(CVProcessingTest, BBox_Area_ReturnsCorrectValue) { + BBox box{0.0f, 0.0f, 10.0f, 5.0f}; + EXPECT_FLOAT_EQ(box.area(), 50.0f); +} + +TEST_F(CVProcessingTest, BBox_IsValid_ValidBox_ReturnsTrue) { + BBox box{0.0f, 0.0f, 10.0f, 5.0f}; + EXPECT_TRUE(box.isValid()); +} + +TEST_F(CVProcessingTest, BBox_IsValid_InvalidBox_ReturnsFalse) { + BBox box1{10.0f, 0.0f, 5.0f, 5.0f}; // x2 < x1 + EXPECT_FALSE(box1.isValid()); + + BBox box2{0.0f, 10.0f, 5.0f, 5.0f}; // y2 < y1 + EXPECT_FALSE(box2.isValid()); + + BBox box3{-1.0f, 0.0f, 5.0f, 5.0f}; // negative x1 + EXPECT_FALSE(box3.isValid()); +} + +TEST_F(CVProcessingTest, BBox_Scale_ReturnsCorrectlyScaledBox) { + BBox box{1.0f, 2.0f, 3.0f, 4.0f}; + BBox scaled = box.scale(2.0f, 3.0f); + EXPECT_FLOAT_EQ(scaled.x1, 2.0f); + EXPECT_FLOAT_EQ(scaled.y1, 6.0f); + EXPECT_FLOAT_EQ(scaled.x2, 6.0f); + EXPECT_FLOAT_EQ(scaled.y2, 12.0f); +} + +// ============================================================================ +// ScaleRatios Tests +// ============================================================================ + +TEST_F(CVProcessingTest, ScaleRatios_Compute_ReturnsCorrectRatios) { + cv::Size original(640, 480); + cv::Size model(320, 240); + auto ratios = ScaleRatios::compute(original, model); + EXPECT_FLOAT_EQ(ratios.widthRatio, 2.0f); + EXPECT_FLOAT_EQ(ratios.heightRatio, 2.0f); +} + +// ============================================================================ +// validateNormParam Tests +// ============================================================================ + +TEST_F(CVProcessingTest, ValidateNormParam_ValidThreeElements_ReturnsScalar) { + std::vector values = {0.5f, 0.6f, 0.7f}; + auto result = validateNormParam(values, "test"); + ASSERT_TRUE(result.has_value()); + EXPECT_FLOAT_EQ((*result)[0], 0.5f); + EXPECT_FLOAT_EQ((*result)[1], 0.6f); + EXPECT_FLOAT_EQ((*result)[2], 0.7f); +} + +TEST_F(CVProcessingTest, ValidateNormParam_EmptyVector_ReturnsNullopt) { + std::vector values = {}; + auto result = validateNormParam(values, "test"); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(CVProcessingTest, ValidateNormParam_WrongSize_ReturnsNullopt) { + std::vector values = {0.5f, 0.6f}; // Only 2 elements + auto result = validateNormParam(values, "test"); + EXPECT_FALSE(result.has_value()); +} + +// ============================================================================ +// nonMaxSuppression Tests +// ============================================================================ + +struct TestDetection { + BBox bbox; + float score; + int32_t classIndex; +}; + +TEST_F(CVProcessingTest, NonMaxSuppression_EmptyVector_ReturnsEmpty) { + std::vector detections = {}; + auto result = nonMaxSuppression(detections, 0.5); + EXPECT_TRUE(result.empty()); +} + +TEST_F(CVProcessingTest, + NonMaxSuppression_SingleDetection_ReturnsSingleDetection) { + std::vector detections = { + {{0.0f, 0.0f, 10.0f, 10.0f}, 0.9f, 1}}; + auto result = nonMaxSuppression(detections, 0.5); + EXPECT_EQ(result.size(), 1); + EXPECT_FLOAT_EQ(result[0].score, 0.9f); +} + +TEST_F(CVProcessingTest, + NonMaxSuppression_OverlappingBoxes_SuppressesLowerScore) { + std::vector detections = { + {{0.0f, 0.0f, 10.0f, 10.0f}, 0.9f, 1}, // High score + {{0.0f, 0.0f, 10.0f, 10.0f}, 0.5f, 1}, // Same box, low score + }; + auto result = nonMaxSuppression(detections, 0.5); + EXPECT_EQ(result.size(), 1); + EXPECT_FLOAT_EQ(result[0].score, 0.9f); +} + +TEST_F(CVProcessingTest, NonMaxSuppression_DifferentClasses_KeepsBothBoxes) { + std::vector detections = { + {{0.0f, 0.0f, 10.0f, 10.0f}, 0.9f, 1}, // Class 1 + {{0.0f, 0.0f, 10.0f, 10.0f}, 0.8f, 2}, // Class 2, same location + }; + auto result = nonMaxSuppression(detections, 0.5); + EXPECT_EQ(result.size(), 2); // Both should be kept (different classes) +} + +TEST_F(CVProcessingTest, NonMaxSuppression_NoOverlap_KeepsAllBoxes) { + std::vector detections = { + {{0.0f, 0.0f, 10.0f, 10.0f}, 0.9f, 1}, + {{20.0f, 20.0f, 30.0f, 30.0f}, 0.8f, 1}, + }; + auto result = nonMaxSuppression(detections, 0.5); + EXPECT_EQ(result.size(), 2); // Both should be kept (no overlap) +} diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.cpp deleted file mode 100644 index 108fd6ff8a..0000000000 --- a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "Processing.h" -#include -#include - -namespace rnexecutorch::utils::computer_vision { - -float computeIoU(const BBox &a, const BBox &b) { - float x1 = std::max(a.x1, b.x1); - float y1 = std::max(a.y1, b.y1); - float x2 = std::min(a.x2, b.x2); - float y2 = std::min(a.y2, b.y2); - - float intersectionArea = std::max(0.0f, x2 - x1) * std::max(0.0f, y2 - y1); - float areaA = a.area(); - float areaB = b.area(); - float unionArea = areaA + areaB - intersectionArea; - - return (unionArea > 0.0f) ? (intersectionArea / unionArea) : 0.0f; -} - -} // namespace rnexecutorch::utils::computer_vision diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h deleted file mode 100644 index 3bd3022d4a..0000000000 --- a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include "Types.h" -#include -#include - -namespace rnexecutorch::utils::computer_vision { - -float computeIoU(const BBox &a, const BBox &b); - -template -std::vector nonMaxSuppression(std::vector items, double iouThreshold) { - if (items.empty()) { - return {}; - } - - std::ranges::sort(items, - [](const T &a, const T &b) { return a.score > b.score; }); - - std::vector result; - std::vector suppressed(items.size(), false); - - for (size_t i = 0; i < items.size(); ++i) { - if (suppressed[i]) { - continue; - } - - result.push_back(items[i]); - - for (size_t j = i + 1; j < items.size(); ++j) { - if (suppressed[j]) { - continue; - } - - if constexpr (requires(T t) { t.classIndex; }) { - if (items[i].classIndex != items[j].classIndex) { - continue; - } - } - - float iou = computeIoU(items[i].bbox, items[j].bbox); - if (iou > iouThreshold) { - suppressed[j] = true; - } - } - } - - return result; -} - -} // namespace rnexecutorch::utils::computer_vision diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h deleted file mode 100644 index 8899d3b87c..0000000000 --- a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::utils::computer_vision { - -struct BBox { - - float width() const { return x2 - x1; } - - float height() const { return y2 - y1; } - - float area() const { return width() * height(); } - - bool isValid() const { - return x2 > x1 && y2 > y1 && x1 >= 0.0f && y1 >= 0.0f; - } - - BBox scale(float widthRatio, float heightRatio) const { - return {x1 * widthRatio, y1 * heightRatio, x2 * widthRatio, - y2 * heightRatio}; - } - - float x1, y1, x2, y2; -}; - -template -concept HasBBoxAndScore = requires(T t) { - { t.bbox } -> std::convertible_to; - { t.score } -> std::convertible_to; -}; - -} // namespace rnexecutorch::utils::computer_vision