Skip to content

BaseModelNN

BaseModelNN #

Bases: ABC, Module

Abstract base class for neural network models in Focoos.

This class provides a common interface for all neural network models, defining abstract methods that must be implemented by concrete model classes. It extends both ABC (Abstract Base Class) and nn.Module from PyTorch.

Parameters:

Name Type Description Default
config ModelConfig

Model configuration containing hyperparameters and settings.

required
Source code in focoos/models/base_model.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class BaseModelNN(ABC, nn.Module):
    """Abstract base class for neural network models in Focoos.

    This class provides a common interface for all neural network models,
    defining abstract methods that must be implemented by concrete model classes.
    It extends both ABC (Abstract Base Class) and nn.Module from PyTorch.

    Args:
        config: Model configuration containing hyperparameters and settings.
    """

    def __init__(self, config: ModelConfig):
        """Initialize the base model.

        Args:
            config: Model configuration object containing model parameters
                and settings.
        """
        super().__init__()

    @property
    @abstractmethod
    def device(self) -> torch.device:
        """Get the device where the model is located.

        Returns:
            The PyTorch device (CPU or CUDA) where the model parameters
            are stored.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError("Device is not implemented for this model.")

    @property
    @abstractmethod
    def dtype(self) -> torch.dtype:
        """Get the data type of the model parameters.

        Returns:
            The PyTorch data type (e.g., float32, float16) of the model
            parameters.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError("Dtype is not implemented for this model.")

    @abstractmethod
    def forward(
        self,
        inputs: Union[
            torch.Tensor,
            np.ndarray,
            Image.Image,
            list[Image.Image],
            list[np.ndarray],
            list[torch.Tensor],
            list[DatasetEntry],
        ],
    ) -> ModelOutput:
        """Perform forward pass through the model.

        Args:
            inputs: Input data in various supported formats:
                - torch.Tensor: Single tensor input
                - np.ndarray: Single numpy array input
                - Image.Image: Single PIL Image input
                - list[Image.Image]: List of PIL Images
                - list[np.ndarray]: List of numpy arrays
                - list[torch.Tensor]: List of tensors
                - list[DatasetEntry]: List of dataset entries

        Returns:
            Model output containing predictions and any additional metadata.

        Raises:
            NotImplementedError: This method must be implemented by subclasses.
        """
        raise NotImplementedError("Forward is not implemented for this model.")

    def load_state_dict(self, checkpoint_state_dict: dict, strict: bool = True) -> IncompatibleKeys:
        """Load model state dictionary from checkpoint with preprocessing.

        This method handles common issues when loading checkpoints:
        - Removes "module." prefix from DataParallel/DistributedDataParallel models
        - Handles shape mismatches by removing incompatible parameters
        - Logs incompatible keys for debugging

        Args:
            checkpoint_state_dict: Dictionary containing model parameters from
                a saved checkpoint.
            strict: Whether to strictly enforce that the keys in checkpoint_state_dict
                match the keys returned by this module's state_dict() function.
                Defaults to True.

        Returns:
            IncompatibleKeys object containing information about missing keys,
            unexpected keys, and parameters with incorrect shapes.
        """
        # if the state_dict comes from a model that was wrapped in a
        # DataParallel or DistributedDataParallel during serialization,
        # remove the "module" prefix before performing the matching.
        strip_prefix_if_present(checkpoint_state_dict, "module.")

        # workaround https://github.com/pytorch/pytorch/issues/24139
        model_state_dict = self.state_dict()
        incorrect_shapes = []
        for k in list(checkpoint_state_dict.keys()):
            if k in model_state_dict:
                model_param = model_state_dict[k]
                shape_model = tuple(model_param.shape)
                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
                if shape_model != shape_checkpoint:
                    incorrect_shapes.append((k, shape_checkpoint, shape_model))
                    checkpoint_state_dict.pop(k)

        incompatible = super().load_state_dict(checkpoint_state_dict, strict=strict)
        incompatible = IncompatibleKeys(
            missing_keys=incompatible.missing_keys,
            unexpected_keys=incompatible.unexpected_keys,
            incorrect_shapes=incorrect_shapes,
        )

        incompatible.log_incompatible_keys()

        return incompatible

    def benchmark(self, iterations: int = 50, size: Tuple[int, int] = (640, 640)) -> LatencyMetrics:
        """Benchmark model inference latency and throughput.

        Performs multiple inference runs on random data to measure model
        performance metrics including FPS, mean latency, and latency statistics.
        Uses CUDA events for precise timing when running on GPU.

        Args:
            iterations: Number of inference runs to perform for benchmarking.
                Defaults to 50.
            size: Input image size as (height, width) tuple. Defaults to (640, 640).

        Returns:
            LatencyMetrics object containing:
                - fps: Frames per second (throughput)
                - engine: Hardware/framework used for inference
                - mean: Mean inference time in milliseconds
                - max: Maximum inference time in milliseconds
                - min: Minimum inference time in milliseconds
                - std: Standard deviation of inference times
                - im_size: Input image size
                - device: Device used for inference

        Note:
            This method assumes the model is running on CUDA for timing.
            Input data is randomly generated for benchmarking purposes.
        """
        logger.info(f"⏱️ Benchmarking latency on {self.device}, size: {size}x{size}..")
        # warmup
        data = 128 * torch.randn(1, 3, size[0], size[1]).to(self.device)
        durations = []
        for _ in range(iterations):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record(stream=torch.cuda.Stream())
            _ = self(data)
            end.record(stream=torch.cuda.Stream())
            torch.cuda.synchronize()
            durations.append(start.elapsed_time(end))

        durations = np.array(durations)
        metrics = LatencyMetrics(
            fps=int(1000 / durations.mean()),
            engine=f"torch.{self.device}",
            mean=round(durations.mean().astype(float), 3),
            max=round(durations.max().astype(float), 3),
            min=round(durations.min().astype(float), 3),
            std=round(durations.std().astype(float), 3),
            im_size=size[0],  # FIXME: this is a hack to get the im_size as int, assuming it's a square
            device=str(self.device),
        )
        logger.info(f"🔥 FPS: {metrics.fps} Mean latency: {metrics.mean} ms ")
        return metrics

device abstractmethod property #

Get the device where the model is located.

Returns:

Type Description
device

The PyTorch device (CPU or CUDA) where the model parameters

device

are stored.

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

dtype abstractmethod property #

Get the data type of the model parameters.

Returns:

Type Description
dtype

The PyTorch data type (e.g., float32, float16) of the model

dtype

parameters.

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

__init__(config) #

Initialize the base model.

Parameters:

Name Type Description Default
config ModelConfig

Model configuration object containing model parameters and settings.

required
Source code in focoos/models/base_model.py
27
28
29
30
31
32
33
34
def __init__(self, config: ModelConfig):
    """Initialize the base model.

    Args:
        config: Model configuration object containing model parameters
            and settings.
    """
    super().__init__()

benchmark(iterations=50, size=(640, 640)) #

Benchmark model inference latency and throughput.

Performs multiple inference runs on random data to measure model performance metrics including FPS, mean latency, and latency statistics. Uses CUDA events for precise timing when running on GPU.

Parameters:

Name Type Description Default
iterations int

Number of inference runs to perform for benchmarking. Defaults to 50.

50
size Tuple[int, int]

Input image size as (height, width) tuple. Defaults to (640, 640).

(640, 640)

Returns:

Type Description
LatencyMetrics

LatencyMetrics object containing: - fps: Frames per second (throughput) - engine: Hardware/framework used for inference - mean: Mean inference time in milliseconds - max: Maximum inference time in milliseconds - min: Minimum inference time in milliseconds - std: Standard deviation of inference times - im_size: Input image size - device: Device used for inference

Note

This method assumes the model is running on CUDA for timing. Input data is randomly generated for benchmarking purposes.

Source code in focoos/models/base_model.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def benchmark(self, iterations: int = 50, size: Tuple[int, int] = (640, 640)) -> LatencyMetrics:
    """Benchmark model inference latency and throughput.

    Performs multiple inference runs on random data to measure model
    performance metrics including FPS, mean latency, and latency statistics.
    Uses CUDA events for precise timing when running on GPU.

    Args:
        iterations: Number of inference runs to perform for benchmarking.
            Defaults to 50.
        size: Input image size as (height, width) tuple. Defaults to (640, 640).

    Returns:
        LatencyMetrics object containing:
            - fps: Frames per second (throughput)
            - engine: Hardware/framework used for inference
            - mean: Mean inference time in milliseconds
            - max: Maximum inference time in milliseconds
            - min: Minimum inference time in milliseconds
            - std: Standard deviation of inference times
            - im_size: Input image size
            - device: Device used for inference

    Note:
        This method assumes the model is running on CUDA for timing.
        Input data is randomly generated for benchmarking purposes.
    """
    logger.info(f"⏱️ Benchmarking latency on {self.device}, size: {size}x{size}..")
    # warmup
    data = 128 * torch.randn(1, 3, size[0], size[1]).to(self.device)
    durations = []
    for _ in range(iterations):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record(stream=torch.cuda.Stream())
        _ = self(data)
        end.record(stream=torch.cuda.Stream())
        torch.cuda.synchronize()
        durations.append(start.elapsed_time(end))

    durations = np.array(durations)
    metrics = LatencyMetrics(
        fps=int(1000 / durations.mean()),
        engine=f"torch.{self.device}",
        mean=round(durations.mean().astype(float), 3),
        max=round(durations.max().astype(float), 3),
        min=round(durations.min().astype(float), 3),
        std=round(durations.std().astype(float), 3),
        im_size=size[0],  # FIXME: this is a hack to get the im_size as int, assuming it's a square
        device=str(self.device),
    )
    logger.info(f"🔥 FPS: {metrics.fps} Mean latency: {metrics.mean} ms ")
    return metrics

forward(inputs) abstractmethod #

Perform forward pass through the model.

Parameters:

Name Type Description Default
inputs Union[Tensor, ndarray, Image, list[Image], list[ndarray], list[Tensor], list[DatasetEntry]]

Input data in various supported formats: - torch.Tensor: Single tensor input - np.ndarray: Single numpy array input - Image.Image: Single PIL Image input - list[Image.Image]: List of PIL Images - list[np.ndarray]: List of numpy arrays - list[torch.Tensor]: List of tensors - list[DatasetEntry]: List of dataset entries

required

Returns:

Type Description
ModelOutput

Model output containing predictions and any additional metadata.

Raises:

Type Description
NotImplementedError

This method must be implemented by subclasses.

Source code in focoos/models/base_model.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@abstractmethod
def forward(
    self,
    inputs: Union[
        torch.Tensor,
        np.ndarray,
        Image.Image,
        list[Image.Image],
        list[np.ndarray],
        list[torch.Tensor],
        list[DatasetEntry],
    ],
) -> ModelOutput:
    """Perform forward pass through the model.

    Args:
        inputs: Input data in various supported formats:
            - torch.Tensor: Single tensor input
            - np.ndarray: Single numpy array input
            - Image.Image: Single PIL Image input
            - list[Image.Image]: List of PIL Images
            - list[np.ndarray]: List of numpy arrays
            - list[torch.Tensor]: List of tensors
            - list[DatasetEntry]: List of dataset entries

    Returns:
        Model output containing predictions and any additional metadata.

    Raises:
        NotImplementedError: This method must be implemented by subclasses.
    """
    raise NotImplementedError("Forward is not implemented for this model.")

load_state_dict(checkpoint_state_dict, strict=True) #

Load model state dictionary from checkpoint with preprocessing.

This method handles common issues when loading checkpoints: - Removes "module." prefix from DataParallel/DistributedDataParallel models - Handles shape mismatches by removing incompatible parameters - Logs incompatible keys for debugging

Parameters:

Name Type Description Default
checkpoint_state_dict dict

Dictionary containing model parameters from a saved checkpoint.

required
strict bool

Whether to strictly enforce that the keys in checkpoint_state_dict match the keys returned by this module's state_dict() function. Defaults to True.

True

Returns:

Type Description
IncompatibleKeys

IncompatibleKeys object containing information about missing keys,

IncompatibleKeys

unexpected keys, and parameters with incorrect shapes.

Source code in focoos/models/base_model.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def load_state_dict(self, checkpoint_state_dict: dict, strict: bool = True) -> IncompatibleKeys:
    """Load model state dictionary from checkpoint with preprocessing.

    This method handles common issues when loading checkpoints:
    - Removes "module." prefix from DataParallel/DistributedDataParallel models
    - Handles shape mismatches by removing incompatible parameters
    - Logs incompatible keys for debugging

    Args:
        checkpoint_state_dict: Dictionary containing model parameters from
            a saved checkpoint.
        strict: Whether to strictly enforce that the keys in checkpoint_state_dict
            match the keys returned by this module's state_dict() function.
            Defaults to True.

    Returns:
        IncompatibleKeys object containing information about missing keys,
        unexpected keys, and parameters with incorrect shapes.
    """
    # if the state_dict comes from a model that was wrapped in a
    # DataParallel or DistributedDataParallel during serialization,
    # remove the "module" prefix before performing the matching.
    strip_prefix_if_present(checkpoint_state_dict, "module.")

    # workaround https://github.com/pytorch/pytorch/issues/24139
    model_state_dict = self.state_dict()
    incorrect_shapes = []
    for k in list(checkpoint_state_dict.keys()):
        if k in model_state_dict:
            model_param = model_state_dict[k]
            shape_model = tuple(model_param.shape)
            shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
            if shape_model != shape_checkpoint:
                incorrect_shapes.append((k, shape_checkpoint, shape_model))
                checkpoint_state_dict.pop(k)

    incompatible = super().load_state_dict(checkpoint_state_dict, strict=strict)
    incompatible = IncompatibleKeys(
        missing_keys=incompatible.missing_keys,
        unexpected_keys=incompatible.unexpected_keys,
        incorrect_shapes=incorrect_shapes,
    )

    incompatible.log_incompatible_keys()

    return incompatible