Source code for rizemind.strategies.contribution.shapley.trainer_set
import statistics
from collections.abc import Callable
from eth_typing import ChecksumAddress
from flwr.common import EvaluateRes
from flwr.common.typing import Parameters, Scalar
[docs]
class TrainerSet:
"""A set of trainers forming a coalition.
Attributes:
id: Unique identifier for this trainer set.
members: List of trainer addresses that are members of this set.
"""
id: str
members: list[ChecksumAddress]
def __init__(
self,
id: str,
members: list[ChecksumAddress],
) -> None:
"""Initialize a trainer set.
Args:
id: Unique identifier for this set.
members: List of trainer addresses in this set.
"""
self.id = id
self.members = members
[docs]
def size(self) -> int:
"""Get the number of trainers in this set.
Returns:
The count of member trainers.
"""
return len(self.members)
[docs]
class TrainerSetAggregate(TrainerSet):
"""A trainer set with aggregated model parameters and configuration dictionary.
Attributes:
parameters: Aggregated model parameters for this coalition.
config: Configuration dictionary.
"""
parameters: Parameters
config: dict[str, Scalar]
_evaluation_res: list[EvaluateRes]
def __init__(
self,
id: str,
members: list[ChecksumAddress],
parameters: Parameters,
config: dict[str, Scalar],
) -> None:
"""Initialize a trainer set aggregate.
Args:
id: Unique identifier for this coalition.
members: List of trainer addresses in this coalition.
parameters: Aggregated model parameters.
config: Configuration dictionary.
"""
super().__init__(id, members=members)
self.parameters = parameters
self.config = config
self._evaluation_res = []
[docs]
def insert_res(self, eval_res: EvaluateRes):
"""Add an evaluation result to this coalition.
Args:
eval_res: The evaluation result to store.
"""
self._evaluation_res.append(eval_res)
[docs]
def get_loss(
self,
aggregator: Callable[[list[float]], float] = statistics.mean,
):
"""Get the aggregated loss for this coalition.
Args:
aggregator: Function to aggregate multiple loss values.
Defaults to mean.
Returns:
The aggregated loss value, or infinity if no evaluations exist.
"""
if len(self._evaluation_res) == 0:
return float("Inf")
losses = [res.loss for res in self._evaluation_res]
return aggregator(losses)
[docs]
def get_metric(
self,
name: str,
default: Scalar,
aggregator: Callable,
):
"""Get an aggregated metric value for this coalition.
Args:
name: The metric name to retrieve.
default: Default value to return if metric is unavailable.
aggregator: Function to aggregate multiple metric values.
Returns:
The aggregated metric value, or default if not all evaluations
contain this metric.
"""
if not self._evaluation_res:
return default
metric_values = [res.metrics.get(name) for res in self._evaluation_res]
valid_metrics = [v for v in metric_values if v is not None]
if len(valid_metrics) != len(self._evaluation_res):
return default
return aggregator(valid_metrics)
[docs]
class TrainerSetAggregateStore:
"""Storage system for trainer set aggregates.
Maintains a collection of coalition aggregates indexed by their ids.
Attributes:
set_aggregates: Dictionary mapping coalition IDs to their aggregates.
"""
set_aggregates: dict[str, TrainerSetAggregate]
def __init__(self) -> None:
"""Initialize an empty aggregate store."""
self.set_aggregates = {}
[docs]
def insert(self, aggregate: TrainerSetAggregate) -> None:
"""Insert a coalition aggregate in the store.
Args:
aggregate: The coalition aggregate to store.
"""
self.set_aggregates[aggregate.id] = aggregate
[docs]
def get_sets(self) -> list[TrainerSetAggregate]:
"""Get all coalition aggregates.
Returns:
List of all stored coalition aggregates.
"""
return list(self.set_aggregates.values())
[docs]
def clear(self) -> None:
"""Remove all coalition aggregates from the store."""
self.set_aggregates = {}
[docs]
def get_set(self, id: str) -> TrainerSetAggregate:
"""Retrieve a specific coalition aggregate by ID.
Args:
id: The unique identifier of the coalition.
Returns:
The requested coalition aggregate.
Raises:
Exception: If no coalition with the given ID exists.
"""
if id in self.set_aggregates:
return self.set_aggregates[id]
raise Exception(f"Coalition {id} not found")
[docs]
def is_available(self, id: str) -> bool:
"""Check if a coalition aggregate exists in the store.
Args:
id: The unique identifier to check.
Returns:
True if the coalition exists, False otherwise.
"""
return id in self.set_aggregates