Source code for rizemind.strategies.contribution.shapley.shapley_value_strategy

from collections.abc import Callable
from logging import DEBUG, INFO, WARNING
from typing import Protocol

from eth_typing import ChecksumAddress
from flwr.common import FitRes
from flwr.common.logger import log
from flwr.common.typing import FitIns, Parameters, Scalar
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import Strategy

from rizemind.authentication import (
    AuthenticatedClientProperties,
)
from rizemind.strategies.contribution.calculators import (
    ShapleyValueCalculator,
)
from rizemind.strategies.contribution.calculators.calculator import (
    ContributionCalculator,
    PlayerScore,
)
from rizemind.strategies.contribution.sampling import AllSets
from rizemind.strategies.contribution.sampling.sets_sampling_strat import (
    SetsSamplingStrategy,
)
from rizemind.strategies.contribution.shapley.trainer_set import (
    TrainerSetAggregate,
    TrainerSetAggregateStore,
)


[docs] class SupportsShapleyValueStrategy(Protocol): """Protocol defining the interface for swarm management in Shapley value strategies. This protocol specifies the required methods for managing trainer compensation and round progression in a federated learning swarm using Shapley value-based contribution calculation. """
[docs] def distribute(self, trainer_scores: list[tuple[ChecksumAddress, float]]) -> str: """Distribute rewards to trainers based on their contribution scores. Args: trainer_scores: List of tuples containing trainer addresses and their corresponding contribution scores. Returns: Transaction hash or confirmation string of the distribution operation. """ ...
[docs] def next_round( self, round_id: int, n_trainers: int, model_score: float, total_contributions: float, ) -> str: """Advance to the next training round and record round statistics. Args: round_id: The identifier of the current round. n_trainers: Number of trainers participating in this round. model_score: Performance score of the selected model for next round. total_contributions: Sum of all trainer contribution scores. Returns: Transaction hash or confirmation string of the next round operation. """ ...
[docs] class ShapleyValueStrategy(Strategy): """Federated learning strategy using Shapley values for contribution calculation. This strategy extends the Flower Strategy to incorporate Shapley value-based contribution calculation. It creates coalitions of trainers, aggregates their models, and evaluates each coalition's performance to compute fair contribution scores using the Shapley value method. Attributes: strategy: The underlying federated learning strategy for aggregation. swarm: The swarm manager handling reward distribution and round progression. coalition_to_score_fn: Optional function to compute a score from a coalition aggregate. last_round_parameters: Parameters from the previous round. aggregate_coalition_metrics: Optional function to aggregate metrics across coalitions. sets_sampling_strat: Strategy for sampling trainer subsets/coalitions. set_aggregates: Store for coalitions. contribution_calculator: Calculator for computing Shapley value contributions. """ # TODO: # There is a mismatch between the loss returned by `evaluate_coalitions` and the loss of the # selected model parameters for next round. This is due to the fact that for `evaluate_coalitions` # returns the minimum loss among all coalitions, while the selected model is from the coalition # that all trainers participated in. Therefore if this model does not have the lowest loss # (which can occur often) there will be a mismatch between the selected parameter's loss vs # what is displayed. This needs to be addresses in later versions by selecting the model # that its loss is returned by `evaluate_coalitions`. strategy: Strategy swarm: SupportsShapleyValueStrategy coalition_to_score_fn: Callable[[TrainerSetAggregate], float] | None = None last_round_parameters: Parameters | None aggregate_coalition_metrics: ( Callable[[list[TrainerSetAggregate]], dict[str, Scalar]] | None ) = None sets_sampling_strat: SetsSamplingStrategy set_aggregates: TrainerSetAggregateStore contribution_calculator: ContributionCalculator def __init__( self, strategy: Strategy, swarm: SupportsShapleyValueStrategy, coalition_to_score_fn: Callable[[TrainerSetAggregate], float] | None = None, aggregate_coalition_metrics_fn: Callable[ [list[TrainerSetAggregate]], dict[str, Scalar] ] | None = None, shapley_sampling_strat: SetsSamplingStrategy = AllSets(), contribution_calculator: ContributionCalculator = ShapleyValueCalculator(), ) -> None: """Initialize the Shapley value strategy. Args: strategy: Base federated learning strategy for model aggregation. swarm: Swarm manager for reward distribution and round management. coalition_to_score_fn: Optional function to extract score from coalition. If None, uses the coalition's loss value. aggregate_coalition_metrics_fn: Optional function to compute aggregate metrics across all coalitions. shapley_sampling_strat: Strategy for sampling trainer coalitions. Defaults to AllSets() which generates all possible subsets. contribution_calculator: Calculator for computing contribution scores. Defaults to ShapleyValueCalculator(). """ log(DEBUG, "ShapleyValueStrategy: initializing") self.strategy = strategy self.swarm = swarm self.coalition_to_score_fn = coalition_to_score_fn self.set_aggregates = TrainerSetAggregateStore() self.aggregate_coalition_metrics = aggregate_coalition_metrics_fn self.sets_sampling_strat = shapley_sampling_strat self.contribution_calculator = contribution_calculator
[docs] def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None: """Delegate the initialization of model parameters to the underlying strategy. Args: client_manager: Manager handling available clients. Returns: The initialized model parameters, or None if not applicable. """ self.last_round_parameters = self.strategy.initialize_parameters(client_manager) return self.last_round_parameters
[docs] def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager ) -> list[tuple[ClientProxy, FitIns]]: """Configure the next round of training for clients. Selects the coalition where every trainer participated from the previous round and uses its parameters as the starting point for the next training round. Args: server_round: The current server round number. parameters: Current model parameters. client_manager: Manager handling available clients. Returns: List of tuples containing client proxies and their fit instructions. """ log(DEBUG, "configure_fit: creating fit instructions for clients") log( DEBUG, "configure_fit: selecting the base coalition for next round", ) coalition = self.select_aggregate() parameters = parameters if coalition is None else coalition.parameters log( DEBUG, "configure_fit: setting the previous rounds best parameter from the selected coalition", ) self.last_round_parameters = parameters return self.strategy.configure_fit(server_round, parameters, client_manager)
[docs] def select_aggregate(self) -> TrainerSetAggregate | None: """Select the coalition aggregate to use for the next training round. Selects the coalition with the highest number of members as the base for the next round. Returns: The selected coalition aggregate, or None if no coalitions exist. """ coalitions = self.get_coalitions() if len(coalitions) == 0: log(DEBUG, "select_coalition: no coalition was found") return None # Find the coalition with the highest number of members log(DEBUG, "select_coalition: get coalition with the highest number of members") return max(coalitions, key=lambda coalition: len(coalition.members))
[docs] def aggregate_fit( self, server_round: int, results: list[tuple[ClientProxy, FitRes]], failures: list[tuple[ClientProxy, FitRes] | BaseException], ) -> tuple[Parameters | None, dict[str, bool | bytes | float | int | str]]: """Aggregate client training results and form coalitions. Creates coalitions from client fit results and delegates parameter aggregation to the underlying strategy. Args: server_round: The current server round number. results: List of tuples containing client proxies and their fit results. failures: List of any failed client results. Returns: A tuple containing the aggregated parameters (or None) and a dictionary of metrics. """ if len(failures) > 0: log( level=WARNING, msg=f"aggregate_fit: there have been {len(failures)} failures in round {server_round}", ) self.create_coalitions(server_round, results) return self.strategy.aggregate_fit(server_round, results, failures)
[docs] def create_coalitions( self, server_round: int, results: list[tuple[ClientProxy, FitRes]] ) -> list[TrainerSetAggregate]: """Create coalitions from client training results. Samples trainer subsets using the sampling strategy, aggregates parameters for each subset, and stores the resulting coalition aggregates. Args: server_round: The current server round number. results: List of tuples containing client proxies and their fit results. Returns: List of created coalitions. Raises: ValueError: If no aggregate parameters are returned for a trainer set. """ log(DEBUG, "create_coalitions: initializing") trainer_sets = self.sets_sampling_strat.sample_trainer_sets( server_round=server_round, results=results ) for trainer_set in trainer_sets: set_results: list[tuple[ClientProxy, FitRes]] = [] for client, result in results: auth = AuthenticatedClientProperties.from_client(client) if auth.trainer_address in trainer_set.members: set_results.append((client, result)) if trainer_set.size() == 0: parameters, config = self.last_round_parameters, {} else: parameters, config = self.strategy.aggregate_fit( server_round, set_results, [] ) if parameters is None: raise ValueError( f"No aggregate returned for trainer set ID {trainer_set.id}" ) self.set_aggregates.insert( TrainerSetAggregate( trainer_set.id, trainer_set.members, parameters, config ) ) return self.get_coalitions()
[docs] def get_coalitions(self) -> list[TrainerSetAggregate]: """Returns all coalitions.""" return self.set_aggregates.get_sets()
[docs] def get_coalition(self, id: str) -> TrainerSetAggregate: """Get a specific coalition by ID. Args: id: The identifier of the coalition. Returns: The requested coalition aggregate. Raises: Exception: If the coalition with the given ID is not found. """ return self.set_aggregates.get_set(id)
[docs] def compute_contributions( self, round_id: int, coalitions: list[TrainerSetAggregate] | None ) -> list[PlayerScore]: """Compute Shapley value contribution score for each trainer. Uses the contribution calculator to determine each trainer's contribution to the overall model performance based on coalition evaluations. Args: round_id: The current round identifier. coalitions: Optional list of coalitions to compute contributions from. If None, uses all available coalitions. Returns: List of player scores containing trainer addresses and their contribution values. """ # Create a bijective mapping between addresses and a bit_based representation # First the coalition_and_scores is sorted based on the length of list of addresses # Then given that the largest list has all addresses, it will assign it to # list_of_addresses log(DEBUG, "compute_contributions: initializing") if coalitions is None: coalitions = self.get_coalitions() if len(coalitions) == 0: log(DEBUG, "compute_contributions: no coalition was found, returning empty") return [] trainer_mapping = self.sets_sampling_strat.get_trainer_mapping(round_id) player_scores = self.contribution_calculator.get_scores( participant_mapping=trainer_mapping, store=self.set_aggregates, coalition_to_score_fn=self.coalition_to_score_fn, ) log( INFO, "compute_contributions: calculated player contributions.", extra={"player_scores": player_scores}, ) return list(player_scores.values())
[docs] def get_coalition_score(self, coalition: TrainerSetAggregate) -> float: """Get the performance score for a coalition. If no `coalition_to_score_fn` is provided it defaults to the loss value. The loss value is inversed since higher loss means lower performance. Args: coalition: The coalition aggregate to score. Returns: The performance score of the coalition. Raises: Exception: If the coalition has not been evaluated. """ score = None if self.coalition_to_score_fn is None: score = 1 / coalition.get_loss() else: score = self.coalition_to_score_fn(coalition) if score is None: raise Exception(f"Coalition {coalition.id} not evaluated") return score
[docs] def normalize_contribution_scores( self, trainers_and_contributions: list[PlayerScore] ) -> list[PlayerScore]: """Normalize contribution scores to ensure non-negative values. Args: trainers_and_contributions: List of player scores to normalize. Returns: List of player scores with negative values clamped to zero. """ return [ PlayerScore( trainer_address=score.trainer_address, score=max(score.score, 0) ) for score in trainers_and_contributions ]
[docs] def close_round(self, round_id: int) -> tuple[float, dict[str, Scalar]]: """Finalize the current round by computing contributions and distributing rewards. Computes trainer contributions, normalizes scores, distributes rewards, and prepares for the next round. Args: round_id: The current round identifier. Returns: A tuple containing the best coalition loss and aggregated metrics. """ coalitions = self.get_coalitions() player_scores = self.compute_contributions(round_id, coalitions) player_scores = self.normalize_contribution_scores(player_scores) for player_score in player_scores: if player_score.score == 0: log( WARNING, f"aggregate_evaluate: free rider detected! Trainer address: {player_score.trainer_address}, Score: {player_score.score}", ) self.swarm.distribute( [ (player_score.trainer_address, player_score.score) for player_score in player_scores ] ) loss, metrics = self.evaluate_coalitions() next_model = self.select_aggregate() score = 0 if next_model is None else self.get_coalition_score(next_model) self.swarm.next_round( round_id, len(player_scores), score, sum(score.score for score in player_scores), ) return loss, metrics
[docs] def evaluate_coalitions(self) -> tuple[float, dict[str, Scalar]]: """Evaluate all coalitions and determine the best performance. Calculates loss values for all coalitions and optionally aggregates metrics. Returns: A tuple containing the minimum coalition loss and aggregated metrics. """ log( DEBUG, "evaluate_coalitions: evaluating coalitions by calculating their loss and optional metrics", ) coalitions = self.get_coalitions() if len(coalitions) == 0: log( DEBUG, "evaluate_coalitions: no coalition found, returning inf as the loss value", ) return float("inf"), {} coalition_losses = [ coalition.get_loss() or float("inf") for coalition in coalitions ] metrics = ( {} if self.aggregate_coalition_metrics is None else self.aggregate_coalition_metrics(coalitions) ) return min(coalition_losses), metrics