Source code for rizemind.strategies.contribution.sampling.sets_sampling_strat
from abc import ABC, abstractmethod
from flwr.common import FitRes
from flwr.server.client_proxy import ClientProxy
from rizemind.strategies.contribution.shapley.trainer_mapping import ParticipantMapping
from rizemind.strategies.contribution.shapley.trainer_set import (
TrainerSet,
)
[docs]
class SetsSamplingStrategy(ABC):
"""Abstract strategy for sampling trainer sets during federated learning rounds.
This abstract base class defines the interface for strategies that create sets
of trainers to evaluate during contribution assessment.
"""
[docs]
@abstractmethod
def sample_trainer_sets(
self, server_round: int, results: list[tuple[ClientProxy, FitRes]]
) -> list[TrainerSet]:
"""Samples and generates trainer sets for the given round.
Args:
server_round: The current server round number.
results: A list of tuples containing client proxies and their
corresponding fit results from the training round.
Returns:
A list of TrainerSet objects representing the sampled combinations
of trainers for this round.
"""
pass
[docs]
@abstractmethod
def get_sets(self, round_id: int) -> list[TrainerSet]:
"""Returns all trainer sets for the specified round.
Args:
round_id: The round identifier to retrieve sets for.
"""
pass
[docs]
@abstractmethod
def get_set(self, round_id: int, id: str) -> TrainerSet:
"""Returns a specific trainer set by ID for the specified round.
Args:
round_id: The round identifier to retrieve the set from.
id: The unique identifier of the trainer set to retrieve.
"""
pass
[docs]
@abstractmethod
def get_trainer_mapping(self, round_id: int) -> ParticipantMapping:
"""Returns the participant mapping for the specified round.
Args:
round_id: The round identifier to retrieve the mapping for.
"""
pass