Source code for rizemind.strategies.contribution.sampling.all_sets

import itertools

from eth_typing import ChecksumAddress
from flwr.common import FitRes
from flwr.server.client_proxy import ClientProxy

from rizemind.authentication.authenticated_client_properties import (
    AuthenticatedClientProperties,
)
from rizemind.strategies.contribution.sampling.sets_sampling_strat import (
    SetsSamplingStrategy,
)
from rizemind.strategies.contribution.shapley import (
    ParticipantMapping,
    TrainerSet,
)


[docs] class AllSets(SetsSamplingStrategy): """Sampling strategy that generates all possible combinations of trainers. This strategy implements exhaustive sampling by generating the complete power set of trainers for each round. It creates all 2^n possible combinations (coalitions) where n is the number of trainers, including the empty set which corresponds to the previous rounds best model. The strategy maintains state for the current round and caches the generated sets and participant mapping. If called multiple times for the same round, it returns the cached results. When a new round is detected, it regenerates all sets. Attributes: trainer_mapping: Maps trainer addresses to integer indices for set IDs. current_round: The round number for which sets have been generated. sets: Dictionary mapping set IDs to TrainerSet objects for the current round. """ trainer_mapping: ParticipantMapping current_round: int sets: dict[str, TrainerSet] def __init__(self) -> None: """Initializes the AllSets strategy with empty state. Sets up an empty participant mapping, initializes the current round to an invalid state (-1), and prepares an empty sets dictionary for caching. """ self.trainer_mapping = ParticipantMapping() self.current_round = -1 # Initialize to invalid round self.sets = {}
[docs] def sample_trainer_sets( self, server_round: int, results: list[tuple[ClientProxy, FitRes]] ) -> list[TrainerSet]: """Generates all possible combinations of trainers for the given round. If this is a new round, the method: 1. Extracts all trainer addresses from the results 2. Creates a participant mapping for ID generation 3. Generates all 2^n possible combinations (power set) using itertools 4. Creates a TrainerSet for each combination with a unique ID If called again for the same round, it returns cached results without regeneration. Args: server_round: The current server round number. results: Training results containing client proxies and fit results. Returns: A list of all possible TrainerSet combinations, including the empty set. For n trainers, returns 2^n sets. Raises: ValueError: if server_round is less than the current round. """ if server_round < self.current_round: raise ValueError( f"Unsupported sampling, round {server_round} is less than the current round {self.current_round}." ) if server_round == self.current_round: return self.get_sets(round_id=server_round) self.current_round = server_round self.trainer_mapping = ParticipantMapping() self.sets = {} for client, _ in results: auth = AuthenticatedClientProperties.from_client(client) self.trainer_mapping.add_participant(auth.trainer_address) results_coalitions = [ list(combination) for r in range(len(results) + 1) for combination in itertools.combinations(results, r) ] for results_coalition in results_coalitions: members: list[ChecksumAddress] = [] for client, _ in results_coalition: auth = AuthenticatedClientProperties.from_client(client) members.append(auth.trainer_address) id = self.trainer_mapping.get_participant_set_id(members) self.sets[id] = TrainerSet(id, members) return self.get_sets(round_id=server_round)
[docs] def get_sets(self, round_id: int) -> list[TrainerSet]: """Returns all trainer sets for the given round. Args: round_id: The round identifier to retrieve sets for. Returns: A list of all TrainerSet objects generated for the specified round. Raises: ValueError: If the round_id does not match the current round. """ if round_id != self.current_round: raise ValueError( f"Round {round_id} is not the current round {self.current_round}." ) return list(self.sets.values())
[docs] def get_set(self, round_id: int, id: str) -> TrainerSet: """Returns a specific trainer set by ID for the given round. Args: round_id: The round identifier to retrieve the set from. id: The unique identifier of the trainer set, generated by the participant mapping based on the set's members. Returns: The TrainerSet object with the specified ID. Raises: ValueError: If the round_id does not match the current round, or if no trainer set with the given ID exists in the current round. """ if round_id != self.current_round: raise ValueError( f"Round {round_id} is not the current round {self.current_round}." ) if id not in self.sets: raise ValueError(f"Trainer set with ID {id} not found") return self.sets[id]
[docs] def get_trainer_mapping(self, round_id: int) -> ParticipantMapping: """Returns the participant mapping for the specified round. Args: round_id: The round identifier to retrieve the mapping for. Returns: The ParticipantMapping object containing all trainer address mappings for the specified round. Raises: ValueError: If the round_id does not match the current round, indicating that a new round needs to be processed via sample_trainer_sets. """ if round_id != self.current_round: raise ValueError( f"Round {round_id} is not the current round {self.current_round}." ) return self.trainer_mapping