Source code for rizemind.logging.mlflow.mod
import time
from logging import WARNING
from typing import cast
import mlflow
import pandas as pd
from flwr.client.typing import ClientAppCallable
from flwr.common import Context, log
from flwr.common.constant import MessageType
from flwr.common.message import Message
from flwr.common.recorddict_compat import recorddict_to_fitres
from mlflow.entities import RunStatus, ViewType
from rizemind.logging.mlflow.config import MLFlowConfig
from rizemind.logging.train_metric_history import (
TRAIN_METRIC_HISTORY_KEY,
TrainMetricHistory,
)
[docs]
def mlflow_mod(msg: Message, ctx: Context, call_next: ClientAppCallable) -> Message:
"""Logs metrics on an incoming TRAIN message to an Mlflow server.
The `mlflow_mod` relies on the `TRAIN_METRIC_HISTORY_KEY` as a standardized
metric type and reads the content of this metric for logging.
In addition to the metrics available in `TRAIN_METRIC_HISTORY_KEY`,
`mlflow_mod` automatically logs training_time and epochs.
Args:
msg: The incoming message from the ServerApp to the ClientApp.
ctx: Context of the run.
call_next: The next callable in the chain to process the message.
Returns:
The response message sent from the ClientApp to the ServerApp.
"""
start_time = time.time()
reply: Message = call_next(msg, ctx)
time_diff = time.time() - start_time
mlflow_config = MLFlowConfig.from_context(ctx=ctx)
if mlflow_config is None:
log(
level=WARNING,
msg="mlflow config was not found in client context, skipping logging.",
)
return reply
mlflow.set_tracking_uri(mlflow_config.mlflow_uri)
mlflow_experiment_name = mlflow_config.experiment_name
mlflow_run_name = f"{mlflow_config.run_name}_client_id_{ctx.node_id}"
if msg.metadata.message_type == MessageType.TRAIN:
mlflow.set_experiment(experiment_name=mlflow_experiment_name)
runs_df = cast(
pd.DataFrame,
mlflow.search_runs(
experiment_names=[mlflow_experiment_name],
filter_string=f"tags.mlflow.runName = '{mlflow_run_name}'",
run_view_type=ViewType.ALL,
order_by=["attributes.end_time DESC"],
max_results=1,
),
)
epochs_passed = 0
run_id = ""
if runs_df.empty:
# If a previous run doesn't exist
# start a run with the given name
mlflow.start_run(run_name=mlflow_run_name)
else:
# If a previous run exists
# update the number of epochs passed
epochs_passed = int(cast(int, runs_df.loc[0, "metrics.epochs"]))
# continue the run
run_id: str = cast(str, runs_df.loc[0, "run_id"])
mlflow.start_run(run_id=run_id)
if not reply.has_content():
mlflow.end_run(status=RunStatus.to_string(RunStatus.FAILED))
else:
# Log training time
server_round = int(msg.metadata.group_id)
mlflow.log_metric(key="training_time", value=time_diff, step=server_round)
# Get metrics and log them
fit_res = recorddict_to_fitres(reply.content, keep_input=True)
serialized_train_metric_history = cast(
str, fit_res.metrics.get(TRAIN_METRIC_HISTORY_KEY)
)
train_metric_history = TrainMetricHistory.deserialize(
serialized_train_metric_history=serialized_train_metric_history
)
epochs_this_round = 0
for metric, phases in train_metric_history.model_dump().items():
for phase, values in phases.items():
for step, metric_value in enumerate(values):
mlflow.log_metric(
key=f"{phase}_{metric}",
value=metric_value,
step=step + epochs_passed,
)
epochs_this_round = max(epochs_this_round, len(values))
epochs_passed += epochs_this_round
mlflow.log_metric(key="epochs", value=epochs_passed)
mlflow.end_run(status=RunStatus.to_string(RunStatus.FINISHED))
return reply