Source code for rizemind.logging.mlflow.metric_storage

import os
import tempfile

import mlflow
import numpy as np
from flwr.common import (
    Parameters,
    Scalar,
    parameters_to_ndarrays,
)

from rizemind.logging.base_metric_storage import BaseMetricStorage


[docs] class MLFLowMetricStorage(BaseMetricStorage): """A concrete implementation of `BaseMetricStorage` that logs metrics and models to an MLflow tracking server. This class integrates Flower federated learning with MLflow, enabling centralized tracking of experiments, metrics, and model artifacts. Upon initialization, it connects to a specified MLflow tracking URI, sets up an experiment, and creates a new run to store all subsequent data. Attributes: experiment_name: The name of the MLflow experiment. run_name: The name of the MLflow run. mlflow_uri: The URI for the MLflow tracking server. mlflow_client: The MLflow client for interacting with the API. run_id: The unique ID of the MLflow run created for this session. """ def __init__(self, experiment_name: str, run_name: str, mlflow_uri: str): """Initializes the MLFLowMetricStorage and sets up the MLflow run. This constructor connects to the MLflow tracking server, ensures the specified experiment exists, and starts a new run. The run ID is stored for logging metrics and artifacts throughout the federated learning process. Args: experiment_name: The name of the experiment in MLflow. If it doesn't exist, it will be created. run_name: The name assigned to the run within the experiment. mlflow_uri: The connection URI for the MLflow tracking server. """ self.experiment_name = experiment_name self.run_name = run_name self.mlflow_uri = mlflow_uri mlflow.set_tracking_uri(self.mlflow_uri) self.mlflow_client = mlflow.MlflowClient() mlflow.set_experiment(experiment_name=self.experiment_name) run = mlflow.start_run(run_name=self.run_name) self.run_id: str = run.info.run_id mlflow.end_run() self._best_loss = np.inf self._current_round_model = Parameters(tensors=[], tensor_type="")
[docs] def write_metrics(self, server_round: int, metrics: dict[str, Scalar]): """Logs a dictionary of metrics to the MLflow run for a specific server round. This method iterates through the provided metrics and logs each one to the active MLflow run, using the server round as the step. Args: server_round: The current round of federated learning, used as the 'step' in MLflow. metrics: A dictionary mapping metric names (e.g., "accuracy") to their scalar values. """ for k, v in metrics.items(): self.mlflow_client.log_metric( run_id=self.run_id, key=k, value=float(v), step=server_round )
[docs] def update_current_round_model(self, parameters: Parameters): """Temporarily stores the model parameters for the current round in memory. This method holds the latest model parameters so they can be saved as an MLflow artifact later by `update_best_model` if this model proves to be the best one based on its loss. Args: parameters: The model parameters from the current round. """ self._current_round_model = parameters
[docs] def update_best_model(self, server_round: int, loss: float): """Saves the current model as an MLflow artifact if its loss is the lowest seen so far. It compares the provided loss with its internally tracked best loss. If the new loss is lower, it updates the best loss and serializes the in-memory model parameters to a temporary `.npz` file. This file is then uploaded as an artifact to the MLflow run. It also logs the best round and loss as metrics. Args: server_round: The server round that produced this model. loss: The loss value of the current model, used to determine if it is the new best model. """ if loss < self._best_loss: self._best_loss = loss with tempfile.TemporaryDirectory() as tmp: ndarray_params = parameters_to_ndarrays(self._current_round_model) path = os.path.join(tmp, "weights.npz") np.savez(path, *ndarray_params) self.mlflow_client.log_artifact( run_id=self.run_id, local_path=path, artifact_path="flwr_best_model_params", ) self.mlflow_client.log_metric( run_id=self.run_id, key="best_round", value=server_round ) self.mlflow_client.log_metric( run_id=self.run_id, key="avg_loss", value=loss, step=server_round )