from eth_account.signers.base import BaseAccount
from flwr.common.typing import FitRes
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import Strategy
from rizemind.authentication.authenticated_client_manager import (
AuthenticatedClientManager,
)
from rizemind.authentication.authenticated_client_properties import (
AuthenticatedClientProperties,
)
from rizemind.authentication.notary.model.config import (
parse_model_notary_config,
prepare_model_notary_config,
)
from rizemind.authentication.notary.model.model_signature import (
hash_parameters,
recover_model_signer,
sign_parameters_model,
)
from rizemind.authentication.typing import SupportsEthAccountStrategy
from rizemind.exception import ParseException, RizemindException
[docs]
class CannotTrainException(RizemindException):
"""An attempt was made to train with an unauthorized address."""
def __init__(self, address: str) -> None:
message = f"{address} cannot train"
super().__init__(code="cannot_train", message=message)
[docs]
class CannotRecoverSignerException(RizemindException):
"""The signer of a model update could not be recovered."""
def __init__(
self,
) -> None:
super().__init__(code="cannot_recover_signer", message="Cannot recover signer")
[docs]
class EthAccountStrategy(Strategy):
"""A federated learning strategy that verifies model authenticity.
This strategy wraps an existing Flower Strategy to ensure that only authorized
clients can contribute training updates. It verifies cryptographic signatures
against a blockchain-based model registry. If a client is not authorized, it
is added to the failures list with a `CannotTrainException`.
Attributes:
strat: The base Flower Strategy to wrap.
swarm: The blockchain-based model registry.
address: The contract address of the swarm.
account: The Ethereum account used for signing.
Example Usage:
strategy = SomeBaseStrategy()\n
model_registry = SwarmV1.from_address(address="0xMY_MODEL_ADDRESS")\n
eth_strategy = EthAccountStrategy(strategy, model_registry)
"""
strat: Strategy
swarm: SupportsEthAccountStrategy
address: str
account: BaseAccount
def __init__(
self,
strat: Strategy,
swarm: SupportsEthAccountStrategy,
account: BaseAccount,
):
"""Initializes the EthAccountStrategy.
Args:
strat: The base Flower Strategy to wrap.
swarm: The blockchain-based model registry.
account: The Ethereum account used for signing.
"""
super().__init__()
self.strat = strat
self.swarm = swarm
domain = self.swarm.get_eip712_domain()
self.address = domain.verifyingContract
self.account = account
[docs]
def initialize_parameters(self, client_manager):
"""Initializes model parameters."""
return self.strat.initialize_parameters(client_manager)
[docs]
def aggregate_fit(self, server_round, results, failures):
"""Aggregate fit results from authorized clients only.
Recovers the signer address from each client's fit result, tags the
client with the recovered address, and filters out contributions from
non-authorized addresses. Unauthorized attempts are recorded as
failures. Delegates final aggregation to the wrapped strategy.
Args:
server_round: The current server round.
results: A list of tuples `(ClientProxy, FitRes)` received from
clients.
failures: A list that will be extended with failures that occur
during processing.
Returns:
The aggregated result as returned by the wrapped strategy's
`aggregate_fit`.
"""
whitelisted: list[tuple[ClientProxy, FitRes]] = []
for client, res in results:
try:
signer = self._recover_signer(res, server_round)
properties = AuthenticatedClientProperties(trainer_address=signer)
properties.tag_client(client)
if self.swarm.can_train(signer, server_round):
whitelisted.append((client, res))
else:
failures.append(CannotTrainException(signer))
except ParseException:
failures.append(CannotRecoverSignerException())
return self.strat.aggregate_fit(server_round, whitelisted, failures)
def _recover_signer(self, res: FitRes, server_round: int):
"""Recovers the signer's address from a client's response.
Args:
res: The client's fit response.
server_round: The current server round.
Returns:
The Ethereum address of the signer.
Raises:
ParseException: If the notary configuration cannot be parsed.
"""
notary_config = parse_model_notary_config(res.metrics)
eip712_domain = self.swarm.get_eip712_domain()
return recover_model_signer(
model=res.parameters,
domain=eip712_domain,
round=server_round,
signature=notary_config.signature,
)
[docs]
def aggregate_evaluate(self, server_round, results, failures):
"""Aggregate evaluation results by delegating to the wrapped strategy.
Args:
server_round: The current server round.
results: A list of tuples `(ClientProxy, EvaluateRes)` received
from clients.
failures: A list of failures encountered during evaluation.
Returns:
The aggregated evaluation result as returned by the wrapped
strategy's `aggregate_evaluate`.
"""
return self.strat.aggregate_evaluate(server_round, results, failures)
[docs]
def evaluate(self, server_round, parameters):
"""Evaluate the current global model via the wrapped strategy.
Args:
server_round: The current server round.
parameters: The global model parameters to evaluate.
Returns:
The evaluation result as returned by the wrapped strategy's
`evaluate`.
"""
return self.strat.evaluate(server_round, parameters)