Source code for rizemind.authentication.notary.model.model_signature

from eth_account import Account
from eth_account.signers.base import BaseAccount
from eth_typing import ChecksumAddress
from flwr.common.typing import Parameters
from web3 import Web3

from rizemind.authentication.signatures.eip712 import (
    EIP712DomainRequiredFields,
    prepare_eip712_message,
)
from rizemind.authentication.signatures.signature import Signature

ModelTypeName = "Model"
ModelTypeAbi = [
    {"name": "round", "type": "uint256"},
    {"name": "hash", "type": "bytes32"},
]


[docs] def hash_parameters(parameters: Parameters) -> bytes: """Hashes the Parameters dataclass using keccak256. Args: parameters: The model parameters to hash. Returns: The keccak256 hash of the concatenated tensors and tensor type. """ # Concatenate tensors and tensor type for hashing data = b"".join(parameters.tensors) + parameters.tensor_type.encode() return Web3.keccak(data)
[docs] def sign_parameters_model( *, parameters: Parameters, round: int, domain: EIP712DomainRequiredFields, account: BaseAccount, ) -> Signature: """Signs a model's parameters using the EIP-712 standard. @TODO -> requires double checking with domain Args: account: An Ethereum account object from which the message will be signed. parameters: The model parameters to sign. domain: The EIP712 required fields. round: The current round number of the federated learning. Returns: The `SignedMessage` from eth_account """ parameters_hash = hash_parameters(parameters) eip712_message = prepare_eip712_message( domain, ModelTypeName, {"round": round, "hash": parameters_hash}, {ModelTypeName: ModelTypeAbi}, ) signature = account.sign_message(eip712_message) return Signature(data=signature.signature)
[docs] def recover_model_signer( *, model: Parameters, round: int, domain: EIP712DomainRequiredFields, signature: Signature, ) -> ChecksumAddress: """Recover the address of the signed model. Args: model: The model's parameters. round: The current round number of the federated learning. domain: The EIP712 required fields. signature: The signature of the trainer/aggregator that sent the parameters. Returns: The hex address of the signer. """ model_hash = hash_parameters(model) eip712_message = prepare_eip712_message( domain, ModelTypeName, {"round": round, "hash": model_hash}, {ModelTypeName: ModelTypeAbi}, ) return Web3.to_checksum_address( Account.recover_message(eip712_message, signature=signature.data) )