diff --git a/src/MSAAdvanced.sol b/src/MSAAdvanced.sol index 45a9324..b9df6e8 100644 --- a/src/MSAAdvanced.sol +++ b/src/MSAAdvanced.sol @@ -16,6 +16,7 @@ import { ECDSA } from "solady/utils/ECDSA.sol"; import { Initializable } from "./lib/Initializable.sol"; import { ERC7779Adapter } from "./core/ERC7779Adapter.sol"; import { SentinelListLib } from "sentinellist/SentinelList.sol"; +import { PreValidationHookManager } from "./core/PreValidationHookManager.sol"; /** * @author zeroknots.eth | rhinestone.wtf @@ -29,6 +30,7 @@ contract MSAAdvanced is ExecutionHelper, ModuleManager, HookManager, + PreValidationHookManager, RegistryAdapter, ERC7779Adapter { @@ -182,11 +184,22 @@ contract MSAAdvanced is { if (!IModule(module).isModuleType(moduleTypeId)) revert MismatchModuleTypeId(moduleTypeId); - if (moduleTypeId == MODULE_TYPE_VALIDATOR) _installValidator(module, initData); - else if (moduleTypeId == MODULE_TYPE_EXECUTOR) _installExecutor(module, initData); - else if (moduleTypeId == MODULE_TYPE_FALLBACK) _installFallbackHandler(module, initData); - else if (moduleTypeId == MODULE_TYPE_HOOK) _installHook(module, initData); - else revert UnsupportedModuleType(moduleTypeId); + if (moduleTypeId == MODULE_TYPE_VALIDATOR) { + _installValidator(module, initData); + } else if (moduleTypeId == MODULE_TYPE_EXECUTOR) { + _installExecutor(module, initData); + } else if (moduleTypeId == MODULE_TYPE_FALLBACK) { + _installFallbackHandler(module, initData); + } else if (moduleTypeId == MODULE_TYPE_HOOK) { + _installHook(module, initData); + } else if ( + moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 + || moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 + ) { + _installPreValidationHook(module, moduleTypeId, initData); + } else { + revert UnsupportedModuleType(moduleTypeId); + } emit ModuleInstalled(moduleTypeId, module); } @@ -211,6 +224,11 @@ contract MSAAdvanced is _uninstallFallbackHandler(module, deInitData); } else if (moduleTypeId == MODULE_TYPE_HOOK) { _uninstallHook(module, deInitData); + } else if ( + moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 + || moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 + ) { + _uninstallPreValidationHook(module, moduleTypeId, deInitData); } else { revert UnsupportedModuleType(moduleTypeId); } @@ -261,6 +279,8 @@ contract MSAAdvanced is return VALIDATION_FAILED; } } else { + (userOpHash, userOp.signature) = + _withPreValidationHook(userOpHash, userOp, missingAccountFunds); // bubble up the return value of the validator module validSignature = IValidator(validator).validateUserOp(userOp, userOpHash); } @@ -286,7 +306,9 @@ contract MSAAdvanced is { address validator = address(bytes20(data[0:20])); if (!_isValidatorInstalled(validator)) revert InvalidModule(validator); - return IValidator(validator).isValidSignatureWithSender(msg.sender, hash, data[20:]); + bytes memory signature_; + (hash, signature_) = _withPreValidationHook(hash, data[20:]); + return IValidator(validator).isValidSignatureWithSender(msg.sender, hash, signature_); } /** @@ -310,6 +332,11 @@ contract MSAAdvanced is return _isFallbackHandlerInstalled(abi.decode(additionalContext, (bytes4)), module); } else if (moduleTypeId == MODULE_TYPE_HOOK) { return _isHookInstalled(module); + } else if ( + moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 + || moduleTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 + ) { + return _isPreValidationHookInstalled(module, moduleTypeId); } else { return false; } @@ -354,6 +381,10 @@ contract MSAAdvanced is else if (modulTypeId == MODULE_TYPE_EXECUTOR) return true; else if (modulTypeId == MODULE_TYPE_FALLBACK) return true; else if (modulTypeId == MODULE_TYPE_HOOK) return true; + else if ( + modulTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 + || modulTypeId == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 + ) return true; else return false; } diff --git a/src/core/PreValidationHookManager.sol b/src/core/PreValidationHookManager.sol new file mode 100644 index 0000000..f672d82 --- /dev/null +++ b/src/core/PreValidationHookManager.sol @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.21; + +import "./ModuleManager.sol"; +import "../interfaces/IERC7579Account.sol"; +import "../interfaces/IERC7579Module.sol"; + +/** + * @title reference implementation of PreValidationHookManager + * @author highskore | rhinestone.wtf + */ +abstract contract PreValidationHookManager { + event PreValidationHookUninstallFailed(address hook, bytes data); + + error InvalidHookType(); + + /// @custom:storage-location erc7201:prevalidationhookmanager.storage.msa + struct PreValidationHookManagerStorage { + IPreValidationHookERC1271 hook1271; + IPreValidationHookERC4337 hook4337; + } + + // forgefmt: disable-next-line + // keccak256(abi.encode(uint256(keccak256("prevalidationhookmanager.storage.msa")) - 1)) & ~bytes32(uint256(0xff)); + bytes32 constant PREVALIDATION_HOOKMANAGER_STORAGE_LOCATION = + 0x088e45215d3756b04bd240e41d75700a696139d5b53082481ffc3914e4840000; + + error PreValidationHookAlreadyInstalled(address currentHook); + + function _getStorage() + internal + pure + returns (PreValidationHookManagerStorage storage storage_) + { + bytes32 slot = PREVALIDATION_HOOKMANAGER_STORAGE_LOCATION; + assembly { + storage_.slot := slot + } + } + + function _setPreValidationHook(address hook, uint256 hookType) internal virtual { + PreValidationHookManagerStorage storage $ = _getStorage(); + if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271) { + $.hook1271 = IPreValidationHookERC1271(hook); + } else if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337) { + $.hook4337 = IPreValidationHookERC4337(hook); + } else { + revert InvalidHookType(); + } + } + + function _installPreValidationHook( + address hook, + uint256 hookType, + bytes calldata data + ) + internal + virtual + { + PreValidationHookManagerStorage storage $ = _getStorage(); + address currentHook = _getPreValidationHook(hookType); + if (currentHook != address(0)) { + revert PreValidationHookAlreadyInstalled(currentHook); + } + _setPreValidationHook(hook, hookType); + if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271) { + $.hook1271.onInstall(data); + } else if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337) { + $.hook4337.onInstall(data); + } + } + + function _uninstallPreValidationHook( + address hook, + uint256 hookType, + bytes calldata data + ) + internal + virtual + { + PreValidationHookManagerStorage storage $ = _getStorage(); + if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 && address($.hook1271) == hook) { + $.hook1271.onUninstall(data); + } else if ( + hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 && address($.hook4337) == hook + ) { + $.hook4337.onUninstall(data); + } else { + revert InvalidHookType(); + } + _setPreValidationHook(address(0), hookType); + } + + function _tryUninstallPreValidationHook(address hook, uint256 hookType) internal virtual { + PreValidationHookManagerStorage storage $ = _getStorage(); + if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271) { + try $.hook1271.onUninstall("") { } + catch { + emit PreValidationHookUninstallFailed(hook, ""); + } + $.hook1271 = IPreValidationHookERC1271(address(0)); + } else if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337) { + try $.hook4337.onUninstall("") { } + catch { + emit PreValidationHookUninstallFailed(hook, ""); + } + $.hook4337 = IPreValidationHookERC4337(address(0)); + } else { + revert InvalidHookType(); + } + } + + function _getPreValidationHook(uint256 hookType) internal view returns (address _hook) { + PreValidationHookManagerStorage storage $ = _getStorage(); + if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC1271) { + return address($.hook1271); + } else if (hookType == MODULE_TYPE_PREVALIDATION_HOOK_ERC4337) { + return address($.hook4337); + } else { + revert InvalidHookType(); + } + } + + function _isPreValidationHookInstalled( + address module, + uint256 hookType + ) + internal + view + returns (bool) + { + return _getPreValidationHook(hookType) == module; + } + + function getActiveHook(uint256 hookType) external view returns (address hook) { + return _getPreValidationHook(hookType); + } + + function _withPreValidationHook( + bytes32 hash, + bytes calldata signature + ) + internal + view + virtual + returns (bytes32 postHash, bytes memory postSig) + { + address preValidationHook = _getPreValidationHook(MODULE_TYPE_PREVALIDATION_HOOK_ERC1271); + if (preValidationHook == address(0)) { + return (hash, signature); + } else { + return IPreValidationHookERC1271(preValidationHook).preValidationHookERC1271( + msg.sender, hash, signature + ); + } + } + + function _withPreValidationHook( + bytes32 hash, + PackedUserOperation memory userOp, + uint256 missingAccountFunds + ) + internal + virtual + returns (bytes32 postHash, bytes memory postSig) + { + address preValidationHook = _getPreValidationHook(MODULE_TYPE_PREVALIDATION_HOOK_ERC4337); + if (preValidationHook == address(0)) { + return (hash, userOp.signature); + } else { + return IPreValidationHookERC4337(preValidationHook).preValidationHookERC4337( + userOp, missingAccountFunds, hash + ); + } + } +} diff --git a/src/interfaces/IERC7579Module.sol b/src/interfaces/IERC7579Module.sol index 09480a4..a2f7a3d 100644 --- a/src/interfaces/IERC7579Module.sol +++ b/src/interfaces/IERC7579Module.sol @@ -10,6 +10,8 @@ uint256 constant MODULE_TYPE_VALIDATOR = 1; uint256 constant MODULE_TYPE_EXECUTOR = 2; uint256 constant MODULE_TYPE_FALLBACK = 3; uint256 constant MODULE_TYPE_HOOK = 4; +uint256 constant MODULE_TYPE_PREVALIDATION_HOOK_ERC1271 = 8; +uint256 constant MODULE_TYPE_PREVALIDATION_HOOK_ERC4337 = 9; interface IModule { error AlreadyInitialized(address smartAccount); @@ -95,3 +97,25 @@ interface IHook is IModule { } interface IFallback is IModule { } + +interface IPreValidationHookERC1271 is IModule { + function preValidationHookERC1271( + address sender, + bytes32 hash, + bytes calldata data + ) + external + view + returns (bytes32 hookHash, bytes memory hookSignature); +} + +interface IPreValidationHookERC4337 is IModule { + function preValidationHookERC4337( + PackedUserOperation calldata userOp, + uint256 missingAccountFunds, + bytes32 userOpHash + ) + external + view + returns (bytes32 hookHash, bytes memory hookSignature); +}