Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import warnings
from enum import Enum

import google.protobuf.internal.api_implementation

Expand All @@ -26,18 +27,28 @@
from snet.sdk.client_lib_generator import ClientLibGenerator
from snet.sdk.mpe.mpe_contract import MPEContract
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
from snet.sdk.payment_strategies.default_payment_strategy import DefaultPaymentStrategy
from snet.sdk.payment_strategies.default_payment_strategy import *
from snet.sdk.service_client import ServiceClient
from snet.sdk.storage_provider.storage_provider import StorageProvider
from snet.sdk.custom_typing import ModuleName, ServiceStub
from snet.sdk.utils.utils import (bytes32_to_str, find_file_by_keyword,
type_converter)
from snet.sdk.utils.utils import (
bytes32_to_str,
find_file_by_keyword,
type_converter
)

google.protobuf.internal.api_implementation.Type = lambda: 'python'
_sym_db = _symbol_database.Default()
_sym_db.RegisterMessage = lambda x: None


class PaymentStrategyType(Enum):
PAID_CALL = PaidCallPaymentStrategy
FREE_CALL = FreeCallPaymentStrategy
PREPAID_CALL = PrePaidPaymentStrategy
DEFAULT = DefaultPaymentStrategy


class SnetSDK:
"""Base Snet SDK"""

Expand Down Expand Up @@ -91,8 +102,9 @@ def __init__(self, sdk_config: Config, metadata_provider=None):
def create_service_client(self,
org_id: str,
service_id: str,
group_name=None,
payment_strategy=None,
group_name: str=None,
payment_strategy: PaymentStrategy = None,
payment_strategy_type: PaymentStrategyType=PaymentStrategyType.DEFAULT,
address=None,
options=None,
concurrent_calls: int = 1):
Expand All @@ -118,15 +130,14 @@ def create_service_client(self,
print("Generating client library...")
self.lib_generator.generate_client_library()

if payment_strategy is None:
payment_strategy = DefaultPaymentStrategy(
concurrent_calls=concurrent_calls
)

if options is None:
options = dict()
options['user_address'] = address if address else ""
options['concurrency'] = self._sdk_config.get("concurrency", True)
options['concurrent_calls'] = concurrent_calls

if payment_strategy is None:
payment_strategy = payment_strategy_type.value()

service_metadata = self._metadata_provider.enhance_service_metadata(
org_id, service_id
Expand All @@ -137,7 +148,8 @@ def create_service_client(self,

pb2_module = self.get_module_by_keyword(keyword="pb2.py")
_service_client = ServiceClient(org_id, service_id, service_metadata,
group, service_stubs, payment_strategy,
group, service_stubs,
payment_strategy,
options, self.mpe_contract,
self.account, self.web3, pb2_module,
self.payment_channel_provider,
Expand Down
13 changes: 8 additions & 5 deletions snet/sdk/concurrency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import grpc
import web3

from snet.sdk.service_client import ServiceClient
from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path


class ConcurrencyManager:
def __init__(self, concurrent_calls: int):
def __init__(self, concurrent_calls: int=1):
self.__concurrent_calls: int = concurrent_calls
self.__token: str = ''
self.__planned_amount: int = 0
Expand All @@ -18,14 +17,18 @@ def __init__(self, concurrent_calls: int):
def concurrent_calls(self) -> int:
return self.__concurrent_calls

@concurrent_calls.setter
def concurrent_calls(self, concurrent_calls: int):
self.__concurrent_calls = concurrent_calls

def get_token(self, service_client, channel, service_call_price):
if len(self.__token) == 0:
self.__token = self.__get_token(service_client, channel, service_call_price)
elif self.__used_amount >= self.__planned_amount:
self.__token = self.__get_token(service_client, channel, service_call_price, new_token=True)
return self.__token

def __get_token(self, service_client: ServiceClient, channel, service_call_price, new_token=False):
def __get_token(self, service_client, channel, service_call_price, new_token=False):
if not new_token:
amount = channel.state["last_signed_amount"]
if amount != 0:
Expand All @@ -47,13 +50,13 @@ def __get_token(self, service_client: ServiceClient, channel, service_call_price
self.__planned_amount = token_reply.planned_amount
return token_reply.token

def __get_stub_for_get_token(self, service_client: ServiceClient):
def __get_stub_for_get_token(self, service_client):
grpc_channel = service_client.get_grpc_base_channel()
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
token_service_pb2_grpc = importlib.import_module("token_service_pb2_grpc")
return token_service_pb2_grpc.TokenServiceStub(grpc_channel)

def __get_token_for_amount(self, service_client: ServiceClient, channel, amount):
def __get_token_for_amount(self, service_client, channel, amount):
nonce = channel.state["nonce"]
stub = self.__get_stub_for_get_token(service_client)
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
Expand Down
18 changes: 7 additions & 11 deletions snet/sdk/payment_strategies/default_payment_strategy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from snet.sdk.concurrency_manager import ConcurrencyManager
from snet.sdk.payment_strategies.freecall_payment_strategy import FreeCallPaymentStrategy
from snet.sdk.payment_strategies.paidcall_payment_strategy import PaidCallPaymentStrategy
from snet.sdk.payment_strategies.prepaid_payment_strategy import PrePaidPaymentStrategy
Expand All @@ -7,26 +6,22 @@

class DefaultPaymentStrategy(PaymentStrategy):

def __init__(self, concurrent_calls: int = 1):
self.concurrent_calls = concurrent_calls
self.concurrencyManager = ConcurrencyManager(concurrent_calls)
def __init__(self):
self.channel = None

def set_concurrency_token(self, token):
self.concurrencyManager.__token = token

def set_channel(self, channel):
self.channel = channel

def get_payment_metadata(self, service_client):
free_call_payment_strategy = FreeCallPaymentStrategy()

if free_call_payment_strategy.is_free_call_available(service_client):
if free_call_payment_strategy.get_free_calls_available(service_client) > 0:
metadata = free_call_payment_strategy.get_payment_metadata(service_client)
else:
if service_client.get_concurrency_flag():
payment_strategy = PrePaidPaymentStrategy(self.concurrencyManager)
metadata = payment_strategy.get_payment_metadata(service_client, self.channel)
concurrent_calls = service_client.get_concurrent_calls()
payment_strategy = PrePaidPaymentStrategy(concurrent_calls)
metadata = payment_strategy.get_payment_metadata(service_client)
else:
payment_strategy = PaidCallPaymentStrategy()
metadata = payment_strategy.get_payment_metadata(service_client)
Expand All @@ -37,5 +32,6 @@ def get_price(self, service_client):
pass

def get_concurrency_token_and_channel(self, service_client):
payment_strategy = PrePaidPaymentStrategy(self.concurrencyManager)
concurrent_calls = service_client.get_concurrent_calls()
payment_strategy = PrePaidPaymentStrategy(concurrent_calls)
return payment_strategy.get_concurrency_token_and_channel(service_client)
113 changes: 52 additions & 61 deletions snet/sdk/payment_strategies/freecall_payment_strategy.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,99 @@
import importlib
from urllib.parse import urlparse

import grpc
from grpc import _channel
import web3

from snet.sdk.payment_strategies.payment_strategy import PaymentStrategy
from snet.sdk.resources.root_certificate import certificate
from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path

class FreeCallPaymentStrategy(PaymentStrategy):

def is_free_call_available(self, service_client) -> bool:
try:
self._user_address = service_client.options["user_address"]
self._free_call_token, self._token_expiry_date_block = self.get_free_call_token_details(service_client)
def __init__(self):
self._user_address = None
self._free_call_token = None
self._token_expiration_block = None

if not self._free_call_token:
return False
def get_free_calls_available(self, service_client) -> int:
if not self._user_address:
self._user_address = service_client.account.signer_address

with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
state_service_pb2 = importlib.import_module("state_service_pb2")
current_block_number = service_client.get_current_block_number()

with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")
if (not self._free_call_token or
not self._token_expiration_block or
current_block_number > self._token_expiration_block):
self._free_call_token, self._token_expiration_block = self.get_free_call_token_details(service_client)

signature, current_block_number = self.generate_signature(service_client)
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
state_service_pb2 = importlib.import_module("state_service_pb2")

request = state_service_pb2.FreeCallStateRequest()
request.user_address = self._user_address
request.token_for_free_call = self._free_call_token
request.token_expiry_date_block = self._token_expiry_date_block
request.signature = signature
request.current_block = current_block_number
with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")

channel = self.select_channel(service_client)
signature, _ = self.generate_signature(service_client, current_block_number)
request = state_service_pb2.FreeCallStateRequest(
address=self._user_address,
free_call_token=self._free_call_token,
signature=signature,
current_block=current_block_number
)

stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
channel = service_client.get_grpc_base_channel()
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)

try:
response = stub.GetFreeCallsAvailable(request)
if response.free_calls_available > 0:
return True
return False
return response.free_calls_available
except grpc.RpcError as e:
if self._user_address:
print(f"Warning: {e.details()}")
return False
except Exception as e:
return False
return 0

def get_payment_metadata(self, service_client) -> list:
if self.get_free_calls_available(service_client) <= 0:
raise Exception(f"Free calls limit for address {self._user_address} has expired. Please use another payment strategy")
signature, current_block_number = self.generate_signature(service_client)
metadata = [("snet-free-call-auth-token-bin", self._free_call_token),
("snet-free-call-token-expiry-block", str(self._token_expiry_date_block)),
("snet-payment-type", "free-call"),
("snet-free-call-user-id", self._user_address),
("snet-free-call-user-address", self._user_address),
("snet-current-block-number", str(current_block_number)),
("snet-payment-channel-signature-bin", signature)]

return metadata

def select_channel(self, service_client) -> _channel.Channel:
_, _, _, daemon_endpoint = service_client.get_service_details()
endpoint_object = urlparse(daemon_endpoint)
if endpoint_object.port is not None:
channel_endpoint = endpoint_object.hostname + ":" + str(endpoint_object.port)
else:
channel_endpoint = endpoint_object.hostname

if endpoint_object.scheme == "http":
channel = grpc.insecure_channel(channel_endpoint)
elif endpoint_object.scheme == "https":
channel = grpc.secure_channel(channel_endpoint, grpc.ssl_channel_credentials(root_certificates=certificate))
else:
raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme))
return channel

def generate_signature(self, service_client) -> tuple[bytes, int]:
def generate_signature(self, service_client, current_block_number=None, with_token=True) -> tuple[bytes, int]:
if not current_block_number:
current_block_number = service_client.get_current_block_number()
org_id, service_id, group_id, _ = service_client.get_service_details()

if self._token_expiry_date_block == 0 or len(self._user_address) == 0 or len(self._free_call_token) == 0:
raise Exception(
"You are using default 'FreeCallPaymentStrategy' to use this strategy you need to pass "
"'free_call_auth_token-bin','user_address','free-call-token-expiry-block' in config")
message_types = ["string", "string", "string", "string", "string", "uint256", "bytes32"]
message_values = ["__prefix_free_trial", self._user_address, org_id, service_id, group_id,
current_block_number, self._free_call_token]

current_block_number = service_client.get_current_block_number()
if not with_token:
message_types = message_types[:-1]
message_values = message_values[:-1]

message = web3.Web3.solidity_keccak(
["string", "string", "string", "string", "string", "uint256", "bytes32"],
["__prefix_free_trial", self._user_address, org_id, service_id, group_id, current_block_number,
self._free_call_token]
)
message = web3.Web3.solidity_keccak(message_types, message_values)
return service_client.generate_signature(message), current_block_number

def get_free_call_token_details(self, service_client) -> tuple[bytes, int]:
def get_free_call_token_details(self, service_client, current_block_number=None) -> tuple[bytes, int]:

signature, current_block_number = self.generate_signature(service_client, current_block_number, with_token=False)

with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
state_service_pb2 = importlib.import_module("state_service_pb2")

request = state_service_pb2.GetFreeCallTokenRequest()
request.address = self._user_address
request = state_service_pb2.GetFreeCallTokenRequest(
address=self._user_address,
signature=signature,
current_block=current_block_number
)

with add_to_path(str(RESOURCES_PATH.joinpath("proto"))):
state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc")

channel = self.select_channel(service_client)
channel = service_client.get_grpc_base_channel()
stub = state_service_pb2_grpc.FreeCallStateServiceStub(channel)
response = stub.GetFreeCallToken(request)

Expand Down
13 changes: 7 additions & 6 deletions snet/sdk/payment_strategies/prepaid_payment_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@

class PrePaidPaymentStrategy(PaymentStrategy):

def __init__(self, concurrency_manager: ConcurrencyManager,
block_offset: int = 240, call_allowance: int = 1):
self.concurrency_manager = concurrency_manager
def __init__(self, concurrent_calls: int=1, block_offset: int = 240, call_allowance: int = 1):
self.concurrency_manager = ConcurrencyManager(concurrent_calls)
self.block_offset = block_offset
self.call_allowance = call_allowance

def get_price(self, service_client):
return service_client.get_price() * self.concurrency_manager.concurrent_calls

def get_payment_metadata(self, service_client, channel):
if channel is None:
channel = self.select_channel(service_client)
def set_concurrent_calls(self, concurrent_calls):
self.concurrency_manager.concurrent_calls = concurrent_calls

def get_payment_metadata(self, service_client):
channel = self.select_channel(service_client)
token = self.concurrency_manager.get_token(service_client, channel, self.get_price(service_client))
metadata = [
("snet-payment-type", "prepaid-call"),
Expand Down
Loading
Loading