From 216fbdfc74e82a7533f30e30503b4fe5be206bb7 Mon Sep 17 00:00:00 2001 From: James Levy Date: Mon, 29 Sep 2025 15:45:11 -0700 Subject: [PATCH] feat: Micropayment channel extension for AP2 This enhancement extends AP2 with a comprehensive micropayment channel framework that enables high-frequency, sub-cent transactions between agents through state channels, supporting Kite's vision of pay-per-use and streaming payment models. Key Features: - Payment channel infrastructure with state management - Streaming payment primitives for real-time micropayments - Multi-currency support including stablecoins (USDC, PYUSD) - Channel security framework with dispute resolution - Pay-per-token, pay-per-second, and pay-per-API-call models - A2A extension integration for agent capability advertising Components Added: - PaymentChannel with participant management and policies - StreamingPaymentSession for continuous payment flows - ChannelManager for lifecycle and security operations - CryptoPaymentAmount for blockchain-native currencies - AI inference service sample with micropayment integration - Extended A2A roles: micropayment-provider, streaming-payment-consumer This enables use cases like: - Pay-per-API-call for AI inference services - Streaming payments for real-time data feeds - Usage-based pricing for computational resources - Micro-subscriptions with automatic scaling - Cross-agent value transfer in multi-step workflows Transforms AP2 from traditional e-commerce into a comprehensive framework for the agent economy, enabling ultra-high-frequency programmable payment flows while maintaining security and interoperability. --- .cspell/custom-words.txt | 147 +++-- docs/a2a-extension.md | 120 +++- .../ai_inference_service/__init__.py | 15 + .../ai_inference_service/service.py | 562 ++++++++++++++++++ .../python/src/common/a2a_extension_utils.py | 2 +- .../python/src/common/a2a_message_builder.py | 137 ++--- samples/python/src/common/artifact_utils.py | 89 +-- .../python/src/common/base_server_executor.py | 296 ++++----- .../src/common/function_call_resolver.py | 131 ++-- .../src/common/payment_remote_a2a_client.py | 174 +++--- .../python/src/common/retrying_llm_agent.py | 76 +-- samples/python/src/common/server.py | 330 +++++----- samples/python/src/common/validation.py | 24 +- samples/python/src/common/watch_log.py | 120 ++-- .../credentials_provider_agent/__main__.py | 29 +- .../account_manager.py | 274 +++++---- .../agent_executor.py | 44 +- .../roles/credentials_provider_agent/tools.py | 353 ++++++----- .../src/roles/merchant_agent/__main__.py | 26 +- .../roles/merchant_agent/agent_executor.py | 176 +++--- .../src/roles/merchant_agent/storage.py | 22 +- .../sub_agents/catalog_agent.py | 207 ++++--- .../python/src/roles/merchant_agent/tools.py | 419 ++++++------- .../__main__.py | 28 +- .../agent_executor.py | 27 +- .../merchant_payment_processor_agent/tools.py | 325 +++++----- .../python/src/roles/shopping_agent/agent.py | 18 +- .../src/roles/shopping_agent/remote_agents.py | 8 +- .../payment_method_collector/agent.py | 10 +- .../payment_method_collector/tools.py | 110 ++-- .../shipping_address_collector/agent.py | 11 +- .../shipping_address_collector/tools.py | 53 +- .../shopping_agent/subagents/shopper/agent.py | 10 +- .../shopping_agent/subagents/shopper/tools.py | 182 +++--- .../python/src/roles/shopping_agent/tools.py | 477 ++++++++------- src/ap2/channels/__init__.py | 15 + src/ap2/channels/channel_manager.py | 425 +++++++++++++ src/ap2/types/contact_picker.py | 37 +- src/ap2/types/mandate.py | 277 ++++----- src/ap2/types/payment_channels.py | 428 +++++++++++++ src/ap2/types/payment_request.py | 361 +++++------ src/ap2/types/streaming_payments.py | 490 +++++++++++++++ 42 files changed, 4637 insertions(+), 2428 deletions(-) create mode 100644 samples/python/scenarios/a2a/micropayments/ai_inference_service/__init__.py create mode 100644 samples/python/scenarios/a2a/micropayments/ai_inference_service/service.py create mode 100644 src/ap2/channels/__init__.py create mode 100644 src/ap2/channels/channel_manager.py create mode 100644 src/ap2/types/payment_channels.py create mode 100644 src/ap2/types/streaming_payments.py diff --git a/.cspell/custom-words.txt b/.cspell/custom-words.txt index 71b8fff4..f58a18bf 100644 --- a/.cspell/custom-words.txt +++ b/.cspell/custom-words.txt @@ -1,112 +1,159 @@ +ACMRTUXB +ASGI +Adyen +Agentic +BVNK +Blackhawk +CLASSPATH +CYGPATTERN +Crossmint +DCQL +DID +DIDs +DPAN +Dafiti +Dcql +Dfile +Dorg +Drawables +Ebanx +Fiuu +Forter +Fudd's +Garena +Gravitee +Imtp +JAVACMD +Jetpack +KXMYBJWNQ +Kotlinx +Ktor +Lazada +Lightspark +MSYS +Mispick +Monee +Mysten +OURCYGPATTERN +Otherville +PYUSD +Payoneer +Proguard +ROOTDIRS +ROOTDIRSRAW +Revolut +Rulebook +Shopcider +Shopee +USDC +Wallex +Worldline +Worldpay +XVCJ +Xdock +Zalora aapt absl achatassistant -ACMRTUXB -Adyen agentic -Agentic agenticpayments androidx appname -ASGI -Blackhawk -BVNK +attestation +autopayments +blacklist +blockchain classpath -CLASSPATH cmwallet contentnegotiation credman -Crossmint +cryptographic cryptographical -CYGPATTERN -Dafiti dcql -Dcql -DCQL +decentralized deviceauth -Dfile -Dorg -DPAN -Drawables -Ebanx emvco endlocal +ephemeral esac -Fiuu +ethereum +ethr fontawesome -Forter fqcn -Fudd's -Garena +fromisoformat gemini genai generativeai gradlew -Gravitee gson +gte icns imei -Imtp inlinehilite inmemory +interoperability +isoformat issuerauth -JAVACMD jetbrains -Jetpack jvmargs keepattributes keepclassmembers kotlin kotlinx -Kotlinx ktor -Ktor -KXMYBJWNQ -Lazada -Lightspark linenums +lte mastercard +mbps +micropayment micropayments -Mispick -Monee +millisecond +milliseconds msys -MSYS +multi +multicurrency multistep -Mysten +nanosecond +nanoseconds +ne +nonce +nonces octicons okhttp -Otherville -OURCYGPATTERN passcodes -Payoneer +payee +payer paypal pids +polygon +prepayments +programmable proguard -Proguard pymdownx reemademo refundability +regex renamesourcefileattribute representment repudiable -Revolut -ROOTDIRS -ROOTDIRSRAW -Rulebook +repudiation +revocable screenreaders setlocal sharedpref -Shopcider -Shopee sideloaded skus +solana +stablecoin stablecoins +streamable +subagent +subagents superfences +timestamped +verifiable viewmodel +voucher +vouchers vulnz -Wallex -Worldline -Worldpay -Xdock -XVCJ -Zalora +whitelist diff --git a/docs/a2a-extension.md b/docs/a2a-extension.md index f715293c..cd7a3a9c 100644 --- a/docs/a2a-extension.md +++ b/docs/a2a-extension.md @@ -40,7 +40,7 @@ schema: "description": "The roles that this agent performs in the AP2 model.", "minItems": 1, "items": { - "enum": ["merchant", "shopper", "credentials-provider", "payment-processor"] + "enum": ["merchant", "shopper", "credentials-provider", "payment-processor", "micropayment-provider", "streaming-payment-consumer"] } } }, @@ -51,7 +51,7 @@ schema: This schema is also expressed by the following Pydantic type definition: ```py -AP2Role = "merchant" | "shopper" | "credentials-provider" | "payment-processor" +AP2Role = "merchant" | "shopper" | "credentials-provider" | "payment-processor" | "micropayment-provider" | "streaming-payment-consumer" class AP2ExtensionParameters(BaseModel): # The roles this agent performs in the AP2 model. At least one value is required. @@ -97,6 +97,122 @@ The following listing shows an AgentCard declaring AP2 extension support. } ``` +## Micropayment Channel Capabilities + +Agents that support micropayment channels can advertise their capabilities through additional parameters in the AgentCard extension. This enables high-frequency, sub-cent transactions for AI inference, data streaming, and other pay-per-use services. + +### Micropayment Provider Example + +```json +{ + "name": "AI Inference Service", + "description": "High-performance AI model with pay-per-token pricing", + "capabilities": { + "extensions": [ + { + "uri": "https://github.com/google-agentic-commerce/ap2/tree/v0.1", + "description": "Micropayment channels for AI inference", + "params": { + "roles": ["merchant", "micropayment-provider"], + "payment_channels": { + "supported": true, + "min_deposit": "1.0", + "rate_per_call": "0.001", + "rate_per_token": "0.0001", + "supported_currencies": ["USDC", "PYUSD", "ETH"], + "blockchain_networks": ["ethereum", "polygon", "kite"], + "max_channel_duration": "86400", + "checkpoint_frequency": "30", + "streaming_payments": true + } + } + } + ] + } +} +``` + +### Streaming Payment Consumer Example + +```json +{ + "name": "Data Analytics Agent", + "description": "Consumes real-time data feeds with streaming payments", + "capabilities": { + "extensions": [ + { + "uri": "https://github.com/google-agentic-commerce/ap2/tree/v0.1", + "description": "Streaming payments for data consumption", + "params": { + "roles": ["shopper", "streaming-payment-consumer"], + "payment_streams": { + "max_concurrent_streams": 5, + "supported_rate_types": ["per_second", "per_byte", "per_request"], + "auto_pause_threshold": "10.0", + "preferred_currencies": ["USDC", "DAI"], + "quality_requirements": { + "max_latency_ms": 100, + "min_throughput_mbps": 10 + } + } + } + } + ] + } +} +``` + +### Payment Channel Parameters Schema + +The `payment_channels` object in the extension params supports the following schema: + +```json +{ + "type": "object", + "name": "PaymentChannelCapabilities", + "properties": { + "supported": { + "type": "boolean", + "description": "Whether payment channels are supported" + }, + "min_deposit": { + "type": "string", + "description": "Minimum deposit required to open a channel" + }, + "rate_per_call": { + "type": "string", + "description": "Cost per API call or request" + }, + "rate_per_token": { + "type": "string", + "description": "Cost per AI token for inference services" + }, + "supported_currencies": { + "type": "array", + "items": {"type": "string"}, + "description": "List of supported payment currencies" + }, + "blockchain_networks": { + "type": "array", + "items": {"type": "string"}, + "description": "Supported blockchain networks" + }, + "max_channel_duration": { + "type": "string", + "description": "Maximum channel duration in seconds" + }, + "checkpoint_frequency": { + "type": "string", + "description": "How often to create payment checkpoints (seconds)" + }, + "streaming_payments": { + "type": "boolean", + "description": "Whether streaming payments are supported" + } + } +} +``` + ## AP2 Data Type Containers The following sections describe how AP2 data types are encapsulated into A2A diff --git a/samples/python/scenarios/a2a/micropayments/ai_inference_service/__init__.py b/samples/python/scenarios/a2a/micropayments/ai_inference_service/__init__.py new file mode 100644 index 00000000..6dbf2349 --- /dev/null +++ b/samples/python/scenarios/a2a/micropayments/ai_inference_service/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AI inference service with micropayment channels.""" diff --git a/samples/python/scenarios/a2a/micropayments/ai_inference_service/service.py b/samples/python/scenarios/a2a/micropayments/ai_inference_service/service.py new file mode 100644 index 00000000..9656b654 --- /dev/null +++ b/samples/python/scenarios/a2a/micropayments/ai_inference_service/service.py @@ -0,0 +1,562 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AI Inference Service with micropayment channels. + +This sample demonstrates how to build an AI service that accepts micropayments +for inference requests, supporting pay-per-token and streaming payment models. + +Example usage: + python service.py --port 8080 --model "gpt-3.5-turbo" --rate-per-token 0.0001 +""" + +import argparse +import asyncio +import uuid + +from typing import Any + +from ap2.channels.channel_manager import ChannelManager, ChannelOperationResult +from ap2.types.payment_channels import ( + ChannelOpenRequest, + ChannelParticipant, + ChannelPolicy, +) +from ap2.types.payment_request import CryptoPaymentAmount +from ap2.types.streaming_payments import ( + PaymentRate, + PaymentRateType, + StreamStatus, + StreamingPaymentManager, + StreamingPaymentPolicy, + StreamingPaymentSession, +) + + +class AIInferenceService: + """AI inference service with micropayment channel support.""" + + def __init__( + self, + service_did: str, + model_name: str, + rate_per_token: float, + currency: str = 'USDC', + blockchain_network: str = 'kite', + ): + """Initialize the AI inference service. + + Args: + service_did: DID of the service + model_name: Name of the AI model being served + rate_per_token: Cost per token in the specified currency + currency: Payment currency (default: USDC) + blockchain_network: Blockchain network (default: kite) + """ + self.service_did = service_did + self.model_name = model_name + self.currency = currency + self.blockchain_network = blockchain_network + + # Initialize payment components + self.channel_manager = ChannelManager( + manager_id=f'channel_mgr_{service_did}', agent_did=service_did + ) + + self.streaming_manager = StreamingPaymentManager( + manager_id=f'stream_mgr_{service_did}', agent_did=service_did + ) + + # Service configuration + self.payment_rate = PaymentRate( + rate_type=PaymentRateType.PER_TOKEN, + rate_amount=CryptoPaymentAmount( + currency=currency, + value=rate_per_token, + blockchain_network=blockchain_network, + decimal_places=6, # USDC has 6 decimal places + ), + minimum_charge=CryptoPaymentAmount( + currency=currency, + value=0.001, # Minimum charge of 0.001 USDC + blockchain_network=blockchain_network, + decimal_places=6, + ), + billing_frequency_seconds=1, + unit_description='AI model tokens', + ) + + self.service_policy = StreamingPaymentPolicy( + max_stream_duration_seconds=3600, # 1 hour max + checkpoint_frequency_seconds=30, # Checkpoint every 30 seconds + auto_pause_threshold=CryptoPaymentAmount( + currency=currency, + value=10.0, # Auto-pause at $10 + blockchain_network=blockchain_network, + decimal_places=6, + ), + max_cumulative_amount=CryptoPaymentAmount( + currency=currency, + value=100.0, # Max $100 per stream + blockchain_network=blockchain_network, + decimal_places=6, + ), + rate_adjustment_allowed=False, + ) + + # Service statistics + self.total_requests = 0 + self.total_tokens_processed = 0 + self.total_revenue = 0.0 + self.active_clients = {} + + async def handle_channel_open_request( + self, client_did: str, client_wallet: str, initial_deposit: float + ) -> ChannelOperationResult: + """Handle a request to open a payment channel.""" + print(f'๐Ÿ“จ Channel open request from {client_did}') + + # Create participant for client + client_participant = ChannelParticipant( + participant_id=client_did, + agent_did=client_did, + wallet_address=client_wallet, + role='payer', + public_key=f'pubkey_{client_did}', # Mock public key + initial_balance=CryptoPaymentAmount( + currency=self.currency, + value=initial_deposit, + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + current_balance=CryptoPaymentAmount( + currency=self.currency, + value=initial_deposit, + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + ) + + # Create participant for service + service_participant = ChannelParticipant( + participant_id=self.service_did, + agent_did=self.service_did, + wallet_address=f'0x{uuid.uuid4().hex[:40]}', # Mock service wallet + role='payee', + public_key=f'pubkey_{self.service_did}', # Mock public key + initial_balance=CryptoPaymentAmount( + currency=self.currency, + value=0.0, + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + current_balance=CryptoPaymentAmount( + currency=self.currency, + value=0.0, + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + ) + + # Create channel policy + policy = ChannelPolicy( + max_transaction_amount=CryptoPaymentAmount( + currency=self.currency, + value=1.0, # Max $1 per transaction + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + min_transaction_amount=CryptoPaymentAmount( + currency=self.currency, + value=0.0001, # Min $0.0001 per transaction + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + settlement_threshold=CryptoPaymentAmount( + currency=self.currency, + value=5.0, # Auto-settle at $5 + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + fee_rate=0.001, # 0.1% fee + auto_close_timeout=86400, # 24 hours + ) + + # Create channel open request + open_request = ChannelOpenRequest( + requesting_participant=client_participant, + target_participant=service_participant, + proposed_policy=policy, + duration_hours=24, # 24 hour channel + initial_deposit=CryptoPaymentAmount( + currency=self.currency, + value=initial_deposit, + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + purpose=f'AI inference payments for {self.model_name}', + ) + + # Open the channel + result = self.channel_manager.open_channel(open_request) + + if result.success: + # Automatically activate the channel + activation_result = self.channel_manager.activate_channel( + result.channel_id + ) + if activation_result.success: + print(f'โœ… Channel {result.channel_id} opened and activated') + self.active_clients[client_did] = result.channel_id + else: + print( + f'โŒ Failed to activate channel: {activation_result.message}' + ) + + return result + + async def start_streaming_session( + self, client_did: str, service_description: str = 'AI inference tokens' + ) -> StreamingPaymentSession | None: + """Start a streaming payment session for a client.""" + channel_id = self.active_clients.get(client_did) + if not channel_id: + print(f'โŒ No active channel found for client {client_did}') + return None + + # Create streaming session + stream = self.streaming_manager.create_stream( + channel_id=channel_id, + payer_id=client_did, + payee_id=self.service_did, + service_description=service_description, + rate=self.payment_rate, + policy=self.service_policy, + ) + + stream.status = StreamStatus.ACTIVE + print( + f'๐ŸŒŠ Started streaming session {stream.stream_id} for {client_did}' + ) + return stream + + async def process_inference_request( + self, client_did: str, prompt: str, max_tokens: int = 100 + ) -> dict[str, Any]: + """Process an AI inference request with micropayment.""" + print(f'๐Ÿง  Processing inference request from {client_did}') + print(f' Prompt: {prompt[:50]}{"..." if len(prompt) > 50 else ""}') + print(f' Max tokens: {max_tokens}') + + # Get or create streaming session + active_streams = self.streaming_manager.get_streams_by_channel( + self.active_clients.get(client_did, '') + ) + + stream = None + if active_streams: + stream = active_streams[0] # Use existing stream + else: + stream = await self.start_streaming_session(client_did) + + if not stream: + return { + 'error': 'Failed to establish payment stream', + 'status': 'payment_required', + } + + # Simulate AI inference (mock implementation) + generated_text, actual_tokens = self._mock_ai_inference( + prompt, max_tokens + ) + + # Calculate payment for actual tokens used + payment_amount = self.payment_rate.rate_amount.value * actual_tokens + + # Check if payment can be processed + channel = self.channel_manager.active_channels.get(stream.channel_id) + if not channel: + return { + 'error': 'Payment channel not found', + 'status': 'payment_error', + } + + can_pay, reason = channel.can_process_payment( + client_did, + self.service_did, + CryptoPaymentAmount( + currency=self.currency, + value=payment_amount, + blockchain_network=self.blockchain_network, + decimal_places=6, + ), + ) + + if not can_pay: + stream.pause_stream(f'Payment failed: {reason}') + return { + 'error': f'Payment failed: {reason}', + 'status': 'insufficient_funds', + 'required_amount': payment_amount, + 'currency': self.currency, + } + + # Process payment through streaming voucher + try: + voucher = stream.add_voucher( + units_consumed=actual_tokens, + metadata={ + 'prompt_length': len(prompt), + 'response_length': len(generated_text), + 'model': self.model_name, + }, + ) + + # Process payment through channel + payment_result = self.channel_manager.process_payment( + channel_id=stream.channel_id, + from_participant=client_did, + to_participant=self.service_did, + amount=voucher.increment_amount, + metadata={ + 'stream_id': stream.stream_id, + 'voucher_id': voucher.voucher_id, + 'service_type': 'ai_inference', + }, + ) + + if not payment_result.success: + stream.pause_stream( + f'Payment processing failed: {payment_result.message}' + ) + return { + 'error': f'Payment processing failed: {payment_result.message}', + 'status': 'payment_error', + } + + # Update service statistics + self.total_requests += 1 + self.total_tokens_processed += actual_tokens + self.total_revenue += payment_amount + + print( + f'๐Ÿ’ฐ Payment processed: {payment_amount} {self.currency} for {actual_tokens} tokens' + ) + + return { + 'status': 'success', + 'response': generated_text, + 'tokens_used': actual_tokens, + 'cost': payment_amount, + 'currency': self.currency, + 'voucher_id': voucher.voucher_id, + 'stream_id': stream.stream_id, + } + + except Exception as e: + stream.pause_stream(f'Processing error: {e!s}') + return { + 'error': f'Processing error: {e!s}', + 'status': 'service_error', + } + + def _mock_ai_inference( + self, prompt: str, max_tokens: int + ) -> tuple[str, int]: + """Mock AI inference that generates a response and counts tokens.""" + # Simulate token processing (1 token โ‰ˆ 4 characters) + prompt_tokens = len(prompt) // 4 + + # Generate mock response + responses = [ + "I understand your request. Here's a helpful response based on the information provided.", + "Thank you for your query. I'll analyze this and provide you with a comprehensive answer.", + 'Based on your input, I can suggest several approaches to address this topic effectively.', + 'This is an interesting question that requires careful consideration of multiple factors.', + 'I appreciate your question. Let me break this down into manageable components for you.', + ] + + import random + + response = random.choice(responses) + + # Limit response to max_tokens (approximately) + max_response_chars = max_tokens * 4 + if len(response) > max_response_chars: + response = response[:max_response_chars] + '...' + + response_tokens = len(response) // 4 + total_tokens = prompt_tokens + response_tokens + + return response, total_tokens + + def get_service_stats(self) -> dict[str, Any]: + """Get service statistics.""" + active_channels = len(self.channel_manager.active_channels) + active_streams = len(self.streaming_manager.active_streams) + + return { + 'service_did': self.service_did, + 'model_name': self.model_name, + 'total_requests': self.total_requests, + 'total_tokens_processed': self.total_tokens_processed, + 'total_revenue': self.total_revenue, + 'currency': self.currency, + 'rate_per_token': self.payment_rate.rate_amount.value, + 'active_channels': active_channels, + 'active_streams': active_streams, + 'active_clients': list(self.active_clients.keys()), + } + + async def cleanup_expired_resources(self): + """Clean up expired channels and streams.""" + # Clean up expired channels + expired_channels = self.channel_manager.cleanup_expired_channels() + if expired_channels: + print(f'๐Ÿงน Cleaned up {len(expired_channels)} expired channels') + + # Clean up expired streams + expired_streams = self.streaming_manager.cleanup_expired_streams() + if expired_streams: + print(f'๐Ÿงน Cleaned up {len(expired_streams)} expired streams') + + # Remove inactive clients + inactive_clients = [] + for client_did, channel_id in self.active_clients.items(): + if channel_id not in self.channel_manager.active_channels: + inactive_clients.append(client_did) + + for client_did in inactive_clients: + del self.active_clients[client_did] + + +async def main(): + """Main function demonstrating the AI inference service.""" + parser = argparse.ArgumentParser( + description='AI Inference Service with Micropayments' + ) + parser.add_argument( + '--model', default='gpt-3.5-turbo', help='AI model name' + ) + parser.add_argument( + '--rate-per-token', type=float, default=0.0001, help='Rate per token' + ) + parser.add_argument('--currency', default='USDC', help='Payment currency') + parser.add_argument('--network', default='kite', help='Blockchain network') + + args = parser.parse_args() + + # Initialize service + service_did = f'did:kite:1:ai_service_{uuid.uuid4().hex[:8]}' + service = AIInferenceService( + service_did=service_did, + model_name=args.model, + rate_per_token=args.rate_per_token, + currency=args.currency, + blockchain_network=args.network, + ) + + print('๐Ÿค– AI Inference Service with Micropayments') + print('=' * 50) + print(f'Service DID: {service_did}') + print(f'Model: {args.model}') + print(f'Rate: {args.rate_per_token} {args.currency} per token') + print(f'Network: {args.network}') + print() + + # Simulate client interactions + clients = [ + { + 'did': f'did:kite:1:client_alice_{uuid.uuid4().hex[:8]}', + 'wallet': f'0x{uuid.uuid4().hex[:40]}', + 'deposit': 5.0, + }, + { + 'did': f'did:kite:1:client_bob_{uuid.uuid4().hex[:8]}', + 'wallet': f'0x{uuid.uuid4().hex[:40]}', + 'deposit': 10.0, + }, + ] + + # Setup channels for clients + for client in clients: + print(f'๐Ÿ”— Setting up channel for {client["did"][:20]}...') + result = await service.handle_channel_open_request( + client['did'], client['wallet'], client['deposit'] + ) + + if result.success: + print(f'โœ… Channel opened: {result.channel_id}') + else: + print(f'โŒ Failed to open channel: {result.message}') + + print() + + # Simulate inference requests + inference_requests = [ + { + 'client': clients[0]['did'], + 'prompt': 'What are the benefits of using micropayment channels for AI services?', + 'max_tokens': 150, + }, + { + 'client': clients[1]['did'], + 'prompt': 'Explain how streaming payments work in blockchain applications', + 'max_tokens': 200, + }, + { + 'client': clients[0]['did'], + 'prompt': 'How do agent-to-agent payments differ from traditional payment systems?', + 'max_tokens': 100, + }, + ] + + for i, request in enumerate(inference_requests, 1): + print(f'๐Ÿง  Processing request {i}/3...') + result = await service.process_inference_request( + request['client'], request['prompt'], request['max_tokens'] + ) + + if result.get('status') == 'success': + print( + f'โœ… Response: {result["response"][:80]}{"..." if len(result["response"]) > 80 else ""}' + ) + print( + f' Cost: {result["cost"]} {result["currency"]} for {result["tokens_used"]} tokens' + ) + else: + print(f'โŒ Error: {result.get("error", "Unknown error")}') + + print() + + # Small delay between requests + await asyncio.sleep(1) + + # Show final statistics + stats = service.get_service_stats() + print('๐Ÿ“Š Final Service Statistics:') + print(f' Total requests: {stats["total_requests"]}') + print(f' Total tokens: {stats["total_tokens_processed"]}') + print(f' Total revenue: {stats["total_revenue"]} {stats["currency"]}') + print(f' Active channels: {stats["active_channels"]}') + print(f' Active streams: {stats["active_streams"]}') + + # Cleanup + await service.cleanup_expired_resources() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/samples/python/src/common/a2a_extension_utils.py b/samples/python/src/common/a2a_extension_utils.py index af3343eb..2f94aedd 100644 --- a/samples/python/src/common/a2a_extension_utils.py +++ b/samples/python/src/common/a2a_extension_utils.py @@ -14,4 +14,4 @@ """Utility class for storing A2A related objects.""" -EXTENSION_URI = "https://github.com/google-agentic-commerce/ap2/v1" +EXTENSION_URI = 'https://github.com/google-agentic-commerce/ap2/v1' diff --git a/samples/python/src/common/a2a_message_builder.py b/samples/python/src/common/a2a_message_builder.py index 49f27b7c..8f3866c1 100644 --- a/samples/python/src/common/a2a_message_builder.py +++ b/samples/python/src/common/a2a_message_builder.py @@ -14,77 +14,78 @@ """A builder class for building an A2A Message object.""" -from typing import Any, Self import uuid +from typing import Any, Self + from a2a import types as a2a_types class A2aMessageBuilder: - """A builder class for building an A2A Message object.""" - - def __init__(self): - self._message = self._create_base_message() - - def add_text(self, text: str) -> Self: - """Adds a TextPart to the Message. - - Args: - text: The text to be added to the Message. - - Returns: - The A2aMessageBuilder instance. - """ - part = a2a_types.Part(root=a2a_types.TextPart(text=text)) - self._message.parts.append(part) - return self - - def add_data(self, key: str, data: str | dict[str, Any]) -> Self: - """Adds a new DataPart to the Message. - - If a key is provided, then the data part must be a string. The DataPart's - data dictionary will be set to { key: data}. - - If no key is provided, then the data part must be a dictionary. The - DataPart's data dictionary will be set to data. - - Args: - key: The key to use for the data part. - data: The data to accompany the key, if provided. Otherwise, the data to - be set within the DataPart object. - - Returns: - The A2aMessageBuilder instance. - """ - if not data: - return self - - nested_data = data - if key: - nested_data = {key: data} - - part = a2a_types.Part(root=a2a_types.DataPart(data=nested_data)) - self._message.parts.append(part) - return self - - def set_context_id(self, context_id: str) -> Self: - """Sets the context id on the Message.""" - self._message.context_id = context_id - return self - - def set_task_id(self, task_id: str) -> Self: - """Sets the task id on the Message.""" - self._message.task_id = task_id - return self - - def build(self) -> a2a_types.Message: - """Returns the Message object that has been built.""" - return self._message - - def _create_base_message(self) -> a2a_types.Message: - """Creates and returns a base Message object.""" - return a2a_types.Message( - message_id=uuid.uuid4().hex, - parts=[], - role=a2a_types.Role.agent, - ) + """A builder class for building an A2A Message object.""" + + def __init__(self): + self._message = self._create_base_message() + + def add_text(self, text: str) -> Self: + """Adds a TextPart to the Message. + + Args: + text: The text to be added to the Message. + + Returns: + The A2aMessageBuilder instance. + """ + part = a2a_types.Part(root=a2a_types.TextPart(text=text)) + self._message.parts.append(part) + return self + + def add_data(self, key: str, data: str | dict[str, Any]) -> Self: + """Adds a new DataPart to the Message. + + If a key is provided, then the data part must be a string. The DataPart's + data dictionary will be set to { key: data}. + + If no key is provided, then the data part must be a dictionary. The + DataPart's data dictionary will be set to data. + + Args: + key: The key to use for the data part. + data: The data to accompany the key, if provided. Otherwise, the data to + be set within the DataPart object. + + Returns: + The A2aMessageBuilder instance. + """ + if not data: + return self + + nested_data = data + if key: + nested_data = {key: data} + + part = a2a_types.Part(root=a2a_types.DataPart(data=nested_data)) + self._message.parts.append(part) + return self + + def set_context_id(self, context_id: str) -> Self: + """Sets the context id on the Message.""" + self._message.context_id = context_id + return self + + def set_task_id(self, task_id: str) -> Self: + """Sets the task id on the Message.""" + self._message.task_id = task_id + return self + + def build(self) -> a2a_types.Message: + """Returns the Message object that has been built.""" + return self._message + + def _create_base_message(self) -> a2a_types.Message: + """Creates and returns a base Message object.""" + return a2a_types.Message( + message_id=uuid.uuid4().hex, + parts=[], + role=a2a_types.Role.agent, + ) diff --git a/samples/python/src/common/artifact_utils.py b/samples/python/src/common/artifact_utils.py index ff98066c..3792f4dc 100644 --- a/samples/python/src/common/artifact_utils.py +++ b/samples/python/src/common/artifact_utils.py @@ -20,59 +20,62 @@ from a2a.utils import message as message_utils from pydantic import BaseModel -T = TypeVar("T") + +T = TypeVar('T') def find_canonical_objects( artifacts: list[Artifact], data_key: str, model: BaseModel ) -> list[BaseModel]: - """Finds all canonical objects of the given type in the artifacts. - - Args: - artifacts: a list of the artifacts to be searched. - data_key: The key of the DataPart to search for. - model: The model of the canonical object to search for. - - Returns: - A list of canonical objects of the given type in the artifacts. - """ - canonical_objects = [] - for artifact in artifacts: - for part in artifact.parts: - if hasattr(part.root, "data") and data_key in part.root.data: - canonical_objects.append(model.model_validate(part.root.data[data_key])) - return canonical_objects + """Finds all canonical objects of the given type in the artifacts. + + Args: + artifacts: a list of the artifacts to be searched. + data_key: The key of the DataPart to search for. + model: The model of the canonical object to search for. + + Returns: + A list of canonical objects of the given type in the artifacts. + """ + canonical_objects = [] + for artifact in artifacts: + for part in artifact.parts: + if hasattr(part.root, 'data') and data_key in part.root.data: + canonical_objects.append( + model.model_validate(part.root.data[data_key]) + ) + return canonical_objects def get_first_data_part(artifacts: list[Artifact]) -> dict[str, Any]: - """Returns the first DataPart encountered in all the given artifacts. + """Returns the first DataPart encountered in all the given artifacts. - Args: - artifacts: The artifacts to be searched for a DataPart. + Args: + artifacts: The artifacts to be searched for a DataPart. - Returns: - The data contents within the first found DataPart. - """ - data_parts = [ - message_utils.get_data_parts(artifact.parts) for artifact in artifacts - ] - for data_part in data_parts: - for item in data_part: - return item - return {} + Returns: + The data contents within the first found DataPart. + """ + data_parts = [ + message_utils.get_data_parts(artifact.parts) for artifact in artifacts + ] + for data_part in data_parts: + for item in data_part: + return item + return {} def only(list_: list[T]) -> T: - """Returns the only element in a list. - - Args: - list_: The list expected to contain exactly one element. - - Raises: - ValueError: if the list is empty or has more than one element. - """ - if not list_: - raise ValueError("List is empty.") - if len(list_) > 1: - raise ValueError("List has more than one element.") - return list_[0] + """Returns the only element in a list. + + Args: + list_: The list expected to contain exactly one element. + + Raises: + ValueError: if the list is empty or has more than one element. + """ + if not list_: + raise ValueError('List is empty.') + if len(list_) > 1: + raise ValueError('List has more than one element.') + return list_[0] diff --git a/samples/python/src/common/base_server_executor.py b/samples/python/src/common/base_server_executor.py index 4b0410c5..87744733 100644 --- a/samples/python/src/common/base_server_executor.py +++ b/samples/python/src/common/base_server_executor.py @@ -25,167 +25,171 @@ import abc import logging -from typing import Any, Callable, Tuple import uuid +from collections.abc import Callable +from typing import Any + from a2a.server.agent_execution.agent_executor import AgentExecutor from a2a.server.agent_execution.context import RequestContext from a2a.server.events.event_queue import EventQueue from a2a.server.tasks.task_updater import TaskUpdater -from a2a.types import Part -from a2a.types import Task -from a2a.types import TextPart +from a2a.types import Part, Task, TextPart from a2a.utils import message -from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY from google import genai -from ap2.types.mandate import PaymentMandate -from common import message_utils -from common import watch_log + +from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY, PaymentMandate +from common import message_utils, watch_log from common.a2a_extension_utils import EXTENSION_URI from common.function_call_resolver import FunctionCallResolver from common.validation import validate_payment_mandate_signature + DataPartContent = dict[str, Any] Tool = Callable[[list[DataPartContent], TaskUpdater, Task | None], Any] + class BaseServerExecutor(AgentExecutor, abc.ABC): - """A baseline A2A AgentExecutor to be utilized by agents.""" - - def __init__( - self, - supported_extensions: list[dict[str, Any]] | None, - tools: list[Tool], - system_prompt: str = "You are a helpful assistant.", - ): - """Initialization. - - Args: - supported_extensions: Extensions the agent declares that it supports. - tools: Tools supported by the agent. - system_prompt: Helps steer the model when choosing tools. - """ - if supported_extensions is not None: - self._supported_extension_uris = {ext.uri for ext in supported_extensions} - else: - self._supported_extension_uris = set() - self._client = genai.Client() - self._tools = tools - self._tool_resolver = FunctionCallResolver( - self._client, self._tools, system_prompt - ) - super().__init__() - - async def execute( - self, context: RequestContext, event_queue: EventQueue - ) -> None: - """Execute the agent's logic for a given request context. - - Args: - context: The request context containing the message, task ID, etc. - event_queue: The queue to publish events to. - """ - watch_log.log_a2a_request_extensions(context) - - text_parts, data_parts = self._parse_request(context) - watch_log.log_a2a_message_parts(text_parts, data_parts) - - self._handle_extensions(context) - - if EXTENSION_URI in context.call_context.activated_extensions: - payment_mandate = message_utils.find_data_part( - PAYMENT_MANDATE_DATA_KEY, data_parts - ) - if payment_mandate is not None: - validate_payment_mandate_signature( - PaymentMandate.model_validate(payment_mandate) + """A baseline A2A AgentExecutor to be utilized by agents.""" + + def __init__( + self, + supported_extensions: list[dict[str, Any]] | None, + tools: list[Tool], + system_prompt: str = 'You are a helpful assistant.', + ): + """Initialization. + + Args: + supported_extensions: Extensions the agent declares that it supports. + tools: Tools supported by the agent. + system_prompt: Helps steer the model when choosing tools. + """ + if supported_extensions is not None: + self._supported_extension_uris = { + ext.uri for ext in supported_extensions + } + else: + self._supported_extension_uris = set() + self._client = genai.Client() + self._tools = tools + self._tool_resolver = FunctionCallResolver( + self._client, self._tools, system_prompt + ) + super().__init__() + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + """Execute the agent's logic for a given request context. + + Args: + context: The request context containing the message, task ID, etc. + event_queue: The queue to publish events to. + """ + watch_log.log_a2a_request_extensions(context) + + text_parts, data_parts = self._parse_request(context) + watch_log.log_a2a_message_parts(text_parts, data_parts) + + self._handle_extensions(context) + + if EXTENSION_URI in context.call_context.activated_extensions: + payment_mandate = message_utils.find_data_part( + PAYMENT_MANDATE_DATA_KEY, data_parts + ) + if payment_mandate is not None: + validate_payment_mandate_signature( + PaymentMandate.model_validate(payment_mandate) + ) + else: + raise ValueError( + 'Payment extension not activated.' + f' {context.call_context.activated_extensions}' + ) + + updater = TaskUpdater( + event_queue, + task_id=context.task_id or str(uuid.uuid4()), + context_id=context.context_id or str(uuid.uuid4()), ) - else: - raise ValueError( - "Payment extension not activated." - f" {context.call_context.activated_extensions}" - ) - - updater = TaskUpdater( - event_queue, - task_id=context.task_id or str(uuid.uuid4()), - context_id=context.context_id or str(uuid.uuid4()), - ) - - logging.info( - "Server working on (context_id, task_id): (%s, %s)", - updater.context_id, - updater.task_id, - ) - await self._handle_request( - text_parts, - data_parts, - updater, - context.current_task, - ) - - async def cancel(self, context: RequestContext) -> None: - """Request the agent to cancel an ongoing task.""" - pass - - async def _handle_request( - self, - text_parts: list[str], - data_parts: list[dict[str, Any]], - updater: TaskUpdater, - current_task: Task | None, - ) -> None: - """Receives a parsed request and dispatches to the appropriate tool. - - Args: - text_parts: A list of text parts from the request. - data_parts: A list of data parts from the request. - updater: The TaskUpdater instance for updating the task. - current_task: The current Task, if available. - """ - try: - prompt = (text_parts[0] if text_parts else "").strip() - tool_name = self._tool_resolver.determine_tool_to_use(prompt) - logging.info("Using tool: %s", tool_name) - - matching_tools = list( - filter(lambda tool: tool.__name__ == tool_name, self._tools) - ) - if len(matching_tools) != 1: - raise ValueError( - f"Expected 1 tool matching {tool_name}, got {len(matching_tools)}" + + logging.info( + 'Server working on (context_id, task_id): (%s, %s)', + updater.context_id, + updater.task_id, + ) + await self._handle_request( + text_parts, + data_parts, + updater, + context.current_task, + ) + + async def cancel(self, context: RequestContext) -> None: + """Request the agent to cancel an ongoing task.""" + + async def _handle_request( + self, + text_parts: list[str], + data_parts: list[dict[str, Any]], + updater: TaskUpdater, + current_task: Task | None, + ) -> None: + """Receives a parsed request and dispatches to the appropriate tool. + + Args: + text_parts: A list of text parts from the request. + data_parts: A list of data parts from the request. + updater: The TaskUpdater instance for updating the task. + current_task: The current Task, if available. + """ + try: + prompt = (text_parts[0] if text_parts else '').strip() + tool_name = self._tool_resolver.determine_tool_to_use(prompt) + logging.info('Using tool: %s', tool_name) + + matching_tools = list( + filter(lambda tool: tool.__name__ == tool_name, self._tools) + ) + if len(matching_tools) != 1: + raise ValueError( + f'Expected 1 tool matching {tool_name}, got {len(matching_tools)}' + ) + callable_tool = matching_tools[0] + await callable_tool(data_parts, updater, current_task) + + except Exception as e: # pylint: disable=broad-exception-caught + error_message = updater.new_agent_message( + parts=[Part(root=TextPart(text=f'An error occurred: {e}'))] + ) + await updater.failed(message=error_message) + + def _parse_request( + self, context: RequestContext + ) -> tuple[list[str], list[dict[str, Any]]]: + """Parses the request and returns the text and data parts. + + Args: + context: The A2A RequestContext + + Returns: + A tuple containing the contents of TextPart and DataPart objects. + """ + parts = context.message.parts if context.message else [] + text_parts = message.get_text_parts(parts) + data_parts = message.get_data_parts(parts) + return text_parts, data_parts + + def _handle_extensions(self, context: RequestContext) -> None: + """Activates any requested extensions that the agent supports. + + Args: + context: The A2A RequestContext + """ + requested_uris = context.requested_extensions + activated_uris = requested_uris.intersection( + self._supported_extension_uris ) - callable_tool = matching_tools[0] - await callable_tool(data_parts, updater, current_task) - - except Exception as e: # pylint: disable=broad-exception-caught - error_message = updater.new_agent_message( - parts=[Part(root=TextPart(text=f"An error occurred: {e}"))] - ) - await updater.failed(message=error_message) - - def _parse_request( - self, context: RequestContext - ) -> Tuple[list[str], list[dict[str, Any]]]: - """Parses the request and returns the text and data parts. - - Args: - context: The A2A RequestContext - - Returns: - A tuple containing the contents of TextPart and DataPart objects. - """ - parts = context.message.parts if context.message else [] - text_parts = message.get_text_parts(parts) - data_parts = message.get_data_parts(parts) - return text_parts, data_parts - - def _handle_extensions(self, context: RequestContext) -> None: - """Activates any requested extensions that the agent supports. - - Args: - context: The A2A RequestContext - """ - requested_uris = context.requested_extensions - activated_uris = requested_uris.intersection(self._supported_extension_uris) - for uri in activated_uris: - context.add_activated_extension(uri) + for uri in activated_uris: + context.add_activated_extension(uri) diff --git a/samples/python/src/common/function_call_resolver.py b/samples/python/src/common/function_call_resolver.py index 49356429..22f424b2 100644 --- a/samples/python/src/common/function_call_resolver.py +++ b/samples/python/src/common/function_call_resolver.py @@ -19,7 +19,9 @@ """ import logging -from typing import Any, Callable + +from collections.abc import Callable +from typing import Any from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import Task @@ -32,69 +34,68 @@ class FunctionCallResolver: - """Resolves a natural language prompt to the name of a tool.""" - - def __init__( - self, - llm_client: genai.Client, - tools: list[Tool], - instructions: str = "You are a helpful assistant.", - ): - """Initialization. - - Args: - llm_client: The LLM client. - tools: The list of tools that a request can be resolved to. - instructions: The instructions to guide the LLM. - """ - self._client = llm_client - function_declarations = [ - types.FunctionDeclaration( - name=tool.__name__, description=tool.__doc__ - ) - for tool in tools - ] - self._config = types.GenerateContentConfig( - system_instruction=instructions, - tools=[types.Tool(function_declarations=function_declarations)], - automatic_function_calling=types.AutomaticFunctionCallingConfig( - disable=True - ), - # Force the model to call 'any' function, instead of chatting. - tool_config=types.ToolConfig( - function_calling_config=types.FunctionCallingConfig(mode="ANY") - ), - ) - - def determine_tool_to_use(self, prompt: str) -> str: - """Determines which tool to use based on a user's prompt. - - Uses a LLM to analyze the user's prompt and decide which of the available - tools (functions) is the most appropriate to handle the request. - - Args: - prompt: The user's request as a string. - - Returns: - The name of the tool function that the model has determined should be - called. If no suitable tool is found, it returns "Unknown". - """ - - response = self._client.models.generate_content( - model="gemini-2.5-flash", - contents=prompt, - config=self._config, - ) - - logging.debug("\nDetermine Tool Response: %s\n", response) - - if ( - response.candidates - and response.candidates[0].content - and response.candidates[0].content.parts + """Resolves a natural language prompt to the name of a tool.""" + + def __init__( + self, + llm_client: genai.Client, + tools: list[Tool], + instructions: str = 'You are a helpful assistant.', ): - for part in response.candidates[0].content.parts: - if part.function_call: - return part.function_call.name + """Initialization. + + Args: + llm_client: The LLM client. + tools: The list of tools that a request can be resolved to. + instructions: The instructions to guide the LLM. + """ + self._client = llm_client + function_declarations = [ + types.FunctionDeclaration( + name=tool.__name__, description=tool.__doc__ + ) + for tool in tools + ] + self._config = types.GenerateContentConfig( + system_instruction=instructions, + tools=[types.Tool(function_declarations=function_declarations)], + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True + ), + # Force the model to call 'any' function, instead of chatting. + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode='ANY') + ), + ) + + def determine_tool_to_use(self, prompt: str) -> str: + """Determines which tool to use based on a user's prompt. + + Uses a LLM to analyze the user's prompt and decide which of the available + tools (functions) is the most appropriate to handle the request. + + Args: + prompt: The user's request as a string. + + Returns: + The name of the tool function that the model has determined should be + called. If no suitable tool is found, it returns "Unknown". + """ + response = self._client.models.generate_content( + model='gemini-2.5-flash', + contents=prompt, + config=self._config, + ) + + logging.debug('\nDetermine Tool Response: %s\n', response) + + if ( + response.candidates + and response.candidates[0].content + and response.candidates[0].content.parts + ): + for part in response.candidates[0].content.parts: + if part.function_call: + return part.function_call.name - return "Unknown" + return 'Unknown' diff --git a/samples/python/src/common/payment_remote_a2a_client.py b/samples/python/src/common/payment_remote_a2a_client.py index 56f2aca6..42d8665a 100644 --- a/samples/python/src/common/payment_remote_a2a_client.py +++ b/samples/python/src/common/payment_remote_a2a_client.py @@ -14,107 +14,107 @@ """Wrapper for the A2A client.""" -import httpx import logging import uuid +import httpx + from a2a import types as a2a_types from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client import Client -from a2a.client.client import ClientConfig +from a2a.client.client import Client, ClientConfig from a2a.client.client_factory import ClientFactory from a2a.client.client_task_manager import ClientTaskManager from a2a.extensions.common import HTTP_EXTENSION_HEADER -DEFAULT_TIMEOUT = 600.0 - -class PaymentRemoteA2aClient(): - """Wrapper for the A2A client. +DEFAULT_TIMEOUT = 600.0 - Always assumes the AgentCard is at base_url + {AGENT_CARD_WELL_KNOWN_PATH}. - Provides convenience for establishing connection and for sending messages. - """ +class PaymentRemoteA2aClient: + """Wrapper for the A2A client. - def __init__( - self, - name: str, - base_url: str, - required_extensions: set[str] | None = None, - ): - """Initializes the PaymentRemoteA2aClient. + Always assumes the AgentCard is at base_url + {AGENT_CARD_WELL_KNOWN_PATH}. - Args: - name: The name of the agent. - base_url: The base URL where the remote agent is hosted. - required_extensions: A set of extension URIs that the client requires. + Provides convenience for establishing connection and for sending messages. """ - self._httpx_client = httpx.AsyncClient( - timeout=httpx.Timeout(timeout=DEFAULT_TIMEOUT) - ) - self._a2a_client_factory = ClientFactory( - ClientConfig( - httpx_client=self._httpx_client, + def __init__( + self, + name: str, + base_url: str, + required_extensions: set[str] | None = None, + ): + """Initializes the PaymentRemoteA2aClient. + + Args: + name: The name of the agent. + base_url: The base URL where the remote agent is hosted. + required_extensions: A set of extension URIs that the client requires. + """ + self._httpx_client = httpx.AsyncClient( + timeout=httpx.Timeout(timeout=DEFAULT_TIMEOUT) + ) + self._a2a_client_factory = ClientFactory( + ClientConfig( + httpx_client=self._httpx_client, + ) + ) + self._name = name + self._base_url = base_url + self._agent_card = None + self._client_required_extensions = required_extensions or set() + + async def get_agent_card(self) -> a2a_types.AgentCard: + """Get agent card.""" + if self._agent_card is None: + resolver = A2ACardResolver( + httpx_client=self._httpx_client, + base_url=self._base_url, + ) + self._agent_card = await resolver.get_agent_card() + return self._agent_card + + async def send_a2a_message( + self, message: a2a_types.Message + ) -> a2a_types.Task: + """Retrieves the A2A client, sends the message, and returns the event.""" + my_a2a_client: Client = await self._get_a2a_client() + + task_manager = ClientTaskManager() + + async for event in my_a2a_client.send_message(message): + # Tasks are returned in tuples (aka ClientEvent). The first element is the + # Task, the second element is the UpdateEvent. + if isinstance(event, tuple): + event = event[0] + await task_manager.process(event) + + task = task_manager.get_task() + if task is None: + raise RuntimeError(f'No response from {self._name}') + logging.info( + 'Response received from %s for (context_id, task_id): (%s, %s)', + self._name, + task.context_id, + task.id, + ) + return task + + async def _get_a2a_client(self) -> Client: + """Get A2A client.""" + agent_card = await self.get_agent_card() + self._httpx_client.headers[HTTP_EXTENSION_HEADER] = ', '.join( + self._client_required_extensions + ) + return self._a2a_client_factory.create(agent_card) + + def _create_agent_message( + self, + message: str, + ) -> a2a_types.Message: + """Get message.""" + return a2a_types.Message( + message_id=uuid.uuid4().hex, + parts=[a2a_types.Part(root=a2a_types.TextPart(text=str(message)))], + role=a2a_types.Role.agent, ) - ) - self._name = name - self._base_url = base_url - self._agent_card = None - self._client_required_extensions = required_extensions or set() - - async def get_agent_card(self) -> a2a_types.AgentCard: - """Get agent card.""" - if self._agent_card is None: - resolver = A2ACardResolver( - httpx_client=self._httpx_client, - base_url=self._base_url, - ) - self._agent_card = await resolver.get_agent_card() - return self._agent_card - - async def send_a2a_message( - self, message: a2a_types.Message - ) -> a2a_types.Task: - """Retrieves the A2A client, sends the message, and returns the event.""" - my_a2a_client: Client = await self._get_a2a_client() - - task_manager = ClientTaskManager() - - async for event in my_a2a_client.send_message(message): - # Tasks are returned in tuples (aka ClientEvent). The first element is the - # Task, the second element is the UpdateEvent. - if isinstance(event, tuple): - event = event[0] - await task_manager.process(event) - - task = task_manager.get_task() - if task is None: - raise RuntimeError(f"No response from {self._name}") - logging.info( - "Response received from %s for (context_id, task_id): (%s, %s)", - self._name, - task.context_id, - task.id, - ) - return task - - async def _get_a2a_client(self) -> Client: - """Get A2A client.""" - agent_card = await self.get_agent_card() - self._httpx_client.headers[HTTP_EXTENSION_HEADER] = ", ".join( - self._client_required_extensions - ) - return self._a2a_client_factory.create(agent_card) - - def _create_agent_message( - self, - message: str, - ) -> a2a_types.Message: - """Get message.""" - return a2a_types.Message( - message_id=uuid.uuid4().hex, - parts=[a2a_types.Part(root=a2a_types.TextPart(text=str(message)))], - role=a2a_types.Role.agent, - ) diff --git a/samples/python/src/common/retrying_llm_agent.py b/samples/python/src/common/retrying_llm_agent.py index c5e50e2c..a6d6855d 100644 --- a/samples/python/src/common/retrying_llm_agent.py +++ b/samples/python/src/common/retrying_llm_agent.py @@ -18,48 +18,52 @@ requests and surfacing errors captured from the LLM. """ +from typing import override + from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.events.event import Event -from typing_extensions import AsyncGenerator, override +from typing_extensions import AsyncGenerator class RetryingLlmAgent(LlmAgent): - """An LLM agent that surfaces errors to the user and then retries.""" + """An LLM agent that surfaces errors to the user and then retries.""" - def __init__(self, *args, max_retries: int = 1, **kwargs): - super().__init__(*args, **kwargs) - self._max_retries = max_retries + def __init__(self, *args, max_retries: int = 1, **kwargs): + super().__init__(*args, **kwargs) + self._max_retries = max_retries - async def _retry_async( - self, ctx: InvocationContext, retries_left: int = 0 - ) -> AsyncGenerator[Event, None]: - if retries_left <= 0: - yield Event( - author=ctx.agent.name, - invocation_id=ctx.invocation_id, - error_message=( - "Maximum retries exhausted. The remote Gemini server failed to" - " respond. Please try again later." - ), - ) - else: - try: - async for event in super()._run_async_impl(ctx): - yield event - except Exception as e: # pylint: disable=broad-exception-caught - yield Event( - author=ctx.agent.name, - invocation_id=ctx.invocation_id, - error_message="Gemini server error. Retrying...", - custom_metadata={"error": str(e)}, - ) - async for event in self._retry_async(ctx, retries_left - 1): - yield event + async def _retry_async( + self, ctx: InvocationContext, retries_left: int = 0 + ) -> AsyncGenerator[Event, None]: + if retries_left <= 0: + yield Event( + author=ctx.agent.name, + invocation_id=ctx.invocation_id, + error_message=( + 'Maximum retries exhausted. The remote Gemini server failed to' + ' respond. Please try again later.' + ), + ) + else: + try: + async for event in super()._run_async_impl(ctx): + yield event + except Exception as e: # pylint: disable=broad-exception-caught + yield Event( + author=ctx.agent.name, + invocation_id=ctx.invocation_id, + error_message='Gemini server error. Retrying...', + custom_metadata={'error': str(e)}, + ) + async for event in self._retry_async(ctx, retries_left - 1): + yield event - @override - async def _run_async_impl( - self, ctx: InvocationContext - ) -> AsyncGenerator[Event, None]: - async for event in self._retry_async(ctx, retries_left=self._max_retries): - yield event + @override + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + async for event in self._retry_async( + ctx, retries_left=self._max_retries + ): + yield event diff --git a/samples/python/src/common/server.py b/samples/python/src/common/server.py index 51ec993b..18965d2f 100644 --- a/samples/python/src/common/server.py +++ b/samples/python/src/common/server.py @@ -23,9 +23,15 @@ import logging import os -from a2a.server.agent_execution.simple_request_context_builder import SimpleRequestContextBuilder +import uvicorn + +from a2a.server.agent_execution.simple_request_context_builder import ( + SimpleRequestContextBuilder, +) from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication -from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler +from a2a.server.request_handlers.default_request_handler import ( + DefaultRequestHandler, +) from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import AgentCard from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -33,28 +39,28 @@ from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import Response -import uvicorn from . import watch_log from .base_server_executor import BaseServerExecutor + # Constant for the A2A extensions header -A2A_EXTENSIONS_HEADER = "X-A2A-Extensions" +A2A_EXTENSIONS_HEADER = 'X-A2A-Extensions' def load_local_agent_card(file_path: str) -> AgentCard: - """Loads the AgentCard from the specified file path. + """Loads the AgentCard from the specified file path. - Args: - file_path: The directory where the agent.json file is located. + Args: + file_path: The directory where the agent.json file is located. - Returns: - The loaded AgentCard instance. - """ - card_path = os.path.join(os.path.dirname(file_path), "agent.json") - with open(card_path, "r", encoding="utf-8") as f: - data = json.load(f) - return AgentCard.model_validate(data) + Returns: + The loaded AgentCard instance. + """ + card_path = os.path.join(os.path.dirname(file_path), 'agent.json') + with open(card_path, encoding='utf-8') as f: + data = json.load(f) + return AgentCard.model_validate(data) def run_agent_blocking( @@ -64,170 +70,174 @@ def run_agent_blocking( executor: BaseServerExecutor, rpc_url: str, ) -> None: - """Launches a Uvicorn server for an agent and block the current thread. - - Args: - port: TCP port to bind to. - agent_card: The AgentCard object describing the agent. - executor: The AgentExecutor that processes A2A requests. - rpc_url: The base URL path at which to mount the JSON-RPC handler. - """ - - # Add a file handler to the logger for watch.log. - logger = logging.getLogger(__name__) - logger.addHandler(watch_log.create_file_handler()) - - # Build the Starlette app and add middlewares. - app = _build_starlette_app(agent_card, executor=executor, rpc_url=rpc_url) - _add_middlewares(app, logger) - - # Start the server. - logger.info("%s listening on http://localhost:%d", agent_card.name, port) - uvicorn.run( - app, host="127.0.0.1", port=port, log_level="info", timeout_keep_alive=120 - ) + """Launches a Uvicorn server for an agent and block the current thread. + + Args: + port: TCP port to bind to. + agent_card: The AgentCard object describing the agent. + executor: The AgentExecutor that processes A2A requests. + rpc_url: The base URL path at which to mount the JSON-RPC handler. + """ + # Add a file handler to the logger for watch.log. + logger = logging.getLogger(__name__) + logger.addHandler(watch_log.create_file_handler()) + + # Build the Starlette app and add middlewares. + app = _build_starlette_app(agent_card, executor=executor, rpc_url=rpc_url) + _add_middlewares(app, logger) + + # Start the server. + logger.info('%s listening on http://localhost:%d', agent_card.name, port) + uvicorn.run( + app, + host='127.0.0.1', + port=port, + log_level='info', + timeout_keep_alive=120, + ) def _create_watch_log_handler() -> logging.FileHandler: - """Create a file handler for watch.log logger. + """Create a file handler for watch.log logger. - watch.log is a log file meant to be watched in parallel with running a - scenario. It will contain all the requests and responses to/from the agent - that are sent to/from the client, so engineers can see what is happening - between the servers in real time. + watch.log is a log file meant to be watched in parallel with running a + scenario. It will contain all the requests and responses to/from the agent + that are sent to/from the client, so engineers can see what is happening + between the servers in real time. - Returns: - A logging.FileHandler instance configured for 'watch.log'. - """ - file_handler = logging.FileHandler(".logs/watch.log") - file_handler.setLevel(logging.INFO) - file_handler.setFormatter(logging.Formatter("%(name)s: %(message)s")) - return file_handler + Returns: + A logging.FileHandler instance configured for 'watch.log'. + """ + file_handler = logging.FileHandler('.logs/watch.log') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter('%(name)s: %(message)s')) + return file_handler class _LoggingMiddleware(BaseHTTPMiddleware): - """Intercepts and logs incoming request and response details.""" - - def __init__(self, *args, logger: logging.Logger, **kwargs): - self._logger = logger - super().__init__(*args, **kwargs) - - async def dispatch(self, request: Request, call_next) -> Response: - self._logger.info("\n\n\n") - self._logger.info("---------- New Agent Request Received---------") - - # Log the request method and URL. - self._logger.info("%s %s", request.method, request.url) - - # Log the request body if it's present. - content_length = request.headers.get("content-length") - if content_length and int(content_length) > 0: - request_body = await request.json() - else: - request_body = "" - - self._logger.info("\n") - self._logger.info("[Request Body]") - self._logger.info("%s", request_body) - - # If the extension header is present, log a notice. - extension_header = request.headers.get(A2A_EXTENSIONS_HEADER) - if extension_header: - self._logger.info( - "\n[Extension Header]\n%s: %s", A2A_EXTENSIONS_HEADER, extension_header - ) - - response = await call_next(request) - - # Ensure the response has a body to read. - if response.body_iterator: - body = b"" - - # Read the entire response body. - # All responses are UTF-8 encoded JSON, so this should always succeed. - async for chunk in response.body_iterator: - body += chunk - - try: - response_body_json = body.decode("utf-8") - except UnicodeDecodeError: - self._logger.warning("Failed to decode response body as UTF-8.") - response_body_json = body - - self._logger.info("\n") - self._logger.info("[Response Body]") - self._logger.info("%s", response_body_json) - - return Response( - content=body, - status_code=response.status_code, - media_type=response.media_type, - headers=response.headers, - ) - else: - self._logger.info("\n") - self._logger.info("[Response Body]") - self._logger.info("") - return response + """Intercepts and logs incoming request and response details.""" + + def __init__(self, *args, logger: logging.Logger, **kwargs): + self._logger = logger + super().__init__(*args, **kwargs) + + async def dispatch(self, request: Request, call_next) -> Response: + self._logger.info('\n\n\n') + self._logger.info('---------- New Agent Request Received---------') + + # Log the request method and URL. + self._logger.info('%s %s', request.method, request.url) + + # Log the request body if it's present. + content_length = request.headers.get('content-length') + if content_length and int(content_length) > 0: + request_body = await request.json() + else: + request_body = '' + + self._logger.info('\n') + self._logger.info('[Request Body]') + self._logger.info('%s', request_body) + + # If the extension header is present, log a notice. + extension_header = request.headers.get(A2A_EXTENSIONS_HEADER) + if extension_header: + self._logger.info( + '\n[Extension Header]\n%s: %s', + A2A_EXTENSIONS_HEADER, + extension_header, + ) + + response = await call_next(request) + + # Ensure the response has a body to read. + if response.body_iterator: + body = b'' + + # Read the entire response body. + # All responses are UTF-8 encoded JSON, so this should always succeed. + async for chunk in response.body_iterator: + body += chunk + + try: + response_body_json = body.decode('utf-8') + except UnicodeDecodeError: + self._logger.warning('Failed to decode response body as UTF-8.') + response_body_json = body + + self._logger.info('\n') + self._logger.info('[Response Body]') + self._logger.info('%s', response_body_json) + + return Response( + content=body, + status_code=response.status_code, + media_type=response.media_type, + headers=response.headers, + ) + self._logger.info('\n') + self._logger.info('[Response Body]') + self._logger.info('') + return response def _build_starlette_app( agent_card: AgentCard, *, executor, rpc_url ) -> A2AStarletteApplication: - """Create and return a ready-to-serve Starlette ASGI application. + """Create and return a ready-to-serve Starlette ASGI application. - Args: - agent_card: The AgentCard object describing the agent. - executor: The AgentExecutor that processes A2A requests. - rpc_url: The base URL path at which to mount the JSON-RPC handler. + Args: + agent_card: The AgentCard object describing the agent. + executor: The AgentExecutor that processes A2A requests. + rpc_url: The base URL path at which to mount the JSON-RPC handler. - Returns: - An instance of A2AStarletteApplication. + Returns: + An instance of A2AStarletteApplication. - Raises: - ValueError: If executor is None. - """ - if executor is None: - raise ValueError("executor must be supplied") + Raises: + ValueError: If executor is None. + """ + if executor is None: + raise ValueError('executor must be supplied') - handler = DefaultRequestHandler( - agent_executor=executor, - task_store=InMemoryTaskStore(), - request_context_builder=SimpleRequestContextBuilder(), - ) + handler = DefaultRequestHandler( + agent_executor=executor, + task_store=InMemoryTaskStore(), + request_context_builder=SimpleRequestContextBuilder(), + ) - app = A2AStarletteApplication( - agent_card=agent_card, http_handler=handler - ).build( - rpc_url=rpc_url, agent_card_url=f"{rpc_url}{AGENT_CARD_WELL_KNOWN_PATH}" - ) - return app + app = A2AStarletteApplication( + agent_card=agent_card, http_handler=handler + ).build( + rpc_url=rpc_url, agent_card_url=f'{rpc_url}{AGENT_CARD_WELL_KNOWN_PATH}' + ) + return app def _add_middlewares(app, logger: logging.Logger) -> None: - """Add middlewares to the Starlette app.""" - app.add_middleware( - CORSMiddleware, - allow_origins=[ - "http://localhost:8000", - "http://127.0.0.1:8000", - "http://0.0.0.0:8000", - "http://localhost:8081", - "http://127.0.0.1:8081", - "http://0.0.0.0:8081", - "http://localhost:8082", - "http://127.0.0.1:8082", - "http://0.0.0.0:8082", - "http://localhost:8083", - "http://127.0.0.1:8083", - "http://0.0.0.0:8083", - "http://localhost:8080", - "http://127.0.0.1:8080", - "http://0.0.0.0:8080", - ], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - app.add_middleware(_LoggingMiddleware, logger=logger) - return app + """Add middlewares to the Starlette app.""" + app.add_middleware( + CORSMiddleware, + allow_origins=[ + 'http://localhost:8000', + 'http://127.0.0.1:8000', + 'http://0.0.0.0:8000', + 'http://localhost:8081', + 'http://127.0.0.1:8081', + 'http://0.0.0.0:8081', + 'http://localhost:8082', + 'http://127.0.0.1:8082', + 'http://0.0.0.0:8082', + 'http://localhost:8083', + 'http://127.0.0.1:8083', + 'http://0.0.0.0:8083', + 'http://localhost:8080', + 'http://127.0.0.1:8080', + 'http://0.0.0.0:8080', + ], + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + app.add_middleware(_LoggingMiddleware, logger=logger) + return app diff --git a/samples/python/src/common/validation.py b/samples/python/src/common/validation.py index d7e2dce3..7cb47db4 100644 --- a/samples/python/src/common/validation.py +++ b/samples/python/src/common/validation.py @@ -20,18 +20,18 @@ def validate_payment_mandate_signature(payment_mandate: PaymentMandate) -> None: - """Validates the PaymentMandate signature. + """Validates the PaymentMandate signature. - Args: - payment_mandate: The PaymentMandate to be validated. + Args: + payment_mandate: The PaymentMandate to be validated. - Raises: - ValueError: If the PaymentMandate signature is not valid. - """ - # In a real implementation, full validation logic would reside here. For - # demonstration purposes, we simply log that the authorization field is - # populated. - if payment_mandate.user_authorization is None: - raise ValueError("User authorization not found in PaymentMandate.") + Raises: + ValueError: If the PaymentMandate signature is not valid. + """ + # In a real implementation, full validation logic would reside here. For + # demonstration purposes, we simply log that the authorization field is + # populated. + if payment_mandate.user_authorization is None: + raise ValueError('User authorization not found in PaymentMandate.') - logging.info("Valid PaymentMandate found.") + logging.info('Valid PaymentMandate found.') diff --git a/samples/python/src/common/watch_log.py b/samples/python/src/common/watch_log.py index 42419e3a..8a396df8 100644 --- a/samples/python/src/common/watch_log.py +++ b/samples/python/src/common/watch_log.py @@ -21,95 +21,97 @@ """ import logging + from typing import Any from a2a.server.agent_execution.context import RequestContext -from ap2.types.mandate import CART_MANDATE_DATA_KEY -from ap2.types.mandate import INTENT_MANDATE_DATA_KEY -from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY +from ap2.types.mandate import ( + CART_MANDATE_DATA_KEY, + INTENT_MANDATE_DATA_KEY, + PAYMENT_MANDATE_DATA_KEY, +) + _logger = logging.getLogger(__name__) def create_file_handler() -> logging.FileHandler: - """Creates a file handler to the logger for watch.log. + """Creates a file handler to the logger for watch.log. - Returns: - A logging.FileHandler instance configured for 'watch.log'. - """ - file_handler = logging.FileHandler(".logs/watch.log") - file_handler.setLevel(logging.INFO) - file_handler.setFormatter(logging.Formatter("%(message)s")) - return file_handler + Returns: + A logging.FileHandler instance configured for 'watch.log'. + """ + file_handler = logging.FileHandler('.logs/watch.log') + file_handler.setLevel(logging.INFO) + file_handler.setFormatter(logging.Formatter('%(message)s')) + return file_handler def log_a2a_message_parts( text_parts: list[str], data_parts: list[dict[str, Any]] ): - _load_logger() + _load_logger() - """Logs the A2A message parts to the watch.log file.""" - _log_request_instructions(text_parts) - _log_mandates(data_parts) - _log_extra_data(data_parts) + """Logs the A2A message parts to the watch.log file.""" + _log_request_instructions(text_parts) + _log_mandates(data_parts) + _log_extra_data(data_parts) def log_a2a_request_extensions(context: RequestContext) -> None: - """Logs the A2A extensions activated to the watch.log file.""" - - if not context.call_context.activated_extensions: - return + """Logs the A2A extensions activated to the watch.log file.""" + if not context.call_context.activated_extensions: + return - _logger.info("\n") - _logger.info("[A2A Extensions Activated in the Request]") + _logger.info('\n') + _logger.info('[A2A Extensions Activated in the Request]') - for extension in context.call_context.requested_extensions: - _logger.info(extension) + for extension in context.call_context.requested_extensions: + _logger.info(extension) def _load_logger(): - if not _logger.handlers: - _logger.addHandler(create_file_handler()) + if not _logger.handlers: + _logger.addHandler(create_file_handler()) def _log_request_instructions(text_parts: list[str]) -> None: - """Logs the request instructions from the text parts.""" - _logger.info("\n") - _logger.info("[Request Instructions]") - _logger.info(text_parts) + """Logs the request instructions from the text parts.""" + _logger.info('\n') + _logger.info('[Request Instructions]') + _logger.info(text_parts) def _log_mandates(data_parts: list[dict[str, Any]]) -> None: - """Extracts and logs mandates from the data parts.""" - - for data_part in data_parts: - for key, value in data_part.items(): - if key == CART_MANDATE_DATA_KEY: - _logger.info("\n") - _logger.info("[A Cart Mandate was in the request Data]") - _logger.info(value) - elif key == INTENT_MANDATE_DATA_KEY: - _logger.info("\n") - _logger.info("[An Intent Mandate was in the request Data]") - _logger.info(value) - elif key == PAYMENT_MANDATE_DATA_KEY: - _logger.info("\n") - _logger.info("[A Payment Mandate was in the request Data]") - _logger.info(value) + """Extracts and logs mandates from the data parts.""" + for data_part in data_parts: + for key, value in data_part.items(): + if key == CART_MANDATE_DATA_KEY: + _logger.info('\n') + _logger.info('[A Cart Mandate was in the request Data]') + _logger.info(value) + elif key == INTENT_MANDATE_DATA_KEY: + _logger.info('\n') + _logger.info('[An Intent Mandate was in the request Data]') + _logger.info(value) + elif key == PAYMENT_MANDATE_DATA_KEY: + _logger.info('\n') + _logger.info('[A Payment Mandate was in the request Data]') + _logger.info(value) def _log_extra_data(data_parts: list[dict[str, Any]]) -> None: - """Extracts and logs extra data from the data parts.""" - for data_part in data_parts: - for key, value in data_part.items(): - if ( - key == CART_MANDATE_DATA_KEY - or key == INTENT_MANDATE_DATA_KEY - or key == PAYMENT_MANDATE_DATA_KEY - ): - continue - - _logger.info("\n") - _logger.info("[Data Part: %s] ", key) - _logger.info(value) + """Extracts and logs extra data from the data parts.""" + for data_part in data_parts: + for key, value in data_part.items(): + if ( + key == CART_MANDATE_DATA_KEY + or key == INTENT_MANDATE_DATA_KEY + or key == PAYMENT_MANDATE_DATA_KEY + ): + continue + + _logger.info('\n') + _logger.info('[Data Part: %s] ', key) + _logger.info(value) diff --git a/samples/python/src/roles/credentials_provider_agent/__main__.py b/samples/python/src/roles/credentials_provider_agent/__main__.py index 6e555318..1814d7d6 100644 --- a/samples/python/src/roles/credentials_provider_agent/__main__.py +++ b/samples/python/src/roles/credentials_provider_agent/__main__.py @@ -17,21 +17,26 @@ from collections.abc import Sequence from absl import app -from roles.credentials_provider_agent.agent_executor import CredentialsProviderExecutor from common import server +from roles.credentials_provider_agent.agent_executor import ( + CredentialsProviderExecutor, +) + AGENT_PORT = 8002 def main(argv: Sequence[str]) -> None: - agent_card = server.load_local_agent_card(__file__) - server.run_agent_blocking( - port=AGENT_PORT, - agent_card=agent_card, - executor=CredentialsProviderExecutor(agent_card.capabilities.extensions), - rpc_url="/a2a/credentials_provider", - ) - - -if __name__ == "__main__": - app.run(main) + agent_card = server.load_local_agent_card(__file__) + server.run_agent_blocking( + port=AGENT_PORT, + agent_card=agent_card, + executor=CredentialsProviderExecutor( + agent_card.capabilities.extensions + ), + rpc_url='/a2a/credentials_provider', + ) + + +if __name__ == '__main__': + app.run(main) diff --git a/samples/python/src/roles/credentials_provider_agent/account_manager.py b/samples/python/src/roles/credentials_provider_agent/account_manager.py index 8fba0bef..8a1ab9f4 100644 --- a/samples/python/src/roles/credentials_provider_agent/account_manager.py +++ b/samples/python/src/roles/credentials_provider_agent/account_manager.py @@ -22,74 +22,74 @@ _account_db = { - "bugsbunny@gmail.com": { - "shipping_address": { - "recipient": "Bugs Bunny", - "organization": "Sample Organization", - "address_line": ["123 Main St"], - "city": "Sample City", - "region": "ST", - "postal_code": "00000", - "country": "US", - "phone_number": "+1-000-000-0000", + 'bugsbunny@gmail.com': { + 'shipping_address': { + 'recipient': 'Bugs Bunny', + 'organization': 'Sample Organization', + 'address_line': ['123 Main St'], + 'city': 'Sample City', + 'region': 'ST', + 'postal_code': '00000', + 'country': 'US', + 'phone_number': '+1-000-000-0000', }, - "payment_methods": { - "card1": { - "type": "CARD", - "alias": "American Express ending in 4444", - "network": [{"name": "amex", "formats": ["DPAN"]}], - "cryptogram": "fake_cryptogram_abc123", - "token": "1111000000000000", - "card_holder_name": "John Doe", - "card_expiration": "12/2025", - "card_billing_address": { - "country": "US", - "postal_code": "00000", + 'payment_methods': { + 'card1': { + 'type': 'CARD', + 'alias': 'American Express ending in 4444', + 'network': [{'name': 'amex', 'formats': ['DPAN']}], + 'cryptogram': 'fake_cryptogram_abc123', + 'token': '1111000000000000', + 'card_holder_name': 'John Doe', + 'card_expiration': '12/2025', + 'card_billing_address': { + 'country': 'US', + 'postal_code': '00000', }, }, - "card2": { - "type": "CARD", - "alias": "American Express ending in 8888", - "network": [{"name": "amex", "formats": ["DPAN"]}], - "cryptogram": "fake_cryptogram_ghi789", - "token": "2222000000000000", - "card_holder_name": "Bugs Bunny", - "card_expiration": "10/2027", - "card_billing_address": { - "country": "US", - "postal_code": "00000", + 'card2': { + 'type': 'CARD', + 'alias': 'American Express ending in 8888', + 'network': [{'name': 'amex', 'formats': ['DPAN']}], + 'cryptogram': 'fake_cryptogram_ghi789', + 'token': '2222000000000000', + 'card_holder_name': 'Bugs Bunny', + 'card_expiration': '10/2027', + 'card_billing_address': { + 'country': 'US', + 'postal_code': '00000', }, }, - "bank_account1": { - "type": "BANK_ACCOUNT", - "account_number": "111", - "alias": "Primary bank account", + 'bank_account1': { + 'type': 'BANK_ACCOUNT', + 'account_number': '111', + 'alias': 'Primary bank account', }, - "digital_wallet1": { - "type": "DIGITAL_WALLET", - "brand": "PayPal", - "account_identifier": "foo@bar.com", - "alias": "Bugs's PayPal account", + 'digital_wallet1': { + 'type': 'DIGITAL_WALLET', + 'brand': 'PayPal', + 'account_identifier': 'foo@bar.com', + 'alias': "Bugs's PayPal account", }, }, }, - "daffyduck@gmail.com": { - "payment_methods": { - "bank_account1": { - "type": "BANK_ACCOUNT", - "brand": "Bank of Money", - "account_number": "789", - "alias": "Main checking account", + 'daffyduck@gmail.com': { + 'payment_methods': { + 'bank_account1': { + 'type': 'BANK_ACCOUNT', + 'brand': 'Bank of Money', + 'account_number': '789', + 'alias': 'Main checking account', } }, }, - "elmerfudd@gmail.com": { - "payment_methods": { - "digital_wallet1": { - "type": "DIGITAL_WALLET", - "brand": "PayPal", - "account_identifier": "elmerfudd@gmail.com", - "alias": "Fudd's PayPal", + 'elmerfudd@gmail.com': { + 'payment_methods': { + 'digital_wallet1': { + 'type': 'DIGITAL_WALLET', + 'brand': 'PayPal', + 'account_identifier': 'elmerfudd@gmail.com', + 'alias': "Fudd's PayPal", } } }, @@ -100,109 +100,107 @@ def create_token(email_address: str, payment_method_alias: str) -> str: - """Creates and stores a token for an account. + """Creates and stores a token for an account. - Args: - email_address: The email address of the account. - payment_method_alias: The alias of the payment method. + Args: + email_address: The email address of the account. + payment_method_alias: The alias of the payment method. - Returns: - The token for the payment method. - """ - token = f"fake_payment_credential_token_{len(_token)}" + Returns: + The token for the payment method. + """ + token = f'fake_payment_credential_token_{len(_token)}' - _token[token] = { - "email_address": email_address, - "payment_method_alias": payment_method_alias, - "payment_mandate_id": None, - } + _token[token] = { + 'email_address': email_address, + 'payment_method_alias': payment_method_alias, + 'payment_mandate_id': None, + } - return token + return token def update_token(token: str, payment_mandate_id: str) -> None: - """Updates the token with the payment mandate id. - - Args: - token: The token to update. - payment_mandate_id: The payment mandate id to associate with the token. - """ - if token not in _token: - raise ValueError(f"Token {token} not found") - if _token[token].get("payment_mandate_id"): - # Do not overwrite the payment mandate id if it is already set. - return - _token[token]["payment_mandate_id"] = payment_mandate_id + """Updates the token with the payment mandate id. + + Args: + token: The token to update. + payment_mandate_id: The payment mandate id to associate with the token. + """ + if token not in _token: + raise ValueError(f'Token {token} not found') + if _token[token].get('payment_mandate_id'): + # Do not overwrite the payment mandate id if it is already set. + return + _token[token]['payment_mandate_id'] = payment_mandate_id + def verify_token(token: str, payment_mandate_id: str) -> dict[str, Any]: - """Look up an account by token. - - Args: - token: The token for look up. - payment_mandate_id: The payment mandate id associated with the token. - - Returns: - The account for the given token, or status:invalid_token if the token is not - valid. - """ - account_lookup = _token.get(token, {}) - if not account_lookup: - raise ValueError("Invalid token") - if account_lookup.get("payment_mandate_id") != payment_mandate_id: - raise ValueError("Invalid token") - email_address = account_lookup.get("email_address") - alias = account_lookup.get("payment_method_alias") - return get_payment_method_by_alias(email_address, alias) + """Look up an account by token. + + Args: + token: The token for look up. + payment_mandate_id: The payment mandate id associated with the token. + + Returns: + The account for the given token, or status:invalid_token if the token is not + valid. + """ + account_lookup = _token.get(token, {}) + if not account_lookup: + raise ValueError('Invalid token') + if account_lookup.get('payment_mandate_id') != payment_mandate_id: + raise ValueError('Invalid token') + email_address = account_lookup.get('email_address') + alias = account_lookup.get('payment_method_alias') + return get_payment_method_by_alias(email_address, alias) def get_account_payment_methods(email_address: str) -> list[dict[str, Any]]: - """Returns a list of the payment methods for the given account email address. + """Returns a list of the payment methods for the given account email address. - Args: - email_address: The account's email address. + Args: + email_address: The account's email address. - Returns: - A list of the user's payment_methods. - """ - - return list( - _account_db.get(email_address, {}).get("payment_methods", {}).values() - ) + Returns: + A list of the user's payment_methods. + """ + return list( + _account_db.get(email_address, {}).get('payment_methods', {}).values() + ) def get_account_shipping_address(email_address: str) -> dict[str, Any]: - """Gets the shipping address associated for the given account email address. - - Args: - email_address: The account's email address. + """Gets the shipping address associated for the given account email address. - Returns: - The account's shipping address. - """ + Args: + email_address: The account's email address. - return _account_db.get(email_address, {}).get("shipping_address", {}) + Returns: + The account's shipping address. + """ + return _account_db.get(email_address, {}).get('shipping_address', {}) def get_payment_method_by_alias( email_address: str, alias: str ) -> dict[str, Any] | None: - """Returns the payment method for a given account and alias. - - Args: - email_address: The account's email address. - alias: The alias of the payment method to retrieve. - - Returns: - The payment method for the given account and alias, or status:not_found. - """ - - payment_methods = list( - filter( - lambda payment_method: payment_method.get("alias").casefold() - == alias.casefold(), - get_account_payment_methods(email_address), - ) - ) - if not payment_methods: - return None - return payment_methods[0] + """Returns the payment method for a given account and alias. + + Args: + email_address: The account's email address. + alias: The alias of the payment method to retrieve. + + Returns: + The payment method for the given account and alias, or status:not_found. + """ + payment_methods = list( + filter( + lambda payment_method: payment_method.get('alias').casefold() + == alias.casefold(), + get_account_payment_methods(email_address), + ) + ) + if not payment_methods: + return None + return payment_methods[0] diff --git a/samples/python/src/roles/credentials_provider_agent/agent_executor.py b/samples/python/src/roles/credentials_provider_agent/agent_executor.py index dbb7b5ea..ffb21979 100644 --- a/samples/python/src/roles/credentials_provider_agent/agent_executor.py +++ b/samples/python/src/roles/credentials_provider_agent/agent_executor.py @@ -20,8 +20,8 @@ 3. Provide a payment credential token for a specific payment method. 4. Provide payment credentials to a processor for completion of a payment. -In order to clearly demonstrate the use of the Agent Payments Protocol A2A -extension, this agent was built directly using the A2A framework. +In order to clearly demonstrate the use of the Agent Payments Protocol A2A +extension, this agent was built directly using the A2A framework. The core logic of how an A2A agent processes requests and generates responses is handled by an AgentExecutor. The BaseServerExecutor handles the common task of @@ -31,16 +31,17 @@ from typing import Any -from . import tools from common.base_server_executor import BaseServerExecutor from common.system_utils import DEBUG_MODE_INSTRUCTIONS +from . import tools class CredentialsProviderExecutor(BaseServerExecutor): - """AgentExecutor for the credentials provider agent.""" + """AgentExecutor for the credentials provider agent.""" - _system_prompt = """ + _system_prompt = ( + """ You are a credentials provider agent acting as a secure digital wallet. Your job is to manage a user's payment methods and shipping addresses. @@ -49,21 +50,22 @@ class CredentialsProviderExecutor(BaseServerExecutor): Do not engage in conversation. %s - """ % DEBUG_MODE_INSTRUCTIONS - - def __init__(self, supported_extensions: list[dict[str, Any]] = None): - """Initializes the CredentialsProviderExecutor. + """ + % DEBUG_MODE_INSTRUCTIONS + ) - Args: - supported_extensions: A list of extension objects supported by the - agent. - """ + def __init__(self, supported_extensions: list[dict[str, Any]] = None): + """Initializes the CredentialsProviderExecutor. - agent_tools = [ - tools.handle_create_payment_credential_token, - tools.handle_get_payment_method_raw_credentials, - tools.handle_get_shipping_address, - tools.handle_search_payment_methods, - tools.handle_signed_payment_mandate, - ] - super().__init__(supported_extensions, agent_tools, self._system_prompt) + Args: + supported_extensions: A list of extension objects supported by the + agent. + """ + agent_tools = [ + tools.handle_create_payment_credential_token, + tools.handle_get_payment_method_raw_credentials, + tools.handle_get_shipping_address, + tools.handle_search_payment_methods, + tools.handle_signed_payment_mandate, + ] + super().__init__(supported_extensions, agent_tools, self._system_prompt) diff --git a/samples/python/src/roles/credentials_provider_agent/tools.py b/samples/python/src/roles/credentials_provider_agent/tools.py index 49c00116..4517297f 100644 --- a/samples/python/src/roles/credentials_provider_agent/tools.py +++ b/samples/python/src/roles/credentials_provider_agent/tools.py @@ -21,17 +21,17 @@ from typing import Any from a2a.server.tasks.task_updater import TaskUpdater -from a2a.types import DataPart -from a2a.types import Part -from a2a.types import Task +from a2a.types import DataPart, Part, Task +from common import message_utils -from . import account_manager from ap2.types.contact_picker import CONTACT_ADDRESS_DATA_KEY -from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY -from ap2.types.mandate import PaymentMandate -from ap2.types.payment_request import PAYMENT_METHOD_DATA_DATA_KEY -from ap2.types.payment_request import PaymentMethodData -from common import message_utils +from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY, PaymentMandate +from ap2.types.payment_request import ( + PAYMENT_METHOD_DATA_DATA_KEY, + PaymentMethodData, +) + +from . import account_manager async def handle_get_shipping_address( @@ -39,23 +39,23 @@ async def handle_get_shipping_address( updater: TaskUpdater, current_task: Task | None, ) -> None: - """Handles a request to get the user's shipping address. - - Updates a task with the user's shipping address if found. - - Args: - data_parts: DataPart contents. Should contain a single user_email. - updater: The TaskUpdater instance for updating the task state. - current_task: The current task if there is one. - """ - user_email = message_utils.find_data_part("user_email", data_parts) - if not user_email: - raise ValueError("user_email is required for get_shipping_address") - shipping_address = account_manager.get_account_shipping_address(user_email) - await updater.add_artifact( - [Part(root=DataPart(data={CONTACT_ADDRESS_DATA_KEY: shipping_address}))] - ) - await updater.complete() + """Handles a request to get the user's shipping address. + + Updates a task with the user's shipping address if found. + + Args: + data_parts: DataPart contents. Should contain a single user_email. + updater: The TaskUpdater instance for updating the task state. + current_task: The current task if there is one. + """ + user_email = message_utils.find_data_part('user_email', data_parts) + if not user_email: + raise ValueError('user_email is required for get_shipping_address') + shipping_address = account_manager.get_account_shipping_address(user_email) + await updater.add_artifact( + [Part(root=DataPart(data={CONTACT_ADDRESS_DATA_KEY: shipping_address}))] + ) + await updater.complete() async def handle_search_payment_methods( @@ -63,40 +63,38 @@ async def handle_search_payment_methods( updater: TaskUpdater, current_task: Task | None, ) -> None: - """Returns the user's payment methods that match what the merchant accepts. - - The merchant's accepted payment methods are provided in the data_parts as a - list of PaymentMethodData objects. The user's account is identified by the - user_email provided in the data_parts. - - This tool finds and returns all the payment methods associated with the user's - account that match the merchant's accepted payment methods. - - Args: - data_parts: DataPart contents. Should contain a single user_email and a - list of PaymentMethodData objects. - updater: The TaskUpdater instance for updating the task state. - current_task: The current task if there is one. - """ - user_email = message_utils.find_data_part("user_email", data_parts) - method_data = message_utils.find_data_parts( - PAYMENT_METHOD_DATA_DATA_KEY, data_parts - ) - if not user_email: - raise ValueError( - "user_email is required for search_payment_methods" + """Returns the user's payment methods that match what the merchant accepts. + + The merchant's accepted payment methods are provided in the data_parts as a + list of PaymentMethodData objects. The user's account is identified by the + user_email provided in the data_parts. + + This tool finds and returns all the payment methods associated with the user's + account that match the merchant's accepted payment methods. + + Args: + data_parts: DataPart contents. Should contain a single user_email and a + list of PaymentMethodData objects. + updater: The TaskUpdater instance for updating the task state. + current_task: The current task if there is one. + """ + user_email = message_utils.find_data_part('user_email', data_parts) + method_data = message_utils.find_data_parts( + PAYMENT_METHOD_DATA_DATA_KEY, data_parts ) - if not method_data: - raise ValueError("method_data is required for search_payment_methods") - - merchant_method_data_list = [ - PaymentMethodData.model_validate(data) for data in method_data - ] - eligible_aliases = _get_eligible_payment_method_aliases( - user_email, merchant_method_data_list - ) - await updater.add_artifact([Part(root=DataPart(data=eligible_aliases))]) - await updater.complete() + if not user_email: + raise ValueError('user_email is required for search_payment_methods') + if not method_data: + raise ValueError('method_data is required for search_payment_methods') + + merchant_method_data_list = [ + PaymentMethodData.model_validate(data) for data in method_data + ] + eligible_aliases = _get_eligible_payment_method_aliases( + user_email, merchant_method_data_list + ) + await updater.add_artifact([Part(root=DataPart(data=eligible_aliases))]) + await updater.complete() async def handle_get_payment_method_raw_credentials( @@ -104,30 +102,29 @@ async def handle_get_payment_method_raw_credentials( updater: TaskUpdater, current_task: Task | None, ) -> None: - """Exchanges a payment token for the payment method's raw credentials. + """Exchanges a payment token for the payment method's raw credentials. - Updates a task with the payment credentials. + Updates a task with the payment credentials. - Args: - data_parts: DataPart contents. Should contain a single PaymentMandate. - updater: The TaskUpdater instance for updating the task state. - current_task: The current task if there is one. - """ + Args: + data_parts: DataPart contents. Should contain a single PaymentMandate. + updater: The TaskUpdater instance for updating the task state. + current_task: The current task if there is one. + """ + payment_mandate_contents = message_utils.parse_canonical_object( + PAYMENT_MANDATE_DATA_KEY, data_parts, PaymentMandate + ).payment_mandate_contents - payment_mandate_contents = message_utils.parse_canonical_object( - PAYMENT_MANDATE_DATA_KEY, data_parts, PaymentMandate - ).payment_mandate_contents + token = payment_mandate_contents.payment_response.details.get( + 'token', {} + ).get('value', '') + payment_mandate_id = payment_mandate_contents.payment_mandate_id - token = payment_mandate_contents.payment_response.details.get( - "token", {} - ).get("value", "") - payment_mandate_id = payment_mandate_contents.payment_mandate_id - - payment_method = account_manager.verify_token(token, payment_mandate_id) - if not payment_method: - raise ValueError(f"Payment method not found for token: {token}") - await updater.add_artifact([Part(root=DataPart(data=payment_method))]) - await updater.complete() + payment_method = account_manager.verify_token(token, payment_mandate_id) + if not payment_method: + raise ValueError(f'Payment method not found for token: {token}') + await updater.add_artifact([Part(root=DataPart(data=payment_method))]) + await updater.complete() async def handle_create_payment_credential_token( @@ -135,34 +132,34 @@ async def handle_create_payment_credential_token( updater: TaskUpdater, current_task: Task | None, ) -> None: - """Handles a request to get a payment credential token. - - Updates a task with the payment credential token. - - Args: - data_parts: DataPart contents. Should contain the user_email and - payment_method_alias. - updater: The TaskUpdater instance for updating the task state. - current_task: The current task if there is one. - """ - user_email = message_utils.find_data_part("user_email", data_parts) - payment_method_alias = message_utils.find_data_part( - "payment_method_alias", data_parts - ) - if not user_email or not payment_method_alias: - raise ValueError( - "user_email and payment_method_alias are required for" - " create_payment_credential_token" + """Handles a request to get a payment credential token. + + Updates a task with the payment credential token. + + Args: + data_parts: DataPart contents. Should contain the user_email and + payment_method_alias. + updater: The TaskUpdater instance for updating the task state. + current_task: The current task if there is one. + """ + user_email = message_utils.find_data_part('user_email', data_parts) + payment_method_alias = message_utils.find_data_part( + 'payment_method_alias', data_parts + ) + if not user_email or not payment_method_alias: + raise ValueError( + 'user_email and payment_method_alias are required for' + ' create_payment_credential_token' + ) + + tokenized_payment_method = account_manager.create_token( + user_email, payment_method_alias ) - tokenized_payment_method = account_manager.create_token( - user_email, payment_method_alias - ) - - await updater.add_artifact( - [Part(root=DataPart(data={"token": tokenized_payment_method}))] - ) - await updater.complete() + await updater.add_artifact( + [Part(root=DataPart(data={'token': tokenized_payment_method}))] + ) + await updater.complete() async def handle_signed_payment_mandate( @@ -170,91 +167,93 @@ async def handle_signed_payment_mandate( updater: TaskUpdater, current_task: Task | None, ) -> None: - """Handles a signed payment mandate. - - Adds the payment mandate id to the token in storage and then completes the - task. - - Args: - data_parts: DataPart contents. Should contain a single PaymentMandate. - updater: The TaskUpdater instance for updating the task state. - current_task: The current task if there is one. - """ - payment_mandate = message_utils.parse_canonical_object( - PAYMENT_MANDATE_DATA_KEY, data_parts, PaymentMandate - ) - token = payment_mandate.payment_mandate_contents.payment_response.details.get( - "token", {} - ).get("value", "") - payment_mandate_id = ( - payment_mandate.payment_mandate_contents.payment_mandate_id - ) - account_manager.update_token(token, payment_mandate_id) - await updater.complete() + """Handles a signed payment mandate. + + Adds the payment mandate id to the token in storage and then completes the + task. + + Args: + data_parts: DataPart contents. Should contain a single PaymentMandate. + updater: The TaskUpdater instance for updating the task state. + current_task: The current task if there is one. + """ + payment_mandate = message_utils.parse_canonical_object( + PAYMENT_MANDATE_DATA_KEY, data_parts, PaymentMandate + ) + token = ( + payment_mandate.payment_mandate_contents.payment_response.details.get( + 'token', {} + ).get('value', '') + ) + payment_mandate_id = ( + payment_mandate.payment_mandate_contents.payment_mandate_id + ) + account_manager.update_token(token, payment_mandate_id) + await updater.complete() def _get_payment_method_aliases( payment_methods: list[dict[str, Any]], ) -> list[str | None]: - """Gets the payment method aliases from a list of payment methods.""" - return [payment_method.get("alias") for payment_method in payment_methods] + """Gets the payment method aliases from a list of payment methods.""" + return [payment_method.get('alias') for payment_method in payment_methods] def _get_eligible_payment_method_aliases( user_email: str, merchant_accepted_payment_methods: list[PaymentMethodData] ) -> dict[str, list[str | None]]: - """Gets the payment_methods eligible according to given PaymentMethodData. - - Args: - user_email: The email address of the user's account. - merchant_accepted_payment_methods: A list of eligible payment method - criteria. - - Returns: - A list of the user's eligible payment_methods. - """ - payment_methods = account_manager.get_account_payment_methods(user_email) - eligible_payment_methods = [] - - for payment_method in payment_methods: - for criteria in merchant_accepted_payment_methods: - if _payment_method_is_eligible(payment_method, criteria): - eligible_payment_methods.append(payment_method) - break - return { - "payment_method_aliases": _get_payment_method_aliases( - eligible_payment_methods - ) - } + """Gets the payment_methods eligible according to given PaymentMethodData. + + Args: + user_email: The email address of the user's account. + merchant_accepted_payment_methods: A list of eligible payment method + criteria. + + Returns: + A list of the user's eligible payment_methods. + """ + payment_methods = account_manager.get_account_payment_methods(user_email) + eligible_payment_methods = [] + + for payment_method in payment_methods: + for criteria in merchant_accepted_payment_methods: + if _payment_method_is_eligible(payment_method, criteria): + eligible_payment_methods.append(payment_method) + break + return { + 'payment_method_aliases': _get_payment_method_aliases( + eligible_payment_methods + ) + } def _payment_method_is_eligible( payment_method: dict[str, Any], merchant_criteria: PaymentMethodData ) -> bool: - """Checks if a payment method is eligible based on a PaymentMethodData. - - Args: - payment_method: A dictionary representing the payment method. - merchant_criteria: A PaymentMethodData object containing the eligibility - criteria. - - Returns: - True if the payment_method is eligible according to the payment method, - False otherwise. - """ - if payment_method.get("type", "") != merchant_criteria.supported_methods: + """Checks if a payment method is eligible based on a PaymentMethodData. + + Args: + payment_method: A dictionary representing the payment method. + merchant_criteria: A PaymentMethodData object containing the eligibility + criteria. + + Returns: + True if the payment_method is eligible according to the payment method, + False otherwise. + """ + if payment_method.get('type', '') != merchant_criteria.supported_methods: + return False + + merchant_supported_networks = [ + network.casefold() + for network in merchant_criteria.data.get('network', []) + ] + if not merchant_supported_networks: + return False + + payment_card_networks = payment_method.get('network', []) + for network_info in payment_card_networks: + for supported_network in merchant_supported_networks: + if network_info.get('name', '').casefold() == supported_network: + return True return False - - merchant_supported_networks = [ - network.casefold() - for network in merchant_criteria.data.get("network", []) - ] - if not merchant_supported_networks: - return False - - payment_card_networks = payment_method.get("network", []) - for network_info in payment_card_networks: - for supported_network in merchant_supported_networks: - if network_info.get("name", "").casefold() == supported_network: - return True - return False diff --git a/samples/python/src/roles/merchant_agent/__main__.py b/samples/python/src/roles/merchant_agent/__main__.py index 6c6e77e3..f96bdc2e 100644 --- a/samples/python/src/roles/merchant_agent/__main__.py +++ b/samples/python/src/roles/merchant_agent/__main__.py @@ -17,20 +17,22 @@ from collections.abc import Sequence from absl import app - -from roles.merchant_agent.agent_executor import MerchantAgentExecutor from common import server +from roles.merchant_agent.agent_executor import MerchantAgentExecutor + AGENT_MERCHANT_PORT = 8001 + def main(argv: Sequence[str]) -> None: - agent_card = server.load_local_agent_card(__file__) - server.run_agent_blocking( - port=AGENT_MERCHANT_PORT, - agent_card=agent_card, - executor=MerchantAgentExecutor(agent_card.capabilities.extensions), - rpc_url="/a2a/merchant_agent", - ) - -if __name__ == "__main__": - app.run(main) + agent_card = server.load_local_agent_card(__file__) + server.run_agent_blocking( + port=AGENT_MERCHANT_PORT, + agent_card=agent_card, + executor=MerchantAgentExecutor(agent_card.capabilities.extensions), + rpc_url='/a2a/merchant_agent', + ) + + +if __name__ == '__main__': + app.run(main) diff --git a/samples/python/src/roles/merchant_agent/agent_executor.py b/samples/python/src/roles/merchant_agent/agent_executor.py index 55cc43e9..b6b3596c 100644 --- a/samples/python/src/roles/merchant_agent/agent_executor.py +++ b/samples/python/src/roles/merchant_agent/agent_executor.py @@ -28,115 +28,121 @@ invoking it to complete a task. """ - import logging + from typing import Any from a2a.server.tasks.task_updater import TaskUpdater -from a2a.types import Part -from a2a.types import Task -from a2a.types import TextPart - -from . import tools -from .sub_agents import catalog_agent +from a2a.types import Part, Task, TextPart from common import message_utils from common.base_server_executor import BaseServerExecutor from common.system_utils import DEBUG_MODE_INSTRUCTIONS +from . import tools +from .sub_agents import catalog_agent + # A list of known Shopping Agent identifiers that this Merchant is willing to # work with. _KNOWN_SHOPPING_AGENTS = [ - "trusted_shopping_agent", + 'trusted_shopping_agent', ] + class MerchantAgentExecutor(BaseServerExecutor): - """AgentExecutor for the merchant agent.""" + """AgentExecutor for the merchant agent.""" - _system_prompt = """ + _system_prompt = ( + """ You are a merchant agent. Your role is to help users with their shopping requests. You can find items, update shopping carts, and initiate payments. %s - """ % DEBUG_MODE_INSTRUCTIONS - - def __init__(self, supported_extensions: list[dict[str, Any]] = None): - """Initializes the MerchantAgentExecutor. - - Args: - supported_extensions: A list of extension objects supported by the - agent. - """ - agent_tools = [ - tools.update_cart, - catalog_agent.find_items_workflow, - tools.initiate_payment, - tools.dpc_finish, - ] - super().__init__(supported_extensions, agent_tools, self._system_prompt) - - async def _handle_request( - self, - text_parts: list[str], - data_parts: list[dict[str, Any]], - updater: TaskUpdater, - current_task: Task | None, - ) -> None: - """Overrides the base class method to validate the shopping agent first.""" - if not await self._validate_shopping_agent(data_parts, updater): - error_message = updater.new_agent_message( - parts=[ - Part(root=TextPart(text=f"Failed to validate shopping agent.")) - ] - ) - await updater.failed(message=error_message) - return - await super()._handle_request(text_parts, data_parts, updater, current_task) - - async def _validate_shopping_agent( - self, data_parts: list[dict[str, Any]], updater: TaskUpdater - ) -> None: - """Validates that the incoming request is from a trusted Shopping Agent. - - Args: - data_parts: A list of data part contents from the request. - - Returns: - True if the Shopping Agent is trusted, or False if not. - """ - - shopping_agent_id = message_utils.find_data_part( - "shopping_agent_id", data_parts - ) - logging.info( - "Received request from shopping_agent_id: %s", shopping_agent_id + """ + % DEBUG_MODE_INSTRUCTIONS ) - if not shopping_agent_id: - logging.warning("Missing shopping_agent_id in request.") - await _fail_task( - updater, "Unauthorized Request: Missing shopping_agent_id." - ) - return False - - if shopping_agent_id not in _KNOWN_SHOPPING_AGENTS: - logging.warning("Unknown Shopping Agent: %s", shopping_agent_id) - await _fail_task( - updater, f"Unauthorized Request: Unknown agent '{shopping_agent_id}'." - ) - return False - - logging.info( - "Authorized request from shopping_agent_id: %s", shopping_agent_id - ) - return True + def __init__(self, supported_extensions: list[dict[str, Any]] = None): + """Initializes the MerchantAgentExecutor. + + Args: + supported_extensions: A list of extension objects supported by the + agent. + """ + agent_tools = [ + tools.update_cart, + catalog_agent.find_items_workflow, + tools.initiate_payment, + tools.dpc_finish, + ] + super().__init__(supported_extensions, agent_tools, self._system_prompt) + + async def _handle_request( + self, + text_parts: list[str], + data_parts: list[dict[str, Any]], + updater: TaskUpdater, + current_task: Task | None, + ) -> None: + """Overrides the base class method to validate the shopping agent first.""" + if not await self._validate_shopping_agent(data_parts, updater): + error_message = updater.new_agent_message( + parts=[ + Part( + root=TextPart(text='Failed to validate shopping agent.') + ) + ] + ) + await updater.failed(message=error_message) + return + await super()._handle_request( + text_parts, data_parts, updater, current_task + ) + + async def _validate_shopping_agent( + self, data_parts: list[dict[str, Any]], updater: TaskUpdater + ) -> None: + """Validates that the incoming request is from a trusted Shopping Agent. + + Args: + data_parts: A list of data part contents from the request. + + Returns: + True if the Shopping Agent is trusted, or False if not. + """ + shopping_agent_id = message_utils.find_data_part( + 'shopping_agent_id', data_parts + ) + logging.info( + 'Received request from shopping_agent_id: %s', shopping_agent_id + ) + + if not shopping_agent_id: + logging.warning('Missing shopping_agent_id in request.') + await _fail_task( + updater, 'Unauthorized Request: Missing shopping_agent_id.' + ) + return False + + if shopping_agent_id not in _KNOWN_SHOPPING_AGENTS: + logging.warning('Unknown Shopping Agent: %s', shopping_agent_id) + await _fail_task( + updater, + f"Unauthorized Request: Unknown agent '{shopping_agent_id}'.", + ) + return False + + logging.info( + 'Authorized request from shopping_agent_id: %s', shopping_agent_id + ) + return True async def _fail_task(updater: TaskUpdater, error_text: str) -> None: - """A helper function to fail a task with a given error message.""" - error_message = updater.new_agent_message( - parts=[Part(root=TextPart(text=error_text))] - ) - await updater.failed(message=error_message) + """A helper function to fail a task with a given error message.""" + error_message = updater.new_agent_message( + parts=[Part(root=TextPart(text=error_text))] + ) + await updater.failed(message=error_message) diff --git a/samples/python/src/roles/merchant_agent/storage.py b/samples/python/src/roles/merchant_agent/storage.py index 2f0a998d..140038a2 100644 --- a/samples/python/src/roles/merchant_agent/storage.py +++ b/samples/python/src/roles/merchant_agent/storage.py @@ -19,29 +19,27 @@ interactions between the shopper and merchant agents. """ -from typing import Optional - from ap2.types.mandate import CartMandate -def get_cart_mandate(cart_id: str) -> Optional[CartMandate]: - """Get a cart mandate by cart ID.""" - return _store.get(cart_id) +def get_cart_mandate(cart_id: str) -> CartMandate | None: + """Get a cart mandate by cart ID.""" + return _store.get(cart_id) def set_cart_mandate(cart_id: str, cart_mandate: CartMandate) -> None: - """Set a cart mandate by cart ID.""" - _store[cart_id] = cart_mandate + """Set a cart mandate by cart ID.""" + _store[cart_id] = cart_mandate def set_risk_data(context_id: str, risk_data: str) -> None: - """Set risk data by context ID.""" - _store[context_id] = risk_data + """Set risk data by context ID.""" + _store[context_id] = risk_data -def get_risk_data(context_id: str) -> Optional[str]: - """Get risk data by context ID.""" - return _store.get(context_id) +def get_risk_data(context_id: str) -> str | None: + """Get risk data by context ID.""" + return _store.get(context_id) _store = {} diff --git a/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py b/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py index 66c324d7..62a1c785 100644 --- a/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py +++ b/samples/python/src/roles/merchant_agent/sub_agents/catalog_agent.py @@ -17,32 +17,32 @@ This agent fabricates catalog content based on the user's request. """ -from datetime import datetime -from datetime import timedelta -from datetime import timezone +from datetime import UTC, datetime, timedelta from typing import Any from a2a.server.tasks.task_updater import TaskUpdater -from a2a.types import DataPart -from a2a.types import Part -from a2a.types import Task -from a2a.types import TextPart +from a2a.types import DataPart, Part, Task, TextPart +from common import message_utils +from common.system_utils import DEBUG_MODE_INSTRUCTIONS from google import genai from pydantic import ValidationError +from ap2.types.mandate import ( + CART_MANDATE_DATA_KEY, + INTENT_MANDATE_DATA_KEY, + CartContents, + CartMandate, + IntentMandate, +) +from ap2.types.payment_request import ( + PaymentDetailsInit, + PaymentItem, + PaymentMethodData, + PaymentOptions, + PaymentRequest, +) + from .. import storage -from ap2.types.mandate import CART_MANDATE_DATA_KEY -from ap2.types.mandate import CartContents -from ap2.types.mandate import CartMandate -from ap2.types.mandate import INTENT_MANDATE_DATA_KEY -from ap2.types.mandate import IntentMandate -from ap2.types.payment_request import PaymentDetailsInit -from ap2.types.payment_request import PaymentItem -from ap2.types.payment_request import PaymentMethodData -from ap2.types.payment_request import PaymentOptions -from ap2.types.payment_request import PaymentRequest -from common import message_utils -from common.system_utils import DEBUG_MODE_INSTRUCTIONS async def find_items_workflow( @@ -50,51 +50,56 @@ async def find_items_workflow( updater: TaskUpdater, current_task: Task | None, ) -> None: - """Finds products that match the user's IntentMandate.""" - llm_client = genai.Client() - - intent_mandate = message_utils.parse_canonical_object( - INTENT_MANDATE_DATA_KEY, data_parts, IntentMandate - ) - intent = intent_mandate.natural_language_description - prompt = f""" + """Finds products that match the user's IntentMandate.""" + llm_client = genai.Client() + + intent_mandate = message_utils.parse_canonical_object( + INTENT_MANDATE_DATA_KEY, data_parts, IntentMandate + ) + intent = intent_mandate.natural_language_description + prompt = ( + f""" Based on the user's request for '{intent}', your task is to generate 3 complete, unique and realistic PaymentItem JSON objects. You MUST exclude all branding from the PaymentItem `label` field. %s - """ % DEBUG_MODE_INSTRUCTIONS - - llm_response = llm_client.models.generate_content( - model="gemini-2.5-flash", - contents=prompt, - config={ - "response_mime_type": "application/json", - "response_schema": list[PaymentItem], - } - ) - try: - items: list[PaymentItem] = llm_response.parsed - - current_time = datetime.now(timezone.utc) - item_count = 0 - for item in items: - item_count += 1 - await _create_and_add_cart_mandate_artifact( - item, item_count, current_time, updater - ) - risk_data = _collect_risk_data(updater) - updater.add_artifact([ - Part(root=DataPart(data={"risk_data": risk_data})), - ]) - await updater.complete() - except ValidationError as e: - error_message = updater.new_agent_message( - parts=[Part(root=TextPart(text=f"Invalid CartMandate list: {e}"))] + """ + % DEBUG_MODE_INSTRUCTIONS ) - await updater.failed(message=error_message) - return + + llm_response = llm_client.models.generate_content( + model='gemini-2.5-flash', + contents=prompt, + config={ + 'response_mime_type': 'application/json', + 'response_schema': list[PaymentItem], + }, + ) + try: + items: list[PaymentItem] = llm_response.parsed + + current_time = datetime.now(UTC) + item_count = 0 + for item in items: + item_count += 1 + await _create_and_add_cart_mandate_artifact( + item, item_count, current_time, updater + ) + risk_data = _collect_risk_data(updater) + updater.add_artifact( + [ + Part(root=DataPart(data={'risk_data': risk_data})), + ] + ) + await updater.complete() + except ValidationError as e: + error_message = updater.new_agent_message( + parts=[Part(root=TextPart(text=f'Invalid CartMandate list: {e}'))] + ) + await updater.failed(message=error_message) + return async def _create_and_add_cart_mandate_artifact( @@ -103,48 +108,52 @@ async def _create_and_add_cart_mandate_artifact( current_time: datetime, updater: TaskUpdater, ) -> None: - """Creates a CartMandate and adds it as an artifact.""" - payment_request = PaymentRequest( - method_data=[ - PaymentMethodData( - supported_methods="CARD", - data={ - "network": ["mastercard", "paypal", "amex"], - }, - ) - ], - details=PaymentDetailsInit( - id=f"order_{item_count}", - display_items=[item], - total=PaymentItem( - label="Total", - amount=item.amount, - ), - ), - options=PaymentOptions(request_shipping=True), - ) - - cart_contents = CartContents( - id=f"cart_{item_count}", - user_cart_confirmation_required=True, - payment_request=payment_request, - cart_expiry=(current_time + timedelta(minutes=30)).isoformat(), - merchant_name="Generic Merchant", - ) - - cart_mandate = CartMandate(contents=cart_contents) - - storage.set_cart_mandate(cart_mandate.contents.id, cart_mandate) - await updater.add_artifact([ - Part( - root=DataPart(data={CART_MANDATE_DATA_KEY: cart_mandate.model_dump()}) - ) - ]) + """Creates a CartMandate and adds it as an artifact.""" + payment_request = PaymentRequest( + method_data=[ + PaymentMethodData( + supported_methods='CARD', + data={ + 'network': ['mastercard', 'paypal', 'amex'], + }, + ) + ], + details=PaymentDetailsInit( + id=f'order_{item_count}', + display_items=[item], + total=PaymentItem( + label='Total', + amount=item.amount, + ), + ), + options=PaymentOptions(request_shipping=True), + ) + + cart_contents = CartContents( + id=f'cart_{item_count}', + user_cart_confirmation_required=True, + payment_request=payment_request, + cart_expiry=(current_time + timedelta(minutes=30)).isoformat(), + merchant_name='Generic Merchant', + ) + + cart_mandate = CartMandate(contents=cart_contents) + + storage.set_cart_mandate(cart_mandate.contents.id, cart_mandate) + await updater.add_artifact( + [ + Part( + root=DataPart( + data={CART_MANDATE_DATA_KEY: cart_mandate.model_dump()} + ) + ) + ] + ) def _collect_risk_data(updater: TaskUpdater) -> dict: - """Creates a risk_data in the tool_context.""" - # This is a fake risk data for demonstration purposes. - risk_data = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...fake_risk_data" - storage.set_risk_data(updater.context_id, risk_data) - return risk_data + """Creates a risk_data in the tool_context.""" + # This is a fake risk data for demonstration purposes. + risk_data = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...fake_risk_data' + storage.set_risk_data(updater.context_id, risk_data) + return risk_data diff --git a/samples/python/src/roles/merchant_agent/tools.py b/samples/python/src/roles/merchant_agent/tools.py index bd7e79ad..5852d7c7 100644 --- a/samples/python/src/roles/merchant_agent/tools.py +++ b/samples/python/src/roles/merchant_agent/tools.py @@ -18,40 +18,38 @@ shopping and purchasing process. """ -import base64 -import json import logging -from pydantic import ValidationError from typing import Any from a2a.server.tasks.task_updater import TaskUpdater -from a2a.types import DataPart -from a2a.types import Part -from a2a.types import Task -from a2a.types import TextPart - -from . import storage -from ap2.types.contact_picker import ContactAddress -from ap2.types.mandate import CART_MANDATE_DATA_KEY -from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY -from ap2.types.mandate import PaymentMandate -from ap2.types.payment_request import PaymentCurrencyAmount -from ap2.types.payment_request import PaymentItem +from a2a.types import DataPart, Part, Task, TextPart from common import message_utils from common.a2a_extension_utils import EXTENSION_URI from common.a2a_message_builder import A2aMessageBuilder from common.payment_remote_a2a_client import PaymentRemoteA2aClient +from pydantic import ValidationError + +from ap2.types.contact_picker import ContactAddress +from ap2.types.mandate import ( + CART_MANDATE_DATA_KEY, + PAYMENT_MANDATE_DATA_KEY, + PaymentMandate, +) +from ap2.types.payment_request import PaymentCurrencyAmount, PaymentItem + +from . import storage + # A map of payment method types to their corresponding processor agent URLs. # This is the set of linked Merchant Payment Processor Agents this Merchant # is integrated with. _PAYMENT_PROCESSORS_BY_PAYMENT_METHOD_TYPE = { - "CARD": "http://localhost:8003/a2a/merchant_payment_processor_agent", + 'CARD': 'http://localhost:8003/a2a/merchant_payment_processor_agent', } # A placeholder for a JSON Web Token (JWT) used for merchant authorization. -_FAKE_JWT = "eyJhbGciOiJSUzI1NiIsImtpZIwMjQwOTA..." +_FAKE_JWT = 'eyJhbGciOiJSUzI1NiIsImtpZIwMjQwOTA...' async def update_cart( @@ -60,159 +58,166 @@ async def update_cart( current_task: Task | None, debug_mode: bool = False, ) -> None: - """Updates an existing cart after a shipping address is provided. - - Args: - data_parts: A list of data part contents from the request. - updater: The TaskUpdater instance to add artifacts and complete the task. - current_task: The current task -- not used in this function. - debug_mode: Whether the agent is in debug mode. - """ - cart_id = message_utils.find_data_part("cart_id", data_parts) - if not cart_id: - await _fail_task(updater, "Missing cart_id.") - return - - shipping_address = message_utils.find_data_part( - "shipping_address", data_parts - ) - if not shipping_address: - await _fail_task(updater, "Missing shipping_address.") - return - - cart_mandate = storage.get_cart_mandate(cart_id) - if not cart_mandate: - await _fail_task(updater, f"CartMandate not found for cart_id: {cart_id}") - return - - risk_data = storage.get_risk_data(updater.context_id) - if not risk_data: - await _fail_task( - updater, f"Missing risk_data for context_id: {updater.context_id}" + """Updates an existing cart after a shipping address is provided. + + Args: + data_parts: A list of data part contents from the request. + updater: The TaskUpdater instance to add artifacts and complete the task. + current_task: The current task -- not used in this function. + debug_mode: Whether the agent is in debug mode. + """ + cart_id = message_utils.find_data_part('cart_id', data_parts) + if not cart_id: + await _fail_task(updater, 'Missing cart_id.') + return + + shipping_address = message_utils.find_data_part( + 'shipping_address', data_parts ) - return + if not shipping_address: + await _fail_task(updater, 'Missing shipping_address.') + return + + cart_mandate = storage.get_cart_mandate(cart_id) + if not cart_mandate: + await _fail_task( + updater, f'CartMandate not found for cart_id: {cart_id}' + ) + return + + risk_data = storage.get_risk_data(updater.context_id) + if not risk_data: + await _fail_task( + updater, f'Missing risk_data for context_id: {updater.context_id}' + ) + return + + # Update the CartMandate with new shipping and tax cost. + try: + # Add the shipping address to the CartMandate: + cart_mandate.contents.payment_request.shipping_address = ( + ContactAddress.model_validate(shipping_address) + ) + + # Add new shipping and tax costs to the PaymentRequest: + tax_and_shipping_costs = [ + PaymentItem( + label='Shipping', + amount=PaymentCurrencyAmount(currency='USD', value=2.00), + ), + PaymentItem( + label='Tax', + amount=PaymentCurrencyAmount(currency='USD', value=1.50), + ), + ] + + payment_request = cart_mandate.contents.payment_request + + if payment_request.details.display_items is None: + payment_request.details.display_items = tax_and_shipping_costs + else: + payment_request.details.display_items.extend(tax_and_shipping_costs) + + # Recompute the total amount of the PaymentRequest: + payment_request.details.total.amount.value = sum( + item.amount.value for item in payment_request.details.display_items + ) + + # A base64url-encoded JSON Web Token (JWT) that digitally signs the cart + # contents by the merchant's private key. + cart_mandate.merchant_authorization = _FAKE_JWT + + await updater.add_artifact( + [ + Part( + root=DataPart( + data={CART_MANDATE_DATA_KEY: cart_mandate.model_dump()} + ) + ), + Part(root=DataPart(data={'risk_data': risk_data})), + ] + ) + await updater.complete() + + except ValidationError as e: + await _fail_task(updater, f'Invalid CartMandate after update: {e}') + - # Update the CartMandate with new shipping and tax cost. - try: - # Add the shipping address to the CartMandate: - cart_mandate.contents.payment_request.shipping_address = ( - ContactAddress.model_validate(shipping_address) +async def initiate_payment( + data_parts: list[dict[str, Any]], + updater: TaskUpdater, + current_task: Task | None, + debug_mode: bool = False, +) -> None: + """Initiates a payment for a given payment mandate. Use to make a payment. + + Args: + data_parts: The data parts from the request, expected to contain a + PaymentMandate and optionally a challenge response. + updater: The TaskUpdater instance for updating the task state. + current_task: The current task, used to find the processor's task ID. + debug_mode: Whether the agent is in debug mode. + """ + payment_mandate = message_utils.parse_canonical_object( + PAYMENT_MANDATE_DATA_KEY, data_parts, PaymentMandate ) + if not payment_mandate: + await _fail_task(updater, 'Missing payment_mandate.') + return + + risk_data = message_utils.find_data_part('risk_data', data_parts) + if not risk_data: + await _fail_task(updater, 'Missing risk_data.') + return - # Add new shipping and tax costs to the PaymentRequest: - tax_and_shipping_costs = [ - PaymentItem( - label="Shipping", - amount=PaymentCurrencyAmount(currency="USD", value=2.00), - ), - PaymentItem( - label="Tax", - amount=PaymentCurrencyAmount(currency="USD", value=1.50), - ), - ] - - payment_request = cart_mandate.contents.payment_request - - if payment_request.details.display_items is None: - payment_request.details.display_items = tax_and_shipping_costs - else: - payment_request.details.display_items.extend(tax_and_shipping_costs) - - # Recompute the total amount of the PaymentRequest: - payment_request.details.total.amount.value = sum( - item.amount.value for item in payment_request.details.display_items + payment_method_type = ( + payment_mandate.payment_mandate_contents.payment_response.method_name + ) + processor_url = _PAYMENT_PROCESSORS_BY_PAYMENT_METHOD_TYPE.get( + payment_method_type ) - # A base64url-encoded JSON Web Token (JWT) that digitally signs the cart - # contents by the merchant's private key. - cart_mandate.merchant_authorization = _FAKE_JWT + if not processor_url: + await _fail_task( + updater, + f'No payment processor found for method: {payment_method_type}', + ) + return + + payment_processor_agent = PaymentRemoteA2aClient( + name='payment_processor_agent', + base_url=processor_url, + required_extensions={ + EXTENSION_URI, + }, + ) - await updater.add_artifact([ - Part( - root=DataPart( - data={CART_MANDATE_DATA_KEY: cart_mandate.model_dump()} - ) - ), - Part(root=DataPart(data={"risk_data": risk_data})), - ]) - await updater.complete() + message_builder = ( + A2aMessageBuilder() + .set_context_id(updater.context_id) + .add_text('initiate_payment') + .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate.model_dump()) + .add_data('risk_data', risk_data) + .add_data('debug_mode', debug_mode) + ) - except ValidationError as e: - await _fail_task(updater, f"Invalid CartMandate after update: {e}") + challenge_response = ( + message_utils.find_data_part('challenge_response', data_parts) or '' + ) + if challenge_response: + message_builder.add_data('challenge_response', challenge_response) + payment_processor_task_id = _get_payment_processor_task_id(current_task) + if payment_processor_task_id: + message_builder.set_task_id(payment_processor_task_id) -async def initiate_payment( - data_parts: list[dict[str, Any]], - updater: TaskUpdater, - current_task: Task | None, - debug_mode: bool = False, -) -> None: - """Initiates a payment for a given payment mandate. Use to make a payment. - - Args: - data_parts: The data parts from the request, expected to contain a - PaymentMandate and optionally a challenge response. - updater: The TaskUpdater instance for updating the task state. - current_task: The current task, used to find the processor's task ID. - debug_mode: Whether the agent is in debug mode. - """ - payment_mandate = message_utils.parse_canonical_object( - PAYMENT_MANDATE_DATA_KEY, data_parts, PaymentMandate - ) - if not payment_mandate: - await _fail_task(updater, "Missing payment_mandate.") - return - - risk_data = message_utils.find_data_part("risk_data", data_parts) - if not risk_data: - await _fail_task(updater, "Missing risk_data.") - return - - payment_method_type = ( - payment_mandate.payment_mandate_contents.payment_response.method_name - ) - processor_url = _PAYMENT_PROCESSORS_BY_PAYMENT_METHOD_TYPE.get( - payment_method_type - ) - - if not processor_url: - await _fail_task( - updater, f"No payment processor found for method: {payment_method_type}" + task = await payment_processor_agent.send_a2a_message( + message_builder.build() + ) + await updater.update_status( + state=task.status.state, + message=task.status.message, ) - return - - payment_processor_agent = PaymentRemoteA2aClient( - name="payment_processor_agent", - base_url=processor_url, - required_extensions={ - EXTENSION_URI, - }, - ) - - message_builder = ( - A2aMessageBuilder() - .set_context_id(updater.context_id) - .add_text("initiate_payment") - .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate.model_dump()) - .add_data("risk_data", risk_data) - .add_data("debug_mode", debug_mode) - ) - - challenge_response = ( - message_utils.find_data_part("challenge_response", data_parts) or "" - ) - if challenge_response: - message_builder.add_data("challenge_response", challenge_response) - - payment_processor_task_id = _get_payment_processor_task_id(current_task) - if payment_processor_task_id: - message_builder.set_task_id(payment_processor_task_id) - - task = await payment_processor_agent.send_a2a_message(message_builder.build()) - await updater.update_status( - state=task.status.state, - message=task.status.message, - ) async def dpc_finish( @@ -220,55 +225,61 @@ async def dpc_finish( updater: TaskUpdater, current_task: Task | None, ) -> None: - """Receives and validates a DPC response to finalize payment. - - This tool receives the Digital Payment Credential (DPC) response, in the form - of an OpenID4VP JSON, validates it, and simulates payment finalization. - - Args: - data_parts: A list of data part contents from the request. - updater: The TaskUpdater instance to add artifacts and complete the task. - current_task: The current task, not used in this function. - """ - dpc_response = message_utils.find_data_part("dpc_response", data_parts) - if not dpc_response: - await _fail_task(updater, "Missing dpc_response.") - return - - logging.info("Received DPC response for finalization: %s", dpc_response) - - # --- Sample validation and payment finalization logic --- - # TODO: Validate the nonce, and other merchant-specific attributes from the - # DPC response. - # TODO: Pass the DPC response to the payment processor agent for validation. - - # Simulate payment finalization. - await updater.add_artifact([ - Part(root=DataPart(data={ - "payment_status": "SUCCESS", - "transaction_id": "txn_1234567890", - })) - ]) - await updater.complete() + """Receives and validates a DPC response to finalize payment. + + This tool receives the Digital Payment Credential (DPC) response, in the form + of an OpenID4VP JSON, validates it, and simulates payment finalization. + + Args: + data_parts: A list of data part contents from the request. + updater: The TaskUpdater instance to add artifacts and complete the task. + current_task: The current task, not used in this function. + """ + dpc_response = message_utils.find_data_part('dpc_response', data_parts) + if not dpc_response: + await _fail_task(updater, 'Missing dpc_response.') + return + + logging.info('Received DPC response for finalization: %s', dpc_response) + + # --- Sample validation and payment finalization logic --- + # TODO: Validate the nonce, and other merchant-specific attributes from the + # DPC response. + # TODO: Pass the DPC response to the payment processor agent for validation. + + # Simulate payment finalization. + await updater.add_artifact( + [ + Part( + root=DataPart( + data={ + 'payment_status': 'SUCCESS', + 'transaction_id': 'txn_1234567890', + } + ) + ) + ] + ) + await updater.complete() def _get_payment_processor_task_id(task: Task | None) -> str | None: - """Returns the task ID of the payment processor task, if it exists. - - Identified by assuming the first message with a task ID that is not the - merchant's task ID is a payment processor message. - """ - if task is None: + """Returns the task ID of the payment processor task, if it exists. + + Identified by assuming the first message with a task ID that is not the + merchant's task ID is a payment processor message. + """ + if task is None: + return None + for message in task.history: + if message.task_id != task.id: + return message.task_id return None - for message in task.history: - if message.task_id != task.id: - return message.task_id - return None async def _fail_task(updater: TaskUpdater, error_text: str) -> None: - """A helper function to fail a task with a given error message.""" - error_message = updater.new_agent_message( - parts=[Part(root=TextPart(text=error_text))] - ) - await updater.failed(message=error_message) + """A helper function to fail a task with a given error message.""" + error_message = updater.new_agent_message( + parts=[Part(root=TextPart(text=error_text))] + ) + await updater.failed(message=error_message) diff --git a/samples/python/src/roles/merchant_payment_processor_agent/__main__.py b/samples/python/src/roles/merchant_payment_processor_agent/__main__.py index 4b8706bb..ef376e83 100644 --- a/samples/python/src/roles/merchant_payment_processor_agent/__main__.py +++ b/samples/python/src/roles/merchant_payment_processor_agent/__main__.py @@ -17,20 +17,24 @@ from collections.abc import Sequence from absl import app - -from roles.merchant_payment_processor_agent.agent_executor import PaymentProcessorExecutor from common import server +from roles.merchant_payment_processor_agent.agent_executor import ( + PaymentProcessorExecutor, +) + AGENT_PAYMENT_PROCESSOR_PORT = 8003 + def main(argv: Sequence[str]) -> None: - agent_card = server.load_local_agent_card(__file__) - server.run_agent_blocking( - port=AGENT_PAYMENT_PROCESSOR_PORT, - agent_card=agent_card, - executor=PaymentProcessorExecutor(agent_card.capabilities.extensions), - rpc_url="/a2a/merchant_payment_processor_agent", - ) - -if __name__ == "__main__": - app.run(main) + agent_card = server.load_local_agent_card(__file__) + server.run_agent_blocking( + port=AGENT_PAYMENT_PROCESSOR_PORT, + agent_card=agent_card, + executor=PaymentProcessorExecutor(agent_card.capabilities.extensions), + rpc_url='/a2a/merchant_payment_processor_agent', + ) + + +if __name__ == '__main__': + app.run(main) diff --git a/samples/python/src/roles/merchant_payment_processor_agent/agent_executor.py b/samples/python/src/roles/merchant_payment_processor_agent/agent_executor.py index ea28dfc4..548d2940 100644 --- a/samples/python/src/roles/merchant_payment_processor_agent/agent_executor.py +++ b/samples/python/src/roles/merchant_payment_processor_agent/agent_executor.py @@ -26,29 +26,30 @@ invoking it to complete a task. """ - from typing import Any -from . import tools from common.base_server_executor import BaseServerExecutor from common.system_utils import DEBUG_MODE_INSTRUCTIONS - +from . import tools class PaymentProcessorExecutor(BaseServerExecutor): - """AgentExecutor for the merchant payment processor agent.""" + """AgentExecutor for the merchant payment processor agent.""" - _system_prompt = """ + _system_prompt = ( + """ You are a payment processor agent. Your role is to process payments on behalf of a merchant. %s - """ % DEBUG_MODE_INSTRUCTIONS - - def __init__(self, supported_extensions: list[dict[str, Any]] = None): - """Initializes the PaymentProcessorExecutor.""" - agent_tools = [ - tools.initiate_payment, - ] - super().__init__(supported_extensions, agent_tools, self._system_prompt) + """ + % DEBUG_MODE_INSTRUCTIONS + ) + + def __init__(self, supported_extensions: list[dict[str, Any]] = None): + """Initializes the PaymentProcessorExecutor.""" + agent_tools = [ + tools.initiate_payment, + ] + super().__init__(supported_extensions, agent_tools, self._system_prompt) diff --git a/samples/python/src/roles/merchant_payment_processor_agent/tools.py b/samples/python/src/roles/merchant_payment_processor_agent/tools.py index cea3c8f2..b97b2607 100644 --- a/samples/python/src/roles/merchant_payment_processor_agent/tools.py +++ b/samples/python/src/roles/merchant_payment_processor_agent/tools.py @@ -18,25 +18,19 @@ shopping and purchasing process. """ - import logging + from typing import Any from a2a.server.tasks.task_updater import TaskUpdater -from a2a.types import DataPart -from a2a.types import Part -from a2a.types import Task -from a2a.types import TaskState -from a2a.types import TextPart - -from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY -from ap2.types.mandate import PaymentMandate -from common import artifact_utils -from common import message_utils +from a2a.types import DataPart, Part, Task, TaskState, TextPart +from common import artifact_utils, message_utils from common.a2a_extension_utils import EXTENSION_URI from common.a2a_message_builder import A2aMessageBuilder from common.payment_remote_a2a_client import PaymentRemoteA2aClient +from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY, PaymentMandate + async def initiate_payment( data_parts: list[dict[str, Any]], @@ -44,25 +38,27 @@ async def initiate_payment( current_task: Task | None, debug_mode: bool = False, ) -> None: - """Handles the initiation of a payment.""" - payment_mandate = message_utils.find_data_part( - PAYMENT_MANDATE_DATA_KEY, data_parts - ) - if not payment_mandate: - error_message = _create_text_parts("Missing payment_mandate.") - await updater.failed(message=updater.new_agent_message(parts=error_message)) - return - - challenge_response = ( - message_utils.find_data_part("challenge_response", data_parts) or "" - ) - await _handle_payment_mandate( - PaymentMandate.model_validate(payment_mandate), - challenge_response, - updater, - current_task, - debug_mode, - ) + """Handles the initiation of a payment.""" + payment_mandate = message_utils.find_data_part( + PAYMENT_MANDATE_DATA_KEY, data_parts + ) + if not payment_mandate: + error_message = _create_text_parts('Missing payment_mandate.') + await updater.failed( + message=updater.new_agent_message(parts=error_message) + ) + return + + challenge_response = ( + message_utils.find_data_part('challenge_response', data_parts) or '' + ) + await _handle_payment_mandate( + PaymentMandate.model_validate(payment_mandate), + challenge_response, + updater, + current_task, + debug_mode, + ) async def _handle_payment_mandate( @@ -72,61 +68,61 @@ async def _handle_payment_mandate( current_task: Task | None, debug_mode: bool = False, ) -> None: - """Handles a payment mandate. - - If no task is present, it initiates a transaction challenge. If a task - requires input, it verifies the challenge response and completes the payment. - - Args: - payment_mandate: The payment mandate containing payment details. - challenge_response: The response to a transaction challenge, if any. - updater: The task updater for managing task state. - current_task: The current task, or None if it's a new payment. - debug_mode: Whether the agent is in debug mode. - """ - if current_task is None: - await _raise_challenge(updater) - return - - if current_task.status.state == TaskState.input_required: - await _check_challenge_response_and_complete_payment( - payment_mandate, - challenge_response, - updater, - debug_mode, - ) - return + """Handles a payment mandate. + + If no task is present, it initiates a transaction challenge. If a task + requires input, it verifies the challenge response and completes the payment. + + Args: + payment_mandate: The payment mandate containing payment details. + challenge_response: The response to a transaction challenge, if any. + updater: The task updater for managing task state. + current_task: The current task, or None if it's a new payment. + debug_mode: Whether the agent is in debug mode. + """ + if current_task is None: + await _raise_challenge(updater) + return + + if current_task.status.state == TaskState.input_required: + await _check_challenge_response_and_complete_payment( + payment_mandate, + challenge_response, + updater, + debug_mode, + ) + return async def _raise_challenge( updater: TaskUpdater, ) -> None: - """Raises a transaction challenge. - - This challenge would normally be raised by the issuer, but we don't - have an issuer in the demo, so we raise the challenge here. For concreteness, - we are using an OTP challenge in this sample. - - Args: - updater: The task updater. - """ - challenge_data = { - "type": "otp", - "display_text": ( - "The payment method issuer sent a verification code to the phone " - "number on file, please enter it below. It will be shared with the " - "issuer so they can authorize the transaction." - "(Demo only hint: the code is 123)" - ), - } - text_part = TextPart( - text="Please provide the challenge response to complete the payment." - ) - data_part = DataPart(data={"challenge": challenge_data}) - message = updater.new_agent_message( - parts=[Part(root=text_part), Part(root=data_part)] - ) - await updater.requires_input(message=message) + """Raises a transaction challenge. + + This challenge would normally be raised by the issuer, but we don't + have an issuer in the demo, so we raise the challenge here. For concreteness, + we are using an OTP challenge in this sample. + + Args: + updater: The task updater. + """ + challenge_data = { + 'type': 'otp', + 'display_text': ( + 'The payment method issuer sent a verification code to the phone ' + 'number on file, please enter it below. It will be shared with the ' + 'issuer so they can authorize the transaction.' + '(Demo only hint: the code is 123)' + ), + } + text_part = TextPart( + text='Please provide the challenge response to complete the payment.' + ) + data_part = DataPart(data={'challenge': challenge_data}) + message = updater.new_agent_message( + parts=[Part(root=text_part), Part(root=data_part)] + ) + await updater.requires_input(message=message) async def _check_challenge_response_and_complete_payment( @@ -135,25 +131,25 @@ async def _check_challenge_response_and_complete_payment( updater: TaskUpdater, debug_mode: bool = False, ) -> None: - """Checks the challenge response and completes the payment process. - - Checking the challenge response would be done by the issuer, but we don't - have an issuer in the demo, so we do it here. - - Args: - payment_mandate: The payment mandate. - challenge_response: The challenge response. - updater: The task updater. - debug_mode: Whether the agent is in debug mode. - """ - if _challenge_response_is_valid(challenge_response=challenge_response): - await _complete_payment(payment_mandate, updater, debug_mode) - return - - message = updater.new_agent_message( - _create_text_parts("Challenge response incorrect.") - ) - await updater.requires_input(message=message) + """Checks the challenge response and completes the payment process. + + Checking the challenge response would be done by the issuer, but we don't + have an issuer in the demo, so we do it here. + + Args: + payment_mandate: The payment mandate. + challenge_response: The challenge response. + updater: The task updater. + debug_mode: Whether the agent is in debug mode. + """ + if _challenge_response_is_valid(challenge_response=challenge_response): + await _complete_payment(payment_mandate, updater, debug_mode) + return + + message = updater.new_agent_message( + _create_text_parts('Challenge response incorrect.') + ) + await updater.requires_input(message=message) async def _complete_payment( @@ -161,36 +157,35 @@ async def _complete_payment( updater: TaskUpdater, debug_mode: bool = False, ) -> None: - """Completes the payment process. - - Args: - payment_mandate: The payment mandate. - updater: The task updater. - debug_mode: Whether the agent is in debug mode. - """ - payment_mandate_id = ( - payment_mandate.payment_mandate_contents.payment_mandate_id - ) - payment_credential = await _request_payment_credential( - payment_mandate, updater, debug_mode - ) - - logging.info( - "Calling issuer to complete payment for %s with payment credential %s...", - payment_mandate_id, - payment_credential, - ) - # Call issuer to complete the payment - success_message = updater.new_agent_message( - parts=_create_text_parts("{'status': 'success'}") - ) - await updater.complete(message=success_message) + """Completes the payment process. + + Args: + payment_mandate: The payment mandate. + updater: The task updater. + debug_mode: Whether the agent is in debug mode. + """ + payment_mandate_id = ( + payment_mandate.payment_mandate_contents.payment_mandate_id + ) + payment_credential = await _request_payment_credential( + payment_mandate, updater, debug_mode + ) + logging.info( + 'Calling issuer to complete payment for %s with payment credential %s...', + payment_mandate_id, + payment_credential, + ) + # Call issuer to complete the payment + success_message = updater.new_agent_message( + parts=_create_text_parts("{'status': 'success'}") + ) + await updater.complete(message=success_message) -def _challenge_response_is_valid(challenge_response: str) -> bool: - """Validates the challenge response.""" - return challenge_response == "123" +def _challenge_response_is_valid(challenge_response: str) -> bool: + """Validates the challenge response.""" + return challenge_response == '123' async def _request_payment_credential( @@ -198,45 +193,45 @@ async def _request_payment_credential( updater: TaskUpdater, debug_mode: bool = False, ) -> str: - """Sends a request to the Credentials Provider for payment credentials. - - Args: - payment_mandate: The PaymentMandate containing payment details. - updater: The task updater. - debug_mode: Whether the agent is in debug mode. - - Returns: - payment_credential: The payment credential details. - """ - token_object = ( - payment_mandate.payment_mandate_contents.payment_response.details.get( - "token" - ) - ) - credentials_provider_url = token_object.get("url") - - credentials_provider = PaymentRemoteA2aClient( - name="credentials_provider", - base_url=credentials_provider_url, - required_extensions={EXTENSION_URI}, - ) - - message_builder = ( - A2aMessageBuilder() - .set_context_id(updater.context_id) - .add_text("Give me the payment method credentials for the given token.") - .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate.model_dump()) - .add_data("debug_mode", debug_mode) - ) - task = await credentials_provider.send_a2a_message(message_builder.build()) - - if not task.artifacts: - raise ValueError("Failed to find the payment method data.") - payment_credential = artifact_utils.get_first_data_part(task.artifacts) - - return payment_credential + """Sends a request to the Credentials Provider for payment credentials. + + Args: + payment_mandate: The PaymentMandate containing payment details. + updater: The task updater. + debug_mode: Whether the agent is in debug mode. + + Returns: + payment_credential: The payment credential details. + """ + token_object = ( + payment_mandate.payment_mandate_contents.payment_response.details.get( + 'token' + ) + ) + credentials_provider_url = token_object.get('url') + + credentials_provider = PaymentRemoteA2aClient( + name='credentials_provider', + base_url=credentials_provider_url, + required_extensions={EXTENSION_URI}, + ) + + message_builder = ( + A2aMessageBuilder() + .set_context_id(updater.context_id) + .add_text('Give me the payment method credentials for the given token.') + .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate.model_dump()) + .add_data('debug_mode', debug_mode) + ) + task = await credentials_provider.send_a2a_message(message_builder.build()) + + if not task.artifacts: + raise ValueError('Failed to find the payment method data.') + payment_credential = artifact_utils.get_first_data_part(task.artifacts) + + return payment_credential def _create_text_parts(*texts: str) -> list[Part]: - """Helper to create text parts.""" - return [Part(root=TextPart(text=text)) for text in texts] + """Helper to create text parts.""" + return [Part(root=TextPart(text=text)) for text in texts] diff --git a/samples/python/src/roles/shopping_agent/agent.py b/samples/python/src/roles/shopping_agent/agent.py index 37c91a31..5c3353d9 100644 --- a/samples/python/src/roles/shopping_agent/agent.py +++ b/samples/python/src/roles/shopping_agent/agent.py @@ -19,21 +19,24 @@ 2. Help complete the purchase of their chosen items. The Google ADK powers this shopping agent, chosen for its simplicity and -efficiency in developing robust LLM agents. +efficiency in developing robust LLM agents. """ +from common.retrying_llm_agent import RetryingLlmAgent +from common.system_utils import DEBUG_MODE_INSTRUCTIONS + from . import tools from .subagents.payment_method_collector.agent import payment_method_collector -from .subagents.shipping_address_collector.agent import shipping_address_collector +from .subagents.shipping_address_collector.agent import ( + shipping_address_collector, +) from .subagents.shopper.agent import shopper -from common.retrying_llm_agent import RetryingLlmAgent -from common.system_utils import DEBUG_MODE_INSTRUCTIONS root_agent = RetryingLlmAgent( max_retries=5, - model="gemini-2.5-flash", - name="root_agent", + model='gemini-2.5-flash', + name='root_agent', instruction=""" You are a shopping agent responsible for helping users find and purchase products from merchants. @@ -107,7 +110,8 @@ 1. Respond to the user with this message: "Hi, I'm your shopping assistant. How can I help you? For example, you can say 'I want to buy a pair of shoes'" - """ % DEBUG_MODE_INSTRUCTIONS, + """ + % DEBUG_MODE_INSTRUCTIONS, tools=[ tools.create_payment_mandate, tools.initiate_payment, diff --git a/samples/python/src/roles/shopping_agent/remote_agents.py b/samples/python/src/roles/shopping_agent/remote_agents.py index 9c2a492f..e9557365 100644 --- a/samples/python/src/roles/shopping_agent/remote_agents.py +++ b/samples/python/src/roles/shopping_agent/remote_agents.py @@ -26,8 +26,8 @@ credentials_provider_client = PaymentRemoteA2aClient( - name="credentials_provider", - base_url="http://localhost:8002/a2a/credentials_provider", + name='credentials_provider', + base_url='http://localhost:8002/a2a/credentials_provider', required_extensions={ EXTENSION_URI, }, @@ -35,8 +35,8 @@ merchant_agent_client = PaymentRemoteA2aClient( - name="merchant_agent", - base_url="http://localhost:8001/a2a/merchant_agent", + name='merchant_agent', + base_url='http://localhost:8001/a2a/merchant_agent', required_extensions={ EXTENSION_URI, }, diff --git a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py index 446f83fe..e7d5a76a 100644 --- a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py +++ b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/agent.py @@ -25,14 +25,15 @@ provider, which is then sent to the merchant agent for payment. """ -from . import tools from common.retrying_llm_agent import RetryingLlmAgent from common.system_utils import DEBUG_MODE_INSTRUCTIONS +from . import tools + payment_method_collector = RetryingLlmAgent( - model="gemini-2.5-flash", - name="payment_method_collector", + model='gemini-2.5-flash', + name='payment_method_collector', max_retries=5, instruction=""" You are an agent responsible for obtaining the user's payment method for a @@ -70,7 +71,8 @@ 5. Call the `get_payment_credential_token` tool to get the payment credential token with the user_email and payment_method_alias. 6. Transfer back to the root_agent with the payment_method_alias. - """ % DEBUG_MODE_INSTRUCTIONS, + """ + % DEBUG_MODE_INSTRUCTIONS, tools=[ tools.get_payment_methods, tools.get_payment_credential_token, diff --git a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py index f24fe572..1af8671d 100644 --- a/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py +++ b/samples/python/src/roles/shopping_agent/subagents/payment_method_collector/tools.py @@ -18,46 +18,46 @@ shopping and purchasing process. """ +from common import artifact_utils +from common.a2a_message_builder import A2aMessageBuilder from google.adk.tools.tool_context import ToolContext +from roles.shopping_agent.remote_agents import credentials_provider_client from ap2.types.payment_request import PAYMENT_METHOD_DATA_DATA_KEY -from common.a2a_message_builder import A2aMessageBuilder -from common import artifact_utils -from roles.shopping_agent.remote_agents import credentials_provider_client async def get_payment_methods( user_email: str, tool_context: ToolContext, ) -> list[str]: - """Gets the user's payment methods from the credentials provider. + """Gets the user's payment methods from the credentials provider. - These will match the payment method on the cart being purchased. + These will match the payment method on the cart being purchased. - Args: - user_email: Identifies the user's account - tool_context: The ADK supplied tool context. + Args: + user_email: Identifies the user's account + tool_context: The ADK supplied tool context. - Returns: - A dictionary of the user's applicable payment methods. - """ - cart_mandate = tool_context.state["cart_mandate"] - message_builder = ( - A2aMessageBuilder() - .set_context_id(tool_context.state["shopping_context_id"]) - .add_text("Get a filtered list of the user's payment methods.") - .add_data("user_email", user_email) - ) - for method_data in cart_mandate.contents.payment_request.method_data: - message_builder.add_data( - PAYMENT_METHOD_DATA_DATA_KEY, - method_data.model_dump(), + Returns: + A dictionary of the user's applicable payment methods. + """ + cart_mandate = tool_context.state['cart_mandate'] + message_builder = ( + A2aMessageBuilder() + .set_context_id(tool_context.state['shopping_context_id']) + .add_text("Get a filtered list of the user's payment methods.") + .add_data('user_email', user_email) + ) + for method_data in cart_mandate.contents.payment_request.method_data: + message_builder.add_data( + PAYMENT_METHOD_DATA_DATA_KEY, + method_data.model_dump(), + ) + task = await credentials_provider_client.send_a2a_message( + message_builder.build() ) - task = await credentials_provider_client.send_a2a_message( - message_builder.build() - ) - payment_methods = artifact_utils.get_first_data_part(task.artifacts) - return payment_methods + payment_methods = artifact_utils.get_first_data_part(task.artifacts) + return payment_methods async def get_payment_credential_token( @@ -65,33 +65,35 @@ async def get_payment_credential_token( payment_method_alias: str, tool_context: ToolContext, ) -> str: - """Gets a payment credential token from the credentials provider. + """Gets a payment credential token from the credentials provider. - Args: - user_email: The user's email address. - payment_method_alias: The payment method alias. - tool_context: The ADK supplied tool context. + Args: + user_email: The user's email address. + payment_method_alias: The payment method alias. + tool_context: The ADK supplied tool context. - Returns: - Status of the call and the payment credential token. - """ - message = ( - A2aMessageBuilder() - .set_context_id(tool_context.state["shopping_context_id"]) - .add_text("Get a payment credential token for the user's payment method.") - .add_data("payment_method_alias", payment_method_alias) - .add_data("user_email", user_email) - .build() - ) - task = await credentials_provider_client.send_a2a_message(message) - data = artifact_utils.get_first_data_part(task.artifacts) - token = data.get("token") - credentials_provider_agent_card = ( - await credentials_provider_client.get_agent_card() - ) + Returns: + Status of the call and the payment credential token. + """ + message = ( + A2aMessageBuilder() + .set_context_id(tool_context.state['shopping_context_id']) + .add_text( + "Get a payment credential token for the user's payment method." + ) + .add_data('payment_method_alias', payment_method_alias) + .add_data('user_email', user_email) + .build() + ) + task = await credentials_provider_client.send_a2a_message(message) + data = artifact_utils.get_first_data_part(task.artifacts) + token = data.get('token') + credentials_provider_agent_card = ( + await credentials_provider_client.get_agent_card() + ) - tool_context.state["payment_credential_token"] = { - "value": token, - "url": credentials_provider_agent_card.url, - } - return {"status": "success", "token": token} + tool_context.state['payment_credential_token'] = { + 'value': token, + 'url': credentials_provider_agent_card.url, + } + return {'status': 'success', 'token': token} diff --git a/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py b/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py index 0407b2b6..4ab48ff2 100644 --- a/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py +++ b/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/agent.py @@ -26,13 +26,15 @@ This is just one of many possible approaches. """ -from . import tools from common.retrying_llm_agent import RetryingLlmAgent from common.system_utils import DEBUG_MODE_INSTRUCTIONS +from . import tools + + shipping_address_collector = RetryingLlmAgent( - model="gemini-2.5-flash", - name="shipping_address_collector", + model='gemini-2.5-flash', + name='shipping_address_collector', max_retries=5, instruction=""" You are an agent responsible for obtaining the user's shipping address. @@ -73,7 +75,8 @@ 1. Collect the user's shipping address. Ensure you have collected all of the necessary parts of a US address. 2. Transfer back to the root_agent with the shipping address. - """ % DEBUG_MODE_INSTRUCTIONS, + """ + % DEBUG_MODE_INSTRUCTIONS, tools=[ tools.get_shipping_address, ], diff --git a/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/tools.py b/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/tools.py index fd9e506c..1676563b 100644 --- a/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/tools.py +++ b/samples/python/src/roles/shopping_agent/subagents/shipping_address_collector/tools.py @@ -19,42 +19,41 @@ """ from a2a.types import Artifact -from google.adk.tools.tool_context import ToolContext - -from ap2.types.contact_picker import CONTACT_ADDRESS_DATA_KEY -from ap2.types.contact_picker import ContactAddress from common import artifact_utils from common.a2a_message_builder import A2aMessageBuilder +from google.adk.tools.tool_context import ToolContext from roles.shopping_agent.remote_agents import credentials_provider_client +from ap2.types.contact_picker import CONTACT_ADDRESS_DATA_KEY, ContactAddress + async def get_shipping_address( user_email: str, tool_context: ToolContext, ) -> ContactAddress: - """Gets the user's shipping address from the credentials provider. - - Args: - user_email: The ID of the user to get the shipping address for. - tool_context: The ADK supplied tool context. - - Returns: - The user's shipping address. - """ - message = ( - A2aMessageBuilder() - .set_context_id(tool_context.state["shopping_context_id"]) - .add_text("Get the user's shipping address.") - .add_data("user_email", user_email) - .build() - ) - task = await credentials_provider_client.send_a2a_message(message) - shipping_address = artifact_utils.only(_parse_addresses(task.artifacts)) - return shipping_address + """Gets the user's shipping address from the credentials provider. + + Args: + user_email: The ID of the user to get the shipping address for. + tool_context: The ADK supplied tool context. + + Returns: + The user's shipping address. + """ + message = ( + A2aMessageBuilder() + .set_context_id(tool_context.state['shopping_context_id']) + .add_text("Get the user's shipping address.") + .add_data('user_email', user_email) + .build() + ) + task = await credentials_provider_client.send_a2a_message(message) + shipping_address = artifact_utils.only(_parse_addresses(task.artifacts)) + return shipping_address def _parse_addresses(artifacts: list[Artifact]) -> list[ContactAddress]: - """Parses a list of artifacts into a list of ContactAddress objects.""" - return artifact_utils.find_canonical_objects( - artifacts, CONTACT_ADDRESS_DATA_KEY, ContactAddress - ) + """Parses a list of artifacts into a list of ContactAddress objects.""" + return artifact_utils.find_canonical_objects( + artifacts, CONTACT_ADDRESS_DATA_KEY, ContactAddress + ) diff --git a/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py b/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py index d380fac7..0ba641b2 100644 --- a/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py +++ b/samples/python/src/roles/shopping_agent/subagents/shopper/agent.py @@ -24,14 +24,15 @@ This is just one of many possible approaches. """ -from . import tools from common.retrying_llm_agent import RetryingLlmAgent from common.system_utils import DEBUG_MODE_INSTRUCTIONS +from . import tools + shopper = RetryingLlmAgent( - model="gemini-2.5-flash", - name="shopper", + model='gemini-2.5-flash', + name='shopper', max_retries=5, instruction=""" You are an agent responsible for helping the user shop for products. @@ -98,7 +99,8 @@ 9. Monitor the tool's output. If the cart ID is not found, you must inform the user and prompt them to try again. If the selection is successful, signal a successful update and hand off the process to the root_agent. - """ % DEBUG_MODE_INSTRUCTIONS, + """ + % DEBUG_MODE_INSTRUCTIONS, tools=[ tools.create_intent_mandate, tools.find_products, diff --git a/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py b/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py index cee16eff..96993b7f 100644 --- a/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py +++ b/samples/python/src/roles/shopping_agent/subagents/shopper/tools.py @@ -18,21 +18,21 @@ shopping and purchasing process. """ -from datetime import datetime -from datetime import timedelta -from datetime import timezone +from datetime import UTC, datetime, timedelta from a2a.types import Artifact -from google.adk.tools.tool_context import ToolContext - -from ap2.types.mandate import CART_MANDATE_DATA_KEY -from ap2.types.mandate import CartMandate -from ap2.types.mandate import INTENT_MANDATE_DATA_KEY -from ap2.types.mandate import IntentMandate from common.a2a_message_builder import A2aMessageBuilder from common.artifact_utils import find_canonical_objects +from google.adk.tools.tool_context import ToolContext from roles.shopping_agent.remote_agents import merchant_agent_client +from ap2.types.mandate import ( + CART_MANDATE_DATA_KEY, + INTENT_MANDATE_DATA_KEY, + CartMandate, + IntentMandate, +) + def create_intent_mandate( natural_language_description: str, @@ -42,100 +42,100 @@ def create_intent_mandate( requires_refundability: bool, tool_context: ToolContext, ) -> IntentMandate: - """Creates an IntentMandate object. - - Args: - natural_language_description: The description of the user's intent. - user_cart_confirmation_required: If the user must confirm the cart. - merchants: A list of allowed merchants. - skus: A list of allowed SKUs. - requires_refundability: If the items must be refundable. - tool_context: The ADK supplied tool context. - - Returns: - An IntentMandate object valid for 1 day. - """ - intent_mandate = IntentMandate( - natural_language_description=natural_language_description, - user_cart_confirmation_required=user_cart_confirmation_required, - merchants=merchants, - skus=skus, - requires_refundability=requires_refundability, - intent_expiry=( - datetime.now(timezone.utc) + timedelta(days=1) - ).isoformat(), - ) - tool_context.state["intent_mandate"] = intent_mandate - return intent_mandate + """Creates an IntentMandate object. + + Args: + natural_language_description: The description of the user's intent. + user_cart_confirmation_required: If the user must confirm the cart. + merchants: A list of allowed merchants. + skus: A list of allowed SKUs. + requires_refundability: If the items must be refundable. + tool_context: The ADK supplied tool context. + + Returns: + An IntentMandate object valid for 1 day. + """ + intent_mandate = IntentMandate( + natural_language_description=natural_language_description, + user_cart_confirmation_required=user_cart_confirmation_required, + merchants=merchants, + skus=skus, + requires_refundability=requires_refundability, + intent_expiry=(datetime.now(UTC) + timedelta(days=1)).isoformat(), + ) + tool_context.state['intent_mandate'] = intent_mandate + return intent_mandate async def find_products( tool_context: ToolContext, debug_mode: bool = False ) -> list[CartMandate]: - """Calls the merchant agent to find products matching the user's intent. - - Args: - tool_context: The ADK supplied tool context. - debug_mode: Whether the agent is in debug mode. - - Returns: - A list of CartMandate objects. - - Raises: - RuntimeError: If the merchant agent fails to provide products. - """ - intent_mandate = tool_context.state["intent_mandate"] - if not intent_mandate: - raise RuntimeError("No IntentMandate found in tool context state.") - risk_data = _collect_risk_data(tool_context) - if not risk_data: - raise RuntimeError("No risk data found in tool context state.") - message = ( - A2aMessageBuilder() - .add_text("Find products that match the user's IntentMandate.") - .add_data(INTENT_MANDATE_DATA_KEY, intent_mandate.model_dump()) - .add_data("risk_data", risk_data) - .add_data("debug_mode", debug_mode) - .add_data("shopping_agent_id", "trusted_shopping_agent") - .build() - ) - task = await merchant_agent_client.send_a2a_message(message) - - if task.status.state != "completed": - raise RuntimeError(f"Failed to find products: {task.status}") - - tool_context.state["shopping_context_id"] = task.context_id - cart_mandates = _parse_cart_mandates(task.artifacts) - tool_context.state["cart_mandates"] = cart_mandates - return cart_mandates + """Calls the merchant agent to find products matching the user's intent. + + Args: + tool_context: The ADK supplied tool context. + debug_mode: Whether the agent is in debug mode. + + Returns: + A list of CartMandate objects. + + Raises: + RuntimeError: If the merchant agent fails to provide products. + """ + intent_mandate = tool_context.state['intent_mandate'] + if not intent_mandate: + raise RuntimeError('No IntentMandate found in tool context state.') + risk_data = _collect_risk_data(tool_context) + if not risk_data: + raise RuntimeError('No risk data found in tool context state.') + message = ( + A2aMessageBuilder() + .add_text("Find products that match the user's IntentMandate.") + .add_data(INTENT_MANDATE_DATA_KEY, intent_mandate.model_dump()) + .add_data('risk_data', risk_data) + .add_data('debug_mode', debug_mode) + .add_data('shopping_agent_id', 'trusted_shopping_agent') + .build() + ) + task = await merchant_agent_client.send_a2a_message(message) + + if task.status.state != 'completed': + raise RuntimeError(f'Failed to find products: {task.status}') + + tool_context.state['shopping_context_id'] = task.context_id + cart_mandates = _parse_cart_mandates(task.artifacts) + tool_context.state['cart_mandates'] = cart_mandates + return cart_mandates def update_chosen_cart_mandate(cart_id: str, tool_context: ToolContext) -> str: - """Updates the chosen CartMandate in the tool context state. - - Args: - cart_id: The ID of the chosen cart. - tool_context: The ADK supplied tool context. - """ - cart_mandates: list[CartMandate] = tool_context.state.get("cart_mandates", []) - for cart in cart_mandates: - print( - f"Checking cart with ID: {cart.contents.id} with chosen ID: {cart_id}" + """Updates the chosen CartMandate in the tool context state. + + Args: + cart_id: The ID of the chosen cart. + tool_context: The ADK supplied tool context. + """ + cart_mandates: list[CartMandate] = tool_context.state.get( + 'cart_mandates', [] ) - if cart.contents.id == cart_id: - tool_context.state["chosen_cart_id"] = cart_id - return f"CartMandate with ID {cart_id} selected." - return f"CartMandate with ID {cart_id} not found." + for cart in cart_mandates: + print( + f'Checking cart with ID: {cart.contents.id} with chosen ID: {cart_id}' + ) + if cart.contents.id == cart_id: + tool_context.state['chosen_cart_id'] = cart_id + return f'CartMandate with ID {cart_id} selected.' + return f'CartMandate with ID {cart_id} not found.' def _parse_cart_mandates(artifacts: list[Artifact]) -> list[CartMandate]: - """Parses a list of artifacts into a list of CartMandate objects.""" - return find_canonical_objects(artifacts, CART_MANDATE_DATA_KEY, CartMandate) + """Parses a list of artifacts into a list of CartMandate objects.""" + return find_canonical_objects(artifacts, CART_MANDATE_DATA_KEY, CartMandate) def _collect_risk_data(tool_context: ToolContext) -> dict: - """Creates a risk_data in the tool_context.""" - # This is a fake risk data for demonstration purposes. - risk_data = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...fake_risk_data" - tool_context.state["risk_data"] = risk_data - return risk_data + """Creates a risk_data in the tool_context.""" + # This is a fake risk data for demonstration purposes. + risk_data = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9...fake_risk_data' + tool_context.state['risk_data'] = risk_data + return risk_data diff --git a/samples/python/src/roles/shopping_agent/tools.py b/samples/python/src/roles/shopping_agent/tools.py index 474abfc5..c661a895 100644 --- a/samples/python/src/roles/shopping_agent/tools.py +++ b/samples/python/src/roles/shopping_agent/tools.py @@ -18,24 +18,26 @@ shopping and purchasing process, such as updating a cart or initiating payment. """ -from datetime import datetime -from datetime import timezone import uuid +from datetime import UTC, datetime + from a2a.types import Artifact +from common import artifact_utils +from common.a2a_message_builder import A2aMessageBuilder from google.adk.tools.tool_context import ToolContext -from .remote_agents import credentials_provider_client -from .remote_agents import merchant_agent_client from ap2.types.contact_picker import ContactAddress -from ap2.types.mandate import CART_MANDATE_DATA_KEY -from ap2.types.mandate import CartMandate -from ap2.types.mandate import PAYMENT_MANDATE_DATA_KEY -from ap2.types.mandate import PaymentMandate -from ap2.types.mandate import PaymentMandateContents +from ap2.types.mandate import ( + CART_MANDATE_DATA_KEY, + PAYMENT_MANDATE_DATA_KEY, + CartMandate, + PaymentMandate, + PaymentMandateContents, +) from ap2.types.payment_request import PaymentResponse -from common import artifact_utils -from common.a2a_message_builder import A2aMessageBuilder + +from .remote_agents import credentials_provider_client, merchant_agent_client async def update_cart( @@ -43,112 +45,122 @@ async def update_cart( tool_context: ToolContext, debug_mode: bool = False, ) -> str: - """Notifies the merchant agent of a shipping address selection for a cart. - - Args: - shipping_address: The user's selected shipping address. - tool_context: The ADK supplied tool context. - debug_mode: Whether the agent is in debug mode. - - Returns: - The updated CartMandate. - """ - chosen_cart_id = tool_context.state["chosen_cart_id"] - if not chosen_cart_id: - raise RuntimeError("No chosen cart mandate found in tool context state.") - - message = ( - A2aMessageBuilder() - .set_context_id(tool_context.state["shopping_context_id"]) - .add_text("Update the cart with the user's shipping address.") - .add_data("cart_id", chosen_cart_id) - .add_data("shipping_address", shipping_address) - .add_data("shopping_agent_id", "trusted_shopping_agent") - .add_data("debug_mode", debug_mode) - .build() - ) - task = await merchant_agent_client.send_a2a_message(message) - - updated_cart_mandate = artifact_utils.only( - _parse_cart_mandates(task.artifacts) - ) - - tool_context.state["cart_mandate"] = updated_cart_mandate - tool_context.state["shipping_address"] = shipping_address - - return updated_cart_mandate + """Notifies the merchant agent of a shipping address selection for a cart. + + Args: + shipping_address: The user's selected shipping address. + tool_context: The ADK supplied tool context. + debug_mode: Whether the agent is in debug mode. + + Returns: + The updated CartMandate. + """ + chosen_cart_id = tool_context.state['chosen_cart_id'] + if not chosen_cart_id: + raise RuntimeError( + 'No chosen cart mandate found in tool context state.' + ) + + message = ( + A2aMessageBuilder() + .set_context_id(tool_context.state['shopping_context_id']) + .add_text("Update the cart with the user's shipping address.") + .add_data('cart_id', chosen_cart_id) + .add_data('shipping_address', shipping_address) + .add_data('shopping_agent_id', 'trusted_shopping_agent') + .add_data('debug_mode', debug_mode) + .build() + ) + task = await merchant_agent_client.send_a2a_message(message) + + updated_cart_mandate = artifact_utils.only( + _parse_cart_mandates(task.artifacts) + ) + + tool_context.state['cart_mandate'] = updated_cart_mandate + tool_context.state['shipping_address'] = shipping_address + + return updated_cart_mandate async def initiate_payment(tool_context: ToolContext, debug_mode: bool = False): - """Initiates a payment using the payment mandate from state. - - Args: - tool_context: The ADK supplied tool context. - debug_mode: Whether the agent is in debug mode. - - Returns: - The status of the payment initiation. - """ - payment_mandate = tool_context.state["signed_payment_mandate"] - if not payment_mandate: - raise RuntimeError("No signed payment mandate found in tool context state.") - risk_data = tool_context.state["risk_data"] - if not risk_data: - raise RuntimeError("No risk data found in tool context state.") - - outgoing_message_builder = ( - A2aMessageBuilder() - .set_context_id(tool_context.state["shopping_context_id"]) - .add_text("Initiate a payment") - .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate) - .add_data("risk_data", risk_data) - .add_data("shopping_agent_id", "trusted_shopping_agent") - .add_data("debug_mode", debug_mode) - .build() - ) - task = await merchant_agent_client.send_a2a_message(outgoing_message_builder) - tool_context.state["initiate_payment_task_id"] = task.id - return task.status + """Initiates a payment using the payment mandate from state. + + Args: + tool_context: The ADK supplied tool context. + debug_mode: Whether the agent is in debug mode. + + Returns: + The status of the payment initiation. + """ + payment_mandate = tool_context.state['signed_payment_mandate'] + if not payment_mandate: + raise RuntimeError( + 'No signed payment mandate found in tool context state.' + ) + risk_data = tool_context.state['risk_data'] + if not risk_data: + raise RuntimeError('No risk data found in tool context state.') + + outgoing_message_builder = ( + A2aMessageBuilder() + .set_context_id(tool_context.state['shopping_context_id']) + .add_text('Initiate a payment') + .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate) + .add_data('risk_data', risk_data) + .add_data('shopping_agent_id', 'trusted_shopping_agent') + .add_data('debug_mode', debug_mode) + .build() + ) + task = await merchant_agent_client.send_a2a_message( + outgoing_message_builder + ) + tool_context.state['initiate_payment_task_id'] = task.id + return task.status async def initiate_payment_with_otp( challenge_response: str, tool_context: ToolContext, debug_mode: bool = False ): - """Initiates a payment using the payment mandate from state and a - - challenge response. In our sample, the challenge response is a one-time - password (OTP) sent to the user. - - Args: - challenge_response: The challenge response. - tool_context: The ADK supplied tool context. - debug_mode: Whether the agent is in debug mode. - - Returns: - The status of the payment initiation. - """ - payment_mandate = tool_context.state["signed_payment_mandate"] - if not payment_mandate: - raise RuntimeError("No signed payment mandate found in tool context state.") - risk_data = tool_context.state["risk_data"] - if not risk_data: - raise RuntimeError("No risk data found in tool context state.") - - outgoing_message_builder = ( - A2aMessageBuilder() - .set_context_id(tool_context.state["shopping_context_id"]) - .set_task_id(tool_context.state["initiate_payment_task_id"]) - .add_text("Initiate a payment. Include the challenge response.") - .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate) - .add_data("shopping_agent_id", "trusted_shopping_agent") - .add_data("challenge_response", challenge_response) - .add_data("risk_data", risk_data) - .add_data("debug_mode", debug_mode) - .build() - ) - - task = await merchant_agent_client.send_a2a_message(outgoing_message_builder) - return task.status + """Initiates a payment using the payment mandate from state and a + + challenge response. In our sample, the challenge response is a one-time + password (OTP) sent to the user. + + Args: + challenge_response: The challenge response. + tool_context: The ADK supplied tool context. + debug_mode: Whether the agent is in debug mode. + + Returns: + The status of the payment initiation. + """ + payment_mandate = tool_context.state['signed_payment_mandate'] + if not payment_mandate: + raise RuntimeError( + 'No signed payment mandate found in tool context state.' + ) + risk_data = tool_context.state['risk_data'] + if not risk_data: + raise RuntimeError('No risk data found in tool context state.') + + outgoing_message_builder = ( + A2aMessageBuilder() + .set_context_id(tool_context.state['shopping_context_id']) + .set_task_id(tool_context.state['initiate_payment_task_id']) + .add_text('Initiate a payment. Include the challenge response.') + .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate) + .add_data('shopping_agent_id', 'trusted_shopping_agent') + .add_data('challenge_response', challenge_response) + .add_data('risk_data', risk_data) + .add_data('debug_mode', debug_mode) + .build() + ) + + task = await merchant_agent_client.send_a2a_message( + outgoing_message_builder + ) + return task.status def create_payment_mandate( @@ -156,153 +168,156 @@ def create_payment_mandate( user_email: str, tool_context: ToolContext, ) -> str: - """Creates a payment mandate and stores it in state. - - Args: - payment_method_alias: The payment method alias. - user_email: The user's email address. - tool_context: The ADK supplied tool context. - - Returns: - The payment mandate. - """ - cart_mandate = tool_context.state["cart_mandate"] - - payment_request = cart_mandate.contents.payment_request - shipping_address = tool_context.state["shipping_address"] - payment_response = PaymentResponse( - request_id=payment_request.details.id, - method_name="CARD", - details={ - "token": tool_context.state["payment_credential_token"], - }, - shipping_address=shipping_address, - payer_email=user_email, - ) - - payment_mandate = PaymentMandate( - payment_mandate_contents=PaymentMandateContents( - payment_mandate_id=uuid.uuid4().hex, - timestamp=datetime.now(timezone.utc).isoformat(), - payment_details_id=payment_request.details.id, - payment_details_total=payment_request.details.total, - payment_response=payment_response, - merchant_agent=cart_mandate.contents.merchant_name, - ), - ) - - tool_context.state["payment_mandate"] = payment_mandate - return payment_mandate + """Creates a payment mandate and stores it in state. + + Args: + payment_method_alias: The payment method alias. + user_email: The user's email address. + tool_context: The ADK supplied tool context. + + Returns: + The payment mandate. + """ + cart_mandate = tool_context.state['cart_mandate'] + + payment_request = cart_mandate.contents.payment_request + shipping_address = tool_context.state['shipping_address'] + payment_response = PaymentResponse( + request_id=payment_request.details.id, + method_name='CARD', + details={ + 'token': tool_context.state['payment_credential_token'], + }, + shipping_address=shipping_address, + payer_email=user_email, + ) + + payment_mandate = PaymentMandate( + payment_mandate_contents=PaymentMandateContents( + payment_mandate_id=uuid.uuid4().hex, + timestamp=datetime.now(UTC).isoformat(), + payment_details_id=payment_request.details.id, + payment_details_total=payment_request.details.total, + payment_response=payment_response, + merchant_agent=cart_mandate.contents.merchant_name, + ), + ) + + tool_context.state['payment_mandate'] = payment_mandate + return payment_mandate def sign_mandates_on_user_device(tool_context: ToolContext) -> str: - """Simulates signing the transaction details on a user's secure device. - - This function represents the step where the final transaction details, - including hashes of the cart and payment mandates, would be sent to a - secure hardware element on the user's device (e.g., Secure Enclave) to be - cryptographically signed with the user's private key. - - Note: This is a placeholder implementation. It does not perform any actual - cryptographic operations. It simulates the creation of a signature by - concatenating the mandate hashes. - - Args: - tool_context: The context object used for state management. It is expected - to contain the `payment_mandate` and `cart_mandate`. - - Returns: - A string representing the simulated user authorization signature (JWT). - """ - payment_mandate: PaymentMandate = tool_context.state["payment_mandate"] - cart_mandate: CartMandate = tool_context.state["cart_mandate"] - cart_mandate_hash = _generate_cart_mandate_hash(cart_mandate) - payment_mandate_hash = _generate_payment_mandate_hash( - payment_mandate.payment_mandate_contents - ) - # A JWT containing the user's digital signature to authorize the transaction. - # The payload uses hashes to bind the signature to the specific cart and - # payment details, and includes a nonce to prevent replay attacks. - payment_mandate.user_authorization = ( - cart_mandate_hash + "_" + payment_mandate_hash - ) - tool_context.state["signed_payment_mandate"] = payment_mandate - return payment_mandate.user_authorization + """Simulates signing the transaction details on a user's secure device. + + This function represents the step where the final transaction details, + including hashes of the cart and payment mandates, would be sent to a + secure hardware element on the user's device (e.g., Secure Enclave) to be + cryptographically signed with the user's private key. + + Note: This is a placeholder implementation. It does not perform any actual + cryptographic operations. It simulates the creation of a signature by + concatenating the mandate hashes. + + Args: + tool_context: The context object used for state management. It is expected + to contain the `payment_mandate` and `cart_mandate`. + + Returns: + A string representing the simulated user authorization signature (JWT). + """ + payment_mandate: PaymentMandate = tool_context.state['payment_mandate'] + cart_mandate: CartMandate = tool_context.state['cart_mandate'] + cart_mandate_hash = _generate_cart_mandate_hash(cart_mandate) + payment_mandate_hash = _generate_payment_mandate_hash( + payment_mandate.payment_mandate_contents + ) + # A JWT containing the user's digital signature to authorize the transaction. + # The payload uses hashes to bind the signature to the specific cart and + # payment details, and includes a nonce to prevent replay attacks. + payment_mandate.user_authorization = ( + cart_mandate_hash + '_' + payment_mandate_hash + ) + tool_context.state['signed_payment_mandate'] = payment_mandate + return payment_mandate.user_authorization async def send_signed_payment_mandate_to_credentials_provider( tool_context: ToolContext, debug_mode: bool = False, ) -> str: - """Sends the signed payment mandate to the credentials provider. - - Args: - tool_context: The ADK supplied tool context. - debug_mode: Whether the agent is in debug mode. - """ - payment_mandate = tool_context.state["signed_payment_mandate"] - if not payment_mandate: - raise RuntimeError("No signed payment mandate found in tool context state.") - risk_data = tool_context.state["risk_data"] - if not risk_data: - raise RuntimeError("No risk data found in tool context state.") - message = ( - A2aMessageBuilder() - .set_context_id(tool_context.state["shopping_context_id"]) - .add_text("This is the signed payment mandate") - .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate) - .add_data("risk_data", risk_data) - .add_data("debug_mode", debug_mode) - .build() - ) - return await credentials_provider_client.send_a2a_message(message) + """Sends the signed payment mandate to the credentials provider. + + Args: + tool_context: The ADK supplied tool context. + debug_mode: Whether the agent is in debug mode. + """ + payment_mandate = tool_context.state['signed_payment_mandate'] + if not payment_mandate: + raise RuntimeError( + 'No signed payment mandate found in tool context state.' + ) + risk_data = tool_context.state['risk_data'] + if not risk_data: + raise RuntimeError('No risk data found in tool context state.') + message = ( + A2aMessageBuilder() + .set_context_id(tool_context.state['shopping_context_id']) + .add_text('This is the signed payment mandate') + .add_data(PAYMENT_MANDATE_DATA_KEY, payment_mandate) + .add_data('risk_data', risk_data) + .add_data('debug_mode', debug_mode) + .build() + ) + return await credentials_provider_client.send_a2a_message(message) def _generate_cart_mandate_hash(cart_mandate: CartMandate) -> str: - """Generates a cryptographic hash of the CartMandate. + """Generates a cryptographic hash of the CartMandate. - This hash serves as a tamper-proof reference to the specific merchant-signed - cart offer that the user has approved. + This hash serves as a tamper-proof reference to the specific merchant-signed + cart offer that the user has approved. - Note: This is a placeholder implementation for development. A real - implementation must use a secure hashing algorithm (e.g., SHA-256) on the - canonical representation of the CartMandate object. + Note: This is a placeholder implementation for development. A real + implementation must use a secure hashing algorithm (e.g., SHA-256) on the + canonical representation of the CartMandate object. - Args: - cart_mandate: The complete CartMandate object, including the merchant's - authorization. + Args: + cart_mandate: The complete CartMandate object, including the merchant's + authorization. - Returns: - A string representing the hash of the cart mandate. - """ - return "fake_cart_mandate_hash_" + cart_mandate.contents.id + Returns: + A string representing the hash of the cart mandate. + """ + return 'fake_cart_mandate_hash_' + cart_mandate.contents.id def _generate_payment_mandate_hash( payment_mandate_contents: PaymentMandateContents, ) -> str: - """Generates a cryptographic hash of the PaymentMandateContents. + """Generates a cryptographic hash of the PaymentMandateContents. - This hash creates a tamper-proof reference to the specific payment details - the user is about to authorize. + This hash creates a tamper-proof reference to the specific payment details + the user is about to authorize. - Note: This is a placeholder implementation for development. A real - implementation must use a secure hashing algorithm (e.g., SHA-256) on the - canonical representation of the PaymentMandateContents object. + Note: This is a placeholder implementation for development. A real + implementation must use a secure hashing algorithm (e.g., SHA-256) on the + canonical representation of the PaymentMandateContents object. - Args: - payment_mandate_contents: The payment mandate contents to hash. + Args: + payment_mandate_contents: The payment mandate contents to hash. - Returns: - A string representing the hash of the payment mandate contents. - """ - return ( - "fake_payment_mandate_hash_" + payment_mandate_contents.payment_mandate_id - ) + Returns: + A string representing the hash of the payment mandate contents. + """ + return ( + 'fake_payment_mandate_hash_' + + payment_mandate_contents.payment_mandate_id + ) def _parse_cart_mandates(artifacts: list[Artifact]) -> list[CartMandate]: - """Parses a list of artifacts into a list of CartMandate objects.""" - return artifact_utils.find_canonical_objects( - artifacts, CART_MANDATE_DATA_KEY, CartMandate - ) + """Parses a list of artifacts into a list of CartMandate objects.""" + return artifact_utils.find_canonical_objects( + artifacts, CART_MANDATE_DATA_KEY, CartMandate + ) diff --git a/src/ap2/channels/__init__.py b/src/ap2/channels/__init__.py new file mode 100644 index 00000000..7974c0ba --- /dev/null +++ b/src/ap2/channels/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Payment channel management and security framework.""" diff --git a/src/ap2/channels/channel_manager.py b/src/ap2/channels/channel_manager.py new file mode 100644 index 00000000..3c820e8b --- /dev/null +++ b/src/ap2/channels/channel_manager.py @@ -0,0 +1,425 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Channel lifecycle management for micropayment channels.""" + +import hashlib +import uuid + +from datetime import UTC, datetime, timedelta +from typing import Any + +from pydantic import BaseModel, Field + +from ap2.types.payment_channels import ( + ChannelCloseRequest, + ChannelDispute, + ChannelOpenRequest, + ChannelParticipant, + ChannelState, + DisputeReason, + PaymentChannel, + PaymentVoucher, +) +from ap2.types.payment_request import PaymentCurrencyAmount + + +class ChannelOperationResult(BaseModel): + """Result of a channel operation.""" + + success: bool = Field(..., description='Whether the operation succeeded') + channel_id: str | None = Field(None, description='Channel ID if applicable') + message: str = Field(..., description='Result message') + data: dict[str, Any] = Field( + default_factory=dict, description='Additional result data' + ) + + +class ChannelManager(BaseModel): + """Manages payment channel lifecycle and operations.""" + + manager_id: str = Field( + ..., description='Unique identifier for this manager' + ) + agent_did: str = Field( + ..., description='DID of the agent using this manager' + ) + active_channels: dict[str, PaymentChannel] = Field( + default_factory=dict, description='Active payment channels' + ) + channel_history: list[str] = Field( + default_factory=list, description='Historical channel IDs' + ) + pending_operations: dict[str, dict[str, Any]] = Field( + default_factory=dict, description='Pending channel operations' + ) + security_config: dict[str, Any] = Field( + default_factory=dict, description='Security configuration' + ) + + def open_channel( + self, request: ChannelOpenRequest + ) -> ChannelOperationResult: + """Open a new payment channel.""" + try: + # Generate channel ID + channel_id = f'ch_{uuid.uuid4()}' + + # Validate participants + if ( + request.requesting_participant.participant_id + == request.target_participant.participant_id + ): + return ChannelOperationResult( + success=False, + message='Requesting and target participants cannot be the same', + ) + + # Create channel + participants = [ + request.requesting_participant, + request.target_participant, + ] + + # Calculate total capacity + total_capacity = PaymentCurrencyAmount( + currency=request.initial_deposit.currency, + value=sum(p.initial_balance.value for p in participants), + ) + + # Set expiry time + expiry_time = datetime.now(UTC) + timedelta( + hours=request.duration_hours + ) + + channel = PaymentChannel( + channel_id=channel_id, + participants=participants, + state=ChannelState.OPENING, + policy=request.proposed_policy, + total_capacity=total_capacity, + current_state_hash=self._calculate_state_hash( + channel_id, participants + ), + expires_at=expiry_time.isoformat(), + ) + + # Store the channel + self.active_channels[channel_id] = channel + + return ChannelOperationResult( + success=True, + channel_id=channel_id, + message=f'Channel {channel_id} opened successfully', + data={ + 'channel': channel.model_dump(), + 'next_step': 'activate_channel', + }, + ) + + except Exception as e: + return ChannelOperationResult( + success=False, message=f'Failed to open channel: {e!s}' + ) + + def activate_channel(self, channel_id: str) -> ChannelOperationResult: + """Activate a channel after both parties confirm.""" + channel = self.active_channels.get(channel_id) + if not channel: + return ChannelOperationResult( + success=False, message=f'Channel {channel_id} not found' + ) + + if channel.state != ChannelState.OPENING: + return ChannelOperationResult( + success=False, + message=f'Channel {channel_id} is not in opening state', + ) + + # In a real implementation, this would verify both participants have confirmed + channel.state = ChannelState.ACTIVE + channel.last_activity = datetime.now(UTC).isoformat() + + return ChannelOperationResult( + success=True, + channel_id=channel_id, + message=f'Channel {channel_id} activated successfully', + ) + + def process_payment( + self, + channel_id: str, + from_participant: str, + to_participant: str, + amount: PaymentCurrencyAmount, + metadata: dict[str, Any] | None = None, + ) -> ChannelOperationResult: + """Process a payment through the channel.""" + channel = self.active_channels.get(channel_id) + if not channel: + return ChannelOperationResult( + success=False, message=f'Channel {channel_id} not found' + ) + + # Check if payment can be processed + can_process, reason = channel.can_process_payment( + from_participant, to_participant, amount + ) + if not can_process: + return ChannelOperationResult( + success=False, message=f'Payment rejected: {reason}' + ) + + # Create payment voucher + voucher = PaymentVoucher( + voucher_id=f'voucher_{uuid.uuid4()}', + channel_id=channel_id, + from_participant=from_participant, + to_participant=to_participant, + amount=amount, + nonce=channel.sequence_number + 1, + cumulative_amount=amount, # Simplified - would track cumulative properly + signature=self._sign_voucher(channel_id, from_participant, amount), + metadata=metadata or {}, + ) + + # Update channel balances + payer = channel.get_participant(from_participant) + payee = channel.get_participant(to_participant) + + if payer and payee: + payer.current_balance.value -= amount.value + payee.current_balance.value += amount.value + + # Update channel state + channel.sequence_number += 1 + channel.last_activity = datetime.now(UTC).isoformat() + channel.current_state_hash = self._calculate_state_hash( + channel_id, channel.participants + ) + + return ChannelOperationResult( + success=True, + channel_id=channel_id, + message='Payment processed successfully', + data={ + 'voucher': voucher.model_dump(), + 'new_balances': { + p.participant_id: p.current_balance.model_dump() + for p in channel.participants + }, + }, + ) + + def close_channel( + self, request: ChannelCloseRequest + ) -> ChannelOperationResult: + """Close a payment channel.""" + channel = self.active_channels.get(request.channel_id) + if not channel: + return ChannelOperationResult( + success=False, message=f'Channel {request.channel_id} not found' + ) + + if channel.state not in [ChannelState.ACTIVE, ChannelState.CLOSING]: + return ChannelOperationResult( + success=False, + message=f'Channel {request.channel_id} cannot be closed in state {channel.state}', + ) + + try: + # Validate final balances + if not self._validate_final_balances( + channel, request.final_balances + ): + return ChannelOperationResult( + success=False, + message='Final balances do not match channel state', + ) + + # Update channel state + channel.state = ( + ChannelState.CLOSING + if not request.force_close + else ChannelState.CLOSED + ) + channel.settlement_info = { + 'final_balances': { + k: v.model_dump() for k, v in request.final_balances.items() + }, + 'closed_by': request.requesting_participant, + 'close_reason': request.reason, + 'closure_time': datetime.now(UTC).isoformat(), + } + + if request.force_close or channel.state == ChannelState.CLOSED: + # Move to history + self.channel_history.append(request.channel_id) + del self.active_channels[request.channel_id] + + return ChannelOperationResult( + success=True, + channel_id=request.channel_id, + message=f'Channel {request.channel_id} {"closed" if channel.state == ChannelState.CLOSED else "closing"}', + data=channel.settlement_info, + ) + + except Exception as e: + return ChannelOperationResult( + success=False, message=f'Failed to close channel: {e!s}' + ) + + def dispute_channel( + self, + channel_id: str, + disputing_participant: str, + reason: DisputeReason, + evidence: list[dict[str, Any]], + ) -> ChannelOperationResult: + """Raise a dispute for a channel.""" + channel = self.active_channels.get(channel_id) + if not channel: + return ChannelOperationResult( + success=False, message=f'Channel {channel_id} not found' + ) + + # Create dispute + dispute = ChannelDispute( + dispute_id=f'dispute_{uuid.uuid4()}', + channel_id=channel_id, + disputing_participant=disputing_participant, + dispute_reason=reason, + contested_state=channel.model_dump(), + evidence=evidence, + resolution_deadline=( + datetime.now(UTC) + + timedelta(seconds=channel.policy.dispute_timeout_seconds) + ).isoformat(), + ) + + # Update channel state + channel.state = ChannelState.DISPUTED + channel.dispute_info = dispute.model_dump() + + return ChannelOperationResult( + success=True, + channel_id=channel_id, + message=f'Dispute {dispute.dispute_id} created for channel {channel_id}', + data={'dispute': dispute.model_dump()}, + ) + + def get_channel_status(self, channel_id: str) -> dict[str, Any] | None: + """Get comprehensive status of a channel.""" + channel = self.active_channels.get(channel_id) + if not channel: + return None + + return { + 'channel_id': channel_id, + 'state': channel.state, + 'participants': [ + { + 'id': p.participant_id, + 'balance': p.current_balance.model_dump(), + 'role': p.role, + } + for p in channel.participants + ], + 'total_capacity': channel.total_capacity.model_dump(), + 'sequence_number': channel.sequence_number, + 'expires_at': channel.expires_at, + 'last_activity': channel.last_activity, + 'is_expired': channel.is_expired(), + 'dispute_info': channel.dispute_info, + 'settlement_info': channel.settlement_info, + } + + def cleanup_expired_channels(self) -> list[str]: + """Clean up expired channels.""" + expired_channels = [] + current_time = datetime.now(UTC) + + for channel_id, channel in list(self.active_channels.items()): + if channel.is_expired(current_time): + expired_channels.append(channel_id) + + # Force close expired channel + channel.state = ChannelState.EXPIRED + channel.settlement_info = { + 'final_balances': { + p.participant_id: p.current_balance.model_dump() + for p in channel.participants + }, + 'close_reason': 'expired', + 'closure_time': current_time.isoformat(), + } + + self.channel_history.append(channel_id) + del self.active_channels[channel_id] + + return expired_channels + + def get_channels_by_participant( + self, participant_id: str + ) -> list[dict[str, Any]]: + """Get all channels involving a specific participant.""" + result = [] + for channel_id, channel in self.active_channels.items(): + if any( + p.participant_id == participant_id for p in channel.participants + ): + result.append(self.get_channel_status(channel_id)) + return [status for status in result if status is not None] + + def _calculate_state_hash( + self, channel_id: str, participants: list[ChannelParticipant] + ) -> str: + """Calculate hash of channel state.""" + state_data = f'{channel_id}' + for participant in participants: + state_data += f'{participant.participant_id}:{participant.current_balance.value}' + + return hashlib.sha256(state_data.encode()).hexdigest() + + def _sign_voucher( + self, + channel_id: str, + from_participant: str, + amount: PaymentCurrencyAmount, + ) -> str: + """Create signature for payment voucher (mock implementation).""" + voucher_data = ( + f'{channel_id}:{from_participant}:{amount.value}:{amount.currency}' + ) + return hashlib.sha256(voucher_data.encode()).hexdigest() + + def _validate_final_balances( + self, + channel: PaymentChannel, + final_balances: dict[str, PaymentCurrencyAmount], + ) -> bool: + """Validate that final balances are consistent with channel state.""" + # Check that all participants are accounted for + participant_ids = {p.participant_id for p in channel.participants} + final_balance_ids = set(final_balances.keys()) + + if participant_ids != final_balance_ids: + return False + + # Check that total balances match channel capacity + total_final = sum(balance.value for balance in final_balances.values()) + total_capacity = channel.total_capacity.value + + # Allow for small floating-point differences + return abs(total_final - total_capacity) < 0.001 diff --git a/src/ap2/types/contact_picker.py b/src/ap2/types/contact_picker.py index a65145ec..a4b49a08 100644 --- a/src/ap2/types/contact_picker.py +++ b/src/ap2/types/contact_picker.py @@ -23,27 +23,26 @@ https://www.w3.org/TR/contact-picker/ """ -from typing import Optional - from pydantic import BaseModel -CONTACT_ADDRESS_DATA_KEY = "contact_picker.ContactAddress" + +CONTACT_ADDRESS_DATA_KEY = 'contact_picker.ContactAddress' class ContactAddress(BaseModel): - """The ContactAddress interface represents a physical address. - - Specification: - https://www.w3.org/TR/contact-picker/#contact-address - """ - - city: Optional[str] = None - country: Optional[str] = None - dependent_locality: Optional[str] = None - organization: Optional[str] = None - phone_number: Optional[str] = None - postal_code: Optional[str] = None - recipient: Optional[str] = None - region: Optional[str] = None - sorting_code: Optional[str] = None - address_line: Optional[list[str]] = None + """The ContactAddress interface represents a physical address. + + Specification: + https://www.w3.org/TR/contact-picker/#contact-address + """ + + city: str | None = None + country: str | None = None + dependent_locality: str | None = None + organization: str | None = None + phone_number: str | None = None + postal_code: str | None = None + recipient: str | None = None + region: str | None = None + sorting_code: str | None = None + address_line: list[str] | None = None diff --git a/src/ap2/types/mandate.py b/src/ap2/types/mandate.py index c5506689..8e81221b 100644 --- a/src/ap2/types/mandate.py +++ b/src/ap2/types/mandate.py @@ -14,106 +14,108 @@ """Contains the definitions of the Agent Payments Protocol mandates.""" -from datetime import datetime -from datetime import timezone -from typing import Optional +from datetime import UTC, datetime -from ap2.types.payment_request import PaymentItem -from ap2.types.payment_request import PaymentRequest -from ap2.types.payment_request import PaymentResponse -from pydantic import BaseModel -from pydantic import Field +from pydantic import BaseModel, Field -CART_MANDATE_DATA_KEY = "ap2.mandates.CartMandate" -INTENT_MANDATE_DATA_KEY = "ap2.mandates.IntentMandate" -PAYMENT_MANDATE_DATA_KEY = "ap2.mandates.PaymentMandate" +from ap2.types.payment_request import ( + PaymentItem, + PaymentRequest, + PaymentResponse, +) + + +CART_MANDATE_DATA_KEY = 'ap2.mandates.CartMandate' +INTENT_MANDATE_DATA_KEY = 'ap2.mandates.IntentMandate' +PAYMENT_MANDATE_DATA_KEY = 'ap2.mandates.PaymentMandate' class IntentMandate(BaseModel): - """Represents the user's purchase intent. - - These are the initial fields utilized in the human-present flow. For - human-not-present flows, additional fields will be added to this mandate. - """ - - user_cart_confirmation_required: bool = Field( - True, - description=( - "If false, the agent can make purchases on the user's behalf once all" - " purchase conditions have been satisfied. This must be true if the" - " intent mandate is not signed by the user." - ), - ) - natural_language_description: str = Field( - ..., - description=( - "The natural language description of the user's intent. This is" - " generated by the shopping agent, and confirmed by the user. The" - " goal is to have informed consent by the user." - ), - example="High top, old school, red basketball shoes", - ) - merchants: Optional[list[str]] = Field( - None, - description=( - "Merchants allowed to fulfill the intent. If not set, the shopping" - " agent is able to work with any suitable merchant." - ), - ) - skus: Optional[list[str]] = Field( - None, - description=( - "A list of specific product SKUs. If not set, any SKU is allowed." - ), - ) - requires_refundability: Optional[bool] = Field( - False, - description="If true, items must be refundable.", - ) - intent_expiry: str = Field( - ..., - description="When the intent mandate expires, in ISO 8601 format.", - ) + """Represents the user's purchase intent. + + These are the initial fields utilized in the human-present flow. For + human-not-present flows, additional fields will be added to this mandate. + """ + + user_cart_confirmation_required: bool = Field( + True, + description=( + "If false, the agent can make purchases on the user's behalf once all" + ' purchase conditions have been satisfied. This must be true if the' + ' intent mandate is not signed by the user.' + ), + ) + natural_language_description: str = Field( + ..., + description=( + "The natural language description of the user's intent. This is" + ' generated by the shopping agent, and confirmed by the user. The' + ' goal is to have informed consent by the user.' + ), + example='High top, old school, red basketball shoes', + ) + merchants: list[str] | None = Field( + None, + description=( + 'Merchants allowed to fulfill the intent. If not set, the shopping' + ' agent is able to work with any suitable merchant.' + ), + ) + skus: list[str] | None = Field( + None, + description=( + 'A list of specific product SKUs. If not set, any SKU is allowed.' + ), + ) + requires_refundability: bool | None = Field( + False, + description='If true, items must be refundable.', + ) + intent_expiry: str = Field( + ..., + description='When the intent mandate expires, in ISO 8601 format.', + ) class CartContents(BaseModel): - """The detailed contents of a cart. - - This object is signed by the merchant to create a CartMandate. - """ - - id: str = Field(..., description="A unique identifier for this cart.") - user_cart_confirmation_required: bool = Field( - ..., - description=( - "If true, the merchant requires the user to confirm the cart before" - " the purchase can be completed." - ), - ) - payment_request: PaymentRequest = Field( - ..., - description=( - "The W3C PaymentRequest object to initiate payment. This contains the" - " items being purchased, prices, and the set of payment methods" - " accepted by the merchant for this cart." - ), - ) - cart_expiry: str = Field( - ..., description="When this cart expires, in ISO 8601 format." - ) - merchant_name: str = Field(..., description="The name of the merchant.") + """The detailed contents of a cart. + + This object is signed by the merchant to create a CartMandate. + """ + + id: str = Field(..., description='A unique identifier for this cart.') + user_cart_confirmation_required: bool = Field( + ..., + description=( + 'If true, the merchant requires the user to confirm the cart before' + ' the purchase can be completed.' + ), + ) + payment_request: PaymentRequest = Field( + ..., + description=( + 'The W3C PaymentRequest object to initiate payment. This contains the' + ' items being purchased, prices, and the set of payment methods' + ' accepted by the merchant for this cart.' + ), + ) + cart_expiry: str = Field( + ..., description='When this cart expires, in ISO 8601 format.' + ) + merchant_name: str = Field(..., description='The name of the merchant.') class CartMandate(BaseModel): - """A cart whose contents have been digitally signed by the merchant. + """A cart whose contents have been digitally signed by the merchant. - This serves as a guarantee of the items and price for a limited time. - """ + This serves as a guarantee of the items and price for a limited time. + """ - contents: CartContents = Field(..., description="The contents of the cart.") - merchant_authorization: Optional[str] = Field( - None, - description=(""" A base64url-encoded JSON Web Token (JWT) that digitally + contents: CartContents = Field(..., description='The contents of the cart.') + merchant_authorization: str | None = Field( + None, + description=( + """ A base64url-encoded JSON Web Token (JWT) that digitally signs the cart contents, guaranteeing its authenticity and integrity: 1. Header includes the signing algorithm and key ID. 2. Payload includes: @@ -129,59 +131,60 @@ class CartMandate(BaseModel): key. It allows anyone with the public key to verify the token's authenticity and confirm that the payload has not been tampered with. The entire JWT is base64url encoded to ensure safe transmission. - """), - example="eyJhbGciOiJSUzI1NiIsImtpZCI6IjIwMjQwOTA...", # Example JWT - ) + """ + ), + example='eyJhbGciOiJSUzI1NiIsImtpZCI6IjIwMjQwOTA...', # Example JWT + ) class PaymentMandateContents(BaseModel): - """The data contents of a PaymentMandate.""" - - payment_mandate_id: str = Field( - ..., description="A unique identifier for this payment mandate." - ) - payment_details_id: str = Field( - ..., description="A unique identifier for the payment request." - ) - payment_details_total: PaymentItem = Field( - ..., description="The total payment amount." - ) - payment_response: PaymentResponse = Field( - ..., - description=( - "The payment response containing details of the payment method chosen" - " by the user." - ), - ) - merchant_agent: str = Field(..., description="Identifier for the merchant.") - timestamp: str = Field( - description=( - "The date and time the mandate was created, in ISO 8601 format." - ), - default_factory=lambda: datetime.now(timezone.utc).isoformat(), - ) + """The data contents of a PaymentMandate.""" + + payment_mandate_id: str = Field( + ..., description='A unique identifier for this payment mandate.' + ) + payment_details_id: str = Field( + ..., description='A unique identifier for the payment request.' + ) + payment_details_total: PaymentItem = Field( + ..., description='The total payment amount.' + ) + payment_response: PaymentResponse = Field( + ..., + description=( + 'The payment response containing details of the payment method chosen' + ' by the user.' + ), + ) + merchant_agent: str = Field(..., description='Identifier for the merchant.') + timestamp: str = Field( + description=( + 'The date and time the mandate was created, in ISO 8601 format.' + ), + default_factory=lambda: datetime.now(UTC).isoformat(), + ) class PaymentMandate(BaseModel): - """Contains the user's instructions & authorization for payment. - - While the Cart and Intent mandates are required by the merchant to fulfill the - order, separately the protocol provides additional visibility into the agentic - transaction to the payments ecosystem. For this purpose, the PaymentMandate - (bound to Cart/Intent mandate but containing separate information) may be - shared with the network/issuer along with the standard transaction - authorization messages. The goal of the PaymentMandate is to help the - network/issuer build trust into the agentic transaction. - """ - - payment_mandate_contents: PaymentMandateContents = Field( - ..., - description="The data contents of the payment mandate.", - ) - user_authorization: Optional[str] = Field( - None, - description=( - """ + """Contains the user's instructions & authorization for payment. + + While the Cart and Intent mandates are required by the merchant to fulfill the + order, separately the protocol provides additional visibility into the agentic + transaction to the payments ecosystem. For this purpose, the PaymentMandate + (bound to Cart/Intent mandate but containing separate information) may be + shared with the network/issuer along with the standard transaction + authorization messages. The goal of the PaymentMandate is to help the + network/issuer build trust into the agentic transaction. + """ + + payment_mandate_contents: PaymentMandateContents = Field( + ..., + description='The data contents of the payment mandate.', + ) + user_authorization: str | None = Field( + None, + description=( + """ This is a base64_url-encoded verifiable presentation of a verifiable credential signing over the cart_mandate and payment_mandate_hashes. For example an sd-jwt-vc would contain: @@ -195,6 +198,6 @@ class PaymentMandate(BaseModel): CartMandate and PaymentMandateContents. """ - ), - example="eyJhbGciOiJFUzI1NksiLCJraWQiOiJkaWQ6ZXhhbXBsZ...", - ) + ), + example='eyJhbGciOiJFUzI1NksiLCJraWQiOiJkaWQ6ZXhhbXBsZ...', + ) diff --git a/src/ap2/types/payment_channels.py b/src/ap2/types/payment_channels.py new file mode 100644 index 00000000..9f73b5ff --- /dev/null +++ b/src/ap2/types/payment_channels.py @@ -0,0 +1,428 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Payment channel framework for micropayments in AP2. + +This module provides the foundation for high-frequency, sub-cent transactions +between agents through state channels, enabling pay-per-use and streaming +payment models for AI inference and API calls. +""" + +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + +from ap2.types.payment_request import PaymentCurrencyAmount + + +class ChannelState(str, Enum): + """States of a payment channel lifecycle.""" + + OPENING = 'opening' + ACTIVE = 'active' + CLOSING = 'closing' + CLOSED = 'closed' + DISPUTED = 'disputed' + EXPIRED = 'expired' + + +class DisputeReason(str, Enum): + """Reasons for payment channel disputes.""" + + INVALID_STATE = 'invalid_state' + STALE_UPDATE = 'stale_update' + INVALID_SIGNATURE = 'invalid_signature' + INSUFFICIENT_FUNDS = 'insufficient_funds' + TIMEOUT = 'timeout' + FRAUD_ATTEMPT = 'fraud_attempt' + + +class ChannelParticipant(BaseModel): + """Participant in a payment channel.""" + + participant_id: str = Field( + ..., description='Unique identifier for the participant' + ) + agent_did: str | None = Field( + None, description='DID of the agent representing this participant' + ) + wallet_address: str = Field( + ..., description='Blockchain wallet address for settlements' + ) + role: str = Field( + ..., description='Role in the channel (payer, payee, mediator)' + ) + public_key: str = Field( + ..., description='Public key for channel signature verification' + ) + initial_balance: PaymentCurrencyAmount = Field( + ..., description='Initial balance contributed to the channel' + ) + current_balance: PaymentCurrencyAmount = Field( + ..., description='Current balance in the channel' + ) + + +class ChannelPolicy(BaseModel): + """Policy governing payment channel behavior.""" + + max_transaction_amount: PaymentCurrencyAmount = Field( + ..., description='Maximum amount per transaction' + ) + min_transaction_amount: PaymentCurrencyAmount = Field( + ..., description='Minimum amount per transaction' + ) + dispute_timeout_seconds: int = Field( + default=86400, + description='Time to challenge disputed states (24 hours)', + ) + max_pending_updates: int = Field( + default=1000, description='Maximum number of pending state updates' + ) + settlement_threshold: PaymentCurrencyAmount = Field( + ..., description='Balance threshold triggering automatic settlement' + ) + fee_rate: float = Field( + default=0.001, description='Transaction fee rate (0.1% default)' + ) + auto_close_timeout: int = Field( + default=604800, description='Auto-close timeout in seconds (7 days)' + ) + + +class PaymentVoucher(BaseModel): + """Off-chain payment voucher for micropayments.""" + + voucher_id: str = Field( + ..., description='Unique identifier for this voucher' + ) + channel_id: str = Field(..., description='Payment channel identifier') + from_participant: str = Field( + ..., description='ID of participant making payment' + ) + to_participant: str = Field( + ..., description='ID of participant receiving payment' + ) + amount: PaymentCurrencyAmount = Field( + ..., description='Amount being transferred' + ) + nonce: int = Field(..., description='Monotonic nonce for replay protection') + cumulative_amount: PaymentCurrencyAmount = Field( + ..., description='Total amount transferred to this participant' + ) + timestamp: str = Field( + description='When the voucher was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + signature: str = Field( + ..., description='Cryptographic signature of the voucher' + ) + metadata: dict[str, Any] = Field( + default_factory=dict, description='Additional voucher metadata' + ) + + +class ChannelUpdate(BaseModel): + """State update for a payment channel.""" + + update_id: str = Field(..., description='Unique identifier for this update') + channel_id: str = Field(..., description='Payment channel identifier') + sequence_number: int = Field( + ..., description='Monotonic sequence number for ordering' + ) + previous_state_hash: str = Field( + ..., description='Hash of the previous channel state' + ) + new_balances: dict[str, PaymentCurrencyAmount] = Field( + ..., description='New balances for all participants' + ) + included_vouchers: list[str] = Field( + default_factory=list, + description='List of voucher IDs included in this update', + ) + timestamp: str = Field( + description='When the update was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + signatures: dict[str, str] = Field( + default_factory=dict, description='Signatures from participants' + ) + state_hash: str = Field( + ..., description='Hash of the new channel state after this update' + ) + + +class PaymentChannel(BaseModel): + """Core payment channel for micropayment transactions.""" + + channel_id: str = Field( + ..., description='Unique identifier for this channel' + ) + participants: list[ChannelParticipant] = Field( + ..., + description='List of channel participants', + min_length=2, + max_length=10, + ) + state: ChannelState = Field( + default=ChannelState.OPENING, description='Current channel state' + ) + policy: ChannelPolicy = Field( + ..., description='Policies governing this channel' + ) + total_capacity: PaymentCurrencyAmount = Field( + ..., description='Total capacity of the channel' + ) + current_state_hash: str = Field( + ..., description='Hash of the current channel state' + ) + sequence_number: int = Field( + default=0, description='Current sequence number' + ) + created_at: str = Field( + description='When the channel was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + expires_at: str = Field( + ..., description='When the channel expires, in ISO 8601 format' + ) + last_activity: str = Field( + description='Last activity timestamp, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + dispute_info: dict[str, Any] | None = Field( + None, description='Information about any active disputes' + ) + settlement_info: dict[str, Any] | None = Field( + None, description='Information about channel settlement' + ) + + def get_participant(self, participant_id: str) -> ChannelParticipant | None: + """Get a participant by ID.""" + for participant in self.participants: + if participant.participant_id == participant_id: + return participant + return None + + def is_expired(self, current_time: datetime | None = None) -> bool: + """Check if the channel has expired.""" + if current_time is None: + current_time = datetime.now(UTC) + + expires_at = datetime.fromisoformat( + self.expires_at.replace('Z', '+00:00') + ) + return current_time > expires_at + + def get_total_balance(self) -> PaymentCurrencyAmount: + """Calculate total balance across all participants.""" + if not self.participants: + return PaymentCurrencyAmount(currency='USD', value=0.0) + + currency = self.participants[0].current_balance.currency + total_value = sum(p.current_balance.value for p in self.participants) + + return PaymentCurrencyAmount(currency=currency, value=total_value) + + def can_process_payment( + self, from_id: str, to_id: str, amount: PaymentCurrencyAmount + ) -> tuple[bool, str]: + """Check if a payment can be processed.""" + # Check channel state + if self.state != ChannelState.ACTIVE: + return False, f'Channel state is {self.state}, not active' + + # Check expiry + if self.is_expired(): + return False, 'Channel has expired' + + # Check participants exist + payer = self.get_participant(from_id) + payee = self.get_participant(to_id) + + if not payer: + return False, f'Payer {from_id} not found in channel' + if not payee: + return False, f'Payee {to_id} not found in channel' + + # Check currency compatibility + if amount.currency != payer.current_balance.currency: + return False, 'Currency mismatch' + + # Check sufficient balance + if payer.current_balance.value < amount.value: + return False, 'Insufficient balance' + + # Check policy limits + if amount.value > self.policy.max_transaction_amount.value: + return False, 'Amount exceeds maximum transaction limit' + + if amount.value < self.policy.min_transaction_amount.value: + return False, 'Amount below minimum transaction limit' + + return True, 'Payment can be processed' + + +class ChannelOpenRequest(BaseModel): + """Request to open a new payment channel.""" + + requesting_participant: ChannelParticipant = Field( + ..., description='Participant requesting channel creation' + ) + target_participant: ChannelParticipant = Field( + ..., description='Target participant for the channel' + ) + proposed_policy: ChannelPolicy = Field( + ..., description='Proposed channel policies' + ) + duration_hours: int = Field( + default=168, + description='Requested channel duration in hours (7 days default)', + ) + initial_deposit: PaymentCurrencyAmount = Field( + ..., description='Initial deposit from requesting participant' + ) + purpose: str = Field(..., description='Purpose description for the channel') + metadata: dict[str, Any] = Field( + default_factory=dict, description='Additional channel metadata' + ) + + +class ChannelCloseRequest(BaseModel): + """Request to close a payment channel.""" + + channel_id: str = Field(..., description='Channel to close') + requesting_participant: str = Field( + ..., description='ID of participant requesting closure' + ) + final_balances: dict[str, PaymentCurrencyAmount] = Field( + ..., description='Proposed final balances for all participants' + ) + reason: str = Field( + default='normal_closure', description='Reason for channel closure' + ) + force_close: bool = Field( + default=False, description='Whether to force close without consensus' + ) + signature: str = Field(..., description='Signature authorizing the closure') + + +class ChannelDispute(BaseModel): + """Dispute information for a payment channel.""" + + dispute_id: str = Field( + ..., description='Unique identifier for the dispute' + ) + channel_id: str = Field(..., description='Channel under dispute') + disputing_participant: str = Field( + ..., description='ID of participant raising the dispute' + ) + dispute_reason: DisputeReason = Field( + ..., description='Reason for the dispute' + ) + contested_state: dict[str, Any] = Field( + ..., description='The state being contested' + ) + evidence: list[dict[str, Any]] = Field( + default_factory=list, description='Evidence supporting the dispute' + ) + resolution_deadline: str = Field( + ..., description='Deadline for dispute resolution, in ISO 8601 format' + ) + status: str = Field( + default='open', description='Current status of the dispute' + ) + created_at: str = Field( + description='When the dispute was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + + +class ChannelRegistry(BaseModel): + """Registry of active payment channels for an agent or service.""" + + registry_id: str = Field( + ..., description='Unique identifier for this registry' + ) + owner_agent_did: str = Field( + ..., description='DID of the agent owning this registry' + ) + active_channels: dict[str, PaymentChannel] = Field( + default_factory=dict, description='Map of channel_id to PaymentChannel' + ) + channel_history: list[str] = Field( + default_factory=list, + description='Historical list of closed channel IDs', + ) + total_volume: dict[str, float] = Field( + default_factory=dict, description='Total volume by currency' + ) + total_transactions: int = Field( + default=0, description='Total number of transactions processed' + ) + created_at: str = Field( + description='When the registry was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + last_updated: str = Field( + description='Last update timestamp, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + + def add_channel(self, channel: PaymentChannel) -> None: + """Add a new channel to the registry.""" + self.active_channels[channel.channel_id] = channel + self.last_updated = datetime.now(UTC).isoformat() + + def remove_channel(self, channel_id: str) -> PaymentChannel | None: + """Remove a channel from active registry and move to history.""" + channel = self.active_channels.pop(channel_id, None) + if channel: + self.channel_history.append(channel_id) + + # Update statistics + total_balance = channel.get_total_balance() + currency = total_balance.currency + if currency not in self.total_volume: + self.total_volume[currency] = 0.0 + self.total_volume[currency] += total_balance.value + + self.last_updated = datetime.now(UTC).isoformat() + + return channel + + def get_channels_by_participant( + self, participant_id: str + ) -> list[PaymentChannel]: + """Get all channels involving a specific participant.""" + result = [] + for channel in self.active_channels.values(): + if any( + p.participant_id == participant_id for p in channel.participants + ): + result.append(channel) + return result + + def get_total_balance_for_participant( + self, participant_id: str, currency: str + ) -> float: + """Get total balance across all channels for a participant.""" + total = 0.0 + for channel in self.active_channels.values(): + participant = channel.get_participant(participant_id) + if participant and participant.current_balance.currency == currency: + total += participant.current_balance.value + return total diff --git a/src/ap2/types/payment_request.py b/src/ap2/types/payment_request.py index 22495028..12092bf5 100644 --- a/src/ap2/types/payment_request.py +++ b/src/ap2/types/payment_request.py @@ -22,208 +22,235 @@ https://www.w3.org/TR/payment-request/ """ -from typing import Any, Dict, Optional +from typing import Any + +from pydantic import BaseModel, Field from ap2.types.contact_picker import ContactAddress -from pydantic import BaseModel -from pydantic import Field -PAYMENT_METHOD_DATA_DATA_KEY = "payment_request.PaymentMethodData" + +PAYMENT_METHOD_DATA_DATA_KEY = 'payment_request.PaymentMethodData' class PaymentCurrencyAmount(BaseModel): - """A PaymentCurrencyAmount is used to supply monetary amounts. + """A PaymentCurrencyAmount is used to supply monetary amounts. + + Extended to support both traditional fiat currencies and cryptocurrencies + for micropayment channel operations. + + Specification: + https://www.w3.org/TR/payment-request/#dom-paymentcurrencyamount + """ + + currency: str = Field( + ..., + description="Currency code (ISO 4217 for fiat, symbol for crypto, e.g., 'USD', 'USDC', 'ETH')", + ) + value: float = Field(..., description='The monetary value.') + + +class CryptoPaymentAmount(PaymentCurrencyAmount): + """Extended PaymentCurrencyAmount for cryptocurrency payments. - Specification: - https://www.w3.org/TR/payment-request/#dom-paymentcurrencyamount - """ + Supports stablecoins and native cryptocurrencies used in micropayment channels. + """ - currency: str = Field( - ..., description="The three-letter ISO 4217 currency code." - ) - value: float = Field(..., description="The monetary value.") + blockchain_network: str = Field( + ..., + description="Blockchain network (e.g., 'ethereum', 'polygon', 'kite')", + ) + token_contract: str | None = Field( + None, + description='Smart contract address for tokens (not needed for native currencies)', + ) + decimal_places: int = Field( + default=18, description='Number of decimal places for this currency' + ) + network_fee_estimate: float | None = Field( + None, description='Estimated network fee for this transaction' + ) class PaymentItem(BaseModel): - """An item for purchase and the value asked for it. - - Specification: - https://www.w3.org/TR/payment-request/#dom-paymentitem - """ - - label: str = Field( - ..., description="A human-readable description of the item." - ) - amount: PaymentCurrencyAmount = Field( - ..., description="The monetary amount of the item." - ) - pending: Optional[bool] = Field( - None, description="If true, indicates the amount is not final." - ) - refund_period: int = Field( - 30, description="The refund duration for this item, in days." - ) + """An item for purchase and the value asked for it. + + Specification: + https://www.w3.org/TR/payment-request/#dom-paymentitem + """ + + label: str = Field( + ..., description='A human-readable description of the item.' + ) + amount: PaymentCurrencyAmount = Field( + ..., description='The monetary amount of the item.' + ) + pending: bool | None = Field( + None, description='If true, indicates the amount is not final.' + ) + refund_period: int = Field( + 30, description='The refund duration for this item, in days.' + ) class PaymentShippingOption(BaseModel): - """Describes a shipping option. - - Specification: - https://www.w3.org/TR/payment-request/#dom-paymentshippingoption - """ - - id: str = Field( - ..., description="A unique identifier for the shipping option." - ) - label: str = Field( - ..., description="A human-readable description of the shipping option." - ) - amount: PaymentCurrencyAmount = Field( - ..., description="The cost of this shipping option." - ) - selected: Optional[bool] = Field( - False, description="If true, indicates this as the default option." - ) + """Describes a shipping option. + + Specification: + https://www.w3.org/TR/payment-request/#dom-paymentshippingoption + """ + + id: str = Field( + ..., description='A unique identifier for the shipping option.' + ) + label: str = Field( + ..., description='A human-readable description of the shipping option.' + ) + amount: PaymentCurrencyAmount = Field( + ..., description='The cost of this shipping option.' + ) + selected: bool | None = Field( + False, description='If true, indicates this as the default option.' + ) class PaymentOptions(BaseModel): - """Information about the eligible payment options for the payment request. - - Specification: - https://www.w3.org/TR/payment-request/#dom-paymentoptions - """ - - request_payer_name: Optional[bool] = Field( - False, description="Indicates if the payer's name should be collected." - ) - request_payer_email: Optional[bool] = Field( - False, description="Indicates if the payer's email should be collected." - ) - request_payer_phone: Optional[bool] = Field( - False, - description="Indicates if the payer's phone number should be collected.", - ) - request_shipping: Optional[bool] = Field( - True, - description=( - "Indicates if the payer's shipping address should be collected." - ), - ) - shipping_type: Optional[str] = Field( - None, description="Can be `shipping`, `delivery`, or `pickup`." - ) + """Information about the eligible payment options for the payment request. + + Specification: + https://www.w3.org/TR/payment-request/#dom-paymentoptions + """ + + request_payer_name: bool | None = Field( + False, description="Indicates if the payer's name should be collected." + ) + request_payer_email: bool | None = Field( + False, description="Indicates if the payer's email should be collected." + ) + request_payer_phone: bool | None = Field( + False, + description="Indicates if the payer's phone number should be collected.", + ) + request_shipping: bool | None = Field( + True, + description=( + "Indicates if the payer's shipping address should be collected." + ), + ) + shipping_type: str | None = Field( + None, description='Can be `shipping`, `delivery`, or `pickup`.' + ) class PaymentMethodData(BaseModel): - """Indicates a payment method and associated data specific to the method. + """Indicates a payment method and associated data specific to the method. - For example: - - A card may have a processing fee if it is used. - - A loyalty card may offer a discount on the purchase. + For example: + - A card may have a processing fee if it is used. + - A loyalty card may offer a discount on the purchase. - Specification: - https://www.w3.org/TR/payment-request/#dom-paymentmethoddata - """ + Specification: + https://www.w3.org/TR/payment-request/#dom-paymentmethoddata + """ - supported_methods: str = Field( - ..., description="A string identifying the payment method." - ) - data: Optional[Dict[str, Any]] = Field( - default_factory=dict, description="Payment method specific details." - ) + supported_methods: str = Field( + ..., description='A string identifying the payment method.' + ) + data: dict[str, Any] | None = Field( + default_factory=dict, description='Payment method specific details.' + ) class PaymentDetailsModifier(BaseModel): - """Provides details that modify the payment details based on a payment method. - - Specification: - https://www.w3.org/TR/payment-request/#dom-paymentdetailsmodifier - """ - - supported_methods: str = Field( - ..., - description="The payment method ID that this modifier applies to.", - ) - total: Optional[PaymentItem] = Field( - None, - description="A PaymentItem value that overrides the original item total.", - ) - additional_display_items: Optional[list[PaymentItem]] = Field( - None, - description="Additional PaymentItems applicable for this payment method.", - ) - data: Optional[dict[str, Any]] = Field( - None, description="Payment method specific data for the modifier." - ) + """Provides details that modify the payment details based on a payment method. + + Specification: + https://www.w3.org/TR/payment-request/#dom-paymentdetailsmodifier + """ + + supported_methods: str = Field( + ..., + description='The payment method ID that this modifier applies to.', + ) + total: PaymentItem | None = Field( + None, + description='A PaymentItem value that overrides the original item total.', + ) + additional_display_items: list[PaymentItem] | None = Field( + None, + description='Additional PaymentItems applicable for this payment method.', + ) + data: dict[str, Any] | None = Field( + None, description='Payment method specific data for the modifier.' + ) class PaymentDetailsInit(BaseModel): - """Contains the details of the payment being requested. - - Specification: - https://www.w3.org/TR/payment-request/#dom-paymentdetailsinit - """ - - id: str = Field( - ..., description="A unique identifier for the payment request." - ) - display_items: list[PaymentItem] = Field( - ..., description="A list of payment items to be displayed to the user." - ) - shipping_options: Optional[list[PaymentShippingOption]] = Field( - None, - description="A list of available shipping options.", - ) - modifiers: Optional[list[PaymentDetailsModifier]] = Field( - None, - description="A list of price modifiers for particular payment methods.", - ) - total: PaymentItem = Field(..., description="The total payment amount.") + """Contains the details of the payment being requested. + + Specification: + https://www.w3.org/TR/payment-request/#dom-paymentdetailsinit + """ + + id: str = Field( + ..., description='A unique identifier for the payment request.' + ) + display_items: list[PaymentItem] = Field( + ..., description='A list of payment items to be displayed to the user.' + ) + shipping_options: list[PaymentShippingOption] | None = Field( + None, + description='A list of available shipping options.', + ) + modifiers: list[PaymentDetailsModifier] | None = Field( + None, + description='A list of price modifiers for particular payment methods.', + ) + total: PaymentItem = Field(..., description='The total payment amount.') class PaymentRequest(BaseModel): - """A request for payment. + """A request for payment. - Specification: - https://www.w3.org/TR/payment-request/#paymentrequest-interface - """ + Specification: + https://www.w3.org/TR/payment-request/#paymentrequest-interface + """ - method_data: list[PaymentMethodData] = Field( - ..., description="A list of supported payment methods." - ) - details: PaymentDetailsInit = Field( - ..., description="The financial details of the transaction." - ) - options: Optional[PaymentOptions] = None + method_data: list[PaymentMethodData] = Field( + ..., description='A list of supported payment methods.' + ) + details: PaymentDetailsInit = Field( + ..., description='The financial details of the transaction.' + ) + options: PaymentOptions | None = None - shipping_address: Optional[ContactAddress] = Field( - None, description="The user's provided shipping address." - ) + shipping_address: ContactAddress | None = Field( + None, description="The user's provided shipping address." + ) class PaymentResponse(BaseModel): - """Indicates a user has chosen a payment method & approved a payment request. - - Specification: - https://www.w3.org/TR/payment-request/#paymentresponse-interface - """ - - request_id: str = Field( - ..., description="The unique ID from the original PaymentRequest." - ) - method_name: str = Field( - ..., description="The payment method chosen by the user." - ) - details: Optional[Dict[str, Any]] = Field( - None, - description=( - "A dictionary generated by a payment method that a merchant can use" - "to process a transaction. The contents will depend upon the payment" - "method." - ), - ) - shipping_address: Optional[ContactAddress] = None - shipping_option: Optional[PaymentShippingOption] = None - payer_name: Optional[str] = None - payer_email: Optional[str] = None - payer_phone: Optional[str] = None + """Indicates a user has chosen a payment method & approved a payment request. + + Specification: + https://www.w3.org/TR/payment-request/#paymentresponse-interface + """ + + request_id: str = Field( + ..., description='The unique ID from the original PaymentRequest.' + ) + method_name: str = Field( + ..., description='The payment method chosen by the user.' + ) + details: dict[str, Any] | None = Field( + None, + description=( + 'A dictionary generated by a payment method that a merchant can use' + 'to process a transaction. The contents will depend upon the payment' + 'method.' + ), + ) + shipping_address: ContactAddress | None = None + shipping_option: PaymentShippingOption | None = None + payer_name: str | None = None + payer_email: str | None = None + payer_phone: str | None = None diff --git a/src/ap2/types/streaming_payments.py b/src/ap2/types/streaming_payments.py new file mode 100644 index 00000000..7bcf3acd --- /dev/null +++ b/src/ap2/types/streaming_payments.py @@ -0,0 +1,490 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Streaming payment primitives for real-time micropayments. + +This module provides the foundation for streaming payments, enabling pay-per-token, +pay-per-second, or pay-per-API-call models essential for AI inference and +real-time services. +""" + +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + +from ap2.types.payment_request import PaymentCurrencyAmount + + +class PaymentRateType(str, Enum): + """Types of payment rate calculations.""" + + PER_SECOND = 'per_second' + PER_MINUTE = 'per_minute' + PER_HOUR = 'per_hour' + PER_TOKEN = 'per_token' + PER_REQUEST = 'per_request' + PER_BYTE = 'per_byte' + PER_COMPUTE_UNIT = 'per_compute_unit' + FLAT_RATE = 'flat_rate' + TIERED_RATE = 'tiered_rate' + + +class StreamStatus(str, Enum): + """Status of a streaming payment.""" + + INITIALIZING = 'initializing' + ACTIVE = 'active' + PAUSED = 'paused' + COMPLETED = 'completed' + FAILED = 'failed' + CANCELLED = 'cancelled' + + +class PaymentRate(BaseModel): + """Rate structure for streaming payments.""" + + rate_type: PaymentRateType = Field( + ..., description='Type of rate calculation' + ) + rate_amount: PaymentCurrencyAmount = Field( + ..., description='Amount charged per unit' + ) + minimum_charge: PaymentCurrencyAmount | None = Field( + None, description='Minimum charge regardless of usage' + ) + maximum_charge: PaymentCurrencyAmount | None = Field( + None, description='Maximum charge cap' + ) + billing_frequency_seconds: int = Field( + default=1, description='How often to bill (in seconds)' + ) + unit_description: str = Field( + ..., description='Description of what constitutes one unit' + ) + tier_thresholds: list[dict[str, Any]] | None = Field( + None, + description='Tiered pricing thresholds for complex rate structures', + ) + + +class StreamingPaymentVoucher(BaseModel): + """Voucher for incremental streaming payments.""" + + voucher_id: str = Field( + ..., description='Unique identifier for this voucher' + ) + stream_id: str = Field( + ..., description='Streaming payment session identifier' + ) + channel_id: str = Field(..., description='Payment channel identifier') + sequence_number: int = Field( + ..., description='Sequence number within the stream' + ) + increment_amount: PaymentCurrencyAmount = Field( + ..., description='Incremental amount for this voucher' + ) + cumulative_amount: PaymentCurrencyAmount = Field( + ..., description='Total amount streamed so far' + ) + units_consumed: float = Field( + ..., + description='Number of units consumed (tokens, seconds, requests, etc.)', + ) + cumulative_units: float = Field( + ..., description='Total units consumed in the stream' + ) + rate_applied: PaymentRate = Field( + ..., description='Rate structure used for this calculation' + ) + timestamp: str = Field( + description='When the voucher was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + signature: str = Field( + ..., description='Cryptographic signature of the voucher' + ) + metadata: dict[str, Any] = Field( + default_factory=dict, description='Additional voucher metadata' + ) + + +class PaymentCheckpoint(BaseModel): + """Checkpoint for resumable streaming payments.""" + + checkpoint_id: str = Field( + ..., description='Unique identifier for this checkpoint' + ) + stream_id: str = Field( + ..., description='Streaming payment session identifier' + ) + sequence_number: int = Field( + ..., description='Sequence number at checkpoint' + ) + cumulative_amount: PaymentCurrencyAmount = Field( + ..., description='Total amount at checkpoint' + ) + cumulative_units: float = Field( + ..., description='Total units consumed at checkpoint' + ) + timestamp: str = Field( + description='When the checkpoint was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + state_hash: str = Field( + ..., description='Hash of the stream state at this checkpoint' + ) + signatures: dict[str, str] = Field( + default_factory=dict, description='Participant signatures on checkpoint' + ) + + +class StreamingPaymentPolicy(BaseModel): + """Policy governing automated streaming payments.""" + + max_stream_duration_seconds: int = Field( + default=3600, + description='Maximum duration for a single stream (1 hour default)', + ) + checkpoint_frequency_seconds: int = Field( + default=60, description='How often to create checkpoints' + ) + auto_pause_threshold: PaymentCurrencyAmount = Field( + ..., description='Amount threshold that triggers automatic pause' + ) + max_cumulative_amount: PaymentCurrencyAmount = Field( + ..., description='Maximum total amount for the stream' + ) + rate_adjustment_allowed: bool = Field( + default=False, description='Whether rates can be adjusted during stream' + ) + dispute_resolution_timeout: int = Field( + default=300, + description='Timeout for resolving streaming disputes (5 minutes)', + ) + quality_requirements: dict[str, Any] | None = Field( + None, description='Service quality requirements (SLA)' + ) + + +class StreamingPaymentSession(BaseModel): + """Active streaming payment session.""" + + stream_id: str = Field(..., description='Unique identifier for this stream') + channel_id: str = Field(..., description='Payment channel identifier') + payer_id: str = Field(..., description='ID of the paying participant') + payee_id: str = Field(..., description='ID of the receiving participant') + service_description: str = Field( + ..., description='Description of the service being paid for' + ) + rate: PaymentRate = Field(..., description='Rate structure for this stream') + policy: StreamingPaymentPolicy = Field( + ..., description='Policy governing this stream' + ) + status: StreamStatus = Field( + default=StreamStatus.INITIALIZING, description='Current stream status' + ) + start_time: str = Field( + description='When the stream started, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + end_time: str | None = Field( + None, description='When the stream ended, in ISO 8601 format' + ) + current_sequence: int = Field( + default=0, description='Current sequence number' + ) + cumulative_amount: PaymentCurrencyAmount = Field( + ..., description='Total amount streamed so far' + ) + cumulative_units: float = Field( + default=0.0, description='Total units consumed' + ) + last_checkpoint: PaymentCheckpoint | None = Field( + None, description='Most recent checkpoint' + ) + vouchers: list[StreamingPaymentVoucher] = Field( + default_factory=list, description='List of vouchers in this stream' + ) + metadata: dict[str, Any] = Field( + default_factory=dict, description='Additional stream metadata' + ) + + def calculate_next_payment( + self, units_consumed: float + ) -> PaymentCurrencyAmount: + """Calculate the next payment amount based on units consumed.""" + if self.rate.rate_type == PaymentRateType.FLAT_RATE: + # Flat rate - pay the full amount once + if self.current_sequence == 0: + return self.rate.rate_amount + return PaymentCurrencyAmount( + currency=self.rate.rate_amount.currency, value=0.0 + ) + + if self.rate.rate_type in [ + PaymentRateType.PER_SECOND, + PaymentRateType.PER_MINUTE, + PaymentRateType.PER_HOUR, + PaymentRateType.PER_TOKEN, + PaymentRateType.PER_REQUEST, + PaymentRateType.PER_BYTE, + PaymentRateType.PER_COMPUTE_UNIT, + ]: + # Unit-based pricing + increment_amount = self.rate.rate_amount.value * units_consumed + return PaymentCurrencyAmount( + currency=self.rate.rate_amount.currency, value=increment_amount + ) + + if self.rate.rate_type == PaymentRateType.TIERED_RATE: + # Tiered pricing - calculate based on cumulative usage + return self._calculate_tiered_payment(units_consumed) + + raise ValueError(f'Unsupported rate type: {self.rate.rate_type}') + + def _calculate_tiered_payment( + self, units_consumed: float + ) -> PaymentCurrencyAmount: + """Calculate payment for tiered rate structure.""" + if not self.rate.tier_thresholds: + # Fallback to simple rate + increment_amount = self.rate.rate_amount.value * units_consumed + return PaymentCurrencyAmount( + currency=self.rate.rate_amount.currency, value=increment_amount + ) + + new_total_units = self.cumulative_units + units_consumed + old_total_units = self.cumulative_units + + total_cost = 0.0 + for tier in self.rate.tier_thresholds: + tier_min = tier.get('min_units', 0) + tier_max = tier.get('max_units', float('inf')) + tier_rate = tier.get('rate_per_unit', self.rate.rate_amount.value) + + # Calculate how many units fall in this tier for the increment + old_units_in_tier = max( + 0, min(old_total_units, tier_max) - tier_min + ) + new_units_in_tier = max( + 0, min(new_total_units, tier_max) - tier_min + ) + + increment_units_in_tier = new_units_in_tier - old_units_in_tier + if increment_units_in_tier > 0: + total_cost += increment_units_in_tier * tier_rate + + return PaymentCurrencyAmount( + currency=self.rate.rate_amount.currency, value=total_cost + ) + + def add_voucher( + self, units_consumed: float, metadata: dict[str, Any] | None = None + ) -> StreamingPaymentVoucher: + """Add a new voucher to the stream.""" + increment_amount = self.calculate_next_payment(units_consumed) + new_cumulative_amount = PaymentCurrencyAmount( + currency=self.cumulative_amount.currency, + value=self.cumulative_amount.value + increment_amount.value, + ) + new_cumulative_units = self.cumulative_units + units_consumed + + voucher = StreamingPaymentVoucher( + voucher_id=f'{self.stream_id}_{self.current_sequence + 1}', + stream_id=self.stream_id, + channel_id=self.channel_id, + sequence_number=self.current_sequence + 1, + increment_amount=increment_amount, + cumulative_amount=new_cumulative_amount, + units_consumed=units_consumed, + cumulative_units=new_cumulative_units, + rate_applied=self.rate, + signature='placeholder_signature', # Would be cryptographically signed + metadata=metadata or {}, + ) + + # Update stream state + self.vouchers.append(voucher) + self.current_sequence += 1 + self.cumulative_amount = new_cumulative_amount + self.cumulative_units = new_cumulative_units + + return voucher + + def create_checkpoint(self) -> PaymentCheckpoint: + """Create a checkpoint of the current stream state.""" + checkpoint = PaymentCheckpoint( + checkpoint_id=f'{self.stream_id}_checkpoint_{len(self.vouchers)}', + stream_id=self.stream_id, + sequence_number=self.current_sequence, + cumulative_amount=self.cumulative_amount, + cumulative_units=self.cumulative_units, + state_hash=f'hash_{self.stream_id}_{self.current_sequence}', # Would be actual hash + signatures={}, # Would be signed by participants + ) + + self.last_checkpoint = checkpoint + return checkpoint + + def pause_stream(self, reason: str = '') -> None: + """Pause the streaming payment.""" + self.status = StreamStatus.PAUSED + self.metadata['pause_reason'] = reason + self.metadata['paused_at'] = datetime.now(UTC).isoformat() + + def resume_stream(self) -> None: + """Resume the streaming payment.""" + if self.status == StreamStatus.PAUSED: + self.status = StreamStatus.ACTIVE + self.metadata['resumed_at'] = datetime.now(UTC).isoformat() + + def complete_stream(self) -> None: + """Complete the streaming payment.""" + self.status = StreamStatus.COMPLETED + self.end_time = datetime.now(UTC).isoformat() + + def is_within_limits(self) -> tuple[bool, str]: + """Check if the stream is within policy limits.""" + # Check cumulative amount limit + if ( + self.cumulative_amount.value + > self.policy.max_cumulative_amount.value + ): + return False, 'Cumulative amount exceeds policy limit' + + # Check duration limit + if self.status == StreamStatus.ACTIVE: + start_dt = datetime.fromisoformat( + self.start_time.replace('Z', '+00:00') + ) + current_dt = datetime.now(UTC) + duration_seconds = (current_dt - start_dt).total_seconds() + + if duration_seconds > self.policy.max_stream_duration_seconds: + return False, 'Stream duration exceeds policy limit' + + # Check auto-pause threshold + if ( + self.cumulative_amount.value + >= self.policy.auto_pause_threshold.value + ): + return False, 'Auto-pause threshold reached' + + return True, 'Stream is within limits' + + +class StreamingPaymentManager(BaseModel): + """Manager for multiple streaming payment sessions.""" + + manager_id: str = Field( + ..., description='Unique identifier for this manager' + ) + agent_did: str = Field( + ..., description='DID of the agent using this manager' + ) + active_streams: dict[str, StreamingPaymentSession] = Field( + default_factory=dict, description='Active streaming sessions' + ) + completed_streams: list[str] = Field( + default_factory=list, description='List of completed stream IDs' + ) + total_volume: dict[str, float] = Field( + default_factory=dict, description='Total volume by currency' + ) + total_streams: int = Field( + default=0, description='Total number of streams created' + ) + created_at: str = Field( + description='When the manager was created, in ISO 8601 format', + default_factory=lambda: datetime.now(UTC).isoformat(), + ) + + def create_stream( + self, + channel_id: str, + payer_id: str, + payee_id: str, + service_description: str, + rate: PaymentRate, + policy: StreamingPaymentPolicy, + ) -> StreamingPaymentSession: + """Create a new streaming payment session.""" + stream_id = f'stream_{self.agent_did}_{self.total_streams + 1}' + + stream = StreamingPaymentSession( + stream_id=stream_id, + channel_id=channel_id, + payer_id=payer_id, + payee_id=payee_id, + service_description=service_description, + rate=rate, + policy=policy, + cumulative_amount=PaymentCurrencyAmount( + currency=rate.rate_amount.currency, value=0.0 + ), + ) + + self.active_streams[stream_id] = stream + self.total_streams += 1 + + return stream + + def get_stream(self, stream_id: str) -> StreamingPaymentSession | None: + """Get an active streaming session.""" + return self.active_streams.get(stream_id) + + def complete_stream(self, stream_id: str) -> bool: + """Complete a streaming session.""" + stream = self.active_streams.pop(stream_id, None) + if stream: + stream.complete_stream() + self.completed_streams.append(stream_id) + + # Update statistics + currency = stream.cumulative_amount.currency + if currency not in self.total_volume: + self.total_volume[currency] = 0.0 + self.total_volume[currency] += stream.cumulative_amount.value + + return True + return False + + def get_streams_by_channel( + self, channel_id: str + ) -> list[StreamingPaymentSession]: + """Get all streams for a specific channel.""" + return [ + stream + for stream in self.active_streams.values() + if stream.channel_id == channel_id + ] + + def cleanup_expired_streams(self) -> list[str]: + """Clean up expired streaming sessions.""" + expired_streams = [] + current_time = datetime.now(UTC) + + for stream_id, stream in list(self.active_streams.items()): + start_time = datetime.fromisoformat( + stream.start_time.replace('Z', '+00:00') + ) + duration = (current_time - start_time).total_seconds() + + if duration > stream.policy.max_stream_duration_seconds: + expired_streams.append(stream_id) + self.complete_stream(stream_id) + + return expired_streams